File size: 6,257 Bytes
abab711 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | """
Tests for domainTokenizer Phase 2D: Fine-tuning Pipeline.
15 tests covering dataset, batching, forward/backward, Trainer smoke, multiclass.
Run: pytest tests/test_finetune.py -v
"""
import logging
import random
from datetime import datetime, timedelta
import numpy as np
import torch
import pytest
from domain_tokenizer.schemas.predefined import FINANCE_SCHEMA
from domain_tokenizer.tokenizers.domain_tokenizer import DomainTokenizerBuilder
from domain_tokenizer.models.configuration import DomainTransformerConfig
from domain_tokenizer.models.modeling import DomainTransformerForCausalLM
from domain_tokenizer.models.joint_fusion import JointFusionModel
from domain_tokenizer.training.finetune_data import DomainFinetuneDataset, prepare_finetune_dataset
from domain_tokenizer.training.finetune import finetune_domain_model
logging.basicConfig(level=logging.INFO)
def make_events(n=10, seed=42):
rng = random.Random(seed)
merchants = ["AMAZON", "UBER", "SALARY", "GROCERY", "NETFLIX", "GAS"]
base = datetime(2025, 1, 1)
return [
{"amount_sign": (a := rng.uniform(5, 5000) * rng.choice([1, -1])),
"amount": a, "timestamp": base + timedelta(days=rng.randint(0, 365), hours=rng.randint(0, 23)),
"description": rng.choice(merchants)}
for _ in range(n)
]
def make_labeled_data(n_users=20, n_tabular=10, seed=42):
rng = random.Random(seed)
seqs = [make_events(rng.randint(3, 15), seed + i) for i in range(n_users)]
tab = np.random.RandomState(seed).randn(n_users, n_tabular).astype(np.float32)
labels = np.array([rng.choice([0.0, 1.0]) for _ in range(n_users)])
return seqs, tab, labels
def build_tok(seqs):
flat = [e for s in seqs for e in s]
b = DomainTokenizerBuilder(FINANCE_SCHEMA)
b.fit(flat)
return b, b.build(text_corpus=list(set(e["description"] for e in flat)) * 20, bpe_vocab_size=300)
def tiny_cfg(v=128):
return DomainTransformerConfig(vocab_size=v, hidden_size=64, num_hidden_layers=2,
num_attention_heads=4, intermediate_size=128)
def make_fusion(v, nt=10, nc=1):
return JointFusionModel(
DomainTransformerForCausalLM(tiny_cfg(v)), nt, nc,
plr_frequencies=4, plr_embedding_dim=8, dcn_cross_layers=2,
dcn_deep_layers=1, dcn_deep_dim=32, head_hidden_dim=32,
)
class TestDataset:
@pytest.fixture
def setup(self):
seqs, tab, labels = make_labeled_data(10, 5)
b, t = build_tok(seqs)
return DomainFinetuneDataset(seqs, tab, labels, b, t, 64), t
def test_len(self, setup):
assert len(setup[0]) == 10
def test_keys(self, setup):
assert set(setup[0][0].keys()) == {"input_ids", "attention_mask", "tabular_features", "labels"}
def test_shapes(self, setup):
it = setup[0][0]
assert it["input_ids"].shape == (64,) and it["tabular_features"].shape == (5,)
def test_padding(self, setup):
it = setup[0][0]
assert (it["input_ids"] != setup[1].pad_token_id).any()
def test_mask_matches_pad(self, setup):
it = setup[0][0]
assert torch.equal(it["input_ids"] == setup[1].pad_token_id, it["attention_mask"] == 0)
def test_dtypes(self, setup):
it = setup[0][0]
assert it["labels"].dtype == torch.float32 and it["tabular_features"].dtype == torch.float32
def test_mismatch(self):
seqs, tab, labels = make_labeled_data(10)
b, t = build_tok(seqs)
with pytest.raises(AssertionError):
DomainFinetuneDataset(seqs[:5], tab, labels, b, t)
def test_stats(self, setup):
assert setup[0].get_stats()["n_samples"] == 10
class TestBatching:
def test_loader(self):
seqs, tab, labels = make_labeled_data(8, 5)
b, t = build_tok(seqs)
ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
assert batch["input_ids"].shape == (4, 32) and batch["tabular_features"].shape == (4, 5)
class TestForwardBackward:
def test_forward(self):
seqs, tab, labels = make_labeled_data(8, 5)
b, t = build_tok(seqs)
ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
out = make_fusion(t.vocab_size, 5)(** batch)
assert out["loss"].item() > 0
def test_backward(self):
seqs, tab, labels = make_labeled_data(4, 5)
b, t = build_tok(seqs)
ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
model = make_fusion(t.vocab_size, 5)
model(**batch)["loss"].backward()
assert model.transformer.model.embed_tokens.weight.grad is not None
assert model.plr.frequencies.grad is not None
def test_multiclass(self):
seqs, tab, _ = make_labeled_data(8, 5)
labels = np.array([random.randint(0, 2) for _ in range(8)])
b, t = build_tok(seqs)
ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
batch["labels"] = batch["labels"].long()
out = make_fusion(t.vocab_size, 5, 3)(**batch)
assert out["logits"].shape == (4, 3) and out["loss"] is not None
class TestTrainer:
def test_smoke(self, tmp_path):
seqs, tab, labels = make_labeled_data(20, 5)
b, t = build_tok(seqs)
ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
trainer = finetune_domain_model(
make_fusion(t.vocab_size, 5), ds,
output_dir=str(tmp_path), num_epochs=1, per_device_batch_size=4,
learning_rate=1e-3, warmup_steps=0, logging_steps=1,
save_strategy="no", report_to="none", seed=42,
)
assert trainer.state.global_step > 0
losses = [h["loss"] for h in trainer.state.log_history if "loss" in h]
assert len(losses) > 0
class TestPrepare:
def test_prepare(self):
seqs, tab, labels = make_labeled_data(10, 5)
b, t = build_tok(seqs)
ds = prepare_finetune_dataset(seqs, tab, labels, b, t, 32)
assert len(ds) == 10
|