Upload main.ipynb with huggingface_hub
Browse files- main.ipynb +476 -0
main.ipynb
CHANGED
|
@@ -2161,6 +2161,482 @@
|
|
| 2161 |
"\n",
|
| 2162 |
"visualize_diffusion(model, tokenizer)"
|
| 2163 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2164 |
}
|
| 2165 |
],
|
| 2166 |
"metadata": {
|
|
|
|
| 2161 |
"\n",
|
| 2162 |
"visualize_diffusion(model, tokenizer)"
|
| 2163 |
]
|
| 2164 |
+
},
|
| 2165 |
+
{
|
| 2166 |
+
"cell_type": "markdown",
|
| 2167 |
+
"id": "ft_header",
|
| 2168 |
+
"metadata": {},
|
| 2169 |
+
"outputs": [],
|
| 2170 |
+
"source": [
|
| 2171 |
+
"---\n",
|
| 2172 |
+
"# Part 2: Fine-tuning for Chat\n",
|
| 2173 |
+
"\n",
|
| 2174 |
+
"Now we turn the pretrained MDLM into a **chatbot** using supervised fine-tuning on dialogue data.\n",
|
| 2175 |
+
"\n",
|
| 2176 |
+
"## How diffusion chat works\n",
|
| 2177 |
+
"1. Format: `<|user|> message <|assistant|> response <|end|>`\n",
|
| 2178 |
+
"2. **Training**: Mask only the response tokens \u2014 the user message stays visible as context\n",
|
| 2179 |
+
"3. **Inference**: User types a message \u2192 freeze those tokens \u2192 diffusion unmasks only the response\n",
|
| 2180 |
+
"4. **The cool part**: The response materializes all at once, not left-to-right"
|
| 2181 |
+
]
|
| 2182 |
+
},
|
| 2183 |
+
{
|
| 2184 |
+
"cell_type": "code",
|
| 2185 |
+
"id": "ft_config",
|
| 2186 |
+
"metadata": {},
|
| 2187 |
+
"outputs": [],
|
| 2188 |
+
"source": [
|
| 2189 |
+
"# ============================================================\n",
|
| 2190 |
+
"# FINE-TUNING CONFIG\n",
|
| 2191 |
+
"# ============================================================\n",
|
| 2192 |
+
"\n",
|
| 2193 |
+
"@dataclass\n",
|
| 2194 |
+
"class FinetuneConfig:\n",
|
| 2195 |
+
" # Training\n",
|
| 2196 |
+
" ft_steps: int = 5000\n",
|
| 2197 |
+
" ft_batch_size: int = 16\n",
|
| 2198 |
+
" ft_lr: float = 5e-5 # Lower LR for fine-tuning\n",
|
| 2199 |
+
" ft_warmup: int = 200\n",
|
| 2200 |
+
" max_response_len: int = 128 # Max response length\n",
|
| 2201 |
+
" max_prompt_len: int = 64 # Max prompt length\n",
|
| 2202 |
+
" log_every: int = 50\n",
|
| 2203 |
+
" sample_every: int = 500\n",
|
| 2204 |
+
"\n",
|
| 2205 |
+
"ft_config = FinetuneConfig()\n",
|
| 2206 |
+
"\n",
|
| 2207 |
+
"# Add special tokens to tokenizer\n",
|
| 2208 |
+
"SPECIAL_TOKENS = {\n",
|
| 2209 |
+
" 'additional_special_tokens': ['<|user|>', '<|assistant|>', '<|end|>']\n",
|
| 2210 |
+
"}\n",
|
| 2211 |
+
"tokenizer.add_special_tokens(SPECIAL_TOKENS)\n",
|
| 2212 |
+
"\n",
|
| 2213 |
+
"USER_TOKEN = tokenizer.convert_tokens_to_ids('<|user|>')\n",
|
| 2214 |
+
"ASST_TOKEN = tokenizer.convert_tokens_to_ids('<|assistant|>')\n",
|
| 2215 |
+
"END_TOKEN = tokenizer.convert_tokens_to_ids('<|end|>')\n",
|
| 2216 |
+
"\n",
|
| 2217 |
+
"print(f'Special token IDs: USER={USER_TOKEN}, ASST={ASST_TOKEN}, END={END_TOKEN}')\n",
|
| 2218 |
+
"\n",
|
| 2219 |
+
"# Resize model embeddings to accommodate new tokens\n",
|
| 2220 |
+
"old_vocab = config.vocab_size\n",
|
| 2221 |
+
"new_vocab = len(tokenizer)\n",
|
| 2222 |
+
"if new_vocab > old_vocab:\n",
|
| 2223 |
+
" # Expand embedding and output projection\n",
|
| 2224 |
+
" old_emb = model_unwrapped.token_emb.weight.data\n",
|
| 2225 |
+
" model_unwrapped.token_emb = nn.Embedding(new_vocab, config.hidden_dim).to(device)\n",
|
| 2226 |
+
" model_unwrapped.token_emb.weight.data[:old_vocab] = old_emb\n",
|
| 2227 |
+
" # Re-tie output projection\n",
|
| 2228 |
+
" model_unwrapped.output_proj = nn.Linear(config.hidden_dim, new_vocab, bias=False).to(device)\n",
|
| 2229 |
+
" model_unwrapped.output_proj.weight = model_unwrapped.token_emb.weight\n",
|
| 2230 |
+
" # Update config\n",
|
| 2231 |
+
" config.vocab_size = new_vocab\n",
|
| 2232 |
+
" model_unwrapped.config.vocab_size = new_vocab\n",
|
| 2233 |
+
" print(f'Resized embeddings: {old_vocab} -> {new_vocab}')\n",
|
| 2234 |
+
"\n",
|
| 2235 |
+
"print(f'Fine-tune config ready')\n"
|
| 2236 |
+
]
|
| 2237 |
+
},
|
| 2238 |
+
{
|
| 2239 |
+
"cell_type": "code",
|
| 2240 |
+
"id": "ft_dataset",
|
| 2241 |
+
"metadata": {},
|
| 2242 |
+
"outputs": [],
|
| 2243 |
+
"source": [
|
| 2244 |
+
"# ============================================================\n",
|
| 2245 |
+
"# DIALOGUE DATASET\n",
|
| 2246 |
+
"# ============================================================\n",
|
| 2247 |
+
"\n",
|
| 2248 |
+
"from datasets import load_dataset\n",
|
| 2249 |
+
"\n",
|
| 2250 |
+
"# Using Alpaca-cleaned: simple instruction-response pairs\n",
|
| 2251 |
+
"print('Loading Alpaca dataset...')\n",
|
| 2252 |
+
"alpaca = load_dataset('yahma/alpaca-cleaned', split='train')\n",
|
| 2253 |
+
"print(f'Loaded {len(alpaca)} examples')\n",
|
| 2254 |
+
"\n",
|
| 2255 |
+
"class ChatDataset(torch.utils.data.Dataset):\n",
|
| 2256 |
+
" \"\"\"Format dialogue as: <|user|> instruction <|assistant|> response <|end|>\n",
|
| 2257 |
+
" \n",
|
| 2258 |
+
" Returns:\n",
|
| 2259 |
+
" input_ids: full sequence token ids\n",
|
| 2260 |
+
" response_mask: bool mask, True for response tokens (what we train on)\n",
|
| 2261 |
+
" \"\"\"\n",
|
| 2262 |
+
" def __init__(self, dataset, tokenizer, max_prompt_len, max_response_len):\n",
|
| 2263 |
+
" self.data = dataset\n",
|
| 2264 |
+
" self.tokenizer = tokenizer\n",
|
| 2265 |
+
" self.max_prompt_len = max_prompt_len\n",
|
| 2266 |
+
" self.max_response_len = max_response_len\n",
|
| 2267 |
+
" self.total_len = max_prompt_len + max_response_len\n",
|
| 2268 |
+
" \n",
|
| 2269 |
+
" def __len__(self):\n",
|
| 2270 |
+
" return len(self.data)\n",
|
| 2271 |
+
" \n",
|
| 2272 |
+
" def __getitem__(self, idx):\n",
|
| 2273 |
+
" item = self.data[idx]\n",
|
| 2274 |
+
" \n",
|
| 2275 |
+
" # Build prompt\n",
|
| 2276 |
+
" instruction = item['instruction']\n",
|
| 2277 |
+
" if item.get('input', ''):\n",
|
| 2278 |
+
" instruction = instruction + ' ' + item['input']\n",
|
| 2279 |
+
" response = item['output']\n",
|
| 2280 |
+
" \n",
|
| 2281 |
+
" # Tokenize separately\n",
|
| 2282 |
+
" prompt_tokens = [USER_TOKEN] + self.tokenizer.encode(instruction)[:self.max_prompt_len - 2] + [ASST_TOKEN]\n",
|
| 2283 |
+
" response_tokens = self.tokenizer.encode(response)[:self.max_response_len - 1] + [END_TOKEN]\n",
|
| 2284 |
+
" \n",
|
| 2285 |
+
" # Combine\n",
|
| 2286 |
+
" input_ids = prompt_tokens + response_tokens\n",
|
| 2287 |
+
" prompt_len = len(prompt_tokens)\n",
|
| 2288 |
+
" \n",
|
| 2289 |
+
" # Pad or truncate to fixed length\n",
|
| 2290 |
+
" if len(input_ids) < self.total_len:\n",
|
| 2291 |
+
" pad_len = self.total_len - len(input_ids)\n",
|
| 2292 |
+
" input_ids = input_ids + [tokenizer.eos_token_id] * pad_len\n",
|
| 2293 |
+
" else:\n",
|
| 2294 |
+
" input_ids = input_ids[:self.total_len]\n",
|
| 2295 |
+
" \n",
|
| 2296 |
+
" input_ids = torch.tensor(input_ids, dtype=torch.long)\n",
|
| 2297 |
+
" \n",
|
| 2298 |
+
" # Response mask: True for response positions only\n",
|
| 2299 |
+
" response_mask = torch.zeros(self.total_len, dtype=torch.bool)\n",
|
| 2300 |
+
" response_mask[prompt_len:prompt_len + len(response_tokens)] = True\n",
|
| 2301 |
+
" \n",
|
| 2302 |
+
" return input_ids, response_mask\n",
|
| 2303 |
+
"\n",
|
| 2304 |
+
"chat_dataset = ChatDataset(alpaca, tokenizer, ft_config.max_prompt_len, ft_config.max_response_len)\n",
|
| 2305 |
+
"chat_loader = DataLoader(chat_dataset, batch_size=ft_config.ft_batch_size, shuffle=True, num_workers=2, pin_memory=True)\n",
|
| 2306 |
+
"\n",
|
| 2307 |
+
"# Test\n",
|
| 2308 |
+
"test_ids, test_mask = chat_dataset[0]\n",
|
| 2309 |
+
"print(f'\\nExample:')\n",
|
| 2310 |
+
"print(f'Full sequence: {tokenizer.decode(test_ids[:40])}...')\n",
|
| 2311 |
+
"print(f'Prompt tokens: {test_mask.sum().item()} response positions out of {len(test_ids)}')\n",
|
| 2312 |
+
"print(f'\\nPrompt part: {tokenizer.decode(test_ids[~test_mask][:30])}')\n",
|
| 2313 |
+
"print(f'Response part: {tokenizer.decode(test_ids[test_mask][:30])}')\n"
|
| 2314 |
+
]
|
| 2315 |
+
},
|
| 2316 |
+
{
|
| 2317 |
+
"cell_type": "code",
|
| 2318 |
+
"id": "ft_train",
|
| 2319 |
+
"metadata": {},
|
| 2320 |
+
"outputs": [],
|
| 2321 |
+
"source": [
|
| 2322 |
+
"# ============================================================\n",
|
| 2323 |
+
"# FINE-TUNING LOOP\n",
|
| 2324 |
+
"# ============================================================\n",
|
| 2325 |
+
"\n",
|
| 2326 |
+
"# Fresh optimizer with lower LR\n",
|
| 2327 |
+
"ft_optimizer = torch.optim.AdamW(\n",
|
| 2328 |
+
" model_unwrapped.parameters(),\n",
|
| 2329 |
+
" lr=ft_config.ft_lr,\n",
|
| 2330 |
+
" betas=(0.9, 0.98),\n",
|
| 2331 |
+
" weight_decay=0.01,\n",
|
| 2332 |
+
")\n",
|
| 2333 |
+
"ft_scaler = GradScaler('cuda')\n",
|
| 2334 |
+
"ft_ema = EMA(model_unwrapped, decay=0.999) # Faster EMA for fine-tuning\n",
|
| 2335 |
+
"\n",
|
| 2336 |
+
"model_unwrapped.train()\n",
|
| 2337 |
+
"ft_losses = []\n",
|
| 2338 |
+
"ft_accuracies = []\n",
|
| 2339 |
+
"ft_start = time.time()\n",
|
| 2340 |
+
"chat_iter = iter(chat_loader)\n",
|
| 2341 |
+
"\n",
|
| 2342 |
+
"print(f'Fine-tuning for {ft_config.ft_steps} steps...')\n",
|
| 2343 |
+
"print(f'Batch size: {ft_config.ft_batch_size}')\n",
|
| 2344 |
+
"print('=' * 60)\n",
|
| 2345 |
+
"\n",
|
| 2346 |
+
"for step in range(1, ft_config.ft_steps + 1):\n",
|
| 2347 |
+
" # LR schedule: linear warmup + cosine decay\n",
|
| 2348 |
+
" lr = get_lr(step, ft_config.ft_warmup, ft_config.ft_steps, ft_config.ft_lr)\n",
|
| 2349 |
+
" for pg in ft_optimizer.param_groups:\n",
|
| 2350 |
+
" pg['lr'] = lr\n",
|
| 2351 |
+
"\n",
|
| 2352 |
+
" try:\n",
|
| 2353 |
+
" input_ids, response_mask = next(chat_iter)\n",
|
| 2354 |
+
" except StopIteration:\n",
|
| 2355 |
+
" chat_iter = iter(chat_loader)\n",
|
| 2356 |
+
" input_ids, response_mask = next(chat_iter)\n",
|
| 2357 |
+
"\n",
|
| 2358 |
+
" input_ids = input_ids.to(device)\n",
|
| 2359 |
+
" response_mask = response_mask.to(device)\n",
|
| 2360 |
+
"\n",
|
| 2361 |
+
" ft_optimizer.zero_grad()\n",
|
| 2362 |
+
"\n",
|
| 2363 |
+
" with autocast('cuda', dtype=torch.float16):\n",
|
| 2364 |
+
" B, L = input_ids.shape\n",
|
| 2365 |
+
"\n",
|
| 2366 |
+
" # Sample timestep\n",
|
| 2367 |
+
" t = model_unwrapped.noise_schedule.sample_t(B, device)\n",
|
| 2368 |
+
"\n",
|
| 2369 |
+
" # Forward process: mask ONLY response tokens\n",
|
| 2370 |
+
" # Prompt tokens stay unmasked (model can always see them)\n",
|
| 2371 |
+
" alpha_t = model_unwrapped.noise_schedule.alpha(t)[:, None] # [B, 1]\n",
|
| 2372 |
+
" mask_prob = 1.0 - alpha_t\n",
|
| 2373 |
+
" noise_mask = (torch.rand_like(input_ids.float()) < mask_prob) & response_mask\n",
|
| 2374 |
+
" z_t = torch.where(noise_mask, config.mask_token_id, input_ids)\n",
|
| 2375 |
+
"\n",
|
| 2376 |
+
" # Forward pass\n",
|
| 2377 |
+
" hidden = model_unwrapped.forward_hidden(z_t, t)\n",
|
| 2378 |
+
"\n",
|
| 2379 |
+
" # Loss only at masked response positions\n",
|
| 2380 |
+
" masked_hidden = hidden[noise_mask]\n",
|
| 2381 |
+
" masked_targets = input_ids[noise_mask]\n",
|
| 2382 |
+
"\n",
|
| 2383 |
+
" if masked_hidden.shape[0] > 0:\n",
|
| 2384 |
+
" masked_logits = F.linear(masked_hidden, model_unwrapped.output_proj.weight)\n",
|
| 2385 |
+
" masked_logits[:, config.mask_token_id] = -1e9\n",
|
| 2386 |
+
" ce_loss = F.cross_entropy(masked_logits, masked_targets, reduction='none')\n",
|
| 2387 |
+
" weight = model_unwrapped.noise_schedule.loss_weight(t)\n",
|
| 2388 |
+
" weight_expanded = weight[:, None].expand(B, L)[noise_mask]\n",
|
| 2389 |
+
" loss = (ce_loss * weight_expanded).mean()\n",
|
| 2390 |
+
"\n",
|
| 2391 |
+
" with torch.no_grad():\n",
|
| 2392 |
+
" acc = (masked_logits.argmax(-1) == masked_targets).float().mean().item()\n",
|
| 2393 |
+
" else:\n",
|
| 2394 |
+
" loss = torch.tensor(0.0, device=device)\n",
|
| 2395 |
+
" acc = 1.0\n",
|
| 2396 |
+
"\n",
|
| 2397 |
+
" ft_scaler.scale(loss).backward()\n",
|
| 2398 |
+
" ft_scaler.unscale_(ft_optimizer)\n",
|
| 2399 |
+
" grad_norm = nn.utils.clip_grad_norm_(model_unwrapped.parameters(), 1.0)\n",
|
| 2400 |
+
" ft_scaler.step(ft_optimizer)\n",
|
| 2401 |
+
" ft_scaler.update()\n",
|
| 2402 |
+
" ft_ema.update(model_unwrapped)\n",
|
| 2403 |
+
"\n",
|
| 2404 |
+
" ft_losses.append(loss.item())\n",
|
| 2405 |
+
" ft_accuracies.append(acc)\n",
|
| 2406 |
+
"\n",
|
| 2407 |
+
" if step % ft_config.log_every == 0:\n",
|
| 2408 |
+
" elapsed = time.time() - ft_start\n",
|
| 2409 |
+
" avg_loss = np.mean(ft_losses[-ft_config.log_every:])\n",
|
| 2410 |
+
" avg_acc = np.mean(ft_accuracies[-ft_config.log_every:])\n",
|
| 2411 |
+
" eta = (ft_config.ft_steps - step) / (step / elapsed) / 60\n",
|
| 2412 |
+
" print(f'Step {step:>5d}/{ft_config.ft_steps} | Loss: {avg_loss:.4f} | Acc: {avg_acc:.3f} | LR: {lr:.2e} | Grad: {grad_norm:.2f} | ETA: {eta:.1f}m')\n",
|
| 2413 |
+
"\n",
|
| 2414 |
+
" # Generate chat samples\n",
|
| 2415 |
+
" if step % ft_config.sample_every == 0:\n",
|
| 2416 |
+
" print(f\"\\n{'='*60}\")\n",
|
| 2417 |
+
" print(f'Chat samples at step {step}:')\n",
|
| 2418 |
+
" ft_ema.apply_shadow(model_unwrapped)\n",
|
| 2419 |
+
" model_unwrapped.eval()\n",
|
| 2420 |
+
"\n",
|
| 2421 |
+
" test_prompts = [\n",
|
| 2422 |
+
" 'What is the moon?',\n",
|
| 2423 |
+
" 'Write a short poem about the ocean.',\n",
|
| 2424 |
+
" 'Explain what a computer is.',\n",
|
| 2425 |
+
" 'What is the meaning of life?',\n",
|
| 2426 |
+
" ]\n",
|
| 2427 |
+
"\n",
|
| 2428 |
+
" for prompt in test_prompts:\n",
|
| 2429 |
+
" # Tokenize prompt\n",
|
| 2430 |
+
" prompt_tokens = [USER_TOKEN] + tokenizer.encode(prompt)[:ft_config.max_prompt_len - 2] + [ASST_TOKEN]\n",
|
| 2431 |
+
" prompt_len = len(prompt_tokens)\n",
|
| 2432 |
+
" total_len = prompt_len + ft_config.max_response_len\n",
|
| 2433 |
+
"\n",
|
| 2434 |
+
" # Start with prompt + all masks for response\n",
|
| 2435 |
+
" x = torch.full((1, total_len), config.mask_token_id, dtype=torch.long, device=device)\n",
|
| 2436 |
+
" x[0, :prompt_len] = torch.tensor(prompt_tokens, dtype=torch.long, device=device)\n",
|
| 2437 |
+
"\n",
|
| 2438 |
+
" # Diffusion sampling \u2014 only unmask response positions\n",
|
| 2439 |
+
" timesteps = torch.linspace(1.0 - 1e-5, 1e-5, 128 + 1, device=device)\n",
|
| 2440 |
+
" for i in range(128):\n",
|
| 2441 |
+
" t_now = timesteps[i]\n",
|
| 2442 |
+
" t_next = timesteps[i + 1]\n",
|
| 2443 |
+
" alpha_now = model_unwrapped.noise_schedule.alpha(t_now)\n",
|
| 2444 |
+
" alpha_next = model_unwrapped.noise_schedule.alpha(t_next)\n",
|
| 2445 |
+
"\n",
|
| 2446 |
+
" t_batch = torch.full((1,), t_now.item(), device=device)\n",
|
| 2447 |
+
" logits = model_unwrapped.forward_full(x, t_batch)\n",
|
| 2448 |
+
" probs = F.softmax(logits / 0.7, dim=-1)\n",
|
| 2449 |
+
"\n",
|
| 2450 |
+
" unmask_prob = ((alpha_next - alpha_now) / (1.0 - alpha_now + 1e-8)).clamp(0, 1)\n",
|
| 2451 |
+
" is_masked = (x == config.mask_token_id)\n",
|
| 2452 |
+
" unmask = is_masked & (torch.rand_like(x.float()) < unmask_prob)\n",
|
| 2453 |
+
"\n",
|
| 2454 |
+
" if unmask.any():\n",
|
| 2455 |
+
" flat_probs = probs.reshape(-1, config.vocab_size)\n",
|
| 2456 |
+
" sampled = torch.multinomial(flat_probs, 1).reshape(1, total_len)\n",
|
| 2457 |
+
" x = torch.where(unmask, sampled, x)\n",
|
| 2458 |
+
"\n",
|
| 2459 |
+
" # Final cleanup\n",
|
| 2460 |
+
" is_masked = (x == config.mask_token_id)\n",
|
| 2461 |
+
" if is_masked.any():\n",
|
| 2462 |
+
" t_batch = torch.full((1,), 1e-5, device=device)\n",
|
| 2463 |
+
" logits = model_unwrapped.forward_full(x, t_batch)\n",
|
| 2464 |
+
" probs = F.softmax(logits / 0.7, dim=-1)\n",
|
| 2465 |
+
" flat_probs = probs.reshape(-1, config.vocab_size)\n",
|
| 2466 |
+
" sampled = torch.multinomial(flat_probs, 1).reshape(1, total_len)\n",
|
| 2467 |
+
" x = torch.where(is_masked, sampled, x)\n",
|
| 2468 |
+
"\n",
|
| 2469 |
+
" # Decode response only\n",
|
| 2470 |
+
" response_tokens = x[0, prompt_len:].cpu().tolist()\n",
|
| 2471 |
+
" # Cut at END token\n",
|
| 2472 |
+
" if END_TOKEN in response_tokens:\n",
|
| 2473 |
+
" response_tokens = response_tokens[:response_tokens.index(END_TOKEN)]\n",
|
| 2474 |
+
" response = tokenizer.decode(response_tokens, skip_special_tokens=True)\n",
|
| 2475 |
+
" print(f'\\n User: {prompt}')\n",
|
| 2476 |
+
" print(f' Bot: {response}')\n",
|
| 2477 |
+
"\n",
|
| 2478 |
+
" model_unwrapped.train()\n",
|
| 2479 |
+
" ft_ema.restore(model_unwrapped)\n",
|
| 2480 |
+
" print(f\"{'='*60}\\n\")\n",
|
| 2481 |
+
"\n",
|
| 2482 |
+
"# Save fine-tuned model\n",
|
| 2483 |
+
"torch.save({\n",
|
| 2484 |
+
" 'step': step,\n",
|
| 2485 |
+
" 'model_state_dict': model_unwrapped.state_dict(),\n",
|
| 2486 |
+
" 'ema_shadow': ft_ema.shadow,\n",
|
| 2487 |
+
" 'config': config,\n",
|
| 2488 |
+
"}, 'checkpoint_chat.pt')\n",
|
| 2489 |
+
"print('Fine-tuning complete! Saved checkpoint_chat.pt')\n"
|
| 2490 |
+
]
|
| 2491 |
+
},
|
| 2492 |
+
{
|
| 2493 |
+
"cell_type": "markdown",
|
| 2494 |
+
"id": "chat_header",
|
| 2495 |
+
"metadata": {},
|
| 2496 |
+
"outputs": [],
|
| 2497 |
+
"source": [
|
| 2498 |
+
"## Chat with your Diffusion LM\n",
|
| 2499 |
+
"\n",
|
| 2500 |
+
"Type a message and watch the response **materialize from noise** via the diffusion process."
|
| 2501 |
+
]
|
| 2502 |
+
},
|
| 2503 |
+
{
|
| 2504 |
+
"cell_type": "code",
|
| 2505 |
+
"id": "chat_interface",
|
| 2506 |
+
"metadata": {},
|
| 2507 |
+
"outputs": [],
|
| 2508 |
+
"source": [
|
| 2509 |
+
"# ============================================================\n",
|
| 2510 |
+
"# CHAT INTERFACE WITH DIFFUSION VISUALIZATION\n",
|
| 2511 |
+
"# ============================================================\n",
|
| 2512 |
+
"\n",
|
| 2513 |
+
"from IPython.display import clear_output, display\n",
|
| 2514 |
+
"import time as _time\n",
|
| 2515 |
+
"\n",
|
| 2516 |
+
"# Load EMA weights\n",
|
| 2517 |
+
"ft_ema.apply_shadow(model_unwrapped)\n",
|
| 2518 |
+
"model_unwrapped.eval()\n",
|
| 2519 |
+
"\n",
|
| 2520 |
+
"@torch.no_grad()\n",
|
| 2521 |
+
"def chat(prompt: str, steps: int = 64, temperature: float = 0.7, show_diffusion: bool = True):\n",
|
| 2522 |
+
" \"\"\"Chat with the diffusion model.\n",
|
| 2523 |
+
" \n",
|
| 2524 |
+
" Args:\n",
|
| 2525 |
+
" prompt: Your message\n",
|
| 2526 |
+
" steps: Denoising steps (more = better quality, slower)\n",
|
| 2527 |
+
" temperature: Sampling temperature (lower = more focused)\n",
|
| 2528 |
+
" show_diffusion: Show the step-by-step unmasking process\n",
|
| 2529 |
+
" \"\"\"\n",
|
| 2530 |
+
" # Tokenize prompt\n",
|
| 2531 |
+
" prompt_tokens = [USER_TOKEN] + tokenizer.encode(prompt)[:ft_config.max_prompt_len - 2] + [ASST_TOKEN]\n",
|
| 2532 |
+
" prompt_len = len(prompt_tokens)\n",
|
| 2533 |
+
" total_len = prompt_len + ft_config.max_response_len\n",
|
| 2534 |
+
"\n",
|
| 2535 |
+
" # Initialize: prompt (visible) + all masks (response)\n",
|
| 2536 |
+
" x = torch.full((1, total_len), config.mask_token_id, dtype=torch.long, device=device)\n",
|
| 2537 |
+
" x[0, :prompt_len] = torch.tensor(prompt_tokens, dtype=torch.long, device=device)\n",
|
| 2538 |
+
"\n",
|
| 2539 |
+
" timesteps_sched = torch.linspace(1.0 - 1e-5, 1e-5, steps + 1, device=device)\n",
|
| 2540 |
+
" snapshot_steps = set([int(steps * p) for p in [0, 0.1, 0.2, 0.35, 0.5, 0.7, 0.85, 1.0]])\n",
|
| 2541 |
+
"\n",
|
| 2542 |
+
" if show_diffusion:\n",
|
| 2543 |
+
" print(f'User: {prompt}')\n",
|
| 2544 |
+
" print(f'\\n--- Diffusion Process ({steps} steps) ---\\n')\n",
|
| 2545 |
+
"\n",
|
| 2546 |
+
" for i in range(steps):\n",
|
| 2547 |
+
" t_now = timesteps_sched[i]\n",
|
| 2548 |
+
" t_next = timesteps_sched[i + 1]\n",
|
| 2549 |
+
" alpha_now = model_unwrapped.noise_schedule.alpha(t_now)\n",
|
| 2550 |
+
" alpha_next = model_unwrapped.noise_schedule.alpha(t_next)\n",
|
| 2551 |
+
"\n",
|
| 2552 |
+
" t_batch = torch.full((1,), t_now.item(), device=device)\n",
|
| 2553 |
+
" logits = model_unwrapped.forward_full(x, t_batch)\n",
|
| 2554 |
+
" probs = F.softmax(logits / temperature, dim=-1)\n",
|
| 2555 |
+
"\n",
|
| 2556 |
+
" unmask_prob = ((alpha_next - alpha_now) / (1.0 - alpha_now + 1e-8)).clamp(0, 1)\n",
|
| 2557 |
+
" is_masked = (x == config.mask_token_id)\n",
|
| 2558 |
+
" unmask = is_masked & (torch.rand_like(x.float()) < unmask_prob)\n",
|
| 2559 |
+
"\n",
|
| 2560 |
+
" if unmask.any():\n",
|
| 2561 |
+
" flat_probs = probs.reshape(-1, config.vocab_size)\n",
|
| 2562 |
+
" sampled = torch.multinomial(flat_probs, 1).reshape(1, total_len)\n",
|
| 2563 |
+
" x = torch.where(unmask, sampled, x)\n",
|
| 2564 |
+
"\n",
|
| 2565 |
+
" # Show snapshot\n",
|
| 2566 |
+
" if show_diffusion and i in snapshot_steps:\n",
|
| 2567 |
+
" resp_tokens = x[0, prompt_len:].cpu().tolist()\n",
|
| 2568 |
+
" text = ''\n",
|
| 2569 |
+
" for tok in resp_tokens:\n",
|
| 2570 |
+
" if tok == config.mask_token_id:\n",
|
| 2571 |
+
" text += ' \\u2588'\n",
|
| 2572 |
+
" elif tok == END_TOKEN:\n",
|
| 2573 |
+
" break\n",
|
| 2574 |
+
" else:\n",
|
| 2575 |
+
" text += tokenizer.decode([tok])\n",
|
| 2576 |
+
" pct = (1 - is_masked[:, prompt_len:].float().mean()).item() * 100\n",
|
| 2577 |
+
" print(f' [{pct:5.1f}% revealed] {text[:200]}')\n",
|
| 2578 |
+
"\n",
|
| 2579 |
+
" # Final cleanup\n",
|
| 2580 |
+
" is_masked = (x == config.mask_token_id)\n",
|
| 2581 |
+
" if is_masked.any():\n",
|
| 2582 |
+
" t_batch = torch.full((1,), 1e-5, device=device)\n",
|
| 2583 |
+
" logits = model_unwrapped.forward_full(x, t_batch)\n",
|
| 2584 |
+
" probs = F.softmax(logits / temperature, dim=-1)\n",
|
| 2585 |
+
" flat_probs = probs.reshape(-1, config.vocab_size)\n",
|
| 2586 |
+
" sampled = torch.multinomial(flat_probs, 1).reshape(1, total_len)\n",
|
| 2587 |
+
" x = torch.where(is_masked, sampled, x)\n",
|
| 2588 |
+
"\n",
|
| 2589 |
+
" # Decode final response\n",
|
| 2590 |
+
" response_tokens = x[0, prompt_len:].cpu().tolist()\n",
|
| 2591 |
+
" if END_TOKEN in response_tokens:\n",
|
| 2592 |
+
" response_tokens = response_tokens[:response_tokens.index(END_TOKEN)]\n",
|
| 2593 |
+
" response = tokenizer.decode(response_tokens, skip_special_tokens=True)\n",
|
| 2594 |
+
"\n",
|
| 2595 |
+
" if show_diffusion:\n",
|
| 2596 |
+
" print(f'\\n--- Final ---')\n",
|
| 2597 |
+
" print(f'\\nUser: {prompt}')\n",
|
| 2598 |
+
" print(f'Bot: {response}')\n",
|
| 2599 |
+
" return response\n",
|
| 2600 |
+
"\n",
|
| 2601 |
+
"print('Chat function ready! Usage: chat(\"your message here\")')\n"
|
| 2602 |
+
]
|
| 2603 |
+
},
|
| 2604 |
+
{
|
| 2605 |
+
"cell_type": "code",
|
| 2606 |
+
"id": "chat_examples",
|
| 2607 |
+
"metadata": {},
|
| 2608 |
+
"outputs": [],
|
| 2609 |
+
"source": [
|
| 2610 |
+
"# Try it out!\n",
|
| 2611 |
+
"chat('What is the moon?')\n",
|
| 2612 |
+
"print('\\n' + '='*60 + '\\n')\n",
|
| 2613 |
+
"chat('Write a short poem about the ocean.')\n",
|
| 2614 |
+
"print('\\n' + '='*60 + '\\n')\n",
|
| 2615 |
+
"chat('Explain what a computer is to a child.')\n",
|
| 2616 |
+
"print('\\n' + '='*60 + '\\n')\n",
|
| 2617 |
+
"chat('What are three things that make people happy?')\n"
|
| 2618 |
+
]
|
| 2619 |
+
},
|
| 2620 |
+
{
|
| 2621 |
+
"cell_type": "code",
|
| 2622 |
+
"id": "ft_upload",
|
| 2623 |
+
"metadata": {},
|
| 2624 |
+
"outputs": [],
|
| 2625 |
+
"source": [
|
| 2626 |
+
"# Upload fine-tuned model to HuggingFace\n",
|
| 2627 |
+
"from huggingface_hub import HfApi\n",
|
| 2628 |
+
"TOKEN = 'YOUR_HF_TOKEN_HERE'\n",
|
| 2629 |
+
"api = HfApi(token=TOKEN)\n",
|
| 2630 |
+
"\n",
|
| 2631 |
+
"api.upload_file(\n",
|
| 2632 |
+
" path_or_fileobj='checkpoint_chat.pt',\n",
|
| 2633 |
+
" path_in_repo='checkpoint_chat.pt',\n",
|
| 2634 |
+
" repo_id='chipling/opium-mdlm',\n",
|
| 2635 |
+
" repo_type='model',\n",
|
| 2636 |
+
" token=TOKEN,\n",
|
| 2637 |
+
")\n",
|
| 2638 |
+
"print('Chat model uploaded to HuggingFace!')\n"
|
| 2639 |
+
]
|
| 2640 |
}
|
| 2641 |
],
|
| 2642 |
"metadata": {
|