rtferraz commited on
Commit
345d9e3
·
verified ·
1 Parent(s): af3b720

Add training test suite — 19 tests covering data pipeline, packing, collation, integration, Trainer smoke test

Browse files
Files changed (1) hide show
  1. tests/test_training.py +207 -0
tests/test_training.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for domainTokenizer Phase 2C: Pre-training Pipeline.
3
+ 19 tests covering tokenization, packing, collation, integration, and Trainer smoke test.
4
+
5
+ Run: pytest tests/test_training.py -v
6
+ """
7
+
8
+ import logging
9
+ import random
10
+ from datetime import datetime, timedelta
11
+ from typing import Any, Dict, List
12
+
13
+ import numpy as np
14
+ import torch
15
+ import pytest
16
+
17
+ from datasets import Dataset as HFDataset
18
+ from transformers import DataCollatorForLanguageModeling
19
+
20
+ from domain_tokenizer.schemas.predefined import FINANCE_SCHEMA
21
+ from domain_tokenizer.tokenizers.domain_tokenizer import DomainTokenizerBuilder
22
+ from domain_tokenizer.models.configuration import DomainTransformerConfig
23
+ from domain_tokenizer.models.modeling import DomainTransformerForCausalLM
24
+ from domain_tokenizer.training.data_pipeline import (
25
+ tokenize_user_sequences, pack_sequences, prepare_clm_dataset,
26
+ )
27
+ from domain_tokenizer.training.pretrain import pretrain_domain_model
28
+
29
+ logging.basicConfig(level=logging.INFO)
30
+
31
+
32
+ def make_finance_events(n_events=10, seed=42):
33
+ rng = random.Random(seed)
34
+ merchants = ["AMAZON", "UBER", "SALARY", "GROCERY", "NETFLIX", "GAS", "RESTAURANT", "PHARMACY"]
35
+ base_date = datetime(2025, 1, 1)
36
+ return [
37
+ {"amount_sign": (amt := rng.uniform(5, 5000) * rng.choice([1, -1])),
38
+ "amount": amt,
39
+ "timestamp": base_date + timedelta(days=rng.randint(0, 365), hours=rng.randint(0, 23)),
40
+ "description": rng.choice(merchants)}
41
+ for _ in range(n_events)
42
+ ]
43
+
44
+
45
+ def make_user_sequences(n_users=20, min_events=5, max_events=30, seed=42):
46
+ rng = random.Random(seed)
47
+ return [make_finance_events(rng.randint(min_events, max_events), seed + i) for i in range(n_users)]
48
+
49
+
50
+ def build_finance_tokenizer(events_flat):
51
+ builder = DomainTokenizerBuilder(FINANCE_SCHEMA)
52
+ builder.fit(events_flat)
53
+ text_corpus = list(set(e["description"] for e in events_flat)) * 20
54
+ return builder, builder.build(text_corpus=text_corpus, bpe_vocab_size=500)
55
+
56
+
57
+ class TestTokenizeUserSequences:
58
+ @pytest.fixture
59
+ def setup(self):
60
+ seqs = make_user_sequences(5, 3, 10)
61
+ flat = [e for s in seqs for e in s]
62
+ b, t = build_finance_tokenizer(flat)
63
+ return seqs, b, t
64
+
65
+ def test_returns_list_of_lists(self, setup):
66
+ seqs, b, t = setup
67
+ r = tokenize_user_sequences(seqs, b, t)
68
+ assert len(r) == 5 and all(isinstance(s, list) for s in r)
69
+
70
+ def test_variable_lengths(self, setup):
71
+ seqs, b, t = setup
72
+ r = tokenize_user_sequences(seqs, b, t)
73
+ assert len(set(len(s) for s in r)) > 1
74
+
75
+ def test_bos_eos_present(self, setup):
76
+ seqs, b, t = setup
77
+ r = tokenize_user_sequences(seqs, b, t, add_bos=True, add_eos=True)
78
+ bos = t.convert_tokens_to_ids("[BOS]")
79
+ assert all(bos in s[:5] for s in r)
80
+
81
+ def test_no_bos_eos(self, setup):
82
+ seqs, b, t = setup
83
+ with_ = tokenize_user_sequences(seqs[:1], b, t, add_bos=True, add_eos=True)
84
+ without = tokenize_user_sequences(seqs[:1], b, t, add_bos=False, add_eos=False)
85
+ assert len(without[0]) < len(with_[0])
86
+
87
+
88
+ class TestPackSequences:
89
+ def test_fixed_length(self):
90
+ ds = pack_sequences([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15]], block_size=5)
91
+ assert len(ds) == 3 and all(len(r["input_ids"]) == 5 for r in ds)
92
+
93
+ def test_concat(self):
94
+ ds = pack_sequences([[1, 2, 3], [4, 5, 6]], block_size=3)
95
+ assert ds[0]["input_ids"] == [1, 2, 3] and ds[1]["input_ids"] == [4, 5, 6]
96
+
97
+ def test_drops_remainder(self):
98
+ ds = pack_sequences([[1, 2, 3, 4, 5, 6, 7]], block_size=3)
99
+ assert len(ds) == 2
100
+
101
+ def test_too_few(self):
102
+ with pytest.raises(ValueError):
103
+ pack_sequences([[1, 2]], block_size=10)
104
+
105
+ def test_hf_dataset(self):
106
+ assert isinstance(pack_sequences([[i for i in range(100)]], block_size=10), HFDataset)
107
+
108
+ def test_no_padding(self):
109
+ ds = pack_sequences([[i for i in range(50)] for _ in range(10)], block_size=25)
110
+ assert all(len(r["input_ids"]) == 25 for r in ds)
111
+
112
+
113
+ class TestPrepareCLMDataset:
114
+ def test_full(self):
115
+ seqs = make_user_sequences(10, 5, 15)
116
+ flat = [e for s in seqs for e in s]
117
+ b, t = build_finance_tokenizer(flat)
118
+ ds = prepare_clm_dataset(seqs, b, t, block_size=64)
119
+ assert len(ds) > 0 and all(len(r["input_ids"]) == 64 for r in ds)
120
+
121
+ def test_block_sizes(self):
122
+ seqs = make_user_sequences(10, 10, 20)
123
+ flat = [e for s in seqs for e in s]
124
+ b, t = build_finance_tokenizer(flat)
125
+ ds32 = prepare_clm_dataset(seqs, b, t, block_size=32)
126
+ ds64 = prepare_clm_dataset(seqs, b, t, block_size=64)
127
+ assert len(ds32) > len(ds64)
128
+
129
+
130
+ class TestDataCollator:
131
+ @pytest.fixture
132
+ def setup(self):
133
+ seqs = make_user_sequences(5, 5, 15)
134
+ flat = [e for s in seqs for e in s]
135
+ b, t = build_finance_tokenizer(flat)
136
+ ds = prepare_clm_dataset(seqs, b, t, block_size=32)
137
+ return ds, DataCollatorForLanguageModeling(tokenizer=t, mlm=False), t
138
+
139
+ def test_adds_labels(self, setup):
140
+ ds, c, _ = setup
141
+ batch = c([ds[i] for i in range(min(4, len(ds)))])
142
+ assert all(k in batch for k in ["input_ids", "labels", "attention_mask"])
143
+
144
+ def test_labels_eq_ids(self, setup):
145
+ ds, c, _ = setup
146
+ batch = c([ds[0]])
147
+ assert torch.equal(batch["input_ids"], batch["labels"])
148
+
149
+ def test_shapes(self, setup):
150
+ ds, c, _ = setup
151
+ n = min(4, len(ds))
152
+ batch = c([ds[i] for i in range(n)])
153
+ assert batch["input_ids"].shape == (n, 32)
154
+
155
+ def test_all_ones_mask(self, setup):
156
+ ds, c, _ = setup
157
+ batch = c([ds[0]])
158
+ assert batch["attention_mask"].sum() == 32
159
+
160
+
161
+ class TestTrainingIntegration:
162
+ def test_forward(self):
163
+ seqs = make_user_sequences(10, 5, 15)
164
+ flat = [e for s in seqs for e in s]
165
+ b, t = build_finance_tokenizer(flat)
166
+ ds = prepare_clm_dataset(seqs, b, t, block_size=32)
167
+ c = DataCollatorForLanguageModeling(tokenizer=t, mlm=False)
168
+ batch = c([ds[i] for i in range(min(4, len(ds)))])
169
+ config = DomainTransformerConfig(vocab_size=t.vocab_size, hidden_size=64,
170
+ num_hidden_layers=2, num_attention_heads=4, intermediate_size=128)
171
+ model = DomainTransformerForCausalLM(config)
172
+ out = model(**batch)
173
+ assert out.loss.item() > 0
174
+ out.loss.backward()
175
+ assert sum(p.grad.norm().item() for p in model.parameters() if p.grad is not None) > 0
176
+
177
+
178
+ class TestPretrainDomainModel:
179
+ def test_smoke(self, tmp_path):
180
+ seqs = make_user_sequences(20, 5, 15)
181
+ flat = [e for s in seqs for e in s]
182
+ b, t = build_finance_tokenizer(flat)
183
+ ds = prepare_clm_dataset(seqs, b, t, block_size=32)
184
+ config = DomainTransformerConfig(vocab_size=t.vocab_size, hidden_size=64,
185
+ num_hidden_layers=2, num_attention_heads=4, intermediate_size=128)
186
+ model = DomainTransformerForCausalLM(config)
187
+ trainer = pretrain_domain_model(
188
+ model=model, tokenizer=t, train_dataset=ds,
189
+ output_dir=str(tmp_path / "ck"), hub_model_id=None,
190
+ num_epochs=1, per_device_batch_size=4, gradient_accumulation_steps=1,
191
+ learning_rate=1e-3, warmup_steps=0, logging_steps=1,
192
+ save_steps=999999, report_to="none", seed=42,
193
+ )
194
+ assert trainer.state.global_step > 0
195
+
196
+ def test_no_pad_raises(self, tmp_path):
197
+ from transformers import PreTrainedTokenizerFast
198
+ from tokenizers import Tokenizer
199
+ from tokenizers.models import BPE
200
+ hf = PreTrainedTokenizerFast(tokenizer_object=Tokenizer(BPE(unk_token="[UNK]")), unk_token="[UNK]")
201
+ config = DomainTransformerConfig(vocab_size=100, hidden_size=32, num_hidden_layers=1, num_attention_heads=2)
202
+ with pytest.raises(ValueError, match="pad_token"):
203
+ pretrain_domain_model(
204
+ model=DomainTransformerForCausalLM(config), tokenizer=hf,
205
+ train_dataset=HFDataset.from_dict({"input_ids": [[1, 2, 3]]}),
206
+ output_dir=str(tmp_path),
207
+ )