File size: 32,300 Bytes
d63774a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
import torch
from tqdm import tqdm
from src.utils.metrics import batch_metrics, compute_bertscore, compute_semantic_score
from src.utils.text_utils import is_medical_term_compliant, normalize_answer, postprocess_answer

def normalize_for_metric(text: str) -> str:
    return text.strip().lower()


def _normalize_closed_answer(question_vi: str, question_en: str, pred_vi: str, pred_en: str = "") -> str:
    """Map descriptive yes/no-style outputs to closed-form labels."""
    question_vi_norm = normalize_answer(question_vi)
    question_en_norm = normalize_answer(question_en)
    pred_vi_norm = normalize_answer(pred_vi)
    pred_en_norm = normalize_answer(pred_en)
    combined = " ".join(part for part in [pred_vi_norm, pred_en_norm] if part).strip()

    is_normality_question = any(
        pattern in " ".join([question_vi_norm, question_en_norm])
        for pattern in ["bình thường", "normal", "abnormal", "bat thuong"]
    )

    if is_normality_question:
        explicit_negative_patterns = [
            "không bình thường",
            "not normal",
        ]
        explicit_positive_patterns = [
            "có",
            "yes",
        ]
        positive_patterns = [
            "bình thường",
            "normal",
            "no significant abnormalities",
            "no abnormality",
            "unremarkable",
            "appears to be normal",
            "without significant abnormalities",
            "không phát hiện bất thường",
        ]
        negative_patterns = [
            "bất thường",
            "abnormal",
            "abnormality detected",
            "fracture",
            "lesion",
            "mass",
            "effusion",
            "pneumothorax",
        ]
        if any(pattern in combined for pattern in explicit_negative_patterns):
            return "không"
        if any(pattern in combined.split() for pattern in explicit_positive_patterns):
            return "có"
        if any(pattern in combined for pattern in positive_patterns):
            return "có"
        if any(pattern in combined for pattern in negative_patterns):
            return "không"
    else:
        positive_patterns = [
            "có",
            "yes",
            "present",
            "detected",
            "positive",
        ]
        negative_patterns = [
            "không",
            "no",
            "absent",
            "not seen",
            "negative",
            "none",
        ]
        # For presence/absence questions, "không có ..." contains "có" but
        # semantically means no. Check negation before positive cues.
        if any(pattern in combined for pattern in negative_patterns):
            return "không"
        if any(pattern in combined for pattern in positive_patterns):
            return "có"

    fallback_positive_patterns = [
        "bình thường",
        "normal",
        "no significant abnormalities",
        "no abnormality",
        "unremarkable",
        "appears to be normal",
        "without significant abnormalities",
        "không phát hiện bất thường",
    ]
    fallback_negative_patterns = [
        "bất thường",
        "abnormal",
        "abnormality detected",
        "fracture",
        "lesion",
        "mass",
        "effusion",
        "pneumothorax",
    ]

    if any(pattern in combined for pattern in fallback_positive_patterns):
        return "có"
    if any(pattern in combined for pattern in fallback_negative_patterns):
        return "không"
    return pred_vi_norm or pred_en_norm


def _compute_format_stats(preds: list[str], max_words: int) -> dict[str, float]:
    if not preds:
        return {
            "max_10_word_compliance_rate": 0.0,
            "medical_term_compliance_rate": 0.0,
            "avg_answer_length": 0.0,
        }
    word_counts = [len(p.split()) for p in preds]
    return {
        "max_10_word_compliance_rate": sum(1 for count in word_counts if count <= max_words) / len(word_counts),
        "medical_term_compliance_rate": sum(1 for pred in preds if is_medical_term_compliant(pred)) / len(preds),
        "avg_answer_length": sum(word_counts) / len(word_counts),
    }


