rtferraz commited on
Commit
65ecf7e
·
verified ·
1 Parent(s): 2c3ddfa

Fix notebook: total_mem → total_memory, add hub_model_id push, add wandb logging support

Browse files
Files changed (1) hide show
  1. notebooks/01_finance_pretrain.ipynb +75 -138
notebooks/01_finance_pretrain.ipynb CHANGED
@@ -38,7 +38,7 @@
38
  "outputs": [],
39
  "source": [
40
  "# Uncomment and run once to install dependencies:\n",
41
- "# !pip install datasets transformers torch accelerate tokenizers numpy pandas matplotlib scikit-learn"
42
  ]
43
  },
44
  {
@@ -75,7 +75,23 @@
75
  "logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')\n",
76
  "print(f'torch: {torch.__version__}, CUDA: {torch.cuda.is_available()}')\n",
77
  "if torch.cuda.is_available():\n",
78
- " print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f}GB')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  ]
80
  },
81
  {
@@ -143,7 +159,6 @@
143
  "metadata": {},
144
  "outputs": [],
145
  "source": [
146
- "# Events per user distribution\n",
147
  "events_per_user = df.groupby('sender_account').size()\n",
148
  "print(f\"Events per user: min={events_per_user.min()}, max={events_per_user.max()}, \"\n",
149
  " f\"mean={events_per_user.mean():.1f}, median={events_per_user.median():.1f}\")\n",
@@ -151,23 +166,13 @@
151
  "print(f\"Users with 10+ events: {(events_per_user >= 10).sum():,}\")\n",
152
  "\n",
153
  "fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
154
- "\n",
155
  "axes[0].hist(np.log10(df['amount_ngn'].clip(lower=1)), bins=50, edgecolor='black', alpha=0.7)\n",
156
- "axes[0].set_xlabel('log10(Amount NGN)')\n",
157
- "axes[0].set_ylabel('Count')\n",
158
- "axes[0].set_title('Amount Distribution (log scale)')\n",
159
- "\n",
160
  "axes[1].hist(events_per_user.clip(upper=50), bins=50, edgecolor='black', alpha=0.7)\n",
161
- "axes[1].set_xlabel('Events per User')\n",
162
- "axes[1].set_ylabel('Count')\n",
163
- "axes[1].set_title('Events per User')\n",
164
- "\n",
165
  "df['transaction_type'].value_counts().head(10).plot(kind='barh', ax=axes[2])\n",
166
- "axes[2].set_xlabel('Count')\n",
167
- "axes[2].set_title('Transaction Types')\n",
168
- "\n",
169
- "plt.tight_layout()\n",
170
- "plt.show()"
171
  ]
172
  },
173
  {
@@ -190,21 +195,14 @@
190
  "outputs": [],
191
  "source": [
192
  "def row_to_event(row):\n",
193
- " \"\"\"Convert a DataFrame row to a FINANCE_SCHEMA event dict.\"\"\"\n",
194
  " dt = datetime.strptime(row['timestamp'][:19], '%Y-%m-%d %H:%M:%S')\n",
195
  " desc = f\"{row['merchant_category']} {row['transaction_type']}\"\n",
196
  " amt = row['amount_ngn']\n",
197
  " if row['transaction_type'] == 'withdrawal':\n",
198
  " amt = -abs(amt)\n",
199
- " return {\n",
200
- " 'amount_sign': amt,\n",
201
- " 'amount': amt,\n",
202
- " 'timestamp': dt,\n",
203
- " 'description': desc,\n",
204
- " }\n",
205
- "\n",
206
- "sample = row_to_event(df.iloc[0])\n",
207
- "print(f'Sample event: {sample}')"
208
  ]
209
  },
