Rohan03's picture
Sprint 9B: dataset.py — trace → filtered training dataset builder
631c715 verified
raw
history blame
5.55 kB
"""
dataset.py — Build training datasets from execution traces.
Converts successful agent traces into structured train/eval datasets:
- Filter by success (only learn from good trajectories)
- Immune scan (remove poisoned examples)
- Deduplicate
- Split train/validation/test
- Keep tool-use trajectories (not just final answers)
"""
from __future__ import annotations
import json
import hashlib
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from purpose_agent.trace import Trace
from purpose_agent.immune import scan_memory
from purpose_agent.memory import MemoryCard, MemoryKind
@dataclass
class DatasetExample:
"""A single training example extracted from a trace."""
id: str
input_text: str
output_text: str
tool_calls: list[dict[str, Any]] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
split: str = "train" # train/validation/test
@dataclass
class TraceDataset:
"""A filtered, deduplicated dataset ready for training."""
examples: list[DatasetExample] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
@property
def train(self) -> list[DatasetExample]:
return [e for e in self.examples if e.split == "train"]
@property
def validation(self) -> list[DatasetExample]:
return [e for e in self.examples if e.split == "validation"]
@property
def test(self) -> list[DatasetExample]:
return [e for e in self.examples if e.split == "test"]
def save(self, path: str) -> None:
Path(path).parent.mkdir(parents=True, exist_ok=True)
data = {"examples": [{"id": e.id, "input": e.input_text, "output": e.output_text,
"tool_calls": e.tool_calls, "split": e.split} for e in self.examples],
"metadata": self.metadata}
with open(path, "w") as f:
json.dump(data, f, indent=2)
@property
def size(self) -> int:
return len(self.examples)
class TraceDatasetBuilder:
"""
Builds filtered datasets from execution traces.
Usage:
builder = TraceDatasetBuilder(min_phi=7.0, train_ratio=0.7)
dataset = builder.build(traces)
dataset.save("./datasets/training.json")
"""
def __init__(self, min_phi: float = 6.0, train_ratio: float = 0.7, val_ratio: float = 0.15):
self.min_phi = min_phi
self.train_ratio = train_ratio
self.val_ratio = val_ratio
self._rejected = 0
def build(self, traces: list[Trace]) -> TraceDataset:
"""Build dataset from traces."""
examples = []
seen_hashes: set[str] = set()
for trace in traces:
trace_examples = self._extract_from_trace(trace)
for ex in trace_examples:
# Dedup by content hash
h = hashlib.md5((ex.input_text + ex.output_text).encode()).hexdigest()[:12]
if h in seen_hashes:
continue
seen_hashes.add(h)
# Immune scan
if not self._passes_immune(ex):
self._rejected += 1
continue
examples.append(ex)
# Assign splits
n = len(examples)
n_train = int(n * self.train_ratio)
n_val = int(n * self.val_ratio)
for i, ex in enumerate(examples):
if i < n_train:
ex.split = "train"
elif i < n_train + n_val:
ex.split = "validation"
else:
ex.split = "test"
return TraceDataset(
examples=examples,
metadata={"total_traces": len(traces), "total_examples": n,
"rejected": self._rejected, "min_phi": self.min_phi},
)
def _extract_from_trace(self, trace: Trace) -> list[DatasetExample]:
"""Extract examples from a single trace."""
examples = []
# Only use successful traces
score_events = [e for e in trace.events if e.kind == "score"]
if score_events:
final_phi = score_events[-1].data.get("phi_after", score_events[-1].data.get("phi", 0))
if final_phi < self.min_phi:
return []
# Extract action→observation pairs as examples
action_events = [e for e in trace.events if e.kind in ("action", "agent.progress")]
for event in action_events:
input_text = trace.purpose
output_text = event.data.get("thought", "") or event.data.get("action", "")
if not output_text:
continue
tool_calls = []
tool_name = event.data.get("name", event.data.get("tool", ""))
if tool_name:
tool_calls.append({"name": tool_name, "args": event.data.get("params", {})})
examples.append(DatasetExample(
id=f"{trace.trace_id}_{event.step}",
input_text=input_text,
output_text=output_text[:1000],
tool_calls=tool_calls,
metadata={"trace_id": trace.trace_id, "step": event.step},
))
return examples
def _passes_immune(self, ex: DatasetExample) -> bool:
"""Check example doesn't contain adversarial content."""
card = MemoryCard(kind=MemoryKind.SKILL_CARD, content=ex.output_text[:500])
return scan_memory(card).passed
@property
def rejected_count(self) -> int:
return self._rejected