| import torch |
| import torch.nn as nn |
| import argparse |
| from safetensors.torch import load_file, save_file |
| from model import LocalSongModel |
| from pathlib import Path |
|
|
| class LoRALinear(nn.Module): |
| def __init__(self, original_linear: nn.Linear, rank: int = 8, alpha: float = 16.0): |
| super().__init__() |
| self.original_linear = original_linear |
| self.rank = rank |
| self.alpha = alpha |
| self.scaling = alpha / rank |
|
|
| self.lora_A = nn.Parameter(torch.zeros(original_linear.in_features, rank)) |
| self.lora_B = nn.Parameter(torch.zeros(rank, original_linear.out_features)) |
|
|
| nn.init.kaiming_uniform_(self.lora_A, a=5**0.5) |
| nn.init.zeros_(self.lora_B) |
|
|
| self.original_linear.weight.requires_grad = False |
| if self.original_linear.bias is not None: |
| self.original_linear.bias.requires_grad = False |
|
|
| def forward(self, x): |
| result = self.original_linear(x) |
| lora_out = (x @ self.lora_A @ self.lora_B) * self.scaling |
| return result + lora_out |
|
|
| def inject_lora(model, rank=8, alpha=16.0, target_modules=['qkv', 'proj', 'w1', 'w2', 'w3', 'q_proj', 'kv_proj'], device=None): |
| if device is None: |
| device = next(model.parameters()).device |
|
|
| for name, module in model.named_modules(): |
| if isinstance(module, nn.Linear): |
| if any(target in name for target in target_modules): |
| *parent_path, attr_name = name.split('.') |
| parent = model |
| for p in parent_path: |
| parent = getattr(parent, p) |
|
|
| lora_layer = LoRALinear(module, rank=rank, alpha=alpha) |
| lora_layer.lora_A.data = lora_layer.lora_A.data.to(device) |
| lora_layer.lora_B.data = lora_layer.lora_B.data.to(device) |
| setattr(parent, attr_name, lora_layer) |
|
|
| return model |
|
|
| def load_lora_weights(model, lora_path, device): |
| print(f"Loading LoRA from {lora_path}") |
| lora_state_dict = load_file(lora_path, device=str(device)) |
|
|
| loaded_count = 0 |
| for name, module in model.named_modules(): |
| if isinstance(module, LoRALinear): |
| lora_a_key = f"{name}.lora_A" |
| lora_b_key = f"{name}.lora_B" |
| if lora_a_key in lora_state_dict and lora_b_key in lora_state_dict: |
| module.lora_A.data = lora_state_dict[lora_a_key].to(device) |
| module.lora_B.data = lora_state_dict[lora_b_key].to(device) |
| loaded_count += 2 |
|
|
| print(f"Loaded {loaded_count} LoRA parameters") |
|
|
| def merge_lora_into_model(model): |
| """ |
| Merge LoRA weights into the base model weights. |
| For each LoRALinear layer: W_merged = W_original + (lora_A @ lora_B) * scaling |
| """ |
| print("\nMerging LoRA weights into base model...") |
| merged_count = 0 |
|
|
| for name, module in model.named_modules(): |
| if isinstance(module, LoRALinear): |
| lora_delta = (module.lora_A @ module.lora_B) * module.scaling |
|
|
| with torch.no_grad(): |
| module.original_linear.weight.data += lora_delta.T |
|
|
| merged_count += 1 |
|
|
| print(f"Merged {merged_count} LoRA layers into base weights") |
|
|
| def extract_base_weights(model): |
| """ |
| Extract the merged weights from LoRALinear modules back into a regular state dict. |
| """ |
| print("\nExtracting merged weights...") |
| new_state_dict = {} |
|
|
| for name, module in model.named_modules(): |
| if isinstance(module, LoRALinear): |
| original_name_weight = f"{name}.weight" |
| original_name_bias = f"{name}.bias" |
|
|
| new_state_dict[original_name_weight] = module.original_linear.weight.data |
| if module.original_linear.bias is not None: |
| new_state_dict[original_name_bias] = module.original_linear.bias.data |
|
|
| |
| for name, param in model.named_parameters(): |
| if 'lora_A' not in name and 'lora_B' not in name and 'original_linear' not in name: |
| new_state_dict[name] = param.data |
|
|
| print(f"Extracted {len(new_state_dict)} parameters") |
| return new_state_dict |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Merge LoRA weights into a base model checkpoint") |
| parser.add_argument( |
| "--base-checkpoint", |
| type=str, |
| default="checkpoints/checkpoint_461260.safetensors", |
| help="Path to the base model checkpoint" |
| ) |
| parser.add_argument( |
| "--lora-checkpoint", |
| type=str, |
| default="lora.safetensors", |
| help="Path to the LoRA checkpoint" |
| ) |
| parser.add_argument( |
| "--output-checkpoint", |
| type=str, |
| default="checkpoints/checkpoint_461260_merged_lora.safetensors", |
| help="Path to save the merged checkpoint" |
| ) |
| args = parser.parse_args() |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Using device: {device}") |
|
|
| |
| base_checkpoint = args.base_checkpoint |
| lora_checkpoint = args.lora_checkpoint |
| output_checkpoint = args.output_checkpoint |
|
|
| lora_rank = 16 |
| lora_alpha = 16.0 |
|
|
| print(f"\nBase checkpoint: {base_checkpoint}") |
| print(f"LoRA checkpoint: {lora_checkpoint}") |
| print(f"Output checkpoint: {output_checkpoint}") |
| print(f"LoRA rank: {lora_rank}, alpha: {lora_alpha}") |
|
|
| |
| print("\nLoading base model...") |
| model = LocalSongModel( |
| in_channels=8, |
| num_groups=16, |
| hidden_size=1024, |
| decoder_hidden_size=2048, |
| num_blocks=36, |
| patch_size=(16, 1), |
| num_classes=2304, |
| max_tags=8, |
| ).to(device) |
|
|
| state_dict = load_file(base_checkpoint, device=str(device)) |
| model.load_state_dict(state_dict, strict=True) |
| print("Base model loaded") |
|
|
| print("\nInjecting LoRA layers...") |
| model = inject_lora(model, rank=lora_rank, alpha=lora_alpha, device=device) |
|
|
| load_lora_weights(model, lora_checkpoint, device) |
|
|
| merge_lora_into_model(model) |
|
|
| merged_state_dict = extract_base_weights(model) |
|
|
| print(f"\nSaving merged checkpoint to {output_checkpoint}...") |
| save_file(merged_state_dict, output_checkpoint) |
| print("✓ Merged checkpoint saved successfully!") |
|
|
| print("\nVerifying merged checkpoint...") |
| test_model = LocalSongModel( |
| in_channels=8, |
| num_groups=16, |
| hidden_size=1024, |
| decoder_hidden_size=2048, |
| num_blocks=36, |
| patch_size=(16, 1), |
| num_classes=2304, |
| max_tags=8, |
| ).to(device) |
|
|
| merged_loaded = load_file(output_checkpoint, device=str(device)) |
| test_model.load_state_dict(merged_loaded, strict=True) |
| print("✓ Merged checkpoint verified successfully!") |
|
|
| print(f"\nDone! You can now use '{output_checkpoint}' as a standalone checkpoint without needing LoRA.") |
|
|
| if __name__ == '__main__': |
| main() |
|
|