Spaces:
Sleeping
Sleeping
github-actions commited on
Commit ·
200b382
0
Parent(s):
Deploy render worker from GitHub Actions
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Dockerfile +47 -0
- Dockerfile.worker.ocr +42 -0
- README.md +24 -0
- README_HF_WORKER_OCR.md +27 -0
- agents/geometry_agent.py +120 -0
- agents/knowledge_agent.py +135 -0
- agents/ocr_agent.py +307 -0
- agents/orchestrator.py +219 -0
- agents/parser_agent.py +106 -0
- agents/renderer_agent.py +265 -0
- agents/solver_agent.py +107 -0
- agents/torch_ultralytics_compat.py +33 -0
- app/chat_image_upload.py +206 -0
- app/dependencies.py +69 -0
- app/errors.py +59 -0
- app/llm_client.py +100 -0
- app/logging_setup.py +112 -0
- app/logutil.py +67 -0
- app/main.py +140 -0
- app/models/schemas.py +80 -0
- app/ocr_celery.py +54 -0
- app/ocr_local_file.py +43 -0
- app/ocr_text_merge.py +14 -0
- app/routers/__init__.py +1 -0
- app/routers/auth.py +23 -0
- app/routers/sessions.py +165 -0
- app/routers/solve.py +410 -0
- app/runtime_env.py +12 -0
- app/session_cache.py +48 -0
- app/supabase_client.py +37 -0
- app/url_utils.py +23 -0
- app/websocket_manager.py +40 -0
- clean_ports.sh +22 -0
- dump.rdb +0 -0
- migrations/add_image_bucket_storage.sql +35 -0
- migrations/fix_rls_assets.sql +96 -0
- migrations/v4_migration.sql +131 -0
- pytest.ini +18 -0
- requirements.txt +38 -0
- requirements.worker-ocr.txt +23 -0
- run_api_test.sh +69 -0
- run_full_api_test.sh +56 -0
- scripts/benchmark_openrouter.py +77 -0
- scripts/generate_report.py +115 -0
- scripts/prepare_api_test.py +38 -0
- scripts/prewarm_models.py +42 -0
- scripts/prewarm_ocr_worker.py +37 -0
- scripts/run_real_integration.sh +134 -0
- scripts/test_LLM.py +142 -0
- scripts/test_engine_direct.py +36 -0
Dockerfile
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Same runtime as API; runs health endpoint + Celery worker (see worker_health.py)
|
| 2 |
+
FROM python:3.11-slim-bookworm
|
| 3 |
+
|
| 4 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 5 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 6 |
+
PIP_NO_CACHE_DIR=1 \
|
| 7 |
+
PIP_ROOT_USER_ACTION=ignore \
|
| 8 |
+
NO_ALBUMENTATIONS_UPDATE=1 \
|
| 9 |
+
OMP_NUM_THREADS=1 \
|
| 10 |
+
MKL_NUM_THREADS=1 \
|
| 11 |
+
OPENBLAS_NUM_THREADS=1
|
| 12 |
+
|
| 13 |
+
WORKDIR /app
|
| 14 |
+
ENV PYTHONPATH=/app
|
| 15 |
+
|
| 16 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 17 |
+
ffmpeg \
|
| 18 |
+
pkg-config \
|
| 19 |
+
cmake \
|
| 20 |
+
libcairo2 \
|
| 21 |
+
libcairo2-dev \
|
| 22 |
+
libpango-1.0-0 \
|
| 23 |
+
libpango1.0-dev \
|
| 24 |
+
libpangocairo-1.0-0 \
|
| 25 |
+
libgdk-pixbuf-2.0-0 \
|
| 26 |
+
libffi-dev \
|
| 27 |
+
python3-dev \
|
| 28 |
+
texlive-latex-base \
|
| 29 |
+
texlive-fonts-recommended \
|
| 30 |
+
texlive-latex-extra \
|
| 31 |
+
build-essential \
|
| 32 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 33 |
+
|
| 34 |
+
COPY requirements.txt .
|
| 35 |
+
RUN pip install --upgrade pip setuptools wheel \
|
| 36 |
+
&& pip install -r requirements.txt
|
| 37 |
+
|
| 38 |
+
COPY . .
|
| 39 |
+
|
| 40 |
+
RUN python scripts/prewarm_models.py
|
| 41 |
+
|
| 42 |
+
ENV PORT=7860 \
|
| 43 |
+
CELERY_WORKER_QUEUES=render
|
| 44 |
+
EXPOSE 7860
|
| 45 |
+
|
| 46 |
+
ENTRYPOINT []
|
| 47 |
+
CMD ["sh", "-c", "exec python3 -u worker_health.py"]
|
Dockerfile.worker.ocr
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Celery worker: OCR queue only (no Manim / LaTeX / Cairo stack).
|
| 2 |
+
FROM python:3.11-slim-bookworm
|
| 3 |
+
|
| 4 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 5 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 6 |
+
PIP_NO_CACHE_DIR=1 \
|
| 7 |
+
PIP_ROOT_USER_ACTION=ignore \
|
| 8 |
+
NO_ALBUMENTATIONS_UPDATE=1 \
|
| 9 |
+
OMP_NUM_THREADS=1 \
|
| 10 |
+
MKL_NUM_THREADS=1 \
|
| 11 |
+
OPENBLAS_NUM_THREADS=1
|
| 12 |
+
|
| 13 |
+
WORKDIR /app
|
| 14 |
+
ENV PYTHONPATH=/app
|
| 15 |
+
|
| 16 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 17 |
+
build-essential \
|
| 18 |
+
cmake \
|
| 19 |
+
pkg-config \
|
| 20 |
+
python3-dev \
|
| 21 |
+
libglib2.0-0 \
|
| 22 |
+
libgomp1 \
|
| 23 |
+
libgl1 \
|
| 24 |
+
libsm6 \
|
| 25 |
+
libxext6 \
|
| 26 |
+
libxrender1 \
|
| 27 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 28 |
+
|
| 29 |
+
COPY requirements.worker-ocr.txt .
|
| 30 |
+
RUN pip install --upgrade pip setuptools wheel \
|
| 31 |
+
&& pip install -r requirements.worker-ocr.txt
|
| 32 |
+
|
| 33 |
+
COPY . .
|
| 34 |
+
|
| 35 |
+
RUN python scripts/prewarm_ocr_worker.py
|
| 36 |
+
|
| 37 |
+
ENV PORT=7860 \
|
| 38 |
+
CELERY_WORKER_QUEUES=ocr
|
| 39 |
+
EXPOSE 7860
|
| 40 |
+
|
| 41 |
+
ENTRYPOINT []
|
| 42 |
+
CMD ["sh", "-c", "exec python3 -u worker_health.py"]
|
README.md
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Math Solver Render Worker
|
| 3 |
+
emoji: 👷
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# Math Solver — Render worker (Manim)
|
| 11 |
+
|
| 12 |
+
This Space runs **Celery** via `worker_health.py` and consumes **only** queue **`render`** (`render_geometry_video`). Image sets `CELERY_WORKER_QUEUES=render` by default (`Dockerfile.worker`).
|
| 13 |
+
|
| 14 |
+
**Solve** (orchestrator, agents, OCR-in-request when `OCR_USE_CELERY` is off) runs on the **API** Space, not on this worker.
|
| 15 |
+
|
| 16 |
+
## OCR offload (separate Space)
|
| 17 |
+
|
| 18 |
+
Queue **`ocr`** is handled by a **dedicated OCR worker** (`Dockerfile.worker.ocr`, `README_HF_WORKER_OCR.md`, workflow `deploy-worker-ocr.yml`). On the API, set `OCR_USE_CELERY=true` and deploy an OCR Space that listens on `ocr`.
|
| 19 |
+
|
| 20 |
+
## Secrets
|
| 21 |
+
|
| 22 |
+
Same broker as the API: `REDIS_URL` / `CELERY_BROKER_URL`, Supabase, OpenRouter (renderer may use LLM paths), etc.
|
| 23 |
+
|
| 24 |
+
**GitHub Actions:** repository secrets `HF_TOKEN` and `HF_WORKER_REPO` (`owner/space-name`) for this workflow (`deploy-worker.yml`).
|
README_HF_WORKER_OCR.md
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Math Solver OCR Worker
|
| 3 |
+
emoji: 👁️
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# Math Solver — OCR-only worker
|
| 11 |
+
|
| 12 |
+
This Space runs **Celery** (`worker_health.py`) consuming **only** the `ocr` queue.
|
| 13 |
+
|
| 14 |
+
Set environment:
|
| 15 |
+
|
| 16 |
+
- `CELERY_WORKER_QUEUES=ocr` (default in `Dockerfile.worker.ocr`)
|
| 17 |
+
- Same `REDIS_URL` / `CELERY_BROKER_URL` / `CELERY_RESULT_BACKEND` as the API
|
| 18 |
+
|
| 19 |
+
This Space runs **raw OCR only** (YOLO, PaddleOCR, Pix2Tex). **OpenRouter / LLM tinh chỉnh** không chạy ở đây; API Space gọi `refine_with_llm` sau khi nhận kết quả từ queue `ocr`.
|
| 20 |
+
|
| 21 |
+
On the **API** Space, set `OCR_USE_CELERY=true` so `run_ocr_from_url` tasks are sent to this worker instead of running Paddle/Pix2Tex on the API process.
|
| 22 |
+
|
| 23 |
+
Optional: `OCR_CELERY_TIMEOUT_SEC` (default `180`).
|
| 24 |
+
|
| 25 |
+
**Manim / video** uses a different Celery queue (`render`) and Space — see `README_HF_WORKER.md` and workflow `deploy-worker.yml`.
|
| 26 |
+
|
| 27 |
+
GitHub Actions: repository secrets `HF_TOKEN` and `HF_OCR_WORKER_REPO` (`owner/space-name`) enable workflow `deploy-worker-ocr.yml`.
|
agents/geometry_agent.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from openai import AsyncOpenAI
|
| 5 |
+
from typing import Dict, Any
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
from app.url_utils import openai_compatible_api_key, sanitize_env
|
| 12 |
+
from app.llm_client import get_llm_client
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GeometryAgent:
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.llm = get_llm_client()
|
| 18 |
+
|
| 19 |
+
async def generate_dsl(self, semantic_data: Dict[str, Any], previous_dsl: str = None) -> str:
|
| 20 |
+
logger.info("==[GeometryAgent] Generating DSL from semantic data==")
|
| 21 |
+
if previous_dsl:
|
| 22 |
+
logger.info(f"[GeometryAgent] Using previous DSL context (len={len(previous_dsl)})")
|
| 23 |
+
|
| 24 |
+
system_prompt = """
|
| 25 |
+
You are a Geometry DSL Generator. Convert semantic geometry data into a precise Geometry DSL program.
|
| 26 |
+
|
| 27 |
+
=== MULTI-TURN CONTEXT ===
|
| 28 |
+
If a PREVIOUS DSL is provided, your job is to UPDATE or EXTEND it.
|
| 29 |
+
1. DO NOT remove existing points unless the user explicitly asks to "redefine" or "move" them.
|
| 30 |
+
2. Ensure new segments/points connect correctly to existing ones.
|
| 31 |
+
3. Your output should be the ENTIRE updated DSL, not just the changes.
|
| 32 |
+
|
| 33 |
+
=== DSL COMMANDS ===
|
| 34 |
+
POINT(A) — declare a point
|
| 35 |
+
POINT(A, x, y, z) — declare a point with explicit coordinates
|
| 36 |
+
LENGTH(AB, 5) — distance between A and B is 5 (2D/3D)
|
| 37 |
+
ANGLE(A, 90) — interior angle at vertex A is 90° (2D/3D)
|
| 38 |
+
PARALLEL(AB, CD) — segment AB is parallel to CD (2D/3D)
|
| 39 |
+
PERPENDICULAR(AB, CD) — segment AB is perpendicular to CD (2D/3D)
|
| 40 |
+
MIDPOINT(M, AB) — M is the midpoint of segment AB
|
| 41 |
+
SECTION(E, A, C, k) — E satisfies vector AE = k * vector AC (k is decimal)
|
| 42 |
+
LINE(A, B) — infinite line passing through A and B
|
| 43 |
+
RAY(A, B) — ray starting at A and passing through B
|
| 44 |
+
CIRCLE(O, 5) — circle with center O and radius 5 (2D)
|
| 45 |
+
SPHERE(O, 5) — sphere with center O and radius 5 (3D)
|
| 46 |
+
SEGMENT(M, N) — auxiliary segment MN to be drawn
|
| 47 |
+
POLYGON_ORDER(A, B, C, D) — the order in which vertices form the polygon boundary
|
| 48 |
+
TRIANGLE(ABC) — equilateral/arbitrary triangle
|
| 49 |
+
PYRAMID(S_ABCD) — pyramid with apex S and base ABCD
|
| 50 |
+
PRISM(ABC_DEF) — triangular prism
|
| 51 |
+
|
| 52 |
+
=== RULES ===
|
| 53 |
+
1. 3D Coordinates: Use POINT(A, x, y, z) if specific coordinates are given in the problem.
|
| 54 |
+
2. Space Geometry: For pyramids/prisms, use the specialized commands.
|
| 55 |
+
3. Primary Vertices: Always declare the main vertices of the shape (e.g., A, B, C, D) using POINT(X).
|
| 56 |
+
4. POLYGON_ORDER: Always emit POLYGON_ORDER(...) for the main shape using ONLY these primary vertices.
|
| 57 |
+
5. All Points: EVERY point mentioned (A, B, C, H, M, etc.) MUST be declared with POINT(Name) first.
|
| 58 |
+
6. Altitudes/Perpendiculars: For an altitude AH to BC, use POINT(H) + PERPENDICULAR(AH, BC).
|
| 59 |
+
7. Format: Output ONLY DSL lines — NO explanation, NO markdown, NO code blocks.
|
| 60 |
+
|
| 61 |
+
=== SHAPE EXAMPLES ===
|
| 62 |
+
|
| 63 |
+
--- Case: Square Pyramid S.ABCD with side 10, height 15 ---
|
| 64 |
+
PYRAMID(S_ABCD)
|
| 65 |
+
POINT(A, 0, 0, 0)
|
| 66 |
+
POINT(B, 10, 0, 0)
|
| 67 |
+
POINT(C, 10, 10, 0)
|
| 68 |
+
POINT(D, 0, 10, 0)
|
| 69 |
+
POINT(S)
|
| 70 |
+
POINT(O)
|
| 71 |
+
SECTION(O, A, C, 0.5)
|
| 72 |
+
LENGTH(SO, 15)
|
| 73 |
+
PERPENDICULAR(SO, AC)
|
| 74 |
+
PERPENDICULAR(SO, AB)
|
| 75 |
+
POLYGON_ORDER(A, B, C, D)
|
| 76 |
+
|
| 77 |
+
--- Case: Right Triangle ABC at A, AB=3, AC=4, altitude AH ---
|
| 78 |
+
POLYGON_ORDER(A, B, C)
|
| 79 |
+
POINT(A)
|
| 80 |
+
POINT(B)
|
| 81 |
+
POINT(C)
|
| 82 |
+
POINT(H)
|
| 83 |
+
LENGTH(AB, 3)
|
| 84 |
+
LENGTH(AC, 4)
|
| 85 |
+
ANGLE(A, 90)
|
| 86 |
+
PERPENDICULAR(AH, BC)
|
| 87 |
+
SEGMENT(A, H)
|
| 88 |
+
|
| 89 |
+
--- Case: Rectangle ABCD with AB=5, AD=10 ---
|
| 90 |
+
POLYGON_ORDER(A, B, C, D)
|
| 91 |
+
POINT(A)
|
| 92 |
+
POINT(B)
|
| 93 |
+
POINT(C)
|
| 94 |
+
POINT(D)
|
| 95 |
+
LENGTH(AB, 5)
|
| 96 |
+
LENGTH(AD, 10)
|
| 97 |
+
PERPENDICULAR(AB, AD)
|
| 98 |
+
PARALLEL(AB, CD)
|
| 99 |
+
PARALLEL(AD, BC)
|
| 100 |
+
|
| 101 |
+
[Circle with center O radius 7]
|
| 102 |
+
POINT(O)
|
| 103 |
+
CIRCLE(O, 7)
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
user_content = f"Semantic Data: {json.dumps(semantic_data, ensure_ascii=False)}"
|
| 107 |
+
if previous_dsl:
|
| 108 |
+
user_content = f"PREVIOUS DSL:\n{previous_dsl}\n\nUPDATE WITH NEW DATA: {json.dumps(semantic_data, ensure_ascii=False)}"
|
| 109 |
+
|
| 110 |
+
logger.debug("[GeometryAgent] Calling LLM (Multi-Layer)...")
|
| 111 |
+
content = await self.llm.chat_completions_create(
|
| 112 |
+
messages=[
|
| 113 |
+
{"role": "system", "content": system_prompt},
|
| 114 |
+
{"role": "user", "content": user_content}
|
| 115 |
+
]
|
| 116 |
+
)
|
| 117 |
+
dsl = content.strip() if content else ""
|
| 118 |
+
logger.info(f"[GeometryAgent] DSL generated ({len(dsl.splitlines())} lines).")
|
| 119 |
+
logger.debug(f"[GeometryAgent] DSL output:\n{dsl}")
|
| 120 |
+
return dsl
|
agents/knowledge_agent.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Dict, Any
|
| 3 |
+
|
| 4 |
+
logger = logging.getLogger(__name__)
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# ─── Shape rule registry ────────────────────────────────────────────────────
|
| 8 |
+
# Each entry: keyword list → augmentation function
|
| 9 |
+
# Augmentation receives (values: dict, text: str) and returns updated values dict.
|
| 10 |
+
|
| 11 |
+
class KnowledgeAgent:
|
| 12 |
+
"""Knowledge Agent: Stores geometric theorems and common patterns to augment Parser output."""
|
| 13 |
+
|
| 14 |
+
def augment_semantic_data(self, semantic_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 15 |
+
logger.info("==[KnowledgeAgent] Augmenting semantic data==")
|
| 16 |
+
text = str(semantic_data.get("input_text", "")).lower()
|
| 17 |
+
logger.debug(f"[KnowledgeAgent] Input text for matching: '{text[:200]}'")
|
| 18 |
+
|
| 19 |
+
shape_type = self._detect_shape(text, semantic_data.get("type", ""))
|
| 20 |
+
if shape_type:
|
| 21 |
+
semantic_data["type"] = shape_type
|
| 22 |
+
values = semantic_data.get("values", {})
|
| 23 |
+
values = self._augment_values(shape_type, values, text)
|
| 24 |
+
semantic_data["values"] = values
|
| 25 |
+
else:
|
| 26 |
+
logger.info("[KnowledgeAgent] No special rule matched. Returning data unchanged.")
|
| 27 |
+
|
| 28 |
+
logger.debug(f"[KnowledgeAgent] Output semantic data: {semantic_data}")
|
| 29 |
+
return semantic_data
|
| 30 |
+
|
| 31 |
+
# ─── Shape detection ────────────────────────────────────────────────────
|
| 32 |
+
def _detect_shape(self, text: str, llm_type: str) -> str | None:
|
| 33 |
+
"""Detect shape from text keywords. LLM type provides a hint."""
|
| 34 |
+
checks = [
|
| 35 |
+
(["hình vuông", "square"], "square"),
|
| 36 |
+
(["hình chữ nhật", "rectangle"], "rectangle"),
|
| 37 |
+
(["hình thoi", "rhombus"], "rhombus"),
|
| 38 |
+
(["hình bình hành", "parallelogram"], "parallelogram"),
|
| 39 |
+
(["hình thang vuông"], "right_trapezoid"),
|
| 40 |
+
(["hình thang", "trapezoid", "trapezium"], "trapezoid"),
|
| 41 |
+
(["tam giác vuông", "right triangle"], "right_triangle"),
|
| 42 |
+
(["tam giác đều", "equilateral triangle", "equilateral"], "equilateral_triangle"),
|
| 43 |
+
(["tam giác cân", "isosceles"], "isosceles_triangle"),
|
| 44 |
+
(["tam giác", "triangle"], "triangle"),
|
| 45 |
+
(["đường tròn", "circle"], "circle"),
|
| 46 |
+
]
|
| 47 |
+
for keywords, shape in checks:
|
| 48 |
+
if any(kw in text for kw in keywords):
|
| 49 |
+
logger.info(f"[KnowledgeAgent] Rule MATCH: '{shape}' detected (keyword match).")
|
| 50 |
+
return shape
|
| 51 |
+
|
| 52 |
+
# Fallback: trust LLM-detected type if it's a known type
|
| 53 |
+
known = {
|
| 54 |
+
"rectangle", "square", "rhombus", "parallelogram",
|
| 55 |
+
"trapezoid", "right_trapezoid", "triangle", "right_triangle",
|
| 56 |
+
"equilateral_triangle", "isosceles_triangle", "circle",
|
| 57 |
+
}
|
| 58 |
+
if llm_type in known:
|
| 59 |
+
logger.info(f"[KnowledgeAgent] Using LLM-detected type '{llm_type}'.")
|
| 60 |
+
return llm_type
|
| 61 |
+
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
# ─── Value augmentation ──────────────────────────────────────────────────
|
| 65 |
+
def _augment_values(self, shape: str, values: dict, text: str) -> dict:
|
| 66 |
+
ab = values.get("AB")
|
| 67 |
+
ad = values.get("AD")
|
| 68 |
+
bc = values.get("BC")
|
| 69 |
+
cd = values.get("CD")
|
| 70 |
+
|
| 71 |
+
if shape == "rectangle":
|
| 72 |
+
if ab and ad:
|
| 73 |
+
values.setdefault("CD", ab)
|
| 74 |
+
values.setdefault("BC", ad)
|
| 75 |
+
values.setdefault("angle_A", 90)
|
| 76 |
+
logger.info(f"[KnowledgeAgent] Rectangle: AB=CD={ab}, AD=BC={ad}, angle_A=90°")
|
| 77 |
+
else:
|
| 78 |
+
values.setdefault("angle_A", 90)
|
| 79 |
+
|
| 80 |
+
elif shape == "square":
|
| 81 |
+
side = ab or ad or bc or cd or values.get("side")
|
| 82 |
+
if side:
|
| 83 |
+
values.update({"AB": side, "AD": side, "angle_A": 90})
|
| 84 |
+
logger.info(f"[KnowledgeAgent] Square: side={side}, angle_A=90°")
|
| 85 |
+
else:
|
| 86 |
+
values.setdefault("angle_A", 90)
|
| 87 |
+
|
| 88 |
+
elif shape == "rhombus":
|
| 89 |
+
side = ab or values.get("side")
|
| 90 |
+
if side:
|
| 91 |
+
values.update({"AB": side, "BC": side, "CD": side, "DA": side})
|
| 92 |
+
logger.info(f"[KnowledgeAgent] Rhombus: all sides={side}")
|
| 93 |
+
|
| 94 |
+
elif shape == "parallelogram":
|
| 95 |
+
if ab:
|
| 96 |
+
values.setdefault("CD", ab)
|
| 97 |
+
if ad:
|
| 98 |
+
values.setdefault("BC", ad)
|
| 99 |
+
logger.info(f"[KnowledgeAgent] Parallelogram: AB||CD, AD||BC")
|
| 100 |
+
|
| 101 |
+
elif shape == "trapezoid":
|
| 102 |
+
logger.info("[KnowledgeAgent] Trapezoid: AB||CD (bottom||top)")
|
| 103 |
+
|
| 104 |
+
elif shape == "right_trapezoid":
|
| 105 |
+
logger.info("[KnowledgeAgent] Right trapezoid: AB||CD, AD⊥AB")
|
| 106 |
+
values.setdefault("angle_A", 90)
|
| 107 |
+
|
| 108 |
+
elif shape == "equilateral_triangle":
|
| 109 |
+
side = ab or values.get("side")
|
| 110 |
+
if side:
|
| 111 |
+
values.update({"AB": side, "BC": side, "CA": side, "angle_A": 60})
|
| 112 |
+
logger.info(f"[KnowledgeAgent] Equilateral triangle: all sides={side}, angle_A=60°")
|
| 113 |
+
|
| 114 |
+
elif shape == "right_triangle":
|
| 115 |
+
# Try to infer which vertex is the right angle
|
| 116 |
+
rt_vertex = _detect_right_angle_vertex(text)
|
| 117 |
+
values.setdefault(f"angle_{rt_vertex}", 90)
|
| 118 |
+
logger.info(f"[KnowledgeAgent] Right triangle: angle_{rt_vertex}=90°")
|
| 119 |
+
|
| 120 |
+
elif shape == "isosceles_triangle":
|
| 121 |
+
logger.info("[KnowledgeAgent] Isosceles triangle: AB=AC (default, LLM may override)")
|
| 122 |
+
|
| 123 |
+
elif shape == "circle":
|
| 124 |
+
logger.info("[KnowledgeAgent] Circle detected — no side augmentation needed.")
|
| 125 |
+
|
| 126 |
+
return values
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _detect_right_angle_vertex(text: str) -> str:
|
| 130 |
+
"""Heuristic: detect which vertex is right angle from text."""
|
| 131 |
+
for vertex in ["A", "B", "C", "D"]:
|
| 132 |
+
patterns = [f"vuông tại {vertex}", f"góc {vertex} vuông", f"right angle at {vertex}"]
|
| 133 |
+
if any(p.lower() in text for p in patterns):
|
| 134 |
+
return vertex
|
| 135 |
+
return "A" # default
|
agents/ocr_agent.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
import logging
|
| 4 |
+
import asyncio
|
| 5 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import cv2
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
_OCR_MAX_EDGE = 2000
|
| 13 |
+
_CROP_PAD = 4
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ImprovedOCRAgent:
|
| 17 |
+
"""
|
| 18 |
+
Advanced OCR Agent using a hybrid pipeline:
|
| 19 |
+
1. YOLO for layout analysis (text vs formula).
|
| 20 |
+
2. PaddleOCR for Vietnamese text extraction.
|
| 21 |
+
3. Pix2Tex for LaTeX formula extraction.
|
| 22 |
+
4. Optional MegaLLM for semantic correction and formatting (skipped when ``skip_llm_refinement`` is True,
|
| 23 |
+
e.g. on the dedicated OCR Celery worker; the API Space runs ``refine_with_llm`` on the raw text).
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, skip_llm_refinement: bool = False):
|
| 27 |
+
self._skip_llm_refinement = bool(skip_llm_refinement)
|
| 28 |
+
logger.info("[ImprovedOCRAgent] Initializing engines (skip_llm_refinement=%s)...", self._skip_llm_refinement)
|
| 29 |
+
|
| 30 |
+
if self._skip_llm_refinement:
|
| 31 |
+
self.llm = None
|
| 32 |
+
logger.info("[ImprovedOCRAgent] LLM client skipped (raw OCR only).")
|
| 33 |
+
else:
|
| 34 |
+
from app.llm_client import get_llm_client
|
| 35 |
+
|
| 36 |
+
self.llm = get_llm_client()
|
| 37 |
+
logger.info("[ImprovedOCRAgent] Multi-Layer LLM Client initialized.")
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
from agents.torch_ultralytics_compat import allow_ultralytics_weights
|
| 41 |
+
from ultralytics import YOLO
|
| 42 |
+
|
| 43 |
+
allow_ultralytics_weights()
|
| 44 |
+
logger.info("[ImprovedOCRAgent] Loading YOLO...")
|
| 45 |
+
self.layout_model = YOLO("yolov8n.pt")
|
| 46 |
+
logger.info("[ImprovedOCRAgent] YOLO initialized.")
|
| 47 |
+
except Exception as e:
|
| 48 |
+
logger.error("[ImprovedOCRAgent] YOLO init failed: %s", e)
|
| 49 |
+
self.layout_model = None
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
from paddleocr import PaddleOCR
|
| 53 |
+
|
| 54 |
+
logger.info("[ImprovedOCRAgent] Loading PaddleOCR...")
|
| 55 |
+
self.text_model = PaddleOCR(use_angle_cls=True, lang="vi")
|
| 56 |
+
logger.info("[ImprovedOCRAgent] PaddleOCR (vi) initialized.")
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.error("[ImprovedOCRAgent] PaddleOCR init failed: %s", e)
|
| 59 |
+
self.text_model = None
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
from pix2tex.cli import LatexOCR
|
| 63 |
+
|
| 64 |
+
logger.info("[ImprovedOCRAgent] Loading Pix2Tex...")
|
| 65 |
+
self.math_model = LatexOCR()
|
| 66 |
+
logger.info("[ImprovedOCRAgent] Pix2Tex initialized.")
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.error("[ImprovedOCRAgent] Pix2Tex init failed: %s", e)
|
| 69 |
+
self.math_model = None
|
| 70 |
+
|
| 71 |
+
def _preprocess_image_for_ocr(self, src_path: str) -> Tuple[str, bool]:
|
| 72 |
+
"""Resize large images, CLAHE on luminance; returns path (may be new temp file)."""
|
| 73 |
+
img = cv2.imread(src_path, cv2.IMREAD_COLOR)
|
| 74 |
+
if img is None:
|
| 75 |
+
g = cv2.imread(src_path, cv2.IMREAD_GRAYSCALE)
|
| 76 |
+
if g is None:
|
| 77 |
+
logger.warning("[ImprovedOCRAgent] OpenCV could not read %s; using original.", src_path)
|
| 78 |
+
return src_path, False
|
| 79 |
+
img = cv2.cvtColor(g, cv2.COLOR_GRAY2BGR)
|
| 80 |
+
h, w = img.shape[:2]
|
| 81 |
+
max_dim = max(h, w)
|
| 82 |
+
if max_dim > _OCR_MAX_EDGE:
|
| 83 |
+
scale = _OCR_MAX_EDGE / max_dim
|
| 84 |
+
img = cv2.resize(
|
| 85 |
+
img, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_AREA
|
| 86 |
+
)
|
| 87 |
+
logger.info("[ImprovedOCRAgent] Resized for OCR to max edge %s", _OCR_MAX_EDGE)
|
| 88 |
+
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 89 |
+
gray = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(gray)
|
| 90 |
+
den = cv2.fastNlMeansDenoising(gray, None, 8, 7, 21)
|
| 91 |
+
out = f"temp_ocr_prep_{uuid.uuid4().hex}.png"
|
| 92 |
+
cv2.imwrite(out, den)
|
| 93 |
+
return out, True
|
| 94 |
+
|
| 95 |
+
def _load_bgr_for_crops(self, path: str) -> Optional[np.ndarray]:
|
| 96 |
+
im = cv2.imread(path, cv2.IMREAD_COLOR)
|
| 97 |
+
if im is None:
|
| 98 |
+
g = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
|
| 99 |
+
if g is None:
|
| 100 |
+
return None
|
| 101 |
+
im = cv2.cvtColor(g, cv2.COLOR_GRAY2BGR)
|
| 102 |
+
return im
|
| 103 |
+
|
| 104 |
+
def _crop_from_quad(self, img_bgr: np.ndarray, bbox) -> Optional[np.ndarray]:
|
| 105 |
+
try:
|
| 106 |
+
pts = np.array(bbox, dtype=np.float32)
|
| 107 |
+
xs = pts[:, 0]
|
| 108 |
+
ys = pts[:, 1]
|
| 109 |
+
H, W = img_bgr.shape[:2]
|
| 110 |
+
x1 = max(0, int(xs.min()) - _CROP_PAD)
|
| 111 |
+
y1 = max(0, int(ys.min()) - _CROP_PAD)
|
| 112 |
+
x2 = min(W, int(xs.max()) + _CROP_PAD)
|
| 113 |
+
y2 = min(H, int(ys.max()) + _CROP_PAD)
|
| 114 |
+
if x2 <= x1 or y2 <= y1:
|
| 115 |
+
return None
|
| 116 |
+
return img_bgr[y1:y2, x1:x2].copy()
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logger.debug("[ImprovedOCRAgent] crop failed: %s", e)
|
| 119 |
+
return None
|
| 120 |
+
|
| 121 |
+
def _latex_from_crop_bgr(self, crop_bgr: np.ndarray) -> Optional[str]:
|
| 122 |
+
if self.math_model is None or crop_bgr is None or crop_bgr.size == 0:
|
| 123 |
+
return None
|
| 124 |
+
ch, cw = crop_bgr.shape[:2]
|
| 125 |
+
if ch < 10 or cw < 10:
|
| 126 |
+
return None
|
| 127 |
+
try:
|
| 128 |
+
from PIL import Image
|
| 129 |
+
|
| 130 |
+
rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB)
|
| 131 |
+
pil = Image.fromarray(rgb)
|
| 132 |
+
out = self.math_model(pil)
|
| 133 |
+
if isinstance(out, str) and out.strip():
|
| 134 |
+
return out.strip()
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.debug("[ImprovedOCRAgent] Pix2Tex on crop failed: %s", e)
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
def _maybe_math_from_crop(self, img_bgr: Optional[np.ndarray], bbox, text: str) -> str:
|
| 140 |
+
if img_bgr is None or not self.math_model:
|
| 141 |
+
return text
|
| 142 |
+
is_math_hint = any(
|
| 143 |
+
c in text for c in ["\\", "^", "_", "{", "}", "=", "+", "-", "*", "/"]
|
| 144 |
+
)
|
| 145 |
+
if not is_math_hint:
|
| 146 |
+
return text
|
| 147 |
+
crop = self._crop_from_quad(img_bgr, bbox)
|
| 148 |
+
latex = self._latex_from_crop_bgr(crop) if crop is not None else None
|
| 149 |
+
if latex:
|
| 150 |
+
logger.info("[ImprovedOCRAgent] Pix2Tex replaced line fragment (len=%s)", len(latex))
|
| 151 |
+
return f"${latex}$"
|
| 152 |
+
return text
|
| 153 |
+
|
| 154 |
+
async def process_image(self, image_path: str) -> str:
|
| 155 |
+
logger.info("==[ImprovedOCRAgent] Processing: %s==", image_path)
|
| 156 |
+
|
| 157 |
+
if not os.path.exists(image_path):
|
| 158 |
+
return f"Error: File {image_path} not found."
|
| 159 |
+
|
| 160 |
+
prep_path, prep_cleanup = self._preprocess_image_for_ocr(image_path)
|
| 161 |
+
paddle_path = prep_path if prep_cleanup else image_path
|
| 162 |
+
img_bgr = self._load_bgr_for_crops(prep_path if prep_cleanup else image_path)
|
| 163 |
+
|
| 164 |
+
raw_fragments: List[Dict[str, Any]] = []
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
if self.text_model:
|
| 168 |
+
logger.info("[ImprovedOCRAgent] Running PaddleOCR on %s...", paddle_path)
|
| 169 |
+
result = self.text_model.ocr(paddle_path)
|
| 170 |
+
logger.info("[ImprovedOCRAgent] PaddleOCR raw result: %s", result)
|
| 171 |
+
|
| 172 |
+
if not result:
|
| 173 |
+
logger.warning("[ImprovedOCRAgent] PaddleOCR returned no results.")
|
| 174 |
+
return ""
|
| 175 |
+
|
| 176 |
+
if isinstance(result[0], dict):
|
| 177 |
+
res_dict = result[0]
|
| 178 |
+
rec_texts = res_dict.get("rec_texts", [])
|
| 179 |
+
rec_scores = res_dict.get("rec_scores", [])
|
| 180 |
+
rec_polys = res_dict.get("rec_polys", [])
|
| 181 |
+
|
| 182 |
+
for i in range(len(rec_texts)):
|
| 183 |
+
text = rec_texts[i]
|
| 184 |
+
bbox = rec_polys[i]
|
| 185 |
+
score = rec_scores[i] if i < len(rec_scores) else None
|
| 186 |
+
if score is not None and float(score) < 0.45:
|
| 187 |
+
logger.debug(
|
| 188 |
+
"[ImprovedOCRAgent] Low-confidence line (score=%s): %s",
|
| 189 |
+
score,
|
| 190 |
+
text[:80],
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
y_top = int(min(p[1] for p in bbox)) if hasattr(bbox, "__iter__") else 0
|
| 194 |
+
content = self._maybe_math_from_crop(img_bgr, bbox, text)
|
| 195 |
+
raw_fragments.append({"y": y_top, "content": content, "type": "text"})
|
| 196 |
+
elif isinstance(result[0], list):
|
| 197 |
+
for line in result[0]:
|
| 198 |
+
bbox = line[0]
|
| 199 |
+
text = line[1][0]
|
| 200 |
+
score = line[1][1] if len(line[1]) > 1 else None
|
| 201 |
+
if score is not None and float(score) < 0.45:
|
| 202 |
+
logger.debug(
|
| 203 |
+
"[ImprovedOCRAgent] Low-confidence line (score=%s): %s",
|
| 204 |
+
score,
|
| 205 |
+
text[:80],
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
y_top = int(bbox[0][1])
|
| 209 |
+
content = self._maybe_math_from_crop(img_bgr, bbox, text)
|
| 210 |
+
raw_fragments.append({"y": y_top, "content": content, "type": "text"})
|
| 211 |
+
finally:
|
| 212 |
+
if prep_cleanup and os.path.exists(prep_path):
|
| 213 |
+
try:
|
| 214 |
+
os.remove(prep_path)
|
| 215 |
+
except OSError:
|
| 216 |
+
pass
|
| 217 |
+
|
| 218 |
+
raw_fragments.sort(key=lambda x: x["y"])
|
| 219 |
+
combined_text = "\n".join([f["content"] for f in raw_fragments])
|
| 220 |
+
|
| 221 |
+
logger.info(
|
| 222 |
+
"[ImprovedOCRAgent] Raw OCR output assembled:\n---\n%s\n---", combined_text
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
if not combined_text.strip():
|
| 226 |
+
logger.warning("[ImprovedOCRAgent] No text detected to refine.")
|
| 227 |
+
return ""
|
| 228 |
+
|
| 229 |
+
if self._skip_llm_refinement or self.llm is None:
|
| 230 |
+
logger.info("[ImprovedOCRAgent] Skipping MegaLLM refinement (raw OCR output).")
|
| 231 |
+
return combined_text
|
| 232 |
+
|
| 233 |
+
try:
|
| 234 |
+
logger.info("[ImprovedOCRAgent] Sending to MegaLLM for refinement...")
|
| 235 |
+
refined_text = await asyncio.wait_for(
|
| 236 |
+
self.refine_with_llm(combined_text), timeout=30.0
|
| 237 |
+
)
|
| 238 |
+
return refined_text
|
| 239 |
+
except asyncio.TimeoutError:
|
| 240 |
+
logger.error("[ImprovedOCRAgent] MegaLLM refinement timed out.")
|
| 241 |
+
return combined_text
|
| 242 |
+
except Exception as e:
|
| 243 |
+
logger.error("[ImprovedOCRAgent] MegaLLM refinement failed: %s", e)
|
| 244 |
+
return combined_text
|
| 245 |
+
|
| 246 |
+
async def refine_with_llm(self, text: str) -> str:
|
| 247 |
+
if not text.strip():
|
| 248 |
+
return ""
|
| 249 |
+
if self.llm is None:
|
| 250 |
+
logger.warning("[ImprovedOCRAgent] refine_with_llm: no LLM client; returning raw text.")
|
| 251 |
+
return text
|
| 252 |
+
|
| 253 |
+
prompt = f"""Bạn là một chuyên gia số hóa tài liệu toán học.
|
| 254 |
+
Dưới đây là kết quả OCR thô từ một trang sách toán Tiếng Việt.
|
| 255 |
+
Kết quả này có thể chứa lỗi chính tả, lỗi định dạng mã LaTeX, hoặc bị ngắt quãng không logic.
|
| 256 |
+
|
| 257 |
+
Nhiệm vụ của bạn:
|
| 258 |
+
1. Sửa lỗi chính tả tiếng Việt.
|
| 259 |
+
2. Đảm bảo các công thức toán học được viết đúng định dạng LaTeX và nằm trong cặp dấu $...$.
|
| 260 |
+
3. Giữ nguyên cấu trúc logic của bài toán.
|
| 261 |
+
4. Trả về nội dung đã được làm sạch dưới dạng Markdown.
|
| 262 |
+
|
| 263 |
+
Nội dung OCR thô:
|
| 264 |
+
---
|
| 265 |
+
{text}
|
| 266 |
+
---
|
| 267 |
+
|
| 268 |
+
Kết quả làm sạch:"""
|
| 269 |
+
|
| 270 |
+
try:
|
| 271 |
+
refined = await self.llm.chat_completions_create(
|
| 272 |
+
messages=[{"role": "user", "content": prompt}],
|
| 273 |
+
temperature=0.1,
|
| 274 |
+
)
|
| 275 |
+
logger.info("[ImprovedOCRAgent] LLM refinement complete.")
|
| 276 |
+
return refined
|
| 277 |
+
except Exception as e:
|
| 278 |
+
logger.error("[ImprovedOCRAgent] LLM refinement failed: %s", e)
|
| 279 |
+
return text
|
| 280 |
+
|
| 281 |
+
async def process_url(self, url: str) -> str:
|
| 282 |
+
import httpx
|
| 283 |
+
|
| 284 |
+
from app.url_utils import sanitize_url
|
| 285 |
+
|
| 286 |
+
url = sanitize_url(url)
|
| 287 |
+
if not url:
|
| 288 |
+
return "Error: Empty image URL after cleanup."
|
| 289 |
+
|
| 290 |
+
async with httpx.AsyncClient() as client:
|
| 291 |
+
resp = await client.get(url)
|
| 292 |
+
if resp.status_code == 200:
|
| 293 |
+
temp_path = "temp_url_image.png"
|
| 294 |
+
with open(temp_path, "wb") as f:
|
| 295 |
+
f.write(resp.content)
|
| 296 |
+
try:
|
| 297 |
+
return await self.process_image(temp_path)
|
| 298 |
+
finally:
|
| 299 |
+
if os.path.exists(temp_path):
|
| 300 |
+
os.remove(temp_path)
|
| 301 |
+
return f"Error: Failed to fetch image from URL {url}"
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class OCRAgent(ImprovedOCRAgent):
|
| 305 |
+
"""Alias for compatibility with existing code."""
|
| 306 |
+
|
| 307 |
+
pass
|
agents/orchestrator.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
from agents.geometry_agent import GeometryAgent
|
| 6 |
+
from agents.knowledge_agent import KnowledgeAgent
|
| 7 |
+
from agents.ocr_agent import OCRAgent
|
| 8 |
+
from agents.parser_agent import ParserAgent
|
| 9 |
+
from agents.renderer_agent import RendererAgent
|
| 10 |
+
from agents.solver_agent import SolverAgent
|
| 11 |
+
from app.logutil import log_step
|
| 12 |
+
from app.ocr_celery import ocr_from_image_url
|
| 13 |
+
from solver.dsl_parser import DSLParser
|
| 14 |
+
from solver.engine import GeometryEngine
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
_CLIP = 2000
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _clip(val: Any, n: int = _CLIP) -> str | None:
|
| 22 |
+
if val is None:
|
| 23 |
+
return None
|
| 24 |
+
if isinstance(val, str):
|
| 25 |
+
s = val
|
| 26 |
+
else:
|
| 27 |
+
s = json.dumps(val, ensure_ascii=False, default=str)
|
| 28 |
+
return s if len(s) <= n else s[:n] + "…"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _step_io(step: str, input_val: Any = None, output_val: Any = None) -> None:
|
| 32 |
+
"""Debug: chỉ input/output (đã cắt), tránh dump dài dòng không cần thiết."""
|
| 33 |
+
log_step(step, input=_clip(input_val), output=_clip(output_val))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Orchestrator:
|
| 37 |
+
def __init__(self):
|
| 38 |
+
self.parser_agent = ParserAgent()
|
| 39 |
+
self.geometry_agent = GeometryAgent()
|
| 40 |
+
self.ocr_agent = OCRAgent()
|
| 41 |
+
self.knowledge_agent = KnowledgeAgent()
|
| 42 |
+
self.renderer_agent = RendererAgent()
|
| 43 |
+
self.solver_agent = SolverAgent()
|
| 44 |
+
self.solver_engine = GeometryEngine()
|
| 45 |
+
self.dsl_parser = DSLParser()
|
| 46 |
+
|
| 47 |
+
def _generate_step_description(self, semantic_json: Dict[str, Any], engine_result: Dict[str, Any]) -> str:
|
| 48 |
+
"""Tạo mô tả từng bước vẽ dựa trên kết quả của engine."""
|
| 49 |
+
analysis = semantic_json.get("analysis", "")
|
| 50 |
+
if not analysis:
|
| 51 |
+
analysis = f"Giải bài toán về {semantic_json.get('type', 'hình học')}."
|
| 52 |
+
|
| 53 |
+
steps = ["\n\n**Các bước dựng hình:**"]
|
| 54 |
+
drawing_phases = engine_result.get("drawing_phases", [])
|
| 55 |
+
|
| 56 |
+
for phase in drawing_phases:
|
| 57 |
+
label = phase.get("label", f"Giai đoạn {phase['phase']}")
|
| 58 |
+
points = ", ".join(phase.get("points", []))
|
| 59 |
+
segments = ", ".join([f"{s[0]}{s[1]}" for s in phase.get("segments", [])])
|
| 60 |
+
|
| 61 |
+
step_text = f"- **{label}**:"
|
| 62 |
+
if points:
|
| 63 |
+
step_text += f" Xác định các điểm {points}."
|
| 64 |
+
if segments:
|
| 65 |
+
step_text += f" Vẽ các đoạn thẳng {segments}."
|
| 66 |
+
steps.append(step_text)
|
| 67 |
+
|
| 68 |
+
circles = engine_result.get("circles", [])
|
| 69 |
+
for c in circles:
|
| 70 |
+
steps.append(f"- **Đường tròn**: Vẽ đường tròn tâm {c['center']} bán kính {c['radius']}.")
|
| 71 |
+
|
| 72 |
+
return analysis + "\n".join(steps)
|
| 73 |
+
|
| 74 |
+
async def run(
|
| 75 |
+
self,
|
| 76 |
+
text: str,
|
| 77 |
+
image_url: str = None,
|
| 78 |
+
job_id: str = None,
|
| 79 |
+
session_id: str = None,
|
| 80 |
+
status_callback=None,
|
| 81 |
+
history: list = None,
|
| 82 |
+
) -> Dict[str, Any]:
|
| 83 |
+
"""
|
| 84 |
+
Run the full pipeline. Optional history allows context-aware solving.
|
| 85 |
+
"""
|
| 86 |
+
_step_io(
|
| 87 |
+
"orchestrate_start",
|
| 88 |
+
input_val={
|
| 89 |
+
"job_id": job_id,
|
| 90 |
+
"text_len": len(text or ""),
|
| 91 |
+
"image_url": image_url,
|
| 92 |
+
"history_len": len(history or []),
|
| 93 |
+
},
|
| 94 |
+
output_val=None,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
if status_callback:
|
| 98 |
+
await status_callback("processing")
|
| 99 |
+
|
| 100 |
+
# 1. Extract context from history (if any)
|
| 101 |
+
previous_context = None
|
| 102 |
+
if history:
|
| 103 |
+
# Look for the last assistant message with geometry data
|
| 104 |
+
for msg in reversed(history):
|
| 105 |
+
if msg.get("role") == "assistant" and msg.get("metadata", {}).get("geometry_dsl"):
|
| 106 |
+
previous_context = {
|
| 107 |
+
"geometry_dsl": msg["metadata"]["geometry_dsl"],
|
| 108 |
+
"coordinates": msg["metadata"].get("coordinates", {}),
|
| 109 |
+
"analysis": msg.get("content", ""),
|
| 110 |
+
}
|
| 111 |
+
break
|
| 112 |
+
|
| 113 |
+
if previous_context:
|
| 114 |
+
_step_io("context_found", input_val=None, output_val={"dsl_len": len(previous_context["geometry_dsl"])})
|
| 115 |
+
|
| 116 |
+
# 2. Gather input text (OCR or direct)
|
| 117 |
+
input_text = text
|
| 118 |
+
if image_url:
|
| 119 |
+
input_text = await ocr_from_image_url(image_url, self.ocr_agent)
|
| 120 |
+
_step_io("step1_ocr", input_val=image_url, output_val=input_text)
|
| 121 |
+
else:
|
| 122 |
+
_step_io("step1_ocr", input_val="(no image)", output_val=text)
|
| 123 |
+
|
| 124 |
+
feedback = None
|
| 125 |
+
MAX_RETRIES = 2
|
| 126 |
+
|
| 127 |
+
for attempt in range(MAX_RETRIES + 1):
|
| 128 |
+
_step_io(
|
| 129 |
+
"attempt",
|
| 130 |
+
input_val=f"{attempt + 1}/{MAX_RETRIES + 1}",
|
| 131 |
+
output_val=None,
|
| 132 |
+
)
|
| 133 |
+
if status_callback:
|
| 134 |
+
await status_callback("solving")
|
| 135 |
+
|
| 136 |
+
# Parser with context
|
| 137 |
+
_step_io("step2_parse", input_val=f"{input_text[:50]}...", output_val=None)
|
| 138 |
+
semantic_json = await self.parser_agent.process(input_text, feedback=feedback, context=previous_context)
|
| 139 |
+
semantic_json["input_text"] = input_text
|
| 140 |
+
_step_io("step2_parse", input_val=None, output_val=semantic_json)
|
| 141 |
+
|
| 142 |
+
# Knowledge augmentation
|
| 143 |
+
_step_io("step3_knowledge", input_val=semantic_json, output_val=None)
|
| 144 |
+
semantic_json = self.knowledge_agent.augment_semantic_data(semantic_json)
|
| 145 |
+
_step_io("step3_knowledge", input_val=None, output_val=semantic_json)
|
| 146 |
+
|
| 147 |
+
# Geometry DSL with context (passing previous DSL to guide generation)
|
| 148 |
+
_step_io("step4_geometry_dsl", input_val=semantic_json, output_val=None)
|
| 149 |
+
dsl_code = await self.geometry_agent.generate_dsl(
|
| 150 |
+
semantic_json,
|
| 151 |
+
previous_dsl=previous_context["geometry_dsl"] if previous_context else None
|
| 152 |
+
)
|
| 153 |
+
_step_io("step4_geometry_dsl", input_val=None, output_val=dsl_code)
|
| 154 |
+
|
| 155 |
+
_step_io("step5_dsl_parse", input_val=dsl_code, output_val=None)
|
| 156 |
+
points, constraints, is_3d = self.dsl_parser.parse(dsl_code)
|
| 157 |
+
_step_io(
|
| 158 |
+
"step5_dsl_parse",
|
| 159 |
+
input_val=None,
|
| 160 |
+
output_val={
|
| 161 |
+
"points": len(points),
|
| 162 |
+
"constraints": len(constraints),
|
| 163 |
+
"is_3d": is_3d,
|
| 164 |
+
},
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
_step_io("step6_solve", input_val=f"{len(points)} pts / {len(constraints)} cons (is_3d={is_3d})", output_val=None)
|
| 168 |
+
import anyio
|
| 169 |
+
engine_result = await anyio.to_thread.run_sync(self.solver_engine.solve, points, constraints, is_3d)
|
| 170 |
+
|
| 171 |
+
if engine_result:
|
| 172 |
+
coordinates = engine_result.get("coordinates")
|
| 173 |
+
_step_io("step6_solve", input_val=None, output_val=coordinates)
|
| 174 |
+
break
|
| 175 |
+
|
| 176 |
+
feedback = "Geometry solver failed to find a valid solution for the given constraints. Parallelism or lengths might be inconsistent."
|
| 177 |
+
_step_io(
|
| 178 |
+
"step6_solve",
|
| 179 |
+
input_val=f"attempt {attempt + 1}",
|
| 180 |
+
output_val=feedback,
|
| 181 |
+
)
|
| 182 |
+
if attempt == MAX_RETRIES:
|
| 183 |
+
_step_io(
|
| 184 |
+
"orchestrate_abort",
|
| 185 |
+
input_val=None,
|
| 186 |
+
output_val="solver_exhausted_retries",
|
| 187 |
+
)
|
| 188 |
+
return {
|
| 189 |
+
"error": "Solver failed after multiple attempts.",
|
| 190 |
+
"last_dsl": dsl_code,
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
_step_io("orchestrate_done", input_val=job_id, output_val="success")
|
| 194 |
+
|
| 195 |
+
# 8. Solution calculation (New in v5.1)
|
| 196 |
+
solution = None
|
| 197 |
+
if engine_result:
|
| 198 |
+
_step_io("step8_solve_math", input_val=semantic_json.get("target_question"), output_val=None)
|
| 199 |
+
solution = await self.solver_agent.solve(semantic_json, engine_result)
|
| 200 |
+
_step_io("step8_solve_math", input_val=None, output_val=solution.get("answer"))
|
| 201 |
+
|
| 202 |
+
final_analysis = self._generate_step_description(semantic_json, engine_result)
|
| 203 |
+
|
| 204 |
+
status = "success"
|
| 205 |
+
return {
|
| 206 |
+
"status": status,
|
| 207 |
+
"job_id": job_id,
|
| 208 |
+
"geometry_dsl": dsl_code,
|
| 209 |
+
"coordinates": coordinates,
|
| 210 |
+
"polygon_order": engine_result.get("polygon_order", []),
|
| 211 |
+
"circles": engine_result.get("circles", []),
|
| 212 |
+
"lines": engine_result.get("lines", []),
|
| 213 |
+
"rays": engine_result.get("rays", []),
|
| 214 |
+
"drawing_phases": engine_result.get("drawing_phases", []),
|
| 215 |
+
"semantic": semantic_json,
|
| 216 |
+
"semantic_analysis": final_analysis,
|
| 217 |
+
"solution": solution,
|
| 218 |
+
"is_3d": is_3d,
|
| 219 |
+
}
|
agents/parser_agent.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from openai import AsyncOpenAI
|
| 5 |
+
from typing import Dict, Any
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
from app.url_utils import openai_compatible_api_key, sanitize_env
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
from app.llm_client import get_llm_client
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ParserAgent:
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.llm = get_llm_client()
|
| 20 |
+
|
| 21 |
+
async def process(self, text: str, feedback: str = None, context: Dict[str, Any] = None) -> Dict[str, Any]:
|
| 22 |
+
logger.info(f"==[ParserAgent] Processing input (len={len(text)})==")
|
| 23 |
+
if feedback:
|
| 24 |
+
logger.warning(f"[ParserAgent] Feedback from previous attempt: {feedback}")
|
| 25 |
+
if context:
|
| 26 |
+
logger.info(f"[ParserAgent] Using previous context (dsl_len={len(context.get('geometry_dsl', ''))})")
|
| 27 |
+
|
| 28 |
+
system_prompt = """
|
| 29 |
+
You are a Geometry Parser Agent. Extract geometric entities and constraints from Vietnamese/LaTeX math problem text.
|
| 30 |
+
|
| 31 |
+
=== CONTEXT AWARENESS ===
|
| 32 |
+
If previous context is provided, it means this is a follow-up request.
|
| 33 |
+
- Combine old entities with new ones.
|
| 34 |
+
- Update 'analysis' to reflect the entire problem state.
|
| 35 |
+
|
| 36 |
+
Output ONLY a JSON object with this EXACT structure (no extra keys, no markdown):
|
| 37 |
+
{
|
| 38 |
+
"entities": ["Point A", "Point B", ...],
|
| 39 |
+
"type": "pyramid|prism|sphere|rectangle|triangle|circle|parallelogram|trapezoid|square|rhombus|general",
|
| 40 |
+
"values": {"AB": 5, "SO": 15, "radius": 3},
|
| 41 |
+
"target_question": "Câu hỏi cụ thể cần giải (ví dụ: 'Tính diện tích tam giác ABC'). NẾU KHÔNG CÓ CÂU HỎI THÌ ĐỂ null.",
|
| 42 |
+
"analysis": "Tóm tắt ngắn gọn toàn bộ bài toán sau khi đã cập nhật các yêu cầu mới bằng tiếng Việt."
|
| 43 |
+
}
|
| 44 |
+
Rules:
|
| 45 |
+
- "analysis" MUST be a meaningful and UP-TO-DATE summary of the problem in Vietnamese.
|
| 46 |
+
- "target_question" must be concise.
|
| 47 |
+
- Include midpoints, auxiliary points in "entities" if mentioned.
|
| 48 |
+
- If feedback is provided, correct your previous output accordingly.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
user_content = f"Text: {text}"
|
| 52 |
+
if context:
|
| 53 |
+
user_content = f"PREVIOUS ANALYSIS: {context.get('analysis')}\nNEW REQUEST: {text}"
|
| 54 |
+
|
| 55 |
+
if feedback:
|
| 56 |
+
user_content += f"\nFeedback from previous attempt: {feedback}. Please correct the constraints."
|
| 57 |
+
|
| 58 |
+
logger.debug("[ParserAgent] Calling LLM (Multi-Layer)...")
|
| 59 |
+
raw = await self.llm.chat_completions_create(
|
| 60 |
+
messages=[
|
| 61 |
+
{"role": "system", "content": system_prompt},
|
| 62 |
+
{"role": "user", "content": user_content}
|
| 63 |
+
],
|
| 64 |
+
response_format={"type": "json_object"}
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Pre-process raw string: extract the JSON block if present
|
| 68 |
+
import re
|
| 69 |
+
clean_raw = raw.strip()
|
| 70 |
+
# Handle potential markdown code blocks
|
| 71 |
+
if clean_raw.startswith("```"):
|
| 72 |
+
import re
|
| 73 |
+
match = re.search(r"```(?:json)?\s*(.*?)\s*```", clean_raw, re.DOTALL)
|
| 74 |
+
if match:
|
| 75 |
+
clean_raw = match.group(1).strip()
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
result = json.loads(clean_raw)
|
| 79 |
+
except json.JSONDecodeError as e:
|
| 80 |
+
logger.error(f"[ParserAgent] JSON Parse Error: {e}. Attempting regex fallback...")
|
| 81 |
+
import re
|
| 82 |
+
json_match = re.search(r'(\{.*\})', clean_raw, re.DOTALL)
|
| 83 |
+
if json_match:
|
| 84 |
+
try:
|
| 85 |
+
# Handle single quotes if present (common LLM failure)
|
| 86 |
+
json_str = json_match.group(1)
|
| 87 |
+
if "'" in json_str and '"' not in json_str:
|
| 88 |
+
json_str = json_str.replace("'", '"')
|
| 89 |
+
result = json.loads(json_str)
|
| 90 |
+
except:
|
| 91 |
+
result = None
|
| 92 |
+
else:
|
| 93 |
+
result = None
|
| 94 |
+
|
| 95 |
+
if not result:
|
| 96 |
+
# Fallback for critical failure
|
| 97 |
+
result = {
|
| 98 |
+
"entities": [],
|
| 99 |
+
"type": "general",
|
| 100 |
+
"values": {},
|
| 101 |
+
"target_question": None,
|
| 102 |
+
"analysis": text
|
| 103 |
+
}
|
| 104 |
+
logger.info(f"[ParserAgent] LLM response received.")
|
| 105 |
+
logger.debug(f"[ParserAgent] Parsed JSON: {json.dumps(result, ensure_ascii=False, indent=2)}")
|
| 106 |
+
return result
|
agents/renderer_agent.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import glob
|
| 4 |
+
import string
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Dict, Any, List
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RendererAgent:
|
| 12 |
+
"""
|
| 13 |
+
Renderer Agent — generates Manim scripts from geometry data.
|
| 14 |
+
|
| 15 |
+
Drawing happens in phases:
|
| 16 |
+
Phase 1: Main polygon (base shape with correct vertex order)
|
| 17 |
+
Phase 2: Auxiliary points and segments (midpoints, derived segments)
|
| 18 |
+
Phase 3: Labels for all points
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def generate_manim_script(self, data: Dict[str, Any]) -> str:
|
| 22 |
+
coords: Dict[str, List[float]] = data.get("coordinates", {})
|
| 23 |
+
polygon_order: List[str] = data.get("polygon_order", [])
|
| 24 |
+
circles_meta: List[Dict] = data.get("circles", [])
|
| 25 |
+
lines_meta: List[List[str]] = data.get("lines", [])
|
| 26 |
+
rays_meta: List[List[str]] = data.get("rays", [])
|
| 27 |
+
drawing_phases: List[Dict] = data.get("drawing_phases", [])
|
| 28 |
+
semantic: Dict[str, Any] = data.get("semantic", {})
|
| 29 |
+
shape_type = semantic.get("type", "").lower()
|
| 30 |
+
|
| 31 |
+
# ── Detect 3D Context ────────────────────────────────────────────────
|
| 32 |
+
is_3d = False
|
| 33 |
+
for pos in coords.values():
|
| 34 |
+
if len(pos) >= 3 and abs(pos[2]) > 0.001:
|
| 35 |
+
is_3d = True
|
| 36 |
+
break
|
| 37 |
+
if shape_type in ["pyramid", "prism", "sphere"]:
|
| 38 |
+
is_3d = True
|
| 39 |
+
|
| 40 |
+
# ── Fallback: infer polygon_order from coords keys (alphabetical uppercase) ──
|
| 41 |
+
if not polygon_order:
|
| 42 |
+
base = sorted(
|
| 43 |
+
[pid for pid in coords if pid in string.ascii_uppercase],
|
| 44 |
+
key=lambda p: string.ascii_uppercase.index(p)
|
| 45 |
+
)
|
| 46 |
+
polygon_order = base
|
| 47 |
+
|
| 48 |
+
# Separate base points from derived (multi-char or lowercase)
|
| 49 |
+
base_ids = [pid for pid in polygon_order if pid in coords]
|
| 50 |
+
derived_ids = [pid for pid in coords if pid not in polygon_order]
|
| 51 |
+
|
| 52 |
+
scene_base = "ThreeDScene" if is_3d else "MovingCameraScene"
|
| 53 |
+
lines = [
|
| 54 |
+
"from manim import *",
|
| 55 |
+
"",
|
| 56 |
+
f"class GeometryScene({scene_base}):",
|
| 57 |
+
" def construct(self):",
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
if is_3d:
|
| 61 |
+
lines.append(" # 3D Setup")
|
| 62 |
+
lines.append(" self.set_camera_orientation(phi=75*DEGREES, theta=-45*DEGREES)")
|
| 63 |
+
lines.append(" axes = ThreeDAxes(axis_config={'stroke_width': 1})")
|
| 64 |
+
lines.append(" axes.set_opacity(0.3)")
|
| 65 |
+
lines.append(" self.add(axes)")
|
| 66 |
+
lines.append(" self.begin_ambient_camera_rotation(rate=0.1)")
|
| 67 |
+
lines.append("")
|
| 68 |
+
|
| 69 |
+
# ── Declare all dots and labels ───────────────────────────────────────
|
| 70 |
+
for pid, pos in coords.items():
|
| 71 |
+
x, y, z = 0, 0, 0
|
| 72 |
+
if len(pos) >= 1: x = round(pos[0], 4)
|
| 73 |
+
if len(pos) >= 2: y = round(pos[1], 4)
|
| 74 |
+
if len(pos) >= 3: z = round(pos[2], 4)
|
| 75 |
+
|
| 76 |
+
dot_class = "Dot3D" if is_3d else "Dot"
|
| 77 |
+
lines.append(f" p_{pid} = {dot_class}(point=[{x}, {y}, {z}], color=WHITE, radius=0.08)")
|
| 78 |
+
|
| 79 |
+
if is_3d:
|
| 80 |
+
lines.append(
|
| 81 |
+
f" l_{pid} = Text('{pid}', font_size=20, color=WHITE)"
|
| 82 |
+
f".move_to(p_{pid}.get_center() + [0.2, 0.2, 0.2])"
|
| 83 |
+
)
|
| 84 |
+
# Ensure labels follow camera in 3D (fixed orientation)
|
| 85 |
+
lines.append(f" self.add_fixed_orientation_mobjects(l_{pid})")
|
| 86 |
+
else:
|
| 87 |
+
lines.append(
|
| 88 |
+
f" l_{pid} = Text('{pid}', font_size=22, color=WHITE)"
|
| 89 |
+
f".next_to(p_{pid}, UR, buff=0.15)"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# ── 3D Shape Special: Pyramid/Prism Faces ────────────────────────────
|
| 93 |
+
if is_3d and shape_type == "pyramid" and len(base_ids) >= 3:
|
| 94 |
+
# Find apex (usually 'S')
|
| 95 |
+
apex_id = "S" if "S" in coords else derived_ids[0] if derived_ids else None
|
| 96 |
+
if apex_id:
|
| 97 |
+
# Draw base face
|
| 98 |
+
base_pts = ", ".join([f"p_{pid}.get_center()" for pid in base_ids])
|
| 99 |
+
lines.append(f" base_face = Polygon({base_pts}, color=BLUE, fill_opacity=0.1)")
|
| 100 |
+
lines.append(" self.play(Create(base_face), run_time=1.0)")
|
| 101 |
+
|
| 102 |
+
# Draw side faces
|
| 103 |
+
for i in range(len(base_ids)):
|
| 104 |
+
p1 = base_ids[i]
|
| 105 |
+
p2 = base_ids[(i+1)%len(base_ids)]
|
| 106 |
+
face_pts = f"p_{apex_id}.get_center(), p_{p1}.get_center(), p_{p2}.get_center()"
|
| 107 |
+
lines.append(f" side_{i} = Polygon({face_pts}, color=BLUE, stroke_width=1, fill_opacity=0.05)")
|
| 108 |
+
lines.append(f" self.play(Create(side_{i}), run_time=0.5)")
|
| 109 |
+
|
| 110 |
+
# ── Circles ──────────────────────────────────────────────────────────
|
| 111 |
+
for i, c in enumerate(circles_meta):
|
| 112 |
+
center = c["center"]
|
| 113 |
+
r = c["radius"]
|
| 114 |
+
if center in coords:
|
| 115 |
+
cx, cy, cz = 0, 0, 0
|
| 116 |
+
pos = coords[center]
|
| 117 |
+
if len(pos) >= 1: cx = round(pos[0], 4)
|
| 118 |
+
if len(pos) >= 2: cy = round(pos[1], 4)
|
| 119 |
+
if len(pos) >= 3: cz = round(pos[2], 4)
|
| 120 |
+
lines.append(
|
| 121 |
+
f" circle_{i} = Circle(radius={r}, color=BLUE)"
|
| 122 |
+
f".move_to([{cx}, {cy}, {cz}])"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# ── Infinite Lines & Rays ────────────────────────────────────────────
|
| 126 |
+
# (Standard Line works for 3D coordinates in Manim)
|
| 127 |
+
for i, (p1, p2) in enumerate(lines_meta):
|
| 128 |
+
if p1 in coords and p2 in coords:
|
| 129 |
+
lines.append(
|
| 130 |
+
f" line_ext_{i} = Line(p_{p1}.get_center(), p_{p2}.get_center(), color=GRAY_D, stroke_width=2)"
|
| 131 |
+
f".scale(20)"
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
for i, (p1, p2) in enumerate(rays_meta):
|
| 135 |
+
if p1 in coords and p2 in coords:
|
| 136 |
+
lines.append(
|
| 137 |
+
f" ray_{i} = Line(p_{p1}.get_center(), p_{p1}.get_center() + 15 * (p_{p2}.get_center() - p_{p1}.get_center()),"
|
| 138 |
+
f" color=GRAY_C, stroke_width=2)"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# ── Camera auto-fit group (Only for 2D) ──────────────────────────────
|
| 142 |
+
if not is_3d:
|
| 143 |
+
all_dot_names = [f"p_{pid}" for pid in coords]
|
| 144 |
+
all_names_str = ", ".join(all_dot_names)
|
| 145 |
+
lines.append(f" _all = VGroup({all_names_str})")
|
| 146 |
+
lines.append(" self.camera.frame.set_width(max(_all.width * 2.0, 8))")
|
| 147 |
+
lines.append(" self.camera.frame.move_to(_all)")
|
| 148 |
+
lines.append("")
|
| 149 |
+
|
| 150 |
+
# ── Phase 1: Base polygon ─────────────────────────────────────────────
|
| 151 |
+
if len(base_ids) >= 3:
|
| 152 |
+
pts_str = ", ".join([f"p_{pid}.get_center()" for pid in base_ids])
|
| 153 |
+
lines.append(f" poly = Polygon({pts_str}, color=BLUE, fill_color=BLUE, fill_opacity=0.15)")
|
| 154 |
+
lines.append(" self.play(Create(poly), run_time=1.5)")
|
| 155 |
+
elif len(base_ids) == 2:
|
| 156 |
+
p1, p2 = base_ids
|
| 157 |
+
lines.append(f" base_line = Line(p_{p1}.get_center(), p_{p2}.get_center(), color=BLUE)")
|
| 158 |
+
lines.append(" self.play(Create(base_line), run_time=1.0)")
|
| 159 |
+
|
| 160 |
+
# Draw base points
|
| 161 |
+
if base_ids:
|
| 162 |
+
base_dots_str = ", ".join([f"p_{pid}" for pid in base_ids])
|
| 163 |
+
lines.append(f" self.play(FadeIn(VGroup({base_dots_str})), run_time=0.5)")
|
| 164 |
+
lines.append(" self.wait(0.5)")
|
| 165 |
+
|
| 166 |
+
# ── Phase 2: Auxiliary points and segments ────────────────────────────
|
| 167 |
+
if derived_ids:
|
| 168 |
+
derived_dots_str = ", ".join([f"p_{pid}" for pid in derived_ids])
|
| 169 |
+
lines.append(f" self.play(FadeIn(VGroup({derived_dots_str})), run_time=0.8)")
|
| 170 |
+
|
| 171 |
+
# Segments from drawing_phases
|
| 172 |
+
segment_lines = []
|
| 173 |
+
for phase in drawing_phases:
|
| 174 |
+
if phase.get("phase") == 2:
|
| 175 |
+
for seg in phase.get("segments", []):
|
| 176 |
+
if len(seg) == 2 and seg[0] in coords and seg[1] in coords:
|
| 177 |
+
p1, p2 = seg[0], seg[1]
|
| 178 |
+
seg_var = f"seg_{p1}_{p2}"
|
| 179 |
+
lines.append(
|
| 180 |
+
f" {seg_var} = Line(p_{p1}.get_center(), p_{p2}.get_center(),"
|
| 181 |
+
f" color=YELLOW)"
|
| 182 |
+
)
|
| 183 |
+
segment_lines.append(seg_var)
|
| 184 |
+
|
| 185 |
+
if segment_lines:
|
| 186 |
+
segs_str = ", ".join([f"Create({sv})" for sv in segment_lines])
|
| 187 |
+
lines.append(f" self.play({segs_str}, run_time=1.2)")
|
| 188 |
+
|
| 189 |
+
if derived_ids or segment_lines:
|
| 190 |
+
lines.append(" self.wait(0.5)")
|
| 191 |
+
|
| 192 |
+
# ── Phase 3: All labels ───────────────────────────────────────────────
|
| 193 |
+
all_labels_str = ", ".join([f"l_{pid}" for pid in coords])
|
| 194 |
+
lines.append(f" self.play(FadeIn(VGroup({all_labels_str})), run_time=0.8)")
|
| 195 |
+
|
| 196 |
+
# ── Circles phase ─────────────────────────────────────────────────────
|
| 197 |
+
for i in range(len(circles_meta)):
|
| 198 |
+
lines.append(f" self.play(Create(circle_{i}), run_time=1.5)")
|
| 199 |
+
|
| 200 |
+
# ── Lines & Rays phase ────────────────────────────────────────────────
|
| 201 |
+
if lines_meta or rays_meta:
|
| 202 |
+
lr_anims = []
|
| 203 |
+
for i in range(len(lines_meta)): lr_anims.append(f"Create(line_ext_{i})")
|
| 204 |
+
for i in range(len(rays_meta)): lr_anims.append(f"Create(ray_{i})")
|
| 205 |
+
lines.append(f" self.play({', '.join(lr_anims)}, run_time=1.5)")
|
| 206 |
+
|
| 207 |
+
lines.append(" self.wait(2)")
|
| 208 |
+
|
| 209 |
+
return "\n".join(lines)
|
| 210 |
+
|
| 211 |
+
def run_manim(self, script_content: str, job_id: str) -> str:
|
| 212 |
+
import subprocess
|
| 213 |
+
import os
|
| 214 |
+
import glob
|
| 215 |
+
|
| 216 |
+
script_file = f"{job_id}.py"
|
| 217 |
+
with open(script_file, "w") as f:
|
| 218 |
+
f.write(script_content)
|
| 219 |
+
|
| 220 |
+
try:
|
| 221 |
+
if os.getenv("MOCK_VIDEO") == "true":
|
| 222 |
+
logger.info(f"MOCK_VIDEO is true. Skipping Manim for job {job_id}")
|
| 223 |
+
# Create a dummy file if needed, or just return a path that exists
|
| 224 |
+
dummy_path = f"videos/{job_id}.mp4"
|
| 225 |
+
os.makedirs("videos", exist_ok=True)
|
| 226 |
+
with open(dummy_path, "wb") as f:
|
| 227 |
+
f.write(b"dummy video content")
|
| 228 |
+
return dummy_path
|
| 229 |
+
|
| 230 |
+
# Determine manim executable path
|
| 231 |
+
manim_exe = "manim"
|
| 232 |
+
venv_manim = os.path.join(os.getcwd(), "venv", "bin", "manim")
|
| 233 |
+
if os.path.exists(venv_manim):
|
| 234 |
+
manim_exe = venv_manim
|
| 235 |
+
|
| 236 |
+
# Prepare environment with homebrew paths
|
| 237 |
+
custom_env = os.environ.copy()
|
| 238 |
+
brew_path = "/opt/homebrew/bin:/usr/local/bin"
|
| 239 |
+
custom_env["PATH"] = f"{brew_path}:{custom_env.get('PATH', '')}"
|
| 240 |
+
|
| 241 |
+
logger.info(f"Running {manim_exe} for job {job_id}...")
|
| 242 |
+
result = subprocess.run(
|
| 243 |
+
[manim_exe, "-ql", "--media_dir", ".", "-o", f"{job_id}.mp4", script_file, "GeometryScene"],
|
| 244 |
+
capture_output=True,
|
| 245 |
+
text=True,
|
| 246 |
+
env=custom_env
|
| 247 |
+
)
|
| 248 |
+
logger.info(f"Manim STDOUT: {result.stdout}")
|
| 249 |
+
if result.returncode != 0:
|
| 250 |
+
logger.error(f"Manim STDERR: {result.stderr}")
|
| 251 |
+
|
| 252 |
+
for pattern in [f"**/videos/**/{job_id}.mp4", f"**/{job_id}*.mp4"]:
|
| 253 |
+
found = glob.glob(pattern, recursive=True)
|
| 254 |
+
if found:
|
| 255 |
+
logger.info(f"Manim Success: Found {found[0]}")
|
| 256 |
+
return found[0]
|
| 257 |
+
|
| 258 |
+
logger.error(f"Manim file not found for job {job_id}. Return code: {result.returncode}")
|
| 259 |
+
return ""
|
| 260 |
+
except Exception as e:
|
| 261 |
+
logger.exception(f"Manim Execution Error: {e}")
|
| 262 |
+
return ""
|
| 263 |
+
finally:
|
| 264 |
+
if os.path.exists(script_file):
|
| 265 |
+
os.remove(script_file)
|
agents/solver_agent.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import sympy as sp
|
| 4 |
+
from typing import Dict, Any, List
|
| 5 |
+
from app.llm_client import get_llm_client
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
class SolverAgent:
|
| 10 |
+
def __init__(self):
|
| 11 |
+
self.llm = get_llm_client()
|
| 12 |
+
|
| 13 |
+
async def solve(self, semantic_data: Dict[str, Any], engine_result: Dict[str, Any]) -> Dict[str, Any]:
|
| 14 |
+
"""
|
| 15 |
+
Solves the geometric problem based on coordinates and the target question.
|
| 16 |
+
Returns a 'solution' dictionary with answer, steps, and symbolic_expression.
|
| 17 |
+
"""
|
| 18 |
+
target_question = semantic_data.get("target_question")
|
| 19 |
+
if not target_question:
|
| 20 |
+
# If no question, just return an empty solution structure
|
| 21 |
+
return {
|
| 22 |
+
"answer": None,
|
| 23 |
+
"steps": [],
|
| 24 |
+
"symbolic_expression": None
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
logger.info(f"==[SolverAgent] Solving for: '{target_question}'==")
|
| 28 |
+
|
| 29 |
+
input_text = semantic_data.get("input_text", "")
|
| 30 |
+
coordinates = engine_result.get("coordinates", {})
|
| 31 |
+
|
| 32 |
+
# We provide the coordinates and semantic context to the LLM to help it reason.
|
| 33 |
+
# The LLM is tasked with generating the solution structure directly.
|
| 34 |
+
|
| 35 |
+
system_prompt = """
|
| 36 |
+
You are a Geometry Solver Agent. Your goal is to provide a step-by-step solution for a specific geometric question.
|
| 37 |
+
|
| 38 |
+
=== DATA PROVIDED ===
|
| 39 |
+
1. Target Question: The specific question to answer.
|
| 40 |
+
2. Geometry Data: Entities and values extracted from the problem.
|
| 41 |
+
3. Coordinates: Calculated coordinates for all points.
|
| 42 |
+
|
| 43 |
+
=== REQUIREMENTS ===
|
| 44 |
+
- Provide the solution in the SAME LANGUAGE as the user's input.
|
| 45 |
+
- Use SymPy concepts if appropriate.
|
| 46 |
+
- Steps should be clear, concise, and logical.
|
| 47 |
+
- The final answer should be numerically or symbolically accurate based on the coordinates and geometric properties.
|
| 48 |
+
- For geometric proofs (e.g., "Is AB perpendicular to AC?"), explain the reasoning based on the data.
|
| 49 |
+
|
| 50 |
+
Output ONLY a JSON object with this structure:
|
| 51 |
+
{
|
| 52 |
+
"answer": "Chuỗi văn bản kết quả cuối cùng (kèm đơn vị nếu có)",
|
| 53 |
+
"steps": [
|
| 54 |
+
"Bước 1: ...",
|
| 55 |
+
"Bước 2: ...",
|
| 56 |
+
...
|
| 57 |
+
],
|
| 58 |
+
"symbolic_expression": "Biểu thức toán học rút gọn (LaTeX format optional)"
|
| 59 |
+
}
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
user_content = f"""
|
| 63 |
+
INPUT_TEXT: {input_text}
|
| 64 |
+
TARGET_QUESTION: {target_question}
|
| 65 |
+
SEMANTIC_DATA: {json.dumps(semantic_data, ensure_ascii=False)}
|
| 66 |
+
COORDINATES: {json.dumps(coordinates)}
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
logger.debug("[SolverAgent] Requesting solution from LLM...")
|
| 70 |
+
try:
|
| 71 |
+
raw = await self.llm.chat_completions_create(
|
| 72 |
+
messages=[
|
| 73 |
+
{"role": "system", "content": system_prompt},
|
| 74 |
+
{"role": "user", "content": user_content}
|
| 75 |
+
],
|
| 76 |
+
response_format={"type": "json_object"}
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
clean_raw = raw.strip()
|
| 80 |
+
# Handle potential markdown code blocks if the response_format wasn't strictly honored
|
| 81 |
+
if clean_raw.startswith("```"):
|
| 82 |
+
import re
|
| 83 |
+
match = re.search(r"```(?:json)?\s*(.*?)\s*```", clean_raw, re.DOTALL)
|
| 84 |
+
if match:
|
| 85 |
+
clean_raw = match.group(1).strip()
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
solution = json.loads(clean_raw)
|
| 89 |
+
except json.JSONDecodeError:
|
| 90 |
+
# Last resort: try to find anything between { and }
|
| 91 |
+
import re
|
| 92 |
+
json_match = re.search(r'(\{.*\})', clean_raw, re.DOTALL)
|
| 93 |
+
if json_match:
|
| 94 |
+
solution = json.loads(json_match.group(1))
|
| 95 |
+
else:
|
| 96 |
+
raise
|
| 97 |
+
|
| 98 |
+
logger.info("[SolverAgent] Solution generated successfully.")
|
| 99 |
+
return solution
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logger.error(f"[SolverAgent] Error generating solution: {e}")
|
| 102 |
+
logger.debug(f"[SolverAgent] Raw LLM output was: \n{raw if 'raw' in locals() else 'N/A'}")
|
| 103 |
+
return {
|
| 104 |
+
"answer": "Không thể tính toán lời giải tại thời điểm này.",
|
| 105 |
+
"steps": ["Đã xảy ra lỗi trong quá trình xử lý lời giải."],
|
| 106 |
+
"symbolic_expression": None
|
| 107 |
+
}
|
agents/torch_ultralytics_compat.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch 2.6+ defaults weights_only=True; Ultralytics YOLO .pt checkpoints unpickle full nn graphs (trusted official weights)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import functools
|
| 6 |
+
|
| 7 |
+
_torch_load_patched = False
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def allow_ultralytics_weights() -> None:
|
| 11 |
+
"""
|
| 12 |
+
Official yolov8n.pt is a trusted checkpoint. PyTorch 2.6+ safe unpickling would require
|
| 13 |
+
allowlisting many torch.nn globals; loading with weights_only=False matches Ultralytics
|
| 14 |
+
upstream behavior for local .pt files.
|
| 15 |
+
"""
|
| 16 |
+
global _torch_load_patched
|
| 17 |
+
if _torch_load_patched:
|
| 18 |
+
return
|
| 19 |
+
try:
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
_orig = torch.load
|
| 23 |
+
|
| 24 |
+
@functools.wraps(_orig)
|
| 25 |
+
def _load(*args, **kwargs):
|
| 26 |
+
if "weights_only" not in kwargs:
|
| 27 |
+
kwargs["weights_only"] = False
|
| 28 |
+
return _orig(*args, **kwargs)
|
| 29 |
+
|
| 30 |
+
torch.load = _load
|
| 31 |
+
_torch_load_patched = True
|
| 32 |
+
except Exception:
|
| 33 |
+
pass
|
app/chat_image_upload.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Validate and upload chat/solve attachment images to Supabase Storage (image bucket)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import uuid
|
| 8 |
+
from typing import Any, Dict, Tuple
|
| 9 |
+
|
| 10 |
+
from fastapi import HTTPException
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _get_next_image_version(session_id: str) -> int:
|
| 16 |
+
"""Same logic as worker.asset_manager.get_next_version for asset_type image."""
|
| 17 |
+
from app.supabase_client import get_supabase
|
| 18 |
+
|
| 19 |
+
supabase = get_supabase()
|
| 20 |
+
try:
|
| 21 |
+
res = (
|
| 22 |
+
supabase.table("session_assets")
|
| 23 |
+
.select("version")
|
| 24 |
+
.eq("session_id", session_id)
|
| 25 |
+
.eq("asset_type", "image")
|
| 26 |
+
.order("version", desc=True)
|
| 27 |
+
.limit(1)
|
| 28 |
+
.execute()
|
| 29 |
+
)
|
| 30 |
+
if res.data:
|
| 31 |
+
return res.data[0]["version"] + 1
|
| 32 |
+
return 1
|
| 33 |
+
except Exception as e:
|
| 34 |
+
logger.error("Error fetching image version: %s", e)
|
| 35 |
+
return 1
|
| 36 |
+
|
| 37 |
+
_MAX_BYTES_DEFAULT = 10 * 1024 * 1024
|
| 38 |
+
|
| 39 |
+
_EXT_TO_MIME: dict[str, str] = {
|
| 40 |
+
".png": "image/png",
|
| 41 |
+
".jpg": "image/jpeg",
|
| 42 |
+
".jpeg": "image/jpeg",
|
| 43 |
+
".webp": "image/webp",
|
| 44 |
+
".gif": "image/gif",
|
| 45 |
+
".bmp": "image/bmp",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _max_bytes() -> int:
|
| 50 |
+
raw = os.getenv("CHAT_IMAGE_MAX_BYTES")
|
| 51 |
+
if raw and raw.isdigit():
|
| 52 |
+
return min(int(raw), 50 * 1024 * 1024)
|
| 53 |
+
return _MAX_BYTES_DEFAULT
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _magic_ok(ext: str, body: bytes) -> bool:
|
| 57 |
+
if len(body) < 12:
|
| 58 |
+
return False
|
| 59 |
+
if ext == ".png":
|
| 60 |
+
return body.startswith(b"\x89PNG\r\n\x1a\n")
|
| 61 |
+
if ext in (".jpg", ".jpeg"):
|
| 62 |
+
return body.startswith(b"\xff\xd8\xff")
|
| 63 |
+
if ext == ".webp":
|
| 64 |
+
return body.startswith(b"RIFF") and body[8:12] == b"WEBP"
|
| 65 |
+
if ext == ".gif":
|
| 66 |
+
return body.startswith(b"GIF87a") or body.startswith(b"GIF89a")
|
| 67 |
+
if ext == ".bmp":
|
| 68 |
+
return body.startswith(b"BM")
|
| 69 |
+
return False
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def validate_chat_image_bytes(
|
| 73 |
+
filename: str | None,
|
| 74 |
+
body: bytes,
|
| 75 |
+
declared_content_type: str | None,
|
| 76 |
+
) -> Tuple[str, str]:
|
| 77 |
+
"""
|
| 78 |
+
Validate size, extension, and magic bytes.
|
| 79 |
+
Returns (extension_with_dot, content_type).
|
| 80 |
+
"""
|
| 81 |
+
max_b = _max_bytes()
|
| 82 |
+
if not body:
|
| 83 |
+
raise HTTPException(status_code=400, detail="Empty file.")
|
| 84 |
+
if len(body) > max_b:
|
| 85 |
+
raise HTTPException(
|
| 86 |
+
status_code=413,
|
| 87 |
+
detail=f"Image too large (max {max_b // (1024 * 1024)} MB).",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
ext = os.path.splitext(filename or "")[1].lower()
|
| 91 |
+
if not ext:
|
| 92 |
+
ext = ".png"
|
| 93 |
+
if ext not in _EXT_TO_MIME:
|
| 94 |
+
raise HTTPException(
|
| 95 |
+
status_code=400,
|
| 96 |
+
detail=f"Unsupported image type: {ext}. Allowed: {', '.join(sorted(_EXT_TO_MIME))}",
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
if not _magic_ok(ext, body):
|
| 100 |
+
raise HTTPException(
|
| 101 |
+
status_code=400,
|
| 102 |
+
detail="File content does not match declared image type.",
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
mime = _EXT_TO_MIME[ext]
|
| 106 |
+
if declared_content_type:
|
| 107 |
+
decl = declared_content_type.split(";")[0].strip().lower()
|
| 108 |
+
if decl and decl not in ("application/octet-stream", mime) and decl != mime:
|
| 109 |
+
logger.warning(
|
| 110 |
+
"Content-Type mismatch (declared=%s, inferred=%s); using inferred.",
|
| 111 |
+
declared_content_type,
|
| 112 |
+
mime,
|
| 113 |
+
)
|
| 114 |
+
return ext, mime
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def upload_session_chat_image(
|
| 118 |
+
session_id: str,
|
| 119 |
+
job_id: str,
|
| 120 |
+
file_bytes: bytes,
|
| 121 |
+
ext_with_dot: str,
|
| 122 |
+
content_type: str,
|
| 123 |
+
) -> Dict[str, Any]:
|
| 124 |
+
"""
|
| 125 |
+
Upload to SUPABASE_IMAGE_BUCKET (default: image), insert session_assets row.
|
| 126 |
+
Returns dict with public_url, storage_path, version, session_asset_id (if returned).
|
| 127 |
+
"""
|
| 128 |
+
from app.supabase_client import get_supabase
|
| 129 |
+
|
| 130 |
+
supabase = get_supabase()
|
| 131 |
+
bucket_name = os.getenv("SUPABASE_IMAGE_BUCKET", "image")
|
| 132 |
+
raw_ext = ext_with_dot.lstrip(".").lower()
|
| 133 |
+
version = _get_next_image_version(session_id)
|
| 134 |
+
file_name = f"image_v{version}_{job_id}.{raw_ext}"
|
| 135 |
+
storage_path = f"sessions/{session_id}/{file_name}"
|
| 136 |
+
|
| 137 |
+
supabase.storage.from_(bucket_name).upload(
|
| 138 |
+
path=storage_path,
|
| 139 |
+
file=file_bytes,
|
| 140 |
+
file_options={"content-type": content_type},
|
| 141 |
+
)
|
| 142 |
+
public_url = supabase.storage.from_(bucket_name).get_public_url(storage_path)
|
| 143 |
+
if isinstance(public_url, dict):
|
| 144 |
+
public_url = public_url.get("publicUrl") or public_url.get("public_url") or str(public_url)
|
| 145 |
+
|
| 146 |
+
row = {
|
| 147 |
+
"session_id": session_id,
|
| 148 |
+
"job_id": job_id,
|
| 149 |
+
"asset_type": "image",
|
| 150 |
+
"storage_path": storage_path,
|
| 151 |
+
"public_url": public_url,
|
| 152 |
+
"version": version,
|
| 153 |
+
}
|
| 154 |
+
ins = supabase.table("session_assets").insert(row).select("id").execute()
|
| 155 |
+
asset_id = None
|
| 156 |
+
if ins.data and len(ins.data) > 0:
|
| 157 |
+
asset_id = ins.data[0].get("id")
|
| 158 |
+
|
| 159 |
+
log_data = {
|
| 160 |
+
"public_url": public_url,
|
| 161 |
+
"storage_path": storage_path,
|
| 162 |
+
"version": version,
|
| 163 |
+
"session_asset_id": str(asset_id) if asset_id else None,
|
| 164 |
+
}
|
| 165 |
+
logger.info("Uploaded chat image: %s", log_data)
|
| 166 |
+
return {
|
| 167 |
+
"public_url": public_url,
|
| 168 |
+
"storage_path": storage_path,
|
| 169 |
+
"version": version,
|
| 170 |
+
"session_asset_id": str(asset_id) if asset_id else None,
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def upload_ephemeral_ocr_blob(
|
| 175 |
+
file_bytes: bytes,
|
| 176 |
+
ext_with_dot: str,
|
| 177 |
+
content_type: str,
|
| 178 |
+
) -> Tuple[str, str]:
|
| 179 |
+
"""
|
| 180 |
+
Upload bytes to image bucket under _ocr_temp/ for worker-only OCR (no session_assets row).
|
| 181 |
+
Returns (storage_path, public_url). Caller must delete_storage_object when done.
|
| 182 |
+
"""
|
| 183 |
+
from app.supabase_client import get_supabase
|
| 184 |
+
|
| 185 |
+
bucket_name = os.getenv("SUPABASE_IMAGE_BUCKET", "image")
|
| 186 |
+
raw_ext = ext_with_dot.lstrip(".").lower() or "png"
|
| 187 |
+
name = f"_ocr_temp/{uuid.uuid4().hex}.{raw_ext}"
|
| 188 |
+
supabase = get_supabase()
|
| 189 |
+
supabase.storage.from_(bucket_name).upload(
|
| 190 |
+
path=name,
|
| 191 |
+
file=file_bytes,
|
| 192 |
+
file_options={"content-type": content_type},
|
| 193 |
+
)
|
| 194 |
+
public_url = supabase.storage.from_(bucket_name).get_public_url(name)
|
| 195 |
+
if isinstance(public_url, dict):
|
| 196 |
+
public_url = public_url.get("publicUrl") or public_url.get("public_url") or str(public_url)
|
| 197 |
+
return name, public_url
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def delete_storage_object(bucket_name: str, storage_path: str) -> None:
|
| 201 |
+
try:
|
| 202 |
+
from app.supabase_client import get_supabase
|
| 203 |
+
|
| 204 |
+
get_supabase().storage.from_(bucket_name).remove([storage_path])
|
| 205 |
+
except Exception as e:
|
| 206 |
+
logger.warning("delete_storage_object failed path=%s: %s", storage_path, e)
|
app/dependencies.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import HTTPException, Header
|
| 2 |
+
|
| 3 |
+
from app.supabase_client import get_supabase, get_supabase_for_user_jwt
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
async def get_current_user_id(authorization: str | None = Header(None)):
|
| 7 |
+
"""
|
| 8 |
+
Authenticate user using Supabase JWT.
|
| 9 |
+
Expected Header: Authorization: Bearer <token>
|
| 10 |
+
"""
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
if not authorization:
|
| 14 |
+
raise HTTPException(
|
| 15 |
+
status_code=401,
|
| 16 |
+
detail="Authorization header missing or invalid. Use 'Bearer <token>'",
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
if os.getenv("ALLOW_TEST_BYPASS") == "true" and authorization.startswith("Test "):
|
| 20 |
+
return authorization.split(" ")[1]
|
| 21 |
+
|
| 22 |
+
if not authorization.startswith("Bearer "):
|
| 23 |
+
raise HTTPException(
|
| 24 |
+
status_code=401,
|
| 25 |
+
detail="Authorization header missing or invalid. Use 'Bearer <token>'",
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
token = authorization.split(" ")[1]
|
| 29 |
+
supabase = get_supabase()
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
user_response = supabase.auth.get_user(token)
|
| 33 |
+
if not user_response or not user_response.user:
|
| 34 |
+
raise HTTPException(status_code=401, detail="Invalid session or token.")
|
| 35 |
+
|
| 36 |
+
return user_response.user.id
|
| 37 |
+
except HTTPException:
|
| 38 |
+
raise
|
| 39 |
+
except Exception as e:
|
| 40 |
+
raise HTTPException(status_code=401, detail=f"Authentication failed: {str(e)}")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
async def get_authenticated_supabase(authorization: str = Header(...)):
|
| 44 |
+
"""
|
| 45 |
+
Supabase client that carries the user's JWT (anon key + Authorization header).
|
| 46 |
+
Use for routes that should respect Row Level Security; pair with app logic as needed.
|
| 47 |
+
"""
|
| 48 |
+
if not authorization or not authorization.startswith("Bearer "):
|
| 49 |
+
raise HTTPException(
|
| 50 |
+
status_code=401,
|
| 51 |
+
detail="Authorization header missing or invalid. Use 'Bearer <token>'",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
token = authorization.split(" ")[1]
|
| 55 |
+
supabase = get_supabase()
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
user_response = supabase.auth.get_user(token)
|
| 59 |
+
if not user_response or not user_response.user:
|
| 60 |
+
raise HTTPException(status_code=401, detail="Invalid session or token.")
|
| 61 |
+
except HTTPException:
|
| 62 |
+
raise
|
| 63 |
+
except Exception as e:
|
| 64 |
+
raise HTTPException(status_code=401, detail=f"Authentication failed: {str(e)}")
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
return get_supabase_for_user_jwt(token)
|
| 68 |
+
except RuntimeError as e:
|
| 69 |
+
raise HTTPException(status_code=503, detail=str(e))
|
app/errors.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Map exceptions to short, user-visible messages (avoid leaking HTML bodies from 404 proxies)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _looks_like_html(text: str) -> bool:
|
| 11 |
+
t = text.lstrip()[:500].lower()
|
| 12 |
+
return t.startswith("<!doctype") or t.startswith("<html") or "<html" in t[:200]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def format_error_for_user(exc: BaseException) -> str:
|
| 16 |
+
"""
|
| 17 |
+
Produce a safe message for chat/UI. Full detail stays in server logs via logger.exception.
|
| 18 |
+
"""
|
| 19 |
+
# httpx: wrong URL often returns 404 HTML; don't show body
|
| 20 |
+
try:
|
| 21 |
+
import httpx
|
| 22 |
+
|
| 23 |
+
if isinstance(exc, httpx.HTTPStatusError):
|
| 24 |
+
req = exc.request
|
| 25 |
+
code = exc.response.status_code
|
| 26 |
+
url_hint = ""
|
| 27 |
+
try:
|
| 28 |
+
url_hint = str(req.url.host) if req and req.url else ""
|
| 29 |
+
except Exception:
|
| 30 |
+
pass
|
| 31 |
+
logger.warning(
|
| 32 |
+
"HTTPStatusError %s for %s (response not shown to user)",
|
| 33 |
+
code,
|
| 34 |
+
url_hint or "?",
|
| 35 |
+
)
|
| 36 |
+
return (
|
| 37 |
+
"Kiểm tra URL API, khóa bí mật và biến môi trường (OpenRouter/Supabase/Redis)."
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
if isinstance(exc, httpx.RequestError):
|
| 41 |
+
return "Không kết nối được tới dịch vụ ngoài (mạng hoặc URL sai)."
|
| 42 |
+
except ImportError:
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
raw = str(exc).strip()
|
| 46 |
+
if not raw:
|
| 47 |
+
return "Đã xảy ra lỗi không xác định."
|
| 48 |
+
|
| 49 |
+
if _looks_like_html(raw):
|
| 50 |
+
logger.warning("Suppressed HTML error body from user-facing message")
|
| 51 |
+
return (
|
| 52 |
+
"Dịch vụ trả về trang lỗi (thường là URL API sai hoặc endpoint không tồn tại — HTTP 404). "
|
| 53 |
+
"Kiểm tra OPENROUTER_MODEL và khóa API trên server."
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
if len(raw) > 800:
|
| 57 |
+
return raw[:800] + "…"
|
| 58 |
+
|
| 59 |
+
return raw
|
app/llm_client.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import asyncio
|
| 4 |
+
import logging
|
| 5 |
+
from openai import AsyncOpenAI
|
| 6 |
+
from typing import List, Dict, Any, Optional
|
| 7 |
+
from app.url_utils import openai_compatible_api_key, sanitize_env
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class MultiLayerLLMClient:
|
| 12 |
+
def __init__(self):
|
| 13 |
+
# 1. Models sequence loading
|
| 14 |
+
self.models = []
|
| 15 |
+
for i in range(1, 4):
|
| 16 |
+
model = os.getenv(f"OPENROUTER_MODEL_{i}")
|
| 17 |
+
if model:
|
| 18 |
+
self.models.append(model)
|
| 19 |
+
|
| 20 |
+
# Fallback to legacy OPENROUTER_MODEL if no numbered models found
|
| 21 |
+
if not self.models:
|
| 22 |
+
legacy_model = os.getenv("OPENROUTER_MODEL", "google/gemini-2.0-flash-001")
|
| 23 |
+
self.models = [legacy_model]
|
| 24 |
+
|
| 25 |
+
# 2. Key selection (No rotation, always use the first available key)
|
| 26 |
+
api_key = os.getenv("OPENROUTER_API_KEY_1") or os.getenv("OPENROUTER_API_KEY")
|
| 27 |
+
|
| 28 |
+
if not api_key:
|
| 29 |
+
logger.error("[LLM] No OpenRouter API key found.")
|
| 30 |
+
self.client = None
|
| 31 |
+
else:
|
| 32 |
+
self.client = AsyncOpenAI(
|
| 33 |
+
api_key=openai_compatible_api_key(api_key),
|
| 34 |
+
base_url="https://openrouter.ai/api/v1",
|
| 35 |
+
timeout=60.0,
|
| 36 |
+
default_headers={
|
| 37 |
+
"HTTP-Referer": "https://mathsolver.ai",
|
| 38 |
+
"X-Title": "MathSolver Backend",
|
| 39 |
+
}
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
async def chat_completions_create(
|
| 43 |
+
self,
|
| 44 |
+
messages: List[Dict[str, str]],
|
| 45 |
+
response_format: Optional[Dict[str, str]] = None,
|
| 46 |
+
**kwargs
|
| 47 |
+
) -> str:
|
| 48 |
+
"""
|
| 49 |
+
Implements Model Fallback Sequence: Model 1 -> Model 2 -> Model 3.
|
| 50 |
+
Always starts from Model 1 for every new call.
|
| 51 |
+
"""
|
| 52 |
+
if not self.client:
|
| 53 |
+
raise ValueError("No API client configured. Check your API keys.")
|
| 54 |
+
|
| 55 |
+
MAX_ATTEMPTS = len(self.models)
|
| 56 |
+
RETRY_DELAY = 1.0 # second
|
| 57 |
+
|
| 58 |
+
for attempt_idx in range(MAX_ATTEMPTS):
|
| 59 |
+
current_model = self.models[attempt_idx]
|
| 60 |
+
attempt_num = attempt_idx + 1
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
logger.info(f"[LLM] Attempt {attempt_num}/{MAX_ATTEMPTS} using Model: {current_model}...")
|
| 64 |
+
|
| 65 |
+
response = await self.client.chat.completions.create(
|
| 66 |
+
model=current_model,
|
| 67 |
+
messages=messages,
|
| 68 |
+
response_format=response_format,
|
| 69 |
+
**kwargs
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if not response or not getattr(response, "choices", None):
|
| 73 |
+
raise ValueError(f"Invalid response structure from model {current_model}")
|
| 74 |
+
|
| 75 |
+
content = response.choices[0].message.content
|
| 76 |
+
if content:
|
| 77 |
+
logger.info(f"[LLM] SUCCESS on attempt {attempt_num} ({current_model}).")
|
| 78 |
+
return content
|
| 79 |
+
|
| 80 |
+
raise ValueError(f"Empty content from model {current_model}")
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
err_msg = f"{type(e).__name__}: {str(e)}"
|
| 84 |
+
logger.warning(f"[LLM] FAILED on attempt {attempt_num} ({current_model}): {err_msg}")
|
| 85 |
+
|
| 86 |
+
if attempt_num < MAX_ATTEMPTS:
|
| 87 |
+
logger.info(f"[LLM] Retrying next model in {RETRY_DELAY}s...")
|
| 88 |
+
await asyncio.sleep(RETRY_DELAY)
|
| 89 |
+
else:
|
| 90 |
+
logger.error(f"[LLM] FINAL FAILURE after {attempt_num} models.")
|
| 91 |
+
raise e
|
| 92 |
+
|
| 93 |
+
# Global instance for easy reuse (singleton-ish)
|
| 94 |
+
_llm_client = None
|
| 95 |
+
|
| 96 |
+
def get_llm_client() -> MultiLayerLLMClient:
|
| 97 |
+
global _llm_client
|
| 98 |
+
if _llm_client is None:
|
| 99 |
+
_llm_client = MultiLayerLLMClient()
|
| 100 |
+
return _llm_client
|
app/logging_setup.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Logging theo một biến LOG_LEVEL: debug | info | warning | error."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
from typing import Final
|
| 8 |
+
|
| 9 |
+
_SETUP_DONE = False
|
| 10 |
+
|
| 11 |
+
PIPELINE_LOGGER_NAME: Final = "app.pipeline"
|
| 12 |
+
CACHE_LOGGER_NAME: Final = "app.cache"
|
| 13 |
+
STEPS_LOGGER_NAME: Final = "app.steps"
|
| 14 |
+
ACCESS_LOGGER_NAME: Final = "app.access"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _normalize_level() -> str:
|
| 18 |
+
raw = os.getenv("LOG_LEVEL", "info").strip().lower()
|
| 19 |
+
if raw in ("debug", "info", "warning", "error"):
|
| 20 |
+
return raw
|
| 21 |
+
return "info"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def setup_application_logging() -> None:
|
| 25 |
+
"""Idempotent; gọi khi khởi động process (uvicorn, celery, worker_health)."""
|
| 26 |
+
global _SETUP_DONE
|
| 27 |
+
if _SETUP_DONE:
|
| 28 |
+
return
|
| 29 |
+
_SETUP_DONE = True
|
| 30 |
+
|
| 31 |
+
mode = _normalize_level()
|
| 32 |
+
|
| 33 |
+
level_map = {
|
| 34 |
+
"debug": logging.DEBUG,
|
| 35 |
+
"info": logging.INFO,
|
| 36 |
+
"warning": logging.WARNING,
|
| 37 |
+
"error": logging.ERROR,
|
| 38 |
+
}
|
| 39 |
+
root_level = level_map[mode]
|
| 40 |
+
|
| 41 |
+
fmt_named = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
|
| 42 |
+
fmt_short = "%(asctime)s | %(levelname)-8s | %(message)s"
|
| 43 |
+
|
| 44 |
+
logging.basicConfig(
|
| 45 |
+
level=root_level,
|
| 46 |
+
format=fmt_named if mode == "debug" else fmt_short,
|
| 47 |
+
datefmt="%H:%M:%S",
|
| 48 |
+
force=True,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 52 |
+
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
| 53 |
+
logging.getLogger("openai").setLevel(logging.WARNING)
|
| 54 |
+
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
| 55 |
+
logging.getLogger("uvicorn.error").setLevel(logging.INFO)
|
| 56 |
+
# HTTP/2 stack (httpx/httpcore) — khi LOG_LEVEL=debug root=DEBUG sẽ tràn log hpack; không cần cho debug app
|
| 57 |
+
for _name in ("hpack", "h2", "hyperframe", "urllib3"):
|
| 58 |
+
logging.getLogger(_name).setLevel(logging.WARNING)
|
| 59 |
+
|
| 60 |
+
if mode == "debug":
|
| 61 |
+
logging.getLogger("agents").setLevel(logging.DEBUG)
|
| 62 |
+
logging.getLogger("solver").setLevel(logging.DEBUG)
|
| 63 |
+
logging.getLogger("app").setLevel(logging.DEBUG)
|
| 64 |
+
logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.DEBUG)
|
| 65 |
+
logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.DEBUG)
|
| 66 |
+
logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.INFO)
|
| 67 |
+
logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.INFO)
|
| 68 |
+
logging.getLogger("app.main").setLevel(logging.INFO)
|
| 69 |
+
logging.getLogger("worker").setLevel(logging.INFO)
|
| 70 |
+
elif mode == "info":
|
| 71 |
+
# Chỉ HTTP access (app.access) + startup; ẩn chi tiết agents/orchestrator/pipeline SUCCESS
|
| 72 |
+
logging.getLogger("agents").setLevel(logging.INFO)
|
| 73 |
+
logging.getLogger("solver").setLevel(logging.WARNING)
|
| 74 |
+
logging.getLogger("app").setLevel(logging.INFO)
|
| 75 |
+
logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.WARNING)
|
| 76 |
+
logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.WARNING)
|
| 77 |
+
logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.WARNING)
|
| 78 |
+
logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.INFO)
|
| 79 |
+
logging.getLogger("app.main").setLevel(logging.INFO)
|
| 80 |
+
logging.getLogger("worker").setLevel(logging.WARNING)
|
| 81 |
+
elif mode == "warning":
|
| 82 |
+
logging.getLogger("agents").setLevel(logging.WARNING)
|
| 83 |
+
logging.getLogger("solver").setLevel(logging.WARNING)
|
| 84 |
+
logging.getLogger("app.routers").setLevel(logging.WARNING)
|
| 85 |
+
logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.WARNING)
|
| 86 |
+
logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.WARNING)
|
| 87 |
+
logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.WARNING)
|
| 88 |
+
logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.WARNING)
|
| 89 |
+
logging.getLogger("app.main").setLevel(logging.WARNING)
|
| 90 |
+
logging.getLogger("worker").setLevel(logging.WARNING)
|
| 91 |
+
else: # error
|
| 92 |
+
logging.getLogger("agents").setLevel(logging.ERROR)
|
| 93 |
+
logging.getLogger("solver").setLevel(logging.ERROR)
|
| 94 |
+
logging.getLogger("app.routers").setLevel(logging.ERROR)
|
| 95 |
+
logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.ERROR)
|
| 96 |
+
logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.ERROR)
|
| 97 |
+
logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.ERROR)
|
| 98 |
+
logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.ERROR)
|
| 99 |
+
logging.getLogger("app.main").setLevel(logging.ERROR)
|
| 100 |
+
logging.getLogger("worker").setLevel(logging.ERROR)
|
| 101 |
+
|
| 102 |
+
logging.getLogger(__name__).debug(
|
| 103 |
+
"LOG_LEVEL=%s root=%s", mode, logging.getLevelName(root_level)
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_log_level() -> str:
|
| 108 |
+
return _normalize_level()
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def is_debug_level() -> bool:
|
| 112 |
+
return _normalize_level() == "debug"
|
app/logutil.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""log_step (debug), pipeline (debug), access log ở middleware."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from app.logging_setup import PIPELINE_LOGGER_NAME, STEPS_LOGGER_NAME
|
| 11 |
+
|
| 12 |
+
_pipeline = logging.getLogger(PIPELINE_LOGGER_NAME)
|
| 13 |
+
_steps = logging.getLogger(STEPS_LOGGER_NAME)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def is_debug_mode() -> bool:
|
| 17 |
+
"""Chi tiết từng bước chỉ khi LOG_LEVEL=debug."""
|
| 18 |
+
return os.getenv("LOG_LEVEL", "info").strip().lower() == "debug"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _truncate(val: Any, max_len: int = 2000) -> Any:
|
| 22 |
+
if val is None:
|
| 23 |
+
return None
|
| 24 |
+
if isinstance(val, (int, float, bool)):
|
| 25 |
+
return val
|
| 26 |
+
s = str(val)
|
| 27 |
+
if len(s) > max_len:
|
| 28 |
+
return s[:max_len] + f"... (+{len(s) - max_len} chars)"
|
| 29 |
+
return s
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def log_step(step: str, **fields: Any) -> None:
|
| 33 |
+
"""Chỉ khi LOG_LEVEL=debug: DB / cache / orchestrator."""
|
| 34 |
+
if not is_debug_mode():
|
| 35 |
+
return
|
| 36 |
+
safe = {k: _truncate(v) for k, v in fields.items()}
|
| 37 |
+
try:
|
| 38 |
+
payload = json.dumps(safe, ensure_ascii=False, default=str)
|
| 39 |
+
except Exception:
|
| 40 |
+
payload = str(safe)
|
| 41 |
+
_steps.debug("[step:%s] %s", step, payload)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def log_pipeline_success(operation: str, **fields: Any) -> None:
|
| 45 |
+
"""Chỉ hiện khi debug (pipeline SUCCESS không dùng ở info — đã có app.access)."""
|
| 46 |
+
if not is_debug_mode():
|
| 47 |
+
return
|
| 48 |
+
safe = {k: _truncate(v, 500) for k, v in fields.items()}
|
| 49 |
+
_pipeline.info(
|
| 50 |
+
"SUCCESS %s %s",
|
| 51 |
+
operation,
|
| 52 |
+
json.dumps(safe, ensure_ascii=False, default=str),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def log_pipeline_failure(operation: str, error: str | None = None, **fields: Any) -> None:
|
| 57 |
+
"""Thất bại pipeline: luôn dùng WARNING để vẫn thấy khi LOG_LEVEL=warning."""
|
| 58 |
+
if is_debug_mode():
|
| 59 |
+
safe = {k: _truncate(v, 500) for k, v in fields.items()}
|
| 60 |
+
_pipeline.warning(
|
| 61 |
+
"FAIL %s err=%s %s",
|
| 62 |
+
operation,
|
| 63 |
+
_truncate(error, 300),
|
| 64 |
+
json.dumps(safe, ensure_ascii=False, default=str),
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
_pipeline.warning("FAIL %s", operation)
|
app/main.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import uuid
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
from fastapi import Depends, FastAPI, File, HTTPException, UploadFile
|
| 11 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
+
from starlette.requests import Request
|
| 13 |
+
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
from app.runtime_env import apply_runtime_env_defaults
|
| 17 |
+
|
| 18 |
+
apply_runtime_env_defaults()
|
| 19 |
+
|
| 20 |
+
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
|
| 21 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
|
| 22 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="albumentations")
|
| 23 |
+
|
| 24 |
+
from app.logging_setup import ACCESS_LOGGER_NAME, get_log_level, setup_application_logging
|
| 25 |
+
|
| 26 |
+
setup_application_logging()
|
| 27 |
+
|
| 28 |
+
# Routers (after logging)
|
| 29 |
+
from app.dependencies import get_current_user_id
|
| 30 |
+
from app.ocr_local_file import ocr_from_local_image_path
|
| 31 |
+
from app.routers import auth, sessions, solve
|
| 32 |
+
from agents.ocr_agent import OCRAgent
|
| 33 |
+
from app.routers.solve import get_orchestrator
|
| 34 |
+
from app.supabase_client import get_supabase
|
| 35 |
+
from app.websocket_manager import register_websocket_routes
|
| 36 |
+
|
| 37 |
+
logger = logging.getLogger("app.main")
|
| 38 |
+
_access = logging.getLogger(ACCESS_LOGGER_NAME)
|
| 39 |
+
|
| 40 |
+
app = FastAPI(title="Visual Math Solver API v5.1")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@app.middleware("http")
|
| 44 |
+
async def access_log_middleware(request: Request, call_next):
|
| 45 |
+
"""LOG_LEVEL=info/debug: mọi request; warning: chỉ 4xx/5xx; error: chỉ 4xx/5xx ở mức error."""
|
| 46 |
+
start = time.perf_counter()
|
| 47 |
+
response = await call_next(request)
|
| 48 |
+
ms = (time.perf_counter() - start) * 1000
|
| 49 |
+
mode = get_log_level()
|
| 50 |
+
method = request.method
|
| 51 |
+
path = request.url.path
|
| 52 |
+
status = response.status_code
|
| 53 |
+
|
| 54 |
+
if mode in ("debug", "info"):
|
| 55 |
+
_access.info("%s %s -> %s (%.0fms)", method, path, status, ms)
|
| 56 |
+
elif mode == "warning":
|
| 57 |
+
if status >= 500:
|
| 58 |
+
_access.error("%s %s -> %s (%.0fms)", method, path, status, ms)
|
| 59 |
+
elif status >= 400:
|
| 60 |
+
_access.warning("%s %s -> %s (%.0fms)", method, path, status, ms)
|
| 61 |
+
elif mode == "error":
|
| 62 |
+
if status >= 400:
|
| 63 |
+
_access.error("%s %s -> %s", method, path, status)
|
| 64 |
+
|
| 65 |
+
return response
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
from worker.celery_app import BROKER_URL
|
| 69 |
+
|
| 70 |
+
_broker_tail = BROKER_URL.split("@")[-1] if "@" in BROKER_URL else BROKER_URL
|
| 71 |
+
if get_log_level() in ("debug", "info"):
|
| 72 |
+
logger.info("App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail)
|
| 73 |
+
else:
|
| 74 |
+
logger.warning(
|
| 75 |
+
"App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
app.add_middleware(
|
| 79 |
+
CORSMiddleware,
|
| 80 |
+
allow_origins=[
|
| 81 |
+
"http://localhost:3000",
|
| 82 |
+
"http://127.0.0.1:3000",
|
| 83 |
+
"http://localhost:3005",
|
| 84 |
+
],
|
| 85 |
+
allow_credentials=True,
|
| 86 |
+
allow_methods=["*"],
|
| 87 |
+
allow_headers=["*"],
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
app.include_router(auth.router)
|
| 91 |
+
app.include_router(sessions.router)
|
| 92 |
+
app.include_router(solve.router)
|
| 93 |
+
|
| 94 |
+
register_websocket_routes(app)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_ocr_agent() -> OCRAgent:
|
| 98 |
+
"""Same OCR instance as the solve pipeline (no duplicate model load)."""
|
| 99 |
+
return get_orchestrator().ocr_agent
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
supabase_client = get_supabase()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@app.get("/")
|
| 106 |
+
def read_root():
|
| 107 |
+
return {"message": "Visual Math Solver API v5.1 is running", "version": "5.1"}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@app.post("/api/v1/ocr")
|
| 111 |
+
async def upload_ocr(
|
| 112 |
+
file: UploadFile = File(...),
|
| 113 |
+
_user_id=Depends(get_current_user_id),
|
| 114 |
+
):
|
| 115 |
+
"""OCR upload: requires authenticated user."""
|
| 116 |
+
temp_path = f"temp_{uuid.uuid4()}.png"
|
| 117 |
+
with open(temp_path, "wb") as buffer:
|
| 118 |
+
buffer.write(await file.read())
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
text = await ocr_from_local_image_path(temp_path, file.filename, get_ocr_agent())
|
| 122 |
+
return {"text": text}
|
| 123 |
+
finally:
|
| 124 |
+
if os.path.exists(temp_path):
|
| 125 |
+
os.remove(temp_path)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@app.get("/api/v1/solve/{job_id}")
|
| 129 |
+
async def get_job_status(
|
| 130 |
+
job_id: str,
|
| 131 |
+
user_id=Depends(get_current_user_id),
|
| 132 |
+
):
|
| 133 |
+
"""Retrieve job status (can be used for polling if WS fails). Owner-only."""
|
| 134 |
+
response = supabase_client.table("jobs").select("*").eq("id", job_id).execute()
|
| 135 |
+
if not response.data:
|
| 136 |
+
raise HTTPException(status_code=404, detail="Job not found")
|
| 137 |
+
job = response.data[0]
|
| 138 |
+
if job.get("user_id") is not None and str(job["user_id"]) != str(user_id):
|
| 139 |
+
raise HTTPException(status_code=403, detail="Forbidden: You do not own this job.")
|
| 140 |
+
return job
|
app/models/schemas.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, EmailStr, field_validator
|
| 2 |
+
from typing import Optional, List, Any, Dict
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
import uuid
|
| 5 |
+
|
| 6 |
+
from app.url_utils import sanitize_url
|
| 7 |
+
|
| 8 |
+
# --- Auth Schemas ---
|
| 9 |
+
class UserProfile(BaseModel):
|
| 10 |
+
id: uuid.UUID
|
| 11 |
+
display_name: Optional[str] = None
|
| 12 |
+
avatar_url: Optional[str] = None
|
| 13 |
+
created_at: datetime
|
| 14 |
+
|
| 15 |
+
class User(BaseModel):
|
| 16 |
+
id: uuid.UUID
|
| 17 |
+
email: EmailStr
|
| 18 |
+
|
| 19 |
+
# --- Session Schemas ---
|
| 20 |
+
class SessionBase(BaseModel):
|
| 21 |
+
title: str = "Bài toán mới"
|
| 22 |
+
|
| 23 |
+
class SessionCreate(SessionBase):
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
class Session(SessionBase):
|
| 27 |
+
id: uuid.UUID
|
| 28 |
+
user_id: uuid.UUID
|
| 29 |
+
created_at: datetime
|
| 30 |
+
updated_at: datetime
|
| 31 |
+
|
| 32 |
+
class Config:
|
| 33 |
+
from_attributes = True
|
| 34 |
+
|
| 35 |
+
# --- Message Schemas ---
|
| 36 |
+
class MessageBase(BaseModel):
|
| 37 |
+
role: str
|
| 38 |
+
type: str = "text"
|
| 39 |
+
content: str
|
| 40 |
+
metadata: Dict[str, Any] = {}
|
| 41 |
+
|
| 42 |
+
class MessageCreate(MessageBase):
|
| 43 |
+
session_id: uuid.UUID
|
| 44 |
+
|
| 45 |
+
class Message(MessageBase):
|
| 46 |
+
id: uuid.UUID
|
| 47 |
+
session_id: uuid.UUID
|
| 48 |
+
created_at: datetime
|
| 49 |
+
|
| 50 |
+
class Config:
|
| 51 |
+
from_attributes = True
|
| 52 |
+
|
| 53 |
+
# --- Solve Job Schemas ---
|
| 54 |
+
class SolveRequest(BaseModel):
|
| 55 |
+
text: str
|
| 56 |
+
image_url: Optional[str] = None
|
| 57 |
+
|
| 58 |
+
@field_validator("image_url", mode="before")
|
| 59 |
+
@classmethod
|
| 60 |
+
def _clean_image_url(cls, v):
|
| 61 |
+
return sanitize_url(v) if v is not None else None
|
| 62 |
+
|
| 63 |
+
class SolveResponse(BaseModel):
|
| 64 |
+
job_id: str
|
| 65 |
+
status: str
|
| 66 |
+
|
| 67 |
+
class RenderVideoRequest(BaseModel):
|
| 68 |
+
job_id: Optional[str] = None
|
| 69 |
+
|
| 70 |
+
class RenderVideoResponse(BaseModel):
|
| 71 |
+
job_id: str
|
| 72 |
+
status: str
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class OcrPreviewResponse(BaseModel):
|
| 76 |
+
"""Stateless OCR preview before POST .../solve (no DB writes, no job)."""
|
| 77 |
+
|
| 78 |
+
ocr_text: str
|
| 79 |
+
user_message: str = ""
|
| 80 |
+
combined_draft: str
|
app/ocr_celery.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run OCR on a remote worker via Celery (queue `ocr`) when OCR_USE_CELERY is enabled."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
from typing import TYPE_CHECKING
|
| 8 |
+
|
| 9 |
+
import anyio
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
from agents.ocr_agent import OCRAgent
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def ocr_celery_enabled() -> bool:
|
| 18 |
+
return os.getenv("OCR_USE_CELERY", "").strip().lower() in ("1", "true", "yes", "on")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _ocr_timeout_sec() -> float:
|
| 22 |
+
raw = os.getenv("OCR_CELERY_TIMEOUT_SEC", "180")
|
| 23 |
+
try:
|
| 24 |
+
return max(30.0, float(raw))
|
| 25 |
+
except ValueError:
|
| 26 |
+
return 180.0
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _run_ocr_celery_sync(image_url: str) -> str:
|
| 30 |
+
from worker.ocr_tasks import run_ocr_from_url
|
| 31 |
+
|
| 32 |
+
async_result = run_ocr_from_url.apply_async(args=[image_url])
|
| 33 |
+
return async_result.get(timeout=_ocr_timeout_sec())
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _is_ocr_error_response(text: str) -> bool:
|
| 37 |
+
s = (text or "").lstrip()
|
| 38 |
+
return s.startswith("Error:")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
async def ocr_from_image_url(image_url: str, fallback_agent: "OCRAgent") -> str:
|
| 42 |
+
"""
|
| 43 |
+
If OCR_USE_CELERY: delegate to Celery task `run_ocr_from_url` (worker queue `ocr`, raw OCR only),
|
| 44 |
+
then run ``refine_with_llm`` on the API process.
|
| 45 |
+
Else: use fallback_agent.process_url (in-process full pipeline).
|
| 46 |
+
"""
|
| 47 |
+
if not ocr_celery_enabled():
|
| 48 |
+
return await fallback_agent.process_url(image_url)
|
| 49 |
+
logger.info("OCR_USE_CELERY: delegating OCR to Celery queue=ocr (LLM refine on API)")
|
| 50 |
+
raw = await anyio.to_thread.run_sync(_run_ocr_celery_sync, image_url)
|
| 51 |
+
raw = raw if raw is not None else ""
|
| 52 |
+
if not raw.strip() or _is_ocr_error_response(raw):
|
| 53 |
+
return raw
|
| 54 |
+
return await fallback_agent.refine_with_llm(raw)
|
app/ocr_local_file.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OCR from a local file path, optionally via Celery worker (upload temp blob first)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
from typing import TYPE_CHECKING
|
| 8 |
+
|
| 9 |
+
from app.chat_image_upload import (
|
| 10 |
+
delete_storage_object,
|
| 11 |
+
upload_ephemeral_ocr_blob,
|
| 12 |
+
validate_chat_image_bytes,
|
| 13 |
+
)
|
| 14 |
+
from app.ocr_celery import ocr_celery_enabled, ocr_from_image_url
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from agents.ocr_agent import OCRAgent
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
async def ocr_from_local_image_path(
|
| 23 |
+
local_path: str,
|
| 24 |
+
original_filename: str | None,
|
| 25 |
+
fallback_agent: "OCRAgent",
|
| 26 |
+
) -> str:
|
| 27 |
+
"""
|
| 28 |
+
Run OCR on a file on local disk. If OCR_USE_Celery, upload to ephemeral storage URL
|
| 29 |
+
then delegate to worker; otherwise process_image in-process.
|
| 30 |
+
"""
|
| 31 |
+
if not ocr_celery_enabled():
|
| 32 |
+
return await fallback_agent.process_image(local_path)
|
| 33 |
+
|
| 34 |
+
with open(local_path, "rb") as f:
|
| 35 |
+
body = f.read()
|
| 36 |
+
ext = os.path.splitext(original_filename or local_path)[1].lower() or ".png"
|
| 37 |
+
_, content_type = validate_chat_image_bytes(original_filename or local_path, body, None)
|
| 38 |
+
bucket = os.getenv("SUPABASE_IMAGE_BUCKET", "image")
|
| 39 |
+
path, url = upload_ephemeral_ocr_blob(body, ext, content_type)
|
| 40 |
+
try:
|
| 41 |
+
return await ocr_from_image_url(url, fallback_agent)
|
| 42 |
+
finally:
|
| 43 |
+
delete_storage_object(bucket, path)
|
app/ocr_text_merge.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Helpers for OCR preview combined draft (no Pydantic email deps)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def build_combined_ocr_preview_draft(user_message: Optional[str], ocr_text: str) -> str:
|
| 9 |
+
"""Merge user caption and OCR text for confirm step (user message first, then OCR)."""
|
| 10 |
+
u = (user_message or "").strip()
|
| 11 |
+
o = (ocr_text or "").strip()
|
| 12 |
+
if u and o:
|
| 13 |
+
return f"{u}\n\n{o}"
|
| 14 |
+
return u or o
|
app/routers/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from . import auth, sessions, solve
|
app/routers/auth.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 2 |
+
from app.dependencies import get_current_user_id
|
| 3 |
+
from app.supabase_client import get_supabase
|
| 4 |
+
from app.models.schemas import UserProfile
|
| 5 |
+
import uuid
|
| 6 |
+
|
| 7 |
+
router = APIRouter(prefix="/api/v1/auth", tags=["Auth"])
|
| 8 |
+
|
| 9 |
+
@router.get("/me")
|
| 10 |
+
async def get_me(user_id=Depends(get_current_user_id)):
|
| 11 |
+
"""获取当前登录用户的信息 (Retrieve current user profile)"""
|
| 12 |
+
supabase = get_supabase()
|
| 13 |
+
res = supabase.table("profiles").select("*").eq("id", user_id).execute()
|
| 14 |
+
if not res.data:
|
| 15 |
+
raise HTTPException(status_code=404, detail="Profile not found.")
|
| 16 |
+
return res.data[0]
|
| 17 |
+
|
| 18 |
+
@router.patch("/me")
|
| 19 |
+
async def update_me(data: dict, user_id=Depends(get_current_user_id)):
|
| 20 |
+
"""Cập nhật profile hiện tại (Update current profile)"""
|
| 21 |
+
supabase = get_supabase()
|
| 22 |
+
res = supabase.table("profiles").update(data).eq("id", user_id).execute()
|
| 23 |
+
return res.data[0]
|
app/routers/sessions.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 6 |
+
|
| 7 |
+
from app.dependencies import get_current_user_id
|
| 8 |
+
from app.logutil import log_step
|
| 9 |
+
from app.session_cache import (
|
| 10 |
+
get_sessions_list_cached,
|
| 11 |
+
invalidate_for_user,
|
| 12 |
+
invalidate_session_owner,
|
| 13 |
+
session_owned_by_user,
|
| 14 |
+
)
|
| 15 |
+
from app.supabase_client import get_supabase
|
| 16 |
+
|
| 17 |
+
router = APIRouter(prefix="/api/v1/sessions", tags=["Sessions"])
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@router.get("", response_model=List[dict])
|
| 21 |
+
async def list_sessions(user_id=Depends(get_current_user_id)):
|
| 22 |
+
"""Danh sách các phiên chat của người dùng (List user's chat sessions)"""
|
| 23 |
+
supabase = get_supabase()
|
| 24 |
+
|
| 25 |
+
def fetch() -> list:
|
| 26 |
+
res = (
|
| 27 |
+
supabase.table("sessions")
|
| 28 |
+
.select("*")
|
| 29 |
+
.eq("user_id", user_id)
|
| 30 |
+
.order("updated_at", desc=True)
|
| 31 |
+
.execute()
|
| 32 |
+
)
|
| 33 |
+
log_step("db_select", table="sessions", op="list", user_id=str(user_id))
|
| 34 |
+
return res.data
|
| 35 |
+
|
| 36 |
+
return get_sessions_list_cached(str(user_id), fetch)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@router.post("", response_model=dict)
|
| 40 |
+
async def create_session(user_id=Depends(get_current_user_id)):
|
| 41 |
+
"""Tạo một phiên chat mới (Create a new chat session)"""
|
| 42 |
+
supabase = get_supabase()
|
| 43 |
+
res = supabase.table("sessions").insert(
|
| 44 |
+
{"user_id": user_id, "title": "Bài toán mới"}
|
| 45 |
+
).execute()
|
| 46 |
+
log_step("db_insert", table="sessions", op="create")
|
| 47 |
+
invalidate_for_user(str(user_id))
|
| 48 |
+
return res.data[0]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@router.get("/{session_id}/messages", response_model=List[dict])
|
| 52 |
+
async def get_session_messages(session_id: str, user_id=Depends(get_current_user_id)):
|
| 53 |
+
"""Lấy toàn bộ lịch sử tin nhắn của một phiên (Get chat history for a session)"""
|
| 54 |
+
supabase = get_supabase()
|
| 55 |
+
|
| 56 |
+
def owns() -> bool:
|
| 57 |
+
res = (
|
| 58 |
+
supabase.table("sessions")
|
| 59 |
+
.select("id")
|
| 60 |
+
.eq("id", session_id)
|
| 61 |
+
.eq("user_id", user_id)
|
| 62 |
+
.execute()
|
| 63 |
+
)
|
| 64 |
+
log_step("db_select", table="sessions", op="owner_check", session_id=session_id)
|
| 65 |
+
return bool(res.data)
|
| 66 |
+
|
| 67 |
+
if not session_owned_by_user(session_id, str(user_id), owns):
|
| 68 |
+
raise HTTPException(
|
| 69 |
+
status_code=403, detail="Forbidden: You do not own this session."
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
res = (
|
| 73 |
+
supabase.table("messages")
|
| 74 |
+
.select("*")
|
| 75 |
+
.eq("session_id", session_id)
|
| 76 |
+
.order("created_at", desc=False)
|
| 77 |
+
.execute()
|
| 78 |
+
)
|
| 79 |
+
log_step("db_select", table="messages", op="list", session_id=session_id)
|
| 80 |
+
return res.data
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@router.delete("/{session_id}")
|
| 84 |
+
async def delete_session(session_id: str, user_id=Depends(get_current_user_id)):
|
| 85 |
+
"""Xóa một phiên chat (Delete a chat session)"""
|
| 86 |
+
supabase = get_supabase()
|
| 87 |
+
|
| 88 |
+
def owns() -> bool:
|
| 89 |
+
res = (
|
| 90 |
+
supabase.table("sessions")
|
| 91 |
+
.select("id")
|
| 92 |
+
.eq("id", session_id)
|
| 93 |
+
.eq("user_id", user_id)
|
| 94 |
+
.execute()
|
| 95 |
+
)
|
| 96 |
+
return bool(res.data)
|
| 97 |
+
|
| 98 |
+
if not session_owned_by_user(session_id, str(user_id), owns):
|
| 99 |
+
raise HTTPException(
|
| 100 |
+
status_code=403, detail="Forbidden: You do not own this session."
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# jobs.session_id FK must be cleared before sessions row
|
| 104 |
+
supabase.table("jobs").delete().eq("session_id", session_id).eq("user_id", user_id).execute()
|
| 105 |
+
log_step("db_delete", table="jobs", op="by_session", session_id=session_id)
|
| 106 |
+
supabase.table("messages").delete().eq("session_id", session_id).execute()
|
| 107 |
+
log_step("db_delete", table="messages", op="by_session", session_id=session_id)
|
| 108 |
+
res = (
|
| 109 |
+
supabase.table("sessions")
|
| 110 |
+
.delete()
|
| 111 |
+
.eq("id", session_id)
|
| 112 |
+
.eq("user_id", user_id)
|
| 113 |
+
.execute()
|
| 114 |
+
)
|
| 115 |
+
log_step("db_delete", table="sessions", session_id=session_id)
|
| 116 |
+
invalidate_for_user(str(user_id))
|
| 117 |
+
invalidate_session_owner(session_id, str(user_id))
|
| 118 |
+
return {"status": "ok", "deleted_id": session_id}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@router.patch("/{session_id}/title")
|
| 122 |
+
async def update_session_title(title: str, session_id: str, user_id=Depends(get_current_user_id)):
|
| 123 |
+
"""Cập nhật tiêu đề phiên chat (Rename a chat session)"""
|
| 124 |
+
supabase = get_supabase()
|
| 125 |
+
res = (
|
| 126 |
+
supabase.table("sessions")
|
| 127 |
+
.update({"title": title})
|
| 128 |
+
.eq("id", session_id)
|
| 129 |
+
.eq("user_id", user_id)
|
| 130 |
+
.execute()
|
| 131 |
+
)
|
| 132 |
+
log_step("db_update", table="sessions", op="title", session_id=session_id)
|
| 133 |
+
invalidate_for_user(str(user_id))
|
| 134 |
+
return res.data[0]
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@router.get("/{session_id}/assets", response_model=List[dict])
|
| 138 |
+
async def get_session_assets(session_id: str, user_id=Depends(get_current_user_id)):
|
| 139 |
+
"""Lấy danh sách video đã render trong session (Get versioned assets for a session)"""
|
| 140 |
+
supabase = get_supabase()
|
| 141 |
+
|
| 142 |
+
def owns() -> bool:
|
| 143 |
+
res = (
|
| 144 |
+
supabase.table("sessions")
|
| 145 |
+
.select("id")
|
| 146 |
+
.eq("id", session_id)
|
| 147 |
+
.eq("user_id", user_id)
|
| 148 |
+
.execute()
|
| 149 |
+
)
|
| 150 |
+
return bool(res.data)
|
| 151 |
+
|
| 152 |
+
if not session_owned_by_user(session_id, str(user_id), owns):
|
| 153 |
+
raise HTTPException(
|
| 154 |
+
status_code=403, detail="Forbidden: You do not own this session."
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
res = (
|
| 158 |
+
supabase.table("session_assets")
|
| 159 |
+
.select("*")
|
| 160 |
+
.eq("session_id", session_id)
|
| 161 |
+
.order("version", desc=True)
|
| 162 |
+
.execute()
|
| 163 |
+
)
|
| 164 |
+
log_step("db_select", table="session_assets", op="list", session_id=session_id)
|
| 165 |
+
return res.data
|
app/routers/solve.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import uuid
|
| 6 |
+
|
| 7 |
+
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, UploadFile
|
| 8 |
+
|
| 9 |
+
from agents.orchestrator import Orchestrator
|
| 10 |
+
from app.chat_image_upload import upload_session_chat_image, validate_chat_image_bytes
|
| 11 |
+
from app.ocr_celery import ocr_celery_enabled
|
| 12 |
+
from app.ocr_local_file import ocr_from_local_image_path
|
| 13 |
+
from app.dependencies import get_current_user_id
|
| 14 |
+
from app.errors import format_error_for_user
|
| 15 |
+
from app.logutil import log_pipeline_failure, log_pipeline_success, log_step
|
| 16 |
+
from app.models.schemas import (
|
| 17 |
+
OcrPreviewResponse,
|
| 18 |
+
RenderVideoRequest,
|
| 19 |
+
RenderVideoResponse,
|
| 20 |
+
SolveRequest,
|
| 21 |
+
SolveResponse,
|
| 22 |
+
)
|
| 23 |
+
from app.ocr_text_merge import build_combined_ocr_preview_draft
|
| 24 |
+
from app.session_cache import invalidate_for_user, session_owned_by_user
|
| 25 |
+
from app.supabase_client import get_supabase
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
router = APIRouter(prefix="/api/v1/sessions", tags=["Solve"])
|
| 29 |
+
|
| 30 |
+
# Eager init: all agents and models load at import time (also run in Docker build via scripts/prewarm_models.py).
|
| 31 |
+
ORCHESTRATOR = Orchestrator()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_orchestrator() -> Orchestrator:
|
| 35 |
+
return ORCHESTRATOR
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
_OCR_PREVIEW_MAX_BYTES = 10 * 1024 * 1024
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _assert_session_owner(supabase, session_id: str, user_id, uid: str, op: str) -> None:
|
| 42 |
+
def owns() -> bool:
|
| 43 |
+
res = (
|
| 44 |
+
supabase.table("sessions")
|
| 45 |
+
.select("id")
|
| 46 |
+
.eq("id", session_id)
|
| 47 |
+
.eq("user_id", user_id)
|
| 48 |
+
.execute()
|
| 49 |
+
)
|
| 50 |
+
log_step("db_select", table="sessions", op=op, session_id=session_id)
|
| 51 |
+
return bool(res.data)
|
| 52 |
+
|
| 53 |
+
if not session_owned_by_user(session_id, uid, owns):
|
| 54 |
+
log_pipeline_failure("solve_request", error="forbidden", session_id=session_id)
|
| 55 |
+
raise HTTPException(
|
| 56 |
+
status_code=403, detail="Forbidden: You do not own this session."
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _enqueue_solve_common(
|
| 61 |
+
supabase,
|
| 62 |
+
background_tasks: BackgroundTasks,
|
| 63 |
+
session_id: str,
|
| 64 |
+
user_id,
|
| 65 |
+
uid: str,
|
| 66 |
+
request: SolveRequest,
|
| 67 |
+
message_metadata: dict,
|
| 68 |
+
job_id: str,
|
| 69 |
+
) -> SolveResponse:
|
| 70 |
+
"""Insert user message, job row, enqueue pipeline; update title when first message."""
|
| 71 |
+
supabase.table("messages").insert(
|
| 72 |
+
{
|
| 73 |
+
"session_id": session_id,
|
| 74 |
+
"role": "user",
|
| 75 |
+
"type": "text",
|
| 76 |
+
"content": request.text,
|
| 77 |
+
"metadata": message_metadata,
|
| 78 |
+
}
|
| 79 |
+
).execute()
|
| 80 |
+
log_step("db_insert", table="messages", op="user_message", session_id=session_id)
|
| 81 |
+
|
| 82 |
+
supabase.table("jobs").insert(
|
| 83 |
+
{
|
| 84 |
+
"id": job_id,
|
| 85 |
+
"user_id": user_id,
|
| 86 |
+
"session_id": session_id,
|
| 87 |
+
"status": "processing",
|
| 88 |
+
"input_text": request.text,
|
| 89 |
+
}
|
| 90 |
+
).execute()
|
| 91 |
+
log_step("db_insert", table="jobs", job_id=job_id)
|
| 92 |
+
|
| 93 |
+
background_tasks.add_task(process_session_job, job_id, session_id, request, str(user_id))
|
| 94 |
+
|
| 95 |
+
title_check = supabase.table("sessions").select("title").eq("id", session_id).execute()
|
| 96 |
+
if title_check.data and title_check.data[0]["title"] == "Bài toán mới":
|
| 97 |
+
new_title = request.text[:50] + ("..." if len(request.text) > 50 else "")
|
| 98 |
+
supabase.table("sessions").update({"title": new_title}).eq("id", session_id).execute()
|
| 99 |
+
log_step("db_update", table="sessions", op="title_from_first_message")
|
| 100 |
+
invalidate_for_user(uid)
|
| 101 |
+
|
| 102 |
+
log_pipeline_success("solve_accepted", job_id=job_id, session_id=session_id)
|
| 103 |
+
return SolveResponse(job_id=job_id, status="processing")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@router.post("/{session_id}/ocr_preview", response_model=OcrPreviewResponse)
|
| 107 |
+
async def ocr_preview(
|
| 108 |
+
session_id: str,
|
| 109 |
+
user_id=Depends(get_current_user_id),
|
| 110 |
+
file: UploadFile = File(...),
|
| 111 |
+
user_message: str | None = Form(None),
|
| 112 |
+
):
|
| 113 |
+
"""
|
| 114 |
+
Run OCR on an uploaded image and merge with optional user_message into combined_draft.
|
| 115 |
+
Does not insert messages or start a solve job. After user confirms, call POST .../solve
|
| 116 |
+
with text=combined_draft (edited) and omit image_url to avoid double OCR.
|
| 117 |
+
"""
|
| 118 |
+
supabase = get_supabase()
|
| 119 |
+
uid = str(user_id)
|
| 120 |
+
_assert_session_owner(supabase, session_id, user_id, uid, "owner_check_ocr_preview")
|
| 121 |
+
|
| 122 |
+
body = await file.read()
|
| 123 |
+
if len(body) > _OCR_PREVIEW_MAX_BYTES:
|
| 124 |
+
raise HTTPException(
|
| 125 |
+
status_code=413,
|
| 126 |
+
detail=f"Image too large (max {_OCR_PREVIEW_MAX_BYTES // (1024 * 1024)} MB).",
|
| 127 |
+
)
|
| 128 |
+
if not body:
|
| 129 |
+
raise HTTPException(status_code=400, detail="Empty file.")
|
| 130 |
+
|
| 131 |
+
if ocr_celery_enabled():
|
| 132 |
+
validate_chat_image_bytes(file.filename, body, file.content_type)
|
| 133 |
+
|
| 134 |
+
suffix = os.path.splitext(file.filename or "")[1].lower()
|
| 135 |
+
if suffix not in (".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp", ""):
|
| 136 |
+
suffix = ".png"
|
| 137 |
+
temp_path = f"temp_ocr_preview_{uuid.uuid4()}{suffix or '.png'}"
|
| 138 |
+
try:
|
| 139 |
+
with open(temp_path, "wb") as f:
|
| 140 |
+
f.write(body)
|
| 141 |
+
ocr_text = await ocr_from_local_image_path(
|
| 142 |
+
temp_path, file.filename, get_orchestrator().ocr_agent
|
| 143 |
+
)
|
| 144 |
+
if ocr_text is None:
|
| 145 |
+
ocr_text = ""
|
| 146 |
+
finally:
|
| 147 |
+
if os.path.exists(temp_path):
|
| 148 |
+
os.remove(temp_path)
|
| 149 |
+
|
| 150 |
+
um = (user_message or "").strip()
|
| 151 |
+
combined = build_combined_ocr_preview_draft(user_message, ocr_text)
|
| 152 |
+
log_step("ocr_preview_done", session_id=session_id, ocr_len=len(ocr_text), user_len=len(um))
|
| 153 |
+
return OcrPreviewResponse(
|
| 154 |
+
ocr_text=ocr_text,
|
| 155 |
+
user_message=um,
|
| 156 |
+
combined_draft=combined,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@router.post("/{session_id}/solve", response_model=SolveResponse)
|
| 161 |
+
async def solve_problem(
|
| 162 |
+
session_id: str,
|
| 163 |
+
request: SolveRequest,
|
| 164 |
+
background_tasks: BackgroundTasks,
|
| 165 |
+
user_id=Depends(get_current_user_id),
|
| 166 |
+
):
|
| 167 |
+
"""
|
| 168 |
+
Gửi câu hỏi giải toán trong một session (Submit geometry problem in a session).
|
| 169 |
+
Lưu câu hỏi vào history và bắt đầu tiến trình giải (chỉ giải toán và tạo hình tĩnh).
|
| 170 |
+
"""
|
| 171 |
+
supabase = get_supabase()
|
| 172 |
+
uid = str(user_id)
|
| 173 |
+
_assert_session_owner(supabase, session_id, user_id, uid, "owner_check")
|
| 174 |
+
|
| 175 |
+
message_metadata = {"image_url": request.image_url} if request.image_url else {}
|
| 176 |
+
job_id = str(uuid.uuid4())
|
| 177 |
+
return _enqueue_solve_common(
|
| 178 |
+
supabase,
|
| 179 |
+
background_tasks,
|
| 180 |
+
session_id,
|
| 181 |
+
user_id,
|
| 182 |
+
uid,
|
| 183 |
+
request,
|
| 184 |
+
message_metadata,
|
| 185 |
+
job_id,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@router.post("/{session_id}/solve_multipart", response_model=SolveResponse)
|
| 190 |
+
async def solve_multipart(
|
| 191 |
+
session_id: str,
|
| 192 |
+
background_tasks: BackgroundTasks,
|
| 193 |
+
user_id=Depends(get_current_user_id),
|
| 194 |
+
text: str = Form(...),
|
| 195 |
+
file: UploadFile = File(...),
|
| 196 |
+
):
|
| 197 |
+
"""
|
| 198 |
+
Gửi text + file ảnh trong một request multipart: validate, upload bucket `image`,
|
| 199 |
+
ghi session_assets, lưu message kèm metadata (URL, size, type), rồi enqueue solve
|
| 200 |
+
(image_url trỏ public URL để orchestrator OCR).
|
| 201 |
+
"""
|
| 202 |
+
supabase = get_supabase()
|
| 203 |
+
uid = str(user_id)
|
| 204 |
+
_assert_session_owner(supabase, session_id, user_id, uid, "owner_check_solve_multipart")
|
| 205 |
+
|
| 206 |
+
t = (text or "").strip()
|
| 207 |
+
if not t:
|
| 208 |
+
raise HTTPException(status_code=400, detail="text must not be empty.")
|
| 209 |
+
|
| 210 |
+
body = await file.read()
|
| 211 |
+
ext, content_type = validate_chat_image_bytes(file.filename, body, file.content_type)
|
| 212 |
+
|
| 213 |
+
job_id = str(uuid.uuid4())
|
| 214 |
+
up = upload_session_chat_image(session_id, job_id, body, ext, content_type)
|
| 215 |
+
public_url = up["public_url"]
|
| 216 |
+
|
| 217 |
+
message_metadata = {
|
| 218 |
+
"image_url": public_url,
|
| 219 |
+
"attachment": {
|
| 220 |
+
"public_url": public_url,
|
| 221 |
+
"storage_path": up["storage_path"],
|
| 222 |
+
"size_bytes": len(body),
|
| 223 |
+
"content_type": content_type,
|
| 224 |
+
"original_filename": file.filename or "",
|
| 225 |
+
"session_asset_id": up.get("session_asset_id"),
|
| 226 |
+
},
|
| 227 |
+
}
|
| 228 |
+
request = SolveRequest(text=t, image_url=public_url)
|
| 229 |
+
return _enqueue_solve_common(
|
| 230 |
+
supabase,
|
| 231 |
+
background_tasks,
|
| 232 |
+
session_id,
|
| 233 |
+
user_id,
|
| 234 |
+
uid,
|
| 235 |
+
request,
|
| 236 |
+
message_metadata,
|
| 237 |
+
job_id,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
@router.post("/{session_id}/render_video", response_model=RenderVideoResponse)
|
| 242 |
+
async def render_video(
|
| 243 |
+
session_id: str,
|
| 244 |
+
request: RenderVideoRequest,
|
| 245 |
+
background_tasks: BackgroundTasks,
|
| 246 |
+
user_id=Depends(get_current_user_id),
|
| 247 |
+
):
|
| 248 |
+
"""
|
| 249 |
+
Yêu cầu tạo video Manim từ trạng thái hình ảnh mới nhất của session.
|
| 250 |
+
"""
|
| 251 |
+
supabase = get_supabase()
|
| 252 |
+
|
| 253 |
+
# 1. Kiểm tra quyền sở hữu
|
| 254 |
+
res = supabase.table("sessions").select("id").eq("id", session_id).eq("user_id", user_id).execute()
|
| 255 |
+
if not res.data:
|
| 256 |
+
raise HTTPException(status_code=403, detail="Forbidden: You do not own this session.")
|
| 257 |
+
|
| 258 |
+
# 2. Tìm tin nhắn assistant có metadata hình học (cụ thể job_id hoặc mới nhất trong 10 tin nhắn gần nhất)
|
| 259 |
+
msg_res = (
|
| 260 |
+
supabase.table("messages")
|
| 261 |
+
.select("metadata")
|
| 262 |
+
.eq("session_id", session_id)
|
| 263 |
+
.eq("role", "assistant")
|
| 264 |
+
.order("created_at", desc=True)
|
| 265 |
+
.limit(10)
|
| 266 |
+
.execute()
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
latest_geometry = None
|
| 270 |
+
if msg_res.data:
|
| 271 |
+
for msg in msg_res.data:
|
| 272 |
+
meta = msg.get("metadata", {})
|
| 273 |
+
# Nếu có yêu cầu job_id cụ thể, phải khớp job_id
|
| 274 |
+
if request.job_id and meta.get("job_id") != request.job_id:
|
| 275 |
+
continue
|
| 276 |
+
|
| 277 |
+
# Phải có dữ liệu hình học
|
| 278 |
+
if meta.get("geometry_dsl") and meta.get("coordinates"):
|
| 279 |
+
latest_geometry = meta
|
| 280 |
+
break
|
| 281 |
+
|
| 282 |
+
if not latest_geometry:
|
| 283 |
+
raise HTTPException(status_code=404, detail="Không tìm thấy dữ liệu hình học để render video.")
|
| 284 |
+
|
| 285 |
+
# 3. Tạo Job rendering
|
| 286 |
+
job_id = str(uuid.uuid4())
|
| 287 |
+
supabase.table("jobs").insert({
|
| 288 |
+
"id": job_id,
|
| 289 |
+
"user_id": user_id,
|
| 290 |
+
"session_id": session_id,
|
| 291 |
+
"status": "rendering_queued",
|
| 292 |
+
"input_text": f"Render video requested at {job_id}",
|
| 293 |
+
}).execute()
|
| 294 |
+
|
| 295 |
+
# 4. Dispatch background task
|
| 296 |
+
background_tasks.add_task(process_render_job, job_id, session_id, latest_geometry)
|
| 297 |
+
|
| 298 |
+
return RenderVideoResponse(job_id=job_id, status="rendering_queued")
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
async def process_session_job(
|
| 302 |
+
job_id: str, session_id: str, request: SolveRequest, user_id: str
|
| 303 |
+
):
|
| 304 |
+
"""Tiến trình giải toán ngầm, tạo hình ảnh tĩnh."""
|
| 305 |
+
from app.websocket_manager import notify_status
|
| 306 |
+
|
| 307 |
+
async def status_update(status: str):
|
| 308 |
+
await notify_status(job_id, {"status": status})
|
| 309 |
+
|
| 310 |
+
supabase = get_supabase()
|
| 311 |
+
try:
|
| 312 |
+
history_res = (
|
| 313 |
+
supabase.table("messages")
|
| 314 |
+
.select("*")
|
| 315 |
+
.eq("session_id", session_id)
|
| 316 |
+
.order("created_at", desc=False)
|
| 317 |
+
.execute()
|
| 318 |
+
)
|
| 319 |
+
history = history_res.data if history_res.data else []
|
| 320 |
+
|
| 321 |
+
result = await get_orchestrator().run(
|
| 322 |
+
request.text,
|
| 323 |
+
request.image_url,
|
| 324 |
+
job_id=job_id,
|
| 325 |
+
session_id=session_id,
|
| 326 |
+
status_callback=status_update,
|
| 327 |
+
history=history,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
status = result.get("status", "error") if "error" not in result else "error"
|
| 331 |
+
|
| 332 |
+
supabase.table("jobs").update({"status": status, "result": result}).eq(
|
| 333 |
+
"id", job_id
|
| 334 |
+
).execute()
|
| 335 |
+
|
| 336 |
+
supabase.table("messages").insert(
|
| 337 |
+
{
|
| 338 |
+
"session_id": session_id,
|
| 339 |
+
"role": "assistant",
|
| 340 |
+
"type": "analysis" if "error" not in result else "error",
|
| 341 |
+
"content": (
|
| 342 |
+
result.get("semantic_analysis", "Đã có lỗi xảy ra.")
|
| 343 |
+
if "error" not in result
|
| 344 |
+
else result["error"]
|
| 345 |
+
),
|
| 346 |
+
"metadata": {
|
| 347 |
+
"job_id": job_id,
|
| 348 |
+
"coordinates": result.get("coordinates"),
|
| 349 |
+
"geometry_dsl": result.get("geometry_dsl"),
|
| 350 |
+
"polygon_order": result.get("polygon_order", []),
|
| 351 |
+
"drawing_phases": result.get("drawing_phases", []),
|
| 352 |
+
"circles": result.get("circles", []),
|
| 353 |
+
"lines": result.get("lines", []),
|
| 354 |
+
"rays": result.get("rays", []),
|
| 355 |
+
"solution": result.get("solution"),
|
| 356 |
+
"is_3d": result.get("is_3d", False),
|
| 357 |
+
},
|
| 358 |
+
}
|
| 359 |
+
).execute()
|
| 360 |
+
|
| 361 |
+
await notify_status(job_id, {"status": status, "result": result})
|
| 362 |
+
|
| 363 |
+
except Exception as e:
|
| 364 |
+
logger.exception("Error processing session job %s", job_id)
|
| 365 |
+
error_msg = format_error_for_user(e)
|
| 366 |
+
supabase = get_supabase()
|
| 367 |
+
supabase.table("jobs").update(
|
| 368 |
+
{"status": "error", "result": {"error": str(e)}}
|
| 369 |
+
).eq("id", job_id).execute()
|
| 370 |
+
supabase.table("messages").insert(
|
| 371 |
+
{
|
| 372 |
+
"session_id": session_id,
|
| 373 |
+
"role": "assistant",
|
| 374 |
+
"type": "error",
|
| 375 |
+
"content": error_msg,
|
| 376 |
+
"metadata": {"job_id": job_id},
|
| 377 |
+
}
|
| 378 |
+
).execute()
|
| 379 |
+
await notify_status(job_id, {"status": "error", "error": error_msg})
|
| 380 |
+
|
| 381 |
+
async def process_render_job(job_id: str, session_id: str, geometry_data: dict):
|
| 382 |
+
"""Tiến trình render video từ metadata có sẵn."""
|
| 383 |
+
from app.websocket_manager import notify_status
|
| 384 |
+
from worker.tasks import render_geometry_video
|
| 385 |
+
|
| 386 |
+
await notify_status(job_id, {"status": "rendering_queued"})
|
| 387 |
+
|
| 388 |
+
# Prepare payload for Celery (similar to what orchestrator used to do)
|
| 389 |
+
result_payload = {
|
| 390 |
+
"geometry_dsl": geometry_data.get("geometry_dsl"),
|
| 391 |
+
"coordinates": geometry_data.get("coordinates"),
|
| 392 |
+
"polygon_order": geometry_data.get("polygon_order", []),
|
| 393 |
+
"drawing_phases": geometry_data.get("drawing_phases", []),
|
| 394 |
+
"circles": geometry_data.get("circles", []),
|
| 395 |
+
"lines": geometry_data.get("lines", []),
|
| 396 |
+
"rays": geometry_data.get("rays", []),
|
| 397 |
+
"semantic": geometry_data.get("semantic", {}),
|
| 398 |
+
"semantic_analysis": geometry_data.get("semantic_analysis", "🎬 Video minh họa dựng từ trạng thái gần nhất."),
|
| 399 |
+
"session_id": session_id,
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
try:
|
| 403 |
+
logger.info(f"[RenderJob] Attempting to dispatch Celery task for job {job_id}...")
|
| 404 |
+
render_geometry_video.delay(job_id, result_payload)
|
| 405 |
+
logger.info(f"[RenderJob] SUCCESS: Dispatched Celery task for job {job_id}")
|
| 406 |
+
except Exception as e:
|
| 407 |
+
logger.exception(f"[RenderJob] FAILED to dispatch Celery task: {e}")
|
| 408 |
+
supabase = get_supabase()
|
| 409 |
+
supabase.table("jobs").update({"status": "error", "result": {"error": f"Task dispatch failed: {str(e)}"}}).eq("id", job_id).execute()
|
| 410 |
+
await notify_status(job_id, {"status": "error", "error": str(e)})
|
app/runtime_env.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Default process env vars (Paddle/OpenMP). Call as early as possible after load_dotenv."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def apply_runtime_env_defaults() -> None:
|
| 9 |
+
# Paddle respects OMP_NUM_THREADS at import; setdefault loses if platform already set 2+
|
| 10 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 11 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
| 12 |
+
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
app/session_cache.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TTL in-memory cache để giảm truy vấn Supabase lặp lại (list session, quyền sở hữu session)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Callable
|
| 6 |
+
|
| 7 |
+
from cachetools import TTLCache
|
| 8 |
+
|
| 9 |
+
from app.logutil import log_step
|
| 10 |
+
|
| 11 |
+
_session_list: TTLCache[str, list[Any]] = TTLCache(maxsize=512, ttl=45)
|
| 12 |
+
_session_owner: TTLCache[tuple[str, str], bool] = TTLCache(maxsize=4096, ttl=45)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def invalidate_for_user(user_id: str) -> None:
|
| 16 |
+
"""Xoá cache list session của user (sau create / delete / rename / solve đổi title)."""
|
| 17 |
+
_session_list.pop(user_id, None)
|
| 18 |
+
log_step("cache_invalidate", target="session_list", user_id=user_id)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def invalidate_session_owner(session_id: str, user_id: str) -> None:
|
| 22 |
+
_session_owner.pop((session_id, user_id), None)
|
| 23 |
+
log_step("cache_invalidate", target="session_owner", session_id=session_id, user_id=user_id)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_sessions_list_cached(user_id: str, fetch: Callable[[], list[Any]]) -> list[Any]:
|
| 27 |
+
if user_id in _session_list:
|
| 28 |
+
log_step("cache_hit", kind="session_list", user_id=user_id)
|
| 29 |
+
return _session_list[user_id]
|
| 30 |
+
log_step("cache_miss", kind="session_list", user_id=user_id)
|
| 31 |
+
data = fetch()
|
| 32 |
+
_session_list[user_id] = data
|
| 33 |
+
return data
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def session_owned_by_user(
|
| 37 |
+
session_id: str,
|
| 38 |
+
user_id: str,
|
| 39 |
+
fetch: Callable[[], bool],
|
| 40 |
+
) -> bool:
|
| 41 |
+
key = (session_id, user_id)
|
| 42 |
+
if key in _session_owner:
|
| 43 |
+
log_step("cache_hit", kind="session_owner", session_id=session_id)
|
| 44 |
+
return _session_owner[key]
|
| 45 |
+
log_step("cache_miss", kind="session_owner", session_id=session_id)
|
| 46 |
+
ok = fetch()
|
| 47 |
+
_session_owner[key] = ok
|
| 48 |
+
return ok
|
app/supabase_client.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from supabase import Client, ClientOptions, create_client
|
| 3 |
+
from supabase_auth import SyncMemoryStorage
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
|
| 6 |
+
load_dotenv()
|
| 7 |
+
|
| 8 |
+
from app.url_utils import sanitize_env
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_supabase() -> Client:
|
| 12 |
+
"""Service-role client for server-side operations (bypasses RLS when policies expect service role)."""
|
| 13 |
+
url = sanitize_env(os.getenv("SUPABASE_URL"))
|
| 14 |
+
key = sanitize_env(os.getenv("SUPABASE_SERVICE_ROLE_KEY") or os.getenv("SUPABASE_KEY"))
|
| 15 |
+
if not url or not key:
|
| 16 |
+
raise RuntimeError(
|
| 17 |
+
"SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY (or SUPABASE_KEY) must be set"
|
| 18 |
+
)
|
| 19 |
+
return create_client(url, key)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_supabase_for_user_jwt(access_token: str) -> Client:
|
| 23 |
+
"""
|
| 24 |
+
Client scoped to the logged-in user: PostgREST sends the user's JWT so RLS applies.
|
| 25 |
+
Use SUPABASE_ANON_KEY (publishable), not the service role key.
|
| 26 |
+
"""
|
| 27 |
+
url = sanitize_env(os.getenv("SUPABASE_URL"))
|
| 28 |
+
anon = sanitize_env(os.getenv("SUPABASE_ANON_KEY") or os.getenv("NEXT_PUBLIC_SUPABASE_ANON_KEY"))
|
| 29 |
+
if not url or not anon:
|
| 30 |
+
raise RuntimeError(
|
| 31 |
+
"SUPABASE_URL and SUPABASE_ANON_KEY (or NEXT_PUBLIC_SUPABASE_ANON_KEY) must be set "
|
| 32 |
+
"for user-scoped Supabase access"
|
| 33 |
+
)
|
| 34 |
+
base_opts = ClientOptions(storage=SyncMemoryStorage())
|
| 35 |
+
merged_headers = {**dict(base_opts.headers), "Authorization": f"Bearer {access_token}"}
|
| 36 |
+
opts = ClientOptions(storage=SyncMemoryStorage(), headers=merged_headers)
|
| 37 |
+
return create_client(url, anon, opts)
|
app/url_utils.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Normalize URLs / env strings (HF secrets and copy-paste often include trailing newlines)."""
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def sanitize_url(value: str | None) -> str | None:
|
| 5 |
+
if value is None:
|
| 6 |
+
return None
|
| 7 |
+
s = value.strip().replace("\r", "").replace("\n", "").replace("\t", "")
|
| 8 |
+
return s or None
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def sanitize_env(value: str | None) -> str | None:
|
| 12 |
+
"""Strip whitespace and line breaks from environment-backed strings."""
|
| 13 |
+
return sanitize_url(value)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# OpenAI SDK (>=1.x) requires a non-empty api_key at client construction (Docker build / prewarm has no secrets).
|
| 17 |
+
_OPENAI_API_KEY_BUILD_PLACEHOLDER = "build-placeholder-openrouter-not-for-production"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def openai_compatible_api_key(raw: str | None) -> str:
|
| 21 |
+
"""Return sanitized API key, or a placeholder so AsyncOpenAI() can be constructed without env at build time."""
|
| 22 |
+
k = sanitize_env(raw)
|
| 23 |
+
return k if k else _OPENAI_API_KEY_BUILD_PLACEHOLDER
|
app/websocket_manager.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""WebSocket connection registry and job status notifications (avoid circular imports with main)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Dict, List
|
| 7 |
+
|
| 8 |
+
from fastapi import WebSocket, WebSocketDisconnect
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
active_connections: Dict[str, List[WebSocket]] = {}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
async def notify_status(job_id: str, data: dict) -> None:
|
| 16 |
+
if job_id not in active_connections:
|
| 17 |
+
return
|
| 18 |
+
for connection in list(active_connections[job_id]):
|
| 19 |
+
try:
|
| 20 |
+
await connection.send_json(data)
|
| 21 |
+
except Exception as e:
|
| 22 |
+
logger.error("WS error sending to %s: %s", job_id, e)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def register_websocket_routes(app) -> None:
|
| 26 |
+
"""Attach websocket endpoint to the FastAPI app."""
|
| 27 |
+
|
| 28 |
+
@app.websocket("/ws/{job_id}")
|
| 29 |
+
async def websocket_endpoint(websocket: WebSocket, job_id: str) -> None:
|
| 30 |
+
await websocket.accept()
|
| 31 |
+
if job_id not in active_connections:
|
| 32 |
+
active_connections[job_id] = []
|
| 33 |
+
active_connections[job_id].append(websocket)
|
| 34 |
+
try:
|
| 35 |
+
while True:
|
| 36 |
+
await websocket.receive_text()
|
| 37 |
+
except WebSocketDisconnect:
|
| 38 |
+
active_connections[job_id].remove(websocket)
|
| 39 |
+
if not active_connections[job_id]:
|
| 40 |
+
del active_connections[job_id]
|
clean_ports.sh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Script to kill all project-related processes for a clean restart
|
| 3 |
+
|
| 4 |
+
echo "🧹 Cleaning up project processes..."
|
| 5 |
+
|
| 6 |
+
# Kill things on ports 8000 (Backend) and 3000 (Frontend)
|
| 7 |
+
PORTS="8000 3000 11020"
|
| 8 |
+
for PORT in $PORTS; do
|
| 9 |
+
PIDS=$(lsof -ti :$PORT)
|
| 10 |
+
if [ ! -z "$PIDS" ]; then
|
| 11 |
+
echo "Killing processes on port $PORT: $PIDS"
|
| 12 |
+
kill -9 $PIDS 2>/dev/null
|
| 13 |
+
fi
|
| 14 |
+
done
|
| 15 |
+
|
| 16 |
+
# Kill by process name
|
| 17 |
+
echo "Killing any remaining Celery, Uvicorn, or Manim processes..."
|
| 18 |
+
pkill -9 -f "celery" 2>/dev/null
|
| 19 |
+
pkill -9 -f "uvicorn" 2>/dev/null
|
| 20 |
+
pkill -9 -f "manim" 2>/dev/null
|
| 21 |
+
|
| 22 |
+
echo "✅ Done. You can now restart your Backend, Worker, and Frontend."
|
dump.rdb
ADDED
|
Binary file (5.44 kB). View file
|
|
|
migrations/add_image_bucket_storage.sql
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- ============================================================
|
| 2 |
+
-- MathSolver: Supabase Storage bucket `image` (chat / OCR attachments)
|
| 3 |
+
-- Run after session_assets and storage.video policies exist.
|
| 4 |
+
-- ============================================================
|
| 5 |
+
|
| 6 |
+
INSERT INTO storage.buckets (id, name, public)
|
| 7 |
+
VALUES ('image', 'image', true)
|
| 8 |
+
ON CONFLICT (id) DO UPDATE SET public = true;
|
| 9 |
+
|
| 10 |
+
-- Service role: upload/delete/list for API + workers
|
| 11 |
+
DROP POLICY IF EXISTS "Service Role manage images" ON storage.objects;
|
| 12 |
+
CREATE POLICY "Service Role manage images" ON storage.objects
|
| 13 |
+
FOR ALL
|
| 14 |
+
TO service_role
|
| 15 |
+
USING (bucket_id = 'image')
|
| 16 |
+
WITH CHECK (bucket_id = 'image');
|
| 17 |
+
|
| 18 |
+
-- Authenticated: read only objects under sessions they own (path sessions/{session_id}/...)
|
| 19 |
+
DROP POLICY IF EXISTS "Users view session images" ON storage.objects;
|
| 20 |
+
CREATE POLICY "Users view session images" ON storage.objects
|
| 21 |
+
FOR SELECT
|
| 22 |
+
TO authenticated
|
| 23 |
+
USING (
|
| 24 |
+
bucket_id = 'image'
|
| 25 |
+
AND (storage.foldername(name))[2] IN (
|
| 26 |
+
SELECT id::text FROM public.sessions WHERE user_id = auth.uid()
|
| 27 |
+
)
|
| 28 |
+
);
|
| 29 |
+
|
| 30 |
+
-- Public read for get_public_url / FE img tags (same model as video bucket)
|
| 31 |
+
DROP POLICY IF EXISTS "Public read images" ON storage.objects;
|
| 32 |
+
CREATE POLICY "Public read images" ON storage.objects
|
| 33 |
+
FOR SELECT
|
| 34 |
+
TO public
|
| 35 |
+
USING (bucket_id = 'image');
|
migrations/fix_rls_assets.sql
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- ============================================================
|
| 2 |
+
-- FIX RLS & SESSION ASSETS (MathSolver v5.1 Worker Fix)
|
| 3 |
+
-- ============================================================
|
| 4 |
+
|
| 5 |
+
-- 1. Ensure session_assets table exists
|
| 6 |
+
CREATE TABLE IF NOT EXISTS public.session_assets (
|
| 7 |
+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
| 8 |
+
session_id UUID NOT NULL REFERENCES public.sessions(id) ON DELETE CASCADE,
|
| 9 |
+
job_id UUID NOT NULL,
|
| 10 |
+
asset_type TEXT NOT NULL CHECK (asset_type IN ('video', 'image')),
|
| 11 |
+
storage_path TEXT NOT NULL,
|
| 12 |
+
public_url TEXT NOT NULL,
|
| 13 |
+
version INTEGER NOT NULL DEFAULT 1,
|
| 14 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
| 15 |
+
);
|
| 16 |
+
|
| 17 |
+
-- Index for session_assets
|
| 18 |
+
CREATE INDEX IF NOT EXISTS idx_session_assets_session_id ON public.session_assets(session_id);
|
| 19 |
+
CREATE INDEX IF NOT EXISTS idx_session_assets_type ON public.session_assets(session_id, asset_type);
|
| 20 |
+
|
| 21 |
+
-- 2. Enable RLS for all tables
|
| 22 |
+
ALTER TABLE public.session_assets ENABLE ROW LEVEL SECURITY;
|
| 23 |
+
ALTER TABLE public.profiles ENABLE ROW LEVEL SECURITY;
|
| 24 |
+
ALTER TABLE public.sessions ENABLE ROW LEVEL SECURITY;
|
| 25 |
+
ALTER TABLE public.messages ENABLE ROW LEVEL SECURITY;
|
| 26 |
+
ALTER TABLE public.jobs ENABLE ROW LEVEL SECURITY;
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
-- 3. Fix Table Policies to allow SERVICE ROLE
|
| 30 |
+
-- In Supabase, service_role usually bypasses RLS, but we add explicit policies for safety
|
| 31 |
+
-- especially for path-based checks or when SECURITY DEFINER functions are used.
|
| 32 |
+
|
| 33 |
+
-- [Session Assets]
|
| 34 |
+
DROP POLICY IF EXISTS "Users view own assets" ON public.session_assets;
|
| 35 |
+
CREATE POLICY "Users view own assets" ON public.session_assets
|
| 36 |
+
FOR SELECT USING (
|
| 37 |
+
session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid())
|
| 38 |
+
);
|
| 39 |
+
|
| 40 |
+
DROP POLICY IF EXISTS "Service role manages assets" ON public.session_assets;
|
| 41 |
+
CREATE POLICY "Service role manages assets" ON public.session_assets
|
| 42 |
+
FOR ALL USING (true)
|
| 43 |
+
WITH CHECK (true);
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
-- [Messages] - Allow Worker to insert assistant messages
|
| 47 |
+
DROP POLICY IF EXISTS "Users manage own messages" ON public.messages;
|
| 48 |
+
CREATE POLICY "Users manage own messages" ON public.messages
|
| 49 |
+
FOR ALL USING (
|
| 50 |
+
session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid())
|
| 51 |
+
OR
|
| 52 |
+
(auth.jwt() ->> 'role' = 'service_role')
|
| 53 |
+
);
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
-- [Jobs] - Allow Worker to update job status
|
| 57 |
+
DROP POLICY IF EXISTS "Users manage own jobs" ON public.jobs;
|
| 58 |
+
CREATE POLICY "Users manage own jobs" ON public.jobs
|
| 59 |
+
FOR ALL USING (
|
| 60 |
+
auth.uid() = user_id
|
| 61 |
+
OR user_id IS NULL
|
| 62 |
+
OR (auth.jwt() ->> 'role' = 'service_role')
|
| 63 |
+
);
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
-- 4. Storage Policies (Bucket: video)
|
| 67 |
+
-- Ensure 'video' bucket exists
|
| 68 |
+
INSERT INTO storage.buckets (id, name, public)
|
| 69 |
+
VALUES ('video', 'video', true)
|
| 70 |
+
ON CONFLICT (id) DO UPDATE SET public = true;
|
| 71 |
+
|
| 72 |
+
-- [Storage: Worker / Service Role] - Allow all in video bucket
|
| 73 |
+
DROP POLICY IF EXISTS "Service Role manage videos" ON storage.objects;
|
| 74 |
+
CREATE POLICY "Service Role manage videos" ON storage.objects
|
| 75 |
+
FOR ALL
|
| 76 |
+
TO service_role
|
| 77 |
+
USING (bucket_id = 'video');
|
| 78 |
+
|
| 79 |
+
-- [Storage: Users] - Allow users to view their session videos
|
| 80 |
+
DROP POLICY IF EXISTS "Users view session videos" ON storage.objects;
|
| 81 |
+
CREATE POLICY "Users view session videos" ON storage.objects
|
| 82 |
+
FOR SELECT
|
| 83 |
+
TO authenticated
|
| 84 |
+
USING (
|
| 85 |
+
bucket_id = 'video'
|
| 86 |
+
AND (storage.foldername(name))[2] IN (
|
| 87 |
+
SELECT id::text FROM public.sessions WHERE user_id = auth.uid()
|
| 88 |
+
)
|
| 89 |
+
);
|
| 90 |
+
|
| 91 |
+
-- [Storage: Public] - Allow public read access to videos
|
| 92 |
+
DROP POLICY IF EXISTS "Public read videos" ON storage.objects;
|
| 93 |
+
CREATE POLICY "Public read videos" ON storage.objects
|
| 94 |
+
FOR SELECT
|
| 95 |
+
TO public
|
| 96 |
+
USING (bucket_id = 'video');
|
migrations/v4_migration.sql
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- ============================================================
|
| 2 |
+
-- MATHSOLVER v4.0 - Migration Script (Multi-Session & History)
|
| 3 |
+
-- ============================================================
|
| 4 |
+
|
| 5 |
+
-- 1. Profiles Table (Extends Supabase Auth)
|
| 6 |
+
CREATE TABLE IF NOT EXISTS public.profiles (
|
| 7 |
+
id UUID PRIMARY KEY REFERENCES auth.users(id) ON DELETE CASCADE,
|
| 8 |
+
display_name TEXT,
|
| 9 |
+
avatar_url TEXT,
|
| 10 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
| 11 |
+
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
| 12 |
+
);
|
| 13 |
+
|
| 14 |
+
-- Function to handle new user signup and auto-create profile
|
| 15 |
+
CREATE OR REPLACE FUNCTION public.handle_new_user()
|
| 16 |
+
RETURNS TRIGGER AS $$
|
| 17 |
+
BEGIN
|
| 18 |
+
INSERT INTO public.profiles (id, display_name, avatar_url)
|
| 19 |
+
VALUES (
|
| 20 |
+
NEW.id,
|
| 21 |
+
COALESCE(NEW.raw_user_meta_data->>'full_name', NEW.email),
|
| 22 |
+
NEW.raw_user_meta_data->>'avatar_url'
|
| 23 |
+
);
|
| 24 |
+
RETURN NEW;
|
| 25 |
+
END;
|
| 26 |
+
$$ LANGUAGE plpgsql SECURITY DEFINER;
|
| 27 |
+
|
| 28 |
+
-- Trigger for profile creation
|
| 29 |
+
DROP TRIGGER IF EXISTS on_auth_user_created ON auth.users;
|
| 30 |
+
CREATE TRIGGER on_auth_user_created
|
| 31 |
+
AFTER INSERT ON auth.users
|
| 32 |
+
FOR EACH ROW EXECUTE FUNCTION public.handle_new_user();
|
| 33 |
+
|
| 34 |
+
-- 2. Sessions Table
|
| 35 |
+
CREATE TABLE IF NOT EXISTS public.sessions (
|
| 36 |
+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
| 37 |
+
user_id UUID NOT NULL REFERENCES auth.users(id) ON DELETE CASCADE,
|
| 38 |
+
title TEXT DEFAULT 'Bài toán mới',
|
| 39 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
| 40 |
+
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
| 41 |
+
);
|
| 42 |
+
|
| 43 |
+
-- Index for sessions
|
| 44 |
+
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON public.sessions(user_id);
|
| 45 |
+
CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON public.sessions(updated_at DESC);
|
| 46 |
+
|
| 47 |
+
-- 3. Messages Table
|
| 48 |
+
CREATE TABLE IF NOT EXISTS public.messages (
|
| 49 |
+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
| 50 |
+
session_id UUID NOT NULL REFERENCES public.sessions(id) ON DELETE CASCADE,
|
| 51 |
+
role TEXT NOT NULL CHECK (role IN ('user', 'assistant', 'system')),
|
| 52 |
+
type TEXT NOT NULL DEFAULT 'text',
|
| 53 |
+
content TEXT NOT NULL,
|
| 54 |
+
metadata JSONB DEFAULT '{}'::jsonb,
|
| 55 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
| 56 |
+
);
|
| 57 |
+
|
| 58 |
+
-- Index for messages
|
| 59 |
+
CREATE INDEX IF NOT EXISTS idx_messages_session_id ON public.messages(session_id);
|
| 60 |
+
CREATE INDEX IF NOT EXISTS idx_messages_created_at ON public.messages(session_id, created_at);
|
| 61 |
+
|
| 62 |
+
-- 4. Session Assets Table (v5.1 Versioning)
|
| 63 |
+
CREATE TABLE IF NOT EXISTS public.session_assets (
|
| 64 |
+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
| 65 |
+
session_id UUID NOT NULL REFERENCES public.sessions(id) ON DELETE CASCADE,
|
| 66 |
+
job_id UUID NOT NULL,
|
| 67 |
+
asset_type TEXT NOT NULL CHECK (asset_type IN ('video', 'image')),
|
| 68 |
+
storage_path TEXT NOT NULL,
|
| 69 |
+
public_url TEXT NOT NULL,
|
| 70 |
+
version INTEGER NOT NULL DEFAULT 1,
|
| 71 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
| 72 |
+
);
|
| 73 |
+
|
| 74 |
+
-- Index for session_assets
|
| 75 |
+
CREATE INDEX IF NOT EXISTS idx_session_assets_session_id ON public.session_assets(session_id);
|
| 76 |
+
|
| 77 |
+
-- 5. Update Jobs Table
|
| 78 |
+
ALTER TABLE public.jobs ADD COLUMN IF NOT EXISTS user_id UUID REFERENCES auth.users(id);
|
| 79 |
+
ALTER TABLE public.jobs ADD COLUMN IF NOT EXISTS session_id UUID REFERENCES public.sessions(id);
|
| 80 |
+
|
| 81 |
+
-- 6. Row Level Security (RLS)
|
| 82 |
+
ALTER TABLE public.profiles ENABLE ROW LEVEL SECURITY;
|
| 83 |
+
ALTER TABLE public.sessions ENABLE ROW LEVEL SECURITY;
|
| 84 |
+
ALTER TABLE public.messages ENABLE ROW LEVEL SECURITY;
|
| 85 |
+
ALTER TABLE public.jobs ENABLE ROW LEVEL SECURITY;
|
| 86 |
+
ALTER TABLE public.session_assets ENABLE ROW LEVEL SECURITY;
|
| 87 |
+
|
| 88 |
+
-- Polices for public.profiles
|
| 89 |
+
DROP POLICY IF EXISTS "Users view own profile" ON public.profiles;
|
| 90 |
+
CREATE POLICY "Users view own profile" ON public.profiles FOR SELECT USING (auth.uid() = id);
|
| 91 |
+
DROP POLICY IF EXISTS "Users update own profile" ON public.profiles;
|
| 92 |
+
CREATE POLICY "Users update own profile" ON public.profiles FOR UPDATE USING (auth.uid() = id);
|
| 93 |
+
|
| 94 |
+
-- Policies for public.sessions
|
| 95 |
+
DROP POLICY IF EXISTS "Users manage own sessions" ON public.sessions;
|
| 96 |
+
CREATE POLICY "Users manage own sessions" ON public.sessions FOR ALL USING (auth.uid() = user_id);
|
| 97 |
+
|
| 98 |
+
-- Policies for public.messages
|
| 99 |
+
DROP POLICY IF EXISTS "Users manage own messages" ON public.messages;
|
| 100 |
+
CREATE POLICY "Users manage own messages" ON public.messages FOR ALL USING (
|
| 101 |
+
session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid())
|
| 102 |
+
OR (auth.jwt() ->> 'role' = 'service_role')
|
| 103 |
+
);
|
| 104 |
+
|
| 105 |
+
-- Policies for public.session_assets
|
| 106 |
+
DROP POLICY IF EXISTS "Users view own assets" ON public.session_assets;
|
| 107 |
+
CREATE POLICY "Users view own assets" ON public.session_assets FOR SELECT USING (
|
| 108 |
+
session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid())
|
| 109 |
+
);
|
| 110 |
+
DROP POLICY IF EXISTS "Service role manages assets" ON public.session_assets;
|
| 111 |
+
CREATE POLICY "Service role manages assets" ON public.session_assets FOR ALL USING (true);
|
| 112 |
+
|
| 113 |
+
-- Policies for public.jobs
|
| 114 |
+
DROP POLICY IF EXISTS "Users manage own jobs" ON public.jobs;
|
| 115 |
+
CREATE POLICY "Users manage own jobs" ON public.jobs FOR ALL USING (
|
| 116 |
+
auth.uid() = user_id OR user_id IS NULL OR (auth.jwt() ->> 'role' = 'service_role')
|
| 117 |
+
);
|
| 118 |
+
|
| 119 |
+
-- 7. Storage Policies (Bucket: video)
|
| 120 |
+
-- (Run this in Supabase Dashboard if not allowed in migration)
|
| 121 |
+
-- INSERT INTO storage.buckets (id, name, public) VALUES ('video', 'video', true) ON CONFLICT (id) DO NOTHING;
|
| 122 |
+
-- CREATE POLICY "Service Role manage videos" ON storage.objects FOR ALL TO service_role USING (bucket_id = 'video');
|
| 123 |
+
-- CREATE POLICY "Public read videos" ON storage.objects FOR SELECT TO public USING (bucket_id = 'video');
|
| 124 |
+
|
| 125 |
+
-- Grant permissions to public/authenticated
|
| 126 |
+
GRANT ALL ON public.profiles TO authenticated;
|
| 127 |
+
GRANT ALL ON public.sessions TO authenticated;
|
| 128 |
+
GRANT ALL ON public.messages TO authenticated;
|
| 129 |
+
GRANT ALL ON public.jobs TO authenticated;
|
| 130 |
+
GRANT ALL ON public.session_assets TO authenticated;
|
| 131 |
+
GRANT ALL ON public.session_assets TO service_role;
|
pytest.ini
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
asyncio_mode = auto
|
| 3 |
+
testpaths = tests
|
| 4 |
+
pythonpath = .
|
| 5 |
+
filterwarnings =
|
| 6 |
+
ignore::DeprecationWarning
|
| 7 |
+
|
| 8 |
+
markers =
|
| 9 |
+
real_api: HTTP tests need running backend and TEST_USER_ID / TEST_SESSION_ID.
|
| 10 |
+
real_worker_ocr: OCR Celery task or full OCR stack (heavy).
|
| 11 |
+
real_worker_manim: Real Manim render and Supabase video upload.
|
| 12 |
+
real_agents: Live LLM / orchestrator agent calls.
|
| 13 |
+
slow: Large suite or long polling timeouts.
|
| 14 |
+
smoke: Fast API health + one solve job.
|
| 15 |
+
orchestrator_local: In-process Orchestrator without HTTP server.
|
| 16 |
+
|
| 17 |
+
# Default: skip integration tests that need services, keys, or long runs.
|
| 18 |
+
addopts = -m "not real_api and not real_worker_ocr and not real_worker_manim and not real_agents and not slow and not orchestrator_local"
|
requirements.txt
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Target: Python 3.11 (see Dockerfile). Used by: FastAPI API, Celery worker, Manim render, OCR/vision stack.
|
| 2 |
+
# Install: pip install -r requirements.txt
|
| 3 |
+
|
| 4 |
+
# --- Dev / test ---
|
| 5 |
+
pytest>=8.0
|
| 6 |
+
pytest-asyncio>=0.24
|
| 7 |
+
|
| 8 |
+
# --- HTTP API ---
|
| 9 |
+
cachetools>=5.3
|
| 10 |
+
fastapi>=0.115,<1
|
| 11 |
+
uvicorn[standard]>=0.30
|
| 12 |
+
python-multipart>=0.0.9
|
| 13 |
+
python-dotenv>=1.0
|
| 14 |
+
pydantic[email]>=2.4
|
| 15 |
+
email-validator>=2
|
| 16 |
+
|
| 17 |
+
# --- Auth / data / queue ---
|
| 18 |
+
openai>=1.40
|
| 19 |
+
supabase>=2.0
|
| 20 |
+
celery>=5.3
|
| 21 |
+
redis>=5
|
| 22 |
+
httpx>=0.27
|
| 23 |
+
websockets>=12
|
| 24 |
+
|
| 25 |
+
# --- Math & symbolic solver ---
|
| 26 |
+
sympy>=1.12
|
| 27 |
+
numpy>=1.26,<2
|
| 28 |
+
scipy>=1.11
|
| 29 |
+
opencv-python-headless>=4.8,<4.10
|
| 30 |
+
|
| 31 |
+
# --- Video (GeometryScene via CLI) ---
|
| 32 |
+
manim>=0.18,<0.20
|
| 33 |
+
|
| 34 |
+
# --- OCR & vision (orchestrator / legacy /ocr) ---
|
| 35 |
+
pix2tex>=0.1.4
|
| 36 |
+
paddleocr==2.7.3
|
| 37 |
+
paddlepaddle==2.6.2
|
| 38 |
+
ultralytics==8.2.2
|
requirements.worker-ocr.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OCR-only Celery worker: YOLO + PaddleOCR + Pix2Tex (no OpenRouter / no Manim).
|
| 2 |
+
# Install: pip install -r requirements.worker-ocr.txt
|
| 3 |
+
|
| 4 |
+
cachetools>=5.3
|
| 5 |
+
fastapi>=0.115,<1
|
| 6 |
+
uvicorn[standard]>=0.30
|
| 7 |
+
python-multipart>=0.0.9
|
| 8 |
+
python-dotenv>=1.0
|
| 9 |
+
pydantic[email]>=2.4
|
| 10 |
+
email-validator>=2
|
| 11 |
+
|
| 12 |
+
celery>=5.3
|
| 13 |
+
redis>=5
|
| 14 |
+
httpx>=0.27
|
| 15 |
+
websockets>=12
|
| 16 |
+
|
| 17 |
+
numpy>=1.26,<2
|
| 18 |
+
opencv-python-headless>=4.8,<4.10
|
| 19 |
+
|
| 20 |
+
pix2tex>=0.1.4
|
| 21 |
+
paddleocr==2.7.3
|
| 22 |
+
paddlepaddle==2.6.2
|
| 23 |
+
ultralytics==8.2.2
|
run_api_test.sh
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
LOG_FILE="api_test_results.log"
|
| 4 |
+
echo "=== Starting API E2E Test Suite ($(date)) ===" > $LOG_FILE
|
| 5 |
+
|
| 6 |
+
# 1. Start BE Server in background
|
| 7 |
+
echo "[INFO] Starting Backend Server..." | tee -a $LOG_FILE
|
| 8 |
+
export ALLOW_TEST_BYPASS=true
|
| 9 |
+
export LOG_LEVEL=info
|
| 10 |
+
export CELERY_TASK_ALWAYS_EAGER=true
|
| 11 |
+
export CELERY_RESULT_BACKEND=rpc://
|
| 12 |
+
export MOCK_VIDEO=true
|
| 13 |
+
PYTHONPATH=. venv/bin/python -m uvicorn app.main:app --port 8000 > server_debug.log 2>&1 &
|
| 14 |
+
SERVER_PID=$!
|
| 15 |
+
|
| 16 |
+
# 2. Wait for server to be ready
|
| 17 |
+
echo "[INFO] Waiting for server (PID: $SERVER_PID) on port 8000..." | tee -a $LOG_FILE
|
| 18 |
+
MAX_RETRIES=15
|
| 19 |
+
READY=0
|
| 20 |
+
for i in $(seq 1 $MAX_RETRIES); do
|
| 21 |
+
if curl -s http://localhost:8000/ > /dev/null; then
|
| 22 |
+
READY=1
|
| 23 |
+
break
|
| 24 |
+
fi
|
| 25 |
+
sleep 2
|
| 26 |
+
done
|
| 27 |
+
|
| 28 |
+
if [ $READY -eq 0 ]; then
|
| 29 |
+
echo "[ERROR] Server failed to start in time. Check server_debug.log" | tee -a $LOG_FILE
|
| 30 |
+
kill $SERVER_PID
|
| 31 |
+
exit 1
|
| 32 |
+
fi
|
| 33 |
+
echo "[INFO] Server is READY." | tee -a $LOG_FILE
|
| 34 |
+
|
| 35 |
+
# 3. Prepare Test Data
|
| 36 |
+
echo "[INFO] Preparing fresh test data..." | tee -a $LOG_FILE
|
| 37 |
+
PREP_OUTPUT=$(PYTHONPATH=. venv/bin/python scripts/prepare_api_test.py)
|
| 38 |
+
echo "$PREP_OUTPUT" >> $LOG_FILE
|
| 39 |
+
|
| 40 |
+
export TEST_USER_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:USER_ID=" | cut -d'=' -f2)
|
| 41 |
+
export TEST_SESSION_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2)
|
| 42 |
+
|
| 43 |
+
if [ -z "$TEST_USER_ID" ] || [ -z "$TEST_SESSION_ID" ]; then
|
| 44 |
+
echo "[ERROR] Failed to prepare test data." | tee -a $LOG_FILE
|
| 45 |
+
kill $SERVER_PID
|
| 46 |
+
exit 1
|
| 47 |
+
fi
|
| 48 |
+
|
| 49 |
+
echo "[INFO] Test Data: User=$TEST_USER_ID, Session=$TEST_SESSION_ID" | tee -a $LOG_FILE
|
| 50 |
+
|
| 51 |
+
# 4. Run Pytest
|
| 52 |
+
echo "[INFO] Running API E2E Tests..." | tee -a $LOG_FILE
|
| 53 |
+
PYTHONPATH=. venv/bin/python -m pytest tests/test_api_real_e2e.py -m "smoke and real_api" -s \
|
| 54 |
+
--junitxml=pytest_smoke.xml >> $LOG_FILE 2>&1
|
| 55 |
+
TEST_EXIT_CODE=$?
|
| 56 |
+
|
| 57 |
+
# 5. Cleanup
|
| 58 |
+
echo "[INFO] Shutting down Server..." | tee -a $LOG_FILE
|
| 59 |
+
kill $SERVER_PID
|
| 60 |
+
|
| 61 |
+
echo "==========================================" | tee -a $LOG_FILE
|
| 62 |
+
if [ $TEST_EXIT_CODE -eq 0 ]; then
|
| 63 |
+
echo "FINAL RESULT: ✅ ALL API TESTS PASSED" | tee -a $LOG_FILE
|
| 64 |
+
else
|
| 65 |
+
echo "FINAL RESULT: ❌ SOME API TESTS FAILED (Code: $TEST_EXIT_CODE)" | tee -a $LOG_FILE
|
| 66 |
+
fi
|
| 67 |
+
echo "==========================================" | tee -a $LOG_FILE
|
| 68 |
+
|
| 69 |
+
exit $TEST_EXIT_CODE
|
run_full_api_test.sh
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Full API integration (CI-style): eager Celery + mock video + full HTTP suite.
|
| 3 |
+
LOG_FILE="full_api_suite.log"
|
| 4 |
+
REPORT_FILE="full_api_test_report.md"
|
| 5 |
+
JSON_RESULTS="temp_suite_results.json"
|
| 6 |
+
JUNIT="pytest_api_suite.xml"
|
| 7 |
+
|
| 8 |
+
echo "=== Starting Full API Suite Test ($(date)) ===" >"$LOG_FILE"
|
| 9 |
+
|
| 10 |
+
trap 'echo "[INFO] Cleaning up processes..."; kill $SERVER_PID 2>/dev/null; sleep 1' EXIT
|
| 11 |
+
|
| 12 |
+
echo "[INFO] Starting Backend Server (EAGER + MOCK_VIDEO)..." | tee -a "$LOG_FILE"
|
| 13 |
+
export ALLOW_TEST_BYPASS=true
|
| 14 |
+
export LOG_LEVEL=info
|
| 15 |
+
export CELERY_TASK_ALWAYS_EAGER=true
|
| 16 |
+
export CELERY_RESULT_BACKEND=rpc://
|
| 17 |
+
export MOCK_VIDEO=true
|
| 18 |
+
PYTHONPATH=. venv/bin/python -m uvicorn app.main:app --port 8000 >server_debug.log 2>&1 &
|
| 19 |
+
SERVER_PID=$!
|
| 20 |
+
|
| 21 |
+
echo "[INFO] Waiting for server (PID: $SERVER_PID)..." | tee -a "$LOG_FILE"
|
| 22 |
+
for i in {1..20}; do
|
| 23 |
+
if curl -s http://localhost:8000/ >/dev/null; then
|
| 24 |
+
echo "[INFO] Server is READY." | tee -a "$LOG_FILE"
|
| 25 |
+
break
|
| 26 |
+
fi
|
| 27 |
+
sleep 2
|
| 28 |
+
done
|
| 29 |
+
|
| 30 |
+
echo "[INFO] Preparing fresh test data..." | tee -a "$LOG_FILE"
|
| 31 |
+
PREP_OUTPUT=$(PYTHONPATH=. venv/bin/python scripts/prepare_api_test.py)
|
| 32 |
+
export TEST_USER_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:USER_ID=" | cut -d'=' -f2)
|
| 33 |
+
export TEST_SESSION_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2)
|
| 34 |
+
|
| 35 |
+
if [ -z "$TEST_USER_ID" ]; then
|
| 36 |
+
echo "[ERROR] Failed to prepare test data." | tee -a "$LOG_FILE"
|
| 37 |
+
exit 1
|
| 38 |
+
fi
|
| 39 |
+
|
| 40 |
+
echo "[INFO] Executing API tests (smoke + full suite)..." | tee -a "$LOG_FILE"
|
| 41 |
+
PYTHONPATH=. venv/bin/python -m pytest tests/test_api_real_e2e.py tests/test_api_full_suite.py \
|
| 42 |
+
-m "real_api" -s --tb=short --junitxml="$JUNIT" >>"$LOG_FILE" 2>&1
|
| 43 |
+
TEST_EXIT_CODE=$?
|
| 44 |
+
|
| 45 |
+
echo "[INFO] Generating Markdown Report..." | tee -a "$LOG_FILE"
|
| 46 |
+
if [ -f "$JSON_RESULTS" ]; then
|
| 47 |
+
PYTHONPATH=. venv/bin/python scripts/generate_report.py "$JSON_RESULTS" "$REPORT_FILE" "$JUNIT"
|
| 48 |
+
else
|
| 49 |
+
echo "[WARN] $JSON_RESULTS not found" | tee -a "$LOG_FILE"
|
| 50 |
+
fi
|
| 51 |
+
|
| 52 |
+
echo "==========================================" | tee -a "$LOG_FILE"
|
| 53 |
+
echo "DONE. Check $REPORT_FILE for results." | tee -a "$LOG_FILE"
|
| 54 |
+
echo "==========================================" | tee -a "$LOG_FILE"
|
| 55 |
+
|
| 56 |
+
exit $TEST_EXIT_CODE
|
scripts/benchmark_openrouter.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Benchmark several OpenRouter models (manual tool; not part of pytest)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
import httpx
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
|
| 12 |
+
_BACKEND_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 13 |
+
load_dotenv(os.path.join(_BACKEND_ROOT, ".env"))
|
| 14 |
+
|
| 15 |
+
MODELS = [
|
| 16 |
+
"nvidia/nemotron-3-super-120b-a12b:free",
|
| 17 |
+
"meta-llama/llama-3.3-70b-instruct:free",
|
| 18 |
+
"openai/gpt-oss-120b:free",
|
| 19 |
+
"z-ai/glm-4.5-air:free",
|
| 20 |
+
"minimax/minimax-m2.5:free",
|
| 21 |
+
"google/gemma-4-26b-a4b-it:free",
|
| 22 |
+
"google/gemma-4-31b-it:free",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
PROMPT = (
|
| 26 |
+
"Cho hình chữ nhật ABCD có AB bằng 5 và AD bằng 10. Gọi E là điểm nằm trong đoạn CD sao cho CE = 2ED. "
|
| 27 |
+
"Vẽ đoạn thẳng AE. Vẽ thêm P là điểm nằm trên đường thẳng BC sao cho BP = 2PC, tính chu vi tam giác PEA"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def main() -> None:
|
| 32 |
+
api_key = os.getenv("OPENROUTER_API_KEY_1") or os.getenv("OPENROUTER_API_KEY")
|
| 33 |
+
base_url = "https://openrouter.ai/api/v1/chat/completions"
|
| 34 |
+
|
| 35 |
+
if not api_key:
|
| 36 |
+
print("Missing OPENROUTER_API_KEY_1 or OPENROUTER_API_KEY in .env")
|
| 37 |
+
return
|
| 38 |
+
|
| 39 |
+
print("Benchmark OpenRouter models\nPrompt:", PROMPT, "\n")
|
| 40 |
+
results = []
|
| 41 |
+
|
| 42 |
+
for model in MODELS:
|
| 43 |
+
print(f"Calling {model}...", end="", flush=True)
|
| 44 |
+
headers = {
|
| 45 |
+
"Authorization": f"Bearer {api_key}",
|
| 46 |
+
"Content-Type": "application/json",
|
| 47 |
+
"HTTP-Referer": "https://mathsolver.io",
|
| 48 |
+
"X-Title": "MathSolver Benchmark Tool",
|
| 49 |
+
}
|
| 50 |
+
payload = {"model": model, "messages": [{"role": "user", "content": PROMPT}]}
|
| 51 |
+
start = time.time()
|
| 52 |
+
try:
|
| 53 |
+
with httpx.Client(timeout=120.0) as client:
|
| 54 |
+
r = client.post(base_url, headers=headers, json=payload)
|
| 55 |
+
r.raise_for_status()
|
| 56 |
+
data = r.json()
|
| 57 |
+
answer = data["choices"][0]["message"]["content"]
|
| 58 |
+
duration = time.time() - start
|
| 59 |
+
results.append(
|
| 60 |
+
{"model": model, "duration": duration, "answer": answer, "status": "success"}
|
| 61 |
+
)
|
| 62 |
+
print(f" OK ({duration:.2f}s)")
|
| 63 |
+
except Exception as e:
|
| 64 |
+
duration = time.time() - start
|
| 65 |
+
results.append(
|
| 66 |
+
{"model": model, "duration": duration, "error": str(e), "status": "error"}
|
| 67 |
+
)
|
| 68 |
+
print(f" FAIL ({duration:.2f}s) {e}")
|
| 69 |
+
|
| 70 |
+
print("\n" + "=" * 80)
|
| 71 |
+
for res in results:
|
| 72 |
+
print(json.dumps(res, ensure_ascii=False, indent=2)[:2000])
|
| 73 |
+
print("-" * 40)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
main()
|
scripts/generate_report.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import xml.etree.ElementTree as ET
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _parse_junit_xml(path: str) -> dict:
|
| 9 |
+
"""Summarize pytest junitxml (JUnit) file."""
|
| 10 |
+
out = {"tests": 0, "failures": 0, "errors": 0, "skipped": 0, "time": 0.0}
|
| 11 |
+
try:
|
| 12 |
+
tree = ET.parse(path)
|
| 13 |
+
root = tree.getroot()
|
| 14 |
+
nodes = [root] if root.tag == "testsuite" else list(root.iter("testsuite"))
|
| 15 |
+
for ts in nodes:
|
| 16 |
+
if ts.tag != "testsuite":
|
| 17 |
+
continue
|
| 18 |
+
out["tests"] += int(ts.attrib.get("tests", 0) or 0)
|
| 19 |
+
out["failures"] += int(ts.attrib.get("failures", 0) or 0)
|
| 20 |
+
out["errors"] += int(ts.attrib.get("errors", 0) or 0)
|
| 21 |
+
out["skipped"] += int(ts.attrib.get("skipped", 0) or 0)
|
| 22 |
+
out["time"] += float(ts.attrib.get("time", 0) or 0)
|
| 23 |
+
except Exception as e:
|
| 24 |
+
out["parse_error"] = str(e)
|
| 25 |
+
return out
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def generate_report(json_path: str, report_path: str, junit_path: str | None = None) -> None:
|
| 29 |
+
try:
|
| 30 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
| 31 |
+
data = json.load(f)
|
| 32 |
+
|
| 33 |
+
junit_summary = None
|
| 34 |
+
if junit_path and os.path.isfile(junit_path):
|
| 35 |
+
junit_summary = _parse_junit_xml(junit_path)
|
| 36 |
+
|
| 37 |
+
with open(report_path, "w", encoding="utf-8") as f:
|
| 38 |
+
f.write("# Báo cáo Kiểm thử tích hợp Backend (Integration Report)\n\n")
|
| 39 |
+
f.write(f"**Thời gian chạy:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
| 40 |
+
suite_ok = all(r.get("success", False) for r in data) if isinstance(data, list) else False
|
| 41 |
+
f.write(f"**API suite (JSON):** {'PASS' if suite_ok else 'FAIL'}\n")
|
| 42 |
+
|
| 43 |
+
if junit_summary and "parse_error" not in junit_summary:
|
| 44 |
+
j_ok = junit_summary["failures"] == 0 and junit_summary["errors"] == 0
|
| 45 |
+
f.write(
|
| 46 |
+
f"**Pytest (JUnit):** {'PASS' if j_ok else 'FAIL'} — "
|
| 47 |
+
f"tests={junit_summary['tests']}, failures={junit_summary['failures']}, "
|
| 48 |
+
f"errors={junit_summary['errors']}, skipped={junit_summary['skipped']}, "
|
| 49 |
+
f"time_s={junit_summary['time']:.2f}\n"
|
| 50 |
+
)
|
| 51 |
+
elif junit_summary and "parse_error" in junit_summary:
|
| 52 |
+
f.write(f"**Pytest (JUnit):** (could not parse: {junit_summary['parse_error']})\n")
|
| 53 |
+
|
| 54 |
+
f.write("\n")
|
| 55 |
+
|
| 56 |
+
f.write("| ID | Câu hỏi (Query) | Trạng thái | Thời gian (s) | Kết quả / Lỗi |\n")
|
| 57 |
+
f.write("| :--- | :--- | :--- | :--- | :--- |\n")
|
| 58 |
+
for r in data:
|
| 59 |
+
status = "PASS" if r.get("success") else "FAIL"
|
| 60 |
+
elapsed = f"{float(r.get('elapsed', 0) or 0):.2f}"
|
| 61 |
+
query = r.get("query", "-")
|
| 62 |
+
|
| 63 |
+
res = r.get("result", {})
|
| 64 |
+
if not isinstance(res, dict):
|
| 65 |
+
res = {}
|
| 66 |
+
|
| 67 |
+
analysis = res.get("semantic_analysis", "-")
|
| 68 |
+
if not r.get("success"):
|
| 69 |
+
analysis = f"**Lỗi:** {r.get('error', '-')}"
|
| 70 |
+
|
| 71 |
+
short_analysis = (analysis[:100] + "...") if len(str(analysis)) > 100 else analysis
|
| 72 |
+
|
| 73 |
+
f.write(f"| {r['id']} | {query} | {status} | {elapsed} | {short_analysis} |\n")
|
| 74 |
+
|
| 75 |
+
f.write("\n---\n**Chi tiết Output (DSL & Analysis):**\n")
|
| 76 |
+
for r in data:
|
| 77 |
+
if not r.get("success"):
|
| 78 |
+
continue
|
| 79 |
+
res = r.get("result", {})
|
| 80 |
+
if not isinstance(res, dict):
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
f.write(f"\n### Case {r['id']}: {r.get('query')}\n")
|
| 84 |
+
f.write(f"**Semantic Analysis:**\n{res.get('semantic_analysis', '-')}\n\n")
|
| 85 |
+
f.write(f"**Geometry DSL:**\n```\n{res.get('geometry_dsl', '-')}\n```\n")
|
| 86 |
+
|
| 87 |
+
sol = res.get("solution")
|
| 88 |
+
if sol and isinstance(sol, dict):
|
| 89 |
+
f.write("**Solution (v5.1):**\n")
|
| 90 |
+
f.write(f"- **Answer:** {sol.get('answer', 'N/A')}\n")
|
| 91 |
+
f.write("- **Steps:**\n")
|
| 92 |
+
steps = sol.get("steps", [])
|
| 93 |
+
if steps:
|
| 94 |
+
for step in steps:
|
| 95 |
+
f.write(f" - {step}\n")
|
| 96 |
+
else:
|
| 97 |
+
f.write(" - (Không có bước giải cụ thể)\n")
|
| 98 |
+
|
| 99 |
+
if sol.get("symbolic_expression"):
|
| 100 |
+
f.write(f"- **Symbolic:** `{sol.get('symbolic_expression')}`\n")
|
| 101 |
+
f.write("\n")
|
| 102 |
+
|
| 103 |
+
print(f"Report generated: {report_path}")
|
| 104 |
+
except Exception as e:
|
| 105 |
+
print(f"Error generating report: {e}")
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
if len(sys.argv) < 3:
|
| 110 |
+
print(
|
| 111 |
+
"Usage: python generate_report.py <json_results> <report_output> [junit_xml_optional]"
|
| 112 |
+
)
|
| 113 |
+
sys.exit(1)
|
| 114 |
+
junit = sys.argv[3] if len(sys.argv) > 3 else None
|
| 115 |
+
generate_report(sys.argv[1], sys.argv[2], junit)
|
scripts/prepare_api_test.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import uuid
|
| 4 |
+
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
|
| 7 |
+
# Add parent dir to path to import app modules
|
| 8 |
+
_BACKEND_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 9 |
+
sys.path.append(_BACKEND_ROOT)
|
| 10 |
+
load_dotenv(os.path.join(_BACKEND_ROOT, ".env"))
|
| 11 |
+
|
| 12 |
+
from app.supabase_client import get_supabase
|
| 13 |
+
|
| 14 |
+
# Default UUID matches historical dev DB; override with TEST_SUPABASE_USER_ID in .env
|
| 15 |
+
_DEFAULT_TEST_USER = "8cd3adb0-7964-4575-949c-d0cadcd8b679"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def prepare():
|
| 19 |
+
supabase = get_supabase()
|
| 20 |
+
user_id = os.environ.get("TEST_SUPABASE_USER_ID", _DEFAULT_TEST_USER).strip()
|
| 21 |
+
session_id = str(uuid.uuid4())
|
| 22 |
+
|
| 23 |
+
print(f"Using test user (TEST_SUPABASE_USER_ID or default): {user_id}")
|
| 24 |
+
|
| 25 |
+
print(f"Creating fresh test session: {session_id}")
|
| 26 |
+
# Insert session
|
| 27 |
+
supabase.table("sessions").insert({
|
| 28 |
+
"id": session_id,
|
| 29 |
+
"user_id": user_id,
|
| 30 |
+
"title": f"Fresh API Test {session_id[:8]}"
|
| 31 |
+
}).execute()
|
| 32 |
+
|
| 33 |
+
# Return IDs for the test script
|
| 34 |
+
print(f"RESULT:USER_ID={user_id}")
|
| 35 |
+
print(f"RESULT:SESSION_ID={session_id}")
|
| 36 |
+
|
| 37 |
+
if __name__ == "__main__":
|
| 38 |
+
prepare()
|
scripts/prewarm_models.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Download and load all heavy models during Docker build (YOLO, PaddleOCR, Pix2Tex, agents).
|
| 4 |
+
Fails the image build if initialization fails.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
# Ensure imports work when run as `python scripts/prewarm_models.py` from WORKDIR
|
| 14 |
+
_APP_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 15 |
+
if _APP_ROOT not in sys.path:
|
| 16 |
+
sys.path.insert(0, _APP_ROOT)
|
| 17 |
+
|
| 18 |
+
os.chdir(_APP_ROOT)
|
| 19 |
+
|
| 20 |
+
from dotenv import load_dotenv
|
| 21 |
+
|
| 22 |
+
load_dotenv()
|
| 23 |
+
|
| 24 |
+
from app.runtime_env import apply_runtime_env_defaults
|
| 25 |
+
|
| 26 |
+
apply_runtime_env_defaults()
|
| 27 |
+
|
| 28 |
+
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s | %(message)s")
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger("prewarm")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def main() -> None:
|
| 34 |
+
from agents.orchestrator import Orchestrator
|
| 35 |
+
|
| 36 |
+
logger.info("Constructing Orchestrator (full agent + model load)...")
|
| 37 |
+
Orchestrator()
|
| 38 |
+
logger.info("Prewarm finished successfully.")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
main()
|
scripts/prewarm_ocr_worker.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Docker build: load OCR models only (no Orchestrator / no LLM). Used by Dockerfile.worker.ocr."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
_APP_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 11 |
+
if _APP_ROOT not in sys.path:
|
| 12 |
+
sys.path.insert(0, _APP_ROOT)
|
| 13 |
+
|
| 14 |
+
os.chdir(_APP_ROOT)
|
| 15 |
+
|
| 16 |
+
from dotenv import load_dotenv
|
| 17 |
+
|
| 18 |
+
load_dotenv()
|
| 19 |
+
|
| 20 |
+
from app.runtime_env import apply_runtime_env_defaults
|
| 21 |
+
|
| 22 |
+
apply_runtime_env_defaults()
|
| 23 |
+
|
| 24 |
+
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s | %(message)s")
|
| 25 |
+
logger = logging.getLogger("prewarm_ocr_worker")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def main() -> None:
|
| 29 |
+
from agents.ocr_agent import OCRAgent
|
| 30 |
+
|
| 31 |
+
logger.info("Loading OCRAgent(skip_llm_refinement=True)...")
|
| 32 |
+
OCRAgent(skip_llm_refinement=True)
|
| 33 |
+
logger.info("OCR worker prewarm finished successfully.")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
main()
|
scripts/run_real_integration.sh
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Run backend integration tests. Usage:
|
| 3 |
+
# ./scripts/run_real_integration.sh # profile ci (default)
|
| 4 |
+
# ./scripts/run_real_integration.sh ci
|
| 5 |
+
# ./scripts/run_real_integration.sh real # heavy: workers, manim, OCR, full API suite
|
| 6 |
+
set -euo pipefail
|
| 7 |
+
|
| 8 |
+
PROFILE="${1:-ci}"
|
| 9 |
+
ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
| 10 |
+
cd "$ROOT"
|
| 11 |
+
PY="${ROOT}/venv/bin/python"
|
| 12 |
+
if [[ ! -x "$PY" ]]; then
|
| 13 |
+
PY="python3"
|
| 14 |
+
fi
|
| 15 |
+
|
| 16 |
+
export PYTHONPATH="$ROOT"
|
| 17 |
+
LOG_FILE="${LOG_FILE:-integration_run.log}"
|
| 18 |
+
JUNIT="${JUNIT:-pytest_integration.xml}"
|
| 19 |
+
REPORT_MD="${REPORT_MD:-integration_report.md}"
|
| 20 |
+
JSON_RESULTS="${JSON_RESULTS:-temp_suite_results.json}"
|
| 21 |
+
|
| 22 |
+
log() { echo "[$(date '+%H:%M:%S')] $*" | tee -a "$LOG_FILE"; }
|
| 23 |
+
|
| 24 |
+
log "Profile=$PROFILE working_dir=$ROOT"
|
| 25 |
+
|
| 26 |
+
if [[ "$PROFILE" == "ci" ]]; then
|
| 27 |
+
export ALLOW_TEST_BYPASS="${ALLOW_TEST_BYPASS:-true}"
|
| 28 |
+
export LOG_LEVEL="${LOG_LEVEL:-info}"
|
| 29 |
+
export CELERY_TASK_ALWAYS_EAGER="${CELERY_TASK_ALWAYS_EAGER:-true}"
|
| 30 |
+
export CELERY_RESULT_BACKEND="${CELERY_RESULT_BACKEND:-rpc://}"
|
| 31 |
+
export MOCK_VIDEO="${MOCK_VIDEO:-true}"
|
| 32 |
+
|
| 33 |
+
set +e
|
| 34 |
+
log "Phase A: default pytest (unit / mocked; excludes real_* markers per pytest.ini)"
|
| 35 |
+
"$PY" -m pytest tests/ -q --tb=short -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE"
|
| 36 |
+
P1=${PIPESTATUS[0]}
|
| 37 |
+
set -e
|
| 38 |
+
|
| 39 |
+
log "Starting uvicorn for API phase..."
|
| 40 |
+
"$PY" -m uvicorn app.main:app --port 8000 >>uvicorn_integration.log 2>&1 &
|
| 41 |
+
SERVER_PID=$!
|
| 42 |
+
trap 'kill "$SERVER_PID" 2>/dev/null || true' EXIT
|
| 43 |
+
|
| 44 |
+
for i in $(seq 1 25); do
|
| 45 |
+
if curl -sf "http://localhost:8000/" >/dev/null; then
|
| 46 |
+
log "API ready"
|
| 47 |
+
break
|
| 48 |
+
fi
|
| 49 |
+
sleep 2
|
| 50 |
+
if [[ "$i" -eq 25 ]]; then
|
| 51 |
+
log "ERROR: API did not start"
|
| 52 |
+
exit 1
|
| 53 |
+
fi
|
| 54 |
+
done
|
| 55 |
+
|
| 56 |
+
PREP="$("$PY" scripts/prepare_api_test.py)"
|
| 57 |
+
echo "$PREP" | tee -a "$LOG_FILE"
|
| 58 |
+
export TEST_USER_ID="$(echo "$PREP" | grep "RESULT:USER_ID=" | cut -d'=' -f2)"
|
| 59 |
+
export TEST_SESSION_ID="$(echo "$PREP" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2)"
|
| 60 |
+
if [[ -z "${TEST_USER_ID:-}" || -z "${TEST_SESSION_ID:-}" ]]; then
|
| 61 |
+
log "ERROR: prepare_api_test did not emit USER_ID / SESSION_ID"
|
| 62 |
+
exit 1
|
| 63 |
+
fi
|
| 64 |
+
|
| 65 |
+
set +e
|
| 66 |
+
log "Phase B: API smoke + full suite (real_api)"
|
| 67 |
+
"$PY" -m pytest tests/test_api_real_e2e.py tests/test_api_full_suite.py \
|
| 68 |
+
-m "real_api" -s --tb=short --junitxml="$JUNIT" -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE"
|
| 69 |
+
P2=${PIPESTATUS[0]}
|
| 70 |
+
set -e
|
| 71 |
+
|
| 72 |
+
if [[ -f "$JSON_RESULTS" ]]; then
|
| 73 |
+
log "Generating Markdown report"
|
| 74 |
+
"$PY" scripts/generate_report.py "$JSON_RESULTS" "$REPORT_MD" "$JUNIT"
|
| 75 |
+
else
|
| 76 |
+
log "WARN: $JSON_RESULTS missing (suite may have failed before write)"
|
| 77 |
+
fi
|
| 78 |
+
|
| 79 |
+
if [[ "$P1" -ne 0 || "$P2" -ne 0 ]]; then
|
| 80 |
+
log "FAIL: phase A exit=$P1 phase B exit=$P2"
|
| 81 |
+
exit 1
|
| 82 |
+
fi
|
| 83 |
+
|
| 84 |
+
log "Done CI profile. See $REPORT_MD and $LOG_FILE"
|
| 85 |
+
exit 0
|
| 86 |
+
fi
|
| 87 |
+
|
| 88 |
+
if [[ "$PROFILE" == "real" ]]; then
|
| 89 |
+
unset CELERY_TASK_ALWAYS_EAGER || true
|
| 90 |
+
export CELERY_TASK_ALWAYS_EAGER="${CELERY_TASK_ALWAYS_EAGER:-false}"
|
| 91 |
+
export MOCK_VIDEO="${MOCK_VIDEO:-false}"
|
| 92 |
+
export RUN_REAL_WORKER_OCR="${RUN_REAL_WORKER_OCR:-0}"
|
| 93 |
+
export RUN_REAL_WORKER_MANIM="${RUN_REAL_WORKER_MANIM:-0}"
|
| 94 |
+
|
| 95 |
+
log "Phase A: default pytest (fast)"
|
| 96 |
+
"$PY" -m pytest tests/ -q --tb=short -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE"
|
| 97 |
+
|
| 98 |
+
log "Phase B: real agents + orchestrator smoke (requires OpenRouter keys)"
|
| 99 |
+
"$PY" -m pytest tests/integration/test_agents_real.py tests/integration/test_orchestrator_smoke.py \
|
| 100 |
+
-m "real_agents" -q --tb=short --junitxml="$JUNIT" -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE" || true
|
| 101 |
+
|
| 102 |
+
if [[ "${RUN_REAL_WORKER_OCR:-0}" == "1" ]] || [[ "${RUN_REAL_WORKER_OCR:-0}" =~ ^(true|yes)$ ]]; then
|
| 103 |
+
log "Phase C: OCR worker task (RUN_REAL_WORKER_OCR enabled)"
|
| 104 |
+
"$PY" -m pytest tests/integration/test_worker_ocr_real.py \
|
| 105 |
+
-m "real_worker_ocr" -q --tb=short -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE" || true
|
| 106 |
+
else
|
| 107 |
+
log "Skipping OCR worker (set RUN_REAL_WORKER_OCR=1 to enable)"
|
| 108 |
+
fi
|
| 109 |
+
|
| 110 |
+
if [[ "${RUN_REAL_WORKER_MANIM:-0}" == "1" ]]; then
|
| 111 |
+
log "Phase D: Manim + storage (RUN_REAL_WORKER_MANIM=1, MOCK_VIDEO=false)"
|
| 112 |
+
"$PY" -m pytest tests/integration/test_worker_manim_real.py -m "real_worker_manim" -s --tb=short \
|
| 113 |
+
-p no:cacheprovider 2>&1 | tee -a "$LOG_FILE" || true
|
| 114 |
+
else
|
| 115 |
+
log "Skipping Manim integration (set RUN_REAL_WORKER_MANIM=1 to enable)"
|
| 116 |
+
fi
|
| 117 |
+
|
| 118 |
+
log "Phase E: API real (expects TEST_BASE_URL or localhost:8000 with server already up)"
|
| 119 |
+
if curl -sf "http://localhost:8000/" >/dev/null 2>&1; then
|
| 120 |
+
PREP="$("$PY" scripts/prepare_api_test.py)"
|
| 121 |
+
export TEST_USER_ID="$(echo "$PREP" | grep "RESULT:USER_ID=" | cut -d'=' -f2)"
|
| 122 |
+
export TEST_SESSION_ID="$(echo "$PREP" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2)"
|
| 123 |
+
"$PY" -m pytest tests/test_api_real_e2e.py tests/test_api_full_suite.py tests/test_api_metadata_real.py \
|
| 124 |
+
-m "real_api" -q --tb=short -p no:cacheprovider 2>&1 | tee -a "$LOG_FILE" || true
|
| 125 |
+
else
|
| 126 |
+
log "WARN: No server on :8000 — skip API real phase (start backend first)"
|
| 127 |
+
fi
|
| 128 |
+
|
| 129 |
+
log "Done REAL profile. Review $LOG_FILE"
|
| 130 |
+
exit 0
|
| 131 |
+
fi
|
| 132 |
+
|
| 133 |
+
echo "Unknown profile: $PROFILE (use ci or real)"
|
| 134 |
+
exit 1
|
scripts/test_LLM.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import asyncio
|
| 5 |
+
import logging
|
| 6 |
+
from typing import List, Dict, Any
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
# Add the parent directory to sys.path to allow importing from 'app'
|
| 10 |
+
# This assumes the script is inside 'backend/scripts' and we want to import from 'backend/app'
|
| 11 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 12 |
+
|
| 13 |
+
from app.url_utils import openai_compatible_api_key
|
| 14 |
+
from openai import AsyncOpenAI
|
| 15 |
+
|
| 16 |
+
# Set up logger
|
| 17 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
# List of models to benchmark
|
| 21 |
+
MODELS_TO_TEST = [
|
| 22 |
+
"nvidia/nemotron-3-super-120b-a12b:free",
|
| 23 |
+
"meta-llama/llama-3.3-70b-instruct:free",
|
| 24 |
+
"openai/gpt-oss-120b:free",
|
| 25 |
+
"z-ai/glm-4.5-air:free",
|
| 26 |
+
"minimax/minimax-m2.5:free",
|
| 27 |
+
"google/gemma-4-26b-a4b-it:free",
|
| 28 |
+
"google/gemma-4-31b-it:free",
|
| 29 |
+
"arcee-ai/trinity-large-preview:free",
|
| 30 |
+
"openai/gpt-oss-20b:free",
|
| 31 |
+
"nvidia/nemotron-3-nano-30b-a3b:free",
|
| 32 |
+
"nvidia/nemotron-nano-9b-v2:free",
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
DEFAULT_QUERY = "Giải hệ phương trình sau: x + y = 10, 2x - y = 2. Trả về kết quả cuối cùng x và y."
|
| 36 |
+
|
| 37 |
+
async def test_model(client: AsyncOpenAI, model: str, query: str) -> Dict[str, Any]:
|
| 38 |
+
"""Test a single model and return performance metrics."""
|
| 39 |
+
start_time = time.time()
|
| 40 |
+
result = {
|
| 41 |
+
"model": model,
|
| 42 |
+
"status": "success",
|
| 43 |
+
"duration": 0,
|
| 44 |
+
"content": "",
|
| 45 |
+
"error": None
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
response = await client.chat.completions.create(
|
| 50 |
+
model=model,
|
| 51 |
+
messages=[{"role": "user", "content": query}],
|
| 52 |
+
timeout=60.0
|
| 53 |
+
)
|
| 54 |
+
result["duration"] = time.time() - start_time
|
| 55 |
+
result["content"] = response.choices[0].message.content.strip()
|
| 56 |
+
except Exception as e:
|
| 57 |
+
result["status"] = "failed"
|
| 58 |
+
result["duration"] = time.time() - start_time
|
| 59 |
+
result["error"] = str(e)
|
| 60 |
+
|
| 61 |
+
return result
|
| 62 |
+
|
| 63 |
+
async def main():
|
| 64 |
+
# Load configuration from .env file inside backend directory
|
| 65 |
+
# If starting from root, backend/.env might be needed. If starting from backend/, .env is enough.
|
| 66 |
+
load_dotenv()
|
| 67 |
+
|
| 68 |
+
# Try multiple common env keys for api key
|
| 69 |
+
api_key = os.getenv("OPENROUTER_API_KEY_1") or os.getenv("OPENROUTER_API_KEY")
|
| 70 |
+
|
| 71 |
+
if not api_key:
|
| 72 |
+
logger.error("❌ Error: NO OPENROUTER_API_KEY found in environment variables.")
|
| 73 |
+
logger.info("Check your .env file in the backend directory.")
|
| 74 |
+
return
|
| 75 |
+
|
| 76 |
+
# Using the project's url_utils to maintain consistency with the main app
|
| 77 |
+
sanitized_key = openai_compatible_api_key(api_key)
|
| 78 |
+
|
| 79 |
+
client = AsyncOpenAI(
|
| 80 |
+
api_key=sanitized_key,
|
| 81 |
+
base_url="https://openrouter.ai/api/v1",
|
| 82 |
+
default_headers={
|
| 83 |
+
"HTTP-Referer": "https://mathsolver.ai",
|
| 84 |
+
"X-Title": "MathSolver LLM Benchmarker",
|
| 85 |
+
}
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
query = DEFAULT_QUERY
|
| 89 |
+
logger.info("=" * 80)
|
| 90 |
+
logger.info(f"🚀 LLM PERFORMANCE BENCHMARK")
|
| 91 |
+
logger.info(f"Query: {query}")
|
| 92 |
+
logger.info("=" * 80)
|
| 93 |
+
logger.info(f"Testing {len(MODELS_TO_TEST)} models sequentially with 30s delay...\n")
|
| 94 |
+
|
| 95 |
+
results = []
|
| 96 |
+
for i, model in enumerate(MODELS_TO_TEST):
|
| 97 |
+
if i > 0:
|
| 98 |
+
logger.info(f"⏳ Waiting 30s before testing next model...")
|
| 99 |
+
await asyncio.sleep(30)
|
| 100 |
+
|
| 101 |
+
logger.info(f"[{i+1}/{len(MODELS_TO_TEST)}] Testing: {model}...")
|
| 102 |
+
res = await test_model(client, model, query)
|
| 103 |
+
results.append(res)
|
| 104 |
+
|
| 105 |
+
# Immediate feedback
|
| 106 |
+
status_str = "✅ SUCCESS" if res["status"] == "success" else "❌ FAILED"
|
| 107 |
+
logger.info(f" Status: {status_str} | Time: {res['duration']:.2f}s")
|
| 108 |
+
|
| 109 |
+
# Report Summary Table
|
| 110 |
+
logger.info("\n" + "=" * 80)
|
| 111 |
+
logger.info("📊 FINAL BENCHMARK SUMMARY")
|
| 112 |
+
logger.info("=" * 80)
|
| 113 |
+
header = f"{'MODEL':<45} | {'STATUS':<10} | {'TIME (s)':<10}"
|
| 114 |
+
logger.info(header)
|
| 115 |
+
logger.info("-" * len(header))
|
| 116 |
+
|
| 117 |
+
for res in results:
|
| 118 |
+
status_str = "✅ SUCCESS" if res["status"] == "success" else "❌ FAILED"
|
| 119 |
+
duration_str = f"{res['duration']:.2f}s"
|
| 120 |
+
logger.info(f"{res['model']:<45} | {status_str:<10} | {duration_str:<10}")
|
| 121 |
+
|
| 122 |
+
logger.info("-" * len(header))
|
| 123 |
+
|
| 124 |
+
# Detailed report for successful ones
|
| 125 |
+
logger.info("\n📝 FULL RESPONSES:")
|
| 126 |
+
for res in results:
|
| 127 |
+
logger.info(f"\n{'='*20} [{res['model']}] {'='*20}")
|
| 128 |
+
if res["status"] == "success":
|
| 129 |
+
logger.info(res["content"])
|
| 130 |
+
else:
|
| 131 |
+
logger.info(f"❌ Error: {res['error']}")
|
| 132 |
+
|
| 133 |
+
logger.info("\n" + "=" * 80)
|
| 134 |
+
logger.info(f"Benchmark finished.")
|
| 135 |
+
|
| 136 |
+
if __name__ == "__main__":
|
| 137 |
+
try:
|
| 138 |
+
asyncio.run(main())
|
| 139 |
+
except KeyboardInterrupt:
|
| 140 |
+
logger.info("\nBenchmark cancelled by user.")
|
| 141 |
+
except Exception as e:
|
| 142 |
+
logger.error(f"Unexpected error: {e}")
|
scripts/test_engine_direct.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
# Add root directory to path to import app and agents
|
| 8 |
+
sys.path.append("/Volumes/WorkSpace/Project/MathSolver/backend")
|
| 9 |
+
|
| 10 |
+
# Configure logging to stdout
|
| 11 |
+
logging.basicConfig(level=logging.DEBUG)
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
from agents.orchestrator import Orchestrator
|
| 15 |
+
|
| 16 |
+
async def main():
|
| 17 |
+
orch = Orchestrator()
|
| 18 |
+
text = "Vẽ tam giác đều cạnh 5."
|
| 19 |
+
job_id = "test_direct_equilateral"
|
| 20 |
+
|
| 21 |
+
print(f"\n--- Testing Orchestrator Direct: {text} ---")
|
| 22 |
+
|
| 23 |
+
async def status_cb(status):
|
| 24 |
+
print(f" [STATUS] {status}")
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
result = await orch.run(text, job_id=job_id, status_callback=status_cb, request_video=False)
|
| 28 |
+
print("\n--- Final Result ---")
|
| 29 |
+
print(json.dumps(result, indent=2))
|
| 30 |
+
except Exception as e:
|
| 31 |
+
print(f"\n--- ERROR ---")
|
| 32 |
+
import traceback
|
| 33 |
+
traceback.print_exc()
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
asyncio.run(main())
|