rtferraz commited on
Commit
2410b7e
·
verified ·
1 Parent(s): a9c4a62

Update 02_ecommerce notebook: add HF login, memory-free cell, subsample option for <64GB RAM machines

Browse files
Files changed (1) hide show
  1. notebooks/02_ecommerce_pretrain.ipynb +99 -141
notebooks/02_ecommerce_pretrain.ipynb CHANGED
@@ -8,16 +8,11 @@
8
  "\n",
9
  "**Goal:** Pre-train a 24M DomainTransformer on real e-commerce behavioral sequences (view → cart → purchase funnels) where sequential patterns actually exist.\n",
10
  "\n",
11
- "**Dataset:** [REES46 Multi-Category Store](https://huggingface.co/datasets/kevykibbz/ecommerce-behavior-data-from-multi-category-store_oct-nov_2019) — ~42M events, real user behavior, Nov 2019.\n",
12
  "\n",
13
- "**Why this dataset after the finance experiment:**\n",
14
- "- The Nigerian finance dataset had only 84 unique descriptions and no sequential dependencies → loss plateaued at 6.9\n",
15
- "- REES46 has millions of products, rich category hierarchies, view→cart→purchase funnels, and diverse browsing patterns\n",
16
- "- This is where domain tokenization should genuinely prove itself\n",
17
  "\n",
18
- "**Lesson applied from finance report:** We run a **sequential entropy check** before training to verify there's learnable structure.\n",
19
- "\n",
20
- "**Hardware:** L4 GPU (24GB VRAM)"
21
  ]
22
  },
23
  {
@@ -33,7 +28,7 @@
33
  "metadata": {},
34
  "outputs": [],
35
  "source": [
36
- "# !pip install datasets transformers torch accelerate tokenizers numpy pandas matplotlib scikit-learn wandb"
37
  ]
38
  },
39
  {
@@ -42,7 +37,7 @@
42
  "metadata": {},
43
  "outputs": [],
44
  "source": [
45
- "import logging, time, pickle, os, sys\n",
46
  "from datetime import datetime\n",
47
  "from collections import Counter\n",
48
  "\n",
@@ -73,6 +68,18 @@
73
  "metadata": {},
74
  "outputs": [],
75
  "source": [
 
 
 
 
 
 
 
 
 
 
 
 
76
  "import wandb\n",
77
  "wandb.login()\n",
78
  "os.environ['WANDB_PROJECT'] = 'domainTokenizer'\n",
@@ -85,7 +92,7 @@
85
  "source": [
86
  "## Step 1 — Load Dataset\n",
87
  "\n",
88
- "42M events, 2GB. We load it all and then subsample users for manageable training."
89
  ]
90
  },
91
  {
@@ -99,18 +106,17 @@
99
  " 'kevykibbz/ecommerce-behavior-data-from-multi-category-store_oct-nov_2019',\n",
100
  " split='train',\n",
101
  ")\n",
102
- "print(f'Loaded: {len(ds):,} events, columns: {ds.column_names}')"
103
- ]
104
- },
105
- {
106
- "cell_type": "code",
107
- "execution_count": null,
108
- "metadata": {},
109
- "outputs": [],
110
- "source": [
111
  "df = ds.to_pandas()\n",
112
- "print(f'Shape: {df.shape}')\n",
113
- "df.head(3)"
114
  ]
115
  },
116
  {
@@ -133,7 +139,6 @@
133
  "print(f\"Price range: {df['price'].min():.2f} to {df['price'].max():.2f}\")\n",
134
  "print(f\"\\nEvent types:\\n{df['event_type'].value_counts().to_string()}\")\n",
135
  "print(f\"\\nCategory codes (top 15):\\n{df['category_code'].value_counts().head(15).to_string()}\")\n",
136
- "print(f\"\\nBrands (top 10):\\n{df['brand'].value_counts().head(10).to_string()}\")\n",
137
  "print(f\"\\nNull rates:\")\n",
138
  "for col in df.columns:\n",
139
  " null_pct = df[col].isnull().mean() * 100\n",
@@ -151,7 +156,6 @@
151
  " f\"mean={events_per_user.mean():.1f}, median={events_per_user.median():.0f}\")\n",
152
  "print(f\"Users with 10+ events: {(events_per_user >= 10).sum():,}\")\n",
153
  "print(f\"Users with 20+ events: {(events_per_user >= 20).sum():,}\")\n",
154
- "print(f\"Users with 50+ events: {(events_per_user >= 50).sum():,}\")\n",
155
  "\n",
156
  "fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
157
  "axes[0].hist(np.log10(df['price'].clip(lower=0.01)), bins=50, edgecolor='black', alpha=0.7)\n",
@@ -169,9 +173,7 @@
169
  "source": [
170
  "## Step 3 — Sequential Entropy Check\n",
171
  "\n",
172
- "**Lesson from finance experiment:** Before committing GPU time, verify the data has learnable sequential patterns.\n",
173
- "\n",
174
- "We check: is `P(event_type_t | event_type_{t-1})` different from `P(event_type_t)`? If yes → sequential structure exists."
175
  ]
176
  },
177
  {
@@ -180,25 +182,19 @@
180
  "metadata": {},
181
  "outputs": [],
182
  "source": [
183
- "# Compute bigram transition probabilities for event_type\n",
184
- "# Sample 50K users for speed\n",
185
  "sample_users = df['user_id'].drop_duplicates().sample(min(50000, df['user_id'].nunique()), random_state=42)\n",
186
  "sample_df = df[df['user_id'].isin(sample_users)].sort_values(['user_id', 'event_time'])\n",
187
  "\n",
188
- "# Unigram distribution\n",
189
  "unigram = sample_df['event_type'].value_counts(normalize=True)\n",
190
  "H_unigram = -(unigram * np.log2(unigram)).sum()\n",
191
  "\n",
192
- "# Bigram transitions within each user\n",
193
- "bigrams = Counter()\n",
194
- "prev_context = Counter()\n",
195
  "for uid, group in sample_df.groupby('user_id'):\n",
196
  " events = group['event_type'].tolist()\n",
197
  " for i in range(1, len(events)):\n",
198
  " bigrams[(events[i-1], events[i])] += 1\n",
199
  " prev_context[events[i-1]] += 1\n",
200
  "\n",
201
- "# Conditional entropy H(event_t | event_{t-1})\n",
202
  "H_conditional = 0\n",
203
  "total_bigrams = sum(bigrams.values())\n",
204
  "for (prev, curr), count in bigrams.items():\n",
@@ -207,21 +203,17 @@
207
  " H_conditional -= p_joint * np.log2(p_cond)\n",
208
  "\n",
209
  "mutual_info = H_unigram - H_conditional\n",
210
- "print(f'Marginal entropy H(event_type): {H_unigram:.3f} bits')\n",
211
- "print(f'Conditional entropy H(event_type | prev): {H_conditional:.3f} bits')\n",
212
- "print(f'Mutual information I(t; t-1): {mutual_info:.3f} bits')\n",
213
- "print(f'Predictability gain: {mutual_info/H_unigram*100:.1f}%')\n",
214
- "print(f'\\n{\"✅ Sequential structure detected\" if mutual_info > 0.1 else \"⚠️ Weak sequential structure\"}')\n",
215
- "\n",
216
- "# Show transition matrix\n",
217
- "print(f'\\nTransition probabilities (P(next | current)):')\n",
218
- "event_types = sorted(unigram.index)\n",
219
- "for prev in event_types:\n",
220
- " transitions = {}\n",
221
- " for curr in event_types:\n",
222
- " if prev_context[prev] > 0:\n",
223
- " transitions[curr] = bigrams.get((prev, curr), 0) / prev_context[prev]\n",
224
- " row = ' | '.join(f'{curr}: {p:.2f}' for curr, p in sorted(transitions.items(), key=lambda x: -x[1]) if p > 0.01)\n",
225
  " print(f' After {prev:20s} → {row}')"
226
  ]
227
  },
@@ -229,9 +221,7 @@
229
  "cell_type": "markdown",
230
  "metadata": {},
231
  "source": [
232
- "## Step 4 — Build E-Commerce Schema and Convert Events\n",
233
- "\n",
234
- "Custom schema for this dataset — maps directly to the available columns."
235
  ]
236
  },