def _build_bad_words_ids(processor, variant: str) -> list[list[int]] | None:
    if variant not in {"B1", "B2", "DPO", "PPO"}:
        return None

    tokenizer = getattr(processor, "tokenizer", None)
    if tokenizer is None:
        return None

    banned_phrases = [
        "yes",
        "no",
        "the answer is",
        "the image is",
        "this image is",
        "the image shows",
        "the scan shows",
        "there is",
        "there are",
        "it appears",
        "the finding is",
    ]

    bad_words_ids = []
    for phrase in banned_phrases:
        token_ids = tokenizer.encode(phrase, add_special_tokens=False)
        if token_ids:
            bad_words_ids.append(token_ids)
    return bad_words_ids or None


def _attach_metric_views(metrics: dict[str, float]) -> dict[str, float]:
    """Add explicit metric names while preserving backward-compatible aliases."""
    if "accuracy" in metrics:
        metrics["accuracy_normalized"] = metrics["accuracy"]
    if "em" in metrics:
        metrics["em_normalized"] = metrics["em"]
    if "f1" in metrics:
        metrics["f1_normalized"] = metrics["f1"]
    if "bleu1" in metrics:
        metrics["bleu1_normalized"] = metrics["bleu1"]
    if "bleu2" in metrics:
        metrics["bleu2_normalized"] = metrics["bleu2"]
    if "bleu3" in metrics:
        metrics["bleu3_normalized"] = metrics["bleu3"]
    if "bleu4" in metrics:
        metrics["bleu4_normalized"] = metrics["bleu4"]
    if "rouge_l" in metrics:
        metrics["rouge_l_normalized"] = metrics["rouge_l"]
    if "meteor" in metrics:
        metrics["meteor_normalized"] = metrics["meteor"]
    if "bert_score" in metrics:
        metrics["bert_score_raw"] = metrics["bert_score"]
    if "semantic" in metrics:
        metrics["semantic_raw"] = metrics["semantic"]
    return metrics

class MedicalVQAEvaluator:
    """
    Hệ thống đánh giá hợp nhất cho cả Hướng A và Hướng B.
    """
    def __init__(self, device, tokenizer=None, processor=None):
        self.device = device
        self.tokenizer = tokenizer
        self.processor = processor

    def evaluate(self, model, dataloader, variant_type='A', beam_width=1):
        """
        Giao diện chung để đánh giá bất kỳ variant nào.
        """
        if variant_type == 'A':
            return evaluate_vqa(model, dataloader, self.device, self.tokenizer, beam_width)
        else:
            return evaluate_multimodal_vqa(model, dataloader, self.device, self.processor, beam_width, variant=variant_type)

