ASTERIZER commited on
Commit
7c7c7ac
Β·
verified Β·
1 Parent(s): 2be87ed

Upload chat.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. chat.py +251 -0
chat.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LUNA SFT β€” Interactive Chat
3
+ Loads the SFT fine-tuned model once, then lets you chat continuously.
4
+
5
+ Usage:
6
+ python chat.py
7
+ python chat.py --ckpt "D:\\ASTERIZER 2026\\LUNA\\Base\\out\\sft\\model.pth"
8
+ python chat.py --max_new 300 --temp 0.7
9
+ """
10
+
11
+ import sys, argparse, torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from pathlib import Path
15
+
16
+
17
+ # ─── Model (must match train.py) ─────────────────────────────────────────────
18
+
19
+ class RotaryEmbedding(nn.Module):
20
+ def __init__(self, dim, max_seq_len=1024):
21
+ super().__init__()
22
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
23
+ self.register_buffer("inv_freq", inv_freq)
24
+ t = torch.arange(max_seq_len).float()
25
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
26
+ emb = torch.cat([freqs, freqs], dim=-1)
27
+ self.register_buffer("cos_cached", emb.cos())
28
+ self.register_buffer("sin_cached", emb.sin())
29
+
30
+ def forward(self, seq_len):
31
+ return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
32
+
33
+
34
+ def rotate_half(x):
35
+ x1, x2 = x.chunk(2, dim=-1)
36
+ return torch.cat([-x2, x1], dim=-1)
37
+
38
+
39
+ def apply_rotary(x, cos, sin):
40
+ c = cos.unsqueeze(0).unsqueeze(0)
41
+ s = sin.unsqueeze(0).unsqueeze(0)
42
+ return x * c + rotate_half(x) * s
43
+
44
+
45
+ class CausalSelfAttention(nn.Module):
46
+ def __init__(self, n_embd, n_head, block_size, rotary_pct=0.25):
47
+ super().__init__()
48
+ self.n_head = n_head
49
+ self.head_dim = n_embd // n_head
50
+ self.rot_dim = int(self.head_dim * rotary_pct)
51
+ self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True)
52
+ self.c_proj = nn.Linear(n_embd, n_embd, bias=True)
53
+ self.rotary = RotaryEmbedding(self.rot_dim, block_size)
54
+
55
+ def forward(self, x):
56
+ B, T, C = x.size()
57
+ qkv = self.c_attn(x).reshape(B, T, 3, self.n_head, self.head_dim).permute(2, 0, 3, 1, 4)
58
+ q, k, v = qkv.unbind(0)
59
+ cos, sin = self.rotary(T)
60
+ q = torch.cat([apply_rotary(q[..., :self.rot_dim], cos, sin), q[..., self.rot_dim:]], dim=-1)
61
+ k = torch.cat([apply_rotary(k[..., :self.rot_dim], cos, sin), k[..., self.rot_dim:]], dim=-1)
62
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
63
+ return self.c_proj(y.transpose(1, 2).contiguous().view(B, T, C))
64
+
65
+
66
+ class MLP(nn.Module):
67
+ def __init__(self, n_embd):
68
+ super().__init__()
69
+ self.fc = nn.Linear(n_embd, 4 * n_embd, bias=True)
70
+ self.gelu = nn.GELU()
71
+ self.proj = nn.Linear(4 * n_embd, n_embd, bias=True)
72
+
73
+ def forward(self, x):
74
+ return self.proj(self.gelu(self.fc(x)))
75
+
76
+
77
+ class Block(nn.Module):
78
+ def __init__(self, n_embd, n_head, block_size):
79
+ super().__init__()
80
+ self.ln1 = nn.LayerNorm(n_embd)
81
+ self.attn = CausalSelfAttention(n_embd, n_head, block_size)
82
+ self.ln2 = nn.LayerNorm(n_embd)
83
+ self.mlp = MLP(n_embd)
84
+
85
+ def forward(self, x):
86
+ x = x + self.attn(self.ln1(x))
87
+ x = x + self.mlp(self.ln2(x))
88
+ return x
89
+
90
+
91
+ class LUNAModel(nn.Module):
92
+ def __init__(self, vocab_size=50304, block_size=1024,
93
+ n_layer=10, n_embd=768, n_head=12):
94
+ super().__init__()
95
+ self.block_size = block_size
96
+ self.wte = nn.Embedding(vocab_size, n_embd)
97
+ self.blocks = nn.ModuleList([Block(n_embd, n_head, block_size) for _ in range(n_layer)])
98
+ self.ln_f = nn.LayerNorm(n_embd)
99
+ self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
100
+ self.lm_head.weight = self.wte.weight
101
+
102
+ def forward(self, idx):
103
+ x = self.wte(idx)
104
+ for block in self.blocks:
105
+ x = block(x)
106
+ return self.lm_head(self.ln_f(x))
107
+
108
+
109
+ # ─── Generation ───────────────────────────────────────────────────────────────
110
+
111
+ @torch.no_grad()
112
+ def generate(model, input_ids, max_new=200, temperature=0.7,
113
+ top_p=0.9, top_k=50, repetition_penalty=1.1, device="cpu"):
114
+ ids = input_ids.to(device)
115
+ generated = []
116
+
117
+ for _ in range(max_new):
118
+ logits = model(ids[:, -model.block_size:])[:, -1, :]
119
+
120
+ # Repetition penalty
121
+ if repetition_penalty != 1.0:
122
+ for tok_id in set(ids[0].tolist()):
123
+ if logits[0, tok_id] > 0:
124
+ logits[0, tok_id] /= repetition_penalty
125
+ else:
126
+ logits[0, tok_id] *= repetition_penalty
127
+
128
+ if temperature < 1e-6:
129
+ next_token = logits.argmax(dim=-1, keepdim=True)
130
+ else:
131
+ logits = logits / temperature
132
+ probs = F.softmax(logits, dim=-1)
133
+
134
+ # Top-k
135
+ if top_k > 0:
136
+ kval = min(top_k, probs.size(-1))
137
+ topk_vals, _ = torch.topk(probs, kval)
138
+ probs[probs < topk_vals[:, [-1]]] = 0.0
139
+ probs /= probs.sum()
140
+
141
+ # Top-p
142
+ if top_p < 1.0:
143
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True)
144
+ cumsum = torch.cumsum(sorted_probs, dim=-1)
145
+ mask = cumsum - sorted_probs > top_p
146
+ sorted_probs[mask] = 0.0
147
+ sorted_probs /= sorted_probs.sum()
148
+ next_token = sorted_idx[0, torch.multinomial(sorted_probs[0], 1)]
149
+ else:
150
+ next_token = torch.multinomial(probs[0], 1)
151
+
152
+ ids = torch.cat([ids, next_token.view(1, 1)], dim=1)
153
+ generated.append(next_token.item())
154
+
155
+ if next_token.item() == 0: # EOS (pythia tokenizer)
156
+ break
157
+
158
+ return generated
159
+
160
+
161
+ # ─── Alpaca prompt template ───────────────────────────────────────────────────
162
+
163
+ # Prompt format matching sft_train.py exactly (no preamble)
164
+ def format_prompt(instruction, context=""):
165
+ inst = instruction.strip()
166
+ ctx = context.strip()
167
+ if inst and ctx:
168
+ return f"### Instruction:\n{inst}\n\n### Input:\n{ctx}\n\n### Response:\n"
169
+ elif inst:
170
+ return f"### Instruction:\n{inst}\n\n### Response:\n"
171
+ else:
172
+ return f"### Input:\n{ctx}\n\n### Response:\n"
173
+
174
+
175
+ # ─── Main ─────────────────────────────────────────────────────────────────────
176
+
177
+ def main():
178
+ parser = argparse.ArgumentParser(description="LUNA SFT β€” Interactive Chat")
179
+ parser.add_argument("--ckpt", default=r"D:\ASTERIZER 2026\LUNA\Base\out\sft\model.pth")
180
+ parser.add_argument("--tok_dir", default="Base/checkpoints/EleutherAI/pythia-160m")
181
+ parser.add_argument("--max_new", type=int, default=150)
182
+ parser.add_argument("--temp", type=float, default=0.7)
183
+ parser.add_argument("--top_p", type=float, default=0.9)
184
+ parser.add_argument("--top_k", type=int, default=40)
185
+ parser.add_argument("--rep_pen", type=float, default=1.0)
186
+ parser.add_argument("--device", default="auto")
187
+ args = parser.parse_args()
188
+
189
+ device = "cuda" if args.device == "auto" and torch.cuda.is_available() else args.device
190
+ if device == "auto":
191
+ device = "cpu"
192
+ print(f"\nDevice: {device}")
193
+
194
+ # Load model
195
+ print(f"Loading: {args.ckpt}")
196
+ ckpt = torch.load(args.ckpt, map_location="cpu", weights_only=True)
197
+ state = ckpt["model"] if "model" in ckpt else ckpt
198
+ model = LUNAModel()
199
+ model.load_state_dict(state, strict=True)
200
+ model = model.to(device).eval()
201
+ params = sum(p.numel() for p in model.parameters())
202
+ print(f" Model loaded: {params:,} parameters")
203
+
204
+ # Load tokenizer
205
+ from transformers import AutoTokenizer
206
+ tokenizer = AutoTokenizer.from_pretrained(args.tok_dir)
207
+ print(f" Tokenizer: {args.tok_dir} (vocab {tokenizer.vocab_size})")
208
+
209
+ # Chat loop
210
+ print(f"\n{'='*60}")
211
+ print(" LUNA β€” Interactive Chat")
212
+ print(f" max_new={args.max_new} temp={args.temp} top_p={args.top_p} top_k={args.top_k}")
213
+ print(f" Type your message and press Enter. Type 'quit' to exit.")
214
+ print(f"{'='*60}\n")
215
+
216
+ while True:
217
+ try:
218
+ user_input = input("You: ").strip()
219
+ except (EOFError, KeyboardInterrupt):
220
+ print("\nBye!")
221
+ break
222
+
223
+ if not user_input:
224
+ continue
225
+ if user_input.lower() in ("quit", "exit", "q"):
226
+ print("Bye!")
227
+ break
228
+
229
+ prompt = format_prompt(user_input)
230
+ ids = tokenizer.encode(prompt, return_tensors="pt")
231
+
232
+ tokens = generate(
233
+ model, ids,
234
+ max_new=args.max_new,
235
+ temperature=args.temp,
236
+ top_p=args.top_p,
237
+ top_k=args.top_k,
238
+ repetition_penalty=args.rep_pen,
239
+ device=device,
240
+ )
241
+
242
+ response = tokenizer.decode(tokens, skip_special_tokens=True).strip()
243
+ # Cut at any trailing ### if model generates next template
244
+ if "### " in response:
245
+ response = response.split("### ")[0].strip()
246
+
247
+ print(f"\nLUNA: {response}\n")
248
+
249
+
250
+ if __name__ == "__main__":
251
+ main()