File size: 6,257 Bytes
abab711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""
Tests for domainTokenizer Phase 2D: Fine-tuning Pipeline.
15 tests covering dataset, batching, forward/backward, Trainer smoke, multiclass.

Run: pytest tests/test_finetune.py -v
"""

import logging
import random
from datetime import datetime, timedelta

import numpy as np
import torch
import pytest

from domain_tokenizer.schemas.predefined import FINANCE_SCHEMA
from domain_tokenizer.tokenizers.domain_tokenizer import DomainTokenizerBuilder
from domain_tokenizer.models.configuration import DomainTransformerConfig
from domain_tokenizer.models.modeling import DomainTransformerForCausalLM
from domain_tokenizer.models.joint_fusion import JointFusionModel
from domain_tokenizer.training.finetune_data import DomainFinetuneDataset, prepare_finetune_dataset
from domain_tokenizer.training.finetune import finetune_domain_model

logging.basicConfig(level=logging.INFO)


def make_events(n=10, seed=42):
    rng = random.Random(seed)
    merchants = ["AMAZON", "UBER", "SALARY", "GROCERY", "NETFLIX", "GAS"]
    base = datetime(2025, 1, 1)
    return [
        {"amount_sign": (a := rng.uniform(5, 5000) * rng.choice([1, -1])),
         "amount": a, "timestamp": base + timedelta(days=rng.randint(0, 365), hours=rng.randint(0, 23)),
         "description": rng.choice(merchants)}
        for _ in range(n)
    ]


def make_labeled_data(n_users=20, n_tabular=10, seed=42):
    rng = random.Random(seed)
    seqs = [make_events(rng.randint(3, 15), seed + i) for i in range(n_users)]
    tab = np.random.RandomState(seed).randn(n_users, n_tabular).astype(np.float32)
    labels = np.array([rng.choice([0.0, 1.0]) for _ in range(n_users)])
    return seqs, tab, labels


def build_tok(seqs):
    flat = [e for s in seqs for e in s]
    b = DomainTokenizerBuilder(FINANCE_SCHEMA)
    b.fit(flat)
    return b, b.build(text_corpus=list(set(e["description"] for e in flat)) * 20, bpe_vocab_size=300)


def tiny_cfg(v=128):
    return DomainTransformerConfig(vocab_size=v, hidden_size=64, num_hidden_layers=2,
                                    num_attention_heads=4, intermediate_size=128)


def make_fusion(v, nt=10, nc=1):
    return JointFusionModel(
        DomainTransformerForCausalLM(tiny_cfg(v)), nt, nc,
        plr_frequencies=4, plr_embedding_dim=8, dcn_cross_layers=2,
        dcn_deep_layers=1, dcn_deep_dim=32, head_hidden_dim=32,
    )


class TestDataset:
    @pytest.fixture
    def setup(self):
        seqs, tab, labels = make_labeled_data(10, 5)
        b, t = build_tok(seqs)
        return DomainFinetuneDataset(seqs, tab, labels, b, t, 64), t

    def test_len(self, setup):
        assert len(setup[0]) == 10

    def test_keys(self, setup):
        assert set(setup[0][0].keys()) == {"input_ids", "attention_mask", "tabular_features", "labels"}

    def test_shapes(self, setup):
        it = setup[0][0]
        assert it["input_ids"].shape == (64,) and it["tabular_features"].shape == (5,)

    def test_padding(self, setup):
        it = setup[0][0]
        assert (it["input_ids"] != setup[1].pad_token_id).any()

    def test_mask_matches_pad(self, setup):
        it = setup[0][0]
        assert torch.equal(it["input_ids"] == setup[1].pad_token_id, it["attention_mask"] == 0)

    def test_dtypes(self, setup):
        it = setup[0][0]
        assert it["labels"].dtype == torch.float32 and it["tabular_features"].dtype == torch.float32

    def test_mismatch(self):
        seqs, tab, labels = make_labeled_data(10)
        b, t = build_tok(seqs)
        with pytest.raises(AssertionError):
            DomainFinetuneDataset(seqs[:5], tab, labels, b, t)

    def test_stats(self, setup):
        assert setup[0].get_stats()["n_samples"] == 10


class TestBatching:
    def test_loader(self):
        seqs, tab, labels = make_labeled_data(8, 5)
        b, t = build_tok(seqs)
        ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
        batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
        assert batch["input_ids"].shape == (4, 32) and batch["tabular_features"].shape == (4, 5)


class TestForwardBackward:
    def test_forward(self):
        seqs, tab, labels = make_labeled_data(8, 5)
        b, t = build_tok(seqs)
        ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
        batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
        out = make_fusion(t.vocab_size, 5)(** batch)
        assert out["loss"].item() > 0

    def test_backward(self):
        seqs, tab, labels = make_labeled_data(4, 5)
        b, t = build_tok(seqs)
        ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
        batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
        model = make_fusion(t.vocab_size, 5)
        model(**batch)["loss"].backward()
        assert model.transformer.model.embed_tokens.weight.grad is not None
        assert model.plr.frequencies.grad is not None

    def test_multiclass(self):
        seqs, tab, _ = make_labeled_data(8, 5)
        labels = np.array([random.randint(0, 2) for _ in range(8)])
        b, t = build_tok(seqs)
        ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
        batch = next(iter(torch.utils.data.DataLoader(ds, batch_size=4)))
        batch["labels"] = batch["labels"].long()
        out = make_fusion(t.vocab_size, 5, 3)(**batch)
        assert out["logits"].shape == (4, 3) and out["loss"] is not None


class TestTrainer:
    def test_smoke(self, tmp_path):
        seqs, tab, labels = make_labeled_data(20, 5)
        b, t = build_tok(seqs)
        ds = DomainFinetuneDataset(seqs, tab, labels, b, t, 32)
        trainer = finetune_domain_model(
            make_fusion(t.vocab_size, 5), ds,
            output_dir=str(tmp_path), num_epochs=1, per_device_batch_size=4,
            learning_rate=1e-3, warmup_steps=0, logging_steps=1,
            save_strategy="no", report_to="none", seed=42,
        )
        assert trainer.state.global_step > 0
        losses = [h["loss"] for h in trainer.state.log_history if "loss" in h]
        assert len(losses) > 0


class TestPrepare:
    def test_prepare(self):
        seqs, tab, labels = make_labeled_data(10, 5)
        b, t = build_tok(seqs)
        ds = prepare_finetune_dataset(seqs, tab, labels, b, t, 32)
        assert len(ds) == 10