237
  {
@@ -240,7 +230,6 @@
240
  "metadata": {},
241
  "outputs": [],
242
  "source": [
243
- "# Custom schema matching REES46 columns\n",
244
  "ECOMMERCE_REES46_SCHEMA = DomainSchema(\n",
245
  " name='ecommerce_rees46',\n",
246
  " description='REES46 e-commerce behavioral event schema',\n",
@@ -248,12 +237,10 @@
248
  " FieldSpec(name='event_type', field_type=FieldType.CATEGORICAL_FIXED, prefix='EVT',\n",
249
  " categories=['view', 'cart', 'remove_from_cart', 'purchase']),\n",
250
  " FieldSpec(name='price', field_type=FieldType.NUMERICAL_CONTINUOUS, prefix='PRICE', n_bins=21),\n",
251
- " FieldSpec(name='category', field_type=FieldType.TEXT, prefix='CAT'), # hierarchical text like 'electronics.smartphone'\n",
252
- " FieldSpec(name='timestamp', field_type=FieldType.TEMPORAL,\n",
253
- " calendar_fields=['dow', 'hour']), # dow + hour capture shopping patterns\n",
254
  " ],\n",
255
  ")\n",
256
- "\n",
257
  "print(ECOMMERCE_REES46_SCHEMA.summary())"
258
  ]
259
  },
@@ -265,16 +252,10 @@
265
  "source": [
266
  "def row_to_event(row):\n",
267
  " dt = datetime.strptime(row['event_time'][:19], '%Y-%m-%dT%H:%M:%S')\n",
268
- " # Use category_code if available, else brand, else 'unknown'\n",
269
  " cat = row['category_code'] if pd.notna(row['category_code']) else (row['brand'] if pd.notna(row['brand']) else 'unknown')\n",
270
- " return {\n",
271
- " 'event_type': row['event_type'],\n",
272
- " 'price': row['price'],\n",
273
- " 'category': cat,\n",
274
- " 'timestamp': dt,\n",
275
- " }\n",
276
  "\n",
277
- "print(f'Sample event: {row_to_event(df.iloc[0])}')"
278
  ]
279
  },
