Sprint 9B: dataset.py — trace → filtered training dataset builder
Browse files
purpose_agent/optimization/dataset.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
dataset.py — Build training datasets from execution traces.
|
| 3 |
+
|
| 4 |
+
Converts successful agent traces into structured train/eval datasets:
|
| 5 |
+
- Filter by success (only learn from good trajectories)
|
| 6 |
+
- Immune scan (remove poisoned examples)
|
| 7 |
+
- Deduplicate
|
| 8 |
+
- Split train/validation/test
|
| 9 |
+
- Keep tool-use trajectories (not just final answers)
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
import json
|
| 13 |
+
import hashlib
|
| 14 |
+
from dataclasses import dataclass, field
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Any
|
| 17 |
+
from purpose_agent.trace import Trace
|
| 18 |
+
from purpose_agent.immune import scan_memory
|
| 19 |
+
from purpose_agent.memory import MemoryCard, MemoryKind
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class DatasetExample:
|
| 24 |
+
"""A single training example extracted from a trace."""
|
| 25 |
+
id: str
|
| 26 |
+
input_text: str
|
| 27 |
+
output_text: str
|
| 28 |
+
tool_calls: list[dict[str, Any]] = field(default_factory=list)
|
| 29 |
+
metadata: dict[str, Any] = field(default_factory=dict)
|
| 30 |
+
split: str = "train" # train/validation/test
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class TraceDataset:
|
| 35 |
+
"""A filtered, deduplicated dataset ready for training."""
|
| 36 |
+
examples: list[DatasetExample] = field(default_factory=list)
|
| 37 |
+
metadata: dict[str, Any] = field(default_factory=dict)
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
def train(self) -> list[DatasetExample]:
|
| 41 |
+
return [e for e in self.examples if e.split == "train"]
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def validation(self) -> list[DatasetExample]:
|
| 45 |
+
return [e for e in self.examples if e.split == "validation"]
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def test(self) -> list[DatasetExample]:
|
| 49 |
+
return [e for e in self.examples if e.split == "test"]
|
| 50 |
+
|
| 51 |
+
def save(self, path: str) -> None:
|
| 52 |
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
| 53 |
+
data = {"examples": [{"id": e.id, "input": e.input_text, "output": e.output_text,
|
| 54 |
+
"tool_calls": e.tool_calls, "split": e.split} for e in self.examples],
|
| 55 |
+
"metadata": self.metadata}
|
| 56 |
+
with open(path, "w") as f:
|
| 57 |
+
json.dump(data, f, indent=2)
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def size(self) -> int:
|
| 61 |
+
return len(self.examples)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class TraceDatasetBuilder:
|
| 65 |
+
"""
|
| 66 |
+
Builds filtered datasets from execution traces.
|
| 67 |
+
|
| 68 |
+
Usage:
|
| 69 |
+
builder = TraceDatasetBuilder(min_phi=7.0, train_ratio=0.7)
|
| 70 |
+
dataset = builder.build(traces)
|
| 71 |
+
dataset.save("./datasets/training.json")
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self, min_phi: float = 6.0, train_ratio: float = 0.7, val_ratio: float = 0.15):
|
| 75 |
+
self.min_phi = min_phi
|
| 76 |
+
self.train_ratio = train_ratio
|
| 77 |
+
self.val_ratio = val_ratio
|
| 78 |
+
self._rejected = 0
|
| 79 |
+
|
| 80 |
+
def build(self, traces: list[Trace]) -> TraceDataset:
|
| 81 |
+
"""Build dataset from traces."""
|
| 82 |
+
examples = []
|
| 83 |
+
seen_hashes: set[str] = set()
|
| 84 |
+
|
| 85 |
+
for trace in traces:
|
| 86 |
+
trace_examples = self._extract_from_trace(trace)
|
| 87 |
+
for ex in trace_examples:
|
| 88 |
+
# Dedup by content hash
|
| 89 |
+
h = hashlib.md5((ex.input_text + ex.output_text).encode()).hexdigest()[:12]
|
| 90 |
+
if h in seen_hashes:
|
| 91 |
+
continue
|
| 92 |
+
seen_hashes.add(h)
|
| 93 |
+
|
| 94 |
+
# Immune scan
|
| 95 |
+
if not self._passes_immune(ex):
|
| 96 |
+
self._rejected += 1
|
| 97 |
+
continue
|
| 98 |
+
|
| 99 |
+
examples.append(ex)
|
| 100 |
+
|
| 101 |
+
# Assign splits
|
| 102 |
+
n = len(examples)
|
| 103 |
+
n_train = int(n * self.train_ratio)
|
| 104 |
+
n_val = int(n * self.val_ratio)
|
| 105 |
+
for i, ex in enumerate(examples):
|
| 106 |
+
if i < n_train:
|
| 107 |
+
ex.split = "train"
|
| 108 |
+
elif i < n_train + n_val:
|
| 109 |
+
ex.split = "validation"
|
| 110 |
+
else:
|
| 111 |
+
ex.split = "test"
|
| 112 |
+
|
| 113 |
+
return TraceDataset(
|
| 114 |
+
examples=examples,
|
| 115 |
+
metadata={"total_traces": len(traces), "total_examples": n,
|
| 116 |
+
"rejected": self._rejected, "min_phi": self.min_phi},
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def _extract_from_trace(self, trace: Trace) -> list[DatasetExample]:
|
| 120 |
+
"""Extract examples from a single trace."""
|
| 121 |
+
examples = []
|
| 122 |
+
|
| 123 |
+
# Only use successful traces
|
| 124 |
+
score_events = [e for e in trace.events if e.kind == "score"]
|
| 125 |
+
if score_events:
|
| 126 |
+
final_phi = score_events[-1].data.get("phi_after", score_events[-1].data.get("phi", 0))
|
| 127 |
+
if final_phi < self.min_phi:
|
| 128 |
+
return []
|
| 129 |
+
|
| 130 |
+
# Extract action→observation pairs as examples
|
| 131 |
+
action_events = [e for e in trace.events if e.kind in ("action", "agent.progress")]
|
| 132 |
+
for event in action_events:
|
| 133 |
+
input_text = trace.purpose
|
| 134 |
+
output_text = event.data.get("thought", "") or event.data.get("action", "")
|
| 135 |
+
if not output_text:
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
tool_calls = []
|
| 139 |
+
tool_name = event.data.get("name", event.data.get("tool", ""))
|
| 140 |
+
if tool_name:
|
| 141 |
+
tool_calls.append({"name": tool_name, "args": event.data.get("params", {})})
|
| 142 |
+
|
| 143 |
+
examples.append(DatasetExample(
|
| 144 |
+
id=f"{trace.trace_id}_{event.step}",
|
| 145 |
+
input_text=input_text,
|
| 146 |
+
output_text=output_text[:1000],
|
| 147 |
+
tool_calls=tool_calls,
|
| 148 |
+
metadata={"trace_id": trace.trace_id, "step": event.step},
|
| 149 |
+
))
|
| 150 |
+
|
| 151 |
+
return examples
|
| 152 |
+
|
| 153 |
+
def _passes_immune(self, ex: DatasetExample) -> bool:
|
| 154 |
+
"""Check example doesn't contain adversarial content."""
|
| 155 |
+
card = MemoryCard(kind=MemoryKind.SKILL_CARD, content=ex.output_text[:500])
|
| 156 |
+
return scan_memory(card).passed
|
| 157 |
+
|
| 158 |
+
@property
|
| 159 |
+
def rejected_count(self) -> int:
|
| 160 |
+
return self._rejected
|