210
  {
@@ -215,12 +213,9 @@
215
  "source": [
216
  "%%time\n",
217
  "MIN_EVENTS = 5\n",
218
- "MAX_EVENTS = 500 # cap to prevent very long sequences from dominating\n",
219
- "\n",
220
- "user_sequences = []\n",
221
- "user_ids = []\n",
222
- "user_fraud_labels = []\n",
223
  "\n",
 
224
  "for sender, group in df.sort_values('timestamp').groupby('sender_account'):\n",
225
  " if len(group) < MIN_EVENTS:\n",
226
  " continue\n",
@@ -231,9 +226,7 @@
231
  "\n",
232
  "print(f'Users with {MIN_EVENTS}+ events: {len(user_sequences):,}')\n",
233
  "print(f'Total events: {sum(len(s) for s in user_sequences):,}')\n",
234
- "print(f'Events per user: min={min(len(s) for s in user_sequences)}, '\n",
235
- " f'max={max(len(s) for s in user_sequences)}, '\n",
236
- " f'mean={np.mean([len(s) for s in user_sequences]):.1f}')\n",
237
  "print(f'Fraud rate (user-level): {np.mean(user_fraud_labels)*100:.2f}%')"
238
  ]
239
  },
@@ -243,8 +236,7 @@
243
  "source": [
244
  "## Step 4 — Build Domain Tokenizer\n",
245
  "\n",
246
- "Hybrid vocabulary: 97 special tokens (sign + amount bins + calendar) + BPE for descriptions.\n",
247
- "Following Nubank nuFormer's tokenization approach."
248
  ]
249
  },
250
  {
@@ -262,16 +254,10 @@
262
  "text_corpus = [e['description'] for e in all_events]\n",
263
  "unique_descs = sorted(set(text_corpus))\n",
264
  "print(f'Unique descriptions: {len(unique_descs)}')\n",
265
- "for d in unique_descs[:10]:\n",
266
- " print(f\" '{d}'\")\n",
267
- "if len(unique_descs) > 10:\n",
268
- " print(f' ... and {len(unique_descs) - 10} more')\n",
269
- "\n",
270
- "hf_tokenizer = builder.build(\n",
271
- " text_corpus=text_corpus,\n",
272
- " bpe_vocab_size=2000,\n",
273
- ")\n",
274
  "\n",
 
275
  "print(f'\\nVocab size: {hf_tokenizer.vocab_size}')\n",
276
  "print(f'Stats: {builder.get_stats()}')"
277
  ]
@@ -282,16 +268,12 @@
282
  "metadata": {},
283
  "outputs": [],
284
  "source": [
285
- "# Inspect tokenized output\n",
286
  "print('--- Sample event tokenized ---')\n",
287
- "sample_tokens = builder.tokenize_event(user_sequences[0][0])\n",
288
- "for i, t in enumerate(sample_tokens):\n",
289
- " print(f' [{i}] {t}')\n",
290
  "\n",
291
  "print(f'\\n--- First user, first 3 events ---')\n",
292
  "seq_tokens = builder.tokenize_sequence(user_sequences[0][:3])\n",
293
- "for i, t in enumerate(seq_tokens):\n",
294
- " print(f' [{i:3d}] {t}')\n",
295
  "\n",
296
  "seq_ids = hf_tokenizer(' '.join(seq_tokens), add_special_tokens=False)['input_ids']\n",
297
  "unk_id = hf_tokenizer.unk_token_id\n",
@@ -305,8 +287,7 @@
305
  "source": [
306
  "## Step 5 — Pack into CLM Training Dataset\n",
307
  "\n",
308
- "Sequence packing (run_clm.py pattern): concatenate all user sequences, split into fixed-length blocks.\n",
309
- "100% token utilization, zero padding waste."
310
  ]
311
  },
312
  {
@@ -316,13 +297,8 @@
316
  "outputs": [],
317
  "source": [
318
  "%%time\n",
319
- "BLOCK_SIZE = 512 # Nubank uses 2048; 512 for faster iteration\n",
320
- "\n",
321
- "dataset = prepare_clm_dataset(\n",
322
- " user_sequences, builder, hf_tokenizer,\n",
323
- " block_size=BLOCK_SIZE,\n",
324
- ")\n",
325
- "\n",
326
  "print(f'Packed: {len(dataset):,} blocks x {BLOCK_SIZE} = {len(dataset)*BLOCK_SIZE:,} training tokens')"
327
  ]
328
  },
@@ -332,25 +308,15 @@
332
  "metadata": {},
333
  "outputs": [],
