rtferraz commited on
Commit
d60868a
·
verified ·
1 Parent(s): 709a7e2

Add 02_ecommerce_pretrain.ipynb — REES46 e-commerce pre-training with sequential entropy check, wandb, push to hub

Browse files
Files changed (1) hide show
  1. notebooks/02_ecommerce_pretrain.ipynb +589 -0
notebooks/02_ecommerce_pretrain.ipynb ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 02 — E-Commerce Pre-Training: Domain Tokenizer on Real User Behavior Data\n",
8
+ "\n",
9
+ "**Goal:** Pre-train a 24M DomainTransformer on real e-commerce behavioral sequences (view → cart → purchase funnels) where sequential patterns actually exist.\n",
10
+ "\n",
11
+ "**Dataset:** [REES46 Multi-Category Store](https://huggingface.co/datasets/kevykibbz/ecommerce-behavior-data-from-multi-category-store_oct-nov_2019) — ~42M events, real user behavior, Nov 2019.\n",
12
+ "\n",
13
+ "**Why this dataset after the finance experiment:**\n",
14
+ "- The Nigerian finance dataset had only 84 unique descriptions and no sequential dependencies → loss plateaued at 6.9\n",
15
+ "- REES46 has millions of products, rich category hierarchies, view→cart→purchase funnels, and diverse browsing patterns\n",
16
+ "- This is where domain tokenization should genuinely prove itself\n",
17
+ "\n",
18
+ "**Lesson applied from finance report:** We run a **sequential entropy check** before training to verify there's learnable structure.\n",
19
+ "\n",
20
+ "**Hardware:** L4 GPU (24GB VRAM)"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "markdown",
25
+ "metadata": {},
26
+ "source": [
27
+ "## Setup"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "# !pip install datasets transformers torch accelerate tokenizers numpy pandas matplotlib scikit-learn wandb"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": null,
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "import logging, time, pickle, os, sys\n",
46
+ "from datetime import datetime\n",
47
+ "from collections import Counter\n",
48
+ "\n",
49
+ "import numpy as np\n",
50
+ "import pandas as pd\n",
51
+ "import matplotlib.pyplot as plt\n",
52
+ "import torch\n",
53
+ "from datasets import load_dataset\n",
54
+ "\n",
55
+ "if os.path.exists('../src'): sys.path.insert(0, '../src')\n",
56
+ "elif os.path.exists('src'): sys.path.insert(0, 'src')\n",
57
+ "\n",
58
+ "from domain_tokenizer import (\n",
59
+ " DomainTokenizerBuilder, DomainTransformerConfig,\n",
60
+ " DomainTransformerForCausalLM, prepare_clm_dataset, pretrain_domain_model,\n",
61
+ ")\n",
62
+ "from domain_tokenizer.schema import DomainSchema, FieldSpec, FieldType\n",
63
+ "\n",
64
+ "logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')\n",
65
+ "print(f'torch: {torch.__version__}, CUDA: {torch.cuda.is_available()}')\n",
66
+ "if torch.cuda.is_available():\n",
67
+ " print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB')"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": null,
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "import wandb\n",
77
+ "wandb.login()\n",
78
+ "os.environ['WANDB_PROJECT'] = 'domainTokenizer'\n",
79
+ "print('wandb ready')"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "markdown",
84
+ "metadata": {},
85
+ "source": [
86
+ "## Step 1 — Load Dataset\n",
87
+ "\n",
88
+ "42M events, 2GB. We load it all and then subsample users for manageable training."
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "%%time\n",
98
+ "ds = load_dataset(\n",
99
+ " 'kevykibbz/ecommerce-behavior-data-from-multi-category-store_oct-nov_2019',\n",
100
+ " split='train',\n",
101
+ ")\n",
102
+ "print(f'Loaded: {len(ds):,} events, columns: {ds.column_names}')"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "df = ds.to_pandas()\n",
112
+ "print(f'Shape: {df.shape}')\n",
113
+ "df.head(3)"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "markdown",
118
+ "metadata": {},
119
+ "source": [
120
+ "## Step 2 — Data Profiling"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "print(f\"Unique users: {df['user_id'].nunique():,}\")\n",
130
+ "print(f\"Unique products: {df['product_id'].nunique():,}\")\n",
131
+ "print(f\"Unique categories: {df['category_code'].nunique():,}\")\n",
132
+ "print(f\"Unique brands: {df['brand'].nunique():,}\")\n",
133
+ "print(f\"Price range: {df['price'].min():.2f} to {df['price'].max():.2f}\")\n",
134
+ "print(f\"\\nEvent types:\\n{df['event_type'].value_counts().to_string()}\")\n",
135
+ "print(f\"\\nCategory codes (top 15):\\n{df['category_code'].value_counts().head(15).to_string()}\")\n",
136
+ "print(f\"\\nBrands (top 10):\\n{df['brand'].value_counts().head(10).to_string()}\")\n",
137
+ "print(f\"\\nNull rates:\")\n",
138
+ "for col in df.columns:\n",
139
+ " null_pct = df[col].isnull().mean() * 100\n",
140
+ " if null_pct > 0: print(f\" {col}: {null_pct:.1f}%\")"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "events_per_user = df.groupby('user_id').size()\n",
150
+ "print(f\"Events/user: min={events_per_user.min()}, max={events_per_user.max()}, \"\n",
151
+ " f\"mean={events_per_user.mean():.1f}, median={events_per_user.median():.0f}\")\n",
152
+ "print(f\"Users with 10+ events: {(events_per_user >= 10).sum():,}\")\n",
153
+ "print(f\"Users with 20+ events: {(events_per_user >= 20).sum():,}\")\n",
154
+ "print(f\"Users with 50+ events: {(events_per_user >= 50).sum():,}\")\n",
155
+ "\n",
156
+ "fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
157
+ "axes[0].hist(np.log10(df['price'].clip(lower=0.01)), bins=50, edgecolor='black', alpha=0.7)\n",
158
+ "axes[0].set_xlabel('log10(Price)'); axes[0].set_title('Price Distribution')\n",
159
+ "axes[1].hist(events_per_user.clip(upper=100), bins=50, edgecolor='black', alpha=0.7)\n",
160
+ "axes[1].set_xlabel('Events/User'); axes[1].set_title('Events per User')\n",
161
+ "df['event_type'].value_counts().plot(kind='barh', ax=axes[2])\n",
162
+ "axes[2].set_xlabel('Count'); axes[2].set_title('Event Types')\n",
163
+ "plt.tight_layout(); plt.show()"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "markdown",
168
+ "metadata": {},
169
+ "source": [
170
+ "## Step 3 — Sequential Entropy Check\n",
171
+ "\n",
172
+ "**Lesson from finance experiment:** Before committing GPU time, verify the data has learnable sequential patterns.\n",
173
+ "\n",
174
+ "We check: is `P(event_type_t | event_type_{t-1})` different from `P(event_type_t)`? If yes → sequential structure exists."
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "metadata": {},
181
+ "outputs": [],
182
+ "source": [
183
+ "# Compute bigram transition probabilities for event_type\n",
184
+ "# Sample 50K users for speed\n",
185
+ "sample_users = df['user_id'].drop_duplicates().sample(min(50000, df['user_id'].nunique()), random_state=42)\n",
186
+ "sample_df = df[df['user_id'].isin(sample_users)].sort_values(['user_id', 'event_time'])\n",
187
+ "\n",
188
+ "# Unigram distribution\n",
189
+ "unigram = sample_df['event_type'].value_counts(normalize=True)\n",
190
+ "H_unigram = -(unigram * np.log2(unigram)).sum()\n",
191
+ "\n",
192
+ "# Bigram transitions within each user\n",
193
+ "bigrams = Counter()\n",
194
+ "prev_context = Counter()\n",
195
+ "for uid, group in sample_df.groupby('user_id'):\n",
196
+ " events = group['event_type'].tolist()\n",
197
+ " for i in range(1, len(events)):\n",
198
+ " bigrams[(events[i-1], events[i])] += 1\n",
199
+ " prev_context[events[i-1]] += 1\n",
200
+ "\n",
201
+ "# Conditional entropy H(event_t | event_{t-1})\n",
202
+ "H_conditional = 0\n",
203
+ "total_bigrams = sum(bigrams.values())\n",
204
+ "for (prev, curr), count in bigrams.items():\n",
205
+ " p_joint = count / total_bigrams\n",
206
+ " p_cond = count / prev_context[prev]\n",
207
+ " H_conditional -= p_joint * np.log2(p_cond)\n",
208
+ "\n",
209
+ "mutual_info = H_unigram - H_conditional\n",
210
+ "print(f'Marginal entropy H(event_type): {H_unigram:.3f} bits')\n",
211
+ "print(f'Conditional entropy H(event_type | prev): {H_conditional:.3f} bits')\n",
212
+ "print(f'Mutual information I(t; t-1): {mutual_info:.3f} bits')\n",
213
+ "print(f'Predictability gain: {mutual_info/H_unigram*100:.1f}%')\n",
214
+ "print(f'\\n{\"✅ Sequential structure detected\" if mutual_info > 0.1 else \"⚠️ Weak sequential structure\"}')\n",
215
+ "\n",
216
+ "# Show transition matrix\n",
217
+ "print(f'\\nTransition probabilities (P(next | current)):')\n",
218
+ "event_types = sorted(unigram.index)\n",
219
+ "for prev in event_types:\n",
220
+ " transitions = {}\n",
221
+ " for curr in event_types:\n",
222
+ " if prev_context[prev] > 0:\n",
223
+ " transitions[curr] = bigrams.get((prev, curr), 0) / prev_context[prev]\n",
224
+ " row = ' | '.join(f'{curr}: {p:.2f}' for curr, p in sorted(transitions.items(), key=lambda x: -x[1]) if p > 0.01)\n",
225
+ " print(f' After {prev:20s} → {row}')"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "markdown",
230
+ "metadata": {},
231
+ "source": [
232
+ "## Step 4 — Build E-Commerce Schema and Convert Events\n",
233
+ "\n",
234
+ "Custom schema for this dataset — maps directly to the available columns."
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": null,
240
+ "metadata": {},
241
+ "outputs": [],
242
+ "source": [
243
+ "# Custom schema matching REES46 columns\n",
244
+ "ECOMMERCE_REES46_SCHEMA = DomainSchema(\n",
245
+ " name='ecommerce_rees46',\n",
246
+ " description='REES46 e-commerce behavioral event schema',\n",
247
+ " fields=[\n",
248
+ " FieldSpec(name='event_type', field_type=FieldType.CATEGORICAL_FIXED, prefix='EVT',\n",
249
+ " categories=['view', 'cart', 'remove_from_cart', 'purchase']),\n",
250
+ " FieldSpec(name='price', field_type=FieldType.NUMERICAL_CONTINUOUS, prefix='PRICE', n_bins=21),\n",
251
+ " FieldSpec(name='category', field_type=FieldType.TEXT, prefix='CAT'), # hierarchical text like 'electronics.smartphone'\n",
252
+ " FieldSpec(name='timestamp', field_type=FieldType.TEMPORAL,\n",
253
+ " calendar_fields=['dow', 'hour']), # dow + hour capture shopping patterns\n",
254
+ " ],\n",
255
+ ")\n",
256
+ "\n",
257
+ "print(ECOMMERCE_REES46_SCHEMA.summary())"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": null,
263
+ "metadata": {},
264
+ "outputs": [],
265
+ "source": [
266
+ "def row_to_event(row):\n",
267
+ " dt = datetime.strptime(row['event_time'][:19], '%Y-%m-%dT%H:%M:%S')\n",
268
+ " # Use category_code if available, else brand, else 'unknown'\n",
269
+ " cat = row['category_code'] if pd.notna(row['category_code']) else (row['brand'] if pd.notna(row['brand']) else 'unknown')\n",
270
+ " return {\n",
271
+ " 'event_type': row['event_type'],\n",
272
+ " 'price': row['price'],\n",
273
+ " 'category': cat,\n",
274
+ " 'timestamp': dt,\n",
275
+ " }\n",
276
+ "\n",
277
+ "print(f'Sample event: {row_to_event(df.iloc[0])}')"
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "code",
282
+ "execution_count": null,
283
+ "metadata": {},
284
+ "outputs": [],
285
+ "source": [
286
+ "%%time\n",
287
+ "# Subsample: take users with 10+ events, cap at 100K users for training speed\n",
288
+ "MIN_EVENTS = 10\n",
289
+ "MAX_EVENTS = 200\n",
290
+ "MAX_USERS = 100_000\n",
291
+ "\n",
292
+ "# Pre-filter to users with enough events\n",
293
+ "user_counts = df.groupby('user_id').size()\n",
294
+ "eligible_users = user_counts[user_counts >= MIN_EVENTS].index\n",
295
+ "print(f'Users with {MIN_EVENTS}+ events: {len(eligible_users):,}')\n",
296
+ "\n",
297
+ "# Subsample if needed\n",
298
+ "if len(eligible_users) > MAX_USERS:\n",
299
+ " eligible_users = pd.Series(eligible_users).sample(MAX_USERS, random_state=42).values\n",
300
+ " print(f'Subsampled to {MAX_USERS:,} users')\n",
301
+ "\n",
302
+ "# Build user sequences\n",
303
+ "user_sequences = []\n",
304
+ "user_ids = []\n",
305
+ "filtered_df = df[df['user_id'].isin(eligible_users)].sort_values(['user_id', 'event_time'])\n",
306
+ "\n",
307
+ "for uid, group in filtered_df.groupby('user_id'):\n",
308
+ " events = [row_to_event(row) for _, row in group.head(MAX_EVENTS).iterrows()]\n",
309
+ " user_sequences.append(events)\n",
310
+ " user_ids.append(uid)\n",
311
+ "\n",
312
+ "total_events = sum(len(s) for s in user_sequences)\n",
313
+ "print(f'Users: {len(user_sequences):,}')\n",
314
+ "print(f'Total events: {total_events:,}')\n",
315
+ "print(f'Events/user: min={min(len(s) for s in user_sequences)}, max={max(len(s) for s in user_sequences)}, mean={np.mean([len(s) for s in user_sequences]):.1f}')"
316
+ ]
317
+ },
318
+ {
319
+ "cell_type": "markdown",
320
+ "metadata": {},
321
+ "source": [
322
+ "## Step 5 — Build Domain Tokenizer"
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "code",
327
+ "execution_count": null,
328
+ "metadata": {},
329
+ "outputs": [],
330
+ "source": [
331
+ "all_events = [e for seq in user_sequences for e in seq]\n",
332
+ "print(f'Total events for fitting: {len(all_events):,}')\n",
333
+ "\n",
334
+ "builder = DomainTokenizerBuilder(ECOMMERCE_REES46_SCHEMA)\n",
335
+ "builder.fit(all_events)\n",
336
+ "\n",
337
+ "# Text corpus = category codes (hierarchical) — rich BPE vocabulary\n",
338
+ "text_corpus = [e['category'] for e in all_events]\n",
339
+ "unique_cats = sorted(set(text_corpus))\n",
340
+ "print(f'Unique category strings: {len(unique_cats):,}')\n",
341
+ "for c in unique_cats[:15]: print(f\" '{c}'\")\n",
342
+ "if len(unique_cats) > 15: print(f' ... and {len(unique_cats)-15} more')\n",
343
+ "\n",
344
+ "hf_tokenizer = builder.build(text_corpus=text_corpus, bpe_vocab_size=4000)\n",
345
+ "print(f'\\nVocab size: {hf_tokenizer.vocab_size}')\n",
346
+ "print(f'Stats: {builder.get_stats()}')"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "code",
351
+ "execution_count": null,
352
+ "metadata": {},
353
+ "outputs": [],
354
+ "source": [
355
+ "# Inspect tokenization\n",
356
+ "print('--- Sample event ---')\n",
357
+ "for i, t in enumerate(builder.tokenize_event(user_sequences[0][0])): print(f' [{i}] {t}')\n",
358
+ "\n",
359
+ "print(f'\\n--- First user, first 5 events ---')\n",
360
+ "seq_tokens = builder.tokenize_sequence(user_sequences[0][:5])\n",
361
+ "for i, t in enumerate(seq_tokens): print(f' [{i:3d}] {t}')\n",
362
+ "\n",
363
+ "seq_ids = hf_tokenizer(' '.join(seq_tokens), add_special_tokens=False)['input_ids']\n",
364
+ "unk_count = sum(1 for i in seq_ids if i == hf_tokenizer.unk_token_id)\n",
365
+ "print(f'\\nUNK rate: {unk_count}/{len(seq_ids)} ({unk_count/max(len(seq_ids),1)*100:.1f}%)')"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "markdown",
370
+ "metadata": {},
371
+ "source": [
372
+ "## Step 6 — Pack and Train"
373
+ ]
374
+ },
375
+ {
376
+ "cell_type": "code",
377
+ "execution_count": null,
378
+ "metadata": {},
379
+ "outputs": [],
380
+ "source": [
381
+ "%%time\n",
382
+ "BLOCK_SIZE = 512\n",
383
+ "dataset = prepare_clm_dataset(user_sequences, builder, hf_tokenizer, block_size=BLOCK_SIZE)\n",
384
+ "print(f'Packed: {len(dataset):,} blocks x {BLOCK_SIZE} = {len(dataset)*BLOCK_SIZE:,} training tokens')"
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "code",
389
+ "execution_count": null,
390
+ "metadata": {},
391
+ "outputs": [],
392
+ "source": [
393
+ "print(f'Sample block: {hf_tokenizer.decode(dataset[0][\"input_ids\"][:50])}')\n",
394
+ "\n",
395
+ "all_ids = [i for row in dataset for i in row['input_ids']]\n",
396
+ "counts = Counter(all_ids)\n",
397
+ "unk_id = hf_tokenizer.unk_token_id\n",
398
+ "print(f'\\nTokens: {len(all_ids):,}, Unique: {len(counts)}/{hf_tokenizer.vocab_size}, UNK: {counts.get(unk_id,0)} ({counts.get(unk_id,0)/len(all_ids)*100:.2f}%)')"
399
+ ]
400
+ },
401
+ {
402
+ "cell_type": "code",
403
+ "execution_count": null,
404
+ "metadata": {},
405
+ "outputs": [],
406
+ "source": [
407
+ "config = DomainTransformerConfig.from_preset('24m', vocab_size=hf_tokenizer.vocab_size)\n",
408
+ "model = DomainTransformerForCausalLM(config)\n",
409
+ "print(f'Model: {sum(p.numel() for p in model.parameters()):,} params | d={config.hidden_size}, L={config.num_hidden_layers}, H={config.num_attention_heads}')"
410
+ ]
411
+ },
412
+ {
413
+ "cell_type": "code",
414
+ "execution_count": null,
415
+ "metadata": {},
416
+ "outputs": [],
417
+ "source": [
418
+ "%%time\n",
419
+ "USE_GPU = torch.cuda.is_available()\n",
420
+ "\n",
421
+ "trainer = pretrain_domain_model(\n",
422
+ " model=model,\n",
423
+ " tokenizer=hf_tokenizer,\n",
424
+ " train_dataset=dataset,\n",
425
+ " output_dir='./ecommerce_pretrain_checkpoints',\n",
426
+ " hub_model_id='rtferraz/ecommerce-domain-24m',\n",
427
+ " num_epochs=3 if USE_GPU else 1,\n",
428
+ " per_device_batch_size=32 if USE_GPU else 4,\n",
429
+ " gradient_accumulation_steps=4 if USE_GPU else 1,\n",
430
+ " learning_rate=3e-4,\n",
431
+ " warmup_steps=200 if USE_GPU else 10,\n",
432
+ " logging_steps=50 if USE_GPU else 10,\n",
433
+ " save_steps=2000 if USE_GPU else 999999,\n",
434
+ " bf16=USE_GPU,\n",
435
+ " report_to='wandb',\n",
436
+ " run_name='ecommerce-pretrain-24m-3ep',\n",
437
+ " seed=42,\n",
438
+ ")"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "markdown",
443
+ "metadata": {},
444
+ "source": [
445
+ "## Step 7 — Results"
446
+ ]
447
+ },
448
+ {
449
+ "cell_type": "code",
450
+ "execution_count": null,
451
+ "metadata": {},
452
+ "outputs": [],
453
+ "source": [
454
+ "losses = [h['loss'] for h in trainer.state.log_history if 'loss' in h]\n",
455
+ "print(f'Steps: {trainer.state.global_step:,}')\n",
456
+ "print(f'Loss: {losses[0]:.4f} -> {losses[-1]:.4f} ({(1-losses[-1]/losses[0])*100:.1f}% reduction)')\n",
457
+ "print(f'Min loss: {min(losses):.4f}')\n",
458
+ "\n",
459
+ "# Compare to random chance: -ln(1/vocab_size)\n",
460
+ "random_loss = np.log(hf_tokenizer.vocab_size)\n",
461
+ "print(f'Random chance loss: {random_loss:.4f}')\n",
462
+ "print(f'Model vs random: {\"✅ Better\" if losses[-1] < random_loss else \"❌ Worse\"} ({losses[-1]:.2f} vs {random_loss:.2f})')\n",
463
+ "\n",
464
+ "fig, ax = plt.subplots(figsize=(10, 5))\n",
465
+ "ax.plot(losses, linewidth=0.5, alpha=0.5, label='Per-step')\n",
466
+ "window = max(len(losses) // 50, 1)\n",
467
+ "if len(losses) > window:\n",
468
+ " ax.plot(pd.Series(losses).rolling(window=window, min_periods=1).mean(), linewidth=2, color='red', label=f'Smoothed')\n",
469
+ "ax.axhline(y=random_loss, color='gray', linestyle='--', label=f'Random chance ({random_loss:.2f})')\n",
470
+ "ax.set_xlabel('Step'); ax.set_ylabel('Loss'); ax.set_title('E-Commerce Pre-Training Loss')\n",
471
+ "ax.legend(); ax.grid(True, alpha=0.3); plt.tight_layout(); plt.show()"
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "code",
476
+ "execution_count": null,
477
+ "metadata": {},
478
+ "outputs": [],
479
+ "source": [
480
+ "# Next-token predictions\n",
481
+ "model.eval()\n",
482
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
483
+ "model = model.to(device)\n",
484
+ "\n",
485
+ "test_ids = hf_tokenizer(' '.join(builder.tokenize_sequence(user_sequences[0][:5])),\n",
486
+ " return_tensors='pt', add_special_tokens=False)['input_ids'].to(device)\n",
487
+ "with torch.no_grad():\n",
488
+ " top5 = torch.topk(model(input_ids=test_ids).logits[0, -1, :], 5)\n",
489
+ "\n",
490
+ "print('Last 5 input tokens:')\n",
491
+ "for tid in test_ids[0, -5:]: print(f\" {tid.item():5d} -> '{hf_tokenizer.decode([tid.item()])}'\")\n",
492
+ "print('\\nTop-5 next token predictions:')\n",
493
+ "for score, tid in zip(top5.values, top5.indices): print(f\" {tid.item():5d} -> '{hf_tokenizer.decode([tid.item()])}' (score={score.item():.3f})\")"
494
+ ]
495
+ },
496
+ {
497
+ "cell_type": "code",
498
+ "execution_count": null,
499
+ "metadata": {},
500
+ "outputs": [],
501
+ "source": [
502
+ "# User embeddings — can we see behavioral clusters?\n",
503
+ "n_sample = min(500, len(user_sequences))\n",
504
+ "embeddings, event_counts, purchase_rates = [], [], []\n",
505
+ "\n",
506
+ "for i in range(n_sample):\n",
507
+ " enc = hf_tokenizer(' '.join(builder.tokenize_sequence(user_sequences[i][:50])),\n",
508
+ " return_tensors='pt', add_special_tokens=False, max_length=256, truncation=True, padding='max_length')\n",
509
+ " with torch.no_grad():\n",
510
+ " embeddings.append(model.get_user_embedding(enc['input_ids'].to(device), enc['attention_mask'].to(device)).cpu().numpy().flatten())\n",
511
+ " events = user_sequences[i]\n",
512
+ " event_counts.append(len(events))\n",
513
+ " purchase_rates.append(sum(1 for e in events if e['event_type'] == 'purchase') / len(events))\n",
514
+ "\n",
515
+ "embeddings = np.array(embeddings)\n",
516
+ "purchase_rates = np.array(purchase_rates)\n",
517
+ "print(f'Embeddings: {embeddings.shape}')\n",
518
+ "\n",
519
+ "if len(embeddings) >= 20:\n",
520
+ " from sklearn.manifold import TSNE\n",
521
+ " coords = TSNE(n_components=2, random_state=42, perplexity=30).fit_transform(embeddings)\n",
522
+ " \n",
523
+ " fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
524
+ " sc1 = axes[0].scatter(coords[:, 0], coords[:, 1], c=purchase_rates, cmap='RdYlGn', alpha=0.6, s=20)\n",
525
+ " axes[0].set_title('User Embeddings — Colored by Purchase Rate'); plt.colorbar(sc1, ax=axes[0], label='Purchase Rate')\n",
526
+ " sc2 = axes[1].scatter(coords[:, 0], coords[:, 1], c=np.log10(np.array(event_counts[:n_sample])), cmap='viridis', alpha=0.6, s=20)\n",
527
+ " axes[1].set_title('User Embeddings — Colored by Activity Level'); plt.colorbar(sc2, ax=axes[1], label='log10(Events)')\n",
528
+ " plt.tight_layout(); plt.show()"
529
+ ]
530
+ },
531
+ {
532
+ "cell_type": "markdown",
533
+ "metadata": {},
534
+ "source": [
535
+ "## Save Artifacts"
536
+ ]
537
+ },
538
+ {
539
+ "cell_type": "code",
540
+ "execution_count": null,
541
+ "metadata": {},
542
+ "outputs": [],
543
+ "source": [
544
+ "hf_tokenizer.save_pretrained('./ecommerce_tokenizer')\n",
545
+ "builder.save('./ecommerce_tokenizer')\n",
546
+ "model.save_pretrained('./ecommerce_pretrain_checkpoints/final')\n",
547
+ "\n",
548
+ "with open('./ecommerce_artifacts.pkl', 'wb') as f:\n",
549
+ " pickle.dump({'user_sequences': user_sequences, 'user_ids': user_ids}, f)\n",
550
+ "\n",
551
+ "print('Saved: ./ecommerce_tokenizer/, ./ecommerce_pretrain_checkpoints/final/, ./ecommerce_artifacts.pkl')"
552
+ ]
553
+ },
554
+ {
555
+ "cell_type": "code",
556
+ "execution_count": null,
557
+ "metadata": {},
558
+ "outputs": [],
559
+ "source": [
560
+ "wandb.finish()"
561
+ ]
562
+ },
563
+ {
564
+ "cell_type": "markdown",
565
+ "metadata": {},
566
+ "source": [
567
+ "## Summary\n",
568
+ "\n",
569
+ "| Metric | Value |\n",
570
+ "|--------|-------|\n",
571
+ "| Dataset | REES46 E-Commerce (42M events) |\n",
572
+ "| Users (sampled) | *see above* |\n",
573
+ "| Training tokens | *see above* |\n",
574
+ "| Sequential entropy gain | *see Step 3* |\n",
575
+ "| Model | DomainTransformer 24M (NoPE) |\n",
576
+ "| Final loss | *see above* |\n",
577
+ "| Loss vs random chance | *see above* |\n",
578
+ "\n",
579
+ "**Next:** `03_ecommerce_finetune.ipynb` — Fine-tune for next-purchase prediction."
580
+ ]
581
+ }
582
+ ],
583
+ "metadata": {
584
+ "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" },
585
+ "language_info": { "name": "python", "version": "3.12.0" }
586
+ },
587
+ "nbformat": 4,
588
+ "nbformat_minor": 4
589
+ }