ASTERIZER commited on
Commit
d426a2d
Β·
verified Β·
1 Parent(s): 828e3ce

Upload lora_chat.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. lora_chat.py +310 -0
lora_chat.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LUNA 100M β€” LoRA Adapter Chat
3
+ Loads the base SFT model, injects LoRA, and applies an adapter checkpoint.
4
+
5
+ Usage:
6
+ python lora_chat.py --adapter Base/out/sft/rag_mcp_lora/final/adapter_model.pt
7
+ python lora_chat.py --adapter Base/out/sft/rag_mcp_lora/step-001554/adapter_model.pt
8
+ python lora_chat.py --adapter /path/to/adapter_model.pt --max_new 300 --temp 0.8
9
+
10
+ # Use the full bundle (has rank/alpha/targets embedded):
11
+ python lora_chat.py --adapter Base/out/sft/rag_mcp_lora/final/adapter_bundle.pt --bundle
12
+ """
13
+
14
+ import argparse
15
+ import math
16
+ import os
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from pathlib import Path
21
+
22
+
23
+ # ─── Model (matches sft_train.py exactly) ─────────────────────────────────────
24
+
25
+ class RotaryEmbedding(nn.Module):
26
+ def __init__(self, dim, max_seq_len=1024):
27
+ super().__init__()
28
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
29
+ self.register_buffer("inv_freq", inv_freq)
30
+ t = torch.arange(max_seq_len).float()
31
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
32
+ emb = torch.cat([freqs, freqs], dim=-1)
33
+ self.register_buffer("cos_cached", emb.cos())
34
+ self.register_buffer("sin_cached", emb.sin())
35
+
36
+ def forward(self, seq_len):
37
+ return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
38
+
39
+
40
+ def rotate_half(x):
41
+ x1, x2 = x.chunk(2, dim=-1)
42
+ return torch.cat([-x2, x1], dim=-1)
43
+
44
+
45
+ def apply_rotary(x, cos, sin):
46
+ c = cos.unsqueeze(0).unsqueeze(0)
47
+ s = sin.unsqueeze(0).unsqueeze(0)
48
+ return x * c + rotate_half(x) * s
49
+
50
+
51
+ class CausalSelfAttention(nn.Module):
52
+ def __init__(self, n_embd, n_head, block_size, rotary_pct=0.25):
53
+ super().__init__()
54
+ self.n_head = n_head
55
+ self.head_dim = n_embd // n_head
56
+ self.rot_dim = int(self.head_dim * rotary_pct)
57
+ self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True)
58
+ self.c_proj = nn.Linear(n_embd, n_embd, bias=True)
59
+ self.rotary = RotaryEmbedding(self.rot_dim, block_size)
60
+
61
+ def forward(self, x):
62
+ B, T, C = x.size()
63
+ qkv = self.c_attn(x).reshape(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4)
64
+ q, k, v = qkv.unbind(0)
65
+ cos, sin = self.rotary(T)
66
+ q = torch.cat([apply_rotary(q[..., :self.rot_dim], cos, sin), q[..., self.rot_dim:]], dim=-1)
67
+ k = torch.cat([apply_rotary(k[..., :self.rot_dim], cos, sin), k[..., self.rot_dim:]], dim=-1)
68
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
69
+ return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C))
70
+
71
+
72
+ class MLP(nn.Module):
73
+ def __init__(self, n_embd):
74
+ super().__init__()
75
+ self.fc = nn.Linear(n_embd, 4 * n_embd, bias=True)
76
+ self.gelu = nn.GELU()
77
+ self.proj = nn.Linear(4 * n_embd, n_embd, bias=True)
78
+
79
+ def forward(self, x):
80
+ return self.proj(self.gelu(self.fc(x)))
81
+
82
+
83
+ class Block(nn.Module):
84
+ def __init__(self, n_embd, n_head, block_size):
85
+ super().__init__()
86
+ self.ln1 = nn.LayerNorm(n_embd)
87
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size)
88
+ self.ln2 = nn.LayerNorm(n_embd)
89
+ self.mlp = MLP(n_embd)
90
+
91
+ def forward(self, x):
92
+ x = x + self.attn(self.ln1(x))
93
+ x = x + self.mlp(self.ln2(x))
94
+ return x
95
+
96
+
97
+ class LUNAModel(nn.Module):
98
+ def __init__(self, vocab_size=50304, block_size=1024,
99
+ n_layer=10, n_embd=768, n_head=12):
100
+ super().__init__()
101
+ self.block_size = block_size
102
+ self.wte = nn.Embedding(vocab_size, n_embd)
103
+ self.blocks = nn.ModuleList([Block(n_embd, n_head, block_size) for _ in range(n_layer)])
104
+ self.ln_f = nn.LayerNorm(n_embd)
105
+ self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
106
+ self.lm_head.weight = self.wte.weight
107
+
108
+ def forward(self, idx):
109
+ x = self.wte(idx)
110
+ for block in self.blocks:
111
+ x = block(x)
112
+ return self.lm_head(self.ln_f(x))
113
+
114
+
115
+ # ─── LoRA ─────────────────────────────────────────────────────────────────────
116
+
117
+ class LoRALinear(nn.Module):
118
+ def __init__(self, base_layer, rank=16, alpha=32, dropout=0.0):
119
+ super().__init__()
120
+ self.base = base_layer
121
+ self.scale = alpha / max(rank, 1)
122
+ self.dropout = nn.Dropout(dropout)
123
+ self.lora_a = nn.Linear(base_layer.in_features, rank, bias=False)
124
+ self.lora_b = nn.Linear(rank, base_layer.out_features, bias=False)
125
+ nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
126
+ nn.init.zeros_(self.lora_b.weight)
127
+ for p in self.base.parameters():
128
+ p.requires_grad = False
129
+
130
+ def forward(self, x):
131
+ return self.base(x) + self.lora_b(self.lora_a(self.dropout(x))) * self.scale
132
+
133
+
134
+ def inject_lora(model, target_modules, rank, alpha):
135
+ for module_name, module in list(model.named_modules()):
136
+ if not isinstance(module, nn.Linear):
137
+ continue
138
+ if not any(module_name.endswith(t) for t in target_modules):
139
+ continue
140
+ parent_name, _, child_name = module_name.rpartition(".")
141
+ parent = model.get_submodule(parent_name) if parent_name else model
142
+ wrapped = LoRALinear(module, rank=rank, alpha=alpha)
143
+ wrapped = wrapped.to(device=module.weight.device, dtype=module.weight.dtype)
144
+ setattr(parent, child_name, wrapped)
145
+
146
+
147
+ # ─── Generation ───────────────────────────────────────────────────────────────
148
+
149
+ @torch.no_grad()
150
+ def generate(model, input_ids, max_new=200, temperature=0.7,
151
+ top_p=0.9, top_k=50, rep_pen=1.1, device="cpu"):
152
+ ids = input_ids.to(device)
153
+ for _ in range(max_new):
154
+ logits = model(ids[:, -model.block_size:])[:, -1, :]
155
+ if rep_pen != 1.0:
156
+ for tid in set(ids[0].tolist()):
157
+ logits[0, tid] = logits[0, tid] / rep_pen if logits[0, tid] > 0 else logits[0, tid] * rep_pen
158
+ if temperature < 1e-6:
159
+ next_tok = logits.argmax(dim=-1, keepdim=True)
160
+ else:
161
+ logits = logits / temperature
162
+ probs = F.softmax(logits, dim=-1)
163
+ if top_k > 0:
164
+ kv, _ = torch.topk(probs, min(top_k, probs.size(-1)))
165
+ probs[probs < kv[:, [-1]]] = 0.0
166
+ probs /= probs.sum()
167
+ if top_p < 1.0:
168
+ sp, si = torch.sort(probs, descending=True)
169
+ cum = torch.cumsum(sp, dim=-1)
170
+ sp[cum - sp > top_p] = 0.0
171
+ sp /= sp.sum()
172
+ next_tok = si[0, torch.multinomial(sp[0], 1)]
173
+ else:
174
+ next_tok = torch.multinomial(probs[0], 1)
175
+ ids = torch.cat([ids, next_tok.view(1, 1)], dim=1)
176
+ if next_tok.item() == 0:
177
+ break
178
+ return ids[0, input_ids.shape[1]:].tolist()
179
+
180
+
181
+ def format_prompt(instruction):
182
+ return f"### Instruction:\n{instruction.strip()}\n\n### Response:\n"
183
+
184
+
185
+ # ─── Main ─────────────────────────────────────────────────────────────────────
186
+
187
+ def main():
188
+ parser = argparse.ArgumentParser(description="LUNA 100M β€” LoRA Adapter Chat")
189
+ parser.add_argument("--adapter", required=True,
190
+ help="Path to adapter_model.pt or adapter_bundle.pt")
191
+ parser.add_argument("--bundle", action="store_true",
192
+ help="Adapter file is an adapter_bundle.pt (has config embedded)")
193
+ parser.add_argument("--base_ckpt", default=None,
194
+ help="Path to base model .pth (auto-downloads from HF if not set)")
195
+ parser.add_argument("--tok_dir", default="Base/checkpoints/EleutherAI/pythia-160m")
196
+ parser.add_argument("--rank", type=int, default=16)
197
+ parser.add_argument("--alpha", type=float, default=32.0)
198
+ parser.add_argument("--targets", nargs="+",
199
+ default=["attn.c_attn", "attn.c_proj", "mlp.fc", "mlp.proj"])
200
+ parser.add_argument("--max_new", type=int, default=200)
201
+ parser.add_argument("--temp", type=float, default=0.7)
202
+ parser.add_argument("--top_p", type=float, default=0.9)
203
+ parser.add_argument("--top_k", type=int, default=50)
204
+ parser.add_argument("--rep_pen", type=float, default=1.1)
205
+ parser.add_argument("--device", default="auto")
206
+ args = parser.parse_args()
207
+
208
+ # ── device ──
209
+ if args.device == "auto":
210
+ device = "cuda" if torch.cuda.is_available() else "cpu"
211
+ else:
212
+ device = args.device
213
+
214
+ # ── load adapter ──
215
+ adapter_path = Path(args.adapter)
216
+ if not adapter_path.exists():
217
+ raise FileNotFoundError(f"Adapter not found: {adapter_path}")
218
+
219
+ bundle = torch.load(adapter_path, map_location="cpu", weights_only=True)
220
+
221
+ if args.bundle and isinstance(bundle, dict) and "lora_rank" in bundle:
222
+ rank = bundle["lora_rank"]
223
+ alpha = bundle["lora_alpha"]
224
+ targets = bundle["target_modules"]
225
+ adapter_state = bundle["adapter"]
226
+ print(f" Bundle config: rank={rank}, alpha={alpha}, targets={targets}")
227
+ else:
228
+ rank = args.rank
229
+ alpha = args.alpha
230
+ targets = args.targets
231
+ adapter_state = bundle
232
+
233
+ # ── resolve base checkpoint ──
234
+ base_ckpt = args.base_ckpt
235
+ if base_ckpt is None:
236
+ default = Path("Base/out/input_models/luna_sft_v1/sft_v1/final/model.pth")
237
+ if default.exists():
238
+ base_ckpt = str(default)
239
+ else:
240
+ print(" Base checkpoint not found locally β€” downloading from HF...")
241
+ from huggingface_hub import hf_hub_download
242
+ default.parent.mkdir(parents=True, exist_ok=True)
243
+ hf_hub_download(
244
+ repo_id="ASTERIZER/LUNA-100M",
245
+ filename="sft_v1/final/model.pth",
246
+ local_dir=str(default.parent.parent.parent),
247
+ token=os.environ.get("HF_TOKEN"),
248
+ )
249
+ base_ckpt = str(default)
250
+
251
+ # ── build and load base model ──
252
+ print(f" Loading base: {base_ckpt}")
253
+ base_state = torch.load(base_ckpt, map_location="cpu", weights_only=True)
254
+ if isinstance(base_state, dict) and "model" in base_state:
255
+ base_state = base_state["model"]
256
+
257
+ model = LUNAModel()
258
+ model.load_state_dict(base_state, strict=True)
259
+ model = model.to(device)
260
+
261
+ # ── inject LoRA and load adapter weights ──
262
+ inject_lora(model, target_modules=targets, rank=rank, alpha=alpha)
263
+ missing, unexpected = model.load_state_dict(adapter_state, strict=False)
264
+ if unexpected:
265
+ print(f" Warning: unexpected keys in adapter: {unexpected[:5]}")
266
+ lora_keys = [k for k in adapter_state if "lora" in k]
267
+ print(f" Loaded {len(lora_keys)} LoRA weight tensors from {adapter_path.name}")
268
+
269
+ model.eval()
270
+
271
+ # ── tokenizer ──
272
+ from transformers import AutoTokenizer
273
+ tokenizer = AutoTokenizer.from_pretrained(args.tok_dir)
274
+
275
+ # ── info ──
276
+ print(f"\n{'='*60}")
277
+ print(f" LUNA 100M + LoRA Adapter")
278
+ print(f" Adapter : {adapter_path}")
279
+ print(f" Device : {device}")
280
+ print(f" max_new : {args.max_new} temp: {args.temp} top_p: {args.top_p}")
281
+ print(f"{'='*60}")
282
+ print(" Type your instruction and press Enter. Ctrl+C to quit.\n")
283
+
284
+ # ── REPL ──
285
+ while True:
286
+ try:
287
+ user_input = input("You: ").strip()
288
+ except (EOFError, KeyboardInterrupt):
289
+ print("\nBye.")
290
+ break
291
+ if not user_input:
292
+ continue
293
+
294
+ prompt = format_prompt(user_input)
295
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
296
+ tokens = generate(
297
+ model, input_ids,
298
+ max_new=args.max_new,
299
+ temperature=args.temp,
300
+ top_p=args.top_p,
301
+ top_k=args.top_k,
302
+ rep_pen=args.rep_pen,
303
+ device=device,
304
+ )
305
+ response = tokenizer.decode(tokens, skip_special_tokens=True)
306
+ print(f"\nLUNA: {response.strip()}\n")
307
+
308
+
309
+ if __name__ == "__main__":
310
+ main()