| """AST-based script validator. |
| |
| Catches forbidden imports and dangerous patterns BEFORE any execution |
| happens. This is a critical defense against reward hacking via system |
| calls, network access, or process manipulation. |
| """ |
| from __future__ import annotations |
|
|
| import ast |
|
|
| from forgeenv.tasks.models import ValidationResult |
|
|
| FORBIDDEN_MODULES = { |
| "os", |
| "subprocess", |
| "socket", |
| "urllib", |
| "requests", |
| "ctypes", |
| "shutil", |
| "signal", |
| "multiprocessing", |
| "threading", |
| } |
|
|
| FORBIDDEN_FUNCTIONS = {"eval", "exec", "compile", "__import__"} |
|
|
|
|
| def validate_script(script_content: str) -> ValidationResult: |
| """Parse a script as AST and reject forbidden patterns. |
| |
| Returns a ValidationResult with `is_valid` and a list of `violations`. |
| """ |
| violations: list[str] = [] |
|
|
| try: |
| tree = ast.parse(script_content) |
| except SyntaxError as e: |
| return ValidationResult(is_valid=False, violations=[f"SyntaxError: {e}"]) |
|
|
| for node in ast.walk(tree): |
| if isinstance(node, ast.Import): |
| for alias in node.names: |
| module_root = alias.name.split(".")[0] |
| if module_root in FORBIDDEN_MODULES: |
| violations.append(f"Forbidden import: {alias.name}") |
|
|
| if isinstance(node, ast.ImportFrom): |
| if node.module: |
| module_root = node.module.split(".")[0] |
| if module_root in FORBIDDEN_MODULES: |
| violations.append(f"Forbidden import from: {node.module}") |
|
|
| if isinstance(node, ast.Call): |
| if isinstance(node.func, ast.Name): |
| if node.func.id in FORBIDDEN_FUNCTIONS: |
| violations.append(f"Forbidden call: {node.func.id}()") |
| if isinstance(node.func, ast.Attribute): |
| if node.func.attr in FORBIDDEN_FUNCTIONS: |
| violations.append(f"Forbidden call: .{node.func.attr}()") |
|
|
| if isinstance(node, ast.Assign): |
| for target in node.targets: |
| if isinstance(target, ast.Name) and target.id == "__builtins__": |
| violations.append("Forbidden: __builtins__ assignment") |
|
|
| return ValidationResult( |
| is_valid=len(violations) == 0, |
| violations=violations, |
| ) |
|
|