""" Tests for domainTokenizer Phase 2D: Fine-tuning Pipeline. 15 tests covering dataset, batching, forward/backward, Trainer smoke, multiclass. Run: pytest tests/test_finetune.py -v """ import logging import random from datetime import datetime, timedelta import numpy as np import torch import pytest from domain_tokenizer.schemas.predefined import FINANCE_SCHEMA from domain_tokenizer.tokenizers.domain_tokenizer import DomainTokenizerBuilder from domain_tokenizer.models.configuration import DomainTransformerConfig from domain_tokenizer.models.modeling import DomainTransformerForCausalLM from domain_tokenizer.models.joint_fusion import JointFusionModel from domain_tokenizer.training.finetune_data import DomainFinetuneDataset, prepare_finetune_dataset from domain_tokenizer.training.finetune import finetune_domain_model logging.basicConfig(level=logging.INFO) def make_events(n=10, seed=42): rng = random.Random(seed) merchants = ["AMAZON", "UBER", "SALARY", "GROCERY", "NETFLIX", "GAS"] base = datetime(2025, 1, 1) return [ {"amount_sign": (a := rng.uniform(5, 5000) * rng.choice([1, -1])), "amount": a, "timestamp": base + timedelta(days=rng.randint(0, 365), hours=rng.randint(0, 23)), "description": rng.choice(merchants)} for _ in range(n) ] def make_labeled_data(n_users=20, n_tabular=10, seed=42): rng = random.Random(seed) seqs = [make_events(rng.randint(3, 15), seed + i) for i in range(n_users)] tab = np.random.RandomState(seed).randn(n_users, n_tabular).astype(np.float32) labels = np.array([rng.choice([0.0, 1.0]) for _ in range(n_users)]) return seqs, tab, labels def build_tok(seqs): flat = [e for s in seqs for e in s] b = DomainTokenizerBuilder(FINANCE_SCHEMA) b.fit(flat) return b, b.build(text_corpus=list(set(e["description"] for e in flat)) * 20, bpe_vocab_size=300) def tiny_cfg(v=128): return DomainTransformerConfig(vocab_size=v, hidden_size=64, num_hidden_layers=2, num_attention_heads=4, intermediate_size=128) def make_fusion(v, nt=10, nc=1): return JointFusionModel( DomainTransformerForCausalLM(tiny_cfg(v)), nt, nc, plr_frequencies=4, plr_embedding_dim=8, dcn_cross_layers=2, dcn_deep_layers=1, dcn_deep_dim=32, head_hidden_dim=32, ) class TestDataset: @pytest.fixture def setup(self): seqs, tab, labels = make_labeled_data(10, 5) b, t = build_tok(seqs) return DomainFinetuneDataset(seqs, tab, labels, b, t, 64), t def test_len(self, setup): assert len(setup[0]) == 10 def test_keys(self, setup): assert set(setup[0][0].keys()) == {"input_ids", "attention_mask", "tabular_features", "labels"} def test_shapes(self, setup): it = setup[0][0] assert it["input_ids"].shape == (64,) and it["tabular_features"].shape == (5,) def test_padding(self, setup): it = setup[0][0] assert (it["input_ids"] != setup[1].pad_token_id).any() def test_mask_matches_pad(self, setup): it = setup[0][0] assert torch.equal(it["input_ids"] == setup[1].pad_token_id, it["attention_mask"] == 0) def test_dtypes(self, setup): it = setup[0][0] assert it["labels"].dtype == torch.float32 and it["tabular_features"].dtype == torch.float32 def test_mismatch(self): seqs, tab, labels = make_labeled_data(10) b, t = build_tok(seqs) with pytest.raises(AssertionError): DomainFinetuneDataset(seqs[:5], tab, labels, b, t) def test_stats(self, setup): assert setup[0].get_stats()["n_samples"] == 10 class TestBatching: def test_loader(self): seqs, tab, labels = make_labeled_data(8, 5) b, t = build_tok(seqs) ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32) batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4))) assert batch["input_ids"].shape == (4, 32) and batch["tabular_features"].shape == (4, 5) class TestForwardBackward: def test_forward(self): seqs, tab, labels = make_labeled_data(8, 5) b, t = build_tok(seqs) ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32) batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4))) out = make_fusion(t.vocab_size, 5)(** batch) assert out["loss"].item() > 0 def test_backward(self): seqs, tab, labels = make_labeled_data(4, 5) b, t = build_tok(seqs) ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32) batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4))) model = make_fusion(t.vocab_size, 5) model(**batch)["loss"].backward() assert model.transformer.model.embed_tokens.weight.grad is not None assert model.plr.frequencies.grad is not None def test_multiclass(self): seqs, tab, _ = make_labeled_data(8, 5) labels = np.array([random.randint(0, 2) for _ in range(8)]) b, t = build_tok(seqs) ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32) batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4))) batch["labels"] = batch["labels"].long() out = make_fusion(t.vocab_size, 5, 3)(**batch) assert out["logits"].shape == (4, 3) and out["loss"] is not None class TestTrainer: def test_smoke(self, tmp_path): seqs, tab, labels = make_labeled_data(20, 5) b, t = build_tok(seqs) ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32) trainer = finetune_domain_model( make_fusion(t.vocab_size, 5), ds, output_dir=str(tmp_path), num_epochs=1, per_device_batch_size=4, learning_rate=1e-3, warmup_steps=0, logging_steps=1, save_strategy="no", report_to="none", seed=42, ) assert trainer.state.global_step > 0 losses = [h["loss"] for h in trainer.state.log_history if "loss" in h] assert len(losses) > 0 class TestPrepare: def test_prepare(self): seqs, tab, labels = make_labeled_data(10, 5) b, t = build_tok(seqs) ds = prepare_finetune_dataset(seqs, tab, labels, b, t, 32) assert len(ds) == 10