File size: 5,553 Bytes
631c715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
dataset.py — Build training datasets from execution traces.

Converts successful agent traces into structured train/eval datasets:
  - Filter by success (only learn from good trajectories)
  - Immune scan (remove poisoned examples)
  - Deduplicate
  - Split train/validation/test
  - Keep tool-use trajectories (not just final answers)
"""
from __future__ import annotations
import json
import hashlib
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from purpose_agent.trace import Trace
from purpose_agent.immune import scan_memory
from purpose_agent.memory import MemoryCard, MemoryKind


@dataclass
class DatasetExample:
    """A single training example extracted from a trace."""
    id: str
    input_text: str
    output_text: str
    tool_calls: list[dict[str, Any]] = field(default_factory=list)
    metadata: dict[str, Any] = field(default_factory=dict)
    split: str = "train"  # train/validation/test


@dataclass
class TraceDataset:
    """A filtered, deduplicated dataset ready for training."""
    examples: list[DatasetExample] = field(default_factory=list)
    metadata: dict[str, Any] = field(default_factory=dict)

    @property
    def train(self) -> list[DatasetExample]:
        return [e for e in self.examples if e.split == "train"]

    @property
    def validation(self) -> list[DatasetExample]:
        return [e for e in self.examples if e.split == "validation"]

    @property
    def test(self) -> list[DatasetExample]:
        return [e for e in self.examples if e.split == "test"]

    def save(self, path: str) -> None:
        Path(path).parent.mkdir(parents=True, exist_ok=True)
        data = {"examples": [{"id": e.id, "input": e.input_text, "output": e.output_text,
                              "tool_calls": e.tool_calls, "split": e.split} for e in self.examples],
                "metadata": self.metadata}
        with open(path, "w") as f:
            json.dump(data, f, indent=2)

    @property
    def size(self) -> int:
        return len(self.examples)


class TraceDatasetBuilder:
    """
    Builds filtered datasets from execution traces.
    
    Usage:
        builder = TraceDatasetBuilder(min_phi=7.0, train_ratio=0.7)
        dataset = builder.build(traces)
        dataset.save("./datasets/training.json")
    """

    def __init__(self, min_phi: float = 6.0, train_ratio: float = 0.7, val_ratio: float = 0.15):
        self.min_phi = min_phi
        self.train_ratio = train_ratio
        self.val_ratio = val_ratio
        self._rejected = 0

    def build(self, traces: list[Trace]) -> TraceDataset:
        """Build dataset from traces."""
        examples = []
        seen_hashes: set[str] = set()

        for trace in traces:
            trace_examples = self._extract_from_trace(trace)
            for ex in trace_examples:
                # Dedup by content hash
                h = hashlib.md5((ex.input_text + ex.output_text).encode()).hexdigest()[:12]
                if h in seen_hashes:
                    continue
                seen_hashes.add(h)

                # Immune scan
                if not self._passes_immune(ex):
                    self._rejected += 1
                    continue

                examples.append(ex)

        # Assign splits
        n = len(examples)
        n_train = int(n * self.train_ratio)
        n_val = int(n * self.val_ratio)
        for i, ex in enumerate(examples):
            if i < n_train:
                ex.split = "train"
            elif i < n_train + n_val:
                ex.split = "validation"
            else:
                ex.split = "test"

        return TraceDataset(
            examples=examples,
            metadata={"total_traces": len(traces), "total_examples": n,
                      "rejected": self._rejected, "min_phi": self.min_phi},
        )

    def _extract_from_trace(self, trace: Trace) -> list[DatasetExample]:
        """Extract examples from a single trace."""
        examples = []

        # Only use successful traces
        score_events = [e for e in trace.events if e.kind == "score"]
        if score_events:
            final_phi = score_events[-1].data.get("phi_after", score_events[-1].data.get("phi", 0))
            if final_phi < self.min_phi:
                return []

        # Extract action→observation pairs as examples
        action_events = [e for e in trace.events if e.kind in ("action", "agent.progress")]
        for event in action_events:
            input_text = trace.purpose
            output_text = event.data.get("thought", "") or event.data.get("action", "")
            if not output_text:
                continue

            tool_calls = []
            tool_name = event.data.get("name", event.data.get("tool", ""))
            if tool_name:
                tool_calls.append({"name": tool_name, "args": event.data.get("params", {})})

            examples.append(DatasetExample(
                id=f"{trace.trace_id}_{event.step}",
                input_text=input_text,
                output_text=output_text[:1000],
                tool_calls=tool_calls,
                metadata={"trace_id": trace.trace_id, "step": event.step},
            ))

        return examples

    def _passes_immune(self, ex: DatasetExample) -> bool:
        """Check example doesn't contain adversarial content."""
        card = MemoryCard(kind=MemoryKind.SKILL_CARD, content=ex.output_text[:500])
        return scan_memory(card).passed

    @property
    def rejected_count(self) -> int:
        return self._rejected