File size: 9,241 Bytes
1fce89d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
"""
Conversation Memory module for the Financial Intelligence Engine.

Maintains a rolling window of the last N question-answer turns and uses
that history to contextualise follow-up questions before they are passed
to the retriever and generation agent.

Design decisions:
- Pure Python β€” no LLM call needed for query reformulation.
  Follow-up detection is heuristic (pronoun / reference word scanning).
  This avoids adding latency and token cost to every user turn.
- Rolling window (default: last 3 turns). Older turns are evicted to
  prevent context window bloat in the retrieval query string.
- Immutable Turn dataclass β€” history entries cannot be mutated after the
  fact, preventing accidental state corruption in multi-threaded Gradio workers.
- Thread-safe add_turn via list replacement rather than in-place mutation.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

from src.config import logger


# ── Data Model ────────────────────────────────────────────────────────────────

@dataclass(frozen=True)
class Turn:
    """
    Immutable record of a single conversation turn.

    Attributes:
        query:  The user's original question (before any reformulation).
        answer: The final audited answer returned by the generation agent.
    """
    query:  str
    answer: str


# ── Conversation Memory ───────────────────────────────────────────────────────

# Signals that a user message is a follow-up referencing a previous entity.
# When any of these appear in the new query, the last N prior questions are
# appended as context so the retriever receives a self-contained query string.
_FOLLOW_UP_SIGNALS: frozenset[str] = frozenset({
    "their", "its", "they", "it", "those", "these",
    "the same", "what about", "how about", "also",
    "and their", "compare to", "compared to",
    "how does that", "how do they", "what else",
    "tell me more", "elaborate", "expand on",
    "what about the", "versus", "vs",
})


class ConversationMemory:
    """
    Rolling-window conversation history with follow-up query reformulation.

    Usage:
        memory = ConversationMemory(max_turns=3)

        # After each RAG call:
        memory.add_turn(user_query, agent_answer)

        # Before next RAG call β€” enriches follow-up queries with prior context:
        retrieval_query = memory.reformulate_query(new_user_question)

    The reformulated query is ONLY used for retrieval and generation; the
    original user question is always stored in history so the conversation
    log remains human-readable.
    """

    def __init__(self, max_turns: int = 3) -> None:
        """
        Args:
            max_turns: Maximum number of prior turns to keep in the rolling
                       window. Older turns are evicted FIFO. Default: 3.
        """
        if max_turns < 1:
            raise ValueError("max_turns must be at least 1.")
        self._max_turns: int        = max_turns
        self._history:  list[Turn]  = []

    # ── Mutation ──────────────────────────────────────────────────────────────

    def add_turn(self, query: str, answer: str) -> None:
        """
        Append a completed turn and evict the oldest if the window is full.

        Args:
            query:  The original user question for this turn.
            answer: The final agent answer for this turn.
        """
        self._history.append(Turn(query=query, answer=answer))
        if len(self._history) > self._max_turns:
            # Evict oldest turn (FIFO). List slice creates a new list,
            # so concurrent readers of _history see a consistent snapshot.
            self._history = self._history[-self._max_turns:]
        logger.info(
            "[ConversationMemory] Turn added. History depth: %d/%d.",
            len(self._history), self._max_turns,
        )

    def clear(self) -> None:
        """Reset conversation history to an empty state."""
        self._history = []
        logger.info("[ConversationMemory] History cleared.")

    # ── Query Reformulation ───────────────────────────────────────────────────

    def reformulate_query(self, current_query: str) -> str:
        """
        Return a retrieval-ready query enriched with conversation context
        when the current message appears to be a follow-up.

        Follow-up detection: if the lowercased query contains any token from
        _FOLLOW_UP_SIGNALS AND there is at least one prior turn in history,
        the prior questions are appended so the retriever has a self-contained
        query string (no dangling pronouns or implicit references).

        If the query is standalone (no follow-up signals, or history is empty),
        it is returned unchanged β€” zero overhead for fresh questions.

        Args:
            current_query: The raw question from the user this turn.

        Returns:
            The original query, or a context-enriched version for retrieval.
        """
        if not self._history:
            return current_query

        lowered: str = current_query.lower()
        is_follow_up: bool = any(
            signal in lowered for signal in _FOLLOW_UP_SIGNALS
        )

        if not is_follow_up:
            return current_query

        # Build a compact context string from the last N prior questions.
        # We include only the questions (not the full answers) to keep the
        # retrieval query concise and avoid the 8192-token BM25 limit.
        prior_questions: str = " | ".join(
            f"Q{i}: {turn.query}"
            for i, turn in enumerate(self._history, 1)
        )

        reformulated: str = (
            f"{current_query} "
            f"[Prior conversation context: {prior_questions}]"
        )

        logger.info(
            "[ConversationMemory] Follow-up detected. Reformulated query: '%s'",
            reformulated[:120],
        )
        return reformulated

    # ── Read-Only Accessors ───────────────────────────────────────────────────

    def get_history_as_text(self, include_answers: bool = True) -> str:
        """
        Return the conversation history as a human-readable string.

        Args:
            include_answers: If True, include truncated answer previews (200 chars).
                             If False, return only the questions β€” useful for
                             building a compact retrieval context string.

        Returns:
            Multi-line string with Q/A pairs, or empty string if no history.
        """
        if not self._history:
            return ""

        lines: list[str] = []
        for i, turn in enumerate(self._history, 1):
            lines.append(f"Q{i}: {turn.query}")
            if include_answers:
                preview = turn.answer[:200].rstrip()
                suffix  = "..." if len(turn.answer) > 200 else ""
                lines.append(f"A{i}: {preview}{suffix}")
            lines.append("")   # blank line separator

        return "\n".join(lines).strip()

    def get_history_as_gradio_pairs(self) -> list[list]:
        """
        Return history in the [[user, assistant], ...] tuples format expected
        by gr.Chatbot (Gradio 5.x default) so a restored session can seed
        the Gradio chat component.

        Returns:
            List of [query, answer] pairs.
        """
        return [[turn.query, turn.answer] for turn in self._history]

    def get_last_n_queries(self, n: Optional[int] = None) -> list[str]:
        """
        Return the last n user questions.

        Args:
            n: Number of recent queries to return. Defaults to max_turns.

        Returns:
            List of query strings, oldest first.
        """
        limit: int = n if n is not None else self._max_turns
        return [turn.query for turn in self._history[-limit:]]

    # ── Properties ────────────────────────────────────────────────────────────

    @property
    def turn_count(self) -> int:
        """Number of completed turns currently held in the window."""
        return len(self._history)

    @property
    def is_empty(self) -> bool:
        """True if no turns have been recorded yet."""
        return len(self._history) == 0

    @property
    def max_turns(self) -> int:
        """The configured rolling-window size."""
        return self._max_turns

    def __repr__(self) -> str:
        return (
            f"ConversationMemory("
            f"turns={self.turn_count}/{self._max_turns}, "
            f"queries={self.get_last_n_queries()})"
        )