| """Unit tests for MegaLLM API Key Rotation. |
| |
| Run with: |
| cd /Volumes/WorkSpace/Project/LocalMate/localmate-danang-backend-v2 |
| python -m pytest tests/test_key_rotation.py -v |
| """ |
|
|
| import os |
| import threading |
| from concurrent.futures import ThreadPoolExecutor |
| from unittest.mock import patch |
|
|
| import pytest |
|
|
|
|
| class TestKeyRotator: |
| """Tests for KeyRotator class.""" |
|
|
| def test_rotation_cycles_through_keys(self): |
| """Verify round-robin cycles through all keys in order.""" |
| from app.shared.integrations.key_rotator import KeyRotator |
| |
| keys = ["key_1", "key_2", "key_3"] |
| rotator = KeyRotator(keys, name="test") |
| |
| |
| assert rotator.get_next_key() == "key_1" |
| assert rotator.get_next_key() == "key_2" |
| assert rotator.get_next_key() == "key_3" |
| |
| |
| assert rotator.get_next_key() == "key_1" |
| assert rotator.get_next_key() == "key_2" |
| assert rotator.get_next_key() == "key_3" |
| |
| |
| assert rotator.request_count == 6 |
| |
| def test_single_key_always_returns_same(self): |
| """Verify single key mode works correctly.""" |
| from app.shared.integrations.key_rotator import KeyRotator |
| |
| keys = ["only_key"] |
| rotator = KeyRotator(keys, name="single") |
| |
| for _ in range(5): |
| assert rotator.get_next_key() == "only_key" |
| |
| assert rotator.request_count == 5 |
| |
| def test_empty_keys_raises_error(self): |
| """Verify empty keys list raises ValueError.""" |
| from app.shared.integrations.key_rotator import KeyRotator |
| |
| with pytest.raises(ValueError, match="At least one API key is required"): |
| KeyRotator([], name="empty") |
| |
| def test_rotation_thread_safety(self): |
| """Verify rotation is thread-safe under concurrent access.""" |
| from app.shared.integrations.key_rotator import KeyRotator |
| |
| keys = ["key_1", "key_2", "key_3"] |
| rotator = KeyRotator(keys, name="threaded") |
| |
| results = [] |
| lock = threading.Lock() |
| |
| def get_key(): |
| key = rotator.get_next_key() |
| with lock: |
| results.append(key) |
| |
| |
| with ThreadPoolExecutor(max_workers=10) as executor: |
| futures = [executor.submit(get_key) for _ in range(100)] |
| for future in futures: |
| future.result() |
| |
| |
| assert len(results) == 100 |
| assert rotator.request_count == 100 |
| |
| |
| for key in keys: |
| count = results.count(key) |
| |
| assert 20 <= count <= 45, f"Key {key} used {count} times (expected ~33)" |
| |
| def test_get_stats(self): |
| """Verify stats reporting works.""" |
| from app.shared.integrations.key_rotator import KeyRotator |
| |
| keys = ["key_1", "key_2"] |
| rotator = KeyRotator(keys, name="stats_test") |
| |
| rotator.get_next_key() |
| rotator.get_next_key() |
| rotator.get_next_key() |
| |
| stats = rotator.get_stats() |
| assert stats["name"] == "stats_test" |
| assert stats["total_keys"] == 2 |
| assert stats["total_requests"] == 3 |
| assert stats["current_index"] == 1 |
|
|
|
|
| class TestLoadMegaLLMKeys: |
| """Tests for environment-based key loading.""" |
| |
| def test_load_numbered_keys(self): |
| """Verify loading MEGALLM_API_KEY_1, _2, _3 format.""" |
| env_vars = { |
| "MEGALLM_API_KEY_1": "first_key", |
| "MEGALLM_API_KEY_2": "second_key", |
| "MEGALLM_API_KEY_3": "third_key", |
| } |
| |
| with patch.dict(os.environ, env_vars, clear=False): |
| from importlib import reload |
| from app.shared.integrations import key_rotator |
| reload(key_rotator) |
| |
| keys = key_rotator.load_megallm_keys() |
| assert keys == ["first_key", "second_key", "third_key"] |
| |
| def test_load_fallback_single_key(self): |
| """Verify fallback to MEGALLM_API_KEY (legacy format).""" |
| |
| env_vars = { |
| "MEGALLM_API_KEY": "legacy_key", |
| } |
| |
| |
| for i in range(1, 10): |
| env_vars[f"MEGALLM_API_KEY_{i}"] = "" |
| |
| with patch.dict(os.environ, env_vars, clear=False): |
| from importlib import reload |
| from app.shared.integrations import key_rotator |
| reload(key_rotator) |
| |
| |
| keys = key_rotator.load_megallm_keys() |
| |
| assert len(keys) >= 1 |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main([__file__, "-v"]) |
|
|