File size: 8,566 Bytes
494c9e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
"""模型管理模块:主槽位与语义槽位对称配置,权重缓存共用。"""
from enum import Enum
import threading

from backend import REGISTERED_MODELS
from backend.project_registry import ModelRegistry
from backend.device import DeviceManager
from backend.model_loader import attn_implementation_for_device, load_causal_lm, load_tokenizer

from model_paths import DEFAULT_MODEL, DEFAULT_SEMANTIC_MODEL, resolve_hf_path

project_registry = ModelRegistry(REGISTERED_MODELS)
_init_lock = threading.Lock()

# 统一推理锁:信息密度分析与 Semantic 分析共用,确保模型推理串行执行
_inference_lock = threading.Lock()

# 按 HuggingFace 路径去重的已加载模型缓存(主分析 / 语义 / 续写共用)
_hf_load_lock = threading.Lock()
_hf_loaded: dict[str, tuple] = {}


class ModelSlot(str, Enum):
    """与 CLI --model / --semantic_model 对应的两个对等槽位。"""

    MAIN = "main"
    SEMANTIC = "semantic"


# 启动预载与「全部权重」枚举时使用的槽位顺序(对等、无主次)
CONFIGURED_SLOTS: tuple[ModelSlot, ...] = (ModelSlot.MAIN, ModelSlot.SEMANTIC)


def _resolved_hf_path_for_slot(slot: ModelSlot) -> str:
    """由应用上下文解析槽位对应的 HuggingFace 路径(或本地路径字符串)。"""
    if slot == ModelSlot.MAIN:
        try:
            from backend.app_context import get_app_context

            context = get_app_context(prefer_module_context=True)
            model_name = context.model_name or DEFAULT_MODEL
        except RuntimeError:
            model_name = DEFAULT_MODEL
        return resolve_hf_path(model_name)
    if slot == ModelSlot.SEMANTIC:
        try:
            from backend.app_context import get_args

            raw = getattr(get_args(), "semantic_model", DEFAULT_SEMANTIC_MODEL)
        except RuntimeError:
            raw = DEFAULT_SEMANTIC_MODEL
        return resolve_hf_path(raw)
    raise ValueError(f"unknown ModelSlot: {slot!r}")


def ensure_slot_weights_loaded(slot: ModelSlot):
    """
    加载指定槽位权重(若未缓存);主 / 语义完全相同的入口。
    返回 (tokenizer, model, device)。
    """
    return ensure_model_loaded(_resolved_hf_path_for_slot(slot))


def ensure_model_loaded(resolved_hf_path: str):
    """
    唯一底层加载入口:保证 resolved_hf_path 对应权重已加载。
    返回 (tokenizer, model, device),其中 device 为模型参数所在 device。
    """
    with _hf_load_lock:
        hit = _hf_loaded.get(resolved_hf_path)
        if hit is not None:
            return hit

        device = DeviceManager.get_device()
        display = resolved_hf_path.split("/")[-1] if "/" in resolved_hf_path else resolved_hf_path
        print(f"📦 正在加载模型权重: {display}")
        tokenizer = load_tokenizer(resolved_hf_path)
        model = load_causal_lm(
            resolved_hf_path,
            device,
            attn_implementation=attn_implementation_for_device(device),
        )
        for p in model.parameters():
            p.requires_grad_(False)
        model_device = next(model.parameters()).device
        device_name = DeviceManager.get_device_name(device)
        print(f"✓ {display} 已加载 ({device_name})")
        out = (tokenizer, model, model_device)
        _hf_loaded[resolved_hf_path] = out
        return out


def ensure_project_loaded(project_name: str):
    """确保项目已加载,如果未加载则加载它"""
    if not project_name:
        raise ValueError("model name is required")
    if not project_registry.is_available(project_name):
        raise KeyError(project_name)
    try:
        return project_registry.ensure_loaded(project_name)
    except KeyError:
        # Re-raise to allow caller to format message uniformly.
        raise
    except Exception as exc:  # noqa: BLE001 - propagate detailed message
        raise RuntimeError(f"模型 '{project_name}' 加载失败: {exc}") from exc


