File size: 2,754 Bytes
e6b3389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import asyncio

import pytest

from server.registry import Registry


pytestmark = pytest.mark.asyncio


async def test_get_or_load_loads_first_time(fake_classes):
    reg = Registry(adapter_classes=fake_classes, device="cpu")
    a = await reg.get_or_load("fake")
    assert a.loaded is True
    assert reg.status()["status"] == "loaded"
    assert reg.status()["id"] == "fake"


async def test_get_or_load_reuses_active(fake_classes):
    reg = Registry(adapter_classes=fake_classes, device="cpu")
    a1 = await reg.get_or_load("fake")
    a2 = await reg.get_or_load("fake")
    assert a1 is a2


async def test_get_or_load_swaps_to_different(fake_classes):
    reg = Registry(adapter_classes=fake_classes, device="cpu")
    a = await reg.get_or_load("fake")
    b = await reg.get_or_load("fake-b")
    assert b.loaded is True
    assert a.unload_called is True
    assert reg.status()["id"] == "fake-b"


async def test_get_or_load_unknown_id_raises(fake_classes):
    reg = Registry(adapter_classes=fake_classes, device="cpu")
    with pytest.raises(KeyError):
        await reg.get_or_load("nope")


async def test_load_failure_sets_error_status(fake_classes):
    reg = Registry(adapter_classes=fake_classes, device="cpu")
    fake_classes["fake"].instances.clear()
    orig_init = fake_classes["fake"].__init__

    def patched_init(self, device):
        orig_init(self, device)
        self.load_should_fail = True

    fake_classes["fake"].__init__ = patched_init
    try:
        with pytest.raises(RuntimeError):
            await reg.get_or_load("fake")
        s = reg.status()
        assert s["status"] == "error"
        assert "simulated" in s["last_error"]
    finally:
        fake_classes["fake"].__init__ = orig_init


async def test_concurrent_activations_serialize(fake_classes):
    reg = Registry(adapter_classes=fake_classes, device="cpu")
    # All three resolve without error; the final state is the last requested model.
    await asyncio.gather(
        reg.get_or_load("fake"),
        reg.get_or_load("fake-b"),
        reg.get_or_load("fake"),
    )
    s = reg.status()
    assert s["status"] == "loaded"
    assert s["id"] == "fake"


async def test_emits_loading_then_loaded_events(fake_classes):
    reg = Registry(adapter_classes=fake_classes, device="cpu")
    seen: list[dict] = []

    async def collect():
        async for evt in reg.stream_events():
            seen.append(evt)
            if evt["status"] == "loaded":
                return

    consumer = asyncio.create_task(collect())
    await asyncio.sleep(0)
    await reg.get_or_load("fake")
    await asyncio.wait_for(consumer, timeout=2)
    statuses = [e["status"] for e in seen]
    assert "loading" in statuses
    assert "loaded" in statuses