chipling commited on
Commit
c6e6873
·
verified ·
1 Parent(s): 830a02f

Upload main.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. main.ipynb +476 -0
main.ipynb CHANGED
@@ -2161,6 +2161,482 @@
2161
  "\n",
2162
  "visualize_diffusion(model, tokenizer)"
2163
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2164
  }
2165
  ],
2166
  "metadata": {
 
2161
  "\n",
2162
  "visualize_diffusion(model, tokenizer)"
2163
  ]
2164
+ },
2165
+ {
2166
+ "cell_type": "markdown",
2167
+ "id": "ft_header",
2168
+ "metadata": {},
2169
+ "outputs": [],
2170
+ "source": [
2171
+ "---\n",
2172
+ "# Part 2: Fine-tuning for Chat\n",
2173
+ "\n",
2174
+ "Now we turn the pretrained MDLM into a **chatbot** using supervised fine-tuning on dialogue data.\n",
2175
+ "\n",
2176
+ "## How diffusion chat works\n",
2177
+ "1. Format: `<|user|> message <|assistant|> response <|end|>`\n",
2178
+ "2. **Training**: Mask only the response tokens \u2014 the user message stays visible as context\n",
2179
+ "3. **Inference**: User types a message \u2192 freeze those tokens \u2192 diffusion unmasks only the response\n",
2180
+ "4. **The cool part**: The response materializes all at once, not left-to-right"
2181
+ ]
2182
+ },
2183
+ {
2184
+ "cell_type": "code",
2185
+ "id": "ft_config",
2186
+ "metadata": {},
2187
+ "outputs": [],
2188
+ "source": [
2189
+ "# ============================================================\n",
2190
+ "# FINE-TUNING CONFIG\n",
2191
+ "# ============================================================\n",
2192
+ "\n",
2193
+ "@dataclass\n",
2194
+ "class FinetuneConfig:\n",
2195
+ " # Training\n",
2196
+ " ft_steps: int = 5000\n",
2197
+ " ft_batch_size: int = 16\n",
2198
+ " ft_lr: float = 5e-5 # Lower LR for fine-tuning\n",
2199
+ " ft_warmup: int = 200\n",
2200
+ " max_response_len: int = 128 # Max response length\n",
2201
+ " max_prompt_len: int = 64 # Max prompt length\n",
2202
+ " log_every: int = 50\n",
2203
+ " sample_every: int = 500\n",
2204
+ "\n",
2205
+ "ft_config = FinetuneConfig()\n",
2206
+ "\n",
2207
+ "# Add special tokens to tokenizer\n",
2208
+ "SPECIAL_TOKENS = {\n",
2209
+ " 'additional_special_tokens': ['<|user|>', '<|assistant|>', '<|end|>']\n",
2210
+ "}\n",
2211
+ "tokenizer.add_special_tokens(SPECIAL_TOKENS)\n",
2212
+ "\n",
2213
+ "USER_TOKEN = tokenizer.convert_tokens_to_ids('<|user|>')\n",
2214
+ "ASST_TOKEN = tokenizer.convert_tokens_to_ids('<|assistant|>')\n",
2215
+ "END_TOKEN = tokenizer.convert_tokens_to_ids('<|end|>')\n",
2216
+ "\n",
2217
+ "print(f'Special token IDs: USER={USER_TOKEN}, ASST={ASST_TOKEN}, END={END_TOKEN}')\n",
2218
+ "\n",
2219
+ "# Resize model embeddings to accommodate new tokens\n",
2220
+ "old_vocab = config.vocab_size\n",
2221
+ "new_vocab = len(tokenizer)\n",
2222
+ "if new_vocab > old_vocab:\n",
2223
+ " # Expand embedding and output projection\n",
2224
+ " old_emb = model_unwrapped.token_emb.weight.data\n",
2225
+ " model_unwrapped.token_emb = nn.Embedding(new_vocab, config.hidden_dim).to(device)\n",
2226
+ " model_unwrapped.token_emb.weight.data[:old_vocab] = old_emb\n",
2227
+ " # Re-tie output projection\n",
2228
+ " model_unwrapped.output_proj = nn.Linear(config.hidden_dim, new_vocab, bias=False).to(device)\n",
2229
+ " model_unwrapped.output_proj.weight = model_unwrapped.token_emb.weight\n",
2230
+ " # Update config\n",
2231
+ " config.vocab_size = new_vocab\n",
2232
+ " model_unwrapped.config.vocab_size = new_vocab\n",
2233
+ " print(f'Resized embeddings: {old_vocab} -> {new_vocab}')\n",
2234
+ "\n",
2235
+ "print(f'Fine-tune config ready')\n"
2236
+ ]
2237
+ },
2238
+ {
2239
+ "cell_type": "code",
2240
+ "id": "ft_dataset",
2241
+ "metadata": {},
2242
+ "outputs": [],
2243
+ "source": [
2244
+ "# ============================================================\n",
2245
+ "# DIALOGUE DATASET\n",
2246
+ "# ============================================================\n",
2247
+ "\n",
2248
+ "from datasets import load_dataset\n",
2249
+ "\n",
2250
+ "# Using Alpaca-cleaned: simple instruction-response pairs\n",
2251
+ "print('Loading Alpaca dataset...')\n",
2252
+ "alpaca = load_dataset('yahma/alpaca-cleaned', split='train')\n",
2253
+ "print(f'Loaded {len(alpaca)} examples')\n",
2254
+ "\n",
2255
+ "class ChatDataset(torch.utils.data.Dataset):\n",
2256
+ " \"\"\"Format dialogue as: <|user|> instruction <|assistant|> response <|end|>\n",
2257
+ " \n",
2258
+ " Returns:\n",
2259
+ " input_ids: full sequence token ids\n",
2260
+ " response_mask: bool mask, True for response tokens (what we train on)\n",
2261
+ " \"\"\"\n",
2262
+ " def __init__(self, dataset, tokenizer, max_prompt_len, max_response_len):\n",
2263
+ " self.data = dataset\n",
2264
+ " self.tokenizer = tokenizer\n",
2265
+ " self.max_prompt_len = max_prompt_len\n",
2266
+ " self.max_response_len = max_response_len\n",
2267
+ " self.total_len = max_prompt_len + max_response_len\n",
2268
+ " \n",
2269
+ " def __len__(self):\n",
2270
+ " return len(self.data)\n",
2271
+ " \n",
2272
+ " def __getitem__(self, idx):\n",
2273
+ " item = self.data[idx]\n",
2274
+ " \n",
2275
+ " # Build prompt\n",
2276
+ " instruction = item['instruction']\n",
2277
+ " if item.get('input', ''):\n",
2278
+ " instruction = instruction + ' ' + item['input']\n",
2279
+ " response = item['output']\n",
2280
+ " \n",
2281
+ " # Tokenize separately\n",
2282
+ " prompt_tokens = [USER_TOKEN] + self.tokenizer.encode(instruction)[:self.max_prompt_len - 2] + [ASST_TOKEN]\n",
2283
+ " response_tokens = self.tokenizer.encode(response)[:self.max_response_len - 1] + [END_TOKEN]\n",
2284
+ " \n",
2285
+ " # Combine\n",
2286
+ " input_ids = prompt_tokens + response_tokens\n",
2287
+ " prompt_len = len(prompt_tokens)\n",
2288
+ " \n",
2289
+ " # Pad or truncate to fixed length\n",
2290
+ " if len(input_ids) < self.total_len:\n",
2291
+ " pad_len = self.total_len - len(input_ids)\n",
2292
+ " input_ids = input_ids + [tokenizer.eos_token_id] * pad_len\n",
2293
+ " else:\n",
2294
+ " input_ids = input_ids[:self.total_len]\n",
2295
+ " \n",
2296
+ " input_ids = torch.tensor(input_ids, dtype=torch.long)\n",
2297
+ " \n",
2298
+ " # Response mask: True for response positions only\n",
2299
+ " response_mask = torch.zeros(self.total_len, dtype=torch.bool)\n",
2300
+ " response_mask[prompt_len:prompt_len + len(response_tokens)] = True\n",
2301
+ " \n",
2302
+ " return input_ids, response_mask\n",
2303
+ "\n",
2304
+ "chat_dataset = ChatDataset(alpaca, tokenizer, ft_config.max_prompt_len, ft_config.max_response_len)\n",
2305
+ "chat_loader = DataLoader(chat_dataset, batch_size=ft_config.ft_batch_size, shuffle=True, num_workers=2, pin_memory=True)\n",
2306
+ "\n",
2307
+ "# Test\n",
2308
+ "test_ids, test_mask = chat_dataset[0]\n",
2309
+ "print(f'\\nExample:')\n",
2310
+ "print(f'Full sequence: {tokenizer.decode(test_ids[:40])}...')\n",
2311
+ "print(f'Prompt tokens: {test_mask.sum().item()} response positions out of {len(test_ids)}')\n",
2312
+ "print(f'\\nPrompt part: {tokenizer.decode(test_ids[~test_mask][:30])}')\n",
2313
+ "print(f'Response part: {tokenizer.decode(test_ids[test_mask][:30])}')\n"
2314
+ ]
2315
+ },
2316
+ {
2317
+ "cell_type": "code",
2318
+ "id": "ft_train",
2319
+ "metadata": {},
2320
+ "outputs": [],
2321
+ "source": [
2322
+ "# ============================================================\n",
2323
+ "# FINE-TUNING LOOP\n",
2324
+ "# ============================================================\n",
2325
+ "\n",
2326
+ "# Fresh optimizer with lower LR\n",
2327
+ "ft_optimizer = torch.optim.AdamW(\n",
2328
+ " model_unwrapped.parameters(),\n",
2329
+ " lr=ft_config.ft_lr,\n",
2330
+ " betas=(0.9, 0.98),\n",
2331
+ " weight_decay=0.01,\n",
2332
+ ")\n",
2333
+ "ft_scaler = GradScaler('cuda')\n",
2334
+ "ft_ema = EMA(model_unwrapped, decay=0.999) # Faster EMA for fine-tuning\n",
2335
+ "\n",
2336
+ "model_unwrapped.train()\n",
2337
+ "ft_losses = []\n",
2338
+ "ft_accuracies = []\n",
2339
+ "ft_start = time.time()\n",
2340
+ "chat_iter = iter(chat_loader)\n",
2341
+ "\n",
2342
+ "print(f'Fine-tuning for {ft_config.ft_steps} steps...')\n",
2343
+ "print(f'Batch size: {ft_config.ft_batch_size}')\n",
2344
+ "print('=' * 60)\n",
2345
+ "\n",
2346
+ "for step in range(1, ft_config.ft_steps + 1):\n",
2347
+ " # LR schedule: linear warmup + cosine decay\n",
2348
+ " lr = get_lr(step, ft_config.ft_warmup, ft_config.ft_steps, ft_config.ft_lr)\n",
2349
+ " for pg in ft_optimizer.param_groups:\n",
2350
+ " pg['lr'] = lr\n",
2351
+ "\n",
2352
+ " try:\n",
2353
+ " input_ids, response_mask = next(chat_iter)\n",
2354
+ " except StopIteration:\n",
2355
+ " chat_iter = iter(chat_loader)\n",
2356
+ " input_ids, response_mask = next(chat_iter)\n",
2357
+ "\n",
2358
+ " input_ids = input_ids.to(device)\n",
2359
+ " response_mask = response_mask.to(device)\n",
2360
+ "\n",
2361
+ " ft_optimizer.zero_grad()\n",
2362
+ "\n",
2363
+ " with autocast('cuda', dtype=torch.float16):\n",
2364
+ " B, L = input_ids.shape\n",
2365
+ "\n",
2366
+ " # Sample timestep\n",
2367
+ " t = model_unwrapped.noise_schedule.sample_t(B, device)\n",
2368
+ "\n",
2369
+ " # Forward process: mask ONLY response tokens\n",
2370
+ " # Prompt tokens stay unmasked (model can always see them)\n",
2371
+ " alpha_t = model_unwrapped.noise_schedule.alpha(t)[:, None] # [B, 1]\n",
2372
+ " mask_prob = 1.0 - alpha_t\n",
2373
+ " noise_mask = (torch.rand_like(input_ids.float()) < mask_prob) & response_mask\n",
2374
+ " z_t = torch.where(noise_mask, config.mask_token_id, input_ids)\n",
2375
+ "\n",
2376
+ " # Forward pass\n",
2377
+ " hidden = model_unwrapped.forward_hidden(z_t, t)\n",
2378
+ "\n",
2379
+ " # Loss only at masked response positions\n",
2380
+ " masked_hidden = hidden[noise_mask]\n",
2381
+ " masked_targets = input_ids[noise_mask]\n",
2382
+ "\n",
2383
+ " if masked_hidden.shape[0] > 0:\n",
2384
+ " masked_logits = F.linear(masked_hidden, model_unwrapped.output_proj.weight)\n",
2385
+ " masked_logits[:, config.mask_token_id] = -1e9\n",
2386
+ " ce_loss = F.cross_entropy(masked_logits, masked_targets, reduction='none')\n",
2387
+ " weight = model_unwrapped.noise_schedule.loss_weight(t)\n",
2388
+ " weight_expanded = weight[:, None].expand(B, L)[noise_mask]\n",
2389
+ " loss = (ce_loss * weight_expanded).mean()\n",
2390
+ "\n",
2391
+ " with torch.no_grad():\n",
2392
+ " acc = (masked_logits.argmax(-1) == masked_targets).float().mean().item()\n",
2393
+ " else:\n",
2394
+ " loss = torch.tensor(0.0, device=device)\n",
2395
+ " acc = 1.0\n",
2396
+ "\n",
2397
+ " ft_scaler.scale(loss).backward()\n",
2398
+ " ft_scaler.unscale_(ft_optimizer)\n",
2399
+ " grad_norm = nn.utils.clip_grad_norm_(model_unwrapped.parameters(), 1.0)\n",
2400
+ " ft_scaler.step(ft_optimizer)\n",
2401
+ " ft_scaler.update()\n",
2402
+ " ft_ema.update(model_unwrapped)\n",
2403
+ "\n",
2404
+ " ft_losses.append(loss.item())\n",
2405
+ " ft_accuracies.append(acc)\n",
2406
+ "\n",
2407
+ " if step % ft_config.log_every == 0:\n",
2408
+ " elapsed = time.time() - ft_start\n",
2409
+ " avg_loss = np.mean(ft_losses[-ft_config.log_every:])\n",
2410
+ " avg_acc = np.mean(ft_accuracies[-ft_config.log_every:])\n",
2411
+ " eta = (ft_config.ft_steps - step) / (step / elapsed) / 60\n",
2412
+ " print(f'Step {step:>5d}/{ft_config.ft_steps} | Loss: {avg_loss:.4f} | Acc: {avg_acc:.3f} | LR: {lr:.2e} | Grad: {grad_norm:.2f} | ETA: {eta:.1f}m')\n",
2413
+ "\n",
2414
+ " # Generate chat samples\n",
2415
+ " if step % ft_config.sample_every == 0:\n",
2416
+ " print(f\"\\n{'='*60}\")\n",
2417
+ " print(f'Chat samples at step {step}:')\n",
2418
+ " ft_ema.apply_shadow(model_unwrapped)\n",
2419
+ " model_unwrapped.eval()\n",
2420
+ "\n",
2421
+ " test_prompts = [\n",
2422
+ " 'What is the moon?',\n",
2423
+ " 'Write a short poem about the ocean.',\n",
2424
+ " 'Explain what a computer is.',\n",
2425
+ " 'What is the meaning of life?',\n",
2426
+ " ]\n",
2427
+ "\n",
2428
+ " for prompt in test_prompts:\n",
2429
+ " # Tokenize prompt\n",
2430
+ " prompt_tokens = [USER_TOKEN] + tokenizer.encode(prompt)[:ft_config.max_prompt_len - 2] + [ASST_TOKEN]\n",
2431
+ " prompt_len = len(prompt_tokens)\n",
2432
+ " total_len = prompt_len + ft_config.max_response_len\n",
2433
+ "\n",
2434
+ " # Start with prompt + all masks for response\n",
2435
+ " x = torch.full((1, total_len), config.mask_token_id, dtype=torch.long, device=device)\n",
2436
+ " x[0, :prompt_len] = torch.tensor(prompt_tokens, dtype=torch.long, device=device)\n",
2437
+ "\n",
2438
+ " # Diffusion sampling \u2014 only unmask response positions\n",
2439
+ " timesteps = torch.linspace(1.0 - 1e-5, 1e-5, 128 + 1, device=device)\n",
2440
+ " for i in range(128):\n",
2441
+ " t_now = timesteps[i]\n",
2442
+ " t_next = timesteps[i + 1]\n",
2443
+ " alpha_now = model_unwrapped.noise_schedule.alpha(t_now)\n",
2444
+ " alpha_next = model_unwrapped.noise_schedule.alpha(t_next)\n",
2445
+ "\n",
2446
+ " t_batch = torch.full((1,), t_now.item(), device=device)\n",
2447
+ " logits = model_unwrapped.forward_full(x, t_batch)\n",
2448
+ " probs = F.softmax(logits / 0.7, dim=-1)\n",
2449
+ "\n",
2450
+ " unmask_prob = ((alpha_next - alpha_now) / (1.0 - alpha_now + 1e-8)).clamp(0, 1)\n",
2451
+ " is_masked = (x == config.mask_token_id)\n",
2452
+ " unmask = is_masked & (torch.rand_like(x.float()) < unmask_prob)\n",
2453
+ "\n",
2454
+ " if unmask.any():\n",
2455
+ " flat_probs = probs.reshape(-1, config.vocab_size)\n",
2456
+ " sampled = torch.multinomial(flat_probs, 1).reshape(1, total_len)\n",
2457
+ " x = torch.where(unmask, sampled, x)\n",
2458
+ "\n",
2459
+ " # Final cleanup\n",
2460
+ " is_masked = (x == config.mask_token_id)\n",
2461
+ " if is_masked.any():\n",
2462
+ " t_batch = torch.full((1,), 1e-5, device=device)\n",
2463
+ " logits = model_unwrapped.forward_full(x, t_batch)\n",
2464
+ " probs = F.softmax(logits / 0.7, dim=-1)\n",
2465
+ " flat_probs = probs.reshape(-1, config.vocab_size)\n",
2466
+ " sampled = torch.multinomial(flat_probs, 1).reshape(1, total_len)\n",
2467
+ " x = torch.where(is_masked, sampled, x)\n",
2468
+ "\n",
2469
+ " # Decode response only\n",
2470
+ " response_tokens = x[0, prompt_len:].cpu().tolist()\n",
2471
+ " # Cut at END token\n",
2472
+ " if END_TOKEN in response_tokens:\n",
2473
+ " response_tokens = response_tokens[:response_tokens.index(END_TOKEN)]\n",
2474
+ " response = tokenizer.decode(response_tokens, skip_special_tokens=True)\n",
2475
+ " print(f'\\n User: {prompt}')\n",
2476
+ " print(f' Bot: {response}')\n",
2477
+ "\n",
2478
+ " model_unwrapped.train()\n",
2479
+ " ft_ema.restore(model_unwrapped)\n",
2480
+ " print(f\"{'='*60}\\n\")\n",
2481
+ "\n",
2482
+ "# Save fine-tuned model\n",
2483
+ "torch.save({\n",
2484
+ " 'step': step,\n",
2485
+ " 'model_state_dict': model_unwrapped.state_dict(),\n",
2486
+ " 'ema_shadow': ft_ema.shadow,\n",
2487
+ " 'config': config,\n",
2488
+ "}, 'checkpoint_chat.pt')\n",
2489
+ "print('Fine-tuning complete! Saved checkpoint_chat.pt')\n"
2490
+ ]
2491
+ },
2492
+ {
2493
+ "cell_type": "markdown",
2494
+ "id": "chat_header",
2495
+ "metadata": {},
2496
+ "outputs": [],
2497
+ "source": [
2498
+ "## Chat with your Diffusion LM\n",
2499
+ "\n",
2500
+ "Type a message and watch the response **materialize from noise** via the diffusion process."
2501
+ ]
2502
+ },
2503
+ {
2504
+ "cell_type": "code",
2505
+ "id": "chat_interface",
2506
+ "metadata": {},
2507
+ "outputs": [],
2508
+ "source": [
2509
+ "# ============================================================\n",
2510
+ "# CHAT INTERFACE WITH DIFFUSION VISUALIZATION\n",
2511
+ "# ============================================================\n",
2512
+ "\n",
2513
+ "from IPython.display import clear_output, display\n",
2514
+ "import time as _time\n",
2515
+ "\n",
2516
+ "# Load EMA weights\n",
2517
+ "ft_ema.apply_shadow(model_unwrapped)\n",
2518
+ "model_unwrapped.eval()\n",
2519
+ "\n",
2520
+ "@torch.no_grad()\n",
2521
+ "def chat(prompt: str, steps: int = 64, temperature: float = 0.7, show_diffusion: bool = True):\n",
2522
+ " \"\"\"Chat with the diffusion model.\n",
2523
+ " \n",
2524
+ " Args:\n",
2525
+ " prompt: Your message\n",
2526
+ " steps: Denoising steps (more = better quality, slower)\n",
2527
+ " temperature: Sampling temperature (lower = more focused)\n",
2528
+ " show_diffusion: Show the step-by-step unmasking process\n",
2529
+ " \"\"\"\n",
2530
+ " # Tokenize prompt\n",
2531
+ " prompt_tokens = [USER_TOKEN] + tokenizer.encode(prompt)[:ft_config.max_prompt_len - 2] + [ASST_TOKEN]\n",
2532
+ " prompt_len = len(prompt_tokens)\n",
2533
+ " total_len = prompt_len + ft_config.max_response_len\n",
2534
+ "\n",
2535
+ " # Initialize: prompt (visible) + all masks (response)\n",
2536
+ " x = torch.full((1, total_len), config.mask_token_id, dtype=torch.long, device=device)\n",
2537
+ " x[0, :prompt_len] = torch.tensor(prompt_tokens, dtype=torch.long, device=device)\n",
2538
+ "\n",
2539
+ " timesteps_sched = torch.linspace(1.0 - 1e-5, 1e-5, steps + 1, device=device)\n",
2540
+ " snapshot_steps = set([int(steps * p) for p in [0, 0.1, 0.2, 0.35, 0.5, 0.7, 0.85, 1.0]])\n",
2541
+ "\n",
2542
+ " if show_diffusion:\n",
2543
+ " print(f'User: {prompt}')\n",
2544
+ " print(f'\\n--- Diffusion Process ({steps} steps) ---\\n')\n",
2545
+ "\n",
2546
+ " for i in range(steps):\n",
2547
+ " t_now = timesteps_sched[i]\n",
2548
+ " t_next = timesteps_sched[i + 1]\n",
2549
+ " alpha_now = model_unwrapped.noise_schedule.alpha(t_now)\n",
2550
+ " alpha_next = model_unwrapped.noise_schedule.alpha(t_next)\n",
2551
+ "\n",
2552
+ " t_batch = torch.full((1,), t_now.item(), device=device)\n",
2553
+ " logits = model_unwrapped.forward_full(x, t_batch)\n",
2554
+ " probs = F.softmax(logits / temperature, dim=-1)\n",
2555
+ "\n",
2556
+ " unmask_prob = ((alpha_next - alpha_now) / (1.0 - alpha_now + 1e-8)).clamp(0, 1)\n",
2557
+ " is_masked = (x == config.mask_token_id)\n",
2558
+ " unmask = is_masked & (torch.rand_like(x.float()) < unmask_prob)\n",
2559
+ "\n",
2560
+ " if unmask.any():\n",
2561
+ " flat_probs = probs.reshape(-1, config.vocab_size)\n",
2562
+ " sampled = torch.multinomial(flat_probs, 1).reshape(1, total_len)\n",
2563
+ " x = torch.where(unmask, sampled, x)\n",
2564
+ "\n",
2565
+ " # Show snapshot\n",
2566
+ " if show_diffusion and i in snapshot_steps:\n",
2567
+ " resp_tokens = x[0, prompt_len:].cpu().tolist()\n",
2568
+ " text = ''\n",
2569
+ " for tok in resp_tokens:\n",
2570
+ " if tok == config.mask_token_id:\n",
2571
+ " text += ' \\u2588'\n",
2572
+ " elif tok == END_TOKEN:\n",
2573
+ " break\n",
2574
+ " else:\n",
2575
+ " text += tokenizer.decode([tok])\n",
2576
+ " pct = (1 - is_masked[:, prompt_len:].float().mean()).item() * 100\n",
2577
+ " print(f' [{pct:5.1f}% revealed] {text[:200]}')\n",
2578
+ "\n",
2579
+ " # Final cleanup\n",
2580
+ " is_masked = (x == config.mask_token_id)\n",
2581
+ " if is_masked.any():\n",
2582
+ " t_batch = torch.full((1,), 1e-5, device=device)\n",
2583
+ " logits = model_unwrapped.forward_full(x, t_batch)\n",
2584
+ " probs = F.softmax(logits / temperature, dim=-1)\n",
2585
+ " flat_probs = probs.reshape(-1, config.vocab_size)\n",
2586
+ " sampled = torch.multinomial(flat_probs, 1).reshape(1, total_len)\n",
2587
+ " x = torch.where(is_masked, sampled, x)\n",
2588
+ "\n",
2589
+ " # Decode final response\n",
2590
+ " response_tokens = x[0, prompt_len:].cpu().tolist()\n",
2591
+ " if END_TOKEN in response_tokens:\n",
2592
+ " response_tokens = response_tokens[:response_tokens.index(END_TOKEN)]\n",
2593
+ " response = tokenizer.decode(response_tokens, skip_special_tokens=True)\n",
2594
+ "\n",
2595
+ " if show_diffusion:\n",
2596
+ " print(f'\\n--- Final ---')\n",
2597
+ " print(f'\\nUser: {prompt}')\n",
2598
+ " print(f'Bot: {response}')\n",
2599
+ " return response\n",
2600
+ "\n",
2601
+ "print('Chat function ready! Usage: chat(\"your message here\")')\n"
2602
+ ]
2603
+ },
2604
+ {
2605
+ "cell_type": "code",
2606
+ "id": "chat_examples",
2607
+ "metadata": {},
2608
+ "outputs": [],
2609
+ "source": [
2610
+ "# Try it out!\n",
2611
+ "chat('What is the moon?')\n",
2612
+ "print('\\n' + '='*60 + '\\n')\n",
2613
+ "chat('Write a short poem about the ocean.')\n",
2614
+ "print('\\n' + '='*60 + '\\n')\n",
2615
+ "chat('Explain what a computer is to a child.')\n",
2616
+ "print('\\n' + '='*60 + '\\n')\n",
2617
+ "chat('What are three things that make people happy?')\n"
2618
+ ]
2619
+ },
2620
+ {
2621
+ "cell_type": "code",
2622
+ "id": "ft_upload",
2623
+ "metadata": {},
2624
+ "outputs": [],
2625
+ "source": [
2626
+ "# Upload fine-tuned model to HuggingFace\n",
2627
+ "from huggingface_hub import HfApi\n",
2628
+ "TOKEN = 'YOUR_HF_TOKEN_HERE'\n",
2629
+ "api = HfApi(token=TOKEN)\n",
2630
+ "\n",
2631
+ "api.upload_file(\n",
2632
+ " path_or_fileobj='checkpoint_chat.pt',\n",
2633
+ " path_in_repo='checkpoint_chat.pt',\n",
2634
+ " repo_id='chipling/opium-mdlm',\n",
2635
+ " repo_type='model',\n",
2636
+ " token=TOKEN,\n",
2637
+ ")\n",
2638
+ "print('Chat model uploaded to HuggingFace!')\n"
2639
+ ]
2640
  }
2641
  ],
2642
  "metadata": {