diff --git a/calculator/calculator.py b/calculator/calculator.py index ef0da93e221160bb9ecee49d1d917dc1bca739b2..a45531eb9cb9f91282b16c38b3ce0f7007c95065 100644 --- a/calculator/calculator.py +++ b/calculator/calculator.py @@ -1,7 +1,7 @@ """ This Calculator holds the logic for the calculator. """ - +import pytest from calculator.operators import Operator, STANDARD_OPERATORS from calculator.expression import Token, Term, Expression, TermExpression, OperatorExpression @@ -13,6 +13,7 @@ class Calculator: operator ::= + | - | * | / with the usual precedence rules. """ + def __init__(self, operators=None): if operators is None: operators = STANDARD_OPERATORS @@ -64,3 +65,55 @@ class Calculator: def __call__(self, expression: str) -> Term: return self.parse(self.tokenize(expression))() + + +@pytest.fixture(scope="module", name="setup") +def fixture_setup(): + """ + Setup the test suite, by instantiating the calculator and the operators. + """ + plus = Operator('+', 1, lambda a, b: a + b) + minus = Operator('-', 1, lambda a, b: a - b) + times = Operator('*', 2, lambda a, b: a * b) + divide = Operator('/', 2, lambda a, b: a / b) + calculator = Calculator( + operators={'+': plus, '-': minus, '*': times, '/': divide}) + yield plus, minus, times, divide, calculator + + +def test_tokenizer(setup): + """ + Test the tokenizer. + """ + plus, minus, times, divide, calc = setup + assert calc.tokenize("1 + 2") == [1.0, plus, 2.0] + assert calc.tokenize("1 + 2 * 3") == [1.0, plus, 2.0, times, 3.0] + assert calc.tokenize( + "1 + 2 * 3 / 4") == [1.0, plus, 2.0, times, 3.0, divide, 4.0] + assert calc.tokenize( + "1 + 2 * 3 / 4 - 5") == [1.0, plus, 2.0, times, 3.0, divide, 4.0, minus, 5.0] + + +def test_parser(setup): + """ + Test the parser. + """ + _, _, _, _, calc = setup + assert repr(calc.parse(calc.tokenize("1 + 2"))) == '(1.0 + 2.0)' + assert repr(calc.parse(calc.tokenize("1 + 2 * 3")) + ) == '(1.0 + (2.0 * 3.0))' + assert repr(calc.parse(calc.tokenize( + "1 + 2 * 3 / 4"))) == '(1.0 + ((2.0 * 3.0) / 4.0))' + assert repr(calc.parse(calc.tokenize( + "1 + 2 * 3 / 4 - 5"))) == '((1.0 + ((2.0 * 3.0) / 4.0)) - 5.0)' + + +def test_evaluation(setup): + """ + Test the evaluation. + """ + _, _, _, _, calc = setup + assert calc("1 + 2") == 3 + assert calc("1 + 2 * 3") == 7 + assert calc("1 + 2 * 3 / 4") == 2.5 + assert calc("1 + 2 * 3 / 4 - 5") == -2.5 diff --git a/calculator/expression.py b/calculator/expression.py index c461c95f688b652fe5fe2d3fe357c10ac1ab2350..a4106d02598ee724f931d166a077d688154fbec0 100644 --- a/calculator/expression.py +++ b/calculator/expression.py @@ -4,6 +4,7 @@ Expression module defines the structure of an expression. from typing import Union from calculator.operators import Operator + Term: type = float Token: type = Union[Operator, Term] @@ -12,6 +13,7 @@ class OperatorExpression: """ OperatorExpression class is an expression that contains an operator and two sub-expressions. """ + def __init__(self, operator: Operator, left, right): self.operator = operator self.left = left @@ -28,6 +30,7 @@ class TermExpression: """ TermExpression class is an expression that contains a single term. """ + def __init__(self, value: Term): self.value = value @@ -39,3 +42,37 @@ class TermExpression: Expression: type = Union[OperatorExpression, TermExpression] + + +def test_single_term(): + """ + Test the TermExpression class. + """ + expression = TermExpression(42) + assert repr(expression) == '42' + assert expression() == 42 + + +def test_single_operator(): + """ + Test the OperatorExpression class. + """ + add = Operator('+', 1, lambda a, b: a + b) + expression = OperatorExpression(add, TermExpression(1), TermExpression(2)) + assert repr(expression) == '(1 + 2)' + assert expression() == 3 + + +def test_complex_expression(): + """ + Test a complex expression. + """ + add = Operator('+', 1, lambda a, b: a + b) + multiply = Operator('*', 2, lambda a, b: a * b) + expression = OperatorExpression( + multiply, + OperatorExpression(add, TermExpression(1), TermExpression(2)), + TermExpression(3) + ) + assert repr(expression) == '((1 + 2) * 3)' + assert expression() == 9