334
  "source": [
335
- "# Decode a sample block to verify it looks right\n",
336
- "sample_block = dataset[0]['input_ids']\n",
337
  "print(f'Sample block decoded (first 60 tokens):')\n",
338
- "print(hf_tokenizer.decode(sample_block[:60]))\n",
339
  "\n",
340
- "# Token frequency analysis\n",
341
  "all_ids = [i for row in dataset for i in row['input_ids']]\n",
342
  "counts = Counter(all_ids)\n",
343
- "unk_pct = counts.get(unk_id, 0) / len(all_ids) * 100\n",
344
- "\n",
345
- "print(f'\\nTotal tokens: {len(all_ids):,}')\n",
346
- "print(f'Unique token IDs used: {len(counts)}/{hf_tokenizer.vocab_size}')\n",
347
- "print(f'UNK tokens: {counts.get(unk_id, 0):,} ({unk_pct:.2f}%)')\n",
348
- "\n",
349
  "print(f'\\nTop 20 tokens:')\n",
350
  "for tid, count in counts.most_common(20):\n",
351
- " tok_str = hf_tokenizer.decode([tid]).strip() or '(space/control)'\n",
352
- " pct = count / len(all_ids) * 100\n",
353
- " print(f' {tid:5d} {count:8,} ({pct:5.1f}%) {tok_str}')"
354
  ]
355
  },
356
  {
@@ -359,11 +325,12 @@
359
  "source": [
360
  "## Step 6 — Pre-Train 24M DomainTransformer\n",
361
  "\n",
362
- "Architecture (Nubank nuFormer):\n",
363
  "- GPT-style causal decoder, NoPE (no positional encoding)\n",
364
  "- 24M preset: d=512, 6 layers, 8 heads, FFN=2048\n",
365
- "- Cosine LR schedule with warmup, AdamW optimizer\n",
366
- "- CLM objective (next token prediction on transaction sequences)"
 
367
  ]
368
  },
369
  {
@@ -374,11 +341,8 @@
374
  "source": [
375
  "config = DomainTransformerConfig.from_preset('24m', vocab_size=hf_tokenizer.vocab_size)\n",
376
  "model = DomainTransformerForCausalLM(config)\n",
377
- "\n",
378
  "n_params = sum(p.numel() for p in model.parameters())\n",
379
- "print(f'Model: {n_params:,} parameters')\n",
380
- "print(f'Config: d={config.hidden_size}, L={config.num_hidden_layers}, H={config.num_attention_heads}')\n",
381
- "print(f'VRAM estimate: ~{n_params * 2 / 1e9:.1f}GB (bf16 training with optimizer states ~3x)')"
382
  ]
383
  },
384
  {
@@ -395,7 +359,7 @@
395
  " tokenizer=hf_tokenizer,\n",
396
  " train_dataset=dataset,\n",
397
  " output_dir='./finance_pretrain_checkpoints',\n",
398
- " hub_model_id=None, # set to 'your-username/finance-domain-24m' to auto-push\n",
399
  " num_epochs=3 if USE_GPU else 1,\n",
400
  " per_device_batch_size=32 if USE_GPU else 4,\n",
401
  " gradient_accumulation_steps=4 if USE_GPU else 1,\n",
@@ -404,7 +368,8 @@
404
  " logging_steps=50 if USE_GPU else 10,\n",
405
  " save_steps=1000 if USE_GPU else 999999,\n",
406
  " bf16=USE_GPU,\n",
407
- " report_to='none',\n",
 
408
  " seed=42,\n",
409
  ")"
410
  ]
@@ -422,9 +387,7 @@
422
  "metadata": {},
423
  "outputs": [],
424
  "source": [
425
- "# Loss curve\n",
426
  "losses = [h['loss'] for h in trainer.state.log_history if 'loss' in h]\n",
427
- "\n",
428
  "print(f'Steps: {trainer.state.global_step:,}')\n",
429
  "print(f'Loss: {losses[0]:.4f} -> {losses[-1]:.4f} ({(1-losses[-1]/losses[0])*100:.1f}% reduction)')\n",
430
  "print(f'Min loss: {min(losses):.4f}')\n",
@@ -433,15 +396,9 @@
433
  "ax.plot(losses, linewidth=0.5, alpha=0.5, label='Per-step')\n",
434
  "window = max(len(losses) // 50, 1)\n",
435
  "if len(losses) > window:\n",
436
- " smoothed = pd.Series(losses).rolling(window=window, min_periods=1).mean()\n",
437
- " ax.plot(smoothed, linewidth=2, color='red', label=f'Smoothed (w={window})')\n",
438
- "ax.set_xlabel('Step')\n",
439
- "ax.set_ylabel('Loss')\n",
440
- "ax.set_title('Pre-Training Loss Curve')\n",
441
- "ax.legend()\n",
442
- "ax.grid(True, alpha=0.3)\n",
443
- "plt.tight_layout()\n",
444
- "plt.show()"
445
  ]
