File size: 5,442 Bytes
a15535e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""Fast simulation executor for development.

Static-analysis-based execution simulator. Sub-100ms per call. No Docker
required. The success probability of a simulated run depends on whether
the script contains expected HF training markers (model imports, training
calls, save calls). When the simulation succeeds, a synthetic decreasing
loss curve is emitted; when it fails, a representative HF error is raised.
"""
from __future__ import annotations

import random
import time
from typing import Optional

from forgeenv.sandbox.ast_validator import validate_script
from forgeenv.tasks.models import ExecutionResult, Task


class SimulationExecutor:
    """Simulates script execution via static analysis.

    Use this throughout development phases. Real Docker execution is added
    later for grounded final-stage verification.
    """

    def __init__(self, seed: Optional[int] = None) -> None:
        self._rng = random.Random(seed) if seed is not None else random

    def execute(
        self, script_content: str, task: Optional[Task] = None
    ) -> ExecutionResult:
        start = time.time()

        validation = validate_script(script_content)
        if not validation.is_valid:
            return ExecutionResult(
                exit_code=1,
                stdout="",
                stderr=f"Validation failed: {'; '.join(validation.violations)}",
                wall_time_ms=int((time.time() - start) * 1000),
                script_content=script_content,
            )

        try:
            compile(script_content, "<forge_script>", "exec")
        except SyntaxError as e:
            return ExecutionResult(
                exit_code=1,
                stdout="",
                stderr=f"SyntaxError: {e}",
                wall_time_ms=int((time.time() - start) * 1000),
                script_content=script_content,
            )

        has_model_import = any(
            kw in script_content
            for kw in ("from transformers", "import torch", "from datasets")
        )
        has_training_call = any(
            kw in script_content
            for kw in ("trainer.train()", ".fit(", "train_loop", "for epoch")
        )
        has_save = any(
            kw in script_content
            for kw in ("save_pretrained", "save_model", "torch.save")
        )

        success_prob = 0.3
        if has_model_import:
            success_prob += 0.3
        if has_training_call:
            success_prob += 0.2
        if has_save:
            success_prob += 0.1

        # Mark obviously broken patterns as definite failures even when
        # they pass syntactic compilation. The simulator pretends to be a
        # static linter that catches AttributeError / ImportError signatures
        # before they would fire at runtime.
        broken_markers = (
            "_DEPRECATED(",
            "transformers.legacy",
            "from transformers.training import",
            ".start_training(",
            "load_from_hub(",
            "save_to_hub(",
            "pad_to_max_length=",
            "evaluation_loop(",
        )
        if any(marker in script_content for marker in broken_markers):
            success_prob = 0.0
        # Patterns that look like dataset column drift: a renamed column
        # that doesn't appear in real HF datasets.
        import re as _re

        if _re.search(r"['\"]input_text['\"]\s*[]:),]", script_content):
            success_prob = min(success_prob, 0.05)
        if _re.search(r"['\"]words['\"]\s*[]:),]", script_content):
            success_prob = min(success_prob, 0.05)
        # Tokenizer kwarg drift (truncate is not valid; truncation is).
        if _re.search(r"\btruncate\s*=", script_content):
            success_prob = min(success_prob, 0.05)

        succeeded = self._rng.random() < success_prob

        if succeeded:
            steps = self._rng.randint(20, 50)
            log_lines: list[str] = []
            loss = self._rng.uniform(2.0, 4.0)
            for step in range(1, steps + 1):
                loss *= self._rng.uniform(0.92, 0.99)
                log_lines.append(f"step={step} loss={loss:.4f}")
            log_lines.append("eval_accuracy=0.78")
            log_lines.append("TRAINING_COMPLETE")

            return ExecutionResult(
                exit_code=0,
                stdout="\n".join(log_lines),
                stderr="",
                wall_time_ms=int((time.time() - start) * 1000)
                + self._rng.randint(1000, 5000),
                checkpoint_exists=True,
                peak_memory_mb=self._rng.uniform(500, 2000),
                script_content=script_content,
            )

        error_types = [
            "ImportError: cannot import name 'OldTrainer' from 'transformers'",
            "AttributeError: 'Trainer' object has no attribute 'evaluate_model'",
            "KeyError: 'text' column not found in dataset",
            "TypeError: __init__() got an unexpected keyword argument 'num_epochs'",
            "RuntimeError: Expected input batch_size (16) to match target batch_size (32)",
            "ModuleNotFoundError: No module named 'transformers.legacy'",
        ]
        return ExecutionResult(
            exit_code=1,
            stdout="",
            stderr=self._rng.choice(error_types),
            wall_time_ms=int((time.time() - start) * 1000)
            + self._rng.randint(100, 500),
            script_content=script_content,
        )