280
  {
@@ -284,42 +265,48 @@
284
  "outputs": [],
285
  "source": [
286
  "%%time\n",
287
- "# Subsample: take users with 10+ events, cap at 100K users for training speed\n",
288
  "MIN_EVENTS = 10\n",
289
  "MAX_EVENTS = 200\n",
290
  "MAX_USERS = 100_000\n",
291
  "\n",
292
- "# Pre-filter to users with enough events\n",
293
  "user_counts = df.groupby('user_id').size()\n",
294
  "eligible_users = user_counts[user_counts >= MIN_EVENTS].index\n",
295
  "print(f'Users with {MIN_EVENTS}+ events: {len(eligible_users):,}')\n",
296
  "\n",
297
- "# Subsample if needed\n",
298
  "if len(eligible_users) > MAX_USERS:\n",
299
  " eligible_users = pd.Series(eligible_users).sample(MAX_USERS, random_state=42).values\n",
300
  " print(f'Subsampled to {MAX_USERS:,} users')\n",
301
  "\n",
302
- "# Build user sequences\n",
303
- "user_sequences = []\n",
304
- "user_ids = []\n",
305
  "filtered_df = df[df['user_id'].isin(eligible_users)].sort_values(['user_id', 'event_time'])\n",
306
  "\n",
 
307
  "for uid, group in filtered_df.groupby('user_id'):\n",
308
- " events = [row_to_event(row) for _, row in group.head(MAX_EVENTS).iterrows()]\n",
309
- " user_sequences.append(events)\n",
310
  " user_ids.append(uid)\n",
311
  "\n",
312
- "total_events = sum(len(s) for s in user_sequences)\n",
313
- "print(f'Users: {len(user_sequences):,}')\n",
314
- "print(f'Total events: {total_events:,}')\n",
315
  "print(f'Events/user: min={min(len(s) for s in user_sequences)}, max={max(len(s) for s in user_sequences)}, mean={np.mean([len(s) for s in user_sequences]):.1f}')"
316
  ]
317
  },
 
 
 
 
 
 
 
 
 
 
 
 
318
  {
319
  "cell_type": "markdown",
320
  "metadata": {},
321
  "source": [
322
- "## Step 5 — Build Domain Tokenizer"
 
 
323
  ]
324
  },
325
  {
@@ -334,16 +321,16 @@
334
  "builder = DomainTokenizerBuilder(ECOMMERCE_REES46_SCHEMA)\n",
335
  "builder.fit(all_events)\n",
336
  "\n",
337
- "# Text corpus = category codes (hierarchical) — rich BPE vocabulary\n",
338
  "text_corpus = [e['category'] for e in all_events]\n",
339
  "unique_cats = sorted(set(text_corpus))\n",
340
  "print(f'Unique category strings: {len(unique_cats):,}')\n",
341
- "for c in unique_cats[:15]: print(f\" '{c}'\")\n",
342
- "if len(unique_cats) > 15: print(f' ... and {len(unique_cats)-15} more')\n",
343
  "\n",
344
  "hf_tokenizer = builder.build(text_corpus=text_corpus, bpe_vocab_size=4000)\n",
345
  "print(f'\\nVocab size: {hf_tokenizer.vocab_size}')\n",
346
- "print(f'Stats: {builder.get_stats()}')"
 
347
  ]
348
  },
349
  {
@@ -352,17 +339,17 @@
352
  "metadata": {},
353
  "outputs": [],
354
  "source": [
355
- "# Inspect tokenization\n",
356
  "print('--- Sample event ---')\n",
357
  "for i, t in enumerate(builder.tokenize_event(user_sequences[0][0])): print(f' [{i}] {t}')\n",
358
  "\n",
359
- "print(f'\\n--- First user, first 5 events ---')\n",
360
- "seq_tokens = builder.tokenize_sequence(user_sequences[0][:5])\n",
361
  "for i, t in enumerate(seq_tokens): print(f' [{i:3d}] {t}')\n",
362
  "\n",
363
  "seq_ids = hf_tokenizer(' '.join(seq_tokens), add_special_tokens=False)['input_ids']\n",
364
  "unk_count = sum(1 for i in seq_ids if i == hf_tokenizer.unk_token_id)\n",
365
- "print(f'\\nUNK rate: {unk_count}/{len(seq_ids)} ({unk_count/max(len(seq_ids),1)*100:.1f}%)')"
366
  ]
367
  },
368
  {
@@ -390,12 +377,13 @@
390
  "metadata": {},
391
  "outputs": [],
392
  "source": [
393
- "print(f'Sample block: {hf_tokenizer.decode(dataset[0][\"input_ids\"][:50])}')\n",
 
394
  "\n",
395
- "all_ids = [i for row in dataset for i in row['input_ids']]\n",
396
- "counts = Counter(all_ids)\n",
397
- "unk_id = hf_tokenizer.unk_token_id\n",
398
- "print(f'\\nTokens: {len(all_ids):,}, Unique: {len(counts)}/{hf_tokenizer.vocab_size}, UNK: {counts.get(unk_id,0)} ({counts.get(unk_id,0)/len(all_ids)*100:.2f}%)')"
399
  ]
400
  },
