| """AST-based script validator. | |
| Catches forbidden imports and dangerous patterns BEFORE any execution | |
| happens. This is a critical defense against reward hacking via system | |
| calls, network access, or process manipulation. | |
| """ | |
| from __future__ import annotations | |
| import ast | |
| from forgeenv.tasks.models import ValidationResult | |
| FORBIDDEN_MODULES = { | |
| "os", | |
| "subprocess", | |
| "socket", | |
| "urllib", | |
| "requests", | |
| "ctypes", | |
| "shutil", | |
| "signal", | |
| "multiprocessing", | |
| "threading", | |
| } | |
| FORBIDDEN_FUNCTIONS = {"eval", "exec", "compile", "__import__"} | |
| def validate_script(script_content: str) -> ValidationResult: | |
| """Parse a script as AST and reject forbidden patterns. | |
| Returns a ValidationResult with `is_valid` and a list of `violations`. | |
| """ | |
| violations: list[str] = [] | |
| try: | |
| tree = ast.parse(script_content) | |
| except SyntaxError as e: | |
| return ValidationResult(is_valid=False, violations=[f"SyntaxError: {e}"]) | |
| for node in ast.walk(tree): | |
| if isinstance(node, ast.Import): | |
| for alias in node.names: | |
| module_root = alias.name.split(".")[0] | |
| if module_root in FORBIDDEN_MODULES: | |
| violations.append(f"Forbidden import: {alias.name}") | |
| if isinstance(node, ast.ImportFrom): | |
| if node.module: | |
| module_root = node.module.split(".")[0] | |
| if module_root in FORBIDDEN_MODULES: | |
| violations.append(f"Forbidden import from: {node.module}") | |
| if isinstance(node, ast.Call): | |
| if isinstance(node.func, ast.Name): | |
| if node.func.id in FORBIDDEN_FUNCTIONS: | |
| violations.append(f"Forbidden call: {node.func.id}()") | |
| if isinstance(node.func, ast.Attribute): | |
| if node.func.attr in FORBIDDEN_FUNCTIONS: | |
| violations.append(f"Forbidden call: .{node.func.attr}()") | |
| if isinstance(node, ast.Assign): | |
| for target in node.targets: | |
| if isinstance(target, ast.Name) and target.id == "__builtins__": | |
| violations.append("Forbidden: __builtins__ assignment") | |
| return ValidationResult( | |
| is_valid=len(violations) == 0, | |
| violations=violations, | |
| ) | |