Spaces:
Sleeping
Sleeping
| """Parser for agent-emitted equations of motion. | |
| RHS is parsed via Python's ``ast`` module, then walked by a whitelist visitor | |
| that only permits Constant / Name / UnaryOp (+/-) / BinOp (+ - * / **) / | |
| Call (bare allowed-function name, no kwargs). Anything else — Attribute, | |
| Subscript, Lambda, IfExp, keyword args, etc. — raises ParseError by | |
| construction. We never call sympify on raw text, so there is no eval stage | |
| that can crash the trainer with an AttributeError. | |
| Pre-transforms before AST parse: | |
| - ``^`` → ``**`` (physics power notation) | |
| - ``dx/dt`` / bare ``dx`` → ``vx`` when the system pairs x with vx | |
| """ | |
| from __future__ import annotations | |
| import ast | |
| import re | |
| import sympy as sp | |
| from pydantic import BaseModel, ConfigDict | |
| class ParseError(ValueError): | |
| """Raised when the agent's text payload violates the equation grammar.""" | |
| ALLOWED_FUNCTIONS: dict[str, sp.Function] = { | |
| "sin": sp.sin, | |
| "cos": sp.cos, | |
| "tan": sp.tan, | |
| "exp": sp.exp, | |
| "log": sp.log, | |
| "sqrt": sp.sqrt, | |
| "abs": sp.Abs, | |
| "Abs": sp.Abs, | |
| } | |
| def _build_grammar_hint() -> str: | |
| funcs = sorted({name.lower() for name in ALLOWED_FUNCTIONS}) | |
| return ( | |
| "The 'equation' field is an infix ODE in plain ASCII. " | |
| "LHS form: 'dN<var>/dtN' where N is 1 or 2 (omit N for first " | |
| "order, e.g. 'dy/dt' or 'd2y/dt2'). " | |
| "RHS uses operators + - * / ** (or ^ for power), parentheses, " | |
| "the state variables listed under STATE_VARIABLES, and any " | |
| "names you declare in 'params'. " | |
| f"Allowed functions: {' '.join(funcs)}. " | |
| "Velocity convention: when STATE_VARIABLES lists both 'x' and 'vx' " | |
| "(or 'y'/'vy', etc.), use the 'vx' name on the RHS to refer to the " | |
| "first time-derivative of x. The aliases 'dx/dt' and bare 'dx' are " | |
| "also accepted for that case. The system is autonomous: time 't' is " | |
| "not a valid RHS symbol. " | |
| "No LaTeX, no \\frac, no array indexing, no library prefixes " | |
| "(write 'sqrt(x)', not 'np.sqrt(x)'), no keyword arguments. " | |
| "Working examples appear in the HISTORY block of each subsequent turn." | |
| ) | |
| GRAMMAR_HINT: str = _build_grammar_hint() | |
| _LHS_PATTERN = re.compile( | |
| r""" | |
| ^\s* | |
| d(?P<order>\d*) | |
| (?P<var>[A-Za-z_][A-Za-z0-9_]*) | |
| / | |
| d t | |
| (?P<order2>\d*) | |
| \s*$ | |
| """, | |
| re.VERBOSE, | |
| ) | |
| _BIN_OP_TO_SYMPY: dict[type, "callable"] = { | |
| ast.Add: lambda a, b: a + b, | |
| ast.Sub: lambda a, b: a - b, | |
| ast.Mult: lambda a, b: a * b, | |
| ast.Div: lambda a, b: a / b, | |
| ast.Pow: lambda a, b: a**b, | |
| } | |
| class Equation(BaseModel): | |
| model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) | |
| var: str | |
| order: int | |
| rhs: sp.Expr | |
| class ParsedEquation(BaseModel): | |
| model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) | |
| equations: tuple[Equation, ...] | |
| free_symbols: frozenset[str] | |
| operator_count: int | |
| def parse_equation( | |
| text: str, | |
| state_variables: tuple[str, ...], | |
| parameter_names: frozenset[str], | |
| ) -> ParsedEquation: | |
| """Parse and validate the agent's equation payload. | |
| Only ParseError ever escapes — callers convert it to r_format=0. | |
| """ | |
| if not text or not text.strip(): | |
| raise ParseError("Empty equation payload.") | |
| raw_equations = _split_equations(text) | |
| if not raw_equations: | |
| raise ParseError("No equations found in payload.") | |
| allowed_symbols = frozenset(state_variables) | parameter_names | |
| parsed: list[Equation] = [] | |
| free_symbol_names: set[str] = set() | |
| operator_count = 0 | |
| for raw in raw_equations: | |
| eq = _parse_one(raw, allowed_symbols, state_variables) | |
| parsed.append(eq) | |
| free_symbol_names.update(s.name for s in eq.rhs.free_symbols) | |
| operator_count += _count_operators(eq.rhs) | |
| return ParsedEquation( | |
| equations=tuple(parsed), | |
| free_symbols=frozenset(free_symbol_names), | |
| operator_count=operator_count, | |
| ) | |
| def _split_equations(text: str) -> list[str]: | |
| parts = re.split(r"[;\n]+", text) | |
| return [p.strip() for p in parts if p.strip()] | |
| def _parse_one( | |
| raw: str, | |
| allowed_symbols: frozenset[str], | |
| state_variables: tuple[str, ...], | |
| ) -> Equation: | |
| if "=" not in raw: | |
| raise ParseError(f"Equation has no '=' sign: {raw!r}") | |
| lhs_text, rhs_text = raw.split("=", 1) | |
| var, order = _parse_lhs(lhs_text) | |
| rhs_expr = _parse_rhs(rhs_text, allowed_symbols, state_variables) | |
| return Equation(var=var, order=order, rhs=rhs_expr) | |
| def _parse_lhs(lhs: str) -> tuple[str, int]: | |
| match = _LHS_PATTERN.match(lhs) | |
| if not match: | |
| raise ParseError( | |
| f"Cannot parse LHS {lhs!r}. Expected 'dN<var>/dtN' where N is " | |
| "1 or 2 (or empty for first order)." | |
| ) | |
| order_top = match.group("order") | |
| order_bot = match.group("order2") | |
| var = match.group("var") | |
| if order_top != order_bot: | |
| raise ParseError( | |
| f"LHS order mismatch in {lhs!r}: top order {order_top!r} vs " | |
| f"bottom order {order_bot!r}." | |
| ) | |
| if order_top == "": | |
| order = 1 | |
| elif order_top in {"1", "2"}: | |
| order = int(order_top) | |
| else: | |
| raise ParseError(f"Only orders 1 and 2 are supported. Got {order_top!r}.") | |
| return var, order | |
| def _parse_rhs( | |
| rhs: str, | |
| allowed_symbols: frozenset[str], | |
| state_variables: tuple[str, ...], | |
| ) -> sp.Expr: | |
| rhs = rhs.strip() | |
| if not rhs: | |
| raise ParseError("Empty RHS.") | |
| rhs = rhs.replace("^", "**") | |
| rhs = _apply_velocity_alias(rhs, state_variables) | |
| try: | |
| tree = ast.parse(rhs, mode="eval") | |
| except SyntaxError as exc: | |
| raise ParseError( | |
| f"Syntax error in RHS {rhs!r}: {exc.msg}. " | |
| "Expected an infix expression like '-k*x + c*vx'." | |
| ) from exc | |
| return _ast_to_sympy(tree.body, allowed_symbols, state_variables) | |
| def _ast_to_sympy( | |
| node: ast.AST, | |
| allowed_symbols: frozenset[str], | |
| state_variables: tuple[str, ...], | |
| ) -> sp.Expr: | |
| if isinstance(node, ast.Constant): | |
| if isinstance(node.value, bool) or not isinstance(node.value, (int, float)): | |
| raise ParseError( | |
| f"Only numeric literals allowed on RHS; got " | |
| f"{node.value!r} ({type(node.value).__name__})." | |
| ) | |
| return sp.Number(node.value) | |
| if isinstance(node, ast.Name): | |
| return _name_to_sympy(node.id, allowed_symbols, state_variables) | |
| if isinstance(node, ast.UnaryOp): | |
| operand = _ast_to_sympy(node.operand, allowed_symbols, state_variables) | |
| if isinstance(node.op, ast.UAdd): | |
| return +operand | |
| if isinstance(node.op, ast.USub): | |
| return -operand | |
| raise ParseError( | |
| f"Unsupported unary operator {type(node.op).__name__}. " | |
| "Allowed: + (positive), - (negation)." | |
| ) | |
| if isinstance(node, ast.BinOp): | |
| op_fn = _BIN_OP_TO_SYMPY.get(type(node.op)) | |
| if op_fn is None: | |
| raise ParseError( | |
| f"Unsupported binary operator {type(node.op).__name__}. " | |
| "Allowed: + - * / ** (also '^' as a power synonym)." | |
| ) | |
| left = _ast_to_sympy(node.left, allowed_symbols, state_variables) | |
| right = _ast_to_sympy(node.right, allowed_symbols, state_variables) | |
| return op_fn(left, right) | |
| if isinstance(node, ast.Call): | |
| return _call_to_sympy(node, allowed_symbols, state_variables) | |
| if isinstance(node, ast.Attribute): | |
| raise ParseError( | |
| "Attribute access is not allowed in equation RHS " | |
| f"(saw '.{node.attr}'). Use bare function names like " | |
| "'sqrt(x)' or 'sin(theta)', not 'np.sqrt(x)'." | |
| ) | |
| if isinstance(node, ast.Subscript): | |
| raise ParseError( | |
| "Array indexing is not allowed in equation RHS. " | |
| "Use named scalars declared in 'params'." | |
| ) | |
| if isinstance(node, ast.Compare): | |
| raise ParseError( | |
| "Comparisons (==, <, >, etc.) are not allowed in equation RHS." | |
| ) | |
| if isinstance(node, ast.BoolOp): | |
| raise ParseError( | |
| "Boolean operators ('and', 'or') are not allowed in equation RHS." | |
| ) | |
| if isinstance(node, ast.IfExp): | |
| raise ParseError( | |
| "Conditional expressions ('a if cond else b') are not allowed in " | |
| "equation RHS." | |
| ) | |
| if isinstance(node, ast.Lambda): | |
| raise ParseError("Lambda expressions are not allowed in equation RHS.") | |
| if isinstance(node, (ast.Tuple, ast.List, ast.Set, ast.Dict)): | |
| raise ParseError( | |
| f"Collection literal ({type(node).__name__}) is not allowed in " | |
| "equation RHS." | |
| ) | |
| raise ParseError( | |
| f"Unsupported expression construct {type(node).__name__}. " | |
| "The grammar accepts: numeric literals, allowed identifiers, " | |
| f"+ - * / **, parentheses, and {sorted(ALLOWED_FUNCTIONS)}." | |
| ) | |
| def _name_to_sympy( | |
| name: str, | |
| allowed_symbols: frozenset[str], | |
| state_variables: tuple[str, ...], | |
| ) -> sp.Symbol: | |
| if name in ALLOWED_FUNCTIONS: | |
| raise ParseError( | |
| f"{name!r} is a function and must be called with parentheses, " | |
| f"e.g. {name}(x)." | |
| ) | |
| if name not in allowed_symbols: | |
| hint = _explain_unknown_symbol(name, state_variables) | |
| suffix = f" {hint}" if hint else "" | |
| raise ParseError( | |
| f"Unknown symbol {name!r}; allowed {sorted(allowed_symbols)!r}." | |
| f"{suffix}" | |
| ) | |
| return sp.Symbol(name) | |
| def _call_to_sympy( | |
| node: ast.Call, | |
| allowed_symbols: frozenset[str], | |
| state_variables: tuple[str, ...], | |
| ) -> sp.Expr: | |
| if node.keywords: | |
| raise ParseError( | |
| "Keyword arguments are not allowed in function calls " | |
| "(e.g. 'sin(theta=0.1)'). Pass positional arguments only." | |
| ) | |
| for arg in node.args: | |
| if isinstance(arg, ast.Starred): | |
| raise ParseError("Star-arg / unpacking ('*args') is not allowed.") | |
| if isinstance(node.func, ast.Attribute): | |
| raise ParseError( | |
| "Attribute access is not allowed in equation RHS " | |
| f"(saw '.{node.func.attr}'). Use bare function names like " | |
| "'sqrt(x)' or 'sin(theta)', not 'np.sqrt(x)'." | |
| ) | |
| if not isinstance(node.func, ast.Name): | |
| raise ParseError( | |
| "Only direct calls to named functions are allowed. " | |
| f"Use one of {sorted(ALLOWED_FUNCTIONS)}, not a computed-name call." | |
| ) | |
| func_name = node.func.id | |
| if func_name not in ALLOWED_FUNCTIONS: | |
| raise ParseError( | |
| f"Unknown function {func_name!r}; " | |
| f"allowed: {sorted(ALLOWED_FUNCTIONS)}." | |
| ) | |
| args = [_ast_to_sympy(a, allowed_symbols, state_variables) for a in node.args] | |
| return ALLOWED_FUNCTIONS[func_name](*args) | |
| def _apply_velocity_alias(rhs: str, state_variables: tuple[str, ...]) -> str: | |
| aliases = _velocity_aliases(state_variables) | |
| if not aliases: | |
| return rhs | |
| out = rhs | |
| for var, velocity in aliases: | |
| slash_pattern = rf"\bd{re.escape(var)}\s*/\s*dt\b" | |
| out = re.sub(slash_pattern, velocity, out) | |
| bare_pattern = rf"\bd{re.escape(var)}\b" | |
| out = re.sub(bare_pattern, velocity, out) | |
| return out | |
| def _velocity_aliases(state_variables: tuple[str, ...]) -> list[tuple[str, str]]: | |
| state_set = set(state_variables) | |
| out: list[tuple[str, str]] = [] | |
| for var in state_variables: | |
| if not var or var.startswith(("d", "v")): | |
| continue | |
| velocity = f"v{var}" | |
| if velocity in state_set: | |
| out.append((var, velocity)) | |
| return out | |
| def _explain_unknown_symbol(name: str, state_variables: tuple[str, ...]) -> str: | |
| state_set = set(state_variables) | |
| if name == "t": | |
| return ( | |
| "'t' is not allowed — the equation must be autonomous " | |
| "(express forces via state variables only, no explicit time)." | |
| ) | |
| if name.startswith("d") and len(name) > 1: | |
| base = name[1:] | |
| velocity = f"v{base}" | |
| if velocity in state_set: | |
| return ( | |
| f"Did you mean '{velocity}'? " | |
| f"Use '{velocity}' for the velocity of '{base}'." | |
| ) | |
| if base in state_set: | |
| return ( | |
| f"'{name}' looks like a derivative; this system has no " | |
| f"separate velocity name, write '{base}' on the RHS." | |
| ) | |
| return "" | |
| def _count_operators(expr: sp.Expr) -> int: | |
| count = 0 | |
| for node in sp.preorder_traversal(expr): | |
| if not isinstance(node, (sp.Symbol, sp.Number)): | |
| count += 1 | |
| return count | |