446
  },
447
  {
@@ -450,25 +407,18 @@
450
  "metadata": {},
451
  "outputs": [],
452
  "source": [
453
- "# Next-token prediction test\n",
454
  "model.eval()\n",
455
  "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
456
  "model = model.to(device)\n",
457
  "\n",
458
- "test_tokens = builder.tokenize_sequence(user_sequences[0][:3])\n",
459
- "test_ids = hf_tokenizer(' '.join(test_tokens), return_tensors='pt', add_special_tokens=False)['input_ids'].to(device)\n",
460
- "\n",
461
  "with torch.no_grad():\n",
462
- " logits = model(input_ids=test_ids).logits\n",
463
- " top5 = torch.topk(logits[0, -1, :], 5)\n",
464
  "\n",
465
  "print('Last 5 input tokens:')\n",
466
- "for tid in test_ids[0, -5:]:\n",
467
- " print(f\" {tid.item():5d} -> '{hf_tokenizer.decode([tid.item()])}'\")\n",
468
- "\n",
469
  "print('\\nTop-5 next token predictions:')\n",
470
- "for score, tid in zip(top5.values, top5.indices):\n",
471
- " print(f\" {tid.item():5d} -> '{hf_tokenizer.decode([tid.item()])}' (score={score.item():.3f})\")"
472
  ]
473
  },
474
  {
@@ -477,45 +427,35 @@
477
  "metadata": {},
478
  "outputs": [],
479
  "source": [
480
- "# User embedding visualization (t-SNE)\n",
481
  "n_sample = min(200, len(user_sequences))\n",
482
- "embeddings = []\n",
483
- "labels_sample = []\n",
484
- "\n",
485
  "for i in range(n_sample):\n",
486
- " tokens = builder.tokenize_sequence(user_sequences[i][:50])\n",
487
- " enc = hf_tokenizer(' '.join(tokens), return_tensors='pt', add_special_tokens=False,\n",
488
- " max_length=256, truncation=True, padding='max_length')\n",
489
  " with torch.no_grad():\n",
490
- " emb = model.get_user_embedding(enc['input_ids'].to(device), enc['attention_mask'].to(device))\n",
491
- " embeddings.append(emb.cpu().numpy().flatten())\n",
492
  " labels_sample.append(user_fraud_labels[i])\n",
493
  "\n",
494
- "embeddings = np.array(embeddings)\n",
495
- "labels_sample = np.array(labels_sample)\n",
496
  "print(f'Embeddings: {embeddings.shape}, Fraud: {labels_sample.sum()}/{len(labels_sample)}')\n",
497
  "\n",
498
  "if len(embeddings) >= 20:\n",
499
  " from sklearn.manifold import TSNE\n",
500
  " coords = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1)).fit_transform(embeddings)\n",
501
- " \n",
502
  " fig, ax = plt.subplots(figsize=(8, 6))\n",
503
  " for label, color, name in [(0, 'tab:green', 'Normal'), (1, 'tab:red', 'Fraud')]:\n",
504
  " mask = labels_sample == label\n",
505
  " ax.scatter(coords[mask, 0], coords[mask, 1], c=color, label=name, alpha=0.6, edgecolors='black', linewidth=0.3, s=30)\n",
506
- " ax.set_title('User Embeddings (t-SNE) — Pre-trained DomainTransformer')\n",
507
- " ax.legend()\n",
508
- " plt.tight_layout()\n",
509
- " plt.show()"
510
  ]
511
  },
512
  {
513
  "cell_type": "markdown",
514
  "metadata": {},
515
  "source": [
516
- "## Save Artifacts for Fine-Tuning Notebook\n",
517
- "\n",
518
- "Saves the pre-trained model, tokenizer, and user data so `02_finance_finetune.ipynb` can pick up where we left off."
519
  ]
520
  },
