"""
Expression module defines the structure of an expression.
"""
from typing import Union
from calculator.operators import Operator


Term: type = float
Token: type = Union[Operator, Term]


class OperatorExpression:
    """
    OperatorExpression class is an expression that contains an operator and two sub-expressions.
    """

    def __init__(self, operator: Operator, left, right):
        self.operator = operator
        self.left = left
        self.right = right

    def __repr__(self):
        return f"({self.left} {self.operator} {self.right})"

    def __call__(self) -> Term:
        return self.operator(self.left(), self.right())


class TermExpression:
    """
    TermExpression class is an expression that contains a single term.
    """

    def __init__(self, value: Term):
        self.value = value

    def __repr__(self):
        return str(self.value)

    def __call__(self) -> Term:
        return self.value


Expression: type = Union[OperatorExpression, TermExpression]


def test_single_term():
    """
    Test the TermExpression class.
    """
    expression = TermExpression(42)
    assert repr(expression) == '42'
    assert expression() == 42


def test_single_operator():
    """
    Test the OperatorExpression class.
    """
    add = Operator('+', 1, lambda a, b: a + b)
    expression = OperatorExpression(add, TermExpression(1), TermExpression(2))
    assert repr(expression) == '(1 + 2)'
    assert expression() == 3


def test_complex_expression():
    """
    Test a complex expression.
    """
    add = Operator('+', 1, lambda a, b: a + b)
    multiply = Operator('*', 2, lambda a, b: a * b)
    expression = OperatorExpression(
        multiply,
        OperatorExpression(add, TermExpression(1), TermExpression(2)),
        TermExpression(3)
    )
    assert repr(expression) == '((1 + 2) * 3)'
    assert expression() == 9