File size: 17,005 Bytes
1fce89d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
"""
Generation Agent with Chain-of-Thought (CoT) Reasoning and Self-Correction.

UPGRADES vs previous version:
- Retry logic via tenacity: all LLM calls automatically retry up to
  MAX_API_RETRIES times with exponential backoff. Prevents silent failures
  under Groq rate limits or transient network errors.
- Explicit request_timeout on ChatGroq to surface hung calls instead of
  waiting indefinitely.
- All prompts, logic, and dual-LLM (generator + critic) architecture
  are fully preserved from the original.
- Full type annotations on all public methods.

NEW β€” Query Decomposition (generate_answer_decomposed):
- Decomposes complex multi-part questions into 2–4 focused sub-queries.
- Retrieves context independently for each sub-query.
- Deduplicates and merges all retrieved chunks by chunk_id.
- Synthesises sub-results into a single coherent cited answer via a
  dedicated synthesis prompt, then audits with the existing compliance
  auditor for hallucination removal.
- Returns (final_answer, all_source_docs, sub_queries) so callers can
  surface the decomposition reasoning chain in a UI.
"""

import json

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document
from langchain_groq import ChatGroq
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type

from src.config import (
    logger,
    GENERATOR_MODEL,
    MAX_API_RETRIES,
    API_RETRY_MIN_WAIT,
    API_RETRY_MAX_WAIT,
)


# ── Retry Decorator ───────────────────────────────────────────────────────────
# Applied to every LLM call. Retries on any Exception (covers rate limits,
# timeouts, and transient network errors from Groq).
# wait_exponential: 2s β†’ 4s β†’ 8s β†’ cap 10s between attempts.
_llm_retry = retry(
    stop=stop_after_attempt(MAX_API_RETRIES),
    wait=wait_exponential(
        multiplier=1,
        min=API_RETRY_MIN_WAIT,
        max=API_RETRY_MAX_WAIT,
    ),
    reraise=True,  # if all retries fail, re-raise the original exception
)


# ── Query Decomposition Prompt ────────────────────────────────────────────────
# Asks the LLM to split a complex question into 2–4 atomic sub-queries, each
# independently answerable via a single targeted retrieval pass.
# Simple / single-focus questions are returned unchanged as a 1-element array.
_DECOMPOSE_PROMPT = ChatPromptTemplate.from_messages([
    ("system", """You are a financial research query analyst specialising in SEC 10-K filings.

Your task: analyse the user's question and decide whether it requires decomposition.

Rules:
1. If the question asks about multiple companies, multiple metrics, or multiple time periods, decompose it into 2–4 focused sub-queries β€” one per distinct retrieval need.
2. If the question is already atomic and focused (single fact, single company), return it unchanged as a 1-element array.
3. Every sub-query must be fully self-contained (include the company name, metric name, and fiscal year if inferable from the original).
4. You MUST respond with ONLY a valid JSON array of strings. No explanation, no preamble, no markdown code fences.

Example β€” complex:
Input:  "Compare Google and Meta's R&D spending and headcount trends"
Output: ["What were Google's total R&D expenses and year-over-year change?", "What were Meta's total R&D expenses and year-over-year change?", "What were Google's total headcount and hiring trends?", "What were Meta's total headcount and hiring trends?"]

Example β€” simple:
Input:  "What was Google's net income in FY2025?"
Output: ["What was Google's net income in FY2025?"]"""),
    ("human", "{question}"),
])


# ── Synthesis Prompt ──────────────────────────────────────────────────────────
# Takes the merged context from all sub-query retrieval passes and writes a
# single coherent comparative answer covering every decomposed sub-aspect.
_SYNTHESIS_PROMPT = ChatPromptTemplate.from_messages([
    ("system", """You are a Lead Financial Data Scientist. Multiple sub-queries were retrieved independently to answer the user's original question. Synthesise all findings into a single, coherent, professional comparative analysis.

Rules:
- Use structured bullet points with clear per-company or per-topic section headers.
- Cite the source immediately after every claim (e.g., [Source: Meta 10-K]).
- If a sub-query returned no relevant data, explicitly state the filing does not provide that information.
- Do NOT invent or infer any numbers absent from the context.

Sub-queries that were researched:
{sub_queries_formatted}

Combined Retrieved Context (all sub-queries merged and deduplicated):
{context}"""),
    ("human", "Original question: {question}\n\nSynthesize a complete, cited answer covering every sub-aspect above."),
])