def evaluate_vqa(model, dataloader, device, tokenizer, beam_width=1, max_len=32, max_words=10):
    model.eval()
    all_preds = []
    all_preds_raw = []
    all_preds_display = []
    all_refs = []
    all_refs_full = []
    all_is_closed = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            images = batch['image'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label_closed']
            
            # [FIX] Gọi inference() để lấy CẢ HAI head outputs, truyền max_len từ config
            logits_closed, pred_ids = model.inference(images, input_ids, attention_mask, beam_width=beam_width, max_len=max_len)
            
            # Decode generative head + làm sạch subword artifacts
            preds_text_raw = [
                postprocess_answer(t, max_words=max_words)
                for t in tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
            ]
            preds_text = list(preds_text_raw)
            
            # [CRITICAL FIX] Với câu Đóng (Yes/No), dùng classifier head thay vì generator
            closed_map = {0: "không", 1: "có"}
            closed_preds_idx = torch.argmax(logits_closed, dim=-1)  # [B]
            for i in range(len(preds_text)):
                if labels[i].item() != -1:  # Câu hỏi đóng
                    preds_text[i] = closed_map[closed_preds_idx[i].item()]
                preds_text[i] = postprocess_answer(preds_text[i], max_words=max_words)
            
            # Debug: Hiển thị cả câu Đóng và câu Mở để kiểm tra đa dạng
            if len(all_preds) == 0:
                print("\n--- DEBUG PREDICTIONS ---")
                shown_closed, shown_open = 0, 0
                for i in range(len(preds_text)):
                    is_closed = labels[i].item() != -1
                    if (is_closed and shown_closed < 2) or (not is_closed and shown_open < 2):
                        q_type = "CLOSED" if is_closed else "OPEN"
                        print(f"[{q_type}] Q: {batch['raw_questions'][i]}")
                        print(f"  Pred raw: '{preds_text_raw[i]}'")
                        print(f"  Pred normalized: '{preds_text[i]}'")
                        print(f"  GT  : '{batch['raw_answer'][i]}'")
                        if is_closed: shown_closed += 1
                        else: shown_open += 1
                    if shown_closed >= 2 and shown_open >= 2:
                        break
                print("--------------------------\n")
            
            all_preds.extend([normalize_for_metric(p) for p in preds_text])
            all_preds_raw.extend([normalize_for_metric(p) for p in preds_text_raw])
            all_preds_display.extend([normalize_for_metric(p) for p in preds_text_raw])
            # [CRITICAL FIX] Dùng đáp án Tiếng Việt để chấm điểm
            all_refs.extend([normalize_for_metric(postprocess_answer(r, max_words=max_words)) for r in batch['raw_answer']])
            all_refs_full.extend([normalize_for_metric(postprocess_answer(r, max_words=100)) for r in batch.get('raw_answer_full', batch['raw_answer'])])
            is_closed = (batch['label_closed'] != -1).tolist()
            all_is_closed.extend(is_closed)

    metrics = batch_metrics(all_preds, all_refs)
    metrics["semantic"] = compute_semantic_score(all_preds_raw, all_refs)
    metrics["bert_score"] = compute_bertscore(all_preds_raw, all_refs)
    metrics = _attach_metric_views(metrics)
    metrics.update(_compute_format_stats(all_preds, max_words=max_words))
    metrics['predictions'] = all_preds
    metrics['predictions_raw'] = all_preds_raw
    metrics['predictions_display'] = all_preds_display
    metrics['ground_truths'] = all_refs
    
    closed_preds = [p for p, c in zip(all_preds, all_is_closed) if c]
    closed_refs = [r for r, c in zip(all_refs, all_is_closed) if c]
    closed_preds_raw = [p for p, c in zip(all_preds_raw, all_is_closed) if c]
    if closed_preds:
        metrics['closed'] = batch_metrics(closed_preds, closed_refs)
        metrics['closed']["semantic"] = compute_semantic_score(closed_preds_raw, closed_refs)
        metrics['closed']["bert_score"] = compute_bertscore(closed_preds_raw, closed_refs)
        metrics['closed'] = _attach_metric_views(metrics['closed'])
        metrics['closed'].update(_compute_format_stats(closed_preds, max_words=max_words))
        metrics['closed_eval'] = {
            "accuracy": metrics['closed'].get("accuracy_normalized", 0.0),
            "em": metrics['closed'].get("em_normalized", 0.0),
            "f1": metrics['closed'].get("f1_normalized", 0.0),
            "count": len(closed_preds),
        }
    open_preds = [p for p, c in zip(all_preds, all_is_closed) if not c]
    open_refs = [r for r, c in zip(all_refs, all_is_closed) if not c]
    open_preds_raw = [p for p, c in zip(all_preds_raw, all_is_closed) if not c]
    if open_preds:
        metrics['open'] = batch_metrics(open_preds, open_refs)
        metrics['open']["semantic"] = compute_semantic_score(open_preds_raw, open_refs)
        metrics['open']["bert_score"] = compute_bertscore(open_preds_raw, open_refs)
        metrics['open'] = _attach_metric_views(metrics['open'])
        metrics['open'].update(_compute_format_stats(open_preds, max_words=max_words))
        metrics['open_eval'] = {
            "semantic": metrics['open'].get("semantic_raw", 0.0),
            "bert_score": metrics['open'].get("bert_score_raw", 0.0),
            "f1": metrics['open'].get("f1_normalized", 0.0),
            "rouge_l": metrics['open'].get("rouge_l_normalized", 0.0),
            "count": len(open_preds),
        }
        
    metrics['long_answers_eval'] = {
        "accuracy": batch_metrics(all_preds, all_refs_full).get("accuracy_normalized", 0),
        "f1": batch_metrics(all_preds, all_refs_full).get("f1_normalized", 0),
        "bleu4": batch_metrics(all_preds, all_refs_full).get("bleu4_normalized", 0),
        "semantic": compute_semantic_score(all_preds_raw, all_refs_full),
        "bert_score": compute_bertscore(all_preds_raw, all_refs_full)
    }
    return metrics

# ─────────────────────────────────────────────────────────────────────────────
# B1 HELPERS
# ─────────────────────────────────────────────────────────────────────────────

_B1_FEW_SHOT = (
    "Q: Is there cardiomegaly? A: yes\n"
    "Q: What organ is shown? A: lung\n"
    "Q: Is the aorta normal? A: no\n"
    "Q: What abnormality is present? A: pleural effusion\n"
)


def _build_b1_prompt(question_en: str, max_words: int) -> str:
    """
    Few-shot prompt ép LLaVA trả lời ngắn (≤max_words từ y tế), không sinh câu dài.
    Đặt 4 ví dụ in-context trước câu hỏi thực để suppress verbose prefix.
    """
    return (
        f"USER: <image>\n"
        f"Answer each question with medical terminology only, "
        f"no more than {max_words} words, no full sentences.\n"
        f"{_B1_FEW_SHOT}"
        f"Q: {question_en} A: ASSISTANT:"
    )


# En → Vi fast lookup (50+ thuật ngữ y tế thường gặp trong SLAKE + VQA-RAD)
_EN_VI_DIRECT: dict = {
    # binary
    "yes": "có", "no": "không",
    "present": "có", "absent": "không",
    "normal": "bình thường", "abnormal": "bất thường",
    "true": "có", "false": "không",
    "positive": "có", "negative": "không",
    # anatomy
    "lung": "phổi", "lungs": "phổi",
    "heart": "tim", "liver": "gan", "spleen": "lách",
    "kidney": "thận", "brain": "não", "bladder": "bàng quang",
    "chest": "ngực", "abdomen": "bụng", "pelvis": "xương chậu",
    "spine": "cột sống", "rib": "xương sườn", "ribs": "xương sườn",
    "trachea": "khí quản", "aorta": "động mạch chủ",
    "diaphragm": "cơ hoành", "mediastinum": "trung thất",
    # modality
    "chest x-ray": "x-quang ngực", "x-ray": "x-quang", "xray": "x-quang",
    "mri": "mri", "ct": "ct", "ultrasound": "siêu âm",
    "ct scan": "ct", "mri scan": "mri",
    # planes
    "axial": "mặt phẳng ngang",
    "coronal": "mặt phẳng vành",
    "sagittal": "mặt phẳng dọc",
    "transverse": "mặt phẳng ngang",
    # pathologies
    "cardiomegaly": "tim to",
    "pneumonia": "viêm phổi",
    "pleural effusion": "tràn dịch màng phổi",
    "pneumothorax": "tràn khí màng phổi",
    "fracture": "gãy xương",
    "edema": "phù nề",
    "pulmonary edema": "phù phổi",
    "consolidation": "đông đặc",
    "atelectasis": "xẹp phổi",
    "opacity": "mờ đục",
    "mass": "khối u",
    "nodule": "nốt",
    "lesion": "tổn thương",
    "tumor": "khối u",
    "effusion": "tràn dịch",
    "infiltrate": "thâm nhiễm",
    "fibrosis": "xơ hóa",
    "calcification": "vôi hóa",
    "carcinoma": "ung thư",
    "metastasis": "di căn",
    "bilateral": "hai bên",
    "unilateral": "một bên",
    "left": "trái", "right": "phải",
    "upper": "trên", "lower": "dưới",
    "right upper quadrant": "phía trên bên phải",
    "left upper quadrant": "phía trên bên trái",
    "right lower quadrant": "phía dưới bên phải",
    "left lower quadrant": "phía dưới bên trái",
    "right upper": "phía trên bên phải",
    "left upper": "phía trên bên trái",
    "upper left": "phía trên bên trái",
    "upper right": "phía trên bên phải",
    "lower left": "phía dưới bên trái",
    "lower right": "phía dưới bên phải",
}


def _extract_key_medical_term(raw_en: str, max_words: int) -> str:
    """
    Loại bỏ verbose prefix LLaVA hay sinh ("The image shows a chest X-ray with..."),
    chỉ giữ lại thuật ngữ y tế chính.
    """
    import re
    text = raw_en.strip().lower()

    # Các prefix verbose phổ biến cần xóa
    prefixes = [
        r"^the (image|scan|x-ray|xray|mri|ct|picture|photo|radiograph) (shows?|depicts?|demonstrates?|reveals?|indicates?|presents?)\s+",
        r"^based on the (image|scan|x-ray|mri|ct)\s*,?\s*",
        r"^in (this|the) (image|scan|x-ray|mri|ct)\s*,?\s*",
        r"^i (can see|observe|notice|see)\s+",
        r"^there (is|are)\s+(a |an |some )?",
        r"^(it |this )(shows?|is|appears?|looks?)\s+(like\s+)?",
        r"^the (patient|subject)\s+(has|shows?|presents?)\s+",
        r"^(a|an|the)\s+",
        r"^[a-z\s]+ is (located|seen|found|present)( in| at| on)?\s+(the\s+)?",
    ]
    for pat in prefixes:
        text = re.sub(pat, "", text)

    text = re.sub(r"[.!?,;:]+$", "", text).strip()
    text = re.sub(r"\s+", " ", text).strip()

    words = text.split()
    return " ".join(words[:max_words]) if words else raw_en.strip()


def _en_to_vi_direct(en_text: str) -> str | None:
    """
    Tra từ điển nhanh. Sắp xếp theo độ dài giảm dần để phrase dài match trước.
    Trả về None nếu không match → caller dùng Translation Model.
    """
    norm = en_text.strip().lower()
    if norm in _EN_VI_DIRECT:
        return _EN_VI_DIRECT[norm]
    return None


def _dual_score_open(
    preds_vi: list,
    preds_en: list,
    refs_vi: list,
    refs_en: list,
) -> list:
    """
    Với mỗi câu hỏi mở, so sánh F1 Vi vs F1 En rồi chọn prediction tốt hơn.
    Giải quyết 0% open-ended do dịch thuật mất nghĩa.
    """
    from src.utils.metrics import compute_f1
    from src.utils.text_utils import normalize_answer
    result = []
    for pv, pe, rv, re_ in zip(preds_vi, preds_en, refs_vi, refs_en):
        f1_vi = compute_f1(pv, rv)
        f1_en = compute_f1(normalize_answer(pe), normalize_answer(re_)) if re_ else 0.0
        result.append(pv if f1_vi >= f1_en else normalize_answer(pe))
    return result


def evaluate_multimodal_vqa(
    model,
    dataloader,
    device,
    processor,
    beam_width=1,
    max_words=10,
    variant='B1',
    beam_width_closed=None,
    beam_width_open=None,
    max_new_tokens_closed=None,
    max_new_tokens_open=None,
    generation_batch_size=None,
):
    """
    B1 Zero-Shot evaluation & B2/DPO/PPO Fine-Tuned evaluation.
    """
    model.eval()
    all_preds     = []
    all_preds_raw = []
    all_preds_display = []
    all_preds_en  = []
    all_refs      = []
    all_refs_full = []
    all_refs_en   = []
    all_is_closed = []

    from src.utils.translator import MedicalTranslator
    translator = MedicalTranslator(device=device.type)
    
    from src.models.multimodal_vqa import MultimodalVQA
    wrapper = MultimodalVQA()

    beam_width_closed = beam_width if beam_width_closed is None else beam_width_closed
    beam_width_open = beam_width if beam_width_open is None else beam_width_open
    max_new_tokens_closed = 4 if max_new_tokens_closed is None else max_new_tokens_closed
    max_new_tokens_open = (max_words + 6) if max_new_tokens_open is None else max_new_tokens_open
    generation_batch_size = 1 if generation_batch_size is None else max(1, int(generation_batch_size))
    bad_words_ids = _build_bad_words_ids(processor, variant)

    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Evaluating {variant}")):
            raw_images   = batch.get('raw_image')
            questions_vi = batch.get('raw_questions', [])
            questions_en = batch.get('raw_questions_en', [])
            refs_vi_raw  = batch.get('raw_answer', [])
            refs_en_raw  = batch.get('raw_answer_en', [])
            labels       = batch['label_closed']

            if variant == 'B1':
                # B1 (Zero-shot) needs English translation & English few-shot prompt
                if not questions_en or any(not str(q).strip() for q in questions_en):
                    questions_en = translator.translate_vi2en(questions_vi)
                prompts = [_build_b1_prompt(q, max_words) for q in questions_en]
            else:
                # B2 / DPO / PPO (Fine-tuned) expect Vietnamese instruction directly
                prompts = [wrapper.build_instruction_prompt(q, language="vi", include_answer=False) for q in questions_vi]
            preds_raw = [""] * len(prompts)
            closed_idx = [i for i, lbl in enumerate(labels.tolist()) if lbl != -1]
            open_idx = [i for i, lbl in enumerate(labels.tolist()) if lbl == -1]

            def _run_generation(sample_indices, num_beams, max_new_tokens):
                if not sample_indices:
                    return []
                decoded_outputs = []
                chunk_size = generation_batch_size if num_beams > 1 else max(generation_batch_size, 2)

                for start in range(0, len(sample_indices), chunk_size):
                    chunk_indices = sample_indices[start:start + chunk_size]
                    text_subset = [prompts[i] for i in chunk_indices]
                    image_subset = [raw_images[i] for i in chunk_indices] if raw_images is not None else None
                    if image_subset is not None:
                        inputs = processor(
                            text=text_subset,
                            images=image_subset,
                            return_tensors="pt",
                            padding=True,
                        ).to(device)
                        if "pixel_values" in inputs:
                            inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
                    else:
                        inputs = processor(text=text_subset, return_tensors="pt", padding=True).to(device)

                    output_ids = model.generate(
                        **inputs,
                        max_new_tokens=max_new_tokens,
                        do_sample=False,
                        num_beams=num_beams,
                        early_stopping=num_beams > 1,
                        bad_words_ids=bad_words_ids,
                    )
                    input_token_len = inputs.input_ids.shape[1]
                    decoded_outputs.extend(
                        processor.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)
                    )

                    del inputs, output_ids
                    if device.type == "cuda":
                        torch.cuda.empty_cache()

                return decoded_outputs

            if variant == 'B1':
                generated = _run_generation(list(range(len(prompts))), beam_width_open, max_new_tokens_open)
                preds_raw = generated
            else:
                for idx, pred in zip(closed_idx, _run_generation(closed_idx, beam_width_closed, max_new_tokens_closed)):
                    preds_raw[idx] = pred
                for idx, pred in zip(open_idx, _run_generation(open_idx, beam_width_open, max_new_tokens_open)):
                    preds_raw[idx] = pred

            preds_vi = []
            preds_vi_display = []
            preds_en_clean = []
            
            if variant == 'B1':
                # [FIX 2] Strip verbose prefix → giữ key medical term. Tránh cắt vụn câu tiếng Anh để Dịch thuật hiểu đúng.
                preds_en_clean = [_extract_key_medical_term(p, 50) for p in preds_raw]
    
                # [FIX 3 + 5] Per-sample: closed → normalize En trước; open → dict lookup rồi Translation Model
                needs_translate_idx = []   # index cần dịch
                needs_translate_txt = []
    
                for i, pred_en in enumerate(preds_en_clean):
                    if labels[i].item() != -1:
                        # Closed: dùng _normalize_closed_answer với En pred (chính xác hơn)
                        preds_vi.append(
                            _normalize_closed_answer(
                                questions_vi[i], questions_en[i], pred_en, pred_en
                            )
                        )
                    else:
                        # Open: thử dict nhanh trước
                        vi_direct = _en_to_vi_direct(pred_en)
                        if vi_direct is not None:
                            preds_vi.append(postprocess_answer(vi_direct, max_words=max_words))
                        else:
                            preds_vi.append(None)           # placeholder
                            needs_translate_idx.append(i)
                            needs_translate_txt.append(pred_en)
    
                # Batch dịch những câu cần Translation Model
                if needs_translate_txt:
                    translated = translator.translate_en2vi(needs_translate_txt)
                    if isinstance(translated, str):
                        translated = [translated]
                    for idx, vi in zip(needs_translate_idx, translated):
                        preds_vi[idx] = postprocess_answer(vi, max_words=max_words)
                preds_vi_display = list(preds_vi)
            else:
                # B2 / DPO / PPO directly outputs Vietnamese, no translation needed
                preds_vi_display = [postprocess_answer(p, max_words=max_words) if p else "" for p in preds_raw]
                for i, pred_vi in enumerate(preds_raw):
                    if labels[i].item() != -1:
                        preds_vi.append(
                            _normalize_closed_answer(
                                questions_vi[i], questions_en[i] if i < len(questions_en) else "", pred_vi
                            )
                        )
                    else:
                        preds_vi.append(pred_vi)
                preds_en_clean = [""] * len(preds_raw)

            # Đảm bảo không có None
            preds_vi = [postprocess_answer(p, max_words=max_words) if p else "" for p in preds_vi]
            preds_vi_display = [postprocess_answer(p, max_words=max_words) if p else "" for p in preds_vi_display]
            preds_vi_raw = list(preds_vi_display)

            # Refs
            refs_vi  = [postprocess_answer(r, max_words=max_words) for r in refs_vi_raw]
            refs_en  = [postprocess_answer(r, max_words=max_words) if r else "" for r in refs_en_raw]

            # Debug batch đầu
            if batch_idx == 0:
                print(f"\n--- DEBUG {variant} (Evaluation) ---")
                for i in range(min(4, len(preds_vi))):
                    q_type = "CLOSED" if labels[i].item() != -1 else "OPEN"
                    if variant == 'B1':
                        print(f"[{q_type}] Q (En): {questions_en[i]}")
                        print(f"  Pred (En raw):   '{preds_raw[i]}'")
                        print(f"  Pred (En clean): '{preds_en_clean[i]}'")
                    else:
                        print(f"[{q_type}] Q (Vi): {questions_vi[i]}")
                        print(f"  Pred (Vi raw):   '{preds_raw[i]}'")
                    print(f"  Pred display:    '{preds_vi_display[i]}'")
                    print(f"  Pred (Vi):       '{preds_vi[i]}'")
                    print(f"  GT (Vi): '{refs_vi[i]}'  |  GT (En): '{refs_en[i]}'")
                print("-----------------------------------------\n")

            all_preds.extend([normalize_for_metric(p) for p in preds_vi])
            all_preds_raw.extend([normalize_for_metric(p) for p in preds_vi_raw])
            all_preds_display.extend([normalize_for_metric(p) for p in preds_vi_display])
            all_preds_en.extend([normalize_for_metric(p) for p in preds_en_clean])
            all_refs.extend([normalize_for_metric(r) for r in refs_vi])
            all_refs_full.extend([normalize_for_metric(postprocess_answer(r, max_words=100)) for r in batch.get('raw_answer_full', batch['raw_answer'])])
            all_refs_en.extend([normalize_for_metric(r) for r in refs_en])
            all_is_closed.extend((labels != -1).tolist())

    # [FIX 4] Dual-language scoring cho open-ended (chỉ dùng cho B1)
    if variant == 'B1':
        open_idx = [i for i, c in enumerate(all_is_closed) if not c]
        if open_idx:
            best_open = _dual_score_open(
                [all_preds[i]    for i in open_idx],
                [all_preds_en[i] for i in open_idx],
                [all_refs[i]     for i in open_idx],
                [all_refs_en[i]  for i in open_idx],
            )
            for k, i in enumerate(open_idx):
                all_preds[i] = best_open[k]

    # ── Compute metrics ──────────────────────────────────────────────────────
    metrics = batch_metrics(all_preds, all_refs)
    metrics["semantic"]   = compute_semantic_score(all_preds_raw, all_refs)
    metrics["bert_score"] = compute_bertscore(all_preds_raw, all_refs)
    metrics = _attach_metric_views(metrics)
    metrics.update(_compute_format_stats(all_preds, max_words=max_words))
    metrics['predictions']      = all_preds
    metrics['predictions_raw']  = all_preds_raw
    metrics['predictions_display'] = all_preds_display
    metrics['predictions_en']   = all_preds_en
    metrics['ground_truths']    = all_refs
    metrics['ground_truths_en'] = all_refs_en

    def _subset(pred_list, ref_list, pred_raw_list):
        m = batch_metrics(pred_list, ref_list)
        m["semantic"]   = compute_semantic_score(pred_raw_list, ref_list)
        m["bert_score"] = compute_bertscore(pred_raw_list, ref_list)
        m = _attach_metric_views(m)
        m.update(_compute_format_stats(pred_list, max_words=max_words))
        return m

    closed_idx = [i for i, c in enumerate(all_is_closed) if c]
    open_idx   = [i for i, c in enumerate(all_is_closed) if not c]

    if closed_idx:
        metrics['closed'] = _subset(
            [all_preds[i]     for i in closed_idx],
            [all_refs[i]      for i in closed_idx],
            [all_preds_raw[i] for i in closed_idx],
        )
        metrics['closed_eval'] = {
            "accuracy": metrics['closed'].get("accuracy_normalized", 0.0),
            "em": metrics['closed'].get("em_normalized", 0.0),
            "f1": metrics['closed'].get("f1_normalized", 0.0),
            "count": len(closed_idx),
        }
    if open_idx:
        metrics['open'] = _subset(
            [all_preds[i]     for i in open_idx],
            [all_refs[i]      for i in open_idx],
            [all_preds_raw[i] for i in open_idx],
        )
        metrics['open_eval'] = {
            "semantic": metrics['open'].get("semantic_raw", 0.0),
            "bert_score": metrics['open'].get("bert_score_raw", 0.0),
            "f1": metrics['open'].get("f1_normalized", 0.0),
            "rouge_l": metrics['open'].get("rouge_l_normalized", 0.0),
            "count": len(open_idx),
        }
        
    metrics['long_answers_eval'] = {
        "accuracy": batch_metrics(all_preds, all_refs_full).get("accuracy_normalized", 0),
        "f1": batch_metrics(all_preds, all_refs_full).get("f1_normalized", 0),
        "bleu4": batch_metrics(all_preds, all_refs_full).get("bleu4_normalized", 0),
        "semantic": compute_semantic_score(all_preds_raw, all_refs_full),
        "bert_score": compute_bertscore(all_preds_raw, all_refs_full)
    }

    return metrics