SpringWang08 commited on
Commit
cb6aa4c
·
verified ·
1 Parent(s): 1e8f431

Simplify result table labels

Browse files
Files changed (1) hide show
  1. app.py +15 -9
app.py CHANGED
@@ -36,6 +36,14 @@ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
36
  ROOT_DIR = Path(__file__).resolve().parent
37
  CONFIG_PATH = ROOT_DIR / "configs" / "medical_vqa.yaml"
38
  VARIANT_ORDER = ["A1", "A2", "B1", "B2", "DPO", "PPO"]
 
 
 
 
 
 
 
 
39
  HF_MODEL_REPOS = {
40
  "A1": "SpringWang08/medical-vqa-a1",
41
  "A2": "SpringWang08/medical-vqa-a2",
@@ -323,7 +331,9 @@ async def _predict_variant(variant: str, question: str, image: Image.Image) -> d
323
  )
324
  return {
325
  "model": variant,
 
326
  "prediction": rewritten,
 
327
  "prediction_before_rewrite": out["prediction"],
328
  "raw": out["prediction_raw"],
329
  "answer_used_for_rewrite": answer_for_rewrite,
@@ -334,7 +344,9 @@ async def _predict_variant(variant: str, question: str, image: Image.Image) -> d
334
  except Exception as exc:
335
  return {
336
  "model": variant,
 
337
  "prediction": "",
 
338
  "prediction_before_rewrite": "",
339
  "raw": "",
340
  "answer_used_for_rewrite": "",
@@ -362,7 +374,7 @@ def predict_all(image: Image.Image, question: str, selected_models: list[str]) -
362
  return rows
363
 
364
  rows = asyncio.run(_run())
365
- return pd.DataFrame(rows)
366
 
367
 
368
  CSS = """
@@ -390,14 +402,8 @@ with gr.Blocks(css=CSS, title="Medical VQA Compare") as demo:
390
  output_table = gr.Dataframe(
391
  label="Kết quả",
392
  headers=[
393
- "model",
394
- "prediction",
395
- "prediction_before_rewrite",
396
- "raw",
397
- "answer_used_for_rewrite",
398
- "checkpoint",
399
- "latency_ms",
400
- "status",
401
  ],
402
  wrap=True,
403
  )
 
36
  ROOT_DIR = Path(__file__).resolve().parent
37
  CONFIG_PATH = ROOT_DIR / "configs" / "medical_vqa.yaml"
38
  VARIANT_ORDER = ["A1", "A2", "B1", "B2", "DPO", "PPO"]
39
+ MODEL_DISPLAY_NAMES = {
40
+ "A1": "A1 LSTM",
41
+ "A2": "A2 Transformer",
42
+ "B1": "B1 Zero-shot",
43
+ "B2": "B2 Fine-tuned",
44
+ "DPO": "DPO Alignment",
45
+ "PPO": "PPO RL refinement",
46
+ }
47
  HF_MODEL_REPOS = {
48
  "A1": "SpringWang08/medical-vqa-a1",
49
  "A2": "SpringWang08/medical-vqa-a2",
 
331
  )
332
  return {
333
  "model": variant,
334
+ "Model": MODEL_DISPLAY_NAMES.get(variant, variant),
335
  "prediction": rewritten,
336
+ "Prediction": rewritten,
337
  "prediction_before_rewrite": out["prediction"],
338
  "raw": out["prediction_raw"],
339
  "answer_used_for_rewrite": answer_for_rewrite,
 
344
  except Exception as exc:
345
  return {
346
  "model": variant,
347
+ "Model": MODEL_DISPLAY_NAMES.get(variant, variant),
348
  "prediction": "",
349
+ "Prediction": "",
350
  "prediction_before_rewrite": "",
351
  "raw": "",
352
  "answer_used_for_rewrite": "",
 
374
  return rows
375
 
376
  rows = asyncio.run(_run())
377
+ return pd.DataFrame(rows)[["Model", "Prediction"]]
378
 
379
 
380
  CSS = """
 
402
  output_table = gr.Dataframe(
403
  label="Kết quả",
404
  headers=[
405
+ "Model",
406
+ "Prediction",
 
 
 
 
 
 
407
  ],
408
  wrap=True,
409
  )