class FinancialGenerationAgent:
    """
    Two-stage LLM pipeline: Chain-of-Thought generator β†’ Compliance auditor.

    Stage 1 (Generator): Llama-3.3-70B reasons step-by-step using retrieved
                         context, then writes a structured comparative analysis.
    Stage 2 (Auditor):   The same model reviews the draft as an SEC Compliance
                         Auditor β€” removing any claim not grounded in the context.

    Both stages use the same GENERATOR_MODEL intentionally: the auditor prompt
    role-plays a distinct persona (compliance reviewer vs. data scientist),
    creating adversarial tension within a single model family. The evaluator
    in evaluation.py uses a DIFFERENT model family to avoid circular bias.
    """

    def __init__(self, retriever, api_key: str) -> None:
        self.retriever = retriever
        self.llm = ChatGroq(
            model=GENERATOR_MODEL,
            temperature=0,
            api_key=api_key,
            request_timeout=60,   # surface hung calls; don't wait indefinitely
        )

        # ── Generator Prompt (CoT Analysis) ──────────────────────────────────
        self.qa_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are a Lead Financial Data Scientist. Your objective is to answer the user's query using ONLY the provided SEC 10-K context.

You MUST structure your response strictly in two parts:

<thought_process>
1. Extract the raw facts from the context for each company.
2. Identify the strategic overlaps and stark differences.
3. Note any data requested by the prompt that is explicitly missing from the context.
</thought_process>

<final_answer>
Synthesize your findings into a professional, comparative analysis. 
- Use structured bullet points.
- You MUST cite the source immediately after every claim (e.g., [Source: Meta 10-K]).
- If data is missing (e.g., specific 2025 budgets), explicitly state that the filings do not provide forward-looking numbers for that year.
</final_answer>

