Spaces:
Sleeping
Sleeping
| import copy | |
| import json | |
| import os | |
| import re | |
| import sys | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| from typing import Any | |
| import gradio as gr | |
| from google import genai | |
| try: | |
| import numpy as np | |
| except ImportError: | |
| np = None # type: ignore | |
| from PIL import Image | |
| # Avoid Ultralytics warning on HF Spaces / read-only home (before first YOLO use). | |
| os.environ.setdefault("YOLO_CONFIG_DIR", str(Path(tempfile.gettempdir()) / "Ultralytics")) | |
| gemini_api_key = os.environ.get("Gemini_API") | |
| VERIFY_TOKEN = os.environ.get("VERIFY_TOKEN") | |
| gemini_client = genai.Client(api_key=gemini_api_key) | |
| from ultralytics import YOLO | |
| from gradio_bbox_annotator import BBoxAnnotator | |
| import gradio_bbox_annotator.bbox_annotator as _bbox_annotator_mod | |
| from gradio.data_classes import FileData, GradioModel | |
| from pydantic import field_validator | |
| def analyze_student_work_gemini_img(image_path, prompt=None, analysis_prompt=None, knowledge_base_files=None, model_name="gemini-3-pro-preview", debug_mode=False): | |
| """ | |
| Analyzes student work using an image with Gemini. | |
| Optionally accepts knowledge_base_files (single Gemini file object or list of objects). | |
| """ | |
| if prompt is None: | |
| prompt = basic_prompt | |
| if analysis_prompt is None: | |
| analysis_prompt = with_analysis_prompt | |
| if not os.path.exists(image_path): | |
| print(f"❌ Image file does not exist: {image_path}") | |
| return None | |
| print(f"📤 Loading image: {image_path}") | |
| try: | |
| img = Image.open(image_path) | |
| except Exception as e: | |
| print(f"❌ Failed to load image: {e}") | |
| return None | |
| instruction = prompt + "\n" + analysis_prompt | |
| if debug_mode: | |
| print(f"instruction: {instruction}") | |
| print("⏳ Analyzing with Gemini (Image)...") | |
| content = [instruction, img] | |
| if knowledge_base_files: | |
| if isinstance(knowledge_base_files, list): | |
| content.extend(knowledge_base_files) | |
| else: | |
| content.append(knowledge_base_files) | |
| max_retries = 3 | |
| retryable_error_markers = ( | |
| "429", | |
| "RESOURCE_EXHAUSTED", | |
| "UNAVAILABLE", | |
| "DEADLINE_EXCEEDED", | |
| "INTERNAL", | |
| "500", | |
| "503", | |
| "TIMEOUT", | |
| ) | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| # Pass both text instruction, image, and optional KB file(s) | |
| response = gemini_client.models.generate_content( | |
| model=model_name, | |
| contents=content | |
| ) | |
| # Primary path | |
| text = None | |
| try: | |
| text = response.text | |
| except Exception as e: | |
| # Some blocked/empty responses may raise when accessing .text | |
| if debug_mode: | |
| print(f"⚠️ response.text access error: {e}") | |
| if text and text.strip(): | |
| return text | |
| # Fallback: attempt to reconstruct text from response candidates | |
| fallback_parts = [] | |
| finish_reasons = [] | |
| candidates = getattr(response, "candidates", None) or [] | |
| for cand in candidates: | |
| finish_reason = getattr(cand, "finish_reason", None) | |
| if finish_reason is not None: | |
| finish_reasons.append(str(finish_reason)) | |
| cand_content = getattr(cand, "content", None) | |
| parts = getattr(cand_content, "parts", None) if cand_content else None | |
| if not parts: | |
| continue | |
| for part in parts: | |
| part_text = getattr(part, "text", None) | |
| if part_text and part_text.strip(): | |
| fallback_parts.append(part_text) | |
| if fallback_parts: | |
| merged = "\n".join(fallback_parts).strip() | |
| if merged: | |
| print("⚠️ response.text is empty; using candidate parts fallback.") | |
| return merged | |
| print( | |
| f"❌ No text output from Gemini (attempt {attempt}/{max_retries}). " | |
| f"finish_reasons={finish_reasons if finish_reasons else 'N/A'}" | |
| ) | |
| if attempt < max_retries: | |
| sleep_s = 2 ** (attempt - 1) | |
| time.sleep(sleep_s) | |
| continue | |
| return None | |
| except Exception as e: | |
| err = str(e) | |
| is_retryable = any(marker in err.upper() for marker in retryable_error_markers) | |
| print(f"❌ Gemini API call failed (attempt {attempt}/{max_retries}): {e}") | |
| if is_retryable and attempt < max_retries: | |
| sleep_s = 2 ** (attempt - 1) | |
| print(f"⏳ Retrying after {sleep_s}s...") | |
| time.sleep(sleep_s) | |
| continue | |
| return None | |
| def analyze_student_work_wout_image_gemini(json_data, prompt=None, analysis_prompt=None, knowledge_base_files=None, model_name="gemini-3-pro-preview", debug_mode=False): | |
| """ | |
| Analyzes student work using ground truth JSON data with Gemini. | |
| Optionally accepts knowledge_base_files (single Gemini file object or list of objects). | |
| """ | |
| if prompt is None: | |
| prompt = basic_prompt_wout_image | |
| if analysis_prompt is None: | |
| analysis_prompt = with_analysis_prompt | |
| if not json_data: | |
| print("❌ json_data is empty") | |
| return None | |
| json_ret_str = json.dumps(json_data, indent=2, ensure_ascii=False) | |
| instruction = f"""{prompt} | |
| {analysis_prompt} | |
| INPUT_JSON: | |
| {json_ret_str}""" | |
| if debug_mode: | |
| print(f"instruction: {instruction}") | |
| print("⏳ Analyzing with Gemini (JSON)...") | |
| content = [instruction] | |
| if knowledge_base_files: | |
| if isinstance(knowledge_base_files, list): | |
| content.extend(knowledge_base_files) | |
| else: | |
| content.append(knowledge_base_files) | |
| max_retries = 3 | |
| retryable_error_markers = ( | |
| "429", | |
| "RESOURCE_EXHAUSTED", | |
| "UNAVAILABLE", | |
| "DEADLINE_EXCEEDED", | |
| "INTERNAL", | |
| "500", | |
| "503", | |
| "TIMEOUT", | |
| ) | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| response = gemini_client.models.generate_content( | |
| model=model_name, | |
| contents=content | |
| ) | |
| # Primary path | |
| text = None | |
| try: | |
| text = response.text | |
| except Exception as e: | |
| if debug_mode: | |
| print(f"⚠️ response.text access error: {e}") | |
| if text and text.strip(): | |
| return text | |
| # Fallback: attempt to reconstruct text from response candidates | |
| fallback_parts = [] | |
| finish_reasons = [] | |
| candidates = getattr(response, "candidates", None) or [] | |
| for cand in candidates: | |
| finish_reason = getattr(cand, "finish_reason", None) | |
| if finish_reason is not None: | |
| finish_reasons.append(str(finish_reason)) | |
| cand_content = getattr(cand, "content", None) | |
| parts = getattr(cand_content, "parts", None) if cand_content else None | |
| if not parts: | |
| continue | |
| for part in parts: | |
| part_text = getattr(part, "text", None) | |
| if part_text and part_text.strip(): | |
| fallback_parts.append(part_text) | |
| if fallback_parts: | |
| merged = "\n".join(fallback_parts).strip() | |
| if merged: | |
| print("⚠️ response.text is empty; using candidate parts fallback.") | |
| return merged | |
| print( | |
| f"❌ No text output from Gemini (attempt {attempt}/{max_retries}). " | |
| f"finish_reasons={finish_reasons if finish_reasons else 'N/A'}" | |
| ) | |
| if attempt < max_retries: | |
| sleep_s = 2 ** (attempt - 1) | |
| time.sleep(sleep_s) | |
| continue | |
| return None | |
| except Exception as e: | |
| err = str(e) | |
| is_retryable = any(marker in err.upper() for marker in retryable_error_markers) | |
| print(f"❌ Gemini API call failed (attempt {attempt}/{max_retries}): {e}") | |
| if is_retryable and attempt < max_retries: | |
| sleep_s = 2 ** (attempt - 1) | |
| print(f"⏳ Retrying after {sleep_s}s...") | |
| time.sleep(sleep_s) | |
| continue | |
| return None | |
| def _patch_bbox_annotator_coerce_float_coords() -> None: | |
| """ | |
| Gradio 5 + pydantic v2: BBoxAnnotator's data model expects int bbox coords. | |
| The annotator frontend sends floats → ValidationError before our callbacks run. | |
| Replace Annotation/AnnotatedImage with versions that round to int on validate. | |
| """ | |
| try: | |
| class Annotation(GradioModel): | |
| left: int | |
| top: int | |
| right: int | |
| bottom: int | |
| label: str | None | |
| def _round_to_int(cls, v: Any) -> int: | |
| if v is None: | |
| return 0 | |
| return int(round(float(v))) | |
| def width(self) -> int: | |
| return self.right - self.left | |
| def height(self) -> int: | |
| return self.bottom - self.top | |
| class AnnotatedImage(GradioModel): | |
| image: FileData | |
| annotations: list[Annotation] | |
| _bbox_annotator_mod.Annotation = Annotation | |
| _bbox_annotator_mod.AnnotatedImage = AnnotatedImage | |
| BBoxAnnotator.data_model = AnnotatedImage | |
| except Exception as exc: | |
| print(f"[MathNet] BBoxAnnotator float→int patch skipped: {exc}") | |
| _patch_bbox_annotator_coerce_float_coords() | |
| STANDARD_LABELS = ["tick", "fraction", "zero", "one", "customize_label"] | |
| ERROR_OPTIONS = [ | |
| "Error 1: Unequal segmentation", | |
| "Error 2: Wrong number of segments (denominator)", | |
| "Error 3: Wrong number of segments chosen (numerator)", | |
| "Error 4: Incorrect tick labels (confused with whole numbers)", | |
| "Error 5: Incorrect tick labels (careless)", | |
| "Error 6: Incorrect 0 or 1s", | |
| "Error 7: Unit of 1 and subunits", | |
| ] | |
| MODEL_PATHS = { | |
| "yolo_accurate": "./vt_dataset_yolov12_v7_weights.pt", | |
| "yolo_extra_large": "./VT_dataset_2_Yolov12_Extra_large.pt", | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Import prompts from variables.py (tempScript/) – fall back to minimal stubs | |
| # --------------------------------------------------------------------------- | |
| def _resolve_temp_script_dir() -> Path | None: | |
| here = Path(__file__).resolve() | |
| candidates = [ | |
| here.parent / "tempScript", | |
| here.parent.parent / "tempScript", | |
| here.parent.parent.parent / "tempScript", | |
| Path.cwd() / "tempScript", | |
| ] | |
| for candidate in candidates: | |
| if candidate.exists() and candidate.is_dir(): | |
| return candidate | |
| return None | |
| TEMP_SCRIPT_DIR = _resolve_temp_script_dir() | |
| if TEMP_SCRIPT_DIR and str(TEMP_SCRIPT_DIR) not in sys.path: | |
| sys.path.append(str(TEMP_SCRIPT_DIR)) | |
| try: | |
| from variables import ( | |
| basic_prompt, | |
| basic_prompt_wout_image, | |
| advanced_prompt_wout_image, | |
| with_analysis_prompt, | |
| without_analysis_prompt, | |
| gemini_with_analysis_multiple_error_prompt, | |
| gemini_without_analysis_multiple_error_prompt, | |
| get_gemini_one_error_prompt_with_analysis, | |
| get_gemini_one_error_prompt_without_analysis, | |
| ) | |
| except Exception as _imp_err: | |
| print(f"[MathNet] Could not import prompts from variables.py: {_imp_err}") | |
| basic_prompt = "Please analyze the error type of the student's work based on the definition in the knowledge base." | |
| basic_prompt_wout_image = "You are analyzing a student's number-line fraction homework." | |
| advanced_prompt_wout_image = basic_prompt_wout_image | |
| with_analysis_prompt = "Output ONLY a valid JSON object with error_types array." | |
| without_analysis_prompt = with_analysis_prompt | |
| gemini_with_analysis_multiple_error_prompt = with_analysis_prompt | |
| gemini_without_analysis_multiple_error_prompt = without_analysis_prompt | |
| def get_gemini_one_error_prompt_with_analysis(t: str) -> str: | |
| return f"Check ONLY for {t}. Output JSON with analysis and error_type array." | |
| def get_gemini_one_error_prompt_without_analysis(t: str) -> str: | |
| return f"Check ONLY for {t}. Output JSON with error_type array." | |
| # --------------------------------------------------------------------------- | |
| # Load knowledge base markdown files from Definitions/ folder | |
| # --------------------------------------------------------------------------- | |
| def _load_knowledge_base() -> str: | |
| """Read all Error *.md files from the Definitions folder and concatenate into a single KB string.""" | |
| here = Path(__file__).resolve() | |
| candidates = [ | |
| here.parent / "Definitions", | |
| here.parent.parent / "Definitions", | |
| here.parent.parent.parent / "Definitions", | |
| Path.cwd() / "Definitions", | |
| ] | |
| kb_dir: Path | None = None | |
| for c in candidates: | |
| if c.exists() and c.is_dir(): | |
| kb_dir = c | |
| break | |
| if kb_dir is None: | |
| print("[MathNet] Definitions folder not found – knowledge base will be empty.") | |
| return "" | |
| md_files = sorted(kb_dir.glob("*.md")) | |
| if not md_files: | |
| print(f"[MathNet] No .md files in {kb_dir}") | |
| return "" | |
| parts: list[str] = [] | |
| for f in md_files: | |
| parts.append(f.read_text(encoding="utf-8").strip()) | |
| kb_text = "\n\n---\n\n".join(parts) | |
| print(f"[MathNet] Loaded knowledge base: {len(md_files)} files, {len(kb_text)} chars") | |
| return kb_text | |
| KNOWLEDGE_BASE_TEXT = _load_knowledge_base() | |
| GEMINI_KB_FILES: list | None = None | |
| def empty_schema() -> dict[str, Any]: | |
| return { | |
| "gt_info": { | |
| "fraction": {"bbox": [], "values": []}, | |
| "one": {"bbox": []}, | |
| "tick": {"bbox": []}, | |
| "zero": {"bbox": []}, | |
| "relationship": "", | |
| } | |
| } | |
| def empty_ui_state() -> dict[str, Any]: | |
| return {"image": None, "boxes": [], "relationship": "", "next_id": 0, "_last_json_ui_fp": None, "_selected_box_idx": None} | |
| def ensure_placeholder_image_path() -> str: | |
| """Create a tiny valid image file for BBoxAnnotator initial value.""" | |
| placeholder = Path(tempfile.gettempdir()) / "mathnet_placeholder.png" | |
| if not placeholder.exists(): | |
| Image.new("RGB", (8, 8), color=(245, 245, 245)).save(placeholder) | |
| return str(placeholder) | |
| def coerce_image_path(value: Any) -> str | None: | |
| """ | |
| Gradio 5 / Spaces may pass filepath as str, Path, or FileData-like dict. | |
| BBoxAnnotator postprocess expects a real path string. | |
| """ | |
| if value is None: | |
| return None | |
| if isinstance(value, Path): | |
| p = str(value) | |
| return p if p.strip() else None | |
| if isinstance(value, str): | |
| return value if value.strip() else None | |
| if isinstance(value, dict): | |
| for key in ("path", "name", "url"): | |
| v = value.get(key) | |
| if isinstance(v, str) and v.strip(): | |
| return v | |
| return None | |
| if hasattr(value, "path") and isinstance(getattr(value, "path", None), str): | |
| p = getattr(value, "path") | |
| return p if p.strip() else None | |
| return None | |
| def to_python_scalar(x: Any) -> Any: | |
| """Ensure gr.State / JSON never sees numpy scalars.""" | |
| if np is not None: | |
| if isinstance(x, np.generic): | |
| return x.item() | |
| if isinstance(x, np.ndarray): | |
| return x.tolist() | |
| return x | |
| def ensure_list_boxes(raw: Any) -> list[Any]: | |
| if raw is None: | |
| return [] | |
| if isinstance(raw, list): | |
| return raw | |
| if isinstance(raw, tuple): | |
| return list(raw) | |
| # Some component versions may stringify annotations | |
| if isinstance(raw, str): | |
| try: | |
| parsed = json.loads(raw) | |
| return parsed if isinstance(parsed, list) else [] | |
| except Exception: | |
| return [] | |
| return [] | |
| def _indices_left_to_right(boxes: list[dict[str, Any]], label_key: str) -> dict[int, int]: | |
| """Map box list index -> 0-based order index for boxes with given label, sorted by xmin (left to right).""" | |
| pairs = [(i, b) for i, b in enumerate(boxes) if b.get("label") == label_key] | |
| pairs.sort(key=lambda ib: float(ib[1]["xmin"])) | |
| return {i: j for j, (i, _) in enumerate(pairs)} | |
| def _center_y(box: dict[str, Any]) -> float: | |
| ymin = float(box.get("ymin", 0.0)) | |
| ymax = float(box.get("ymax", 0.0)) | |
| return (ymin + ymax) / 2.0 | |
| def annotation_display_text( | |
| box: dict[str, Any], | |
| box_index: int, | |
| tick_map: dict[int, int], | |
| frac_map: dict[int, int], | |
| ) -> str: | |
| """ | |
| Text on each bbox outline in the annotator (T0,T1,… / F0-value / zero / …). | |
| box['label'] in state remains the canonical type (tick, fraction, …); this string is display-only. | |
| """ | |
| lbl = box.get("label", "tick") | |
| if lbl == "tick": | |
| return f"T{tick_map.get(box_index, 0)}" | |
| if lbl == "fraction": | |
| fi = frac_map.get(box_index, 0) | |
| fv = (box.get("fraction_value") or "").strip() | |
| return f"F{fi}-{fv}" | |
| if lbl == "zero": | |
| return "zero" | |
| if lbl == "one": | |
| return "one" | |
| if lbl == "customize_label": | |
| cn = (box.get("custom_label") or "").strip() | |
| return cn if cn else "customize_label" | |
| return str(lbl) | |
| def canonicalize_incoming_annotation_label(raw: Any, prev: dict[str, Any]) -> dict[str, Any]: | |
| """ | |
| Map annotator label string back to internal label + fields. | |
| Accepts category names from the tool (tick, fraction, …) or display strings (T0, F1-3/4, …). | |
| """ | |
| prev = prev or {} | |
| s = (raw if raw is not None else "").strip() | |
| if not s: | |
| return { | |
| "label": str(prev.get("label", "tick")), | |
| "fraction_value": str(prev.get("fraction_value", "") or ""), | |
| "custom_label": str(prev.get("custom_label", "") or ""), | |
| } | |
| if s in STANDARD_LABELS: | |
| return { | |
| "label": s, | |
| "fraction_value": (prev.get("fraction_value") or "").strip() if s == "fraction" else "", | |
| "custom_label": (prev.get("custom_label") or "").strip() if s == "customize_label" else "", | |
| } | |
| m = re.match(r"^T(\d+)$", s) | |
| if m: | |
| return {"label": "tick", "fraction_value": "", "custom_label": ""} | |
| m = re.match(r"^F(\d+)-(.*)$", s, re.DOTALL) | |
| if m: | |
| return {"label": "fraction", "fraction_value": m.group(2).strip(), "custom_label": ""} | |
| if s == "zero": | |
| return {"label": "zero", "fraction_value": "", "custom_label": ""} | |
| if s == "one": | |
| return {"label": "one", "fraction_value": "", "custom_label": ""} | |
| return {"label": "customize_label", "fraction_value": "", "custom_label": s} | |
| _BBOX_OVERLAY_CSS = """\ | |
| .box-preview[data-display] { overflow: visible !important; } | |
| .box-preview[data-display]::before { | |
| content: attr(data-display); | |
| position: absolute; | |
| top: -16px; | |
| left: 0; | |
| font-size: 11px; | |
| font-weight: 700; | |
| color: #fff; | |
| padding: 1px 4px; | |
| border-radius: 3px; | |
| white-space: nowrap; | |
| pointer-events: none; | |
| z-index: 10; | |
| line-height: 14px; | |
| } | |
| .box-preview[data-label="tick"]::before { background: hsl(180,100%,35%); } | |
| .box-preview[data-label="fraction"]::before { background: hsl(225,100%,35%); } | |
| .box-preview[data-label="zero"]::before { background: hsl(270,100%,35%); } | |
| .box-preview[data-label="one"]::before { background: hsl(315,100%,35%); } | |
| .box-preview[data-label="customize_label"]::before { background: hsl(0,100%,35%); } | |
| """ | |
| _BBOX_OVERLAY_JS = """\ | |
| () => { | |
| function refresh() { | |
| const boxes = document.querySelectorAll('.box-preview[data-label]'); | |
| if (!boxes.length) return; | |
| const byLabel = {}; | |
| boxes.forEach(el => { | |
| const lbl = el.getAttribute('data-label') || ''; | |
| if (!byLabel[lbl]) byLabel[lbl] = []; | |
| byLabel[lbl].push(el); | |
| }); | |
| for (const lbl of Object.keys(byLabel)) { | |
| const group = byLabel[lbl]; | |
| group.sort((a, b) => parseFloat(a.style.left) - parseFloat(b.style.left)); | |
| group.forEach((el, idx) => { | |
| let txt; | |
| if (lbl === 'tick') txt = 'T' + idx; | |
| else if (lbl === 'fraction') txt = 'F' + idx; | |
| else txt = lbl; | |
| if (el.getAttribute('data-display') !== txt) | |
| el.setAttribute('data-display', txt); | |
| }); | |
| } | |
| } | |
| setInterval(refresh, 300); | |
| const tryObs = setInterval(() => { | |
| const target = document.querySelector('.image-frame'); | |
| if (target) { | |
| clearInterval(tryObs); | |
| new MutationObserver(refresh).observe(target, | |
| { childList: true, subtree: true, attributes: true, | |
| attributeFilter: ['style', 'data-label'] }); | |
| refresh(); | |
| } | |
| }, 200); | |
| /* Click-to-select: write clicked box index into a hidden Gradio textbox (debounced) */ | |
| let _lastClickTime = 0; | |
| document.addEventListener('pointerdown', function(e) { | |
| const now = Date.now(); | |
| if (now - _lastClickTime < 300) return; | |
| const box = e.target.closest('.box-preview'); | |
| if (!box) return; | |
| _lastClickTime = now; | |
| const allBoxes = Array.from(document.querySelectorAll('.box-preview')); | |
| const idx = allBoxes.indexOf(box); | |
| if (idx < 0) return; | |
| const ta = document.querySelector('#clicked-box-idx textarea'); | |
| if (!ta) return; | |
| const setter = Object.getOwnPropertyDescriptor( | |
| window.HTMLTextAreaElement.prototype, 'value' | |
| ).set; | |
| setter.call(ta, String(idx) + '_' + now); | |
| ta.dispatchEvent(new Event('input', {bubbles: true})); | |
| }, true); | |
| } | |
| """ | |
| def compute_auto_relationship(state: dict[str, Any]) -> str: | |
| """ | |
| For each tick (left-to-right by xmin), pair with the fraction whose vertical center is closest. | |
| Output format: T0-F1,T1-F0,... | |
| """ | |
| boxes = state.get("boxes", []) | |
| tick_map = _indices_left_to_right(boxes, "tick") | |
| frac_map = _indices_left_to_right(boxes, "fraction") | |
| if not tick_map or not frac_map: | |
| return "" | |
| tick_indices = sorted(tick_map.keys(), key=lambda i: float(boxes[i]["xmin"])) | |
| frac_entries: list[tuple[int, float, int]] = [] | |
| for i in frac_map.keys(): | |
| frac_entries.append((i, _center_y(boxes[i]), frac_map[i])) | |
| parts: list[str] = [] | |
| for ti in tick_indices: | |
| cy_t = _center_y(boxes[ti]) | |
| best = min( | |
| frac_entries, | |
| key=lambda e: (abs(e[1] - cy_t), float(boxes[e[0]]["xmin"])), | |
| ) | |
| t_ord = tick_map[ti] | |
| f_ord = best[2] | |
| parts.append(f"T{t_ord}-F{f_ord}") | |
| return ",".join(parts) | |
| def clamp_box(box: dict[str, Any]) -> dict[str, Any]: | |
| xmin = float(box.get("xmin", 0.0)) | |
| ymin = float(box.get("ymin", 0.0)) | |
| xmax = float(box.get("xmax", 0.0)) | |
| ymax = float(box.get("ymax", 0.0)) | |
| if xmax < xmin: | |
| xmin, xmax = xmax, xmin | |
| if ymax < ymin: | |
| ymin, ymax = ymax, ymin | |
| return {"xmin": xmin, "ymin": ymin, "xmax": xmax, "ymax": ymax} | |
| def normalize_annotator_boxes(raw_boxes: list[dict[str, Any]], current_state: dict[str, Any]) -> list[dict[str, Any]]: | |
| by_id = {b["id"]: b for b in current_state.get("boxes", []) if "id" in b} | |
| existing_boxes = current_state.get("boxes", []) | |
| out: list[dict[str, Any]] = [] | |
| next_id = int(current_state.get("next_id", 0)) | |
| used_existing_ids: set[int] = set() | |
| def _coord_distance(a: dict[str, Any], b: dict[str, Any]) -> float: | |
| return ( | |
| abs(a["xmin"] - b["xmin"]) | |
| + abs(a["ymin"] - b["ymin"]) | |
| + abs(a["xmax"] - b["xmax"]) | |
| + abs(a["ymax"] - b["ymax"]) | |
| ) | |
| for i, box in enumerate(raw_boxes or []): | |
| normalized = clamp_box(box) | |
| box_id = box.get("id") | |
| incoming_label = box.get("label") | |
| incoming_canon_for_match = canonicalize_incoming_annotation_label(incoming_label, {}) | |
| incoming_canon_label = incoming_canon_for_match.get("label", "tick") | |
| if box_id is None or box_id not in by_id or box_id in used_existing_ids: | |
| best_id = None | |
| best_dist = float("inf") | |
| for prev_box in existing_boxes: | |
| pid = prev_box.get("id") | |
| if pid is None or pid in used_existing_ids: | |
| continue | |
| prev_label = prev_box.get("label", "tick") | |
| # Match on canonical label (T0 / tick / fraction / …). | |
| if incoming_label is not None and str(incoming_canon_label) != str(prev_label): | |
| continue | |
| dist = _coord_distance(normalized, prev_box) | |
| if dist < best_dist: | |
| best_dist = dist | |
| best_id = int(pid) | |
| # High tolerance for dragging; if counts match exactly, force match to prevent ID churn. | |
| if best_id is not None and (best_dist <= 200.0 or (raw_boxes and len(raw_boxes) == len(existing_boxes))): | |
| box_id = best_id | |
| prev = by_id.get(box_id, {}) | |
| else: | |
| box_id = next_id | |
| next_id += 1 | |
| prev = {} | |
| else: | |
| prev = by_id[box_id] | |
| used_existing_ids.add(int(box_id)) | |
| incoming_canon = canonicalize_incoming_annotation_label(incoming_label, prev) | |
| label = incoming_canon.get("label", prev.get("label", "tick")) | |
| fraction_value = (incoming_canon.get("fraction_value") or "").strip() | |
| custom_label = (incoming_canon.get("custom_label") or "").strip() | |
| if label == "fraction" and not fraction_value: | |
| fraction_value = (prev.get("fraction_value") or "").strip() | |
| if label == "customize_label" and not custom_label: | |
| custom_label = (prev.get("custom_label") or "").strip() | |
| out.append( | |
| { | |
| "id": box_id, | |
| "xmin": normalized["xmin"], | |
| "ymin": normalized["ymin"], | |
| "xmax": normalized["xmax"], | |
| "ymax": normalized["ymax"], | |
| "label": label, | |
| "fraction_value": fraction_value, | |
| "custom_label": custom_label, | |
| } | |
| ) | |
| # Preserve the order of boxes as they were in existing_boxes to prevent UI flickering on hover reorders | |
| existing_ids = [b.get("id") for b in existing_boxes if "id" in b] | |
| def _sort_key(b): | |
| bid = b.get("id") | |
| try: | |
| return existing_ids.index(bid) | |
| except ValueError: | |
| return len(existing_ids) + 1 | |
| out.sort(key=_sort_key) | |
| current_state["next_id"] = next_id | |
| return out | |
| def bbox_tuple_to_internal(raw_boxes: list[tuple[Any, Any, Any, Any, Any]], current_state: dict[str, Any]) -> list[dict[str, Any]]: | |
| converted: list[dict[str, Any]] = [] | |
| for raw in raw_boxes or []: | |
| if isinstance(raw, dict): | |
| # Accept common dict forms from custom components. | |
| label = raw.get("label") or raw.get("category") or "tick" | |
| xmin = raw.get("xmin", raw.get("left")) | |
| ymin = raw.get("ymin", raw.get("top")) | |
| xmax = raw.get("xmax", raw.get("right")) | |
| ymax = raw.get("ymax", raw.get("bottom")) | |
| if None in (xmin, ymin, xmax, ymax): | |
| continue | |
| try: | |
| converted.append( | |
| { | |
| "xmin": float(to_python_scalar(xmin)), | |
| "ymin": float(to_python_scalar(ymin)), | |
| "xmax": float(to_python_scalar(xmax)), | |
| "ymax": float(to_python_scalar(ymax)), | |
| "label": str(label), | |
| } | |
| ) | |
| except (TypeError, ValueError): | |
| continue | |
| elif isinstance(raw, (tuple, list)) and len(raw) >= 4: | |
| label = raw[4] if len(raw) >= 5 and raw[4] else "tick" | |
| try: | |
| converted.append( | |
| { | |
| "xmin": float(to_python_scalar(raw[0])), | |
| "ymin": float(to_python_scalar(raw[1])), | |
| "xmax": float(to_python_scalar(raw[2])), | |
| "ymax": float(to_python_scalar(raw[3])), | |
| "label": str(label), | |
| } | |
| ) | |
| except (TypeError, ValueError): | |
| continue | |
| return normalize_annotator_boxes(converted, current_state) | |
| def parse_annotator_payload( | |
| annot_data: Any, | |
| fallback_image: str | None, | |
| ) -> tuple[str | None, list[Any]]: | |
| """Normalize BBoxAnnotator payloads from different event paths.""" | |
| raw_img: Any = None | |
| raw_boxes: Any = [] | |
| if isinstance(annot_data, tuple) and len(annot_data) >= 2: | |
| raw_img, raw_boxes = annot_data[0], annot_data[1] | |
| elif isinstance(annot_data, list) and len(annot_data) >= 2: | |
| raw_img, raw_boxes = annot_data[0], annot_data[1] | |
| elif isinstance(annot_data, dict): | |
| raw_img = annot_data.get("image") or annot_data.get("path") | |
| raw_boxes = annot_data.get("boxes") | |
| if raw_boxes is None: | |
| raw_boxes = annot_data.get("annotations", []) | |
| else: | |
| raw_img = None | |
| raw_boxes = [] | |
| image_path = coerce_image_path(raw_img) or coerce_image_path(fallback_image) | |
| boxes = ensure_list_boxes(raw_boxes) | |
| return image_path, boxes | |
| def dropdown_update_for_boxes( | |
| state: dict[str, Any], | |
| prefer_last: bool = False, | |
| selected_index: int | None = None, | |
| ): | |
| choices = to_box_choices(state) | |
| if not choices: | |
| # Keep dropdown value as empty string to avoid stale-value preprocess errors. | |
| return gr.update(choices=[], value="") | |
| if selected_index is not None and 0 <= selected_index < len(choices): | |
| value = choices[selected_index] | |
| else: | |
| value = choices[-1] if prefer_last else choices[0] | |
| return gr.update(choices=choices, value=value) | |
| def sync_label_controls_for_box(state: dict[str, Any], selected_index: int | None) -> tuple[Any, Any, Any]: | |
| """ | |
| Keep 'Assign label' + fraction/custom fields aligned with the selected (or last) box. | |
| Called after canvas sync so the category used in BBoxAnnotator is reflected in the side panel. | |
| """ | |
| boxes = state.get("boxes", []) | |
| if not boxes: | |
| return ( | |
| gr.update(value="tick"), | |
| gr.update(value="", visible=False), | |
| gr.update(value="", visible=False), | |
| ) | |
| idx = selected_index if selected_index is not None else len(boxes) - 1 | |
| idx = max(0, min(int(idx), len(boxes) - 1)) | |
| box = boxes[idx] | |
| raw = str(box.get("label", "tick")) | |
| lbl = raw if raw in STANDARD_LABELS else "customize_label" | |
| return ( | |
| gr.update(value=lbl), | |
| gr.update( | |
| visible=(lbl == "fraction"), | |
| value=(box.get("fraction_value", "") or "") if lbl == "fraction" else "", | |
| ), | |
| gr.update( | |
| visible=(lbl == "customize_label"), | |
| value=(box.get("custom_label", "") or "") if lbl == "customize_label" else "", | |
| ), | |
| ) | |
| def _wire_label(box: dict[str, Any]) -> str: | |
| """Canonical label sent to BBoxAnnotator (controls colour + category buttons).""" | |
| lbl = box.get("label", "tick") | |
| if lbl in STANDARD_LABELS: | |
| return lbl | |
| return lbl | |
| def state_to_annotator_value(state: dict[str, Any]) -> tuple[str | None, list[tuple[int, int, int, int, str]]]: | |
| image_path = coerce_image_path(state.get("image")) or ensure_placeholder_image_path() | |
| boxes = state.get("boxes", []) | |
| annotations = [ | |
| ( | |
| int(round(b["xmin"])), | |
| int(round(b["ymin"])), | |
| int(round(b["xmax"])), | |
| int(round(b["ymax"])), | |
| _wire_label(b), | |
| ) | |
| for b in boxes | |
| ] | |
| return image_path, annotations | |
| def serialize_to_schema(state: dict[str, Any]) -> dict[str, Any]: | |
| schema = empty_schema() | |
| gt_info = schema["gt_info"] | |
| relationship = state.get("relationship", "") | |
| gt_info["relationship"] = relationship | |
| for b in state.get("boxes", []): | |
| coords = [round(float(b["xmin"]), 2), round(float(b["ymin"]), 2), round(float(b["xmax"]), 2), round(float(b["ymax"]), 2)] | |
| label = b.get("label", "tick") | |
| if label == "fraction": | |
| gt_info["fraction"]["bbox"].append(coords) | |
| gt_info["fraction"]["values"].append(b.get("fraction_value", "")) | |
| elif label in ("tick", "zero", "one"): | |
| gt_info[label]["bbox"].append(coords) | |
| elif label == "customize_label": | |
| custom_name = (b.get("custom_label", "") or "customize_label").strip() | |
| if custom_name not in gt_info: | |
| gt_info[custom_name] = {"bbox": []} | |
| gt_info[custom_name]["bbox"].append(coords) | |
| else: | |
| # Keep output shape consistent if labels were loaded from JSON custom keys. | |
| if label not in gt_info: | |
| gt_info[label] = {"bbox": []} | |
| gt_info[label]["bbox"].append(coords) | |
| return schema | |
| def parse_schema_to_state(schema_like: dict[str, Any], image: Any, current_state: dict[str, Any]) -> dict[str, Any]: | |
| gt_info = (schema_like or {}).get("gt_info", {}) | |
| boxes: list[dict[str, Any]] = [] | |
| next_id = int(current_state.get("next_id", 0)) | |
| def add_box(label: str, bbox: list[float], fraction_value: str = "", custom_label: str = "") -> None: | |
| nonlocal next_id | |
| if not isinstance(bbox, list) or len(bbox) != 4: | |
| return | |
| boxes.append( | |
| { | |
| "id": next_id, | |
| "xmin": float(bbox[0]), | |
| "ymin": float(bbox[1]), | |
| "xmax": float(bbox[2]), | |
| "ymax": float(bbox[3]), | |
| "label": label, | |
| "fraction_value": fraction_value, | |
| "custom_label": custom_label, | |
| } | |
| ) | |
| next_id += 1 | |
| fractions = gt_info.get("fraction", {}) if isinstance(gt_info.get("fraction", {}), dict) else {} | |
| fraction_bboxes = fractions.get("bbox", []) if isinstance(fractions.get("bbox", []), list) else [] | |
| fraction_values = fractions.get("values", []) if isinstance(fractions.get("values", []), list) else [] | |
| for i, bbox in enumerate(fraction_bboxes): | |
| add_box("fraction", bbox, fraction_values[i] if i < len(fraction_values) else "") | |
| for core_label in ("one", "tick", "zero"): | |
| entry = gt_info.get(core_label, {}) | |
| if not isinstance(entry, dict): | |
| continue | |
| for bbox in entry.get("bbox", []) if isinstance(entry.get("bbox", []), list) else []: | |
| add_box(core_label, bbox) | |
| for key, value in gt_info.items(): | |
| if key in ("fraction", "one", "tick", "zero", "relationship"): | |
| continue | |
| if isinstance(value, dict): | |
| for bbox in value.get("bbox", []) if isinstance(value.get("bbox", []), list) else []: | |
| add_box("customize_label", bbox, custom_label=key) | |
| return { | |
| "image": image, | |
| "boxes": boxes, | |
| "relationship": str(gt_info.get("relationship", "")), | |
| "next_id": next_id, | |
| } | |
| def state_to_json_text(state: dict[str, Any]) -> str: | |
| return json.dumps(serialize_to_schema(state), indent=4, ensure_ascii=False) | |
| def json_semantically_equal(a: str, b: str) -> bool: | |
| """Compare decoded JSON with sorted object keys (avoids harmless key-order flicker).""" | |
| try: | |
| return json.dumps(json.loads(a), sort_keys=True, ensure_ascii=False) == json.dumps( | |
| json.loads(b), sort_keys=True, ensure_ascii=False | |
| ) | |
| except Exception: | |
| return a == b | |
| def _round_floats_in_obj(obj: Any, ndigits: int = 2) -> Any: | |
| if isinstance(obj, bool): | |
| return obj | |
| if isinstance(obj, float): | |
| return round(obj, ndigits) | |
| if isinstance(obj, int): | |
| return obj | |
| if isinstance(obj, list): | |
| return [_round_floats_in_obj(x, ndigits) for x in obj] | |
| if isinstance(obj, dict): | |
| return {k: _round_floats_in_obj(v, ndigits) for k, v in obj.items()} | |
| return obj | |
| def _sort_gt_bbox_lists(obj: Any) -> Any: | |
| """Sort bbox arrays so echo round-trips from the annotator compare equal.""" | |
| obj = copy.deepcopy(obj) | |
| gi = obj.get("gt_info") if isinstance(obj, dict) else None | |
| if not isinstance(gi, dict): | |
| return obj | |
| for k, v in list(gi.items()): | |
| if k == "relationship" or not isinstance(v, dict): | |
| continue | |
| bbs = v.get("bbox") | |
| if not isinstance(bbs, list) or not bbs: | |
| continue | |
| if k == "fraction" and isinstance(v.get("values"), list): | |
| vals = v["values"] | |
| pairs = list(zip(bbs, vals)) | |
| pairs.sort( | |
| key=lambda p: tuple(float(x) for x in p[0]) | |
| if isinstance(p[0], list) and len(p[0]) == 4 | |
| else (0.0, 0.0, 0.0, 0.0) | |
| ) | |
| v["bbox"] = [p[0] for p in pairs] | |
| v["values"] = [p[1] for p in pairs] | |
| else: | |
| def _bbox_key(b: Any) -> tuple[float, float, float, float]: | |
| if isinstance(b, list) and len(b) == 4: | |
| return tuple(float(x) for x in b) | |
| return (0.0, 0.0, 0.0, 0.0) | |
| v["bbox"] = sorted(bbs, key=_bbox_key) | |
| gi[k] = v | |
| return obj | |
| def stable_gt_json_fingerprint(json_text: str, ndigits: int = 2) -> str: | |
| """Normalize numeric noise + bbox order for comparing whether UI needs an update.""" | |
| try: | |
| data = json.loads(json_text) | |
| data = _round_floats_in_obj(data, ndigits) | |
| data = _sort_gt_bbox_lists(data) | |
| return json.dumps(data, sort_keys=True, ensure_ascii=False) | |
| except Exception: | |
| return json_text | |
| def json_outputs_equivalent_for_ui(a: str, b: str) -> bool: | |
| if json_semantically_equal(a, b): | |
| return True | |
| if stable_gt_json_fingerprint(a, 2) == stable_gt_json_fingerprint(b, 2): | |
| return True | |
| # YOLO ↔ bbox annotator echo can differ slightly below 0.1px after rounding. | |
| return stable_gt_json_fingerprint(a, 1) == stable_gt_json_fingerprint(b, 1) | |
| def _seed_json_ui_fp(state: dict[str, Any]) -> None: | |
| """After any direct write to JSON UI, record fingerprint so Timer won't repaint the same text.""" | |
| try: | |
| jt = state_to_json_text(state) | |
| except Exception: | |
| jt = json.dumps(empty_schema(), indent=4, ensure_ascii=False) | |
| state["_last_json_ui_fp"] = stable_gt_json_fingerprint(jt, 2) | |
| def _structural_fingerprint(state: dict[str, Any]) -> tuple: | |
| """Captures box count, labels, values, and selected index — ignores coordinates.""" | |
| boxes = state.get("boxes", []) | |
| return ( | |
| len(boxes), | |
| tuple((b.get("label", ""), b.get("fraction_value", ""), b.get("custom_label", "")) for b in boxes), | |
| state.get("_selected_box_idx"), | |
| ) | |
| def to_box_choices(state: dict[str, Any]) -> list[str]: | |
| return [f"{idx}: id={b['id']} [{b.get('label', 'tick')}]" for idx, b in enumerate(state.get("boxes", []))] | |
| def detect_with_yolo(image_path: str | None, selected_model: str) -> list[dict[str, Any]]: | |
| if not image_path: | |
| return [] | |
| model_path = MODEL_PATHS.get(selected_model, MODEL_PATHS["yolo_accurate"]) | |
| if not os.path.exists(model_path): | |
| return [] | |
| model = YOLO(model_path) | |
| pred = model(image_path, verbose=False)[0] | |
| names = model.names | |
| boxes = [] | |
| for bb in pred.boxes: | |
| cls_name = names[int(bb.cls)] | |
| if cls_name not in ("tick", "fraction", "zero", "one"): | |
| continue | |
| x, y, w, h = bb.xywh[0].tolist() | |
| boxes.append( | |
| { | |
| "xmin": x - w / 2, | |
| "ymin": y - h / 2, | |
| "xmax": x + w / 2, | |
| "ymax": y + h / 2, | |
| "label": cls_name, | |
| } | |
| ) | |
| return boxes | |
| def initialize_from_upload(image_path: str | None): | |
| try: | |
| state = empty_ui_state() | |
| state["image"] = coerce_image_path(image_path) | |
| jt = state_to_json_text(state) | |
| _seed_json_ui_fp(state) | |
| lab, fr, cu = sync_label_controls_for_box(state, None) | |
| return state_to_annotator_value(state), state, jt, "", dropdown_update_for_boxes(state), lab, fr, cu | |
| except Exception: | |
| state = empty_ui_state() | |
| jt = state_to_json_text(state) | |
| _seed_json_ui_fp(state) | |
| lab, fr, cu = sync_label_controls_for_box(state, None) | |
| return state_to_annotator_value(state), state, jt, "", dropdown_update_for_boxes(state), lab, fr, cu | |
| def run_detection(image_path: str | None, selected_model: str, state: dict[str, Any]): | |
| state = state or empty_ui_state() | |
| path = coerce_image_path(image_path) | |
| state["image"] = path | |
| old_boxes = list(state.get("boxes", [])) | |
| try: | |
| detected = normalize_annotator_boxes(detect_with_yolo(path, selected_model), state) | |
| state["boxes"] = detected | |
| state["relationship"] = compute_auto_relationship(state) | |
| json_text = state_to_json_text(state) | |
| _seed_json_ui_fp(state) | |
| n = len(state.get("boxes", [])) | |
| sel = (n - 1) if n else None | |
| dd = dropdown_update_for_boxes(state, selected_index=sel) if n else dropdown_update_for_boxes(state) | |
| lab, fr, cu = sync_label_controls_for_box(state, sel) | |
| rel = state.get("relationship", "") | |
| return state_to_annotator_value(state), state, json_text, dd, lab, fr, cu, rel | |
| except Exception: | |
| state["boxes"] = old_boxes | |
| state["relationship"] = compute_auto_relationship(state) | |
| _seed_json_ui_fp(state) | |
| n = len(state.get("boxes", [])) | |
| sel = (n - 1) if n else None | |
| dd = dropdown_update_for_boxes(state, selected_index=sel) if n else dropdown_update_for_boxes(state) | |
| lab, fr, cu = sync_label_controls_for_box(state, sel) | |
| rel = state.get("relationship", "") | |
| return state_to_annotator_value(state), state, state_to_json_text(state), dd, lab, fr, cu, rel | |
| def _detect_changed_box_index(old_boxes: list[dict], new_boxes: list[dict]) -> int | None: | |
| """Find the index of the box that was most likely just interacted with.""" | |
| if not new_boxes: | |
| return None | |
| if len(new_boxes) > len(old_boxes): | |
| return len(new_boxes) - 1 | |
| if len(new_boxes) < len(old_boxes): | |
| return max(0, len(new_boxes) - 1) if new_boxes else None | |
| old_by_id = {b.get("id"): b for b in old_boxes if "id" in b} | |
| best_idx, best_dist = None, 0.0 | |
| for i, nb in enumerate(new_boxes): | |
| ob = old_by_id.get(nb.get("id")) | |
| if ob is None: | |
| return i | |
| dist = ( | |
| abs(float(nb.get("xmin", 0)) - float(ob.get("xmin", 0))) | |
| + abs(float(nb.get("ymin", 0)) - float(ob.get("ymin", 0))) | |
| + abs(float(nb.get("xmax", 0)) - float(ob.get("xmax", 0))) | |
| + abs(float(nb.get("ymax", 0)) - float(ob.get("ymax", 0))) | |
| ) | |
| if dist > best_dist: | |
| best_dist = dist | |
| best_idx = i | |
| # Increase threshold to 5.0 to prevent hover-jitter / rounding from constantly triggering selection change. | |
| # Actual clicks are handled robustly by the JS pointerdown listener anyway. | |
| if best_dist > 5.0: | |
| return best_idx | |
| return None | |
| def sync_canvas_to_state(annot_data: Any, state: dict[str, Any]): | |
| """ | |
| Fires on every annotator change (draw, drag, resize, delete). | |
| Structural changes (add/remove/label) refresh all UI controls. | |
| Coordinate-only changes (drag/resize) update state silently — no flicker. | |
| """ | |
| try: | |
| state = state or empty_ui_state() | |
| prev_struct = _structural_fingerprint(state) | |
| image_path, raw_boxes = parse_annotator_payload(annot_data, state.get("image")) | |
| coerced = coerce_image_path(image_path) or coerce_image_path(state.get("image")) | |
| work = copy.deepcopy(state) | |
| old_boxes = list(state.get("boxes", [])) | |
| new_boxes = bbox_tuple_to_internal(raw_boxes, work) | |
| work["image"] = coerced | |
| work["boxes"] = new_boxes | |
| changed = _detect_changed_box_index(old_boxes, new_boxes) | |
| if changed is not None: | |
| work["_selected_box_idx"] = changed | |
| new_struct = _structural_fingerprint(work) | |
| # Gradio needs a new object to reliably detect State changes | |
| if new_struct == prev_struct: | |
| return (work, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()) | |
| _seed_json_ui_fp(work) | |
| n = len(work.get("boxes", [])) | |
| sel = work.get("_selected_box_idx") | |
| try: | |
| sel = int(sel) if sel is not None else None | |
| except (TypeError, ValueError): | |
| sel = None | |
| if sel is None or sel < 0 or sel >= n: | |
| sel = (n - 1) if n else None | |
| jt = state_to_json_text(work) | |
| dd = dropdown_update_for_boxes(work, selected_index=sel) if n else dropdown_update_for_boxes(work) | |
| lab, fr, cu = sync_label_controls_for_box(work, sel) | |
| return work, jt, dd, lab, fr, cu, gr.update() | |
| except Exception as exc: | |
| print(f"[MathNet] sync_canvas_to_state error (suppressed): {exc}") | |
| return (state or empty_ui_state(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()) | |
| def on_label_change(label_name: str) -> tuple[gr.Textbox, gr.Textbox]: | |
| return ( | |
| gr.update(visible=(label_name == "fraction")), | |
| gr.update(visible=(label_name == "customize_label")), | |
| ) | |
| def on_selected_box_change(selected_box: str | None, state: dict[str, Any]): | |
| """When user changes Current box dropdown, mirror that box's label into Assign label.""" | |
| _noop = (gr.update(), gr.update(visible=False), gr.update(visible=False)) | |
| try: | |
| state = state or empty_ui_state() | |
| if not selected_box: | |
| n = len(state.get("boxes", [])) | |
| s = (n - 1) if n else None | |
| return sync_label_controls_for_box(state, s) | |
| idx = int(str(selected_box).split(":")[0]) | |
| return sync_label_controls_for_box(state, idx) | |
| except Exception as exc: | |
| print(f"[MathNet] on_selected_box_change error (suppressed): {exc}") | |
| return _noop | |
| def on_box_clicked(clicked_value: str, state: dict[str, Any]): | |
| """Triggered by JS when user clicks on a box outline in the annotator canvas.""" | |
| _noop = (gr.update(), gr.update(), gr.update(), gr.update(), state or empty_ui_state()) | |
| try: | |
| state = state or empty_ui_state() | |
| if not clicked_value: | |
| return _noop | |
| idx = int(str(clicked_value).split("_")[0]) | |
| boxes = state.get("boxes", []) | |
| if idx < 0 or idx >= len(boxes): | |
| return _noop | |
| state["_selected_box_idx"] = idx | |
| dd = dropdown_update_for_boxes(state, selected_index=idx) | |
| lab, fr, cu = sync_label_controls_for_box(state, idx) | |
| return dd, lab, fr, cu, state | |
| except Exception as exc: | |
| print(f"[MathNet] on_box_clicked error (suppressed): {exc}") | |
| return _noop | |
| def apply_label_to_selected( | |
| selected_box: str, | |
| label_name: str, | |
| fraction_value: str, | |
| custom_label_name: str, | |
| state: dict[str, Any], | |
| ) -> tuple[Any, Any, str, Any, Any, Any, Any]: | |
| if not selected_box: | |
| n = len(state.get("boxes", [])) | |
| s = (n - 1) if n else None | |
| lab, fr, cu = sync_label_controls_for_box(state, s) | |
| rel = state.get("relationship", "") | |
| return state_to_annotator_value(state), state, state_to_json_text(state), lab, fr, cu, rel | |
| idx = int(selected_box.split(":")[0]) | |
| if idx < 0 or idx >= len(state.get("boxes", [])): | |
| n = len(state.get("boxes", [])) | |
| s = (n - 1) if n else None | |
| lab, fr, cu = sync_label_controls_for_box(state, s) | |
| rel = state.get("relationship", "") | |
| return state_to_annotator_value(state), state, state_to_json_text(state), lab, fr, cu, rel | |
| box = state["boxes"][idx] | |
| box["label"] = label_name | |
| box["fraction_value"] = fraction_value.strip() if label_name == "fraction" else "" | |
| box["custom_label"] = custom_label_name.strip() if label_name == "customize_label" else "" | |
| json_text = state_to_json_text(state) | |
| _seed_json_ui_fp(state) | |
| lab, fr, cu = sync_label_controls_for_box(state, idx) | |
| rel = state.get("relationship", "") | |
| return state_to_annotator_value(state), state, json_text, lab, fr, cu, rel | |
| def update_relationship(relationship_text: str, state: dict[str, Any]) -> tuple[dict[str, Any], str]: | |
| state["relationship"] = (relationship_text or "").strip() | |
| jt = state_to_json_text(state) | |
| _seed_json_ui_fp(state) | |
| return state, jt | |
| def load_json_text(json_text: str, state: dict[str, Any]): | |
| state = state or empty_ui_state() | |
| try: | |
| parsed = json.loads(json_text) | |
| except Exception as exc: | |
| choices = to_box_choices(state) | |
| err_msg = f"JSON parse error: {exc}" | |
| state["_last_json_ui_fp"] = stable_gt_json_fingerprint(err_msg, 2) | |
| n = len(state.get("boxes", [])) | |
| sel = (n - 1) if n else None | |
| lab, fr, cu = sync_label_controls_for_box(state, sel) | |
| return ( | |
| state_to_annotator_value(state), | |
| state, | |
| err_msg, | |
| gr.update(choices=choices, value=choices[0] if choices else None), | |
| state.get("relationship", ""), | |
| lab, | |
| fr, | |
| cu, | |
| ) | |
| rebuilt = parse_schema_to_state(parsed, state.get("image"), state) | |
| jt_ok = state_to_json_text(rebuilt) | |
| _seed_json_ui_fp(rebuilt) | |
| n = len(rebuilt.get("boxes", [])) | |
| sel = (n - 1) if n else None | |
| dd = dropdown_update_for_boxes(rebuilt, selected_index=sel) if n else dropdown_update_for_boxes(rebuilt) | |
| lab, fr, cu = sync_label_controls_for_box(rebuilt, sel) | |
| return ( | |
| state_to_annotator_value(rebuilt), | |
| rebuilt, | |
| jt_ok, | |
| dd, | |
| rebuilt.get("relationship", ""), | |
| lab, | |
| fr, | |
| cu, | |
| ) | |
| def create_json_download(json_text: str): | |
| if not json_text or not str(json_text).strip(): | |
| return gr.update() | |
| with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as f: | |
| f.write(json_text) | |
| return gr.update(value=f.name) | |
| def packDict(dict_data: dict) -> str: | |
| pack_str = "" | |
| for key, value in dict_data.items(): | |
| pack_str += f"{key}: {value}\n" | |
| return pack_str | |
| def run_diagnosis( | |
| input_selector: str, | |
| mode_selector: str, | |
| analysis_mode: str, | |
| error_choice: str, | |
| annot_data: Any, | |
| state: dict[str, Any], | |
| ) -> str: | |
| kb = GEMINI_KB_FILES if GEMINI_KB_FILES else (KNOWLEDGE_BASE_TEXT if KNOWLEDGE_BASE_TEXT else None) | |
| want_analysis = (analysis_mode == "With analysis") | |
| if mode_selector == "Binary error only": | |
| if want_analysis: | |
| analysis_prompt = get_gemini_one_error_prompt_with_analysis(error_choice) | |
| else: | |
| analysis_prompt = get_gemini_one_error_prompt_without_analysis(error_choice) | |
| else: | |
| if want_analysis: | |
| analysis_prompt = gemini_with_analysis_multiple_error_prompt | |
| else: | |
| analysis_prompt = gemini_without_analysis_multiple_error_prompt | |
| if input_selector == "img-only": | |
| img_from_annot, _ = parse_annotator_payload(annot_data, state.get("image")) | |
| image_path = coerce_image_path(img_from_annot) or coerce_image_path(state.get("image")) | |
| if not image_path: | |
| return "No image available for diagnosis." | |
| result = analyze_student_work_gemini_img( | |
| image_path=image_path, | |
| prompt=basic_prompt, | |
| analysis_prompt=analysis_prompt, | |
| knowledge_base_files=kb, | |
| ) | |
| parsed_result = parse_json_response(result, kw = "error_type") | |
| parsed_result_str = packDict(parsed_result) | |
| return parsed_result_str or "Diagnosis returned no text." | |
| schema = serialize_to_schema(state) | |
| if not schema: | |
| return "No JSON data available for diagnosis. Run YOLO detection or annotate first." | |
| result = analyze_student_work_wout_image_gemini( | |
| json_data=schema, | |
| prompt=basic_prompt_wout_image, | |
| analysis_prompt=analysis_prompt, | |
| knowledge_base_files=kb, | |
| ) | |
| parsed_result = parse_json_response(result, kw = "error_type") | |
| parsed_result_str = packDict(parsed_result) | |
| return parsed_result_str or "Diagnosis returned no text." | |
| def check_token(token): | |
| if VERIFY_TOKEN and token == VERIFY_TOKEN: | |
| return gr.update(visible=True), gr.update(value="Token accepted. You can now use the application.") | |
| else: | |
| return gr.update(visible=False), gr.update(value="Invalid token. Please try again.") | |
| with gr.Blocks(title="MathNet Annotation + Diagnosis", css=_BBOX_OVERLAY_CSS, js=_BBOX_OVERLAY_JS) as demo: | |
| gr.Markdown("## MathNet Interactive Annotation and Diagnosis") | |
| # ── Token verification gate ── | |
| with gr.Group() as auth_group: | |
| gr.Markdown("### Please enter your access token to continue") | |
| token_input = gr.Textbox(label="Access Token", type="password", placeholder="Enter token...") | |
| token_btn = gr.Button("Verify") | |
| token_msg = gr.Textbox(label="Status", interactive=False, value="") | |
| # ── Main application (hidden until verified) ── | |
| main_app = gr.Group(visible=False) | |
| with main_app: | |
| _init_state = empty_ui_state() | |
| _init_json_text = state_to_json_text(_init_state) | |
| _init_state["_last_json_ui_fp"] = stable_gt_json_fingerprint(_init_json_text, 2) | |
| app_state = gr.State(_init_state) | |
| with gr.Row(): | |
| image_input = gr.Image(type="filepath", label="Upload student work image") | |
| with gr.Column(): | |
| model_choice = gr.Dropdown( | |
| choices=list(MODEL_PATHS.keys()), | |
| value="yolo_accurate", | |
| label="Detection model", | |
| ) | |
| detect_btn = gr.Button("Run YOLO Detection") | |
| download_json_btn = gr.DownloadButton("Download JSON") | |
| annotator = BBoxAnnotator( | |
| value=(ensure_placeholder_image_path(), []), | |
| label="Editable Bounding Boxes", | |
| categories=STANDARD_LABELS, | |
| show_download_button=False, | |
| ) | |
| clicked_box_idx = gr.Textbox(elem_id="clicked-box-idx", visible=False) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| selected_box_dd = gr.Dropdown(choices=[], label="Current box", allow_custom_value=True) | |
| label_dd = gr.Dropdown(choices=STANDARD_LABELS, value="tick", label="Assign label") | |
| fraction_value_tb = gr.Textbox(label="Fraction value (e.g., 3/4)", visible=False) | |
| custom_label_tb = gr.Textbox(label="Custom label name", visible=False) | |
| apply_label_btn = gr.Button("Apply Label To Current Box") | |
| relationship_tb = gr.Textbox( | |
| label="Fraction–Tick relationship (auto + editable)", | |
| placeholder="Auto: each tick (left→right) pairs with nearest fraction by vertical center; e.g. T0-F1,T1-F0", | |
| ) | |
| save_relationship_btn = gr.Button("Save Relationship to JSON") | |
| with gr.Column(scale=3): | |
| json_editor = gr.Textbox( | |
| label="Synchronized JSON", | |
| lines=18, | |
| value=_init_json_text, | |
| ) | |
| load_json_btn = gr.Button("Load JSON Into UI") | |
| with gr.Group(): | |
| gr.Markdown("### Diagnosis Control Panel") | |
| input_selector = gr.Radio(choices=["img-only", "json data"], value="img-only", label="Input Selector") | |
| mode_selector = gr.Radio(choices=["Binary error only", "Auto-diagnosis"], value="Auto-diagnosis", label="Mode Selector") | |
| analysis_selector = gr.Radio(choices=["With analysis", "Without analysis"], value="With analysis", label="Analysis Mode") | |
| error_dropdown = gr.Dropdown(choices=ERROR_OPTIONS, value=ERROR_OPTIONS[0], label="Error Dropdown", visible=False) | |
| run_diagnosis_btn = gr.Button("Run Diagnosis") | |
| diagnosis_output = gr.Textbox(label="Diagnosis Output", lines=10) | |
| image_input.change( | |
| initialize_from_upload, | |
| inputs=[image_input], | |
| outputs=[ | |
| annotator, | |
| app_state, | |
| json_editor, | |
| relationship_tb, | |
| selected_box_dd, | |
| label_dd, | |
| fraction_value_tb, | |
| custom_label_tb, | |
| ], | |
| queue=False, | |
| ) | |
| image_input.clear( | |
| initialize_from_upload, | |
| inputs=[image_input], | |
| outputs=[ | |
| annotator, | |
| app_state, | |
| json_editor, | |
| relationship_tb, | |
| selected_box_dd, | |
| label_dd, | |
| fraction_value_tb, | |
| custom_label_tb, | |
| ], | |
| queue=False, | |
| ) | |
| detect_btn.click( | |
| run_detection, | |
| inputs=[image_input, model_choice, app_state], | |
| outputs=[ | |
| annotator, | |
| app_state, | |
| json_editor, | |
| selected_box_dd, | |
| label_dd, | |
| fraction_value_tb, | |
| custom_label_tb, | |
| relationship_tb, | |
| ], | |
| ) | |
| _sync_outputs = [ | |
| app_state, json_editor, selected_box_dd, | |
| label_dd, fraction_value_tb, custom_label_tb, relationship_tb, | |
| ] | |
| annotator.change( | |
| sync_canvas_to_state, | |
| inputs=[annotator, app_state], | |
| outputs=_sync_outputs, | |
| queue=False, | |
| show_progress="hidden", | |
| ) | |
| annotator.clear( | |
| sync_canvas_to_state, | |
| inputs=[annotator, app_state], | |
| outputs=_sync_outputs, | |
| queue=False, | |
| show_progress="hidden", | |
| ) | |
| clicked_box_idx.change( | |
| on_box_clicked, | |
| inputs=[clicked_box_idx, app_state], | |
| outputs=[selected_box_dd, label_dd, fraction_value_tb, custom_label_tb, app_state], | |
| queue=False, | |
| show_progress="hidden", | |
| ) | |
| # ── Token verification wiring ── | |
| token_btn.click( | |
| check_token, | |
| inputs=[token_input], | |
| outputs=[main_app, token_msg], | |
| queue=False, | |
| ) | |
| selected_box_dd.change( | |
| on_selected_box_change, | |
| inputs=[selected_box_dd, app_state], | |
| outputs=[label_dd, fraction_value_tb, custom_label_tb], | |
| queue=False, | |
| ) | |
| label_dd.change( | |
| on_label_change, | |
| inputs=[label_dd], | |
| outputs=[fraction_value_tb, custom_label_tb], | |
| queue=False, | |
| ) | |
| apply_label_btn.click( | |
| apply_label_to_selected, | |
| inputs=[selected_box_dd, label_dd, fraction_value_tb, custom_label_tb, app_state], | |
| outputs=[ | |
| annotator, | |
| app_state, | |
| json_editor, | |
| label_dd, | |
| fraction_value_tb, | |
| custom_label_tb, | |
| relationship_tb, | |
| ], | |
| queue=False, | |
| ) | |
| save_relationship_btn.click( | |
| update_relationship, | |
| inputs=[relationship_tb, app_state], | |
| outputs=[app_state, json_editor], | |
| queue=False, | |
| ) | |
| load_json_btn.click( | |
| load_json_text, | |
| inputs=[json_editor, app_state], | |
| outputs=[ | |
| annotator, | |
| app_state, | |
| json_editor, | |
| selected_box_dd, | |
| relationship_tb, | |
| label_dd, | |
| fraction_value_tb, | |
| custom_label_tb, | |
| ], | |
| queue=False, | |
| ) | |
| mode_selector.change( | |
| lambda mode: gr.update(visible=(mode == "Binary error only")), | |
| inputs=[mode_selector], | |
| outputs=[error_dropdown], | |
| queue=False, | |
| ) | |
| run_diagnosis_btn.click( | |
| run_diagnosis, | |
| inputs=[input_selector, mode_selector, analysis_selector, error_dropdown, annotator, app_state], | |
| outputs=[diagnosis_output], | |
| ) | |
| download_json_btn.click(create_json_download, inputs=[json_editor], outputs=[download_json_btn]) | |
| def load_all_files_in_folder(folder_path, extension=None): | |
| """ | |
| Load all file paths from a given folder. | |
| Optionally filter by file extension (e.g., '.md'). | |
| """ | |
| file_paths = [] | |
| if not os.path.exists(folder_path): | |
| print(f"❌ Folder not found: {folder_path}") | |
| return file_paths | |
| for root, dirs, files in os.walk(folder_path): | |
| for file in files: | |
| if extension: | |
| if file.endswith(extension): | |
| file_paths.append(os.path.join(root, file)) | |
| else: | |
| file_paths.append(os.path.join(root, file)) | |
| return file_paths | |
| CACHE_FILE = "gemini_file_cache.json" | |
| def load_cache(): | |
| if os.path.exists(CACHE_FILE): | |
| try: | |
| with open(CACHE_FILE, "r") as f: | |
| return json.load(f) | |
| except: | |
| return {} | |
| return {} | |
| def save_cache(cache): | |
| with open(CACHE_FILE, "w") as f: | |
| json.dump(cache, f, indent=2) | |
| def setup_knowledge_base_gemini(file_path): | |
| """ | |
| Uploads a file to Gemini to be used as a knowledge base. | |
| Returns the uploaded file object. | |
| """ | |
| if not os.path.exists(file_path): | |
| print(f"❌ File not found: {file_path}") | |
| return None | |
| cache = load_cache() | |
| abs_path = os.path.abspath(file_path) | |
| # Check cache for existing file | |
| if abs_path in cache: | |
| cached_info = cache[abs_path] | |
| file_name = cached_info.get("name") | |
| upload_time = cached_info.get("time", 0) | |
| # Check if within 47 hours (leaving 1 hour buffer) | |
| if time.time() - upload_time < 47 * 3600: | |
| try: | |
| # Verify if file still exists on Gemini | |
| try: | |
| file = gemini_client.files.get(name=file_name) | |
| state_name = getattr(file.state, "name", str(file.state)) | |
| if state_name == "ACTIVE": | |
| print(f"✅ Using cached file: {file.name}") | |
| return file | |
| else: | |
| print(f"⚠️ Cached file found but state is {state_name}, re-uploading...") | |
| except Exception as e: | |
| # Specific error handling for permission/not found | |
| if "403" in str(e) or "404" in str(e) or "PERMISSION_DENIED" in str(e): | |
| print("⚠️ Cached file expired or not accessible (403/404), re-uploading...") | |
| else: | |
| raise e # Re-raise other unexpected errors | |
| except Exception as e: | |
| print(f"⚠️ Error verifying cached file: {e}, re-uploading...") | |
| else: | |
| print("⚠️ Cached file expired (>47h), re-uploading...") | |
| print(f"Uploading file to Gemini: {file_path}...") | |
| try: | |
| # Upload the file to Gemini | |
| # Use gemini_client.files.upload which is standard for google-genai SDK | |
| # The argument name for the file path or object is typically 'file' | |
| file = gemini_client.files.upload(file=file_path, config={'display_name': os.path.basename(file_path)}) | |
| print(f"✅ File uploaded: {file.name}") | |
| # Wait for file to be active | |
| while getattr(file.state, "name", str(file.state)) == "PROCESSING": | |
| print("Processing file...", end='\r') | |
| time.sleep(2) | |
| file = gemini_client.files.get(name=file.name) | |
| if getattr(file.state, "name", str(file.state)) == "FAILED": | |
| print("❌ File processing failed.") | |
| return None | |
| print(f"File ready: {file.uri}") | |
| # Update cache | |
| cache[abs_path] = { | |
| "name": file.name, | |
| "uri": file.uri, | |
| "time": time.time() | |
| } | |
| save_cache(cache) | |
| return file | |
| except Exception as e: | |
| print(f"❌ Failed to upload file: {file_path}") | |
| print(f" ↳ {_extract_gemini_error_details(e)}") | |
| return None | |
| def _extract_gemini_error_details(exception_obj): | |
| """ | |
| Build a readable, actionable summary from Gemini API exceptions. | |
| """ | |
| err_text = str(exception_obj) | |
| details = [f"raw_error={err_text}"] | |
| # Try to parse API error payload embedded in exception text. | |
| if "{" in err_text and "}" in err_text: | |
| try: | |
| payload = json.loads(err_text[err_text.find("{"):]) | |
| err = payload.get("error", {}) | |
| code = err.get("code") | |
| status = err.get("status") | |
| message = err.get("message") | |
| if code is not None: | |
| details.append(f"code={code}") | |
| if status: | |
| details.append(f"status={status}") | |
| if message: | |
| details.append(f"message={message}") | |
| for item in err.get("details", []): | |
| metadata = item.get("metadata", {}) | |
| reason = item.get("reason") | |
| if reason: | |
| details.append(f"reason={reason}") | |
| if metadata.get("consumer"): | |
| details.append(f"consumer={metadata.get('consumer')}") | |
| if metadata.get("service"): | |
| details.append(f"service={metadata.get('service')}") | |
| if metadata.get("activationUrl"): | |
| details.append(f"activation_url={metadata.get('activationUrl')}") | |
| except Exception: | |
| # Keep fallback raw error text only. | |
| pass | |
| upper_err = err_text.upper() | |
| if "SERVICE_DISABLED" in upper_err or "GENERATIVELANGUAGE.GOOGLEAPIS.COM" in upper_err: | |
| details.append( | |
| "hint=Generative Language API may be disabled for this Google Cloud project. " | |
| "Enable it in Google Cloud Console and retry after propagation." | |
| ) | |
| if "PERMISSION_DENIED" in upper_err: | |
| details.append( | |
| "hint=API key may not belong to the same project as cached files, or key lacks required permissions." | |
| ) | |
| if "403" in upper_err: | |
| details.append("hint=403 usually means API disabled, wrong project, or insufficient permission.") | |
| if "404" in upper_err or "NOT_FOUND" in upper_err: | |
| details.append("hint=Cached remote file ID no longer exists; re-upload is required.") | |
| # Deduplicate while preserving order. | |
| seen = set() | |
| unique_details = [] | |
| for item in details: | |
| if item in seen: | |
| continue | |
| seen.add(item) | |
| unique_details.append(item) | |
| return " | ".join(unique_details) | |
| def _resolve_definitions_dir() -> str | None: | |
| here = Path(__file__).resolve() | |
| candidates = [ | |
| here.parent / "Definitions", | |
| here.parent.parent / "Definitions", | |
| here.parent.parent.parent / "Definitions", | |
| Path.cwd() / "Definitions", | |
| ] | |
| for c in candidates: | |
| if c.exists() and c.is_dir(): | |
| return str(c) | |
| return None | |
| def getKnowledgeBaseFiles(kw_file): | |
| gemini_cache = load_cache() | |
| md_files = [] | |
| if kw_file == "def": | |
| def_dir = _resolve_definitions_dir() | |
| if def_dir is None: | |
| print("❌ Definitions folder not found") | |
| return None | |
| md_files = load_all_files_in_folder(def_dir, extension=".md") | |
| else: | |
| print(f"❌ Invalid keyword file: {kw_file}") | |
| return None | |
| knowledge_base = [] | |
| for item in md_files: | |
| abs_path = os.path.abspath(item) | |
| file = None | |
| if abs_path in gemini_cache: | |
| cached_name = gemini_cache[abs_path].get("name") | |
| try: | |
| file = gemini_client.files.get(name=cached_name) | |
| except Exception as e: | |
| # Common when switching API key/account: cached file id is not accessible. | |
| print(f"⚠️ Cached Gemini file is not accessible for {item}") | |
| print(f" ↳ {_extract_gemini_error_details(e)}") | |
| file = None | |
| else: | |
| print(f"⚠️ File not found in gemini cache, uploading: {item}") | |
| # Re-upload if cache miss, get failed, or file is not active. | |
| if file is None: | |
| print(f"Reuploading file: {item}") | |
| file = setup_knowledge_base_gemini(item) | |
| if file is None: | |
| print(f"❌ Failed to upload file: {item}") | |
| print("❌ Knowledge base loading aborted: one or more files could not be uploaded to Gemini.") | |
| return None | |
| knowledge_base.append(file) | |
| continue | |
| state_name = getattr(file.state, "name", str(file.state)) | |
| if state_name == "ACTIVE": | |
| knowledge_base.append(file) | |
| else: | |
| print(f"⚠️ Cached file state is {state_name}, reuploading: {item}") | |
| file = setup_knowledge_base_gemini(item) | |
| if file: | |
| knowledge_base.append(file) | |
| else: | |
| print(f"❌ Failed to upload file: {item}") | |
| print("❌ Knowledge base loading aborted: one or more files could not be uploaded to Gemini.") | |
| return None | |
| return knowledge_base | |
| def parse_json_response(text: str, kw = "error_types") -> dict: | |
| """ | |
| Parses the JSON output from the LLM response based on the defined prompts. | |
| Handles both 'with_analysis' (list of dicts) and 'without_analysis' (list of strings) formats. | |
| Robustly searches for JSON objects in the text using a JSON decoder to handle nested structures and noise. | |
| """ | |
| if not text: | |
| return {kw: []} | |
| decoder = json.JSONDecoder() | |
| candidates = [] | |
| # 1. Try to extract content inside ```json ... ``` or ``` ... ``` | |
| code_block_pattern = r"```(?:json)?\s*(.*?)\s*```" | |
| for match in re.finditer(code_block_pattern, text, re.DOTALL): | |
| try: | |
| block_content = match.group(1).strip() | |
| candidate = json.loads(block_content) | |
| candidates.append(candidate) | |
| except json.JSONDecodeError: | |
| pass | |
| # 2. Scan the text for any valid JSON objects starting with '{' | |
| # This avoids regex fragility with LaTeX braces or other content | |
| idx = 0 | |
| while idx < len(text): | |
| next_brace = text.find('{', idx) | |
| if next_brace == -1: | |
| break | |
| try: | |
| # raw_decode parses a JSON object starting at the given index | |
| obj, end_idx = decoder.raw_decode(text, next_brace) | |
| candidates.append(obj) | |
| idx = end_idx | |
| except json.JSONDecodeError: | |
| # If this { didn't start a valid JSON, move to the next character | |
| idx = next_brace + 1 | |
| # 3. Select the best candidate | |
| best_candidate = None | |
| # Priority: Contains the expected keyword (or common typo) | |
| for cand in candidates: | |
| if isinstance(cand, dict): | |
| if kw in cand: | |
| best_candidate = cand | |
| elif "error_type" in cand: | |
| # Handle singular variant mismatch | |
| best_candidate = {kw: cand["error_type"]} | |
| # If no ideal candidate, try to use the last valid dict found | |
| if not best_candidate and candidates: | |
| dict_candidates = [c for c in candidates if isinstance(c, dict)] | |
| if dict_candidates: | |
| best_candidate = dict_candidates[-1] | |
| if kw not in best_candidate: | |
| print(f"Warning: Selected JSON candidate missing '{kw}'. Keys: {list(best_candidate.keys())}") | |
| else: | |
| # Check for list candidates (wrapped) | |
| list_candidates = [c for c in candidates if isinstance(c, list)] | |
| if list_candidates: | |
| best_candidate = {kw: list_candidates[-1]} | |
| # 4. Fallback: Parse whole text (if it wasn't caught by scan) | |
| if not best_candidate: | |
| try: | |
| clean_text = text.strip() | |
| # Basic check if it looks like JSON | |
| if clean_text.startswith('{') or clean_text.startswith('['): | |
| data = json.loads(clean_text) | |
| if isinstance(data, dict): | |
| best_candidate = data | |
| if "error_type" in data and kw not in data: | |
| best_candidate = {kw: data["error_type"]} | |
| elif isinstance(data, list): | |
| best_candidate = {kw: data} | |
| except: | |
| pass | |
| if best_candidate: | |
| # Final safety check for return structure | |
| if kw not in best_candidate: | |
| # If we have a dict but missing key, just return it (or wrap empty?) | |
| # The caller usually checks 'if key not in data'. | |
| pass | |
| return best_candidate | |
| print(f"❌ Could not find valid JSON in response.") | |
| # print(f"Raw text start: {text[:200]}...") | |
| return {kw: ["Not Able to parse"]} | |
| if __name__ == "__main__": | |
| GEMINI_KB_FILES = getKnowledgeBaseFiles("def") | |
| if GEMINI_KB_FILES is None: | |
| print("❌ Failed to get the knowledge base for the definitions") | |
| exit() | |
| print(f"[MathNet] All {len(GEMINI_KB_FILES)} KB files loaded and ready to analyze") | |
| demo.launch() |