github-actions commited on
Commit
06ca3b1
·
0 Parent(s):

Deploy render worker from GitHub Actions

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Dockerfile ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Same runtime as API; runs health endpoint + Celery worker (see worker_health.py)
2
+ FROM python:3.11-slim-bookworm
3
+
4
+ ENV PYTHONUNBUFFERED=1 \
5
+ PYTHONDONTWRITEBYTECODE=1 \
6
+ PIP_NO_CACHE_DIR=1 \
7
+ PIP_ROOT_USER_ACTION=ignore \
8
+ NO_ALBUMENTATIONS_UPDATE=1 \
9
+ OMP_NUM_THREADS=1 \
10
+ MKL_NUM_THREADS=1 \
11
+ OPENBLAS_NUM_THREADS=1
12
+
13
+ WORKDIR /app
14
+ ENV PYTHONPATH=/app
15
+
16
+ RUN apt-get update && apt-get install -y --no-install-recommends \
17
+ ffmpeg \
18
+ pkg-config \
19
+ cmake \
20
+ libcairo2 \
21
+ libcairo2-dev \
22
+ libpango-1.0-0 \
23
+ libpango1.0-dev \
24
+ libpangocairo-1.0-0 \
25
+ libgdk-pixbuf-2.0-0 \
26
+ libffi-dev \
27
+ python3-dev \
28
+ texlive-latex-base \
29
+ texlive-fonts-recommended \
30
+ texlive-latex-extra \
31
+ build-essential \
32
+ && rm -rf /var/lib/apt/lists/*
33
+
34
+ COPY requirements.worker-render.txt .
35
+ RUN pip install --upgrade pip setuptools wheel \
36
+ && pip install -r requirements.worker-render.txt
37
+
38
+ COPY . .
39
+
40
+ RUN python scripts/prewarm_render_worker.py
41
+
42
+ ENV PORT=7860 \
43
+ CELERY_WORKER_QUEUES=render
44
+ EXPOSE 7860
45
+
46
+ ENTRYPOINT []
47
+ CMD ["sh", "-c", "exec python3 -u worker_health.py"]
Dockerfile.worker.ocr ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Celery worker: OCR queue only (no Manim / LaTeX / Cairo stack).
2
+ FROM python:3.11-slim-bookworm
3
+
4
+ ENV PYTHONUNBUFFERED=1 \
5
+ PYTHONDONTWRITEBYTECODE=1 \
6
+ PIP_NO_CACHE_DIR=1 \
7
+ PIP_ROOT_USER_ACTION=ignore \
8
+ NO_ALBUMENTATIONS_UPDATE=1 \
9
+ OMP_NUM_THREADS=1 \
10
+ MKL_NUM_THREADS=1 \
11
+ OPENBLAS_NUM_THREADS=1
12
+
13
+ WORKDIR /app
14
+ ENV PYTHONPATH=/app
15
+
16
+ RUN apt-get update && apt-get install -y --no-install-recommends \
17
+ build-essential \
18
+ cmake \
19
+ pkg-config \
20
+ python3-dev \
21
+ libglib2.0-0 \
22
+ libgomp1 \
23
+ libgl1 \
24
+ libsm6 \
25
+ libxext6 \
26
+ libxrender1 \
27
+ && rm -rf /var/lib/apt/lists/*
28
+
29
+ COPY requirements.worker-ocr.txt .
30
+ RUN pip install --upgrade pip setuptools wheel \
31
+ && pip install -r requirements.worker-ocr.txt
32
+
33
+ COPY . .
34
+
35
+ RUN python scripts/prewarm_ocr_worker.py
36
+
37
+ ENV PORT=7860 \
38
+ CELERY_WORKER_QUEUES=ocr
39
+ EXPOSE 7860
40
+
41
+ ENTRYPOINT []
42
+ CMD ["sh", "-c", "exec python3 -u worker_health.py"]
README.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Math Solver Render Worker
3
+ emoji: 👷
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: docker
7
+ app_port: 7860
8
+ ---
9
+
10
+ # Math Solver — Render worker (Manim)
11
+
12
+ This Space runs **Celery** via `worker_health.py` and consumes **only** queue **`render`** (`render_geometry_video`). Image sets `CELERY_WORKER_QUEUES=render` by default (`Dockerfile.worker`).
13
+
14
+ **Solve** (orchestrator, agents, OCR-in-request when `OCR_USE_CELERY` is off) runs on the **API** Space, not on this worker.
15
+
16
+ ## OCR offload (separate Space)
17
+
18
+ Queue **`ocr`** is handled by a **dedicated OCR worker** (`Dockerfile.worker.ocr`, `README_HF_WORKER_OCR.md`, workflow `deploy-worker-ocr.yml`). On the API, set `OCR_USE_CELERY=true` and deploy an OCR Space that listens on `ocr`.
19
+
20
+ ## Secrets
21
+
22
+ Same broker as the API: `REDIS_URL` / `CELERY_BROKER_URL`, Supabase, OpenRouter (renderer may use LLM paths), etc.
23
+
24
+ **GitHub Actions:** repository secrets `HF_TOKEN` and `HF_WORKER_REPO` (`owner/space-name`) for this workflow (`deploy-worker.yml`).
README_HF_WORKER_OCR.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Math Solver OCR Worker
3
+ emoji: 👁️
4
+ colorFrom: gray
5
+ colorTo: blue
6
+ sdk: docker
7
+ app_port: 7860
8
+ ---
9
+
10
+ # Math Solver — OCR-only worker
11
+
12
+ This Space runs **Celery** (`worker_health.py`) consuming **only** the `ocr` queue.
13
+
14
+ Set environment:
15
+
16
+ - `CELERY_WORKER_QUEUES=ocr` (default in `Dockerfile.worker.ocr`)
17
+ - Same `REDIS_URL` / `CELERY_BROKER_URL` / `CELERY_RESULT_BACKEND` as the API
18
+
19
+ This Space runs **raw OCR only** (YOLO, PaddleOCR, Pix2Tex). **OpenRouter / LLM tinh chỉnh** không chạy ở đây; API Space gọi `refine_with_llm` sau khi nhận kết quả từ queue `ocr`.
20
+
21
+ On the **API** Space, set `OCR_USE_CELERY=true` so `run_ocr_from_url` tasks are sent to this worker instead of running Paddle/Pix2Tex on the API process.
22
+
23
+ Optional: `OCR_CELERY_TIMEOUT_SEC` (default `180`).
24
+
25
+ **Manim / video** uses a different Celery queue (`render`) and Space — see `README_HF_WORKER.md` and workflow `deploy-worker.yml`.
26
+
27
+ GitHub Actions: repository secrets `HF_TOKEN` and `HF_OCR_WORKER_REPO` (`owner/space-name`) enable workflow `deploy-worker-ocr.yml`.
agents/geometry_agent.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ from openai import AsyncOpenAI
5
+ from typing import Dict, Any
6
+ from dotenv import load_dotenv
7
+
8
+ load_dotenv()
9
+ logger = logging.getLogger(__name__)
10
+
11
+ from app.url_utils import openai_compatible_api_key, sanitize_env
12
+ from app.llm_client import get_llm_client
13
+
14
+
15
+ class GeometryAgent:
16
+ def __init__(self):
17
+ self.llm = get_llm_client()
18
+
19
+ async def generate_dsl(self, semantic_data: Dict[str, Any], previous_dsl: str = None) -> str:
20
+ logger.info("==[GeometryAgent] Generating DSL from semantic data==")
21
+ if previous_dsl:
22
+ logger.info(f"[GeometryAgent] Using previous DSL context (len={len(previous_dsl)})")
23
+
24
+ system_prompt = """
25
+ You are a Geometry DSL Generator. Convert semantic geometry data into a precise Geometry DSL program.
26
+
27
+ === MULTI-TURN CONTEXT ===
28
+ If a PREVIOUS DSL is provided, your job is to UPDATE or EXTEND it.
29
+ 1. DO NOT remove existing points unless the user explicitly asks to "redefine" or "move" them.
30
+ 2. Ensure new segments/points connect correctly to existing ones.
31
+ 3. Your output should be the ENTIRE updated DSL, not just the changes.
32
+
33
+ === DSL COMMANDS ===
34
+ POINT(A) — declare a point
35
+ POINT(A, x, y, z) — declare a point with explicit coordinates
36
+ LENGTH(AB, 5) — distance between A and B is 5 (2D/3D)
37
+ ANGLE(A, 90) — interior angle at vertex A is 90° (2D/3D)
38
+ PARALLEL(AB, CD) — segment AB is parallel to CD (2D/3D)
39
+ PERPENDICULAR(AB, CD) — segment AB is perpendicular to CD (2D/3D)
40
+ MIDPOINT(M, AB) — M is the midpoint of segment AB
41
+ SECTION(E, A, C, k) — E satisfies vector AE = k * vector AC (k is decimal)
42
+ LINE(A, B) — infinite line passing through A and B
43
+ RAY(A, B) — ray starting at A and passing through B
44
+ CIRCLE(O, 5) — circle with center O and radius 5 (2D)
45
+ SPHERE(O, 5) — sphere with center O and radius 5 (3D)
46
+ SEGMENT(M, N) — auxiliary segment MN to be drawn
47
+ POLYGON_ORDER(A, B, C, D) — the order in which vertices form the polygon boundary
48
+ TRIANGLE(ABC) — equilateral/arbitrary triangle
49
+ PYRAMID(S_ABCD) — pyramid with apex S and base ABCD
50
+ PRISM(ABC_DEF) — triangular prism
51
+
52
+ === RULES ===
53
+ 1. 3D Coordinates: Use POINT(A, x, y, z) if specific coordinates are given in the problem.
54
+ 2. Space Geometry: For pyramids/prisms, use the specialized commands.
55
+ 3. Primary Vertices: Always declare the main vertices of the shape (e.g., A, B, C, D) using POINT(X).
56
+ 4. POLYGON_ORDER: Always emit POLYGON_ORDER(...) for the main shape using ONLY these primary vertices.
57
+ 5. All Points: EVERY point mentioned (A, B, C, H, M, etc.) MUST be declared with POINT(Name) first.
58
+ 6. Altitudes/Perpendiculars: For an altitude AH to BC, use POINT(H) + PERPENDICULAR(AH, BC).
59
+ 7. Format: Output ONLY DSL lines — NO explanation, NO markdown, NO code blocks.
60
+
61
+ === SHAPE EXAMPLES ===
62
+
63
+ --- Case: Square Pyramid S.ABCD with side 10, height 15 ---
64
+ PYRAMID(S_ABCD)
65
+ POINT(A, 0, 0, 0)
66
+ POINT(B, 10, 0, 0)
67
+ POINT(C, 10, 10, 0)
68
+ POINT(D, 0, 10, 0)
69
+ POINT(S)
70
+ POINT(O)
71
+ SECTION(O, A, C, 0.5)
72
+ LENGTH(SO, 15)
73
+ PERPENDICULAR(SO, AC)
74
+ PERPENDICULAR(SO, AB)
75
+ POLYGON_ORDER(A, B, C, D)
76
+
77
+ --- Case: Right Triangle ABC at A, AB=3, AC=4, altitude AH ---
78
+ POLYGON_ORDER(A, B, C)
79
+ POINT(A)
80
+ POINT(B)
81
+ POINT(C)
82
+ POINT(H)
83
+ LENGTH(AB, 3)
84
+ LENGTH(AC, 4)
85
+ ANGLE(A, 90)
86
+ PERPENDICULAR(AH, BC)
87
+ SEGMENT(A, H)
88
+
89
+ --- Case: Rectangle ABCD with AB=5, AD=10 ---
90
+ POLYGON_ORDER(A, B, C, D)
91
+ POINT(A)
92
+ POINT(B)
93
+ POINT(C)
94
+ POINT(D)
95
+ LENGTH(AB, 5)
96
+ LENGTH(AD, 10)
97
+ PERPENDICULAR(AB, AD)
98
+ PARALLEL(AB, CD)
99
+ PARALLEL(AD, BC)
100
+
101
+ [Circle with center O radius 7]
102
+ POINT(O)
103
+ CIRCLE(O, 7)
104
+ """
105
+
106
+ user_content = f"Semantic Data: {json.dumps(semantic_data, ensure_ascii=False)}"
107
+ if previous_dsl:
108
+ user_content = f"PREVIOUS DSL:\n{previous_dsl}\n\nUPDATE WITH NEW DATA: {json.dumps(semantic_data, ensure_ascii=False)}"
109
+
110
+ logger.debug("[GeometryAgent] Calling LLM (Multi-Layer)...")
111
+ content = await self.llm.chat_completions_create(
112
+ messages=[
113
+ {"role": "system", "content": system_prompt},
114
+ {"role": "user", "content": user_content}
115
+ ]
116
+ )
117
+ dsl = content.strip() if content else ""
118
+ logger.info(f"[GeometryAgent] DSL generated ({len(dsl.splitlines())} lines).")
119
+ logger.debug(f"[GeometryAgent] DSL output:\n{dsl}")
120
+ return dsl
agents/knowledge_agent.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Dict, Any
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ # ─── Shape rule registry ────────────────────────────────────────────────────
8
+ # Each entry: keyword list → augmentation function
9
+ # Augmentation receives (values: dict, text: str) and returns updated values dict.
10
+
11
+ class KnowledgeAgent:
12
+ """Knowledge Agent: Stores geometric theorems and common patterns to augment Parser output."""
13
+
14
+ def augment_semantic_data(self, semantic_data: Dict[str, Any]) -> Dict[str, Any]:
15
+ logger.info("==[KnowledgeAgent] Augmenting semantic data==")
16
+ text = str(semantic_data.get("input_text", "")).lower()
17
+ logger.debug(f"[KnowledgeAgent] Input text for matching: '{text[:200]}'")
18
+
19
+ shape_type = self._detect_shape(text, semantic_data.get("type", ""))
20
+ if shape_type:
21
+ semantic_data["type"] = shape_type
22
+ values = semantic_data.get("values", {})
23
+ values = self._augment_values(shape_type, values, text)
24
+ semantic_data["values"] = values
25
+ else:
26
+ logger.info("[KnowledgeAgent] No special rule matched. Returning data unchanged.")
27
+
28
+ logger.debug(f"[KnowledgeAgent] Output semantic data: {semantic_data}")
29
+ return semantic_data
30
+
31
+ # ─── Shape detection ────────────────────────────────────────────────────
32
+ def _detect_shape(self, text: str, llm_type: str) -> str | None:
33
+ """Detect shape from text keywords. LLM type provides a hint."""
34
+ checks = [
35
+ (["hình vuông", "square"], "square"),
36
+ (["hình chữ nhật", "rectangle"], "rectangle"),
37
+ (["hình thoi", "rhombus"], "rhombus"),
38
+ (["hình bình hành", "parallelogram"], "parallelogram"),
39
+ (["hình thang vuông"], "right_trapezoid"),
40
+ (["hình thang", "trapezoid", "trapezium"], "trapezoid"),
41
+ (["tam giác vuông", "right triangle"], "right_triangle"),
42
+ (["tam giác đều", "equilateral triangle", "equilateral"], "equilateral_triangle"),
43
+ (["tam giác cân", "isosceles"], "isosceles_triangle"),
44
+ (["tam giác", "triangle"], "triangle"),
45
+ (["đường tròn", "circle"], "circle"),
46
+ ]
47
+ for keywords, shape in checks:
48
+ if any(kw in text for kw in keywords):
49
+ logger.info(f"[KnowledgeAgent] Rule MATCH: '{shape}' detected (keyword match).")
50
+ return shape
51
+
52
+ # Fallback: trust LLM-detected type if it's a known type
53
+ known = {
54
+ "rectangle", "square", "rhombus", "parallelogram",
55
+ "trapezoid", "right_trapezoid", "triangle", "right_triangle",
56
+ "equilateral_triangle", "isosceles_triangle", "circle",
57
+ }
58
+ if llm_type in known:
59
+ logger.info(f"[KnowledgeAgent] Using LLM-detected type '{llm_type}'.")
60
+ return llm_type
61
+
62
+ return None
63
+
64
+ # ─── Value augmentation ──────────────────────────────────────────────────
65
+ def _augment_values(self, shape: str, values: dict, text: str) -> dict:
66
+ ab = values.get("AB")
67
+ ad = values.get("AD")
68
+ bc = values.get("BC")
69
+ cd = values.get("CD")
70
+
71
+ if shape == "rectangle":
72
+ if ab and ad:
73
+ values.setdefault("CD", ab)
74
+ values.setdefault("BC", ad)
75
+ values.setdefault("angle_A", 90)
76
+ logger.info(f"[KnowledgeAgent] Rectangle: AB=CD={ab}, AD=BC={ad}, angle_A=90°")
77
+ else:
78
+ values.setdefault("angle_A", 90)
79
+
80
+ elif shape == "square":
81
+ side = ab or ad or bc or cd or values.get("side")
82
+ if side:
83
+ values.update({"AB": side, "AD": side, "angle_A": 90})
84
+ logger.info(f"[KnowledgeAgent] Square: side={side}, angle_A=90°")
85
+ else:
86
+ values.setdefault("angle_A", 90)
87
+
88
+ elif shape == "rhombus":
89
+ side = ab or values.get("side")
90
+ if side:
91
+ values.update({"AB": side, "BC": side, "CD": side, "DA": side})
92
+ logger.info(f"[KnowledgeAgent] Rhombus: all sides={side}")
93
+
94
+ elif shape == "parallelogram":
95
+ if ab:
96
+ values.setdefault("CD", ab)
97
+ if ad:
98
+ values.setdefault("BC", ad)
99
+ logger.info(f"[KnowledgeAgent] Parallelogram: AB||CD, AD||BC")
100
+
101
+ elif shape == "trapezoid":
102
+ logger.info("[KnowledgeAgent] Trapezoid: AB||CD (bottom||top)")
103
+
104
+ elif shape == "right_trapezoid":
105
+ logger.info("[KnowledgeAgent] Right trapezoid: AB||CD, AD⊥AB")
106
+ values.setdefault("angle_A", 90)
107
+
108
+ elif shape == "equilateral_triangle":
109
+ side = ab or values.get("side")
110
+ if side:
111
+ values.update({"AB": side, "BC": side, "CA": side, "angle_A": 60})
112
+ logger.info(f"[KnowledgeAgent] Equilateral triangle: all sides={side}, angle_A=60°")
113
+
114
+ elif shape == "right_triangle":
115
+ # Try to infer which vertex is the right angle
116
+ rt_vertex = _detect_right_angle_vertex(text)
117
+ values.setdefault(f"angle_{rt_vertex}", 90)
118
+ logger.info(f"[KnowledgeAgent] Right triangle: angle_{rt_vertex}=90°")
119
+
120
+ elif shape == "isosceles_triangle":
121
+ logger.info("[KnowledgeAgent] Isosceles triangle: AB=AC (default, LLM may override)")
122
+
123
+ elif shape == "circle":
124
+ logger.info("[KnowledgeAgent] Circle detected — no side augmentation needed.")
125
+
126
+ return values
127
+
128
+
129
+ def _detect_right_angle_vertex(text: str) -> str:
130
+ """Heuristic: detect which vertex is right angle from text."""
131
+ for vertex in ["A", "B", "C", "D"]:
132
+ patterns = [f"vuông tại {vertex}", f"góc {vertex} vuông", f"right angle at {vertex}"]
133
+ if any(p.lower() in text for p in patterns):
134
+ return vertex
135
+ return "A" # default
agents/ocr_agent.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+
4
+ from vision_ocr.pipeline import OcrVisionPipeline
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class ImprovedOCRAgent:
10
+ """
11
+ API-facing OCR: composes ``OcrVisionPipeline`` (vision only) with optional LLM refinement.
12
+ Celery OCR workers should import ``OcrVisionPipeline`` directly from ``vision_ocr``.
13
+ """
14
+
15
+ def __init__(self, skip_llm_refinement: bool = False):
16
+ self._skip_llm_refinement = bool(skip_llm_refinement)
17
+ self._vision = OcrVisionPipeline()
18
+ logger.info(
19
+ "[ImprovedOCRAgent] Vision pipeline ready (skip_llm_refinement=%s)...",
20
+ self._skip_llm_refinement,
21
+ )
22
+
23
+ if self._skip_llm_refinement:
24
+ self.llm = None
25
+ logger.info("[ImprovedOCRAgent] LLM client skipped (raw OCR only).")
26
+ else:
27
+ from app.llm_client import get_llm_client
28
+
29
+ self.llm = get_llm_client()
30
+ logger.info("[ImprovedOCRAgent] Multi-Layer LLM Client initialized.")
31
+
32
+ async def process_image(self, image_path: str) -> str:
33
+ combined_text = await self._vision.process_image(image_path)
34
+
35
+ if not combined_text.strip():
36
+ return combined_text
37
+
38
+ if self._skip_llm_refinement or self.llm is None:
39
+ logger.info("[ImprovedOCRAgent] Skipping MegaLLM refinement (raw OCR output).")
40
+ return combined_text
41
+
42
+ try:
43
+ logger.info("[ImprovedOCRAgent] Sending to MegaLLM for refinement...")
44
+ refined_text = await asyncio.wait_for(
45
+ self.refine_with_llm(combined_text), timeout=30.0
46
+ )
47
+ return refined_text
48
+ except asyncio.TimeoutError:
49
+ logger.error("[ImprovedOCRAgent] MegaLLM refinement timed out.")
50
+ return combined_text
51
+ except Exception as e:
52
+ logger.error("[ImprovedOCRAgent] MegaLLM refinement failed: %s", e)
53
+ return combined_text
54
+
55
+ async def refine_with_llm(self, text: str) -> str:
56
+ if not text.strip():
57
+ return ""
58
+ if self.llm is None:
59
+ logger.warning("[ImprovedOCRAgent] refine_with_llm: no LLM client; returning raw text.")
60
+ return text
61
+
62
+ prompt = f"""Bạn là một chuyên gia số hóa tài liệu toán học.
63
+ Dưới đây là kết quả OCR thô từ một trang sách toán Tiếng Việt.
64
+ Kết quả này có thể chứa lỗi chính tả, lỗi định dạng mã LaTeX, hoặc bị ngắt quãng không logic.
65
+
66
+ Nhiệm vụ của bạn:
67
+ 1. Sửa lỗi chính tả tiếng Việt.
68
+ 2. Đảm bảo các công thức toán học được viết đúng định dạng LaTeX và nằm trong cặp dấu $...$.
69
+ 3. Giữ nguyên cấu trúc logic của bài toán.
70
+ 4. Trả về nội dung đã được làm sạch dưới dạng Markdown.
71
+
72
+ Nội dung OCR thô:
73
+ ---
74
+ {text}
75
+ ---
76
+
77
+ Kết quả làm sạch:"""
78
+
79
+ try:
80
+ refined = await self.llm.chat_completions_create(
81
+ messages=[{"role": "user", "content": prompt}],
82
+ temperature=0.1,
83
+ )
84
+ logger.info("[ImprovedOCRAgent] LLM refinement complete.")
85
+ return refined
86
+ except Exception as e:
87
+ logger.error("[ImprovedOCRAgent] LLM refinement failed: %s", e)
88
+ return text
89
+
90
+ async def process_url(self, url: str) -> str:
91
+ combined_text = await self._vision.process_url(url)
92
+
93
+ if not combined_text.strip() or combined_text.lstrip().startswith("Error:"):
94
+ return combined_text
95
+
96
+ if self._skip_llm_refinement or self.llm is None:
97
+ return combined_text
98
+
99
+ try:
100
+ return await asyncio.wait_for(self.refine_with_llm(combined_text), timeout=30.0)
101
+ except asyncio.TimeoutError:
102
+ logger.error("[ImprovedOCRAgent] MegaLLM refinement timed out.")
103
+ return combined_text
104
+ except Exception as e:
105
+ logger.error("[ImprovedOCRAgent] MegaLLM refinement failed: %s", e)
106
+ return combined_text
107
+
108
+
109
+ class OCRAgent(ImprovedOCRAgent):
110
+ """Alias for compatibility with existing code."""
111
+
112
+ pass
agents/orchestrator.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from typing import Any, Dict
4
+
5
+ from agents.geometry_agent import GeometryAgent
6
+ from agents.knowledge_agent import KnowledgeAgent
7
+ from agents.ocr_agent import OCRAgent
8
+ from agents.parser_agent import ParserAgent
9
+ from agents.solver_agent import SolverAgent
10
+ from app.logutil import log_step
11
+ from app.ocr_celery import ocr_from_image_url
12
+ from solver.dsl_parser import DSLParser
13
+ from solver.engine import GeometryEngine
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ _CLIP = 2000
18
+
19
+
20
+ def _clip(val: Any, n: int = _CLIP) -> str | None:
21
+ if val is None:
22
+ return None
23
+ if isinstance(val, str):
24
+ s = val
25
+ else:
26
+ s = json.dumps(val, ensure_ascii=False, default=str)
27
+ return s if len(s) <= n else s[:n] + "…"
28
+
29
+
30
+ def _step_io(step: str, input_val: Any = None, output_val: Any = None) -> None:
31
+ """Debug: chỉ input/output (đã cắt), tránh dump dài dòng không cần thiết."""
32
+ log_step(step, input=_clip(input_val), output=_clip(output_val))
33
+
34
+
35
+ class Orchestrator:
36
+ def __init__(self):
37
+ self.parser_agent = ParserAgent()
38
+ self.geometry_agent = GeometryAgent()
39
+ self.ocr_agent = OCRAgent()
40
+ self.knowledge_agent = KnowledgeAgent()
41
+ self.solver_agent = SolverAgent()
42
+ self.solver_engine = GeometryEngine()
43
+ self.dsl_parser = DSLParser()
44
+
45
+ def _generate_step_description(self, semantic_json: Dict[str, Any], engine_result: Dict[str, Any]) -> str:
46
+ """Tạo mô tả từng bước vẽ dựa trên kết quả của engine."""
47
+ analysis = semantic_json.get("analysis", "")
48
+ if not analysis:
49
+ analysis = f"Giải bài toán về {semantic_json.get('type', 'hình học')}."
50
+
51
+ steps = ["\n\n**Các bước dựng hình:**"]
52
+ drawing_phases = engine_result.get("drawing_phases", [])
53
+
54
+ for phase in drawing_phases:
55
+ label = phase.get("label", f"Giai đoạn {phase['phase']}")
56
+ points = ", ".join(phase.get("points", []))
57
+ segments = ", ".join([f"{s[0]}{s[1]}" for s in phase.get("segments", [])])
58
+
59
+ step_text = f"- **{label}**:"
60
+ if points:
61
+ step_text += f" Xác định các điểm {points}."
62
+ if segments:
63
+ step_text += f" Vẽ các đoạn thẳng {segments}."
64
+ steps.append(step_text)
65
+
66
+ circles = engine_result.get("circles", [])
67
+ for c in circles:
68
+ steps.append(f"- **Đường tròn**: Vẽ đường tròn tâm {c['center']} bán kính {c['radius']}.")
69
+
70
+ return analysis + "\n".join(steps)
71
+
72
+ async def run(
73
+ self,
74
+ text: str,
75
+ image_url: str = None,
76
+ job_id: str = None,
77
+ session_id: str = None,
78
+ status_callback=None,
79
+ history: list = None,
80
+ ) -> Dict[str, Any]:
81
+ """
82
+ Run the full pipeline. Optional history allows context-aware solving.
83
+ """
84
+ _step_io(
85
+ "orchestrate_start",
86
+ input_val={
87
+ "job_id": job_id,
88
+ "text_len": len(text or ""),
89
+ "image_url": image_url,
90
+ "history_len": len(history or []),
91
+ },
92
+ output_val=None,
93
+ )
94
+
95
+ if status_callback:
96
+ await status_callback("processing")
97
+
98
+ # 1. Extract context from history (if any)
99
+ previous_context = None
100
+ if history:
101
+ # Look for the last assistant message with geometry data
102
+ for msg in reversed(history):
103
+ if msg.get("role") == "assistant" and msg.get("metadata", {}).get("geometry_dsl"):
104
+ previous_context = {
105
+ "geometry_dsl": msg["metadata"]["geometry_dsl"],
106
+ "coordinates": msg["metadata"].get("coordinates", {}),
107
+ "analysis": msg.get("content", ""),
108
+ }
109
+ break
110
+
111
+ if previous_context:
112
+ _step_io("context_found", input_val=None, output_val={"dsl_len": len(previous_context["geometry_dsl"])})
113
+
114
+ # 2. Gather input text (OCR or direct)
115
+ input_text = text
116
+ if image_url:
117
+ input_text = await ocr_from_image_url(image_url, self.ocr_agent)
118
+ _step_io("step1_ocr", input_val=image_url, output_val=input_text)
119
+ else:
120
+ _step_io("step1_ocr", input_val="(no image)", output_val=text)
121
+
122
+ feedback = None
123
+ MAX_RETRIES = 2
124
+
125
+ for attempt in range(MAX_RETRIES + 1):
126
+ _step_io(
127
+ "attempt",
128
+ input_val=f"{attempt + 1}/{MAX_RETRIES + 1}",
129
+ output_val=None,
130
+ )
131
+ if status_callback:
132
+ await status_callback("solving")
133
+
134
+ # Parser with context
135
+ _step_io("step2_parse", input_val=f"{input_text[:50]}...", output_val=None)
136
+ semantic_json = await self.parser_agent.process(input_text, feedback=feedback, context=previous_context)
137
+ semantic_json["input_text"] = input_text
138
+ _step_io("step2_parse", input_val=None, output_val=semantic_json)
139
+
140
+ # Knowledge augmentation
141
+ _step_io("step3_knowledge", input_val=semantic_json, output_val=None)
142
+ semantic_json = self.knowledge_agent.augment_semantic_data(semantic_json)
143
+ _step_io("step3_knowledge", input_val=None, output_val=semantic_json)
144
+
145
+ # Geometry DSL with context (passing previous DSL to guide generation)
146
+ _step_io("step4_geometry_dsl", input_val=semantic_json, output_val=None)
147
+ dsl_code = await self.geometry_agent.generate_dsl(
148
+ semantic_json,
149
+ previous_dsl=previous_context["geometry_dsl"] if previous_context else None
150
+ )
151
+ _step_io("step4_geometry_dsl", input_val=None, output_val=dsl_code)
152
+
153
+ _step_io("step5_dsl_parse", input_val=dsl_code, output_val=None)
154
+ points, constraints, is_3d = self.dsl_parser.parse(dsl_code)
155
+ _step_io(
156
+ "step5_dsl_parse",
157
+ input_val=None,
158
+ output_val={
159
+ "points": len(points),
160
+ "constraints": len(constraints),
161
+ "is_3d": is_3d,
162
+ },
163
+ )
164
+
165
+ _step_io("step6_solve", input_val=f"{len(points)} pts / {len(constraints)} cons (is_3d={is_3d})", output_val=None)
166
+ import anyio
167
+ engine_result = await anyio.to_thread.run_sync(self.solver_engine.solve, points, constraints, is_3d)
168
+
169
+ if engine_result:
170
+ coordinates = engine_result.get("coordinates")
171
+ _step_io("step6_solve", input_val=None, output_val=coordinates)
172
+ logger.info(
173
+ "[Orchestrator] geometry solved job_id=%s is_3d=%s n_coords=%d",
174
+ job_id,
175
+ is_3d,
176
+ len(coordinates) if isinstance(coordinates, dict) else 0,
177
+ )
178
+ break
179
+
180
+ feedback = "Geometry solver failed to find a valid solution for the given constraints. Parallelism or lengths might be inconsistent."
181
+ _step_io(
182
+ "step6_solve",
183
+ input_val=f"attempt {attempt + 1}",
184
+ output_val=feedback,
185
+ )
186
+ if attempt == MAX_RETRIES:
187
+ _step_io(
188
+ "orchestrate_abort",
189
+ input_val=None,
190
+ output_val="solver_exhausted_retries",
191
+ )
192
+ return {
193
+ "error": "Solver failed after multiple attempts.",
194
+ "last_dsl": dsl_code,
195
+ }
196
+
197
+ _step_io("orchestrate_done", input_val=job_id, output_val="success")
198
+
199
+ # 8. Solution calculation (New in v5.1)
200
+ solution = None
201
+ if engine_result:
202
+ _step_io("step8_solve_math", input_val=semantic_json.get("target_question"), output_val=None)
203
+ solution = await self.solver_agent.solve(semantic_json, engine_result)
204
+ _step_io("step8_solve_math", input_val=None, output_val=solution.get("answer"))
205
+
206
+ final_analysis = self._generate_step_description(semantic_json, engine_result)
207
+
208
+ status = "success"
209
+ return {
210
+ "status": status,
211
+ "job_id": job_id,
212
+ "geometry_dsl": dsl_code,
213
+ "coordinates": coordinates,
214
+ "polygon_order": engine_result.get("polygon_order", []),
215
+ "circles": engine_result.get("circles", []),
216
+ "lines": engine_result.get("lines", []),
217
+ "rays": engine_result.get("rays", []),
218
+ "drawing_phases": engine_result.get("drawing_phases", []),
219
+ "semantic": semantic_json,
220
+ "semantic_analysis": final_analysis,
221
+ "solution": solution,
222
+ "is_3d": is_3d,
223
+ }
agents/parser_agent.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ from openai import AsyncOpenAI
5
+ from typing import Dict, Any
6
+ from dotenv import load_dotenv
7
+
8
+ load_dotenv()
9
+ logger = logging.getLogger(__name__)
10
+
11
+ from app.url_utils import openai_compatible_api_key, sanitize_env
12
+
13
+
14
+ from app.llm_client import get_llm_client
15
+
16
+
17
+ class ParserAgent:
18
+ def __init__(self):
19
+ self.llm = get_llm_client()
20
+
21
+ async def process(self, text: str, feedback: str = None, context: Dict[str, Any] = None) -> Dict[str, Any]:
22
+ logger.info(f"==[ParserAgent] Processing input (len={len(text)})==")
23
+ if feedback:
24
+ logger.warning(f"[ParserAgent] Feedback from previous attempt: {feedback}")
25
+ if context:
26
+ logger.info(f"[ParserAgent] Using previous context (dsl_len={len(context.get('geometry_dsl', ''))})")
27
+
28
+ system_prompt = """
29
+ You are a Geometry Parser Agent. Extract geometric entities and constraints from Vietnamese/LaTeX math problem text.
30
+
31
+ === CONTEXT AWARENESS ===
32
+ If previous context is provided, it means this is a follow-up request.
33
+ - Combine old entities with new ones.
34
+ - Update 'analysis' to reflect the entire problem state.
35
+
36
+ Output ONLY a JSON object with this EXACT structure (no extra keys, no markdown):
37
+ {
38
+ "entities": ["Point A", "Point B", ...],
39
+ "type": "pyramid|prism|sphere|rectangle|triangle|circle|parallelogram|trapezoid|square|rhombus|general",
40
+ "values": {"AB": 5, "SO": 15, "radius": 3},
41
+ "target_question": "Câu hỏi cụ thể cần giải (ví dụ: 'Tính diện tích tam giác ABC'). NẾU KHÔNG CÓ CÂU HỎI THÌ ĐỂ null.",
42
+ "analysis": "Tóm tắt ngắn gọn toàn bộ bài toán sau khi đã cập nhật các yêu cầu mới bằng tiếng Việt."
43
+ }
44
+ Rules:
45
+ - "analysis" MUST be a meaningful and UP-TO-DATE summary of the problem in Vietnamese.
46
+ - "target_question" must be concise.
47
+ - Include midpoints, auxiliary points in "entities" if mentioned.
48
+ - If feedback is provided, correct your previous output accordingly.
49
+ """
50
+
51
+ user_content = f"Text: {text}"
52
+ if context:
53
+ user_content = f"PREVIOUS ANALYSIS: {context.get('analysis')}\nNEW REQUEST: {text}"
54
+
55
+ if feedback:
56
+ user_content += f"\nFeedback from previous attempt: {feedback}. Please correct the constraints."
57
+
58
+ logger.debug("[ParserAgent] Calling LLM (Multi-Layer)...")
59
+ raw = await self.llm.chat_completions_create(
60
+ messages=[
61
+ {"role": "system", "content": system_prompt},
62
+ {"role": "user", "content": user_content}
63
+ ],
64
+ response_format={"type": "json_object"}
65
+ )
66
+
67
+ # Pre-process raw string: extract the JSON block if present
68
+ import re
69
+ clean_raw = raw.strip()
70
+ # Handle potential markdown code blocks
71
+ if clean_raw.startswith("```"):
72
+ import re
73
+ match = re.search(r"```(?:json)?\s*(.*?)\s*```", clean_raw, re.DOTALL)
74
+ if match:
75
+ clean_raw = match.group(1).strip()
76
+
77
+ try:
78
+ result = json.loads(clean_raw)
79
+ except json.JSONDecodeError as e:
80
+ logger.error(f"[ParserAgent] JSON Parse Error: {e}. Attempting regex fallback...")
81
+ import re
82
+ json_match = re.search(r'(\{.*\})', clean_raw, re.DOTALL)
83
+ if json_match:
84
+ try:
85
+ # Handle single quotes if present (common LLM failure)
86
+ json_str = json_match.group(1)
87
+ if "'" in json_str and '"' not in json_str:
88
+ json_str = json_str.replace("'", '"')
89
+ result = json.loads(json_str)
90
+ except:
91
+ result = None
92
+ else:
93
+ result = None
94
+
95
+ if not result:
96
+ # Fallback for critical failure
97
+ result = {
98
+ "entities": [],
99
+ "type": "general",
100
+ "values": {},
101
+ "target_question": None,
102
+ "analysis": text
103
+ }
104
+ logger.info(f"[ParserAgent] LLM response received.")
105
+ logger.debug(f"[ParserAgent] Parsed JSON: {json.dumps(result, ensure_ascii=False, indent=2)}")
106
+ return result
agents/renderer_agent.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Shim: geometry rendering lives in ``geometry_render`` (worker-safe package)."""
2
+
3
+ from geometry_render.renderer import RendererAgent
4
+
5
+ __all__ = ["RendererAgent"]
agents/solver_agent.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import sympy as sp
4
+ from typing import Dict, Any, List
5
+ from app.llm_client import get_llm_client
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class SolverAgent:
10
+ def __init__(self):
11
+ self.llm = get_llm_client()
12
+
13
+ async def solve(self, semantic_data: Dict[str, Any], engine_result: Dict[str, Any]) -> Dict[str, Any]:
14
+ """
15
+ Solves the geometric problem based on coordinates and the target question.
16
+ Returns a 'solution' dictionary with answer, steps, and symbolic_expression.
17
+ """
18
+ target_question = semantic_data.get("target_question")
19
+ if not target_question:
20
+ # If no question, just return an empty solution structure
21
+ return {
22
+ "answer": None,
23
+ "steps": [],
24
+ "symbolic_expression": None
25
+ }
26
+
27
+ logger.info(f"==[SolverAgent] Solving for: '{target_question}'==")
28
+
29
+ input_text = semantic_data.get("input_text", "")
30
+ coordinates = engine_result.get("coordinates", {})
31
+
32
+ # We provide the coordinates and semantic context to the LLM to help it reason.
33
+ # The LLM is tasked with generating the solution structure directly.
34
+
35
+ system_prompt = """
36
+ You are a Geometry Solver Agent. Your goal is to provide a step-by-step solution for a specific geometric question.
37
+
38
+ === DATA PROVIDED ===
39
+ 1. Target Question: The specific question to answer.
40
+ 2. Geometry Data: Entities and values extracted from the problem.
41
+ 3. Coordinates: Calculated coordinates for all points.
42
+
43
+ === REQUIREMENTS ===
44
+ - Provide the solution in the SAME LANGUAGE as the user's input.
45
+ - Use SymPy concepts if appropriate.
46
+ - Steps should be clear, concise, and logical.
47
+ - The final answer should be numerically or symbolically accurate based on the coordinates and geometric properties.
48
+ - For geometric proofs (e.g., "Is AB perpendicular to AC?"), explain the reasoning based on the data.
49
+
50
+ Output ONLY a JSON object with this structure:
51
+ {
52
+ "answer": "Chuỗi văn bản kết quả cuối cùng (kèm đơn vị nếu có)",
53
+ "steps": [
54
+ "Bước 1: ...",
55
+ "Bước 2: ...",
56
+ ...
57
+ ],
58
+ "symbolic_expression": "Biểu thức toán học rút gọn (LaTeX format optional)"
59
+ }
60
+ """
61
+
62
+ user_content = f"""
63
+ INPUT_TEXT: {input_text}
64
+ TARGET_QUESTION: {target_question}
65
+ SEMANTIC_DATA: {json.dumps(semantic_data, ensure_ascii=False)}
66
+ COORDINATES: {json.dumps(coordinates)}
67
+ """
68
+
69
+ logger.debug("[SolverAgent] Requesting solution from LLM...")
70
+ try:
71
+ raw = await self.llm.chat_completions_create(
72
+ messages=[
73
+ {"role": "system", "content": system_prompt},
74
+ {"role": "user", "content": user_content}
75
+ ],
76
+ response_format={"type": "json_object"}
77
+ )
78
+
79
+ clean_raw = raw.strip()
80
+ # Handle potential markdown code blocks if the response_format wasn't strictly honored
81
+ if clean_raw.startswith("```"):
82
+ import re
83
+ match = re.search(r"```(?:json)?\s*(.*?)\s*```", clean_raw, re.DOTALL)
84
+ if match:
85
+ clean_raw = match.group(1).strip()
86
+
87
+ try:
88
+ solution = json.loads(clean_raw)
89
+ except json.JSONDecodeError:
90
+ # Last resort: try to find anything between { and }
91
+ import re
92
+ json_match = re.search(r'(\{.*\})', clean_raw, re.DOTALL)
93
+ if json_match:
94
+ solution = json.loads(json_match.group(1))
95
+ else:
96
+ raise
97
+
98
+ logger.info("[SolverAgent] Solution generated successfully.")
99
+ return solution
100
+ except Exception as e:
101
+ logger.error(f"[SolverAgent] Error generating solution: {e}")
102
+ logger.debug(f"[SolverAgent] Raw LLM output was: \n{raw if 'raw' in locals() else 'N/A'}")
103
+ return {
104
+ "answer": "Không thể tính toán lời giải tại thời điểm này.",
105
+ "steps": ["Đã xảy ra lỗi trong quá trình xử lý lời giải."],
106
+ "symbolic_expression": None
107
+ }
agents/torch_ultralytics_compat.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Shim: moved to ``vision_ocr.compat`` for OCR worker isolation."""
2
+
3
+ from vision_ocr.compat import allow_ultralytics_weights
4
+
5
+ __all__ = ["allow_ultralytics_weights"]
app/chat_image_upload.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Validate and upload chat/solve attachment images to Supabase Storage (image bucket)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ import uuid
8
+ from typing import Any, Dict, Tuple
9
+
10
+ from fastapi import HTTPException
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def _get_next_image_version(session_id: str) -> int:
16
+ """Same logic as worker.asset_manager.get_next_version for asset_type image."""
17
+ from app.supabase_client import get_supabase
18
+
19
+ supabase = get_supabase()
20
+ try:
21
+ res = (
22
+ supabase.table("session_assets")
23
+ .select("version")
24
+ .eq("session_id", session_id)
25
+ .eq("asset_type", "image")
26
+ .order("version", desc=True)
27
+ .limit(1)
28
+ .execute()
29
+ )
30
+ if res.data:
31
+ return res.data[0]["version"] + 1
32
+ return 1
33
+ except Exception as e:
34
+ logger.error("Error fetching image version: %s", e)
35
+ return 1
36
+
37
+ _MAX_BYTES_DEFAULT = 10 * 1024 * 1024
38
+
39
+ _EXT_TO_MIME: dict[str, str] = {
40
+ ".png": "image/png",
41
+ ".jpg": "image/jpeg",
42
+ ".jpeg": "image/jpeg",
43
+ ".webp": "image/webp",
44
+ ".gif": "image/gif",
45
+ ".bmp": "image/bmp",
46
+ }
47
+
48
+
49
+ def _max_bytes() -> int:
50
+ raw = os.getenv("CHAT_IMAGE_MAX_BYTES")
51
+ if raw and raw.isdigit():
52
+ return min(int(raw), 50 * 1024 * 1024)
53
+ return _MAX_BYTES_DEFAULT
54
+
55
+
56
+ def _magic_ok(ext: str, body: bytes) -> bool:
57
+ if len(body) < 12:
58
+ return False
59
+ if ext == ".png":
60
+ return body.startswith(b"\x89PNG\r\n\x1a\n")
61
+ if ext in (".jpg", ".jpeg"):
62
+ return body.startswith(b"\xff\xd8\xff")
63
+ if ext == ".webp":
64
+ return body.startswith(b"RIFF") and body[8:12] == b"WEBP"
65
+ if ext == ".gif":
66
+ return body.startswith(b"GIF87a") or body.startswith(b"GIF89a")
67
+ if ext == ".bmp":
68
+ return body.startswith(b"BM")
69
+ return False
70
+
71
+
72
+ def validate_chat_image_bytes(
73
+ filename: str | None,
74
+ body: bytes,
75
+ declared_content_type: str | None,
76
+ ) -> Tuple[str, str]:
77
+ """
78
+ Validate size, extension, and magic bytes.
79
+ Returns (extension_with_dot, content_type).
80
+ """
81
+ max_b = _max_bytes()
82
+ if not body:
83
+ raise HTTPException(status_code=400, detail="Empty file.")
84
+ if len(body) > max_b:
85
+ raise HTTPException(
86
+ status_code=413,
87
+ detail=f"Image too large (max {max_b // (1024 * 1024)} MB).",
88
+ )
89
+
90
+ ext = os.path.splitext(filename or "")[1].lower()
91
+ if not ext:
92
+ ext = ".png"
93
+ if ext not in _EXT_TO_MIME:
94
+ raise HTTPException(
95
+ status_code=400,
96
+ detail=f"Unsupported image type: {ext}. Allowed: {', '.join(sorted(_EXT_TO_MIME))}",
97
+ )
98
+
99
+ if not _magic_ok(ext, body):
100
+ raise HTTPException(
101
+ status_code=400,
102
+ detail="File content does not match declared image type.",
103
+ )
104
+
105
+ mime = _EXT_TO_MIME[ext]
106
+ if declared_content_type:
107
+ decl = declared_content_type.split(";")[0].strip().lower()
108
+ if decl and decl not in ("application/octet-stream", mime) and decl != mime:
109
+ logger.warning(
110
+ "Content-Type mismatch (declared=%s, inferred=%s); using inferred.",
111
+ declared_content_type,
112
+ mime,
113
+ )
114
+ return ext, mime
115
+
116
+
117
+ def upload_session_chat_image(
118
+ session_id: str,
119
+ job_id: str,
120
+ file_bytes: bytes,
121
+ ext_with_dot: str,
122
+ content_type: str,
123
+ ) -> Dict[str, Any]:
124
+ """
125
+ Upload to SUPABASE_IMAGE_BUCKET (default: image), insert session_assets row.
126
+ Returns dict with public_url, storage_path, version, session_asset_id (if returned).
127
+ """
128
+ from app.supabase_client import get_supabase
129
+
130
+ supabase = get_supabase()
131
+ bucket_name = os.getenv("SUPABASE_IMAGE_BUCKET", "image")
132
+ raw_ext = ext_with_dot.lstrip(".").lower()
133
+ version = _get_next_image_version(session_id)
134
+ file_name = f"image_v{version}_{job_id}.{raw_ext}"
135
+ storage_path = f"sessions/{session_id}/{file_name}"
136
+
137
+ supabase.storage.from_(bucket_name).upload(
138
+ path=storage_path,
139
+ file=file_bytes,
140
+ file_options={"content-type": content_type},
141
+ )
142
+ public_url = supabase.storage.from_(bucket_name).get_public_url(storage_path)
143
+ if isinstance(public_url, dict):
144
+ public_url = public_url.get("publicUrl") or public_url.get("public_url") or str(public_url)
145
+
146
+ row = {
147
+ "session_id": session_id,
148
+ "job_id": job_id,
149
+ "asset_type": "image",
150
+ "storage_path": storage_path,
151
+ "public_url": public_url,
152
+ "version": version,
153
+ }
154
+ ins = supabase.table("session_assets").insert(row).select("id").execute()
155
+ asset_id = None
156
+ if ins.data and len(ins.data) > 0:
157
+ asset_id = ins.data[0].get("id")
158
+
159
+ log_data = {
160
+ "public_url": public_url,
161
+ "storage_path": storage_path,
162
+ "version": version,
163
+ "session_asset_id": str(asset_id) if asset_id else None,
164
+ }
165
+ logger.info("Uploaded chat image: %s", log_data)
166
+ return {
167
+ "public_url": public_url,
168
+ "storage_path": storage_path,
169
+ "version": version,
170
+ "session_asset_id": str(asset_id) if asset_id else None,
171
+ }
172
+
173
+
174
+ def upload_ephemeral_ocr_blob(
175
+ file_bytes: bytes,
176
+ ext_with_dot: str,
177
+ content_type: str,
178
+ ) -> Tuple[str, str]:
179
+ """
180
+ Upload bytes to image bucket under _ocr_temp/ for worker-only OCR (no session_assets row).
181
+ Returns (storage_path, public_url). Caller must delete_storage_object when done.
182
+ """
183
+ from app.supabase_client import get_supabase
184
+
185
+ bucket_name = os.getenv("SUPABASE_IMAGE_BUCKET", "image")
186
+ raw_ext = ext_with_dot.lstrip(".").lower() or "png"
187
+ name = f"_ocr_temp/{uuid.uuid4().hex}.{raw_ext}"
188
+ supabase = get_supabase()
189
+ supabase.storage.from_(bucket_name).upload(
190
+ path=name,
191
+ file=file_bytes,
192
+ file_options={"content-type": content_type},
193
+ )
194
+ public_url = supabase.storage.from_(bucket_name).get_public_url(name)
195
+ if isinstance(public_url, dict):
196
+ public_url = public_url.get("publicUrl") or public_url.get("public_url") or str(public_url)
197
+ return name, public_url
198
+
199
+
200
+ def delete_storage_object(bucket_name: str, storage_path: str) -> None:
201
+ try:
202
+ from app.supabase_client import get_supabase
203
+
204
+ get_supabase().storage.from_(bucket_name).remove([storage_path])
205
+ except Exception as e:
206
+ logger.warning("delete_storage_object failed path=%s: %s", storage_path, e)
app/dependencies.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import HTTPException, Header
2
+
3
+ from app.supabase_client import get_supabase, get_supabase_for_user_jwt
4
+
5
+
6
+ async def get_current_user_id(authorization: str | None = Header(None)):
7
+ """
8
+ Authenticate user using Supabase JWT.
9
+ Expected Header: Authorization: Bearer <token>
10
+ """
11
+ import os
12
+
13
+ if not authorization:
14
+ raise HTTPException(
15
+ status_code=401,
16
+ detail="Authorization header missing or invalid. Use 'Bearer <token>'",
17
+ )
18
+
19
+ if os.getenv("ALLOW_TEST_BYPASS") == "true" and authorization.startswith("Test "):
20
+ return authorization.split(" ")[1]
21
+
22
+ if not authorization.startswith("Bearer "):
23
+ raise HTTPException(
24
+ status_code=401,
25
+ detail="Authorization header missing or invalid. Use 'Bearer <token>'",
26
+ )
27
+
28
+ token = authorization.split(" ")[1]
29
+ supabase = get_supabase()
30
+
31
+ try:
32
+ user_response = supabase.auth.get_user(token)
33
+ if not user_response or not user_response.user:
34
+ raise HTTPException(status_code=401, detail="Invalid session or token.")
35
+
36
+ return user_response.user.id
37
+ except HTTPException:
38
+ raise
39
+ except Exception as e:
40
+ raise HTTPException(status_code=401, detail=f"Authentication failed: {str(e)}")
41
+
42
+
43
+ async def get_authenticated_supabase(authorization: str = Header(...)):
44
+ """
45
+ Supabase client that carries the user's JWT (anon key + Authorization header).
46
+ Use for routes that should respect Row Level Security; pair with app logic as needed.
47
+ """
48
+ if not authorization or not authorization.startswith("Bearer "):
49
+ raise HTTPException(
50
+ status_code=401,
51
+ detail="Authorization header missing or invalid. Use 'Bearer <token>'",
52
+ )
53
+
54
+ token = authorization.split(" ")[1]
55
+ supabase = get_supabase()
56
+
57
+ try:
58
+ user_response = supabase.auth.get_user(token)
59
+ if not user_response or not user_response.user:
60
+ raise HTTPException(status_code=401, detail="Invalid session or token.")
61
+ except HTTPException:
62
+ raise
63
+ except Exception as e:
64
+ raise HTTPException(status_code=401, detail=f"Authentication failed: {str(e)}")
65
+
66
+ try:
67
+ return get_supabase_for_user_jwt(token)
68
+ except RuntimeError as e:
69
+ raise HTTPException(status_code=503, detail=str(e))
app/errors.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Map exceptions to short, user-visible messages (avoid leaking HTML bodies from 404 proxies)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def _looks_like_html(text: str) -> bool:
11
+ t = text.lstrip()[:500].lower()
12
+ return t.startswith("<!doctype") or t.startswith("<html") or "<html" in t[:200]
13
+
14
+
15
+ def format_error_for_user(exc: BaseException) -> str:
16
+ """
17
+ Produce a safe message for chat/UI. Full detail stays in server logs via logger.exception.
18
+ """
19
+ # httpx: wrong URL often returns 404 HTML; don't show body
20
+ try:
21
+ import httpx
22
+
23
+ if isinstance(exc, httpx.HTTPStatusError):
24
+ req = exc.request
25
+ code = exc.response.status_code
26
+ url_hint = ""
27
+ try:
28
+ url_hint = str(req.url.host) if req and req.url else ""
29
+ except Exception:
30
+ pass
31
+ logger.warning(
32
+ "HTTPStatusError %s for %s (response not shown to user)",
33
+ code,
34
+ url_hint or "?",
35
+ )
36
+ return (
37
+ "Kiểm tra URL API, khóa bí mật và biến môi trường (OpenRouter/Supabase/Redis)."
38
+ )
39
+
40
+ if isinstance(exc, httpx.RequestError):
41
+ return "Không kết nối được tới dịch vụ ngoài (mạng hoặc URL sai)."
42
+ except ImportError:
43
+ pass
44
+
45
+ raw = str(exc).strip()
46
+ if not raw:
47
+ return "Đã xảy ra lỗi không xác định."
48
+
49
+ if _looks_like_html(raw):
50
+ logger.warning("Suppressed HTML error body from user-facing message")
51
+ return (
52
+ "Dịch vụ trả về trang lỗi (thường là URL API sai hoặc endpoint không tồn tại — HTTP 404). "
53
+ "Kiểm tra OPENROUTER_MODEL và khóa API trên server."
54
+ )
55
+
56
+ if len(raw) > 800:
57
+ return raw[:800] + "…"
58
+
59
+ return raw
app/job_poll.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Normalize Supabase `jobs` rows for polling / WebSocket clients (stable `job_id` + JSON `result`)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ from typing import Any
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def _coerce_result(value: Any) -> Any:
13
+ if value is None:
14
+ return None
15
+ if isinstance(value, (dict, list)):
16
+ return value
17
+ if isinstance(value, str):
18
+ try:
19
+ return json.loads(value)
20
+ except json.JSONDecodeError:
21
+ logger.warning("job_poll: result is non-JSON string, returning raw")
22
+ return {"raw": value}
23
+ return value
24
+
25
+
26
+ def normalize_job_row_for_client(row: dict[str, Any]) -> dict[str, Any]:
27
+ """
28
+ Build a JSON-serializable dict that always includes:
29
+ - ``job_id`` (alias of DB ``id``) for clients that expect it on poll bodies
30
+ - ``status`` as str
31
+ - ``result`` as object/array when stored as JSON string
32
+ All other columns are passed through (UUID/datetime become JSON-safe via FastAPI encoder).
33
+ """
34
+ out = dict(row)
35
+ jid = out.get("id")
36
+ if jid is not None:
37
+ out["job_id"] = str(jid)
38
+ st = out.get("status")
39
+ if st is not None:
40
+ out["status"] = str(st)
41
+ if "result" in out:
42
+ out["result"] = _coerce_result(out.get("result"))
43
+ if out.get("user_id") is not None:
44
+ out["user_id"] = str(out["user_id"])
45
+ if out.get("session_id") is not None:
46
+ out["session_id"] = str(out["session_id"])
47
+ return out
app/llm_client.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import asyncio
4
+ import logging
5
+ from openai import AsyncOpenAI
6
+ from typing import List, Dict, Any, Optional
7
+ from app.url_utils import openai_compatible_api_key, sanitize_env
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class MultiLayerLLMClient:
12
+ def __init__(self):
13
+ # 1. Models sequence loading
14
+ self.models = []
15
+ for i in range(1, 4):
16
+ model = os.getenv(f"OPENROUTER_MODEL_{i}")
17
+ if model:
18
+ self.models.append(model)
19
+
20
+ # Fallback to legacy OPENROUTER_MODEL if no numbered models found
21
+ if not self.models:
22
+ legacy_model = os.getenv("OPENROUTER_MODEL", "google/gemini-2.0-flash-001")
23
+ self.models = [legacy_model]
24
+
25
+ # 2. Key selection (No rotation, always use the first available key)
26
+ api_key = os.getenv("OPENROUTER_API_KEY_1") or os.getenv("OPENROUTER_API_KEY")
27
+
28
+ if not api_key:
29
+ logger.error("[LLM] No OpenRouter API key found.")
30
+ self.client = None
31
+ else:
32
+ self.client = AsyncOpenAI(
33
+ api_key=openai_compatible_api_key(api_key),
34
+ base_url="https://openrouter.ai/api/v1",
35
+ timeout=60.0,
36
+ default_headers={
37
+ "HTTP-Referer": "https://mathsolver.ai",
38
+ "X-Title": "MathSolver Backend",
39
+ }
40
+ )
41
+
42
+ async def chat_completions_create(
43
+ self,
44
+ messages: List[Dict[str, str]],
45
+ response_format: Optional[Dict[str, str]] = None,
46
+ **kwargs
47
+ ) -> str:
48
+ """
49
+ Implements Model Fallback Sequence: Model 1 -> Model 2 -> Model 3.
50
+ Always starts from Model 1 for every new call.
51
+ """
52
+ if not self.client:
53
+ raise ValueError("No API client configured. Check your API keys.")
54
+
55
+ MAX_ATTEMPTS = len(self.models)
56
+ RETRY_DELAY = 1.0 # second
57
+
58
+ for attempt_idx in range(MAX_ATTEMPTS):
59
+ current_model = self.models[attempt_idx]
60
+ attempt_num = attempt_idx + 1
61
+
62
+ try:
63
+ logger.info(f"[LLM] Attempt {attempt_num}/{MAX_ATTEMPTS} using Model: {current_model}...")
64
+
65
+ response = await self.client.chat.completions.create(
66
+ model=current_model,
67
+ messages=messages,
68
+ response_format=response_format,
69
+ **kwargs
70
+ )
71
+
72
+ if not response or not getattr(response, "choices", None):
73
+ raise ValueError(f"Invalid response structure from model {current_model}")
74
+
75
+ content = response.choices[0].message.content
76
+ if content:
77
+ logger.info(f"[LLM] SUCCESS on attempt {attempt_num} ({current_model}).")
78
+ return content
79
+
80
+ raise ValueError(f"Empty content from model {current_model}")
81
+
82
+ except Exception as e:
83
+ err_msg = f"{type(e).__name__}: {str(e)}"
84
+ logger.warning(f"[LLM] FAILED on attempt {attempt_num} ({current_model}): {err_msg}")
85
+
86
+ if attempt_num < MAX_ATTEMPTS:
87
+ logger.info(f"[LLM] Retrying next model in {RETRY_DELAY}s...")
88
+ await asyncio.sleep(RETRY_DELAY)
89
+ else:
90
+ logger.error(f"[LLM] FINAL FAILURE after {attempt_num} models.")
91
+ raise e
92
+
93
+ # Global instance for easy reuse (singleton-ish)
94
+ _llm_client = None
95
+
96
+ def get_llm_client() -> MultiLayerLLMClient:
97
+ global _llm_client
98
+ if _llm_client is None:
99
+ _llm_client = MultiLayerLLMClient()
100
+ return _llm_client
app/logging_setup.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Logging theo một biến LOG_LEVEL: debug | info | warning | error."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import Final
8
+
9
+ _SETUP_DONE = False
10
+
11
+ PIPELINE_LOGGER_NAME: Final = "app.pipeline"
12
+ CACHE_LOGGER_NAME: Final = "app.cache"
13
+ STEPS_LOGGER_NAME: Final = "app.steps"
14
+ ACCESS_LOGGER_NAME: Final = "app.access"
15
+
16
+
17
+ def _normalize_level() -> str:
18
+ raw = os.getenv("LOG_LEVEL", "info").strip().lower()
19
+ if raw in ("debug", "info", "warning", "error"):
20
+ return raw
21
+ return "info"
22
+
23
+
24
+ def setup_application_logging() -> None:
25
+ """Idempotent; gọi khi khởi động process (uvicorn, celery, worker_health)."""
26
+ global _SETUP_DONE
27
+ if _SETUP_DONE:
28
+ return
29
+ _SETUP_DONE = True
30
+
31
+ mode = _normalize_level()
32
+
33
+ level_map = {
34
+ "debug": logging.DEBUG,
35
+ "info": logging.INFO,
36
+ "warning": logging.WARNING,
37
+ "error": logging.ERROR,
38
+ }
39
+ root_level = level_map[mode]
40
+
41
+ fmt_named = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
42
+ fmt_short = "%(asctime)s | %(levelname)-8s | %(message)s"
43
+
44
+ logging.basicConfig(
45
+ level=root_level,
46
+ format=fmt_named if mode == "debug" else fmt_short,
47
+ datefmt="%H:%M:%S",
48
+ force=True,
49
+ )
50
+
51
+ logging.getLogger("httpx").setLevel(logging.WARNING)
52
+ logging.getLogger("httpcore").setLevel(logging.WARNING)
53
+ logging.getLogger("openai").setLevel(logging.WARNING)
54
+ logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
55
+ logging.getLogger("uvicorn.error").setLevel(logging.INFO)
56
+ # HTTP/2 stack (httpx/httpcore) — khi LOG_LEVEL=debug root=DEBUG sẽ tràn log hpack; không cần cho debug app
57
+ for _name in ("hpack", "h2", "hyperframe", "urllib3"):
58
+ logging.getLogger(_name).setLevel(logging.WARNING)
59
+
60
+ if mode == "debug":
61
+ logging.getLogger("agents").setLevel(logging.DEBUG)
62
+ logging.getLogger("solver").setLevel(logging.DEBUG)
63
+ logging.getLogger("app").setLevel(logging.DEBUG)
64
+ logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.DEBUG)
65
+ logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.DEBUG)
66
+ logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.INFO)
67
+ logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.INFO)
68
+ logging.getLogger("app.main").setLevel(logging.INFO)
69
+ logging.getLogger("worker").setLevel(logging.INFO)
70
+ elif mode == "info":
71
+ # Chỉ HTTP access (app.access) + startup; ẩn chi tiết agents/orchestrator/pipeline SUCCESS
72
+ logging.getLogger("agents").setLevel(logging.INFO)
73
+ logging.getLogger("solver").setLevel(logging.WARNING)
74
+ logging.getLogger("app").setLevel(logging.INFO)
75
+ logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.WARNING)
76
+ logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.WARNING)
77
+ logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.WARNING)
78
+ logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.INFO)
79
+ logging.getLogger("app.main").setLevel(logging.INFO)
80
+ logging.getLogger("worker").setLevel(logging.WARNING)
81
+ elif mode == "warning":
82
+ logging.getLogger("agents").setLevel(logging.WARNING)
83
+ logging.getLogger("solver").setLevel(logging.WARNING)
84
+ logging.getLogger("app.routers").setLevel(logging.WARNING)
85
+ logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.WARNING)
86
+ logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.WARNING)
87
+ logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.WARNING)
88
+ logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.WARNING)
89
+ logging.getLogger("app.main").setLevel(logging.WARNING)
90
+ logging.getLogger("worker").setLevel(logging.WARNING)
91
+ else: # error
92
+ logging.getLogger("agents").setLevel(logging.ERROR)
93
+ logging.getLogger("solver").setLevel(logging.ERROR)
94
+ logging.getLogger("app.routers").setLevel(logging.ERROR)
95
+ logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.ERROR)
96
+ logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.ERROR)
97
+ logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.ERROR)
98
+ logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.ERROR)
99
+ logging.getLogger("app.main").setLevel(logging.ERROR)
100
+ logging.getLogger("worker").setLevel(logging.ERROR)
101
+
102
+ logging.getLogger(__name__).debug(
103
+ "LOG_LEVEL=%s root=%s", mode, logging.getLevelName(root_level)
104
+ )
105
+
106
+
107
+ def get_log_level() -> str:
108
+ return _normalize_level()
109
+
110
+
111
+ def is_debug_level() -> bool:
112
+ return _normalize_level() == "debug"
app/logutil.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """log_step (debug), pipeline (debug), access log ở middleware."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+ from typing import Any
9
+
10
+ from app.logging_setup import PIPELINE_LOGGER_NAME, STEPS_LOGGER_NAME
11
+
12
+ _pipeline = logging.getLogger(PIPELINE_LOGGER_NAME)
13
+ _steps = logging.getLogger(STEPS_LOGGER_NAME)
14
+
15
+
16
+ def is_debug_mode() -> bool:
17
+ """Chi tiết từng bước chỉ khi LOG_LEVEL=debug."""
18
+ return os.getenv("LOG_LEVEL", "info").strip().lower() == "debug"
19
+
20
+
21
+ def _truncate(val: Any, max_len: int = 2000) -> Any:
22
+ if val is None:
23
+ return None
24
+ if isinstance(val, (int, float, bool)):
25
+ return val
26
+ s = str(val)
27
+ if len(s) > max_len:
28
+ return s[:max_len] + f"... (+{len(s) - max_len} chars)"
29
+ return s
30
+
31
+
32
+ def log_step(step: str, **fields: Any) -> None:
33
+ """Chỉ khi LOG_LEVEL=debug: DB / cache / orchestrator."""
34
+ if not is_debug_mode():
35
+ return
36
+ safe = {k: _truncate(v) for k, v in fields.items()}
37
+ try:
38
+ payload = json.dumps(safe, ensure_ascii=False, default=str)
39
+ except Exception:
40
+ payload = str(safe)
41
+ _steps.debug("[step:%s] %s", step, payload)
42
+
43
+
44
+ def log_pipeline_success(operation: str, **fields: Any) -> None:
45
+ """Chỉ hiện khi debug (pipeline SUCCESS không dùng ở info — đã có app.access)."""
46
+ if not is_debug_mode():
47
+ return
48
+ safe = {k: _truncate(v, 500) for k, v in fields.items()}
49
+ _pipeline.info(
50
+ "SUCCESS %s %s",
51
+ operation,
52
+ json.dumps(safe, ensure_ascii=False, default=str),
53
+ )
54
+
55
+
56
+ def log_pipeline_failure(operation: str, error: str | None = None, **fields: Any) -> None:
57
+ """Thất bại pipeline: luôn dùng WARNING để vẫn thấy khi LOG_LEVEL=warning."""
58
+ if is_debug_mode():
59
+ safe = {k: _truncate(v, 500) for k, v in fields.items()}
60
+ _pipeline.warning(
61
+ "FAIL %s err=%s %s",
62
+ operation,
63
+ _truncate(error, 300),
64
+ json.dumps(safe, ensure_ascii=False, default=str),
65
+ )
66
+ else:
67
+ _pipeline.warning("FAIL %s", operation)
app/main.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import time
6
+ import uuid
7
+ import warnings
8
+
9
+ from dotenv import load_dotenv
10
+ from fastapi import Depends, FastAPI, File, HTTPException, UploadFile
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from starlette.requests import Request
13
+
14
+ load_dotenv()
15
+
16
+ from app.runtime_env import apply_runtime_env_defaults
17
+
18
+ apply_runtime_env_defaults()
19
+
20
+ os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
21
+ warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
22
+ warnings.filterwarnings("ignore", category=UserWarning, module="albumentations")
23
+
24
+ from app.logging_setup import ACCESS_LOGGER_NAME, get_log_level, setup_application_logging
25
+
26
+ setup_application_logging()
27
+
28
+ # Routers (after logging)
29
+ from app.dependencies import get_current_user_id
30
+ from app.ocr_local_file import ocr_from_local_image_path
31
+ from app.routers import auth, sessions, solve
32
+ from agents.ocr_agent import OCRAgent
33
+ from app.routers.solve import get_orchestrator
34
+ from app.job_poll import normalize_job_row_for_client
35
+ from app.supabase_client import get_supabase
36
+ from app.websocket_manager import register_websocket_routes
37
+
38
+ logger = logging.getLogger("app.main")
39
+ _access = logging.getLogger(ACCESS_LOGGER_NAME)
40
+
41
+ app = FastAPI(title="Visual Math Solver API v5.1")
42
+
43
+
44
+ @app.middleware("http")
45
+ async def access_log_middleware(request: Request, call_next):
46
+ """LOG_LEVEL=info/debug: mọi request; warning: chỉ 4xx/5xx; error: chỉ 4xx/5xx ở mức error."""
47
+ start = time.perf_counter()
48
+ response = await call_next(request)
49
+ ms = (time.perf_counter() - start) * 1000
50
+ mode = get_log_level()
51
+ method = request.method
52
+ path = request.url.path
53
+ status = response.status_code
54
+
55
+ if mode in ("debug", "info"):
56
+ _access.info("%s %s -> %s (%.0fms)", method, path, status, ms)
57
+ elif mode == "warning":
58
+ if status >= 500:
59
+ _access.error("%s %s -> %s (%.0fms)", method, path, status, ms)
60
+ elif status >= 400:
61
+ _access.warning("%s %s -> %s (%.0fms)", method, path, status, ms)
62
+ elif mode == "error":
63
+ if status >= 400:
64
+ _access.error("%s %s -> %s", method, path, status)
65
+
66
+ return response
67
+
68
+
69
+ from worker.celery_app import BROKER_URL
70
+
71
+ _broker_tail = BROKER_URL.split("@")[-1] if "@" in BROKER_URL else BROKER_URL
72
+ if get_log_level() in ("debug", "info"):
73
+ logger.info("App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail)
74
+ else:
75
+ logger.warning(
76
+ "App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail
77
+ )
78
+
79
+ app.add_middleware(
80
+ CORSMiddleware,
81
+ allow_origins=[
82
+ "http://localhost:3000",
83
+ "http://127.0.0.1:3000",
84
+ "http://localhost:3005",
85
+ ],
86
+ allow_credentials=True,
87
+ allow_methods=["*"],
88
+ allow_headers=["*"],
89
+ )
90
+
91
+ app.include_router(auth.router)
92
+ app.include_router(sessions.router)
93
+ app.include_router(solve.router)
94
+
95
+ register_websocket_routes(app)
96
+
97
+
98
+ def get_ocr_agent() -> OCRAgent:
99
+ """Same OCR instance as the solve pipeline (no duplicate model load)."""
100
+ return get_orchestrator().ocr_agent
101
+
102
+
103
+ supabase_client = get_supabase()
104
+
105
+
106
+ @app.get("/")
107
+ def read_root():
108
+ return {"message": "Visual Math Solver API v5.1 is running", "version": "5.1"}
109
+
110
+
111
+ @app.post("/api/v1/ocr")
112
+ async def upload_ocr(
113
+ file: UploadFile = File(...),
114
+ _user_id=Depends(get_current_user_id),
115
+ ):
116
+ """OCR upload: requires authenticated user."""
117
+ temp_path = f"temp_{uuid.uuid4()}.png"
118
+ with open(temp_path, "wb") as buffer:
119
+ buffer.write(await file.read())
120
+
121
+ try:
122
+ text = await ocr_from_local_image_path(temp_path, file.filename, get_ocr_agent())
123
+ return {"text": text}
124
+ finally:
125
+ if os.path.exists(temp_path):
126
+ os.remove(temp_path)
127
+
128
+
129
+ @app.get("/api/v1/solve/{job_id}")
130
+ async def get_job_status(
131
+ job_id: str,
132
+ user_id=Depends(get_current_user_id),
133
+ ):
134
+ """Retrieve job status (can be used for polling if WS fails). Owner-only."""
135
+ response = supabase_client.table("jobs").select("*").eq("id", job_id).execute()
136
+ if not response.data:
137
+ raise HTTPException(status_code=404, detail="Job not found")
138
+ job = response.data[0]
139
+ if job.get("user_id") is not None and str(job["user_id"]) != str(user_id):
140
+ raise HTTPException(status_code=403, detail="Forbidden: You do not own this job.")
141
+ # Stable contract for FE poll (job_id alias, parsed result JSON, string UUIDs)
142
+ return normalize_job_row_for_client(job)
app/models/schemas.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, EmailStr, field_validator
2
+ from typing import Optional, List, Any, Dict
3
+ from datetime import datetime
4
+ import uuid
5
+
6
+ from app.url_utils import sanitize_url
7
+
8
+ # --- Auth Schemas ---
9
+ class UserProfile(BaseModel):
10
+ id: uuid.UUID
11
+ display_name: Optional[str] = None
12
+ avatar_url: Optional[str] = None
13
+ created_at: datetime
14
+
15
+ class User(BaseModel):
16
+ id: uuid.UUID
17
+ email: EmailStr
18
+
19
+ # --- Session Schemas ---
20
+ class SessionBase(BaseModel):
21
+ title: str = "Bài toán mới"
22
+
23
+ class SessionCreate(SessionBase):
24
+ pass
25
+
26
+ class Session(SessionBase):
27
+ id: uuid.UUID
28
+ user_id: uuid.UUID
29
+ created_at: datetime
30
+ updated_at: datetime
31
+
32
+ class Config:
33
+ from_attributes = True
34
+
35
+ # --- Message Schemas ---
36
+ class MessageBase(BaseModel):
37
+ role: str
38
+ type: str = "text"
39
+ content: str
40
+ metadata: Dict[str, Any] = {}
41
+
42
+ class MessageCreate(MessageBase):
43
+ session_id: uuid.UUID
44
+
45
+ class Message(MessageBase):
46
+ id: uuid.UUID
47
+ session_id: uuid.UUID
48
+ created_at: datetime
49
+
50
+ class Config:
51
+ from_attributes = True
52
+
53
+ # --- Solve Job Schemas ---
54
+ class SolveRequest(BaseModel):
55
+ text: str
56
+ image_url: Optional[str] = None
57
+
58
+ @field_validator("image_url", mode="before")
59
+ @classmethod
60
+ def _clean_image_url(cls, v):
61
+ return sanitize_url(v) if v is not None else None
62
+
63
+ class SolveResponse(BaseModel):
64
+ job_id: str
65
+ status: str
66
+
67
+ class RenderVideoRequest(BaseModel):
68
+ job_id: Optional[str] = None
69
+
70
+ class RenderVideoResponse(BaseModel):
71
+ job_id: str
72
+ status: str
73
+
74
+
75
+ class OcrPreviewResponse(BaseModel):
76
+ """Stateless OCR preview before POST .../solve (no DB writes, no job)."""
77
+
78
+ ocr_text: str
79
+ user_message: str = ""
80
+ combined_draft: str
app/ocr_celery.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run OCR on a remote worker via Celery (queue `ocr`) when OCR_USE_CELERY is enabled."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import TYPE_CHECKING
8
+
9
+ import anyio
10
+
11
+ if TYPE_CHECKING:
12
+ from agents.ocr_agent import OCRAgent
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def ocr_celery_enabled() -> bool:
18
+ return os.getenv("OCR_USE_CELERY", "").strip().lower() in ("1", "true", "yes", "on")
19
+
20
+
21
+ def _ocr_timeout_sec() -> float:
22
+ raw = os.getenv("OCR_CELERY_TIMEOUT_SEC", "180")
23
+ try:
24
+ return max(30.0, float(raw))
25
+ except ValueError:
26
+ return 180.0
27
+
28
+
29
+ def _run_ocr_celery_sync(image_url: str) -> str:
30
+ from worker.ocr_tasks import run_ocr_from_url
31
+
32
+ async_result = run_ocr_from_url.apply_async(args=[image_url])
33
+ return async_result.get(timeout=_ocr_timeout_sec())
34
+
35
+
36
+ def _is_ocr_error_response(text: str) -> bool:
37
+ s = (text or "").lstrip()
38
+ return s.startswith("Error:")
39
+
40
+
41
+ async def ocr_from_image_url(image_url: str, fallback_agent: "OCRAgent") -> str:
42
+ """
43
+ If OCR_USE_CELERY: delegate to Celery task `run_ocr_from_url` (worker queue `ocr`, raw OCR only),
44
+ then run ``refine_with_llm`` on the API process.
45
+ Else: use fallback_agent.process_url (in-process full pipeline).
46
+ """
47
+ if not ocr_celery_enabled():
48
+ return await fallback_agent.process_url(image_url)
49
+ logger.info("OCR_USE_CELERY: delegating OCR to Celery queue=ocr (LLM refine on API)")
50
+ raw = await anyio.to_thread.run_sync(_run_ocr_celery_sync, image_url)
51
+ raw = raw if raw is not None else ""
52
+ if not raw.strip() or _is_ocr_error_response(raw):
53
+ return raw
54
+ return await fallback_agent.refine_with_llm(raw)
app/ocr_local_file.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OCR from a local file path, optionally via Celery worker (upload temp blob first)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import TYPE_CHECKING
8
+
9
+ from app.chat_image_upload import (
10
+ delete_storage_object,
11
+ upload_ephemeral_ocr_blob,
12
+ validate_chat_image_bytes,
13
+ )
14
+ from app.ocr_celery import ocr_celery_enabled, ocr_from_image_url
15
+
16
+ if TYPE_CHECKING:
17
+ from agents.ocr_agent import OCRAgent
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ async def ocr_from_local_image_path(
23
+ local_path: str,
24
+ original_filename: str | None,
25
+ fallback_agent: "OCRAgent",
26
+ ) -> str:
27
+ """
28
+ Run OCR on a file on local disk. If OCR_USE_Celery, upload to ephemeral storage URL
29
+ then delegate to worker; otherwise process_image in-process.
30
+ """
31
+ if not ocr_celery_enabled():
32
+ return await fallback_agent.process_image(local_path)
33
+
34
+ with open(local_path, "rb") as f:
35
+ body = f.read()
36
+ ext = os.path.splitext(original_filename or local_path)[1].lower() or ".png"
37
+ _, content_type = validate_chat_image_bytes(original_filename or local_path, body, None)
38
+ bucket = os.getenv("SUPABASE_IMAGE_BUCKET", "image")
39
+ path, url = upload_ephemeral_ocr_blob(body, ext, content_type)
40
+ try:
41
+ return await ocr_from_image_url(url, fallback_agent)
42
+ finally:
43
+ delete_storage_object(bucket, path)
app/ocr_text_merge.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helpers for OCR preview combined draft (no Pydantic email deps)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Optional
6
+
7
+
8
+ def build_combined_ocr_preview_draft(user_message: Optional[str], ocr_text: str) -> str:
9
+ """Merge user caption and OCR text for confirm step (user message first, then OCR)."""
10
+ u = (user_message or "").strip()
11
+ o = (ocr_text or "").strip()
12
+ if u and o:
13
+ return f"{u}\n\n{o}"
14
+ return u or o
app/routers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import auth, sessions, solve
app/routers/auth.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException
2
+ from app.dependencies import get_current_user_id
3
+ from app.supabase_client import get_supabase
4
+ from app.models.schemas import UserProfile
5
+ import uuid
6
+
7
+ router = APIRouter(prefix="/api/v1/auth", tags=["Auth"])
8
+
9
+ @router.get("/me")
10
+ async def get_me(user_id=Depends(get_current_user_id)):
11
+ """获取当前登录用户的信息 (Retrieve current user profile)"""
12
+ supabase = get_supabase()
13
+ res = supabase.table("profiles").select("*").eq("id", user_id).execute()
14
+ if not res.data:
15
+ raise HTTPException(status_code=404, detail="Profile not found.")
16
+ return res.data[0]
17
+
18
+ @router.patch("/me")
19
+ async def update_me(data: dict, user_id=Depends(get_current_user_id)):
20
+ """Cập nhật profile hiện tại (Update current profile)"""
21
+ supabase = get_supabase()
22
+ res = supabase.table("profiles").update(data).eq("id", user_id).execute()
23
+ return res.data[0]
app/routers/sessions.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import time
5
+ from typing import List
6
+
7
+ from fastapi import APIRouter, Depends, HTTPException
8
+
9
+ from app.dependencies import get_current_user_id
10
+ from app.logutil import log_step
11
+ from app.session_cache import (
12
+ get_sessions_list_cached,
13
+ invalidate_for_user,
14
+ invalidate_session_owner,
15
+ session_owned_by_user,
16
+ )
17
+ from app.supabase_client import get_supabase
18
+
19
+ router = APIRouter(prefix="/api/v1/sessions", tags=["Sessions"])
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @router.get("", response_model=List[dict])
24
+ async def list_sessions(user_id=Depends(get_current_user_id)):
25
+ """Danh sách các phiên chat của người dùng (List user's chat sessions)"""
26
+ supabase = get_supabase()
27
+ t0 = time.perf_counter()
28
+
29
+ def fetch() -> list:
30
+ res = (
31
+ supabase.table("sessions")
32
+ .select("id, user_id, title, created_at, updated_at")
33
+ .eq("user_id", user_id)
34
+ .order("updated_at", desc=True)
35
+ .execute()
36
+ )
37
+ log_step("db_select", table="sessions", op="list", user_id=str(user_id))
38
+ return res.data
39
+
40
+ out = get_sessions_list_cached(str(user_id), fetch)
41
+ logger.info(
42
+ "sessions.list user=%s count=%d %.1fms",
43
+ user_id,
44
+ len(out),
45
+ (time.perf_counter() - t0) * 1000,
46
+ )
47
+ return out
48
+
49
+
50
+ @router.post("", response_model=dict)
51
+ async def create_session(user_id=Depends(get_current_user_id)):
52
+ """Tạo một phiên chat mới (Create a new chat session)"""
53
+ supabase = get_supabase()
54
+ t0 = time.perf_counter()
55
+ res = supabase.table("sessions").insert(
56
+ {"user_id": user_id, "title": "Bài toán mới"}
57
+ ).execute()
58
+ log_step("db_insert", table="sessions", op="create")
59
+ invalidate_for_user(str(user_id))
60
+ row = res.data[0]
61
+ logger.info(
62
+ "sessions.create user=%s id=%s %.1fms",
63
+ user_id,
64
+ row.get("id"),
65
+ (time.perf_counter() - t0) * 1000,
66
+ )
67
+ return row
68
+
69
+
70
+ @router.get("/{session_id}/messages", response_model=List[dict])
71
+ async def get_session_messages(session_id: str, user_id=Depends(get_current_user_id)):
72
+ """Lấy toàn bộ lịch sử tin nhắn của một phiên (Get chat history for a session)"""
73
+ supabase = get_supabase()
74
+
75
+ def owns() -> bool:
76
+ res = (
77
+ supabase.table("sessions")
78
+ .select("id")
79
+ .eq("id", session_id)
80
+ .eq("user_id", user_id)
81
+ .execute()
82
+ )
83
+ log_step("db_select", table="sessions", op="owner_check", session_id=session_id)
84
+ return bool(res.data)
85
+
86
+ if not session_owned_by_user(session_id, str(user_id), owns):
87
+ raise HTTPException(
88
+ status_code=403, detail="Forbidden: You do not own this session."
89
+ )
90
+
91
+ res = (
92
+ supabase.table("messages")
93
+ .select("*")
94
+ .eq("session_id", session_id)
95
+ .order("created_at", desc=False)
96
+ .execute()
97
+ )
98
+ log_step("db_select", table="messages", op="list", session_id=session_id)
99
+ return res.data
100
+
101
+
102
+ @router.delete("/{session_id}")
103
+ async def delete_session(session_id: str, user_id=Depends(get_current_user_id)):
104
+ """Xóa một phiên chat (Delete a chat session)"""
105
+ supabase = get_supabase()
106
+
107
+ def owns() -> bool:
108
+ res = (
109
+ supabase.table("sessions")
110
+ .select("id")
111
+ .eq("id", session_id)
112
+ .eq("user_id", user_id)
113
+ .execute()
114
+ )
115
+ return bool(res.data)
116
+
117
+ if not session_owned_by_user(session_id, str(user_id), owns):
118
+ raise HTTPException(
119
+ status_code=403, detail="Forbidden: You do not own this session."
120
+ )
121
+
122
+ # jobs.session_id FK must be cleared before sessions row
123
+ supabase.table("jobs").delete().eq("session_id", session_id).eq("user_id", user_id).execute()
124
+ log_step("db_delete", table="jobs", op="by_session", session_id=session_id)
125
+ supabase.table("messages").delete().eq("session_id", session_id).execute()
126
+ log_step("db_delete", table="messages", op="by_session", session_id=session_id)
127
+ res = (
128
+ supabase.table("sessions")
129
+ .delete()
130
+ .eq("id", session_id)
131
+ .eq("user_id", user_id)
132
+ .execute()
133
+ )
134
+ log_step("db_delete", table="sessions", session_id=session_id)
135
+ invalidate_for_user(str(user_id))
136
+ invalidate_session_owner(session_id, str(user_id))
137
+ return {"status": "ok", "deleted_id": session_id}
138
+
139
+
140
+ @router.patch("/{session_id}/title")
141
+ async def update_session_title(title: str, session_id: str, user_id=Depends(get_current_user_id)):
142
+ """Cập nhật tiêu đề phiên chat (Rename a chat session)"""
143
+ supabase = get_supabase()
144
+ res = (
145
+ supabase.table("sessions")
146
+ .update({"title": title})
147
+ .eq("id", session_id)
148
+ .eq("user_id", user_id)
149
+ .execute()
150
+ )
151
+ log_step("db_update", table="sessions", op="title", session_id=session_id)
152
+ invalidate_for_user(str(user_id))
153
+ return res.data[0]
154
+
155
+
156
+ @router.get("/{session_id}/assets", response_model=List[dict])
157
+ async def get_session_assets(session_id: str, user_id=Depends(get_current_user_id)):
158
+ """Lấy danh sách video đã render trong session (Get versioned assets for a session)"""
159
+ supabase = get_supabase()
160
+
161
+ def owns() -> bool:
162
+ res = (
163
+ supabase.table("sessions")
164
+ .select("id")
165
+ .eq("id", session_id)
166
+ .eq("user_id", user_id)
167
+ .execute()
168
+ )
169
+ return bool(res.data)
170
+
171
+ if not session_owned_by_user(session_id, str(user_id), owns):
172
+ raise HTTPException(
173
+ status_code=403, detail="Forbidden: You do not own this session."
174
+ )
175
+
176
+ res = (
177
+ supabase.table("session_assets")
178
+ .select("*")
179
+ .eq("session_id", session_id)
180
+ .order("version", desc=True)
181
+ .execute()
182
+ )
183
+ log_step("db_select", table="session_assets", op="list", session_id=session_id)
184
+ return res.data
app/routers/solve.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import uuid
6
+
7
+ from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, UploadFile
8
+
9
+ from agents.orchestrator import Orchestrator
10
+ from app.chat_image_upload import upload_session_chat_image, validate_chat_image_bytes
11
+ from app.ocr_celery import ocr_celery_enabled
12
+ from app.ocr_local_file import ocr_from_local_image_path
13
+ from app.dependencies import get_current_user_id
14
+ from app.errors import format_error_for_user
15
+ from app.logutil import log_pipeline_failure, log_pipeline_success, log_step
16
+ from app.models.schemas import (
17
+ OcrPreviewResponse,
18
+ RenderVideoRequest,
19
+ RenderVideoResponse,
20
+ SolveRequest,
21
+ SolveResponse,
22
+ )
23
+ from app.ocr_text_merge import build_combined_ocr_preview_draft
24
+ from app.session_cache import invalidate_for_user, session_owned_by_user
25
+ from app.supabase_client import get_supabase
26
+
27
+ logger = logging.getLogger(__name__)
28
+ router = APIRouter(prefix="/api/v1/sessions", tags=["Solve"])
29
+
30
+ # Eager init: all agents and models load at import time (also run in Docker build via scripts/prewarm_models.py).
31
+ ORCHESTRATOR = Orchestrator()
32
+
33
+
34
+ def get_orchestrator() -> Orchestrator:
35
+ return ORCHESTRATOR
36
+
37
+
38
+ _OCR_PREVIEW_MAX_BYTES = 10 * 1024 * 1024
39
+
40
+
41
+ def _assert_session_owner(supabase, session_id: str, user_id, uid: str, op: str) -> None:
42
+ def owns() -> bool:
43
+ res = (
44
+ supabase.table("sessions")
45
+ .select("id")
46
+ .eq("id", session_id)
47
+ .eq("user_id", user_id)
48
+ .execute()
49
+ )
50
+ log_step("db_select", table="sessions", op=op, session_id=session_id)
51
+ return bool(res.data)
52
+
53
+ if not session_owned_by_user(session_id, uid, owns):
54
+ log_pipeline_failure("solve_request", error="forbidden", session_id=session_id)
55
+ raise HTTPException(
56
+ status_code=403, detail="Forbidden: You do not own this session."
57
+ )
58
+
59
+
60
+ def _enqueue_solve_common(
61
+ supabase,
62
+ background_tasks: BackgroundTasks,
63
+ session_id: str,
64
+ user_id,
65
+ uid: str,
66
+ request: SolveRequest,
67
+ message_metadata: dict,
68
+ job_id: str,
69
+ ) -> SolveResponse:
70
+ """Insert user message, job row, enqueue pipeline; update title when first message."""
71
+ supabase.table("messages").insert(
72
+ {
73
+ "session_id": session_id,
74
+ "role": "user",
75
+ "type": "text",
76
+ "content": request.text,
77
+ "metadata": message_metadata,
78
+ }
79
+ ).execute()
80
+ log_step("db_insert", table="messages", op="user_message", session_id=session_id)
81
+
82
+ supabase.table("jobs").insert(
83
+ {
84
+ "id": job_id,
85
+ "user_id": user_id,
86
+ "session_id": session_id,
87
+ "status": "processing",
88
+ "input_text": request.text,
89
+ }
90
+ ).execute()
91
+ log_step("db_insert", table="jobs", job_id=job_id)
92
+
93
+ background_tasks.add_task(process_session_job, job_id, session_id, request, str(user_id))
94
+
95
+ title_check = supabase.table("sessions").select("title").eq("id", session_id).execute()
96
+ if title_check.data and title_check.data[0]["title"] == "Bài toán mới":
97
+ new_title = request.text[:50] + ("..." if len(request.text) > 50 else "")
98
+ supabase.table("sessions").update({"title": new_title}).eq("id", session_id).execute()
99
+ log_step("db_update", table="sessions", op="title_from_first_message")
100
+ invalidate_for_user(uid)
101
+
102
+ log_pipeline_success("solve_accepted", job_id=job_id, session_id=session_id)
103
+ return SolveResponse(job_id=job_id, status="processing")
104
+
105
+
106
+ @router.post("/{session_id}/ocr_preview", response_model=OcrPreviewResponse)
107
+ async def ocr_preview(
108
+ session_id: str,
109
+ user_id=Depends(get_current_user_id),
110
+ file: UploadFile = File(...),
111
+ user_message: str | None = Form(None),
112
+ ):
113
+ """
114
+ Run OCR on an uploaded image and merge with optional user_message into combined_draft.
115
+ Does not insert messages or start a solve job. After user confirms, call POST .../solve
116
+ with text=combined_draft (edited) and omit image_url to avoid double OCR.
117
+ """
118
+ supabase = get_supabase()
119
+ uid = str(user_id)
120
+ _assert_session_owner(supabase, session_id, user_id, uid, "owner_check_ocr_preview")
121
+
122
+ body = await file.read()
123
+ if len(body) > _OCR_PREVIEW_MAX_BYTES:
124
+ raise HTTPException(
125
+ status_code=413,
126
+ detail=f"Image too large (max {_OCR_PREVIEW_MAX_BYTES // (1024 * 1024)} MB).",
127
+ )
128
+ if not body:
129
+ raise HTTPException(status_code=400, detail="Empty file.")
130
+
131
+ if ocr_celery_enabled():
132
+ validate_chat_image_bytes(file.filename, body, file.content_type)
133
+
134
+ suffix = os.path.splitext(file.filename or "")[1].lower()
135
+ if suffix not in (".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp", ""):
136
+ suffix = ".png"
137
+ temp_path = f"temp_ocr_preview_{uuid.uuid4()}{suffix or '.png'}"
138
+ try:
139
+ with open(temp_path, "wb") as f:
140
+ f.write(body)
141
+ ocr_text = await ocr_from_local_image_path(
142
+ temp_path, file.filename, get_orchestrator().ocr_agent
143
+ )
144
+ if ocr_text is None:
145
+ ocr_text = ""
146
+ finally:
147
+ if os.path.exists(temp_path):
148
+ os.remove(temp_path)
149
+
150
+ um = (user_message or "").strip()
151
+ combined = build_combined_ocr_preview_draft(user_message, ocr_text)
152
+ log_step("ocr_preview_done", session_id=session_id, ocr_len=len(ocr_text), user_len=len(um))
153
+ return OcrPreviewResponse(
154
+ ocr_text=ocr_text,
155
+ user_message=um,
156
+ combined_draft=combined,
157
+ )
158
+
159
+
160
+ @router.post("/{session_id}/solve", response_model=SolveResponse)
161
+ async def solve_problem(
162
+ session_id: str,
163
+ request: SolveRequest,
164
+ background_tasks: BackgroundTasks,
165
+ user_id=Depends(get_current_user_id),
166
+ ):
167
+ """
168
+ Gửi câu hỏi giải toán trong một session (Submit geometry problem in a session).
169
+ Lưu câu hỏi vào history và bắt đầu tiến trình giải (chỉ giải toán và tạo hình tĩnh).
170
+ """
171
+ supabase = get_supabase()
172
+ uid = str(user_id)
173
+ _assert_session_owner(supabase, session_id, user_id, uid, "owner_check")
174
+
175
+ message_metadata = {"image_url": request.image_url} if request.image_url else {}
176
+ job_id = str(uuid.uuid4())
177
+ return _enqueue_solve_common(
178
+ supabase,
179
+ background_tasks,
180
+ session_id,
181
+ user_id,
182
+ uid,
183
+ request,
184
+ message_metadata,
185
+ job_id,
186
+ )
187
+
188
+
189
+ @router.post("/{session_id}/solve_multipart", response_model=SolveResponse)
190
+ async def solve_multipart(
191
+ session_id: str,
192
+ background_tasks: BackgroundTasks,
193
+ user_id=Depends(get_current_user_id),
194
+ text: str = Form(...),
195
+ file: UploadFile = File(...),
196
+ ):
197
+ """
198
+ Gửi text + file ảnh trong một request multipart: validate, upload bucket `image`,
199
+ ghi session_assets, lưu message kèm metadata (URL, size, type), rồi enqueue solve
200
+ (image_url trỏ public URL để orchestrator OCR).
201
+ """
202
+ supabase = get_supabase()
203
+ uid = str(user_id)
204
+ _assert_session_owner(supabase, session_id, user_id, uid, "owner_check_solve_multipart")
205
+
206
+ t = (text or "").strip()
207
+ if not t:
208
+ raise HTTPException(status_code=400, detail="text must not be empty.")
209
+
210
+ body = await file.read()
211
+ ext, content_type = validate_chat_image_bytes(file.filename, body, file.content_type)
212
+
213
+ job_id = str(uuid.uuid4())
214
+ up = upload_session_chat_image(session_id, job_id, body, ext, content_type)
215
+ public_url = up["public_url"]
216
+
217
+ message_metadata = {
218
+ "image_url": public_url,
219
+ "attachment": {
220
+ "public_url": public_url,
221
+ "storage_path": up["storage_path"],
222
+ "size_bytes": len(body),
223
+ "content_type": content_type,
224
+ "original_filename": file.filename or "",
225
+ "session_asset_id": up.get("session_asset_id"),
226
+ },
227
+ }
228
+ request = SolveRequest(text=t, image_url=public_url)
229
+ return _enqueue_solve_common(
230
+ supabase,
231
+ background_tasks,
232
+ session_id,
233
+ user_id,
234
+ uid,
235
+ request,
236
+ message_metadata,
237
+ job_id,
238
+ )
239
+
240
+
241
+ @router.post("/{session_id}/render_video", response_model=RenderVideoResponse)
242
+ async def render_video(
243
+ session_id: str,
244
+ request: RenderVideoRequest,
245
+ background_tasks: BackgroundTasks,
246
+ user_id=Depends(get_current_user_id),
247
+ ):
248
+ """
249
+ Yêu cầu tạo video Manim từ trạng thái hình ảnh mới nhất của session.
250
+ """
251
+ supabase = get_supabase()
252
+
253
+ # 1. Kiểm tra quyền sở hữu
254
+ res = supabase.table("sessions").select("id").eq("id", session_id).eq("user_id", user_id).execute()
255
+ if not res.data:
256
+ raise HTTPException(status_code=403, detail="Forbidden: You do not own this session.")
257
+
258
+ # 2. Tìm tin nhắn assistant có metadata hình học (cụ thể job_id hoặc mới nhất trong 10 tin nhắn gần nhất)
259
+ msg_res = (
260
+ supabase.table("messages")
261
+ .select("metadata")
262
+ .eq("session_id", session_id)
263
+ .eq("role", "assistant")
264
+ .order("created_at", desc=True)
265
+ .limit(10)
266
+ .execute()
267
+ )
268
+
269
+ latest_geometry = None
270
+ if msg_res.data:
271
+ for msg in msg_res.data:
272
+ meta = msg.get("metadata", {})
273
+ # Nếu có yêu cầu job_id cụ thể, phải khớp job_id
274
+ if request.job_id and meta.get("job_id") != request.job_id:
275
+ continue
276
+
277
+ # Phải có dữ liệu hình học
278
+ if meta.get("geometry_dsl") and meta.get("coordinates"):
279
+ latest_geometry = meta
280
+ break
281
+
282
+ if not latest_geometry:
283
+ raise HTTPException(status_code=404, detail="Không tìm thấy dữ liệu hình học để render video.")
284
+
285
+ # 3. Tạo Job rendering
286
+ job_id = str(uuid.uuid4())
287
+ supabase.table("jobs").insert({
288
+ "id": job_id,
289
+ "user_id": user_id,
290
+ "session_id": session_id,
291
+ "status": "rendering_queued",
292
+ "input_text": f"Render video requested at {job_id}",
293
+ }).execute()
294
+
295
+ # 4. Dispatch background task
296
+ background_tasks.add_task(process_render_job, job_id, session_id, latest_geometry)
297
+
298
+ return RenderVideoResponse(job_id=job_id, status="rendering_queued")
299
+
300
+
301
+ async def process_session_job(
302
+ job_id: str, session_id: str, request: SolveRequest, user_id: str
303
+ ):
304
+ """Tiến trình giải toán ngầm, tạo hình ảnh tĩnh."""
305
+ from app.websocket_manager import notify_status
306
+
307
+ async def status_update(status: str):
308
+ await notify_status(job_id, {"status": status, "job_id": job_id})
309
+
310
+ supabase = get_supabase()
311
+ try:
312
+ history_res = (
313
+ supabase.table("messages")
314
+ .select("*")
315
+ .eq("session_id", session_id)
316
+ .order("created_at", desc=False)
317
+ .execute()
318
+ )
319
+ history = history_res.data if history_res.data else []
320
+
321
+ result = await get_orchestrator().run(
322
+ request.text,
323
+ request.image_url,
324
+ job_id=job_id,
325
+ session_id=session_id,
326
+ status_callback=status_update,
327
+ history=history,
328
+ )
329
+
330
+ status = result.get("status", "error") if "error" not in result else "error"
331
+
332
+ supabase.table("jobs").update({"status": status, "result": result}).eq(
333
+ "id", job_id
334
+ ).execute()
335
+
336
+ supabase.table("messages").insert(
337
+ {
338
+ "session_id": session_id,
339
+ "role": "assistant",
340
+ "type": "analysis" if "error" not in result else "error",
341
+ "content": (
342
+ result.get("semantic_analysis", "Đã có lỗi xảy ra.")
343
+ if "error" not in result
344
+ else result["error"]
345
+ ),
346
+ "metadata": {
347
+ "job_id": job_id,
348
+ "coordinates": result.get("coordinates"),
349
+ "geometry_dsl": result.get("geometry_dsl"),
350
+ "polygon_order": result.get("polygon_order", []),
351
+ "drawing_phases": result.get("drawing_phases", []),
352
+ "circles": result.get("circles", []),
353
+ "lines": result.get("lines", []),
354
+ "rays": result.get("rays", []),
355
+ "solution": result.get("solution"),
356
+ "is_3d": result.get("is_3d", False),
357
+ },
358
+ }
359
+ ).execute()
360
+
361
+ await notify_status(job_id, {"status": status, "job_id": job_id, "result": result})
362
+
363
+ except Exception as e:
364
+ logger.exception("Error processing session job %s", job_id)
365
+ error_msg = format_error_for_user(e)
366
+ supabase = get_supabase()
367
+ supabase.table("jobs").update(
368
+ {"status": "error", "result": {"error": str(e)}}
369
+ ).eq("id", job_id).execute()
370
+ supabase.table("messages").insert(
371
+ {
372
+ "session_id": session_id,
373
+ "role": "assistant",
374
+ "type": "error",
375
+ "content": error_msg,
376
+ "metadata": {"job_id": job_id},
377
+ }
378
+ ).execute()
379
+ await notify_status(job_id, {"status": "error", "job_id": job_id, "error": error_msg})
380
+
381
+ async def process_render_job(job_id: str, session_id: str, geometry_data: dict):
382
+ """Tiến trình render video từ metadata có sẵn."""
383
+ from app.websocket_manager import notify_status
384
+ from worker.tasks import render_geometry_video
385
+
386
+ await notify_status(job_id, {"status": "rendering_queued", "job_id": job_id})
387
+
388
+ # Prepare payload for Celery (similar to what orchestrator used to do)
389
+ result_payload = {
390
+ "geometry_dsl": geometry_data.get("geometry_dsl"),
391
+ "coordinates": geometry_data.get("coordinates"),
392
+ "polygon_order": geometry_data.get("polygon_order", []),
393
+ "drawing_phases": geometry_data.get("drawing_phases", []),
394
+ "circles": geometry_data.get("circles", []),
395
+ "lines": geometry_data.get("lines", []),
396
+ "rays": geometry_data.get("rays", []),
397
+ "semantic": geometry_data.get("semantic", {}),
398
+ "semantic_analysis": geometry_data.get("semantic_analysis", "🎬 Video minh họa dựng từ trạng thái gần nhất."),
399
+ "session_id": session_id,
400
+ }
401
+
402
+ try:
403
+ logger.info(f"[RenderJob] Attempting to dispatch Celery task for job {job_id}...")
404
+ render_geometry_video.delay(job_id, result_payload)
405
+ logger.info(f"[RenderJob] SUCCESS: Dispatched Celery task for job {job_id}")
406
+ except Exception as e:
407
+ logger.exception(f"[RenderJob] FAILED to dispatch Celery task: {e}")
408
+ supabase = get_supabase()
409
+ supabase.table("jobs").update({"status": "error", "result": {"error": f"Task dispatch failed: {str(e)}"}}).eq("id", job_id).execute()
410
+ await notify_status(job_id, {"status": "error", "job_id": job_id, "error": str(e)})
app/runtime_env.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Default process env vars (Paddle/OpenMP). Call as early as possible after load_dotenv."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+
8
+ def apply_runtime_env_defaults() -> None:
9
+ # Paddle respects OMP_NUM_THREADS at import; setdefault loses if platform already set 2+
10
+ os.environ["OMP_NUM_THREADS"] = "1"
11
+ os.environ["MKL_NUM_THREADS"] = "1"
12
+ os.environ["OPENBLAS_NUM_THREADS"] = "1"
app/session_cache.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TTL in-memory cache để giảm truy vấn Supabase lặp lại (list session, quyền sở hữu session)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Callable
6
+
7
+ from cachetools import TTLCache
8
+
9
+ from app.logutil import log_step
10
+
11
+ _session_list: TTLCache[str, list[Any]] = TTLCache(maxsize=512, ttl=45)
12
+ _session_owner: TTLCache[tuple[str, str], bool] = TTLCache(maxsize=4096, ttl=45)
13
+
14
+
15
+ def invalidate_for_user(user_id: str) -> None:
16
+ """Xoá cache list session của user (sau create / delete / rename / solve đổi title)."""
17
+ _session_list.pop(user_id, None)
18
+ log_step("cache_invalidate", target="session_list", user_id=user_id)
19
+
20
+
21
+ def invalidate_session_owner(session_id: str, user_id: str) -> None:
22
+ _session_owner.pop((session_id, user_id), None)
23
+ log_step("cache_invalidate", target="session_owner", session_id=session_id, user_id=user_id)
24
+
25
+
26
+ def get_sessions_list_cached(user_id: str, fetch: Callable[[], list[Any]]) -> list[Any]:
27
+ if user_id in _session_list:
28
+ log_step("cache_hit", kind="session_list", user_id=user_id)
29
+ return _session_list[user_id]
30
+ log_step("cache_miss", kind="session_list", user_id=user_id)
31
+ data = fetch()
32
+ _session_list[user_id] = data
33
+ return data
34
+
35
+
36
+ def session_owned_by_user(
37
+ session_id: str,
38
+ user_id: str,
39
+ fetch: Callable[[], bool],
40
+ ) -> bool:
41
+ key = (session_id, user_id)
42
+ if key in _session_owner:
43
+ log_step("cache_hit", kind="session_owner", session_id=session_id)
44
+ return _session_owner[key]
45
+ log_step("cache_miss", kind="session_owner", session_id=session_id)
46
+ ok = fetch()
47
+ _session_owner[key] = ok
48
+ return ok
app/supabase_client.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from supabase import Client, ClientOptions, create_client
3
+ from supabase_auth import SyncMemoryStorage
4
+ from dotenv import load_dotenv
5
+
6
+ load_dotenv()
7
+
8
+ from app.url_utils import sanitize_env
9
+
10
+
11
+ def get_supabase() -> Client:
12
+ """Service-role client for server-side operations (bypasses RLS when policies expect service role)."""
13
+ url = sanitize_env(os.getenv("SUPABASE_URL"))
14
+ key = sanitize_env(os.getenv("SUPABASE_SERVICE_ROLE_KEY") or os.getenv("SUPABASE_KEY"))
15
+ if not url or not key:
16
+ raise RuntimeError(
17
+ "SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY (or SUPABASE_KEY) must be set"
18
+ )
19
+ return create_client(url, key)
20
+
21
+
22
+ def get_supabase_for_user_jwt(access_token: str) -> Client:
23
+ """
24
+ Client scoped to the logged-in user: PostgREST sends the user's JWT so RLS applies.
25
+ Use SUPABASE_ANON_KEY (publishable), not the service role key.
26
+ """
27
+ url = sanitize_env(os.getenv("SUPABASE_URL"))
28
+ anon = sanitize_env(os.getenv("SUPABASE_ANON_KEY") or os.getenv("NEXT_PUBLIC_SUPABASE_ANON_KEY"))
29
+ if not url or not anon:
30
+ raise RuntimeError(
31
+ "SUPABASE_URL and SUPABASE_ANON_KEY (or NEXT_PUBLIC_SUPABASE_ANON_KEY) must be set "
32
+ "for user-scoped Supabase access"
33
+ )
34
+ base_opts = ClientOptions(storage=SyncMemoryStorage())
35
+ merged_headers = {**dict(base_opts.headers), "Authorization": f"Bearer {access_token}"}
36
+ opts = ClientOptions(storage=SyncMemoryStorage(), headers=merged_headers)
37
+ return create_client(url, anon, opts)
app/url_utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Normalize URLs / env strings (HF secrets and copy-paste often include trailing newlines)."""
2
+
3
+
4
+ def sanitize_url(value: str | None) -> str | None:
5
+ if value is None:
6
+ return None
7
+ s = value.strip().replace("\r", "").replace("\n", "").replace("\t", "")
8
+ return s or None
9
+
10
+
11
+ def sanitize_env(value: str | None) -> str | None:
12
+ """Strip whitespace and line breaks from environment-backed strings."""
13
+ return sanitize_url(value)
14
+
15
+
16
+ # OpenAI SDK (>=1.x) requires a non-empty api_key at client construction (Docker build / prewarm has no secrets).
17
+ _OPENAI_API_KEY_BUILD_PLACEHOLDER = "build-placeholder-openrouter-not-for-production"
18
+
19
+
20
+ def openai_compatible_api_key(raw: str | None) -> str:
21
+ """Return sanitized API key, or a placeholder so AsyncOpenAI() can be constructed without env at build time."""
22
+ k = sanitize_env(raw)
23
+ return k if k else _OPENAI_API_KEY_BUILD_PLACEHOLDER
app/websocket_manager.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """WebSocket connection registry and job status notifications (avoid circular imports with main)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Dict, List
7
+
8
+ from fastapi import WebSocket, WebSocketDisconnect
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ active_connections: Dict[str, List[WebSocket]] = {}
13
+
14
+
15
+ async def notify_status(job_id: str, data: dict) -> None:
16
+ if job_id not in active_connections:
17
+ return
18
+ for connection in list(active_connections[job_id]):
19
+ try:
20
+ await connection.send_json(data)
21
+ except Exception as e:
22
+ logger.error("WS error sending to %s: %s", job_id, e)
23
+
24
+
25
+ def register_websocket_routes(app) -> None:
26
+ """Attach websocket endpoint to the FastAPI app."""
27
+
28
+ @app.websocket("/ws/{job_id}")
29
+ async def websocket_endpoint(websocket: WebSocket, job_id: str) -> None:
30
+ await websocket.accept()
31
+ if job_id not in active_connections:
32
+ active_connections[job_id] = []
33
+ active_connections[job_id].append(websocket)
34
+ try:
35
+ while True:
36
+ await websocket.receive_text()
37
+ except WebSocketDisconnect:
38
+ active_connections[job_id].remove(websocket)
39
+ if not active_connections[job_id]:
40
+ del active_connections[job_id]
clean_ports.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Script to kill all project-related processes for a clean restart
3
+
4
+ echo "🧹 Cleaning up project processes..."
5
+
6
+ # Kill things on ports 8000 (Backend) and 3000 (Frontend)
7
+ PORTS="8000 3000 11020"
8
+ for PORT in $PORTS; do
9
+ PIDS=$(lsof -ti :$PORT)
10
+ if [ ! -z "$PIDS" ]; then
11
+ echo "Killing processes on port $PORT: $PIDS"
12
+ kill -9 $PIDS 2>/dev/null
13
+ fi
14
+ done
15
+
16
+ # Kill by process name
17
+ echo "Killing any remaining Celery, Uvicorn, or Manim processes..."
18
+ pkill -9 -f "celery" 2>/dev/null
19
+ pkill -9 -f "uvicorn" 2>/dev/null
20
+ pkill -9 -f "manim" 2>/dev/null
21
+
22
+ echo "✅ Done. You can now restart your Backend, Worker, and Frontend."
dump.rdb ADDED
Binary file (5.44 kB). View file
 
geometry_render/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Manim geometry script generation and rendering (worker-safe, no LLM agents)."""
2
+
3
+ from .renderer import RendererAgent
4
+
5
+ __all__ = ["RendererAgent"]
geometry_render/renderer.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import glob
4
+ import string
5
+ import logging
6
+ from typing import Dict, Any, List
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class RendererAgent:
12
+ """
13
+ Renderer — generates Manim scripts from geometry data.
14
+
15
+ Drawing happens in phases:
16
+ Phase 1: Main polygon (base shape with correct vertex order)
17
+ Phase 2: Auxiliary points and segments (midpoints, derived segments)
18
+ Phase 3: Labels for all points
19
+ """
20
+
21
+ def generate_manim_script(self, data: Dict[str, Any]) -> str:
22
+ coords: Dict[str, List[float]] = data.get("coordinates", {})
23
+ polygon_order: List[str] = data.get("polygon_order", [])
24
+ circles_meta: List[Dict] = data.get("circles", [])
25
+ lines_meta: List[List[str]] = data.get("lines", [])
26
+ rays_meta: List[List[str]] = data.get("rays", [])
27
+ drawing_phases: List[Dict] = data.get("drawing_phases", [])
28
+ semantic: Dict[str, Any] = data.get("semantic", {})
29
+ shape_type = semantic.get("type", "").lower()
30
+
31
+ # ── Detect 3D Context ────────────────────────────────────────────────
32
+ is_3d = False
33
+ for pos in coords.values():
34
+ if len(pos) >= 3 and abs(pos[2]) > 0.001:
35
+ is_3d = True
36
+ break
37
+ if shape_type in ["pyramid", "prism", "sphere"]:
38
+ is_3d = True
39
+
40
+ # ── Fallback: infer polygon_order from coords keys (alphabetical uppercase) ──
41
+ if not polygon_order:
42
+ base = sorted(
43
+ [pid for pid in coords if pid in string.ascii_uppercase],
44
+ key=lambda p: string.ascii_uppercase.index(p)
45
+ )
46
+ polygon_order = base
47
+
48
+ # Separate base points from derived (multi-char or lowercase)
49
+ base_ids = [pid for pid in polygon_order if pid in coords]
50
+ derived_ids = [pid for pid in coords if pid not in polygon_order]
51
+
52
+ scene_base = "ThreeDScene" if is_3d else "MovingCameraScene"
53
+ lines = [
54
+ "from manim import *",
55
+ "",
56
+ f"class GeometryScene({scene_base}):",
57
+ " def construct(self):",
58
+ ]
59
+
60
+ if is_3d:
61
+ lines.append(" # 3D Setup")
62
+ lines.append(" self.set_camera_orientation(phi=75*DEGREES, theta=-45*DEGREES)")
63
+ lines.append(" axes = ThreeDAxes(axis_config={'stroke_width': 1})")
64
+ lines.append(" axes.set_opacity(0.3)")
65
+ lines.append(" self.add(axes)")
66
+ lines.append(" self.begin_ambient_camera_rotation(rate=0.1)")
67
+ lines.append("")
68
+
69
+ # ── Declare all dots and labels ───────────────────────────────────────
70
+ for pid, pos in coords.items():
71
+ x, y, z = 0, 0, 0
72
+ if len(pos) >= 1: x = round(pos[0], 4)
73
+ if len(pos) >= 2: y = round(pos[1], 4)
74
+ if len(pos) >= 3: z = round(pos[2], 4)
75
+
76
+ dot_class = "Dot3D" if is_3d else "Dot"
77
+ lines.append(f" p_{pid} = {dot_class}(point=[{x}, {y}, {z}], color=WHITE, radius=0.08)")
78
+
79
+ if is_3d:
80
+ lines.append(
81
+ f" l_{pid} = Text('{pid}', font_size=20, color=WHITE)"
82
+ f".move_to(p_{pid}.get_center() + [0.2, 0.2, 0.2])"
83
+ )
84
+ # Ensure labels follow camera in 3D (fixed orientation)
85
+ lines.append(f" self.add_fixed_orientation_mobjects(l_{pid})")
86
+ else:
87
+ lines.append(
88
+ f" l_{pid} = Text('{pid}', font_size=22, color=WHITE)"
89
+ f".next_to(p_{pid}, UR, buff=0.15)"
90
+ )
91
+
92
+ # ── 3D Shape Special: Pyramid/Prism Faces ────────────────────────────
93
+ if is_3d and shape_type == "pyramid" and len(base_ids) >= 3:
94
+ # Find apex (usually 'S')
95
+ apex_id = "S" if "S" in coords else derived_ids[0] if derived_ids else None
96
+ if apex_id:
97
+ # Draw base face
98
+ base_pts = ", ".join([f"p_{pid}.get_center()" for pid in base_ids])
99
+ lines.append(f" base_face = Polygon({base_pts}, color=BLUE, fill_opacity=0.1)")
100
+ lines.append(" self.play(Create(base_face), run_time=1.0)")
101
+
102
+ # Draw side faces
103
+ for i in range(len(base_ids)):
104
+ p1 = base_ids[i]
105
+ p2 = base_ids[(i + 1) % len(base_ids)]
106
+ face_pts = f"p_{apex_id}.get_center(), p_{p1}.get_center(), p_{p2}.get_center()"
107
+ lines.append(
108
+ f" side_{i} = Polygon({face_pts}, color=BLUE, stroke_width=1, fill_opacity=0.05)"
109
+ )
110
+ lines.append(f" self.play(Create(side_{i}), run_time=0.5)")
111
+
112
+ # ── Circles ──────────────────────────────────────────────────────────
113
+ for i, c in enumerate(circles_meta):
114
+ center = c["center"]
115
+ r = c["radius"]
116
+ if center in coords:
117
+ cx, cy, cz = 0, 0, 0
118
+ pos = coords[center]
119
+ if len(pos) >= 1: cx = round(pos[0], 4)
120
+ if len(pos) >= 2: cy = round(pos[1], 4)
121
+ if len(pos) >= 3: cz = round(pos[2], 4)
122
+ lines.append(
123
+ f" circle_{i} = Circle(radius={r}, color=BLUE)"
124
+ f".move_to([{cx}, {cy}, {cz}])"
125
+ )
126
+
127
+ # ── Infinite Lines & Rays ────────────────────────────────────────────
128
+ # (Standard Line works for 3D coordinates in Manim)
129
+ for i, (p1, p2) in enumerate(lines_meta):
130
+ if p1 in coords and p2 in coords:
131
+ lines.append(
132
+ f" line_ext_{i} = Line(p_{p1}.get_center(), p_{p2}.get_center(), color=GRAY_D, stroke_width=2)"
133
+ f".scale(20)"
134
+ )
135
+
136
+ for i, (p1, p2) in enumerate(rays_meta):
137
+ if p1 in coords and p2 in coords:
138
+ lines.append(
139
+ f" ray_{i} = Line(p_{p1}.get_center(), p_{p1}.get_center() + 15 * (p_{p2}.get_center() - p_{p1}.get_center()),"
140
+ f" color=GRAY_C, stroke_width=2)"
141
+ )
142
+
143
+ # ── Camera auto-fit group (Only for 2D) ──────────────────────────────
144
+ if not is_3d:
145
+ all_dot_names = [f"p_{pid}" for pid in coords]
146
+ all_names_str = ", ".join(all_dot_names)
147
+ lines.append(f" _all = VGroup({all_names_str})")
148
+ lines.append(" self.camera.frame.set_width(max(_all.width * 2.0, 8))")
149
+ lines.append(" self.camera.frame.move_to(_all)")
150
+ lines.append("")
151
+
152
+ # ── Phase 1: Base polygon ─────────────────────────────────────────────
153
+ if len(base_ids) >= 3:
154
+ pts_str = ", ".join([f"p_{pid}.get_center()" for pid in base_ids])
155
+ lines.append(f" poly = Polygon({pts_str}, color=BLUE, fill_color=BLUE, fill_opacity=0.15)")
156
+ lines.append(" self.play(Create(poly), run_time=1.5)")
157
+ elif len(base_ids) == 2:
158
+ p1, p2 = base_ids
159
+ lines.append(f" base_line = Line(p_{p1}.get_center(), p_{p2}.get_center(), color=BLUE)")
160
+ lines.append(" self.play(Create(base_line), run_time=1.0)")
161
+
162
+ # Draw base points
163
+ if base_ids:
164
+ base_dots_str = ", ".join([f"p_{pid}" for pid in base_ids])
165
+ lines.append(f" self.play(FadeIn(VGroup({base_dots_str})), run_time=0.5)")
166
+ lines.append(" self.wait(0.5)")
167
+
168
+ # ── Phase 2: Auxiliary points and segments ────────────────────────────
169
+ if derived_ids:
170
+ derived_dots_str = ", ".join([f"p_{pid}" for pid in derived_ids])
171
+ lines.append(f" self.play(FadeIn(VGroup({derived_dots_str})), run_time=0.8)")
172
+
173
+ # Segments from drawing_phases
174
+ segment_lines = []
175
+ for phase in drawing_phases:
176
+ if phase.get("phase") == 2:
177
+ for seg in phase.get("segments", []):
178
+ if len(seg) == 2 and seg[0] in coords and seg[1] in coords:
179
+ p1, p2 = seg[0], seg[1]
180
+ seg_var = f"seg_{p1}_{p2}"
181
+ lines.append(
182
+ f" {seg_var} = Line(p_{p1}.get_center(), p_{p2}.get_center(),"
183
+ f" color=YELLOW)"
184
+ )
185
+ segment_lines.append(seg_var)
186
+
187
+ if segment_lines:
188
+ segs_str = ", ".join([f"Create({sv})" for sv in segment_lines])
189
+ lines.append(f" self.play({segs_str}, run_time=1.2)")
190
+
191
+ if derived_ids or segment_lines:
192
+ lines.append(" self.wait(0.5)")
193
+
194
+ # ── Phase 3: All labels ───────────────────────────────────────────────
195
+ all_labels_str = ", ".join([f"l_{pid}" for pid in coords])
196
+ lines.append(f" self.play(FadeIn(VGroup({all_labels_str})), run_time=0.8)")
197
+
198
+ # ── Circles phase ─────────────────────────────────────────────────────
199
+ for i in range(len(circles_meta)):
200
+ lines.append(f" self.play(Create(circle_{i}), run_time=1.5)")
201
+
202
+ # ── Lines & Rays phase ────────────────────────────────────────────────
203
+ if lines_meta or rays_meta:
204
+ lr_anims = []
205
+ for i in range(len(lines_meta)):
206
+ lr_anims.append(f"Create(line_ext_{i})")
207
+ for i in range(len(rays_meta)):
208
+ lr_anims.append(f"Create(ray_{i})")
209
+ lines.append(f" self.play({', '.join(lr_anims)}, run_time=1.5)")
210
+
211
+ lines.append(" self.wait(2)")
212
+
213
+ return "\n".join(lines)
214
+
215
+ def run_manim(self, script_content: str, job_id: str) -> str:
216
+ script_file = f"{job_id}.py"
217
+ with open(script_file, "w") as f:
218
+ f.write(script_content)
219
+
220
+ try:
221
+ if os.getenv("MOCK_VIDEO") == "true":
222
+ logger.info(f"MOCK_VIDEO is true. Skipping Manim for job {job_id}")
223
+ # Create a dummy file if needed, or just return a path that exists
224
+ dummy_path = f"videos/{job_id}.mp4"
225
+ os.makedirs("videos", exist_ok=True)
226
+ with open(dummy_path, "wb") as f:
227
+ f.write(b"dummy video content")
228
+ return dummy_path
229
+
230
+ # Determine manim executable path
231
+ manim_exe = "manim"
232
+ venv_manim = os.path.join(os.getcwd(), "venv", "bin", "manim")
233
+ if os.path.exists(venv_manim):
234
+ manim_exe = venv_manim
235
+
236
+ # Prepare environment with homebrew paths
237
+ custom_env = os.environ.copy()
238
+ brew_path = "/opt/homebrew/bin:/usr/local/bin"
239
+ custom_env["PATH"] = f"{brew_path}:{custom_env.get('PATH', '')}"
240
+
241
+ logger.info(f"Running {manim_exe} for job {job_id}...")
242
+ result = subprocess.run(
243
+ [manim_exe, "-ql", "--media_dir", ".", "-o", f"{job_id}.mp4", script_file, "GeometryScene"],
244
+ capture_output=True,
245
+ text=True,
246
+ env=custom_env,
247
+ )
248
+ logger.info(f"Manim STDOUT: {result.stdout}")
249
+ if result.returncode != 0:
250
+ logger.error(f"Manim STDERR: {result.stderr}")
251
+
252
+ for pattern in [f"**/videos/**/{job_id}.mp4", f"**/{job_id}*.mp4"]:
253
+ found = glob.glob(pattern, recursive=True)
254
+ if found:
255
+ logger.info(f"Manim Success: Found {found[0]}")
256
+ return found[0]
257
+
258
+ logger.error(f"Manim file not found for job {job_id}. Return code: {result.returncode}")
259
+ return ""
260
+ except Exception as e:
261
+ logger.exception(f"Manim Execution Error: {e}")
262
+ return ""
263
+ finally:
264
+ if os.path.exists(script_file):
265
+ os.remove(script_file)
migrations/add_image_bucket_storage.sql ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- ============================================================
2
+ -- MathSolver: Supabase Storage bucket `image` (chat / OCR attachments)
3
+ -- Run after session_assets and storage.video policies exist.
4
+ -- ============================================================
5
+
6
+ INSERT INTO storage.buckets (id, name, public)
7
+ VALUES ('image', 'image', true)
8
+ ON CONFLICT (id) DO UPDATE SET public = true;
9
+
10
+ -- Service role: upload/delete/list for API + workers
11
+ DROP POLICY IF EXISTS "Service Role manage images" ON storage.objects;
12
+ CREATE POLICY "Service Role manage images" ON storage.objects
13
+ FOR ALL
14
+ TO service_role
15
+ USING (bucket_id = 'image')
16
+ WITH CHECK (bucket_id = 'image');
17
+
18
+ -- Authenticated: read only objects under sessions they own (path sessions/{session_id}/...)
19
+ DROP POLICY IF EXISTS "Users view session images" ON storage.objects;
20
+ CREATE POLICY "Users view session images" ON storage.objects
21
+ FOR SELECT
22
+ TO authenticated
23
+ USING (
24
+ bucket_id = 'image'
25
+ AND (storage.foldername(name))[2] IN (
26
+ SELECT id::text FROM public.sessions WHERE user_id = auth.uid()
27
+ )
28
+ );
29
+
30
+ -- Public read for get_public_url / FE img tags (same model as video bucket)
31
+ DROP POLICY IF EXISTS "Public read images" ON storage.objects;
32
+ CREATE POLICY "Public read images" ON storage.objects
33
+ FOR SELECT
34
+ TO public
35
+ USING (bucket_id = 'image');
migrations/fix_rls_assets.sql ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- ============================================================
2
+ -- FIX RLS & SESSION ASSETS (MathSolver v5.1 Worker Fix)
3
+ -- ============================================================
4
+
5
+ -- 1. Ensure session_assets table exists
6
+ CREATE TABLE IF NOT EXISTS public.session_assets (
7
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
8
+ session_id UUID NOT NULL REFERENCES public.sessions(id) ON DELETE CASCADE,
9
+ job_id UUID NOT NULL,
10
+ asset_type TEXT NOT NULL CHECK (asset_type IN ('video', 'image')),
11
+ storage_path TEXT NOT NULL,
12
+ public_url TEXT NOT NULL,
13
+ version INTEGER NOT NULL DEFAULT 1,
14
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
15
+ );
16
+
17
+ -- Index for session_assets
18
+ CREATE INDEX IF NOT EXISTS idx_session_assets_session_id ON public.session_assets(session_id);
19
+ CREATE INDEX IF NOT EXISTS idx_session_assets_type ON public.session_assets(session_id, asset_type);
20
+
21
+ -- 2. Enable RLS for all tables
22
+ ALTER TABLE public.session_assets ENABLE ROW LEVEL SECURITY;
23
+ ALTER TABLE public.profiles ENABLE ROW LEVEL SECURITY;
24
+ ALTER TABLE public.sessions ENABLE ROW LEVEL SECURITY;
25
+ ALTER TABLE public.messages ENABLE ROW LEVEL SECURITY;
26
+ ALTER TABLE public.jobs ENABLE ROW LEVEL SECURITY;
27
+
28
+
29
+ -- 3. Fix Table Policies to allow SERVICE ROLE
30
+ -- In Supabase, service_role usually bypasses RLS, but we add explicit policies for safety
31
+ -- especially for path-based checks or when SECURITY DEFINER functions are used.
32
+
33
+ -- [Session Assets]
34
+ DROP POLICY IF EXISTS "Users view own assets" ON public.session_assets;
35
+ CREATE POLICY "Users view own assets" ON public.session_assets
36
+ FOR SELECT USING (
37
+ session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid())
38
+ );
39
+
40
+ DROP POLICY IF EXISTS "Service role manages assets" ON public.session_assets;
41
+ CREATE POLICY "Service role manages assets" ON public.session_assets
42
+ FOR ALL USING (true)
43
+ WITH CHECK (true);
44
+
45
+
46
+ -- [Messages] - Allow Worker to insert assistant messages
47
+ DROP POLICY IF EXISTS "Users manage own messages" ON public.messages;
48
+ CREATE POLICY "Users manage own messages" ON public.messages
49
+ FOR ALL USING (
50
+ session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid())
51
+ OR
52
+ (auth.jwt() ->> 'role' = 'service_role')
53
+ );
54
+
55
+
56
+ -- [Jobs] - Allow Worker to update job status
57
+ DROP POLICY IF EXISTS "Users manage own jobs" ON public.jobs;
58
+ CREATE POLICY "Users manage own jobs" ON public.jobs
59
+ FOR ALL USING (
60
+ auth.uid() = user_id
61
+ OR user_id IS NULL
62
+ OR (auth.jwt() ->> 'role' = 'service_role')
63
+ );
64
+
65
+
66
+ -- 4. Storage Policies (Bucket: video)
67
+ -- Ensure 'video' bucket exists
68
+ INSERT INTO storage.buckets (id, name, public)
69
+ VALUES ('video', 'video', true)
70
+ ON CONFLICT (id) DO UPDATE SET public = true;
71
+
72
+ -- [Storage: Worker / Service Role] - Allow all in video bucket
73
+ DROP POLICY IF EXISTS "Service Role manage videos" ON storage.objects;
74
+ CREATE POLICY "Service Role manage videos" ON storage.objects
75
+ FOR ALL
76
+ TO service_role
77
+ USING (bucket_id = 'video');
78
+
79
+ -- [Storage: Users] - Allow users to view their session videos
80
+ DROP POLICY IF EXISTS "Users view session videos" ON storage.objects;
81
+ CREATE POLICY "Users view session videos" ON storage.objects
82
+ FOR SELECT
83
+ TO authenticated
84
+ USING (
85
+ bucket_id = 'video'
86
+ AND (storage.foldername(name))[2] IN (
87
+ SELECT id::text FROM public.sessions WHERE user_id = auth.uid()
88
+ )
89
+ );
90
+
91
+ -- [Storage: Public] - Allow public read access to videos
92
+ DROP POLICY IF EXISTS "Public read videos" ON storage.objects;
93
+ CREATE POLICY "Public read videos" ON storage.objects
94
+ FOR SELECT
95
+ TO public
96
+ USING (bucket_id = 'video');
migrations/v4_migration.sql ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- ============================================================
2
+ -- MATHSOLVER v4.0 - Migration Script (Multi-Session & History)
3
+ -- ============================================================
4
+
5
+ -- 1. Profiles Table (Extends Supabase Auth)
6
+ CREATE TABLE IF NOT EXISTS public.profiles (
7
+ id UUID PRIMARY KEY REFERENCES auth.users(id) ON DELETE CASCADE,
8
+ display_name TEXT,
9
+ avatar_url TEXT,
10
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
11
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
12
+ );
13
+
14
+ -- Function to handle new user signup and auto-create profile
15
+ CREATE OR REPLACE FUNCTION public.handle_new_user()
16
+ RETURNS TRIGGER AS $$
17
+ BEGIN
18
+ INSERT INTO public.profiles (id, display_name, avatar_url)
19
+ VALUES (
20
+ NEW.id,
21
+ COALESCE(NEW.raw_user_meta_data->>'full_name', NEW.email),
22
+ NEW.raw_user_meta_data->>'avatar_url'
23
+ );
24
+ RETURN NEW;
25
+ END;
26
+ $$ LANGUAGE plpgsql SECURITY DEFINER;
27
+
28
+ -- Trigger for profile creation
29
+ DROP TRIGGER IF EXISTS on_auth_user_created ON auth.users;
30
+ CREATE TRIGGER on_auth_user_created
31
+ AFTER INSERT ON auth.users
32
+ FOR EACH ROW EXECUTE FUNCTION public.handle_new_user();
33
+
34
+ -- 2. Sessions Table
35
+ CREATE TABLE IF NOT EXISTS public.sessions (
36
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
37
+ user_id UUID NOT NULL REFERENCES auth.users(id) ON DELETE CASCADE,
38
+ title TEXT DEFAULT 'Bài toán mới',
39
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
40
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
41
+ );
42
+
43
+ -- Index for sessions
44
+ CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON public.sessions(user_id);
45
+ CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON public.sessions(updated_at DESC);
46
+
47
+ -- 3. Messages Table
48
+ CREATE TABLE IF NOT EXISTS public.messages (
49
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
50
+ session_id UUID NOT NULL REFERENCES public.sessions(id) ON DELETE CASCADE,
51
+ role TEXT NOT NULL CHECK (role IN ('user', 'assistant', 'system')),
52
+ type TEXT NOT NULL DEFAULT 'text',
53
+ content TEXT NOT NULL,
54
+ metadata JSONB DEFAULT '{}'::jsonb,
55
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
56
+ );
57
+
58
+ -- Index for messages
59
+ CREATE INDEX IF NOT EXISTS idx_messages_session_id ON public.messages(session_id);
60
+ CREATE INDEX IF NOT EXISTS idx_messages_created_at ON public.messages(session_id, created_at);
61
+
62
+ -- 4. Session Assets Table (v5.1 Versioning)
63
+ CREATE TABLE IF NOT EXISTS public.session_assets (
64
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
65
+ session_id UUID NOT NULL REFERENCES public.sessions(id) ON DELETE CASCADE,
66
+ job_id UUID NOT NULL,
67
+ asset_type TEXT NOT NULL CHECK (asset_type IN ('video', 'image')),
68
+ storage_path TEXT NOT NULL,
69
+ public_url TEXT NOT NULL,
70
+ version INTEGER NOT NULL DEFAULT 1,
71
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
72
+ );
73
+
74
+ -- Index for session_assets
75
+ CREATE INDEX IF NOT EXISTS idx_session_assets_session_id ON public.session_assets(session_id);
76
+
77
+ -- 5. Update Jobs Table
78
+ ALTER TABLE public.jobs ADD COLUMN IF NOT EXISTS user_id UUID REFERENCES auth.users(id);
79
+ ALTER TABLE public.jobs ADD COLUMN IF NOT EXISTS session_id UUID REFERENCES public.sessions(id);
80
+
81
+ -- 6. Row Level Security (RLS)
82
+ ALTER TABLE public.profiles ENABLE ROW LEVEL SECURITY;
83
+ ALTER TABLE public.sessions ENABLE ROW LEVEL SECURITY;
84
+ ALTER TABLE public.messages ENABLE ROW LEVEL SECURITY;
85
+ ALTER TABLE public.jobs ENABLE ROW LEVEL SECURITY;
86
+ ALTER TABLE public.session_assets ENABLE ROW LEVEL SECURITY;
87
+
88
+ -- Polices for public.profiles
89
+ DROP POLICY IF EXISTS "Users view own profile" ON public.profiles;
90
+ CREATE POLICY "Users view own profile" ON public.profiles FOR SELECT USING (auth.uid() = id);
91
+ DROP POLICY IF EXISTS "Users update own profile" ON public.profiles;
92
+ CREATE POLICY "Users update own profile" ON public.profiles FOR UPDATE USING (auth.uid() = id);
93
+
94
+ -- Policies for public.sessions
95
+ DROP POLICY IF EXISTS "Users manage own sessions" ON public.sessions;
96
+ CREATE POLICY "Users manage own sessions" ON public.sessions FOR ALL USING (auth.uid() = user_id);
97
+
98
+ -- Policies for public.messages
99
+ DROP POLICY IF EXISTS "Users manage own messages" ON public.messages;
100
+ CREATE POLICY "Users manage own messages" ON public.messages FOR ALL USING (
101
+ session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid())
102
+ OR (auth.jwt() ->> 'role' = 'service_role')
103
+ );
104
+
105
+ -- Policies for public.session_assets
106
+ DROP POLICY IF EXISTS "Users view own assets" ON public.session_assets;
107
+ CREATE POLICY "Users view own assets" ON public.session_assets FOR SELECT USING (
108
+ session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid())
109
+ );
110
+ DROP POLICY IF EXISTS "Service role manages assets" ON public.session_assets;
111
+ CREATE POLICY "Service role manages assets" ON public.session_assets FOR ALL USING (true);
112
+
113
+ -- Policies for public.jobs
114
+ DROP POLICY IF EXISTS "Users manage own jobs" ON public.jobs;
115
+ CREATE POLICY "Users manage own jobs" ON public.jobs FOR ALL USING (
116
+ auth.uid() = user_id OR user_id IS NULL OR (auth.jwt() ->> 'role' = 'service_role')
117
+ );
118
+
119
+ -- 7. Storage Policies (Bucket: video)
120
+ -- (Run this in Supabase Dashboard if not allowed in migration)
121
+ -- INSERT INTO storage.buckets (id, name, public) VALUES ('video', 'video', true) ON CONFLICT (id) DO NOTHING;
122
+ -- CREATE POLICY "Service Role manage videos" ON storage.objects FOR ALL TO service_role USING (bucket_id = 'video');
123
+ -- CREATE POLICY "Public read videos" ON storage.objects FOR SELECT TO public USING (bucket_id = 'video');
124
+
125
+ -- Grant permissions to public/authenticated
126
+ GRANT ALL ON public.profiles TO authenticated;
127
+ GRANT ALL ON public.sessions TO authenticated;
128
+ GRANT ALL ON public.messages TO authenticated;
129
+ GRANT ALL ON public.jobs TO authenticated;
130
+ GRANT ALL ON public.session_assets TO authenticated;
131
+ GRANT ALL ON public.session_assets TO service_role;
pytest.ini ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [pytest]
2
+ asyncio_mode = auto
3
+ testpaths = tests
4
+ pythonpath = .
5
+ filterwarnings =
6
+ ignore::DeprecationWarning
7
+
8
+ markers =
9
+ real_api: HTTP tests need running backend and TEST_USER_ID / TEST_SESSION_ID.
10
+ real_worker_ocr: OCR Celery task or full OCR stack (heavy).
11
+ real_worker_manim: Real Manim render and Supabase video upload.
12
+ real_agents: Live LLM / orchestrator agent calls.
13
+ slow: Large suite or long polling timeouts.
14
+ smoke: Fast API health + one solve job.
15
+ orchestrator_local: In-process Orchestrator without HTTP server.
16
+
17
+ # Default: skip integration tests that need services, keys, or long runs.
18
+ addopts = -m "not real_api and not real_worker_ocr and not real_worker_manim and not real_agents and not slow and not orchestrator_local"
requirements.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Target: Python 3.11 (see Dockerfile). Used by: FastAPI API, Celery worker, Manim render, OCR/vision stack.
2
+ # Install: pip install -r requirements.txt
3
+
4
+ # --- Dev / test ---
5
+ pytest>=8.0
6
+ pytest-asyncio>=0.24
7
+
8
+ # --- HTTP API ---
9
+ cachetools>=5.3
10
+ fastapi>=0.115,<1
11
+ uvicorn[standard]>=0.30
12
+ python-multipart>=0.0.9
13
+ python-dotenv>=1.0
14
+ pydantic[email]>=2.4
15
+ email-validator>=2
16
+
17
+ # --- Auth / data / queue ---
18
+ openai>=1.40
19
+ supabase>=2.0
20
+ celery>=5.3
21
+ redis>=5
22
+ httpx>=0.27
23
+ websockets>=12
24
+
25
+ # --- Math & symbolic solver ---
26
+ sympy>=1.12
27
+ numpy>=1.26,<2
28
+ scipy>=1.11
29
+ opencv-python-headless>=4.8,<4.10
30
+
31
+ # --- Video (GeometryScene via CLI) ---
32
+ manim>=0.18,<0.20
33
+
34
+ # --- OCR & vision (orchestrator / legacy /ocr) ---
35
+ pix2tex>=0.1.4
36
+ paddleocr==2.7.3
37
+ paddlepaddle==2.6.2
38
+ ultralytics==8.2.2
requirements.worker-ocr.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OCR-only Celery worker: YOLO + PaddleOCR + Pix2Tex (no OpenRouter / no Manim).
2
+ # Install: pip install -r requirements.worker-ocr.txt
3
+
4
+ cachetools>=5.3
5
+ fastapi>=0.115,<1
6
+ uvicorn[standard]>=0.30
7
+ python-multipart>=0.0.9
8
+ python-dotenv>=1.0
9
+ pydantic[email]>=2.4
10
+ email-validator>=2
11
+
12
+ celery>=5.3
13
+ redis>=5
14
+ httpx>=0.27
15
+ websockets>=12
16
+
17
+ numpy>=1.26,<2
18
+ opencv-python-headless>=4.8,<4.10
19
+
20
+ pix2tex>=0.1.4
21
+ paddleocr==2.7.3
22
+ paddlepaddle==2.6.2
23
+ ultralytics==8.2.2
requirements.worker-render.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Celery render worker: Manim + Supabase (no OpenAI / SymPy / OCR vision stack).
2
+ # Install: pip install -r requirements.worker-render.txt
3
+ # Includes FastAPI/uvicorn for worker_health.py (HF Spaces).
4
+
5
+ cachetools>=5.3
6
+ fastapi>=0.115,<1
7
+ uvicorn[standard]>=0.30
8
+ python-multipart>=0.0.9
9
+ python-dotenv>=1.0
10
+ pydantic[email]>=2.4
11
+ email-validator>=2
12
+
13
+ celery>=5.3
14
+ redis>=5
15
+ httpx>=0.27
16
+ websockets>=12
17
+
18
+ supabase>=2.0
19
+
20
+ numpy>=1.26,<2
21
+ manim>=0.18,<0.20
run_api_test.sh ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ LOG_FILE="api_test_results.log"
4
+ echo "=== Starting API E2E Test Suite ($(date)) ===" > $LOG_FILE
5
+
6
+ # 1. Start BE Server in background
7
+ echo "[INFO] Starting Backend Server..." | tee -a $LOG_FILE
8
+ export ALLOW_TEST_BYPASS=true
9
+ export LOG_LEVEL=info
10
+ export CELERY_TASK_ALWAYS_EAGER=true
11
+ export CELERY_RESULT_BACKEND=rpc://
12
+ export MOCK_VIDEO=true
13
+ PYTHONPATH=. venv/bin/python -m uvicorn app.main:app --port 8000 > server_debug.log 2>&1 &
14
+ SERVER_PID=$!
15
+
16
+ # 2. Wait for server to be ready
17
+ echo "[INFO] Waiting for server (PID: $SERVER_PID) on port 8000..." | tee -a $LOG_FILE
18
+ MAX_RETRIES=15
19
+ READY=0
20
+ for i in $(seq 1 $MAX_RETRIES); do
21
+ if curl -s http://localhost:8000/ > /dev/null; then
22
+ READY=1
23
+ break
24
+ fi
25
+ sleep 2
26
+ done
27
+
28
+ if [ $READY -eq 0 ]; then
29
+ echo "[ERROR] Server failed to start in time. Check server_debug.log" | tee -a $LOG_FILE
30
+ kill $SERVER_PID
31
+ exit 1
32
+ fi
33
+ echo "[INFO] Server is READY." | tee -a $LOG_FILE
34
+
35
+ # 3. Prepare Test Data
36
+ echo "[INFO] Preparing fresh test data..." | tee -a $LOG_FILE
37
+ PREP_OUTPUT=$(PYTHONPATH=. venv/bin/python scripts/prepare_api_test.py)
38
+ echo "$PREP_OUTPUT" >> $LOG_FILE
39
+
40
+ export TEST_USER_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:USER_ID=" | cut -d'=' -f2)
41
+ export TEST_SESSION_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2)
42
+
43
+ if [ -z "$TEST_USER_ID" ] || [ -z "$TEST_SESSION_ID" ]; then
44
+ echo "[ERROR] Failed to prepare test data." | tee -a $LOG_FILE
45
+ kill $SERVER_PID
46
+ exit 1
47
+ fi
48
+
49
+ echo "[INFO] Test Data: User=$TEST_USER_ID, Session=$TEST_SESSION_ID" | tee -a $LOG_FILE
50
+
51
+ # 4. Run Pytest
52
+ echo "[INFO] Running API E2E Tests..." | tee -a $LOG_FILE
53
+ PYTHONPATH=. venv/bin/python -m pytest tests/test_api_real_e2e.py -m "smoke and real_api" -s \
54
+ --junitxml=pytest_smoke.xml >> $LOG_FILE 2>&1
55
+ TEST_EXIT_CODE=$?
56
+
57
+ # 5. Cleanup
58
+ echo "[INFO] Shutting down Server..." | tee -a $LOG_FILE
59
+ kill $SERVER_PID
60
+
61
+ echo "==========================================" | tee -a $LOG_FILE
62
+ if [ $TEST_EXIT_CODE -eq 0 ]; then
63
+ echo "FINAL RESULT: ✅ ALL API TESTS PASSED" | tee -a $LOG_FILE
64
+ else
65
+ echo "FINAL RESULT: ❌ SOME API TESTS FAILED (Code: $TEST_EXIT_CODE)" | tee -a $LOG_FILE
66
+ fi
67
+ echo "==========================================" | tee -a $LOG_FILE
68
+
69
+ exit $TEST_EXIT_CODE
run_full_api_test.sh ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Full API integration (CI-style): eager Celery + mock video + full HTTP suite.
3
+ LOG_FILE="full_api_suite.log"
4
+ REPORT_FILE="full_api_test_report.md"
5
+ JSON_RESULTS="temp_suite_results.json"
6
+ JUNIT="pytest_api_suite.xml"
7
+
8
+ echo "=== Starting Full API Suite Test ($(date)) ===" >"$LOG_FILE"
9
+
10
+ trap 'echo "[INFO] Cleaning up processes..."; kill $SERVER_PID 2>/dev/null; sleep 1' EXIT
11
+
12
+ echo "[INFO] Starting Backend Server (EAGER + MOCK_VIDEO)..." | tee -a "$LOG_FILE"
13
+ export ALLOW_TEST_BYPASS=true
14
+ export LOG_LEVEL=info
15
+ export CELERY_TASK_ALWAYS_EAGER=true
16
+ export CELERY_RESULT_BACKEND=rpc://
17
+ export MOCK_VIDEO=true
18
+ PYTHONPATH=. venv/bin/python -m uvicorn app.main:app --port 8000 >server_debug.log 2>&1 &
19
+ SERVER_PID=$!
20
+
21
+ echo "[INFO] Waiting for server (PID: $SERVER_PID)..." | tee -a "$LOG_FILE"
22
+ for i in {1..20}; do
23
+ if curl -s http://localhost:8000/ >/dev/null; then
24
+ echo "[INFO] Server is READY." | tee -a "$LOG_FILE"
25
+ break
26
+ fi
27
+ sleep 2
28
+ done
29
+
30
+ echo "[INFO] Preparing fresh test data..." | tee -a "$LOG_FILE"
31
+ PREP_OUTPUT=$(PYTHONPATH=. venv/bin/python scripts/prepare_api_test.py)
32
+ export TEST_USER_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:USER_ID=" | cut -d'=' -f2)
33
+ export TEST_SESSION_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2)
34
+
35
+ if [ -z "$TEST_USER_ID" ]; then
36
+ echo "[ERROR] Failed to prepare test data." | tee -a "$LOG_FILE"
37
+ exit 1
38
+ fi
39
+
40
+ echo "[INFO] Executing API tests (smoke + full suite)..." | tee -a "$LOG_FILE"
41
+ PYTHONPATH=. venv/bin/python -m pytest tests/test_api_real_e2e.py tests/test_api_full_suite.py \
42
+ -m "real_api" -s --tb=short --junitxml="$JUNIT" >>"$LOG_FILE" 2>&1
43
+ TEST_EXIT_CODE=$?
44
+
45
+ echo "[INFO] Generating Markdown Report..." | tee -a "$LOG_FILE"
46
+ if [ -f "$JSON_RESULTS" ]; then
47
+ PYTHONPATH=. venv/bin/python scripts/generate_report.py "$JSON_RESULTS" "$REPORT_FILE" "$JUNIT"
48
+ else
49
+ echo "[WARN] $JSON_RESULTS not found" | tee -a "$LOG_FILE"
50
+ fi
51
+
52
+ echo "==========================================" | tee -a "$LOG_FILE"
53
+ echo "DONE. Check $REPORT_FILE for results." | tee -a "$LOG_FILE"
54
+ echo "==========================================" | tee -a "$LOG_FILE"
55
+
56
+ exit $TEST_EXIT_CODE
scripts/benchmark_openrouter.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Benchmark several OpenRouter models (manual tool; not part of pytest)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import time
8
+
9
+ import httpx
10
+ from dotenv import load_dotenv
11
+
12
+ _BACKEND_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13
+ load_dotenv(os.path.join(_BACKEND_ROOT, ".env"))
14
+
15
+ MODELS = [
16
+ "nvidia/nemotron-3-super-120b-a12b:free",
17
+ "meta-llama/llama-3.3-70b-instruct:free",
18
+ "openai/gpt-oss-120b:free",
19
+ "z-ai/glm-4.5-air:free",
20
+ "minimax/minimax-m2.5:free",
21
+ "google/gemma-4-26b-a4b-it:free",
22
+ "google/gemma-4-31b-it:free",
23
+ ]
24
+
25
+ PROMPT = (
26
+ "Cho hình chữ nhật ABCD có AB bằng 5 và AD bằng 10. Gọi E là điểm nằm trong đoạn CD sao cho CE = 2ED. "
27
+ "Vẽ đoạn thẳng AE. Vẽ thêm P là điểm nằm trên đường thẳng BC sao cho BP = 2PC, tính chu vi tam giác PEA"
28
+ )
29
+
30
+
31
+ def main() -> None:
32
+ api_key = os.getenv("OPENROUTER_API_KEY_1") or os.getenv("OPENROUTER_API_KEY")
33
+ base_url = "https://openrouter.ai/api/v1/chat/completions"
34
+
35
+ if not api_key:
36
+ print("Missing OPENROUTER_API_KEY_1 or OPENROUTER_API_KEY in .env")
37
+ return
38
+
39
+ print("Benchmark OpenRouter models\nPrompt:", PROMPT, "\n")
40
+ results = []
41
+
42
+ for model in MODELS:
43
+ print(f"Calling {model}...", end="", flush=True)
44
+ headers = {
45
+ "Authorization": f"Bearer {api_key}",
46
+ "Content-Type": "application/json",
47
+ "HTTP-Referer": "https://mathsolver.io",
48
+ "X-Title": "MathSolver Benchmark Tool",
49
+ }
50
+ payload = {"model": model, "messages": [{"role": "user", "content": PROMPT}]}
51
+ start = time.time()
52
+ try:
53
+ with httpx.Client(timeout=120.0) as client:
54
+ r = client.post(base_url, headers=headers, json=payload)
55
+ r.raise_for_status()
56
+ data = r.json()
57
+ answer = data["choices"][0]["message"]["content"]
58
+ duration = time.time() - start
59
+ results.append(
60
+ {"model": model, "duration": duration, "answer": answer, "status": "success"}
61
+ )
62
+ print(f" OK ({duration:.2f}s)")
63
+ except Exception as e:
64
+ duration = time.time() - start
65
+ results.append(
66
+ {"model": model, "duration": duration, "error": str(e), "status": "error"}
67
+ )
68
+ print(f" FAIL ({duration:.2f}s) {e}")
69
+
70
+ print("\n" + "=" * 80)
71
+ for res in results:
72
+ print(json.dumps(res, ensure_ascii=False, indent=2)[:2000])
73
+ print("-" * 40)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ main()
scripts/generate_report.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ import xml.etree.ElementTree as ET
5
+ from datetime import datetime
6
+
7
+
8
+ def _parse_junit_xml(path: str) -> dict:
9
+ """Summarize pytest junitxml (JUnit) file."""
10
+ out = {"tests": 0, "failures": 0, "errors": 0, "skipped": 0, "time": 0.0}
11
+ try:
12
+ tree = ET.parse(path)
13
+ root = tree.getroot()
14
+ nodes = [root] if root.tag == "testsuite" else list(root.iter("testsuite"))
15
+ for ts in nodes:
16
+ if ts.tag != "testsuite":
17
+ continue
18
+ out["tests"] += int(ts.attrib.get("tests", 0) or 0)
19
+ out["failures"] += int(ts.attrib.get("failures", 0) or 0)
20
+ out["errors"] += int(ts.attrib.get("errors", 0) or 0)
21
+ out["skipped"] += int(ts.attrib.get("skipped", 0) or 0)
22
+ out["time"] += float(ts.attrib.get("time", 0) or 0)
23
+ except Exception as e:
24
+ out["parse_error"] = str(e)
25
+ return out
26
+
27
+
28
+ def generate_report(json_path: str, report_path: str, junit_path: str | None = None) -> None:
29
+ try:
30
+ with open(json_path, "r", encoding="utf-8") as f:
31
+ data = json.load(f)
32
+
33
+ junit_summary = None
34
+ if junit_path and os.path.isfile(junit_path):
35
+ junit_summary = _parse_junit_xml(junit_path)
36
+
37
+ with open(report_path, "w", encoding="utf-8") as f:
38
+ f.write("# Báo cáo Kiểm thử tích hợp Backend (Integration Report)\n\n")
39
+ f.write(f"**Thời gian chạy:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
40
+ suite_ok = all(r.get("success", False) for r in data) if isinstance(data, list) else False
41
+ f.write(f"**API suite (JSON):** {'PASS' if suite_ok else 'FAIL'}\n")
42
+
43
+ if junit_summary and "parse_error" not in junit_summary:
44
+ j_ok = junit_summary["failures"] == 0 and junit_summary["errors"] == 0
45
+ f.write(
46
+ f"**Pytest (JUnit):** {'PASS' if j_ok else 'FAIL'} — "
47
+ f"tests={junit_summary['tests']}, failures={junit_summary['failures']}, "
48
+ f"errors={junit_summary['errors']}, skipped={junit_summary['skipped']}, "
49
+ f"time_s={junit_summary['time']:.2f}\n"
50
+ )
51
+ elif junit_summary and "parse_error" in junit_summary:
52
+ f.write(f"**Pytest (JUnit):** (could not parse: {junit_summary['parse_error']})\n")
53
+
54
+ f.write("\n")
55
+
56
+ f.write("| ID | Câu hỏi (Query) | Trạng thái | Thời gian (s) | Kết quả / Lỗi |\n")
57
+ f.write("| :--- | :--- | :--- | :--- | :--- |\n")
58
+ for r in data:
59
+ status = "PASS" if r.get("success") else "FAIL"
60
+ elapsed = f"{float(r.get('elapsed', 0) or 0):.2f}"
61
+ query = r.get("query", "-")
62
+
63
+ res = r.get("result", {})
64
+ if not isinstance(res, dict):
65
+ res = {}
66
+
67
+ analysis = res.get("semantic_analysis", "-")
68
+ if not r.get("success"):
69
+ analysis = f"**Lỗi:** {r.get('error', '-')}"
70
+
71
+ short_analysis = (analysis[:100] + "...") if len(str(analysis)) > 100 else analysis
72
+
73
+ f.write(f"| {r['id']} | {query} | {status} | {elapsed} | {short_analysis} |\n")
74
+
75
+ f.write("\n---\n**Chi tiết Output (DSL & Analysis):**\n")
76
+ for r in data:
77
+ if not r.get("success"):
78
+ continue
79
+ res = r.get("result", {})
80
+ if not isinstance(res, dict):
81
+ continue
82
+
83
+ f.write(f"\n### Case {r['id']}: {r.get('query')}\n")
84
+ f.write(f"**Semantic Analysis:**\n{res.get('semantic_analysis', '-')}\n\n")
85
+ f.write(f"**Geometry DSL:**\n```\n{res.get('geometry_dsl', '-')}\n```\n")
86
+
87
+ sol = res.get("solution")
88
+ if sol and isinstance(sol, dict):
89
+ f.write("**Solution (v5.1):**\n")
90
+ f.write(f"- **Answer:** {sol.get('answer', 'N/A')}\n")
91
+ f.write("- **Steps:**\n")
92
+ steps = sol.get("steps", [])
93
+ if steps:
94
+ for step in steps:
95
+ f.write(f" - {step}\n")
96
+ else:
97
+ f.write(" - (Không có bước giải cụ thể)\n")
98
+
99
+ if sol.get("symbolic_expression"):
100
+ f.write(f"- **Symbolic:** `{sol.get('symbolic_expression')}`\n")
101
+ f.write("\n")
102
+
103
+ print(f"Report generated: {report_path}")
104
+ except Exception as e:
105
+ print(f"Error generating report: {e}")
106
+
107
+
108
+ if __name__ == "__main__":
109
+ if len(sys.argv) < 3:
110
+ print(
111
+ "Usage: python generate_report.py <json_results> <report_output> [junit_xml_optional]"
112
+ )
113
+ sys.exit(1)
114
+ junit = sys.argv[3] if len(sys.argv) > 3 else None
115
+ generate_report(sys.argv[1], sys.argv[2], junit)
scripts/prepare_api_test.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import uuid
4
+
5
+ from dotenv import load_dotenv
6
+
7
+ # Add parent dir to path to import app modules
8
+ _BACKEND_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
9
+ sys.path.append(_BACKEND_ROOT)
10
+ load_dotenv(os.path.join(_BACKEND_ROOT, ".env"))
11
+
12
+ from app.supabase_client import get_supabase
13
+
14
+ # Default UUID matches historical dev DB; override with TEST_SUPABASE_USER_ID in .env
15
+ _DEFAULT_TEST_USER = "8cd3adb0-7964-4575-949c-d0cadcd8b679"
16
+
17
+
18
+ def prepare():
19
+ supabase = get_supabase()
20
+ user_id = os.environ.get("TEST_SUPABASE_USER_ID", _DEFAULT_TEST_USER).strip()
21
+ session_id = str(uuid.uuid4())
22
+
23
+ print(f"Using test user (TEST_SUPABASE_USER_ID or default): {user_id}")
24
+
25
+ print(f"Creating fresh test session: {session_id}")
26
+ # Insert session
27
+ supabase.table("sessions").insert({
28
+ "id": session_id,
29
+ "user_id": user_id,
30
+ "title": f"Fresh API Test {session_id[:8]}"
31
+ }).execute()
32
+
33
+ # Return IDs for the test script
34
+ print(f"RESULT:USER_ID={user_id}")
35
+ print(f"RESULT:SESSION_ID={session_id}")
36
+
37
+ if __name__ == "__main__":
38
+ prepare()
scripts/prewarm_models.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Download and load all heavy models during Docker build (YOLO, PaddleOCR, Pix2Tex, agents).
4
+ Fails the image build if initialization fails.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ import os
11
+ import sys
12
+
13
+ # Ensure imports work when run as `python scripts/prewarm_models.py` from WORKDIR
14
+ _APP_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
15
+ if _APP_ROOT not in sys.path:
16
+ sys.path.insert(0, _APP_ROOT)
17
+
18
+ os.chdir(_APP_ROOT)
19
+
20
+ from dotenv import load_dotenv
21
+
22
+ load_dotenv()
23
+
24
+ from app.runtime_env import apply_runtime_env_defaults
25
+
26
+ apply_runtime_env_defaults()
27
+
28
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s | %(message)s")
29
+
30
+ logger = logging.getLogger("prewarm")
31
+
32
+
33
+ def main() -> None:
34
+ from agents.orchestrator import Orchestrator
35
+
36
+ logger.info("Constructing Orchestrator (full agent + model load)...")
37
+ Orchestrator()
38
+ logger.info("Prewarm finished successfully.")
39
+
40
+
41
+ if __name__ == "__main__":
42
+ main()