Spaces:
Running
Running
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Maria Learning Service | app.py | |
| # FastAPI + CPU (Qwen3-0.6B, int4 via bitsandbytes) + FAISS RAG + gTTS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import asyncio | |
| import os | |
| import gc | |
| import json | |
| import base64 | |
| import hashlib | |
| import logging | |
| import copy | |
| from io import BytesIO | |
| from typing import List, Any, Optional | |
| import httpx | |
| import numpy as np | |
| import pandas as pd | |
| import faiss | |
| import gradio as gr | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| from huggingface_hub import hf_hub_download | |
| from gtts import gTTS | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)-8s %(message)s", | |
| ) | |
| log = logging.getLogger(__name__) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Config / Secrets | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HASH_VALUE = os.environ.get("HASH_VALUE", "") | |
| CF_SECRET_KEY = os.environ.get("CF_SECRET_KEY", "") | |
| ALLOWED_DOMAIN = os.environ.get("ALLOWED_DOMAIN", "") | |
| HF_REPO_ID = "digifreely/Maria" | |
| LLM_MODEL_ID = "Qwen/Qwen3-0.6B" # Qwen3 0.6B β CPU int4 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Preload: model + tokenizer loaded once into CPU RAM at container start. | |
| # | |
| # bitsandbytes 4-bit is CUDA-only; for CPU we load the 0.6B model in float32 | |
| # which is lightweight enough to reside in memory for the container lifetime. | |
| # No per-request load/unload cycle β the model object is reused directly. | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _llm_tok = None # tokenizer | |
| _llm_model = None # model β lives in CPU RAM for the container lifetime | |
| def _preload_model(): | |
| """Load Qwen3-0.6B tokenizer + model into CPU RAM at container start.""" | |
| global _llm_tok, _llm_model | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| log.info("Loading %s to CPU RAMβ¦", LLM_MODEL_ID) | |
| _llm_tok = AutoTokenizer.from_pretrained(LLM_MODEL_ID, trust_remote_code=True) | |
| _llm_model = AutoModelForCausalLM.from_pretrained( | |
| LLM_MODEL_ID, | |
| torch_dtype=torch.float32, # float32 for CPU compatibility | |
| device_map="cpu", | |
| trust_remote_code=True, | |
| ) | |
| _llm_model.eval() | |
| log.info("Model loaded on CPU β ready for inference.") | |
| # Trigger preload immediately when the module is imported | |
| _preload_model() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Embedding model (CPU, loaded once per container lifetime) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _emb_model = None | |
| def _get_emb_model(name: str = "sentence-transformers/all-MiniLM-L6-v2"): | |
| global _emb_model | |
| if _emb_model is None: | |
| from sentence_transformers import SentenceTransformer | |
| log.info("Loading embedding model: %s", name) | |
| _emb_model = SentenceTransformer(name) | |
| return _emb_model | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Security helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _check_auth_code(code: str) -> bool: | |
| if not HASH_VALUE: | |
| return False | |
| return hashlib.sha256(code.encode()).hexdigest() == HASH_VALUE | |
| async def _check_turnstile(token: str) -> bool: | |
| if not CF_SECRET_KEY: | |
| return False | |
| try: | |
| async with httpx.AsyncClient(timeout=8.0) as client: | |
| resp = await client.post( | |
| "https://challenges.cloudflare.com/turnstile/v0/siteverify", | |
| data={"secret": CF_SECRET_KEY, "response": token}, | |
| ) | |
| return resp.json().get("success", False) | |
| except Exception as exc: | |
| log.error("Turnstile verification error: %s", exc) | |
| return False | |
| async def _authenticate(request: Request) -> bool: | |
| auth_code = request.headers.get("auth_code") | |
| cf_token = request.headers.get("cf-turnstile-token") | |
| if auth_code: | |
| return _check_auth_code(auth_code) | |
| if cf_token: | |
| return await _check_turnstile(cf_token) | |
| # Fallback: domain/referer check (same as init service) | |
| referer = request.headers.get("referer", "") | |
| origin = request.headers.get("origin", "") | |
| if ALLOWED_DOMAIN in referer or ALLOWED_DOMAIN in origin: | |
| return True | |
| return False | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Change 3: Dataset cache β populated by /dataset, consumed by /chat | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Key: (board, cls, subject) β (config, faiss_index, metadata) | |
| _dataset_cache: dict = {} | |
| def _dataset_key(board: str, cls: str, subject: str) -> tuple: | |
| return (board.strip(), cls.strip(), subject.strip()) | |
| def _load_dataset(board: str, cls: str, subject: str): | |
| """Download config / FAISS index / metadata from HF Hub and return them.""" | |
| prefix = f"knowledgebase/{board}/{cls}/{subject}" | |
| log.info("Fetching dataset: %s", prefix) | |
| config_path = hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=f"{prefix}/config.json", | |
| repo_type="dataset", | |
| ) | |
| faiss_path = hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=f"{prefix}/faiss_index.bin", | |
| repo_type="dataset", | |
| ) | |
| meta_path = hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=f"{prefix}/metadata.parquet", | |
| repo_type="dataset", | |
| ) | |
| with open(config_path) as fh: | |
| config = json.load(fh) | |
| index = faiss.read_index(faiss_path) | |
| metadata = pd.read_parquet(meta_path) | |
| return config, index, metadata | |
| def _rag_search( | |
| query: str, | |
| config: dict, | |
| index, | |
| metadata: pd.DataFrame, | |
| k: int = 3, | |
| ) -> List[str]: | |
| """Embed query, search FAISS, return top-k text chunks.""" | |
| emb_model_name = config.get( | |
| "embedding_model", "sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| emb = _get_emb_model(emb_model_name) | |
| vec = emb.encode([query], normalize_embeddings=True).astype(np.float32) | |
| _, idxs = index.search(vec, k) | |
| text_cols = ["text", "content", "chunk", "passage", "answer", "description"] | |
| chunks: List[str] = [] | |
| for i in idxs[0]: | |
| if 0 <= i < len(metadata): | |
| row = metadata.iloc[i] | |
| for col in text_cols: | |
| if col in metadata.columns and pd.notna(row[col]): | |
| chunks.append(str(row[col])[:600]) | |
| break | |
| return chunks | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # LLM inference β uses the preloaded CPU model; no per-call load/unload. | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _model_generate(system_prompt: str, user_prompt: str) -> str: | |
| import torch | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| text = _llm_tok.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=False, # Qwen3: suppress <think>β¦</think> block | |
| ) | |
| inputs = _llm_tok([text], return_tensors="pt").to(_llm_model.device) | |
| with torch.no_grad(): | |
| out_ids = _llm_model.generate( | |
| **inputs, | |
| max_new_tokens=180, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| repetition_penalty=1.1, | |
| pad_token_id=_llm_tok.eos_token_id, | |
| ) | |
| new_tokens = out_ids[0][inputs.input_ids.shape[1]:] | |
| result = _llm_tok.decode(new_tokens, skip_special_tokens=True).strip() | |
| log.info("Inference complete. Output length: %d chars", len(result)) | |
| return result | |
| # CPU inference β direct reference (no ZeroGPU decorator) | |
| run_inference = _model_generate | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Text-to-Speech | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _tts_to_b64(text: str) -> str: | |
| try: | |
| tts = gTTS(text=text[:3000], lang="en", tld="co.uk", slow=False) | |
| buf = BytesIO() | |
| tts.write_to_fp(buf) | |
| buf.seek(0) | |
| return base64.b64encode(buf.read()).decode("utf-8") | |
| except Exception as exc: | |
| log.error("TTS error: %s", exc) | |
| return "" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Prompt builder β trimmed for 180-token output budget (Qwen3-0.6B, CPU) | |
| # | |
| # Key design: only the ACTIVE topic/goal is passed to stage_updates context. | |
| # Showing all topics caused the model to update every entry, blowing the | |
| # token budget and truncating the JSON. | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _STAGES = ("teach", "re_teach", "show_and_tell", "assess") | |
| def _find_active_topic(current_learning: list) -> tuple: | |
| """Return (topic_name, goal_name, stage) for the first incomplete objective.""" | |
| for item in current_learning: | |
| topic = item.get("topic", "") | |
| for obj in item.get("learning_objectives", []): | |
| goal = obj.get("goal", "") | |
| for stage in _STAGES: | |
| if obj.get(stage, "Not_Complete") != "complete": | |
| return topic, goal, stage | |
| return "", "", "teach" # all complete β nothing active | |
| def _build_system_prompt(lp: dict, rag_chunks: List[str]) -> str: | |
| persona = lp.get("teacher_persona", "A friendly and patient teacher") | |
| student = lp.get("student_name", "Student") | |
| chat_history = lp.get("chat_history", [])[-2:] # last 2 turns only | |
| scratchpad = lp.get("scratchpad", [])[-1:] # last 1 entry only | |
| current_learning = lp.get("assessment_stages", {}).get("current_learning", []) | |
| # ββ Find the single active topic/goal to teach right now βββββββββββββββββ | |
| active_topic, active_goal, active_stage = _find_active_topic(current_learning) | |
| history_block = "\n".join( | |
| f'S: {h.get("user_input","")}\nT: {h.get("system_output","")}' | |
| for h in chat_history | |
| ) or "None." | |
| scratch_block = "\n".join( | |
| f'[{s.get("chat_id","")}] {s.get("thought","")} | {s.get("action","")}' | |
| for s in scratchpad | |
| ) or "Empty." | |
| rag_block = "\n---\n".join(rag_chunks) if rag_chunks else "No relevant content found." | |
| # Pass only the active topic/goal β not the whole list β to keep output short | |
| active_block = ( | |
| f'Topic: "{active_topic}"\nGoal: "{active_goal}"\nCurrent stage: {active_stage}' | |
| if active_topic else "All objectives complete." | |
| ) | |
| return f"""You are {persona} teaching {student}, aged 6β12. Use simple English. Be warm and brief. | |
| STUDENT: {student} | |
| ACTIVE OBJECTIVE (teach this now): | |
| {active_block} | |
| KNOWLEDGE BASE: | |
| {rag_block} | |
| RECENT CHAT: | |
| {history_block} | |
| NOTES: | |
| {scratch_block} | |
| TASK: Classify intent, respond to the student, return ONLY valid JSON. Keep "response" under 50 words. | |
| INTENT RULES: | |
| "block" β rude/inappropriate. Redirect kindly (first time) or end gently (repeat). | |
| "questions" β off-topic. Answer briefly from KB, then redirect. | |
| "curriculum" β on-topic. Follow: teach β re_teach β show_and_tell β assess. | |
| "chitchat" β casual. Respond warmly, bring up active topic. | |
| STAGE VALUE RULES β CRITICAL: | |
| The fields teach / re_teach / show_and_tell / assess must ONLY ever be the exact string "complete" or "Not_Complete". | |
| NEVER put a sentence, description, or any other text in these fields. | |
| Set the current active stage ("{active_stage}") to "complete" if the student has completed it, else "Not_Complete". | |
| All other stages keep their previous value β use "Not_Complete" if unknown. | |
| OUTPUT β return ONLY this JSON (stage_updates: EXACTLY 1 entry): | |
| {{ | |
| "intent": "<block|questions|curriculum|chitchat>", | |
| "response": "<reply, max 50 words>", | |
| "stage_updates": [{{"topic":"{active_topic}","goal":"{active_goal}","teach":"Not_Complete","re_teach":"Not_Complete","show_and_tell":"Not_Complete","assess":"Not_Complete"}}], | |
| "thought": "<one sentence>", | |
| "action": "<teach|re_teach|show_and_tell|assess|answer|redirect|discourage|end|chitchat>", | |
| "observation": "<one sentence>" | |
| }}\ | |
| """ | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # JSON parser β layered extraction, regex-anchored on "intent" key. | |
| # | |
| # Layer 0 : strip any <think>β¦</think> block (Qwen3 safety fallback). | |
| # Layer 1 : strip markdown ```json β¦ ``` fences. | |
| # Layer 2 : direct json.loads on the cleaned text. | |
| # Layer 3 : regex β walk every '{' left-to-right; skip those that don't | |
| # contain "intent":; try every '}' right-to-left until a valid | |
| # JSON object with "intent" key parses successfully. | |
| # Layer 4 : broad regex β outermost { β¦ } regardless of content. | |
| # Layer 5 : fallback dict with raw text as the response field. | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import re as _re | |
| def _parse_llm_output(raw: str) -> dict: | |
| # ββ Layer 0: strip Qwen3 <think>β¦</think> block ββββββββββββββββββββββββββ | |
| text = _re.sub(r"<think>.*?</think>", "", raw, flags=_re.DOTALL).strip() | |
| # ββ Layer 1: strip markdown fences βββββββββββββββββββββββββββββββββββββββ | |
| fence_match = _re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, _re.DOTALL) | |
| if fence_match: | |
| try: | |
| return json.loads(fence_match.group(1)) | |
| except json.JSONDecodeError: | |
| pass | |
| # ββ Layer 2: direct parse βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| pass | |
| # ββ Layer 3: intent-anchored brace scan βββββββββββββββββββββββββββββββββββ | |
| intent_pat = _re.compile(r'"intent"\s*:', _re.DOTALL) | |
| brace_opens = [m.start() for m in _re.finditer(r'\{', text)] | |
| brace_closes = [m.end() for m in _re.finditer(r'\}', text)] | |
| for open_pos in brace_opens: | |
| region = text[open_pos:] | |
| if not intent_pat.search(region): | |
| continue # no "intent": inside this brace | |
| for close_pos in reversed(brace_closes): | |
| if close_pos <= open_pos: | |
| break | |
| candidate = text[open_pos:close_pos] | |
| try: | |
| parsed = json.loads(candidate) | |
| if "intent" in parsed: | |
| log.info("JSON extracted via intent-anchored regex.") | |
| return parsed | |
| except json.JSONDecodeError: | |
| continue | |
| # ββ Layer 4: outermost { β¦ } fallback ββββββββββββββββββββββββββββββββββββ | |
| broad = _re.search(r'\{.*\}', text, _re.DOTALL) | |
| if broad: | |
| try: | |
| return json.loads(broad.group()) | |
| except json.JSONDecodeError: | |
| pass | |
| # ββ Layer 5: give up βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| log.warning("Could not parse JSON from model output. Raw: %.200s", raw) | |
| return { | |
| "intent": "questions", | |
| "response": text or raw, | |
| "stage_updates": [], | |
| "thought": "", | |
| "action": "answer", | |
| "observation": "json_parse_failed", | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # State updater | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _apply_state_updates( | |
| lp: dict, | |
| parsed: dict, | |
| user_msg: str, | |
| ai_msg: str, | |
| ) -> dict: | |
| lp = copy.deepcopy(lp) | |
| history = lp.setdefault("chat_history", []) | |
| new_id = (history[-1]["chat_id"] + 1) if history else 1 | |
| history.append({ | |
| "chat_id": new_id, | |
| "user_input": user_msg, | |
| "system_output": ai_msg, | |
| }) | |
| scratch = lp.setdefault("scratchpad", []) | |
| scratch.append({ | |
| "chat_id": new_id, | |
| "thought": parsed.get("thought", ""), | |
| "action": parsed.get("action", ""), | |
| "action_input": user_msg, | |
| "observation": parsed.get("observation", ""), | |
| }) | |
| current_learning = lp.get("assessment_stages", {}).get("current_learning", []) | |
| valid_statuses = {"complete", "Not_Complete"} | |
| for upd in parsed.get("stage_updates", []): | |
| # Sanitise: coerce any non-enum value to "Not_Complete" | |
| for stage in ("teach", "re_teach", "show_and_tell", "assess"): | |
| if upd.get(stage) not in valid_statuses: | |
| upd[stage] = "Not_Complete" | |
| for item in current_learning: | |
| if item.get("topic") == upd.get("topic"): | |
| for obj in item.get("learning_objectives", []): | |
| if obj.get("goal") == upd.get("goal"): | |
| for stage in ("teach", "re_teach", "show_and_tell", "assess"): | |
| val = upd.get(stage) | |
| if val in valid_statuses: | |
| obj[stage] = val | |
| lp.setdefault("assessment_stages", {})["current_learning"] = current_learning | |
| return lp | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # FastAPI application | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _fastapi = FastAPI( | |
| title="Maria Learning Service", | |
| description="AI tutoring API powered by Qwen3-0.6B on CPU.", | |
| version="1.1.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| ) | |
| class ChatRequest(BaseModel): | |
| learning_path: dict[str, Any] | |
| query: dict[str, Any] | |
| class DatasetRequest(BaseModel): | |
| board: str | |
| subject: str | |
| # Pydantic alias so "class" (reserved word) maps to cls_name internally | |
| class_name: str = "" | |
| class Config: | |
| # Allow the JSON field "class" to populate class_name via alias | |
| populate_by_name = True | |
| def model_validate_with_class(cls, data: dict): | |
| data = dict(data) | |
| if "class" in data: | |
| data["class_name"] = data.pop("class") | |
| return cls(**data) | |
| async def health(): | |
| return {"status": "ok", "model": LLM_MODEL_ID} | |
| async def ping(request: Request): | |
| """Health-check endpoint β wakes the Space if sleeping.""" | |
| if not await _authenticate(request): | |
| raise HTTPException(status_code=403, detail="Forbidden") | |
| return JSONResponse(content={"status": "alive"}) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Change 3: /dataset endpoint | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def dataset(request: Request): | |
| """ | |
| Pre-load the FAISS index, config, and metadata for a given board/class/subject. | |
| Must be called before /chat. Subsequent calls with the same key are no-ops (cached). | |
| Request body: | |
| { "board": "NCERT", "class": "Class 1", "subject": "English" } | |
| Response: | |
| { "status": "ready", "message": "Dataset Loaded" } | |
| """ | |
| # ββ Authentication ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if not await _authenticate(request): | |
| raise HTTPException(status_code=403, detail="Forbidden") | |
| # ββ Parse body manually to handle "class" reserved keyword βββββββββββββ | |
| try: | |
| body = await request.json() | |
| except Exception: | |
| raise HTTPException(status_code=422, detail="Invalid JSON body") | |
| board = str(body.get("board", "")).strip() | |
| cls = str(body.get("class", "")).strip() | |
| subject = str(body.get("subject", "")).strip() | |
| if not all([board, cls, subject]): | |
| raise HTTPException( | |
| status_code=422, | |
| detail="Request body must contain board, class, and subject", | |
| ) | |
| key = _dataset_key(board, cls, subject) | |
| # ββ Return immediately if already cached ββββββββββββββββββββββββββββββββ | |
| if key in _dataset_cache: | |
| log.info("Dataset cache hit: %s", key) | |
| return JSONResponse({"status": "ready", "message": "Dataset Loaded"}) | |
| # ββ Load and cache β run blocking HF I/O in a thread pool so the event | |
| # loop is not frozen, but we still await completion before responding. ββ | |
| try: | |
| config, faiss_index, metadata = await asyncio.to_thread( | |
| _load_dataset, board, cls, subject | |
| ) | |
| _dataset_cache[key] = (config, faiss_index, metadata) | |
| log.info("Dataset cached for key: %s", key) | |
| except Exception as exc: | |
| log.error("Dataset load error: %s", exc) | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Could not load dataset for {board}/{cls}/{subject}: {exc}", | |
| ) | |
| return JSONResponse({"status": "ready", "message": "Dataset Loaded"}) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # /chat endpoint β Change 4: uses dataset preloaded via /dataset | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def chat(request: Request, body: ChatRequest): | |
| # ββ 1. Authentication βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if not await _authenticate(request): | |
| raise HTTPException(status_code=403, detail="Forbidden") | |
| # ββ 2. Validate request body ββββββββββββββββββββββββββββββββββββββββββββ | |
| lp = body.learning_path | |
| msg = body.query.get("request_message", "").strip() | |
| if not msg: | |
| raise HTTPException(status_code=422, detail="request_message must not be empty") | |
| board = lp.get("board", "").strip() | |
| cls = lp.get("class", "").strip() | |
| subject = lp.get("subject", "").strip() | |
| if not all([board, cls, subject]): | |
| raise HTTPException( | |
| status_code=422, | |
| detail="learning_path must contain board, class, and subject", | |
| ) | |
| # ββ 3. Change 4: Retrieve dataset from cache (must call /dataset first) β | |
| key = _dataset_key(board, cls, subject) | |
| if key not in _dataset_cache: | |
| raise HTTPException( | |
| status_code=412, | |
| detail=( | |
| f"Dataset for {board}/{cls}/{subject} is not loaded. " | |
| "Please call POST /dataset first." | |
| ), | |
| ) | |
| config, faiss_index, metadata = _dataset_cache[key] | |
| # ββ 4. RAG retrieval ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| rag_chunks = _rag_search(msg, config, faiss_index, metadata) | |
| except Exception as exc: | |
| log.warning("RAG search failed (%s) β continuing without context", exc) | |
| rag_chunks = [] | |
| # ββ 5. Build prompt and run LLM (Change 2: only CPUβGPU move happens here) | |
| system_prompt = _build_system_prompt(lp, rag_chunks) | |
| user_prompt = f"Student: {msg}" | |
| try: | |
| raw_output = run_inference(system_prompt, user_prompt) | |
| except Exception as exc: | |
| log.error("Inference error: %s", exc) | |
| raise HTTPException(status_code=500, detail=f"Inference failed: {exc}") | |
| # ββ 6. Parse structured output ββββββββββββββββββββββββββββββββββββββββββ | |
| parsed = _parse_llm_output(raw_output) | |
| ai_text = parsed.get("response", raw_output).strip() | |
| # ββ 7. Text-to-speech βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| audio_b64 = _tts_to_b64(ai_text) | |
| # ββ 8. Update learning path state βββββββββββββββββββββββββββββββββββββββ | |
| updated_lp = _apply_state_updates(lp, parsed, msg, ai_text) | |
| # ββ 9. Return response ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| return JSONResponse({ | |
| "learning_path": updated_lp, | |
| "query": { | |
| "response_message": { | |
| "text": ai_text, | |
| "visual": "No", | |
| "visual_content": "", | |
| "audio_output": audio_b64, | |
| } | |
| }, | |
| }) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Gradio shim | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="Maria Learning Service") as _gradio_ui: | |
| gr.Markdown( | |
| """ | |
| ## Maria Learning Service | |
| This Space exposes a **REST API** β it is not a chat UI. | |
| | Endpoint | Method | Description | | |
| |-----------|--------|------------------------------------| | |
| | `/dataset`| POST | Pre-load dataset (call before chat)| | |
| | `/chat` | POST | Main tutoring endpoint | | |
| | `/health` | GET | Health check | | |
| | `/docs` | GET | Swagger UI | | |
| Authenticate via `auth_code` header or `cf-turnstile-token` header. | |
| """ | |
| ) | |
| # Mount Gradio UI at /ui β keeps FastAPI routes at root level | |
| app = gr.mount_gradio_app(_fastapi, _gradio_ui, path="/ui") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Entry point | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info", | |
| workers=1, # Single worker β shared in-memory model object | |
| ) |