File size: 5,552 Bytes
1fa0c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
compiler.py — Prompt compiler with token budget enforcement and credit assignment.

The PromptCompiler selects which memories to include in the prompt based on:
  1. Relevance to the current task
  2. Trust score (immune-scanned, tested)
  3. Utility score (has this memory actually helped before?)
  4. Scope match (right agent, right tools, right task category)
  5. Diversity (avoid redundant memories)
  6. Token cost (fit within budget)

Returns included_memory_ids so the orchestrator can do credit assignment
after the step: update utility scores for memories that were in context.
"""
from __future__ import annotations

import logging
from dataclasses import dataclass, field
from typing import Any

from purpose_agent.memory import MemoryCard, MemoryKind, MemoryStatus, MemoryStore
from purpose_agent.v2_types import MemoryScope

logger = logging.getLogger(__name__)


@dataclass
class CompiledPrompt:
    """Result of prompt compilation."""
    system_sections: list[str] = field(default_factory=list)
    included_memory_ids: list[str] = field(default_factory=list)
    total_tokens_estimated: int = 0
    budget_remaining: int = 0
    memories_considered: int = 0
    memories_included: int = 0

    @property
    def system_prompt(self) -> str:
        return "\n\n".join(self.system_sections)


class PromptCompiler:
    """
    Compiles a prompt by selecting the best memories under a token budget.

    The key invariant: only promoted memories are included.
    Candidate/quarantined/rejected memories are never exposed to the LLM.
    """

    def __init__(
        self,
        memory_store: MemoryStore,
        token_budget: int = 4096,
        chars_per_token: int = 4,
    ):
        self.store = memory_store
        self.token_budget = token_budget
        self.chars_per_token = chars_per_token

    def compile(
        self,
        task: str,
        base_prompt: str,
        scope: MemoryScope | None = None,
        max_memories: int = 15,
    ) -> CompiledPrompt:
        """
        Compile a prompt: base_prompt + best memories under token budget.

        Returns CompiledPrompt with included_memory_ids for credit assignment.
        """
        result = CompiledPrompt()
        result.system_sections.append(base_prompt)

        base_tokens = len(base_prompt) // self.chars_per_token
        remaining = self.token_budget - base_tokens
        result.budget_remaining = remaining

        if remaining <= 100:
            result.total_tokens_estimated = base_tokens
            return result

        # Retrieve candidate memories (only PROMOTED)
        candidates = self.store.retrieve(
            query_text=task,
            scope=scope,
            statuses=[MemoryStatus.PROMOTED],
            top_k=max_memories * 2,  # over-fetch for diversity filtering
        )
        result.memories_considered = len(candidates)

        # Deduplicate by content similarity
        selected = self._diverse_select(candidates, max_memories)

        # Fill prompt under budget
        memory_sections = []
        for card in selected:
            text = self._format_memory(card)
            token_cost = len(text) // self.chars_per_token

            if token_cost > remaining:
                continue

            memory_sections.append(text)
            result.included_memory_ids.append(card.id)
            remaining -= token_cost
            card.times_retrieved += 1

        if memory_sections:
            result.system_sections.append(
                "## Learned Knowledge\n" + "\n".join(memory_sections)
            )

        result.memories_included = len(result.included_memory_ids)
        result.budget_remaining = remaining
        result.total_tokens_estimated = (self.token_budget - remaining)

        return result

    def _format_memory(self, card: MemoryCard) -> str:
        """Format a single memory card for prompt inclusion."""
        if card.kind == MemoryKind.SKILL_CARD:
            text = f"- Skill: When {card.pattern}, do: {card.strategy}"
            if card.steps:
                text += " Steps: " + "; ".join(card.steps[:3])
        elif card.kind == MemoryKind.USER_PREFERENCE:
            text = f"- User preference: {card.content or card.strategy}"
        elif card.kind == MemoryKind.FAILURE_PATTERN:
            text = f"- Avoid: {card.pattern}{card.strategy}"
        elif card.kind == MemoryKind.TOOL_POLICY:
            text = f"- Tool tip ({', '.join(card.scope.tool_names)}): {card.strategy}"
        elif card.kind == MemoryKind.PURPOSE_CONTRACT:
            text = f"- Goal constraint: {card.content or card.strategy}"
        elif card.kind == MemoryKind.CRITIC_CALIBRATION:
            text = f"- Scoring note: {card.content or card.strategy}"
        else:
            text = f"- [{card.kind.value}] {card.content or card.strategy}"
        return text

    def _diverse_select(
        self, candidates: list[MemoryCard], max_n: int
    ) -> list[MemoryCard]:
        """Select diverse memories — avoid near-duplicates."""
        if len(candidates) <= max_n:
            return candidates

        selected: list[MemoryCard] = []
        seen_patterns: set[str] = set()

        for card in candidates:
            # Rough dedup by pattern prefix
            key = (card.pattern or card.content or "")[:50].lower().strip()
            if key in seen_patterns:
                continue
            seen_patterns.add(key)
            selected.append(card)
            if len(selected) >= max_n:
                break

        return selected