521
  {
@@ -524,26 +464,23 @@
524
  "metadata": {},
525
  "outputs": [],
526
  "source": [
527
- "# Save tokenizer\n",
528
  "hf_tokenizer.save_pretrained('./finance_tokenizer')\n",
529
  "builder.save('./finance_tokenizer')\n",
530
- "\n",
531
- "# Save model\n",
532
  "model.save_pretrained('./finance_pretrain_checkpoints/final')\n",
533
  "\n",
534
- "# Save user data\n",
535
- "artifacts = {\n",
536
- " 'user_sequences': user_sequences,\n",
537
- " 'user_ids': user_ids,\n",
538
- " 'user_fraud_labels': user_fraud_labels,\n",
539
- "}\n",
540
  "with open('./finance_artifacts.pkl', 'wb') as f:\n",
541
- " pickle.dump(artifacts, f)\n",
542
  "\n",
543
- "print('Saved: tokenizer, model, user data')\n",
544
- "print(f' ./finance_tokenizer/')\n",
545
- "print(f' ./finance_pretrain_checkpoints/final/')\n",
546
- "print(f' ./finance_artifacts.pkl')"
 
 
 
 
 
 
547
  ]
548
  },
549
  {
 
38
  "outputs": [],
39
  "source": [
40
  "# Uncomment and run once to install dependencies:\n",
41
+ "# !pip install datasets transformers torch accelerate tokenizers numpy pandas matplotlib scikit-learn wandb"
42
  ]
43
  },
44
  {
 
75
  "logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')\n",
76
  "print(f'torch: {torch.__version__}, CUDA: {torch.cuda.is_available()}')\n",
77
  "if torch.cuda.is_available():\n",
78
+ " print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB')"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "# wandb setup — logs persist even if notebook kernel disconnects\n",
88
+ "# Run `wandb login` in terminal first, or set WANDB_API_KEY env var\n",
89
+ "import wandb\n",
90
+ "wandb.login()\n",
91
+ "\n",
92
+ "WANDB_PROJECT = 'domainTokenizer' # all runs grouped under this project\n",
93
+ "os.environ['WANDB_PROJECT'] = WANDB_PROJECT\n",
94
+ "print(f'wandb project: {WANDB_PROJECT}')"
95
  ]
96
  },
97
  {
 
159
  "metadata": {},
160
  "outputs": [],
161
  "source": [
 
162
  "events_per_user = df.groupby('sender_account').size()\n",
163
  "print(f\"Events per user: min={events_per_user.min()}, max={events_per_user.max()}, \"\n",
164
  " f\"mean={events_per_user.mean():.1f}, median={events_per_user.median():.1f}\")\n",
 
166
  "print(f\"Users with 10+ events: {(events_per_user >= 10).sum():,}\")\n",
167
  "\n",
168
  "fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
 
169
  "axes[0].hist(np.log10(df['amount_ngn'].clip(lower=1)), bins=50, edgecolor='black', alpha=0.7)\n",
170
+ "axes[0].set_xlabel('log10(Amount NGN)'); axes[0].set_ylabel('Count'); axes[0].set_title('Amount Distribution (log scale)')\n",
 
 
 
171
  "axes[1].hist(events_per_user.clip(upper=50), bins=50, edgecolor='black', alpha=0.7)\n",
172
+ "axes[1].set_xlabel('Events per User'); axes[1].set_ylabel('Count'); axes[1].set_title('Events per User')\n",
 
 
 
173
  "df['transaction_type'].value_counts().head(10).plot(kind='barh', ax=axes[2])\n",
174
+ "axes[2].set_xlabel('Count'); axes[2].set_title('Transaction Types')\n",
175
+ "plt.tight_layout(); plt.show()"
 
 
 
176
  ]
177
  },