def _register_main_qwenlm_if_needed():
    """
    信息密度路径:在 MAIN 槽位权重已就绪后,注册 project_registry 中的 QwenLM 实例。
    语义槽位无对应 registry 包装,故仅此槽位需要。
    """
    from backend.app_context import get_app_context

    context = get_app_context(prefer_module_context=True)
    selected_name = context.model_name

    if not selected_name:
        raise ValueError("未指定模型名称")

    if selected_name in project_registry:
        _ensure_default_project_ready(selected_name)
        return

    if not project_registry.is_available(selected_name):
        raise KeyError(f"模型 '{selected_name}' 未找到,可用模型: {list(REGISTERED_MODELS.keys())}")

    try:
        project_registry.load(selected_name)
        _ensure_default_project_ready(selected_name)
    except Exception as exc:  # noqa: BLE001
        raise RuntimeError(f"模型 '{selected_name}' 加载失败: {exc}") from exc


def preload_all_slots():
    """
    启动预载(非 --no_auto_load):对 CONFIGURED_SLOTS 各解析 HF 路径,去重后加载全部权重,
    再注册主槽位 QwenLM 项目。两槽位在「先加载权重」层面完全对等。
    """
    from backend.app_context import get_app_context

    get_app_context(prefer_module_context=True)

    paths = {_resolved_hf_path_for_slot(s) for s in CONFIGURED_SLOTS}

    with _init_lock:
        for path in paths:
            ensure_model_loaded(path)
        _register_main_qwenlm_if_needed()


def ensure_slot_ready(slot: ModelSlot):
    """
    槽位业务就绪(对称 API):保证该槽位后续推理所需状态已备好。

    - 两槽位均先保证 HF 权重已加载,返回 (tokenizer, model, device)。
    - MAIN 另需将 QwenLM 挂入 project_registry(信息密度管线);SEMANTIC 无 registry 步骤。

    懒加载时:信息密度调 ensure_main_slot_ready();语义/续写调 ensure_semantic_slot_ready()。
    """
    from backend.app_context import get_app_context

    get_app_context(prefer_module_context=True)

    if slot == ModelSlot.MAIN:
        with _init_lock:
            out = ensure_slot_weights_loaded(ModelSlot.MAIN)
            _register_main_qwenlm_if_needed()
            return out
    if slot == ModelSlot.SEMANTIC:
        return ensure_slot_weights_loaded(ModelSlot.SEMANTIC)
    raise ValueError(f"unknown ModelSlot: {slot!r}")


def ensure_main_slot_ready():
    """懒加载首次信息密度:同 ensure_slot_ready(ModelSlot.MAIN)。"""
    return ensure_slot_ready(ModelSlot.MAIN)


def ensure_semantic_slot_ready():
    """懒加载首次语义类请求:同 ensure_slot_ready(ModelSlot.SEMANTIC)。"""
    return ensure_slot_ready(ModelSlot.SEMANTIC)


def get_current_model_max_token_length() -> int:
    """
    查询当前生效模型的 max_token_length 参数。
    优先从已加载的模型实例获取,未加载时取 default_model.default_cpu_machine 配置。
    """
    from backend.app_context import get_app_context
    from backend.runtime_config import RUNTIME_CONFIGS

    try:
        context = get_app_context(prefer_module_context=True)
        model_name = context.model_name or DEFAULT_MODEL
    except RuntimeError:
        model_name = "default_model"

    project = project_registry.get(model_name)
    if project is not None and hasattr(project.lm, "max_length"):
        return project.lm.max_length
    return RUNTIME_CONFIGS["default_model"]["default_cpu_machine"]["max_token_length"]


def _ensure_default_project_ready(selected_name: str):
    """确保默认项目已准备好"""
    if not selected_name:
        return
    if selected_name in project_registry:
        return
    print(f"⚠️ 默认模型未缓存,正在预加载: {selected_name}")
    project_registry.ensure_loaded(selected_name)


# 旧名保留(与槽位就绪 API 等价)
ensure_semantic_loaded = ensure_semantic_slot_ready
ensure_main_project_ready = ensure_main_slot_ready

def get_semantic_model_display_name() -> str:
    """返回 semantic 槽位 HuggingFace 路径(用于结果中的 model 字段)"""
    return _resolved_hf_path_for_slot(ModelSlot.SEMANTIC)


def ensure_main_model_loaded():
    """
    仅需主模型前向、且不必经过 project_registry 时(如 attribution):MAIN 槽位权重。
    """
    return ensure_slot_weights_loaded(ModelSlot.MAIN)


def get_main_model_display_name() -> str:
    """返回主槽位 HuggingFace 路径(用于结果中的 model 字段)"""
    return _resolved_hf_path_for_slot(ModelSlot.MAIN)