| """ |
| Tests for domainTokenizer Phase 2D: Fine-tuning Pipeline. |
| 15 tests covering dataset, batching, forward/backward, Trainer smoke, multiclass. |
| |
| Run: pytest tests/test_finetune.py -v |
| """ |
|
|
| import logging |
| import random |
| from datetime import datetime, timedelta |
|
|
| import numpy as np |
| import torch |
| import pytest |
|
|
| from domain_tokenizer.schemas.predefined import FINANCE_SCHEMA |
| from domain_tokenizer.tokenizers.domain_tokenizer import DomainTokenizerBuilder |
| from domain_tokenizer.models.configuration import DomainTransformerConfig |
| from domain_tokenizer.models.modeling import DomainTransformerForCausalLM |
| from domain_tokenizer.models.joint_fusion import JointFusionModel |
| from domain_tokenizer.training.finetune_data import DomainFinetuneDataset, prepare_finetune_dataset |
| from domain_tokenizer.training.finetune import finetune_domain_model |
|
|
| logging.basicConfig(level=logging.INFO) |
|
|
|
|
| def make_events(n=10, seed=42): |
| rng = random.Random(seed) |
| merchants = ["AMAZON", "UBER", "SALARY", "GROCERY", "NETFLIX", "GAS"] |
| base = datetime(2025, 1, 1) |
| return [ |
| {"amount_sign": (a := rng.uniform(5, 5000) * rng.choice([1, -1])), |
| "amount": a, "timestamp": base + timedelta(days=rng.randint(0, 365), hours=rng.randint(0, 23)), |
| "description": rng.choice(merchants)} |
| for _ in range(n) |
| ] |
|
|
|
|
| def make_labeled_data(n_users=20, n_tabular=10, seed=42): |
| rng = random.Random(seed) |
| seqs = [make_events(rng.randint(3, 15), seed + i) for i in range(n_users)] |
| tab = np.random.RandomState(seed).randn(n_users, n_tabular).astype(np.float32) |
| labels = np.array([rng.choice([0.0, 1.0]) for _ in range(n_users)]) |
| return seqs, tab, labels |
|
|
|
|
| def build_tok(seqs): |
| flat = [e for s in seqs for e in s] |
| b = DomainTokenizerBuilder(FINANCE_SCHEMA) |
| b.fit(flat) |
| return b, b.build(text_corpus=list(set(e["description"] for e in flat)) * 20, bpe_vocab_size=300) |
|
|
|
|
| def tiny_cfg(v=128): |
| return DomainTransformerConfig(vocab_size=v, hidden_size=64, num_hidden_layers=2, |
| num_attention_heads=4, intermediate_size=128) |
|
|
|
|
| def make_fusion(v, nt=10, nc=1): |
| return JointFusionModel( |
| DomainTransformerForCausalLM(tiny_cfg(v)), nt, nc, |
| plr_frequencies=4, plr_embedding_dim=8, dcn_cross_layers=2, |
| dcn_deep_layers=1, dcn_deep_dim=32, head_hidden_dim=32, |
| ) |
|
|
|
|
| class TestDataset: |
| @pytest.fixture |
| def setup(self): |
| seqs, tab, labels = make_labeled_data(10, 5) |
| b, t = build_tok(seqs) |
| return DomainFinetuneDataset(seqs, tab, labels, b, t, 64), t |
|
|
| def test_len(self, setup): |
| assert len(setup[0]) == 10 |
|
|
| def test_keys(self, setup): |
| assert set(setup[0][0].keys()) == {"input_ids", "attention_mask", "tabular_features", "labels"} |
|
|
| def test_shapes(self, setup): |
| it = setup[0][0] |
| assert it["input_ids"].shape == (64,) and it["tabular_features"].shape == (5,) |
|
|
| def test_padding(self, setup): |
| it = setup[0][0] |
| assert (it["input_ids"] != setup[1].pad_token_id).any() |
|
|
| def test_mask_matches_pad(self, setup): |
| it = setup[0][0] |
| assert torch.equal(it["input_ids"] == setup[1].pad_token_id, it["attention_mask"] == 0) |
|
|
| def test_dtypes(self, setup): |
| it = setup[0][0] |
| assert it["labels"].dtype == torch.float32 and it["tabular_features"].dtype == torch.float32 |
|
|
| def test_mismatch(self): |
| seqs, tab, labels = make_labeled_data(10) |
| b, t = build_tok(seqs) |
| with pytest.raises(AssertionError): |
| DomainFinetuneDataset(seqs[:5], tab, labels, b, t) |
|
|
| def test_stats(self, setup): |
| assert setup[0].get_stats()["n_samples"] == 10 |
|
|
|
|
| class TestBatching: |
| def test_loader(self): |
| seqs, tab, labels = make_labeled_data(8, 5) |
| b, t = build_tok(seqs) |
| ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32) |
| batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4))) |
| assert batch["input_ids"].shape == (4, 32) and batch["tabular_features"].shape == (4, 5) |
|
|
|
|
| class TestForwardBackward: |
| def test_forward(self): |
| seqs, tab, labels = make_labeled_data(8, 5) |
| b, t = build_tok(seqs) |
| ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32) |
| batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4))) |
| out = make_fusion(t.vocab_size, 5)(** batch) |
| assert out["loss"].item() > 0 |
|
|
| def test_backward(self): |
| seqs, tab, labels = make_labeled_data(4, 5) |
| b, t = build_tok(seqs) |
| ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32) |
| batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4))) |
| model = make_fusion(t.vocab_size, 5) |
| model(**batch)["loss"].backward() |
| assert model.transformer.model.embed_tokens.weight.grad is not None |
| assert model.plr.frequencies.grad is not None |
|
|
| def test_multiclass(self): |
| seqs, tab, _ = make_labeled_data(8, 5) |
| labels = np.array([random.randint(0, 2) for _ in range(8)]) |
| b, t = build_tok(seqs) |
| ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32) |
| batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4))) |
| batch["labels"] = batch["labels"].long() |
| out = make_fusion(t.vocab_size, 5, 3)(**batch) |
| assert out["logits"].shape == (4, 3) and out["loss"] is not None |
|
|
|
|
| class TestTrainer: |
| def test_smoke(self, tmp_path): |
| seqs, tab, labels = make_labeled_data(20, 5) |
| b, t = build_tok(seqs) |
| ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32) |
| trainer = finetune_domain_model( |
| make_fusion(t.vocab_size, 5), ds, |
| output_dir=str(tmp_path), num_epochs=1, per_device_batch_size=4, |
| learning_rate=1e-3, warmup_steps=0, logging_steps=1, |
| save_strategy="no", report_to="none", seed=42, |
| ) |
| assert trainer.state.global_step > 0 |
| losses = [h["loss"] for h in trainer.state.log_history if "loss" in h] |
| assert len(losses) > 0 |
|
|
|
|
| class TestPrepare: |
| def test_prepare(self): |
| seqs, tab, labels = make_labeled_data(10, 5) |
| b, t = build_tok(seqs) |
| ds = prepare_finetune_dataset(seqs, tab, labels, b, t, 32) |
| assert len(ds) == 10 |
|
|