rtferraz commited on
Commit
ab8a8b6
·
verified ·
1 Parent(s): b86b1ee

Add model test suite — 33 tests covering config, model, PLR, DCNv2, joint fusion, integration

Browse files
Files changed (1) hide show
  1. tests/test_model.py +219 -0
tests/test_model.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for domainTokenizer Phase 2B: Model Architecture.
3
+ 33 tests covering config, model, PLR, DCNv2, joint fusion, and end-to-end integration.
4
+
5
+ Run: pytest tests/test_model.py -v
6
+ """
7
+
8
+ import math
9
+ import json
10
+ from datetime import datetime
11
+
12
+ import numpy as np
13
+ import torch
14
+ import pytest
15
+
16
+ from domain_tokenizer.models.configuration import DomainTransformerConfig
17
+ from domain_tokenizer.models.modeling import (
18
+ DomainTransformerForCausalLM, DomainTransformerModel, DomainTransformerAttention, DomainTransformerBlock,
19
+ )
20
+ from domain_tokenizer.models.plr_embeddings import PeriodicLinearReLU
21
+ from domain_tokenizer.models.joint_fusion import DCNv2CrossLayer, DCNv2, JointFusionModel
22
+ from domain_tokenizer.tokenizers.domain_tokenizer import DomainTokenizerBuilder
23
+ from domain_tokenizer.schemas.predefined import FINANCE_SCHEMA
24
+
25
+
26
+ def tiny_config(vocab_size=128):
27
+ return DomainTransformerConfig(
28
+ vocab_size=vocab_size, hidden_size=64, num_hidden_layers=2, num_attention_heads=4,
29
+ intermediate_size=128, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0,
30
+ max_position_embeddings=64,
31
+ )
32
+
33
+
34
+ class TestDomainTransformerConfig:
35
+ def test_default(self):
36
+ c = DomainTransformerConfig()
37
+ assert c.vocab_size == 32000 and c.hidden_size == 512 and c.model_type == "domain_transformer"
38
+
39
+ def test_preset_24m(self):
40
+ c = DomainTransformerConfig.from_preset("24m")
41
+ assert c.hidden_size == 512 and c.num_hidden_layers == 6
42
+
43
+ def test_preset_85m(self):
44
+ assert DomainTransformerConfig.from_preset("85m").hidden_size == 768
45
+
46
+ def test_preset_330m(self):
47
+ c = DomainTransformerConfig.from_preset("330m")
48
+ assert c.hidden_size == 1024 and c.num_hidden_layers == 24
49
+
50
+ def test_preset_override(self):
51
+ c = DomainTransformerConfig.from_preset("24m", vocab_size=500)
52
+ assert c.vocab_size == 500 and c.hidden_size == 512
53
+
54
+ def test_invalid_preset(self):
55
+ with pytest.raises(ValueError):
56
+ DomainTransformerConfig.from_preset("999m")
57
+
58
+ def test_serialization(self):
59
+ c = DomainTransformerConfig(vocab_size=1000, hidden_size=128, num_hidden_layers=2, num_attention_heads=4)
60
+ c2 = DomainTransformerConfig(**c.to_dict())
61
+ assert c2.vocab_size == 1000
62
+
63
+ def test_head_dim(self):
64
+ with pytest.raises(AssertionError):
65
+ DomainTransformerConfig(hidden_size=100, num_attention_heads=7)
66
+
67
+ def test_intermediate_default(self):
68
+ assert DomainTransformerConfig(hidden_size=256).intermediate_size == 1024
69
+
70
+
71
+ class TestDomainTransformerModel:
72
+ def test_forward(self):
73
+ m = DomainTransformerModel(tiny_config())
74
+ assert m(input_ids=torch.randint(0, 128, (2, 16))).last_hidden_state.shape == (2, 16, 64)
75
+
76
+ def test_embeds(self):
77
+ m = DomainTransformerModel(tiny_config())
78
+ assert m(inputs_embeds=torch.randn(2, 16, 64)).last_hidden_state.shape == (2, 16, 64)
79
+
80
+
81
+ class TestDomainTransformerForCausalLM:
82
+ def test_no_labels(self):
83
+ m = DomainTransformerForCausalLM(tiny_config())
84
+ m.eval()
85
+ with torch.no_grad():
86
+ o = m(input_ids=torch.randint(0, 128, (2, 16)))
87
+ assert o.logits.shape == (2, 16, 128) and o.loss is None
88
+
89
+ def test_with_labels(self):
90
+ m = DomainTransformerForCausalLM(tiny_config())
91
+ ids = torch.randint(0, 128, (2, 16))
92
+ o = m(input_ids=ids, labels=ids)
93
+ assert o.loss is not None and o.loss.item() > 0
94
+
95
+ def test_backward(self):
96
+ m = DomainTransformerForCausalLM(tiny_config())
97
+ ids = torch.randint(0, 128, (2, 16))
98
+ m(input_ids=ids, labels=ids).loss.backward()
99
+ assert any(p.grad is not None for p in m.parameters() if p.requires_grad)
100
+
101
+ def test_weight_tying(self):
102
+ m = DomainTransformerForCausalLM(tiny_config())
103
+ assert m.lm_head.weight is m.model.embed_tokens.weight
104
+
105
+ def test_user_embedding(self):
106
+ m = DomainTransformerForCausalLM(tiny_config())
107
+ m.eval()
108
+ with torch.no_grad():
109
+ assert m.get_user_embedding(torch.randint(0, 128, (3, 16))).shape == (3, 64)
110
+
111
+ def test_user_embedding_mask(self):
112
+ m = DomainTransformerForCausalLM(tiny_config())
113
+ m.eval()
114
+ mask = torch.ones(2, 16, dtype=torch.long)
115
+ mask[0, 10:] = 0
116
+ with torch.no_grad():
117
+ assert m.get_user_embedding(torch.randint(0, 128, (2, 16)), attention_mask=mask).shape == (2, 64)
118
+
119
+ def test_params_tiny(self):
120
+ n = sum(p.numel() for p in DomainTransformerForCausalLM(tiny_config()).parameters())
121
+ assert n < 1_000_000
122
+
123
+ def test_params_24m(self):
124
+ n = sum(p.numel() for p in DomainTransformerForCausalLM(DomainTransformerConfig.from_preset("24m")).parameters())
125
+ assert 15_000_000 < n < 40_000_000
126
+
127
+ def test_grad_checkpoint(self):
128
+ m = DomainTransformerForCausalLM(tiny_config())
129
+ m.gradient_checkpointing_enable()
130
+ m(input_ids=torch.randint(0, 128, (2, 16)), labels=torch.randint(0, 128, (2, 16))).loss.backward()
131
+
132
+
133
+ class TestAttention:
134
+ def test_shape(self):
135
+ assert DomainTransformerAttention(tiny_config())(torch.randn(2, 16, 64)).shape == (2, 16, 64)
136
+
137
+ def test_causal(self):
138
+ c = tiny_config()
139
+ c.attention_probs_dropout_prob = 0.0
140
+ a = DomainTransformerAttention(c)
141
+ a.eval()
142
+ x = torch.zeros(1, 8, 64)
143
+ x[0, 4:, :] = 100.0
144
+ with torch.no_grad():
145
+ o = a(x)
146
+ assert o[0, 7].norm() > o[0, 0].norm() * 2
147
+
148
+
149
+ class TestPLR:
150
+ def test_shape(self):
151
+ assert PeriodicLinearReLU(10, 32, 64)(torch.randn(4, 10)).shape == (4, 10, 64)
152
+
153
+ def test_different(self):
154
+ p = PeriodicLinearReLU(5, 16, 32)
155
+ assert not torch.allclose(p(torch.ones(1, 5)), p(torch.ones(1, 5) * 10))
156
+
157
+ def test_grad(self):
158
+ p = PeriodicLinearReLU(5, 16, 32)
159
+ x = torch.randn(2, 5, requires_grad=True)
160
+ p(x).sum().backward()
161
+ assert x.grad is not None and p.frequencies.grad is not None
162
+
163
+ def test_single(self):
164
+ assert PeriodicLinearReLU(1, 8, 16)(torch.tensor([[3.14]])).shape == (1, 1, 16)
165
+
166
+
167
+ class TestDCNv2:
168
+ def test_cross(self):
169
+ assert DCNv2CrossLayer(64)(torch.randn(4, 64), torch.randn(4, 64)).shape == (4, 64)
170
+
171
+ def test_dcn(self):
172
+ d = DCNv2(128, 3, 2, 64)
173
+ assert d(torch.randn(4, 128)).shape == (4, 64) and d.output_dim == 64
174
+
175
+
176
+ class TestJointFusion:
177
+ @pytest.fixture
178
+ def model(self):
179
+ return JointFusionModel(
180
+ DomainTransformerForCausalLM(tiny_config(128)), 10, 1, 8, 16, 2, 2, 32, 32,
181
+ )
182
+
183
+ def test_forward(self, model):
184
+ o = model(torch.randint(0, 128, (2, 16)), torch.ones(2, 16, dtype=torch.long), torch.randn(2, 10))
185
+ assert o["logits"].shape == (2, 1) and o["loss"] is None
186
+
187
+ def test_loss(self, model):
188
+ o = model(torch.randint(0, 128, (2, 16)), torch.ones(2, 16, dtype=torch.long), torch.randn(2, 10), torch.tensor([1.0, 0.0]))
189
+ assert o["loss"] is not None and o["loss"].dim() == 0
190
+
191
+ def test_backward(self, model):
192
+ o = model(torch.randint(0, 128, (2, 16)), torch.ones(2, 16, dtype=torch.long), torch.randn(2, 10), torch.tensor([1.0, 0.0]))
193
+ o["loss"].backward()
194
+ assert model.transformer.model.embed_tokens.weight.grad is not None
195
+ assert model.plr.frequencies.grad is not None
196
+
197
+ def test_multiclass(self):
198
+ m = JointFusionModel(DomainTransformerForCausalLM(tiny_config(128)), 5, 3, 4, 8, 2, 2, 16, 16)
199
+ o = m(torch.randint(0, 128, (2, 8)), tabular_features=torch.randn(2, 5), labels=torch.tensor([0, 2]))
200
+ assert o["logits"].shape == (2, 3) and o["loss"] is not None
201
+
202
+
203
+ class TestIntegration:
204
+ def test_finance(self):
205
+ events = [
206
+ {"amount_sign": 79.99, "amount": 79.99, "timestamp": datetime(2025, 3, 15, 14, 30), "description": "AMAZON"},
207
+ {"amount_sign": -200.0, "amount": -200.0, "timestamp": datetime(2025, 3, 16, 9, 0), "description": "SALARY"},
208
+ ]
209
+ builder = DomainTokenizerBuilder(FINANCE_SCHEMA)
210
+ builder.fit(events)
211
+ hf_tok = builder.build(text_corpus=["AMAZON", "SALARY", "UBER", "GROCERY"] * 20, bpe_vocab_size=300)
212
+ enc = builder.encode_sequence(events, hf_tok, max_length=64)
213
+ ids = torch.tensor([enc["input_ids"]])
214
+ mask = torch.tensor([enc["attention_mask"]])
215
+ model = DomainTransformerForCausalLM(tiny_config(hf_tok.vocab_size))
216
+ out = model(input_ids=ids, attention_mask=mask, labels=ids)
217
+ assert out.loss.item() > 0
218
+ out.loss.backward()
219
+ assert sum(p.grad.norm().item() for p in model.parameters() if p.grad is not None) > 0