File size: 1,865 Bytes
d309047
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import threading
import time
from typing import Any, Dict

from .embeddings import load_embedding_model
from .topic_model import load_topic_model
from .classifier import load_classifier
from .drift_detector import load_drift_model
from .rewriter import load_rewriter
from .stt import load_stt_model

class ModelCache:
    def __init__(self):
        self._lock = threading.Lock()
        self._cache: Dict[str, Any] = {
            "stt": load_stt_model(),
            "embedder": load_embedding_model(),
            "classifier": load_classifier(),
            "drift_model": load_drift_model(),
        }
        self._last_access: Dict[str, float] = {k: time.time() for k in self._cache}
        self._ttl = 3600  # 1 hour unused retention

    def _check_expire(self):
        now = time.time()
        expired = [k for k, t in self._last_access.items() if now - t > self._ttl]
        for key in expired:
            self._cache.pop(key, None)
            self._last_access.pop(key, None)

    def get(self, key: str):
        with self._lock:
            self._check_expire()
            if key in self._cache:
                self._last_access[key] = time.time()
                return self._cache[key]

            model = self._load_model(key)
            self._cache[key] = model
            self._last_access[key] = time.time()
            return model

    def _load_model(self, key: str):
        if key == "embedder":
            return load_embedding_model()
        if key == "topic_model":
            return load_topic_model()
        if key == "classifier":
            return load_classifier()
        if key == "drift_model":
            return load_drift_model()
        if key == "rewriter":
            return load_rewriter()
        if key == "stt":
            return load_stt_model()
        raise ValueError(f"Unknown model key: {key}")