Medical-VQA / src /utils /discriminative_lr.py
SpringWang08's picture
Deploy Medical VQA app
d63774a
"""
Discriminative learning rates for different model layers.
Earlier layers (pretrained) get lower LR to preserve learned features.
Later layers get higher LR for task-specific adaptation.
"""
import torch
from torch.optim import AdamW
from transformers import get_cosine_schedule_with_warmup
def create_discriminative_optimizer(model, config):
"""
Create optimizer with discriminative learning rates.
Layer groups and their learning rates:
- Image Encoder (pretrained XRV): 1e-5 (preserve medical features)
- Text Encoder (PhoBERT): 1e-5 (preserve language understanding)
- Fusion layer (co-attention): 1e-4 (moderate adaptation)
- Decoder (task-specific): 1e-3 (heavy adaptation)
Args:
model: Model with parameter groups
config: Config dict with learning rates
Returns:
Optimizer with layer-specific learning rates
"""
# Define parameter groups with different learning rates
param_groups = []
base_lr = float(config['train'].get('learning_rate', 3e-4))
vision_lr = float(config['train'].get('vision_lr', 1e-5))
phobert_lr = float(config['train'].get('phobert_lr', 1e-5))
# Group 1: Image Encoder (lowest LR)
if hasattr(model, 'image_encoder'):
param_groups.append({
'params': model.image_encoder.parameters(),
'lr': vision_lr,
'name': 'image_encoder'
})
# Group 2: Text Encoder (low LR)
if hasattr(model, 'text_encoder'):
param_groups.append({
'params': model.text_encoder.parameters(),
'lr': phobert_lr,
'name': 'text_encoder'
})
# Group 3: Fusion/Attention layers (medium LR)
fusion_params = []
if hasattr(model, 'fusion'):
fusion_params.extend(model.fusion.parameters())
if hasattr(model, 'co_attention'):
fusion_params.extend(model.co_attention.parameters())
if hasattr(model, 'spatial_attention'):
fusion_params.extend(model.spatial_attention.parameters())
if fusion_params:
param_groups.append({
'params': fusion_params,
'lr': base_lr * 0.5, # 50% of base LR
'name': 'fusion'
})
# Group 4: Decoder (highest LR)
decoder_params = []
if hasattr(model, 'decoder'):
decoder_params.extend(model.decoder.parameters())
if hasattr(model, 'open_head'):
decoder_params.extend(model.open_head.parameters())
if hasattr(model, 'closed_head'):
decoder_params.extend(model.closed_head.parameters())
if decoder_params:
param_groups.append({
'params': decoder_params,
'lr': base_lr, # Full base LR
'name': 'decoder'
})
# Group 5: Any remaining parameters
# Collect all params that aren't in above groups
all_params = set(model.parameters())
grouped_params = set()
for group in param_groups:
grouped_params.update(group['params'])
remaining_params = [p for p in all_params if p not in grouped_params]
if remaining_params:
param_groups.append({
'params': remaining_params,
'lr': base_lr * 0.1, # 10% of base LR for safety
'name': 'remaining'
})
# Create optimizer
optimizer = AdamW(
param_groups,
betas=(0.9, 0.999),
weight_decay=config['train'].get('weight_decay', 0.01)
)
# Log layer learning rates
print("[INFO] Discriminative Learning Rates Setup:")
for group in param_groups:
param_count = sum(p.numel() for p in group['params'])
print(f" {group['name']:15s}: LR={group['lr']:.2e}, Params={param_count:,}")
return optimizer
def create_scheduler_with_warmup(optimizer, num_training_steps, config):
"""
Create cosine scheduler with warmup.
Args:
optimizer: Optimizer instance
num_training_steps: Total training steps
config: Config dict
Returns:
LambdaLR scheduler with warmup
"""
warmup_steps = int(num_training_steps * config['train'].get('warmup_steps_ratio', 0.1))
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=num_training_steps,
num_cycles=0.5, # 0.5 = cosine goes from 1 to 0
last_epoch=-1
)
print(f"[INFO] Scheduler: Cosine with warmup")
print(f" Warmup steps: {warmup_steps} ({warmup_steps/num_training_steps*100:.1f}%)")
print(f" Total steps: {num_training_steps}")
return scheduler
def get_current_learning_rates(optimizer):
"""Get current learning rate for each parameter group."""
lrs = {}
for i, param_group in enumerate(optimizer.param_groups):
name = param_group.get('name', f'group_{i}')
lrs[name] = param_group['lr']
return lrs