178
  {
 
195
  "outputs": [],
196
  "source": [
197
  "def row_to_event(row):\n",
 
198
  " dt = datetime.strptime(row['timestamp'][:19], '%Y-%m-%d %H:%M:%S')\n",
199
  " desc = f\"{row['merchant_category']} {row['transaction_type']}\"\n",
200
  " amt = row['amount_ngn']\n",
201
  " if row['transaction_type'] == 'withdrawal':\n",
202
  " amt = -abs(amt)\n",
203
+ " return {'amount_sign': amt, 'amount': amt, 'timestamp': dt, 'description': desc}\n",
204
+ "\n",
205
+ "print(f'Sample event: {row_to_event(df.iloc[0])}')"
 
 
 
 
 
 
206
  ]
207
  },
208
  {
 
213
  "source": [
214
  "%%time\n",
215
  "MIN_EVENTS = 5\n",
216
+ "MAX_EVENTS = 500\n",
 
 
 
 
217
  "\n",
218
+ "user_sequences, user_ids, user_fraud_labels = [], [], []\n",
219
  "for sender, group in df.sort_values('timestamp').groupby('sender_account'):\n",
220
  " if len(group) < MIN_EVENTS:\n",
221
  " continue\n",
 
226
  "\n",
227
  "print(f'Users with {MIN_EVENTS}+ events: {len(user_sequences):,}')\n",
228
  "print(f'Total events: {sum(len(s) for s in user_sequences):,}')\n",
229
+ "print(f'Events/user: min={min(len(s) for s in user_sequences)}, max={max(len(s) for s in user_sequences)}, mean={np.mean([len(s) for s in user_sequences]):.1f}')\n",
 
 
230
  "print(f'Fraud rate (user-level): {np.mean(user_fraud_labels)*100:.2f}%')"
231
  ]
232
  },
 
236
  "source": [
237
  "## Step 4 — Build Domain Tokenizer\n",
238
  "\n",
239
+ "Hybrid vocabulary: 97 special tokens (sign + amount bins + calendar) + BPE for descriptions."
 
240
  ]
241
  },
242
  {
 
254
  "text_corpus = [e['description'] for e in all_events]\n",
255
  "unique_descs = sorted(set(text_corpus))\n",
256
  "print(f'Unique descriptions: {len(unique_descs)}')\n",
257
+ "for d in unique_descs[:10]: print(f\" '{d}'\")\n",
258
+ "if len(unique_descs) > 10: print(f' ... and {len(unique_descs) - 10} more')\n",
 
 
 
 
 
 
 
259
  "\n",
260
+ "hf_tokenizer = builder.build(text_corpus=text_corpus, bpe_vocab_size=2000)\n",
261
  "print(f'\\nVocab size: {hf_tokenizer.vocab_size}')\n",
262
  "print(f'Stats: {builder.get_stats()}')"
263
  ]
 
268
  "metadata": {},
269
  "outputs": [],
270
  "source": [
 
271
  "print('--- Sample event tokenized ---')\n",
272
+ "for i, t in enumerate(builder.tokenize_event(user_sequences[0][0])): print(f' [{i}] {t}')\n",
 
 
273
  "\n",
274
  "print(f'\\n--- First user, first 3 events ---')\n",
275
  "seq_tokens = builder.tokenize_sequence(user_sequences[0][:3])\n",
276
+ "for i, t in enumerate(seq_tokens): print(f' [{i:3d}] {t}')\n",
 
277
  "\n",
278
  "seq_ids = hf_tokenizer(' '.join(seq_tokens), add_special_tokens=False)['input_ids']\n",
279
  "unk_id = hf_tokenizer.unk_token_id\n",
 
287
  "source": [
288
  "## Step 5 — Pack into CLM Training Dataset\n",
289
  "\n",
290
+ "Sequence packing: concatenate all user sequences, split into fixed-length blocks. 100% token utilization."
 
291
  ]
292
  },
293
  {
 
297
  "outputs": [],
298
  "source": [
299
  "%%time\n",
300
+ "BLOCK_SIZE = 512\n",
301
+ "dataset = prepare_clm_dataset(user_sequences, builder, hf_tokenizer, block_size=BLOCK_SIZE)\n",
 
 
 
 
 
302
  "print(f'Packed: {len(dataset):,} blocks x {BLOCK_SIZE} = {len(dataset)*BLOCK_SIZE:,} training tokens')"
303
  ]
304
  },
 
308
  "metadata": {},
309
  "outputs": [],
