Simo76 commited on
Commit
1fda0d1
·
1 Parent(s): ec0d4eb

Implement Nested LoRA architecture for dynamic rank control

Browse files

This module implements a Nested LoRA architecture for dynamic rank control in linear layers, allowing for efficient training with frozen original weights and adaptive rank changes.

Files changed (1) hide show
  1. nested_lora.py +130 -0
nested_lora.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Nested LoRA — One Particle, Multiple Orbitals
3
+ ===============================================
4
+
5
+ Single LoRA adapter pair with dynamic rank via slicing.
6
+ r4 ⊂ r8 ⊂ r16 — descending pauses dimensions, ascending resumes them.
7
+ Zero cold start on transitions.
8
+
9
+ This module is the "engine" — pure architecture, no control logic.
10
+ Pair with OrbitalController for adaptive rank decisions.
11
+
12
+ Author: Simona Vargiu
13
+ License: Apache 2.0
14
+ """
15
+
16
+ import math
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from typing import List
21
+
22
+
23
+ class NestedLoRALinear(nn.Module):
24
+ """
25
+ Single LoRA adapter with dynamic rank via slicing.
26
+
27
+ A single pair of matrices A(max_rank, in) and B(out, max_rank) is shared
28
+ across all rank levels. The active rank is controlled by slicing:
29
+
30
+ r=4 → A[:4, :], B[:, :4]
31
+ r=8 → A[:8, :], B[:, :8]
32
+ r=16 → A[:16,:], B[:, :16]
33
+
34
+ When descending from r=16 to r=4, dimensions 0-3 retain all learned
35
+ weights. Dimensions 4-15 are paused (no gradient), not destroyed.
36
+ When ascending back, they resume exactly where they left off.
37
+
38
+ Output is scaled by max_rank/active_rank to maintain consistent
39
+ magnitude across rank changes (analogous to alpha/r in standard LoRA).
40
+
41
+ Args:
42
+ linear: Original nn.Linear layer to wrap
43
+ max_rank: Maximum LoRA rank (default: 16)
44
+
45
+ Example:
46
+ >>> layer = NestedLoRALinear(original_linear, max_rank=16)
47
+ >>> layer.set_rank(4) # use 4 dimensions
48
+ >>> out = layer(x) # forward with r=4
49
+ >>> layer.set_rank(16) # expand to full rank
50
+ >>> out = layer(x) # forward with r=16, dimensions 0-3 unchanged
51
+ """
52
+
53
+ def __init__(self, linear: nn.Linear, max_rank: int = 16):
54
+ super().__init__()
55
+ self.linear = linear
56
+ self.max_rank = max_rank
57
+ self.active_rank = max_rank
58
+
59
+ # Freeze original weights
60
+ for p in self.linear.parameters():
61
+ p.requires_grad = False
62
+
63
+ # One particle: single A and B
64
+ self.lora_A = nn.Parameter(torch.empty(max_rank, linear.in_features))
65
+ self.lora_B = nn.Parameter(torch.zeros(linear.out_features, max_rank))
66
+
67
+ # Standard LoRA init: A = kaiming, B = zeros → initial delta = 0
68
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
69
+
70
+ def set_rank(self, r: int):
71
+ """Set the active orbital. Must be <= max_rank."""
72
+ self.active_rank = min(r, self.max_rank)
73
+
74
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
75
+ base = self.linear(x)
76
+ r = self.active_rank
77
+
78
+ h = F.linear(x, self.lora_A[:r, :])
79
+ delta = F.linear(h, self.lora_B[:, :r])
80
+
81
+ scale = self.max_rank / r
82
+ return base + delta * scale
83
+
84
+
85
+ def inject_nested_lora(model: nn.Module, max_rank: int = 16) -> nn.Module:
86
+ """
87
+ Replace attention Linear layers with NestedLoRALinear.
88
+
89
+ Targets any nn.Linear whose full name contains "attention".
90
+ Original weights are frozen; only LoRA parameters are trainable.
91
+
92
+ Args:
93
+ model: PyTorch model
94
+ max_rank: Maximum LoRA rank
95
+
96
+ Returns:
97
+ Model with NestedLoRA injected
98
+ """
99
+ for name, module in list(model.named_modules()):
100
+ if isinstance(module, nn.Linear) and "attention" in name:
101
+ parent = model
102
+ *path, last = name.split(".")
103
+ for p in path:
104
+ parent = getattr(parent, p)
105
+ setattr(parent, last, NestedLoRALinear(module, max_rank))
106
+ return model
107
+
108
+
109
+ def set_rank(model: nn.Module, r: int):
110
+ """Set active rank on all NestedLoRALinear modules in the model."""
111
+ for m in model.modules():
112
+ if isinstance(m, NestedLoRALinear):
113
+ m.set_rank(r)
114
+
115
+
116
+ def get_lora_params(model: nn.Module) -> List[nn.Parameter]:
117
+ """Get all LoRA parameters (for optimizer setup)."""
118
+ params = []
119
+ for m in model.modules():
120
+ if isinstance(m, NestedLoRALinear):
121
+ params.extend([m.lora_A, m.lora_B])
122
+ return params
123
+
124
+
125
+ def count_params(model: nn.Module) -> dict:
126
+ """Count total, trainable, and LoRA parameters."""
127
+ total = sum(p.numel() for p in model.parameters())
128
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
129
+ lora = sum(p.numel() for p in get_lora_params(model))
130
+ return {"total": total, "trainable": trainable, "lora": lora}