File size: 1,928 Bytes
8097081
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1435892
8097081
 
 
 
 
 
 
 
 
 
 
 
 
 
1435892
8097081
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# src/pytorch_debug_env/models.py
from __future__ import annotations

from typing import Dict, List, Literal, Optional
from pydantic import BaseModel, Field


class Hypothesis(BaseModel):
    bug_type: str = Field(..., description="Current suspected bug type")
    affected_file: str = Field(..., description="Current suspected file")
    confidence: float = Field(..., ge=0.0, le=1.0)


class InvestigationAction(BaseModel):
    action: Literal[
        "reveal_file",
        "extend_loss_curve",
        "extend_gpu_profile",
        "reveal_log_chunk",
        "run_diagnostic",
    ]
    target: Optional[str] = None


class FinalDiagnosis(BaseModel):
    bug_type: str
    affected_file: str
    line_range: List[int]
    fix_strategy: str
    confidence: float = Field(..., ge=0.0, le=1.0)


class PyTorchDebugAction(BaseModel):
    current_hypothesis: Hypothesis
    investigation_action: Optional[InvestigationAction] = None
    commit_diagnosis: bool = False
    final_diagnosis: Optional[FinalDiagnosis] = None


class HypothesisRecord(BaseModel):
    step: int
    hypothesis: Hypothesis
    quality: float


class PyTorchDebugObservation(BaseModel):
    scenario_id: str
    task_id: str
    revealed_files: Dict[str, str]
    available_files: List[str]
    loss_curve_window: List[Dict]
    gpu_profile_window: List[Dict]
    training_log_tail: str
    diagnostic_report: Optional[str] = None
    step_num: int
    steps_remaining: int
    investigation_budget: int
    hypothesis_history: List[HypothesisRecord]
    last_feedback: str


class PyTorchDebugState(BaseModel):
    scenario_id: str
    task_id: str
    max_steps: int
    current_step: int
    revealed_files: List[str]
    remaining_files: List[str]
    diagnostic_revealed: bool = False
    done: bool
    final_score: float = 0.0


class PyTorchDebugReward(BaseModel):
    value: float = Field(..., ge=0.0, le=1.0)
    components: Dict[str, float]