File size: 5,553 Bytes
631c715 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | """
dataset.py — Build training datasets from execution traces.
Converts successful agent traces into structured train/eval datasets:
- Filter by success (only learn from good trajectories)
- Immune scan (remove poisoned examples)
- Deduplicate
- Split train/validation/test
- Keep tool-use trajectories (not just final answers)
"""
from __future__ import annotations
import json
import hashlib
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from purpose_agent.trace import Trace
from purpose_agent.immune import scan_memory
from purpose_agent.memory import MemoryCard, MemoryKind
@dataclass
class DatasetExample:
"""A single training example extracted from a trace."""
id: str
input_text: str
output_text: str
tool_calls: list[dict[str, Any]] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
split: str = "train" # train/validation/test
@dataclass
class TraceDataset:
"""A filtered, deduplicated dataset ready for training."""
examples: list[DatasetExample] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
@property
def train(self) -> list[DatasetExample]:
return [e for e in self.examples if e.split == "train"]
@property
def validation(self) -> list[DatasetExample]:
return [e for e in self.examples if e.split == "validation"]
@property
def test(self) -> list[DatasetExample]:
return [e for e in self.examples if e.split == "test"]
def save(self, path: str) -> None:
Path(path).parent.mkdir(parents=True, exist_ok=True)
data = {"examples": [{"id": e.id, "input": e.input_text, "output": e.output_text,
"tool_calls": e.tool_calls, "split": e.split} for e in self.examples],
"metadata": self.metadata}
with open(path, "w") as f:
json.dump(data, f, indent=2)
@property
def size(self) -> int:
return len(self.examples)
class TraceDatasetBuilder:
"""
Builds filtered datasets from execution traces.
Usage:
builder = TraceDatasetBuilder(min_phi=7.0, train_ratio=0.7)
dataset = builder.build(traces)
dataset.save("./datasets/training.json")
"""
def __init__(self, min_phi: float = 6.0, train_ratio: float = 0.7, val_ratio: float = 0.15):
self.min_phi = min_phi
self.train_ratio = train_ratio
self.val_ratio = val_ratio
self._rejected = 0
def build(self, traces: list[Trace]) -> TraceDataset:
"""Build dataset from traces."""
examples = []
seen_hashes: set[str] = set()
for trace in traces:
trace_examples = self._extract_from_trace(trace)
for ex in trace_examples:
# Dedup by content hash
h = hashlib.md5((ex.input_text + ex.output_text).encode()).hexdigest()[:12]
if h in seen_hashes:
continue
seen_hashes.add(h)
# Immune scan
if not self._passes_immune(ex):
self._rejected += 1
continue
examples.append(ex)
# Assign splits
n = len(examples)
n_train = int(n * self.train_ratio)
n_val = int(n * self.val_ratio)
for i, ex in enumerate(examples):
if i < n_train:
ex.split = "train"
elif i < n_train + n_val:
ex.split = "validation"
else:
ex.split = "test"
return TraceDataset(
examples=examples,
metadata={"total_traces": len(traces), "total_examples": n,
"rejected": self._rejected, "min_phi": self.min_phi},
)
def _extract_from_trace(self, trace: Trace) -> list[DatasetExample]:
"""Extract examples from a single trace."""
examples = []
# Only use successful traces
score_events = [e for e in trace.events if e.kind == "score"]
if score_events:
final_phi = score_events[-1].data.get("phi_after", score_events[-1].data.get("phi", 0))
if final_phi < self.min_phi:
return []
# Extract action→observation pairs as examples
action_events = [e for e in trace.events if e.kind in ("action", "agent.progress")]
for event in action_events:
input_text = trace.purpose
output_text = event.data.get("thought", "") or event.data.get("action", "")
if not output_text:
continue
tool_calls = []
tool_name = event.data.get("name", event.data.get("tool", ""))
if tool_name:
tool_calls.append({"name": tool_name, "args": event.data.get("params", {})})
examples.append(DatasetExample(
id=f"{trace.trace_id}_{event.step}",
input_text=input_text,
output_text=output_text[:1000],
tool_calls=tool_calls,
metadata={"trace_id": trace.trace_id, "step": event.step},
))
return examples
def _passes_immune(self, ex: DatasetExample) -> bool:
"""Check example doesn't contain adversarial content."""
card = MemoryCard(kind=MemoryKind.SKILL_CARD, content=ex.output_text[:500])
return scan_memory(card).passed
@property
def rejected_count(self) -> int:
return self._rejected
|