File size: 4,934 Bytes
d63774a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
Discriminative learning rates for different model layers.
Earlier layers (pretrained) get lower LR to preserve learned features.
Later layers get higher LR for task-specific adaptation.
"""

import torch
from torch.optim import AdamW
from transformers import get_cosine_schedule_with_warmup


def create_discriminative_optimizer(model, config):
    """
    Create optimizer with discriminative learning rates.
    
    Layer groups and their learning rates:
    - Image Encoder (pretrained XRV): 1e-5 (preserve medical features)
    - Text Encoder (PhoBERT): 1e-5 (preserve language understanding)
    - Fusion layer (co-attention): 1e-4 (moderate adaptation)
    - Decoder (task-specific): 1e-3 (heavy adaptation)
    
    Args:
        model: Model with parameter groups
        config: Config dict with learning rates
    
    Returns:
        Optimizer with layer-specific learning rates
    """
    
    # Define parameter groups with different learning rates
    param_groups = []
    
    base_lr = float(config['train'].get('learning_rate', 3e-4))
    vision_lr = float(config['train'].get('vision_lr', 1e-5))
    phobert_lr = float(config['train'].get('phobert_lr', 1e-5))
    
    # Group 1: Image Encoder (lowest LR)
    if hasattr(model, 'image_encoder'):
        param_groups.append({
            'params': model.image_encoder.parameters(),
            'lr': vision_lr,
            'name': 'image_encoder'
        })
    
    # Group 2: Text Encoder (low LR)
    if hasattr(model, 'text_encoder'):
        param_groups.append({
            'params': model.text_encoder.parameters(),
            'lr': phobert_lr,
            'name': 'text_encoder'
        })
    
    # Group 3: Fusion/Attention layers (medium LR)
    fusion_params = []
    if hasattr(model, 'fusion'):
        fusion_params.extend(model.fusion.parameters())
    if hasattr(model, 'co_attention'):
        fusion_params.extend(model.co_attention.parameters())
    if hasattr(model, 'spatial_attention'):
        fusion_params.extend(model.spatial_attention.parameters())
    
    if fusion_params:
        param_groups.append({
            'params': fusion_params,
            'lr': base_lr * 0.5,  # 50% of base LR
            'name': 'fusion'
        })
    
    # Group 4: Decoder (highest LR)
    decoder_params = []
    if hasattr(model, 'decoder'):
        decoder_params.extend(model.decoder.parameters())
    if hasattr(model, 'open_head'):
        decoder_params.extend(model.open_head.parameters())
    if hasattr(model, 'closed_head'):
        decoder_params.extend(model.closed_head.parameters())
    
    if decoder_params:
        param_groups.append({
            'params': decoder_params,
            'lr': base_lr,  # Full base LR
            'name': 'decoder'
        })
    
    # Group 5: Any remaining parameters
    # Collect all params that aren't in above groups
    all_params = set(model.parameters())
    grouped_params = set()
    for group in param_groups:
        grouped_params.update(group['params'])
    
    remaining_params = [p for p in all_params if p not in grouped_params]
    if remaining_params:
        param_groups.append({
            'params': remaining_params,
            'lr': base_lr * 0.1,  # 10% of base LR for safety
            'name': 'remaining'
        })
    
    # Create optimizer
    optimizer = AdamW(
        param_groups,
        betas=(0.9, 0.999),
        weight_decay=config['train'].get('weight_decay', 0.01)
    )
    
    # Log layer learning rates
    print("[INFO] Discriminative Learning Rates Setup:")
    for group in param_groups:
        param_count = sum(p.numel() for p in group['params'])
        print(f"  {group['name']:15s}: LR={group['lr']:.2e}, Params={param_count:,}")
    
    return optimizer


def create_scheduler_with_warmup(optimizer, num_training_steps, config):
    """
    Create cosine scheduler with warmup.
    
    Args:
        optimizer: Optimizer instance
        num_training_steps: Total training steps
        config: Config dict
    
    Returns:
        LambdaLR scheduler with warmup
    """
    
    warmup_steps = int(num_training_steps * config['train'].get('warmup_steps_ratio', 0.1))
    
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=num_training_steps,
        num_cycles=0.5,  # 0.5 = cosine goes from 1 to 0
        last_epoch=-1
    )
    
    print(f"[INFO] Scheduler: Cosine with warmup")
    print(f"  Warmup steps: {warmup_steps} ({warmup_steps/num_training_steps*100:.1f}%)")
    print(f"  Total steps: {num_training_steps}")
    
    return scheduler


def get_current_learning_rates(optimizer):
    """Get current learning rate for each parameter group."""
    lrs = {}
    for i, param_group in enumerate(optimizer.param_groups):
        name = param_group.get('name', f'group_{i}')
        lrs[name] = param_group['lr']
    return lrs