Rohan03 commited on
Commit
631c715
·
verified ·
1 Parent(s): 130c63c

Sprint 9B: dataset.py — trace → filtered training dataset builder

Browse files
Files changed (1) hide show
  1. purpose_agent/optimization/dataset.py +160 -0
purpose_agent/optimization/dataset.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ dataset.py — Build training datasets from execution traces.
3
+
4
+ Converts successful agent traces into structured train/eval datasets:
5
+ - Filter by success (only learn from good trajectories)
6
+ - Immune scan (remove poisoned examples)
7
+ - Deduplicate
8
+ - Split train/validation/test
9
+ - Keep tool-use trajectories (not just final answers)
10
+ """
11
+ from __future__ import annotations
12
+ import json
13
+ import hashlib
14
+ from dataclasses import dataclass, field
15
+ from pathlib import Path
16
+ from typing import Any
17
+ from purpose_agent.trace import Trace
18
+ from purpose_agent.immune import scan_memory
19
+ from purpose_agent.memory import MemoryCard, MemoryKind
20
+
21
+
22
+ @dataclass
23
+ class DatasetExample:
24
+ """A single training example extracted from a trace."""
25
+ id: str
26
+ input_text: str
27
+ output_text: str
28
+ tool_calls: list[dict[str, Any]] = field(default_factory=list)
29
+ metadata: dict[str, Any] = field(default_factory=dict)
30
+ split: str = "train" # train/validation/test
31
+
32
+
33
+ @dataclass
34
+ class TraceDataset:
35
+ """A filtered, deduplicated dataset ready for training."""
36
+ examples: list[DatasetExample] = field(default_factory=list)
37
+ metadata: dict[str, Any] = field(default_factory=dict)
38
+
39
+ @property
40
+ def train(self) -> list[DatasetExample]:
41
+ return [e for e in self.examples if e.split == "train"]
42
+
43
+ @property
44
+ def validation(self) -> list[DatasetExample]:
45
+ return [e for e in self.examples if e.split == "validation"]
46
+
47
+ @property
48
+ def test(self) -> list[DatasetExample]:
49
+ return [e for e in self.examples if e.split == "test"]
50
+
51
+ def save(self, path: str) -> None:
52
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
53
+ data = {"examples": [{"id": e.id, "input": e.input_text, "output": e.output_text,
54
+ "tool_calls": e.tool_calls, "split": e.split} for e in self.examples],
55
+ "metadata": self.metadata}
56
+ with open(path, "w") as f:
57
+ json.dump(data, f, indent=2)
58
+
59
+ @property
60
+ def size(self) -> int:
61
+ return len(self.examples)
62
+
63
+
64
+ class TraceDatasetBuilder:
65
+ """
66
+ Builds filtered datasets from execution traces.
67
+
68
+ Usage:
69
+ builder = TraceDatasetBuilder(min_phi=7.0, train_ratio=0.7)
70
+ dataset = builder.build(traces)
71
+ dataset.save("./datasets/training.json")
72
+ """
73
+
74
+ def __init__(self, min_phi: float = 6.0, train_ratio: float = 0.7, val_ratio: float = 0.15):
75
+ self.min_phi = min_phi
76
+ self.train_ratio = train_ratio
77
+ self.val_ratio = val_ratio
78
+ self._rejected = 0
79
+
80
+ def build(self, traces: list[Trace]) -> TraceDataset:
81
+ """Build dataset from traces."""
82
+ examples = []
83
+ seen_hashes: set[str] = set()
84
+
85
+ for trace in traces:
86
+ trace_examples = self._extract_from_trace(trace)
87
+ for ex in trace_examples:
88
+ # Dedup by content hash
89
+ h = hashlib.md5((ex.input_text + ex.output_text).encode()).hexdigest()[:12]
90
+ if h in seen_hashes:
91
+ continue
92
+ seen_hashes.add(h)
93
+
94
+ # Immune scan
95
+ if not self._passes_immune(ex):
96
+ self._rejected += 1
97
+ continue
98
+
99
+ examples.append(ex)
100
+
101
+ # Assign splits
102
+ n = len(examples)
103
+ n_train = int(n * self.train_ratio)
104
+ n_val = int(n * self.val_ratio)
105
+ for i, ex in enumerate(examples):
106
+ if i < n_train:
107
+ ex.split = "train"
108
+ elif i < n_train + n_val:
109
+ ex.split = "validation"
110
+ else:
111
+ ex.split = "test"
112
+
113
+ return TraceDataset(
114
+ examples=examples,
115
+ metadata={"total_traces": len(traces), "total_examples": n,
116
+ "rejected": self._rejected, "min_phi": self.min_phi},
117
+ )
118
+
119
+ def _extract_from_trace(self, trace: Trace) -> list[DatasetExample]:
120
+ """Extract examples from a single trace."""
121
+ examples = []
122
+
123
+ # Only use successful traces
124
+ score_events = [e for e in trace.events if e.kind == "score"]
125
+ if score_events:
126
+ final_phi = score_events[-1].data.get("phi_after", score_events[-1].data.get("phi", 0))
127
+ if final_phi < self.min_phi:
128
+ return []
129
+
130
+ # Extract action→observation pairs as examples
131
+ action_events = [e for e in trace.events if e.kind in ("action", "agent.progress")]
132
+ for event in action_events:
133
+ input_text = trace.purpose
134
+ output_text = event.data.get("thought", "") or event.data.get("action", "")
135
+ if not output_text:
136
+ continue
137
+
138
+ tool_calls = []
139
+ tool_name = event.data.get("name", event.data.get("tool", ""))
140
+ if tool_name:
141
+ tool_calls.append({"name": tool_name, "args": event.data.get("params", {})})
142
+
143
+ examples.append(DatasetExample(
144
+ id=f"{trace.trace_id}_{event.step}",
145
+ input_text=input_text,
146
+ output_text=output_text[:1000],
147
+ tool_calls=tool_calls,
148
+ metadata={"trace_id": trace.trace_id, "step": event.step},
149
+ ))
150
+
151
+ return examples
152
+
153
+ def _passes_immune(self, ex: DatasetExample) -> bool:
154
+ """Check example doesn't contain adversarial content."""
155
+ card = MemoryCard(kind=MemoryKind.SKILL_CARD, content=ex.output_text[:500])
156
+ return scan_memory(card).passed
157
+
158
+ @property
159
+ def rejected_count(self) -> int:
160
+ return self._rejected