Spaces:
Sleeping
Sleeping
Cuong2004 commited on
Commit ·
395651c
0
Parent(s):
Deploy API from GitHub Actions
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Dockerfile +48 -0
- Dockerfile.worker +47 -0
- Dockerfile.worker.ocr +42 -0
- README.md +33 -0
- README_HF_WORKER.md +24 -0
- README_HF_WORKER_OCR.md +27 -0
- agents/geometry_agent.py +120 -0
- agents/knowledge_agent.py +135 -0
- agents/ocr_agent.py +112 -0
- agents/orchestrator.py +223 -0
- agents/parser_agent.py +106 -0
- agents/renderer_agent.py +5 -0
- agents/solver_agent.py +107 -0
- agents/torch_ultralytics_compat.py +5 -0
- app/chat_image_upload.py +206 -0
- app/dependencies.py +69 -0
- app/errors.py +59 -0
- app/job_poll.py +47 -0
- app/llm_client.py +100 -0
- app/logging_setup.py +112 -0
- app/logutil.py +67 -0
- app/main.py +142 -0
- app/models/schemas.py +80 -0
- app/ocr_celery.py +54 -0
- app/ocr_local_file.py +43 -0
- app/ocr_text_merge.py +14 -0
- app/routers/__init__.py +1 -0
- app/routers/auth.py +23 -0
- app/routers/sessions.py +184 -0
- app/routers/solve.py +410 -0
- app/runtime_env.py +12 -0
- app/session_cache.py +48 -0
- app/supabase_client.py +37 -0
- app/url_utils.py +23 -0
- app/websocket_manager.py +40 -0
- clean_ports.sh +22 -0
- dump.rdb +0 -0
- geometry_render/__init__.py +5 -0
- geometry_render/renderer.py +265 -0
- migrations/add_image_bucket_storage.sql +35 -0
- migrations/fix_rls_assets.sql +96 -0
- migrations/v4_migration.sql +131 -0
- pytest.ini +18 -0
- requirements.txt +38 -0
- requirements.worker-ocr.txt +23 -0
- requirements.worker-render.txt +21 -0
- run_api_test.sh +69 -0
- run_full_api_test.sh +56 -0
- scripts/benchmark_openrouter.py +77 -0
- scripts/generate_report.py +115 -0
Dockerfile
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Visual Math Solver — API container (Python 3.11 + Manim + OCR stack)
|
| 2 |
+
FROM python:3.11-slim-bookworm
|
| 3 |
+
|
| 4 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 5 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 6 |
+
PIP_NO_CACHE_DIR=1 \
|
| 7 |
+
PIP_ROOT_USER_ACTION=ignore \
|
| 8 |
+
NO_ALBUMENTATIONS_UPDATE=1 \
|
| 9 |
+
OMP_NUM_THREADS=1 \
|
| 10 |
+
MKL_NUM_THREADS=1 \
|
| 11 |
+
OPENBLAS_NUM_THREADS=1
|
| 12 |
+
|
| 13 |
+
WORKDIR /app
|
| 14 |
+
ENV PYTHONPATH=/app
|
| 15 |
+
|
| 16 |
+
# Runtime + *-dev: Manim/pycairo need pkg-config + cairo headers; libpango1.0-dev covers PangoCairo on Bookworm (no libpangocairo-*-dev package).
|
| 17 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 18 |
+
ffmpeg \
|
| 19 |
+
pkg-config \
|
| 20 |
+
cmake \
|
| 21 |
+
libcairo2 \
|
| 22 |
+
libcairo2-dev \
|
| 23 |
+
libpango-1.0-0 \
|
| 24 |
+
libpango1.0-dev \
|
| 25 |
+
libpangocairo-1.0-0 \
|
| 26 |
+
libgdk-pixbuf-2.0-0 \
|
| 27 |
+
libffi-dev \
|
| 28 |
+
python3-dev \
|
| 29 |
+
texlive-latex-base \
|
| 30 |
+
texlive-fonts-recommended \
|
| 31 |
+
texlive-latex-extra \
|
| 32 |
+
build-essential \
|
| 33 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 34 |
+
|
| 35 |
+
COPY requirements.txt .
|
| 36 |
+
RUN pip install --upgrade pip setuptools wheel \
|
| 37 |
+
&& pip install -r requirements.txt
|
| 38 |
+
|
| 39 |
+
COPY . .
|
| 40 |
+
|
| 41 |
+
# Bake model weights and agent init into the image (YOLO, PaddleOCR, Pix2Tex, etc.)
|
| 42 |
+
RUN python scripts/prewarm_models.py
|
| 43 |
+
|
| 44 |
+
# Hugging Face Spaces defaults to 7860; docker-compose can set PORT=8000
|
| 45 |
+
ENV PORT=7860
|
| 46 |
+
EXPOSE 7860
|
| 47 |
+
|
| 48 |
+
CMD ["sh", "-c", "exec uvicorn app.main:app --host 0.0.0.0 --port ${PORT}"]
|
Dockerfile.worker
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Same runtime as API; runs health endpoint + Celery worker (see worker_health.py)
|
| 2 |
+
FROM python:3.11-slim-bookworm
|
| 3 |
+
|
| 4 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 5 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 6 |
+
PIP_NO_CACHE_DIR=1 \
|
| 7 |
+
PIP_ROOT_USER_ACTION=ignore \
|
| 8 |
+
NO_ALBUMENTATIONS_UPDATE=1 \
|
| 9 |
+
OMP_NUM_THREADS=1 \
|
| 10 |
+
MKL_NUM_THREADS=1 \
|
| 11 |
+
OPENBLAS_NUM_THREADS=1
|
| 12 |
+
|
| 13 |
+
WORKDIR /app
|
| 14 |
+
ENV PYTHONPATH=/app
|
| 15 |
+
|
| 16 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 17 |
+
ffmpeg \
|
| 18 |
+
pkg-config \
|
| 19 |
+
cmake \
|
| 20 |
+
libcairo2 \
|
| 21 |
+
libcairo2-dev \
|
| 22 |
+
libpango-1.0-0 \
|
| 23 |
+
libpango1.0-dev \
|
| 24 |
+
libpangocairo-1.0-0 \
|
| 25 |
+
libgdk-pixbuf-2.0-0 \
|
| 26 |
+
libffi-dev \
|
| 27 |
+
python3-dev \
|
| 28 |
+
texlive-latex-base \
|
| 29 |
+
texlive-fonts-recommended \
|
| 30 |
+
texlive-latex-extra \
|
| 31 |
+
build-essential \
|
| 32 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 33 |
+
|
| 34 |
+
COPY requirements.worker-render.txt .
|
| 35 |
+
RUN pip install --upgrade pip setuptools wheel \
|
| 36 |
+
&& pip install -r requirements.worker-render.txt
|
| 37 |
+
|
| 38 |
+
COPY . .
|
| 39 |
+
|
| 40 |
+
RUN python scripts/prewarm_render_worker.py
|
| 41 |
+
|
| 42 |
+
ENV PORT=7860 \
|
| 43 |
+
CELERY_WORKER_QUEUES=render
|
| 44 |
+
EXPOSE 7860
|
| 45 |
+
|
| 46 |
+
ENTRYPOINT []
|
| 47 |
+
CMD ["sh", "-c", "exec python3 -u worker_health.py"]
|
Dockerfile.worker.ocr
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Celery worker: OCR queue only (no Manim / LaTeX / Cairo stack).
|
| 2 |
+
FROM python:3.11-slim-bookworm
|
| 3 |
+
|
| 4 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 5 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 6 |
+
PIP_NO_CACHE_DIR=1 \
|
| 7 |
+
PIP_ROOT_USER_ACTION=ignore \
|
| 8 |
+
NO_ALBUMENTATIONS_UPDATE=1 \
|
| 9 |
+
OMP_NUM_THREADS=1 \
|
| 10 |
+
MKL_NUM_THREADS=1 \
|
| 11 |
+
OPENBLAS_NUM_THREADS=1
|
| 12 |
+
|
| 13 |
+
WORKDIR /app
|
| 14 |
+
ENV PYTHONPATH=/app
|
| 15 |
+
|
| 16 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 17 |
+
build-essential \
|
| 18 |
+
cmake \
|
| 19 |
+
pkg-config \
|
| 20 |
+
python3-dev \
|
| 21 |
+
libglib2.0-0 \
|
| 22 |
+
libgomp1 \
|
| 23 |
+
libgl1 \
|
| 24 |
+
libsm6 \
|
| 25 |
+
libxext6 \
|
| 26 |
+
libxrender1 \
|
| 27 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 28 |
+
|
| 29 |
+
COPY requirements.worker-ocr.txt .
|
| 30 |
+
RUN pip install --upgrade pip setuptools wheel \
|
| 31 |
+
&& pip install -r requirements.worker-ocr.txt
|
| 32 |
+
|
| 33 |
+
COPY . .
|
| 34 |
+
|
| 35 |
+
RUN python scripts/prewarm_ocr_worker.py
|
| 36 |
+
|
| 37 |
+
ENV PORT=7860 \
|
| 38 |
+
CELERY_WORKER_QUEUES=ocr
|
| 39 |
+
EXPOSE 7860
|
| 40 |
+
|
| 41 |
+
ENTRYPOINT []
|
| 42 |
+
CMD ["sh", "-c", "exec python3 -u worker_health.py"]
|
README.md
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Math Solver Backend
|
| 3 |
+
emoji: 📐
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: false
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# Visual Math Solver - Backend (Hugging Face Space)
|
| 12 |
+
|
| 13 |
+
Hệ thống AI giải toán hình học sử dụng Multi-Agent và Manim. Hiện đã nâng cấp lên phiên bản **v5.1**.
|
| 14 |
+
|
| 15 |
+
## Tính năng mới (v5.1)
|
| 16 |
+
- **Symbolic Solver**: Tích hợp SymPy để tự động hóa việc tính toán các giá trị hình học (diện tích, chu vi, độ dài) với độ chính xác tuyệt đối và trình bày các bước giải chi tiết.
|
| 17 |
+
- **3D Video Support**: Nâng cấp công cụ Manim để hỗ trợ hiển thị và xoay camera cho các bài toán hình học không gian (Hình chóp, Hình lăng trụ, v.v.).
|
| 18 |
+
|
| 19 |
+
## Kiến trúc Pipeline (Agentic Flow)
|
| 20 |
+
1. **OCR Agent**: Nhận diện văn bản từ hình ảnh câu hỏi.
|
| 21 |
+
2. **Parser Agent**: Chuyển đổi ngôn ngữ tự nhiên thành Geometry DSL.
|
| 22 |
+
3. **Knowledge Agent**: Bổ sung kiến thức chuyên sâu về hình học.
|
| 23 |
+
4. **Geometry Engine**: Giải hệ phương trình tọa độ để dựng hình.
|
| 24 |
+
5. **Solver Agent (New)**: Thực hiện các phép tính toán học hình thức (Symbolic Math).
|
| 25 |
+
6. **Renderer Agent**: Sinh mã Manim và render video (hỗ trợ cả 2D và 3D).
|
| 26 |
+
|
| 27 |
+
## Triển khai
|
| 28 |
+
Space này chạy Docker container chứa FastAPI và môi trường Manim. Để chạy cục bộ, tham khảo `setup.sh` và `.env.example`.
|
| 29 |
+
|
| 30 |
+
## Kiểm thử (pytest)
|
| 31 |
+
- **Nhanh (mặc định):** từ thư mục `backend`, cài `pip install -r requirements.txt`, rồi `PYTHONPATH=. python -m pytest tests/`. Các marker `real_api`, `real_agents`, `slow`, v.v. bị loại theo `pytest.ini` để không cần server hay API key.
|
| 32 |
+
- **CI API (mock video + eager Celery):** `chmod +x scripts/run_real_integration.sh && ./scripts/run_real_integration.sh ci` — khởi động API, chạy smoke + full suite, ghi `integration_report.md` và `temp_suite_results.json`.
|
| 33 |
+
- **Tích hợp thật (worker / Manim / OpenRouter):** `./scripts/run_real_integration.sh real` với backend + worker đang chạy (Redis, `.env` đầy đủ). Bật từng phần bằng `RUN_REAL_WORKER_OCR=1`, `RUN_REAL_WORKER_MANIM=1` (cần `MOCK_VIDEO=false`). Đặt `TEST_SUPABASE_USER_ID` trong `.env` cho user Supabase hợp lệ (xem `.env.example`).
|
README_HF_WORKER.md
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Math Solver Render Worker
|
| 3 |
+
emoji: 👷
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# Math Solver — Render worker (Manim)
|
| 11 |
+
|
| 12 |
+
This Space runs **Celery** via `worker_health.py` and consumes **only** queue **`render`** (`render_geometry_video`). Image sets `CELERY_WORKER_QUEUES=render` by default (`Dockerfile.worker`).
|
| 13 |
+
|
| 14 |
+
**Solve** (orchestrator, agents, OCR-in-request when `OCR_USE_CELERY` is off) runs on the **API** Space, not on this worker.
|
| 15 |
+
|
| 16 |
+
## OCR offload (separate Space)
|
| 17 |
+
|
| 18 |
+
Queue **`ocr`** is handled by a **dedicated OCR worker** (`Dockerfile.worker.ocr`, `README_HF_WORKER_OCR.md`, workflow `deploy-worker-ocr.yml`). On the API, set `OCR_USE_CELERY=true` and deploy an OCR Space that listens on `ocr`.
|
| 19 |
+
|
| 20 |
+
## Secrets
|
| 21 |
+
|
| 22 |
+
Same broker as the API: `REDIS_URL` / `CELERY_BROKER_URL`, Supabase, OpenRouter (renderer may use LLM paths), etc.
|
| 23 |
+
|
| 24 |
+
**GitHub Actions:** repository secrets `HF_TOKEN` and `HF_WORKER_REPO` (`owner/space-name`) for this workflow (`deploy-worker.yml`).
|
README_HF_WORKER_OCR.md
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Math Solver OCR Worker
|
| 3 |
+
emoji: 👁️
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# Math Solver — OCR-only worker
|
| 11 |
+
|
| 12 |
+
This Space runs **Celery** (`worker_health.py`) consuming **only** the `ocr` queue.
|
| 13 |
+
|
| 14 |
+
Set environment:
|
| 15 |
+
|
| 16 |
+
- `CELERY_WORKER_QUEUES=ocr` (default in `Dockerfile.worker.ocr`)
|
| 17 |
+
- Same `REDIS_URL` / `CELERY_BROKER_URL` / `CELERY_RESULT_BACKEND` as the API
|
| 18 |
+
|
| 19 |
+
This Space runs **raw OCR only** (YOLO, PaddleOCR, Pix2Tex). **OpenRouter / LLM tinh chỉnh** không chạy ở đây; API Space gọi `refine_with_llm` sau khi nhận kết quả từ queue `ocr`.
|
| 20 |
+
|
| 21 |
+
On the **API** Space, set `OCR_USE_CELERY=true` so `run_ocr_from_url` tasks are sent to this worker instead of running Paddle/Pix2Tex on the API process.
|
| 22 |
+
|
| 23 |
+
Optional: `OCR_CELERY_TIMEOUT_SEC` (default `180`).
|
| 24 |
+
|
| 25 |
+
**Manim / video** uses a different Celery queue (`render`) and Space — see `README_HF_WORKER.md` and workflow `deploy-worker.yml`.
|
| 26 |
+
|
| 27 |
+
GitHub Actions: repository secrets `HF_TOKEN` and `HF_OCR_WORKER_REPO` (`owner/space-name`) enable workflow `deploy-worker-ocr.yml`.
|
agents/geometry_agent.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from openai import AsyncOpenAI
|
| 5 |
+
from typing import Dict, Any
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
from app.url_utils import openai_compatible_api_key, sanitize_env
|
| 12 |
+
from app.llm_client import get_llm_client
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GeometryAgent:
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.llm = get_llm_client()
|
| 18 |
+
|
| 19 |
+
async def generate_dsl(self, semantic_data: Dict[str, Any], previous_dsl: str = None) -> str:
|
| 20 |
+
logger.info("==[GeometryAgent] Generating DSL from semantic data==")
|
| 21 |
+
if previous_dsl:
|
| 22 |
+
logger.info(f"[GeometryAgent] Using previous DSL context (len={len(previous_dsl)})")
|
| 23 |
+
|
| 24 |
+
system_prompt = """
|
| 25 |
+
You are a Geometry DSL Generator. Convert semantic geometry data into a precise Geometry DSL program.
|
| 26 |
+
|
| 27 |
+
=== MULTI-TURN CONTEXT ===
|
| 28 |
+
If a PREVIOUS DSL is provided, your job is to UPDATE or EXTEND it.
|
| 29 |
+
1. DO NOT remove existing points unless the user explicitly asks to "redefine" or "move" them.
|
| 30 |
+
2. Ensure new segments/points connect correctly to existing ones.
|
| 31 |
+
3. Your output should be the ENTIRE updated DSL, not just the changes.
|
| 32 |
+
|
| 33 |
+
=== DSL COMMANDS ===
|
| 34 |
+
POINT(A) — declare a point
|
| 35 |
+
POINT(A, x, y, z) — declare a point with explicit coordinates
|
| 36 |
+
LENGTH(AB, 5) — distance between A and B is 5 (2D/3D)
|
| 37 |
+
ANGLE(A, 90) — interior angle at vertex A is 90° (2D/3D)
|
| 38 |
+
PARALLEL(AB, CD) — segment AB is parallel to CD (2D/3D)
|
| 39 |
+
PERPENDICULAR(AB, CD) — segment AB is perpendicular to CD (2D/3D)
|
| 40 |
+
MIDPOINT(M, AB) — M is the midpoint of segment AB
|
| 41 |
+
SECTION(E, A, C, k) — E satisfies vector AE = k * vector AC (k is decimal)
|
| 42 |
+
LINE(A, B) — infinite line passing through A and B
|
| 43 |
+
RAY(A, B) — ray starting at A and passing through B
|
| 44 |
+
CIRCLE(O, 5) — circle with center O and radius 5 (2D)
|
| 45 |
+
SPHERE(O, 5) — sphere with center O and radius 5 (3D)
|
| 46 |
+
SEGMENT(M, N) — auxiliary segment MN to be drawn
|
| 47 |
+
POLYGON_ORDER(A, B, C, D) — the order in which vertices form the polygon boundary
|
| 48 |
+
TRIANGLE(ABC) — equilateral/arbitrary triangle
|
| 49 |
+
PYRAMID(S_ABCD) — pyramid with apex S and base ABCD
|
| 50 |
+
PRISM(ABC_DEF) — triangular prism
|
| 51 |
+
|
| 52 |
+
=== RULES ===
|
| 53 |
+
1. 3D Coordinates: Use POINT(A, x, y, z) if specific coordinates are given in the problem.
|
| 54 |
+
2. Space Geometry: For pyramids/prisms, use the specialized commands.
|
| 55 |
+
3. Primary Vertices: Always declare the main vertices of the shape (e.g., A, B, C, D) using POINT(X).
|
| 56 |
+
4. POLYGON_ORDER: Always emit POLYGON_ORDER(...) for the main shape using ONLY these primary vertices.
|
| 57 |
+
5. All Points: EVERY point mentioned (A, B, C, H, M, etc.) MUST be declared with POINT(Name) first.
|
| 58 |
+
6. Altitudes/Perpendiculars: For an altitude AH to BC, use POINT(H) + PERPENDICULAR(AH, BC).
|
| 59 |
+
7. Format: Output ONLY DSL lines — NO explanation, NO markdown, NO code blocks.
|
| 60 |
+
|
| 61 |
+
=== SHAPE EXAMPLES ===
|
| 62 |
+
|
| 63 |
+
--- Case: Square Pyramid S.ABCD with side 10, height 15 ---
|
| 64 |
+
PYRAMID(S_ABCD)
|
| 65 |
+
POINT(A, 0, 0, 0)
|
| 66 |
+
POINT(B, 10, 0, 0)
|
| 67 |
+
POINT(C, 10, 10, 0)
|
| 68 |
+
POINT(D, 0, 10, 0)
|
| 69 |
+
POINT(S)
|
| 70 |
+
POINT(O)
|
| 71 |
+
SECTION(O, A, C, 0.5)
|
| 72 |
+
LENGTH(SO, 15)
|
| 73 |
+
PERPENDICULAR(SO, AC)
|
| 74 |
+
PERPENDICULAR(SO, AB)
|
| 75 |
+
POLYGON_ORDER(A, B, C, D)
|
| 76 |
+
|
| 77 |
+
--- Case: Right Triangle ABC at A, AB=3, AC=4, altitude AH ---
|
| 78 |
+
POLYGON_ORDER(A, B, C)
|
| 79 |
+
POINT(A)
|
| 80 |
+
POINT(B)
|
| 81 |
+
POINT(C)
|
| 82 |
+
POINT(H)
|
| 83 |
+
LENGTH(AB, 3)
|
| 84 |
+
LENGTH(AC, 4)
|
| 85 |
+
ANGLE(A, 90)
|
| 86 |
+
PERPENDICULAR(AH, BC)
|
| 87 |
+
SEGMENT(A, H)
|
| 88 |
+
|
| 89 |
+
--- Case: Rectangle ABCD with AB=5, AD=10 ---
|
| 90 |
+
POLYGON_ORDER(A, B, C, D)
|
| 91 |
+
POINT(A)
|
| 92 |
+
POINT(B)
|
| 93 |
+
POINT(C)
|
| 94 |
+
POINT(D)
|
| 95 |
+
LENGTH(AB, 5)
|
| 96 |
+
LENGTH(AD, 10)
|
| 97 |
+
PERPENDICULAR(AB, AD)
|
| 98 |
+
PARALLEL(AB, CD)
|
| 99 |
+
PARALLEL(AD, BC)
|
| 100 |
+
|
| 101 |
+
[Circle with center O radius 7]
|
| 102 |
+
POINT(O)
|
| 103 |
+
CIRCLE(O, 7)
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
user_content = f"Semantic Data: {json.dumps(semantic_data, ensure_ascii=False)}"
|
| 107 |
+
if previous_dsl:
|
| 108 |
+
user_content = f"PREVIOUS DSL:\n{previous_dsl}\n\nUPDATE WITH NEW DATA: {json.dumps(semantic_data, ensure_ascii=False)}"
|
| 109 |
+
|
| 110 |
+
logger.debug("[GeometryAgent] Calling LLM (Multi-Layer)...")
|
| 111 |
+
content = await self.llm.chat_completions_create(
|
| 112 |
+
messages=[
|
| 113 |
+
{"role": "system", "content": system_prompt},
|
| 114 |
+
{"role": "user", "content": user_content}
|
| 115 |
+
]
|
| 116 |
+
)
|
| 117 |
+
dsl = content.strip() if content else ""
|
| 118 |
+
logger.info(f"[GeometryAgent] DSL generated ({len(dsl.splitlines())} lines).")
|
| 119 |
+
logger.debug(f"[GeometryAgent] DSL output:\n{dsl}")
|
| 120 |
+
return dsl
|
agents/knowledge_agent.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Dict, Any
|
| 3 |
+
|
| 4 |
+
logger = logging.getLogger(__name__)
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# ─── Shape rule registry ────────────────────────────────────────────────────
|
| 8 |
+
# Each entry: keyword list → augmentation function
|
| 9 |
+
# Augmentation receives (values: dict, text: str) and returns updated values dict.
|
| 10 |
+
|
| 11 |
+
class KnowledgeAgent:
|
| 12 |
+
"""Knowledge Agent: Stores geometric theorems and common patterns to augment Parser output."""
|
| 13 |
+
|
| 14 |
+
def augment_semantic_data(self, semantic_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 15 |
+
logger.info("==[KnowledgeAgent] Augmenting semantic data==")
|
| 16 |
+
text = str(semantic_data.get("input_text", "")).lower()
|
| 17 |
+
logger.debug(f"[KnowledgeAgent] Input text for matching: '{text[:200]}'")
|
| 18 |
+
|
| 19 |
+
shape_type = self._detect_shape(text, semantic_data.get("type", ""))
|
| 20 |
+
if shape_type:
|
| 21 |
+
semantic_data["type"] = shape_type
|
| 22 |
+
values = semantic_data.get("values", {})
|
| 23 |
+
values = self._augment_values(shape_type, values, text)
|
| 24 |
+
semantic_data["values"] = values
|
| 25 |
+
else:
|
| 26 |
+
logger.info("[KnowledgeAgent] No special rule matched. Returning data unchanged.")
|
| 27 |
+
|
| 28 |
+
logger.debug(f"[KnowledgeAgent] Output semantic data: {semantic_data}")
|
| 29 |
+
return semantic_data
|
| 30 |
+
|
| 31 |
+
# ─── Shape detection ────────────────────────────────────────────────────
|
| 32 |
+
def _detect_shape(self, text: str, llm_type: str) -> str | None:
|
| 33 |
+
"""Detect shape from text keywords. LLM type provides a hint."""
|
| 34 |
+
checks = [
|
| 35 |
+
(["hình vuông", "square"], "square"),
|
| 36 |
+
(["hình chữ nhật", "rectangle"], "rectangle"),
|
| 37 |
+
(["hình thoi", "rhombus"], "rhombus"),
|
| 38 |
+
(["hình bình hành", "parallelogram"], "parallelogram"),
|
| 39 |
+
(["hình thang vuông"], "right_trapezoid"),
|
| 40 |
+
(["hình thang", "trapezoid", "trapezium"], "trapezoid"),
|
| 41 |
+
(["tam giác vuông", "right triangle"], "right_triangle"),
|
| 42 |
+
(["tam giác đều", "equilateral triangle", "equilateral"], "equilateral_triangle"),
|
| 43 |
+
(["tam giác cân", "isosceles"], "isosceles_triangle"),
|
| 44 |
+
(["tam giác", "triangle"], "triangle"),
|
| 45 |
+
(["đường tròn", "circle"], "circle"),
|
| 46 |
+
]
|
| 47 |
+
for keywords, shape in checks:
|
| 48 |
+
if any(kw in text for kw in keywords):
|
| 49 |
+
logger.info(f"[KnowledgeAgent] Rule MATCH: '{shape}' detected (keyword match).")
|
| 50 |
+
return shape
|
| 51 |
+
|
| 52 |
+
# Fallback: trust LLM-detected type if it's a known type
|
| 53 |
+
known = {
|
| 54 |
+
"rectangle", "square", "rhombus", "parallelogram",
|
| 55 |
+
"trapezoid", "right_trapezoid", "triangle", "right_triangle",
|
| 56 |
+
"equilateral_triangle", "isosceles_triangle", "circle",
|
| 57 |
+
}
|
| 58 |
+
if llm_type in known:
|
| 59 |
+
logger.info(f"[KnowledgeAgent] Using LLM-detected type '{llm_type}'.")
|
| 60 |
+
return llm_type
|
| 61 |
+
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
# ─── Value augmentation ──────────────────────────────────────────────────
|
| 65 |
+
def _augment_values(self, shape: str, values: dict, text: str) -> dict:
|
| 66 |
+
ab = values.get("AB")
|
| 67 |
+
ad = values.get("AD")
|
| 68 |
+
bc = values.get("BC")
|
| 69 |
+
cd = values.get("CD")
|
| 70 |
+
|
| 71 |
+
if shape == "rectangle":
|
| 72 |
+
if ab and ad:
|
| 73 |
+
values.setdefault("CD", ab)
|
| 74 |
+
values.setdefault("BC", ad)
|
| 75 |
+
values.setdefault("angle_A", 90)
|
| 76 |
+
logger.info(f"[KnowledgeAgent] Rectangle: AB=CD={ab}, AD=BC={ad}, angle_A=90°")
|
| 77 |
+
else:
|
| 78 |
+
values.setdefault("angle_A", 90)
|
| 79 |
+
|
| 80 |
+
elif shape == "square":
|
| 81 |
+
side = ab or ad or bc or cd or values.get("side")
|
| 82 |
+
if side:
|
| 83 |
+
values.update({"AB": side, "AD": side, "angle_A": 90})
|
| 84 |
+
logger.info(f"[KnowledgeAgent] Square: side={side}, angle_A=90°")
|
| 85 |
+
else:
|
| 86 |
+
values.setdefault("angle_A", 90)
|
| 87 |
+
|
| 88 |
+
elif shape == "rhombus":
|
| 89 |
+
side = ab or values.get("side")
|
| 90 |
+
if side:
|
| 91 |
+
values.update({"AB": side, "BC": side, "CD": side, "DA": side})
|
| 92 |
+
logger.info(f"[KnowledgeAgent] Rhombus: all sides={side}")
|
| 93 |
+
|
| 94 |
+
elif shape == "parallelogram":
|
| 95 |
+
if ab:
|
| 96 |
+
values.setdefault("CD", ab)
|
| 97 |
+
if ad:
|
| 98 |
+
values.setdefault("BC", ad)
|
| 99 |
+
logger.info(f"[KnowledgeAgent] Parallelogram: AB||CD, AD||BC")
|
| 100 |
+
|
| 101 |
+
elif shape == "trapezoid":
|
| 102 |
+
logger.info("[KnowledgeAgent] Trapezoid: AB||CD (bottom||top)")
|
| 103 |
+
|
| 104 |
+
elif shape == "right_trapezoid":
|
| 105 |
+
logger.info("[KnowledgeAgent] Right trapezoid: AB||CD, AD⊥AB")
|
| 106 |
+
values.setdefault("angle_A", 90)
|
| 107 |
+
|
| 108 |
+
elif shape == "equilateral_triangle":
|
| 109 |
+
side = ab or values.get("side")
|
| 110 |
+
if side:
|
| 111 |
+
values.update({"AB": side, "BC": side, "CA": side, "angle_A": 60})
|
| 112 |
+
logger.info(f"[KnowledgeAgent] Equilateral triangle: all sides={side}, angle_A=60°")
|
| 113 |
+
|
| 114 |
+
elif shape == "right_triangle":
|
| 115 |
+
# Try to infer which vertex is the right angle
|
| 116 |
+
rt_vertex = _detect_right_angle_vertex(text)
|
| 117 |
+
values.setdefault(f"angle_{rt_vertex}", 90)
|
| 118 |
+
logger.info(f"[KnowledgeAgent] Right triangle: angle_{rt_vertex}=90°")
|
| 119 |
+
|
| 120 |
+
elif shape == "isosceles_triangle":
|
| 121 |
+
logger.info("[KnowledgeAgent] Isosceles triangle: AB=AC (default, LLM may override)")
|
| 122 |
+
|
| 123 |
+
elif shape == "circle":
|
| 124 |
+
logger.info("[KnowledgeAgent] Circle detected — no side augmentation needed.")
|
| 125 |
+
|
| 126 |
+
return values
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _detect_right_angle_vertex(text: str) -> str:
|
| 130 |
+
"""Heuristic: detect which vertex is right angle from text."""
|
| 131 |
+
for vertex in ["A", "B", "C", "D"]:
|
| 132 |
+
patterns = [f"vuông tại {vertex}", f"góc {vertex} vuông", f"right angle at {vertex}"]
|
| 133 |
+
if any(p.lower() in text for p in patterns):
|
| 134 |
+
return vertex
|
| 135 |
+
return "A" # default
|
agents/ocr_agent.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
from vision_ocr.pipeline import OcrVisionPipeline
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ImprovedOCRAgent:
|
| 10 |
+
"""
|
| 11 |
+
API-facing OCR: composes ``OcrVisionPipeline`` (vision only) with optional LLM refinement.
|
| 12 |
+
Celery OCR workers should import ``OcrVisionPipeline`` directly from ``vision_ocr``.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, skip_llm_refinement: bool = False):
|
| 16 |
+
self._skip_llm_refinement = bool(skip_llm_refinement)
|
| 17 |
+
self._vision = OcrVisionPipeline()
|
| 18 |
+
logger.info(
|
| 19 |
+
"[ImprovedOCRAgent] Vision pipeline ready (skip_llm_refinement=%s)...",
|
| 20 |
+
self._skip_llm_refinement,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
if self._skip_llm_refinement:
|
| 24 |
+
self.llm = None
|
| 25 |
+
logger.info("[ImprovedOCRAgent] LLM client skipped (raw OCR only).")
|
| 26 |
+
else:
|
| 27 |
+
from app.llm_client import get_llm_client
|
| 28 |
+
|
| 29 |
+
self.llm = get_llm_client()
|
| 30 |
+
logger.info("[ImprovedOCRAgent] Multi-Layer LLM Client initialized.")
|
| 31 |
+
|
| 32 |
+
async def process_image(self, image_path: str) -> str:
|
| 33 |
+
combined_text = await self._vision.process_image(image_path)
|
| 34 |
+
|
| 35 |
+
if not combined_text.strip():
|
| 36 |
+
return combined_text
|
| 37 |
+
|
| 38 |
+
if self._skip_llm_refinement or self.llm is None:
|
| 39 |
+
logger.info("[ImprovedOCRAgent] Skipping MegaLLM refinement (raw OCR output).")
|
| 40 |
+
return combined_text
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
logger.info("[ImprovedOCRAgent] Sending to MegaLLM for refinement...")
|
| 44 |
+
refined_text = await asyncio.wait_for(
|
| 45 |
+
self.refine_with_llm(combined_text), timeout=30.0
|
| 46 |
+
)
|
| 47 |
+
return refined_text
|
| 48 |
+
except asyncio.TimeoutError:
|
| 49 |
+
logger.error("[ImprovedOCRAgent] MegaLLM refinement timed out.")
|
| 50 |
+
return combined_text
|
| 51 |
+
except Exception as e:
|
| 52 |
+
logger.error("[ImprovedOCRAgent] MegaLLM refinement failed: %s", e)
|
| 53 |
+
return combined_text
|
| 54 |
+
|
| 55 |
+
async def refine_with_llm(self, text: str) -> str:
|
| 56 |
+
if not text.strip():
|
| 57 |
+
return ""
|
| 58 |
+
if self.llm is None:
|
| 59 |
+
logger.warning("[ImprovedOCRAgent] refine_with_llm: no LLM client; returning raw text.")
|
| 60 |
+
return text
|
| 61 |
+
|
| 62 |
+
prompt = f"""Bạn là một chuyên gia số hóa tài liệu toán học.
|
| 63 |
+
Dưới đây là kết quả OCR thô từ một trang sách toán Tiếng Việt.
|
| 64 |
+
Kết quả này có thể chứa lỗi chính tả, lỗi định dạng mã LaTeX, hoặc bị ngắt quãng không logic.
|
| 65 |
+
|
| 66 |
+
Nhiệm vụ của bạn:
|
| 67 |
+
1. Sửa lỗi chính tả tiếng Việt.
|
| 68 |
+
2. Đảm bảo các công thức toán học được viết đúng định dạng LaTeX và nằm trong cặp dấu $...$.
|
| 69 |
+
3. Giữ nguyên cấu trúc logic của bài toán.
|
| 70 |
+
4. Trả về nội dung đã được làm sạch dưới dạng Markdown.
|
| 71 |
+
|
| 72 |
+
Nội dung OCR thô:
|
| 73 |
+
---
|
| 74 |
+
{text}
|
| 75 |
+
---
|
| 76 |
+
|
| 77 |
+
Kết quả làm sạch:"""
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
refined = await self.llm.chat_completions_create(
|
| 81 |
+
messages=[{"role": "user", "content": prompt}],
|
| 82 |
+
temperature=0.1,
|
| 83 |
+
)
|
| 84 |
+
logger.info("[ImprovedOCRAgent] LLM refinement complete.")
|
| 85 |
+
return refined
|
| 86 |
+
except Exception as e:
|
| 87 |
+
logger.error("[ImprovedOCRAgent] LLM refinement failed: %s", e)
|
| 88 |
+
return text
|
| 89 |
+
|
| 90 |
+
async def process_url(self, url: str) -> str:
|
| 91 |
+
combined_text = await self._vision.process_url(url)
|
| 92 |
+
|
| 93 |
+
if not combined_text.strip() or combined_text.lstrip().startswith("Error:"):
|
| 94 |
+
return combined_text
|
| 95 |
+
|
| 96 |
+
if self._skip_llm_refinement or self.llm is None:
|
| 97 |
+
return combined_text
|
| 98 |
+
|
| 99 |
+
try:
|
| 100 |
+
return await asyncio.wait_for(self.refine_with_llm(combined_text), timeout=30.0)
|
| 101 |
+
except asyncio.TimeoutError:
|
| 102 |
+
logger.error("[ImprovedOCRAgent] MegaLLM refinement timed out.")
|
| 103 |
+
return combined_text
|
| 104 |
+
except Exception as e:
|
| 105 |
+
logger.error("[ImprovedOCRAgent] MegaLLM refinement failed: %s", e)
|
| 106 |
+
return combined_text
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class OCRAgent(ImprovedOCRAgent):
|
| 110 |
+
"""Alias for compatibility with existing code."""
|
| 111 |
+
|
| 112 |
+
pass
|
agents/orchestrator.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
from agents.geometry_agent import GeometryAgent
|
| 6 |
+
from agents.knowledge_agent import KnowledgeAgent
|
| 7 |
+
from agents.ocr_agent import OCRAgent
|
| 8 |
+
from agents.parser_agent import ParserAgent
|
| 9 |
+
from agents.solver_agent import SolverAgent
|
| 10 |
+
from app.logutil import log_step
|
| 11 |
+
from app.ocr_celery import ocr_from_image_url
|
| 12 |
+
from solver.dsl_parser import DSLParser
|
| 13 |
+
from solver.engine import GeometryEngine
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
_CLIP = 2000
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _clip(val: Any, n: int = _CLIP) -> str | None:
|
| 21 |
+
if val is None:
|
| 22 |
+
return None
|
| 23 |
+
if isinstance(val, str):
|
| 24 |
+
s = val
|
| 25 |
+
else:
|
| 26 |
+
s = json.dumps(val, ensure_ascii=False, default=str)
|
| 27 |
+
return s if len(s) <= n else s[:n] + "…"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _step_io(step: str, input_val: Any = None, output_val: Any = None) -> None:
|
| 31 |
+
"""Debug: chỉ input/output (đã cắt), tránh dump dài dòng không cần thiết."""
|
| 32 |
+
log_step(step, input=_clip(input_val), output=_clip(output_val))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Orchestrator:
|
| 36 |
+
def __init__(self):
|
| 37 |
+
self.parser_agent = ParserAgent()
|
| 38 |
+
self.geometry_agent = GeometryAgent()
|
| 39 |
+
self.ocr_agent = OCRAgent()
|
| 40 |
+
self.knowledge_agent = KnowledgeAgent()
|
| 41 |
+
self.solver_agent = SolverAgent()
|
| 42 |
+
self.solver_engine = GeometryEngine()
|
| 43 |
+
self.dsl_parser = DSLParser()
|
| 44 |
+
|
| 45 |
+
def _generate_step_description(self, semantic_json: Dict[str, Any], engine_result: Dict[str, Any]) -> str:
|
| 46 |
+
"""Tạo mô tả từng bước vẽ dựa trên kết quả của engine."""
|
| 47 |
+
analysis = semantic_json.get("analysis", "")
|
| 48 |
+
if not analysis:
|
| 49 |
+
analysis = f"Giải bài toán về {semantic_json.get('type', 'hình học')}."
|
| 50 |
+
|
| 51 |
+
steps = ["\n\n**Các bước dựng hình:**"]
|
| 52 |
+
drawing_phases = engine_result.get("drawing_phases", [])
|
| 53 |
+
|
| 54 |
+
for phase in drawing_phases:
|
| 55 |
+
label = phase.get("label", f"Giai đoạn {phase['phase']}")
|
| 56 |
+
points = ", ".join(phase.get("points", []))
|
| 57 |
+
segments = ", ".join([f"{s[0]}{s[1]}" for s in phase.get("segments", [])])
|
| 58 |
+
|
| 59 |
+
step_text = f"- **{label}**:"
|
| 60 |
+
if points:
|
| 61 |
+
step_text += f" Xác định các điểm {points}."
|
| 62 |
+
if segments:
|
| 63 |
+
step_text += f" Vẽ các đoạn thẳng {segments}."
|
| 64 |
+
steps.append(step_text)
|
| 65 |
+
|
| 66 |
+
circles = engine_result.get("circles", [])
|
| 67 |
+
for c in circles:
|
| 68 |
+
steps.append(f"- **Đường tròn**: Vẽ đường tròn tâm {c['center']} bán kính {c['radius']}.")
|
| 69 |
+
|
| 70 |
+
return analysis + "\n".join(steps)
|
| 71 |
+
|
| 72 |
+
async def run(
|
| 73 |
+
self,
|
| 74 |
+
text: str,
|
| 75 |
+
image_url: str = None,
|
| 76 |
+
job_id: str = None,
|
| 77 |
+
session_id: str = None,
|
| 78 |
+
status_callback=None,
|
| 79 |
+
history: list = None,
|
| 80 |
+
) -> Dict[str, Any]:
|
| 81 |
+
"""
|
| 82 |
+
Run the full pipeline. Optional history allows context-aware solving.
|
| 83 |
+
"""
|
| 84 |
+
_step_io(
|
| 85 |
+
"orchestrate_start",
|
| 86 |
+
input_val={
|
| 87 |
+
"job_id": job_id,
|
| 88 |
+
"text_len": len(text or ""),
|
| 89 |
+
"image_url": image_url,
|
| 90 |
+
"history_len": len(history or []),
|
| 91 |
+
},
|
| 92 |
+
output_val=None,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
if status_callback:
|
| 96 |
+
await status_callback("processing")
|
| 97 |
+
|
| 98 |
+
# 1. Extract context from history (if any)
|
| 99 |
+
previous_context = None
|
| 100 |
+
if history:
|
| 101 |
+
# Look for the last assistant message with geometry data
|
| 102 |
+
for msg in reversed(history):
|
| 103 |
+
if msg.get("role") == "assistant" and msg.get("metadata", {}).get("geometry_dsl"):
|
| 104 |
+
previous_context = {
|
| 105 |
+
"geometry_dsl": msg["metadata"]["geometry_dsl"],
|
| 106 |
+
"coordinates": msg["metadata"].get("coordinates", {}),
|
| 107 |
+
"analysis": msg.get("content", ""),
|
| 108 |
+
}
|
| 109 |
+
break
|
| 110 |
+
|
| 111 |
+
if previous_context:
|
| 112 |
+
_step_io("context_found", input_val=None, output_val={"dsl_len": len(previous_context["geometry_dsl"])})
|
| 113 |
+
|
| 114 |
+
# 2. Gather input text (OCR or direct)
|
| 115 |
+
input_text = text
|
| 116 |
+
if image_url:
|
| 117 |
+
input_text = await ocr_from_image_url(image_url, self.ocr_agent)
|
| 118 |
+
_step_io("step1_ocr", input_val=image_url, output_val=input_text)
|
| 119 |
+
else:
|
| 120 |
+
_step_io("step1_ocr", input_val="(no image)", output_val=text)
|
| 121 |
+
|
| 122 |
+
feedback = None
|
| 123 |
+
MAX_RETRIES = 2
|
| 124 |
+
|
| 125 |
+
for attempt in range(MAX_RETRIES + 1):
|
| 126 |
+
_step_io(
|
| 127 |
+
"attempt",
|
| 128 |
+
input_val=f"{attempt + 1}/{MAX_RETRIES + 1}",
|
| 129 |
+
output_val=None,
|
| 130 |
+
)
|
| 131 |
+
if status_callback:
|
| 132 |
+
await status_callback("solving")
|
| 133 |
+
|
| 134 |
+
# Parser with context
|
| 135 |
+
_step_io("step2_parse", input_val=f"{input_text[:50]}...", output_val=None)
|
| 136 |
+
semantic_json = await self.parser_agent.process(input_text, feedback=feedback, context=previous_context)
|
| 137 |
+
semantic_json["input_text"] = input_text
|
| 138 |
+
_step_io("step2_parse", input_val=None, output_val=semantic_json)
|
| 139 |
+
|
| 140 |
+
# Knowledge augmentation
|
| 141 |
+
_step_io("step3_knowledge", input_val=semantic_json, output_val=None)
|
| 142 |
+
semantic_json = self.knowledge_agent.augment_semantic_data(semantic_json)
|
| 143 |
+
_step_io("step3_knowledge", input_val=None, output_val=semantic_json)
|
| 144 |
+
|
| 145 |
+
# Geometry DSL with context (passing previous DSL to guide generation)
|
| 146 |
+
_step_io("step4_geometry_dsl", input_val=semantic_json, output_val=None)
|
| 147 |
+
dsl_code = await self.geometry_agent.generate_dsl(
|
| 148 |
+
semantic_json,
|
| 149 |
+
previous_dsl=previous_context["geometry_dsl"] if previous_context else None
|
| 150 |
+
)
|
| 151 |
+
_step_io("step4_geometry_dsl", input_val=None, output_val=dsl_code)
|
| 152 |
+
|
| 153 |
+
_step_io("step5_dsl_parse", input_val=dsl_code, output_val=None)
|
| 154 |
+
points, constraints, is_3d = self.dsl_parser.parse(dsl_code)
|
| 155 |
+
_step_io(
|
| 156 |
+
"step5_dsl_parse",
|
| 157 |
+
input_val=None,
|
| 158 |
+
output_val={
|
| 159 |
+
"points": len(points),
|
| 160 |
+
"constraints": len(constraints),
|
| 161 |
+
"is_3d": is_3d,
|
| 162 |
+
},
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
_step_io("step6_solve", input_val=f"{len(points)} pts / {len(constraints)} cons (is_3d={is_3d})", output_val=None)
|
| 166 |
+
import anyio
|
| 167 |
+
engine_result = await anyio.to_thread.run_sync(self.solver_engine.solve, points, constraints, is_3d)
|
| 168 |
+
|
| 169 |
+
if engine_result:
|
| 170 |
+
coordinates = engine_result.get("coordinates")
|
| 171 |
+
_step_io("step6_solve", input_val=None, output_val=coordinates)
|
| 172 |
+
logger.info(
|
| 173 |
+
"[Orchestrator] geometry solved job_id=%s is_3d=%s n_coords=%d",
|
| 174 |
+
job_id,
|
| 175 |
+
is_3d,
|
| 176 |
+
len(coordinates) if isinstance(coordinates, dict) else 0,
|
| 177 |
+
)
|
| 178 |
+
break
|
| 179 |
+
|
| 180 |
+
feedback = "Geometry solver failed to find a valid solution for the given constraints. Parallelism or lengths might be inconsistent."
|
| 181 |
+
_step_io(
|
| 182 |
+
"step6_solve",
|
| 183 |
+
input_val=f"attempt {attempt + 1}",
|
| 184 |
+
output_val=feedback,
|
| 185 |
+
)
|
| 186 |
+
if attempt == MAX_RETRIES:
|
| 187 |
+
_step_io(
|
| 188 |
+
"orchestrate_abort",
|
| 189 |
+
input_val=None,
|
| 190 |
+
output_val="solver_exhausted_retries",
|
| 191 |
+
)
|
| 192 |
+
return {
|
| 193 |
+
"error": "Solver failed after multiple attempts.",
|
| 194 |
+
"last_dsl": dsl_code,
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
_step_io("orchestrate_done", input_val=job_id, output_val="success")
|
| 198 |
+
|
| 199 |
+
# 8. Solution calculation (New in v5.1)
|
| 200 |
+
solution = None
|
| 201 |
+
if engine_result:
|
| 202 |
+
_step_io("step8_solve_math", input_val=semantic_json.get("target_question"), output_val=None)
|
| 203 |
+
solution = await self.solver_agent.solve(semantic_json, engine_result)
|
| 204 |
+
_step_io("step8_solve_math", input_val=None, output_val=solution.get("answer"))
|
| 205 |
+
|
| 206 |
+
final_analysis = self._generate_step_description(semantic_json, engine_result)
|
| 207 |
+
|
| 208 |
+
status = "success"
|
| 209 |
+
return {
|
| 210 |
+
"status": status,
|
| 211 |
+
"job_id": job_id,
|
| 212 |
+
"geometry_dsl": dsl_code,
|
| 213 |
+
"coordinates": coordinates,
|
| 214 |
+
"polygon_order": engine_result.get("polygon_order", []),
|
| 215 |
+
"circles": engine_result.get("circles", []),
|
| 216 |
+
"lines": engine_result.get("lines", []),
|
| 217 |
+
"rays": engine_result.get("rays", []),
|
| 218 |
+
"drawing_phases": engine_result.get("drawing_phases", []),
|
| 219 |
+
"semantic": semantic_json,
|
| 220 |
+
"semantic_analysis": final_analysis,
|
| 221 |
+
"solution": solution,
|
| 222 |
+
"is_3d": is_3d,
|
| 223 |
+
}
|
agents/parser_agent.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from openai import AsyncOpenAI
|
| 5 |
+
from typing import Dict, Any
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
from app.url_utils import openai_compatible_api_key, sanitize_env
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
from app.llm_client import get_llm_client
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ParserAgent:
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.llm = get_llm_client()
|
| 20 |
+
|
| 21 |
+
async def process(self, text: str, feedback: str = None, context: Dict[str, Any] = None) -> Dict[str, Any]:
|
| 22 |
+
logger.info(f"==[ParserAgent] Processing input (len={len(text)})==")
|
| 23 |
+
if feedback:
|
| 24 |
+
logger.warning(f"[ParserAgent] Feedback from previous attempt: {feedback}")
|
| 25 |
+
if context:
|
| 26 |
+
logger.info(f"[ParserAgent] Using previous context (dsl_len={len(context.get('geometry_dsl', ''))})")
|
| 27 |
+
|
| 28 |
+
system_prompt = """
|
| 29 |
+
You are a Geometry Parser Agent. Extract geometric entities and constraints from Vietnamese/LaTeX math problem text.
|
| 30 |
+
|
| 31 |
+
=== CONTEXT AWARENESS ===
|
| 32 |
+
If previous context is provided, it means this is a follow-up request.
|
| 33 |
+
- Combine old entities with new ones.
|
| 34 |
+
- Update 'analysis' to reflect the entire problem state.
|
| 35 |
+
|
| 36 |
+
Output ONLY a JSON object with this EXACT structure (no extra keys, no markdown):
|
| 37 |
+
{
|
| 38 |
+
"entities": ["Point A", "Point B", ...],
|
| 39 |
+
"type": "pyramid|prism|sphere|rectangle|triangle|circle|parallelogram|trapezoid|square|rhombus|general",
|
| 40 |
+
"values": {"AB": 5, "SO": 15, "radius": 3},
|
| 41 |
+
"target_question": "Câu hỏi cụ thể cần giải (ví dụ: 'Tính diện tích tam giác ABC'). NẾU KHÔNG CÓ CÂU HỎI THÌ ĐỂ null.",
|
| 42 |
+
"analysis": "Tóm tắt ngắn gọn toàn bộ bài toán sau khi đã cập nhật các yêu cầu mới bằng tiếng Việt."
|
| 43 |
+
}
|
| 44 |
+
Rules:
|
| 45 |
+
- "analysis" MUST be a meaningful and UP-TO-DATE summary of the problem in Vietnamese.
|
| 46 |
+
- "target_question" must be concise.
|
| 47 |
+
- Include midpoints, auxiliary points in "entities" if mentioned.
|
| 48 |
+
- If feedback is provided, correct your previous output accordingly.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
user_content = f"Text: {text}"
|
| 52 |
+
if context:
|
| 53 |
+
user_content = f"PREVIOUS ANALYSIS: {context.get('analysis')}\nNEW REQUEST: {text}"
|
| 54 |
+
|
| 55 |
+
if feedback:
|
| 56 |
+
user_content += f"\nFeedback from previous attempt: {feedback}. Please correct the constraints."
|
| 57 |
+
|
| 58 |
+
logger.debug("[ParserAgent] Calling LLM (Multi-Layer)...")
|
| 59 |
+
raw = await self.llm.chat_completions_create(
|
| 60 |
+
messages=[
|
| 61 |
+
{"role": "system", "content": system_prompt},
|
| 62 |
+
{"role": "user", "content": user_content}
|
| 63 |
+
],
|
| 64 |
+
response_format={"type": "json_object"}
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Pre-process raw string: extract the JSON block if present
|
| 68 |
+
import re
|
| 69 |
+
clean_raw = raw.strip()
|
| 70 |
+
# Handle potential markdown code blocks
|
| 71 |
+
if clean_raw.startswith("```"):
|
| 72 |
+
import re
|
| 73 |
+
match = re.search(r"```(?:json)?\s*(.*?)\s*```", clean_raw, re.DOTALL)
|
| 74 |
+
if match:
|
| 75 |
+
clean_raw = match.group(1).strip()
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
result = json.loads(clean_raw)
|
| 79 |
+
except json.JSONDecodeError as e:
|
| 80 |
+
logger.error(f"[ParserAgent] JSON Parse Error: {e}. Attempting regex fallback...")
|
| 81 |
+
import re
|
| 82 |
+
json_match = re.search(r'(\{.*\})', clean_raw, re.DOTALL)
|
| 83 |
+
if json_match:
|
| 84 |
+
try:
|
| 85 |
+
# Handle single quotes if present (common LLM failure)
|
| 86 |
+
json_str = json_match.group(1)
|
| 87 |
+
if "'" in json_str and '"' not in json_str:
|
| 88 |
+
json_str = json_str.replace("'", '"')
|
| 89 |
+
result = json.loads(json_str)
|
| 90 |
+
except:
|
| 91 |
+
result = None
|
| 92 |
+
else:
|
| 93 |
+
result = None
|
| 94 |
+
|
| 95 |
+
if not result:
|
| 96 |
+
# Fallback for critical failure
|
| 97 |
+
result = {
|
| 98 |
+
"entities": [],
|
| 99 |
+
"type": "general",
|
| 100 |
+
"values": {},
|
| 101 |
+
"target_question": None,
|
| 102 |
+
"analysis": text
|
| 103 |
+
}
|
| 104 |
+
logger.info(f"[ParserAgent] LLM response received.")
|
| 105 |
+
logger.debug(f"[ParserAgent] Parsed JSON: {json.dumps(result, ensure_ascii=False, indent=2)}")
|
| 106 |
+
return result
|
agents/renderer_agent.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shim: geometry rendering lives in ``geometry_render`` (worker-safe package)."""
|
| 2 |
+
|
| 3 |
+
from geometry_render.renderer import RendererAgent
|
| 4 |
+
|
| 5 |
+
__all__ = ["RendererAgent"]
|
agents/solver_agent.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import sympy as sp
|
| 4 |
+
from typing import Dict, Any, List
|
| 5 |
+
from app.llm_client import get_llm_client
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
class SolverAgent:
|
| 10 |
+
def __init__(self):
|
| 11 |
+
self.llm = get_llm_client()
|
| 12 |
+
|
| 13 |
+
async def solve(self, semantic_data: Dict[str, Any], engine_result: Dict[str, Any]) -> Dict[str, Any]:
|
| 14 |
+
"""
|
| 15 |
+
Solves the geometric problem based on coordinates and the target question.
|
| 16 |
+
Returns a 'solution' dictionary with answer, steps, and symbolic_expression.
|
| 17 |
+
"""
|
| 18 |
+
target_question = semantic_data.get("target_question")
|
| 19 |
+
if not target_question:
|
| 20 |
+
# If no question, just return an empty solution structure
|
| 21 |
+
return {
|
| 22 |
+
"answer": None,
|
| 23 |
+
"steps": [],
|
| 24 |
+
"symbolic_expression": None
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
logger.info(f"==[SolverAgent] Solving for: '{target_question}'==")
|
| 28 |
+
|
| 29 |
+
input_text = semantic_data.get("input_text", "")
|
| 30 |
+
coordinates = engine_result.get("coordinates", {})
|
| 31 |
+
|
| 32 |
+
# We provide the coordinates and semantic context to the LLM to help it reason.
|
| 33 |
+
# The LLM is tasked with generating the solution structure directly.
|
| 34 |
+
|
| 35 |
+
system_prompt = """
|
| 36 |
+
You are a Geometry Solver Agent. Your goal is to provide a step-by-step solution for a specific geometric question.
|
| 37 |
+
|
| 38 |
+
=== DATA PROVIDED ===
|
| 39 |
+
1. Target Question: The specific question to answer.
|
| 40 |
+
2. Geometry Data: Entities and values extracted from the problem.
|
| 41 |
+
3. Coordinates: Calculated coordinates for all points.
|
| 42 |
+
|
| 43 |
+
=== REQUIREMENTS ===
|
| 44 |
+
- Provide the solution in the SAME LANGUAGE as the user's input.
|
| 45 |
+
- Use SymPy concepts if appropriate.
|
| 46 |
+
- Steps should be clear, concise, and logical.
|
| 47 |
+
- The final answer should be numerically or symbolically accurate based on the coordinates and geometric properties.
|
| 48 |
+
- For geometric proofs (e.g., "Is AB perpendicular to AC?"), explain the reasoning based on the data.
|
| 49 |
+
|
| 50 |
+
Output ONLY a JSON object with this structure:
|
| 51 |
+
{
|
| 52 |
+
"answer": "Chuỗi văn bản kết quả cuối cùng (kèm đơn vị nếu có)",
|
| 53 |
+
"steps": [
|
| 54 |
+
"Bước 1: ...",
|
| 55 |
+
"Bước 2: ...",
|
| 56 |
+
...
|
| 57 |
+
],
|
| 58 |
+
"symbolic_expression": "Biểu thức toán học rút gọn (LaTeX format optional)"
|
| 59 |
+
}
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
user_content = f"""
|
| 63 |
+
INPUT_TEXT: {input_text}
|
| 64 |
+
TARGET_QUESTION: {target_question}
|
| 65 |
+
SEMANTIC_DATA: {json.dumps(semantic_data, ensure_ascii=False)}
|
| 66 |
+
COORDINATES: {json.dumps(coordinates)}
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
logger.debug("[SolverAgent] Requesting solution from LLM...")
|
| 70 |
+
try:
|
| 71 |
+
raw = await self.llm.chat_completions_create(
|
| 72 |
+
messages=[
|
| 73 |
+
{"role": "system", "content": system_prompt},
|
| 74 |
+
{"role": "user", "content": user_content}
|
| 75 |
+
],
|
| 76 |
+
response_format={"type": "json_object"}
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
clean_raw = raw.strip()
|
| 80 |
+
# Handle potential markdown code blocks if the response_format wasn't strictly honored
|
| 81 |
+
if clean_raw.startswith("```"):
|
| 82 |
+
import re
|
| 83 |
+
match = re.search(r"```(?:json)?\s*(.*?)\s*```", clean_raw, re.DOTALL)
|
| 84 |
+
if match:
|
| 85 |
+
clean_raw = match.group(1).strip()
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
solution = json.loads(clean_raw)
|
| 89 |
+
except json.JSONDecodeError:
|
| 90 |
+
# Last resort: try to find anything between { and }
|
| 91 |
+
import re
|
| 92 |
+
json_match = re.search(r'(\{.*\})', clean_raw, re.DOTALL)
|
| 93 |
+
if json_match:
|
| 94 |
+
solution = json.loads(json_match.group(1))
|
| 95 |
+
else:
|
| 96 |
+
raise
|
| 97 |
+
|
| 98 |
+
logger.info("[SolverAgent] Solution generated successfully.")
|
| 99 |
+
return solution
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logger.error(f"[SolverAgent] Error generating solution: {e}")
|
| 102 |
+
logger.debug(f"[SolverAgent] Raw LLM output was: \n{raw if 'raw' in locals() else 'N/A'}")
|
| 103 |
+
return {
|
| 104 |
+
"answer": "Không thể tính toán lời giải tại thời điểm này.",
|
| 105 |
+
"steps": ["Đã xảy ra lỗi trong quá trình xử lý lời giải."],
|
| 106 |
+
"symbolic_expression": None
|
| 107 |
+
}
|
agents/torch_ultralytics_compat.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shim: moved to ``vision_ocr.compat`` for OCR worker isolation."""
|
| 2 |
+
|
| 3 |
+
from vision_ocr.compat import allow_ultralytics_weights
|
| 4 |
+
|
| 5 |
+
__all__ = ["allow_ultralytics_weights"]
|
app/chat_image_upload.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Validate and upload chat/solve attachment images to Supabase Storage (image bucket)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import uuid
|
| 8 |
+
from typing import Any, Dict, Tuple
|
| 9 |
+
|
| 10 |
+
from fastapi import HTTPException
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _get_next_image_version(session_id: str) -> int:
|
| 16 |
+
"""Same logic as worker.asset_manager.get_next_version for asset_type image."""
|
| 17 |
+
from app.supabase_client import get_supabase
|
| 18 |
+
|
| 19 |
+
supabase = get_supabase()
|
| 20 |
+
try:
|
| 21 |
+
res = (
|
| 22 |
+
supabase.table("session_assets")
|
| 23 |
+
.select("version")
|
| 24 |
+
.eq("session_id", session_id)
|
| 25 |
+
.eq("asset_type", "image")
|
| 26 |
+
.order("version", desc=True)
|
| 27 |
+
.limit(1)
|
| 28 |
+
.execute()
|
| 29 |
+
)
|
| 30 |
+
if res.data:
|
| 31 |
+
return res.data[0]["version"] + 1
|
| 32 |
+
return 1
|
| 33 |
+
except Exception as e:
|
| 34 |
+
logger.error("Error fetching image version: %s", e)
|
| 35 |
+
return 1
|
| 36 |
+
|
| 37 |
+
_MAX_BYTES_DEFAULT = 10 * 1024 * 1024
|
| 38 |
+
|
| 39 |
+
_EXT_TO_MIME: dict[str, str] = {
|
| 40 |
+
".png": "image/png",
|
| 41 |
+
".jpg": "image/jpeg",
|
| 42 |
+
".jpeg": "image/jpeg",
|
| 43 |
+
".webp": "image/webp",
|
| 44 |
+
".gif": "image/gif",
|
| 45 |
+
".bmp": "image/bmp",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _max_bytes() -> int:
|
| 50 |
+
raw = os.getenv("CHAT_IMAGE_MAX_BYTES")
|
| 51 |
+
if raw and raw.isdigit():
|
| 52 |
+
return min(int(raw), 50 * 1024 * 1024)
|
| 53 |
+
return _MAX_BYTES_DEFAULT
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _magic_ok(ext: str, body: bytes) -> bool:
|
| 57 |
+
if len(body) < 12:
|
| 58 |
+
return False
|
| 59 |
+
if ext == ".png":
|
| 60 |
+
return body.startswith(b"\x89PNG\r\n\x1a\n")
|
| 61 |
+
if ext in (".jpg", ".jpeg"):
|
| 62 |
+
return body.startswith(b"\xff\xd8\xff")
|
| 63 |
+
if ext == ".webp":
|
| 64 |
+
return body.startswith(b"RIFF") and body[8:12] == b"WEBP"
|
| 65 |
+
if ext == ".gif":
|
| 66 |
+
return body.startswith(b"GIF87a") or body.startswith(b"GIF89a")
|
| 67 |
+
if ext == ".bmp":
|
| 68 |
+
return body.startswith(b"BM")
|
| 69 |
+
return False
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def validate_chat_image_bytes(
|
| 73 |
+
filename: str | None,
|
| 74 |
+
body: bytes,
|
| 75 |
+
declared_content_type: str | None,
|
| 76 |
+
) -> Tuple[str, str]:
|
| 77 |
+
"""
|
| 78 |
+
Validate size, extension, and magic bytes.
|
| 79 |
+
Returns (extension_with_dot, content_type).
|
| 80 |
+
"""
|
| 81 |
+
max_b = _max_bytes()
|
| 82 |
+
if not body:
|
| 83 |
+
raise HTTPException(status_code=400, detail="Empty file.")
|
| 84 |
+
if len(body) > max_b:
|
| 85 |
+
raise HTTPException(
|
| 86 |
+
status_code=413,
|
| 87 |
+
detail=f"Image too large (max {max_b // (1024 * 1024)} MB).",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
ext = os.path.splitext(filename or "")[1].lower()
|
| 91 |
+
if not ext:
|
| 92 |
+
ext = ".png"
|
| 93 |
+
if ext not in _EXT_TO_MIME:
|
| 94 |
+
raise HTTPException(
|
| 95 |
+
status_code=400,
|
| 96 |
+
detail=f"Unsupported image type: {ext}. Allowed: {', '.join(sorted(_EXT_TO_MIME))}",
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
if not _magic_ok(ext, body):
|
| 100 |
+
raise HTTPException(
|
| 101 |
+
status_code=400,
|
| 102 |
+
detail="File content does not match declared image type.",
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
mime = _EXT_TO_MIME[ext]
|
| 106 |
+
if declared_content_type:
|
| 107 |
+
decl = declared_content_type.split(";")[0].strip().lower()
|
| 108 |
+
if decl and decl not in ("application/octet-stream", mime) and decl != mime:
|
| 109 |
+
logger.warning(
|
| 110 |
+
"Content-Type mismatch (declared=%s, inferred=%s); using inferred.",
|
| 111 |
+
declared_content_type,
|
| 112 |
+
mime,
|
| 113 |
+
)
|
| 114 |
+
return ext, mime
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def upload_session_chat_image(
|
| 118 |
+
session_id: str,
|
| 119 |
+
job_id: str,
|
| 120 |
+
file_bytes: bytes,
|
| 121 |
+
ext_with_dot: str,
|
| 122 |
+
content_type: str,
|
| 123 |
+
) -> Dict[str, Any]:
|
| 124 |
+
"""
|
| 125 |
+
Upload to SUPABASE_IMAGE_BUCKET (default: image), insert session_assets row.
|
| 126 |
+
Returns dict with public_url, storage_path, version, session_asset_id (if returned).
|
| 127 |
+
"""
|
| 128 |
+
from app.supabase_client import get_supabase
|
| 129 |
+
|
| 130 |
+
supabase = get_supabase()
|
| 131 |
+
bucket_name = os.getenv("SUPABASE_IMAGE_BUCKET", "image")
|
| 132 |
+
raw_ext = ext_with_dot.lstrip(".").lower()
|
| 133 |
+
version = _get_next_image_version(session_id)
|
| 134 |
+
file_name = f"image_v{version}_{job_id}.{raw_ext}"
|
| 135 |
+
storage_path = f"sessions/{session_id}/{file_name}"
|
| 136 |
+
|
| 137 |
+
supabase.storage.from_(bucket_name).upload(
|
| 138 |
+
path=storage_path,
|
| 139 |
+
file=file_bytes,
|
| 140 |
+
file_options={"content-type": content_type},
|
| 141 |
+
)
|
| 142 |
+
public_url = supabase.storage.from_(bucket_name).get_public_url(storage_path)
|
| 143 |
+
if isinstance(public_url, dict):
|
| 144 |
+
public_url = public_url.get("publicUrl") or public_url.get("public_url") or str(public_url)
|
| 145 |
+
|
| 146 |
+
row = {
|
| 147 |
+
"session_id": session_id,
|
| 148 |
+
"job_id": job_id,
|
| 149 |
+
"asset_type": "image",
|
| 150 |
+
"storage_path": storage_path,
|
| 151 |
+
"public_url": public_url,
|
| 152 |
+
"version": version,
|
| 153 |
+
}
|
| 154 |
+
ins = supabase.table("session_assets").insert(row).select("id").execute()
|
| 155 |
+
asset_id = None
|
| 156 |
+
if ins.data and len(ins.data) > 0:
|
| 157 |
+
asset_id = ins.data[0].get("id")
|
| 158 |
+
|
| 159 |
+
log_data = {
|
| 160 |
+
"public_url": public_url,
|
| 161 |
+
"storage_path": storage_path,
|
| 162 |
+
"version": version,
|
| 163 |
+
"session_asset_id": str(asset_id) if asset_id else None,
|
| 164 |
+
}
|
| 165 |
+
logger.info("Uploaded chat image: %s", log_data)
|
| 166 |
+
return {
|
| 167 |
+
"public_url": public_url,
|
| 168 |
+
"storage_path": storage_path,
|
| 169 |
+
"version": version,
|
| 170 |
+
"session_asset_id": str(asset_id) if asset_id else None,
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def upload_ephemeral_ocr_blob(
|
| 175 |
+
file_bytes: bytes,
|
| 176 |
+
ext_with_dot: str,
|
| 177 |
+
content_type: str,
|
| 178 |
+
) -> Tuple[str, str]:
|
| 179 |
+
"""
|
| 180 |
+
Upload bytes to image bucket under _ocr_temp/ for worker-only OCR (no session_assets row).
|
| 181 |
+
Returns (storage_path, public_url). Caller must delete_storage_object when done.
|
| 182 |
+
"""
|
| 183 |
+
from app.supabase_client import get_supabase
|
| 184 |
+
|
| 185 |
+
bucket_name = os.getenv("SUPABASE_IMAGE_BUCKET", "image")
|
| 186 |
+
raw_ext = ext_with_dot.lstrip(".").lower() or "png"
|
| 187 |
+
name = f"_ocr_temp/{uuid.uuid4().hex}.{raw_ext}"
|
| 188 |
+
supabase = get_supabase()
|
| 189 |
+
supabase.storage.from_(bucket_name).upload(
|
| 190 |
+
path=name,
|
| 191 |
+
file=file_bytes,
|
| 192 |
+
file_options={"content-type": content_type},
|
| 193 |
+
)
|
| 194 |
+
public_url = supabase.storage.from_(bucket_name).get_public_url(name)
|
| 195 |
+
if isinstance(public_url, dict):
|
| 196 |
+
public_url = public_url.get("publicUrl") or public_url.get("public_url") or str(public_url)
|
| 197 |
+
return name, public_url
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def delete_storage_object(bucket_name: str, storage_path: str) -> None:
|
| 201 |
+
try:
|
| 202 |
+
from app.supabase_client import get_supabase
|
| 203 |
+
|
| 204 |
+
get_supabase().storage.from_(bucket_name).remove([storage_path])
|
| 205 |
+
except Exception as e:
|
| 206 |
+
logger.warning("delete_storage_object failed path=%s: %s", storage_path, e)
|
app/dependencies.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import HTTPException, Header
|
| 2 |
+
|
| 3 |
+
from app.supabase_client import get_supabase, get_supabase_for_user_jwt
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
async def get_current_user_id(authorization: str | None = Header(None)):
|
| 7 |
+
"""
|
| 8 |
+
Authenticate user using Supabase JWT.
|
| 9 |
+
Expected Header: Authorization: Bearer <token>
|
| 10 |
+
"""
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
if not authorization:
|
| 14 |
+
raise HTTPException(
|
| 15 |
+
status_code=401,
|
| 16 |
+
detail="Authorization header missing or invalid. Use 'Bearer <token>'",
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
if os.getenv("ALLOW_TEST_BYPASS") == "true" and authorization.startswith("Test "):
|
| 20 |
+
return authorization.split(" ")[1]
|
| 21 |
+
|
| 22 |
+
if not authorization.startswith("Bearer "):
|
| 23 |
+
raise HTTPException(
|
| 24 |
+
status_code=401,
|
| 25 |
+
detail="Authorization header missing or invalid. Use 'Bearer <token>'",
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
token = authorization.split(" ")[1]
|
| 29 |
+
supabase = get_supabase()
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
user_response = supabase.auth.get_user(token)
|
| 33 |
+
if not user_response or not user_response.user:
|
| 34 |
+
raise HTTPException(status_code=401, detail="Invalid session or token.")
|
| 35 |
+
|
| 36 |
+
return user_response.user.id
|
| 37 |
+
except HTTPException:
|
| 38 |
+
raise
|
| 39 |
+
except Exception as e:
|
| 40 |
+
raise HTTPException(status_code=401, detail=f"Authentication failed: {str(e)}")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
async def get_authenticated_supabase(authorization: str = Header(...)):
|
| 44 |
+
"""
|
| 45 |
+
Supabase client that carries the user's JWT (anon key + Authorization header).
|
| 46 |
+
Use for routes that should respect Row Level Security; pair with app logic as needed.
|
| 47 |
+
"""
|
| 48 |
+
if not authorization or not authorization.startswith("Bearer "):
|
| 49 |
+
raise HTTPException(
|
| 50 |
+
status_code=401,
|
| 51 |
+
detail="Authorization header missing or invalid. Use 'Bearer <token>'",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
token = authorization.split(" ")[1]
|
| 55 |
+
supabase = get_supabase()
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
user_response = supabase.auth.get_user(token)
|
| 59 |
+
if not user_response or not user_response.user:
|
| 60 |
+
raise HTTPException(status_code=401, detail="Invalid session or token.")
|
| 61 |
+
except HTTPException:
|
| 62 |
+
raise
|
| 63 |
+
except Exception as e:
|
| 64 |
+
raise HTTPException(status_code=401, detail=f"Authentication failed: {str(e)}")
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
return get_supabase_for_user_jwt(token)
|
| 68 |
+
except RuntimeError as e:
|
| 69 |
+
raise HTTPException(status_code=503, detail=str(e))
|
app/errors.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Map exceptions to short, user-visible messages (avoid leaking HTML bodies from 404 proxies)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _looks_like_html(text: str) -> bool:
|
| 11 |
+
t = text.lstrip()[:500].lower()
|
| 12 |
+
return t.startswith("<!doctype") or t.startswith("<html") or "<html" in t[:200]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def format_error_for_user(exc: BaseException) -> str:
|
| 16 |
+
"""
|
| 17 |
+
Produce a safe message for chat/UI. Full detail stays in server logs via logger.exception.
|
| 18 |
+
"""
|
| 19 |
+
# httpx: wrong URL often returns 404 HTML; don't show body
|
| 20 |
+
try:
|
| 21 |
+
import httpx
|
| 22 |
+
|
| 23 |
+
if isinstance(exc, httpx.HTTPStatusError):
|
| 24 |
+
req = exc.request
|
| 25 |
+
code = exc.response.status_code
|
| 26 |
+
url_hint = ""
|
| 27 |
+
try:
|
| 28 |
+
url_hint = str(req.url.host) if req and req.url else ""
|
| 29 |
+
except Exception:
|
| 30 |
+
pass
|
| 31 |
+
logger.warning(
|
| 32 |
+
"HTTPStatusError %s for %s (response not shown to user)",
|
| 33 |
+
code,
|
| 34 |
+
url_hint or "?",
|
| 35 |
+
)
|
| 36 |
+
return (
|
| 37 |
+
"Kiểm tra URL API, khóa bí mật và biến môi trường (OpenRouter/Supabase/Redis)."
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
if isinstance(exc, httpx.RequestError):
|
| 41 |
+
return "Không kết nối được tới dịch vụ ngoài (mạng hoặc URL sai)."
|
| 42 |
+
except ImportError:
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
raw = str(exc).strip()
|
| 46 |
+
if not raw:
|
| 47 |
+
return "Đã xảy ra lỗi không xác định."
|
| 48 |
+
|
| 49 |
+
if _looks_like_html(raw):
|
| 50 |
+
logger.warning("Suppressed HTML error body from user-facing message")
|
| 51 |
+
return (
|
| 52 |
+
"Dịch vụ trả về trang lỗi (thường là URL API sai hoặc endpoint không tồn tại — HTTP 404). "
|
| 53 |
+
"Kiểm tra OPENROUTER_MODEL và khóa API trên server."
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
if len(raw) > 800:
|
| 57 |
+
return raw[:800] + "…"
|
| 58 |
+
|
| 59 |
+
return raw
|
app/job_poll.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Normalize Supabase `jobs` rows for polling / WebSocket clients (stable `job_id` + JSON `result`)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _coerce_result(value: Any) -> Any:
|
| 13 |
+
if value is None:
|
| 14 |
+
return None
|
| 15 |
+
if isinstance(value, (dict, list)):
|
| 16 |
+
return value
|
| 17 |
+
if isinstance(value, str):
|
| 18 |
+
try:
|
| 19 |
+
return json.loads(value)
|
| 20 |
+
except json.JSONDecodeError:
|
| 21 |
+
logger.warning("job_poll: result is non-JSON string, returning raw")
|
| 22 |
+
return {"raw": value}
|
| 23 |
+
return value
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def normalize_job_row_for_client(row: dict[str, Any]) -> dict[str, Any]:
|
| 27 |
+
"""
|
| 28 |
+
Build a JSON-serializable dict that always includes:
|
| 29 |
+
- ``job_id`` (alias of DB ``id``) for clients that expect it on poll bodies
|
| 30 |
+
- ``status`` as str
|
| 31 |
+
- ``result`` as object/array when stored as JSON string
|
| 32 |
+
All other columns are passed through (UUID/datetime become JSON-safe via FastAPI encoder).
|
| 33 |
+
"""
|
| 34 |
+
out = dict(row)
|
| 35 |
+
jid = out.get("id")
|
| 36 |
+
if jid is not None:
|
| 37 |
+
out["job_id"] = str(jid)
|
| 38 |
+
st = out.get("status")
|
| 39 |
+
if st is not None:
|
| 40 |
+
out["status"] = str(st)
|
| 41 |
+
if "result" in out:
|
| 42 |
+
out["result"] = _coerce_result(out.get("result"))
|
| 43 |
+
if out.get("user_id") is not None:
|
| 44 |
+
out["user_id"] = str(out["user_id"])
|
| 45 |
+
if out.get("session_id") is not None:
|
| 46 |
+
out["session_id"] = str(out["session_id"])
|
| 47 |
+
return out
|
app/llm_client.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import asyncio
|
| 4 |
+
import logging
|
| 5 |
+
from openai import AsyncOpenAI
|
| 6 |
+
from typing import List, Dict, Any, Optional
|
| 7 |
+
from app.url_utils import openai_compatible_api_key, sanitize_env
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class MultiLayerLLMClient:
|
| 12 |
+
def __init__(self):
|
| 13 |
+
# 1. Models sequence loading
|
| 14 |
+
self.models = []
|
| 15 |
+
for i in range(1, 4):
|
| 16 |
+
model = os.getenv(f"OPENROUTER_MODEL_{i}")
|
| 17 |
+
if model:
|
| 18 |
+
self.models.append(model)
|
| 19 |
+
|
| 20 |
+
# Fallback to legacy OPENROUTER_MODEL if no numbered models found
|
| 21 |
+
if not self.models:
|
| 22 |
+
legacy_model = os.getenv("OPENROUTER_MODEL", "google/gemini-2.0-flash-001")
|
| 23 |
+
self.models = [legacy_model]
|
| 24 |
+
|
| 25 |
+
# 2. Key selection (No rotation, always use the first available key)
|
| 26 |
+
api_key = os.getenv("OPENROUTER_API_KEY_1") or os.getenv("OPENROUTER_API_KEY")
|
| 27 |
+
|
| 28 |
+
if not api_key:
|
| 29 |
+
logger.error("[LLM] No OpenRouter API key found.")
|
| 30 |
+
self.client = None
|
| 31 |
+
else:
|
| 32 |
+
self.client = AsyncOpenAI(
|
| 33 |
+
api_key=openai_compatible_api_key(api_key),
|
| 34 |
+
base_url="https://openrouter.ai/api/v1",
|
| 35 |
+
timeout=60.0,
|
| 36 |
+
default_headers={
|
| 37 |
+
"HTTP-Referer": "https://mathsolver.ai",
|
| 38 |
+
"X-Title": "MathSolver Backend",
|
| 39 |
+
}
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
async def chat_completions_create(
|
| 43 |
+
self,
|
| 44 |
+
messages: List[Dict[str, str]],
|
| 45 |
+
response_format: Optional[Dict[str, str]] = None,
|
| 46 |
+
**kwargs
|
| 47 |
+
) -> str:
|
| 48 |
+
"""
|
| 49 |
+
Implements Model Fallback Sequence: Model 1 -> Model 2 -> Model 3.
|
| 50 |
+
Always starts from Model 1 for every new call.
|
| 51 |
+
"""
|
| 52 |
+
if not self.client:
|
| 53 |
+
raise ValueError("No API client configured. Check your API keys.")
|
| 54 |
+
|
| 55 |
+
MAX_ATTEMPTS = len(self.models)
|
| 56 |
+
RETRY_DELAY = 1.0 # second
|
| 57 |
+
|
| 58 |
+
for attempt_idx in range(MAX_ATTEMPTS):
|
| 59 |
+
current_model = self.models[attempt_idx]
|
| 60 |
+
attempt_num = attempt_idx + 1
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
logger.info(f"[LLM] Attempt {attempt_num}/{MAX_ATTEMPTS} using Model: {current_model}...")
|
| 64 |
+
|
| 65 |
+
response = await self.client.chat.completions.create(
|
| 66 |
+
model=current_model,
|
| 67 |
+
messages=messages,
|
| 68 |
+
response_format=response_format,
|
| 69 |
+
**kwargs
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if not response or not getattr(response, "choices", None):
|
| 73 |
+
raise ValueError(f"Invalid response structure from model {current_model}")
|
| 74 |
+
|
| 75 |
+
content = response.choices[0].message.content
|
| 76 |
+
if content:
|
| 77 |
+
logger.info(f"[LLM] SUCCESS on attempt {attempt_num} ({current_model}).")
|
| 78 |
+
return content
|
| 79 |
+
|
| 80 |
+
raise ValueError(f"Empty content from model {current_model}")
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
err_msg = f"{type(e).__name__}: {str(e)}"
|
| 84 |
+
logger.warning(f"[LLM] FAILED on attempt {attempt_num} ({current_model}): {err_msg}")
|
| 85 |
+
|
| 86 |
+
if attempt_num < MAX_ATTEMPTS:
|
| 87 |
+
logger.info(f"[LLM] Retrying next model in {RETRY_DELAY}s...")
|
| 88 |
+
await asyncio.sleep(RETRY_DELAY)
|
| 89 |
+
else:
|
| 90 |
+
logger.error(f"[LLM] FINAL FAILURE after {attempt_num} models.")
|
| 91 |
+
raise e
|
| 92 |
+
|
| 93 |
+
# Global instance for easy reuse (singleton-ish)
|
| 94 |
+
_llm_client = None
|
| 95 |
+
|
| 96 |
+
def get_llm_client() -> MultiLayerLLMClient:
|
| 97 |
+
global _llm_client
|
| 98 |
+
if _llm_client is None:
|
| 99 |
+
_llm_client = MultiLayerLLMClient()
|
| 100 |
+
return _llm_client
|
app/logging_setup.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Logging theo một biến LOG_LEVEL: debug | info | warning | error."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
from typing import Final
|
| 8 |
+
|
| 9 |
+
_SETUP_DONE = False
|
| 10 |
+
|
| 11 |
+
PIPELINE_LOGGER_NAME: Final = "app.pipeline"
|
| 12 |
+
CACHE_LOGGER_NAME: Final = "app.cache"
|
| 13 |
+
STEPS_LOGGER_NAME: Final = "app.steps"
|
| 14 |
+
ACCESS_LOGGER_NAME: Final = "app.access"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _normalize_level() -> str:
|
| 18 |
+
raw = os.getenv("LOG_LEVEL", "info").strip().lower()
|
| 19 |
+
if raw in ("debug", "info", "warning", "error"):
|
| 20 |
+
return raw
|
| 21 |
+
return "info"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def setup_application_logging() -> None:
|
| 25 |
+
"""Idempotent; gọi khi khởi động process (uvicorn, celery, worker_health)."""
|
| 26 |
+
global _SETUP_DONE
|
| 27 |
+
if _SETUP_DONE:
|
| 28 |
+
return
|
| 29 |
+
_SETUP_DONE = True
|
| 30 |
+
|
| 31 |
+
mode = _normalize_level()
|
| 32 |
+
|
| 33 |
+
level_map = {
|
| 34 |
+
"debug": logging.DEBUG,
|
| 35 |
+
"info": logging.INFO,
|
| 36 |
+
"warning": logging.WARNING,
|
| 37 |
+
"error": logging.ERROR,
|
| 38 |
+
}
|
| 39 |
+
root_level = level_map[mode]
|
| 40 |
+
|
| 41 |
+
fmt_named = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
|
| 42 |
+
fmt_short = "%(asctime)s | %(levelname)-8s | %(message)s"
|
| 43 |
+
|
| 44 |
+
logging.basicConfig(
|
| 45 |
+
level=root_level,
|
| 46 |
+
format=fmt_named if mode == "debug" else fmt_short,
|
| 47 |
+
datefmt="%H:%M:%S",
|
| 48 |
+
force=True,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 52 |
+
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
| 53 |
+
logging.getLogger("openai").setLevel(logging.WARNING)
|
| 54 |
+
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
| 55 |
+
logging.getLogger("uvicorn.error").setLevel(logging.INFO)
|
| 56 |
+
# HTTP/2 stack (httpx/httpcore) — khi LOG_LEVEL=debug root=DEBUG sẽ tràn log hpack; không cần cho debug app
|
| 57 |
+
for _name in ("hpack", "h2", "hyperframe", "urllib3"):
|
| 58 |
+
logging.getLogger(_name).setLevel(logging.WARNING)
|
| 59 |
+
|
| 60 |
+
if mode == "debug":
|
| 61 |
+
logging.getLogger("agents").setLevel(logging.DEBUG)
|
| 62 |
+
logging.getLogger("solver").setLevel(logging.DEBUG)
|
| 63 |
+
logging.getLogger("app").setLevel(logging.DEBUG)
|
| 64 |
+
logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.DEBUG)
|
| 65 |
+
logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.DEBUG)
|
| 66 |
+
logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.INFO)
|
| 67 |
+
logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.INFO)
|
| 68 |
+
logging.getLogger("app.main").setLevel(logging.INFO)
|
| 69 |
+
logging.getLogger("worker").setLevel(logging.INFO)
|
| 70 |
+
elif mode == "info":
|
| 71 |
+
# Chỉ HTTP access (app.access) + startup; ẩn chi tiết agents/orchestrator/pipeline SUCCESS
|
| 72 |
+
logging.getLogger("agents").setLevel(logging.INFO)
|
| 73 |
+
logging.getLogger("solver").setLevel(logging.WARNING)
|
| 74 |
+
logging.getLogger("app").setLevel(logging.INFO)
|
| 75 |
+
logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.WARNING)
|
| 76 |
+
logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.WARNING)
|
| 77 |
+
logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.WARNING)
|
| 78 |
+
logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.INFO)
|
| 79 |
+
logging.getLogger("app.main").setLevel(logging.INFO)
|
| 80 |
+
logging.getLogger("worker").setLevel(logging.WARNING)
|
| 81 |
+
elif mode == "warning":
|
| 82 |
+
logging.getLogger("agents").setLevel(logging.WARNING)
|
| 83 |
+
logging.getLogger("solver").setLevel(logging.WARNING)
|
| 84 |
+
logging.getLogger("app.routers").setLevel(logging.WARNING)
|
| 85 |
+
logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.WARNING)
|
| 86 |
+
logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.WARNING)
|
| 87 |
+
logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.WARNING)
|
| 88 |
+
logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.WARNING)
|
| 89 |
+
logging.getLogger("app.main").setLevel(logging.WARNING)
|
| 90 |
+
logging.getLogger("worker").setLevel(logging.WARNING)
|
| 91 |
+
else: # error
|
| 92 |
+
logging.getLogger("agents").setLevel(logging.ERROR)
|
| 93 |
+
logging.getLogger("solver").setLevel(logging.ERROR)
|
| 94 |
+
logging.getLogger("app.routers").setLevel(logging.ERROR)
|
| 95 |
+
logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.ERROR)
|
| 96 |
+
logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.ERROR)
|
| 97 |
+
logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.ERROR)
|
| 98 |
+
logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.ERROR)
|
| 99 |
+
logging.getLogger("app.main").setLevel(logging.ERROR)
|
| 100 |
+
logging.getLogger("worker").setLevel(logging.ERROR)
|
| 101 |
+
|
| 102 |
+
logging.getLogger(__name__).debug(
|
| 103 |
+
"LOG_LEVEL=%s root=%s", mode, logging.getLevelName(root_level)
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_log_level() -> str:
|
| 108 |
+
return _normalize_level()
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def is_debug_level() -> bool:
|
| 112 |
+
return _normalize_level() == "debug"
|
app/logutil.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""log_step (debug), pipeline (debug), access log ở middleware."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from app.logging_setup import PIPELINE_LOGGER_NAME, STEPS_LOGGER_NAME
|
| 11 |
+
|
| 12 |
+
_pipeline = logging.getLogger(PIPELINE_LOGGER_NAME)
|
| 13 |
+
_steps = logging.getLogger(STEPS_LOGGER_NAME)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def is_debug_mode() -> bool:
|
| 17 |
+
"""Chi tiết từng bước chỉ khi LOG_LEVEL=debug."""
|
| 18 |
+
return os.getenv("LOG_LEVEL", "info").strip().lower() == "debug"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _truncate(val: Any, max_len: int = 2000) -> Any:
|
| 22 |
+
if val is None:
|
| 23 |
+
return None
|
| 24 |
+
if isinstance(val, (int, float, bool)):
|
| 25 |
+
return val
|
| 26 |
+
s = str(val)
|
| 27 |
+
if len(s) > max_len:
|
| 28 |
+
return s[:max_len] + f"... (+{len(s) - max_len} chars)"
|
| 29 |
+
return s
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def log_step(step: str, **fields: Any) -> None:
|
| 33 |
+
"""Chỉ khi LOG_LEVEL=debug: DB / cache / orchestrator."""
|
| 34 |
+
if not is_debug_mode():
|
| 35 |
+
return
|
| 36 |
+
safe = {k: _truncate(v) for k, v in fields.items()}
|
| 37 |
+
try:
|
| 38 |
+
payload = json.dumps(safe, ensure_ascii=False, default=str)
|
| 39 |
+
except Exception:
|
| 40 |
+
payload = str(safe)
|
| 41 |
+
_steps.debug("[step:%s] %s", step, payload)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def log_pipeline_success(operation: str, **fields: Any) -> None:
|
| 45 |
+
"""Chỉ hiện khi debug (pipeline SUCCESS không dùng ở info — đã có app.access)."""
|
| 46 |
+
if not is_debug_mode():
|
| 47 |
+
return
|
| 48 |
+
safe = {k: _truncate(v, 500) for k, v in fields.items()}
|
| 49 |
+
_pipeline.info(
|
| 50 |
+
"SUCCESS %s %s",
|
| 51 |
+
operation,
|
| 52 |
+
json.dumps(safe, ensure_ascii=False, default=str),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def log_pipeline_failure(operation: str, error: str | None = None, **fields: Any) -> None:
|
| 57 |
+
"""Thất bại pipeline: luôn dùng WARNING để vẫn thấy khi LOG_LEVEL=warning."""
|
| 58 |
+
if is_debug_mode():
|
| 59 |
+
safe = {k: _truncate(v, 500) for k, v in fields.items()}
|
| 60 |
+
_pipeline.warning(
|
| 61 |
+
"FAIL %s err=%s %s",
|
| 62 |
+
operation,
|
| 63 |
+
_truncate(error, 300),
|
| 64 |
+
json.dumps(safe, ensure_ascii=False, default=str),
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
_pipeline.warning("FAIL %s", operation)
|
app/main.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import uuid
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
from fastapi import Depends, FastAPI, File, HTTPException, UploadFile
|
| 11 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
+
from starlette.requests import Request
|
| 13 |
+
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
from app.runtime_env import apply_runtime_env_defaults
|
| 17 |
+
|
| 18 |
+
apply_runtime_env_defaults()
|
| 19 |
+
|
| 20 |
+
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
|
| 21 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
|
| 22 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="albumentations")
|
| 23 |
+
|
| 24 |
+
from app.logging_setup import ACCESS_LOGGER_NAME, get_log_level, setup_application_logging
|
| 25 |
+
|
| 26 |
+
setup_application_logging()
|
| 27 |
+
|
| 28 |
+
# Routers (after logging)
|
| 29 |
+
from app.dependencies import get_current_user_id
|
| 30 |
+
from app.ocr_local_file import ocr_from_local_image_path
|
| 31 |
+
from app.routers import auth, sessions, solve
|
| 32 |
+
from agents.ocr_agent import OCRAgent
|
| 33 |
+
from app.routers.solve import get_orchestrator
|
| 34 |
+
from app.job_poll import normalize_job_row_for_client
|
| 35 |
+
from app.supabase_client import get_supabase
|
| 36 |
+
from app.websocket_manager import register_websocket_routes
|
| 37 |
+
|
| 38 |
+
logger = logging.getLogger("app.main")
|
| 39 |
+
_access = logging.getLogger(ACCESS_LOGGER_NAME)
|
| 40 |
+
|
| 41 |
+
app = FastAPI(title="Visual Math Solver API v5.1")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@app.middleware("http")
|
| 45 |
+
async def access_log_middleware(request: Request, call_next):
|
| 46 |
+
"""LOG_LEVEL=info/debug: mọi request; warning: chỉ 4xx/5xx; error: chỉ 4xx/5xx ở mức error."""
|
| 47 |
+
start = time.perf_counter()
|
| 48 |
+
response = await call_next(request)
|
| 49 |
+
ms = (time.perf_counter() - start) * 1000
|
| 50 |
+
mode = get_log_level()
|
| 51 |
+
method = request.method
|
| 52 |
+
path = request.url.path
|
| 53 |
+
status = response.status_code
|
| 54 |
+
|
| 55 |
+
if mode in ("debug", "info"):
|
| 56 |
+
_access.info("%s %s -> %s (%.0fms)", method, path, status, ms)
|
| 57 |
+
elif mode == "warning":
|
| 58 |
+
if status >= 500:
|
| 59 |
+
_access.error("%s %s -> %s (%.0fms)", method, path, status, ms)
|
| 60 |
+
elif status >= 400:
|
| 61 |
+
_access.warning("%s %s -> %s (%.0fms)", method, path, status, ms)
|
| 62 |
+
elif mode == "error":
|
| 63 |
+
if status >= 400:
|
| 64 |
+
_access.error("%s %s -> %s", method, path, status)
|
| 65 |
+
|
| 66 |
+
return response
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
from worker.celery_app import BROKER_URL
|
| 70 |
+
|
| 71 |
+
_broker_tail = BROKER_URL.split("@")[-1] if "@" in BROKER_URL else BROKER_URL
|
| 72 |
+
if get_log_level() in ("debug", "info"):
|
| 73 |
+
logger.info("App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail)
|
| 74 |
+
else:
|
| 75 |
+
logger.warning(
|
| 76 |
+
"App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
app.add_middleware(
|
| 80 |
+
CORSMiddleware,
|
| 81 |
+
allow_origins=[
|
| 82 |
+
"http://localhost:3000",
|
| 83 |
+
"http://127.0.0.1:3000",
|
| 84 |
+
"http://localhost:3005",
|
| 85 |
+
],
|
| 86 |
+
allow_credentials=True,
|
| 87 |
+
allow_methods=["*"],
|
| 88 |
+
allow_headers=["*"],
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
app.include_router(auth.router)
|
| 92 |
+
app.include_router(sessions.router)
|
| 93 |
+
app.include_router(solve.router)
|
| 94 |
+
|
| 95 |
+
register_websocket_routes(app)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_ocr_agent() -> OCRAgent:
|
| 99 |
+
"""Same OCR instance as the solve pipeline (no duplicate model load)."""
|
| 100 |
+
return get_orchestrator().ocr_agent
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
supabase_client = get_supabase()
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@app.get("/")
|
| 107 |
+
def read_root():
|
| 108 |
+
return {"message": "Visual Math Solver API v5.1 is running", "version": "5.1"}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@app.post("/api/v1/ocr")
|
| 112 |
+
async def upload_ocr(
|
| 113 |
+
file: UploadFile = File(...),
|
| 114 |
+
_user_id=Depends(get_current_user_id),
|
| 115 |
+
):
|
| 116 |
+
"""OCR upload: requires authenticated user."""
|
| 117 |
+
temp_path = f"temp_{uuid.uuid4()}.png"
|
| 118 |
+
with open(temp_path, "wb") as buffer:
|
| 119 |
+
buffer.write(await file.read())
|
| 120 |
+
|
| 121 |
+
try:
|
| 122 |
+
text = await ocr_from_local_image_path(temp_path, file.filename, get_ocr_agent())
|
| 123 |
+
return {"text": text}
|
| 124 |
+
finally:
|
| 125 |
+
if os.path.exists(temp_path):
|
| 126 |
+
os.remove(temp_path)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@app.get("/api/v1/solve/{job_id}")
|
| 130 |
+
async def get_job_status(
|
| 131 |
+
job_id: str,
|
| 132 |
+
user_id=Depends(get_current_user_id),
|
| 133 |
+
):
|
| 134 |
+
"""Retrieve job status (can be used for polling if WS fails). Owner-only."""
|
| 135 |
+
response = supabase_client.table("jobs").select("*").eq("id", job_id).execute()
|
| 136 |
+
if not response.data:
|
| 137 |
+
raise HTTPException(status_code=404, detail="Job not found")
|
| 138 |
+
job = response.data[0]
|
| 139 |
+
if job.get("user_id") is not None and str(job["user_id"]) != str(user_id):
|
| 140 |
+
raise HTTPException(status_code=403, detail="Forbidden: You do not own this job.")
|
| 141 |
+
# Stable contract for FE poll (job_id alias, parsed result JSON, string UUIDs)
|
| 142 |
+
return normalize_job_row_for_client(job)
|
app/models/schemas.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, EmailStr, field_validator
|
| 2 |
+
from typing import Optional, List, Any, Dict
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
import uuid
|
| 5 |
+
|
| 6 |
+
from app.url_utils import sanitize_url
|
| 7 |
+
|
| 8 |
+
# --- Auth Schemas ---
|
| 9 |
+
class UserProfile(BaseModel):
|
| 10 |
+
id: uuid.UUID
|
| 11 |
+
display_name: Optional[str] = None
|
| 12 |
+
avatar_url: Optional[str] = None
|
| 13 |
+
created_at: datetime
|
| 14 |
+
|
| 15 |
+
class User(BaseModel):
|
| 16 |
+
id: uuid.UUID
|
| 17 |
+
email: EmailStr
|
| 18 |
+
|
| 19 |
+
# --- Session Schemas ---
|
| 20 |
+
class SessionBase(BaseModel):
|
| 21 |
+
title: str = "Bài toán mới"
|
| 22 |
+
|
| 23 |
+
class SessionCreate(SessionBase):
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
class Session(SessionBase):
|
| 27 |
+
id: uuid.UUID
|
| 28 |
+
user_id: uuid.UUID
|
| 29 |
+
created_at: datetime
|
| 30 |
+
updated_at: datetime
|
| 31 |
+
|
| 32 |
+
class Config:
|
| 33 |
+
from_attributes = True
|
| 34 |
+
|
| 35 |
+
# --- Message Schemas ---
|
| 36 |
+
class MessageBase(BaseModel):
|
| 37 |
+
role: str
|
| 38 |
+
type: str = "text"
|
| 39 |
+
content: str
|
| 40 |
+
metadata: Dict[str, Any] = {}
|
| 41 |
+
|
| 42 |
+
class MessageCreate(MessageBase):
|
| 43 |
+
session_id: uuid.UUID
|
| 44 |
+
|
| 45 |
+
class Message(MessageBase):
|
| 46 |
+
id: uuid.UUID
|
| 47 |
+
session_id: uuid.UUID
|
| 48 |
+
created_at: datetime
|
| 49 |
+
|
| 50 |
+
class Config:
|
| 51 |
+
from_attributes = True
|
| 52 |
+
|
| 53 |
+
# --- Solve Job Schemas ---
|
| 54 |
+
class SolveRequest(BaseModel):
|
| 55 |
+
text: str
|
| 56 |
+
image_url: Optional[str] = None
|
| 57 |
+
|
| 58 |
+
@field_validator("image_url", mode="before")
|
| 59 |
+
@classmethod
|
| 60 |
+
def _clean_image_url(cls, v):
|
| 61 |
+
return sanitize_url(v) if v is not None else None
|
| 62 |
+
|
| 63 |
+
class SolveResponse(BaseModel):
|
| 64 |
+
job_id: str
|
| 65 |
+
status: str
|
| 66 |
+
|
| 67 |
+
class RenderVideoRequest(BaseModel):
|
| 68 |
+
job_id: Optional[str] = None
|
| 69 |
+
|
| 70 |
+
class RenderVideoResponse(BaseModel):
|
| 71 |
+
job_id: str
|
| 72 |
+
status: str
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class OcrPreviewResponse(BaseModel):
|
| 76 |
+
"""Stateless OCR preview before POST .../solve (no DB writes, no job)."""
|
| 77 |
+
|
| 78 |
+
ocr_text: str
|
| 79 |
+
user_message: str = ""
|
| 80 |
+
combined_draft: str
|
app/ocr_celery.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run OCR on a remote worker via Celery (queue `ocr`) when OCR_USE_CELERY is enabled."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
from typing import TYPE_CHECKING
|
| 8 |
+
|
| 9 |
+
import anyio
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
from agents.ocr_agent import OCRAgent
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def ocr_celery_enabled() -> bool:
|
| 18 |
+
return os.getenv("OCR_USE_CELERY", "").strip().lower() in ("1", "true", "yes", "on")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _ocr_timeout_sec() -> float:
|
| 22 |
+
raw = os.getenv("OCR_CELERY_TIMEOUT_SEC", "180")
|
| 23 |
+
try:
|
| 24 |
+
return max(30.0, float(raw))
|
| 25 |
+
except ValueError:
|
| 26 |
+
return 180.0
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _run_ocr_celery_sync(image_url: str) -> str:
|
| 30 |
+
from worker.ocr_tasks import run_ocr_from_url
|
| 31 |
+
|
| 32 |
+
async_result = run_ocr_from_url.apply_async(args=[image_url])
|
| 33 |
+
return async_result.get(timeout=_ocr_timeout_sec())
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _is_ocr_error_response(text: str) -> bool:
|
| 37 |
+
s = (text or "").lstrip()
|
| 38 |
+
return s.startswith("Error:")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
async def ocr_from_image_url(image_url: str, fallback_agent: "OCRAgent") -> str:
|
| 42 |
+
"""
|
| 43 |
+
If OCR_USE_CELERY: delegate to Celery task `run_ocr_from_url` (worker queue `ocr`, raw OCR only),
|
| 44 |
+
then run ``refine_with_llm`` on the API process.
|
| 45 |
+
Else: use fallback_agent.process_url (in-process full pipeline).
|
| 46 |
+
"""
|
| 47 |
+
if not ocr_celery_enabled():
|
| 48 |
+
return await fallback_agent.process_url(image_url)
|
| 49 |
+
logger.info("OCR_USE_CELERY: delegating OCR to Celery queue=ocr (LLM refine on API)")
|
| 50 |
+
raw = await anyio.to_thread.run_sync(_run_ocr_celery_sync, image_url)
|
| 51 |
+
raw = raw if raw is not None else ""
|
| 52 |
+
if not raw.strip() or _is_ocr_error_response(raw):
|
| 53 |
+
return raw
|
| 54 |
+
return await fallback_agent.refine_with_llm(raw)
|
app/ocr_local_file.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OCR from a local file path, optionally via Celery worker (upload temp blob first)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
from typing import TYPE_CHECKING
|
| 8 |
+
|
| 9 |
+
from app.chat_image_upload import (
|
| 10 |
+
delete_storage_object,
|
| 11 |
+
upload_ephemeral_ocr_blob,
|
| 12 |
+
validate_chat_image_bytes,
|
| 13 |
+
)
|
| 14 |
+
from app.ocr_celery import ocr_celery_enabled, ocr_from_image_url
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from agents.ocr_agent import OCRAgent
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
async def ocr_from_local_image_path(
|
| 23 |
+
local_path: str,
|
| 24 |
+
original_filename: str | None,
|
| 25 |
+
fallback_agent: "OCRAgent",
|
| 26 |
+
) -> str:
|
| 27 |
+
"""
|
| 28 |
+
Run OCR on a file on local disk. If OCR_USE_Celery, upload to ephemeral storage URL
|
| 29 |
+
then delegate to worker; otherwise process_image in-process.
|
| 30 |
+
"""
|
| 31 |
+
if not ocr_celery_enabled():
|
| 32 |
+
return await fallback_agent.process_image(local_path)
|
| 33 |
+
|
| 34 |
+
with open(local_path, "rb") as f:
|
| 35 |
+
body = f.read()
|
| 36 |
+
ext = os.path.splitext(original_filename or local_path)[1].lower() or ".png"
|
| 37 |
+
_, content_type = validate_chat_image_bytes(original_filename or local_path, body, None)
|
| 38 |
+
bucket = os.getenv("SUPABASE_IMAGE_BUCKET", "image")
|
| 39 |
+
path, url = upload_ephemeral_ocr_blob(body, ext, content_type)
|
| 40 |
+
try:
|
| 41 |
+
return await ocr_from_image_url(url, fallback_agent)
|
| 42 |
+
finally:
|
| 43 |
+
delete_storage_object(bucket, path)
|
app/ocr_text_merge.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Helpers for OCR preview combined draft (no Pydantic email deps)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def build_combined_ocr_preview_draft(user_message: Optional[str], ocr_text: str) -> str:
|
| 9 |
+
"""Merge user caption and OCR text for confirm step (user message first, then OCR)."""
|
| 10 |
+
u = (user_message or "").strip()
|
| 11 |
+
o = (ocr_text or "").strip()
|
| 12 |
+
if u and o:
|
| 13 |
+
return f"{u}\n\n{o}"
|
| 14 |
+
return u or o
|
app/routers/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from . import auth, sessions, solve
|
app/routers/auth.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 2 |
+
from app.dependencies import get_current_user_id
|
| 3 |
+
from app.supabase_client import get_supabase
|
| 4 |
+
from app.models.schemas import UserProfile
|
| 5 |
+
import uuid
|
| 6 |
+
|
| 7 |
+
router = APIRouter(prefix="/api/v1/auth", tags=["Auth"])
|
| 8 |
+
|
| 9 |
+
@router.get("/me")
|
| 10 |
+
async def get_me(user_id=Depends(get_current_user_id)):
|
| 11 |
+
"""获取当前登录用户的信息 (Retrieve current user profile)"""
|
| 12 |
+
supabase = get_supabase()
|
| 13 |
+
res = supabase.table("profiles").select("*").eq("id", user_id).execute()
|
| 14 |
+
if not res.data:
|
| 15 |
+
raise HTTPException(status_code=404, detail="Profile not found.")
|
| 16 |
+
return res.data[0]
|
| 17 |
+
|
| 18 |
+
@router.patch("/me")
|
| 19 |
+
async def update_me(data: dict, user_id=Depends(get_current_user_id)):
|
| 20 |
+
"""Cập nhật profile hiện tại (Update current profile)"""
|
| 21 |
+
supabase = get_supabase()
|
| 22 |
+
res = supabase.table("profiles").update(data).eq("id", user_id).execute()
|
| 23 |
+
return res.data[0]
|
app/routers/sessions.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import time
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 8 |
+
|
| 9 |
+
from app.dependencies import get_current_user_id
|
| 10 |
+
from app.logutil import log_step
|
| 11 |
+
from app.session_cache import (
|
| 12 |
+
get_sessions_list_cached,
|
| 13 |
+
invalidate_for_user,
|
| 14 |
+
invalidate_session_owner,
|
| 15 |
+
session_owned_by_user,
|
| 16 |
+
)
|
| 17 |
+
from app.supabase_client import get_supabase
|
| 18 |
+
|
| 19 |
+
router = APIRouter(prefix="/api/v1/sessions", tags=["Sessions"])
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@router.get("", response_model=List[dict])
|
| 24 |
+
async def list_sessions(user_id=Depends(get_current_user_id)):
|
| 25 |
+
"""Danh sách các phiên chat của người dùng (List user's chat sessions)"""
|
| 26 |
+
supabase = get_supabase()
|
| 27 |
+
t0 = time.perf_counter()
|
| 28 |
+
|
| 29 |
+
def fetch() -> list:
|
| 30 |
+
res = (
|
| 31 |
+
supabase.table("sessions")
|
| 32 |
+
.select("id, user_id, title, created_at, updated_at")
|
| 33 |
+
.eq("user_id", user_id)
|
| 34 |
+
.order("updated_at", desc=True)
|
| 35 |
+
.execute()
|
| 36 |
+
)
|
| 37 |
+
log_step("db_select", table="sessions", op="list", user_id=str(user_id))
|
| 38 |
+
return res.data
|
| 39 |
+
|
| 40 |
+
out = get_sessions_list_cached(str(user_id), fetch)
|
| 41 |
+
logger.info(
|
| 42 |
+
"sessions.list user=%s count=%d %.1fms",
|
| 43 |
+
user_id,
|
| 44 |
+
len(out),
|
| 45 |
+
(time.perf_counter() - t0) * 1000,
|
| 46 |
+
)
|
| 47 |
+
return out
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@router.post("", response_model=dict)
|
| 51 |
+
async def create_session(user_id=Depends(get_current_user_id)):
|
| 52 |
+
"""Tạo một phiên chat mới (Create a new chat session)"""
|
| 53 |
+
supabase = get_supabase()
|
| 54 |
+
t0 = time.perf_counter()
|
| 55 |
+
res = supabase.table("sessions").insert(
|
| 56 |
+
{"user_id": user_id, "title": "Bài toán mới"}
|
| 57 |
+
).execute()
|
| 58 |
+
log_step("db_insert", table="sessions", op="create")
|
| 59 |
+
invalidate_for_user(str(user_id))
|
| 60 |
+
row = res.data[0]
|
| 61 |
+
logger.info(
|
| 62 |
+
"sessions.create user=%s id=%s %.1fms",
|
| 63 |
+
user_id,
|
| 64 |
+
row.get("id"),
|
| 65 |
+
(time.perf_counter() - t0) * 1000,
|
| 66 |
+
)
|
| 67 |
+
return row
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@router.get("/{session_id}/messages", response_model=List[dict])
|
| 71 |
+
async def get_session_messages(session_id: str, user_id=Depends(get_current_user_id)):
|
| 72 |
+
"""Lấy toàn bộ lịch sử tin nhắn của một phiên (Get chat history for a session)"""
|
| 73 |
+
supabase = get_supabase()
|
| 74 |
+
|
| 75 |
+
def owns() -> bool:
|
| 76 |
+
res = (
|
| 77 |
+
supabase.table("sessions")
|
| 78 |
+
.select("id")
|
| 79 |
+
.eq("id", session_id)
|
| 80 |
+
.eq("user_id", user_id)
|
| 81 |
+
.execute()
|
| 82 |
+
)
|
| 83 |
+
log_step("db_select", table="sessions", op="owner_check", session_id=session_id)
|
| 84 |
+
return bool(res.data)
|
| 85 |
+
|
| 86 |
+
if not session_owned_by_user(session_id, str(user_id), owns):
|
| 87 |
+
raise HTTPException(
|
| 88 |
+
status_code=403, detail="Forbidden: You do not own this session."
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
res = (
|
| 92 |
+
supabase.table("messages")
|
| 93 |
+
.select("*")
|
| 94 |
+
.eq("session_id", session_id)
|
| 95 |
+
.order("created_at", desc=False)
|
| 96 |
+
.execute()
|
| 97 |
+
)
|
| 98 |
+
log_step("db_select", table="messages", op="list", session_id=session_id)
|
| 99 |
+
return res.data
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@router.delete("/{session_id}")
|
| 103 |
+
async def delete_session(session_id: str, user_id=Depends(get_current_user_id)):
|
| 104 |
+
"""Xóa một phiên chat (Delete a chat session)"""
|
| 105 |
+
supabase = get_supabase()
|
| 106 |
+
|
| 107 |
+
def owns() -> bool:
|
| 108 |
+
res = (
|
| 109 |
+
supabase.table("sessions")
|
| 110 |
+
.select("id")
|
| 111 |
+
.eq("id", session_id)
|
| 112 |
+
.eq("user_id", user_id)
|
| 113 |
+
.execute()
|
| 114 |
+
)
|
| 115 |
+
return bool(res.data)
|
| 116 |
+
|
| 117 |
+
if not session_owned_by_user(session_id, str(user_id), owns):
|
| 118 |
+
raise HTTPException(
|
| 119 |
+
status_code=403, detail="Forbidden: You do not own this session."
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# jobs.session_id FK must be cleared before sessions row
|
| 123 |
+
supabase.table("jobs").delete().eq("session_id", session_id).eq("user_id", user_id).execute()
|
| 124 |
+
log_step("db_delete", table="jobs", op="by_session", session_id=session_id)
|
| 125 |
+
supabase.table("messages").delete().eq("session_id", session_id).execute()
|
| 126 |
+
log_step("db_delete", table="messages", op="by_session", session_id=session_id)
|
| 127 |
+
res = (
|
| 128 |
+
supabase.table("sessions")
|
| 129 |
+
.delete()
|
| 130 |
+
.eq("id", session_id)
|
| 131 |
+
.eq("user_id", user_id)
|
| 132 |
+
.execute()
|
| 133 |
+
)
|
| 134 |
+
log_step("db_delete", table="sessions", session_id=session_id)
|
| 135 |
+
invalidate_for_user(str(user_id))
|
| 136 |
+
invalidate_session_owner(session_id, str(user_id))
|
| 137 |
+
return {"status": "ok", "deleted_id": session_id}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@router.patch("/{session_id}/title")
|
| 141 |
+
async def update_session_title(title: str, session_id: str, user_id=Depends(get_current_user_id)):
|
| 142 |
+
"""Cập nhật tiêu đề phiên chat (Rename a chat session)"""
|
| 143 |
+
supabase = get_supabase()
|
| 144 |
+
res = (
|
| 145 |
+
supabase.table("sessions")
|
| 146 |
+
.update({"title": title})
|
| 147 |
+
.eq("id", session_id)
|
| 148 |
+
.eq("user_id", user_id)
|
| 149 |
+
.execute()
|
| 150 |
+
)
|
| 151 |
+
log_step("db_update", table="sessions", op="title", session_id=session_id)
|
| 152 |
+
invalidate_for_user(str(user_id))
|
| 153 |
+
return res.data[0]
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@router.get("/{session_id}/assets", response_model=List[dict])
|
| 157 |
+
async def get_session_assets(session_id: str, user_id=Depends(get_current_user_id)):
|
| 158 |
+
"""Lấy danh sách video đã render trong session (Get versioned assets for a session)"""
|
| 159 |
+
supabase = get_supabase()
|
| 160 |
+
|
| 161 |
+
def owns() -> bool:
|
| 162 |
+
res = (
|
| 163 |
+
supabase.table("sessions")
|
| 164 |
+
.select("id")
|
| 165 |
+
.eq("id", session_id)
|
| 166 |
+
.eq("user_id", user_id)
|
| 167 |
+
.execute()
|
| 168 |
+
)
|
| 169 |
+
return bool(res.data)
|
| 170 |
+
|
| 171 |
+
if not session_owned_by_user(session_id, str(user_id), owns):
|
| 172 |
+
raise HTTPException(
|
| 173 |
+
status_code=403, detail="Forbidden: You do not own this session."
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
res = (
|
| 177 |
+
supabase.table("session_assets")
|
| 178 |
+
.select("*")
|
| 179 |
+
.eq("session_id", session_id)
|
| 180 |
+
.order("version", desc=True)
|
| 181 |
+
.execute()
|
| 182 |
+
)
|
| 183 |
+
log_step("db_select", table="session_assets", op="list", session_id=session_id)
|
| 184 |
+
return res.data
|
app/routers/solve.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import uuid
|
| 6 |
+
|
| 7 |
+
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, UploadFile
|
| 8 |
+
|
| 9 |
+
from agents.orchestrator import Orchestrator
|
| 10 |
+
from app.chat_image_upload import upload_session_chat_image, validate_chat_image_bytes
|
| 11 |
+
from app.ocr_celery import ocr_celery_enabled
|
| 12 |
+
from app.ocr_local_file import ocr_from_local_image_path
|
| 13 |
+
from app.dependencies import get_current_user_id
|
| 14 |
+
from app.errors import format_error_for_user
|
| 15 |
+
from app.logutil import log_pipeline_failure, log_pipeline_success, log_step
|
| 16 |
+
from app.models.schemas import (
|
| 17 |
+
OcrPreviewResponse,
|
| 18 |
+
RenderVideoRequest,
|
| 19 |
+
RenderVideoResponse,
|
| 20 |
+
SolveRequest,
|
| 21 |
+
SolveResponse,
|
| 22 |
+
)
|
| 23 |
+
from app.ocr_text_merge import build_combined_ocr_preview_draft
|
| 24 |
+
from app.session_cache import invalidate_for_user, session_owned_by_user
|
| 25 |
+
from app.supabase_client import get_supabase
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
router = APIRouter(prefix="/api/v1/sessions", tags=["Solve"])
|
| 29 |
+
|
| 30 |
+
# Eager init: all agents and models load at import time (also run in Docker build via scripts/prewarm_models.py).
|
| 31 |
+
ORCHESTRATOR = Orchestrator()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_orchestrator() -> Orchestrator:
|
| 35 |
+
return ORCHESTRATOR
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
_OCR_PREVIEW_MAX_BYTES = 10 * 1024 * 1024
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _assert_session_owner(supabase, session_id: str, user_id, uid: str, op: str) -> None:
|
| 42 |
+
def owns() -> bool:
|
| 43 |
+
res = (
|
| 44 |
+
supabase.table("sessions")
|
| 45 |
+
.select("id")
|
| 46 |
+
.eq("id", session_id)
|
| 47 |
+
.eq("user_id", user_id)
|
| 48 |
+
.execute()
|
| 49 |
+
)
|
| 50 |
+
log_step("db_select", table="sessions", op=op, session_id=session_id)
|
| 51 |
+
return bool(res.data)
|
| 52 |
+
|
| 53 |
+
if not session_owned_by_user(session_id, uid, owns):
|
| 54 |
+
log_pipeline_failure("solve_request", error="forbidden", session_id=session_id)
|
| 55 |
+
raise HTTPException(
|
| 56 |
+
status_code=403, detail="Forbidden: You do not own this session."
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _enqueue_solve_common(
|
| 61 |
+
supabase,
|
| 62 |
+
background_tasks: BackgroundTasks,
|
| 63 |
+
session_id: str,
|
| 64 |
+
user_id,
|
| 65 |
+
uid: str,
|
| 66 |
+
request: SolveRequest,
|
| 67 |
+
message_metadata: dict,
|
| 68 |
+
job_id: str,
|
| 69 |
+
) -> SolveResponse:
|
| 70 |
+
"""Insert user message, job row, enqueue pipeline; update title when first message."""
|
| 71 |
+
supabase.table("messages").insert(
|
| 72 |
+
{
|
| 73 |
+
"session_id": session_id,
|
| 74 |
+
"role": "user",
|
| 75 |
+
"type": "text",
|
| 76 |
+
"content": request.text,
|
| 77 |
+
"metadata": message_metadata,
|
| 78 |
+
}
|
| 79 |
+
).execute()
|
| 80 |
+
log_step("db_insert", table="messages", op="user_message", session_id=session_id)
|
| 81 |
+
|
| 82 |
+
supabase.table("jobs").insert(
|
| 83 |
+
{
|
| 84 |
+
"id": job_id,
|
| 85 |
+
"user_id": user_id,
|
| 86 |
+
"session_id": session_id,
|
| 87 |
+
"status": "processing",
|
| 88 |
+
"input_text": request.text,
|
| 89 |
+
}
|
| 90 |
+
).execute()
|
| 91 |
+
log_step("db_insert", table="jobs", job_id=job_id)
|
| 92 |
+
|
| 93 |
+
background_tasks.add_task(process_session_job, job_id, session_id, request, str(user_id))
|
| 94 |
+
|
| 95 |
+
title_check = supabase.table("sessions").select("title").eq("id", session_id).execute()
|
| 96 |
+
if title_check.data and title_check.data[0]["title"] == "Bài toán mới":
|
| 97 |
+
new_title = request.text[:50] + ("..." if len(request.text) > 50 else "")
|
| 98 |
+
supabase.table("sessions").update({"title": new_title}).eq("id", session_id).execute()
|
| 99 |
+
log_step("db_update", table="sessions", op="title_from_first_message")
|
| 100 |
+
invalidate_for_user(uid)
|
| 101 |
+
|
| 102 |
+
log_pipeline_success("solve_accepted", job_id=job_id, session_id=session_id)
|
| 103 |
+
return SolveResponse(job_id=job_id, status="processing")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@router.post("/{session_id}/ocr_preview", response_model=OcrPreviewResponse)
|
| 107 |
+
async def ocr_preview(
|
| 108 |
+
session_id: str,
|
| 109 |
+
user_id=Depends(get_current_user_id),
|
| 110 |
+
file: UploadFile = File(...),
|
| 111 |
+
user_message: str | None = Form(None),
|
| 112 |
+
):
|
| 113 |
+
"""
|
| 114 |
+
Run OCR on an uploaded image and merge with optional user_message into combined_draft.
|
| 115 |
+
Does not insert messages or start a solve job. After user confirms, call POST .../solve
|
| 116 |
+
with text=combined_draft (edited) and omit image_url to avoid double OCR.
|
| 117 |
+
"""
|
| 118 |
+
supabase = get_supabase()
|
| 119 |
+
uid = str(user_id)
|
| 120 |
+
_assert_session_owner(supabase, session_id, user_id, uid, "owner_check_ocr_preview")
|
| 121 |
+
|
| 122 |
+
body = await file.read()
|
| 123 |
+
if len(body) > _OCR_PREVIEW_MAX_BYTES:
|
| 124 |
+
raise HTTPException(
|
| 125 |
+
status_code=413,
|
| 126 |
+
detail=f"Image too large (max {_OCR_PREVIEW_MAX_BYTES // (1024 * 1024)} MB).",
|
| 127 |
+
)
|
| 128 |
+
if not body:
|
| 129 |
+
raise HTTPException(status_code=400, detail="Empty file.")
|
| 130 |
+
|
| 131 |
+
if ocr_celery_enabled():
|
| 132 |
+
validate_chat_image_bytes(file.filename, body, file.content_type)
|
| 133 |
+
|
| 134 |
+
suffix = os.path.splitext(file.filename or "")[1].lower()
|
| 135 |
+
if suffix not in (".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp", ""):
|
| 136 |
+
suffix = ".png"
|
| 137 |
+
temp_path = f"temp_ocr_preview_{uuid.uuid4()}{suffix or '.png'}"
|
| 138 |
+
try:
|
| 139 |
+
with open(temp_path, "wb") as f:
|
| 140 |
+
f.write(body)
|
| 141 |
+
ocr_text = await ocr_from_local_image_path(
|
| 142 |
+
temp_path, file.filename, get_orchestrator().ocr_agent
|
| 143 |
+
)
|
| 144 |
+
if ocr_text is None:
|
| 145 |
+
ocr_text = ""
|
| 146 |
+
finally:
|
| 147 |
+
if os.path.exists(temp_path):
|
| 148 |
+
os.remove(temp_path)
|
| 149 |
+
|
| 150 |
+
um = (user_message or "").strip()
|
| 151 |
+
combined = build_combined_ocr_preview_draft(user_message, ocr_text)
|
| 152 |
+
log_step("ocr_preview_done", session_id=session_id, ocr_len=len(ocr_text), user_len=len(um))
|
| 153 |
+
return OcrPreviewResponse(
|
| 154 |
+
ocr_text=ocr_text,
|
| 155 |
+
user_message=um,
|
| 156 |
+
combined_draft=combined,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
@router.post("/{session_id}/solve", response_model=SolveResponse)
|
| 161 |
+
async def solve_problem(
|
| 162 |
+
session_id: str,
|
| 163 |
+
request: SolveRequest,
|
| 164 |
+
background_tasks: BackgroundTasks,
|
| 165 |
+
user_id=Depends(get_current_user_id),
|
| 166 |
+
):
|
| 167 |
+
"""
|
| 168 |
+
Gửi câu hỏi giải toán trong một session (Submit geometry problem in a session).
|
| 169 |
+
Lưu câu hỏi vào history và bắt đầu tiến trình giải (chỉ giải toán và tạo hình tĩnh).
|
| 170 |
+
"""
|
| 171 |
+
supabase = get_supabase()
|
| 172 |
+
uid = str(user_id)
|
| 173 |
+
_assert_session_owner(supabase, session_id, user_id, uid, "owner_check")
|
| 174 |
+
|
| 175 |
+
message_metadata = {"image_url": request.image_url} if request.image_url else {}
|
| 176 |
+
job_id = str(uuid.uuid4())
|
| 177 |
+
return _enqueue_solve_common(
|
| 178 |
+
supabase,
|
| 179 |
+
background_tasks,
|
| 180 |
+
session_id,
|
| 181 |
+
user_id,
|
| 182 |
+
uid,
|
| 183 |
+
request,
|
| 184 |
+
message_metadata,
|
| 185 |
+
job_id,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@router.post("/{session_id}/solve_multipart", response_model=SolveResponse)
|
| 190 |
+
async def solve_multipart(
|
| 191 |
+
session_id: str,
|
| 192 |
+
background_tasks: BackgroundTasks,
|
| 193 |
+
user_id=Depends(get_current_user_id),
|
| 194 |
+
text: str = Form(...),
|
| 195 |
+
file: UploadFile = File(...),
|
| 196 |
+
):
|
| 197 |
+
"""
|
| 198 |
+
Gửi text + file ảnh trong một request multipart: validate, upload bucket `image`,
|
| 199 |
+
ghi session_assets, lưu message kèm metadata (URL, size, type), rồi enqueue solve
|
| 200 |
+
(image_url trỏ public URL để orchestrator OCR).
|
| 201 |
+
"""
|
| 202 |
+
supabase = get_supabase()
|
| 203 |
+
uid = str(user_id)
|
| 204 |
+
_assert_session_owner(supabase, session_id, user_id, uid, "owner_check_solve_multipart")
|
| 205 |
+
|
| 206 |
+
t = (text or "").strip()
|
| 207 |
+
if not t:
|
| 208 |
+
raise HTTPException(status_code=400, detail="text must not be empty.")
|
| 209 |
+
|
| 210 |
+
body = await file.read()
|
| 211 |
+
ext, content_type = validate_chat_image_bytes(file.filename, body, file.content_type)
|
| 212 |
+
|
| 213 |
+
job_id = str(uuid.uuid4())
|
| 214 |
+
up = upload_session_chat_image(session_id, job_id, body, ext, content_type)
|
| 215 |
+
public_url = up["public_url"]
|
| 216 |
+
|
| 217 |
+
message_metadata = {
|
| 218 |
+
"image_url": public_url,
|
| 219 |
+
"attachment": {
|
| 220 |
+
"public_url": public_url,
|
| 221 |
+
"storage_path": up["storage_path"],
|
| 222 |
+
"size_bytes": len(body),
|
| 223 |
+
"content_type": content_type,
|
| 224 |
+
"original_filename": file.filename or "",
|
| 225 |
+
"session_asset_id": up.get("session_asset_id"),
|
| 226 |
+
},
|
| 227 |
+
}
|
| 228 |
+
request = SolveRequest(text=t, image_url=public_url)
|
| 229 |
+
return _enqueue_solve_common(
|
| 230 |
+
supabase,
|
| 231 |
+
background_tasks,
|
| 232 |
+
session_id,
|
| 233 |
+
user_id,
|
| 234 |
+
uid,
|
| 235 |
+
request,
|
| 236 |
+
message_metadata,
|
| 237 |
+
job_id,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
@router.post("/{session_id}/render_video", response_model=RenderVideoResponse)
|
| 242 |
+
async def render_video(
|
| 243 |
+
session_id: str,
|
| 244 |
+
request: RenderVideoRequest,
|
| 245 |
+
background_tasks: BackgroundTasks,
|
| 246 |
+
user_id=Depends(get_current_user_id),
|
| 247 |
+
):
|
| 248 |
+
"""
|
| 249 |
+
Yêu cầu tạo video Manim từ trạng thái hình ảnh mới nhất của session.
|
| 250 |
+
"""
|
| 251 |
+
supabase = get_supabase()
|
| 252 |
+
|
| 253 |
+
# 1. Kiểm tra quyền sở hữu
|
| 254 |
+
res = supabase.table("sessions").select("id").eq("id", session_id).eq("user_id", user_id).execute()
|
| 255 |
+
if not res.data:
|
| 256 |
+
raise HTTPException(status_code=403, detail="Forbidden: You do not own this session.")
|
| 257 |
+
|
| 258 |
+
# 2. Tìm tin nhắn assistant có metadata hình học (cụ thể job_id hoặc mới nhất trong 10 tin nhắn gần nhất)
|
| 259 |
+
msg_res = (
|
| 260 |
+
supabase.table("messages")
|
| 261 |
+
.select("metadata")
|
| 262 |
+
.eq("session_id", session_id)
|
| 263 |
+
.eq("role", "assistant")
|
| 264 |
+
.order("created_at", desc=True)
|
| 265 |
+
.limit(10)
|
| 266 |
+
.execute()
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
latest_geometry = None
|
| 270 |
+
if msg_res.data:
|
| 271 |
+
for msg in msg_res.data:
|
| 272 |
+
meta = msg.get("metadata", {})
|
| 273 |
+
# Nếu có yêu cầu job_id cụ thể, phải khớp job_id
|
| 274 |
+
if request.job_id and meta.get("job_id") != request.job_id:
|
| 275 |
+
continue
|
| 276 |
+
|
| 277 |
+
# Phải có dữ liệu hình học
|
| 278 |
+
if meta.get("geometry_dsl") and meta.get("coordinates"):
|
| 279 |
+
latest_geometry = meta
|
| 280 |
+
break
|
| 281 |
+
|
| 282 |
+
if not latest_geometry:
|
| 283 |
+
raise HTTPException(status_code=404, detail="Không tìm thấy dữ liệu hình học để render video.")
|
| 284 |
+
|
| 285 |
+
# 3. Tạo Job rendering
|
| 286 |
+
job_id = str(uuid.uuid4())
|
| 287 |
+
supabase.table("jobs").insert({
|
| 288 |
+
"id": job_id,
|
| 289 |
+
"user_id": user_id,
|
| 290 |
+
"session_id": session_id,
|
| 291 |
+
"status": "rendering_queued",
|
| 292 |
+
"input_text": f"Render video requested at {job_id}",
|
| 293 |
+
}).execute()
|
| 294 |
+
|
| 295 |
+
# 4. Dispatch background task
|
| 296 |
+
background_tasks.add_task(process_render_job, job_id, session_id, latest_geometry)
|
| 297 |
+
|
| 298 |
+
return RenderVideoResponse(job_id=job_id, status="rendering_queued")
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
async def process_session_job(
|
| 302 |
+
job_id: str, session_id: str, request: SolveRequest, user_id: str
|
| 303 |
+
):
|
| 304 |
+
"""Tiến trình giải toán ngầm, tạo hình ảnh tĩnh."""
|
| 305 |
+
from app.websocket_manager import notify_status
|
| 306 |
+
|
| 307 |
+
async def status_update(status: str):
|
| 308 |
+
await notify_status(job_id, {"status": status, "job_id": job_id})
|
| 309 |
+
|
| 310 |
+
supabase = get_supabase()
|
| 311 |
+
try:
|
| 312 |
+
history_res = (
|
| 313 |
+
supabase.table("messages")
|
| 314 |
+
.select("*")
|
| 315 |
+
.eq("session_id", session_id)
|
| 316 |
+
.order("created_at", desc=False)
|
| 317 |
+
.execute()
|
| 318 |
+
)
|
| 319 |
+
history = history_res.data if history_res.data else []
|
| 320 |
+
|
| 321 |
+
result = await get_orchestrator().run(
|
| 322 |
+
request.text,
|
| 323 |
+
request.image_url,
|
| 324 |
+
job_id=job_id,
|
| 325 |
+
session_id=session_id,
|
| 326 |
+
status_callback=status_update,
|
| 327 |
+
history=history,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
status = result.get("status", "error") if "error" not in result else "error"
|
| 331 |
+
|
| 332 |
+
supabase.table("jobs").update({"status": status, "result": result}).eq(
|
| 333 |
+
"id", job_id
|
| 334 |
+
).execute()
|
| 335 |
+
|
| 336 |
+
supabase.table("messages").insert(
|
| 337 |
+
{
|
| 338 |
+
"session_id": session_id,
|
| 339 |
+
"role": "assistant",
|
| 340 |
+
"type": "analysis" if "error" not in result else "error",
|
| 341 |
+
"content": (
|
| 342 |
+
result.get("semantic_analysis", "Đã có lỗi xảy ra.")
|
| 343 |
+
if "error" not in result
|
| 344 |
+
else result["error"]
|
| 345 |
+
),
|
| 346 |
+
"metadata": {
|
| 347 |
+
"job_id": job_id,
|
| 348 |
+
"coordinates": result.get("coordinates"),
|
| 349 |
+
"geometry_dsl": result.get("geometry_dsl"),
|
| 350 |
+
"polygon_order": result.get("polygon_order", []),
|
| 351 |
+
"drawing_phases": result.get("drawing_phases", []),
|
| 352 |
+
"circles": result.get("circles", []),
|
| 353 |
+
"lines": result.get("lines", []),
|
| 354 |
+
"rays": result.get("rays", []),
|
| 355 |
+
"solution": result.get("solution"),
|
| 356 |
+
"is_3d": result.get("is_3d", False),
|
| 357 |
+
},
|
| 358 |
+
}
|
| 359 |
+
).execute()
|
| 360 |
+
|
| 361 |
+
await notify_status(job_id, {"status": status, "job_id": job_id, "result": result})
|
| 362 |
+
|
| 363 |
+
except Exception as e:
|
| 364 |
+
logger.exception("Error processing session job %s", job_id)
|
| 365 |
+
error_msg = format_error_for_user(e)
|
| 366 |
+
supabase = get_supabase()
|
| 367 |
+
supabase.table("jobs").update(
|
| 368 |
+
{"status": "error", "result": {"error": str(e)}}
|
| 369 |
+
).eq("id", job_id).execute()
|
| 370 |
+
supabase.table("messages").insert(
|
| 371 |
+
{
|
| 372 |
+
"session_id": session_id,
|
| 373 |
+
"role": "assistant",
|
| 374 |
+
"type": "error",
|
| 375 |
+
"content": error_msg,
|
| 376 |
+
"metadata": {"job_id": job_id},
|
| 377 |
+
}
|
| 378 |
+
).execute()
|
| 379 |
+
await notify_status(job_id, {"status": "error", "job_id": job_id, "error": error_msg})
|
| 380 |
+
|
| 381 |
+
async def process_render_job(job_id: str, session_id: str, geometry_data: dict):
|
| 382 |
+
"""Tiến trình render video từ metadata có sẵn."""
|
| 383 |
+
from app.websocket_manager import notify_status
|
| 384 |
+
from worker.tasks import render_geometry_video
|
| 385 |
+
|
| 386 |
+
await notify_status(job_id, {"status": "rendering_queued", "job_id": job_id})
|
| 387 |
+
|
| 388 |
+
# Prepare payload for Celery (similar to what orchestrator used to do)
|
| 389 |
+
result_payload = {
|
| 390 |
+
"geometry_dsl": geometry_data.get("geometry_dsl"),
|
| 391 |
+
"coordinates": geometry_data.get("coordinates"),
|
| 392 |
+
"polygon_order": geometry_data.get("polygon_order", []),
|
| 393 |
+
"drawing_phases": geometry_data.get("drawing_phases", []),
|
| 394 |
+
"circles": geometry_data.get("circles", []),
|
| 395 |
+
"lines": geometry_data.get("lines", []),
|
| 396 |
+
"rays": geometry_data.get("rays", []),
|
| 397 |
+
"semantic": geometry_data.get("semantic", {}),
|
| 398 |
+
"semantic_analysis": geometry_data.get("semantic_analysis", "🎬 Video minh họa dựng từ trạng thái gần nhất."),
|
| 399 |
+
"session_id": session_id,
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
try:
|
| 403 |
+
logger.info(f"[RenderJob] Attempting to dispatch Celery task for job {job_id}...")
|
| 404 |
+
render_geometry_video.delay(job_id, result_payload)
|
| 405 |
+
logger.info(f"[RenderJob] SUCCESS: Dispatched Celery task for job {job_id}")
|
| 406 |
+
except Exception as e:
|
| 407 |
+
logger.exception(f"[RenderJob] FAILED to dispatch Celery task: {e}")
|
| 408 |
+
supabase = get_supabase()
|
| 409 |
+
supabase.table("jobs").update({"status": "error", "result": {"error": f"Task dispatch failed: {str(e)}"}}).eq("id", job_id).execute()
|
| 410 |
+
await notify_status(job_id, {"status": "error", "job_id": job_id, "error": str(e)})
|
app/runtime_env.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Default process env vars (Paddle/OpenMP). Call as early as possible after load_dotenv."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def apply_runtime_env_defaults() -> None:
|
| 9 |
+
# Paddle respects OMP_NUM_THREADS at import; setdefault loses if platform already set 2+
|
| 10 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 11 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
| 12 |
+
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
app/session_cache.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TTL in-memory cache để giảm truy vấn Supabase lặp lại (list session, quyền sở hữu session)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Callable
|
| 6 |
+
|
| 7 |
+
from cachetools import TTLCache
|
| 8 |
+
|
| 9 |
+
from app.logutil import log_step
|
| 10 |
+
|
| 11 |
+
_session_list: TTLCache[str, list[Any]] = TTLCache(maxsize=512, ttl=45)
|
| 12 |
+
_session_owner: TTLCache[tuple[str, str], bool] = TTLCache(maxsize=4096, ttl=45)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def invalidate_for_user(user_id: str) -> None:
|
| 16 |
+
"""Xoá cache list session của user (sau create / delete / rename / solve đổi title)."""
|
| 17 |
+
_session_list.pop(user_id, None)
|
| 18 |
+
log_step("cache_invalidate", target="session_list", user_id=user_id)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def invalidate_session_owner(session_id: str, user_id: str) -> None:
|
| 22 |
+
_session_owner.pop((session_id, user_id), None)
|
| 23 |
+
log_step("cache_invalidate", target="session_owner", session_id=session_id, user_id=user_id)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_sessions_list_cached(user_id: str, fetch: Callable[[], list[Any]]) -> list[Any]:
|
| 27 |
+
if user_id in _session_list:
|
| 28 |
+
log_step("cache_hit", kind="session_list", user_id=user_id)
|
| 29 |
+
return _session_list[user_id]
|
| 30 |
+
log_step("cache_miss", kind="session_list", user_id=user_id)
|
| 31 |
+
data = fetch()
|
| 32 |
+
_session_list[user_id] = data
|
| 33 |
+
return data
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def session_owned_by_user(
|
| 37 |
+
session_id: str,
|
| 38 |
+
user_id: str,
|
| 39 |
+
fetch: Callable[[], bool],
|
| 40 |
+
) -> bool:
|
| 41 |
+
key = (session_id, user_id)
|
| 42 |
+
if key in _session_owner:
|
| 43 |
+
log_step("cache_hit", kind="session_owner", session_id=session_id)
|
| 44 |
+
return _session_owner[key]
|
| 45 |
+
log_step("cache_miss", kind="session_owner", session_id=session_id)
|
| 46 |
+
ok = fetch()
|
| 47 |
+
_session_owner[key] = ok
|
| 48 |
+
return ok
|
app/supabase_client.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from supabase import Client, ClientOptions, create_client
|
| 3 |
+
from supabase_auth import SyncMemoryStorage
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
|
| 6 |
+
load_dotenv()
|
| 7 |
+
|
| 8 |
+
from app.url_utils import sanitize_env
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_supabase() -> Client:
|
| 12 |
+
"""Service-role client for server-side operations (bypasses RLS when policies expect service role)."""
|
| 13 |
+
url = sanitize_env(os.getenv("SUPABASE_URL"))
|
| 14 |
+
key = sanitize_env(os.getenv("SUPABASE_SERVICE_ROLE_KEY") or os.getenv("SUPABASE_KEY"))
|
| 15 |
+
if not url or not key:
|
| 16 |
+
raise RuntimeError(
|
| 17 |
+
"SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY (or SUPABASE_KEY) must be set"
|
| 18 |
+
)
|
| 19 |
+
return create_client(url, key)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_supabase_for_user_jwt(access_token: str) -> Client:
|
| 23 |
+
"""
|
| 24 |
+
Client scoped to the logged-in user: PostgREST sends the user's JWT so RLS applies.
|
| 25 |
+
Use SUPABASE_ANON_KEY (publishable), not the service role key.
|
| 26 |
+
"""
|
| 27 |
+
url = sanitize_env(os.getenv("SUPABASE_URL"))
|
| 28 |
+
anon = sanitize_env(os.getenv("SUPABASE_ANON_KEY") or os.getenv("NEXT_PUBLIC_SUPABASE_ANON_KEY"))
|
| 29 |
+
if not url or not anon:
|
| 30 |
+
raise RuntimeError(
|
| 31 |
+
"SUPABASE_URL and SUPABASE_ANON_KEY (or NEXT_PUBLIC_SUPABASE_ANON_KEY) must be set "
|
| 32 |
+
"for user-scoped Supabase access"
|
| 33 |
+
)
|
| 34 |
+
base_opts = ClientOptions(storage=SyncMemoryStorage())
|
| 35 |
+
merged_headers = {**dict(base_opts.headers), "Authorization": f"Bearer {access_token}"}
|
| 36 |
+
opts = ClientOptions(storage=SyncMemoryStorage(), headers=merged_headers)
|
| 37 |
+
return create_client(url, anon, opts)
|
app/url_utils.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Normalize URLs / env strings (HF secrets and copy-paste often include trailing newlines)."""
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def sanitize_url(value: str | None) -> str | None:
|
| 5 |
+
if value is None:
|
| 6 |
+
return None
|
| 7 |
+
s = value.strip().replace("\r", "").replace("\n", "").replace("\t", "")
|
| 8 |
+
return s or None
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def sanitize_env(value: str | None) -> str | None:
|
| 12 |
+
"""Strip whitespace and line breaks from environment-backed strings."""
|
| 13 |
+
return sanitize_url(value)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# OpenAI SDK (>=1.x) requires a non-empty api_key at client construction (Docker build / prewarm has no secrets).
|
| 17 |
+
_OPENAI_API_KEY_BUILD_PLACEHOLDER = "build-placeholder-openrouter-not-for-production"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def openai_compatible_api_key(raw: str | None) -> str:
|
| 21 |
+
"""Return sanitized API key, or a placeholder so AsyncOpenAI() can be constructed without env at build time."""
|
| 22 |
+
k = sanitize_env(raw)
|
| 23 |
+
return k if k else _OPENAI_API_KEY_BUILD_PLACEHOLDER
|
app/websocket_manager.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""WebSocket connection registry and job status notifications (avoid circular imports with main)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Dict, List
|
| 7 |
+
|
| 8 |
+
from fastapi import WebSocket, WebSocketDisconnect
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
active_connections: Dict[str, List[WebSocket]] = {}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
async def notify_status(job_id: str, data: dict) -> None:
|
| 16 |
+
if job_id not in active_connections:
|
| 17 |
+
return
|
| 18 |
+
for connection in list(active_connections[job_id]):
|
| 19 |
+
try:
|
| 20 |
+
await connection.send_json(data)
|
| 21 |
+
except Exception as e:
|
| 22 |
+
logger.error("WS error sending to %s: %s", job_id, e)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def register_websocket_routes(app) -> None:
|
| 26 |
+
"""Attach websocket endpoint to the FastAPI app."""
|
| 27 |
+
|
| 28 |
+
@app.websocket("/ws/{job_id}")
|
| 29 |
+
async def websocket_endpoint(websocket: WebSocket, job_id: str) -> None:
|
| 30 |
+
await websocket.accept()
|
| 31 |
+
if job_id not in active_connections:
|
| 32 |
+
active_connections[job_id] = []
|
| 33 |
+
active_connections[job_id].append(websocket)
|
| 34 |
+
try:
|
| 35 |
+
while True:
|
| 36 |
+
await websocket.receive_text()
|
| 37 |
+
except WebSocketDisconnect:
|
| 38 |
+
active_connections[job_id].remove(websocket)
|
| 39 |
+
if not active_connections[job_id]:
|
| 40 |
+
del active_connections[job_id]
|
clean_ports.sh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Script to kill all project-related processes for a clean restart
|
| 3 |
+
|
| 4 |
+
echo "🧹 Cleaning up project processes..."
|
| 5 |
+
|
| 6 |
+
# Kill things on ports 8000 (Backend) and 3000 (Frontend)
|
| 7 |
+
PORTS="8000 3000 11020"
|
| 8 |
+
for PORT in $PORTS; do
|
| 9 |
+
PIDS=$(lsof -ti :$PORT)
|
| 10 |
+
if [ ! -z "$PIDS" ]; then
|
| 11 |
+
echo "Killing processes on port $PORT: $PIDS"
|
| 12 |
+
kill -9 $PIDS 2>/dev/null
|
| 13 |
+
fi
|
| 14 |
+
done
|
| 15 |
+
|
| 16 |
+
# Kill by process name
|
| 17 |
+
echo "Killing any remaining Celery, Uvicorn, or Manim processes..."
|
| 18 |
+
pkill -9 -f "celery" 2>/dev/null
|
| 19 |
+
pkill -9 -f "uvicorn" 2>/dev/null
|
| 20 |
+
pkill -9 -f "manim" 2>/dev/null
|
| 21 |
+
|
| 22 |
+
echo "✅ Done. You can now restart your Backend, Worker, and Frontend."
|
dump.rdb
ADDED
|
Binary file (5.44 kB). View file
|
|
|
geometry_render/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Manim geometry script generation and rendering (worker-safe, no LLM agents)."""
|
| 2 |
+
|
| 3 |
+
from .renderer import RendererAgent
|
| 4 |
+
|
| 5 |
+
__all__ = ["RendererAgent"]
|
geometry_render/renderer.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import glob
|
| 4 |
+
import string
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Dict, Any, List
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RendererAgent:
|
| 12 |
+
"""
|
| 13 |
+
Renderer — generates Manim scripts from geometry data.
|
| 14 |
+
|
| 15 |
+
Drawing happens in phases:
|
| 16 |
+
Phase 1: Main polygon (base shape with correct vertex order)
|
| 17 |
+
Phase 2: Auxiliary points and segments (midpoints, derived segments)
|
| 18 |
+
Phase 3: Labels for all points
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def generate_manim_script(self, data: Dict[str, Any]) -> str:
|
| 22 |
+
coords: Dict[str, List[float]] = data.get("coordinates", {})
|
| 23 |
+
polygon_order: List[str] = data.get("polygon_order", [])
|
| 24 |
+
circles_meta: List[Dict] = data.get("circles", [])
|
| 25 |
+
lines_meta: List[List[str]] = data.get("lines", [])
|
| 26 |
+
rays_meta: List[List[str]] = data.get("rays", [])
|
| 27 |
+
drawing_phases: List[Dict] = data.get("drawing_phases", [])
|
| 28 |
+
semantic: Dict[str, Any] = data.get("semantic", {})
|
| 29 |
+
shape_type = semantic.get("type", "").lower()
|
| 30 |
+
|
| 31 |
+
# ── Detect 3D Context ────────────────────────────────────────────────
|
| 32 |
+
is_3d = False
|
| 33 |
+
for pos in coords.values():
|
| 34 |
+
if len(pos) >= 3 and abs(pos[2]) > 0.001:
|
| 35 |
+
is_3d = True
|
| 36 |
+
break
|
| 37 |
+
if shape_type in ["pyramid", "prism", "sphere"]:
|
| 38 |
+
is_3d = True
|
| 39 |
+
|
| 40 |
+
# ── Fallback: infer polygon_order from coords keys (alphabetical uppercase) ──
|
| 41 |
+
if not polygon_order:
|
| 42 |
+
base = sorted(
|
| 43 |
+
[pid for pid in coords if pid in string.ascii_uppercase],
|
| 44 |
+
key=lambda p: string.ascii_uppercase.index(p)
|
| 45 |
+
)
|
| 46 |
+
polygon_order = base
|
| 47 |
+
|
| 48 |
+
# Separate base points from derived (multi-char or lowercase)
|
| 49 |
+
base_ids = [pid for pid in polygon_order if pid in coords]
|
| 50 |
+
derived_ids = [pid for pid in coords if pid not in polygon_order]
|
| 51 |
+
|
| 52 |
+
scene_base = "ThreeDScene" if is_3d else "MovingCameraScene"
|
| 53 |
+
lines = [
|
| 54 |
+
"from manim import *",
|
| 55 |
+
"",
|
| 56 |
+
f"class GeometryScene({scene_base}):",
|
| 57 |
+
" def construct(self):",
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
if is_3d:
|
| 61 |
+
lines.append(" # 3D Setup")
|
| 62 |
+
lines.append(" self.set_camera_orientation(phi=75*DEGREES, theta=-45*DEGREES)")
|
| 63 |
+
lines.append(" axes = ThreeDAxes(axis_config={'stroke_width': 1})")
|
| 64 |
+
lines.append(" axes.set_opacity(0.3)")
|
| 65 |
+
lines.append(" self.add(axes)")
|
| 66 |
+
lines.append(" self.begin_ambient_camera_rotation(rate=0.1)")
|
| 67 |
+
lines.append("")
|
| 68 |
+
|
| 69 |
+
# ── Declare all dots and labels ───────────────────────────────────────
|
| 70 |
+
for pid, pos in coords.items():
|
| 71 |
+
x, y, z = 0, 0, 0
|
| 72 |
+
if len(pos) >= 1: x = round(pos[0], 4)
|
| 73 |
+
if len(pos) >= 2: y = round(pos[1], 4)
|
| 74 |
+
if len(pos) >= 3: z = round(pos[2], 4)
|
| 75 |
+
|
| 76 |
+
dot_class = "Dot3D" if is_3d else "Dot"
|
| 77 |
+
lines.append(f" p_{pid} = {dot_class}(point=[{x}, {y}, {z}], color=WHITE, radius=0.08)")
|
| 78 |
+
|
| 79 |
+
if is_3d:
|
| 80 |
+
lines.append(
|
| 81 |
+
f" l_{pid} = Text('{pid}', font_size=20, color=WHITE)"
|
| 82 |
+
f".move_to(p_{pid}.get_center() + [0.2, 0.2, 0.2])"
|
| 83 |
+
)
|
| 84 |
+
# Ensure labels follow camera in 3D (fixed orientation)
|
| 85 |
+
lines.append(f" self.add_fixed_orientation_mobjects(l_{pid})")
|
| 86 |
+
else:
|
| 87 |
+
lines.append(
|
| 88 |
+
f" l_{pid} = Text('{pid}', font_size=22, color=WHITE)"
|
| 89 |
+
f".next_to(p_{pid}, UR, buff=0.15)"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# ── 3D Shape Special: Pyramid/Prism Faces ────────────────────────────
|
| 93 |
+
if is_3d and shape_type == "pyramid" and len(base_ids) >= 3:
|
| 94 |
+
# Find apex (usually 'S')
|
| 95 |
+
apex_id = "S" if "S" in coords else derived_ids[0] if derived_ids else None
|
| 96 |
+
if apex_id:
|
| 97 |
+
# Draw base face
|
| 98 |
+
base_pts = ", ".join([f"p_{pid}.get_center()" for pid in base_ids])
|
| 99 |
+
lines.append(f" base_face = Polygon({base_pts}, color=BLUE, fill_opacity=0.1)")
|
| 100 |
+
lines.append(" self.play(Create(base_face), run_time=1.0)")
|
| 101 |
+
|
| 102 |
+
# Draw side faces
|
| 103 |
+
for i in range(len(base_ids)):
|
| 104 |
+
p1 = base_ids[i]
|
| 105 |
+
p2 = base_ids[(i + 1) % len(base_ids)]
|
| 106 |
+
face_pts = f"p_{apex_id}.get_center(), p_{p1}.get_center(), p_{p2}.get_center()"
|
| 107 |
+
lines.append(
|
| 108 |
+
f" side_{i} = Polygon({face_pts}, color=BLUE, stroke_width=1, fill_opacity=0.05)"
|
| 109 |
+
)
|
| 110 |
+
lines.append(f" self.play(Create(side_{i}), run_time=0.5)")
|
| 111 |
+
|
| 112 |
+
# ── Circles ──────────────────────────────────────────────────────────
|
| 113 |
+
for i, c in enumerate(circles_meta):
|
| 114 |
+
center = c["center"]
|
| 115 |
+
r = c["radius"]
|
| 116 |
+
if center in coords:
|
| 117 |
+
cx, cy, cz = 0, 0, 0
|
| 118 |
+
pos = coords[center]
|
| 119 |
+
if len(pos) >= 1: cx = round(pos[0], 4)
|
| 120 |
+
if len(pos) >= 2: cy = round(pos[1], 4)
|
| 121 |
+
if len(pos) >= 3: cz = round(pos[2], 4)
|
| 122 |
+
lines.append(
|
| 123 |
+
f" circle_{i} = Circle(radius={r}, color=BLUE)"
|
| 124 |
+
f".move_to([{cx}, {cy}, {cz}])"
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# ── Infinite Lines & Rays ────────────────────────────────────────────
|
| 128 |
+
# (Standard Line works for 3D coordinates in Manim)
|
| 129 |
+
for i, (p1, p2) in enumerate(lines_meta):
|
| 130 |
+
if p1 in coords and p2 in coords:
|
| 131 |
+
lines.append(
|
| 132 |
+
f" line_ext_{i} = Line(p_{p1}.get_center(), p_{p2}.get_center(), color=GRAY_D, stroke_width=2)"
|
| 133 |
+
f".scale(20)"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
for i, (p1, p2) in enumerate(rays_meta):
|
| 137 |
+
if p1 in coords and p2 in coords:
|
| 138 |
+
lines.append(
|
| 139 |
+
f" ray_{i} = Line(p_{p1}.get_center(), p_{p1}.get_center() + 15 * (p_{p2}.get_center() - p_{p1}.get_center()),"
|
| 140 |
+
f" color=GRAY_C, stroke_width=2)"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# ── Camera auto-fit group (Only for 2D) ──────────────────────────────
|
| 144 |
+
if not is_3d:
|
| 145 |
+
all_dot_names = [f"p_{pid}" for pid in coords]
|
| 146 |
+
all_names_str = ", ".join(all_dot_names)
|
| 147 |
+
lines.append(f" _all = VGroup({all_names_str})")
|
| 148 |
+
lines.append(" self.camera.frame.set_width(max(_all.width * 2.0, 8))")
|
| 149 |
+
lines.append(" self.camera.frame.move_to(_all)")
|
| 150 |
+
lines.append("")
|
| 151 |
+
|
| 152 |
+
# ── Phase 1: Base polygon ─────────────────────────────────────────────
|
| 153 |
+
if len(base_ids) >= 3:
|
| 154 |
+
pts_str = ", ".join([f"p_{pid}.get_center()" for pid in base_ids])
|
| 155 |
+
lines.append(f" poly = Polygon({pts_str}, color=BLUE, fill_color=BLUE, fill_opacity=0.15)")
|
| 156 |
+
lines.append(" self.play(Create(poly), run_time=1.5)")
|
| 157 |
+
elif len(base_ids) == 2:
|
| 158 |
+
p1, p2 = base_ids
|
| 159 |
+
lines.append(f" base_line = Line(p_{p1}.get_center(), p_{p2}.get_center(), color=BLUE)")
|
| 160 |
+
lines.append(" self.play(Create(base_line), run_time=1.0)")
|
| 161 |
+
|
| 162 |
+
# Draw base points
|
| 163 |
+
if base_ids:
|
| 164 |
+
base_dots_str = ", ".join([f"p_{pid}" for pid in base_ids])
|
| 165 |
+
lines.append(f" self.play(FadeIn(VGroup({base_dots_str})), run_time=0.5)")
|
| 166 |
+
lines.append(" self.wait(0.5)")
|
| 167 |
+
|
| 168 |
+
# ── Phase 2: Auxiliary points and segments ────────────────────────────
|
| 169 |
+
if derived_ids:
|
| 170 |
+
derived_dots_str = ", ".join([f"p_{pid}" for pid in derived_ids])
|
| 171 |
+
lines.append(f" self.play(FadeIn(VGroup({derived_dots_str})), run_time=0.8)")
|
| 172 |
+
|
| 173 |
+
# Segments from drawing_phases
|
| 174 |
+
segment_lines = []
|
| 175 |
+
for phase in drawing_phases:
|
| 176 |
+
if phase.get("phase") == 2:
|
| 177 |
+
for seg in phase.get("segments", []):
|
| 178 |
+
if len(seg) == 2 and seg[0] in coords and seg[1] in coords:
|
| 179 |
+
p1, p2 = seg[0], seg[1]
|
| 180 |
+
seg_var = f"seg_{p1}_{p2}"
|
| 181 |
+
lines.append(
|
| 182 |
+
f" {seg_var} = Line(p_{p1}.get_center(), p_{p2}.get_center(),"
|
| 183 |
+
f" color=YELLOW)"
|
| 184 |
+
)
|
| 185 |
+
segment_lines.append(seg_var)
|
| 186 |
+
|
| 187 |
+
if segment_lines:
|
| 188 |
+
segs_str = ", ".join([f"Create({sv})" for sv in segment_lines])
|
| 189 |
+
lines.append(f" self.play({segs_str}, run_time=1.2)")
|
| 190 |
+
|
| 191 |
+
if derived_ids or segment_lines:
|
| 192 |
+
lines.append(" self.wait(0.5)")
|
| 193 |
+
|
| 194 |
+
# ── Phase 3: All labels ───────────────────────────────────────────────
|
| 195 |
+
all_labels_str = ", ".join([f"l_{pid}" for pid in coords])
|
| 196 |
+
lines.append(f" self.play(FadeIn(VGroup({all_labels_str})), run_time=0.8)")
|
| 197 |
+
|
| 198 |
+
# ── Circles phase ─────────────────────────────────────────────────────
|
| 199 |
+
for i in range(len(circles_meta)):
|
| 200 |
+
lines.append(f" self.play(Create(circle_{i}), run_time=1.5)")
|
| 201 |
+
|
| 202 |
+
# ── Lines & Rays phase ────────────────────────────────────────────────
|
| 203 |
+
if lines_meta or rays_meta:
|
| 204 |
+
lr_anims = []
|
| 205 |
+
for i in range(len(lines_meta)):
|
| 206 |
+
lr_anims.append(f"Create(line_ext_{i})")
|
| 207 |
+
for i in range(len(rays_meta)):
|
| 208 |
+
lr_anims.append(f"Create(ray_{i})")
|
| 209 |
+
lines.append(f" self.play({', '.join(lr_anims)}, run_time=1.5)")
|
| 210 |
+
|
| 211 |
+
lines.append(" self.wait(2)")
|
| 212 |
+
|
| 213 |
+
return "\n".join(lines)
|
| 214 |
+
|
| 215 |
+
def run_manim(self, script_content: str, job_id: str) -> str:
|
| 216 |
+
script_file = f"{job_id}.py"
|
| 217 |
+
with open(script_file, "w") as f:
|
| 218 |
+
f.write(script_content)
|
| 219 |
+
|
| 220 |
+
try:
|
| 221 |
+
if os.getenv("MOCK_VIDEO") == "true":
|
| 222 |
+
logger.info(f"MOCK_VIDEO is true. Skipping Manim for job {job_id}")
|
| 223 |
+
# Create a dummy file if needed, or just return a path that exists
|
| 224 |
+
dummy_path = f"videos/{job_id}.mp4"
|
| 225 |
+
os.makedirs("videos", exist_ok=True)
|
| 226 |
+
with open(dummy_path, "wb") as f:
|
| 227 |
+
f.write(b"dummy video content")
|
| 228 |
+
return dummy_path
|
| 229 |
+
|
| 230 |
+
# Determine manim executable path
|
| 231 |
+
manim_exe = "manim"
|
| 232 |
+
venv_manim = os.path.join(os.getcwd(), "venv", "bin", "manim")
|
| 233 |
+
if os.path.exists(venv_manim):
|
| 234 |
+
manim_exe = venv_manim
|
| 235 |
+
|
| 236 |
+
# Prepare environment with homebrew paths
|
| 237 |
+
custom_env = os.environ.copy()
|
| 238 |
+
brew_path = "/opt/homebrew/bin:/usr/local/bin"
|
| 239 |
+
custom_env["PATH"] = f"{brew_path}:{custom_env.get('PATH', '')}"
|
| 240 |
+
|
| 241 |
+
logger.info(f"Running {manim_exe} for job {job_id}...")
|
| 242 |
+
result = subprocess.run(
|
| 243 |
+
[manim_exe, "-ql", "--media_dir", ".", "-o", f"{job_id}.mp4", script_file, "GeometryScene"],
|
| 244 |
+
capture_output=True,
|
| 245 |
+
text=True,
|
| 246 |
+
env=custom_env,
|
| 247 |
+
)
|
| 248 |
+
logger.info(f"Manim STDOUT: {result.stdout}")
|
| 249 |
+
if result.returncode != 0:
|
| 250 |
+
logger.error(f"Manim STDERR: {result.stderr}")
|
| 251 |
+
|
| 252 |
+
for pattern in [f"**/videos/**/{job_id}.mp4", f"**/{job_id}*.mp4"]:
|
| 253 |
+
found = glob.glob(pattern, recursive=True)
|
| 254 |
+
if found:
|
| 255 |
+
logger.info(f"Manim Success: Found {found[0]}")
|
| 256 |
+
return found[0]
|
| 257 |
+
|
| 258 |
+
logger.error(f"Manim file not found for job {job_id}. Return code: {result.returncode}")
|
| 259 |
+
return ""
|
| 260 |
+
except Exception as e:
|
| 261 |
+
logger.exception(f"Manim Execution Error: {e}")
|
| 262 |
+
return ""
|
| 263 |
+
finally:
|
| 264 |
+
if os.path.exists(script_file):
|
| 265 |
+
os.remove(script_file)
|
migrations/add_image_bucket_storage.sql
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- ============================================================
|
| 2 |
+
-- MathSolver: Supabase Storage bucket `image` (chat / OCR attachments)
|
| 3 |
+
-- Run after session_assets and storage.video policies exist.
|
| 4 |
+
-- ============================================================
|
| 5 |
+
|
| 6 |
+
INSERT INTO storage.buckets (id, name, public)
|
| 7 |
+
VALUES ('image', 'image', true)
|
| 8 |
+
ON CONFLICT (id) DO UPDATE SET public = true;
|
| 9 |
+
|
| 10 |
+
-- Service role: upload/delete/list for API + workers
|
| 11 |
+
DROP POLICY IF EXISTS "Service Role manage images" ON storage.objects;
|
| 12 |
+
CREATE POLICY "Service Role manage images" ON storage.objects
|
| 13 |
+
FOR ALL
|
| 14 |
+
TO service_role
|
| 15 |
+
USING (bucket_id = 'image')
|
| 16 |
+
WITH CHECK (bucket_id = 'image');
|
| 17 |
+
|
| 18 |
+
-- Authenticated: read only objects under sessions they own (path sessions/{session_id}/...)
|
| 19 |
+
DROP POLICY IF EXISTS "Users view session images" ON storage.objects;
|
| 20 |
+
CREATE POLICY "Users view session images" ON storage.objects
|
| 21 |
+
FOR SELECT
|
| 22 |
+
TO authenticated
|
| 23 |
+
USING (
|
| 24 |
+
bucket_id = 'image'
|
| 25 |
+
AND (storage.foldername(name))[2] IN (
|
| 26 |
+
SELECT id::text FROM public.sessions WHERE user_id = auth.uid()
|
| 27 |
+
)
|
| 28 |
+
);
|
| 29 |
+
|
| 30 |
+
-- Public read for get_public_url / FE img tags (same model as video bucket)
|
| 31 |
+
DROP POLICY IF EXISTS "Public read images" ON storage.objects;
|
| 32 |
+
CREATE POLICY "Public read images" ON storage.objects
|
| 33 |
+
FOR SELECT
|
| 34 |
+
TO public
|
| 35 |
+
USING (bucket_id = 'image');
|
migrations/fix_rls_assets.sql
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- ============================================================
|
| 2 |
+
-- FIX RLS & SESSION ASSETS (MathSolver v5.1 Worker Fix)
|
| 3 |
+
-- ============================================================
|
| 4 |
+
|
| 5 |
+
-- 1. Ensure session_assets table exists
|
| 6 |
+
CREATE TABLE IF NOT EXISTS public.session_assets (
|
| 7 |
+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
| 8 |
+
session_id UUID NOT NULL REFERENCES public.sessions(id) ON DELETE CASCADE,
|
| 9 |
+
job_id UUID NOT NULL,
|
| 10 |
+
asset_type TEXT NOT NULL CHECK (asset_type IN ('video', 'image')),
|
| 11 |
+
storage_path TEXT NOT NULL,
|
| 12 |
+
public_url TEXT NOT NULL,
|
| 13 |
+
version INTEGER NOT NULL DEFAULT 1,
|
| 14 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
| 15 |
+
);
|
| 16 |
+
|
| 17 |
+
-- Index for session_assets
|
| 18 |
+
CREATE INDEX IF NOT EXISTS idx_session_assets_session_id ON public.session_assets(session_id);
|
| 19 |
+
CREATE INDEX IF NOT EXISTS idx_session_assets_type ON public.session_assets(session_id, asset_type);
|
| 20 |
+
|
| 21 |
+
-- 2. Enable RLS for all tables
|
| 22 |
+
ALTER TABLE public.session_assets ENABLE ROW LEVEL SECURITY;
|
| 23 |
+
ALTER TABLE public.profiles ENABLE ROW LEVEL SECURITY;
|
| 24 |
+
ALTER TABLE public.sessions ENABLE ROW LEVEL SECURITY;
|
| 25 |
+
ALTER TABLE public.messages ENABLE ROW LEVEL SECURITY;
|
| 26 |
+
ALTER TABLE public.jobs ENABLE ROW LEVEL SECURITY;
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
-- 3. Fix Table Policies to allow SERVICE ROLE
|
| 30 |
+
-- In Supabase, service_role usually bypasses RLS, but we add explicit policies for safety
|
| 31 |
+
-- especially for path-based checks or when SECURITY DEFINER functions are used.
|
| 32 |
+
|
| 33 |
+
-- [Session Assets]
|
| 34 |
+
DROP POLICY IF EXISTS "Users view own assets" ON public.session_assets;
|
| 35 |
+
CREATE POLICY "Users view own assets" ON public.session_assets
|
| 36 |
+
FOR SELECT USING (
|
| 37 |
+
session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid())
|
| 38 |
+
);
|
| 39 |
+
|
| 40 |
+
DROP POLICY IF EXISTS "Service role manages assets" ON public.session_assets;
|
| 41 |
+
CREATE POLICY "Service role manages assets" ON public.session_assets
|
| 42 |
+
FOR ALL USING (true)
|
| 43 |
+
WITH CHECK (true);
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
-- [Messages] - Allow Worker to insert assistant messages
|
| 47 |
+
DROP POLICY IF EXISTS "Users manage own messages" ON public.messages;
|
| 48 |
+
CREATE POLICY "Users manage own messages" ON public.messages
|
| 49 |
+
FOR ALL USING (
|
| 50 |
+
session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid())
|
| 51 |
+
OR
|
| 52 |
+
(auth.jwt() ->> 'role' = 'service_role')
|
| 53 |
+
);
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
-- [Jobs] - Allow Worker to update job status
|
| 57 |
+
DROP POLICY IF EXISTS "Users manage own jobs" ON public.jobs;
|
| 58 |
+
CREATE POLICY "Users manage own jobs" ON public.jobs
|
| 59 |
+
FOR ALL USING (
|
| 60 |
+
auth.uid() = user_id
|
| 61 |
+
OR user_id IS NULL
|
| 62 |
+
OR (auth.jwt() ->> 'role' = 'service_role')
|
| 63 |
+
);
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
-- 4. Storage Policies (Bucket: video)
|
| 67 |
+
-- Ensure 'video' bucket exists
|
| 68 |
+
INSERT INTO storage.buckets (id, name, public)
|
| 69 |
+
VALUES ('video', 'video', true)
|
| 70 |
+
ON CONFLICT (id) DO UPDATE SET public = true;
|
| 71 |
+
|
| 72 |
+
-- [Storage: Worker / Service Role] - Allow all in video bucket
|
| 73 |
+
DROP POLICY IF EXISTS "Service Role manage videos" ON storage.objects;
|
| 74 |
+
CREATE POLICY "Service Role manage videos" ON storage.objects
|
| 75 |
+
FOR ALL
|
| 76 |
+
TO service_role
|
| 77 |
+
USING (bucket_id = 'video');
|
| 78 |
+
|
| 79 |
+
-- [Storage: Users] - Allow users to view their session videos
|
| 80 |
+
DROP POLICY IF EXISTS "Users view session videos" ON storage.objects;
|
| 81 |
+
CREATE POLICY "Users view session videos" ON storage.objects
|
| 82 |
+
FOR SELECT
|
| 83 |
+
TO authenticated
|
| 84 |
+
USING (
|
| 85 |
+
bucket_id = 'video'
|
| 86 |
+
AND (storage.foldername(name))[2] IN (
|
| 87 |
+
SELECT id::text FROM public.sessions WHERE user_id = auth.uid()
|
| 88 |
+
)
|
| 89 |
+
);
|
| 90 |
+
|
| 91 |
+
-- [Storage: Public] - Allow public read access to videos
|
| 92 |
+
DROP POLICY IF EXISTS "Public read videos" ON storage.objects;
|
| 93 |
+
CREATE POLICY "Public read videos" ON storage.objects
|
| 94 |
+
FOR SELECT
|
| 95 |
+
TO public
|
| 96 |
+
USING (bucket_id = 'video');
|
migrations/v4_migration.sql
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- ============================================================
|
| 2 |
+
-- MATHSOLVER v4.0 - Migration Script (Multi-Session & History)
|
| 3 |
+
-- ============================================================
|
| 4 |
+
|
| 5 |
+
-- 1. Profiles Table (Extends Supabase Auth)
|
| 6 |
+
CREATE TABLE IF NOT EXISTS public.profiles (
|
| 7 |
+
id UUID PRIMARY KEY REFERENCES auth.users(id) ON DELETE CASCADE,
|
| 8 |
+
display_name TEXT,
|
| 9 |
+
avatar_url TEXT,
|
| 10 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
| 11 |
+
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
| 12 |
+
);
|
| 13 |
+
|
| 14 |
+
-- Function to handle new user signup and auto-create profile
|
| 15 |
+
CREATE OR REPLACE FUNCTION public.handle_new_user()
|
| 16 |
+
RETURNS TRIGGER AS $$
|
| 17 |
+
BEGIN
|
| 18 |
+
INSERT INTO public.profiles (id, display_name, avatar_url)
|
| 19 |
+
VALUES (
|
| 20 |
+
NEW.id,
|
| 21 |
+
COALESCE(NEW.raw_user_meta_data->>'full_name', NEW.email),
|
| 22 |
+
NEW.raw_user_meta_data->>'avatar_url'
|
| 23 |
+
);
|
| 24 |
+
RETURN NEW;
|
| 25 |
+
END;
|
| 26 |
+
$$ LANGUAGE plpgsql SECURITY DEFINER;
|
| 27 |
+
|
| 28 |
+
-- Trigger for profile creation
|
| 29 |
+
DROP TRIGGER IF EXISTS on_auth_user_created ON auth.users;
|
| 30 |
+
CREATE TRIGGER on_auth_user_created
|
| 31 |
+
AFTER INSERT ON auth.users
|
| 32 |
+
FOR EACH ROW EXECUTE FUNCTION public.handle_new_user();
|
| 33 |
+
|
| 34 |
+
-- 2. Sessions Table
|
| 35 |
+
CREATE TABLE IF NOT EXISTS public.sessions (
|
| 36 |
+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
| 37 |
+
user_id UUID NOT NULL REFERENCES auth.users(id) ON DELETE CASCADE,
|
| 38 |
+
title TEXT DEFAULT 'Bài toán mới',
|
| 39 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
| 40 |
+
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
| 41 |
+
);
|
| 42 |
+
|
| 43 |
+
-- Index for sessions
|
| 44 |
+
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON public.sessions(user_id);
|
| 45 |
+
CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON public.sessions(updated_at DESC);
|
| 46 |
+
|
| 47 |
+
-- 3. Messages Table
|
| 48 |
+
CREATE TABLE IF NOT EXISTS public.messages (
|
| 49 |
+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
| 50 |
+
session_id UUID NOT NULL REFERENCES public.sessions(id) ON DELETE CASCADE,
|
| 51 |
+
role TEXT NOT NULL CHECK (role IN ('user', 'assistant', 'system')),
|
| 52 |
+
type TEXT NOT NULL DEFAULT 'text',
|
| 53 |
+
content TEXT NOT NULL,
|
| 54 |
+
metadata JSONB DEFAULT '{}'::jsonb,
|
| 55 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
| 56 |
+
);
|
| 57 |
+
|
| 58 |
+
-- Index for messages
|
| 59 |
+
CREATE INDEX IF NOT EXISTS idx_messages_session_id ON public.messages(session_id);
|
| 60 |
+
CREATE INDEX IF NOT EXISTS idx_messages_created_at ON public.messages(session_id, created_at);
|
| 61 |
+
|
| 62 |
+
-- 4. Session Assets Table (v5.1 Versioning)
|
| 63 |
+
CREATE TABLE IF NOT EXISTS public.session_assets (
|
| 64 |
+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
| 65 |
+
session_id UUID NOT NULL REFERENCES public.sessions(id) ON DELETE CASCADE,
|
| 66 |
+
job_id UUID NOT NULL,
|
| 67 |
+
asset_type TEXT NOT NULL CHECK (asset_type IN ('video', 'image')),
|
| 68 |
+
storage_path TEXT NOT NULL,
|
| 69 |
+
public_url TEXT NOT NULL,
|
| 70 |
+
version INTEGER NOT NULL DEFAULT 1,
|
| 71 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
| 72 |
+
);
|
| 73 |
+
|
| 74 |
+
-- Index for session_assets
|
| 75 |
+
CREATE INDEX IF NOT EXISTS idx_session_assets_session_id ON public.session_assets(session_id);
|
| 76 |
+
|
| 77 |
+
-- 5. Update Jobs Table
|
| 78 |
+
ALTER TABLE public.jobs ADD COLUMN IF NOT EXISTS user_id UUID REFERENCES auth.users(id);
|
| 79 |
+
ALTER TABLE public.jobs ADD COLUMN IF NOT EXISTS session_id UUID REFERENCES public.sessions(id);
|
| 80 |
+
|
| 81 |
+
-- 6. Row Level Security (RLS)
|
| 82 |
+
ALTER TABLE public.profiles ENABLE ROW LEVEL SECURITY;
|
| 83 |
+
ALTER TABLE public.sessions ENABLE ROW LEVEL SECURITY;
|
| 84 |
+
ALTER TABLE public.messages ENABLE ROW LEVEL SECURITY;
|
| 85 |
+
ALTER TABLE public.jobs ENABLE ROW LEVEL SECURITY;
|
| 86 |
+
ALTER TABLE public.session_assets ENABLE ROW LEVEL SECURITY;
|
| 87 |
+
|
| 88 |
+
-- Polices for public.profiles
|
| 89 |
+
DROP POLICY IF EXISTS "Users view own profile" ON public.profiles;
|
| 90 |
+
CREATE POLICY "Users view own profile" ON public.profiles FOR SELECT USING (auth.uid() = id);
|
| 91 |
+
DROP POLICY IF EXISTS "Users update own profile" ON public.profiles;
|
| 92 |
+
CREATE POLICY "Users update own profile" ON public.profiles FOR UPDATE USING (auth.uid() = id);
|
| 93 |
+
|
| 94 |
+
-- Policies for public.sessions
|
| 95 |
+
DROP POLICY IF EXISTS "Users manage own sessions" ON public.sessions;
|
| 96 |
+
CREATE POLICY "Users manage own sessions" ON public.sessions FOR ALL USING (auth.uid() = user_id);
|
| 97 |
+
|
| 98 |
+
-- Policies for public.messages
|
| 99 |
+
DROP POLICY IF EXISTS "Users manage own messages" ON public.messages;
|
| 100 |
+
CREATE POLICY "Users manage own messages" ON public.messages FOR ALL USING (
|
| 101 |
+
session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid())
|
| 102 |
+
OR (auth.jwt() ->> 'role' = 'service_role')
|
| 103 |
+
);
|
| 104 |
+
|
| 105 |
+
-- Policies for public.session_assets
|
| 106 |
+
DROP POLICY IF EXISTS "Users view own assets" ON public.session_assets;
|
| 107 |
+
CREATE POLICY "Users view own assets" ON public.session_assets FOR SELECT USING (
|
| 108 |
+
session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid())
|
| 109 |
+
);
|
| 110 |
+
DROP POLICY IF EXISTS "Service role manages assets" ON public.session_assets;
|
| 111 |
+
CREATE POLICY "Service role manages assets" ON public.session_assets FOR ALL USING (true);
|
| 112 |
+
|
| 113 |
+
-- Policies for public.jobs
|
| 114 |
+
DROP POLICY IF EXISTS "Users manage own jobs" ON public.jobs;
|
| 115 |
+
CREATE POLICY "Users manage own jobs" ON public.jobs FOR ALL USING (
|
| 116 |
+
auth.uid() = user_id OR user_id IS NULL OR (auth.jwt() ->> 'role' = 'service_role')
|
| 117 |
+
);
|
| 118 |
+
|
| 119 |
+
-- 7. Storage Policies (Bucket: video)
|
| 120 |
+
-- (Run this in Supabase Dashboard if not allowed in migration)
|
| 121 |
+
-- INSERT INTO storage.buckets (id, name, public) VALUES ('video', 'video', true) ON CONFLICT (id) DO NOTHING;
|
| 122 |
+
-- CREATE POLICY "Service Role manage videos" ON storage.objects FOR ALL TO service_role USING (bucket_id = 'video');
|
| 123 |
+
-- CREATE POLICY "Public read videos" ON storage.objects FOR SELECT TO public USING (bucket_id = 'video');
|
| 124 |
+
|
| 125 |
+
-- Grant permissions to public/authenticated
|
| 126 |
+
GRANT ALL ON public.profiles TO authenticated;
|
| 127 |
+
GRANT ALL ON public.sessions TO authenticated;
|
| 128 |
+
GRANT ALL ON public.messages TO authenticated;
|
| 129 |
+
GRANT ALL ON public.jobs TO authenticated;
|
| 130 |
+
GRANT ALL ON public.session_assets TO authenticated;
|
| 131 |
+
GRANT ALL ON public.session_assets TO service_role;
|
pytest.ini
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
asyncio_mode = auto
|
| 3 |
+
testpaths = tests
|
| 4 |
+
pythonpath = .
|
| 5 |
+
filterwarnings =
|
| 6 |
+
ignore::DeprecationWarning
|
| 7 |
+
|
| 8 |
+
markers =
|
| 9 |
+
real_api: HTTP tests need running backend and TEST_USER_ID / TEST_SESSION_ID.
|
| 10 |
+
real_worker_ocr: OCR Celery task or full OCR stack (heavy).
|
| 11 |
+
real_worker_manim: Real Manim render and Supabase video upload.
|
| 12 |
+
real_agents: Live LLM / orchestrator agent calls.
|
| 13 |
+
slow: Large suite or long polling timeouts.
|
| 14 |
+
smoke: Fast API health + one solve job.
|
| 15 |
+
orchestrator_local: In-process Orchestrator without HTTP server.
|
| 16 |
+
|
| 17 |
+
# Default: skip integration tests that need services, keys, or long runs.
|
| 18 |
+
addopts = -m "not real_api and not real_worker_ocr and not real_worker_manim and not real_agents and not slow and not orchestrator_local"
|
requirements.txt
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Target: Python 3.11 (see Dockerfile). Used by: FastAPI API, Celery worker, Manim render, OCR/vision stack.
|
| 2 |
+
# Install: pip install -r requirements.txt
|
| 3 |
+
|
| 4 |
+
# --- Dev / test ---
|
| 5 |
+
pytest>=8.0
|
| 6 |
+
pytest-asyncio>=0.24
|
| 7 |
+
|
| 8 |
+
# --- HTTP API ---
|
| 9 |
+
cachetools>=5.3
|
| 10 |
+
fastapi>=0.115,<1
|
| 11 |
+
uvicorn[standard]>=0.30
|
| 12 |
+
python-multipart>=0.0.9
|
| 13 |
+
python-dotenv>=1.0
|
| 14 |
+
pydantic[email]>=2.4
|
| 15 |
+
email-validator>=2
|
| 16 |
+
|
| 17 |
+
# --- Auth / data / queue ---
|
| 18 |
+
openai>=1.40
|
| 19 |
+
supabase>=2.0
|
| 20 |
+
celery>=5.3
|
| 21 |
+
redis>=5
|
| 22 |
+
httpx>=0.27
|
| 23 |
+
websockets>=12
|
| 24 |
+
|
| 25 |
+
# --- Math & symbolic solver ---
|
| 26 |
+
sympy>=1.12
|
| 27 |
+
numpy>=1.26,<2
|
| 28 |
+
scipy>=1.11
|
| 29 |
+
opencv-python-headless>=4.8,<4.10
|
| 30 |
+
|
| 31 |
+
# --- Video (GeometryScene via CLI) ---
|
| 32 |
+
manim>=0.18,<0.20
|
| 33 |
+
|
| 34 |
+
# --- OCR & vision (orchestrator / legacy /ocr) ---
|
| 35 |
+
pix2tex>=0.1.4
|
| 36 |
+
paddleocr==2.7.3
|
| 37 |
+
paddlepaddle==2.6.2
|
| 38 |
+
ultralytics==8.2.2
|
requirements.worker-ocr.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OCR-only Celery worker: YOLO + PaddleOCR + Pix2Tex (no OpenRouter / no Manim).
|
| 2 |
+
# Install: pip install -r requirements.worker-ocr.txt
|
| 3 |
+
|
| 4 |
+
cachetools>=5.3
|
| 5 |
+
fastapi>=0.115,<1
|
| 6 |
+
uvicorn[standard]>=0.30
|
| 7 |
+
python-multipart>=0.0.9
|
| 8 |
+
python-dotenv>=1.0
|
| 9 |
+
pydantic[email]>=2.4
|
| 10 |
+
email-validator>=2
|
| 11 |
+
|
| 12 |
+
celery>=5.3
|
| 13 |
+
redis>=5
|
| 14 |
+
httpx>=0.27
|
| 15 |
+
websockets>=12
|
| 16 |
+
|
| 17 |
+
numpy>=1.26,<2
|
| 18 |
+
opencv-python-headless>=4.8,<4.10
|
| 19 |
+
|
| 20 |
+
pix2tex>=0.1.4
|
| 21 |
+
paddleocr==2.7.3
|
| 22 |
+
paddlepaddle==2.6.2
|
| 23 |
+
ultralytics==8.2.2
|
requirements.worker-render.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Celery render worker: Manim + Supabase (no OpenAI / SymPy / OCR vision stack).
|
| 2 |
+
# Install: pip install -r requirements.worker-render.txt
|
| 3 |
+
# Includes FastAPI/uvicorn for worker_health.py (HF Spaces).
|
| 4 |
+
|
| 5 |
+
cachetools>=5.3
|
| 6 |
+
fastapi>=0.115,<1
|
| 7 |
+
uvicorn[standard]>=0.30
|
| 8 |
+
python-multipart>=0.0.9
|
| 9 |
+
python-dotenv>=1.0
|
| 10 |
+
pydantic[email]>=2.4
|
| 11 |
+
email-validator>=2
|
| 12 |
+
|
| 13 |
+
celery>=5.3
|
| 14 |
+
redis>=5
|
| 15 |
+
httpx>=0.27
|
| 16 |
+
websockets>=12
|
| 17 |
+
|
| 18 |
+
supabase>=2.0
|
| 19 |
+
|
| 20 |
+
numpy>=1.26,<2
|
| 21 |
+
manim>=0.18,<0.20
|
run_api_test.sh
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
LOG_FILE="api_test_results.log"
|
| 4 |
+
echo "=== Starting API E2E Test Suite ($(date)) ===" > $LOG_FILE
|
| 5 |
+
|
| 6 |
+
# 1. Start BE Server in background
|
| 7 |
+
echo "[INFO] Starting Backend Server..." | tee -a $LOG_FILE
|
| 8 |
+
export ALLOW_TEST_BYPASS=true
|
| 9 |
+
export LOG_LEVEL=info
|
| 10 |
+
export CELERY_TASK_ALWAYS_EAGER=true
|
| 11 |
+
export CELERY_RESULT_BACKEND=rpc://
|
| 12 |
+
export MOCK_VIDEO=true
|
| 13 |
+
PYTHONPATH=. venv/bin/python -m uvicorn app.main:app --port 8000 > server_debug.log 2>&1 &
|
| 14 |
+
SERVER_PID=$!
|
| 15 |
+
|
| 16 |
+
# 2. Wait for server to be ready
|
| 17 |
+
echo "[INFO] Waiting for server (PID: $SERVER_PID) on port 8000..." | tee -a $LOG_FILE
|
| 18 |
+
MAX_RETRIES=15
|
| 19 |
+
READY=0
|
| 20 |
+
for i in $(seq 1 $MAX_RETRIES); do
|
| 21 |
+
if curl -s http://localhost:8000/ > /dev/null; then
|
| 22 |
+
READY=1
|
| 23 |
+
break
|
| 24 |
+
fi
|
| 25 |
+
sleep 2
|
| 26 |
+
done
|
| 27 |
+
|
| 28 |
+
if [ $READY -eq 0 ]; then
|
| 29 |
+
echo "[ERROR] Server failed to start in time. Check server_debug.log" | tee -a $LOG_FILE
|
| 30 |
+
kill $SERVER_PID
|
| 31 |
+
exit 1
|
| 32 |
+
fi
|
| 33 |
+
echo "[INFO] Server is READY." | tee -a $LOG_FILE
|
| 34 |
+
|
| 35 |
+
# 3. Prepare Test Data
|
| 36 |
+
echo "[INFO] Preparing fresh test data..." | tee -a $LOG_FILE
|
| 37 |
+
PREP_OUTPUT=$(PYTHONPATH=. venv/bin/python scripts/prepare_api_test.py)
|
| 38 |
+
echo "$PREP_OUTPUT" >> $LOG_FILE
|
| 39 |
+
|
| 40 |
+
export TEST_USER_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:USER_ID=" | cut -d'=' -f2)
|
| 41 |
+
export TEST_SESSION_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2)
|
| 42 |
+
|
| 43 |
+
if [ -z "$TEST_USER_ID" ] || [ -z "$TEST_SESSION_ID" ]; then
|
| 44 |
+
echo "[ERROR] Failed to prepare test data." | tee -a $LOG_FILE
|
| 45 |
+
kill $SERVER_PID
|
| 46 |
+
exit 1
|
| 47 |
+
fi
|
| 48 |
+
|
| 49 |
+
echo "[INFO] Test Data: User=$TEST_USER_ID, Session=$TEST_SESSION_ID" | tee -a $LOG_FILE
|
| 50 |
+
|
| 51 |
+
# 4. Run Pytest
|
| 52 |
+
echo "[INFO] Running API E2E Tests..." | tee -a $LOG_FILE
|
| 53 |
+
PYTHONPATH=. venv/bin/python -m pytest tests/test_api_real_e2e.py -m "smoke and real_api" -s \
|
| 54 |
+
--junitxml=pytest_smoke.xml >> $LOG_FILE 2>&1
|
| 55 |
+
TEST_EXIT_CODE=$?
|
| 56 |
+
|
| 57 |
+
# 5. Cleanup
|
| 58 |
+
echo "[INFO] Shutting down Server..." | tee -a $LOG_FILE
|
| 59 |
+
kill $SERVER_PID
|
| 60 |
+
|
| 61 |
+
echo "==========================================" | tee -a $LOG_FILE
|
| 62 |
+
if [ $TEST_EXIT_CODE -eq 0 ]; then
|
| 63 |
+
echo "FINAL RESULT: ✅ ALL API TESTS PASSED" | tee -a $LOG_FILE
|
| 64 |
+
else
|
| 65 |
+
echo "FINAL RESULT: ❌ SOME API TESTS FAILED (Code: $TEST_EXIT_CODE)" | tee -a $LOG_FILE
|
| 66 |
+
fi
|
| 67 |
+
echo "==========================================" | tee -a $LOG_FILE
|
| 68 |
+
|
| 69 |
+
exit $TEST_EXIT_CODE
|
run_full_api_test.sh
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Full API integration (CI-style): eager Celery + mock video + full HTTP suite.
|
| 3 |
+
LOG_FILE="full_api_suite.log"
|
| 4 |
+
REPORT_FILE="full_api_test_report.md"
|
| 5 |
+
JSON_RESULTS="temp_suite_results.json"
|
| 6 |
+
JUNIT="pytest_api_suite.xml"
|
| 7 |
+
|
| 8 |
+
echo "=== Starting Full API Suite Test ($(date)) ===" >"$LOG_FILE"
|
| 9 |
+
|
| 10 |
+
trap 'echo "[INFO] Cleaning up processes..."; kill $SERVER_PID 2>/dev/null; sleep 1' EXIT
|
| 11 |
+
|
| 12 |
+
echo "[INFO] Starting Backend Server (EAGER + MOCK_VIDEO)..." | tee -a "$LOG_FILE"
|
| 13 |
+
export ALLOW_TEST_BYPASS=true
|
| 14 |
+
export LOG_LEVEL=info
|
| 15 |
+
export CELERY_TASK_ALWAYS_EAGER=true
|
| 16 |
+
export CELERY_RESULT_BACKEND=rpc://
|
| 17 |
+
export MOCK_VIDEO=true
|
| 18 |
+
PYTHONPATH=. venv/bin/python -m uvicorn app.main:app --port 8000 >server_debug.log 2>&1 &
|
| 19 |
+
SERVER_PID=$!
|
| 20 |
+
|
| 21 |
+
echo "[INFO] Waiting for server (PID: $SERVER_PID)..." | tee -a "$LOG_FILE"
|
| 22 |
+
for i in {1..20}; do
|
| 23 |
+
if curl -s http://localhost:8000/ >/dev/null; then
|
| 24 |
+
echo "[INFO] Server is READY." | tee -a "$LOG_FILE"
|
| 25 |
+
break
|
| 26 |
+
fi
|
| 27 |
+
sleep 2
|
| 28 |
+
done
|
| 29 |
+
|
| 30 |
+
echo "[INFO] Preparing fresh test data..." | tee -a "$LOG_FILE"
|
| 31 |
+
PREP_OUTPUT=$(PYTHONPATH=. venv/bin/python scripts/prepare_api_test.py)
|
| 32 |
+
export TEST_USER_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:USER_ID=" | cut -d'=' -f2)
|
| 33 |
+
export TEST_SESSION_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2)
|
| 34 |
+
|
| 35 |
+
if [ -z "$TEST_USER_ID" ]; then
|
| 36 |
+
echo "[ERROR] Failed to prepare test data." | tee -a "$LOG_FILE"
|
| 37 |
+
exit 1
|
| 38 |
+
fi
|
| 39 |
+
|
| 40 |
+
echo "[INFO] Executing API tests (smoke + full suite)..." | tee -a "$LOG_FILE"
|
| 41 |
+
PYTHONPATH=. venv/bin/python -m pytest tests/test_api_real_e2e.py tests/test_api_full_suite.py \
|
| 42 |
+
-m "real_api" -s --tb=short --junitxml="$JUNIT" >>"$LOG_FILE" 2>&1
|
| 43 |
+
TEST_EXIT_CODE=$?
|
| 44 |
+
|
| 45 |
+
echo "[INFO] Generating Markdown Report..." | tee -a "$LOG_FILE"
|
| 46 |
+
if [ -f "$JSON_RESULTS" ]; then
|
| 47 |
+
PYTHONPATH=. venv/bin/python scripts/generate_report.py "$JSON_RESULTS" "$REPORT_FILE" "$JUNIT"
|
| 48 |
+
else
|
| 49 |
+
echo "[WARN] $JSON_RESULTS not found" | tee -a "$LOG_FILE"
|
| 50 |
+
fi
|
| 51 |
+
|
| 52 |
+
echo "==========================================" | tee -a "$LOG_FILE"
|
| 53 |
+
echo "DONE. Check $REPORT_FILE for results." | tee -a "$LOG_FILE"
|
| 54 |
+
echo "==========================================" | tee -a "$LOG_FILE"
|
| 55 |
+
|
| 56 |
+
exit $TEST_EXIT_CODE
|
scripts/benchmark_openrouter.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Benchmark several OpenRouter models (manual tool; not part of pytest)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
import httpx
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
|
| 12 |
+
_BACKEND_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 13 |
+
load_dotenv(os.path.join(_BACKEND_ROOT, ".env"))
|
| 14 |
+
|
| 15 |
+
MODELS = [
|
| 16 |
+
"nvidia/nemotron-3-super-120b-a12b:free",
|
| 17 |
+
"meta-llama/llama-3.3-70b-instruct:free",
|
| 18 |
+
"openai/gpt-oss-120b:free",
|
| 19 |
+
"z-ai/glm-4.5-air:free",
|
| 20 |
+
"minimax/minimax-m2.5:free",
|
| 21 |
+
"google/gemma-4-26b-a4b-it:free",
|
| 22 |
+
"google/gemma-4-31b-it:free",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
PROMPT = (
|
| 26 |
+
"Cho hình chữ nhật ABCD có AB bằng 5 và AD bằng 10. Gọi E là điểm nằm trong đoạn CD sao cho CE = 2ED. "
|
| 27 |
+
"Vẽ đoạn thẳng AE. Vẽ thêm P là điểm nằm trên đường thẳng BC sao cho BP = 2PC, tính chu vi tam giác PEA"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def main() -> None:
|
| 32 |
+
api_key = os.getenv("OPENROUTER_API_KEY_1") or os.getenv("OPENROUTER_API_KEY")
|
| 33 |
+
base_url = "https://openrouter.ai/api/v1/chat/completions"
|
| 34 |
+
|
| 35 |
+
if not api_key:
|
| 36 |
+
print("Missing OPENROUTER_API_KEY_1 or OPENROUTER_API_KEY in .env")
|
| 37 |
+
return
|
| 38 |
+
|
| 39 |
+
print("Benchmark OpenRouter models\nPrompt:", PROMPT, "\n")
|
| 40 |
+
results = []
|
| 41 |
+
|
| 42 |
+
for model in MODELS:
|
| 43 |
+
print(f"Calling {model}...", end="", flush=True)
|
| 44 |
+
headers = {
|
| 45 |
+
"Authorization": f"Bearer {api_key}",
|
| 46 |
+
"Content-Type": "application/json",
|
| 47 |
+
"HTTP-Referer": "https://mathsolver.io",
|
| 48 |
+
"X-Title": "MathSolver Benchmark Tool",
|
| 49 |
+
}
|
| 50 |
+
payload = {"model": model, "messages": [{"role": "user", "content": PROMPT}]}
|
| 51 |
+
start = time.time()
|
| 52 |
+
try:
|
| 53 |
+
with httpx.Client(timeout=120.0) as client:
|
| 54 |
+
r = client.post(base_url, headers=headers, json=payload)
|
| 55 |
+
r.raise_for_status()
|
| 56 |
+
data = r.json()
|
| 57 |
+
answer = data["choices"][0]["message"]["content"]
|
| 58 |
+
duration = time.time() - start
|
| 59 |
+
results.append(
|
| 60 |
+
{"model": model, "duration": duration, "answer": answer, "status": "success"}
|
| 61 |
+
)
|
| 62 |
+
print(f" OK ({duration:.2f}s)")
|
| 63 |
+
except Exception as e:
|
| 64 |
+
duration = time.time() - start
|
| 65 |
+
results.append(
|
| 66 |
+
{"model": model, "duration": duration, "error": str(e), "status": "error"}
|
| 67 |
+
)
|
| 68 |
+
print(f" FAIL ({duration:.2f}s) {e}")
|
| 69 |
+
|
| 70 |
+
print("\n" + "=" * 80)
|
| 71 |
+
for res in results:
|
| 72 |
+
print(json.dumps(res, ensure_ascii=False, indent=2)[:2000])
|
| 73 |
+
print("-" * 40)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
main()
|
scripts/generate_report.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import xml.etree.ElementTree as ET
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _parse_junit_xml(path: str) -> dict:
|
| 9 |
+
"""Summarize pytest junitxml (JUnit) file."""
|
| 10 |
+
out = {"tests": 0, "failures": 0, "errors": 0, "skipped": 0, "time": 0.0}
|
| 11 |
+
try:
|
| 12 |
+
tree = ET.parse(path)
|
| 13 |
+
root = tree.getroot()
|
| 14 |
+
nodes = [root] if root.tag == "testsuite" else list(root.iter("testsuite"))
|
| 15 |
+
for ts in nodes:
|
| 16 |
+
if ts.tag != "testsuite":
|
| 17 |
+
continue
|
| 18 |
+
out["tests"] += int(ts.attrib.get("tests", 0) or 0)
|
| 19 |
+
out["failures"] += int(ts.attrib.get("failures", 0) or 0)
|
| 20 |
+
out["errors"] += int(ts.attrib.get("errors", 0) or 0)
|
| 21 |
+
out["skipped"] += int(ts.attrib.get("skipped", 0) or 0)
|
| 22 |
+
out["time"] += float(ts.attrib.get("time", 0) or 0)
|
| 23 |
+
except Exception as e:
|
| 24 |
+
out["parse_error"] = str(e)
|
| 25 |
+
return out
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def generate_report(json_path: str, report_path: str, junit_path: str | None = None) -> None:
|
| 29 |
+
try:
|
| 30 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
| 31 |
+
data = json.load(f)
|
| 32 |
+
|
| 33 |
+
junit_summary = None
|
| 34 |
+
if junit_path and os.path.isfile(junit_path):
|
| 35 |
+
junit_summary = _parse_junit_xml(junit_path)
|
| 36 |
+
|
| 37 |
+
with open(report_path, "w", encoding="utf-8") as f:
|
| 38 |
+
f.write("# Báo cáo Kiểm thử tích hợp Backend (Integration Report)\n\n")
|
| 39 |
+
f.write(f"**Thời gian chạy:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
| 40 |
+
suite_ok = all(r.get("success", False) for r in data) if isinstance(data, list) else False
|
| 41 |
+
f.write(f"**API suite (JSON):** {'PASS' if suite_ok else 'FAIL'}\n")
|
| 42 |
+
|
| 43 |
+
if junit_summary and "parse_error" not in junit_summary:
|
| 44 |
+
j_ok = junit_summary["failures"] == 0 and junit_summary["errors"] == 0
|
| 45 |
+
f.write(
|
| 46 |
+
f"**Pytest (JUnit):** {'PASS' if j_ok else 'FAIL'} — "
|
| 47 |
+
f"tests={junit_summary['tests']}, failures={junit_summary['failures']}, "
|
| 48 |
+
f"errors={junit_summary['errors']}, skipped={junit_summary['skipped']}, "
|
| 49 |
+
f"time_s={junit_summary['time']:.2f}\n"
|
| 50 |
+
)
|
| 51 |
+
elif junit_summary and "parse_error" in junit_summary:
|
| 52 |
+
f.write(f"**Pytest (JUnit):** (could not parse: {junit_summary['parse_error']})\n")
|
| 53 |
+
|
| 54 |
+
f.write("\n")
|
| 55 |
+
|
| 56 |
+
f.write("| ID | Câu hỏi (Query) | Trạng thái | Thời gian (s) | Kết quả / Lỗi |\n")
|
| 57 |
+
f.write("| :--- | :--- | :--- | :--- | :--- |\n")
|
| 58 |
+
for r in data:
|
| 59 |
+
status = "PASS" if r.get("success") else "FAIL"
|
| 60 |
+
elapsed = f"{float(r.get('elapsed', 0) or 0):.2f}"
|
| 61 |
+
query = r.get("query", "-")
|
| 62 |
+
|
| 63 |
+
res = r.get("result", {})
|
| 64 |
+
if not isinstance(res, dict):
|
| 65 |
+
res = {}
|
| 66 |
+
|
| 67 |
+
analysis = res.get("semantic_analysis", "-")
|
| 68 |
+
if not r.get("success"):
|
| 69 |
+
analysis = f"**Lỗi:** {r.get('error', '-')}"
|
| 70 |
+
|
| 71 |
+
short_analysis = (analysis[:100] + "...") if len(str(analysis)) > 100 else analysis
|
| 72 |
+
|
| 73 |
+
f.write(f"| {r['id']} | {query} | {status} | {elapsed} | {short_analysis} |\n")
|
| 74 |
+
|
| 75 |
+
f.write("\n---\n**Chi tiết Output (DSL & Analysis):**\n")
|
| 76 |
+
for r in data:
|
| 77 |
+
if not r.get("success"):
|
| 78 |
+
continue
|
| 79 |
+
res = r.get("result", {})
|
| 80 |
+
if not isinstance(res, dict):
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
f.write(f"\n### Case {r['id']}: {r.get('query')}\n")
|
| 84 |
+
f.write(f"**Semantic Analysis:**\n{res.get('semantic_analysis', '-')}\n\n")
|
| 85 |
+
f.write(f"**Geometry DSL:**\n```\n{res.get('geometry_dsl', '-')}\n```\n")
|
| 86 |
+
|
| 87 |
+
sol = res.get("solution")
|
| 88 |
+
if sol and isinstance(sol, dict):
|
| 89 |
+
f.write("**Solution (v5.1):**\n")
|
| 90 |
+
f.write(f"- **Answer:** {sol.get('answer', 'N/A')}\n")
|
| 91 |
+
f.write("- **Steps:**\n")
|
| 92 |
+
steps = sol.get("steps", [])
|
| 93 |
+
if steps:
|
| 94 |
+
for step in steps:
|
| 95 |
+
f.write(f" - {step}\n")
|
| 96 |
+
else:
|
| 97 |
+
f.write(" - (Không có bước giải cụ thể)\n")
|
| 98 |
+
|
| 99 |
+
if sol.get("symbolic_expression"):
|
| 100 |
+
f.write(f"- **Symbolic:** `{sol.get('symbolic_expression')}`\n")
|
| 101 |
+
f.write("\n")
|
| 102 |
+
|
| 103 |
+
print(f"Report generated: {report_path}")
|
| 104 |
+
except Exception as e:
|
| 105 |
+
print(f"Error generating report: {e}")
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
if len(sys.argv) < 3:
|
| 110 |
+
print(
|
| 111 |
+
"Usage: python generate_report.py <json_results> <report_output> [junit_xml_optional]"
|
| 112 |
+
)
|
| 113 |
+
sys.exit(1)
|
| 114 |
+
junit = sys.argv[3] if len(sys.argv) > 3 else None
|
| 115 |
+
generate_report(sys.argv[1], sys.argv[2], junit)
|