""" Expression module defines the structure of an expression. """ from typing import Union from calculator.operators import Operator Term: type = float Token: type = Union[Operator, Term] class OperatorExpression: """ OperatorExpression class is an expression that contains an operator and two sub-expressions. """ def __init__(self, operator: Operator, left, right): self.operator = operator self.left = left self.right = right def __repr__(self): return f"({self.left} {self.operator} {self.right})" def __call__(self) -> Term: return self.operator(self.left(), self.right()) class TermExpression: """ TermExpression class is an expression that contains a single term. """ def __init__(self, value: Term): self.value = value def __repr__(self): return str(self.value) def __call__(self) -> Term: return self.value Expression: type = Union[OperatorExpression, TermExpression] def test_single_term(): """ Test the TermExpression class. """ expression = TermExpression(42) assert repr(expression) == '42' assert expression() == 42 def test_single_operator(): """ Test the OperatorExpression class. """ add = Operator('+', 1, lambda a, b: a + b) expression = OperatorExpression(add, TermExpression(1), TermExpression(2)) assert repr(expression) == '(1 + 2)' assert expression() == 3 def test_complex_expression(): """ Test a complex expression. """ add = Operator('+', 1, lambda a, b: a + b) multiply = Operator('*', 2, lambda a, b: a * b) expression = OperatorExpression( multiply, OperatorExpression(add, TermExpression(1), TermExpression(2)), TermExpression(3) ) assert repr(expression) == '((1 + 2) * 3)' assert expression() == 9