""" Demo API routes — streaming SSE endpoints matching the original TypeScript API. Routes: GET /api/init POST /api/execute-query (SSE) POST /api/benchmark (SSE) GET /api/rl-state GET /api/schema-graph POST /api/feedback """ from __future__ import annotations import asyncio import json import logging import os import time from typing import AsyncIterator, Optional logger = logging.getLogger(__name__) from fastapi import APIRouter from pydantic import BaseModel from sse_starlette.sse import EventSourceResponse from env.database import ( ensure_seeded, get_table_stats, get_schema_info, get_schema_graph, execute_query, connect_external_db, get_active_db_label, ) # Map frontend difficulty names → backend task IDs _DIFFICULTY_MAP = { "easy": "simple_queries", "medium": "join_queries", "hard": "complex_queries", } from env.tasks import TASKS, get_task from env.sql_env import SQLAgentEnv, Action, get_env, BASE_SYSTEM_PROMPT, get_system_prompt, _clean_sql, _clamp_score from rl.environment import get_bandit_state from rl.types import RepairAction, REPAIR_ACTION_NAMES, REPAIR_ACTION_BY_NAME from rl.error_classifier import classify_error, extract_offending_token from rl.grader import GraderInput, compute_reward, compute_episode_reward from rl.types import RLState, EpisodeStep, featurize, ERROR_CLASS_NAMES from gepa.optimizer import get_gepa, QueryResult, GEPA_OPTIMIZE_EVERY router = APIRouter() # ─── /api/test-llm ─────────────────────────────────────────────── @router.get("/test-llm") async def test_llm(): """Diagnostic: test LLM connectivity and return result.""" from env.sql_env import _make_client, _MODEL token = os.environ.get("HF_TOKEN", "") api_base = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1") token_preview = f"{token[:8]}..." if len(token) > 8 else ("(empty)" if not token else token) try: client = _make_client() resp = await client.chat.completions.create( model=_MODEL, messages=[{"role": "user", "content": "Reply with just: OK"}], temperature=0, max_tokens=5, ) result = resp.choices[0].message.content return { "ok": True, "model": _MODEL, "api_base": api_base, "token_set": bool(token), "token_preview": token_preview, "response": result, } except Exception as e: err = str(e) if len(err) > 400 or ' AsyncIterator[dict]: env = get_env() # Accept difficulty names ('easy'/'medium'/'hard') or direct task IDs task_id = _DIFFICULTY_MAP.get(req.task_id, req.task_id) obs = env.reset(task_id) # Pick first question of task matching question text, or default task = get_task(task_id) question_obj = task.questions[0] # Override question text env._episode.question = req.question # type: ignore[union-attr] max_attempts = env.MAX_ATTEMPTS done = False all_step_rewards: list[float] = [] success = False # Initial generate action action = Action(repair_action="generate") from env.sql_env import _make_client, _MODEL from rl.repair_strategies import RepairContext, get_repair_system_suffix, build_repair_user_message # Build initial user message (includes previous-wrong-SQL context if retrying) prev_context = "" if req.previousSql: prev_context = ( f"\nNOTE: A previous session generated the following SQL which was marked INCORRECT:\n" f"```sql\n{req.previousSql}\n```\n" f"You MUST try a completely different approach.\n" ) initial_user_msg = ( f"Schema:\n{obs.schema_info}\n\nQuestion: {req.question}\n" f"{prev_context}\n" "Write a SQL query to answer this question." ) # Multi-turn conversation — grows with each failed attempt so the LLM # sees its own history and doesn't repeat the same mistake. conversation: list[dict] = [ {"role": "system", "content": get_system_prompt()}, {"role": "user", "content": initial_user_msg}, ] for attempt in range(1, max_attempts + 1): yield {"data": json.dumps({"type": "attempt_start", "attempt": attempt})} ep = env._episode # type: ignore[union-attr] ep.attempt_number = attempt # On repair attempts, update system prompt with RL-selected repair suffix if attempt > 1 and ep.current_features is not None: repair_enum, scores = env._bandit.select_action(ep.current_features) ucb_scores = { REPAIR_ACTION_NAMES[RepairAction(i)]: round(scores[i], 4) for i in range(len(scores)) } action = Action(repair_action=REPAIR_ACTION_NAMES[repair_enum]) yield {"data": json.dumps({ "type": "rl_action", "action": action.repair_action, "ucb_scores": ucb_scores, })} # Update system prompt with repair-specific guidance conversation[0] = { "role": "system", "content": get_system_prompt() + get_repair_system_suffix(repair_enum), } elif attempt > 1: repair_enum = RepairAction.REWRITE_FULL action = Action(repair_action="rewrite_full") conversation[0] = { "role": "system", "content": get_system_prompt() + get_repair_system_suffix(repair_enum), } # Stream SQL generation using the full conversation history client = _make_client() chunks: list[str] = [] try: stream = await client.chat.completions.create( model=_MODEL, messages=conversation, stream=True, temperature=0.1, ) async for chunk in stream: if not chunk.choices: continue # HF Router sends empty-choices chunks (ping/final) delta = chunk.choices[0].delta.content if delta: chunks.append(delta) yield {"data": json.dumps({"type": "sql_chunk", "chunk": delta})} except Exception as e: # Format LLM exception concisely (avoid dumping full HTML 401 pages) err_str = str(e) logger.error("LLM call failed attempt=%d: %s: %s", attempt, type(e).__name__, err_str[:200]) print(f"[execute-query] LLM error attempt={attempt}: {type(e).__name__}: {err_str[:200]}", flush=True) if len(err_str) > 300 or ' AsyncIterator[dict]: task_id = _DIFFICULTY_MAP.get(req.task_id, req.task_id) task = get_task(task_id) scores: list[float] = [] questions = task.questions if req.queryIds: questions = [q for q in questions if q.id in req.queryIds] for question_obj in questions: yield {"data": json.dumps({ "type": "query_start", "id": question_obj.id, "question": question_obj.question, })} # Run the question through the env env = SQLAgentEnv() obs = env.reset_with_question(task_id, question_obj.id) attempt = 0 sql = "" success = False task_score = _clamp_score(0.0) max_attempts = env.MAX_ATTEMPTS ep = env._episode # type: ignore[union-attr] gepa = get_gepa() system_prompt = gepa.get_current_prompt() or get_system_prompt() from env.sql_env import _make_client, _MODEL for attempt in range(1, max_attempts + 1): ep.attempt_number = attempt if attempt == 1 or ep.current_sql is None: user_msg = ( f"Schema:\n{obs.schema_info}\n\n" f"Question: {question_obj.question}\n\n" "Write a SQL query to answer this question." ) sys_prompt = system_prompt else: from rl.repair_strategies import RepairContext, get_repair_system_suffix, build_repair_user_message if ep.current_features is not None: repair_enum, _ = env._bandit.select_action(ep.current_features) else: repair_enum = RepairAction.REWRITE_FULL suffix = get_repair_system_suffix(repair_enum) offending = extract_offending_token(ep.error_message or "") ctx = RepairContext( schema=obs.schema_info, question=question_obj.question, failing_sql=ep.current_sql or "", error_message=ep.error_message or "", offending_token=offending, ) sys_prompt = system_prompt + suffix user_msg = build_repair_user_message(repair_enum, ctx) client = _make_client() try: resp = await client.chat.completions.create( model=_MODEL, messages=[ {"role": "system", "content": sys_prompt}, {"role": "user", "content": user_msg}, ], temperature=0.1, ) sql = _clean_sql(resp.choices[0].message.content or "") except Exception as e: break rows, error = execute_query(sql) from env.tasks import grade_response task_score = grade_response( task_id, question_obj.id, sql, rows, error, attempt ) success = task_score >= 0.8 current_ec = None if error: ec = classify_error(error) current_ec = ec error_changed = ep.previous_error_class is not None and ep.previous_error_class != ec if ep.previous_error_class == ec: ep.consecutive_same_error += 1 else: ep.consecutive_same_error = 1 rl_state = RLState( error_class=ec, attempt_number=attempt, previous_action=ep.last_action, error_changed=error_changed, consecutive_same_error=ep.consecutive_same_error, ) ep.current_rl_state = rl_state ep.current_features = featurize(rl_state) from rl.grader import GraderInput, compute_reward grader_in = GraderInput( success=success, attempt_number=attempt, current_error_class=current_ec, previous_error_class=ep.previous_error_class, ) grader_out = compute_reward(grader_in) ep.current_sql = sql ep.error_message = error ep.error_class = ERROR_CLASS_NAMES[current_ec] if current_ec else None ep.previous_error_class = current_ec if success: break scores.append(task_score) yield {"data": json.dumps({ "type": "query_result", "id": question_obj.id, "pass": success, "score": task_score, "sql": sql, "attempts": attempt, "reason": "Correct" if success else "Agent exhausted all repair attempts", })} overall_score = sum(scores) / len(scores) if scores else 0.0 yield {"data": json.dumps({ "type": "done", "overallScore": overall_score, "task_id": task_id, })} return EventSourceResponse(event_generator()) # ─── /api/rl-state ──────────────────────────────────────────────── @router.get("/rl-state") async def get_rl_state(): from rl.experience import get_metrics state = get_bandit_state() metrics = get_metrics() action_names = [REPAIR_ACTION_NAMES[RepairAction(i)] for i in range(8)] # Build actionDistribution as array [{action, count}] expected by frontend action_distribution = [ {"action": name, "count": state["action_counts"][i]} for i, name in enumerate(action_names) ] # Build episodes array [{episode, totalReward, successRate}] from reward_history reward_history: list[float] = metrics.reward_history or [] total_eps = max(metrics.total_episodes, len(reward_history)) episodes = [ { "episode": i + 1, "totalReward": round(r, 3), "successRate": round(metrics.success_rate, 3), } for i, r in enumerate(reward_history) ] from gepa.optimizer import get_gepa gepa = get_gepa() return { "totalEpisodes": total_eps, "successRate": round(metrics.success_rate, 3), "currentAlpha": round(state["alpha"], 4), "episodes": episodes, "actionDistribution": action_distribution, "currentGeneration": gepa.current_generation, } # ─── /api/schema-graph ──────────────────────────────────────────── @router.get("/schema-graph") async def schema_graph(): return get_schema_graph() # ─── /api/feedback ──────────────────────────────────────────────── class FeedbackRequest(BaseModel): question: str sql: str correct: bool remark: Optional[str] = None # user's free-text explanation of what was wrong @router.post("/feedback") async def submit_feedback(req: FeedbackRequest): gepa = get_gepa() errors = [] if not req.correct: errors.append("User marked as incorrect") if req.remark: errors.append(f"User remark: {req.remark}") gepa.record_result(QueryResult( question=req.question, final_sql=req.sql, attempts=1, success=req.correct, errors=errors, timestamp=time.time(), )) result = None if not req.correct and gepa.should_optimize(): feedback_ctx = f"User marked query as incorrect.\nQuestion: {req.question}\nSQL: {req.sql}" if req.remark: feedback_ctx += f"\nUser explanation: {req.remark}" try: result = await gepa.run_optimization_cycle(user_feedback_context=feedback_ctx) except Exception: pass return { "received": True, "gepa_triggered": result is not None, "reflection": result.get("reflection") if result else None, }