Context:
{context}"""),
            ("human", "{question}"),
        ])

        # ── Auditor Prompt (Compliance Review) ───────────────────────────────
        self.critic_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are an SEC Compliance Auditor. Review the <final_answer> section of the Draft Answer against the Source Context.
If the draft contains ANY numbers, metrics, or claims not explicitly present in the context, rewrite it to remove them.
Ensure every claim has a citation.

Original User Question:
{question}

Source Context:
{context}

Draft Answer:
{draft}"""),
            ("human", "Audit this draft. Output ONLY the finalized, hallucination-free <final_answer> text. Do not include the <thought_process>."),
        ])

    # ── Internal: Retryable LLM Calls ────────────────────────────────────────
    @_llm_retry
    def _invoke_generator(self, context: str, question: str) -> str:
        """
        Stage 1: Generate a CoT draft answer. Retried on failure.

        Args:
            context:  Concatenated retrieved document chunks.
            question: The user's original query.

        Returns:
            Raw LLM response string including <thought_process> and <final_answer>.
        """
        chain = self.qa_prompt | self.llm
        return chain.invoke({"context": context, "question": question}).content

    @_llm_retry
    def _invoke_auditor(self, context: str, question: str, draft: str) -> str:
        """
        Stage 2: Audit the draft for hallucinations. Retried on failure.

        Args:
            context:  Same retrieved context passed to the generator.
            question: The user's original query.
            draft:    The raw generator output to be audited.

        Returns:
            Final hallucination-free <final_answer> text.
        """
        chain = self.critic_prompt | self.llm
        return chain.invoke({
            "context":  context,
            "draft":    draft,
            "question": question,
        }).content

    # ── Public: Full Answer Generation ───────────────────────────────────────
    def generate_answer(self, query: str) -> tuple[str, list[Document]]:
        """
        Execute the full two-stage RAG generation pipeline.

        Args:
            query: Natural language question from the user.

        Returns:
            Tuple of (final_answer: str, source_docs: list[Document]).
            final_answer is the audited, hallucination-checked response.
            source_docs are the exact chunks used to construct the answer.

        Raises:
            Exception: Propagates after MAX_API_RETRIES exhausted on LLM calls.
        """
        logger.info("Retrieving documents for query: '%s'", query)
        docs: list[Document] = self.retriever.invoke(query)

        if not docs:
            logger.warning("Retriever returned zero documents for query: '%s'", query)

        # Build context string with explicit company attribution per chunk.
        context: str = "\n\n".join([
            f"[Source: {d.metadata.get('company', 'Unknown')} 10-K]: {d.page_content}"
            for d in docs
        ])

        logger.info("Step 1: Executing Chain-of-Thought Analysis (%s)...", GENERATOR_MODEL)
        draft_response: str = self._invoke_generator(context, query)

        logger.info("Step 2: Running Strict Compliance Audit (%s)...", GENERATOR_MODEL)
        final_response: str = self._invoke_auditor(context, query, draft_response)

        return final_response, docs

    # ── Internal: Query Decomposition Helpers ─────────────────────────────────

    @_llm_retry
    def _invoke_decomposer(self, question: str) -> list[str]:
        """
        Decompose a complex question into focused sub-queries via LLM.

        The LLM is instructed to return ONLY a JSON array of strings.
        Falls back to the original question as a single-element list on any
        parse failure, so the caller always receives a valid list[str].

        Args:
            question: The original user question.

        Returns:
            List of 1–4 sub-query strings.
        """
        chain = _DECOMPOSE_PROMPT | self.llm
        raw: str = chain.invoke({"question": question}).content.strip()

        # Strip accidental markdown code fences if the model adds them.
        if raw.startswith("```"):
            raw = raw.split("```")[1]
            if raw.startswith("json"):
                raw = raw[4:]
            raw = raw.strip()

        try:
            sub_queries: list[str] = json.loads(raw)
            if not isinstance(sub_queries, list) or not all(
                isinstance(q, str) for q in sub_queries
            ):
                raise ValueError("Parsed value is not a list of strings.")
            logger.info(
                "Query decomposed into %d sub-queries: %s",
                len(sub_queries), sub_queries,
            )
            return sub_queries
        except (json.JSONDecodeError, ValueError) as exc:
            logger.warning(
                "Decomposer returned unparseable output (%s). "
                "Falling back to original question as single sub-query.", exc
            )
            return [question]

    @_llm_retry
    def _invoke_synthesizer(
        self, sub_queries: list[str], context: str, question: str
    ) -> str:
        """
        Synthesise answers from multiple sub-query retrievals into one response.

        Args:
            sub_queries: The list of decomposed sub-queries (for prompt context).
            context:     Merged, deduplicated context from all sub-query retrievals.
            question:    The original user question for the synthesis prompt.

        Returns:
            Raw synthesised answer string (before compliance audit).
        """
        sub_queries_formatted: str = "\n".join(
            f"  {i}. {q}" for i, q in enumerate(sub_queries, 1)
        )
        chain = _SYNTHESIS_PROMPT | self.llm
        return chain.invoke({
            "sub_queries_formatted": sub_queries_formatted,
            "context":               context,
            "question":              question,
        }).content

    # ── Public: Decomposed Answer Generation ─────────────────────────────────
    def generate_answer_decomposed(
        self, query: str
    ) -> tuple[str, list[Document], list[str]]:
        """
        Execute the full query-decomposition RAG pipeline.

        Pipeline:
          1. Decompose the query into 1–4 focused sub-queries (LLM).
          2. Retrieve documents independently for each sub-query.
          3. Merge and deduplicate all retrieved chunks by chunk_id.
          4. Synthesise the merged context into a single comparative answer (LLM).
          5. Audit the synthesised answer for hallucinations (existing auditor).

        For simple questions the decomposer returns a 1-element list, making
        this method functionally equivalent to generate_answer β€” no wasted calls.

        Args:
            query: Natural language question from the user.

        Returns:
            Tuple of:
              - final_answer (str):        Audited, hallucination-free response.
              - all_source_docs (list):    Deduplicated chunks used across all sub-queries.
              - sub_queries (list[str]):   The decomposed sub-queries for UI display.

        Raises:
            Exception: Propagates after MAX_API_RETRIES exhausted on any LLM call.
        """
        logger.info(
            "[Decomposition] Decomposing query: '%s'", query
        )
        sub_queries: list[str] = self._invoke_decomposer(query)

        # Retrieve context for each sub-query independently and merge.
        doc_map: dict[str, Document] = {}   # chunk_id β†’ Document (deduplication)

        for i, sub_query in enumerate(sub_queries, 1):
            logger.info(
                "[Decomposition] Sub-query %d/%d retrieval: '%s'",
                i, len(sub_queries), sub_query,
            )
            docs: list[Document] = self.retriever.invoke(sub_query)
            for doc in docs:
                chunk_id: str = doc.metadata.get(
                    "chunk_id",
                    # Fallback hash if chunk_id absent (should not occur in production).
                    f"fallback_{i}_{hash(doc.page_content)}",
                )
                doc_map[chunk_id] = doc   # later sub-queries overwrite on collision; content is identical

        all_docs: list[Document] = list(doc_map.values())

        if not all_docs:
            logger.warning(
                "[Decomposition] All sub-query retrievals returned zero documents."
            )

        # Build merged context string with explicit company attribution.
        merged_context: str = "\n\n".join([
            f"[Source: {d.metadata.get('company', 'Unknown')} 10-K]: {d.page_content}"
            for d in all_docs
        ])

        logger.info(
            "[Decomposition] Synthesis over %d deduplicated chunks (%d sub-queries).",
            len(all_docs), len(sub_queries),
        )
        synthesised_draft: str = self._invoke_synthesizer(
            sub_queries, merged_context, query
        )

        logger.info(
            "[Decomposition] Running compliance audit on synthesised answer..."
        )
        final_response: str = self._invoke_auditor(
            merged_context, query, synthesised_draft
        )

        return final_response, all_docs, sub_queries