| """ |
| dataset.py — Build training datasets from execution traces. |
| |
| Converts successful agent traces into structured train/eval datasets: |
| - Filter by success (only learn from good trajectories) |
| - Immune scan (remove poisoned examples) |
| - Deduplicate |
| - Split train/validation/test |
| - Keep tool-use trajectories (not just final answers) |
| """ |
| from __future__ import annotations |
| import json |
| import hashlib |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Any |
| from purpose_agent.trace import Trace |
| from purpose_agent.immune import scan_memory |
| from purpose_agent.memory import MemoryCard, MemoryKind |
|
|
|
|
| @dataclass |
| class DatasetExample: |
| """A single training example extracted from a trace.""" |
| id: str |
| input_text: str |
| output_text: str |
| tool_calls: list[dict[str, Any]] = field(default_factory=list) |
| metadata: dict[str, Any] = field(default_factory=dict) |
| split: str = "train" |
|
|
|
|
| @dataclass |
| class TraceDataset: |
| """A filtered, deduplicated dataset ready for training.""" |
| examples: list[DatasetExample] = field(default_factory=list) |
| metadata: dict[str, Any] = field(default_factory=dict) |
|
|
| @property |
| def train(self) -> list[DatasetExample]: |
| return [e for e in self.examples if e.split == "train"] |
|
|
| @property |
| def validation(self) -> list[DatasetExample]: |
| return [e for e in self.examples if e.split == "validation"] |
|
|
| @property |
| def test(self) -> list[DatasetExample]: |
| return [e for e in self.examples if e.split == "test"] |
|
|
| def save(self, path: str) -> None: |
| Path(path).parent.mkdir(parents=True, exist_ok=True) |
| data = {"examples": [{"id": e.id, "input": e.input_text, "output": e.output_text, |
| "tool_calls": e.tool_calls, "split": e.split} for e in self.examples], |
| "metadata": self.metadata} |
| with open(path, "w") as f: |
| json.dump(data, f, indent=2) |
|
|
| @property |
| def size(self) -> int: |
| return len(self.examples) |
|
|
|
|
| class TraceDatasetBuilder: |
| """ |
| Builds filtered datasets from execution traces. |
| |
| Usage: |
| builder = TraceDatasetBuilder(min_phi=7.0, train_ratio=0.7) |
| dataset = builder.build(traces) |
| dataset.save("./datasets/training.json") |
| """ |
|
|
| def __init__(self, min_phi: float = 6.0, train_ratio: float = 0.7, val_ratio: float = 0.15): |
| self.min_phi = min_phi |
| self.train_ratio = train_ratio |
| self.val_ratio = val_ratio |
| self._rejected = 0 |
|
|
| def build(self, traces: list[Trace]) -> TraceDataset: |
| """Build dataset from traces.""" |
| examples = [] |
| seen_hashes: set[str] = set() |
|
|
| for trace in traces: |
| trace_examples = self._extract_from_trace(trace) |
| for ex in trace_examples: |
| |
| h = hashlib.md5((ex.input_text + ex.output_text).encode()).hexdigest()[:12] |
| if h in seen_hashes: |
| continue |
| seen_hashes.add(h) |
|
|
| |
| if not self._passes_immune(ex): |
| self._rejected += 1 |
| continue |
|
|
| examples.append(ex) |
|
|
| |
| n = len(examples) |
| n_train = int(n * self.train_ratio) |
| n_val = int(n * self.val_ratio) |
| for i, ex in enumerate(examples): |
| if i < n_train: |
| ex.split = "train" |
| elif i < n_train + n_val: |
| ex.split = "validation" |
| else: |
| ex.split = "test" |
|
|
| return TraceDataset( |
| examples=examples, |
| metadata={"total_traces": len(traces), "total_examples": n, |
| "rejected": self._rejected, "min_phi": self.min_phi}, |
| ) |
|
|
| def _extract_from_trace(self, trace: Trace) -> list[DatasetExample]: |
| """Extract examples from a single trace.""" |
| examples = [] |
|
|
| |
| score_events = [e for e in trace.events if e.kind == "score"] |
| if score_events: |
| final_phi = score_events[-1].data.get("phi_after", score_events[-1].data.get("phi", 0)) |
| if final_phi < self.min_phi: |
| return [] |
|
|
| |
| action_events = [e for e in trace.events if e.kind in ("action", "agent.progress")] |
| for event in action_events: |
| input_text = trace.purpose |
| output_text = event.data.get("thought", "") or event.data.get("action", "") |
| if not output_text: |
| continue |
|
|
| tool_calls = [] |
| tool_name = event.data.get("name", event.data.get("tool", "")) |
| if tool_name: |
| tool_calls.append({"name": tool_name, "args": event.data.get("params", {})}) |
|
|
| examples.append(DatasetExample( |
| id=f"{trace.trace_id}_{event.step}", |
| input_text=input_text, |
| output_text=output_text[:1000], |
| tool_calls=tool_calls, |
| metadata={"trace_id": trace.trace_id, "step": event.step}, |
| )) |
|
|
| return examples |
|
|
| def _passes_immune(self, ex: DatasetExample) -> bool: |
| """Check example doesn't contain adversarial content.""" |
| card = MemoryCard(kind=MemoryKind.SKILL_CARD, content=ex.output_text[:500]) |
| return scan_memory(card).passed |
|
|
| @property |
| def rejected_count(self) -> int: |
| return self._rejected |
|
|