310
  "source": [
 
 
311
  "print(f'Sample block decoded (first 60 tokens):')\n",
312
+ "print(hf_tokenizer.decode(dataset[0]['input_ids'][:60]))\n",
313
  "\n",
 
314
  "all_ids = [i for row in dataset for i in row['input_ids']]\n",
315
  "counts = Counter(all_ids)\n",
316
+ "print(f'\\nTotal tokens: {len(all_ids):,}, Unique: {len(counts)}/{hf_tokenizer.vocab_size}, UNK: {counts.get(unk_id,0)} ({counts.get(unk_id,0)/len(all_ids)*100:.2f}%)')\n",
 
 
 
 
 
317
  "print(f'\\nTop 20 tokens:')\n",
318
  "for tid, count in counts.most_common(20):\n",
319
+ " print(f' {tid:5d} {count:8,} ({count/len(all_ids)*100:5.1f}%) {hf_tokenizer.decode([tid]).strip() or \"(space)\"}')"
 
 
320
  ]
321
  },
322
  {
 
325
  "source": [
326
  "## Step 6 — Pre-Train 24M DomainTransformer\n",
327
  "\n",
328
+ "Architecture:\n",
329
  "- GPT-style causal decoder, NoPE (no positional encoding)\n",
330
  "- 24M preset: d=512, 6 layers, 8 heads, FFN=2048\n",
331
+ "- Cosine LR schedule with warmup, AdamW\n",
332
+ "- CLM objective (next token prediction)\n",
333
+ "- wandb logging for persistent monitoring"
334
  ]
335
  },
336
  {
 
341
  "source": [
342
  "config = DomainTransformerConfig.from_preset('24m', vocab_size=hf_tokenizer.vocab_size)\n",
343
  "model = DomainTransformerForCausalLM(config)\n",
 
344
  "n_params = sum(p.numel() for p in model.parameters())\n",
345
+ "print(f'Model: {n_params:,} params | d={config.hidden_size}, L={config.num_hidden_layers}, H={config.num_attention_heads}')"
 
 
346
  ]
347
  },
348
  {
 
359
  " tokenizer=hf_tokenizer,\n",
360
  " train_dataset=dataset,\n",
361
  " output_dir='./finance_pretrain_checkpoints',\n",
362
+ " hub_model_id='rtferraz/finance-domain-24m',\n",
363
  " num_epochs=3 if USE_GPU else 1,\n",
364
  " per_device_batch_size=32 if USE_GPU else 4,\n",
365
  " gradient_accumulation_steps=4 if USE_GPU else 1,\n",
 
368
  " logging_steps=50 if USE_GPU else 10,\n",
369
  " save_steps=1000 if USE_GPU else 999999,\n",
370
  " bf16=USE_GPU,\n",
371
+ " report_to='wandb',\n",
372
+ " run_name='finance-pretrain-24m-3ep',\n",
373
  " seed=42,\n",
374
  ")"
375
  ]
 
387
  "metadata": {},
388
  "outputs": [],
389
  "source": [
 
390
  "losses = [h['loss'] for h in trainer.state.log_history if 'loss' in h]\n",
 
391
  "print(f'Steps: {trainer.state.global_step:,}')\n",
392
  "print(f'Loss: {losses[0]:.4f} -> {losses[-1]:.4f} ({(1-losses[-1]/losses[0])*100:.1f}% reduction)')\n",
393
  "print(f'Min loss: {min(losses):.4f}')\n",
 
396
  "ax.plot(losses, linewidth=0.5, alpha=0.5, label='Per-step')\n",
397
  "window = max(len(losses) // 50, 1)\n",
398
  "if len(losses) > window:\n",
399
+ " ax.plot(pd.Series(losses).rolling(window=window, min_periods=1).mean(), linewidth=2, color='red', label=f'Smoothed (w={window})')\n",
400
+ "ax.set_xlabel('Step'); ax.set_ylabel('Loss'); ax.set_title('Pre-Training Loss Curve')\n",
401
+ "ax.legend(); ax.grid(True, alpha=0.3); plt.tight_layout(); plt.show()"
 
 
 
 
 
 
402
  ]
403
  },
