"""Fast simulation executor for development. Static-analysis-based execution simulator. Sub-100ms per call. No Docker required. The success probability of a simulated run depends on whether the script contains expected HF training markers (model imports, training calls, save calls). When the simulation succeeds, a synthetic decreasing loss curve is emitted; when it fails, a representative HF error is raised. """ from __future__ import annotations import random import time from typing import Optional from forgeenv.sandbox.ast_validator import validate_script from forgeenv.tasks.models import ExecutionResult, Task class SimulationExecutor: """Simulates script execution via static analysis. Use this throughout development phases. Real Docker execution is added later for grounded final-stage verification. """ def __init__(self, seed: Optional[int] = None) -> None: self._rng = random.Random(seed) if seed is not None else random def execute( self, script_content: str, task: Optional[Task] = None ) -> ExecutionResult: start = time.time() validation = validate_script(script_content) if not validation.is_valid: return ExecutionResult( exit_code=1, stdout="", stderr=f"Validation failed: {'; '.join(validation.violations)}", wall_time_ms=int((time.time() - start) * 1000), script_content=script_content, ) try: compile(script_content, "", "exec") except SyntaxError as e: return ExecutionResult( exit_code=1, stdout="", stderr=f"SyntaxError: {e}", wall_time_ms=int((time.time() - start) * 1000), script_content=script_content, ) has_model_import = any( kw in script_content for kw in ("from transformers", "import torch", "from datasets") ) has_training_call = any( kw in script_content for kw in ("trainer.train()", ".fit(", "train_loop", "for epoch") ) has_save = any( kw in script_content for kw in ("save_pretrained", "save_model", "torch.save") ) success_prob = 0.3 if has_model_import: success_prob += 0.3 if has_training_call: success_prob += 0.2 if has_save: success_prob += 0.1 # Mark obviously broken patterns as definite failures even when # they pass syntactic compilation. The simulator pretends to be a # static linter that catches AttributeError / ImportError signatures # before they would fire at runtime. broken_markers = ( "_DEPRECATED(", "transformers.legacy", "from transformers.training import", ".start_training(", "load_from_hub(", "save_to_hub(", "pad_to_max_length=", "evaluation_loop(", ) if any(marker in script_content for marker in broken_markers): success_prob = 0.0 # Patterns that look like dataset column drift: a renamed column # that doesn't appear in real HF datasets. import re as _re if _re.search(r"['\"]input_text['\"]\s*[]:),]", script_content): success_prob = min(success_prob, 0.05) if _re.search(r"['\"]words['\"]\s*[]:),]", script_content): success_prob = min(success_prob, 0.05) # Tokenizer kwarg drift (truncate is not valid; truncation is). if _re.search(r"\btruncate\s*=", script_content): success_prob = min(success_prob, 0.05) succeeded = self._rng.random() < success_prob if succeeded: steps = self._rng.randint(20, 50) log_lines: list[str] = [] loss = self._rng.uniform(2.0, 4.0) for step in range(1, steps + 1): loss *= self._rng.uniform(0.92, 0.99) log_lines.append(f"step={step} loss={loss:.4f}") log_lines.append("eval_accuracy=0.78") log_lines.append("TRAINING_COMPLETE") return ExecutionResult( exit_code=0, stdout="\n".join(log_lines), stderr="", wall_time_ms=int((time.time() - start) * 1000) + self._rng.randint(1000, 5000), checkpoint_exists=True, peak_memory_mb=self._rng.uniform(500, 2000), script_content=script_content, ) error_types = [ "ImportError: cannot import name 'OldTrainer' from 'transformers'", "AttributeError: 'Trainer' object has no attribute 'evaluate_model'", "KeyError: 'text' column not found in dataset", "TypeError: __init__() got an unexpected keyword argument 'num_epochs'", "RuntimeError: Expected input batch_size (16) to match target batch_size (32)", "ModuleNotFoundError: No module named 'transformers.legacy'", ] return ExecutionResult( exit_code=1, stdout="", stderr=self._rng.choice(error_types), wall_time_ms=int((time.time() - start) * 1000) + self._rng.randint(100, 500), script_content=script_content, )