| """ |
| Calculator Tool - Safe mathematical expression evaluation |
| Author: @mangubee |
| Date: 2026-01-02 |
| |
| Provides safe evaluation of mathematical expressions with: |
| - Whitelisted operations and functions |
| - Timeout protection |
| - Complexity limits |
| - No access to dangerous built-ins |
| |
| Security is prioritized over functionality. |
| """ |
|
|
| import ast |
| import math |
| import operator |
| import logging |
| from typing import Any, Dict |
| import signal |
| from contextlib import contextmanager |
|
|
| |
| |
| |
| MAX_EXPRESSION_LENGTH = 500 |
| MAX_EVAL_TIME_SECONDS = 2 |
| MAX_NUMBER_SIZE = 10**100 |
|
|
| |
| SAFE_OPERATORS = { |
| ast.Add: operator.add, |
| ast.Sub: operator.sub, |
| ast.Mult: operator.mul, |
| ast.Div: operator.truediv, |
| ast.FloorDiv: operator.floordiv, |
| ast.Mod: operator.mod, |
| ast.Pow: operator.pow, |
| ast.USub: operator.neg, |
| ast.UAdd: operator.pos, |
| } |
|
|
| |
| SAFE_FUNCTIONS = { |
| 'abs': abs, |
| 'round': round, |
| 'min': min, |
| 'max': max, |
| 'sum': sum, |
| |
| 'sqrt': math.sqrt, |
| 'ceil': math.ceil, |
| 'floor': math.floor, |
| 'log': math.log, |
| 'log10': math.log10, |
| 'exp': math.exp, |
| 'sin': math.sin, |
| 'cos': math.cos, |
| 'tan': math.tan, |
| 'asin': math.asin, |
| 'acos': math.acos, |
| 'atan': math.atan, |
| 'degrees': math.degrees, |
| 'radians': math.radians, |
| 'factorial': math.factorial, |
| |
| 'pi': math.pi, |
| 'e': math.e, |
| } |
|
|
| |
| |
| |
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| class TimeoutError(Exception): |
| """Raised when evaluation exceeds timeout""" |
| pass |
|
|
|
|
| @contextmanager |
| def timeout(seconds: int): |
| """ |
| Context manager for timeout protection. |
| |
| Args: |
| seconds: Maximum execution time |
| |
| Raises: |
| TimeoutError: If execution exceeds timeout |
| |
| Note: |
| signal.alarm() only works in main thread. In threaded contexts |
| (Gradio, ThreadPoolExecutor), timeout protection is disabled. |
| """ |
| def timeout_handler(signum, frame): |
| raise TimeoutError(f"Evaluation exceeded {seconds} second timeout") |
|
|
| try: |
| |
| old_handler = signal.signal(signal.SIGALRM, timeout_handler) |
| signal.alarm(seconds) |
| _alarm_set = True |
| except (ValueError, AttributeError): |
| |
| |
| logger.warning(f"Timeout protection disabled (threading/Windows limitation)") |
| _alarm_set = False |
| old_handler = None |
|
|
| try: |
| yield |
| finally: |
| |
| if _alarm_set and old_handler is not None: |
| signal.alarm(0) |
| signal.signal(signal.SIGALRM, old_handler) |
|
|
|
|
| |
| |
| |
|
|
| class SafeEvaluator(ast.NodeVisitor): |
| """ |
| AST visitor that evaluates mathematical expressions safely. |
| |
| Only allows whitelisted operations and functions. |
| Prevents code execution, attribute access, and other dangerous operations. |
| """ |
|
|
| def visit_Expression(self, node): |
| """Visit Expression node (root of parse tree)""" |
| return self.visit(node.body) |
|
|
| def visit_Constant(self, node): |
| """Visit Constant node (numbers, strings)""" |
| value = node.value |
|
|
| |
| if not isinstance(value, (int, float, complex)): |
| raise ValueError(f"Unsupported constant type: {type(value).__name__}") |
|
|
| |
| if isinstance(value, (int, float)) and abs(value) > MAX_NUMBER_SIZE: |
| raise ValueError(f"Number too large: {value}") |
|
|
| return value |
|
|
| def visit_BinOp(self, node): |
| """Visit binary operation node (+, -, *, /, etc.)""" |
| op_type = type(node.op) |
|
|
| if op_type not in SAFE_OPERATORS: |
| raise ValueError(f"Unsupported operation: {op_type.__name__}") |
|
|
| left = self.visit(node.left) |
| right = self.visit(node.right) |
|
|
| op_func = SAFE_OPERATORS[op_type] |
|
|
| |
| if op_type in (ast.Div, ast.FloorDiv, ast.Mod) and right == 0: |
| raise ZeroDivisionError("Division by zero") |
|
|
| |
| if op_type == ast.Pow and abs(right) > 1000: |
| raise ValueError(f"Exponent too large: {right}") |
|
|
| return op_func(left, right) |
|
|
| def visit_UnaryOp(self, node): |
| """Visit unary operation node (-, +)""" |
| op_type = type(node.op) |
|
|
| if op_type not in SAFE_OPERATORS: |
| raise ValueError(f"Unsupported unary operation: {op_type.__name__}") |
|
|
| operand = self.visit(node.operand) |
| op_func = SAFE_OPERATORS[op_type] |
|
|
| return op_func(operand) |
|
|
| def visit_Call(self, node): |
| """Visit function call node""" |
| |
| if not isinstance(node.func, ast.Name): |
| raise ValueError("Only direct function calls are allowed") |
|
|
| func_name = node.func.id |
|
|
| if func_name not in SAFE_FUNCTIONS: |
| raise ValueError(f"Unsupported function: {func_name}") |
|
|
| |
| args = [self.visit(arg) for arg in node.args] |
|
|
| |
| if node.keywords: |
| raise ValueError("Keyword arguments not allowed") |
|
|
| func = SAFE_FUNCTIONS[func_name] |
|
|
| try: |
| return func(*args) |
| except Exception as e: |
| raise ValueError(f"Error calling {func_name}: {str(e)}") |
|
|
| def visit_Name(self, node): |
| """Visit name node (variable/constant reference)""" |
| |
| if node.id in SAFE_FUNCTIONS: |
| value = SAFE_FUNCTIONS[node.id] |
| |
| if not callable(value): |
| return value |
|
|
| raise ValueError(f"Undefined name: {node.id}") |
|
|
| def visit_List(self, node): |
| """Visit list node""" |
| return [self.visit(element) for element in node.elts] |
|
|
| def visit_Tuple(self, node): |
| """Visit tuple node""" |
| return tuple(self.visit(element) for element in node.elts) |
|
|
| def generic_visit(self, node): |
| """Catch-all for unsupported node types""" |
| raise ValueError(f"Unsupported expression type: {type(node).__name__}") |
|
|
|
|
| |
| |
| |
|
|
| def safe_eval(expression: str) -> Dict[str, Any]: |
| """ |
| Safely evaluate a mathematical expression. |
| |
| Args: |
| expression: Mathematical expression string |
| |
| Returns: |
| Dict with structure: { |
| "result": float or int, # Evaluation result |
| "expression": str, # Original expression |
| "success": bool # True if evaluation succeeded |
| } |
| |
| Raises: |
| ValueError: For invalid or unsafe expressions |
| ZeroDivisionError: For division by zero |
| TimeoutError: If evaluation exceeds timeout |
| SyntaxError: For malformed expressions |
| |
| Examples: |
| >>> safe_eval("2 + 2") |
| {"result": 4, "expression": "2 + 2", "success": True} |
| |
| >>> safe_eval("sqrt(16) + 3") |
| {"result": 7.0, "expression": "sqrt(16) + 3", "success": True} |
| |
| >>> safe_eval("import os") # Raises ValueError |
| """ |
| |
| if not expression or not isinstance(expression, str): |
| logger.warning("Calculator received empty or non-string expression - returning graceful error") |
| return { |
| "result": None, |
| "expression": str(expression) if expression else "", |
| "success": False, |
| "error": "Empty expression provided. Calculator requires a mathematical expression string." |
| } |
|
|
| expression = expression.strip() |
|
|
| |
| if not expression: |
| logger.warning("Calculator expression was only whitespace - returning graceful error") |
| return { |
| "result": None, |
| "expression": "", |
| "success": False, |
| "error": "Expression was only whitespace. Provide a valid mathematical expression." |
| } |
|
|
| if len(expression) > MAX_EXPRESSION_LENGTH: |
| logger.warning(f"Expression too long ({len(expression)} chars) - returning graceful error") |
| return { |
| "result": None, |
| "expression": expression[:100] + "...", |
| "success": False, |
| "error": f"Expression too long ({len(expression)} chars). Maximum: {MAX_EXPRESSION_LENGTH} chars" |
| } |
|
|
| logger.info(f"Evaluating expression: {expression}") |
|
|
| try: |
| |
| tree = ast.parse(expression, mode='eval') |
|
|
| |
| with timeout(MAX_EVAL_TIME_SECONDS): |
| evaluator = SafeEvaluator() |
| result = evaluator.visit(tree) |
|
|
| logger.info(f"Evaluation successful: {result}") |
|
|
| return { |
| "result": result, |
| "expression": expression, |
| "success": True, |
| } |
|
|
| except SyntaxError as e: |
| logger.error(f"Syntax error in expression: {e}") |
| raise SyntaxError(f"Invalid expression syntax: {str(e)}") |
| except ZeroDivisionError as e: |
| logger.error(f"Division by zero: {expression}") |
| raise |
| except TimeoutError as e: |
| logger.error(f"Evaluation timeout: {expression}") |
| raise |
| except ValueError as e: |
| logger.error(f"Invalid expression: {e}") |
| raise |
| except Exception as e: |
| logger.error(f"Unexpected error evaluating expression: {e}") |
| raise ValueError(f"Evaluation error: {str(e)}") |
|
|