adeshboudh16
fix: use .get() for all state access to prevent KeyError in graph nodes
ee2a25c
import json
import logging
from backend.graph.state import InterviewState
log = logging.getLogger("app.graph")
from backend.llm import call_llm
from backend.prompts import (
build_ask_question_prompt,
build_counter_prompt,
build_evaluate_prompt,
build_report_prompt,
build_summarize_prompt,
)
from backend.db import queries
def _msg_content(msg) -> str:
return msg.get("content", "") if isinstance(msg, dict) else getattr(msg, "content", "")
def _msg_role(msg) -> str:
if isinstance(msg, dict):
return msg.get("role", "")
name = type(msg).__name__.lower()
if "human" in name:
return "human"
return "assistant"
MAX_TURNS = 5
async def ask_question(state: InterviewState) -> dict:
remaining = list(state.get("questions_remaining", []))
log.info("ask_question: turn=%d/%d, remaining=%d", state.get("turn_count", 0), MAX_TURNS, len(remaining))
if not remaining or state.get("turn_count", 0) >= MAX_TURNS:
log.info("ask_question: skipping (limit reached or no questions)")
return {"messages": []}
question = remaining.pop(0)
prompt = build_ask_question_prompt(
state["topic_name"],
state["conversation_summary"],
state["questions_asked"],
question_text=question["question_text"],
)
response = await call_llm(prompt, max_tokens=200)
response = " ".join(response.split()) # collapse any mid-sentence newlines
return {
"questions_remaining": remaining,
"questions_asked": state.get("questions_asked", []) + [question["question_text"]],
"messages": [{"role": "assistant", "content": response}],
"turn_count": state.get("turn_count", 0) + 1,
"awaiting_counter_response": False,
}
async def evaluate_answer(state: InterviewState) -> dict:
log.info("evaluate_answer: turn=%d, messages=%d", state["turn_count"], len(state["messages"]))
last_student = next(
(m for m in reversed(state["messages"]) if _msg_role(m) == "human"),
None,
)
if not last_student:
log.warning("evaluate_answer: no student message found")
return {"last_verdict": "wrong"}
last_question = state["questions_asked"][-1] if state["questions_asked"] else ""
prompt = build_evaluate_prompt(
last_question,
_msg_content(last_student),
state["conversation_summary"],
)
raw = await call_llm(prompt, max_tokens=100)
try:
result = json.loads(raw)
verdict = result.get("verdict", "wrong")
weak_area = result.get("weak_area")
except (json.JSONDecodeError, AttributeError):
verdict = "wrong"
weak_area = None
weak_areas = list(state["student_weak_areas"])
if verdict == "shallow" and weak_area:
weak_areas.append(str(weak_area))
log.info("evaluate_answer: verdict=%s", verdict)
return {
"last_verdict": verdict,
"student_weak_areas": weak_areas,
}
async def counter_question(state: InterviewState) -> dict:
log.info("counter_question: counter_asked=%d", state["counter_questions_asked"])
last_student = next(
(m for m in reversed(state["messages"]) if _msg_role(m) == "human"),
None,
)
last_question = state["questions_asked"][-1] if state["questions_asked"] else ""
prompt = build_counter_prompt(
last_question,
_msg_content(last_student) if last_student else "",
)
response = await call_llm(prompt, max_tokens=150)
return {
"messages": [{"role": "assistant", "content": response}],
"awaiting_counter_response": True,
"counter_questions_asked": state["counter_questions_asked"] + 1,
}
async def summarize(state: InterviewState) -> dict:
prompt = build_summarize_prompt(state["messages"])
summary = await call_llm(prompt, max_tokens=200)
return {
"conversation_summary": summary,
"messages": [],
"awaiting_counter_response": False,
}
async def generate_report(state: InterviewState) -> dict:
log.info("generate_report: questions_asked=%d, turn_count=%d", len(state.get("questions_asked", [])), state.get("turn_count", 0))
prompt = build_report_prompt(
state.get("topic_name", ""),
state.get("questions_asked", []),
state.get("student_weak_areas", []),
state.get("conversation_summary", ""),
state.get("past_best_score"),
messages=state.get("messages", []),
)
raw = await call_llm(prompt, max_tokens=400)
try:
feedback = json.loads(raw)
score = int(feedback.get("score", 0))
except (json.JSONDecodeError, ValueError, TypeError):
feedback = {
"score": 0,
"summary": raw,
"concept_score": 0,
"depth_score": 0,
"mistakes": [],
"tips": [],
}
score = 0
await queries.update_session_complete(state["session_id"], score, feedback)
return {
"status": "complete",
"score": score,
"feedback": feedback,
"messages": [{"role": "assistant", "content": feedback.get("summary", "Interview complete.")}],
}
def route_after_evaluation(state: InterviewState) -> str:
"""Routing function for conditional edges after evaluate_answer."""
verdict = state.get("last_verdict", "wrong")
turn_count = state.get("turn_count", 0)
questions_remaining = state.get("questions_remaining", [])
awaiting_counter = state.get("awaiting_counter_response", False)
decision = "unknown"
# End conditions always win — no extra turns after the limit
if turn_count >= MAX_TURNS or not questions_remaining:
decision = "end"
elif verdict == "shallow" and not awaiting_counter and state.get("counter_questions_asked", 0) < 2:
decision = "counter"
elif turn_count % 4 == 0 and turn_count > 0:
decision = "summarize"
else:
decision = "next_question"
log.info("route: turn=%d, verdict=%s, remaining=%d, counter_asked=%d → %s",
turn_count, verdict, len(questions_remaining),
state.get("counter_questions_asked", 0), decision)
return decision