File size: 5,442 Bytes
a15535e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | """Fast simulation executor for development.
Static-analysis-based execution simulator. Sub-100ms per call. No Docker
required. The success probability of a simulated run depends on whether
the script contains expected HF training markers (model imports, training
calls, save calls). When the simulation succeeds, a synthetic decreasing
loss curve is emitted; when it fails, a representative HF error is raised.
"""
from __future__ import annotations
import random
import time
from typing import Optional
from forgeenv.sandbox.ast_validator import validate_script
from forgeenv.tasks.models import ExecutionResult, Task
class SimulationExecutor:
"""Simulates script execution via static analysis.
Use this throughout development phases. Real Docker execution is added
later for grounded final-stage verification.
"""
def __init__(self, seed: Optional[int] = None) -> None:
self._rng = random.Random(seed) if seed is not None else random
def execute(
self, script_content: str, task: Optional[Task] = None
) -> ExecutionResult:
start = time.time()
validation = validate_script(script_content)
if not validation.is_valid:
return ExecutionResult(
exit_code=1,
stdout="",
stderr=f"Validation failed: {'; '.join(validation.violations)}",
wall_time_ms=int((time.time() - start) * 1000),
script_content=script_content,
)
try:
compile(script_content, "<forge_script>", "exec")
except SyntaxError as e:
return ExecutionResult(
exit_code=1,
stdout="",
stderr=f"SyntaxError: {e}",
wall_time_ms=int((time.time() - start) * 1000),
script_content=script_content,
)
has_model_import = any(
kw in script_content
for kw in ("from transformers", "import torch", "from datasets")
)
has_training_call = any(
kw in script_content
for kw in ("trainer.train()", ".fit(", "train_loop", "for epoch")
)
has_save = any(
kw in script_content
for kw in ("save_pretrained", "save_model", "torch.save")
)
success_prob = 0.3
if has_model_import:
success_prob += 0.3
if has_training_call:
success_prob += 0.2
if has_save:
success_prob += 0.1
# Mark obviously broken patterns as definite failures even when
# they pass syntactic compilation. The simulator pretends to be a
# static linter that catches AttributeError / ImportError signatures
# before they would fire at runtime.
broken_markers = (
"_DEPRECATED(",
"transformers.legacy",
"from transformers.training import",
".start_training(",
"load_from_hub(",
"save_to_hub(",
"pad_to_max_length=",
"evaluation_loop(",
)
if any(marker in script_content for marker in broken_markers):
success_prob = 0.0
# Patterns that look like dataset column drift: a renamed column
# that doesn't appear in real HF datasets.
import re as _re
if _re.search(r"['\"]input_text['\"]\s*[]:),]", script_content):
success_prob = min(success_prob, 0.05)
if _re.search(r"['\"]words['\"]\s*[]:),]", script_content):
success_prob = min(success_prob, 0.05)
# Tokenizer kwarg drift (truncate is not valid; truncation is).
if _re.search(r"\btruncate\s*=", script_content):
success_prob = min(success_prob, 0.05)
succeeded = self._rng.random() < success_prob
if succeeded:
steps = self._rng.randint(20, 50)
log_lines: list[str] = []
loss = self._rng.uniform(2.0, 4.0)
for step in range(1, steps + 1):
loss *= self._rng.uniform(0.92, 0.99)
log_lines.append(f"step={step} loss={loss:.4f}")
log_lines.append("eval_accuracy=0.78")
log_lines.append("TRAINING_COMPLETE")
return ExecutionResult(
exit_code=0,
stdout="\n".join(log_lines),
stderr="",
wall_time_ms=int((time.time() - start) * 1000)
+ self._rng.randint(1000, 5000),
checkpoint_exists=True,
peak_memory_mb=self._rng.uniform(500, 2000),
script_content=script_content,
)
error_types = [
"ImportError: cannot import name 'OldTrainer' from 'transformers'",
"AttributeError: 'Trainer' object has no attribute 'evaluate_model'",
"KeyError: 'text' column not found in dataset",
"TypeError: __init__() got an unexpected keyword argument 'num_epochs'",
"RuntimeError: Expected input batch_size (16) to match target batch_size (32)",
"ModuleNotFoundError: No module named 'transformers.legacy'",
]
return ExecutionResult(
exit_code=1,
stdout="",
stderr=self._rng.choice(error_types),
wall_time_ms=int((time.time() - start) * 1000)
+ self._rng.randint(100, 500),
script_content=script_content,
)
|