File size: 2,564 Bytes
3eae4cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Tests for FeatureBuilder using the real ObservationModel schema."""

from __future__ import annotations

import numpy as np
import pytest

from app.models import (
    ObservationModel,
    OfficerPool,
    PriorityMode,
    QueueSnapshot,
    ServiceType,
    StageType,
)
from rl.feature_builder import FeatureBuilder, OBS_DIM


def _make_obs() -> ObservationModel:
    snapshots = {}
    for i, svc in enumerate(ServiceType):
        snapshots[svc] = QueueSnapshot(
            service_type=svc,
            public_stage_counts={
                StageType.SUBMISSION.value: 2 + i % 2,
                StageType.DOCUMENT_VERIFICATION.value: 1,
                StageType.FIELD_VERIFICATION.value: 1,
                StageType.APPROVAL.value: 0,
                StageType.ISSUANCE.value: 0,
            },
            total_pending=6 + i,
            blocked_missing_docs=i % 3,
            urgent_pending=2 if i % 3 == 0 else 1,
            total_sla_breached=0,
            avg_waiting_days=3.0 + i,
        )

    return ObservationModel(
        task_id="district_backlog_easy",
        episode_id="ep-test",
        day=8,
        max_days=20,
        officer_pool=OfficerPool(
            total_officers=len(ServiceType) + 2,
            available_officers=len(ServiceType) + 2,
            allocated={svc: 1 for svc in ServiceType},
        ),
        queue_snapshots=snapshots,
        total_backlog=sum(s.total_pending for s in snapshots.values()),
        total_completed=15,
        total_sla_breaches=3,
        fairness_index=1.0 - 0.12,
        escalation_budget_remaining=4,
        last_action_valid=True,
        last_action_message="ok",
    )


@pytest.fixture
def builder() -> FeatureBuilder:
    return FeatureBuilder()


def test_output_shape(builder: FeatureBuilder) -> None:
    assert builder.build(_make_obs()).shape == (OBS_DIM,)


def test_output_dtype(builder: FeatureBuilder) -> None:
    assert builder.build(_make_obs()).dtype == np.float32


def test_deterministic(builder: FeatureBuilder) -> None:
    obs = _make_obs()
    np.testing.assert_array_equal(
        builder.build(obs, "urgent_first", "advance_time"),
        builder.build(obs, "urgent_first", "advance_time"),
    )


def test_no_nan_or_inf(builder: FeatureBuilder) -> None:
    vec = builder.build(_make_obs())
    assert not np.any(np.isnan(vec))
    assert not np.any(np.isinf(vec))


def test_values_in_reasonable_range(builder: FeatureBuilder) -> None:
    vec = builder.build(_make_obs())
    assert np.all(vec >= 0.0)
    assert np.all(vec <= 1.0 + 1e-6)