Spaces:
Paused
Paused
| """ | |
| Discriminative learning rates for different model layers. | |
| Earlier layers (pretrained) get lower LR to preserve learned features. | |
| Later layers get higher LR for task-specific adaptation. | |
| """ | |
| import torch | |
| from torch.optim import AdamW | |
| from transformers import get_cosine_schedule_with_warmup | |
| def create_discriminative_optimizer(model, config): | |
| """ | |
| Create optimizer with discriminative learning rates. | |
| Layer groups and their learning rates: | |
| - Image Encoder (pretrained XRV): 1e-5 (preserve medical features) | |
| - Text Encoder (PhoBERT): 1e-5 (preserve language understanding) | |
| - Fusion layer (co-attention): 1e-4 (moderate adaptation) | |
| - Decoder (task-specific): 1e-3 (heavy adaptation) | |
| Args: | |
| model: Model with parameter groups | |
| config: Config dict with learning rates | |
| Returns: | |
| Optimizer with layer-specific learning rates | |
| """ | |
| # Define parameter groups with different learning rates | |
| param_groups = [] | |
| base_lr = float(config['train'].get('learning_rate', 3e-4)) | |
| vision_lr = float(config['train'].get('vision_lr', 1e-5)) | |
| phobert_lr = float(config['train'].get('phobert_lr', 1e-5)) | |
| # Group 1: Image Encoder (lowest LR) | |
| if hasattr(model, 'image_encoder'): | |
| param_groups.append({ | |
| 'params': model.image_encoder.parameters(), | |
| 'lr': vision_lr, | |
| 'name': 'image_encoder' | |
| }) | |
| # Group 2: Text Encoder (low LR) | |
| if hasattr(model, 'text_encoder'): | |
| param_groups.append({ | |
| 'params': model.text_encoder.parameters(), | |
| 'lr': phobert_lr, | |
| 'name': 'text_encoder' | |
| }) | |
| # Group 3: Fusion/Attention layers (medium LR) | |
| fusion_params = [] | |
| if hasattr(model, 'fusion'): | |
| fusion_params.extend(model.fusion.parameters()) | |
| if hasattr(model, 'co_attention'): | |
| fusion_params.extend(model.co_attention.parameters()) | |
| if hasattr(model, 'spatial_attention'): | |
| fusion_params.extend(model.spatial_attention.parameters()) | |
| if fusion_params: | |
| param_groups.append({ | |
| 'params': fusion_params, | |
| 'lr': base_lr * 0.5, # 50% of base LR | |
| 'name': 'fusion' | |
| }) | |
| # Group 4: Decoder (highest LR) | |
| decoder_params = [] | |
| if hasattr(model, 'decoder'): | |
| decoder_params.extend(model.decoder.parameters()) | |
| if hasattr(model, 'open_head'): | |
| decoder_params.extend(model.open_head.parameters()) | |
| if hasattr(model, 'closed_head'): | |
| decoder_params.extend(model.closed_head.parameters()) | |
| if decoder_params: | |
| param_groups.append({ | |
| 'params': decoder_params, | |
| 'lr': base_lr, # Full base LR | |
| 'name': 'decoder' | |
| }) | |
| # Group 5: Any remaining parameters | |
| # Collect all params that aren't in above groups | |
| all_params = set(model.parameters()) | |
| grouped_params = set() | |
| for group in param_groups: | |
| grouped_params.update(group['params']) | |
| remaining_params = [p for p in all_params if p not in grouped_params] | |
| if remaining_params: | |
| param_groups.append({ | |
| 'params': remaining_params, | |
| 'lr': base_lr * 0.1, # 10% of base LR for safety | |
| 'name': 'remaining' | |
| }) | |
| # Create optimizer | |
| optimizer = AdamW( | |
| param_groups, | |
| betas=(0.9, 0.999), | |
| weight_decay=config['train'].get('weight_decay', 0.01) | |
| ) | |
| # Log layer learning rates | |
| print("[INFO] Discriminative Learning Rates Setup:") | |
| for group in param_groups: | |
| param_count = sum(p.numel() for p in group['params']) | |
| print(f" {group['name']:15s}: LR={group['lr']:.2e}, Params={param_count:,}") | |
| return optimizer | |
| def create_scheduler_with_warmup(optimizer, num_training_steps, config): | |
| """ | |
| Create cosine scheduler with warmup. | |
| Args: | |
| optimizer: Optimizer instance | |
| num_training_steps: Total training steps | |
| config: Config dict | |
| Returns: | |
| LambdaLR scheduler with warmup | |
| """ | |
| warmup_steps = int(num_training_steps * config['train'].get('warmup_steps_ratio', 0.1)) | |
| scheduler = get_cosine_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=warmup_steps, | |
| num_training_steps=num_training_steps, | |
| num_cycles=0.5, # 0.5 = cosine goes from 1 to 0 | |
| last_epoch=-1 | |
| ) | |
| print(f"[INFO] Scheduler: Cosine with warmup") | |
| print(f" Warmup steps: {warmup_steps} ({warmup_steps/num_training_steps*100:.1f}%)") | |
| print(f" Total steps: {num_training_steps}") | |
| return scheduler | |
| def get_current_learning_rates(optimizer): | |
| """Get current learning rate for each parameter group.""" | |
| lrs = {} | |
| for i, param_group in enumerate(optimizer.param_groups): | |
| name = param_group.get('name', f'group_{i}') | |
| lrs[name] = param_group['lr'] | |
| return lrs | |