rtferraz commited on
Commit
8efa945
·
verified ·
1 Parent(s): c00ac2c

Add comprehensive test suite — 72 passing tests covering all components

Browse files
Files changed (1) hide show
  1. tests/test_tokenizer.py +353 -0
tests/test_tokenizer.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Comprehensive tests for domainTokenizer core library.
3
+ 72 tests covering: schemas, field tokenizers, predefined schemas,
4
+ DomainTokenizerBuilder pipeline, and end-to-end HF encoding.
5
+
6
+ Run: pytest tests/test_tokenizer.py -v
7
+ """
8
+
9
+ import json
10
+ import math
11
+ import sys
12
+ from datetime import datetime
13
+
14
+ import numpy as np
15
+ import pytest
16
+
17
+ from domain_tokenizer.schema import DomainSchema, FieldSpec, FieldType, CALENDAR_FIELD_SIZES
18
+ from domain_tokenizer.tokenizers.field_tokenizers import (
19
+ SignTokenizer, MagnitudeBucketTokenizer, DiscreteNumericalTokenizer,
20
+ CalendarTokenizer, CategoricalTokenizer, create_field_tokenizer,
21
+ )
22
+ from domain_tokenizer.tokenizers.domain_tokenizer import DomainTokenizerBuilder
23
+ from domain_tokenizer.schemas.predefined import FINANCE_SCHEMA, ECOMMERCE_SCHEMA, HEALTHCARE_SCHEMA
24
+
25
+
26
+ class TestFieldSpec:
27
+ def test_sign_field(self):
28
+ spec = FieldSpec("amount_sign", FieldType.SIGN)
29
+ assert spec.token_count == 2
30
+ assert spec.tokens_per_event == 1
31
+
32
+ def test_numerical_continuous_field(self):
33
+ spec = FieldSpec("amount", FieldType.NUMERICAL_CONTINUOUS, n_bins=21)
34
+ assert spec.token_count == 21
35
+
36
+ def test_numerical_discrete_field(self):
37
+ spec = FieldSpec("quantity", FieldType.NUMERICAL_DISCRETE, max_value=10)
38
+ assert spec.token_count == 12
39
+
40
+ def test_categorical_field(self):
41
+ spec = FieldSpec("event_type", FieldType.CATEGORICAL_FIXED, categories=["a", "b", "c"])
42
+ assert spec.token_count == 4
43
+
44
+ def test_temporal_field(self):
45
+ spec = FieldSpec("ts", FieldType.TEMPORAL, calendar_fields=["month", "dow", "dom", "hour"])
46
+ assert spec.token_count == 74
47
+
48
+ def test_text_field(self):
49
+ spec = FieldSpec("desc", FieldType.TEXT)
50
+ assert spec.token_count == 0
51
+
52
+ def test_custom_prefix(self):
53
+ spec = FieldSpec("amount", FieldType.NUMERICAL_CONTINUOUS, prefix="PRICE")
54
+ assert spec.prefix == "PRICE"
55
+
56
+ def test_categorical_requires_categories(self):
57
+ with pytest.raises(ValueError):
58
+ FieldSpec("event", FieldType.CATEGORICAL_FIXED)
59
+
60
+ def test_discrete_requires_max_value(self):
61
+ with pytest.raises(ValueError):
62
+ FieldSpec("qty", FieldType.NUMERICAL_DISCRETE)
63
+
64
+
65
+ class TestDomainSchema:
66
+ def test_finance_token_count(self):
67
+ expected = 8 + 2 + 21 + 74
68
+ assert FINANCE_SCHEMA.special_token_count == expected
69
+
70
+ def test_finance_fixed_tokens(self):
71
+ assert FINANCE_SCHEMA.fixed_tokens_per_event == 7
72
+
73
+ def test_has_text_fields(self):
74
+ assert FINANCE_SCHEMA.has_text_fields is True
75
+
76
+ def test_text_field_names(self):
77
+ assert FINANCE_SCHEMA.text_field_names == ["description"]
78
+
79
+ def test_fittable_fields(self):
80
+ assert FINANCE_SCHEMA.fittable_field_names == ["amount"]
81
+
82
+ def test_get_field(self):
83
+ assert FINANCE_SCHEMA.get_field("amount").field_type == FieldType.NUMERICAL_CONTINUOUS
84
+
85
+ def test_get_field_missing(self):
86
+ assert FINANCE_SCHEMA.get_field("nonexistent") is None
87
+
88
+ def test_summary(self):
89
+ assert "finance" in FINANCE_SCHEMA.summary()
90
+
91
+
92
+ class TestSignTokenizer:
93
+ def test_positive(self):
94
+ assert SignTokenizer("S")(79.99) == "[S_POS]"
95
+
96
+ def test_negative(self):
97
+ assert SignTokenizer("S")(-50.0) == "[S_NEG]"
98
+
99
+ def test_zero(self):
100
+ assert SignTokenizer("S")(0.0) == "[S_POS]"
101
+
102
+ def test_none(self):
103
+ assert SignTokenizer("S")(None) == "[S_POS]"
104
+
105
+ def test_nan(self):
106
+ assert SignTokenizer("S")(float("nan")) == "[S_POS]"
107
+
108
+ def test_vocab_size(self):
109
+ assert SignTokenizer("S").vocab_size == 2
110
+
111
+ def test_custom_labels(self):
112
+ tok = SignTokenizer("D", pos_label="CREDIT", neg_label="DEBIT")
113
+ assert tok(100) == "[D_CREDIT]"
114
+ assert tok(-100) == "[D_DEBIT]"
115
+
116
+
117
+ class TestMagnitudeBucketTokenizer:
118
+ def setup_method(self):
119
+ self.tok = MagnitudeBucketTokenizer("A", n_bins=5)
120
+ self.tok.fit(np.array([1, 2, 5, 10, 20, 50, 100, 200, 500, 1000]))
121
+
122
+ def test_low(self):
123
+ assert self.tok(1.0) == "[A_00]"
124
+
125
+ def test_high(self):
126
+ assert self.tok(1000.0) == "[A_04]"
127
+
128
+ def test_negative_abs(self):
129
+ assert self.tok(50.0) == self.tok(-50.0)
130
+
131
+ def test_none(self):
132
+ assert self.tok(None) == "[A_00]"
133
+
134
+ def test_nan(self):
135
+ assert self.tok(float("nan")) == "[A_00]"
136
+
137
+ def test_vocab(self):
138
+ assert self.tok.vocab_size == 5
139
+
140
+ def test_not_fitted(self):
141
+ with pytest.raises(RuntimeError):
142
+ MagnitudeBucketTokenizer("X")(50.0)
143
+
144
+ def test_empty_fit(self):
145
+ with pytest.raises(ValueError):
146
+ MagnitudeBucketTokenizer("X").fit(np.array([]))
147
+
148
+ def test_nubank_21(self):
149
+ tok = MagnitudeBucketTokenizer("A", n_bins=21)
150
+ tok.fit(np.random.lognormal(3, 1, 10000))
151
+ assert tok.vocab_size == 21
152
+ for v in [0.01, 1.0, 100.0, 10000.0]:
153
+ assert tok(v) in tok.vocab
154
+
155
+ def test_serialization(self):
156
+ d = self.tok.to_dict()
157
+ tok2 = MagnitudeBucketTokenizer.from_dict(d)
158
+ assert tok2(50.0) == self.tok(50.0)
159
+
160
+
161
+ class TestDiscreteNumericalTokenizer:
162
+ def test_normal(self):
163
+ assert DiscreteNumericalTokenizer("Q", max_value=10)(3) == "[Q_03]"
164
+
165
+ def test_zero(self):
166
+ assert DiscreteNumericalTokenizer("Q", max_value=10)(0) == "[Q_00]"
167
+
168
+ def test_max(self):
169
+ assert DiscreteNumericalTokenizer("Q", max_value=10)(10) == "[Q_10]"
170
+
171
+ def test_overflow(self):
172
+ assert DiscreteNumericalTokenizer("Q", max_value=10)(15) == "[Q_OVER]"
173
+
174
+ def test_negative(self):
175
+ assert DiscreteNumericalTokenizer("Q", max_value=10)(-5) == "[Q_00]"
176
+
177
+ def test_none(self):
178
+ assert DiscreteNumericalTokenizer("Q", max_value=10)(None) == "[Q_00]"
179
+
180
+ def test_vocab(self):
181
+ assert DiscreteNumericalTokenizer("Q", max_value=10).vocab_size == 12
182
+
183
+
184
+ class TestCalendarTokenizer:
185
+ def test_full(self):
186
+ tok = CalendarTokenizer("T", fields=["month", "dow", "dom", "hour"])
187
+ tokens = tok(datetime(2025, 3, 15, 14, 30))
188
+ assert len(tokens) == 4
189
+ assert tokens[0] == "[T_MON_03]"
190
+ assert tokens[3] == "[T_HOUR_14]"
191
+
192
+ def test_string_input(self):
193
+ assert CalendarTokenizer("T", ["month"])("2025-03-15T14:30:00") == ["[T_MON_03]"]
194
+
195
+ def test_date_only(self):
196
+ tokens = CalendarTokenizer("T", ["month", "dow"])("2025-03-15")
197
+ assert tokens[0] == "[T_MON_03]"
198
+
199
+ def test_vocab_standard(self):
200
+ assert CalendarTokenizer("T", ["month", "dow", "dom", "hour"]).vocab_size == 74
201
+
202
+ def test_subset(self):
203
+ assert CalendarTokenizer("T", ["month", "dow"]).vocab_size == 19
204
+
205
+ def test_invalid(self):
206
+ with pytest.raises(ValueError):
207
+ CalendarTokenizer("T", ["invalid"])
208
+
209
+ def test_quarter(self):
210
+ tok = CalendarTokenizer("T", ["quarter"])
211
+ assert tok(datetime(2025, 1, 1)) == ["[T_Q1]"]
212
+ assert tok(datetime(2025, 10, 1)) == ["[T_Q4]"]
213
+
214
+
215
+ class TestCategoricalTokenizer:
216
+ def test_known(self):
217
+ assert CategoricalTokenizer("E", ["view", "buy"])("buy") == "[E_001]"
218
+
219
+ def test_unknown(self):
220
+ assert CategoricalTokenizer("E", ["view", "buy"])("refund") == "[E_UNK]"
221
+
222
+ def test_none(self):
223
+ assert CategoricalTokenizer("E", ["view"])( None) == "[E_UNK]"
224
+
225
+ def test_vocab_unk(self):
226
+ tok = CategoricalTokenizer("E", ["a", "b"])
227
+ assert "[E_UNK]" in tok.vocab
228
+ assert tok.vocab_size == 3
229
+
230
+ def test_decode(self):
231
+ tok = CategoricalTokenizer("E", ["view", "buy"])
232
+ assert tok.decode_token("[E_000]") == "view"
233
+
234
+
235
+ class TestFactory:
236
+ def test_sign(self):
237
+ assert isinstance(create_field_tokenizer(FieldSpec("s", FieldType.SIGN)), SignTokenizer)
238
+
239
+ def test_magnitude(self):
240
+ assert isinstance(create_field_tokenizer(FieldSpec("a", FieldType.NUMERICAL_CONTINUOUS)), MagnitudeBucketTokenizer)
241
+
242
+ def test_discrete(self):
243
+ assert isinstance(create_field_tokenizer(FieldSpec("q", FieldType.NUMERICAL_DISCRETE, max_value=10)), DiscreteNumericalTokenizer)
244
+
245
+ def test_calendar(self):
246
+ assert isinstance(create_field_tokenizer(FieldSpec("t", FieldType.TEMPORAL)), CalendarTokenizer)
247
+
248
+ def test_categorical(self):
249
+ assert isinstance(create_field_tokenizer(FieldSpec("c", FieldType.CATEGORICAL_FIXED, categories=["a"])), CategoricalTokenizer)
250
+
251
+ def test_text_none(self):
252
+ assert create_field_tokenizer(FieldSpec("d", FieldType.TEXT)) is None
253
+
254
+
255
+ class TestPredefinedSchemas:
256
+ def test_finance(self):
257
+ assert FINANCE_SCHEMA.name == "finance"
258
+ assert len(FINANCE_SCHEMA.fields) == 4
259
+
260
+ def test_ecommerce(self):
261
+ assert ECOMMERCE_SCHEMA.name == "ecommerce"
262
+ assert len(ECOMMERCE_SCHEMA.fields) == 6
263
+
264
+ def test_healthcare(self):
265
+ assert HEALTHCARE_SCHEMA.name == "healthcare"
266
+ assert len(HEALTHCARE_SCHEMA.fields) == 6
267
+
268
+ def test_nubank_97(self):
269
+ domain_tokens = sum(f.token_count for f in FINANCE_SCHEMA.fields)
270
+ assert domain_tokens == 97
271
+
272
+
273
+ class TestDomainTokenizerBuilder:
274
+ @pytest.fixture
275
+ def events(self):
276
+ return [
277
+ {"amount_sign": 79.99, "amount": 79.99,
278
+ "timestamp": datetime(2025, 3, 15, 14, 30), "description": "AMAZON"},
279
+ {"amount_sign": -200.0, "amount": -200.0,
280
+ "timestamp": datetime(2025, 3, 16, 9, 15), "description": "SALARY"},
281
+ {"amount_sign": 12.50, "amount": 12.50,
282
+ "timestamp": datetime(2025, 3, 17, 18, 45), "description": "UBER"},
283
+ ]
284
+
285
+ @pytest.fixture
286
+ def corpus(self):
287
+ return ["AMAZON", "SALARY", "UBER", "GROCERY", "NETFLIX"] * 20
288
+
289
+ def test_create(self):
290
+ assert not DomainTokenizerBuilder(FINANCE_SCHEMA).is_fitted
291
+
292
+ def test_fit(self, events):
293
+ b = DomainTokenizerBuilder(FINANCE_SCHEMA)
294
+ b.fit(events)
295
+ assert b.is_fitted
296
+
297
+ def test_tokenize_event(self, events):
298
+ b = DomainTokenizerBuilder(FINANCE_SCHEMA)
299
+ b.fit(events)
300
+ tokens = b.tokenize_event(events[0])
301
+ assert len(tokens) >= 7
302
+ assert tokens[0].startswith("[AMT_SIGN_")
303
+
304
+ def test_tokenize_sequence(self, events):
305
+ b = DomainTokenizerBuilder(FINANCE_SCHEMA)
306
+ b.fit(events)
307
+ tokens = b.tokenize_sequence(events)
308
+ assert tokens[0] == "[BOS]"
309
+ assert tokens[-1] == "[EOS]"
310
+ assert tokens.count(FINANCE_SCHEMA.event_separator) == 2
311
+
312
+ def test_build(self, events, corpus):
313
+ b = DomainTokenizerBuilder(FINANCE_SCHEMA)
314
+ b.fit(events)
315
+ hf = b.build(text_corpus=corpus, bpe_vocab_size=300)
316
+ assert hf.pad_token == "[PAD]"
317
+ assert hf.convert_tokens_to_ids("[AMT_SIGN_POS]") != hf.unk_token_id
318
+
319
+ def test_end_to_end(self, events, corpus):
320
+ b = DomainTokenizerBuilder(FINANCE_SCHEMA)
321
+ b.fit(events)
322
+ hf = b.build(text_corpus=corpus, bpe_vocab_size=300)
323
+ enc = b.encode_sequence(events, hf, max_length=128)
324
+ assert len(enc["input_ids"]) == 128
325
+ assert sum(1 for m in enc["attention_mask"] if m == 1) > 10
326
+
327
+ def test_stats(self, events):
328
+ b = DomainTokenizerBuilder(FINANCE_SCHEMA)
329
+ b.fit(events)
330
+ s = b.get_stats()
331
+ assert s["schema_name"] == "finance"
332
+ assert s["is_fitted"]
333
+
334
+ def test_unfitted_raises(self):
335
+ with pytest.raises(RuntimeError):
336
+ DomainTokenizerBuilder(FINANCE_SCHEMA).build()
337
+
338
+
339
+ class TestEcommerceBuilder:
340
+ def test_full(self):
341
+ events = [
342
+ {"event_type": "view", "price": 29.99, "quantity": 1,
343
+ "category": "electronics", "timestamp": datetime(2025, 3, 15, 10, 0),
344
+ "product_title": "Mouse"},
345
+ {"event_type": "purchase", "price": 29.99, "quantity": 2,
346
+ "category": "electronics", "timestamp": datetime(2025, 3, 15, 10, 10),
347
+ "product_title": "Mouse"},
348
+ ]
349
+ b = DomainTokenizerBuilder(ECOMMERCE_SCHEMA)
350
+ b.fit(events)
351
+ hf = b.build(text_corpus=["Mouse", "Keyboard"] * 20, bpe_vocab_size=200)
352
+ enc = b.encode_sequence(events, hf, max_length=256)
353
+ assert sum(1 for m in enc["attention_mask"] if m == 1) > 10