"""Verify LSTM removal at the 3 wiring points, VideoHead/TalkerHead preserved.""" import torch import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from arbitor.main import ARBModel from arbitor.components import LossWeights, LossComponents def _get_model(): return ARBModel() def test_lstm_removed(): model = _get_model() assert not hasattr(model, "lstm"), "lstm should be removed" assert not hasattr(model, "lstm_enabled"), "lstm_enabled should be removed" assert not hasattr(model, "switch_conversation"), "switch_conversation removed" assert not hasattr(model, "reset_conversation"), "reset_conversation removed" print(" PASS test_lstm_removed") def test_conv_vq_removed(): model = _get_model() assert not hasattr(model, "conv_vq"), "conv_vq should be removed" assert not hasattr(model, "conv_vq_enabled"), "conv_vq_enabled removed" assert not hasattr(model, "_conv_vq_ready"), "_conv_vq_ready removed" print(" PASS test_conv_vq_removed") def test_video_head_preserved(): model = _get_model() assert hasattr(model, "video_head"), "video_head should exist" print(" PASS test_video_head_preserved") def test_talker_head_preserved(): model = _get_model() assert hasattr(model, "talker_head"), "talker_head should exist" print(" PASS test_talker_head_preserved") def test_attention_wired(): model = _get_model() assert hasattr(model, "attention"), "attention module missing" assert hasattr(model, "attention_enabled"), "attention_enabled flag missing" print(" PASS test_attention_wired") def test_kv_ledger_exists(): model = _get_model() assert hasattr(model, "kv_ledger"), "kv_ledger missing" assert hasattr(model, "kq_cache"), "kq_cache missing" print(" PASS test_kv_ledger_exists") def test_loss_cleanup(): lw = LossWeights() assert not hasattr(lw, "lstm_hidden_reg") assert not hasattr(lw, "conv_vq_commitment") lc = LossComponents(weights=lw) assert not hasattr(lc, "lstm_hidden_reg") assert not hasattr(lc, "conv_vq_commitment") print(" PASS test_loss_cleanup") def test_moe_clean(): model = _get_model() if model.moe is not None: assert not hasattr(model.moe, "router_h"), "router_h should be removed" assert not hasattr(model.moe, "lstm_enabled"), "lstm_enabled should be removed" print(" PASS test_moe_clean") def test_memory_budget(): from arbitor.config import KV_LEDGER_SIZE, KQ_CACHE_SIZE ledger_bytes = KV_LEDGER_SIZE * 4 kq_bytes = KQ_CACHE_SIZE * 4 total_mb = (ledger_bytes + kq_bytes) / (1024 * 1024) assert total_mb < 100 print(f" PASS test_memory_budget ({total_mb:.2f} MB)") if __name__ == "__main__": test_lstm_removed() test_conv_vq_removed() test_video_head_preserved() test_talker_head_preserved() test_attention_wired() test_kv_ledger_exists() test_loss_cleanup() test_moe_clean() test_memory_budget() print("\nAll LSTM removal tests PASS")