Spaces:
Paused
Paused
Commit ·
bfbf130
1
Parent(s): 0a2bc32
Load checkpoints from Hugging Face Hub
Browse files- web/main.py +79 -7
web/main.py
CHANGED
|
@@ -13,6 +13,7 @@ import torch
|
|
| 13 |
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
| 14 |
from fastapi.responses import FileResponse, JSONResponse
|
| 15 |
from fastapi.staticfiles import StaticFiles
|
|
|
|
| 16 |
from PIL import Image
|
| 17 |
from peft import PeftModel
|
| 18 |
from transformers import AutoTokenizer, LlavaForConditionalGeneration, LlavaProcessor
|
|
@@ -106,6 +107,17 @@ class VQAServerState:
|
|
| 106 |
self.model_b_cfg = CFG.get("model_b", {})
|
| 107 |
self.eval_cfg = CFG.get("eval", {})
|
| 108 |
self.models_dir = ROOT_DIR / "checkpoints"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
self.qa_tokenizer = None
|
| 110 |
self.translator = MedicalTranslator(device="cpu")
|
| 111 |
self.answer_rewriter = MedicalAnswerRewriter()
|
|
@@ -134,6 +146,19 @@ def _artifact_exists(path: Path) -> bool:
|
|
| 134 |
return path.exists()
|
| 135 |
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
def _as_bool(value: Any) -> bool:
|
| 138 |
if isinstance(value, bool):
|
| 139 |
return value
|
|
@@ -352,7 +377,20 @@ def _resolve_variant_artifact(variant: str) -> dict[str, Any]:
|
|
| 352 |
ckpt_path = ROOT_DIR / "checkpoints" / f"medical_vqa_{variant}_best.pth"
|
| 353 |
if not ckpt_path.exists():
|
| 354 |
resume_path = ROOT_DIR / "checkpoints" / f"medical_vqa_{variant}_resume.pth"
|
| 355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
return {"type": "direction_a", "path": ckpt_path}
|
| 357 |
|
| 358 |
if variant == "B1":
|
|
@@ -360,15 +398,49 @@ def _resolve_variant_artifact(variant: str) -> dict[str, Any]:
|
|
| 360 |
|
| 361 |
if variant == "B2":
|
| 362 |
ckpt_dir = _select_best_b2_checkpoint(ROOT_DIR / "checkpoints" / "B2")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
return {"type": "llava_adapter", "path": ckpt_dir}
|
| 364 |
|
| 365 |
if variant == "DPO":
|
| 366 |
final_adapter = ROOT_DIR / "checkpoints" / "DPO" / "final_adapter"
|
| 367 |
fallback = ROOT_DIR / "checkpoints" / "DPO" / "checkpoint-25"
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
if variant == "PPO":
|
| 371 |
final_adapter = ROOT_DIR / "checkpoints" / "PPO" / "final_adapter"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
return {"type": "llava_adapter", "path": final_adapter}
|
| 373 |
|
| 374 |
raise ValueError(f"Unknown variant: {variant}")
|
|
@@ -857,12 +929,12 @@ def _variant_availability() -> dict[str, dict[str, Any]]:
|
|
| 857 |
b2_checkpoint = _select_best_b2_checkpoint(ROOT_DIR / "checkpoints" / "B2")
|
| 858 |
cuda_ready = torch.cuda.is_available()
|
| 859 |
return {
|
| 860 |
-
"A1": {"available": (_artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A1_best.pth")), "artifact": "checkpoints/medical_vqa_A1_best.pth"},
|
| 861 |
-
"A2": {"available": (_artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A2_best.pth")), "artifact": "checkpoints/medical_vqa_A2_best.pth"},
|
| 862 |
"B1": {"available": cuda_ready, "artifact": state.llava_model_id},
|
| 863 |
-
"B2": {"available": cuda_ready and b2_checkpoint is not None, "artifact": str(b2_checkpoint) if b2_checkpoint else ""},
|
| 864 |
-
"DPO": {"available": cuda_ready and (_artifact_exists(ROOT_DIR / "checkpoints" / "DPO" / "final_adapter") or _artifact_exists(ROOT_DIR / "checkpoints" / "DPO" / "checkpoint-25")), "artifact": "checkpoints/DPO/final_adapter"},
|
| 865 |
-
"PPO": {"available": cuda_ready and _artifact_exists(ROOT_DIR / "checkpoints" / "PPO" / "final_adapter"), "artifact": "checkpoints/PPO/final_adapter"},
|
| 866 |
}
|
| 867 |
|
| 868 |
|
|
|
|
| 13 |
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
| 14 |
from fastapi.responses import FileResponse, JSONResponse
|
| 15 |
from fastapi.staticfiles import StaticFiles
|
| 16 |
+
from huggingface_hub import snapshot_download
|
| 17 |
from PIL import Image
|
| 18 |
from peft import PeftModel
|
| 19 |
from transformers import AutoTokenizer, LlavaForConditionalGeneration, LlavaProcessor
|
|
|
|
| 107 |
self.model_b_cfg = CFG.get("model_b", {})
|
| 108 |
self.eval_cfg = CFG.get("eval", {})
|
| 109 |
self.models_dir = ROOT_DIR / "checkpoints"
|
| 110 |
+
self.artifact_cache_dir = Path(
|
| 111 |
+
os.getenv("MEDVQA_ARTIFACT_CACHE", str(ROOT_DIR / ".cache" / "hub_artifacts"))
|
| 112 |
+
)
|
| 113 |
+
self.artifact_cache_dir.mkdir(parents=True, exist_ok=True)
|
| 114 |
+
self.hub_model_ids = {
|
| 115 |
+
"A1": os.getenv("MEDVQA_A1_MODEL_ID", "SpringWang08/medical-vqa-a1"),
|
| 116 |
+
"A2": os.getenv("MEDVQA_A2_MODEL_ID", "SpringWang08/medical-vqa-a2"),
|
| 117 |
+
"B2": os.getenv("MEDVQA_B2_MODEL_ID", "SpringWang08/medical-vqa-b2"),
|
| 118 |
+
"DPO": os.getenv("MEDVQA_DPO_MODEL_ID", "SpringWang08/medical-vqa-dpo"),
|
| 119 |
+
"PPO": os.getenv("MEDVQA_PPO_MODEL_ID", "SpringWang08/medical-vqa-ppo"),
|
| 120 |
+
}
|
| 121 |
self.qa_tokenizer = None
|
| 122 |
self.translator = MedicalTranslator(device="cpu")
|
| 123 |
self.answer_rewriter = MedicalAnswerRewriter()
|
|
|
|
| 146 |
return path.exists()
|
| 147 |
|
| 148 |
|
| 149 |
+
def _download_hub_snapshot(repo_id: str, cache_subdir: str, allow_patterns: Optional[list[str]] = None) -> Path:
|
| 150 |
+
target_dir = state.artifact_cache_dir / cache_subdir
|
| 151 |
+
target_dir.mkdir(parents=True, exist_ok=True)
|
| 152 |
+
snapshot_download(
|
| 153 |
+
repo_id=repo_id,
|
| 154 |
+
repo_type="model",
|
| 155 |
+
local_dir=str(target_dir),
|
| 156 |
+
local_dir_use_symlinks=False,
|
| 157 |
+
allow_patterns=allow_patterns,
|
| 158 |
+
)
|
| 159 |
+
return target_dir
|
| 160 |
+
|
| 161 |
+
|
| 162 |
def _as_bool(value: Any) -> bool:
|
| 163 |
if isinstance(value, bool):
|
| 164 |
return value
|
|
|
|
| 377 |
ckpt_path = ROOT_DIR / "checkpoints" / f"medical_vqa_{variant}_best.pth"
|
| 378 |
if not ckpt_path.exists():
|
| 379 |
resume_path = ROOT_DIR / "checkpoints" / f"medical_vqa_{variant}_resume.pth"
|
| 380 |
+
if resume_path.exists():
|
| 381 |
+
ckpt_path = resume_path
|
| 382 |
+
else:
|
| 383 |
+
repo_id = state.hub_model_ids.get(variant, "")
|
| 384 |
+
if repo_id:
|
| 385 |
+
downloaded_dir = _download_hub_snapshot(
|
| 386 |
+
repo_id=repo_id,
|
| 387 |
+
cache_subdir=variant.lower(),
|
| 388 |
+
allow_patterns=["README.md", "*.pth"],
|
| 389 |
+
)
|
| 390 |
+
downloaded_ckpt = downloaded_dir / f"medical_vqa_{variant}_best.pth"
|
| 391 |
+
if not downloaded_ckpt.exists():
|
| 392 |
+
downloaded_ckpt = downloaded_dir / f"medical_vqa_{variant}_resume.pth"
|
| 393 |
+
ckpt_path = downloaded_ckpt
|
| 394 |
return {"type": "direction_a", "path": ckpt_path}
|
| 395 |
|
| 396 |
if variant == "B1":
|
|
|
|
| 398 |
|
| 399 |
if variant == "B2":
|
| 400 |
ckpt_dir = _select_best_b2_checkpoint(ROOT_DIR / "checkpoints" / "B2")
|
| 401 |
+
if ckpt_dir is None:
|
| 402 |
+
repo_id = state.hub_model_ids.get("B2", "")
|
| 403 |
+
if repo_id:
|
| 404 |
+
ckpt_dir = _download_hub_snapshot(
|
| 405 |
+
repo_id=repo_id,
|
| 406 |
+
cache_subdir="b2",
|
| 407 |
+
allow_patterns=["README.md", "adapter_model.safetensors", "adapter_config.json", "tokenizer.json", "tokenizer_config.json", "processor_config.json", "chat_template.jinja"],
|
| 408 |
+
)
|
| 409 |
return {"type": "llava_adapter", "path": ckpt_dir}
|
| 410 |
|
| 411 |
if variant == "DPO":
|
| 412 |
final_adapter = ROOT_DIR / "checkpoints" / "DPO" / "final_adapter"
|
| 413 |
fallback = ROOT_DIR / "checkpoints" / "DPO" / "checkpoint-25"
|
| 414 |
+
if final_adapter.exists():
|
| 415 |
+
return {"type": "llava_adapter", "path": final_adapter}
|
| 416 |
+
if fallback.exists():
|
| 417 |
+
return {"type": "llava_adapter", "path": fallback}
|
| 418 |
+
repo_id = state.hub_model_ids.get("DPO", "")
|
| 419 |
+
if repo_id:
|
| 420 |
+
return {
|
| 421 |
+
"type": "llava_adapter",
|
| 422 |
+
"path": _download_hub_snapshot(
|
| 423 |
+
repo_id=repo_id,
|
| 424 |
+
cache_subdir="dpo",
|
| 425 |
+
allow_patterns=["README.md", "adapter_model.safetensors", "adapter_config.json", "tokenizer.json", "tokenizer_config.json", "processor_config.json", "chat_template.jinja"],
|
| 426 |
+
),
|
| 427 |
+
}
|
| 428 |
+
return {"type": "llava_adapter", "path": final_adapter}
|
| 429 |
|
| 430 |
if variant == "PPO":
|
| 431 |
final_adapter = ROOT_DIR / "checkpoints" / "PPO" / "final_adapter"
|
| 432 |
+
if final_adapter.exists():
|
| 433 |
+
return {"type": "llava_adapter", "path": final_adapter}
|
| 434 |
+
repo_id = state.hub_model_ids.get("PPO", "")
|
| 435 |
+
if repo_id:
|
| 436 |
+
return {
|
| 437 |
+
"type": "llava_adapter",
|
| 438 |
+
"path": _download_hub_snapshot(
|
| 439 |
+
repo_id=repo_id,
|
| 440 |
+
cache_subdir="ppo",
|
| 441 |
+
allow_patterns=["README.md", "adapter_model.safetensors", "adapter_config.json", "tokenizer.json", "tokenizer_config.json", "processor_config.json", "chat_template.jinja"],
|
| 442 |
+
),
|
| 443 |
+
}
|
| 444 |
return {"type": "llava_adapter", "path": final_adapter}
|
| 445 |
|
| 446 |
raise ValueError(f"Unknown variant: {variant}")
|
|
|
|
| 929 |
b2_checkpoint = _select_best_b2_checkpoint(ROOT_DIR / "checkpoints" / "B2")
|
| 930 |
cuda_ready = torch.cuda.is_available()
|
| 931 |
return {
|
| 932 |
+
"A1": {"available": (_artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A1_best.pth") or bool(state.hub_model_ids.get("A1"))), "artifact": str(ROOT_DIR / "checkpoints" / "medical_vqa_A1_best.pth") if _artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A1_best.pth") else state.hub_model_ids.get("A1", "")},
|
| 933 |
+
"A2": {"available": (_artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A2_best.pth") or bool(state.hub_model_ids.get("A2"))), "artifact": str(ROOT_DIR / "checkpoints" / "medical_vqa_A2_best.pth") if _artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A2_best.pth") else state.hub_model_ids.get("A2", "")},
|
| 934 |
"B1": {"available": cuda_ready, "artifact": state.llava_model_id},
|
| 935 |
+
"B2": {"available": cuda_ready and (b2_checkpoint is not None or bool(state.hub_model_ids.get("B2"))), "artifact": str(b2_checkpoint) if b2_checkpoint else state.hub_model_ids.get("B2", "")},
|
| 936 |
+
"DPO": {"available": cuda_ready and (_artifact_exists(ROOT_DIR / "checkpoints" / "DPO" / "final_adapter") or _artifact_exists(ROOT_DIR / "checkpoints" / "DPO" / "checkpoint-25") or bool(state.hub_model_ids.get("DPO"))), "artifact": "checkpoints/DPO/final_adapter" if _artifact_exists(ROOT_DIR / "checkpoints" / "DPO" / "final_adapter") else state.hub_model_ids.get("DPO", "")},
|
| 937 |
+
"PPO": {"available": cuda_ready and (_artifact_exists(ROOT_DIR / "checkpoints" / "PPO" / "final_adapter") or bool(state.hub_model_ids.get("PPO"))), "artifact": "checkpoints/PPO/final_adapter" if _artifact_exists(ROOT_DIR / "checkpoints" / "PPO" / "final_adapter") else state.hub_model_ids.get("PPO", "")},
|
| 938 |
}
|
| 939 |
|
| 940 |
|