| """Verification Node - Final quality control and output formatting""" |
| from typing import Dict, Any |
| from langchain_core.messages import SystemMessage, HumanMessage, AIMessage |
| from langchain_groq import ChatGroq |
| from src.tracing import get_langfuse_callback_handler |
|
|
|
|
| def load_verification_prompt() -> str: |
| """Load the verification prompt from file""" |
| try: |
| with open("./prompts/verification_prompt.txt", "r", encoding="utf-8") as f: |
| return f.read().strip() |
| except FileNotFoundError: |
| return """You are a verification agent. Ensure responses meet quality standards and format requirements.""" |
|
|
|
|
| def extract_final_answer(response_content: str) -> str: |
| """Extract and format the final answer according to system prompt requirements""" |
| |
| answer = response_content.strip() |
| |
| |
| answer = answer.replace("**", "").replace("*", "") |
| |
| |
| prefixes_to_remove = [ |
| "Final Answer:", "Answer:", "The answer is:", "The final answer is:", |
| "Result:", "Solution:", "Response:", "Output:", "Conclusion:" |
| ] |
| |
| for prefix in prefixes_to_remove: |
| if answer.lower().startswith(prefix.lower()): |
| answer = answer[len(prefix):].strip() |
| |
| |
| answer = answer.strip('"\'()[]{}') |
| |
| |
| if '\n' in answer and all(line.strip().startswith(('-', '*', '•')) for line in answer.split('\n') if line.strip()): |
| |
| items = [line.strip().lstrip('-*•').strip() for line in answer.split('\n') if line.strip()] |
| answer = ', '.join(items) |
| |
| |
| if '\n' in answer: |
| candidate = None |
| for line in answer.split('\n'): |
| if not line.strip(): |
| continue |
| cleaned_line = line.strip() |
| |
| lower_line = cleaned_line.lower() |
| if lower_line in {"<think>", "think", "[thinking]", "<thinking>", "[think]"}: |
| continue |
| |
| if lower_line.startswith("<") and lower_line.endswith(">"): |
| continue |
| candidate = cleaned_line |
| break |
| if candidate is not None: |
| answer = candidate |
| |
| return answer.strip() |
|
|
|
|
| def verification_node(state: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Verification node that performs final quality control and formatting |
| """ |
| print("Verification Node: Performing final quality control") |
| |
| try: |
| |
| verification_prompt = load_verification_prompt() |
| |
| |
| llm = ChatGroq(model="qwen-qwq-32b", temperature=0.0) |
| |
| |
| callback_handler = get_langfuse_callback_handler() |
| callbacks = [callback_handler] if callback_handler else [] |
| |
| |
| messages = state.get("messages", []) |
| quality_pass = state.get("quality_pass", True) |
| quality_score = state.get("quality_score", 7) |
| critic_assessment = state.get("critic_assessment", "") |
| |
| |
| agent_response = state.get("agent_response") |
| if not agent_response: |
| |
| for msg in reversed(messages): |
| if msg.type == "ai": |
| agent_response = msg |
| break |
| |
| if not agent_response: |
| print("Verification Node: No response to verify") |
| return { |
| **state, |
| "final_answer": "No response found to verify", |
| "verification_status": "failed", |
| "current_step": "complete" |
| } |
| |
| |
| user_query = None |
| for msg in reversed(messages): |
| if msg.type == "human": |
| user_query = msg.content |
| break |
| |
| |
| failure_threshold = 4 |
| max_attempts = state.get("attempt_count", 1) |
| |
| if not quality_pass or quality_score < failure_threshold: |
| if max_attempts >= 3: |
| print("Verification Node: Maximum attempts reached, proceeding with fallback") |
| return { |
| **state, |
| "final_answer": "Unable to provide a satisfactory answer after multiple attempts", |
| "verification_status": "failed_max_attempts", |
| "current_step": "fallback" |
| } |
| else: |
| print(f"Verification Node: Quality check failed (score: {quality_score}), retrying") |
| return { |
| **state, |
| "verification_status": "failed", |
| "attempt_count": max_attempts + 1, |
| "current_step": "routing" |
| } |
| |
| |
| print("Verification Node: Quality check passed, formatting final answer") |
| |
| |
| verification_messages = [SystemMessage(content=verification_prompt)] |
| |
| verification_request = f""" |
| Please verify and format the following response according to the exact-match output rules: |
| |
| Original Query: {user_query or "Unknown query"} |
| |
| Response to Verify: |
| {agent_response.content} |
| |
| Quality Assessment: {critic_assessment} |
| |
| Ensure the final output strictly adheres to the format requirements specified in the system prompt. |
| """ |
| |
| verification_messages.append(HumanMessage(content=verification_request)) |
| |
| |
| verification_response = llm.invoke(verification_messages, config={"callbacks": callbacks}) |
| |
| |
| final_answer = extract_final_answer(verification_response.content) |
| |
| |
| return { |
| **state, |
| "messages": messages + [verification_response], |
| "final_answer": final_answer, |
| "verification_status": "passed", |
| "current_step": "complete" |
| } |
| |
| except Exception as e: |
| print(f"Verification Node Error: {e}") |
| |
| if agent_response: |
| fallback_answer = extract_final_answer(agent_response.content) |
| else: |
| fallback_answer = f"Error during verification: {e}" |
| |
| return { |
| **state, |
| "final_answer": fallback_answer, |
| "verification_status": "error", |
| "current_step": "complete" |
| } |
|
|
|
|
| def should_retry(state: Dict[str, Any]) -> bool: |
| """Determine if we should retry the process""" |
| verification_status = state.get("verification_status", "") |
| return verification_status == "failed" and state.get("attempt_count", 1) < 3 |