401
  {
@@ -417,6 +405,12 @@
417
  "source": [
418
  "%%time\n",
419
  "USE_GPU = torch.cuda.is_available()\n",
 
 
 
 
 
 
420
  "\n",
421
  "trainer = pretrain_domain_model(\n",
422
  " model=model,\n",
@@ -431,7 +425,8 @@
431
  " warmup_steps=200 if USE_GPU else 10,\n",
432
  " logging_steps=50 if USE_GPU else 10,\n",
433
  " save_steps=2000 if USE_GPU else 999999,\n",
434
- " bf16=USE_GPU,\n",
 
435
  " report_to='wandb',\n",
436
  " run_name='ecommerce-pretrain-24m-3ep',\n",
437
  " seed=42,\n",
@@ -456,7 +451,6 @@
456
  "print(f'Loss: {losses[0]:.4f} -> {losses[-1]:.4f} ({(1-losses[-1]/losses[0])*100:.1f}% reduction)')\n",
457
  "print(f'Min loss: {min(losses):.4f}')\n",
458
  "\n",
459
- "# Compare to random chance: -ln(1/vocab_size)\n",
460
  "random_loss = np.log(hf_tokenizer.vocab_size)\n",
461
  "print(f'Random chance loss: {random_loss:.4f}')\n",
462
  "print(f'Model vs random: {\"✅ Better\" if losses[-1] < random_loss else \"❌ Worse\"} ({losses[-1]:.2f} vs {random_loss:.2f})')\n",
@@ -465,8 +459,8 @@
465
  "ax.plot(losses, linewidth=0.5, alpha=0.5, label='Per-step')\n",
466
  "window = max(len(losses) // 50, 1)\n",
467
  "if len(losses) > window:\n",
468
- " ax.plot(pd.Series(losses).rolling(window=window, min_periods=1).mean(), linewidth=2, color='red', label=f'Smoothed')\n",
469
- "ax.axhline(y=random_loss, color='gray', linestyle='--', label=f'Random chance ({random_loss:.2f})')\n",
470
  "ax.set_xlabel('Step'); ax.set_ylabel('Loss'); ax.set_title('E-Commerce Pre-Training Loss')\n",
471
  "ax.legend(); ax.grid(True, alpha=0.3); plt.tight_layout(); plt.show()"
472
  ]
@@ -477,7 +471,6 @@
477
  "metadata": {},
478
  "outputs": [],
479
  "source": [
480
- "# Next-token predictions\n",
481
  "model.eval()\n",
482
  "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
483
  "model = model.to(device)\n",
@@ -499,9 +492,9 @@
499
  "metadata": {},
500
  "outputs": [],
501
  "source": [
502
- "# User embeddings — can we see behavioral clusters?\n",
503
  "n_sample = min(500, len(user_sequences))\n",
504
- "embeddings, event_counts, purchase_rates = [], [], []\n",
505
  "\n",
506
  "for i in range(n_sample):\n",
507
  " enc = hf_tokenizer(' '.join(builder.tokenize_sequence(user_sequences[i][:50])),\n",
@@ -509,30 +502,24 @@
509
  " with torch.no_grad():\n",
510
  " embeddings.append(model.get_user_embedding(enc['input_ids'].to(device), enc['attention_mask'].to(device)).cpu().numpy().flatten())\n",
511
  " events = user_sequences[i]\n",
512
- " event_counts.append(len(events))\n",
513
  " purchase_rates.append(sum(1 for e in events if e['event_type'] == 'purchase') / len(events))\n",
514
  "\n",
515
- "embeddings = np.array(embeddings)\n",
516
- "purchase_rates = np.array(purchase_rates)\n",
517
  "print(f'Embeddings: {embeddings.shape}')\n",
518
  "\n",
519
- "if len(embeddings) >= 20:\n",
520
- " from sklearn.manifold import TSNE\n",
521
- " coords = TSNE(n_components=2, random_state=42, perplexity=30).fit_transform(embeddings)\n",
522
- " \n",
523
- " fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
524
- " sc1 = axes[0].scatter(coords[:, 0], coords[:, 1], c=purchase_rates, cmap='RdYlGn', alpha=0.6, s=20)\n",
525
- " axes[0].set_title('User Embeddings — Colored by Purchase Rate'); plt.colorbar(sc1, ax=axes[0], label='Purchase Rate')\n",
526
- " sc2 = axes[1].scatter(coords[:, 0], coords[:, 1], c=np.log10(np.array(event_counts[:n_sample])), cmap='viridis', alpha=0.6, s=20)\n",
527
- " axes[1].set_title('User Embeddings — Colored by Activity Level'); plt.colorbar(sc2, ax=axes[1], label='log10(Events)')\n",
528
- " plt.tight_layout(); plt.show()"
529
  ]
530
  },
531
  {
532
  "cell_type": "markdown",
533
  "metadata": {},
534
  "source": [
535
- "## Save Artifacts"
536
  ]
537
  },
538
  {
@@ -544,40 +531,11 @@
544
  "hf_tokenizer.save_pretrained('./ecommerce_tokenizer')\n",
545
  "builder.save('./ecommerce_tokenizer')\n",
546
  "model.save_pretrained('./ecommerce_pretrain_checkpoints/final')\n",
547
- "\n",
548
  "with open('./ecommerce_artifacts.pkl', 'wb') as f:\n",
549
  " pickle.dump({'user_sequences': user_sequences, 'user_ids': user_ids}, f)\n",
550
- "\n",
551
- "print('Saved: ./ecommerce_tokenizer/, ./ecommerce_pretrain_checkpoints/final/, ./ecommerce_artifacts.pkl')"
552
- ]
553
- },
554
- {
555
- "cell_type": "code",
556
- "execution_count": null,
557
- "metadata": {},
558
- "outputs": [],
559
- "source": [
560
  "wandb.finish()"
561
  ]
562
- },
563
- {
564
- "cell_type": "markdown",
565
- "metadata": {},
566
- "source": [
567
- "## Summary\n",
568
- "\n",
569
- "| Metric | Value |\n",
570
- "|--------|-------|\n",
571
- "| Dataset | REES46 E-Commerce (42M events) |\n",
572
- "| Users (sampled) | *see above* |\n",
573
- "| Training tokens | *see above* |\n",
574
- "| Sequential entropy gain | *see Step 3* |\n",
575
- "| Model | DomainTransformer 24M (NoPE) |\n",
576
- "| Final loss | *see above* |\n",
577
- "| Loss vs random chance | *see above* |\n",
578
- "\n",
579
- "**Next:** `03_ecommerce_finetune.ipynb` — Fine-tune for next-purchase prediction."
580
- ]
581
  }
