File size: 4,090 Bytes
5551585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import json
import random
import os

def load_predictions(file_path):
    """Load JSON predictions."""
    if not os.path.exists(file_path):
        print(f"[ERROR] Không tìm thấy file: {file_path}")
        return []
    with open(file_path, "r", encoding="utf-8") as f:
        return json.load(f)

def manual_review(samples, preds_b2, preds_dpo, num_samples=20):
    """
    So sánh SFT (B2) vs DPO. Lưu lại sở thích dựa trên tính chính xác y khoa.
    """
    results = {"B2_wins": 0, "DPO_wins": 0, "Tie": 0}
    
    # Lấy các index ngẫu nhiên
    indices = list(range(len(samples)))
    random.shuffle(indices)
    review_indices = indices[:min(num_samples, len(samples))]
    
    print("\n" + "="*50)
    print(f"BẮT ĐẦU PHIÊN ĐÁNH GIÁ THỦ CÔNG ({len(review_indices)} câu hỏi)")
    print("Mục tiêu: Đánh giá xem DPO có sinh ra câu trả lời tốt hơn B2 không.")
    print("="*50)
    
    for i, idx in enumerate(review_indices):
        sample = samples[idx]
        b2_ans = preds_b2[idx].get("predicted", "") if idx < len(preds_b2) else "N/A"
        dpo_ans = preds_dpo[idx].get("predicted", "") if idx < len(preds_dpo) else "N/A"
        
        # Ground Truth
        q_en = sample.get("question", sample.get("raw_questions", ""))
        gt_en = sample.get("answer", sample.get("raw_answers", ""))
        gt_vi = sample.get("answer_vi", "")
        
        print(f"\n[Câu {i+1}/{len(review_indices)}]")
        print(f"Câu hỏi (En): {q_en}")
        print(f"Đáp án chuẩn (Vi): {gt_vi}")
        print("-" * 30)
        
        # Randomize order to prevent bias (Blind Test)
        is_b2_first = random.choice([True, False])
        
        if is_b2_first:
            print(f"Mô hình 1: {b2_ans}")
            print(f"Mô hình 2: {dpo_ans}")
        else:
            print(f"Mô hình 1: {dpo_ans}")
            print(f"Mô hình 2: {b2_ans}")
            
        print("-" * 30)
        choice = ""
        while choice not in ['1', '2', '3']:
            choice = input("Mô hình nào tốt hơn? (1: Mô hình 1 | 2: Mô hình 2 | 3: Hòa): ").strip()
            
        if choice == '3':
            results["Tie"] += 1
        elif (choice == '1' and is_b2_first) or (choice == '2' and not is_b2_first):
            results["B2_wins"] += 1
        else:
            results["DPO_wins"] += 1
            
    print("\n" + "="*50)
    print("KẾT QUẢ ĐÁNH GIÁ THỦ CÔNG (BLIND TEST)")
    print("="*50)
    print(f"B2 thắng:  {results['B2_wins']}")
    print(f"DPO thắng: {results['DPO_wins']}")
    print(f"Hòa:       {results['Tie']}")
    print("="*50)
    
    if results['DPO_wins'] > results['B2_wins']:
        print("=> Kết luận: DPO ĐÃ CẢI THIỆN ĐƯỢC CHẤT LƯỢNG SINH VĂN BẢN (RLHF hoạt động tốt!)")
    elif results['DPO_wins'] < results['B2_wins']:
        print("=> Kết luận: DPO sinh ra kết quả kém hơn B2 (Cần chỉnh lại tham số Beta hoặc dữ liệu Preference).")
    else:
        print("=> Kết luận: B2 và DPO không có sự chênh lệch rõ rệt.")
        
    return results

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--data", type=str, default="data/raw/vqa_rad.json", help="Path to ground truth dataset")
    parser.add_argument("--b2", type=str, default="results/predictions/B2_predictions.json")
    parser.add_argument("--dpo", type=str, default="results/predictions/DPO_predictions.json")
    parser.add_argument("--n", type=int, default=20, help="Số lượng câu cần đánh giá")
    args = parser.parse_args()
    
    # Load data
    samples = load_predictions(args.data)
    preds_b2 = load_predictions(args.b2)
    preds_dpo = load_predictions(args.dpo)
    
    if samples and preds_b2 and preds_dpo:
        manual_review(samples, preds_b2, preds_dpo, num_samples=args.n)
    else:
        print("Vui lòng chạy đánh giá và lưu kết quả predict của B2 và DPO ra file JSON trước khi dùng script này.")