MathNet_Dia / app.py
wzzanthony7's picture
Update app.py
6b3c892 verified
import copy
import json
import os
import re
import sys
import tempfile
import time
from pathlib import Path
from typing import Any
import gradio as gr
from google import genai
try:
import numpy as np
except ImportError:
np = None # type: ignore
from PIL import Image
# Avoid Ultralytics warning on HF Spaces / read-only home (before first YOLO use).
os.environ.setdefault("YOLO_CONFIG_DIR", str(Path(tempfile.gettempdir()) / "Ultralytics"))
gemini_api_key = os.environ.get("Gemini_API")
VERIFY_TOKEN = os.environ.get("VERIFY_TOKEN")
gemini_client = genai.Client(api_key=gemini_api_key)
from ultralytics import YOLO
from gradio_bbox_annotator import BBoxAnnotator
import gradio_bbox_annotator.bbox_annotator as _bbox_annotator_mod
from gradio.data_classes import FileData, GradioModel
from pydantic import field_validator
def analyze_student_work_gemini_img(image_path, prompt=None, analysis_prompt=None, knowledge_base_files=None, model_name="gemini-3-pro-preview", debug_mode=False):
"""
Analyzes student work using an image with Gemini.
Optionally accepts knowledge_base_files (single Gemini file object or list of objects).
"""
if prompt is None:
prompt = basic_prompt
if analysis_prompt is None:
analysis_prompt = with_analysis_prompt
if not os.path.exists(image_path):
print(f"❌ Image file does not exist: {image_path}")
return None
print(f"📤 Loading image: {image_path}")
try:
img = Image.open(image_path)
except Exception as e:
print(f"❌ Failed to load image: {e}")
return None
instruction = prompt + "\n" + analysis_prompt
if debug_mode:
print(f"instruction: {instruction}")
print("⏳ Analyzing with Gemini (Image)...")
content = [instruction, img]
if knowledge_base_files:
if isinstance(knowledge_base_files, list):
content.extend(knowledge_base_files)
else:
content.append(knowledge_base_files)
max_retries = 3
retryable_error_markers = (
"429",
"RESOURCE_EXHAUSTED",
"UNAVAILABLE",
"DEADLINE_EXCEEDED",
"INTERNAL",
"500",
"503",
"TIMEOUT",
)
for attempt in range(1, max_retries + 1):
try:
# Pass both text instruction, image, and optional KB file(s)
response = gemini_client.models.generate_content(
model=model_name,
contents=content
)
# Primary path
text = None
try:
text = response.text
except Exception as e:
# Some blocked/empty responses may raise when accessing .text
if debug_mode:
print(f"⚠️ response.text access error: {e}")
if text and text.strip():
return text
# Fallback: attempt to reconstruct text from response candidates
fallback_parts = []
finish_reasons = []
candidates = getattr(response, "candidates", None) or []
for cand in candidates:
finish_reason = getattr(cand, "finish_reason", None)
if finish_reason is not None:
finish_reasons.append(str(finish_reason))
cand_content = getattr(cand, "content", None)
parts = getattr(cand_content, "parts", None) if cand_content else None
if not parts:
continue
for part in parts:
part_text = getattr(part, "text", None)
if part_text and part_text.strip():
fallback_parts.append(part_text)
if fallback_parts:
merged = "\n".join(fallback_parts).strip()
if merged:
print("⚠️ response.text is empty; using candidate parts fallback.")
return merged
print(
f"❌ No text output from Gemini (attempt {attempt}/{max_retries}). "
f"finish_reasons={finish_reasons if finish_reasons else 'N/A'}"
)
if attempt < max_retries:
sleep_s = 2 ** (attempt - 1)
time.sleep(sleep_s)
continue
return None
except Exception as e:
err = str(e)
is_retryable = any(marker in err.upper() for marker in retryable_error_markers)
print(f"❌ Gemini API call failed (attempt {attempt}/{max_retries}): {e}")
if is_retryable and attempt < max_retries:
sleep_s = 2 ** (attempt - 1)
print(f"⏳ Retrying after {sleep_s}s...")
time.sleep(sleep_s)
continue
return None
def analyze_student_work_wout_image_gemini(json_data, prompt=None, analysis_prompt=None, knowledge_base_files=None, model_name="gemini-3-pro-preview", debug_mode=False):
"""
Analyzes student work using ground truth JSON data with Gemini.
Optionally accepts knowledge_base_files (single Gemini file object or list of objects).
"""
if prompt is None:
prompt = basic_prompt_wout_image
if analysis_prompt is None:
analysis_prompt = with_analysis_prompt
if not json_data:
print("❌ json_data is empty")
return None
json_ret_str = json.dumps(json_data, indent=2, ensure_ascii=False)
instruction = f"""{prompt}
{analysis_prompt}
INPUT_JSON:
{json_ret_str}"""
if debug_mode:
print(f"instruction: {instruction}")
print("⏳ Analyzing with Gemini (JSON)...")
content = [instruction]
if knowledge_base_files:
if isinstance(knowledge_base_files, list):
content.extend(knowledge_base_files)
else:
content.append(knowledge_base_files)
max_retries = 3
retryable_error_markers = (
"429",
"RESOURCE_EXHAUSTED",
"UNAVAILABLE",
"DEADLINE_EXCEEDED",
"INTERNAL",
"500",
"503",
"TIMEOUT",
)
for attempt in range(1, max_retries + 1):
try:
response = gemini_client.models.generate_content(
model=model_name,
contents=content
)
# Primary path
text = None
try:
text = response.text
except Exception as e:
if debug_mode:
print(f"⚠️ response.text access error: {e}")
if text and text.strip():
return text
# Fallback: attempt to reconstruct text from response candidates
fallback_parts = []
finish_reasons = []
candidates = getattr(response, "candidates", None) or []
for cand in candidates:
finish_reason = getattr(cand, "finish_reason", None)
if finish_reason is not None:
finish_reasons.append(str(finish_reason))
cand_content = getattr(cand, "content", None)
parts = getattr(cand_content, "parts", None) if cand_content else None
if not parts:
continue
for part in parts:
part_text = getattr(part, "text", None)
if part_text and part_text.strip():
fallback_parts.append(part_text)
if fallback_parts:
merged = "\n".join(fallback_parts).strip()
if merged:
print("⚠️ response.text is empty; using candidate parts fallback.")
return merged
print(
f"❌ No text output from Gemini (attempt {attempt}/{max_retries}). "
f"finish_reasons={finish_reasons if finish_reasons else 'N/A'}"
)
if attempt < max_retries:
sleep_s = 2 ** (attempt - 1)
time.sleep(sleep_s)
continue
return None
except Exception as e:
err = str(e)
is_retryable = any(marker in err.upper() for marker in retryable_error_markers)
print(f"❌ Gemini API call failed (attempt {attempt}/{max_retries}): {e}")
if is_retryable and attempt < max_retries:
sleep_s = 2 ** (attempt - 1)
print(f"⏳ Retrying after {sleep_s}s...")
time.sleep(sleep_s)
continue
return None
def _patch_bbox_annotator_coerce_float_coords() -> None:
"""
Gradio 5 + pydantic v2: BBoxAnnotator's data model expects int bbox coords.
The annotator frontend sends floats → ValidationError before our callbacks run.
Replace Annotation/AnnotatedImage with versions that round to int on validate.
"""
try:
class Annotation(GradioModel):
left: int
top: int
right: int
bottom: int
label: str | None
@field_validator("left", "top", "right", "bottom", mode="before")
@classmethod
def _round_to_int(cls, v: Any) -> int:
if v is None:
return 0
return int(round(float(v)))
@property
def width(self) -> int:
return self.right - self.left
@property
def height(self) -> int:
return self.bottom - self.top
class AnnotatedImage(GradioModel):
image: FileData
annotations: list[Annotation]
_bbox_annotator_mod.Annotation = Annotation
_bbox_annotator_mod.AnnotatedImage = AnnotatedImage
BBoxAnnotator.data_model = AnnotatedImage
except Exception as exc:
print(f"[MathNet] BBoxAnnotator float→int patch skipped: {exc}")
_patch_bbox_annotator_coerce_float_coords()
STANDARD_LABELS = ["tick", "fraction", "zero", "one", "customize_label"]
ERROR_OPTIONS = [
"Error 1: Unequal segmentation",
"Error 2: Wrong number of segments (denominator)",
"Error 3: Wrong number of segments chosen (numerator)",
"Error 4: Incorrect tick labels (confused with whole numbers)",
"Error 5: Incorrect tick labels (careless)",
"Error 6: Incorrect 0 or 1s",
"Error 7: Unit of 1 and subunits",
]
MODEL_PATHS = {
"yolo_accurate": "./vt_dataset_yolov12_v7_weights.pt",
"yolo_extra_large": "./VT_dataset_2_Yolov12_Extra_large.pt",
}
# ---------------------------------------------------------------------------
# Import prompts from variables.py (tempScript/) – fall back to minimal stubs
# ---------------------------------------------------------------------------
def _resolve_temp_script_dir() -> Path | None:
here = Path(__file__).resolve()
candidates = [
here.parent / "tempScript",
here.parent.parent / "tempScript",
here.parent.parent.parent / "tempScript",
Path.cwd() / "tempScript",
]
for candidate in candidates:
if candidate.exists() and candidate.is_dir():
return candidate
return None
TEMP_SCRIPT_DIR = _resolve_temp_script_dir()
if TEMP_SCRIPT_DIR and str(TEMP_SCRIPT_DIR) not in sys.path:
sys.path.append(str(TEMP_SCRIPT_DIR))
try:
from variables import (
basic_prompt,
basic_prompt_wout_image,
advanced_prompt_wout_image,
with_analysis_prompt,
without_analysis_prompt,
gemini_with_analysis_multiple_error_prompt,
gemini_without_analysis_multiple_error_prompt,
get_gemini_one_error_prompt_with_analysis,
get_gemini_one_error_prompt_without_analysis,
)
except Exception as _imp_err:
print(f"[MathNet] Could not import prompts from variables.py: {_imp_err}")
basic_prompt = "Please analyze the error type of the student's work based on the definition in the knowledge base."
basic_prompt_wout_image = "You are analyzing a student's number-line fraction homework."
advanced_prompt_wout_image = basic_prompt_wout_image
with_analysis_prompt = "Output ONLY a valid JSON object with error_types array."
without_analysis_prompt = with_analysis_prompt
gemini_with_analysis_multiple_error_prompt = with_analysis_prompt
gemini_without_analysis_multiple_error_prompt = without_analysis_prompt
def get_gemini_one_error_prompt_with_analysis(t: str) -> str:
return f"Check ONLY for {t}. Output JSON with analysis and error_type array."
def get_gemini_one_error_prompt_without_analysis(t: str) -> str:
return f"Check ONLY for {t}. Output JSON with error_type array."
# ---------------------------------------------------------------------------
# Load knowledge base markdown files from Definitions/ folder
# ---------------------------------------------------------------------------
def _load_knowledge_base() -> str:
"""Read all Error *.md files from the Definitions folder and concatenate into a single KB string."""
here = Path(__file__).resolve()
candidates = [
here.parent / "Definitions",
here.parent.parent / "Definitions",
here.parent.parent.parent / "Definitions",
Path.cwd() / "Definitions",
]
kb_dir: Path | None = None
for c in candidates:
if c.exists() and c.is_dir():
kb_dir = c
break
if kb_dir is None:
print("[MathNet] Definitions folder not found – knowledge base will be empty.")
return ""
md_files = sorted(kb_dir.glob("*.md"))
if not md_files:
print(f"[MathNet] No .md files in {kb_dir}")
return ""
parts: list[str] = []
for f in md_files:
parts.append(f.read_text(encoding="utf-8").strip())
kb_text = "\n\n---\n\n".join(parts)
print(f"[MathNet] Loaded knowledge base: {len(md_files)} files, {len(kb_text)} chars")
return kb_text
KNOWLEDGE_BASE_TEXT = _load_knowledge_base()
GEMINI_KB_FILES: list | None = None
def empty_schema() -> dict[str, Any]:
return {
"gt_info": {
"fraction": {"bbox": [], "values": []},
"one": {"bbox": []},
"tick": {"bbox": []},
"zero": {"bbox": []},
"relationship": "",
}
}
def empty_ui_state() -> dict[str, Any]:
return {"image": None, "boxes": [], "relationship": "", "next_id": 0, "_last_json_ui_fp": None, "_selected_box_idx": None}
def ensure_placeholder_image_path() -> str:
"""Create a tiny valid image file for BBoxAnnotator initial value."""
placeholder = Path(tempfile.gettempdir()) / "mathnet_placeholder.png"
if not placeholder.exists():
Image.new("RGB", (8, 8), color=(245, 245, 245)).save(placeholder)
return str(placeholder)
def coerce_image_path(value: Any) -> str | None:
"""
Gradio 5 / Spaces may pass filepath as str, Path, or FileData-like dict.
BBoxAnnotator postprocess expects a real path string.
"""
if value is None:
return None
if isinstance(value, Path):
p = str(value)
return p if p.strip() else None
if isinstance(value, str):
return value if value.strip() else None
if isinstance(value, dict):
for key in ("path", "name", "url"):
v = value.get(key)
if isinstance(v, str) and v.strip():
return v
return None
if hasattr(value, "path") and isinstance(getattr(value, "path", None), str):
p = getattr(value, "path")
return p if p.strip() else None
return None
def to_python_scalar(x: Any) -> Any:
"""Ensure gr.State / JSON never sees numpy scalars."""
if np is not None:
if isinstance(x, np.generic):
return x.item()
if isinstance(x, np.ndarray):
return x.tolist()
return x
def ensure_list_boxes(raw: Any) -> list[Any]:
if raw is None:
return []
if isinstance(raw, list):
return raw
if isinstance(raw, tuple):
return list(raw)
# Some component versions may stringify annotations
if isinstance(raw, str):
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, list) else []
except Exception:
return []
return []
def _indices_left_to_right(boxes: list[dict[str, Any]], label_key: str) -> dict[int, int]:
"""Map box list index -> 0-based order index for boxes with given label, sorted by xmin (left to right)."""
pairs = [(i, b) for i, b in enumerate(boxes) if b.get("label") == label_key]
pairs.sort(key=lambda ib: float(ib[1]["xmin"]))
return {i: j for j, (i, _) in enumerate(pairs)}
def _center_y(box: dict[str, Any]) -> float:
ymin = float(box.get("ymin", 0.0))
ymax = float(box.get("ymax", 0.0))
return (ymin + ymax) / 2.0
def annotation_display_text(
box: dict[str, Any],
box_index: int,
tick_map: dict[int, int],
frac_map: dict[int, int],
) -> str:
"""
Text on each bbox outline in the annotator (T0,T1,… / F0-value / zero / …).
box['label'] in state remains the canonical type (tick, fraction, …); this string is display-only.
"""
lbl = box.get("label", "tick")
if lbl == "tick":
return f"T{tick_map.get(box_index, 0)}"
if lbl == "fraction":
fi = frac_map.get(box_index, 0)
fv = (box.get("fraction_value") or "").strip()
return f"F{fi}-{fv}"
if lbl == "zero":
return "zero"
if lbl == "one":
return "one"
if lbl == "customize_label":
cn = (box.get("custom_label") or "").strip()
return cn if cn else "customize_label"
return str(lbl)
def canonicalize_incoming_annotation_label(raw: Any, prev: dict[str, Any]) -> dict[str, Any]:
"""
Map annotator label string back to internal label + fields.
Accepts category names from the tool (tick, fraction, …) or display strings (T0, F1-3/4, …).
"""
prev = prev or {}
s = (raw if raw is not None else "").strip()
if not s:
return {
"label": str(prev.get("label", "tick")),
"fraction_value": str(prev.get("fraction_value", "") or ""),
"custom_label": str(prev.get("custom_label", "") or ""),
}
if s in STANDARD_LABELS:
return {
"label": s,
"fraction_value": (prev.get("fraction_value") or "").strip() if s == "fraction" else "",
"custom_label": (prev.get("custom_label") or "").strip() if s == "customize_label" else "",
}
m = re.match(r"^T(\d+)$", s)
if m:
return {"label": "tick", "fraction_value": "", "custom_label": ""}
m = re.match(r"^F(\d+)-(.*)$", s, re.DOTALL)
if m:
return {"label": "fraction", "fraction_value": m.group(2).strip(), "custom_label": ""}
if s == "zero":
return {"label": "zero", "fraction_value": "", "custom_label": ""}
if s == "one":
return {"label": "one", "fraction_value": "", "custom_label": ""}
return {"label": "customize_label", "fraction_value": "", "custom_label": s}
_BBOX_OVERLAY_CSS = """\
.box-preview[data-display] { overflow: visible !important; }
.box-preview[data-display]::before {
content: attr(data-display);
position: absolute;
top: -16px;
left: 0;
font-size: 11px;
font-weight: 700;
color: #fff;
padding: 1px 4px;
border-radius: 3px;
white-space: nowrap;
pointer-events: none;
z-index: 10;
line-height: 14px;
}
.box-preview[data-label="tick"]::before { background: hsl(180,100%,35%); }
.box-preview[data-label="fraction"]::before { background: hsl(225,100%,35%); }
.box-preview[data-label="zero"]::before { background: hsl(270,100%,35%); }
.box-preview[data-label="one"]::before { background: hsl(315,100%,35%); }
.box-preview[data-label="customize_label"]::before { background: hsl(0,100%,35%); }
"""
_BBOX_OVERLAY_JS = """\
() => {
function refresh() {
const boxes = document.querySelectorAll('.box-preview[data-label]');
if (!boxes.length) return;
const byLabel = {};
boxes.forEach(el => {
const lbl = el.getAttribute('data-label') || '';
if (!byLabel[lbl]) byLabel[lbl] = [];
byLabel[lbl].push(el);
});
for (const lbl of Object.keys(byLabel)) {
const group = byLabel[lbl];
group.sort((a, b) => parseFloat(a.style.left) - parseFloat(b.style.left));
group.forEach((el, idx) => {
let txt;
if (lbl === 'tick') txt = 'T' + idx;
else if (lbl === 'fraction') txt = 'F' + idx;
else txt = lbl;
if (el.getAttribute('data-display') !== txt)
el.setAttribute('data-display', txt);
});
}
}
setInterval(refresh, 300);
const tryObs = setInterval(() => {
const target = document.querySelector('.image-frame');
if (target) {
clearInterval(tryObs);
new MutationObserver(refresh).observe(target,
{ childList: true, subtree: true, attributes: true,
attributeFilter: ['style', 'data-label'] });
refresh();
}
}, 200);
/* Click-to-select: write clicked box index into a hidden Gradio textbox (debounced) */
let _lastClickTime = 0;
document.addEventListener('pointerdown', function(e) {
const now = Date.now();
if (now - _lastClickTime < 300) return;
const box = e.target.closest('.box-preview');
if (!box) return;
_lastClickTime = now;
const allBoxes = Array.from(document.querySelectorAll('.box-preview'));
const idx = allBoxes.indexOf(box);
if (idx < 0) return;
const ta = document.querySelector('#clicked-box-idx textarea');
if (!ta) return;
const setter = Object.getOwnPropertyDescriptor(
window.HTMLTextAreaElement.prototype, 'value'
).set;
setter.call(ta, String(idx) + '_' + now);
ta.dispatchEvent(new Event('input', {bubbles: true}));
}, true);
}
"""
def compute_auto_relationship(state: dict[str, Any]) -> str:
"""
For each tick (left-to-right by xmin), pair with the fraction whose vertical center is closest.
Output format: T0-F1,T1-F0,...
"""
boxes = state.get("boxes", [])
tick_map = _indices_left_to_right(boxes, "tick")
frac_map = _indices_left_to_right(boxes, "fraction")
if not tick_map or not frac_map:
return ""
tick_indices = sorted(tick_map.keys(), key=lambda i: float(boxes[i]["xmin"]))
frac_entries: list[tuple[int, float, int]] = []
for i in frac_map.keys():
frac_entries.append((i, _center_y(boxes[i]), frac_map[i]))
parts: list[str] = []
for ti in tick_indices:
cy_t = _center_y(boxes[ti])
best = min(
frac_entries,
key=lambda e: (abs(e[1] - cy_t), float(boxes[e[0]]["xmin"])),
)
t_ord = tick_map[ti]
f_ord = best[2]
parts.append(f"T{t_ord}-F{f_ord}")
return ",".join(parts)
def clamp_box(box: dict[str, Any]) -> dict[str, Any]:
xmin = float(box.get("xmin", 0.0))
ymin = float(box.get("ymin", 0.0))
xmax = float(box.get("xmax", 0.0))
ymax = float(box.get("ymax", 0.0))
if xmax < xmin:
xmin, xmax = xmax, xmin
if ymax < ymin:
ymin, ymax = ymax, ymin
return {"xmin": xmin, "ymin": ymin, "xmax": xmax, "ymax": ymax}
def normalize_annotator_boxes(raw_boxes: list[dict[str, Any]], current_state: dict[str, Any]) -> list[dict[str, Any]]:
by_id = {b["id"]: b for b in current_state.get("boxes", []) if "id" in b}
existing_boxes = current_state.get("boxes", [])
out: list[dict[str, Any]] = []
next_id = int(current_state.get("next_id", 0))
used_existing_ids: set[int] = set()
def _coord_distance(a: dict[str, Any], b: dict[str, Any]) -> float:
return (
abs(a["xmin"] - b["xmin"])
+ abs(a["ymin"] - b["ymin"])
+ abs(a["xmax"] - b["xmax"])
+ abs(a["ymax"] - b["ymax"])
)
for i, box in enumerate(raw_boxes or []):
normalized = clamp_box(box)
box_id = box.get("id")
incoming_label = box.get("label")
incoming_canon_for_match = canonicalize_incoming_annotation_label(incoming_label, {})
incoming_canon_label = incoming_canon_for_match.get("label", "tick")
if box_id is None or box_id not in by_id or box_id in used_existing_ids:
best_id = None
best_dist = float("inf")
for prev_box in existing_boxes:
pid = prev_box.get("id")
if pid is None or pid in used_existing_ids:
continue
prev_label = prev_box.get("label", "tick")
# Match on canonical label (T0 / tick / fraction / …).
if incoming_label is not None and str(incoming_canon_label) != str(prev_label):
continue
dist = _coord_distance(normalized, prev_box)
if dist < best_dist:
best_dist = dist
best_id = int(pid)
# High tolerance for dragging; if counts match exactly, force match to prevent ID churn.
if best_id is not None and (best_dist <= 200.0 or (raw_boxes and len(raw_boxes) == len(existing_boxes))):
box_id = best_id
prev = by_id.get(box_id, {})
else:
box_id = next_id
next_id += 1
prev = {}
else:
prev = by_id[box_id]
used_existing_ids.add(int(box_id))
incoming_canon = canonicalize_incoming_annotation_label(incoming_label, prev)
label = incoming_canon.get("label", prev.get("label", "tick"))
fraction_value = (incoming_canon.get("fraction_value") or "").strip()
custom_label = (incoming_canon.get("custom_label") or "").strip()
if label == "fraction" and not fraction_value:
fraction_value = (prev.get("fraction_value") or "").strip()
if label == "customize_label" and not custom_label:
custom_label = (prev.get("custom_label") or "").strip()
out.append(
{
"id": box_id,
"xmin": normalized["xmin"],
"ymin": normalized["ymin"],
"xmax": normalized["xmax"],
"ymax": normalized["ymax"],
"label": label,
"fraction_value": fraction_value,
"custom_label": custom_label,
}
)
# Preserve the order of boxes as they were in existing_boxes to prevent UI flickering on hover reorders
existing_ids = [b.get("id") for b in existing_boxes if "id" in b]
def _sort_key(b):
bid = b.get("id")
try:
return existing_ids.index(bid)
except ValueError:
return len(existing_ids) + 1
out.sort(key=_sort_key)
current_state["next_id"] = next_id
return out
def bbox_tuple_to_internal(raw_boxes: list[tuple[Any, Any, Any, Any, Any]], current_state: dict[str, Any]) -> list[dict[str, Any]]:
converted: list[dict[str, Any]] = []
for raw in raw_boxes or []:
if isinstance(raw, dict):
# Accept common dict forms from custom components.
label = raw.get("label") or raw.get("category") or "tick"
xmin = raw.get("xmin", raw.get("left"))
ymin = raw.get("ymin", raw.get("top"))
xmax = raw.get("xmax", raw.get("right"))
ymax = raw.get("ymax", raw.get("bottom"))
if None in (xmin, ymin, xmax, ymax):
continue
try:
converted.append(
{
"xmin": float(to_python_scalar(xmin)),
"ymin": float(to_python_scalar(ymin)),
"xmax": float(to_python_scalar(xmax)),
"ymax": float(to_python_scalar(ymax)),
"label": str(label),
}
)
except (TypeError, ValueError):
continue
elif isinstance(raw, (tuple, list)) and len(raw) >= 4:
label = raw[4] if len(raw) >= 5 and raw[4] else "tick"
try:
converted.append(
{
"xmin": float(to_python_scalar(raw[0])),
"ymin": float(to_python_scalar(raw[1])),
"xmax": float(to_python_scalar(raw[2])),
"ymax": float(to_python_scalar(raw[3])),
"label": str(label),
}
)
except (TypeError, ValueError):
continue
return normalize_annotator_boxes(converted, current_state)
def parse_annotator_payload(
annot_data: Any,
fallback_image: str | None,
) -> tuple[str | None, list[Any]]:
"""Normalize BBoxAnnotator payloads from different event paths."""
raw_img: Any = None
raw_boxes: Any = []
if isinstance(annot_data, tuple) and len(annot_data) >= 2:
raw_img, raw_boxes = annot_data[0], annot_data[1]
elif isinstance(annot_data, list) and len(annot_data) >= 2:
raw_img, raw_boxes = annot_data[0], annot_data[1]
elif isinstance(annot_data, dict):
raw_img = annot_data.get("image") or annot_data.get("path")
raw_boxes = annot_data.get("boxes")
if raw_boxes is None:
raw_boxes = annot_data.get("annotations", [])
else:
raw_img = None
raw_boxes = []
image_path = coerce_image_path(raw_img) or coerce_image_path(fallback_image)
boxes = ensure_list_boxes(raw_boxes)
return image_path, boxes
def dropdown_update_for_boxes(
state: dict[str, Any],
prefer_last: bool = False,
selected_index: int | None = None,
):
choices = to_box_choices(state)
if not choices:
# Keep dropdown value as empty string to avoid stale-value preprocess errors.
return gr.update(choices=[], value="")
if selected_index is not None and 0 <= selected_index < len(choices):
value = choices[selected_index]
else:
value = choices[-1] if prefer_last else choices[0]
return gr.update(choices=choices, value=value)
def sync_label_controls_for_box(state: dict[str, Any], selected_index: int | None) -> tuple[Any, Any, Any]:
"""
Keep 'Assign label' + fraction/custom fields aligned with the selected (or last) box.
Called after canvas sync so the category used in BBoxAnnotator is reflected in the side panel.
"""
boxes = state.get("boxes", [])
if not boxes:
return (
gr.update(value="tick"),
gr.update(value="", visible=False),
gr.update(value="", visible=False),
)
idx = selected_index if selected_index is not None else len(boxes) - 1
idx = max(0, min(int(idx), len(boxes) - 1))
box = boxes[idx]
raw = str(box.get("label", "tick"))
lbl = raw if raw in STANDARD_LABELS else "customize_label"
return (
gr.update(value=lbl),
gr.update(
visible=(lbl == "fraction"),
value=(box.get("fraction_value", "") or "") if lbl == "fraction" else "",
),
gr.update(
visible=(lbl == "customize_label"),
value=(box.get("custom_label", "") or "") if lbl == "customize_label" else "",
),
)
def _wire_label(box: dict[str, Any]) -> str:
"""Canonical label sent to BBoxAnnotator (controls colour + category buttons)."""
lbl = box.get("label", "tick")
if lbl in STANDARD_LABELS:
return lbl
return lbl
def state_to_annotator_value(state: dict[str, Any]) -> tuple[str | None, list[tuple[int, int, int, int, str]]]:
image_path = coerce_image_path(state.get("image")) or ensure_placeholder_image_path()
boxes = state.get("boxes", [])
annotations = [
(
int(round(b["xmin"])),
int(round(b["ymin"])),
int(round(b["xmax"])),
int(round(b["ymax"])),
_wire_label(b),
)
for b in boxes
]
return image_path, annotations
def serialize_to_schema(state: dict[str, Any]) -> dict[str, Any]:
schema = empty_schema()
gt_info = schema["gt_info"]
relationship = state.get("relationship", "")
gt_info["relationship"] = relationship
for b in state.get("boxes", []):
coords = [round(float(b["xmin"]), 2), round(float(b["ymin"]), 2), round(float(b["xmax"]), 2), round(float(b["ymax"]), 2)]
label = b.get("label", "tick")
if label == "fraction":
gt_info["fraction"]["bbox"].append(coords)
gt_info["fraction"]["values"].append(b.get("fraction_value", ""))
elif label in ("tick", "zero", "one"):
gt_info[label]["bbox"].append(coords)
elif label == "customize_label":
custom_name = (b.get("custom_label", "") or "customize_label").strip()
if custom_name not in gt_info:
gt_info[custom_name] = {"bbox": []}
gt_info[custom_name]["bbox"].append(coords)
else:
# Keep output shape consistent if labels were loaded from JSON custom keys.
if label not in gt_info:
gt_info[label] = {"bbox": []}
gt_info[label]["bbox"].append(coords)
return schema
def parse_schema_to_state(schema_like: dict[str, Any], image: Any, current_state: dict[str, Any]) -> dict[str, Any]:
gt_info = (schema_like or {}).get("gt_info", {})
boxes: list[dict[str, Any]] = []
next_id = int(current_state.get("next_id", 0))
def add_box(label: str, bbox: list[float], fraction_value: str = "", custom_label: str = "") -> None:
nonlocal next_id
if not isinstance(bbox, list) or len(bbox) != 4:
return
boxes.append(
{
"id": next_id,
"xmin": float(bbox[0]),
"ymin": float(bbox[1]),
"xmax": float(bbox[2]),
"ymax": float(bbox[3]),
"label": label,
"fraction_value": fraction_value,
"custom_label": custom_label,
}
)
next_id += 1
fractions = gt_info.get("fraction", {}) if isinstance(gt_info.get("fraction", {}), dict) else {}
fraction_bboxes = fractions.get("bbox", []) if isinstance(fractions.get("bbox", []), list) else []
fraction_values = fractions.get("values", []) if isinstance(fractions.get("values", []), list) else []
for i, bbox in enumerate(fraction_bboxes):
add_box("fraction", bbox, fraction_values[i] if i < len(fraction_values) else "")
for core_label in ("one", "tick", "zero"):
entry = gt_info.get(core_label, {})
if not isinstance(entry, dict):
continue
for bbox in entry.get("bbox", []) if isinstance(entry.get("bbox", []), list) else []:
add_box(core_label, bbox)
for key, value in gt_info.items():
if key in ("fraction", "one", "tick", "zero", "relationship"):
continue
if isinstance(value, dict):
for bbox in value.get("bbox", []) if isinstance(value.get("bbox", []), list) else []:
add_box("customize_label", bbox, custom_label=key)
return {
"image": image,
"boxes": boxes,
"relationship": str(gt_info.get("relationship", "")),
"next_id": next_id,
}
def state_to_json_text(state: dict[str, Any]) -> str:
return json.dumps(serialize_to_schema(state), indent=4, ensure_ascii=False)
def json_semantically_equal(a: str, b: str) -> bool:
"""Compare decoded JSON with sorted object keys (avoids harmless key-order flicker)."""
try:
return json.dumps(json.loads(a), sort_keys=True, ensure_ascii=False) == json.dumps(
json.loads(b), sort_keys=True, ensure_ascii=False
)
except Exception:
return a == b
def _round_floats_in_obj(obj: Any, ndigits: int = 2) -> Any:
if isinstance(obj, bool):
return obj
if isinstance(obj, float):
return round(obj, ndigits)
if isinstance(obj, int):
return obj
if isinstance(obj, list):
return [_round_floats_in_obj(x, ndigits) for x in obj]
if isinstance(obj, dict):
return {k: _round_floats_in_obj(v, ndigits) for k, v in obj.items()}
return obj
def _sort_gt_bbox_lists(obj: Any) -> Any:
"""Sort bbox arrays so echo round-trips from the annotator compare equal."""
obj = copy.deepcopy(obj)
gi = obj.get("gt_info") if isinstance(obj, dict) else None
if not isinstance(gi, dict):
return obj
for k, v in list(gi.items()):
if k == "relationship" or not isinstance(v, dict):
continue
bbs = v.get("bbox")
if not isinstance(bbs, list) or not bbs:
continue
if k == "fraction" and isinstance(v.get("values"), list):
vals = v["values"]
pairs = list(zip(bbs, vals))
pairs.sort(
key=lambda p: tuple(float(x) for x in p[0])
if isinstance(p[0], list) and len(p[0]) == 4
else (0.0, 0.0, 0.0, 0.0)
)
v["bbox"] = [p[0] for p in pairs]
v["values"] = [p[1] for p in pairs]
else:
def _bbox_key(b: Any) -> tuple[float, float, float, float]:
if isinstance(b, list) and len(b) == 4:
return tuple(float(x) for x in b)
return (0.0, 0.0, 0.0, 0.0)
v["bbox"] = sorted(bbs, key=_bbox_key)
gi[k] = v
return obj
def stable_gt_json_fingerprint(json_text: str, ndigits: int = 2) -> str:
"""Normalize numeric noise + bbox order for comparing whether UI needs an update."""
try:
data = json.loads(json_text)
data = _round_floats_in_obj(data, ndigits)
data = _sort_gt_bbox_lists(data)
return json.dumps(data, sort_keys=True, ensure_ascii=False)
except Exception:
return json_text
def json_outputs_equivalent_for_ui(a: str, b: str) -> bool:
if json_semantically_equal(a, b):
return True
if stable_gt_json_fingerprint(a, 2) == stable_gt_json_fingerprint(b, 2):
return True
# YOLO ↔ bbox annotator echo can differ slightly below 0.1px after rounding.
return stable_gt_json_fingerprint(a, 1) == stable_gt_json_fingerprint(b, 1)
def _seed_json_ui_fp(state: dict[str, Any]) -> None:
"""After any direct write to JSON UI, record fingerprint so Timer won't repaint the same text."""
try:
jt = state_to_json_text(state)
except Exception:
jt = json.dumps(empty_schema(), indent=4, ensure_ascii=False)
state["_last_json_ui_fp"] = stable_gt_json_fingerprint(jt, 2)
def _structural_fingerprint(state: dict[str, Any]) -> tuple:
"""Captures box count, labels, values, and selected index — ignores coordinates."""
boxes = state.get("boxes", [])
return (
len(boxes),
tuple((b.get("label", ""), b.get("fraction_value", ""), b.get("custom_label", "")) for b in boxes),
state.get("_selected_box_idx"),
)
def to_box_choices(state: dict[str, Any]) -> list[str]:
return [f"{idx}: id={b['id']} [{b.get('label', 'tick')}]" for idx, b in enumerate(state.get("boxes", []))]
def detect_with_yolo(image_path: str | None, selected_model: str) -> list[dict[str, Any]]:
if not image_path:
return []
model_path = MODEL_PATHS.get(selected_model, MODEL_PATHS["yolo_accurate"])
if not os.path.exists(model_path):
return []
model = YOLO(model_path)
pred = model(image_path, verbose=False)[0]
names = model.names
boxes = []
for bb in pred.boxes:
cls_name = names[int(bb.cls)]
if cls_name not in ("tick", "fraction", "zero", "one"):
continue
x, y, w, h = bb.xywh[0].tolist()
boxes.append(
{
"xmin": x - w / 2,
"ymin": y - h / 2,
"xmax": x + w / 2,
"ymax": y + h / 2,
"label": cls_name,
}
)
return boxes
def initialize_from_upload(image_path: str | None):
try:
state = empty_ui_state()
state["image"] = coerce_image_path(image_path)
jt = state_to_json_text(state)
_seed_json_ui_fp(state)
lab, fr, cu = sync_label_controls_for_box(state, None)
return state_to_annotator_value(state), state, jt, "", dropdown_update_for_boxes(state), lab, fr, cu
except Exception:
state = empty_ui_state()
jt = state_to_json_text(state)
_seed_json_ui_fp(state)
lab, fr, cu = sync_label_controls_for_box(state, None)
return state_to_annotator_value(state), state, jt, "", dropdown_update_for_boxes(state), lab, fr, cu
def run_detection(image_path: str | None, selected_model: str, state: dict[str, Any]):
state = state or empty_ui_state()
path = coerce_image_path(image_path)
state["image"] = path
old_boxes = list(state.get("boxes", []))
try:
detected = normalize_annotator_boxes(detect_with_yolo(path, selected_model), state)
state["boxes"] = detected
state["relationship"] = compute_auto_relationship(state)
json_text = state_to_json_text(state)
_seed_json_ui_fp(state)
n = len(state.get("boxes", []))
sel = (n - 1) if n else None
dd = dropdown_update_for_boxes(state, selected_index=sel) if n else dropdown_update_for_boxes(state)
lab, fr, cu = sync_label_controls_for_box(state, sel)
rel = state.get("relationship", "")
return state_to_annotator_value(state), state, json_text, dd, lab, fr, cu, rel
except Exception:
state["boxes"] = old_boxes
state["relationship"] = compute_auto_relationship(state)
_seed_json_ui_fp(state)
n = len(state.get("boxes", []))
sel = (n - 1) if n else None
dd = dropdown_update_for_boxes(state, selected_index=sel) if n else dropdown_update_for_boxes(state)
lab, fr, cu = sync_label_controls_for_box(state, sel)
rel = state.get("relationship", "")
return state_to_annotator_value(state), state, state_to_json_text(state), dd, lab, fr, cu, rel
def _detect_changed_box_index(old_boxes: list[dict], new_boxes: list[dict]) -> int | None:
"""Find the index of the box that was most likely just interacted with."""
if not new_boxes:
return None
if len(new_boxes) > len(old_boxes):
return len(new_boxes) - 1
if len(new_boxes) < len(old_boxes):
return max(0, len(new_boxes) - 1) if new_boxes else None
old_by_id = {b.get("id"): b for b in old_boxes if "id" in b}
best_idx, best_dist = None, 0.0
for i, nb in enumerate(new_boxes):
ob = old_by_id.get(nb.get("id"))
if ob is None:
return i
dist = (
abs(float(nb.get("xmin", 0)) - float(ob.get("xmin", 0)))
+ abs(float(nb.get("ymin", 0)) - float(ob.get("ymin", 0)))
+ abs(float(nb.get("xmax", 0)) - float(ob.get("xmax", 0)))
+ abs(float(nb.get("ymax", 0)) - float(ob.get("ymax", 0)))
)
if dist > best_dist:
best_dist = dist
best_idx = i
# Increase threshold to 5.0 to prevent hover-jitter / rounding from constantly triggering selection change.
# Actual clicks are handled robustly by the JS pointerdown listener anyway.
if best_dist > 5.0:
return best_idx
return None
def sync_canvas_to_state(annot_data: Any, state: dict[str, Any]):
"""
Fires on every annotator change (draw, drag, resize, delete).
Structural changes (add/remove/label) refresh all UI controls.
Coordinate-only changes (drag/resize) update state silently — no flicker.
"""
try:
state = state or empty_ui_state()
prev_struct = _structural_fingerprint(state)
image_path, raw_boxes = parse_annotator_payload(annot_data, state.get("image"))
coerced = coerce_image_path(image_path) or coerce_image_path(state.get("image"))
work = copy.deepcopy(state)
old_boxes = list(state.get("boxes", []))
new_boxes = bbox_tuple_to_internal(raw_boxes, work)
work["image"] = coerced
work["boxes"] = new_boxes
changed = _detect_changed_box_index(old_boxes, new_boxes)
if changed is not None:
work["_selected_box_idx"] = changed
new_struct = _structural_fingerprint(work)
# Gradio needs a new object to reliably detect State changes
if new_struct == prev_struct:
return (work, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
_seed_json_ui_fp(work)
n = len(work.get("boxes", []))
sel = work.get("_selected_box_idx")
try:
sel = int(sel) if sel is not None else None
except (TypeError, ValueError):
sel = None
if sel is None or sel < 0 or sel >= n:
sel = (n - 1) if n else None
jt = state_to_json_text(work)
dd = dropdown_update_for_boxes(work, selected_index=sel) if n else dropdown_update_for_boxes(work)
lab, fr, cu = sync_label_controls_for_box(work, sel)
return work, jt, dd, lab, fr, cu, gr.update()
except Exception as exc:
print(f"[MathNet] sync_canvas_to_state error (suppressed): {exc}")
return (state or empty_ui_state(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
def on_label_change(label_name: str) -> tuple[gr.Textbox, gr.Textbox]:
return (
gr.update(visible=(label_name == "fraction")),
gr.update(visible=(label_name == "customize_label")),
)
def on_selected_box_change(selected_box: str | None, state: dict[str, Any]):
"""When user changes Current box dropdown, mirror that box's label into Assign label."""
_noop = (gr.update(), gr.update(visible=False), gr.update(visible=False))
try:
state = state or empty_ui_state()
if not selected_box:
n = len(state.get("boxes", []))
s = (n - 1) if n else None
return sync_label_controls_for_box(state, s)
idx = int(str(selected_box).split(":")[0])
return sync_label_controls_for_box(state, idx)
except Exception as exc:
print(f"[MathNet] on_selected_box_change error (suppressed): {exc}")
return _noop
def on_box_clicked(clicked_value: str, state: dict[str, Any]):
"""Triggered by JS when user clicks on a box outline in the annotator canvas."""
_noop = (gr.update(), gr.update(), gr.update(), gr.update(), state or empty_ui_state())
try:
state = state or empty_ui_state()
if not clicked_value:
return _noop
idx = int(str(clicked_value).split("_")[0])
boxes = state.get("boxes", [])
if idx < 0 or idx >= len(boxes):
return _noop
state["_selected_box_idx"] = idx
dd = dropdown_update_for_boxes(state, selected_index=idx)
lab, fr, cu = sync_label_controls_for_box(state, idx)
return dd, lab, fr, cu, state
except Exception as exc:
print(f"[MathNet] on_box_clicked error (suppressed): {exc}")
return _noop
def apply_label_to_selected(
selected_box: str,
label_name: str,
fraction_value: str,
custom_label_name: str,
state: dict[str, Any],
) -> tuple[Any, Any, str, Any, Any, Any, Any]:
if not selected_box:
n = len(state.get("boxes", []))
s = (n - 1) if n else None
lab, fr, cu = sync_label_controls_for_box(state, s)
rel = state.get("relationship", "")
return state_to_annotator_value(state), state, state_to_json_text(state), lab, fr, cu, rel
idx = int(selected_box.split(":")[0])
if idx < 0 or idx >= len(state.get("boxes", [])):
n = len(state.get("boxes", []))
s = (n - 1) if n else None
lab, fr, cu = sync_label_controls_for_box(state, s)
rel = state.get("relationship", "")
return state_to_annotator_value(state), state, state_to_json_text(state), lab, fr, cu, rel
box = state["boxes"][idx]
box["label"] = label_name
box["fraction_value"] = fraction_value.strip() if label_name == "fraction" else ""
box["custom_label"] = custom_label_name.strip() if label_name == "customize_label" else ""
json_text = state_to_json_text(state)
_seed_json_ui_fp(state)
lab, fr, cu = sync_label_controls_for_box(state, idx)
rel = state.get("relationship", "")
return state_to_annotator_value(state), state, json_text, lab, fr, cu, rel
def update_relationship(relationship_text: str, state: dict[str, Any]) -> tuple[dict[str, Any], str]:
state["relationship"] = (relationship_text or "").strip()
jt = state_to_json_text(state)
_seed_json_ui_fp(state)
return state, jt
def load_json_text(json_text: str, state: dict[str, Any]):
state = state or empty_ui_state()
try:
parsed = json.loads(json_text)
except Exception as exc:
choices = to_box_choices(state)
err_msg = f"JSON parse error: {exc}"
state["_last_json_ui_fp"] = stable_gt_json_fingerprint(err_msg, 2)
n = len(state.get("boxes", []))
sel = (n - 1) if n else None
lab, fr, cu = sync_label_controls_for_box(state, sel)
return (
state_to_annotator_value(state),
state,
err_msg,
gr.update(choices=choices, value=choices[0] if choices else None),
state.get("relationship", ""),
lab,
fr,
cu,
)
rebuilt = parse_schema_to_state(parsed, state.get("image"), state)
jt_ok = state_to_json_text(rebuilt)
_seed_json_ui_fp(rebuilt)
n = len(rebuilt.get("boxes", []))
sel = (n - 1) if n else None
dd = dropdown_update_for_boxes(rebuilt, selected_index=sel) if n else dropdown_update_for_boxes(rebuilt)
lab, fr, cu = sync_label_controls_for_box(rebuilt, sel)
return (
state_to_annotator_value(rebuilt),
rebuilt,
jt_ok,
dd,
rebuilt.get("relationship", ""),
lab,
fr,
cu,
)
def create_json_download(json_text: str):
if not json_text or not str(json_text).strip():
return gr.update()
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False, encoding="utf-8") as f:
f.write(json_text)
return gr.update(value=f.name)
def packDict(dict_data: dict) -> str:
pack_str = ""
for key, value in dict_data.items():
pack_str += f"{key}: {value}\n"
return pack_str
def run_diagnosis(
input_selector: str,
mode_selector: str,
analysis_mode: str,
error_choice: str,
annot_data: Any,
state: dict[str, Any],
) -> str:
kb = GEMINI_KB_FILES if GEMINI_KB_FILES else (KNOWLEDGE_BASE_TEXT if KNOWLEDGE_BASE_TEXT else None)
want_analysis = (analysis_mode == "With analysis")
if mode_selector == "Binary error only":
if want_analysis:
analysis_prompt = get_gemini_one_error_prompt_with_analysis(error_choice)
else:
analysis_prompt = get_gemini_one_error_prompt_without_analysis(error_choice)
else:
if want_analysis:
analysis_prompt = gemini_with_analysis_multiple_error_prompt
else:
analysis_prompt = gemini_without_analysis_multiple_error_prompt
if input_selector == "img-only":
img_from_annot, _ = parse_annotator_payload(annot_data, state.get("image"))
image_path = coerce_image_path(img_from_annot) or coerce_image_path(state.get("image"))
if not image_path:
return "No image available for diagnosis."
result = analyze_student_work_gemini_img(
image_path=image_path,
prompt=basic_prompt,
analysis_prompt=analysis_prompt,
knowledge_base_files=kb,
)
parsed_result = parse_json_response(result, kw = "error_type")
parsed_result_str = packDict(parsed_result)
return parsed_result_str or "Diagnosis returned no text."
schema = serialize_to_schema(state)
if not schema:
return "No JSON data available for diagnosis. Run YOLO detection or annotate first."
result = analyze_student_work_wout_image_gemini(
json_data=schema,
prompt=basic_prompt_wout_image,
analysis_prompt=analysis_prompt,
knowledge_base_files=kb,
)
parsed_result = parse_json_response(result, kw = "error_type")
parsed_result_str = packDict(parsed_result)
return parsed_result_str or "Diagnosis returned no text."
def check_token(token):
if VERIFY_TOKEN and token == VERIFY_TOKEN:
return gr.update(visible=True), gr.update(value="Token accepted. You can now use the application.")
else:
return gr.update(visible=False), gr.update(value="Invalid token. Please try again.")
with gr.Blocks(title="MathNet Annotation + Diagnosis", css=_BBOX_OVERLAY_CSS, js=_BBOX_OVERLAY_JS) as demo:
gr.Markdown("## MathNet Interactive Annotation and Diagnosis")
# ── Token verification gate ──
with gr.Group() as auth_group:
gr.Markdown("### Please enter your access token to continue")
token_input = gr.Textbox(label="Access Token", type="password", placeholder="Enter token...")
token_btn = gr.Button("Verify")
token_msg = gr.Textbox(label="Status", interactive=False, value="")
# ── Main application (hidden until verified) ──
main_app = gr.Group(visible=False)
with main_app:
_init_state = empty_ui_state()
_init_json_text = state_to_json_text(_init_state)
_init_state["_last_json_ui_fp"] = stable_gt_json_fingerprint(_init_json_text, 2)
app_state = gr.State(_init_state)
with gr.Row():
image_input = gr.Image(type="filepath", label="Upload student work image")
with gr.Column():
model_choice = gr.Dropdown(
choices=list(MODEL_PATHS.keys()),
value="yolo_accurate",
label="Detection model",
)
detect_btn = gr.Button("Run YOLO Detection")
download_json_btn = gr.DownloadButton("Download JSON")
annotator = BBoxAnnotator(
value=(ensure_placeholder_image_path(), []),
label="Editable Bounding Boxes",
categories=STANDARD_LABELS,
show_download_button=False,
)
clicked_box_idx = gr.Textbox(elem_id="clicked-box-idx", visible=False)
with gr.Row():
with gr.Column(scale=2):
selected_box_dd = gr.Dropdown(choices=[], label="Current box", allow_custom_value=True)
label_dd = gr.Dropdown(choices=STANDARD_LABELS, value="tick", label="Assign label")
fraction_value_tb = gr.Textbox(label="Fraction value (e.g., 3/4)", visible=False)
custom_label_tb = gr.Textbox(label="Custom label name", visible=False)
apply_label_btn = gr.Button("Apply Label To Current Box")
relationship_tb = gr.Textbox(
label="Fraction–Tick relationship (auto + editable)",
placeholder="Auto: each tick (left→right) pairs with nearest fraction by vertical center; e.g. T0-F1,T1-F0",
)
save_relationship_btn = gr.Button("Save Relationship to JSON")
with gr.Column(scale=3):
json_editor = gr.Textbox(
label="Synchronized JSON",
lines=18,
value=_init_json_text,
)
load_json_btn = gr.Button("Load JSON Into UI")
with gr.Group():
gr.Markdown("### Diagnosis Control Panel")
input_selector = gr.Radio(choices=["img-only", "json data"], value="img-only", label="Input Selector")
mode_selector = gr.Radio(choices=["Binary error only", "Auto-diagnosis"], value="Auto-diagnosis", label="Mode Selector")
analysis_selector = gr.Radio(choices=["With analysis", "Without analysis"], value="With analysis", label="Analysis Mode")
error_dropdown = gr.Dropdown(choices=ERROR_OPTIONS, value=ERROR_OPTIONS[0], label="Error Dropdown", visible=False)
run_diagnosis_btn = gr.Button("Run Diagnosis")
diagnosis_output = gr.Textbox(label="Diagnosis Output", lines=10)
image_input.change(
initialize_from_upload,
inputs=[image_input],
outputs=[
annotator,
app_state,
json_editor,
relationship_tb,
selected_box_dd,
label_dd,
fraction_value_tb,
custom_label_tb,
],
queue=False,
)
image_input.clear(
initialize_from_upload,
inputs=[image_input],
outputs=[
annotator,
app_state,
json_editor,
relationship_tb,
selected_box_dd,
label_dd,
fraction_value_tb,
custom_label_tb,
],
queue=False,
)
detect_btn.click(
run_detection,
inputs=[image_input, model_choice, app_state],
outputs=[
annotator,
app_state,
json_editor,
selected_box_dd,
label_dd,
fraction_value_tb,
custom_label_tb,
relationship_tb,
],
)
_sync_outputs = [
app_state, json_editor, selected_box_dd,
label_dd, fraction_value_tb, custom_label_tb, relationship_tb,
]
annotator.change(
sync_canvas_to_state,
inputs=[annotator, app_state],
outputs=_sync_outputs,
queue=False,
show_progress="hidden",
)
annotator.clear(
sync_canvas_to_state,
inputs=[annotator, app_state],
outputs=_sync_outputs,
queue=False,
show_progress="hidden",
)
clicked_box_idx.change(
on_box_clicked,
inputs=[clicked_box_idx, app_state],
outputs=[selected_box_dd, label_dd, fraction_value_tb, custom_label_tb, app_state],
queue=False,
show_progress="hidden",
)
# ── Token verification wiring ──
token_btn.click(
check_token,
inputs=[token_input],
outputs=[main_app, token_msg],
queue=False,
)
selected_box_dd.change(
on_selected_box_change,
inputs=[selected_box_dd, app_state],
outputs=[label_dd, fraction_value_tb, custom_label_tb],
queue=False,
)
label_dd.change(
on_label_change,
inputs=[label_dd],
outputs=[fraction_value_tb, custom_label_tb],
queue=False,
)
apply_label_btn.click(
apply_label_to_selected,
inputs=[selected_box_dd, label_dd, fraction_value_tb, custom_label_tb, app_state],
outputs=[
annotator,
app_state,
json_editor,
label_dd,
fraction_value_tb,
custom_label_tb,
relationship_tb,
],
queue=False,
)
save_relationship_btn.click(
update_relationship,
inputs=[relationship_tb, app_state],
outputs=[app_state, json_editor],
queue=False,
)
load_json_btn.click(
load_json_text,
inputs=[json_editor, app_state],
outputs=[
annotator,
app_state,
json_editor,
selected_box_dd,
relationship_tb,
label_dd,
fraction_value_tb,
custom_label_tb,
],
queue=False,
)
mode_selector.change(
lambda mode: gr.update(visible=(mode == "Binary error only")),
inputs=[mode_selector],
outputs=[error_dropdown],
queue=False,
)
run_diagnosis_btn.click(
run_diagnosis,
inputs=[input_selector, mode_selector, analysis_selector, error_dropdown, annotator, app_state],
outputs=[diagnosis_output],
)
download_json_btn.click(create_json_download, inputs=[json_editor], outputs=[download_json_btn])
def load_all_files_in_folder(folder_path, extension=None):
"""
Load all file paths from a given folder.
Optionally filter by file extension (e.g., '.md').
"""
file_paths = []
if not os.path.exists(folder_path):
print(f"❌ Folder not found: {folder_path}")
return file_paths
for root, dirs, files in os.walk(folder_path):
for file in files:
if extension:
if file.endswith(extension):
file_paths.append(os.path.join(root, file))
else:
file_paths.append(os.path.join(root, file))
return file_paths
CACHE_FILE = "gemini_file_cache.json"
def load_cache():
if os.path.exists(CACHE_FILE):
try:
with open(CACHE_FILE, "r") as f:
return json.load(f)
except:
return {}
return {}
def save_cache(cache):
with open(CACHE_FILE, "w") as f:
json.dump(cache, f, indent=2)
def setup_knowledge_base_gemini(file_path):
"""
Uploads a file to Gemini to be used as a knowledge base.
Returns the uploaded file object.
"""
if not os.path.exists(file_path):
print(f"❌ File not found: {file_path}")
return None
cache = load_cache()
abs_path = os.path.abspath(file_path)
# Check cache for existing file
if abs_path in cache:
cached_info = cache[abs_path]
file_name = cached_info.get("name")
upload_time = cached_info.get("time", 0)
# Check if within 47 hours (leaving 1 hour buffer)
if time.time() - upload_time < 47 * 3600:
try:
# Verify if file still exists on Gemini
try:
file = gemini_client.files.get(name=file_name)
state_name = getattr(file.state, "name", str(file.state))
if state_name == "ACTIVE":
print(f"✅ Using cached file: {file.name}")
return file
else:
print(f"⚠️ Cached file found but state is {state_name}, re-uploading...")
except Exception as e:
# Specific error handling for permission/not found
if "403" in str(e) or "404" in str(e) or "PERMISSION_DENIED" in str(e):
print("⚠️ Cached file expired or not accessible (403/404), re-uploading...")
else:
raise e # Re-raise other unexpected errors
except Exception as e:
print(f"⚠️ Error verifying cached file: {e}, re-uploading...")
else:
print("⚠️ Cached file expired (>47h), re-uploading...")
print(f"Uploading file to Gemini: {file_path}...")
try:
# Upload the file to Gemini
# Use gemini_client.files.upload which is standard for google-genai SDK
# The argument name for the file path or object is typically 'file'
file = gemini_client.files.upload(file=file_path, config={'display_name': os.path.basename(file_path)})
print(f"✅ File uploaded: {file.name}")
# Wait for file to be active
while getattr(file.state, "name", str(file.state)) == "PROCESSING":
print("Processing file...", end='\r')
time.sleep(2)
file = gemini_client.files.get(name=file.name)
if getattr(file.state, "name", str(file.state)) == "FAILED":
print("❌ File processing failed.")
return None
print(f"File ready: {file.uri}")
# Update cache
cache[abs_path] = {
"name": file.name,
"uri": file.uri,
"time": time.time()
}
save_cache(cache)
return file
except Exception as e:
print(f"❌ Failed to upload file: {file_path}")
print(f" ↳ {_extract_gemini_error_details(e)}")
return None
def _extract_gemini_error_details(exception_obj):
"""
Build a readable, actionable summary from Gemini API exceptions.
"""
err_text = str(exception_obj)
details = [f"raw_error={err_text}"]
# Try to parse API error payload embedded in exception text.
if "{" in err_text and "}" in err_text:
try:
payload = json.loads(err_text[err_text.find("{"):])
err = payload.get("error", {})
code = err.get("code")
status = err.get("status")
message = err.get("message")
if code is not None:
details.append(f"code={code}")
if status:
details.append(f"status={status}")
if message:
details.append(f"message={message}")
for item in err.get("details", []):
metadata = item.get("metadata", {})
reason = item.get("reason")
if reason:
details.append(f"reason={reason}")
if metadata.get("consumer"):
details.append(f"consumer={metadata.get('consumer')}")
if metadata.get("service"):
details.append(f"service={metadata.get('service')}")
if metadata.get("activationUrl"):
details.append(f"activation_url={metadata.get('activationUrl')}")
except Exception:
# Keep fallback raw error text only.
pass
upper_err = err_text.upper()
if "SERVICE_DISABLED" in upper_err or "GENERATIVELANGUAGE.GOOGLEAPIS.COM" in upper_err:
details.append(
"hint=Generative Language API may be disabled for this Google Cloud project. "
"Enable it in Google Cloud Console and retry after propagation."
)
if "PERMISSION_DENIED" in upper_err:
details.append(
"hint=API key may not belong to the same project as cached files, or key lacks required permissions."
)
if "403" in upper_err:
details.append("hint=403 usually means API disabled, wrong project, or insufficient permission.")
if "404" in upper_err or "NOT_FOUND" in upper_err:
details.append("hint=Cached remote file ID no longer exists; re-upload is required.")
# Deduplicate while preserving order.
seen = set()
unique_details = []
for item in details:
if item in seen:
continue
seen.add(item)
unique_details.append(item)
return " | ".join(unique_details)
def _resolve_definitions_dir() -> str | None:
here = Path(__file__).resolve()
candidates = [
here.parent / "Definitions",
here.parent.parent / "Definitions",
here.parent.parent.parent / "Definitions",
Path.cwd() / "Definitions",
]
for c in candidates:
if c.exists() and c.is_dir():
return str(c)
return None
def getKnowledgeBaseFiles(kw_file):
gemini_cache = load_cache()
md_files = []
if kw_file == "def":
def_dir = _resolve_definitions_dir()
if def_dir is None:
print("❌ Definitions folder not found")
return None
md_files = load_all_files_in_folder(def_dir, extension=".md")
else:
print(f"❌ Invalid keyword file: {kw_file}")
return None
knowledge_base = []
for item in md_files:
abs_path = os.path.abspath(item)
file = None
if abs_path in gemini_cache:
cached_name = gemini_cache[abs_path].get("name")
try:
file = gemini_client.files.get(name=cached_name)
except Exception as e:
# Common when switching API key/account: cached file id is not accessible.
print(f"⚠️ Cached Gemini file is not accessible for {item}")
print(f" ↳ {_extract_gemini_error_details(e)}")
file = None
else:
print(f"⚠️ File not found in gemini cache, uploading: {item}")
# Re-upload if cache miss, get failed, or file is not active.
if file is None:
print(f"Reuploading file: {item}")
file = setup_knowledge_base_gemini(item)
if file is None:
print(f"❌ Failed to upload file: {item}")
print("❌ Knowledge base loading aborted: one or more files could not be uploaded to Gemini.")
return None
knowledge_base.append(file)
continue
state_name = getattr(file.state, "name", str(file.state))
if state_name == "ACTIVE":
knowledge_base.append(file)
else:
print(f"⚠️ Cached file state is {state_name}, reuploading: {item}")
file = setup_knowledge_base_gemini(item)
if file:
knowledge_base.append(file)
else:
print(f"❌ Failed to upload file: {item}")
print("❌ Knowledge base loading aborted: one or more files could not be uploaded to Gemini.")
return None
return knowledge_base
def parse_json_response(text: str, kw = "error_types") -> dict:
"""
Parses the JSON output from the LLM response based on the defined prompts.
Handles both 'with_analysis' (list of dicts) and 'without_analysis' (list of strings) formats.
Robustly searches for JSON objects in the text using a JSON decoder to handle nested structures and noise.
"""
if not text:
return {kw: []}
decoder = json.JSONDecoder()
candidates = []
# 1. Try to extract content inside ```json ... ``` or ``` ... ```
code_block_pattern = r"```(?:json)?\s*(.*?)\s*```"
for match in re.finditer(code_block_pattern, text, re.DOTALL):
try:
block_content = match.group(1).strip()
candidate = json.loads(block_content)
candidates.append(candidate)
except json.JSONDecodeError:
pass
# 2. Scan the text for any valid JSON objects starting with '{'
# This avoids regex fragility with LaTeX braces or other content
idx = 0
while idx < len(text):
next_brace = text.find('{', idx)
if next_brace == -1:
break
try:
# raw_decode parses a JSON object starting at the given index
obj, end_idx = decoder.raw_decode(text, next_brace)
candidates.append(obj)
idx = end_idx
except json.JSONDecodeError:
# If this { didn't start a valid JSON, move to the next character
idx = next_brace + 1
# 3. Select the best candidate
best_candidate = None
# Priority: Contains the expected keyword (or common typo)
for cand in candidates:
if isinstance(cand, dict):
if kw in cand:
best_candidate = cand
elif "error_type" in cand:
# Handle singular variant mismatch
best_candidate = {kw: cand["error_type"]}
# If no ideal candidate, try to use the last valid dict found
if not best_candidate and candidates:
dict_candidates = [c for c in candidates if isinstance(c, dict)]
if dict_candidates:
best_candidate = dict_candidates[-1]
if kw not in best_candidate:
print(f"Warning: Selected JSON candidate missing '{kw}'. Keys: {list(best_candidate.keys())}")
else:
# Check for list candidates (wrapped)
list_candidates = [c for c in candidates if isinstance(c, list)]
if list_candidates:
best_candidate = {kw: list_candidates[-1]}
# 4. Fallback: Parse whole text (if it wasn't caught by scan)
if not best_candidate:
try:
clean_text = text.strip()
# Basic check if it looks like JSON
if clean_text.startswith('{') or clean_text.startswith('['):
data = json.loads(clean_text)
if isinstance(data, dict):
best_candidate = data
if "error_type" in data and kw not in data:
best_candidate = {kw: data["error_type"]}
elif isinstance(data, list):
best_candidate = {kw: data}
except:
pass
if best_candidate:
# Final safety check for return structure
if kw not in best_candidate:
# If we have a dict but missing key, just return it (or wrap empty?)
# The caller usually checks 'if key not in data'.
pass
return best_candidate
print(f"❌ Could not find valid JSON in response.")
# print(f"Raw text start: {text[:200]}...")
return {kw: ["Not Able to parse"]}
if __name__ == "__main__":
GEMINI_KB_FILES = getKnowledgeBaseFiles("def")
if GEMINI_KB_FILES is None:
print("❌ Failed to get the knowledge base for the definitions")
exit()
print(f"[MathNet] All {len(GEMINI_KB_FILES)} KB files loaded and ready to analyze")
demo.launch()