582
  ],
583
  "metadata": {
 
8
  "\n",
9
  "**Goal:** Pre-train a 24M DomainTransformer on real e-commerce behavioral sequences (view → cart → purchase funnels) where sequential patterns actually exist.\n",
10
  "\n",
11
+ "**Dataset:** [REES46 Multi-Category Store](https://huggingface.co/datasets/kevykibbz/ecommerce-behavior-data-from-multi-category-store_oct-nov_2019) — ~110M events, real user behavior.\n",
12
  "\n",
13
+ "**Critical fix applied:** Uses Whitespace pre-tokenizer (not ByteLevel) to avoid the 42% UNK bug from the first run.\n",
 
 
 
14
  "\n",
15
+ "**Hardware:** L4/T4 GPU. For machines with <64GB RAM, we subsample to 10M events."
 
 
16
  ]
17
  },
18
  {
 
28
  "metadata": {},
29
  "outputs": [],
30
  "source": [
31
+ "# !pip install datasets transformers torch accelerate tokenizers numpy pandas matplotlib scikit-learn wandb huggingface_hub"
32
  ]
33
  },
34
  {
 
37
  "metadata": {},
38
  "outputs": [],
39
  "source": [
40
+ "import logging, time, pickle, os, sys, gc\n",
41
  "from datetime import datetime\n",
42
  "from collections import Counter\n",
43
  "\n",
 
68
  "metadata": {},
69
  "outputs": [],
70
  "source": [
71
+ "# HuggingFace login — needed for push_to_hub\n",
72
+ "from huggingface_hub import login\n",
73
+ "login() # prompts for token or reads from ~/.cache/huggingface/token"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "# wandb login — logs persist even if notebook disconnects\n",
83
  "import wandb\n",
84
  "wandb.login()\n",
85
  "os.environ['WANDB_PROJECT'] = 'domainTokenizer'\n",
 
92
  "source": [
93
  "## Step 1 — Load Dataset\n",
94
  "\n",
95
+ "**IMPORTANT:** Full dataset is 110M rows (~25GB in RAM). If your machine has <64GB RAM, subsample at load time."
96
  ]
97
  },
98
  {
 
106
  " 'kevykibbz/ecommerce-behavior-data-from-multi-category-store_oct-nov_2019',\n",
107
  " split='train',\n",
108
  ")\n",
109
+ "print(f'Full dataset: {len(ds):,} events')\n",
110
+ "\n",
111
+ "# Subsample to 10M events if RAM < 64GB (most machines)\n",
112
+ "MAX_EVENTS_LOAD = 10_000_000\n",
113
+ "if len(ds) > MAX_EVENTS_LOAD:\n",
114
+ " ds = ds.shuffle(seed=42).select(range(MAX_EVENTS_LOAD))\n",
115
+ " print(f'Subsampled to {len(ds):,} events (RAM-safe)')\n",
116
+ "\n",
 
117
  "df = ds.to_pandas()\n",
118
+ "del ds; gc.collect() # free the Arrow dataset\n",
119
+ "print(f'DataFrame: {df.shape}, ~{df.memory_usage(deep=True).sum()/1e9:.1f}GB RAM')"
120
  ]
121
  },
122
  {
 
139
  "print(f\"Price range: {df['price'].min():.2f} to {df['price'].max():.2f}\")\n",
140
  "print(f\"\\nEvent types:\\n{df['event_type'].value_counts().to_string()}\")\n",
141
  "print(f\"\\nCategory codes (top 15):\\n{df['category_code'].value_counts().head(15).to_string()}\")\n",
 
142
  "print(f\"\\nNull rates:\")\n",
143
  "for col in df.columns:\n",
144
  " null_pct = df[col].isnull().mean() * 100\n",
 
156
  " f\"mean={events_per_user.mean():.1f}, median={events_per_user.median():.0f}\")\n",
157
  "print(f\"Users with 10+ events: {(events_per_user >= 10).sum():,}\")\n",
158
  "print(f\"Users with 20+ events: {(events_per_user >= 20).sum():,}\")\n",
 
159
  "\n",
160
  "fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
161
  "axes[0].hist(np.log10(df['price'].clip(lower=0.01)), bins=50, edgecolor='black', alpha=0.7)\n",
 
173
  "source": [
174
  "## Step 3 — Sequential Entropy Check\n",
175
  "\n",
176
+ "Verify learnable sequential patterns exist before committing GPU time."
 
 
177
  ]
178
  },
179
  {
 
182
  "metadata": {},
183
  "outputs": [],
184
  "source": [
 
 
185
  "sample_users = df['user_id'].drop_duplicates().sample(min(50000, df['user_id'].nunique()), random_state=42)\n",
186
  "sample_df = df[df['user_id'].isin(sample_users)].sort_values(['user_id', 'event_time'])\n",
187
  "\n",
 
188
  "unigram = sample_df['event_type'].value_counts(normalize=True)\n",
189
  "H_unigram = -(unigram * np.log2(unigram)).sum()\n",
190
  "\n",
191
+ "bigrams, prev_context = Counter(), Counter()\n",
 
 
192
  "for uid, group in sample_df.groupby('user_id'):\n",
193
  " events = group['event_type'].tolist()\n",
194
  " for i in range(1, len(events)):\n",
195
  " bigrams[(events[i-1], events[i])] += 1\n",
196
  " prev_context[events[i-1]] += 1\n",
197
  "\n",
 
198
  "H_conditional = 0\n",
199
  "total_bigrams = sum(bigrams.values())\n",
200
  "for (prev, curr), count in bigrams.items():\n",
 
203
  " H_conditional -= p_joint * np.log2(p_cond)\n",
204
  "\n",
205
  "mutual_info = H_unigram - H_conditional\n",
206
+ "print(f'H(event_type): {H_unigram:.3f} bits')\n",
207
+ "print(f'H(event_type | prev): {H_conditional:.3f} bits')\n",
208
+ "print(f'Mutual info: {mutual_info:.3f} bits ({mutual_info/H_unigram*100:.1f}% predictability gain)')\n",
209
+ "print(f'\\nNote: This only measures event_type transitions. The model also learns')\n",
210
+ "print(f'category, price, and temporal patterns much richer sequential structure.')\n",
211
+ "\n",
212
+ "print(f'\\nTransition probabilities:')\n",
213
+ "for prev in sorted(unigram.index):\n",
214
+ " trans = {curr: bigrams.get((prev, curr), 0) / prev_context[prev] \n",
215
+ " for curr in sorted(unigram.index) if bigrams.get((prev, curr), 0) > 0}\n",
216
+ " row = ' | '.join(f'{c}: {p:.2f}' for c, p in sorted(trans.items(), key=lambda x: -x[1]))\n",
 
 
 
 
217
  " print(f' After {prev:20s} → {row}')"
218
  ]
219
  },
 
221
  "cell_type": "markdown",
222
  "metadata": {},
223
  "source": [
224
+ "## Step 4 — Schema & Event Conversion"
 
 
225
  ]
226
  },
