File size: 3,215 Bytes
d63774a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
from transformers import LlavaProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training

class MultimodalVQA:
    """
    Wrapper cho LLaVA-Med-7B tích hợp QLoRA 4-bit để huấn luyện trên Kaggle.
    Sử dụng kiến trúc LLaVA-1.5 (microsoft/llava-med-v1.5-7b).
    """
    def __init__(
        self,
        model_id="chaoyinshe/llava-med-v1.5-mistral-7b-hf",
        lora_r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        lora_target_modules=None,
    ):
        self.model_id = model_id
        
        # 1. Cấu hình Quantization 4-bit (Tiết kiệm VRAM)
        self.bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        
        # 2. Cấu hình LoRA (Chỉ huấn luyện một phần nhỏ tham số)
        self.peft_config = LoraConfig(
            r=lora_r,
            lora_alpha=lora_alpha,
            target_modules=lora_target_modules or ["q_proj", "v_proj", "k_proj", "o_proj"],
            lora_dropout=lora_dropout,
            bias="none",
            task_type="CAUSAL_LM"
        )

    def load_model(self, adapter_path=None, is_trainable=True):
        print(f"[INFO] Đang tải LLaVA-Med-v1.5-7B với chế độ 4-bit...")
        processor = LlavaProcessor.from_pretrained(self.model_id)
        processor.tokenizer.padding_side = "left" # Bắt buộc cho decoder-only models
        model = LlavaForConditionalGeneration.from_pretrained(
            self.model_id,
            quantization_config=self.bnb_config,
            device_map="auto"
        )

        model.config.use_cache = False

        # Chuẩn bị mô hình cho PEFT
        model = prepare_model_for_kbit_training(model)
        if adapter_path:
            print(f"[INFO] Đang nạp adapter LoRA từ: {adapter_path}")
            model = PeftModel.from_pretrained(model, adapter_path, is_trainable=is_trainable)
        else:
            model = get_peft_model(model, self.peft_config)
        model.gradient_checkpointing_enable()
        model.enable_input_require_grads()

        model.print_trainable_parameters()
        return model, processor

    def generate_prompt_vi(self, question_en):
        """
        Hàm hỗ trợ tạo prompt cho LLaVA-Med (EN). 
        Nhớ dùng Translation Layer trước khi gọi hàm này.
        """
        return self.build_instruction_prompt(question_en, language="en", include_answer=False)

    def build_instruction_prompt(self, question, language="vi", include_answer=False):
        """
        Prompt thống nhất cho zero-shot, SFT và demo.
        """
        if language == "vi":
            instruction = "Chi tra loi bang tieng Viet, khong dung tieng Anh, thuat ngu y khoa chuan, ngan gon, toi da 10 tu."
        else:
            instruction = "Answer with standard medical terminology, concise, at most 10 words."
        suffix = " ASSISTANT:" if not include_answer else ""
        return f"USER: <image>\n{question}\n{instruction}{suffix}"