| """Verify LSTM removal at the 3 wiring points, VideoHead/TalkerHead preserved.""" |
| import torch |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
|
|
| from arbitor.main import ARBModel |
| from arbitor.components import LossWeights, LossComponents |
|
|
|
|
| def _get_model(): |
| return ARBModel() |
|
|
|
|
| def test_lstm_removed(): |
| model = _get_model() |
| assert not hasattr(model, "lstm"), "lstm should be removed" |
| assert not hasattr(model, "lstm_enabled"), "lstm_enabled should be removed" |
| assert not hasattr(model, "switch_conversation"), "switch_conversation removed" |
| assert not hasattr(model, "reset_conversation"), "reset_conversation removed" |
| print(" PASS test_lstm_removed") |
|
|
|
|
| def test_conv_vq_removed(): |
| model = _get_model() |
| assert not hasattr(model, "conv_vq"), "conv_vq should be removed" |
| assert not hasattr(model, "conv_vq_enabled"), "conv_vq_enabled removed" |
| assert not hasattr(model, "_conv_vq_ready"), "_conv_vq_ready removed" |
| print(" PASS test_conv_vq_removed") |
|
|
|
|
| def test_video_head_preserved(): |
| model = _get_model() |
| assert hasattr(model, "video_head"), "video_head should exist" |
| print(" PASS test_video_head_preserved") |
|
|
|
|
| def test_talker_head_preserved(): |
| model = _get_model() |
| assert hasattr(model, "talker_head"), "talker_head should exist" |
| print(" PASS test_talker_head_preserved") |
|
|
|
|
| def test_attention_wired(): |
| model = _get_model() |
| assert hasattr(model, "attention"), "attention module missing" |
| assert hasattr(model, "attention_enabled"), "attention_enabled flag missing" |
| print(" PASS test_attention_wired") |
|
|
|
|
| def test_kv_ledger_exists(): |
| model = _get_model() |
| assert hasattr(model, "kv_ledger"), "kv_ledger missing" |
| assert hasattr(model, "kq_cache"), "kq_cache missing" |
| print(" PASS test_kv_ledger_exists") |
|
|
|
|
| def test_loss_cleanup(): |
| lw = LossWeights() |
| assert not hasattr(lw, "lstm_hidden_reg") |
| assert not hasattr(lw, "conv_vq_commitment") |
| lc = LossComponents(weights=lw) |
| assert not hasattr(lc, "lstm_hidden_reg") |
| assert not hasattr(lc, "conv_vq_commitment") |
| print(" PASS test_loss_cleanup") |
|
|
|
|
| def test_moe_clean(): |
| model = _get_model() |
| if model.moe is not None: |
| assert not hasattr(model.moe, "router_h"), "router_h should be removed" |
| assert not hasattr(model.moe, "lstm_enabled"), "lstm_enabled should be removed" |
| print(" PASS test_moe_clean") |
|
|
|
|
| def test_memory_budget(): |
| from arbitor.config import KV_LEDGER_SIZE, KQ_CACHE_SIZE |
| ledger_bytes = KV_LEDGER_SIZE * 4 |
| kq_bytes = KQ_CACHE_SIZE * 4 |
| total_mb = (ledger_bytes + kq_bytes) / (1024 * 1024) |
| assert total_mb < 100 |
| print(f" PASS test_memory_budget ({total_mb:.2f} MB)") |
|
|
|
|
| if __name__ == "__main__": |
| test_lstm_removed() |
| test_conv_vq_removed() |
| test_video_head_preserved() |
| test_talker_head_preserved() |
| test_attention_wired() |
| test_kv_ledger_exists() |
| test_loss_cleanup() |
| test_moe_clean() |
| test_memory_budget() |
| print("\nAll LSTM removal tests PASS") |
|
|