github-actions commited on
Commit
200b382
·
0 Parent(s):

Deploy render worker from GitHub Actions

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +47 -0
  2. Dockerfile.worker.ocr +42 -0
  3. README.md +24 -0
  4. README_HF_WORKER_OCR.md +27 -0
  5. agents/geometry_agent.py +120 -0
  6. agents/knowledge_agent.py +135 -0
  7. agents/ocr_agent.py +307 -0
  8. agents/orchestrator.py +219 -0
  9. agents/parser_agent.py +106 -0
  10. agents/renderer_agent.py +265 -0
  11. agents/solver_agent.py +107 -0
  12. agents/torch_ultralytics_compat.py +33 -0
  13. app/chat_image_upload.py +206 -0
  14. app/dependencies.py +69 -0
  15. app/errors.py +59 -0
  16. app/llm_client.py +100 -0
  17. app/logging_setup.py +112 -0
  18. app/logutil.py +67 -0
  19. app/main.py +140 -0
  20. app/models/schemas.py +80 -0
  21. app/ocr_celery.py +54 -0
  22. app/ocr_local_file.py +43 -0
  23. app/ocr_text_merge.py +14 -0
  24. app/routers/__init__.py +1 -0
  25. app/routers/auth.py +23 -0
  26. app/routers/sessions.py +165 -0
  27. app/routers/solve.py +410 -0
  28. app/runtime_env.py +12 -0
  29. app/session_cache.py +48 -0
  30. app/supabase_client.py +37 -0
  31. app/url_utils.py +23 -0
  32. app/websocket_manager.py +40 -0
  33. clean_ports.sh +22 -0
  34. dump.rdb +0 -0
  35. migrations/add_image_bucket_storage.sql +35 -0
  36. migrations/fix_rls_assets.sql +96 -0
  37. migrations/v4_migration.sql +131 -0
  38. pytest.ini +18 -0
  39. requirements.txt +38 -0
  40. requirements.worker-ocr.txt +23 -0
  41. run_api_test.sh +69 -0
  42. run_full_api_test.sh +56 -0
  43. scripts/benchmark_openrouter.py +77 -0
  44. scripts/generate_report.py +115 -0
  45. scripts/prepare_api_test.py +38 -0
  46. scripts/prewarm_models.py +42 -0
  47. scripts/prewarm_ocr_worker.py +37 -0
  48. scripts/run_real_integration.sh +134 -0
  49. scripts/test_LLM.py +142 -0
  50. scripts/test_engine_direct.py +36 -0
