Openenv / src /pytorch_debug_env /scenario_generator.py
Priyansh Saxena
feat: expand scenarios and investigation actions
1435892
# src/pytorch_debug_env/scenario_generator.py
from __future__ import annotations
import random
import uuid
from dataclasses import dataclass
from typing import Dict, List
import numpy as np
from .bug_library import BugTemplate
@dataclass
class Scenario:
scenario_id: str
task_id: str
repo_files: Dict[str, str]
loss_curve: List[Dict]
gpu_profile: List[Dict]
training_log: str
diagnostic_report: str
ground_truth: Dict
class ScenarioGenerator:
def __init__(self, bug_templates: List[BugTemplate]):
"""Create a generator that samples from a set of bug templates."""
self.bug_templates = bug_templates
def generate(self, difficulty: str, seed: int | None = None) -> Scenario:
"""Build a scenario with deterministic artifacts when a seed is provided."""
rng = random.Random(seed)
candidates = [b for b in self.bug_templates if b.difficulty == difficulty]
if not candidates:
raise ValueError(f"Unknown difficulty: {difficulty}")
template = rng.choice(candidates)
repo_files = self._base_repo(rng)
repo_files = template.repo_mutator(repo_files, rng)
loss_curve = template.artifact_generator("loss_curve", rng)
gpu_profile = template.artifact_generator("gpu_profile", rng)
training_log = template.artifact_generator("training_log", rng)
diagnostic_report = template.artifact_generator("diagnostic_report", rng)
ground_truth = {
"bug_type": template.bug_type,
"category": template.category,
"primary_bug_file": template.primary_bug_file,
"related_files": template.related_files,
"red_herring_file": template.red_herring_file,
"fix_strategy": template.fix_strategy,
"line_range": template.line_range,
}
return Scenario(
scenario_id=str(uuid.uuid4())[:8],
task_id=difficulty,
repo_files=repo_files,
loss_curve=loss_curve,
gpu_profile=gpu_profile,
training_log=training_log,
diagnostic_report=diagnostic_report,
ground_truth=ground_truth,
)
def _base_repo(self, rng: random.Random) -> Dict[str, str]:
return {
"train.py": self._train_py(),
"model/architecture.py": self._model_py(),
"model/attention.py": self._attention_py(),
"data/dataset.py": self._dataset_py(),
"data/preprocessing.py": self._preprocess_py(),
"config/training_config.yaml": self._config_yaml(),
}
def _train_py(self) -> str:
return """import torch\nfrom model.architecture import Net\n\n# training loop placeholder\n"""
def _model_py(self) -> str:
return """import torch.nn as nn\n\nclass Net(nn.Module):\n def __init__(self):\n super().__init__()\n"""
def _attention_py(self) -> str:
return """# custom attention layer\n"""
def _dataset_py(self) -> str:
return """from torch.utils.data import Dataset\n\nclass ImageDataset(Dataset):\n pass\n"""
def _preprocess_py(self) -> str:
return """def normalize(x):\n return x\n"""
def _config_yaml(self) -> str:
return "lr: 0.001\nbatch_size: 32\n"