akhiilll's picture
forgeenv source snapshot for training job
b0fbec3 verified
"""Fast simulation executor for development.
Static-analysis-based execution simulator. Sub-100ms per call. No Docker
required. The success probability of a simulated run depends on whether
the script contains expected HF training markers (model imports, training
calls, save calls). When the simulation succeeds, a synthetic decreasing
loss curve is emitted; when it fails, a representative HF error is raised.
"""
from __future__ import annotations
import random
import time
from typing import Optional
from forgeenv.sandbox.ast_validator import validate_script
from forgeenv.tasks.models import ExecutionResult, Task
class SimulationExecutor:
"""Simulates script execution via static analysis.
Use this throughout development phases. Real Docker execution is added
later for grounded final-stage verification.
"""
def __init__(self, seed: Optional[int] = None) -> None:
self._rng = random.Random(seed) if seed is not None else random
def execute(
self, script_content: str, task: Optional[Task] = None
) -> ExecutionResult:
start = time.time()
validation = validate_script(script_content)
if not validation.is_valid:
return ExecutionResult(
exit_code=1,
stdout="",
stderr=f"Validation failed: {'; '.join(validation.violations)}",
wall_time_ms=int((time.time() - start) * 1000),
script_content=script_content,
)
try:
compile(script_content, "<forge_script>", "exec")
except SyntaxError as e:
return ExecutionResult(
exit_code=1,
stdout="",
stderr=f"SyntaxError: {e}",
wall_time_ms=int((time.time() - start) * 1000),
script_content=script_content,
)
has_model_import = any(
kw in script_content
for kw in ("from transformers", "import torch", "from datasets")
)
has_training_call = any(
kw in script_content
for kw in ("trainer.train()", ".fit(", "train_loop", "for epoch")
)
has_save = any(
kw in script_content
for kw in ("save_pretrained", "save_model", "torch.save")
)
success_prob = 0.3
if has_model_import:
success_prob += 0.3
if has_training_call:
success_prob += 0.2
if has_save:
success_prob += 0.1
# Mark obviously broken patterns as definite failures even when
# they pass syntactic compilation. The simulator pretends to be a
# static linter that catches AttributeError / ImportError signatures
# before they would fire at runtime.
broken_markers = (
"_DEPRECATED(",
"transformers.legacy",
"from transformers.training import",
".start_training(",
"load_from_hub(",
"save_to_hub(",
"pad_to_max_length=",
"evaluation_loop(",
)
if any(marker in script_content for marker in broken_markers):
success_prob = 0.0
# Patterns that look like dataset column drift: a renamed column
# that doesn't appear in real HF datasets.
import re as _re
if _re.search(r"['\"]input_text['\"]\s*[]:),]", script_content):
success_prob = min(success_prob, 0.05)
if _re.search(r"['\"]words['\"]\s*[]:),]", script_content):
success_prob = min(success_prob, 0.05)
# Tokenizer kwarg drift (truncate is not valid; truncation is).
if _re.search(r"\btruncate\s*=", script_content):
success_prob = min(success_prob, 0.05)
succeeded = self._rng.random() < success_prob
if succeeded:
steps = self._rng.randint(20, 50)
log_lines: list[str] = []
loss = self._rng.uniform(2.0, 4.0)
for step in range(1, steps + 1):
loss *= self._rng.uniform(0.92, 0.99)
log_lines.append(f"step={step} loss={loss:.4f}")
log_lines.append("eval_accuracy=0.78")
log_lines.append("TRAINING_COMPLETE")
return ExecutionResult(
exit_code=0,
stdout="\n".join(log_lines),
stderr="",
wall_time_ms=int((time.time() - start) * 1000)
+ self._rng.randint(1000, 5000),
checkpoint_exists=True,
peak_memory_mb=self._rng.uniform(500, 2000),
script_content=script_content,
)
error_types = [
"ImportError: cannot import name 'OldTrainer' from 'transformers'",
"AttributeError: 'Trainer' object has no attribute 'evaluate_model'",
"KeyError: 'text' column not found in dataset",
"TypeError: __init__() got an unexpected keyword argument 'num_epochs'",
"RuntimeError: Expected input batch_size (16) to match target batch_size (32)",
"ModuleNotFoundError: No module named 'transformers.legacy'",
]
return ExecutionResult(
exit_code=1,
stdout="",
stderr=self._rng.choice(error_types),
wall_time_ms=int((time.time() - start) * 1000)
+ self._rng.randint(100, 500),
script_content=script_content,
)