227
  {
 
230
  "metadata": {},
231
  "outputs": [],
232
  "source": [
 
233
  "ECOMMERCE_REES46_SCHEMA = DomainSchema(\n",
234
  " name='ecommerce_rees46',\n",
235
  " description='REES46 e-commerce behavioral event schema',\n",
 
237
  " FieldSpec(name='event_type', field_type=FieldType.CATEGORICAL_FIXED, prefix='EVT',\n",
238
  " categories=['view', 'cart', 'remove_from_cart', 'purchase']),\n",
239
  " FieldSpec(name='price', field_type=FieldType.NUMERICAL_CONTINUOUS, prefix='PRICE', n_bins=21),\n",
240
+ " FieldSpec(name='category', field_type=FieldType.TEXT, prefix='CAT'),\n",
241
+ " FieldSpec(name='timestamp', field_type=FieldType.TEMPORAL, calendar_fields=['dow', 'hour']),\n",
 
242
  " ],\n",
243
  ")\n",
 
244
  "print(ECOMMERCE_REES46_SCHEMA.summary())"
245
  ]
246
  },
 
252
  "source": [
253
  "def row_to_event(row):\n",
254
  " dt = datetime.strptime(row['event_time'][:19], '%Y-%m-%dT%H:%M:%S')\n",
 
255
  " cat = row['category_code'] if pd.notna(row['category_code']) else (row['brand'] if pd.notna(row['brand']) else 'unknown')\n",
256
+ " return {'event_type': row['event_type'], 'price': row['price'], 'category': cat, 'timestamp': dt}\n",
 
 
 
 
 
257
  "\n",
258
+ "print(f'Sample: {row_to_event(df.iloc[0])}')"
259
  ]
260
  },
