File size: 3,318 Bytes
8097081
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1435892
8097081
 
 
 
 
1435892
8097081
 
 
1435892
8097081
72a7241
 
 
 
8097081
 
 
 
 
 
 
1435892
8097081
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1435892
8097081
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# src/pytorch_debug_env/scenario_generator.py
from __future__ import annotations

import random
import uuid
from dataclasses import dataclass
from typing import Dict, List

import numpy as np

from .bug_library import BugTemplate


@dataclass
class Scenario:
    scenario_id: str
    task_id: str
    repo_files: Dict[str, str]
    loss_curve: List[Dict]
    gpu_profile: List[Dict]
    training_log: str
    diagnostic_report: str
    ground_truth: Dict


class ScenarioGenerator:
    def __init__(self, bug_templates: List[BugTemplate]):
        """Create a generator that samples from a set of bug templates."""
        self.bug_templates = bug_templates

    def generate(self, difficulty: str, seed: int | None = None) -> Scenario:
        """Build a scenario with deterministic artifacts when a seed is provided."""
        rng = random.Random(seed)
        candidates = [b for b in self.bug_templates if b.difficulty == difficulty]
        if not candidates:
            raise ValueError(f"Unknown difficulty: {difficulty}")
        template = rng.choice(candidates)

        repo_files = self._base_repo(rng)
        repo_files = template.repo_mutator(repo_files, rng)

        loss_curve = template.artifact_generator("loss_curve", rng)
        gpu_profile = template.artifact_generator("gpu_profile", rng)
        training_log = template.artifact_generator("training_log", rng)
        diagnostic_report = template.artifact_generator("diagnostic_report", rng)

        ground_truth = {
            "bug_type": template.bug_type,
            "category": template.category,
            "primary_bug_file": template.primary_bug_file,
            "related_files": template.related_files,
            "red_herring_file": template.red_herring_file,
            "fix_strategy": template.fix_strategy,
            "line_range": template.line_range,
        }

        return Scenario(
            scenario_id=str(uuid.uuid4())[:8],
            task_id=difficulty,
            repo_files=repo_files,
            loss_curve=loss_curve,
            gpu_profile=gpu_profile,
            training_log=training_log,
            diagnostic_report=diagnostic_report,
            ground_truth=ground_truth,
        )

    def _base_repo(self, rng: random.Random) -> Dict[str, str]:
        return {
            "train.py": self._train_py(),
            "model/architecture.py": self._model_py(),
            "model/attention.py": self._attention_py(),
            "data/dataset.py": self._dataset_py(),
            "data/preprocessing.py": self._preprocess_py(),
            "config/training_config.yaml": self._config_yaml(),
        }

    def _train_py(self) -> str:
        return """import torch\nfrom model.architecture import Net\n\n# training loop placeholder\n"""

    def _model_py(self) -> str:
        return """import torch.nn as nn\n\nclass Net(nn.Module):\n    def __init__(self):\n        super().__init__()\n"""

    def _attention_py(self) -> str:
        return """# custom attention layer\n"""

    def _dataset_py(self) -> str:
        return """from torch.utils.data import Dataset\n\nclass ImageDataset(Dataset):\n    pass\n"""

    def _preprocess_py(self) -> str:
        return """def normalize(x):\n    return x\n"""

    def _config_yaml(self) -> str:
        return "lr: 0.001\nbatch_size: 32\n"