rtferraz commited on
Commit
2c3ddfa
·
verified ·
1 Parent(s): 6e5b80d

Add 01_finance_pretrain.ipynb — Phase 3.1 notebook for pre-training on 5M Nigerian financial transactions

Browse files

29 cells (18 code + 11 markdown):
- Loads electricsheepafrica/Nigerian-Financial-Transactions dataset from HF Hub
- Data profiling with matplotlib visualizations
- Converts to FINANCE_SCHEMA, groups by sender_account
- Builds hybrid domain tokenizer (97 special + BPE)
- Packs sequences, trains 24M DomainTransformer
- Loss curves, next-token predictions, t-SNE user embeddings
- Saves artifacts for fine-tuning notebook

Auto-detects GPU (L4/CPU) and adjusts batch size/epochs accordingly.

Files changed (1) hide show
  1. notebooks/01_finance_pretrain.ipynb +581 -0
notebooks/01_finance_pretrain.ipynb ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 01 — Finance Pre-Training: Domain Tokenizer on Real Financial Transactions\n",
8
+ "\n",
9
+ "**Goal:** Pre-train a 24M-parameter DomainTransformer on 5M synthetic Nigerian financial transactions, demonstrating that the domainTokenizer pipeline works at scale on real-world data.\n",
10
+ "\n",
11
+ "**Dataset:** [electricsheepafrica/Nigerian-Financial-Transactions-and-Fraud-Detection-Dataset](https://huggingface.co/datasets/electricsheepafrica/Nigerian-Financial-Transactions-and-Fraud-Detection-Dataset) — 5M transactions, 45 features, fraud labels.\n",
12
+ "\n",
13
+ "**Pipeline:**\n",
14
+ "1. Load data from HuggingFace Hub\n",
15
+ "2. Explore and profile the dataset\n",
16
+ "3. Convert to FINANCE_SCHEMA events, group by user\n",
17
+ "4. Build domain tokenizer (special tokens + BPE)\n",
18
+ "5. Pack into CLM training dataset\n",
19
+ "6. Pre-train 24M DomainTransformer (NoPE, GPT-style)\n",
20
+ "7. Inspect learned representations\n",
21
+ "\n",
22
+ "**Hardware:** L4 GPU (24GB VRAM) — 24M model fits comfortably.\n",
23
+ "\n",
24
+ "**Reference:** Nubank nuFormer ([arXiv:2507.23267](https://arxiv.org/abs/2507.23267)) — same architecture pattern."
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "metadata": {},
30
+ "source": [
31
+ "## Setup"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "# Uncomment and run once to install dependencies:\n",
41
+ "# !pip install datasets transformers torch accelerate tokenizers numpy pandas matplotlib scikit-learn"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "import logging\n",
51
+ "import time\n",
52
+ "import pickle\n",
53
+ "from datetime import datetime\n",
54
+ "from collections import Counter\n",
55
+ "\n",
56
+ "import numpy as np\n",
57
+ "import pandas as pd\n",
58
+ "import matplotlib.pyplot as plt\n",
59
+ "import torch\n",
60
+ "from datasets import load_dataset\n",
61
+ "\n",
62
+ "# If running from cloned repo, add src/ to path\n",
63
+ "import sys, os\n",
64
+ "if os.path.exists('../src'):\n",
65
+ " sys.path.insert(0, '../src')\n",
66
+ "elif os.path.exists('src'):\n",
67
+ " sys.path.insert(0, 'src')\n",
68
+ "\n",
69
+ "from domain_tokenizer import (\n",
70
+ " DomainTokenizerBuilder, DomainTransformerConfig,\n",
71
+ " DomainTransformerForCausalLM, prepare_clm_dataset, pretrain_domain_model,\n",
72
+ ")\n",
73
+ "from domain_tokenizer.schemas import FINANCE_SCHEMA\n",
74
+ "\n",
75
+ "logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')\n",
76
+ "print(f'torch: {torch.__version__}, CUDA: {torch.cuda.is_available()}')\n",
77
+ "if torch.cuda.is_available():\n",
78
+ " print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f}GB')"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "metadata": {},
84
+ "source": [
85
+ "## Step 1 — Load Dataset from HuggingFace Hub\n",
86
+ "\n",
87
+ "5M synthetic Nigerian fintech transactions with 45 features including merchant categories, device info, risk scores, and fraud labels."
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "%%time\n",
97
+ "ds = load_dataset(\n",
98
+ " 'electricsheepafrica/Nigerian-Financial-Transactions-and-Fraud-Detection-Dataset',\n",
99
+ " split='train',\n",
100
+ ")\n",
101
+ "print(f'Loaded: {len(ds):,} transactions, {len(ds.column_names)} columns')"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "metadata": {},
108
+ "outputs": [],
109
+ "source": [
110
+ "df = ds.to_pandas()\n",
111
+ "print(f'Shape: {df.shape}')\n",
112
+ "df.head(3)"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "markdown",
117
+ "metadata": {},
118
+ "source": [
119
+ "## Step 2 — Data Profiling\n",
120
+ "\n",
121
+ "Understanding what we're tokenizing: user counts, amount distributions, transaction types, merchant categories."
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": null,
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "print(f\"Unique senders (users): {df['sender_account'].nunique():,}\")\n",
131
+ "print(f\"Timestamp range: {df['timestamp'].min()} to {df['timestamp'].max()}\")\n",
132
+ "print(f\"Amount range: {df['amount_ngn'].min():,.2f} to {df['amount_ngn'].max():,.2f} NGN\")\n",
133
+ "print(f\"Amount mean: {df['amount_ngn'].mean():,.2f}, median: {df['amount_ngn'].median():,.2f}\")\n",
134
+ "print(f\"\\nTransaction types:\\n{df['transaction_type'].value_counts().to_string()}\")\n",
135
+ "print(f\"\\nMerchant categories (top 15):\\n{df['merchant_category'].value_counts().head(15).to_string()}\")\n",
136
+ "print(f\"\\nFraud rate: {df['is_fraud'].mean()*100:.2f}%\")\n",
137
+ "print(f\"\\nPayment channels:\\n{df['payment_channel'].value_counts().to_string()}\")"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": null,
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "# Events per user distribution\n",
147
+ "events_per_user = df.groupby('sender_account').size()\n",
148
+ "print(f\"Events per user: min={events_per_user.min()}, max={events_per_user.max()}, \"\n",
149
+ " f\"mean={events_per_user.mean():.1f}, median={events_per_user.median():.1f}\")\n",
150
+ "print(f\"Users with 5+ events: {(events_per_user >= 5).sum():,}\")\n",
151
+ "print(f\"Users with 10+ events: {(events_per_user >= 10).sum():,}\")\n",
152
+ "\n",
153
+ "fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
154
+ "\n",
155
+ "axes[0].hist(np.log10(df['amount_ngn'].clip(lower=1)), bins=50, edgecolor='black', alpha=0.7)\n",
156
+ "axes[0].set_xlabel('log10(Amount NGN)')\n",
157
+ "axes[0].set_ylabel('Count')\n",
158
+ "axes[0].set_title('Amount Distribution (log scale)')\n",
159
+ "\n",
160
+ "axes[1].hist(events_per_user.clip(upper=50), bins=50, edgecolor='black', alpha=0.7)\n",
161
+ "axes[1].set_xlabel('Events per User')\n",
162
+ "axes[1].set_ylabel('Count')\n",
163
+ "axes[1].set_title('Events per User')\n",
164
+ "\n",
165
+ "df['transaction_type'].value_counts().head(10).plot(kind='barh', ax=axes[2])\n",
166
+ "axes[2].set_xlabel('Count')\n",
167
+ "axes[2].set_title('Transaction Types')\n",
168
+ "\n",
169
+ "plt.tight_layout()\n",
170
+ "plt.show()"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "markdown",
175
+ "metadata": {},
176
+ "source": [
177
+ "## Step 3 — Convert to FINANCE_SCHEMA Events\n",
178
+ "\n",
179
+ "Mapping:\n",
180
+ "- `timestamp` → CalendarTokenizer (month, day-of-week, day-of-month, hour)\n",
181
+ "- `amount_ngn` → SignTokenizer (credit/debit) + MagnitudeBucketTokenizer (21 quantile bins)\n",
182
+ "- `merchant_category` + `transaction_type` → BPE text description\n",
183
+ "- `sender_account` → user grouping key"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": null,
189
+ "metadata": {},
190
+ "outputs": [],
191
+ "source": [
192
+ "def row_to_event(row):\n",
193
+ " \"\"\"Convert a DataFrame row to a FINANCE_SCHEMA event dict.\"\"\"\n",
194
+ " dt = datetime.strptime(row['timestamp'][:19], '%Y-%m-%d %H:%M:%S')\n",
195
+ " desc = f\"{row['merchant_category']} {row['transaction_type']}\"\n",
196
+ " amt = row['amount_ngn']\n",
197
+ " if row['transaction_type'] == 'withdrawal':\n",
198
+ " amt = -abs(amt)\n",
199
+ " return {\n",
200
+ " 'amount_sign': amt,\n",
201
+ " 'amount': amt,\n",
202
+ " 'timestamp': dt,\n",
203
+ " 'description': desc,\n",
204
+ " }\n",
205
+ "\n",
206
+ "sample = row_to_event(df.iloc[0])\n",
207
+ "print(f'Sample event: {sample}')"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": null,
213
+ "metadata": {},
214
+ "outputs": [],
215
+ "source": [
216
+ "%%time\n",
217
+ "MIN_EVENTS = 5\n",
218
+ "MAX_EVENTS = 500 # cap to prevent very long sequences from dominating\n",
219
+ "\n",
220
+ "user_sequences = []\n",
221
+ "user_ids = []\n",
222
+ "user_fraud_labels = []\n",
223
+ "\n",
224
+ "for sender, group in df.sort_values('timestamp').groupby('sender_account'):\n",
225
+ " if len(group) < MIN_EVENTS:\n",
226
+ " continue\n",
227
+ " events = [row_to_event(row) for _, row in group.head(MAX_EVENTS).iterrows()]\n",
228
+ " user_sequences.append(events)\n",
229
+ " user_ids.append(sender)\n",
230
+ " user_fraud_labels.append(int(group['is_fraud'].any()))\n",
231
+ "\n",
232
+ "print(f'Users with {MIN_EVENTS}+ events: {len(user_sequences):,}')\n",
233
+ "print(f'Total events: {sum(len(s) for s in user_sequences):,}')\n",
234
+ "print(f'Events per user: min={min(len(s) for s in user_sequences)}, '\n",
235
+ " f'max={max(len(s) for s in user_sequences)}, '\n",
236
+ " f'mean={np.mean([len(s) for s in user_sequences]):.1f}')\n",
237
+ "print(f'Fraud rate (user-level): {np.mean(user_fraud_labels)*100:.2f}%')"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "markdown",
242
+ "metadata": {},
243
+ "source": [
244
+ "## Step 4 — Build Domain Tokenizer\n",
245
+ "\n",
246
+ "Hybrid vocabulary: 97 special tokens (sign + amount bins + calendar) + BPE for descriptions.\n",
247
+ "Following Nubank nuFormer's tokenization approach."
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "code",
252
+ "execution_count": null,
253
+ "metadata": {},
254
+ "outputs": [],
255
+ "source": [
256
+ "all_events = [e for seq in user_sequences for e in seq]\n",
257
+ "print(f'Total events for fitting: {len(all_events):,}')\n",
258
+ "\n",
259
+ "builder = DomainTokenizerBuilder(FINANCE_SCHEMA)\n",
260
+ "builder.fit(all_events)\n",
261
+ "\n",
262
+ "text_corpus = [e['description'] for e in all_events]\n",
263
+ "unique_descs = sorted(set(text_corpus))\n",
264
+ "print(f'Unique descriptions: {len(unique_descs)}')\n",
265
+ "for d in unique_descs[:10]:\n",
266
+ " print(f\" '{d}'\")\n",
267
+ "if len(unique_descs) > 10:\n",
268
+ " print(f' ... and {len(unique_descs) - 10} more')\n",
269
+ "\n",
270
+ "hf_tokenizer = builder.build(\n",
271
+ " text_corpus=text_corpus,\n",
272
+ " bpe_vocab_size=2000,\n",
273
+ ")\n",
274
+ "\n",
275
+ "print(f'\\nVocab size: {hf_tokenizer.vocab_size}')\n",
276
+ "print(f'Stats: {builder.get_stats()}')"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "code",
281
+ "execution_count": null,
282
+ "metadata": {},
283
+ "outputs": [],
284
+ "source": [
285
+ "# Inspect tokenized output\n",
286
+ "print('--- Sample event tokenized ---')\n",
287
+ "sample_tokens = builder.tokenize_event(user_sequences[0][0])\n",
288
+ "for i, t in enumerate(sample_tokens):\n",
289
+ " print(f' [{i}] {t}')\n",
290
+ "\n",
291
+ "print(f'\\n--- First user, first 3 events ---')\n",
292
+ "seq_tokens = builder.tokenize_sequence(user_sequences[0][:3])\n",
293
+ "for i, t in enumerate(seq_tokens):\n",
294
+ " print(f' [{i:3d}] {t}')\n",
295
+ "\n",
296
+ "seq_ids = hf_tokenizer(' '.join(seq_tokens), add_special_tokens=False)['input_ids']\n",
297
+ "unk_id = hf_tokenizer.unk_token_id\n",
298
+ "unk_count = sum(1 for i in seq_ids if i == unk_id)\n",
299
+ "print(f'\\nUNK rate: {unk_count}/{len(seq_ids)} ({unk_count/max(len(seq_ids),1)*100:.1f}%)')"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "markdown",
304
+ "metadata": {},
305
+ "source": [
306
+ "## Step 5 — Pack into CLM Training Dataset\n",
307
+ "\n",
308
+ "Sequence packing (run_clm.py pattern): concatenate all user sequences, split into fixed-length blocks.\n",
309
+ "100% token utilization, zero padding waste."
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": null,
315
+ "metadata": {},
316
+ "outputs": [],
317
+ "source": [
318
+ "%%time\n",
319
+ "BLOCK_SIZE = 512 # Nubank uses 2048; 512 for faster iteration\n",
320
+ "\n",
321
+ "dataset = prepare_clm_dataset(\n",
322
+ " user_sequences, builder, hf_tokenizer,\n",
323
+ " block_size=BLOCK_SIZE,\n",
324
+ ")\n",
325
+ "\n",
326
+ "print(f'Packed: {len(dataset):,} blocks x {BLOCK_SIZE} = {len(dataset)*BLOCK_SIZE:,} training tokens')"
327
+ ]
328
+ },
329
+ {
330
+ "cell_type": "code",
331
+ "execution_count": null,
332
+ "metadata": {},
333
+ "outputs": [],
334
+ "source": [
335
+ "# Decode a sample block to verify it looks right\n",
336
+ "sample_block = dataset[0]['input_ids']\n",
337
+ "print(f'Sample block decoded (first 60 tokens):')\n",
338
+ "print(hf_tokenizer.decode(sample_block[:60]))\n",
339
+ "\n",
340
+ "# Token frequency analysis\n",
341
+ "all_ids = [i for row in dataset for i in row['input_ids']]\n",
342
+ "counts = Counter(all_ids)\n",
343
+ "unk_pct = counts.get(unk_id, 0) / len(all_ids) * 100\n",
344
+ "\n",
345
+ "print(f'\\nTotal tokens: {len(all_ids):,}')\n",
346
+ "print(f'Unique token IDs used: {len(counts)}/{hf_tokenizer.vocab_size}')\n",
347
+ "print(f'UNK tokens: {counts.get(unk_id, 0):,} ({unk_pct:.2f}%)')\n",
348
+ "\n",
349
+ "print(f'\\nTop 20 tokens:')\n",
350
+ "for tid, count in counts.most_common(20):\n",
351
+ " tok_str = hf_tokenizer.decode([tid]).strip() or '(space/control)'\n",
352
+ " pct = count / len(all_ids) * 100\n",
353
+ " print(f' {tid:5d} {count:8,} ({pct:5.1f}%) {tok_str}')"
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "markdown",
358
+ "metadata": {},
359
+ "source": [
360
+ "## Step 6 — Pre-Train 24M DomainTransformer\n",
361
+ "\n",
362
+ "Architecture (Nubank nuFormer):\n",
363
+ "- GPT-style causal decoder, NoPE (no positional encoding)\n",
364
+ "- 24M preset: d=512, 6 layers, 8 heads, FFN=2048\n",
365
+ "- Cosine LR schedule with warmup, AdamW optimizer\n",
366
+ "- CLM objective (next token prediction on transaction sequences)"
367
+ ]
368
+ },
369
+ {
370
+ "cell_type": "code",
371
+ "execution_count": null,
372
+ "metadata": {},
373
+ "outputs": [],
374
+ "source": [
375
+ "config = DomainTransformerConfig.from_preset('24m', vocab_size=hf_tokenizer.vocab_size)\n",
376
+ "model = DomainTransformerForCausalLM(config)\n",
377
+ "\n",
378
+ "n_params = sum(p.numel() for p in model.parameters())\n",
379
+ "print(f'Model: {n_params:,} parameters')\n",
380
+ "print(f'Config: d={config.hidden_size}, L={config.num_hidden_layers}, H={config.num_attention_heads}')\n",
381
+ "print(f'VRAM estimate: ~{n_params * 2 / 1e9:.1f}GB (bf16 training with optimizer states ~3x)')"
382
+ ]
383
+ },
384
+ {
385
+ "cell_type": "code",
386
+ "execution_count": null,
387
+ "metadata": {},
388
+ "outputs": [],
389
+ "source": [
390
+ "%%time\n",
391
+ "USE_GPU = torch.cuda.is_available()\n",
392
+ "\n",
393
+ "trainer = pretrain_domain_model(\n",
394
+ " model=model,\n",
395
+ " tokenizer=hf_tokenizer,\n",
396
+ " train_dataset=dataset,\n",
397
+ " output_dir='./finance_pretrain_checkpoints',\n",
398
+ " hub_model_id=None, # set to 'your-username/finance-domain-24m' to auto-push\n",
399
+ " num_epochs=3 if USE_GPU else 1,\n",
400
+ " per_device_batch_size=32 if USE_GPU else 4,\n",
401
+ " gradient_accumulation_steps=4 if USE_GPU else 1,\n",
402
+ " learning_rate=3e-4,\n",
403
+ " warmup_steps=200 if USE_GPU else 10,\n",
404
+ " logging_steps=50 if USE_GPU else 10,\n",
405
+ " save_steps=1000 if USE_GPU else 999999,\n",
406
+ " bf16=USE_GPU,\n",
407
+ " report_to='none',\n",
408
+ " seed=42,\n",
409
+ ")"
410
+ ]
411
+ },
412
+ {
413
+ "cell_type": "markdown",
414
+ "metadata": {},
415
+ "source": [
416
+ "## Step 7 — Inspect Training Results"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "execution_count": null,
422
+ "metadata": {},
423
+ "outputs": [],
424
+ "source": [
425
+ "# Loss curve\n",
426
+ "losses = [h['loss'] for h in trainer.state.log_history if 'loss' in h]\n",
427
+ "\n",
428
+ "print(f'Steps: {trainer.state.global_step:,}')\n",
429
+ "print(f'Loss: {losses[0]:.4f} -> {losses[-1]:.4f} ({(1-losses[-1]/losses[0])*100:.1f}% reduction)')\n",
430
+ "print(f'Min loss: {min(losses):.4f}')\n",
431
+ "\n",
432
+ "fig, ax = plt.subplots(figsize=(10, 5))\n",
433
+ "ax.plot(losses, linewidth=0.5, alpha=0.5, label='Per-step')\n",
434
+ "window = max(len(losses) // 50, 1)\n",
435
+ "if len(losses) > window:\n",
436
+ " smoothed = pd.Series(losses).rolling(window=window, min_periods=1).mean()\n",
437
+ " ax.plot(smoothed, linewidth=2, color='red', label=f'Smoothed (w={window})')\n",
438
+ "ax.set_xlabel('Step')\n",
439
+ "ax.set_ylabel('Loss')\n",
440
+ "ax.set_title('Pre-Training Loss Curve')\n",
441
+ "ax.legend()\n",
442
+ "ax.grid(True, alpha=0.3)\n",
443
+ "plt.tight_layout()\n",
444
+ "plt.show()"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "code",
449
+ "execution_count": null,
450
+ "metadata": {},
451
+ "outputs": [],
452
+ "source": [
453
+ "# Next-token prediction test\n",
454
+ "model.eval()\n",
455
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
456
+ "model = model.to(device)\n",
457
+ "\n",
458
+ "test_tokens = builder.tokenize_sequence(user_sequences[0][:3])\n",
459
+ "test_ids = hf_tokenizer(' '.join(test_tokens), return_tensors='pt', add_special_tokens=False)['input_ids'].to(device)\n",
460
+ "\n",
461
+ "with torch.no_grad():\n",
462
+ " logits = model(input_ids=test_ids).logits\n",
463
+ " top5 = torch.topk(logits[0, -1, :], 5)\n",
464
+ "\n",
465
+ "print('Last 5 input tokens:')\n",
466
+ "for tid in test_ids[0, -5:]:\n",
467
+ " print(f\" {tid.item():5d} -> '{hf_tokenizer.decode([tid.item()])}'\")\n",
468
+ "\n",
469
+ "print('\\nTop-5 next token predictions:')\n",
470
+ "for score, tid in zip(top5.values, top5.indices):\n",
471
+ " print(f\" {tid.item():5d} -> '{hf_tokenizer.decode([tid.item()])}' (score={score.item():.3f})\")"
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "code",
476
+ "execution_count": null,
477
+ "metadata": {},
478
+ "outputs": [],
479
+ "source": [
480
+ "# User embedding visualization (t-SNE)\n",
481
+ "n_sample = min(200, len(user_sequences))\n",
482
+ "embeddings = []\n",
483
+ "labels_sample = []\n",
484
+ "\n",
485
+ "for i in range(n_sample):\n",
486
+ " tokens = builder.tokenize_sequence(user_sequences[i][:50])\n",
487
+ " enc = hf_tokenizer(' '.join(tokens), return_tensors='pt', add_special_tokens=False,\n",
488
+ " max_length=256, truncation=True, padding='max_length')\n",
489
+ " with torch.no_grad():\n",
490
+ " emb = model.get_user_embedding(enc['input_ids'].to(device), enc['attention_mask'].to(device))\n",
491
+ " embeddings.append(emb.cpu().numpy().flatten())\n",
492
+ " labels_sample.append(user_fraud_labels[i])\n",
493
+ "\n",
494
+ "embeddings = np.array(embeddings)\n",
495
+ "labels_sample = np.array(labels_sample)\n",
496
+ "print(f'Embeddings: {embeddings.shape}, Fraud: {labels_sample.sum()}/{len(labels_sample)}')\n",
497
+ "\n",
498
+ "if len(embeddings) >= 20:\n",
499
+ " from sklearn.manifold import TSNE\n",
500
+ " coords = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1)).fit_transform(embeddings)\n",
501
+ " \n",
502
+ " fig, ax = plt.subplots(figsize=(8, 6))\n",
503
+ " for label, color, name in [(0, 'tab:green', 'Normal'), (1, 'tab:red', 'Fraud')]:\n",
504
+ " mask = labels_sample == label\n",
505
+ " ax.scatter(coords[mask, 0], coords[mask, 1], c=color, label=name, alpha=0.6, edgecolors='black', linewidth=0.3, s=30)\n",
506
+ " ax.set_title('User Embeddings (t-SNE) — Pre-trained DomainTransformer')\n",
507
+ " ax.legend()\n",
508
+ " plt.tight_layout()\n",
509
+ " plt.show()"
510
+ ]
511
+ },
512
+ {
513
+ "cell_type": "markdown",
514
+ "metadata": {},
515
+ "source": [
516
+ "## Save Artifacts for Fine-Tuning Notebook\n",
517
+ "\n",
518
+ "Saves the pre-trained model, tokenizer, and user data so `02_finance_finetune.ipynb` can pick up where we left off."
519
+ ]
520
+ },
521
+ {
522
+ "cell_type": "code",
523
+ "execution_count": null,
524
+ "metadata": {},
525
+ "outputs": [],
526
+ "source": [
527
+ "# Save tokenizer\n",
528
+ "hf_tokenizer.save_pretrained('./finance_tokenizer')\n",
529
+ "builder.save('./finance_tokenizer')\n",
530
+ "\n",
531
+ "# Save model\n",
532
+ "model.save_pretrained('./finance_pretrain_checkpoints/final')\n",
533
+ "\n",
534
+ "# Save user data\n",
535
+ "artifacts = {\n",
536
+ " 'user_sequences': user_sequences,\n",
537
+ " 'user_ids': user_ids,\n",
538
+ " 'user_fraud_labels': user_fraud_labels,\n",
539
+ "}\n",
540
+ "with open('./finance_artifacts.pkl', 'wb') as f:\n",
541
+ " pickle.dump(artifacts, f)\n",
542
+ "\n",
543
+ "print('Saved: tokenizer, model, user data')\n",
544
+ "print(f' ./finance_tokenizer/')\n",
545
+ "print(f' ./finance_pretrain_checkpoints/final/')\n",
546
+ "print(f' ./finance_artifacts.pkl')"
547
+ ]
548
+ },
549
+ {
550
+ "cell_type": "markdown",
551
+ "metadata": {},
552
+ "source": [
553
+ "## Summary\n",
554
+ "\n",
555
+ "| Metric | Value |\n",
556
+ "|--------|-------|\n",
557
+ "| Dataset | Nigerian Financial Transactions (5M) |\n",
558
+ "| Users (5+ events) | *see output above* |\n",
559
+ "| Training tokens | *see output above* |\n",
560
+ "| Model | DomainTransformer 24M (NoPE, GPT-style) |\n",
561
+ "| Final loss | *see output above* |\n",
562
+ "| UNK rate | *see output above* |\n",
563
+ "\n",
564
+ "**Next:** `02_finance_finetune.ipynb` — Fine-tune for fraud detection with JointFusionModel, compare vs LightGBM."
565
+ ]
566
+ }
567
+ ],
568
+ "metadata": {
569
+ "kernelspec": {
570
+ "display_name": "Python 3",
571
+ "language": "python",
572
+ "name": "python3"
573
+ },
574
+ "language_info": {
575
+ "name": "python",
576
+ "version": "3.12.0"
577
+ }
578
+ },
579
+ "nbformat": 4,
580
+ "nbformat_minor": 4
581
+ }