ARBS / testing /attention /test_lstm_removal_clean.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Verify LSTM removal — structural checks only, no ARBModel construction."""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from arbitor.components import LossWeights, LossComponents
def test_loss_cleanup():
lw = LossWeights()
assert not hasattr(lw, "lstm_hidden_reg")
assert not hasattr(lw, "conv_vq_commitment")
lc = LossComponents(weights=lw)
assert not hasattr(lc, "lstm_hidden_reg")
assert not hasattr(lc, "conv_vq_commitment")
print(" PASS test_loss_cleanup")
def test_moe_clean():
from arbitor.components import MoEGraph
moe = MoEGraph(cb_dim=64, trigram_dim=128, num_experts=4,
core_rank=32, shared_inter=64, max_iters=1,
codebook_size=128, active_graph_max_nodes=0)
assert not hasattr(moe, "router_h"), "router_h should be removed"
assert not hasattr(moe, "lstm_enabled"), "lstm_enabled should be removed"
print(" PASS test_moe_clean")
def test_classes_removed():
import arbitor.components as C
assert not hasattr(C, "ConvVQCodebook"), "ConvVQCodebook should be removed"
assert not hasattr(C, "FocusGate"), "FocusGate should be removed"
assert not hasattr(C, "ConversationStack"), "ConversationStack should be removed"
assert not hasattr(C, "ConversationLSTM"), "ConversationLSTM should be removed"
print(" PASS test_classes_removed")
def test_attention_imports():
from arbitor.attention import GPURingBuffer, KVLedger, KQCache
from arbitor.attention import MultiHeadLatentAttention, ContextAttentionScheduler
from arbitor.attention.mla import apply_rotary_emb, precompute_freqs_cis
print(" PASS test_attention_imports")
def test_video_talker_imports():
from arbitor.decoders import VideoHead, TalkerHead
print(" PASS test_video_talker_imports")
def test_memory_budget():
from arbitor.config import KV_LEDGER_SIZE, KQ_CACHE_SIZE
ledger_bytes = KV_LEDGER_SIZE * 4
kq_bytes = KQ_CACHE_SIZE * 4
total_mb = (ledger_bytes + kq_bytes) / (1024 * 1024)
assert total_mb < 100
print(f" PASS test_memory_budget ({total_mb:.2f} MB)")
if __name__ == "__main__":
test_loss_cleanup()
test_moe_clean()
test_classes_removed()
test_attention_imports()
test_video_talker_imports()
test_memory_budget()
print("\nAll LSTM removal tests PASS")