File size: 6,239 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""Ternary LoRA adapters for memory-efficient fine-tuning.

Freezes base ternary weights, adds small float low-rank adapters.
Only adapters receive gradients — base state stays at 1.71 GB, adapters < 100 MB.

Usage:
    from finetuning.lora import TernaryLoRA, apply_lora_to_model
    
    model = ARBModel(...)
    apply_lora_to_model(model, rank=16, target_modules=['moe', 'byte_head'])
    # Only LoRA params are trainable — base is frozen
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from arbitor.kernel.ternary_scale import TernaryScaleTensor


class TernaryLoRALayer(nn.Module):
    """LoRA adapter wrapping a single TernaryScaleTensor.

    Base ternary weights are frozen. Two small float matrices A and B
    are trained: output = base(x) + (x @ A) @ B * scaling
    """
    def __init__(self, base_layer, rank=8, alpha=16.0):
        super().__init__()
        self.base = base_layer
        self.scaling = alpha / rank

        # Freeze all base parameters and buffers
        for p in base_layer.parameters(True):
            p.requires_grad = False

        # _T_shape is [out_dim, in_dim]
        out_dim, in_dim = base_layer._T_shape.tolist()
        out_dim, in_dim = int(out_dim), int(in_dim)

        self.lora_A = nn.Parameter(torch.randn(in_dim, rank) * 0.02)
        self.lora_B = nn.Parameter(torch.zeros(rank, out_dim))

        # For modules that need it, track if input requires grad
        self.forward_count = 0

    def forward(self, x):
        with torch.no_grad():
            base_out = self.base(x)
        lora_out = (x @ self.lora_A) @ self.lora_B * self.scaling
        return base_out + lora_out

    def extra_repr(self):
        return (f"base={tuple(self.base._T_shape.tolist())}, "
                f"rank={self.lora_A.shape[1]}, alpha={self.scaling * self.lora_A.shape[1]:.0f}")


class LoRAEmbedding(nn.Module):
    """LoRA adapter for ByteEmbedding / TernaryEmbeddingTable.

    Wraps the embedding lookup, adds a small learned delta.
    """
    def __init__(self, base_embed, rank=16):
        super().__init__()
        self.base = base_embed
        for p in base_embed.parameters(True):
            p.requires_grad = False

        num_embeddings, embed_dim = base_embed._T_shape.tolist()
        self.lora_A = nn.Parameter(torch.randn(num_embeddings, rank) * 0.02)
        self.lora_B = nn.Parameter(torch.randn(rank, embed_dim) * 0.02)

    def forward(self, x):
        with torch.no_grad():
            base_out = self.base(x)
        delta = F.embedding(x, self.lora_A @ self.lora_B)
        return base_out + delta * 0.1


def _should_lora(name, target_modules):
    """Check if a module name matches any target pattern."""
    for pattern in target_modules:
        if pattern in name.lower():
            return True
    return False


def apply_lora_to_model(model, rank=16, alpha=32.0, target_modules=None):
    """Apply LoRA adapters to targeted TernaryScaleTensor modules.

    Args:
        model: ARBModel instance
        rank: LoRA rank (8-32 typical, higher = more adapter capacity)
        alpha: LoRA scaling alpha (higher = stronger adapter influence)
        target_modules: List of name patterns.
            Default: ['moe', 'byte_head', 'embedding', 'graph', 'head']
    Returns:
        Dict mapping module names to LoRA layers (for saving/loading adapters)
    """
    if target_modules is None:
        target_modules = ['W_gate', 'W_transform', 'router', 'byte_head', 'head',
                          'output_router', 'shared_up', 'shared_expert_gate',
                          'shared_expert_up']

    lora_layers = {}

    def _apply(module, name=''):
        for child_name, child in list(module.named_children()):
            full_name = f"{name}.{child_name}" if name else child_name

            if isinstance(child, TernaryScaleTensor) and _should_lora(full_name, target_modules):
                # Skip embedding layers (they use lookup, not matmul)
                if hasattr(child, '_T_shape') and child._T_shape[1].item() == child._T_shape[0].item():
                    # Square projection — likely a linear layer, safe for LoRA
                    pass
                elif hasattr(child, '_T_shape') and 'embed' in full_name.lower():
                    continue  # Skip embeddings
                lora = TernaryLoRALayer(child, rank=rank, alpha=alpha)
                setattr(module, child_name, lora)
                lora_layers[full_name] = lora

            elif hasattr(child, '_T_shape') and hasattr(child, 'T_packed') and _should_lora(full_name, target_modules):
                continue  # Non-TernaryScaleTensor with T_packed is an embedding

            else:
                _apply(child, full_name)

    _apply(model)

    # Move LoRA adapters to CUDA (they're created on CPU by default)
    model.to('cuda' if torch.cuda.is_available() else 'cpu')

    # Freeze all non-LoRA params, only LoRA A/B are trainable
    for name, p in model.named_parameters():
        if 'lora_' not in name:
            p.requires_grad = False

    return lora_layers


def count_lora_params(model):
    """Count trainable (LoRA) vs frozen (base) parameters."""
    lora = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    return lora, total


def save_lora(lora_layers, path):
    """Save only LoRA adapter weights (small ~5-50 MB)."""
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    state = {f"lora.{k}.A": v.lora_A for k, v in lora_layers.items()}
    state.update({f"lora.{k}.B": v.lora_B for k, v in lora_layers.items()})
    torch.save(state, path)
    return path


def load_lora(model, path):
    """Load LoRA adapter weights into model."""
    state = torch.load(path, weights_only=True)
    for full_name, param in state.items():
        parts = full_name.split('.')
        if len(parts) < 3:
            continue
        # Find the module
        obj = model
        for p in parts[1:-1]:
            obj = getattr(obj, p, None)
            if obj is None:
                break
        if obj is not None and hasattr(obj, parts[-1]):
            getattr(obj, parts[-1]).data.copy_(param.data)
    return model