Dockerfile ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Same runtime as API; runs health endpoint + Celery worker (see worker_health.py)
2
+ FROM python:3.11-slim-bookworm
3
+
4
+ ENV PYTHONUNBUFFERED=1 \
5
+ PYTHONDONTWRITEBYTECODE=1 \
6
+ PIP_NO_CACHE_DIR=1 \
7
+ PIP_ROOT_USER_ACTION=ignore \
8
+ NO_ALBUMENTATIONS_UPDATE=1 \
9
+ OMP_NUM_THREADS=1 \
10
+ MKL_NUM_THREADS=1 \
11
+ OPENBLAS_NUM_THREADS=1
12
+
13
+ WORKDIR /app
14
+ ENV PYTHONPATH=/app
15
+
16
+ RUN apt-get update && apt-get install -y --no-install-recommends \
17
+ ffmpeg \
18
+ pkg-config \
19
+ cmake \
20
+ libcairo2 \
21
+ libcairo2-dev \
22
+ libpango-1.0-0 \
23
+ libpango1.0-dev \
24
+ libpangocairo-1.0-0 \
25
+ libgdk-pixbuf-2.0-0 \
26
+ libffi-dev \
27
+ python3-dev \
28
+ texlive-latex-base \
29
+ texlive-fonts-recommended \
30
+ texlive-latex-extra \
31
+ build-essential \
32
+ && rm -rf /var/lib/apt/lists/*
33
+
34
+ COPY requirements.txt .
35
+ RUN pip install --upgrade pip setuptools wheel \
36
+ && pip install -r requirements.txt
37
+
38
+ COPY . .
39
+
40
+ RUN python scripts/prewarm_models.py
41
+
42
+ ENV PORT=7860 \
43
+ CELERY_WORKER_QUEUES=render
44
+ EXPOSE 7860
45
+
46
+ ENTRYPOINT []
47
+ CMD ["sh", "-c", "exec python3 -u worker_health.py"]
Dockerfile.worker.ocr ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Celery worker: OCR queue only (no Manim / LaTeX / Cairo stack).
2
+ FROM python:3.11-slim-bookworm
3
+
4
+ ENV PYTHONUNBUFFERED=1 \
5
+ PYTHONDONTWRITEBYTECODE=1 \
6
+ PIP_NO_CACHE_DIR=1 \
7
+ PIP_ROOT_USER_ACTION=ignore \
8
+ NO_ALBUMENTATIONS_UPDATE=1 \
9
+ OMP_NUM_THREADS=1 \
10
+ MKL_NUM_THREADS=1 \
11
+ OPENBLAS_NUM_THREADS=1
12
+
13
+ WORKDIR /app
14
+ ENV PYTHONPATH=/app
15
+
16
+ RUN apt-get update && apt-get install -y --no-install-recommends \
17
+ build-essential \
18
+ cmake \
19
+ pkg-config \
20
+ python3-dev \
21
+ libglib2.0-0 \
22
+ libgomp1 \
23
+ libgl1 \
24
+ libsm6 \
25
+ libxext6 \
26
+ libxrender1 \
27
+ && rm -rf /var/lib/apt/lists/*
28
+
29
+ COPY requirements.worker-ocr.txt .
30
+ RUN pip install --upgrade pip setuptools wheel \
31
+ && pip install -r requirements.worker-ocr.txt
32
+
33
+ COPY . .
34
+
35
+ RUN python scripts/prewarm_ocr_worker.py
36
+
37
+ ENV PORT=7860 \
38
+ CELERY_WORKER_QUEUES=ocr
39
+ EXPOSE 7860
40
+
41
+ ENTRYPOINT []
42
+ CMD ["sh", "-c", "exec python3 -u worker_health.py"]
README.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Math Solver Render Worker
3
+ emoji: 👷
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: docker
7
+ app_port: 7860
8
+ ---
9
+
10
+ # Math Solver — Render worker (Manim)
11
+
12
+ This Space runs **Celery** via `worker_health.py` and consumes **only** queue **`render`** (`render_geometry_video`). Image sets `CELERY_WORKER_QUEUES=render` by default (`Dockerfile.worker`).
13
+
14
+ **Solve** (orchestrator, agents, OCR-in-request when `OCR_USE_CELERY` is off) runs on the **API** Space, not on this worker.
15
+
16
+ ## OCR offload (separate Space)
17
+
18
+ Queue **`ocr`** is handled by a **dedicated OCR worker** (`Dockerfile.worker.ocr`, `README_HF_WORKER_OCR.md`, workflow `deploy-worker-ocr.yml`). On the API, set `OCR_USE_CELERY=true` and deploy an OCR Space that listens on `ocr`.
19
+
20
+ ## Secrets
21
+
22
+ Same broker as the API: `REDIS_URL` / `CELERY_BROKER_URL`, Supabase, OpenRouter (renderer may use LLM paths), etc.
23
+
24
+ **GitHub Actions:** repository secrets `HF_TOKEN` and `HF_WORKER_REPO` (`owner/space-name`) for this workflow (`deploy-worker.yml`).
README_HF_WORKER_OCR.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Math Solver OCR Worker
3
+ emoji: 👁️
4
+ colorFrom: gray
5
+ colorTo: blue
6
+ sdk: docker
7
+ app_port: 7860
8
+ ---
9
+
10
+ # Math Solver — OCR-only worker
11
+
12
+ This Space runs **Celery** (`worker_health.py`) consuming **only** the `ocr` queue.
13
+
14
+ Set environment:
15
+
16
+ - `CELERY_WORKER_QUEUES=ocr` (default in `Dockerfile.worker.ocr`)
17
+ - Same `REDIS_URL` / `CELERY_BROKER_URL` / `CELERY_RESULT_BACKEND` as the API
18
+
19
+ This Space runs **raw OCR only** (YOLO, PaddleOCR, Pix2Tex). **OpenRouter / LLM tinh chỉnh** không chạy ở đây; API Space gọi `refine_with_llm` sau khi nhận kết quả từ queue `ocr`.
20
+
21
+ On the **API** Space, set `OCR_USE_CELERY=true` so `run_ocr_from_url` tasks are sent to this worker instead of running Paddle/Pix2Tex on the API process.
22
+
23
+ Optional: `OCR_CELERY_TIMEOUT_SEC` (default `180`).
24
+
25
+ **Manim / video** uses a different Celery queue (`render`) and Space — see `README_HF_WORKER.md` and workflow `deploy-worker.yml`.
26
+
27
+ GitHub Actions: repository secrets `HF_TOKEN` and `HF_OCR_WORKER_REPO` (`owner/space-name`) enable workflow `deploy-worker-ocr.yml`.
agents/geometry_agent.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ from openai import AsyncOpenAI
5
+ from typing import Dict, Any
6
+ from dotenv import load_dotenv
7
+
8
+ load_dotenv()
9
+ logger = logging.getLogger(__name__)
10
+
11
+ from app.url_utils import openai_compatible_api_key, sanitize_env
12
+ from app.llm_client import get_llm_client
13
+
14
+
15
+ class GeometryAgent:
16
+ def __init__(self):
17
+ self.llm = get_llm_client()
18
+
19
+ async def generate_dsl(self, semantic_data: Dict[str, Any], previous_dsl: str = None) -> str:
20
+ logger.info("==[GeometryAgent] Generating DSL from semantic data==")
21
+ if previous_dsl:
22
+ logger.info(f"[GeometryAgent] Using previous DSL context (len={len(previous_dsl)})")
23
+
24
+ system_prompt = """
25
+ You are a Geometry DSL Generator. Convert semantic geometry data into a precise Geometry DSL program.
26
+
27
+ === MULTI-TURN CONTEXT ===
28
+ If a PREVIOUS DSL is provided, your job is to UPDATE or EXTEND it.
29
+ 1. DO NOT remove existing points unless the user explicitly asks to "redefine" or "move" them.
30
+ 2. Ensure new segments/points connect correctly to existing ones.
31
+ 3. Your output should be the ENTIRE updated DSL, not just the changes.
32
+
33
+ === DSL COMMANDS ===
34
+ POINT(A) — declare a point
35
+ POINT(A, x, y, z) — declare a point with explicit coordinates
36
+ LENGTH(AB, 5) — distance between A and B is 5 (2D/3D)
37
+ ANGLE(A, 90) — interior angle at vertex A is 90° (2D/3D)
38
+ PARALLEL(AB, CD) — segment AB is parallel to CD (2D/3D)
39
+ PERPENDICULAR(AB, CD) — segment AB is perpendicular to CD (2D/3D)
40
+ MIDPOINT(M, AB) — M is the midpoint of segment AB
41
+ SECTION(E, A, C, k) — E satisfies vector AE = k * vector AC (k is decimal)
42
+ LINE(A, B) — infinite line passing through A and B
43
+ RAY(A, B) — ray starting at A and passing through B
44
+ CIRCLE(O, 5) — circle with center O and radius 5 (2D)
45
+ SPHERE(O, 5) — sphere with center O and radius 5 (3D)
46
+ SEGMENT(M, N) — auxiliary segment MN to be drawn
47
+ POLYGON_ORDER(A, B, C, D) — the order in which vertices form the polygon boundary
48
+ TRIANGLE(ABC) — equilateral/arbitrary triangle
49
+ PYRAMID(S_ABCD) — pyramid with apex S and base ABCD
50
+ PRISM(ABC_DEF) — triangular prism
51
+
52
+ === RULES ===
53
+ 1. 3D Coordinates: Use POINT(A, x, y, z) if specific coordinates are given in the problem.
54
+ 2. Space Geometry: For pyramids/prisms, use the specialized commands.
55
+ 3. Primary Vertices: Always declare the main vertices of the shape (e.g., A, B, C, D) using POINT(X).
56
+ 4. POLYGON_ORDER: Always emit POLYGON_ORDER(...) for the main shape using ONLY these primary vertices.
57
+ 5. All Points: EVERY point mentioned (A, B, C, H, M, etc.) MUST be declared with POINT(Name) first.
58
+ 6. Altitudes/Perpendiculars: For an altitude AH to BC, use POINT(H) + PERPENDICULAR(AH, BC).
59
+ 7. Format: Output ONLY DSL lines — NO explanation, NO markdown, NO code blocks.
60
+
61
+ === SHAPE EXAMPLES ===
62
+
63
+ --- Case: Square Pyramid S.ABCD with side 10, height 15 ---
64
+ PYRAMID(S_ABCD)
65
+ POINT(A, 0, 0, 0)
66
+ POINT(B, 10, 0, 0)
67
+ POINT(C, 10, 10, 0)
68
+ POINT(D, 0, 10, 0)
69
+ POINT(S)
70
+ POINT(O)
71
+ SECTION(O, A, C, 0.5)
72
+ LENGTH(SO, 15)
73
+ PERPENDICULAR(SO, AC)
74
+ PERPENDICULAR(SO, AB)
75
+ POLYGON_ORDER(A, B, C, D)
76
+
77
+ --- Case: Right Triangle ABC at A, AB=3, AC=4, altitude AH ---
78
+ POLYGON_ORDER(A, B, C)
79
+ POINT(A)
80
+ POINT(B)
81
+ POINT(C)
82
+ POINT(H)
83
+ LENGTH(AB, 3)
84
+ LENGTH(AC, 4)
85
+ ANGLE(A, 90)
86
+ PERPENDICULAR(AH, BC)
87
+ SEGMENT(A, H)
88
+
89
+ --- Case: Rectangle ABCD with AB=5, AD=10 ---
90
+ POLYGON_ORDER(A, B, C, D)
91
+ POINT(A)
92
+ POINT(B)
93
+ POINT(C)
94
+ POINT(D)
95
+ LENGTH(AB, 5)
96
+ LENGTH(AD, 10)
97
+ PERPENDICULAR(AB, AD)
98
+ PARALLEL(AB, CD)
99
+ PARALLEL(AD, BC)
100
+
101
+ [Circle with center O radius 7]
102
+ POINT(O)
103
+ CIRCLE(O, 7)
104
+ """
105
+
106
+ user_content = f"Semantic Data: {json.dumps(semantic_data, ensure_ascii=False)}"
107
+ if previous_dsl:
108
+ user_content = f"PREVIOUS DSL:\n{previous_dsl}\n\nUPDATE WITH NEW DATA: {json.dumps(semantic_data, ensure_ascii=False)}"
109
+
110
+ logger.debug("[GeometryAgent] Calling LLM (Multi-Layer)...")
111
+ content = await self.llm.chat_completions_create(
112
+ messages=[
113
+ {"role": "system", "content": system_prompt},
114
+ {"role": "user", "content": user_content}
115
+ ]
116
+ )
117
+ dsl = content.strip() if content else ""
118
+ logger.info(f"[GeometryAgent] DSL generated ({len(dsl.splitlines())} lines).")
119
+ logger.debug(f"[GeometryAgent] DSL output:\n{dsl}")
120
+ return dsl
agents/knowledge_agent.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Dict, Any
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ # ─── Shape rule registry ────────────────────────────────────────────────────
8
+ # Each entry: keyword list → augmentation function
9
+ # Augmentation receives (values: dict, text: str) and returns updated values dict.
10
+
11
+ class KnowledgeAgent:
12
+ """Knowledge Agent: Stores geometric theorems and common patterns to augment Parser output."""
13
+
14
+ def augment_semantic_data(self, semantic_data: Dict[str, Any]) -> Dict[str, Any]:
15
+ logger.info("==[KnowledgeAgent] Augmenting semantic data==")
16
+ text = str(semantic_data.get("input_text", "")).lower()
17
+ logger.debug(f"[KnowledgeAgent] Input text for matching: '{text[:200]}'")
18
+
19
+ shape_type = self._detect_shape(text, semantic_data.get("type", ""))
20
+ if shape_type:
21
+ semantic_data["type"] = shape_type
22
+ values = semantic_data.get("values", {})
23
+ values = self._augment_values(shape_type, values, text)
24
+ semantic_data["values"] = values
25
+ else:
26
+ logger.info("[KnowledgeAgent] No special rule matched. Returning data unchanged.")
27
+
28
+ logger.debug(f"[KnowledgeAgent] Output semantic data: {semantic_data}")
29
+ return semantic_data
30
+
31
+ # ─── Shape detection ────────────────────────────────────────────────────
32
+ def _detect_shape(self, text: str, llm_type: str) -> str | None:
33
+ """Detect shape from text keywords. LLM type provides a hint."""
34
+ checks = [
35
+ (["hình vuông", "square"], "square"),
36
+ (["hình chữ nhật", "rectangle"], "rectangle"),
37
+ (["hình thoi", "rhombus"], "rhombus"),
38
+ (["hình bình hành", "parallelogram"], "parallelogram"),
39
+ (["hình thang vuông"], "right_trapezoid"),
40
+ (["hình thang", "trapezoid", "trapezium"], "trapezoid"),
41
+ (["tam giác vuông", "right triangle"], "right_triangle"),
42
+ (["tam giác đều", "equilateral triangle", "equilateral"], "equilateral_triangle"),
43
+ (["tam giác cân", "isosceles"], "isosceles_triangle"),
44
+ (["tam giác", "triangle"], "triangle"),
45
+ (["đường tròn", "circle"], "circle"),
46
+ ]
47
+ for keywords, shape in checks:
48
+ if any(kw in text for kw in keywords):
49
+ logger.info(f"[KnowledgeAgent] Rule MATCH: '{shape}' detected (keyword match).")
50
+ return shape
51
+
52
+ # Fallback: trust LLM-detected type if it's a known type
53
+ known = {
54
+ "rectangle", "square", "rhombus", "parallelogram",
55
+ "trapezoid", "right_trapezoid", "triangle", "right_triangle",
56
+ "equilateral_triangle", "isosceles_triangle", "circle",
57
+ }
58
+ if llm_type in known:
59
+ logger.info(f"[KnowledgeAgent] Using LLM-detected type '{llm_type}'.")
60
+ return llm_type
61
+
62
+ return None
63
+
64
+ # ─── Value augmentation ──────────────────────────────────────────────────
65
+ def _augment_values(self, shape: str, values: dict, text: str) -> dict:
66
+ ab = values.get("AB")
67
+ ad = values.get("AD")
68
+ bc = values.get("BC")
69
+ cd = values.get("CD")
70
+
71
+ if shape == "rectangle":
72
+ if ab and ad:
73
+ values.setdefault("CD", ab)
74
+ values.setdefault("BC", ad)
75
+ values.setdefault("angle_A", 90)
76
+ logger.info(f"[KnowledgeAgent] Rectangle: AB=CD={ab}, AD=BC={ad}, angle_A=90°")
77
+ else:
78
+ values.setdefault("angle_A", 90)
79
+
80
+ elif shape == "square":
81
+ side = ab or ad or bc or cd or values.get("side")
82
+ if side:
83
+ values.update({"AB": side, "AD": side, "angle_A": 90})
84
+ logger.info(f"[KnowledgeAgent] Square: side={side}, angle_A=90°")
85
+ else:
86
+ values.setdefault("angle_A", 90)
87
+
88
+ elif shape == "rhombus":
89
+ side = ab or values.get("side")
90
+ if side:
91
+ values.update({"AB": side, "BC": side, "CD": side, "DA": side})
92
+ logger.info(f"[KnowledgeAgent] Rhombus: all sides={side}")
93
+
94
+ elif shape == "parallelogram":
95
+ if ab:
96
+ values.setdefault("CD", ab)
97
+ if ad:
98
+ values.setdefault("BC", ad)
99
+ logger.info(f"[KnowledgeAgent] Parallelogram: AB||CD, AD||BC")
100
+
101
+ elif shape == "trapezoid":
102
+ logger.info("[KnowledgeAgent] Trapezoid: AB||CD (bottom||top)")
103
+
104
+ elif shape == "right_trapezoid":
105
+ logger.info("[KnowledgeAgent] Right trapezoid: AB||CD, AD⊥AB")
106
+ values.setdefault("angle_A", 90)
107
+
108
+ elif shape == "equilateral_triangle":
109
+ side = ab or values.get("side")
110
+ if side:
111
+ values.update({"AB": side, "BC": side, "CA": side, "angle_A": 60})
112
+ logger.info(f"[KnowledgeAgent] Equilateral triangle: all sides={side}, angle_A=60°")
113
+
114
+ elif shape == "right_triangle":
115
+ # Try to infer which vertex is the right angle
116
+ rt_vertex = _detect_right_angle_vertex(text)
117
+ values.setdefault(f"angle_{rt_vertex}", 90)
118
+ logger.info(f"[KnowledgeAgent] Right triangle: angle_{rt_vertex}=90°")
119
+
120
+ elif shape == "isosceles_triangle":
121
+ logger.info("[KnowledgeAgent] Isosceles triangle: AB=AC (default, LLM may override)")
122
+
123
+ elif shape == "circle":
124
+ logger.info("[KnowledgeAgent] Circle detected — no side augmentation needed.")
125
+
126
+ return values
127
+
128
+
129
+ def _detect_right_angle_vertex(text: str) -> str:
130
+ """Heuristic: detect which vertex is right angle from text."""
131
+ for vertex in ["A", "B", "C", "D"]:
132
+ patterns = [f"vuông tại {vertex}", f"góc {vertex} vuông", f"right angle at {vertex}"]
133
+ if any(p.lower() in text for p in patterns):
134
+ return vertex
135
+ return "A" # default
agents/ocr_agent.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import logging
4
+ import asyncio
5
+ from typing import List, Dict, Any, Optional, Tuple
6
+
7
+ import cv2
8
+ import numpy as np
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ _OCR_MAX_EDGE = 2000
13
+ _CROP_PAD = 4
14
+
15
+
16
+ class ImprovedOCRAgent:
17
+ """
18
+ Advanced OCR Agent using a hybrid pipeline:
19
+ 1. YOLO for layout analysis (text vs formula).
20
+ 2. PaddleOCR for Vietnamese text extraction.
21
+ 3. Pix2Tex for LaTeX formula extraction.
22
+ 4. Optional MegaLLM for semantic correction and formatting (skipped when ``skip_llm_refinement`` is True,
23
+ e.g. on the dedicated OCR Celery worker; the API Space runs ``refine_with_llm`` on the raw text).
24
+ """
25
+
26
+ def __init__(self, skip_llm_refinement: bool = False):
27
+ self._skip_llm_refinement = bool(skip_llm_refinement)
28
+ logger.info("[ImprovedOCRAgent] Initializing engines (skip_llm_refinement=%s)...", self._skip_llm_refinement)
29
+
30
+ if self._skip_llm_refinement:
31
+ self.llm = None
32
+ logger.info("[ImprovedOCRAgent] LLM client skipped (raw OCR only).")
33
+ else:
34
+ from app.llm_client import get_llm_client
35
+
36
+ self.llm = get_llm_client()
37
+ logger.info("[ImprovedOCRAgent] Multi-Layer LLM Client initialized.")
38
+
39
+ try:
40
+ from agents.torch_ultralytics_compat import allow_ultralytics_weights
41
+ from ultralytics import YOLO
42
+
43
+ allow_ultralytics_weights()
44
+ logger.info("[ImprovedOCRAgent] Loading YOLO...")
45
+ self.layout_model = YOLO("yolov8n.pt")
46
+ logger.info("[ImprovedOCRAgent] YOLO initialized.")
47
+ except Exception as e:
48
+ logger.error("[ImprovedOCRAgent] YOLO init failed: %s", e)
49
+ self.layout_model = None
50
+
51
+ try:
52
+ from paddleocr import PaddleOCR
53
+
54
+ logger.info("[ImprovedOCRAgent] Loading PaddleOCR...")
55
+ self.text_model = PaddleOCR(use_angle_cls=True, lang="vi")
56
+ logger.info("[ImprovedOCRAgent] PaddleOCR (vi) initialized.")
57
+ except Exception as e:
58
+ logger.error("[ImprovedOCRAgent] PaddleOCR init failed: %s", e)
59
+ self.text_model = None
60
+
61
+ try:
62
+ from pix2tex.cli import LatexOCR
63
+
64
+ logger.info("[ImprovedOCRAgent] Loading Pix2Tex...")
65
+ self.math_model = LatexOCR()
66
+ logger.info("[ImprovedOCRAgent] Pix2Tex initialized.")
67
+ except Exception as e:
68
+ logger.error("[ImprovedOCRAgent] Pix2Tex init failed: %s", e)
69
+ self.math_model = None
70
+
71
+ def _preprocess_image_for_ocr(self, src_path: str) -> Tuple[str, bool]:
72
+ """Resize large images, CLAHE on luminance; returns path (may be new temp file)."""
73
+ img = cv2.imread(src_path, cv2.IMREAD_COLOR)
74
+ if img is None:
75
+ g = cv2.imread(src_path, cv2.IMREAD_GRAYSCALE)
76
+ if g is None:
77
+ logger.warning("[ImprovedOCRAgent] OpenCV could not read %s; using original.", src_path)
78
+ return src_path, False
79
+ img = cv2.cvtColor(g, cv2.COLOR_GRAY2BGR)
80
+ h, w = img.shape[:2]
81
+ max_dim = max(h, w)
82
+ if max_dim > _OCR_MAX_EDGE:
83
+ scale = _OCR_MAX_EDGE / max_dim
84
+ img = cv2.resize(
85
+ img, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_AREA
86
+ )
87
+ logger.info("[ImprovedOCRAgent] Resized for OCR to max edge %s", _OCR_MAX_EDGE)
88
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
89
+ gray = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(gray)
90
+ den = cv2.fastNlMeansDenoising(gray, None, 8, 7, 21)
91
+ out = f"temp_ocr_prep_{uuid.uuid4().hex}.png"
92
+ cv2.imwrite(out, den)
93
+ return out, True
94
+
95
+ def _load_bgr_for_crops(self, path: str) -> Optional[np.ndarray]:
96
+ im = cv2.imread(path, cv2.IMREAD_COLOR)
97
+ if im is None:
98
+ g = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
99
+ if g is None:
100
+ return None
101
+ im = cv2.cvtColor(g, cv2.COLOR_GRAY2BGR)
102
+ return im
103
+
104
+ def _crop_from_quad(self, img_bgr: np.ndarray, bbox) -> Optional[np.ndarray]:
105
+ try:
106
+ pts = np.array(bbox, dtype=np.float32)
107
+ xs = pts[:, 0]
108
+ ys = pts[:, 1]
109
+ H, W = img_bgr.shape[:2]
110
+ x1 = max(0, int(xs.min()) - _CROP_PAD)
111
+ y1 = max(0, int(ys.min()) - _CROP_PAD)
112
+ x2 = min(W, int(xs.max()) + _CROP_PAD)
113
+ y2 = min(H, int(ys.max()) + _CROP_PAD)
114
+ if x2 <= x1 or y2 <= y1:
115
+ return None
116
+ return img_bgr[y1:y2, x1:x2].copy()
117
+ except Exception as e:
118
+ logger.debug("[ImprovedOCRAgent] crop failed: %s", e)
119
+ return None
120
+
121
+ def _latex_from_crop_bgr(self, crop_bgr: np.ndarray) -> Optional[str]:
122
+ if self.math_model is None or crop_bgr is None or crop_bgr.size == 0:
123
+ return None
124
+ ch, cw = crop_bgr.shape[:2]
125
+ if ch < 10 or cw < 10:
126
+ return None
127
+ try:
128
+ from PIL import Image
129
+
130
+ rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB)
131
+ pil = Image.fromarray(rgb)
132
+ out = self.math_model(pil)
133
+ if isinstance(out, str) and out.strip():
134
+ return out.strip()
135
+ except Exception as e:
136
+ logger.debug("[ImprovedOCRAgent] Pix2Tex on crop failed: %s", e)
137
+ return None
138
+
139
+ def _maybe_math_from_crop(self, img_bgr: Optional[np.ndarray], bbox, text: str) -> str:
140
+ if img_bgr is None or not self.math_model:
141
+ return text
142
+ is_math_hint = any(
143
+ c in text for c in ["\\", "^", "_", "{", "}", "=", "+", "-", "*", "/"]
144
+ )
145
+ if not is_math_hint:
146
+ return text
147
+ crop = self._crop_from_quad(img_bgr, bbox)
148
+ latex = self._latex_from_crop_bgr(crop) if crop is not None else None
149
+ if latex:
150
+ logger.info("[ImprovedOCRAgent] Pix2Tex replaced line fragment (len=%s)", len(latex))
151
+ return f"${latex}$"
152
+ return text
153
+
154
+ async def process_image(self, image_path: str) -> str:
155
+ logger.info("==[ImprovedOCRAgent] Processing: %s==", image_path)
156
+
157
+ if not os.path.exists(image_path):
158
+ return f"Error: File {image_path} not found."
159
+
160
+ prep_path, prep_cleanup = self._preprocess_image_for_ocr(image_path)
161
+ paddle_path = prep_path if prep_cleanup else image_path
162
+ img_bgr = self._load_bgr_for_crops(prep_path if prep_cleanup else image_path)
163
+
164
+ raw_fragments: List[Dict[str, Any]] = []
165
+
166
+ try:
167
+ if self.text_model:
168
+ logger.info("[ImprovedOCRAgent] Running PaddleOCR on %s...", paddle_path)
169
+ result = self.text_model.ocr(paddle_path)
170
+ logger.info("[ImprovedOCRAgent] PaddleOCR raw result: %s", result)
171
+
172
+ if not result:
173
+ logger.warning("[ImprovedOCRAgent] PaddleOCR returned no results.")
174
+ return ""
175
+
176
+ if isinstance(result[0], dict):
177
+ res_dict = result[0]
178
+ rec_texts = res_dict.get("rec_texts", [])
179
+ rec_scores = res_dict.get("rec_scores", [])
180
+ rec_polys = res_dict.get("rec_polys", [])
181
+
182
+ for i in range(len(rec_texts)):
183
+ text = rec_texts[i]
184
+ bbox = rec_polys[i]
185
+ score = rec_scores[i] if i < len(rec_scores) else None
186
+ if score is not None and float(score) < 0.45:
187
+ logger.debug(
188
+ "[ImprovedOCRAgent] Low-confidence line (score=%s): %s",
189
+ score,
190
+ text[:80],
191
+ )
192
+
193
+ y_top = int(min(p[1] for p in bbox)) if hasattr(bbox, "__iter__") else 0
194
+ content = self._maybe_math_from_crop(img_bgr, bbox, text)
195
+ raw_fragments.append({"y": y_top, "content": content, "type": "text"})
196
+ elif isinstance(result[0], list):
197
+ for line in result[0]:
198
+ bbox = line[0]
199
+ text = line[1][0]
200
+ score = line[1][1] if len(line[1]) > 1 else None
201
+ if score is not None and float(score) < 0.45:
202
+ logger.debug(
203
+ "[ImprovedOCRAgent] Low-confidence line (score=%s): %s",
204
+ score,
205
+ text[:80],
206
+ )
207
+
208
+ y_top = int(bbox[0][1])
209
+ content = self._maybe_math_from_crop(img_bgr, bbox, text)
210
+ raw_fragments.append({"y": y_top, "content": content, "type": "text"})
211
+ finally:
212
+ if prep_cleanup and os.path.exists(prep_path):
213
+ try:
214
+ os.remove(prep_path)
215
+ except OSError:
216
+ pass
217
+
218
+ raw_fragments.sort(key=lambda x: x["y"])
219
+ combined_text = "\n".join([f["content"] for f in raw_fragments])
220
+
221
+ logger.info(
222
+ "[ImprovedOCRAgent] Raw OCR output assembled:\n---\n%s\n---", combined_text
223
+ )
224
+
225
+ if not combined_text.strip():
226
+ logger.warning("[ImprovedOCRAgent] No text detected to refine.")
227
+ return ""
228
+
229
+ if self._skip_llm_refinement or self.llm is None:
230
+ logger.info("[ImprovedOCRAgent] Skipping MegaLLM refinement (raw OCR output).")
231
+ return combined_text
232
+
233
+ try:
234
+ logger.info("[ImprovedOCRAgent] Sending to MegaLLM for refinement...")
235
+ refined_text = await asyncio.wait_for(
236
+ self.refine_with_llm(combined_text), timeout=30.0
237
+ )
238
+ return refined_text
239
+ except asyncio.TimeoutError:
240
+ logger.error("[ImprovedOCRAgent] MegaLLM refinement timed out.")
241
+ return combined_text
242
+ except Exception as e:
243
+ logger.error("[ImprovedOCRAgent] MegaLLM refinement failed: %s", e)
244
+ return combined_text
245
+
246
+ async def refine_with_llm(self, text: str) -> str:
247
+ if not text.strip():
248
+ return ""
249
+ if self.llm is None:
250
+ logger.warning("[ImprovedOCRAgent] refine_with_llm: no LLM client; returning raw text.")
251
+ return text
252
+
253
+ prompt = f"""Bạn là một chuyên gia số hóa tài liệu toán học.
254
+ Dưới đây là kết quả OCR thô từ một trang sách toán Tiếng Việt.
255
+ Kết quả này có thể chứa lỗi chính tả, lỗi định dạng mã LaTeX, hoặc bị ngắt quãng không logic.
256
+
257
+ Nhiệm vụ của bạn:
258
+ 1. Sửa lỗi chính tả tiếng Việt.
259
+ 2. Đảm bảo các công thức toán học được viết đúng định dạng LaTeX và nằm trong cặp dấu $...$.
260
+ 3. Giữ nguyên cấu trúc logic của bài toán.
261
+ 4. Trả về nội dung đã được làm sạch dưới dạng Markdown.
262
+
263
+ Nội dung OCR thô:
264
+ ---
265
+ {text}
266
+ ---
267
+
268
+ Kết quả làm sạch:"""
269
+
270
+ try:
271
+ refined = await self.llm.chat_completions_create(
272
+ messages=[{"role": "user", "content": prompt}],
273
+ temperature=0.1,
274
+ )
275
+ logger.info("[ImprovedOCRAgent] LLM refinement complete.")
276
+ return refined
277
+ except Exception as e:
278
+ logger.error("[ImprovedOCRAgent] LLM refinement failed: %s", e)
279
+ return text
280
+
281
+ async def process_url(self, url: str) -> str:
282
+ import httpx
283
+
284
+ from app.url_utils import sanitize_url
285
+
286
+ url = sanitize_url(url)
287
+ if not url:
288
+ return "Error: Empty image URL after cleanup."
289
+
290
+ async with httpx.AsyncClient() as client:
291
+ resp = await client.get(url)
292
+ if resp.status_code == 200:
293
+ temp_path = "temp_url_image.png"
294
+ with open(temp_path, "wb") as f:
295
+ f.write(resp.content)
296
+ try:
297
+ return await self.process_image(temp_path)
298
+ finally:
299
+ if os.path.exists(temp_path):
300
+ os.remove(temp_path)
301
+ return f"Error: Failed to fetch image from URL {url}"
302
+
303
+
304
+ class OCRAgent(ImprovedOCRAgent):
305
+ """Alias for compatibility with existing code."""
306
+
307
+ pass
agents/orchestrator.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from typing import Any, Dict
4
+
5
+ from agents.geometry_agent import GeometryAgent
6
+ from agents.knowledge_agent import KnowledgeAgent
7
+ from agents.ocr_agent import OCRAgent
8
+ from agents.parser_agent import ParserAgent
9
+ from agents.renderer_agent import RendererAgent
10
+ from agents.solver_agent import SolverAgent
11
+ from app.logutil import log_step
12
+ from app.ocr_celery import ocr_from_image_url
13
+ from solver.dsl_parser import DSLParser
14
+ from solver.engine import GeometryEngine
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ _CLIP = 2000
19
+
20
+
21
+ def _clip(val: Any, n: int = _CLIP) -> str | None:
22
+ if val is None:
23
+ return None
24
+ if isinstance(val, str):
25
+ s = val
26
+ else:
27
+ s = json.dumps(val, ensure_ascii=False, default=str)
28
+ return s if len(s) <= n else s[:n] + "…"
29
+
30
+
31
+ def _step_io(step: str, input_val: Any = None, output_val: Any = None) -> None:
32
+ """Debug: chỉ input/output (đã cắt), tránh dump dài dòng không cần thiết."""
33
+ log_step(step, input=_clip(input_val), output=_clip(output_val))
34
+
35
+
36
+ class Orchestrator:
37
+ def __init__(self):
38
+ self.parser_agent = ParserAgent()
39
+ self.geometry_agent = GeometryAgent()
40
+ self.ocr_agent = OCRAgent()
41
+ self.knowledge_agent = KnowledgeAgent()
42
+ self.renderer_agent = RendererAgent()
43
+ self.solver_agent = SolverAgent()
44
+ self.solver_engine = GeometryEngine()
45
+ self.dsl_parser = DSLParser()
46
+
47
+ def _generate_step_description(self, semantic_json: Dict[str, Any], engine_result: Dict[str, Any]) -> str:
48
+ """Tạo mô tả từng bước vẽ dựa trên kết quả của engine."""
49
+ analysis = semantic_json.get("analysis", "")
50
+ if not analysis:
51
+ analysis = f"Giải bài toán về {semantic_json.get('type', 'hình học')}."
52
+
53
+ steps = ["\n\n**Các bước dựng hình:**"]
54
+ drawing_phases = engine_result.get("drawing_phases", [])
55
+
56
+ for phase in drawing_phases:
57
+ label = phase.get("label", f"Giai đoạn {phase['phase']}")
58
+ points = ", ".join(phase.get("points", []))
59
+ segments = ", ".join([f"{s[0]}{s[1]}" for s in phase.get("segments", [])])
60
+
61
+ step_text = f"- **{label}**:"
62
+ if points:
63
+ step_text += f" Xác định các điểm {points}."
64
+ if segments:
65
+ step_text += f" Vẽ các đoạn thẳng {segments}."
66
+ steps.append(step_text)
67
+
68
+ circles = engine_result.get("circles", [])
69
+ for c in circles:
70
+ steps.append(f"- **Đường tròn**: Vẽ đường tròn tâm {c['center']} bán kính {c['radius']}.")
71
+
72
+ return analysis + "\n".join(steps)
73
+
74
+ async def run(
75
+ self,
76
+ text: str,
77
+ image_url: str = None,
78
+ job_id: str = None,
79
+ session_id: str = None,
80
+ status_callback=None,
81
+ history: list = None,
82
+ ) -> Dict[str, Any]:
83
+ """
84
+ Run the full pipeline. Optional history allows context-aware solving.
85
+ """
86
+ _step_io(
87
+ "orchestrate_start",
88
+ input_val={
89
+ "job_id": job_id,
90
+ "text_len": len(text or ""),
91
+ "image_url": image_url,
92
+ "history_len": len(history or []),
93
+ },
94
+ output_val=None,
95
+ )
96
+
97
+ if status_callback:
98
+ await status_callback("processing")
99
+
100
+ # 1. Extract context from history (if any)
101
+ previous_context = None
102
+ if history:
103
+ # Look for the last assistant message with geometry data
104
+ for msg in reversed(history):
105
+ if msg.get("role") == "assistant" and msg.get("metadata", {}).get("geometry_dsl"):
106
+ previous_context = {
107
+ "geometry_dsl": msg["metadata"]["geometry_dsl"],
108
+ "coordinates": msg["metadata"].get("coordinates", {}),
109
+ "analysis": msg.get("content", ""),
110
+ }
111
+ break
112
+
113
+ if previous_context:
114
+ _step_io("context_found", input_val=None, output_val={"dsl_len": len(previous_context["geometry_dsl"])})
115
+
116
+ # 2. Gather input text (OCR or direct)
117
+ input_text = text
118
+ if image_url:
119
+ input_text = await ocr_from_image_url(image_url, self.ocr_agent)
120
+ _step_io("step1_ocr", input_val=image_url, output_val=input_text)
121
+ else:
122
+ _step_io("step1_ocr", input_val="(no image)", output_val=text)
123
+
124
+ feedback = None
125
+ MAX_RETRIES = 2
126
+
127
+ for attempt in range(MAX_RETRIES + 1):
128
+ _step_io(
129
+ "attempt",
130
+ input_val=f"{attempt + 1}/{MAX_RETRIES + 1}",
131
+ output_val=None,
132
+ )
133
+ if status_callback:
134
+ await status_callback("solving")
135
+
136
+ # Parser with context
137
+ _step_io("step2_parse", input_val=f"{input_text[:50]}...", output_val=None)
138
+ semantic_json = await self.parser_agent.process(input_text, feedback=feedback, context=previous_context)
139
+ semantic_json["input_text"] = input_text
140
+ _step_io("step2_parse", input_val=None, output_val=semantic_json)
141
+
142
+ # Knowledge augmentation
143
+ _step_io("step3_knowledge", input_val=semantic_json, output_val=None)
144
+ semantic_json = self.knowledge_agent.augment_semantic_data(semantic_json)
145
+ _step_io("step3_knowledge", input_val=None, output_val=semantic_json)
146
+
147
+ # Geometry DSL with context (passing previous DSL to guide generation)
148
+ _step_io("step4_geometry_dsl", input_val=semantic_json, output_val=None)
149
+ dsl_code = await self.geometry_agent.generate_dsl(
150
+ semantic_json,
151
+ previous_dsl=previous_context["geometry_dsl"] if previous_context else None
152
+ )
153
+ _step_io("step4_geometry_dsl", input_val=None, output_val=dsl_code)
154
+
155
+ _step_io("step5_dsl_parse", input_val=dsl_code, output_val=None)
156
+ points, constraints, is_3d = self.dsl_parser.parse(dsl_code)
157
+ _step_io(
158
+ "step5_dsl_parse",
159
+ input_val=None,
160
+ output_val={
161
+ "points": len(points),
162
+ "constraints": len(constraints),
163
+ "is_3d": is_3d,
164
+ },
165
+ )
166
+
167
+ _step_io("step6_solve", input_val=f"{len(points)} pts / {len(constraints)} cons (is_3d={is_3d})", output_val=None)
168
+ import anyio
169
+ engine_result = await anyio.to_thread.run_sync(self.solver_engine.solve, points, constraints, is_3d)
170
+
171
+ if engine_result:
172
+ coordinates = engine_result.get("coordinates")
173
+ _step_io("step6_solve", input_val=None, output_val=coordinates)
174
+ break
175
+
176
+ feedback = "Geometry solver failed to find a valid solution for the given constraints. Parallelism or lengths might be inconsistent."
177
+ _step_io(
178
+ "step6_solve",
179
+ input_val=f"attempt {attempt + 1}",
180
+ output_val=feedback,
181
+ )
182
+ if attempt == MAX_RETRIES:
183
+ _step_io(
184
+ "orchestrate_abort",
185
+ input_val=None,
186
+ output_val="solver_exhausted_retries",
187
+ )
188
+ return {
189
+ "error": "Solver failed after multiple attempts.",
190
+ "last_dsl": dsl_code,
191
+ }
192
+
193
+ _step_io("orchestrate_done", input_val=job_id, output_val="success")
194
+
195
+ # 8. Solution calculation (New in v5.1)
196
+ solution = None
197
+ if engine_result:
198
+ _step_io("step8_solve_math", input_val=semantic_json.get("target_question"), output_val=None)
199
+ solution = await self.solver_agent.solve(semantic_json, engine_result)
200
+ _step_io("step8_solve_math", input_val=None, output_val=solution.get("answer"))
201
+
202
+ final_analysis = self._generate_step_description(semantic_json, engine_result)
203
+
204
+ status = "success"
205
+ return {
206
+ "status": status,
207
+ "job_id": job_id,
208
+ "geometry_dsl": dsl_code,
209
+ "coordinates": coordinates,
210
+ "polygon_order": engine_result.get("polygon_order", []),
211
+ "circles": engine_result.get("circles", []),
212
+ "lines": engine_result.get("lines", []),
213
+ "rays": engine_result.get("rays", []),
214
+ "drawing_phases": engine_result.get("drawing_phases", []),
215
+ "semantic": semantic_json,
216
+ "semantic_analysis": final_analysis,
217
+ "solution": solution,
218
+ "is_3d": is_3d,
219
+ }
agents/parser_agent.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ from openai import AsyncOpenAI
5
+ from typing import Dict, Any
6
+ from dotenv import load_dotenv
7
+
8
+ load_dotenv()
9
+ logger = logging.getLogger(__name__)
10
+
11
+ from app.url_utils import openai_compatible_api_key, sanitize_env
12
+
13
+
14
+ from app.llm_client import get_llm_client
15
+
16
+
17
+ class ParserAgent:
18
+ def __init__(self):
19
+ self.llm = get_llm_client()
20
+
21
+ async def process(self, text: str, feedback: str = None, context: Dict[str, Any] = None) -> Dict[str, Any]:
22
+ logger.info(f"==[ParserAgent] Processing input (len={len(text)})==")
23
+ if feedback:
24
+ logger.warning(f"[ParserAgent] Feedback from previous attempt: {feedback}")
25
+ if context:
26
+ logger.info(f"[ParserAgent] Using previous context (dsl_len={len(context.get('geometry_dsl', ''))})")
27
+
28
+ system_prompt = """
29
+ You are a Geometry Parser Agent. Extract geometric entities and constraints from Vietnamese/LaTeX math problem text.
30
+
31
+ === CONTEXT AWARENESS ===
32
+ If previous context is provided, it means this is a follow-up request.
33
+ - Combine old entities with new ones.
34
+ - Update 'analysis' to reflect the entire problem state.
35
+
36
+ Output ONLY a JSON object with this EXACT structure (no extra keys, no markdown):
37
+ {
38
+ "entities": ["Point A", "Point B", ...],
39
+ "type": "pyramid|prism|sphere|rectangle|triangle|circle|parallelogram|trapezoid|square|rhombus|general",
40
+ "values": {"AB": 5, "SO": 15, "radius": 3},
41
+ "target_question": "Câu hỏi cụ thể cần giải (ví dụ: 'Tính diện tích tam giác ABC'). NẾU KHÔNG CÓ CÂU HỎI THÌ ĐỂ null.",
42
+ "analysis": "Tóm tắt ngắn gọn toàn bộ bài toán sau khi đã cập nhật các yêu cầu mới bằng tiếng Việt."
43
+ }
44
+ Rules:
45
+ - "analysis" MUST be a meaningful and UP-TO-DATE summary of the problem in Vietnamese.
46
+ - "target_question" must be concise.
47
+ - Include midpoints, auxiliary points in "entities" if mentioned.
48
+ - If feedback is provided, correct your previous output accordingly.
49
+ """
50
+
51
+ user_content = f"Text: {text}"
52
+ if context:
53
+ user_content = f"PREVIOUS ANALYSIS: {context.get('analysis')}\nNEW REQUEST: {text}"
54
+
55
+ if feedback:
56
+ user_content += f"\nFeedback from previous attempt: {feedback}. Please correct the constraints."
57
+
58
+ logger.debug("[ParserAgent] Calling LLM (Multi-Layer)...")
59
+ raw = await self.llm.chat_completions_create(
60
+ messages=[
61
+ {"role": "system", "content": system_prompt},
62
+ {"role": "user", "content": user_content}
63
+ ],
64
+ response_format={"type": "json_object"}
65
+ )
66
+
67
+ # Pre-process raw string: extract the JSON block if present
68
+ import re
69
+ clean_raw = raw.strip()
70
+ # Handle potential markdown code blocks
71
+ if clean_raw.startswith("```"):
72
+ import re
73
+ match = re.search(r"```(?:json)?\s*(.*?)\s*```", clean_raw, re.DOTALL)
74
+ if match:
75
+ clean_raw = match.group(1).strip()
76
+
77
+ try:
78
+ result = json.loads(clean_raw)
79
+ except json.JSONDecodeError as e:
80
+ logger.error(f"[ParserAgent] JSON Parse Error: {e}. Attempting regex fallback...")
81
+ import re
82
+ json_match = re.search(r'(\{.*\})', clean_raw, re.DOTALL)
83
+ if json_match:
84
+ try:
85
+ # Handle single quotes if present (common LLM failure)
86
+ json_str = json_match.group(1)
87
+ if "'" in json_str and '"' not in json_str:
88
+ json_str = json_str.replace("'", '"')
89
+ result = json.loads(json_str)
90
+ except:
91
+ result = None
92
+ else:
93
+ result = None
94
+
95
+ if not result:
96
+ # Fallback for critical failure
97
+ result = {
98
+ "entities": [],
99
+ "type": "general",
100
+ "values": {},
101
+ "target_question": None,
102
+ "analysis": text
103
+ }
104
+ logger.info(f"[ParserAgent] LLM response received.")
105
+ logger.debug(f"[ParserAgent] Parsed JSON: {json.dumps(result, ensure_ascii=False, indent=2)}")
106
+ return result
agents/renderer_agent.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import glob
4
+ import string
5
+ import logging
6
+ from typing import Dict, Any, List
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class RendererAgent:
12
+ """
13
+ Renderer Agent — generates Manim scripts from geometry data.
14
+
15
+ Drawing happens in phases:
16
+ Phase 1: Main polygon (base shape with correct vertex order)
17
+ Phase 2: Auxiliary points and segments (midpoints, derived segments)
18
+ Phase 3: Labels for all points
19
+ """
20
+
21
+ def generate_manim_script(self, data: Dict[str, Any]) -> str:
22
+ coords: Dict[str, List[float]] = data.get("coordinates", {})
23
+ polygon_order: List[str] = data.get("polygon_order", [])
24
+ circles_meta: List[Dict] = data.get("circles", [])
25
+ lines_meta: List[List[str]] = data.get("lines", [])
26
+ rays_meta: List[List[str]] = data.get("rays", [])
27
+ drawing_phases: List[Dict] = data.get("drawing_phases", [])
28
+ semantic: Dict[str, Any] = data.get("semantic", {})
29
+ shape_type = semantic.get("type", "").lower()
30
+
31
+ # ── Detect 3D Context ────────────────────────────────────────────────
32
+ is_3d = False
33
+ for pos in coords.values():
34
+ if len(pos) >= 3 and abs(pos[2]) > 0.001:
35
+ is_3d = True
36
+ break
37
+ if shape_type in ["pyramid", "prism", "sphere"]:
38
+ is_3d = True
39
+
40
+ # ── Fallback: infer polygon_order from coords keys (alphabetical uppercase) ──
41
+ if not polygon_order:
42
+ base = sorted(
43
+ [pid for pid in coords if pid in string.ascii_uppercase],
44
+ key=lambda p: string.ascii_uppercase.index(p)
45
+ )
46
+ polygon_order = base
47
+
48
+ # Separate base points from derived (multi-char or lowercase)
49
+ base_ids = [pid for pid in polygon_order if pid in coords]
50
+ derived_ids = [pid for pid in coords if pid not in polygon_order]
51
+
52
+ scene_base = "ThreeDScene" if is_3d else "MovingCameraScene"
53
+ lines = [
54
+ "from manim import *",
55
+ "",
56
+ f"class GeometryScene({scene_base}):",
57
+ " def construct(self):",
58
+ ]
59
+
60
+ if is_3d:
61
+ lines.append(" # 3D Setup")
62
+ lines.append(" self.set_camera_orientation(phi=75*DEGREES, theta=-45*DEGREES)")
63
+ lines.append(" axes = ThreeDAxes(axis_config={'stroke_width': 1})")
64
+ lines.append(" axes.set_opacity(0.3)")
65
+ lines.append(" self.add(axes)")
66
+ lines.append(" self.begin_ambient_camera_rotation(rate=0.1)")
67
+ lines.append("")
68
+
69
+ # ── Declare all dots and labels ───────────────────────────────────────
70
+ for pid, pos in coords.items():
71
+ x, y, z = 0, 0, 0
72
+ if len(pos) >= 1: x = round(pos[0], 4)
73
+ if len(pos) >= 2: y = round(pos[1], 4)
74
+ if len(pos) >= 3: z = round(pos[2], 4)
75
+
76
+ dot_class = "Dot3D" if is_3d else "Dot"
77
+ lines.append(f" p_{pid} = {dot_class}(point=[{x}, {y}, {z}], color=WHITE, radius=0.08)")
78
+
79
+ if is_3d:
80
+ lines.append(
81
+ f" l_{pid} = Text('{pid}', font_size=20, color=WHITE)"
82
+ f".move_to(p_{pid}.get_center() + [0.2, 0.2, 0.2])"
83
+ )
84
+ # Ensure labels follow camera in 3D (fixed orientation)
85
+ lines.append(f" self.add_fixed_orientation_mobjects(l_{pid})")
86
+ else:
87
+ lines.append(
88
+ f" l_{pid} = Text('{pid}', font_size=22, color=WHITE)"
89
+ f".next_to(p_{pid}, UR, buff=0.15)"
90
+ )
91
+
92
+ # ── 3D Shape Special: Pyramid/Prism Faces ────────────────────────────
93
+ if is_3d and shape_type == "pyramid" and len(base_ids) >= 3:
94
+ # Find apex (usually 'S')
95
+ apex_id = "S" if "S" in coords else derived_ids[0] if derived_ids else None
96
+ if apex_id:
97
+ # Draw base face
98
+ base_pts = ", ".join([f"p_{pid}.get_center()" for pid in base_ids])
99
+ lines.append(f" base_face = Polygon({base_pts}, color=BLUE, fill_opacity=0.1)")
100
+ lines.append(" self.play(Create(base_face), run_time=1.0)")
101
+
102
+ # Draw side faces
103
+ for i in range(len(base_ids)):
104
+ p1 = base_ids[i]
105
+ p2 = base_ids[(i+1)%len(base_ids)]
106
+ face_pts = f"p_{apex_id}.get_center(), p_{p1}.get_center(), p_{p2}.get_center()"
107
+ lines.append(f" side_{i} = Polygon({face_pts}, color=BLUE, stroke_width=1, fill_opacity=0.05)")
108
+ lines.append(f" self.play(Create(side_{i}), run_time=0.5)")
109
+
110
+ # ── Circles ──────────────────────────────────────────────────────────
111
+ for i, c in enumerate(circles_meta):
112
+ center = c["center"]
113
+ r = c["radius"]
114
+ if center in coords:
115
+ cx, cy, cz = 0, 0, 0
116
+ pos = coords[center]
117
+ if len(pos) >= 1: cx = round(pos[0], 4)
118
+ if len(pos) >= 2: cy = round(pos[1], 4)
119
+ if len(pos) >= 3: cz = round(pos[2], 4)
120
+ lines.append(
121
+ f" circle_{i} = Circle(radius={r}, color=BLUE)"
122
+ f".move_to([{cx}, {cy}, {cz}])"
123
+ )
124
+
125
+ # ── Infinite Lines & Rays ────────────────────────────────────────────
126
+ # (Standard Line works for 3D coordinates in Manim)
127
+ for i, (p1, p2) in enumerate(lines_meta):
128
+ if p1 in coords and p2 in coords:
129
+ lines.append(
130
+ f" line_ext_{i} = Line(p_{p1}.get_center(), p_{p2}.get_center(), color=GRAY_D, stroke_width=2)"
131
+ f".scale(20)"
132
+ )
133
+
134
+ for i, (p1, p2) in enumerate(rays_meta):
135
+ if p1 in coords and p2 in coords:
136
+ lines.append(
137
+ f" ray_{i} = Line(p_{p1}.get_center(), p_{p1}.get_center() + 15 * (p_{p2}.get_center() - p_{p1}.get_center()),"
138
+ f" color=GRAY_C, stroke_width=2)"
139
+ )
140
+
141
+ # ── Camera auto-fit group (Only for 2D) ──────────────────────────────
142
+ if not is_3d:
143
+ all_dot_names = [f"p_{pid}" for pid in coords]
144
+ all_names_str = ", ".join(all_dot_names)
145
+ lines.append(f" _all = VGroup({all_names_str})")
146
+ lines.append(" self.camera.frame.set_width(max(_all.width * 2.0, 8))")
147
+ lines.append(" self.camera.frame.move_to(_all)")
148
+ lines.append("")
149
+
150
+ # ── Phase 1: Base polygon ─────────────────────────────────────────────
151
+ if len(base_ids) >= 3:
152
+ pts_str = ", ".join([f"p_{pid}.get_center()" for pid in base_ids])
153
+ lines.append(f" poly = Polygon({pts_str}, color=BLUE, fill_color=BLUE, fill_opacity=0.15)")
154
+ lines.append(" self.play(Create(poly), run_time=1.5)")
155
+ elif len(base_ids) == 2:
156
+ p1, p2 = base_ids
157
+ lines.append(f" base_line = Line(p_{p1}.get_center(), p_{p2}.get_center(), color=BLUE)")
158
+ lines.append(" self.play(Create(base_line), run_time=1.0)")
159
+
160
+ # Draw base points
161
+ if base_ids:
162
+ base_dots_str = ", ".join([f"p_{pid}" for pid in base_ids])
163
+ lines.append(f" self.play(FadeIn(VGroup({base_dots_str})), run_time=0.5)")
164
+ lines.append(" self.wait(0.5)")
165
+
166
+ # ── Phase 2: Auxiliary points and segments ────────────────────────────
167
+ if derived_ids:
168
+ derived_dots_str = ", ".join([f"p_{pid}" for pid in derived_ids])
169
+ lines.append(f" self.play(FadeIn(VGroup({derived_dots_str})), run_time=0.8)")
170
+
171
+ # Segments from drawing_phases
172
+ segment_lines = []
173
+ for phase in drawing_phases:
174
+ if phase.get("phase") == 2:
175
+ for seg in phase.get("segments", []):
176
+ if len(seg) == 2 and seg[0] in coords and seg[1] in coords:
177
+ p1, p2 = seg[0], seg[1]
178
+ seg_var = f"seg_{p1}_{p2}"
179
+ lines.append(
180
+ f" {seg_var} = Line(p_{p1}.get_center(), p_{p2}.get_center(),"
181
+ f" color=YELLOW)"
182
+ )
183
+ segment_lines.append(seg_var)
184
+
185
+ if segment_lines:
186
+ segs_str = ", ".join([f"Create({sv})" for sv in segment_lines])
187
+ lines.append(f" self.play({segs_str}, run_time=1.2)")
188
+
189
+ if derived_ids or segment_lines:
190
+ lines.append(" self.wait(0.5)")
191
+
192
+ # ── Phase 3: All labels ───────────────────────────────────────────────
193
+ all_labels_str = ", ".join([f"l_{pid}" for pid in coords])
194
+ lines.append(f" self.play(FadeIn(VGroup({all_labels_str})), run_time=0.8)")
195
+
196
+ # ── Circles phase ─────────────────────────────────────────────────────
197
+ for i in range(len(circles_meta)):
198
+ lines.append(f" self.play(Create(circle_{i}), run_time=1.5)")
199
+
200
+ # ── Lines & Rays phase ────────────────────────────────────────────────
201
+ if lines_meta or rays_meta:
202
+ lr_anims = []
203
+ for i in range(len(lines_meta)): lr_anims.append(f"Create(line_ext_{i})")
204
+ for i in range(len(rays_meta)): lr_anims.append(f"Create(ray_{i})")
205
+ lines.append(f" self.play({', '.join(lr_anims)}, run_time=1.5)")
206
+
207
+ lines.append(" self.wait(2)")
208
+
209
+ return "\n".join(lines)
210
+
211
+ def run_manim(self, script_content: str, job_id: str) -> str:
212
+ import subprocess
213
+ import os
214
+ import glob
215
+
216
+ script_file = f"{job_id}.py"
217
+ with open(script_file, "w") as f:
218
+ f.write(script_content)
219
+
220
+ try:
221
+ if os.getenv("MOCK_VIDEO") == "true":
222
+ logger.info(f"MOCK_VIDEO is true. Skipping Manim for job {job_id}")
223
+ # Create a dummy file if needed, or just return a path that exists
224
+ dummy_path = f"videos/{job_id}.mp4"
225
+ os.makedirs("videos", exist_ok=True)
226
+ with open(dummy_path, "wb") as f:
227
+ f.write(b"dummy video content")
228
+ return dummy_path
229
+
230
+ # Determine manim executable path
231
+ manim_exe = "manim"
232
+ venv_manim = os.path.join(os.getcwd(), "venv", "bin", "manim")
233
+ if os.path.exists(venv_manim):
234
+ manim_exe = venv_manim
235
+
236
+ # Prepare environment with homebrew paths
237
+ custom_env = os.environ.copy()
238
+ brew_path = "/opt/homebrew/bin:/usr/local/bin"
239
+ custom_env["PATH"] = f"{brew_path}:{custom_env.get('PATH', '')}"
240
+
241
+ logger.info(f"Running {manim_exe} for job {job_id}...")
242
+ result = subprocess.run(
243
+ [manim_exe, "-ql", "--media_dir", ".", "-o", f"{job_id}.mp4", script_file, "GeometryScene"],
244
+ capture_output=True,
245
+ text=True,
246
+ env=custom_env
247
+ )
248
+ logger.info(f"Manim STDOUT: {result.stdout}")
249
+ if result.returncode != 0:
250
+ logger.error(f"Manim STDERR: {result.stderr}")
251
+
252
+ for pattern in [f"**/videos/**/{job_id}.mp4", f"**/{job_id}*.mp4"]:
253
+ found = glob.glob(pattern, recursive=True)
254
+ if found:
255
+ logger.info(f"Manim Success: Found {found[0]}")
256
+ return found[0]
257
+
258
+ logger.error(f"Manim file not found for job {job_id}. Return code: {result.returncode}")
259
+ return ""
260
+ except Exception as e:
261
+ logger.exception(f"Manim Execution Error: {e}")
262
+ return ""
263
+ finally:
264
+ if os.path.exists(script_file):
265
+ os.remove(script_file)
agents/solver_agent.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import sympy as sp
4
+ from typing import Dict, Any, List
5
+ from app.llm_client import get_llm_client
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class SolverAgent:
10
+ def __init__(self):
11
+ self.llm = get_llm_client()
12
+
13
+ async def solve(self, semantic_data: Dict[str, Any], engine_result: Dict[str, Any]) -> Dict[str, Any]:
14
+ """
15
+ Solves the geometric problem based on coordinates and the target question.
16
+ Returns a 'solution' dictionary with answer, steps, and symbolic_expression.
17
+ """
18
+ target_question = semantic_data.get("target_question")
19
+ if not target_question:
20
+ # If no question, just return an empty solution structure
21
+ return {
22
+ "answer": None,
23
+ "steps": [],
24
+ "symbolic_expression": None
25
+ }
26
+
27
+ logger.info(f"==[SolverAgent] Solving for: '{target_question}'==")
28
+
29
+ input_text = semantic_data.get("input_text", "")
30
+ coordinates = engine_result.get("coordinates", {})
31
+
32
+ # We provide the coordinates and semantic context to the LLM to help it reason.
33
+ # The LLM is tasked with generating the solution structure directly.
34
+
35
+ system_prompt = """
36
+ You are a Geometry Solver Agent. Your goal is to provide a step-by-step solution for a specific geometric question.
37
+
38
+ === DATA PROVIDED ===
39
+ 1. Target Question: The specific question to answer.
40
+ 2. Geometry Data: Entities and values extracted from the problem.
41
+ 3. Coordinates: Calculated coordinates for all points.
42
+
43
+ === REQUIREMENTS ===
44
+ - Provide the solution in the SAME LANGUAGE as the user's input.
45
+ - Use SymPy concepts if appropriate.
46
+ - Steps should be clear, concise, and logical.
47
+ - The final answer should be numerically or symbolically accurate based on the coordinates and geometric properties.
48
+ - For geometric proofs (e.g., "Is AB perpendicular to AC?"), explain the reasoning based on the data.
49
+
50
+ Output ONLY a JSON object with this structure:
51
+ {
52
+ "answer": "Chuỗi văn bản kết quả cuối cùng (kèm đơn vị nếu có)",
53
+ "steps": [
54
+ "Bước 1: ...",
55
+ "Bước 2: ...",
56
+ ...
57
+ ],
58
+ "symbolic_expression": "Biểu thức toán học rút gọn (LaTeX format optional)"
59
+ }
60
+ """
61
+
62
+ user_content = f"""
63
+ INPUT_TEXT: {input_text}
64
+ TARGET_QUESTION: {target_question}
65
+ SEMANTIC_DATA: {json.dumps(semantic_data, ensure_ascii=False)}
66
+ COORDINATES: {json.dumps(coordinates)}
67
+ """
68
+
69
+ logger.debug("[SolverAgent] Requesting solution from LLM...")
70
+ try:
71
+ raw = await self.llm.chat_completions_create(
72
+ messages=[
73
+ {"role": "system", "content": system_prompt},
74
+ {"role": "user", "content": user_content}
75
+ ],
76
+ response_format={"type": "json_object"}
77
+ )
78
+
79
+ clean_raw = raw.strip()
80
+ # Handle potential markdown code blocks if the response_format wasn't strictly honored
81
+ if clean_raw.startswith("```"):
82
+ import re
83
+ match = re.search(r"```(?:json)?\s*(.*?)\s*```", clean_raw, re.DOTALL)
84
+ if match:
85
+ clean_raw = match.group(1).strip()
86
+
87
+ try:
88
+ solution = json.loads(clean_raw)
89
+ except json.JSONDecodeError:
90
+ # Last resort: try to find anything between { and }
91
+ import re
92
+ json_match = re.search(r'(\{.*\})', clean_raw, re.DOTALL)
93
+ if json_match:
94
+ solution = json.loads(json_match.group(1))
95
+ else:
96
+ raise
97
+
98
+ logger.info("[SolverAgent] Solution generated successfully.")
99
+ return solution
100
+ except Exception as e:
101
+ logger.error(f"[SolverAgent] Error generating solution: {e}")
102
+ logger.debug(f"[SolverAgent] Raw LLM output was: \n{raw if 'raw' in locals() else 'N/A'}")
103
+ return {
104
+ "answer": "Không thể tính toán lời giải tại thời điểm này.",
105
+ "steps": ["Đã xảy ra lỗi trong quá trình xử lý lời giải."],
106
+ "symbolic_expression": None
107
+ }
agents/torch_ultralytics_compat.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch 2.6+ defaults weights_only=True; Ultralytics YOLO .pt checkpoints unpickle full nn graphs (trusted official weights)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+
7
+ _torch_load_patched = False
8
+
9
+
10
+ def allow_ultralytics_weights() -> None:
11
+ """
12
+ Official yolov8n.pt is a trusted checkpoint. PyTorch 2.6+ safe unpickling would require
13
+ allowlisting many torch.nn globals; loading with weights_only=False matches Ultralytics
14
+ upstream behavior for local .pt files.
15
+ """
16
+ global _torch_load_patched
17
+ if _torch_load_patched:
18
+ return
19
+ try:
20
+ import torch
21
+
22
+ _orig = torch.load
23
+
24
+ @functools.wraps(_orig)
25
+ def _load(*args, **kwargs):
26
+ if "weights_only" not in kwargs:
27
+ kwargs["weights_only"] = False
28
+ return _orig(*args, **kwargs)
29
+
30
+ torch.load = _load
31
+ _torch_load_patched = True
32
+ except Exception:
33
+ pass
app/chat_image_upload.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Validate and upload chat/solve attachment images to Supabase Storage (image bucket)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ import uuid
8
+ from typing import Any, Dict, Tuple
9
+
10
+ from fastapi import HTTPException
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def _get_next_image_version(session_id: str) -> int:
16
+ """Same logic as worker.asset_manager.get_next_version for asset_type image."""
17
+ from app.supabase_client import get_supabase
18
+
19
+ supabase = get_supabase()
20
+ try:
21
+ res = (
22
+ supabase.table("session_assets")
23
+ .select("version")
24
+ .eq("session_id", session_id)
25
+ .eq("asset_type", "image")
26
+ .order("version", desc=True)
27
+ .limit(1)
28
+ .execute()
29
+ )
30
+ if res.data:
31
+ return res.data[0]["version"] + 1
32
+ return 1
33
+ except Exception as e:
34
+ logger.error("Error fetching image version: %s", e)
35
+ return 1
36
+
37
+ _MAX_BYTES_DEFAULT = 10 * 1024 * 1024
38
+
39
+ _EXT_TO_MIME: dict[str, str] = {
40
+ ".png": "image/png",
41
+ ".jpg": "image/jpeg",
42
+ ".jpeg": "image/jpeg",
43
+ ".webp": "image/webp",
44
+ ".gif": "image/gif",
45
+ ".bmp": "image/bmp",
46
+ }
47
+
48
+
49
+ def _max_bytes() -> int:
50
+ raw = os.getenv("CHAT_IMAGE_MAX_BYTES")
51
+ if raw and raw.isdigit():
52
+ return min(int(raw), 50 * 1024 * 1024)
53
+ return _MAX_BYTES_DEFAULT
54
+
55
+
56
+ def _magic_ok(ext: str, body: bytes) -> bool:
57
+ if len(body) < 12:
58
+ return False
59
+ if ext == ".png":
60
+ return body.startswith(b"\x89PNG\r\n\x1a\n")
61
+ if ext in (".jpg", ".jpeg"):
62
+ return body.startswith(b"\xff\xd8\xff")
63
+ if ext == ".webp":
64
+ return body.startswith(b"RIFF") and body[8:12] == b"WEBP"
65
+ if ext == ".gif":
66
+ return body.startswith(b"GIF87a") or body.startswith(b"GIF89a")
67
+ if ext == ".bmp":
68
+ return body.startswith(b"BM")
69
+ return False
70
+
71
+
72
+ def validate_chat_image_bytes(
73
+ filename: str | None,
74
+ body: bytes,
75
+ declared_content_type: str | None,
76
+ ) -> Tuple[str, str]:
77
+ """
78
+ Validate size, extension, and magic bytes.
79
+ Returns (extension_with_dot, content_type).
80
+ """
81
+ max_b = _max_bytes()
82
+ if not body:
83
+ raise HTTPException(status_code=400, detail="Empty file.")
84
+ if len(body) > max_b:
85
+ raise HTTPException(
86
+ status_code=413,
87
+ detail=f"Image too large (max {max_b // (1024 * 1024)} MB).",
88
+ )
89
+
90
+ ext = os.path.splitext(filename or "")[1].lower()
91
+ if not ext:
92
+ ext = ".png"
93
+ if ext not in _EXT_TO_MIME:
94
+ raise HTTPException(
95
+ status_code=400,
96
+ detail=f"Unsupported image type: {ext}. Allowed: {', '.join(sorted(_EXT_TO_MIME))}",
97
+ )
98
+
99
+ if not _magic_ok(ext, body):
100
+ raise HTTPException(
101
+ status_code=400,
102
+ detail="File content does not match declared image type.",
103
+ )
104
+
105
+ mime = _EXT_TO_MIME[ext]
106
+ if declared_content_type:
107
+ decl = declared_content_type.split(";")[0].strip().lower()
108
+ if decl and decl not in ("application/octet-stream", mime) and decl != mime:
109
+ logger.warning(
110
+ "Content-Type mismatch (declared=%s, inferred=%s); using inferred.",
111
+ declared_content_type,
112
+ mime,
113
+ )
114
+ return ext, mime
115
+
116
+
117
+ def upload_session_chat_image(
118
+ session_id: str,
119
+ job_id: str,
120
+ file_bytes: bytes,
121
+ ext_with_dot: str,
122
+ content_type: str,
123
+ ) -> Dict[str, Any]:
124
+ """
125
+ Upload to SUPABASE_IMAGE_BUCKET (default: image), insert session_assets row.
126
+ Returns dict with public_url, storage_path, version, session_asset_id (if returned).
127
+ """
128
+ from app.supabase_client import get_supabase
129
+
130
+ supabase = get_supabase()
131
+ bucket_name = os.getenv("SUPABASE_IMAGE_BUCKET", "image")
132
+ raw_ext = ext_with_dot.lstrip(".").lower()
133
+ version = _get_next_image_version(session_id)
134
+ file_name = f"image_v{version}_{job_id}.{raw_ext}"
135
+ storage_path = f"sessions/{session_id}/{file_name}"
136
+
137
+ supabase.storage.from_(bucket_name).upload(
138
+ path=storage_path,
139
+ file=file_bytes,
140
+ file_options={"content-type": content_type},
141
+ )
142
+ public_url = supabase.storage.from_(bucket_name).get_public_url(storage_path)
143
+ if isinstance(public_url, dict):
144
+ public_url = public_url.get("publicUrl") or public_url.get("public_url") or str(public_url)
145
+
146
+ row = {
147
+ "session_id": session_id,
148
+ "job_id": job_id,
149
+ "asset_type": "image",
150
+ "storage_path": storage_path,
151
+ "public_url": public_url,
152
+ "version": version,
153
+ }
154
+ ins = supabase.table("session_assets").insert(row).select("id").execute()
155
+ asset_id = None
156
+ if ins.data and len(ins.data) > 0:
157
+ asset_id = ins.data[0].get("id")
158
+
159
+ log_data = {
160
+ "public_url": public_url,
161
+ "storage_path": storage_path,
162
+ "version": version,
163
+ "session_asset_id": str(asset_id) if asset_id else None,
164
+ }
165
+ logger.info("Uploaded chat image: %s", log_data)
166
+ return {
167
+ "public_url": public_url,
168
+ "storage_path": storage_path,
169
+ "version": version,
170
+ "session_asset_id": str(asset_id) if asset_id else None,
171
+ }
172
+
173
+
174
+ def upload_ephemeral_ocr_blob(
175
+ file_bytes: bytes,
176
+ ext_with_dot: str,
177
+ content_type: str,
178
+ ) -> Tuple[str, str]:
179
+ """
180
+ Upload bytes to image bucket under _ocr_temp/ for worker-only OCR (no session_assets row).
181
+ Returns (storage_path, public_url). Caller must delete_storage_object when done.
182
+ """
183
+ from app.supabase_client import get_supabase
184
+
185
+ bucket_name = os.getenv("SUPABASE_IMAGE_BUCKET", "image")
186
+ raw_ext = ext_with_dot.lstrip(".").lower() or "png"
187
+ name = f"_ocr_temp/{uuid.uuid4().hex}.{raw_ext}"
188
+ supabase = get_supabase()
189
+ supabase.storage.from_(bucket_name).upload(
190
+ path=name,
191
+ file=file_bytes,
192
+ file_options={"content-type": content_type},
193
+ )
194
+ public_url = supabase.storage.from_(bucket_name).get_public_url(name)
195
+ if isinstance(public_url, dict):
196
+ public_url = public_url.get("publicUrl") or public_url.get("public_url") or str(public_url)
197
+ return name, public_url
198
+
199
+
200
+ def delete_storage_object(bucket_name: str, storage_path: str) -> None:
201
+ try:
202
+ from app.supabase_client import get_supabase
203
+
204
+ get_supabase().storage.from_(bucket_name).remove([storage_path])
205
+ except Exception as e:
206
+ logger.warning("delete_storage_object failed path=%s: %s", storage_path, e)
app/dependencies.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import HTTPException, Header
2
+
3
+ from app.supabase_client import get_supabase, get_supabase_for_user_jwt
4
+
5
+
6
+ async def get_current_user_id(authorization: str | None = Header(None)):
7
+ """
8
+ Authenticate user using Supabase JWT.
9
+ Expected Header: Authorization: Bearer <token>
10
+ """
11
+ import os
12
+
13
+ if not authorization:
14
+ raise HTTPException(
15
+ status_code=401,
16
+ detail="Authorization header missing or invalid. Use 'Bearer <token>'",
17
+ )
18
+
19
+ if os.getenv("ALLOW_TEST_BYPASS") == "true" and authorization.startswith("Test "):
20
+ return authorization.split(" ")[1]
21
+
22
+ if not authorization.startswith("Bearer "):
23
+ raise HTTPException(
24
+ status_code=401,
25
+ detail="Authorization header missing or invalid. Use 'Bearer <token>'",
26
+ )
27
+
28
+ token = authorization.split(" ")[1]
29
+ supabase = get_supabase()
30
+
31
+ try:
32
+ user_response = supabase.auth.get_user(token)
33
+ if not user_response or not user_response.user:
34
+ raise HTTPException(status_code=401, detail="Invalid session or token.")
35
+
36
+ return user_response.user.id
37
+ except HTTPException:
38
+ raise
39
+ except Exception as e:
40
+ raise HTTPException(status_code=401, detail=f"Authentication failed: {str(e)}")
41
+
42
+
43
+ async def get_authenticated_supabase(authorization: str = Header(...)):
44
+ """
45
+ Supabase client that carries the user's JWT (anon key + Authorization header).
46
+ Use for routes that should respect Row Level Security; pair with app logic as needed.
47
+ """
48
+ if not authorization or not authorization.startswith("Bearer "):
49
+ raise HTTPException(
50
+ status_code=401,
51
+ detail="Authorization header missing or invalid. Use 'Bearer <token>'",
52
+ )
53
+
54
+ token = authorization.split(" ")[1]
55
+ supabase = get_supabase()
56
+
57
+ try:
58
+ user_response = supabase.auth.get_user(token)
59
+ if not user_response or not user_response.user:
60
+ raise HTTPException(status_code=401, detail="Invalid session or token.")
61
+ except HTTPException:
62
+ raise
63
+ except Exception as e:
64
+ raise HTTPException(status_code=401, detail=f"Authentication failed: {str(e)}")
65
+
66
+ try:
67
+ return get_supabase_for_user_jwt(token)
68
+ except RuntimeError as e:
69
+ raise HTTPException(status_code=503, detail=str(e))
app/errors.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Map exceptions to short, user-visible messages (avoid leaking HTML bodies from 404 proxies)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def _looks_like_html(text: str) -> bool:
11
+ t = text.lstrip()[:500].lower()
12
+ return t.startswith("<!doctype") or t.startswith("<html") or "<html" in t[:200]
13
+
14
+
15
+ def format_error_for_user(exc: BaseException) -> str:
16
+ """
17
+ Produce a safe message for chat/UI. Full detail stays in server logs via logger.exception.
18
+ """
19
+ # httpx: wrong URL often returns 404 HTML; don't show body
20
+ try:
21
+ import httpx
22
+
23
+ if isinstance(exc, httpx.HTTPStatusError):
24
+ req = exc.request
25
+ code = exc.response.status_code
26
+ url_hint = ""
27
+ try:
28
+ url_hint = str(req.url.host) if req and req.url else ""
29
+ except Exception:
30
+ pass
31
+ logger.warning(
32
+ "HTTPStatusError %s for %s (response not shown to user)",
33
+ code,
34
+ url_hint or "?",
35
+ )
36
+ return (
37
+ "Kiểm tra URL API, khóa bí mật và biến môi trường (OpenRouter/Supabase/Redis)."
38
+ )
39
+
40
+ if isinstance(exc, httpx.RequestError):
41
+ return "Không kết nối được tới dịch vụ ngoài (mạng hoặc URL sai)."
42
+ except ImportError:
43
+ pass
44
+
45
+ raw = str(exc).strip()
46
+ if not raw:
47
+ return "Đã xảy ra lỗi không xác định."
48
+
49
+ if _looks_like_html(raw):
50
+ logger.warning("Suppressed HTML error body from user-facing message")
51
+ return (
52
+ "Dịch vụ trả về trang lỗi (thường là URL API sai hoặc endpoint không tồn tại — HTTP 404). "
53
+ "Kiểm tra OPENROUTER_MODEL và khóa API trên server."
54
+ )
55
+
56
+ if len(raw) > 800:
57
+ return raw[:800] + "…"
58
+
59
+ return raw
app/llm_client.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import asyncio
4
+ import logging
5
+ from openai import AsyncOpenAI
6
+ from typing import List, Dict, Any, Optional
7
+ from app.url_utils import openai_compatible_api_key, sanitize_env
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class MultiLayerLLMClient:
12
+ def __init__(self):
13
+ # 1. Models sequence loading
14
+ self.models = []
15
+ for i in range(1, 4):
16
+ model = os.getenv(f"OPENROUTER_MODEL_{i}")
17
+ if model:
18
+ self.models.append(model)
19
+
20
+ # Fallback to legacy OPENROUTER_MODEL if no numbered models found
21
+ if not self.models:
22
+ legacy_model = os.getenv("OPENROUTER_MODEL", "google/gemini-2.0-flash-001")
23
+ self.models = [legacy_model]
24
+
25
+ # 2. Key selection (No rotation, always use the first available key)
26
+ api_key = os.getenv("OPENROUTER_API_KEY_1") or os.getenv("OPENROUTER_API_KEY")
27
+
28
+ if not api_key:
29
+ logger.error("[LLM] No OpenRouter API key found.")
30
+ self.client = None
31
+ else:
32
+ self.client = AsyncOpenAI(
33
+ api_key=openai_compatible_api_key(api_key),
34
+ base_url="https://openrouter.ai/api/v1",
35
+ timeout=60.0,
36
+ default_headers={
37
+ "HTTP-Referer": "https://mathsolver.ai",
38
+ "X-Title": "MathSolver Backend",
39
+ }
40
+ )
41
+
42
+ async def chat_completions_create(
43
+ self,
44
+ messages: List[Dict[str, str]],
45
+ response_format: Optional[Dict[str, str]] = None,
46
+ **kwargs
47
+ ) -> str:
48
+ """
49
+ Implements Model Fallback Sequence: Model 1 -> Model 2 -> Model 3.
50
+ Always starts from Model 1 for every new call.
51
+ """
52
+ if not self.client:
53
+ raise ValueError("No API client configured. Check your API keys.")
54
+
55
+ MAX_ATTEMPTS = len(self.models)
56
+ RETRY_DELAY = 1.0 # second
57
+
58
+ for attempt_idx in range(MAX_ATTEMPTS):
59
+ current_model = self.models[attempt_idx]
60
+ attempt_num = attempt_idx + 1
61
+
62
+ try:
63
+ logger.info(f"[LLM] Attempt {attempt_num}/{MAX_ATTEMPTS} using Model: {current_model}...")
64
+
65
+ response = await self.client.chat.completions.create(
66
+ model=current_model,
67
+ messages=messages,
68
+ response_format=response_format,
69
+ **kwargs
70
+ )
71
+
72
+ if not response or not getattr(response, "choices", None):
73
+ raise ValueError(f"Invalid response structure from model {current_model}")
74
+
75
+ content = response.choices[0].message.content
76
+ if content:
77
+ logger.info(f"[LLM] SUCCESS on attempt {attempt_num} ({current_model}).")
78
+ return content
79
+
80
+ raise ValueError(f"Empty content from model {current_model}")
81
+
82
+ except Exception as e:
83
+ err_msg = f"{type(e).__name__}: {str(e)}"
84
+ logger.warning(f"[LLM] FAILED on attempt {attempt_num} ({current_model}): {err_msg}")
85
+
86
+ if attempt_num < MAX_ATTEMPTS:
87
+ logger.info(f"[LLM] Retrying next model in {RETRY_DELAY}s...")
88
+ await asyncio.sleep(RETRY_DELAY)
89
+ else:
90
+ logger.error(f"[LLM] FINAL FAILURE after {attempt_num} models.")
91
+ raise e
92
+
93
+ # Global instance for easy reuse (singleton-ish)
94
+ _llm_client = None
95
+
96
+ def get_llm_client() -> MultiLayerLLMClient:
97
+ global _llm_client
98
+ if _llm_client is None:
99
+ _llm_client = MultiLayerLLMClient()
100
+ return _llm_client
app/logging_setup.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Logging theo một biến LOG_LEVEL: debug | info | warning | error."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import Final
8
+
9
+ _SETUP_DONE = False
10
+
11
+ PIPELINE_LOGGER_NAME: Final = "app.pipeline"
12
+ CACHE_LOGGER_NAME: Final = "app.cache"
13
+ STEPS_LOGGER_NAME: Final = "app.steps"
14
+ ACCESS_LOGGER_NAME: Final = "app.access"
15
+
16
+
17
+ def _normalize_level() -> str:
18
+ raw = os.getenv("LOG_LEVEL", "info").strip().lower()
19
+ if raw in ("debug", "info", "warning", "error"):
20
+ return raw
21
+ return "info"
22
+
23
+
24
+ def setup_application_logging() -> None:
25
+ """Idempotent; gọi khi khởi động process (uvicorn, celery, worker_health)."""
26
+ global _SETUP_DONE
27
+ if _SETUP_DONE:
28
+ return
29
+ _SETUP_DONE = True
30
+
31
+ mode = _normalize_level()
32
+
33
+ level_map = {
34
+ "debug": logging.DEBUG,
35
+ "info": logging.INFO,
36
+ "warning": logging.WARNING,
37
+ "error": logging.ERROR,
38
+ }
39
+ root_level = level_map[mode]
40
+
41
+ fmt_named = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
42
+ fmt_short = "%(asctime)s | %(levelname)-8s | %(message)s"
43
+
44
+ logging.basicConfig(
45
+ level=root_level,
46
+ format=fmt_named if mode == "debug" else fmt_short,
47
+ datefmt="%H:%M:%S",
48
+ force=True,
49
+ )
50
+
51
+ logging.getLogger("httpx").setLevel(logging.WARNING)
52
+ logging.getLogger("httpcore").setLevel(logging.WARNING)
53
+ logging.getLogger("openai").setLevel(logging.WARNING)
54
+ logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
55
+ logging.getLogger("uvicorn.error").setLevel(logging.INFO)
56
+ # HTTP/2 stack (httpx/httpcore) — khi LOG_LEVEL=debug root=DEBUG sẽ tràn log hpack; không cần cho debug app
57
+ for _name in ("hpack", "h2", "hyperframe", "urllib3"):
58
+ logging.getLogger(_name).setLevel(logging.WARNING)
59
+
60
+ if mode == "debug":
61
+ logging.getLogger("agents").setLevel(logging.DEBUG)
62
+ logging.getLogger("solver").setLevel(logging.DEBUG)
63
+ logging.getLogger("app").setLevel(logging.DEBUG)
64
+ logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.DEBUG)
65
+ logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.DEBUG)
66
+ logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.INFO)
67
+ logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.INFO)
68
+ logging.getLogger("app.main").setLevel(logging.INFO)
69
+ logging.getLogger("worker").setLevel(logging.INFO)
70
+ elif mode == "info":
71
+ # Chỉ HTTP access (app.access) + startup; ẩn chi tiết agents/orchestrator/pipeline SUCCESS
72
+ logging.getLogger("agents").setLevel(logging.INFO)
73
+ logging.getLogger("solver").setLevel(logging.WARNING)
74
+ logging.getLogger("app").setLevel(logging.INFO)
75
+ logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.WARNING)
76
+ logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.WARNING)
77
+ logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.WARNING)
78
+ logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.INFO)
79
+ logging.getLogger("app.main").setLevel(logging.INFO)
80
+ logging.getLogger("worker").setLevel(logging.WARNING)
81
+ elif mode == "warning":
82
+ logging.getLogger("agents").setLevel(logging.WARNING)
83
+ logging.getLogger("solver").setLevel(logging.WARNING)
84
+ logging.getLogger("app.routers").setLevel(logging.WARNING)
85
+ logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.WARNING)
86
+ logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.WARNING)
87
+ logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.WARNING)
88
+ logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.WARNING)
89
+ logging.getLogger("app.main").setLevel(logging.WARNING)
90
+ logging.getLogger("worker").setLevel(logging.WARNING)
91
+ else: # error
92
+ logging.getLogger("agents").setLevel(logging.ERROR)
93
+ logging.getLogger("solver").setLevel(logging.ERROR)
94
+ logging.getLogger("app.routers").setLevel(logging.ERROR)
95
+ logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.ERROR)
96
+ logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.ERROR)
97
+ logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.ERROR)
98
+ logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.ERROR)
99
+ logging.getLogger("app.main").setLevel(logging.ERROR)
100
+ logging.getLogger("worker").setLevel(logging.ERROR)
101
+
102
+ logging.getLogger(__name__).debug(
103
+ "LOG_LEVEL=%s root=%s", mode, logging.getLevelName(root_level)
104
+ )
105
+
106
+
107
+ def get_log_level() -> str:
108
+ return _normalize_level()
109
+
110
+
111
+ def is_debug_level() -> bool:
112
+ return _normalize_level() == "debug"
app/logutil.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """log_step (debug), pipeline (debug), access log ở middleware."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+ from typing import Any
9
+
10
+ from app.logging_setup import PIPELINE_LOGGER_NAME, STEPS_LOGGER_NAME
11
+
12
+ _pipeline = logging.getLogger(PIPELINE_LOGGER_NAME)
13
+ _steps = logging.getLogger(STEPS_LOGGER_NAME)
14
+
15
+
16
+ def is_debug_mode() -> bool:
17
+ """Chi tiết từng bước chỉ khi LOG_LEVEL=debug."""
18
+ return os.getenv("LOG_LEVEL", "info").strip().lower() == "debug"
19
+
20
+
21
+ def _truncate(val: Any, max_len: int = 2000) -> Any:
22
+ if val is None:
23
+ return None
24
+ if isinstance(val, (int, float, bool)):
25
+ return val
26
+ s = str(val)
27
+ if len(s) > max_len:
28
+ return s[:max_len] + f"... (+{len(s) - max_len} chars)"
29
+ return s
30
+
31
+
32
+ def log_step(step: str, **fields: Any) -> None:
33
+ """Chỉ khi LOG_LEVEL=debug: DB / cache / orchestrator."""
34
+ if not is_debug_mode():
35
+ return
36
+ safe = {k: _truncate(v) for k, v in fields.items()}
37
+ try:
38
+ payload = json.dumps(safe, ensure_ascii=False, default=str)
39
+ except Exception:
40
+ payload = str(safe)
41
+ _steps.debug("[step:%s] %s", step, payload)
42
+
43
+
44
+ def log_pipeline_success(operation: str, **fields: Any) -> None:
45
+ """Chỉ hiện khi debug (pipeline SUCCESS không dùng ở info — đã có app.access)."""
46
+ if not is_debug_mode():
47
+ return
48
+ safe = {k: _truncate(v, 500) for k, v in fields.items()}
49
+ _pipeline.info(
50
+ "SUCCESS %s %s",
51
+ operation,
52
+ json.dumps(safe, ensure_ascii=False, default=str),
53
+ )
54
+
55
+
56
+ def log_pipeline_failure(operation: str, error: str | None = None, **fields: Any) -> None:
57
+ """Thất bại pipeline: luôn dùng WARNING để vẫn thấy khi LOG_LEVEL=warning."""
58
+ if is_debug_mode():
59
+ safe = {k: _truncate(v, 500) for k, v in fields.items()}
60
+ _pipeline.warning(
61
+ "FAIL %s err=%s %s",
62
+ operation,
63
+ _truncate(error, 300),
64
+ json.dumps(safe, ensure_ascii=False, default=str),
65
+ )
66
+ else:
67
+ _pipeline.warning("FAIL %s", operation)
app/main.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import time
6
+ import uuid
7
+ import warnings
8
+
9
+ from dotenv import load_dotenv
10
+ from fastapi import Depends, FastAPI, File, HTTPException, UploadFile
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from starlette.requests import Request
13
+
14
+ load_dotenv()
15
+
16
+ from app.runtime_env import apply_runtime_env_defaults
17
+
18
+ apply_runtime_env_defaults()
19
+
20
+ os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
21
+ warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
22
+ warnings.filterwarnings("ignore", category=UserWarning, module="albumentations")
23
+
24
+ from app.logging_setup import ACCESS_LOGGER_NAME, get_log_level, setup_application_logging
25
+
26
+ setup_application_logging()
27
+
28
+ # Routers (after logging)
29
+ from app.dependencies import get_current_user_id
30
+ from app.ocr_local_file import ocr_from_local_image_path
31
+ from app.routers import auth, sessions, solve
32
+ from agents.ocr_agent import OCRAgent
33
+ from app.routers.solve import get_orchestrator
34
+ from app.supabase_client import get_supabase
35
+ from app.websocket_manager import register_websocket_routes
36
+
37
+ logger = logging.getLogger("app.main")
38
+ _access = logging.getLogger(ACCESS_LOGGER_NAME)
39
+
40
+ app = FastAPI(title="Visual Math Solver API v5.1")
41
+
42
+
43
+ @app.middleware("http")
44
+ async def access_log_middleware(request: Request, call_next):
45
+ """LOG_LEVEL=info/debug: mọi request; warning: chỉ 4xx/5xx; error: chỉ 4xx/5xx ở mức error."""
46
+ start = time.perf_counter()
47
+ response = await call_next(request)
48
+ ms = (time.perf_counter() - start) * 1000
49
+ mode = get_log_level()
50
+ method = request.method
51
+ path = request.url.path
52
+ status = response.status_code
53
+
54
+ if mode in ("debug", "info"):
55
+ _access.info("%s %s -> %s (%.0fms)", method, path, status, ms)
56
+ elif mode == "warning":
57
+ if status >= 500:
58
+ _access.error("%s %s -> %s (%.0fms)", method, path, status, ms)
59
+ elif status >= 400:
60
+ _access.warning("%s %s -> %s (%.0fms)", method, path, status, ms)
61
+ elif mode == "error":
62
+ if status >= 400:
63
+ _access.error("%s %s -> %s", method, path, status)
64
+
65
+ return response
66
+
67
+
68
+ from worker.celery_app import BROKER_URL
69
+
70
+ _broker_tail = BROKER_URL.split("@")[-1] if "@" in BROKER_URL else BROKER_URL
71
+ if get_log_level() in ("debug", "info"):
72
+ logger.info("App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail)
73
+ else:
74
+ logger.warning(
75
+ "App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail
76
+ )
77
+
78
+ app.add_middleware(
79
+ CORSMiddleware,
80
+ allow_origins=[
81
+ "http://localhost:3000",
82
+ "http://127.0.0.1:3000",
83
+ "http://localhost:3005",
84
+ ],
85
+ allow_credentials=True,
86
+ allow_methods=["*"],
87
+ allow_headers=["*"],
88
+ )
89
+
90
+ app.include_router(auth.router)
91
+ app.include_router(sessions.router)
92
+ app.include_router(solve.router)
93
+
94
+ register_websocket_routes(app)
95
+
96
+
97
+ def get_ocr_agent() -> OCRAgent:
98
+ """Same OCR instance as the solve pipeline (no duplicate model load)."""
99
+ return get_orchestrator().ocr_agent
100
+
101
+
102
+ supabase_client = get_supabase()
103
+
104
+
105
+ @app.get("/")
106
+ def read_root():
107
+ return {"message": "Visual Math Solver API v5.1 is running", "version": "5.1"}
108
+
109
+
110
+ @app.post("/api/v1/ocr")
111
+ async def upload_ocr(
112
+ file: UploadFile = File(...),
113
+ _user_id=Depends(get_current_user_id),
114
+ ):
115
+ """OCR upload: requires authenticated user."""
116
+ temp_path = f"temp_{uuid.uuid4()}.png"
117
+ with open(temp_path, "wb") as buffer:
118
+ buffer.write(await file.read())
119
+
120
+ try:
121
+ text = await ocr_from_local_image_path(temp_path, file.filename, get_ocr_agent())
122
+ return {"text": text}
123
+ finally:
124
+ if os.path.exists(temp_path):
125
+ os.remove(temp_path)
126
+
127
+
128
+ @app.get("/api/v1/solve/{job_id}")
129
+ async def get_job_status(
130
+ job_id: str,
131
+ user_id=Depends(get_current_user_id),
132
+ ):
133
+ """Retrieve job status (can be used for polling if WS fails). Owner-only."""
134
+ response = supabase_client.table("jobs").select("*").eq("id", job_id).execute()
135
+ if not response.data:
136
+ raise HTTPException(status_code=404, detail="Job not found")
137
+ job = response.data[0]
138
+ if job.get("user_id") is not None and str(job["user_id"]) != str(user_id):
139
+ raise HTTPException(status_code=403, detail="Forbidden: You do not own this job.")
140
+ return job
app/models/schemas.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, EmailStr, field_validator
2
+ from typing import Optional, List, Any, Dict
3
+ from datetime import datetime
4
+ import uuid
5
+
6
+ from app.url_utils import sanitize_url
7
+
8
+ # --- Auth Schemas ---
9
+ class UserProfile(BaseModel):
10
+ id: uuid.UUID
11
+ display_name: Optional[str] = None
12
+ avatar_url: Optional[str] = None
13
+ created_at: datetime
14
+
15
+ class User(BaseModel):
16
+ id: uuid.UUID
17
+ email: EmailStr
18
+
19
+ # --- Session Schemas ---
20
+ class SessionBase(BaseModel):
21
+ title: str = "Bài toán mới"
22
+
23
+ class SessionCreate(SessionBase):
24
+ pass
25
+
26
+ class Session(SessionBase):
27
+ id: uuid.UUID
28
+ user_id: uuid.UUID
29
+ created_at: datetime
30
+ updated_at: datetime
31
+
32
+ class Config:
33
+ from_attributes = True
34
+
35
+ # --- Message Schemas ---
36
+ class MessageBase(BaseModel):
37
+ role: str
38
+ type: str = "text"
39
+ content: str
40
+ metadata: Dict[str, Any] = {}
41
+
42
+ class MessageCreate(MessageBase):
43
+ session_id: uuid.UUID
44
+
45
+ class Message(MessageBase):
46
+ id: uuid.UUID
47
+ session_id: uuid.UUID
48
+ created_at: datetime
49
+
50
+ class Config:
51
+ from_attributes = True
52
+
53
+ # --- Solve Job Schemas ---
54
+ class SolveRequest(BaseModel):
55
+ text: str
56
+ image_url: Optional[str] = None
57
+
58
+ @field_validator("image_url", mode="before")
59
+ @classmethod
60
+ def _clean_image_url(cls, v):
61
+ return sanitize_url(v) if v is not None else None
62
+
63
+ class SolveResponse(BaseModel):
64
+ job_id: str
65
+ status: str
66
+
67
+ class RenderVideoRequest(BaseModel):
68
+ job_id: Optional[str] = None
69
+
70
+ class RenderVideoResponse(BaseModel):
71
+ job_id: str
72
+ status: str
73
+
74
+
75
+ class OcrPreviewResponse(BaseModel):
76
+ """Stateless OCR preview before POST .../solve (no DB writes, no job)."""
77
+
78
+ ocr_text: str
79
+ user_message: str = ""
80
+ combined_draft: str
app/ocr_celery.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run OCR on a remote worker via Celery (queue `ocr`) when OCR_USE_CELERY is enabled."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import TYPE_CHECKING
8
+
9
+ import anyio
10
+
11
+ if TYPE_CHECKING:
12
+ from agents.ocr_agent import OCRAgent
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def ocr_celery_enabled() -> bool:
18
+ return os.getenv("OCR_USE_CELERY", "").strip().lower() in ("1", "true", "yes", "on")
19
+
20
+
21
+ def _ocr_timeout_sec() -> float:
22
+ raw = os.getenv("OCR_CELERY_TIMEOUT_SEC", "180")
23
+ try:
24
+ return max(30.0, float(raw))
25
+ except ValueError:
26
+ return 180.0
27
+
28
+
29
+ def _run_ocr_celery_sync(image_url: str) -> str:
30
+ from worker.ocr_tasks import run_ocr_from_url
31
+
32
+ async_result = run_ocr_from_url.apply_async(args=[image_url])
33
+ return async_result.get(timeout=_ocr_timeout_sec())
34
+
35
+
36
+ def _is_ocr_error_response(text: str) -> bool:
37
+ s = (text or "").lstrip()
38
+ return s.startswith("Error:")
39
+
40
+
41
+ async def ocr_from_image_url(image_url: str, fallback_agent: "OCRAgent") -> str:
42
+ """
43
+ If OCR_USE_CELERY: delegate to Celery task `run_ocr_from_url` (worker queue `ocr`, raw OCR only),
44
+ then run ``refine_with_llm`` on the API process.
45
+ Else: use fallback_agent.process_url (in-process full pipeline).
46
+ """
47
+ if not ocr_celery_enabled():
48
+ return await fallback_agent.process_url(image_url)
49
+ logger.info("OCR_USE_CELERY: delegating OCR to Celery queue=ocr (LLM refine on API)")
50
+ raw = await anyio.to_thread.run_sync(_run_ocr_celery_sync, image_url)
51
+ raw = raw if raw is not None else ""
52
+ if not raw.strip() or _is_ocr_error_response(raw):
53
+ return raw
54
+ return await fallback_agent.refine_with_llm(raw)
app/ocr_local_file.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OCR from a local file path, optionally via Celery worker (upload temp blob first)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import TYPE_CHECKING
8
+
9
+ from app.chat_image_upload import (
10
+ delete_storage_object,
11
+ upload_ephemeral_ocr_blob,
12
+ validate_chat_image_bytes,
13
+ )
14
+ from app.ocr_celery import ocr_celery_enabled, ocr_from_image_url
15
+
16
+ if TYPE_CHECKING:
17
+ from agents.ocr_agent import OCRAgent
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ async def ocr_from_local_image_path(
23
+ local_path: str,
24
+ original_filename: str | None,
25
+ fallback_agent: "OCRAgent",
26
+ ) -> str:
27
+ """
28
+ Run OCR on a file on local disk. If OCR_USE_Celery, upload to ephemeral storage URL
29
+ then delegate to worker; otherwise process_image in-process.
30
+ """
31
+ if not ocr_celery_enabled():
32
+ return await fallback_agent.process_image(local_path)
33
+
34
+ with open(local_path, "rb") as f:
35
+ body = f.read()
36
+ ext = os.path.splitext(original_filename or local_path)[1].lower() or ".png"
37
+ _, content_type = validate_chat_image_bytes(original_filename or local_path, body, None)
38
+ bucket = os.getenv("SUPABASE_IMAGE_BUCKET", "image")
39
+ path, url = upload_ephemeral_ocr_blob(body, ext, content_type)
40
+ try:
41
+ return await ocr_from_image_url(url, fallback_agent)
42
+ finally:
43
+ delete_storage_object(bucket, path)
app/ocr_text_merge.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helpers for OCR preview combined draft (no Pydantic email deps)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Optional
6
+
7
+
8
+ def build_combined_ocr_preview_draft(user_message: Optional[str], ocr_text: str) -> str:
9
+ """Merge user caption and OCR text for confirm step (user message first, then OCR)."""
10
+ u = (user_message or "").strip()
11
+ o = (ocr_text or "").strip()
12
+ if u and o:
13
+ return f"{u}\n\n{o}"
14
+ return u or o
app/routers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import auth, sessions, solve
app/routers/auth.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException
2
+ from app.dependencies import get_current_user_id
3
+ from app.supabase_client import get_supabase
4
+ from app.models.schemas import UserProfile
5
+ import uuid
6
+
7
+ router = APIRouter(prefix="/api/v1/auth", tags=["Auth"])
8
+
9
+ @router.get("/me")
10
+ async def get_me(user_id=Depends(get_current_user_id)):
11
+ """获取当前登录用户的信息 (Retrieve current user profile)"""
12
+ supabase = get_supabase()
13
+ res = supabase.table("profiles").select("*").eq("id", user_id).execute()
14
+ if not res.data:
15
+ raise HTTPException(status_code=404, detail="Profile not found.")
16
+ return res.data[0]
17
+
18
+ @router.patch("/me")
19
+ async def update_me(data: dict, user_id=Depends(get_current_user_id)):
20
+ """Cập nhật profile hiện tại (Update current profile)"""
21
+ supabase = get_supabase()
22
+ res = supabase.table("profiles").update(data).eq("id", user_id).execute()
23
+ return res.data[0]
app/routers/sessions.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import List
4
+
5
+ from fastapi import APIRouter, Depends, HTTPException
6
+
7
+ from app.dependencies import get_current_user_id
8
+ from app.logutil import log_step
9
+ from app.session_cache import (
10
+ get_sessions_list_cached,
11
+ invalidate_for_user,
12
+ invalidate_session_owner,
13
+ session_owned_by_user,
14
+ )
15
+ from app.supabase_client import get_supabase
16
+
17
+ router = APIRouter(prefix="/api/v1/sessions", tags=["Sessions"])
18
+
19
+
20
+ @router.get("", response_model=List[dict])
21
+ async def list_sessions(user_id=Depends(get_current_user_id)):
22
+ """Danh sách các phiên chat của người dùng (List user's chat sessions)"""
23
+ supabase = get_supabase()
24
+
25
+ def fetch() -> list:
26
+ res = (
27
+ supabase.table("sessions")
28
+ .select("*")
29
+ .eq("user_id", user_id)
30
+ .order("updated_at", desc=True)
31
+ .execute()
32
+ )
33
+ log_step("db_select", table="sessions", op="list", user_id=str(user_id))
34
+ return res.data
35
+
36
+ return get_sessions_list_cached(str(user_id), fetch)
37
+
38
+
39
+ @router.post("", response_model=dict)
40
+ async def create_session(user_id=Depends(get_current_user_id)):
41
+ """Tạo một phiên chat mới (Create a new chat session)"""
42
+ supabase = get_supabase()
43
+ res = supabase.table("sessions").insert(
44
+ {"user_id": user_id, "title": "Bài toán mới"}
45
+ ).execute()
46
+ log_step("db_insert", table="sessions", op="create")
47
+ invalidate_for_user(str(user_id))
48
+ return res.data[0]
49
+
50
+
51
+ @router.get("/{session_id}/messages", response_model=List[dict])
52
+ async def get_session_messages(session_id: str, user_id=Depends(get_current_user_id)):
53
+ """Lấy toàn bộ lịch sử tin nhắn của một phiên (Get chat history for a session)"""
54
+ supabase = get_supabase()
55
+
56
+ def owns() -> bool:
57
+ res = (
58
+ supabase.table("sessions")
59
+ .select("id")
60
+ .eq("id", session_id)
61
+ .eq("user_id", user_id)
62
+ .execute()
63
+ )
64
+ log_step("db_select", table="sessions", op="owner_check", session_id=session_id)
65
+ return bool(res.data)
66
+
67
+ if not session_owned_by_user(session_id, str(user_id), owns):
68
+ raise HTTPException(
69
+ status_code=403, detail="Forbidden: You do not own this session."
70
+ )
71
+
72
+ res = (
73
+ supabase.table("messages")
74
+ .select("*")
75
+ .eq("session_id", session_id)
76
+ .order("created_at", desc=False)
77
+ .execute()
78
+ )
79
+ log_step("db_select", table="messages", op="list", session_id=session_id)
80
+ return res.data
81
+
82
+
83
+ @router.delete("/{session_id}")
84
+ async def delete_session(session_id: str, user_id=Depends(get_current_user_id)):
85
+ """Xóa một phiên chat (Delete a chat session)"""
86
+ supabase = get_supabase()
87
+
88
+ def owns() -> bool:
89
+ res = (
90
+ supabase.table("sessions")
91
+ .select("id")
92
+ .eq("id", session_id)
93
+ .eq("user_id", user_id)
94
+ .execute()
95
+ )
96
+ return bool(res.data)
97
+
98
+ if not session_owned_by_user(session_id, str(user_id), owns):
99
+ raise HTTPException(
100
+ status_code=403, detail="Forbidden: You do not own this session."
101
+ )
102
+
103
+ # jobs.session_id FK must be cleared before sessions row
104
+ supabase.table("jobs").delete().eq("session_id", session_id).eq("user_id", user_id).execute()
105
+ log_step("db_delete", table="jobs", op="by_session", session_id=session_id)
106
+ supabase.table("messages").delete().eq("session_id", session_id).execute()
107
+ log_step("db_delete", table="messages", op="by_session", session_id=session_id)
108
+ res = (
109
+ supabase.table("sessions")
110
+ .delete()
111
+ .eq("id", session_id)
112
+ .eq("user_id", user_id)
113
+ .execute()
114
+ )
115
+ log_step("db_delete", table="sessions", session_id=session_id)
116
+ invalidate_for_user(str(user_id))
117
+ invalidate_session_owner(session_id, str(user_id))
118
+ return {"status": "ok", "deleted_id": session_id}
119
+
120
+
121
+ @router.patch("/{session_id}/title")
122
+ async def update_session_title(title: str, session_id: str, user_id=Depends(get_current_user_id)):
123
+ """Cập nhật tiêu đề phiên chat (Rename a chat session)"""
124
+ supabase = get_supabase()
125
+ res = (
126
+ supabase.table("sessions")
127
+ .update({"title": title})
128
+ .eq("id", session_id)
129
+ .eq("user_id", user_id)
130
+ .execute()
131
+ )
132
+ log_step("db_update", table="sessions", op="title", session_id=session_id)
133
+ invalidate_for_user(str(user_id))
134
+ return res.data[0]
135
+
136
+
137
+ @router.get("/{session_id}/assets", response_model=List[dict])
138
+ async def get_session_assets(session_id: str, user_id=Depends(get_current_user_id)):
139
+ """Lấy danh sách video đã render trong session (Get versioned assets for a session)"""
140
+ supabase = get_supabase()
141
+
142
+ def owns() -> bool:
143
+ res = (
144
+ supabase.table("sessions")
145
+ .select("id")
146
+ .eq("id", session_id)
147
+ .eq("user_id", user_id)
148
+ .execute()
149
+ )
150
+ return bool(res.data)
151
+
152
+ if not session_owned_by_user(session_id, str(user_id), owns):
153
+ raise HTTPException(
154
+ status_code=403, detail="Forbidden: You do not own this session."
155
+ )
156
+
157
+ res = (
158
+ supabase.table("session_assets")
159
+ .select("*")
160
+ .eq("session_id", session_id)
161
+ .order("version", desc=True)
162
+ .execute()
163
+ )
164
+ log_step("db_select", table="session_assets", op="list", session_id=session_id)
165
+ return res.data
app/routers/solve.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import uuid
6
+
7
+ from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, UploadFile
8
+
9
+ from agents.orchestrator import Orchestrator
10
+ from app.chat_image_upload import upload_session_chat_image, validate_chat_image_bytes
11
+ from app.ocr_celery import ocr_celery_enabled
12
+ from app.ocr_local_file import ocr_from_local_image_path
13
+ from app.dependencies import get_current_user_id
14
+ from app.errors import format_error_for_user
15
+ from app.logutil import log_pipeline_failure, log_pipeline_success, log_step
16
+ from app.models.schemas import (
17
+ OcrPreviewResponse,
18
+ RenderVideoRequest,
19
+ RenderVideoResponse,
20
+ SolveRequest,
21
+ SolveResponse,
22
+ )
23
+ from app.ocr_text_merge import build_combined_ocr_preview_draft
24
+ from app.session_cache import invalidate_for_user, session_owned_by_user
25
+ from app.supabase_client import get_supabase
26
+
27
+ logger = logging.getLogger(__name__)
28
+ router = APIRouter(prefix="/api/v1/sessions", tags=["Solve"])
29
+
30
+ # Eager init: all agents and models load at import time (also run in Docker build via scripts/prewarm_models.py).
31
+ ORCHESTRATOR = Orchestrator()
32
+
33
+
34
+ def get_orchestrator() -> Orchestrator:
35
+ return ORCHESTRATOR
36
+
37
+
38
+ _OCR_PREVIEW_MAX_BYTES = 10 * 1024 * 1024
39
+
40
+
41
+ def _assert_session_owner(supabase, session_id: str, user_id, uid: str, op: str) -> None:
42
+ def owns() -> bool:
43
+ res = (
44
+ supabase.table("sessions")
45
+ .select("id")
46
+ .eq("id", session_id)
47
+ .eq("user_id", user_id)
48
+ .execute()
49
+ )
50
+ log_step("db_select", table="sessions", op=op, session_id=session_id)
51
+ return bool(res.data)
52
+
53
+ if not session_owned_by_user(session_id, uid, owns):
54
+ log_pipeline_failure("solve_request", error="forbidden", session_id=session_id)
55
+ raise HTTPException(
56
+ status_code=403, detail="Forbidden: You do not own this session."
57
+ )
58
+
59
+
60
+ def _enqueue_solve_common(
61
+ supabase,
62
+ background_tasks: BackgroundTasks,
63
+ session_id: str,
64
+ user_id,
65
+ uid: str,
66
+ request: SolveRequest,
67
+ message_metadata: dict,
68
+ job_id: str,
69
+ ) -> SolveResponse:
70
+ """Insert user message, job row, enqueue pipeline; update title when first message."""
71
+ supabase.table("messages").insert(
72
+ {
73
+ "session_id": session_id,
74
+ "role": "user",
75
+ "type": "text",
76
+ "content": request.text,
77
+ "metadata": message_metadata,
78
+ }
79
+ ).execute()
80
+ log_step("db_insert", table="messages", op="user_message", session_id=session_id)
81
+
82
+ supabase.table("jobs").insert(
83
+ {
84
+ "id": job_id,
85
+ "user_id": user_id,
86
+ "session_id": session_id,
87
+ "status": "processing",
88
+ "input_text": request.text,
89
+ }
90
+ ).execute()
91
+ log_step("db_insert", table="jobs", job_id=job_id)
92
+
93
+ background_tasks.add_task(process_session_job, job_id, session_id, request, str(user_id))
94
+
95
+ title_check = supabase.table("sessions").select("title").eq("id", session_id).execute()
96
+ if title_check.data and title_check.data[0]["title"] == "Bài toán mới":
97
+ new_title = request.text[:50] + ("..." if len(request.text) > 50 else "")
98
+ supabase.table("sessions").update({"title": new_title}).eq("id", session_id).execute()
99
+ log_step("db_update", table="sessions", op="title_from_first_message")
100
+ invalidate_for_user(uid)
101
+
102
+ log_pipeline_success("solve_accepted", job_id=job_id, session_id=session_id)
103
+ return SolveResponse(job_id=job_id, status="processing")
104
+
105
+
106
+ @router.post("/{session_id}/ocr_preview", response_model=OcrPreviewResponse)
107
+ async def ocr_preview(
108
+ session_id: str,
109
+ user_id=Depends(get_current_user_id),
110
+ file: UploadFile = File(...),
111
+ user_message: str | None = Form(None),
112
+ ):
113
+ """
114
+ Run OCR on an uploaded image and merge with optional user_message into combined_draft.
115
+ Does not insert messages or start a solve job. After user confirms, call POST .../solve
116
+ with text=combined_draft (edited) and omit image_url to avoid double OCR.
117
+ """
118
+ supabase = get_supabase()
119
+ uid = str(user_id)
120
+ _assert_session_owner(supabase, session_id, user_id, uid, "owner_check_ocr_preview")
121
+
122
+ body = await file.read()
123
+ if len(body) > _OCR_PREVIEW_MAX_BYTES:
124
+ raise HTTPException(
125
+ status_code=413,
126
+ detail=f"Image too large (max {_OCR_PREVIEW_MAX_BYTES // (1024 * 1024)} MB).",
127
+ )
128
+ if not body:
129
+ raise HTTPException(status_code=400, detail="Empty file.")
130
+
131
+ if ocr_celery_enabled():
132
+ validate_chat_image_bytes(file.filename, body, file.content_type)
133
+
134
+ suffix = os.path.splitext(file.filename or "")[1].lower()
135
+ if suffix not in (".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp", ""):
136
+ suffix = ".png"
137
+ temp_path = f"temp_ocr_preview_{uuid.uuid4()}{suffix or '.png'}"
138
+ try:
139
+ with open(temp_path, "wb") as f:
140
+ f.write(body)
141
+ ocr_text = await ocr_from_local_image_path(
142
+ temp_path, file.filename, get_orchestrator().ocr_agent
143
+ )
144
+ if ocr_text is None:
145
+ ocr_text = ""
146
+ finally:
147
+ if os.path.exists(temp_path):
148
+ os.remove(temp_path)
149
+
150
+ um = (user_message or "").strip()
151
+ combined = build_combined_ocr_preview_draft(user_message, ocr_text)
152
+ log_step("ocr_preview_done", session_id=session_id, ocr_len=len(ocr_text), user_len=len(um))
153
+ return OcrPreviewResponse(
154
+ ocr_text=ocr_text,
155
+ user_message=um,
156
+ combined_draft=combined,
157
+ )
158
+
159
+
160
+ @router.post("/{session_id}/solve", response_model=SolveResponse)
161
+ async def solve_problem(
162
+ session_id: str,
163
+ request: SolveRequest,
164
+ background_tasks: BackgroundTasks,
165
+ user_id=Depends(get_current_user_id),
166
+ ):
167
+ """
168
+ Gửi câu hỏi giải toán trong một session (Submit geometry problem in a session).
169
+ Lưu câu hỏi vào history và bắt đầu tiến trình giải (chỉ giải toán và tạo hình tĩnh).
170
+ """
171
+ supabase = get_supabase()
172
+ uid = str(user_id)
173
+ _assert_session_owner(supabase, session_id, user_id, uid, "owner_check")
174
+
175
+ message_metadata = {"image_url": request.image_url} if request.image_url else {}
176
+ job_id = str(uuid.uuid4())
177
+ return _enqueue_solve_common(
178
+ supabase,
179
+ background_tasks,
180
+ session_id,
181
+ user_id,
182
+ uid,
183
+ request,
184
+ message_metadata,
185
+ job_id,
186
+ )
187
+
188
+
189
+ @router.post("/{session_id}/solve_multipart", response_model=SolveResponse)
190
+ async def solve_multipart(
191
+ session_id: str,
192
+ background_tasks: BackgroundTasks,
193
+ user_id=Depends(get_current_user_id),
194
+ text: str = Form(...),
195
+ file: UploadFile = File(...),
196
+ ):
197
+ """
198
+ Gửi text + file ảnh trong một request multipart: validate, upload bucket `image`,
199
+ ghi session_assets, lưu message kèm metadata (URL, size, type), rồi enqueue solve
200
+ (image_url trỏ public URL để orchestrator OCR).
201
+ """
202
+ supabase = get_supabase()
203
+ uid = str(user_id)
204
+ _assert_session_owner(supabase, session_id, user_id, uid, "owner_check_solve_multipart")
205
+
206
+ t = (text or "").strip()
207
+ if not t:
208
+ raise HTTPException(status_code=400, detail="text must not be empty.")
209
+
210
+ body = await file.read()
211
+ ext, content_type = validate_chat_image_bytes(file.filename, body, file.content_type)
212
+
213
+ job_id = str(uuid.uuid4())
214
+ up = upload_session_chat_image(session_id, job_id, body, ext, content_type)
215
+ public_url = up["public_url"]
216
+
217
+ message_metadata = {
218
+ "image_url": public_url,
219
+ "attachment": {
220
+ "public_url": public_url,
221
+ "storage_path": up["storage_path"],
222
+ "size_bytes": len(body),
223
+ "content_type": content_type,
224
+ "original_filename": file.filename or "",
225
+ "session_asset_id": up.get("session_asset_id"),
226
+ },
227
+ }
228
+ request = SolveRequest(text=t, image_url=public_url)
229
+ return _enqueue_solve_common(
230
+ supabase,
231
+ background_tasks,
232
+ session_id,
233
+ user_id,
234
+ uid,
235
+ request,
236
+ message_metadata,
237
+ job_id,
238
+ )
239
+
240
+
241
+ @router.post("/{session_id}/render_video", response_model=RenderVideoResponse)
242
+ async def render_video(
243
+ session_id: str,
244
+ request: RenderVideoRequest,
245
+ background_tasks: BackgroundTasks,
246
+ user_id=Depends(get_current_user_id),
247
+ ):
248
+ """
249
+ Yêu cầu tạo video Manim từ trạng thái hình ảnh mới nhất của session.
250
+ """
251
+ supabase = get_supabase()
252
+
253
+ # 1. Kiểm tra quyền sở hữu
254
+ res = supabase.table("sessions").select("id").eq("id", session_id).eq("user_id", user_id).execute()
255
+ if not res.data:
256
+ raise HTTPException(status_code=403, detail="Forbidden: You do not own this session.")
257
+
258
+ # 2. Tìm tin nhắn assistant có metadata hình học (cụ thể job_id hoặc mới nhất trong 10 tin nhắn gần nhất)
259
+ msg_res = (
260
+ supabase.table("messages")
261
+ .select("metadata")
262
+ .eq("session_id", session_id)
263
+ .eq("role", "assistant")
264
+ .order("created_at", desc=True)
265
+ .limit(10)
266
+ .execute()
267
+ )
268
+
269
+ latest_geometry = None
270
+ if msg_res.data:
271
+ for msg in msg_res.data:
272
+ meta = msg.get("metadata", {})
273
+ # Nếu có yêu cầu job_id cụ thể, phải khớp job_id
274
+ if request.job_id and meta.get("job_id") != request.job_id:
275
+ continue
276
+
277
+ # Phải có dữ liệu hình học
278
+ if meta.get("geometry_dsl") and meta.get("coordinates"):
279
+ latest_geometry = meta
280
+ break
281
+
282
+ if not latest_geometry:
283
+ raise HTTPException(status_code=404, detail="Không tìm thấy dữ liệu hình học để render video.")
284
+
285
+ # 3. Tạo Job rendering
286
+ job_id = str(uuid.uuid4())
287
+ supabase.table("jobs").insert({
288
+ "id": job_id,
289
+ "user_id": user_id,
290
+ "session_id": session_id,
291
+ "status": "rendering_queued",
292
+ "input_text": f"Render video requested at {job_id}",
293
+ }).execute()
294
+
295
+ # 4. Dispatch background task
296
+ background_tasks.add_task(process_render_job, job_id, session_id, latest_geometry)
297
+
298
+ return RenderVideoResponse(job_id=job_id, status="rendering_queued")
299
+
300
+
301
+ async def process_session_job(
302
+ job_id: str, session_id: str, request: SolveRequest, user_id: str
303
+ ):
304
+ """Tiến trình giải toán ngầm, tạo hình ảnh tĩnh."""
305
+ from app.websocket_manager import notify_status
306
+
307
+ async def status_update(status: str):
308
+ await notify_status(job_id, {"status": status})
309
+
310
+ supabase = get_supabase()
311
+ try:
312
+ history_res = (
313
+ supabase.table("messages")
314
+ .select("*")
315
+ .eq("session_id", session_id)
316
+ .order("created_at", desc=False)
317
+ .execute()
318
+ )
319
+ history = history_res.data if history_res.data else []
320
+
321
+ result = await get_orchestrator().run(
322
+ request.text,
323
+ request.image_url,
324
+ job_id=job_id,
325
+ session_id=session_id,
326
+ status_callback=status_update,
327
+ history=history,
328
+ )
329
+
330
+ status = result.get("status", "error") if "error" not in result else "error"
331
+
332
+ supabase.table("jobs").update({"status": status, "result": result}).eq(
333
+ "id", job_id
334
+ ).execute()
335
+
336
+ supabase.table("messages").insert(
337
+ {
338
+ "session_id": session_id,
339
+ "role": "assistant",
340
+ "type": "analysis" if "error" not in result else "error",
341
+ "content": (
342
+ result.get("semantic_analysis", "Đã có lỗi xảy ra.")
343
+ if "error" not in result
344
+ else result["error"]
345
+ ),
346
+ "metadata": {
347
+ "job_id": job_id,
348
+ "coordinates": result.get("coordinates"),
349
+ "geometry_dsl": result.get("geometry_dsl"),
350
+ "polygon_order": result.get("polygon_order", []),
351
+ "drawing_phases": result.get("drawing_phases", []),
352
+ "circles": result.get("circles", []),
353
+ "lines": result.get("lines", []),
354
+ "rays": result.get("rays", []),
355
+ "solution": result.get("solution"),
356
+ "is_3d": result.get("is_3d", False),
357
+ },
358
+ }
359
+ ).execute()
360
+
361
+ await notify_status(job_id, {"status": status, "result": result})
362
+
363
+ except Exception as e:
364
+ logger.exception("Error processing session job %s", job_id)
365
+ error_msg = format_error_for_user(e)
366
+ supabase = get_supabase()
367
+ supabase.table("jobs").update(
368
+ {"status": "error", "result": {"error": str(e)}}
369
+ ).eq("id", job_id).execute()
370
+ supabase.table("messages").insert(
371
+ {
372
+ "session_id": session_id,
373
+ "role": "assistant",
374
+ "type": "error",
375
+ "content": error_msg,
376
+ "metadata": {"job_id": job_id},
377
+ }
378
+ ).execute()
379
+ await notify_status(job_id, {"status": "error", "error": error_msg})
380
+
381
+ async def process_render_job(job_id: str, session_id: str, geometry_data: dict):
382
+ """Tiến trình render video từ metadata có sẵn."""
383
+ from app.websocket_manager import notify_status
384
+ from worker.tasks import render_geometry_video
385
+
386
+ await notify_status(job_id, {"status": "rendering_queued"})
387
+
388
+ # Prepare payload for Celery (similar to what orchestrator used to do)
389
+ result_payload = {
390
+ "geometry_dsl": geometry_data.get("geometry_dsl"),
391
+ "coordinates": geometry_data.get("coordinates"),
392
+ "polygon_order": geometry_data.get("polygon_order", []),
393
+ "drawing_phases": geometry_data.get("drawing_phases", []),
394
+ "circles": geometry_data.get("circles", []),
395
+ "lines": geometry_data.get("lines", []),
396
+ "rays": geometry_data.get("rays", []),
397
+ "semantic": geometry_data.get("semantic", {}),
398
+ "semantic_analysis": geometry_data.get("semantic_analysis", "🎬 Video minh họa dựng từ trạng thái gần nhất."),
399
+ "session_id": session_id,
400
+ }
401
+
402
+ try:
403
+ logger.info(f"[RenderJob] Attempting to dispatch Celery task for job {job_id}...")
404
+ render_geometry_video.delay(job_id, result_payload)
405
+ logger.info(f"[RenderJob] SUCCESS: Dispatched Celery task for job {job_id}")
406
+ except Exception as e:
407
+ logger.exception(f"[RenderJob] FAILED to dispatch Celery task: {e}")
408
+ supabase = get_supabase()
409
+ supabase.table("jobs").update({"status": "error", "result": {"error": f"Task dispatch failed: {str(e)}"}}).eq("id", job_id).execute()
410
+ await notify_status(job_id, {"status": "error", "error": str(e)})
app/runtime_env.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Default process env vars (Paddle/OpenMP). Call as early as possible after load_dotenv."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+
8
+ def apply_runtime_env_defaults() -> None:
9
+ # Paddle respects OMP_NUM_THREADS at import; setdefault loses if platform already set 2+
10
+ os.environ["OMP_NUM_THREADS"] = "1"
11
+ os.environ["MKL_NUM_THREADS"] = "1"
12
+ os.environ["OPENBLAS_NUM_THREADS"] = "1"
app/session_cache.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TTL in-memory cache để giảm truy vấn Supabase lặp lại (list session, quyền sở hữu session)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Callable
6
+
7
+ from cachetools import TTLCache
8
+
9
+ from app.logutil import log_step
10
+
11
+ _session_list: TTLCache[str, list[Any]] = TTLCache(maxsize=512, ttl=45)
12
+ _session_owner: TTLCache[tuple[str, str], bool] = TTLCache(maxsize=4096, ttl=45)
13
+
14
+
15
+ def invalidate_for_user(user_id: str) -> None:
16
+ """Xoá cache list session của user (sau create / delete / rename / solve đổi title)."""
17
+ _session_list.pop(user_id, None)
18
+ log_step("cache_invalidate", target="session_list", user_id=user_id)
19
+
20
+
21
+ def invalidate_session_owner(session_id: str, user_id: str) -> None:
22
+ _session_owner.pop((session_id, user_id), None)
23
+ log_step("cache_invalidate", target="session_owner", session_id=session_id, user_id=user_id)
24
+
25
+
26
+ def get_sessions_list_cached(user_id: str, fetch: Callable[[], list[Any]]) -> list[Any]:
27
+ if user_id in _session_list:
28
+ log_step("cache_hit", kind="session_list", user_id=user_id)
29
+ return _session_list[user_id]
30
+ log_step("cache_miss", kind="session_list", user_id=user_id)
31
+ data = fetch()
32
+ _session_list[user_id] = data
33
+ return data
34
+
35
+
36
+ def session_owned_by_user(
37
+ session_id: str,
38
+ user_id: str,
39
+ fetch: Callable[[], bool],
40
+ ) -> bool:
41
+ key = (session_id, user_id)
42
+ if key in _session_owner:
43
+ log_step("cache_hit", kind="session_owner", session_id=session_id)
44
+ return _session_owner[key]
45
+ log_step("cache_miss", kind="session_owner", session_id=session_id)
46
+ ok = fetch()
47
+ _session_owner[key] = ok
48
+ return ok
app/supabase_client.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from supabase import Client, ClientOptions, create_client
3
+ from supabase_auth import SyncMemoryStorage
4
+ from dotenv import load_dotenv
5
+
6
+ load_dotenv()
7
+
8
+ from app.url_utils import sanitize_env
9
+
10
+
11
+ def get_supabase() -> Client:
12
+ """Service-role client for server-side operations (bypasses RLS when policies expect service role)."""
13
+ url = sanitize_env(os.getenv("SUPABASE_URL"))
14
+ key = sanitize_env(os.getenv("SUPABASE_SERVICE_ROLE_KEY") or os.getenv("SUPABASE_KEY"))
15
+ if not url or not key:
16
+ raise RuntimeError(
17
+ "SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY (or SUPABASE_KEY) must be set"
18
+ )
19
+ return create_client(url, key)
20
+
21
+
22
+ def get_supabase_for_user_jwt(access_token: str) -> Client:
23
+ """
24
+ Client scoped to the logged-in user: PostgREST sends the user's JWT so RLS applies.
25
+ Use SUPABASE_ANON_KEY (publishable), not the service role key.
26
+ """
27
+ url = sanitize_env(os.getenv("SUPABASE_URL"))
28
+ anon = sanitize_env(os.getenv("SUPABASE_ANON_KEY") or os.getenv("NEXT_PUBLIC_SUPABASE_ANON_KEY"))
29
+ if not url or not anon:
30
+ raise RuntimeError(
31
+ "SUPABASE_URL and SUPABASE_ANON_KEY (or NEXT_PUBLIC_SUPABASE_ANON_KEY) must be set "
32
+ "for user-scoped Supabase access"
33
+ )
34
+ base_opts = ClientOptions(storage=SyncMemoryStorage())
35
+ merged_headers = {**dict(base_opts.headers), "Authorization": f"Bearer {access_token}"}
36
+ opts = ClientOptions(storage=SyncMemoryStorage(), headers=merged_headers)
37
+ return create_client(url, anon, opts)
app/url_utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Normalize URLs / env strings (HF secrets and copy-paste often include trailing newlines)."""
2
+
3
+
4
+ def sanitize_url(value: str | None) -> str | None:
5
+ if value is None:
6
+ return None
7
+ s = value.strip().replace("\r", "").replace("\n", "").replace("\t", "")
8
+ return s or None
9
+
10
+
11
+ def sanitize_env(value: str | None) -> str | None:
12
+ """Strip whitespace and line breaks from environment-backed strings."""
13
+ return sanitize_url(value)
14
+
15
+
16
+ # OpenAI SDK (>=1.x) requires a non-empty api_key at client construction (Docker build / prewarm has no secrets).
17
+ _OPENAI_API_KEY_BUILD_PLACEHOLDER = "build-placeholder-openrouter-not-for-production"
18
+
19
+
20
+ def openai_compatible_api_key(raw: str | None) -> str:
21
+ """Return sanitized API key, or a placeholder so AsyncOpenAI() can be constructed without env at build time."""
22
+ k = sanitize_env(raw)
23
+ return k if k else _OPENAI_API_KEY_BUILD_PLACEHOLDER
app/websocket_manager.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """WebSocket connection registry and job status notifications (avoid circular imports with main)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Dict, List
7
+
8
+ from fastapi import WebSocket, WebSocketDisconnect
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ active_connections: Dict[str, List[WebSocket]] = {}
13
+
14
+
15
+ async def notify_status(job_id: str, data: dict) -> None:
16
+ if job_id not in active_connections:
17
+ return
18
+ for connection in list(active_connections[job_id]):
19
+ try:
20
+ await connection.send_json(data)
21
+ except Exception as e:
22
+ logger.error("WS error sending to %s: %s", job_id, e)
23
+
24
+
25
+ def register_websocket_routes(app) -> None:
26
+ """Attach websocket endpoint to the FastAPI app."""
27
+
28
+ @app.websocket("/ws/{job_id}")
29
+ async def websocket_endpoint(websocket: WebSocket, job_id: str) -> None:
30
+ await websocket.accept()
31
+ if job_id not in active_connections:
32
+ active_connections[job_id] = []
33
+ active_connections[job_id].append(websocket)
34
+ try:
35
+ while True:
36
+ await websocket.receive_text()
37
+ except WebSocketDisconnect:
38
+ active_connections[job_id].remove(websocket)
39
+ if not active_connections[job_id]:
40
+ del active_connections[job_id]
clean_ports.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Script to kill all project-related processes for a clean restart
3
+
4
+ echo "🧹 Cleaning up project processes..."
5
+
6
+ # Kill things on ports 8000 (Backend) and 3000 (Frontend)
7
+ PORTS="8000 3000 11020"
8
+ for PORT in $PORTS; do
9
+ PIDS=$(lsof -ti :$PORT)
10
+ if [ ! -z "$PIDS" ]; then
11
+ echo "Killing processes on port $PORT: $PIDS"
12
+ kill -9 $PIDS 2>/dev/null
13
+ fi
14
+ done
15
+
16
+ # Kill by process name
17
+ echo "Killing any remaining Celery, Uvicorn, or Manim processes..."
18
+ pkill -9 -f "celery" 2>/dev/null
19
+ pkill -9 -f "uvicorn" 2>/dev/null
20
+ pkill -9 -f "manim" 2>/dev/null
21
+
22
+ echo "✅ Done. You can now restart your Backend, Worker, and Frontend."
dump.rdb ADDED
Binary file (5.44 kB). View file
 
migrations/add_image_bucket_storage.sql ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- ============================================================
2
+ -- MathSolver: Supabase Storage bucket `image` (chat / OCR attachments)
3
+ -- Run after session_assets and storage.video policies exist.
4
+ -- ============================================================
5
+
6
+ INSERT INTO storage.buckets (id, name, public)
7
+ VALUES ('image', 'image', true)
8
+ ON CONFLICT (id) DO UPDATE SET public = true;
9
+
10
+ -- Service role: upload/delete/list for API + workers
11
+ DROP POLICY IF EXISTS "Service Role manage images" ON storage.objects;
12
+ CREATE POLICY "Service Role manage images" ON storage.objects
13
+ FOR ALL
14
+ TO service_role
15
+ USING (bucket_id = 'image')
16
+ WITH CHECK (bucket_id = 'image');
17
+
18
+ -- Authenticated: read only objects under sessions they own (path sessions/{session_id}/...)
19
+ DROP POLICY IF EXISTS "Users view session images" ON storage.objects;
20
+ CREATE POLICY "Users view session images" ON storage.objects
21
+ FOR SELECT
22
+ TO authenticated
23
+ USING (
24
+ bucket_id = 'image'
25
+ AND (storage.foldername(name))[2] IN (
26
+ SELECT id::text FROM public.sessions WHERE user_id = auth.uid()
27
+ )
28
+ );
29
+
30
+ -- Public read for get_public_url / FE img tags (same model as video bucket)
31
+ DROP POLICY IF EXISTS "Public read images" ON storage.objects;
32
+ CREATE POLICY "Public read images" ON storage.objects
33
+ FOR SELECT
34
+ TO public
35
+ USING (bucket_id = 'image');
migrations/fix_rls_assets.sql ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- ============================================================
2
+ -- FIX RLS & SESSION ASSETS (MathSolver v5.1 Worker Fix)
3
+ -- ============================================================
4
+
5
+ -- 1. Ensure session_assets table exists
6
+ CREATE TABLE IF NOT EXISTS public.session_assets (
7
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
8
+ session_id UUID NOT NULL REFERENCES public.sessions(id) ON DELETE CASCADE,
9
+ job_id UUID NOT NULL,
10
+ asset_type TEXT NOT NULL CHECK (asset_type IN ('video', 'image')),
11
+ storage_path TEXT NOT NULL,
12
+ public_url TEXT NOT NULL,
13
+ version INTEGER NOT NULL DEFAULT 1,
14
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
15
+ );
16
+
17
+ -- Index for session_assets
18
+ CREATE INDEX IF NOT EXISTS idx_session_assets_session_id ON public.session_assets(session_id);
19
+ CREATE INDEX IF NOT EXISTS idx_session_assets_type ON public.session_assets(session_id, asset_type);
20
+
21
+ -- 2. Enable RLS for all tables
22
+ ALTER TABLE public.session_assets ENABLE ROW LEVEL SECURITY;
23
+ ALTER TABLE public.profiles ENABLE ROW LEVEL SECURITY;
24
+ ALTER TABLE public.sessions ENABLE ROW LEVEL SECURITY;
25
+ ALTER TABLE public.messages ENABLE ROW LEVEL SECURITY;
26
+ ALTER TABLE public.jobs ENABLE ROW LEVEL SECURITY;
27
+
28
+
29
+ -- 3. Fix Table Policies to allow SERVICE ROLE
30
+ -- In Supabase, service_role usually bypasses RLS, but we add explicit policies for safety
31
+ -- especially for path-based checks or when SECURITY DEFINER functions are used.
32
+
33
+ -- [Session Assets]
34
+ DROP POLICY IF EXISTS "Users view own assets" ON public.session_assets;
35
+ CREATE POLICY "Users view own assets" ON public.session_assets
36
+ FOR SELECT USING (
37
+ session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid())
38
+ );
39
+
40
+ DROP POLICY IF EXISTS "Service role manages assets" ON public.session_assets;
41
+ CREATE POLICY "Service role manages assets" ON public.session_assets
42
+ FOR ALL USING (true)
43
+ WITH CHECK (true);
44
+
45
+
46
+ -- [Messages] - Allow Worker to insert assistant messages
47
+ DROP POLICY IF EXISTS "Users manage own messages" ON public.messages;
48
+ CREATE POLICY "Users manage own messages" ON public.messages
49
+ FOR ALL USING (
50
+ session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid())
51
+ OR
52
+ (auth.jwt() ->> 'role' = 'service_role')
53
+ );
54
+
55
+
56
+ -- [Jobs] - Allow Worker to update job status
57
+ DROP POLICY IF EXISTS "Users manage own jobs" ON public.jobs;
58
+ CREATE POLICY "Users manage own jobs" ON public.jobs
59
+ FOR ALL USING (
60
+ auth.uid() = user_id
61
+ OR user_id IS NULL
62
+ OR (auth.jwt() ->> 'role' = 'service_role')
63
+ );
64
+
65
+
66
+ -- 4. Storage Policies (Bucket: video)
67
+ -- Ensure 'video' bucket exists
68
+ INSERT INTO storage.buckets (id, name, public)
69
+ VALUES ('video', 'video', true)
70
+ ON CONFLICT (id) DO UPDATE SET public = true;
71
+
72
+ -- [Storage: Worker / Service Role] - Allow all in video bucket
73
+ DROP POLICY IF EXISTS "Service Role manage videos" ON storage.objects;
74
+ CREATE POLICY "Service Role manage videos" ON storage.objects
75
+ FOR ALL
76
+ TO service_role
77
+ USING (bucket_id = 'video');
78
+
79
+ -- [Storage: Users] - Allow users to view their session videos
80
+ DROP POLICY IF EXISTS "Users view session videos" ON storage.objects;
81
+ CREATE POLICY "Users view session videos" ON storage.objects
82
+ FOR SELECT
83
+ TO authenticated
84
+ USING (
85
+ bucket_id = 'video'
86
+ AND (storage.foldername(name))[2] IN (
87
+ SELECT id::text FROM public.sessions WHERE user_id = auth.uid()
88
+ )
89
+ );
90
+
91
+ -- [Storage: Public] - Allow public read access to videos
92
+ DROP POLICY IF EXISTS "Public read videos" ON storage.objects;
93
+ CREATE POLICY "Public read videos" ON storage.objects
94
+ FOR SELECT
95
+ TO public
96
+ USING (bucket_id = 'video');
migrations/v4_migration.sql ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- ============================================================
2
+ -- MATHSOLVER v4.0 - Migration Script (Multi-Session & History)
3
+ -- ============================================================
4
+
5
+ -- 1. Profiles Table (Extends Supabase Auth)
6
+ CREATE TABLE IF NOT EXISTS public.profiles (
7
+ id UUID PRIMARY KEY REFERENCES auth.users(id) ON DELETE CASCADE,
8
+ display_name TEXT,
9
+ avatar_url TEXT,
10
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
11
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
12
+ );
13
+
14
+ -- Function to handle new user signup and auto-create profile
15
+ CREATE OR REPLACE FUNCTION public.handle_new_user()
16
+ RETURNS TRIGGER AS $$
17
+ BEGIN
18
+ INSERT INTO public.profiles (id, display_name, avatar_url)
19
+ VALUES (
20
+ NEW.id,
21
+ COALESCE(NEW.raw_user_meta_data->>'full_name', NEW.email),
22
+ NEW.raw_user_meta_data->>'avatar_url'
23
+ );
24
+ RETURN NEW;
25
+ END;
26
+ $$ LANGUAGE plpgsql SECURITY DEFINER;
27
+
28
+ -- Trigger for profile creation
29
+ DROP TRIGGER IF EXISTS on_auth_user_created ON auth.users;
30
+ CREATE TRIGGER on_auth_user_created
31
+ AFTER INSERT ON auth.users
32
+ FOR EACH ROW EXECUTE FUNCTION public.handle_new_user();
33
+
34
+ -- 2. Sessions Table
35
+ CREATE TABLE IF NOT EXISTS public.sessions (
36
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
37
+ user_id UUID NOT NULL REFERENCES auth.users(id) ON DELETE CASCADE,
38
+ title TEXT DEFAULT 'Bài toán mới',
39
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
40
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
41
+ );
42
+
43
+ -- Index for sessions
44
+ CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON public.sessions(user_id);
45
+ CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON public.sessions(updated_at DESC);
46
+
47
+ -- 3. Messages Table
48
+ CREATE TABLE IF NOT EXISTS public.messages (
49
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
50
+ session_id UUID NOT NULL REFERENCES public.sessions(id) ON DELETE CASCADE,
51
+ role TEXT NOT NULL CHECK (role IN ('user', 'assistant', 'system')),
52
+ type TEXT NOT NULL DEFAULT 'text',
53
+ content TEXT NOT NULL,
54
+ metadata JSONB DEFAULT '{}'::jsonb,
55
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
56
+ );
57
+
58
+ -- Index for messages
59
+ CREATE INDEX IF NOT EXISTS idx_messages_session_id ON public.messages(session_id);
60
+ CREATE INDEX IF NOT EXISTS idx_messages_created_at ON public.messages(session_id, created_at);
61
+
62
+ -- 4. Session Assets Table (v5.1 Versioning)
63
+ CREATE TABLE IF NOT EXISTS public.session_assets (
64
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
65
+ session_id UUID NOT NULL REFERENCES public.sessions(id) ON DELETE CASCADE,
66
+ job_id UUID NOT NULL,
67
+ asset_type TEXT NOT NULL CHECK (asset_type IN ('video', 'image')),
68
+ storage_path TEXT NOT NULL,
69
+ public_url TEXT NOT NULL,
70
+ version INTEGER NOT NULL DEFAULT 1,
71
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
72
+ );
73
+
74
+ -- Index for session_assets
75
+ CREATE INDEX IF NOT EXISTS idx_session_assets_session_id ON public.session_assets(session_id);
76
+
77
+ -- 5. Update Jobs Table
78
+ ALTER TABLE public.jobs ADD COLUMN IF NOT EXISTS user_id UUID REFERENCES auth.users(id);
79
+ ALTER TABLE public.jobs ADD COLUMN IF NOT EXISTS session_id UUID REFERENCES public.sessions(id);
80
+
81
+ -- 6. Row Level Security (RLS)
82
+ ALTER TABLE public.profiles ENABLE ROW LEVEL SECURITY;
83
+ ALTER TABLE public.sessions ENABLE ROW LEVEL SECURITY;
84
+ ALTER TABLE public.messages ENABLE ROW LEVEL SECURITY;
85
+ ALTER TABLE public.jobs ENABLE ROW LEVEL SECURITY;
86
+ ALTER TABLE public.session_assets ENABLE ROW LEVEL SECURITY;
87
+
88
+ -- Polices for public.profiles
89
+ DROP POLICY IF EXISTS "Users view own profile" ON public.profiles;
90
+ CREATE POLICY "Users view own profile" ON public.profiles FOR SELECT USING (auth.uid() = id);
91
+ DROP POLICY IF EXISTS "Users update own profile" ON public.profiles;
92
+ CREATE POLICY "Users update own profile" ON public.profiles FOR UPDATE USING (auth.uid() = id);
93
+
94
+ -- Policies for public.sessions
95
+ DROP POLICY IF EXISTS "Users manage own sessions" ON public.sessions;
96
+ CREATE POLICY "Users manage own sessions" ON public.sessions FOR ALL USING (auth.uid() = user_id);
97
+
98
+ -- Policies for public.messages
99
+ DROP POLICY IF EXISTS "Users manage own messages" ON public.messages;
100
+ CREATE POLICY "Users manage own messages" ON public.messages FOR ALL USING (
101
+ session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid())
102
+ OR (auth.jwt() ->> 'role' = 'service_role')
103
+ );
104
+
105
+ -- Policies for public.session_assets
106
+ DROP POLICY IF EXISTS "Users view own assets" ON public.session_assets;
107
+ CREATE POLICY "Users view own assets" ON public.session_assets FOR SELECT USING (
108
+ session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid())
109
+ );
110
+ DROP POLICY IF EXISTS "Service role manages assets" ON public.session_assets;
111
+ CREATE POLICY "Service role manages assets" ON public.session_assets FOR ALL USING (true);
112
+
113
+ -- Policies for public.jobs
114
+ DROP POLICY IF EXISTS "Users manage own jobs" ON public.jobs;
115
+ CREATE POLICY "Users manage own jobs" ON public.jobs FOR ALL USING (
116
+ auth.uid() = user_id OR user_id IS NULL OR (auth.jwt() ->> 'role' = 'service_role')
117
+ );
118
+
119
+ -- 7. Storage Policies (Bucket: video)
120
+ -- (Run this in Supabase Dashboard if not allowed in migration)
121
+ -- INSERT INTO storage.buckets (id, name, public) VALUES ('video', 'video', true) ON CONFLICT (id) DO NOTHING;
122
+ -- CREATE POLICY "Service Role manage videos" ON storage.objects FOR ALL TO service_role USING (bucket_id = 'video');
123
+ -- CREATE POLICY "Public read videos" ON storage.objects FOR SELECT TO public USING (bucket_id = 'video');
124
+
125
+ -- Grant permissions to public/authenticated
126
+ GRANT ALL ON public.profiles TO authenticated;
127
+ GRANT ALL ON public.sessions TO authenticated;
128
+ GRANT ALL ON public.messages TO authenticated;
129
+ GRANT ALL ON public.jobs TO authenticated;
130
+ GRANT ALL ON public.session_assets TO authenticated;
131
+ GRANT ALL ON public.session_assets TO service_role;
pytest.ini ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [pytest]
2
+ asyncio_mode = auto
3
+ testpaths = tests
4
+ pythonpath = .
5
+ filterwarnings =
6
+ ignore::DeprecationWarning
7
+
8
+ markers =
9
+ real_api: HTTP tests need running backend and TEST_USER_ID / TEST_SESSION_ID.
10
+ real_worker_ocr: OCR Celery task or full OCR stack (heavy).
11
+ real_worker_manim: Real Manim render and Supabase video upload.
12
+ real_agents: Live LLM / orchestrator agent calls.
13
+ slow: Large suite or long polling timeouts.
14
+ smoke: Fast API health + one solve job.
15
+ orchestrator_local: In-process Orchestrator without HTTP server.
16
+
17
+ # Default: skip integration tests that need services, keys, or long runs.
18
+ addopts = -m "not real_api and not real_worker_ocr and not real_worker_manim and not real_agents and not slow and not orchestrator_local"
requirements.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Target: Python 3.11 (see Dockerfile). Used by: FastAPI API, Celery worker, Manim render, OCR/vision stack.
2
+ # Install: pip install -r requirements.txt
3
+
4
+ # --- Dev / test ---
5
+ pytest>=8.0
6
+ pytest-asyncio>=0.24
7
+
8
+ # --- HTTP API ---
9
+ cachetools>=5.3
10
+ fastapi>=0.115,<1
11
+ uvicorn[standard]>=0.30
12
+ python-multipart>=0.0.9
13
+ python-dotenv>=1.0
14
+ pydantic[email]>=2.4
15
+ email-validator>=2
16
+
17
+ # --- Auth / data / queue ---
18
+ openai>=1.40
19
+ supabase>=2.0
20
+ celery>=5.3
21
+ redis>=5
22
+ httpx>=0.27
23
+ websockets>=12
24
+
25
+ # --- Math & symbolic solver ---
26
+ sympy>=1.12
27
+ numpy>=1.26,<2
28
+ scipy>=1.11
29
+ opencv-python-headless>=4.8,<4.10
30
+
31
+ # --- Video (GeometryScene via CLI) ---
32
+ manim>=0.18,<0.20
33
+
34
+ # --- OCR & vision (orchestrator / legacy /ocr) ---
35
+ pix2tex>=0.1.4
36
+ paddleocr==2.7.3
37
+ paddlepaddle==2.6.2
38
+ ultralytics==8.2.2
requirements.worker-ocr.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OCR-only Celery worker: YOLO + PaddleOCR + Pix2Tex (no OpenRouter / no Manim).
2
+ # Install: pip install -r requirements.worker-ocr.txt
3
+
4
+ cachetools>=5.3
5
+ fastapi>=0.115,<1
6
+ uvicorn[standard]>=0.30
7
+ python-multipart>=0.0.9
8
+ python-dotenv>=1.0
9
+ pydantic[email]>=2.4
10
+ email-validator>=2
11
+
12
+ celery>=5.3
13
+ redis>=5
14
+ httpx>=0.27
15
+ websockets>=12
16
+
17
+ numpy>=1.26,<2
18
+ opencv-python-headless>=4.8,<4.10
19
+
20
+ pix2tex>=0.1.4
21
+ paddleocr==2.7.3
22
+ paddlepaddle==2.6.2
23
+ ultralytics==8.2.2
run_api_test.sh ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ LOG_FILE="api_test_results.log"
4
+ echo "=== Starting API E2E Test Suite ($(date)) ===" > $LOG_FILE
5
+
6
+ # 1. Start BE Server in background
7
+ echo "[INFO] Starting Backend Server..." | tee -a $LOG_FILE
8
+ export ALLOW_TEST_BYPASS=true
9
+ export LOG_LEVEL=info
10
+ export CELERY_TASK_ALWAYS_EAGER=true
11
+ export CELERY_RESULT_BACKEND=rpc://
12
+ export MOCK_VIDEO=true
13
+ PYTHONPATH=. venv/bin/python -m uvicorn app.main:app --port 8000 > server_debug.log 2>&1 &
14
+ SERVER_PID=$!
15
+
16
+ # 2. Wait for server to be ready
17
+ echo "[INFO] Waiting for server (PID: $SERVER_PID) on port 8000..." | tee -a $LOG_FILE
18
+ MAX_RETRIES=15
19
+ READY=0
20
+ for i in $(seq 1 $MAX_RETRIES); do
21
+ if curl -s http://localhost:8000/ > /dev/null; then
22
+ READY=1
23
+ break
24
+ fi
25
+ sleep 2
26
+ done
27
+
28
+ if [ $READY -eq 0 ]; then
29
+ echo "[ERROR] Server failed to start in time. Check server_debug.log" | tee -a $LOG_FILE
30
+ kill $SERVER_PID
31
+ exit 1
32
+ fi
33
+ echo "[INFO] Server is READY." | tee -a $LOG_FILE
34
+
35
+ # 3. Prepare Test Data
36
+ echo "[INFO] Preparing fresh test data..." | tee -a $LOG_FILE
37
+ PREP_OUTPUT=$(PYTHONPATH=. venv/bin/python scripts/prepare_api_test.py)
38
+ echo "$PREP_OUTPUT" >> $LOG_FILE
39
+
40
+ export TEST_USER_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:USER_ID=" | cut -d'=' -f2)
41
+ export TEST_SESSION_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2)
42
+
43
+ if [ -z "$TEST_USER_ID" ] || [ -z "$TEST_SESSION_ID" ]; then
44
+ echo "[ERROR] Failed to prepare test data." | tee -a $LOG_FILE
45
+ kill $SERVER_PID
46
+ exit 1
47
+ fi
48
+
49
+ echo "[INFO] Test Data: User=$TEST_USER_ID, Session=$TEST_SESSION_ID" | tee -a $LOG_FILE
50
+
51
+ # 4. Run Pytest
52
+ echo "[INFO] Running API E2E Tests..." | tee -a $LOG_FILE
53
+ PYTHONPATH=. venv/bin/python -m pytest tests/test_api_real_e2e.py -m "smoke and real_api" -s \
54
+ --junitxml=pytest_smoke.xml >> $LOG_FILE 2>&1
55
+ TEST_EXIT_CODE=$?
56
+
57
+ # 5. Cleanup
58
+ echo "[INFO] Shutting down Server..." | tee -a $LOG_FILE
59
+ kill $SERVER_PID
60
+
61
+ echo "==========================================" | tee -a $LOG_FILE
62
+ if [ $TEST_EXIT_CODE -eq 0 ]; then
63
+ echo "FINAL RESULT: ✅ ALL API TESTS PASSED" | tee -a $LOG_FILE
64
+ else
65
+ echo "FINAL RESULT: ❌ SOME API TESTS FAILED (Code: $TEST_EXIT_CODE)" | tee -a $LOG_FILE
66
+ fi
67
+ echo "==========================================" | tee -a $LOG_FILE
68
+
69
+ exit $TEST_EXIT_CODE
run_full_api_test.sh ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Full API integration (CI-style): eager Celery + mock video + full HTTP suite.
3
+ LOG_FILE="full_api_suite.log"
4
+ REPORT_FILE="full_api_test_report.md"
5
+ JSON_RESULTS="temp_suite_results.json"
6
+ JUNIT="pytest_api_suite.xml"
7
+
8
+ echo "=== Starting Full API Suite Test ($(date)) ===" >"$LOG_FILE"
9
+
10
+ trap 'echo "[INFO] Cleaning up processes..."; kill $SERVER_PID 2>/dev/null; sleep 1' EXIT
11
+
12
+ echo "[INFO] Starting Backend Server (EAGER + MOCK_VIDEO)..." | tee -a "$LOG_FILE"
13
+ export ALLOW_TEST_BYPASS=true
14
+ export LOG_LEVEL=info
15
+ export CELERY_TASK_ALWAYS_EAGER=true
16
+ export CELERY_RESULT_BACKEND=rpc://
17
+ export MOCK_VIDEO=true
18
+ PYTHONPATH=. venv/bin/python -m uvicorn app.main:app --port 8000 >server_debug.log 2>&1 &
19
+ SERVER_PID=$!
20
+
21
+ echo "[INFO] Waiting for server (PID: $SERVER_PID)..." | tee -a "$LOG_FILE"
22
+ for i in {1..20}; do
23
+ if curl -s http://localhost:8000/ >/dev/null; then
24
+ echo "[INFO] Server is READY." | tee -a "$LOG_FILE"
25
+ break
26
+ fi
27
+ sleep 2
28
+ done
29
+
30
+ echo "[INFO] Preparing fresh test data..." | tee -a "$LOG_FILE"
31
+ PREP_OUTPUT=$(PYTHONPATH=. venv/bin/python scripts/prepare_api_test.py)
32
+ export TEST_USER_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:USER_ID=" | cut -d'=' -f2)
33
+ export TEST_SESSION_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2)
34
+
35
+ if [ -z "$TEST_USER_ID" ]; then
36
+ echo "[ERROR] Failed to prepare test data." | tee -a "$LOG_FILE"
37
+ exit 1
38
+ fi
39
+
40
+ echo "[INFO] Executing API tests (smoke + full suite)..." | tee -a "$LOG_FILE"
41
+ PYTHONPATH=. venv/bin/python -m pytest tests/test_api_real_e2e.py tests/test_api_full_suite.py \
42
+ -m "real_api" -s --tb=short --junitxml="$JUNIT" >>"$LOG_FILE" 2>&1
43
+ TEST_EXIT_CODE=$?
44
+
45
+ echo "[INFO] Generating Markdown Report..." | tee -a "$LOG_FILE"
46
+ if [ -f "$JSON_RESULTS" ]; then
47
+ PYTHONPATH=. venv/bin/python scripts/generate_report.py "$JSON_RESULTS" "$REPORT_FILE" "$JUNIT"
48
+ else
49
+ echo "[WARN] $JSON_RESULTS not found" | tee -a "$LOG_FILE"
50
+ fi
51
+
52
+ echo "==========================================" | tee -a "$LOG_FILE"
53
+ echo "DONE. Check $REPORT_FILE for results." | tee -a "$LOG_FILE"
54
+ echo "==========================================" | tee -a "$LOG_FILE"
55
+
56
+ exit $TEST_EXIT_CODE
scripts/benchmark_openrouter.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Benchmark several OpenRouter models (manual tool; not part of pytest)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import time
8
+
9
+ import httpx
10
+ from dotenv import load_dotenv
11
+
12
+ _BACKEND_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
13
+ load_dotenv(os.path.join(_BACKEND_ROOT, ".env"))
14
+
15
+ MODELS = [
16
+ "nvidia/nemotron-3-super-120b-a12b:free",
17
+ "meta-llama/llama-3.3-70b-instruct:free",
18
+ "openai/gpt-oss-120b:free",
19
+ "z-ai/glm-4.5-air:free",
20
+ "minimax/minimax-m2.5:free",
21
+ "google/gemma-4-26b-a4b-it:free",
22
+ "google/gemma-4-31b-it:free",
23
+ ]
24
+
25
+ PROMPT = (
26
+ "Cho hình chữ nhật ABCD có AB bằng 5 và AD bằng 10. Gọi E là điểm nằm trong đoạn CD sao cho CE = 2ED. "
27
+ "Vẽ đoạn thẳng AE. Vẽ thêm P là điểm nằm trên đường thẳng BC sao cho BP = 2PC, tính chu vi tam giác PEA"
28
+ )
29
+
30
+
31
+ def main() -> None:
32
+ api_key = os.getenv("OPENROUTER_API_KEY_1") or os.getenv("OPENROUTER_API_KEY")
33
+ base_url = "https://openrouter.ai/api/v1/chat/completions"
34
+
35
+ if not api_key:
36
+ print("Missing OPENROUTER_API_KEY_1 or OPENROUTER_API_KEY in .env")
37
+ return
38
+
39
+ print("Benchmark OpenRouter models\nPrompt:", PROMPT, "\n")
40
+ results = []
41
+
42
+ for model in MODELS:
43
+ print(f"Calling {model}...", end="", flush=True)
44
+ headers = {
45
+ "Authorization": f"Bearer {api_key}",
46
+ "Content-Type": "application/json",
47
+ "HTTP-Referer": "https://mathsolver.io",
48
+ "X-Title": "MathSolver Benchmark Tool",
49
+ }
50
+ payload = {"model": model, "messages": [{"role": "user", "content": PROMPT}]}
51
+ start = time.time()
52
+ try:
53
+ with httpx.Client(timeout=120.0) as client:
54
+ r = client.post(base_url, headers=headers, json=payload)
55
+ r.raise_for_status()
56
+ data = r.json()
57
+ answer = data["choices"][0]["message"]["content"]
58
+ duration = time.time() - start
59
+ results.append(
60
+ {"model": model, "duration": duration, "answer": answer, "status": "success"}
61
+ )
62
+ print(f" OK ({duration:.2f}s)")
63
+ except Exception as e:
64
+ duration = time.time() - start
65
+ results.append(
66
+ {"model": model, "duration": duration, "error": str(e), "status": "error"}
67
+ )
68
+ print(f" FAIL ({duration:.2f}s) {e}")
69
+
70
+ print("\n" + "=" * 80)
71
+ for res in results:
72
+ print(json.dumps(res, ensure_ascii=False, indent=2)[:2000])
73
+ print("-" * 40)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ main()
scripts/generate_report.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ import xml.etree.ElementTree as ET
5
+ from datetime import datetime
6
+
7
+
8
+ def _parse_junit_xml(path: str) -> dict:
9
+ """Summarize pytest junitxml (JUnit) file."""
10
+ out = {"tests": 0, "failures": 0, "errors": 0, "skipped": 0, "time": 0.0}
11
+ try:
12
+ tree = ET.parse(path)
13
+ root = tree.getroot()
14
+ nodes = [root] if root.tag == "testsuite" else list(root.iter("testsuite"))
15
+ for ts in nodes:
16
+ if ts.tag != "testsuite":
17
+ continue
18
+ out["tests"] += int(ts.attrib.get("tests", 0) or 0)
19
+ out["failures"] += int(ts.attrib.get("failures", 0) or 0)
20
+ out["errors"] += int(ts.attrib.get("errors", 0) or 0)
21
+ out["skipped"] += int(ts.attrib.get("skipped", 0) or 0)
22
+ out["time"] += float(ts.attrib.get("time", 0) or 0)
23
+ except Exception as e:
24
+ out["parse_error"] = str(e)
25
+ return out
26
+
27
+
28
+ def generate_report(json_path: str, report_path: str, junit_path: str | None = None) -> None:
29
+ try:
30
+ with open(json_path, "r", encoding="utf-8") as f:
31
+ data = json.load(f)
32
+
33
+ junit_summary = None
34
+ if junit_path and os.path.isfile(junit_path):
35
+ junit_summary = _parse_junit_xml(junit_path)
36
+
37
+ with open(report_path, "w", encoding="utf-8") as f:
38
+ f.write("# Báo cáo Kiểm thử tích hợp Backend (Integration Report)\n\n")
39
+ f.write(f"**Thời gian chạy:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
40
+ suite_ok = all(r.get("success", False) for r in data) if isinstance(data, list) else False
41
+ f.write(f"**API suite (JSON):** {'PASS' if suite_ok else 'FAIL'}\n")
42
+
43
+ if junit_summary and "parse_error" not in junit_summary:
44
+ j_ok = junit_summary["failures"] == 0 and junit_summary["errors"] == 0
45
+ f.write(
46
+ f"**Pytest (JUnit):** {'PASS' if j_ok else 'FAIL'} — "
47
+ f"tests={junit_summary['tests']}, failures={junit_summary['failures']}, "
48
+ f"errors={junit_summary['errors']}, skipped={junit_summary['skipped']}, "
49
+ f"time_s={junit_summary['time']:.2f}\n"
50
+ )
51
+ elif junit_summary and "parse_error" in junit_summary:
52
+ f.write(f"**Pytest (JUnit):** (could not parse: {junit_summary['parse_error']})\n")
53
+
54
+ f.write("\n")
55
+
56
+ f.write("| ID | Câu hỏi (Query) | Trạng thái | Thời gian (s) | Kết quả / Lỗi |\n")
57
+ f.write("| :--- | :--- | :--- | :--- | :--- |\n")
58
+ for r in data:
59
+ status = "PASS" if r.get("success") else "FAIL"
60
+ elapsed = f"{float(r.get('elapsed', 0) or 0):.2f}"
61
+ query = r.get("query", "-")
62
+
63
+ res = r.get("result", {})
64
+ if not isinstance(res, dict):
65
+ res = {}
66
+
67
+ analysis = res.get("semantic_analysis", "-")
68
+ if not r.get("success"):
69
+ analysis = f"**Lỗi:** {r.get('error', '-')}"
70
+
71
+ short_analysis = (analysis[:100] + "...") if len(str(analysis)) > 100 else analysis
72
+
73
+ f.write(f"| {r['id']} | {query} | {status} | {elapsed} | {short_analysis} |\n")
74
+
75
+ f.write("\n---\n**Chi tiết Output (DSL & Analysis):**\n")
76
+ for r in data:
77
+ if not r.get("success"):
78
+ continue
79
+ res = r.get("result", {})
80
+ if not isinstance(res, dict):
81
+ continue
82
+
83
+ f.write(f"\n### Case {r['id']}: {r.get('query')}\n")
84
+ f.write(f"**Semantic Analysis:**\n{res.get('semantic_analysis', '-')}\n\n")
85
+ f.write(f"**Geometry DSL:**\n```\n{res.get('geometry_dsl', '-')}\n```\n")
86
+
87
+ sol = res.get("solution")
88
+ if sol and isinstance(sol, dict):
89
+ f.write("**Solution (v5.1):**\n")
90
+ f.write(f"- **Answer:** {sol.get('answer', 'N/A')}\n")
91
+ f.write("- **Steps:**\n")
92
+ steps = sol.get("steps", [])
93
+ if steps:
94
+ for step in steps:
95
+ f.write(f" - {step}\n")
96
+ else:
97
+ f.write(" - (Không có bước giải cụ thể)\n")
98
+
99
+ if sol.get("symbolic_expression"):
100
+ f.write(f"- **Symbolic:** `{sol.get('symbolic_expression')}`\n")
101
+ f.write("\n")
102
+
103
+ print(f"Report generated: {report_path}")
104
+ except Exception as e:
105
+ print(f"Error generating report: {e}")
106
+
107
+
108
+ if __name__ == "__main__":
109
+ if len(sys.argv) < 3:
110
+ print(
111
+ "Usage: python generate_report.py <json_results> <report_output> [junit_xml_optional]"
112
+ )
113
+ sys.exit(1)
114
+ junit = sys.argv[3] if len(sys.argv) > 3 else None
115
+ generate_report(sys.argv[1], sys.argv[2], junit)
scripts/prepare_api_test.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import uuid
4
+
5
+ from dotenv import load_dotenv
6
+
7
+ # Add parent dir to path to import app modules
8
+ _BACKEND_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
9
+ sys.path.append(_BACKEND_ROOT)
10
+ load_dotenv(os.path.join(_BACKEND_ROOT, ".env"))
11
+
12
+ from app.supabase_client import get_supabase
13
+
14
+ # Default UUID matches historical dev DB; override with TEST_SUPABASE_USER_ID in .env
15
+ _DEFAULT_TEST_USER = "8cd3adb0-7964-4575-949c-d0cadcd8b679"
16
+
17
+
18
+ def prepare():
19
+ supabase = get_supabase()
20
+ user_id = os.environ.get("TEST_SUPABASE_USER_ID", _DEFAULT_TEST_USER).strip()
21
+ session_id = str(uuid.uuid4())
22
+
23
+ print(f"Using test user (TEST_SUPABASE_USER_ID or default): {user_id}")
24
+
25
+ print(f"Creating fresh test session: {session_id}")
26
+ # Insert session
27
+ supabase.table("sessions").insert({
28
+ "id": session_id,
29
+ "user_id": user_id,
30
+ "title": f"Fresh API Test {session_id[:8]}"
31
+ }).execute()
32
+
33
+ # Return IDs for the test script
34
+ print(f"RESULT:USER_ID={user_id}")
35
+ print(f"RESULT:SESSION_ID={session_id}")
36
+
37
+ if __name__ == "__main__":
38
+ prepare()
scripts/prewarm_models.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Download and load all heavy models during Docker build (YOLO, PaddleOCR, Pix2Tex, agents).
4
+ Fails the image build if initialization fails.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ import os
11
+ import sys
12
+
13
+ # Ensure imports work when run as `python scripts/prewarm_models.py` from WORKDIR
14
+ _APP_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
15
+ if _APP_ROOT not in sys.path:
16
+ sys.path.insert(0, _APP_ROOT)
17
+
18
+ os.chdir(_APP_ROOT)
19
+
20
+ from dotenv import load_dotenv
21
+
22
+ load_dotenv()
23
+
24
+ from app.runtime_env import apply_runtime_env_defaults
25
+
26
+ apply_runtime_env_defaults()
27
+
28
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s | %(message)s")
29
+
30
+ logger = logging.getLogger("prewarm")
31
+
32
+
33
+ def main() -> None:
34
+ from agents.orchestrator import Orchestrator
35
+
36
+ logger.info("Constructing Orchestrator (full agent + model load)...")
37
+ Orchestrator()
38
+ logger.info("Prewarm finished successfully.")
39
+
40
+
41
+ if __name__ == "__main__":
42
+ main()
scripts/prewarm_ocr_worker.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Docker build: load OCR models only (no Orchestrator / no LLM). Used by Dockerfile.worker.ocr."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import logging
7
+ import os
8
+ import sys
9
+
10
+ _APP_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11
+ if _APP_ROOT not in sys.path:
12
+ sys.path.insert(0, _APP_ROOT)
13
+
14
+ os.chdir(_APP_ROOT)
15
+
16
+ from dotenv import load_dotenv
17
+
18
+ load_dotenv()
19
+
20
+ from app.runtime_env import apply_runtime_env_defaults
21
+
22
+ apply_runtime_env_defaults()
23
+
24
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s | %(message)s")
25
+ logger = logging.getLogger("prewarm_ocr_worker")
26
+
27
+
28
+ def main() -> None:
29
+ from agents.ocr_agent import OCRAgent
30
+
31
+ logger.info("Loading OCRAgent(skip_llm_refinement=True)...")
32
+ OCRAgent(skip_llm_refinement=True)
33
+ logger.info("OCR worker prewarm finished successfully.")
34
+
35
+
36
+ if __name__ == "__main__":
37
+ main()
scripts/run_real_integration.sh ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Run backend integration tests. Usage:
3
+ # ./scripts/run_real_integration.sh # profile ci (default)
4
+ # ./scripts/run_real_integration.sh ci
5
+ # ./scripts/run_real_integration.sh real # heavy: workers, manim, OCR, full API suite
6
+ set -euo pipefail
7
+
8
+ PROFILE="${1:-ci}"
9
+ ROOT="$(cd "$(dirname "$0")/.." && pwd)"
10
+ cd "$ROOT"
11
+ PY="${ROOT}/venv/bin/python"
12
+ if [[ ! -x "$PY" ]]; then
13
+ PY="python3"
14
+ fi
15
+
16
+ export PYTHONPATH="$ROOT"
17
+ LOG_FILE="${LOG_FILE:-integration_run.log}"
18
+ JUNIT="${JUNIT:-pytest_integration.xml}"
19
+ REPORT_MD="${REPORT_MD:-integration_report.md}"
20
+ JSON_RESULTS="${JSON_RESULTS:-temp_suite_results.json}"
21
+
22
+ log() { echo "[$(date '+%H:%M:%S')] $*" | tee -a "$LOG_FILE"; }
23
+
24
+ log "Profile=$PROFILE working_dir=$ROOT"
25
+
26
+ if [[ "$PROFILE" == "ci" ]]; then
27
+ export ALLOW_TEST_BYPASS="${ALLOW_TEST_BYPASS:-true}"
28
+ export LOG_LEVEL="${LOG_LEVEL:-info}"
29
+ export CELERY_TASK_ALWAYS_EAGER="${CELERY_TASK_ALWAYS_EAGER:-true}"
30
+ export CELERY_RESULT_BACKEND="${CELERY_RESULT_BACKEND:-rpc://}"
31
+ export MOCK_VIDEO="${MOCK_VIDEO:-true}"
32
+
33
+ set +e
34
+ log "Phase A: default pytest (unit / mocked; excludes real_* markers per pytest.ini)"
35
+ "$PY" -m pytest tests/ -q --tb=short -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE"
36
+ P1=${PIPESTATUS[0]}
37
+ set -e
38
+
39
+ log "Starting uvicorn for API phase..."
40
+ "$PY" -m uvicorn app.main:app --port 8000 >>uvicorn_integration.log 2>&1 &
41
+ SERVER_PID=$!
42
+ trap 'kill "$SERVER_PID" 2>/dev/null || true' EXIT
43
+
44
+ for i in $(seq 1 25); do
45
+ if curl -sf "http://localhost:8000/" >/dev/null; then
46
+ log "API ready"
47
+ break
48
+ fi
49
+ sleep 2
50
+ if [[ "$i" -eq 25 ]]; then
51
+ log "ERROR: API did not start"
52
+ exit 1
53
+ fi
54
+ done
55
+
56
+ PREP="$("$PY" scripts/prepare_api_test.py)"
57
+ echo "$PREP" | tee -a "$LOG_FILE"
58
+ export TEST_USER_ID="$(echo "$PREP" | grep "RESULT:USER_ID=" | cut -d'=' -f2)"
59
+ export TEST_SESSION_ID="$(echo "$PREP" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2)"
60
+ if [[ -z "${TEST_USER_ID:-}" || -z "${TEST_SESSION_ID:-}" ]]; then
61
+ log "ERROR: prepare_api_test did not emit USER_ID / SESSION_ID"
62
+ exit 1
63
+ fi
64
+
65
+ set +e
66
+ log "Phase B: API smoke + full suite (real_api)"
67
+ "$PY" -m pytest tests/test_api_real_e2e.py tests/test_api_full_suite.py \
68
+ -m "real_api" -s --tb=short --junitxml="$JUNIT" -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE"
69
+ P2=${PIPESTATUS[0]}
70
+ set -e
71
+
72
+ if [[ -f "$JSON_RESULTS" ]]; then
73
+ log "Generating Markdown report"
74
+ "$PY" scripts/generate_report.py "$JSON_RESULTS" "$REPORT_MD" "$JUNIT"
75
+ else
76
+ log "WARN: $JSON_RESULTS missing (suite may have failed before write)"
77
+ fi
78
+
79
+ if [[ "$P1" -ne 0 || "$P2" -ne 0 ]]; then
80
+ log "FAIL: phase A exit=$P1 phase B exit=$P2"
81
+ exit 1
82
+ fi
83
+
84
+ log "Done CI profile. See $REPORT_MD and $LOG_FILE"
85
+ exit 0
86
+ fi
87
+
88
+ if [[ "$PROFILE" == "real" ]]; then
89
+ unset CELERY_TASK_ALWAYS_EAGER || true
90
+ export CELERY_TASK_ALWAYS_EAGER="${CELERY_TASK_ALWAYS_EAGER:-false}"
91
+ export MOCK_VIDEO="${MOCK_VIDEO:-false}"
92
+ export RUN_REAL_WORKER_OCR="${RUN_REAL_WORKER_OCR:-0}"
93
+ export RUN_REAL_WORKER_MANIM="${RUN_REAL_WORKER_MANIM:-0}"
94
+
95
+ log "Phase A: default pytest (fast)"
96
+ "$PY" -m pytest tests/ -q --tb=short -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE"
97
+
98
+ log "Phase B: real agents + orchestrator smoke (requires OpenRouter keys)"
99
+ "$PY" -m pytest tests/integration/test_agents_real.py tests/integration/test_orchestrator_smoke.py \
100
+ -m "real_agents" -q --tb=short --junitxml="$JUNIT" -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE" || true
101
+
102
+ if [[ "${RUN_REAL_WORKER_OCR:-0}" == "1" ]] || [[ "${RUN_REAL_WORKER_OCR:-0}" =~ ^(true|yes)$ ]]; then
103
+ log "Phase C: OCR worker task (RUN_REAL_WORKER_OCR enabled)"
104
+ "$PY" -m pytest tests/integration/test_worker_ocr_real.py \
105
+ -m "real_worker_ocr" -q --tb=short -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE" || true
106
+ else
107
+ log "Skipping OCR worker (set RUN_REAL_WORKER_OCR=1 to enable)"
108
+ fi
109
+
110
+ if [[ "${RUN_REAL_WORKER_MANIM:-0}" == "1" ]]; then
111
+ log "Phase D: Manim + storage (RUN_REAL_WORKER_MANIM=1, MOCK_VIDEO=false)"
112
+ "$PY" -m pytest tests/integration/test_worker_manim_real.py -m "real_worker_manim" -s --tb=short \
113
+ -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE" || true
114
+ else
115
+ log "Skipping Manim integration (set RUN_REAL_WORKER_MANIM=1 to enable)"
116
+ fi
117
+
118
+ log "Phase E: API real (expects TEST_BASE_URL or localhost:8000 with server already up)"
119
+ if curl -sf "http://localhost:8000/" >/dev/null 2>&1; then
120
+ PREP="$("$PY" scripts/prepare_api_test.py)"
121
+ export TEST_USER_ID="$(echo "$PREP" | grep "RESULT:USER_ID=" | cut -d'=' -f2)"
122
+ export TEST_SESSION_ID="$(echo "$PREP" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2)"
123
+ "$PY" -m pytest tests/test_api_real_e2e.py tests/test_api_full_suite.py tests/test_api_metadata_real.py \
124
+ -m "real_api" -q --tb=short -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE" || true
125
+ else
126
+ log "WARN: No server on :8000 — skip API real phase (start backend first)"
127
+ fi
128
+
129
+ log "Done REAL profile. Review $LOG_FILE"
130
+ exit 0
131
+ fi
132
+
133
+ echo "Unknown profile: $PROFILE (use ci or real)"
134
+ exit 1
scripts/test_LLM.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import time
4
+ import asyncio
5
+ import logging
6
+ from typing import List, Dict, Any
7
+ from dotenv import load_dotenv
8
+
9
+ # Add the parent directory to sys.path to allow importing from 'app'
10
+ # This assumes the script is inside 'backend/scripts' and we want to import from 'backend/app'
11
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
+
13
+ from app.url_utils import openai_compatible_api_key
14
+ from openai import AsyncOpenAI
15
+
16
+ # Set up logger
17
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # List of models to benchmark
21
+ MODELS_TO_TEST = [
22
+ "nvidia/nemotron-3-super-120b-a12b:free",
23
+ "meta-llama/llama-3.3-70b-instruct:free",
24
+ "openai/gpt-oss-120b:free",
25
+ "z-ai/glm-4.5-air:free",
26
+ "minimax/minimax-m2.5:free",
27
+ "google/gemma-4-26b-a4b-it:free",
28
+ "google/gemma-4-31b-it:free",
29
+ "arcee-ai/trinity-large-preview:free",
30
+ "openai/gpt-oss-20b:free",
31
+ "nvidia/nemotron-3-nano-30b-a3b:free",
32
+ "nvidia/nemotron-nano-9b-v2:free",
33
+ ]
34
+
35
+ DEFAULT_QUERY = "Giải hệ phương trình sau: x + y = 10, 2x - y = 2. Trả về kết quả cuối cùng x và y."
36
+
37
+ async def test_model(client: AsyncOpenAI, model: str, query: str) -> Dict[str, Any]:
38
+ """Test a single model and return performance metrics."""
39
+ start_time = time.time()
40
+ result = {
41
+ "model": model,
42
+ "status": "success",
43
+ "duration": 0,
44
+ "content": "",
45
+ "error": None
46
+ }
47
+
48
+ try:
49
+ response = await client.chat.completions.create(
50
+ model=model,
51
+ messages=[{"role": "user", "content": query}],
52
+ timeout=60.0
53
+ )
54
+ result["duration"] = time.time() - start_time
55
+ result["content"] = response.choices[0].message.content.strip()
56
+ except Exception as e:
57
+ result["status"] = "failed"
58
+ result["duration"] = time.time() - start_time
59
+ result["error"] = str(e)
60
+
61
+ return result
62
+
63
+ async def main():
64
+ # Load configuration from .env file inside backend directory
65
+ # If starting from root, backend/.env might be needed. If starting from backend/, .env is enough.
66
+ load_dotenv()
67
+
68
+ # Try multiple common env keys for api key
69
+ api_key = os.getenv("OPENROUTER_API_KEY_1") or os.getenv("OPENROUTER_API_KEY")
70
+
71
+ if not api_key:
72
+ logger.error("❌ Error: NO OPENROUTER_API_KEY found in environment variables.")
73
+ logger.info("Check your .env file in the backend directory.")
74
+ return
75
+
76
+ # Using the project's url_utils to maintain consistency with the main app
77
+ sanitized_key = openai_compatible_api_key(api_key)
78
+
79
+ client = AsyncOpenAI(
80
+ api_key=sanitized_key,
81
+ base_url="https://openrouter.ai/api/v1",
82
+ default_headers={
83
+ "HTTP-Referer": "https://mathsolver.ai",
84
+ "X-Title": "MathSolver LLM Benchmarker",
85
+ }
86
+ )
87
+
88
+ query = DEFAULT_QUERY
89
+ logger.info("=" * 80)
90
+ logger.info(f"🚀 LLM PERFORMANCE BENCHMARK")
91
+ logger.info(f"Query: {query}")
92
+ logger.info("=" * 80)
93
+ logger.info(f"Testing {len(MODELS_TO_TEST)} models sequentially with 30s delay...\n")
94
+
95
+ results = []
96
+ for i, model in enumerate(MODELS_TO_TEST):
97
+ if i > 0:
98
+ logger.info(f"⏳ Waiting 30s before testing next model...")
99
+ await asyncio.sleep(30)
100
+
101
+ logger.info(f"[{i+1}/{len(MODELS_TO_TEST)}] Testing: {model}...")
102
+ res = await test_model(client, model, query)
103
+ results.append(res)
104
+
105
+ # Immediate feedback
106
+ status_str = "✅ SUCCESS" if res["status"] == "success" else "❌ FAILED"
107
+ logger.info(f" Status: {status_str} | Time: {res['duration']:.2f}s")
108
+
109
+ # Report Summary Table
110
+ logger.info("\n" + "=" * 80)
111
+ logger.info("📊 FINAL BENCHMARK SUMMARY")
112
+ logger.info("=" * 80)
113
+ header = f"{'MODEL':<45} | {'STATUS':<10} | {'TIME (s)':<10}"
114
+ logger.info(header)
115
+ logger.info("-" * len(header))
116
+
117
+ for res in results:
118
+ status_str = "✅ SUCCESS" if res["status"] == "success" else "❌ FAILED"
119
+ duration_str = f"{res['duration']:.2f}s"
120
+ logger.info(f"{res['model']:<45} | {status_str:<10} | {duration_str:<10}")
121
+
122
+ logger.info("-" * len(header))
123
+
124
+ # Detailed report for successful ones
125
+ logger.info("\n📝 FULL RESPONSES:")
126
+ for res in results:
127
+ logger.info(f"\n{'='*20} [{res['model']}] {'='*20}")
128
+ if res["status"] == "success":
129
+ logger.info(res["content"])
130
+ else:
131
+ logger.info(f"❌ Error: {res['error']}")
132
+
133
+ logger.info("\n" + "=" * 80)
134
+ logger.info(f"Benchmark finished.")
135
+
136
+ if __name__ == "__main__":
137
+ try:
138
+ asyncio.run(main())
139
+ except KeyboardInterrupt:
140
+ logger.info("\nBenchmark cancelled by user.")
141
+ except Exception as e:
142
+ logger.error(f"Unexpected error: {e}")
scripts/test_engine_direct.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import json
4
+ import logging
5
+ import sys
6
+
7
+ # Add root directory to path to import app and agents
8
+ sys.path.append("/Volumes/WorkSpace/Project/MathSolver/backend")
9
+
10
+ # Configure logging to stdout
11
+ logging.basicConfig(level=logging.DEBUG)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ from agents.orchestrator import Orchestrator
15
+
16
+ async def main():
17
+ orch = Orchestrator()
18
+ text = "Vẽ tam giác đều cạnh 5."
19
+ job_id = "test_direct_equilateral"
20
+
21
+ print(f"\n--- Testing Orchestrator Direct: {text} ---")
22
+
23
+ async def status_cb(status):
24
+ print(f" [STATUS] {status}")
25
+
26
+ try:
27
+ result = await orch.run(text, job_id=job_id, status_callback=status_cb, request_video=False)
28
+ print("\n--- Final Result ---")
29
+ print(json.dumps(result, indent=2))
30
+ except Exception as e:
31
+ print(f"\n--- ERROR ---")
32
+ import traceback
33
+ traceback.print_exc()
34
+
35
+ if __name__ == "__main__":
36
+ asyncio.run(main())