mariatutor / app.py
digifreely's picture
Update app.py
9e816bb verified
"""
Maria AI Tutor β€” FastAPI Backend
Hugging Face Spaces (Docker)
"""
import os, io, re, json, base64, logging, wave, struct, urllib.request, urllib.parse
from contextlib import asynccontextmanager
from typing import Optional, List, Dict, Any
import torch
import numpy as np
import pandas as pd
import faiss
import httpx
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, ConfigDict
from transformers import AutoModelForCausalLM, AutoTokenizer
# datasets import removed β€” HF dataset not used at runtime
from sentence_transformers import SentenceTransformer
# ── Logging ────────────────────────────────────────────────────────────────────
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
log = logging.getLogger(__name__)
# ── Config ─────────────────────────────────────────────────────────────────────
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PIPER_MODEL_PATH = os.path.join(BASE_DIR, "models", "en_US-lessac-medium.onnx")
INSTRUCT_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
CODER_MODEL = "Qwen/Qwen2.5-Coder-1.5B-Instruct"
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
DATASET_NAME = "digifreely/Maria"
DATASET_BASE_URL = "https://huggingface.co/datasets/digifreely/Maria/resolve/main"
MAX_HISTORY = 20
GEN_MAX_TOKENS = 220
HF_TOKEN = os.environ.get("HF_TOKEN", "")
EXPECTED_HASH = os.environ.get("EXPECTED_HASH", "")
CF_TURNSTILE_SECRET = os.environ.get("CF_TURNSTILE_SECRET", "")
CF_API_TOKEN = os.environ.get("CF_API_TOKEN", "")
CF_ZONE_ID = os.environ.get("CF_ZONE_ID", "")
ALLOWED_DOMAIN = os.environ.get("ALLOWED_DOMAIN", "buildwithsupratim.github.io")
# ── Global State ───────────────────────────────────────────────────────────────
_models: Dict[str, Any] = {}
_tokenizers: Dict[str, Any] = {}
_embed: Any = None
_faiss: Dict[str, Any] = {}
_meta: Dict[str, Any] = {}
_piper: Any = None
# ══════════════════════════════════════════════════════════════════════════════
# Pydantic Schemas
# ══════════════════════════════════════════════════════════════════════════════
class ScratchpadItem(BaseModel):
chat_id: int
thought: str = ""
action: str = ""
action_input: str = ""
observation: str = ""
class ChatHistoryItem(BaseModel):
chat_id: int
user_input: str
system_output: str
class LearningObjectiveStatus(BaseModel):
goal: str
teach: str = "Not_Complete"
re_teach: str = "Not_Complete"
show_and_tell: str = "Not_Complete"
assess: str = "Not_Complete"
class CurrentLearningItem(BaseModel):
topic: str
content: str
learning_objectives: List[LearningObjectiveStatus] = []
class AssessmentStages(BaseModel):
current_learning: List[CurrentLearningItem] = []
class CurriculumObjective(BaseModel):
topics: str
content: str
learning_objectives: List[str] = []
class LearningPath(BaseModel):
model_config = ConfigDict(populate_by_name=True)
board: str
class_name: str = Field(alias="class")
subject: str
student_name: str
teacher_persona: str
curriculum_objectives: List[CurriculumObjective] = []
chat_history: List[ChatHistoryItem] = []
scratchpad: List[ScratchpadItem] = []
assessment_stages: Optional[AssessmentStages] = None
class QueryIn(BaseModel):
request_message: str
class ChatRequest(BaseModel):
learning_path: LearningPath
query: QueryIn
class ResponseMessage(BaseModel):
text: str
visual: str = "No"
visual_content: str = ""
audio_output: str = ""
class QueryOut(BaseModel):
response_message: ResponseMessage
class ChatResponse(BaseModel):
model_config = ConfigDict(populate_by_name=True)
learning_path: LearningPath
query: QueryOut
# ══════════════════════════════════════════════════════════════════════════════
# Model / Dataset Loading (called once at startup)
# ══════════════════════════════════════════════════════════════════════════════
def _load_transformer(name: str, key: str) -> None:
log.info(f"Loading {key}: {name}")
tok = AutoTokenizer.from_pretrained(name, token=HF_TOKEN or None)
_tokenizers[key] = tok
common_kw: Dict[str, Any] = {"token": HF_TOKEN or None}
if torch.cuda.is_available():
from transformers import BitsAndBytesConfig
qc = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
mdl = AutoModelForCausalLM.from_pretrained(
name, quantization_config=qc, device_map="auto", **common_kw
)
log.info(f"{key} β†’ GPU int4 (bitsandbytes)")
else:
mdl = AutoModelForCausalLM.from_pretrained(
name, torch_dtype=torch.float32, **common_kw
)
try:
from optimum.quanto import quantize, qint4, freeze
quantize(mdl, weights=qint4)
freeze(mdl)
log.info(f"{key} β†’ CPU int4 (quanto)")
except Exception as e:
log.warning(f"{key} β†’ CPU float32 fallback ({e})")
mdl.eval()
_models[key] = mdl
log.info(f"{key} ready")
def load_all_models() -> None:
_load_transformer(INSTRUCT_MODEL, "instruct")
_load_transformer(CODER_MODEL, "coder")
global _embed
log.info("Loading embedding model…")
_embed = SentenceTransformer(EMBED_MODEL)
log.info("Embedding model ready")
def load_piper() -> None:
global _piper
try:
from piper.voice import PiperVoice
if os.path.exists(PIPER_MODEL_PATH):
_piper = PiperVoice.load(PIPER_MODEL_PATH)
log.info("Piper TTS ready")
else:
log.warning("Piper .onnx not found β€” audio disabled")
except Exception as e:
log.warning(f"Piper unavailable: {e}")
# ══════════════════════════════════════════════════════════════════════════════
# FastAPI lifespan
# ══════════════════════════════════════════════════════════════════════════════
@asynccontextmanager
async def lifespan(app: FastAPI):
load_all_models()
load_piper()
yield
app = FastAPI(title="Maria AI Tutor", version="1.0.0", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=[f"https://{ALLOWED_DOMAIN}"],
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# ══════════════════════════════════════════════════════════════════════════════
# Inference helper
# ══════════════════════════════════════════════════════════════════════════════
def generate(key: str, system: str, user: str, max_tokens: int = GEN_MAX_TOKENS) -> str:
tok = _tokenizers[key]
mdl = _models[key]
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tok(prompt, return_tensors="pt")
# Move inputs to the model's device
try:
device = next(mdl.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
except Exception:
pass # device_map="auto" handles it
with torch.no_grad():
out = mdl.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tok.eos_token_id,
)
new_tokens = out[0][inputs["input_ids"].shape[1]:]
return tok.decode(new_tokens, skip_special_tokens=True).strip()
# ══════════════════════════════════════════════════════════════════════════════
# RAG / FAISS
# ══════════════════════════════════════════════════════════════════════════════
def _ensure_faiss(board: str, cls: str, subject: str):
key = f"{board}/{cls}/{subject}"
if key in _faiss:
return _faiss[key], _meta[key]
try:
base = "{}/knowledgebase/{}/{}/{}".format(DATASET_BASE_URL, urllib.parse.quote(board, safe=""), urllib.parse.quote(cls, safe=""), urllib.parse.quote(subject, safe=""))
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
def _download(url: str, dest: str):
req = urllib.request.Request(url, headers=headers)
with urllib.request.urlopen(req, timeout=30) as r, open(dest, "wb") as f:
f.write(r.read())
fp = f"/tmp/fi_{board}_{cls}_{subject}.bin"
mp = f"/tmp/mt_{board}_{cls}_{subject}.parquet"
_download(f"{base}/faiss_index.bin", fp)
_download(f"{base}/metadata.parquet", mp)
index = faiss.read_index(fp)
meta = pd.read_parquet(mp)
_faiss[key], _meta[key] = index, meta
log.info(f"FAISS loaded for {key}")
return index, meta
except Exception as e:
log.error(f"FAISS load error [{key}]: {e}")
return None, None
def rag_search(board: str, cls: str, subject: str, query: str, top_k: int = 3) -> str:
if _embed is None:
return ""
index, meta = _ensure_faiss(board, cls, subject)
if index is None:
return ""
vec = _embed.encode([query])[0].astype("float32").reshape(1, -1)
faiss.normalize_L2(vec)
_, idxs = index.search(vec, top_k)
chunks = []
text_cols = ["text", "content", "chunk", "passage"]
for i in idxs[0]:
if 0 <= i < len(meta):
row = meta.iloc[i]
for col in text_cols:
if col in meta.columns:
chunks.append(str(row[col]))
break
else:
chunks.append(str(row.iloc[0]))
return "\n---\n".join(chunks)
# ══════════════════════════════════════════════════════════════════════════════
# TTS
# ══════════════════════════════════════════════════════════════════════════════
def tts_base64(text: str) -> str:
if _piper is None or not text:
return ""
try:
buf = io.BytesIO()
with wave.open(buf, "wb") as wf:
_piper.synthesize(text, wf)
return base64.b64encode(buf.getvalue()).decode()
except Exception as e:
log.error(f"TTS error: {e}")
return ""
# ══════════════════════════════════════════════════════════════════════════════
# Security
# ══════════════════════════════════════════════════════════════════════════════
def check_auth_code(code: str) -> bool:
if not EXPECTED_HASH or not code:
return False
import hashlib
return hashlib.sha256(code.encode()).hexdigest() == EXPECTED_HASH
async def verify_turnstile(token: str, ip: str) -> bool:
if not CF_TURNSTILE_SECRET:
return True # secret not configured β†’ allow
try:
async with httpx.AsyncClient(timeout=10) as client:
r = await client.post(
"https://challenges.cloudflare.com/turnstile/v0/siteverify",
data={"secret": CF_TURNSTILE_SECRET, "response": token, "remoteip": ip},
)
return r.json().get("success", False)
except Exception:
return False
async def block_ip(ip: str) -> None:
if not CF_API_TOKEN or not CF_ZONE_ID:
return
try:
async with httpx.AsyncClient(timeout=10) as client:
await client.post(
f"https://api.cloudflare.com/client/v4/zones/{CF_ZONE_ID}/firewall/access_rules/rules",
headers={"Authorization": f"Bearer {CF_API_TOKEN}", "Content-Type": "application/json"},
json={
"mode": "block",
"configuration": {"target": "ip", "value": ip},
"notes": "Auto-blocked by Maria AI",
},
)
except Exception as e:
log.error(f"IP block failed: {e}")
# ══════════════════════════════════════════════════════════════════════════════
# Agent helpers
# ══════════════════════════════════════════════════════════════════════════════
def _context(lp: LearningPath) -> str:
hist = "\n".join(
f"Student: {h.user_input}\nTeacher: {h.system_output}"
for h in lp.chat_history[-8:]
)
pad = "\n".join(
f"[{s.action}] {s.observation}"
for s in lp.scratchpad[-4:]
if s.action
)
return f"Chat History:\n{hist}\n\nScratchpad:\n{pad}"
def _current_topic(lp: LearningPath) -> str:
if lp.assessment_stages and lp.assessment_stages.current_learning:
return lp.assessment_stages.current_learning[0].topic
return "the current lesson"
# ── fn_brain ──────────────────────────────────────────────────────────────────
def fn_brain(lp: LearningPath, msg: str) -> str:
system = (
"You are a routing decision maker for a children's educational AI tutor. "
"Choose EXACTLY ONE word from: Block, Question, Curriculum, ChitChat\n\n"
"Block β€” disrespectful, inappropriate, adult, abusive content\n"
"Question β€” curiosity question clearly outside the current curriculum\n"
"Curriculum β€” the student is engaging with lesson content, asking about it, or ready to learn\n"
"ChitChat β€” casual talk, greetings, sharing feelings, general conversation\n\n"
"Reply with ONLY one word."
)
ctx = _context(lp)
user = f"Current topic: {_current_topic(lp)}\n{ctx}\n\nStudent: {msg}"
decision = generate("instruct", system, user, max_tokens=5).strip().lower()
for kw in ("block", "question", "curriculum", "chitchat"):
if kw in decision:
return kw
return "curriculum" # safe default
# ── fn_block ──────────────────────────────────────────────────────────────────
def fn_block(lp: LearningPath, msg: str) -> tuple[str, str]:
ctx = _context(lp)
# Count recent block events in scratchpad
block_count = sum(1 for s in lp.scratchpad if s.action == "block")
permanent = block_count >= 3
system = (
f"You are a teacher for children aged 6-12. Student: {lp.student_name}. "
f"Style: {lp.teacher_persona}. "
"The student said something inappropriate or disrespectful. "
+ (
"This has happened multiple times. Say: 'Please show this chat to your parent or teacher.' "
"Say nothing else and do not engage further."
if permanent else
"Respond with gentle humor/sarcasm to discourage the behavior β€” never be rude or arrogant. "
f"Try to redirect back to: {_current_topic(lp)}. Keep to 2-3 lines max."
)
)
response = generate("instruct", system, f"{ctx}\nStudent: {msg}", max_tokens=120)
return response, "inappropriate content β€” blocking"
# ── fn_chitchat ───────────────────────────────────────────────────────────────
def fn_chitchat(lp: LearningPath, msg: str) -> tuple[str, str]:
system = (
f"You are a friendly teacher for children aged 6-12. "
f"Student: {lp.student_name}. Style: {lp.teacher_persona}. "
"Engage warmly and briefly, then gently guide back to the lesson. "
f"Current topic: {_current_topic(lp)}. Keep to 2-3 lines. Age-appropriate only."
)
response = generate("instruct", system, f"{_context(lp)}\nStudent: {msg}", max_tokens=150)
return response, "chitchat β€” engaging and redirecting"
# ── fn_question ───────────────────────────────────────────────────────────────
def fn_question(lp: LearningPath, msg: str) -> tuple[str, str]:
rag = rag_search(lp.board, lp.class_name, lp.subject, msg)
ref = f"\n\nReference material:\n{rag}" if rag else ""
system = (
f"You are a teacher for children aged 6-12. "
f"Student: {lp.student_name}. Style: {lp.teacher_persona}. "
"Answer the student's question using the reference material if available. "
"If you don't have an answer, say so kindly. "
"After answering, gently nudge the student back to the current lesson. "
"Keep to 2-3 lines. Age-appropriate only."
)
user = f"{_context(lp)}{ref}\n\nStudent: {msg}"
response = generate("instruct", system, user, max_tokens=200)
thought = f"question answered β€” RAG {'found' if rag else 'empty'}"
return response, thought
# ── fn_visualconstruct ────────────────────────────────────────────────────────
def fn_visualconstruct(instruction: str) -> str:
system = (
"You are an HTML/CSS/JavaScript developer making simple educational visuals for children aged 6-12. "
"Generate ONLY executable code inside a single <div> tag. "
"No markdown fences. No explanations. Use only inline styles and vanilla JS. "
"Format: <div> ... your HTML/CSS/JS ... </div>"
)
user = f"Create a very simple visual for: {instruction}"
result = generate("coder", system, user, max_tokens=600)
match = re.search(r"(<div\b[^>]*>.*?</div>)", result, re.DOTALL | re.IGNORECASE)
if match:
return match.group(1)
# wrap whatever came back
cleaned = re.sub(r"```[a-z]*", "", result).strip("`").strip()
return f"<div>{cleaned}</div>"
# ── fn_curriculum ─────────────────────────────────────────────────────────────
_POSITIVE = {"understand", "great", "excellent", "well done", "correct", "move on",
"next", "sure", "yes", "okay", "got it", "right", "nice"}
def _step_complete(response: str) -> bool:
words = set(response.lower().split())
return bool(words & _POSITIVE)
def fn_curriculum(lp: LearningPath, msg: str) -> tuple[str, str, str, str]:
"""
Returns: (text, thought, visual_content, step_name)
"""
if not lp.assessment_stages or not lp.assessment_stages.current_learning:
done = (
f"Congratulations {lp.student_name}! You are done for the day! πŸŽ‰ "
"Have a great time with your friends, parents, and teacher. "
"You're the best kid I ever taught!"
)
return done, "all objectives done", "", "done"
# Find first incomplete step
active_cl = None
active_obj = None
active_step = None
STEPS = ["teach", "re_teach", "show_and_tell", "assess"]
for cl_item in lp.assessment_stages.current_learning:
for obj in cl_item.learning_objectives:
for step in STEPS:
if getattr(obj, step) == "Not_Complete":
active_cl, active_obj, active_step = cl_item, obj, step
break
if active_step:
break
if active_step:
break
if not active_step:
done = (
f"Congratulations {lp.student_name}! You are done for the day! πŸŽ‰ "
"Have a great time with your friends, parents, and teacher. "
"You're the best kid I ever taught!"
)
return done, "all objectives done", "", "done"
rag = rag_search(lp.board, lp.class_name, lp.subject,
f"{active_cl.topic} {active_obj.goal}")
ref = f"\n\nReference material:\n{rag}" if rag else ""
ctx = _context(lp)
base_info = (
f"Topic: {active_cl.topic}\n"
f"Objective: {active_obj.goal}\n"
f"Content: {active_cl.content}{ref}\n{ctx}\nStudent: {msg}"
)
persona = (
f"You are a teacher for children aged 6-12. "
f"Student: {lp.student_name}. Style: {lp.teacher_persona}. "
"Keep responses to 2-3 lines max. Age-appropriate, simple, friendly."
)
visual_content = ""
# ── teach ─────────────────────────────────────────────────────────────────
if active_step == "teach":
system = (
f"{persona} "
"Teach the learning objective simply and clearly using the reference material. "
"If the student already seems to understand from the chat history, say so and offer to move forward."
)
response = generate("instruct", system, base_info, max_tokens=220)
thought = f"teaching: {active_obj.goal}"
# ── re_teach ──────────────────────────────────────────────────────────────
elif active_step == "re_teach":
system = (
f"{persona} "
"Briefly summarize what was taught, then ask the student one simple question to check understanding."
)
response = generate("instruct", system, base_info, max_tokens=220)
thought = f"re-teaching: {active_obj.goal}"
# ── show_and_tell ─────────────────────────────────────────────────────────
elif active_step == "show_and_tell":
system = (
f"{persona} "
"Give a very simple show-and-tell explanation in 2-3 lines to make the concept visual and fun."
)
response = generate("instruct", system, base_info, max_tokens=220)
vis_prompt = (
f"Very simple educational visual for a 6-12 year old child about: "
f"{active_obj.goal} (topic: {active_cl.topic}). Keep it extremely simple."
)
visual_content = fn_visualconstruct(vis_prompt)
thought = f"show and tell: {active_obj.goal}"
# ── assess ────────────────────────────────────────────────────────────────
else: # assess
system = (
f"{persona} "
"Ask if the student has any questions. If not, ask one simple question to check understanding. "
"If the chat history shows they understood, praise them and say you'll move to the next objective."
)
response = generate("instruct", system, base_info, max_tokens=220)
thought = f"assessing: {active_obj.goal}"
# ── Auto-advance step if response signals completion ──────────────────────
if _step_complete(response):
setattr(active_obj, active_step, "complete")
return response, thought, visual_content, active_step
# ══════════════════════════════════════════════════════════════════════════════
# Core request processor
# ══════════════════════════════════════════════════════════════════════════════
def process(req: ChatRequest) -> ChatResponse:
lp = req.learning_path
msg = req.query.request_message
next_id = max((h.chat_id for h in lp.chat_history), default=0) + 1
route = fn_brain(lp, msg)
visual = "No"
visual_content = ""
if route == "block":
text, thought = fn_block(lp, msg)
action, observation = "block", "blocked inappropriate content"
elif route == "question":
text, thought = fn_question(lp, msg)
action, observation = "question", "answered off-curriculum question"
elif route == "curriculum":
text, thought, visual_content, step = fn_curriculum(lp, msg)
action = "curriculum"
observation = f"curriculum step: {step}"
if visual_content:
visual = "Yes"
else: # chitchat
text, thought = fn_chitchat(lp, msg)
action, observation = "chitchat", "engaged in casual conversation"
# ── Update history & scratchpad ───────────────────────────────────────────
lp.chat_history.append(ChatHistoryItem(
chat_id=next_id, user_input=msg, system_output=text
))
lp.scratchpad.append(ScratchpadItem(
chat_id=next_id, thought=thought,
action=action, action_input=msg, observation=observation
))
# Trim to last MAX_HISTORY
lp.chat_history = lp.chat_history[-MAX_HISTORY:]
lp.scratchpad = lp.scratchpad[-MAX_HISTORY:]
# ── TTS ───────────────────────────────────────────────────────────────────
audio = tts_base64(text)
return ChatResponse(
learning_path=lp,
query=QueryOut(
response_message=ResponseMessage(
text=text,
visual=visual,
visual_content=visual_content,
audio_output=audio,
)
),
)
# ══════════════════════════════════════════════════════════════════════════════
# Routes
# ══════════════════════════════════════════════════════════════════════════════
@app.get("/health")
async def health():
return {"status": "ok", "models_loaded": list(_models.keys())}
@app.get("/PING")
async def ping(request: Request):
auth_code = request.headers.get("auth_code", "")
if not check_auth_code(auth_code):
raise HTTPException(status_code=401, detail="Unauthorized")
return {"message": "alive"}
@app.post("/chatmessenger", response_model=None)
async def chatmessenger(request: Request, body: ChatRequest):
client_ip = request.client.host if request.client else "0.0.0.0"
auth_code = request.headers.get("auth_code", "")
ts_token = request.headers.get("cf-turnstile-token", "")
if auth_code:
if not check_auth_code(auth_code):
await block_ip(client_ip)
raise HTTPException(status_code=401, detail="Unauthorized")
elif ts_token:
if not await verify_turnstile(ts_token, client_ip):
await block_ip(client_ip)
raise HTTPException(status_code=403, detail="Turnstile verification failed")
else:
await block_ip(client_ip)
raise HTTPException(status_code=403, detail="Authentication required")
result = process(body)
return result.model_dump(by_alias=True)