404
  {
 
407
  "metadata": {},
408
  "outputs": [],
409
  "source": [
 
410
  "model.eval()\n",
411
  "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
412
  "model = model.to(device)\n",
413
  "\n",
414
+ "test_ids = hf_tokenizer(' '.join(builder.tokenize_sequence(user_sequences[0][:3])), return_tensors='pt', add_special_tokens=False)['input_ids'].to(device)\n",
 
 
415
  "with torch.no_grad():\n",
416
+ " top5 = torch.topk(model(input_ids=test_ids).logits[0, -1, :], 5)\n",
 
417
  "\n",
418
  "print('Last 5 input tokens:')\n",
419
+ "for tid in test_ids[0, -5:]: print(f\" {tid.item():5d} -> '{hf_tokenizer.decode([tid.item()])}'\")\n",
 
 
420
  "print('\\nTop-5 next token predictions:')\n",
421
+ "for score, tid in zip(top5.values, top5.indices): print(f\" {tid.item():5d} -> '{hf_tokenizer.decode([tid.item()])}' (score={score.item():.3f})\")"
 
422
  ]
423
  },
424
  {
 
427
  "metadata": {},
428
  "outputs": [],
429
  "source": [
430
+ "# t-SNE user embeddings colored by fraud label\n",
431
  "n_sample = min(200, len(user_sequences))\n",
432
+ "embeddings, labels_sample = [], []\n",
 
 
433
  "for i in range(n_sample):\n",
434
+ " enc = hf_tokenizer(' '.join(builder.tokenize_sequence(user_sequences[i][:50])),\n",
435
+ " return_tensors='pt', add_special_tokens=False, max_length=256, truncation=True, padding='max_length')\n",
 
436
  " with torch.no_grad():\n",
437
+ " embeddings.append(model.get_user_embedding(enc['input_ids'].to(device), enc['attention_mask'].to(device)).cpu().numpy().flatten())\n",
 
438
  " labels_sample.append(user_fraud_labels[i])\n",
439
  "\n",
440
+ "embeddings = np.array(embeddings); labels_sample = np.array(labels_sample)\n",
 
441
  "print(f'Embeddings: {embeddings.shape}, Fraud: {labels_sample.sum()}/{len(labels_sample)}')\n",
442
  "\n",
443
  "if len(embeddings) >= 20:\n",
444
  " from sklearn.manifold import TSNE\n",
445
  " coords = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1)).fit_transform(embeddings)\n",
 
446
  " fig, ax = plt.subplots(figsize=(8, 6))\n",
447
  " for label, color, name in [(0, 'tab:green', 'Normal'), (1, 'tab:red', 'Fraud')]:\n",
448
  " mask = labels_sample == label\n",
449
  " ax.scatter(coords[mask, 0], coords[mask, 1], c=color, label=name, alpha=0.6, edgecolors='black', linewidth=0.3, s=30)\n",
450
+ " ax.set_title('User Embeddings (t-SNE) — Pre-trained DomainTransformer'); ax.legend()\n",
451
+ " plt.tight_layout(); plt.show()"
 
 
452
  ]
453
  },
454
  {
455
  "cell_type": "markdown",
456
  "metadata": {},
457
  "source": [
458
+ "## Save Artifacts"
 
 
459
  ]
460
  },
461
  {
 
464
  "metadata": {},
465
  "outputs": [],
466
  "source": [
 
467
  "hf_tokenizer.save_pretrained('./finance_tokenizer')\n",
468
  "builder.save('./finance_tokenizer')\n",
 
 
469
  "model.save_pretrained('./finance_pretrain_checkpoints/final')\n",
470
  "\n",
 
 
 
 
 
 
471
  "with open('./finance_artifacts.pkl', 'wb') as f:\n",
472
+ " pickle.dump({'user_sequences': user_sequences, 'user_ids': user_ids, 'user_fraud_labels': user_fraud_labels}, f)\n",
473
  "\n",
474
+ "print('Saved: ./finance_tokenizer/, ./finance_pretrain_checkpoints/final/, ./finance_artifacts.pkl')"
475
+ ]
476
+ },
477
+ {
478
+ "cell_type": "code",
479
+ "execution_count": null,
480
+ "metadata": {},
481
+ "outputs": [],
482
+ "source": [
483
+ "wandb.finish() # close wandb run cleanly"
484
  ]
485
  },
486
  {