261
  {
 
265
  "outputs": [],
266
  "source": [
267
  "%%time\n",
 
268
  "MIN_EVENTS = 10\n",
269
  "MAX_EVENTS = 200\n",
270
  "MAX_USERS = 100_000\n",
271
  "\n",
 
272
  "user_counts = df.groupby('user_id').size()\n",
273
  "eligible_users = user_counts[user_counts >= MIN_EVENTS].index\n",
274
  "print(f'Users with {MIN_EVENTS}+ events: {len(eligible_users):,}')\n",
275
  "\n",
 
276
  "if len(eligible_users) > MAX_USERS:\n",
277
  " eligible_users = pd.Series(eligible_users).sample(MAX_USERS, random_state=42).values\n",
278
  " print(f'Subsampled to {MAX_USERS:,} users')\n",
279
  "\n",
 
 
 
280
  "filtered_df = df[df['user_id'].isin(eligible_users)].sort_values(['user_id', 'event_time'])\n",
281
  "\n",
282
+ "user_sequences, user_ids = [], []\n",
283
  "for uid, group in filtered_df.groupby('user_id'):\n",
284
+ " user_sequences.append([row_to_event(row) for _, row in group.head(MAX_EVENTS).iterrows()])\n",
 
285
  " user_ids.append(uid)\n",
286
  "\n",
287
+ "print(f'Users: {len(user_sequences):,}, Events: {sum(len(s) for s in user_sequences):,}')\n",
 
 
288
  "print(f'Events/user: min={min(len(s) for s in user_sequences)}, max={max(len(s) for s in user_sequences)}, mean={np.mean([len(s) for s in user_sequences]):.1f}')"
289
  ]
290
  },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": null,
294
+ "metadata": {},
295
+ "outputs": [],
296
+ "source": [
297
+ "# Free the big DataFrame — we only need user_sequences from here\n",
298
+ "del df, filtered_df, user_counts, eligible_users\n",
299
+ "gc.collect()\n",
300
+ "print('DataFrame freed from memory')"
301
+ ]
302
+ },
303
  {
304
  "cell_type": "markdown",
305
  "metadata": {},
306
  "source": [
307
+ "## Step 5 — Build Domain Tokenizer\n",
308
+ "\n",
309
+ "**Uses Whitespace pre-tokenizer** (fix for the 42% UNK bug from ByteLevel)."
310
  ]
311
  },
312
  {
 
321
  "builder = DomainTokenizerBuilder(ECOMMERCE_REES46_SCHEMA)\n",
322
  "builder.fit(all_events)\n",
323
  "\n",
 
324
  "text_corpus = [e['category'] for e in all_events]\n",
325
  "unique_cats = sorted(set(text_corpus))\n",
326
  "print(f'Unique category strings: {len(unique_cats):,}')\n",
327
+ "for c in unique_cats[:10]: print(f\" '{c}'\")\n",
328
+ "if len(unique_cats) > 10: print(f' ... and {len(unique_cats)-10} more')\n",
329
  "\n",
330
  "hf_tokenizer = builder.build(text_corpus=text_corpus, bpe_vocab_size=4000)\n",
331
  "print(f'\\nVocab size: {hf_tokenizer.vocab_size}')\n",
332
+ "\n",
333
+ "del all_events, text_corpus; gc.collect() # free fitting data"
334
  ]
335
  },
336
  {
 
339
  "metadata": {},
340
  "outputs": [],
341
  "source": [
342
+ "# Verify UNK rate is now ~0% (was 42% with ByteLevel bug)\n",
343
  "print('--- Sample event ---')\n",
344
  "for i, t in enumerate(builder.tokenize_event(user_sequences[0][0])): print(f' [{i}] {t}')\n",
345
  "\n",
346
+ "seq_tokens = builder.tokenize_sequence(user_sequences[0][:3])\n",
347
+ "print(f'\\n--- First 3 events ({len(seq_tokens)} token strings) ---')\n",
348
  "for i, t in enumerate(seq_tokens): print(f' [{i:3d}] {t}')\n",
349
  "\n",
350
  "seq_ids = hf_tokenizer(' '.join(seq_tokens), add_special_tokens=False)['input_ids']\n",
351
  "unk_count = sum(1 for i in seq_ids if i == hf_tokenizer.unk_token_id)\n",
352
+ "print(f'\\n✅ UNK rate: {unk_count}/{len(seq_ids)} ({unk_count/max(len(seq_ids),1)*100:.1f}%) — should be ~0%')"
353
  ]
354
  },
355
  {
 
377
  "metadata": {},
378
  "outputs": [],
379
  "source": [
380
+ "print(f'Sample block decoded:')\n",
381
+ "print(hf_tokenizer.decode(dataset[0]['input_ids'][:40]))\n",
382
  "\n",
383
+ "# Quick UNK check on full dataset\n",
384
+ "sample_ids = dataset[0]['input_ids'] + dataset[len(dataset)//2]['input_ids'] + dataset[-1]['input_ids']\n",
385
+ "unk_in_sample = sum(1 for i in sample_ids if i == hf_tokenizer.unk_token_id)\n",
386
+ "print(f'\\nUNK in 3 sample blocks: {unk_in_sample}/{len(sample_ids)} ({unk_in_sample/len(sample_ids)*100:.2f}%)')"
387
  ]
388
  },
