File size: 17,005 Bytes
1fce89d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 | """
Generation Agent with Chain-of-Thought (CoT) Reasoning and Self-Correction.
UPGRADES vs previous version:
- Retry logic via tenacity: all LLM calls automatically retry up to
MAX_API_RETRIES times with exponential backoff. Prevents silent failures
under Groq rate limits or transient network errors.
- Explicit request_timeout on ChatGroq to surface hung calls instead of
waiting indefinitely.
- All prompts, logic, and dual-LLM (generator + critic) architecture
are fully preserved from the original.
- Full type annotations on all public methods.
NEW β Query Decomposition (generate_answer_decomposed):
- Decomposes complex multi-part questions into 2β4 focused sub-queries.
- Retrieves context independently for each sub-query.
- Deduplicates and merges all retrieved chunks by chunk_id.
- Synthesises sub-results into a single coherent cited answer via a
dedicated synthesis prompt, then audits with the existing compliance
auditor for hallucination removal.
- Returns (final_answer, all_source_docs, sub_queries) so callers can
surface the decomposition reasoning chain in a UI.
"""
import json
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document
from langchain_groq import ChatGroq
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from src.config import (
logger,
GENERATOR_MODEL,
MAX_API_RETRIES,
API_RETRY_MIN_WAIT,
API_RETRY_MAX_WAIT,
)
# ββ Retry Decorator βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Applied to every LLM call. Retries on any Exception (covers rate limits,
# timeouts, and transient network errors from Groq).
# wait_exponential: 2s β 4s β 8s β cap 10s between attempts.
_llm_retry = retry(
stop=stop_after_attempt(MAX_API_RETRIES),
wait=wait_exponential(
multiplier=1,
min=API_RETRY_MIN_WAIT,
max=API_RETRY_MAX_WAIT,
),
reraise=True, # if all retries fail, re-raise the original exception
)
# ββ Query Decomposition Prompt ββββββββββββββββββββββββββββββββββββββββββββββββ
# Asks the LLM to split a complex question into 2β4 atomic sub-queries, each
# independently answerable via a single targeted retrieval pass.
# Simple / single-focus questions are returned unchanged as a 1-element array.
_DECOMPOSE_PROMPT = ChatPromptTemplate.from_messages([
("system", """You are a financial research query analyst specialising in SEC 10-K filings.
Your task: analyse the user's question and decide whether it requires decomposition.
Rules:
1. If the question asks about multiple companies, multiple metrics, or multiple time periods, decompose it into 2β4 focused sub-queries β one per distinct retrieval need.
2. If the question is already atomic and focused (single fact, single company), return it unchanged as a 1-element array.
3. Every sub-query must be fully self-contained (include the company name, metric name, and fiscal year if inferable from the original).
4. You MUST respond with ONLY a valid JSON array of strings. No explanation, no preamble, no markdown code fences.
Example β complex:
Input: "Compare Google and Meta's R&D spending and headcount trends"
Output: ["What were Google's total R&D expenses and year-over-year change?", "What were Meta's total R&D expenses and year-over-year change?", "What were Google's total headcount and hiring trends?", "What were Meta's total headcount and hiring trends?"]
Example β simple:
Input: "What was Google's net income in FY2025?"
Output: ["What was Google's net income in FY2025?"]"""),
("human", "{question}"),
])
# ββ Synthesis Prompt ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Takes the merged context from all sub-query retrieval passes and writes a
# single coherent comparative answer covering every decomposed sub-aspect.
_SYNTHESIS_PROMPT = ChatPromptTemplate.from_messages([
("system", """You are a Lead Financial Data Scientist. Multiple sub-queries were retrieved independently to answer the user's original question. Synthesise all findings into a single, coherent, professional comparative analysis.
Rules:
- Use structured bullet points with clear per-company or per-topic section headers.
- Cite the source immediately after every claim (e.g., [Source: Meta 10-K]).
- If a sub-query returned no relevant data, explicitly state the filing does not provide that information.
- Do NOT invent or infer any numbers absent from the context.
Sub-queries that were researched:
{sub_queries_formatted}
Combined Retrieved Context (all sub-queries merged and deduplicated):
{context}"""),
("human", "Original question: {question}\n\nSynthesize a complete, cited answer covering every sub-aspect above."),
])
class FinancialGenerationAgent:
"""
Two-stage LLM pipeline: Chain-of-Thought generator β Compliance auditor.
Stage 1 (Generator): Llama-3.3-70B reasons step-by-step using retrieved
context, then writes a structured comparative analysis.
Stage 2 (Auditor): The same model reviews the draft as an SEC Compliance
Auditor β removing any claim not grounded in the context.
Both stages use the same GENERATOR_MODEL intentionally: the auditor prompt
role-plays a distinct persona (compliance reviewer vs. data scientist),
creating adversarial tension within a single model family. The evaluator
in evaluation.py uses a DIFFERENT model family to avoid circular bias.
"""
def __init__(self, retriever, api_key: str) -> None:
self.retriever = retriever
self.llm = ChatGroq(
model=GENERATOR_MODEL,
temperature=0,
api_key=api_key,
request_timeout=60, # surface hung calls; don't wait indefinitely
)
# ββ Generator Prompt (CoT Analysis) ββββββββββββββββββββββββββββββββββ
self.qa_prompt = ChatPromptTemplate.from_messages([
("system", """You are a Lead Financial Data Scientist. Your objective is to answer the user's query using ONLY the provided SEC 10-K context.
You MUST structure your response strictly in two parts:
<thought_process>
1. Extract the raw facts from the context for each company.
2. Identify the strategic overlaps and stark differences.
3. Note any data requested by the prompt that is explicitly missing from the context.
</thought_process>
<final_answer>
Synthesize your findings into a professional, comparative analysis.
- Use structured bullet points.
- You MUST cite the source immediately after every claim (e.g., [Source: Meta 10-K]).
- If data is missing (e.g., specific 2025 budgets), explicitly state that the filings do not provide forward-looking numbers for that year.
</final_answer>
Context:
{context}"""),
("human", "{question}"),
])
# ββ Auditor Prompt (Compliance Review) βββββββββββββββββββββββββββββββ
self.critic_prompt = ChatPromptTemplate.from_messages([
("system", """You are an SEC Compliance Auditor. Review the <final_answer> section of the Draft Answer against the Source Context.
If the draft contains ANY numbers, metrics, or claims not explicitly present in the context, rewrite it to remove them.
Ensure every claim has a citation.
Original User Question:
{question}
Source Context:
{context}
Draft Answer:
{draft}"""),
("human", "Audit this draft. Output ONLY the finalized, hallucination-free <final_answer> text. Do not include the <thought_process>."),
])
# ββ Internal: Retryable LLM Calls ββββββββββββββββββββββββββββββββββββββββ
@_llm_retry
def _invoke_generator(self, context: str, question: str) -> str:
"""
Stage 1: Generate a CoT draft answer. Retried on failure.
Args:
context: Concatenated retrieved document chunks.
question: The user's original query.
Returns:
Raw LLM response string including <thought_process> and <final_answer>.
"""
chain = self.qa_prompt | self.llm
return chain.invoke({"context": context, "question": question}).content
@_llm_retry
def _invoke_auditor(self, context: str, question: str, draft: str) -> str:
"""
Stage 2: Audit the draft for hallucinations. Retried on failure.
Args:
context: Same retrieved context passed to the generator.
question: The user's original query.
draft: The raw generator output to be audited.
Returns:
Final hallucination-free <final_answer> text.
"""
chain = self.critic_prompt | self.llm
return chain.invoke({
"context": context,
"draft": draft,
"question": question,
}).content
# ββ Public: Full Answer Generation βββββββββββββββββββββββββββββββββββββββ
def generate_answer(self, query: str) -> tuple[str, list[Document]]:
"""
Execute the full two-stage RAG generation pipeline.
Args:
query: Natural language question from the user.
Returns:
Tuple of (final_answer: str, source_docs: list[Document]).
final_answer is the audited, hallucination-checked response.
source_docs are the exact chunks used to construct the answer.
Raises:
Exception: Propagates after MAX_API_RETRIES exhausted on LLM calls.
"""
logger.info("Retrieving documents for query: '%s'", query)
docs: list[Document] = self.retriever.invoke(query)
if not docs:
logger.warning("Retriever returned zero documents for query: '%s'", query)
# Build context string with explicit company attribution per chunk.
context: str = "\n\n".join([
f"[Source: {d.metadata.get('company', 'Unknown')} 10-K]: {d.page_content}"
for d in docs
])
logger.info("Step 1: Executing Chain-of-Thought Analysis (%s)...", GENERATOR_MODEL)
draft_response: str = self._invoke_generator(context, query)
logger.info("Step 2: Running Strict Compliance Audit (%s)...", GENERATOR_MODEL)
final_response: str = self._invoke_auditor(context, query, draft_response)
return final_response, docs
# ββ Internal: Query Decomposition Helpers βββββββββββββββββββββββββββββββββ
@_llm_retry
def _invoke_decomposer(self, question: str) -> list[str]:
"""
Decompose a complex question into focused sub-queries via LLM.
The LLM is instructed to return ONLY a JSON array of strings.
Falls back to the original question as a single-element list on any
parse failure, so the caller always receives a valid list[str].
Args:
question: The original user question.
Returns:
List of 1β4 sub-query strings.
"""
chain = _DECOMPOSE_PROMPT | self.llm
raw: str = chain.invoke({"question": question}).content.strip()
# Strip accidental markdown code fences if the model adds them.
if raw.startswith("```"):
raw = raw.split("```")[1]
if raw.startswith("json"):
raw = raw[4:]
raw = raw.strip()
try:
sub_queries: list[str] = json.loads(raw)
if not isinstance(sub_queries, list) or not all(
isinstance(q, str) for q in sub_queries
):
raise ValueError("Parsed value is not a list of strings.")
logger.info(
"Query decomposed into %d sub-queries: %s",
len(sub_queries), sub_queries,
)
return sub_queries
except (json.JSONDecodeError, ValueError) as exc:
logger.warning(
"Decomposer returned unparseable output (%s). "
"Falling back to original question as single sub-query.", exc
)
return [question]
@_llm_retry
def _invoke_synthesizer(
self, sub_queries: list[str], context: str, question: str
) -> str:
"""
Synthesise answers from multiple sub-query retrievals into one response.
Args:
sub_queries: The list of decomposed sub-queries (for prompt context).
context: Merged, deduplicated context from all sub-query retrievals.
question: The original user question for the synthesis prompt.
Returns:
Raw synthesised answer string (before compliance audit).
"""
sub_queries_formatted: str = "\n".join(
f" {i}. {q}" for i, q in enumerate(sub_queries, 1)
)
chain = _SYNTHESIS_PROMPT | self.llm
return chain.invoke({
"sub_queries_formatted": sub_queries_formatted,
"context": context,
"question": question,
}).content
# ββ Public: Decomposed Answer Generation βββββββββββββββββββββββββββββββββ
def generate_answer_decomposed(
self, query: str
) -> tuple[str, list[Document], list[str]]:
"""
Execute the full query-decomposition RAG pipeline.
Pipeline:
1. Decompose the query into 1β4 focused sub-queries (LLM).
2. Retrieve documents independently for each sub-query.
3. Merge and deduplicate all retrieved chunks by chunk_id.
4. Synthesise the merged context into a single comparative answer (LLM).
5. Audit the synthesised answer for hallucinations (existing auditor).
For simple questions the decomposer returns a 1-element list, making
this method functionally equivalent to generate_answer β no wasted calls.
Args:
query: Natural language question from the user.
Returns:
Tuple of:
- final_answer (str): Audited, hallucination-free response.
- all_source_docs (list): Deduplicated chunks used across all sub-queries.
- sub_queries (list[str]): The decomposed sub-queries for UI display.
Raises:
Exception: Propagates after MAX_API_RETRIES exhausted on any LLM call.
"""
logger.info(
"[Decomposition] Decomposing query: '%s'", query
)
sub_queries: list[str] = self._invoke_decomposer(query)
# Retrieve context for each sub-query independently and merge.
doc_map: dict[str, Document] = {} # chunk_id β Document (deduplication)
for i, sub_query in enumerate(sub_queries, 1):
logger.info(
"[Decomposition] Sub-query %d/%d retrieval: '%s'",
i, len(sub_queries), sub_query,
)
docs: list[Document] = self.retriever.invoke(sub_query)
for doc in docs:
chunk_id: str = doc.metadata.get(
"chunk_id",
# Fallback hash if chunk_id absent (should not occur in production).
f"fallback_{i}_{hash(doc.page_content)}",
)
doc_map[chunk_id] = doc # later sub-queries overwrite on collision; content is identical
all_docs: list[Document] = list(doc_map.values())
if not all_docs:
logger.warning(
"[Decomposition] All sub-query retrievals returned zero documents."
)
# Build merged context string with explicit company attribution.
merged_context: str = "\n\n".join([
f"[Source: {d.metadata.get('company', 'Unknown')} 10-K]: {d.page_content}"
for d in all_docs
])
logger.info(
"[Decomposition] Synthesis over %d deduplicated chunks (%d sub-queries).",
len(all_docs), len(sub_queries),
)
synthesised_draft: str = self._invoke_synthesizer(
sub_queries, merged_context, query
)
logger.info(
"[Decomposition] Running compliance audit on synthesised answer..."
)
final_response: str = self._invoke_auditor(
merged_context, query, synthesised_draft
)
return final_response, all_docs, sub_queries |