389
  {
 
405
  "source": [
406
  "%%time\n",
407
  "USE_GPU = torch.cuda.is_available()\n",
408
+ "# T4 doesn't support bf16 — use fp16 instead\n",
409
+ "GPU_NAME = torch.cuda.get_device_name(0) if USE_GPU else ''\n",
410
+ "USE_BF16 = USE_GPU and 'T4' not in GPU_NAME # L4, A100 support bf16\n",
411
+ "USE_FP16 = USE_GPU and not USE_BF16 # T4 uses fp16\n",
412
+ "\n",
413
+ "print(f'Precision: {\"bf16\" if USE_BF16 else \"fp16\" if USE_FP16 else \"fp32\"}')\n",
414
  "\n",
415
  "trainer = pretrain_domain_model(\n",
416
  " model=model,\n",
 
425
  " warmup_steps=200 if USE_GPU else 10,\n",
426
  " logging_steps=50 if USE_GPU else 10,\n",
427
  " save_steps=2000 if USE_GPU else 999999,\n",
428
+ " bf16=USE_BF16,\n",
429
+ " fp16=USE_FP16,\n",
430
  " report_to='wandb',\n",
431
  " run_name='ecommerce-pretrain-24m-3ep',\n",
432
  " seed=42,\n",
 
451
  "print(f'Loss: {losses[0]:.4f} -> {losses[-1]:.4f} ({(1-losses[-1]/losses[0])*100:.1f}% reduction)')\n",
452
  "print(f'Min loss: {min(losses):.4f}')\n",
453
  "\n",
 
454
  "random_loss = np.log(hf_tokenizer.vocab_size)\n",
455
  "print(f'Random chance loss: {random_loss:.4f}')\n",
456
  "print(f'Model vs random: {\"✅ Better\" if losses[-1] < random_loss else \"❌ Worse\"} ({losses[-1]:.2f} vs {random_loss:.2f})')\n",
 
459
  "ax.plot(losses, linewidth=0.5, alpha=0.5, label='Per-step')\n",
460
  "window = max(len(losses) // 50, 1)\n",
461
  "if len(losses) > window:\n",
462
+ " ax.plot(pd.Series(losses).rolling(window=window, min_periods=1).mean(), linewidth=2, color='red', label='Smoothed')\n",
463
+ "ax.axhline(y=random_loss, color='gray', linestyle='--', label=f'Random ({random_loss:.2f})')\n",
464
  "ax.set_xlabel('Step'); ax.set_ylabel('Loss'); ax.set_title('E-Commerce Pre-Training Loss')\n",
465
  "ax.legend(); ax.grid(True, alpha=0.3); plt.tight_layout(); plt.show()"
466
  ]
 
471
  "metadata": {},
472
  "outputs": [],
473
  "source": [
 
474
  "model.eval()\n",
475
  "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
476
  "model = model.to(device)\n",
 
492
  "metadata": {},
493
  "outputs": [],
494
  "source": [
495
+ "# User embeddings — behavioral clusters\n",
496
  "n_sample = min(500, len(user_sequences))\n",
497
+ "embeddings, purchase_rates = [], []\n",
498
  "\n",
499
  "for i in range(n_sample):\n",
500
  " enc = hf_tokenizer(' '.join(builder.tokenize_sequence(user_sequences[i][:50])),\n",
 
502
  " with torch.no_grad():\n",
503
  " embeddings.append(model.get_user_embedding(enc['input_ids'].to(device), enc['attention_mask'].to(device)).cpu().numpy().flatten())\n",
504
  " events = user_sequences[i]\n",
 
505
  " purchase_rates.append(sum(1 for e in events if e['event_type'] == 'purchase') / len(events))\n",
506
  "\n",
507
+ "embeddings = np.array(embeddings); purchase_rates = np.array(purchase_rates)\n",
 
508
  "print(f'Embeddings: {embeddings.shape}')\n",
509
  "\n",
510
+ "from sklearn.manifold import TSNE\n",
511
+ "coords = TSNE(n_components=2, random_state=42, perplexity=30).fit_transform(embeddings)\n",
512
+ "fig, ax = plt.subplots(figsize=(8, 6))\n",
513
+ "sc = ax.scatter(coords[:, 0], coords[:, 1], c=purchase_rates, cmap='RdYlGn', alpha=0.6, s=20, edgecolors='black', linewidth=0.2)\n",
514
+ "ax.set_title('User Embeddings (t-SNE) — Colored by Purchase Rate')\n",
515
+ "plt.colorbar(sc, label='Purchase Rate'); plt.tight_layout(); plt.show()"
 
 
 
 
516
  ]
517
  },
518
  {
519
  "cell_type": "markdown",
520
  "metadata": {},
521
  "source": [
522
+ "## Save"
523
  ]
524
  },
525
  {
 
531
  "hf_tokenizer.save_pretrained('./ecommerce_tokenizer')\n",
532
  "builder.save('./ecommerce_tokenizer')\n",
533
  "model.save_pretrained('./ecommerce_pretrain_checkpoints/final')\n",
 
534
  "with open('./ecommerce_artifacts.pkl', 'wb') as f:\n",
535
  " pickle.dump({'user_sequences': user_sequences, 'user_ids': user_ids}, f)\n",
536
+ "print('Saved all artifacts')\n",
 
 
 
 
 
 
 
 
 
537
  "wandb.finish()"
538
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  }
540
  ],
541
  "metadata": {