File size: 14,061 Bytes
494c9e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
"""
运行时配置管理模块

负责管理不同模型在不同平台下的运行时参数配置,包括:
- max_token_length: 文本分析的最大 token 数限制(信息密度分析)
- chunk_size: 推理时的分块大小
- 语义分析有独立的 SEMANTIC_RUNTIME_CONFIGS,仅含 max_token_length

平台 ID 说明:
- local_mps: 本地 Apple Silicon(M1/M2/M3)
- cloud_cuda: 云端 CUDA GPU
- cloud_cpu_16g: 云端大内存 CPU(如 HF Space 免费层,16G RAM)
- cloud_cpu_32g: 云端大内存 CPU(如 HF Space CPU upgrade,32G RAM)
- default_cpu_machine: 默认 CPU 机器(未知或未识别的 CPU 环境)
- 未来可扩展: cloud_cuda_a100, cloud_cuda_24g 等
"""

import os
import torch
import sys
from typing import Dict, Optional


# ============= 平台级常量 =============

# 分析接口的 pred_topk 默认数量(候选词数量)
# 前端 ToolTip 显示数量与此保持一致
DEFAULT_TOPK = 10

# MPS 单次 TopK 操作的安全序列长度上限(避免 MPS bug)
# chunk_size 必须小于此值以确保每个 chunk 的 TopK 计算安全
MPS_TOPK_BUG_THRESHOLD = 2048


# ============= 运行时参数配置表 (Model × Platform) =============
# 
# 二维表结构:每个模型针对每个平台配置 max_token_length 和 chunk_size
# 
# 四层覆盖优先级(从高到低):
#   1. (model_name, platform)        - 模型在该平台的专用配置(最精确)
#   2. (model_name, "default_cpu_machine") - 模型的通用配置(跨平台)
#   3. ("default_model", platform)   - 平台的通用配置(跨模型)
#   4. ("default_model", "default_cpu_machine") - 全局兜底配置
#
# 每层支持部分覆盖:只填 max_token_length 或 chunk_size 均可

RUNTIME_CONFIGS = {
    # 全局默认模型配置
    "default_model": {
        # 默认 CPU 机器配置(最保守,用于未识别的 CPU 环境)
        "default_cpu_machine": {
            "max_token_length": 2000,
            "chunk_size": 256
        },
        # 云端 CPU(16G),如 HF Spaces CPU basic
        "cloud_cpu_16g": {
            "max_token_length": 2000,
            "chunk_size": 256
        },
        # 云端 CPU(32G),如 HF Spaces CPU upgrade
        "cloud_cpu_32g": {
            "max_token_length": 5000,
            "chunk_size": 512
        },
        # 云端 GPU 显存充足
        "cloud_cuda": {
            # "max_token_length": 10000,
            "max_token_length": 5000,
            "chunk_size": 1024
        },
        # 本地 Apple Silicon
        "local_mps": {
            "max_token_length": 2000,
            "chunk_size": 512
        }
    },
    # # Qwen3-1.7B
    # "qwen3-1.7b": {
    #     "local_mps": {
    #         "max_token_length": 2000,
    #         "chunk_size": 128
    #     }
    # }
}


# ============= 语义分析运行时配置(仅 max_token_length) =============
# 按平台配置,语义分析独立于信息密度模型

SEMANTIC_RUNTIME_CONFIGS = {
    "default_cpu_machine": {"max_token_length": 300},
    "cloud_cpu_16g": {"max_token_length": 300},
    "cloud_cpu_32g": {"max_token_length": 1000},
    "cloud_cuda": {"max_token_length": 1000},
    "local_mps": {"max_token_length": 300},
}


# ============= 平台检测与配置解析 =============

def detect_platform(verbose: bool = True) -> str:
    """
    自动检测当前运行平台
    
    优先级:
      1. 环境变量 FORCE_CPU(显式强制 CPU 模式)
      2. 自动探测硬件(cuda/mps/cpu)
      3. 细分 CPU 类型(如 cloud_cpu_16g)
    
    Args:
        verbose: 是否打印检测信息
    
    Returns:
        平台 ID 字符串(如 'local_mps', 'cloud_cuda', 'cloud_cpu_16g', 'cloud_cpu_32g', 'default_cpu_machine')
    """
    # 1. 显式强制 CPU(可通过环境变量 FORCE_CPU=1 启用)
    if os.environ.get("FORCE_CPU") == "1":
        print(f"🔧 强制 CPU 模式")
        return _detect_cpu_variant()
    
    # 2. 自动探测 GPU/MPS
    if torch.cuda.is_available():
        platform = "cloud_cuda"
    elif torch.backends.mps.is_available():
        platform = "local_mps"
    else:
        # 3. 细分 CPU 类型
        platform = _detect_cpu_variant()
    
    print(f"🔍 自动检测平台配置: {platform}")
    return platform


def _detect_cpu_variant() -> str:
    """
    检测具体的 CPU 环境变体(内部函数)
    根据内存大小识别不同的 CPU 环境:
    - >= 30GB: cloud_cpu_32g(32G 内存环境)
    - >= 15GB: cloud_cpu_16g(16G 内存环境)
    - 其他: default_cpu_machine(默认配置)
    
    优先检测容器内存限制(cgroup),如果不可用则回退到系统内存检测。
    """
    total_memory = 0
    
    try:
        # 优先检测容器内存限制(cgroup)
        # cgroup v2: /sys/fs/cgroup/memory.max
        # cgroup v1: /sys/fs/cgroup/memory/memory.limit_in_bytes
        cgroup_paths = [
            "/sys/fs/cgroup/memory.max",  # cgroup v2
            "/sys/fs/cgroup/memory/memory.limit_in_bytes",  # cgroup v1
        ]
        
        for cgroup_path in cgroup_paths:
            try:
                if os.path.exists(cgroup_path):
                    with open(cgroup_path, 'r') as f:
                        limit_str = f.read().strip()
                        # cgroup v2 可能返回 "max" 表示无限制
                        if limit_str == "max":
                            break
                        limit_bytes = int(limit_str)
                        if limit_bytes > 0 and limit_bytes < (2 ** 63):  # 合理范围
                            total_memory = limit_bytes
                            print(f"🔍 从 cgroup 检测到容器内存限制: {total_memory / (1024 ** 3):.2f} GB")
                            break
            except (ValueError, IOError, OSError):
                continue
        
        # 如果 cgroup 检测失败,回退到系统内存检测
        if total_memory == 0 and sys.platform != "win32":
            try:
                page_size = os.sysconf('SC_PAGE_SIZE')
                phys_pages = os.sysconf('SC_PHYS_PAGES')
                total_memory = page_size * phys_pages
                print(f"🔍 从系统配置检测到内存: {total_memory / (1024 ** 3):.2f} GB")
            except (ValueError, AttributeError):
                pass
        
        # 转换为 GB
        total_memory_gb = total_memory / (1024 ** 3)
        
        # 判断标准:
        # - >= 30GB: cloud_cpu_32g(HF Spaces CPU upgrade 通常会有 30.x GB 可见)
        # - >= 15GB: cloud_cpu_16g(HF Spaces CPU basic 通常会有 15.x GB 可见)
        if total_memory_gb >= 30.0:
            return "cloud_cpu_32g"
        elif total_memory_gb >= 15.0:
            return "cloud_cpu_16g"
            
    except Exception as e:
        print(f"⚠️  CPU 环境检测失败,回退到默认配置: {e}")
    
    return "default_cpu_machine"


def merge_runtime_config(model_name: str, platform: str, verbose: bool = True) -> Dict[str, int]:
    """
    四层配置合并:支持部分覆盖,并追踪配置来源
    
    优先级(从高到低):
      1. (model_name, platform)        - 模型在该平台的专用配置
      2. (model_name, "default_cpu_machine") - 模型通用配置
      3. ("default_model", platform)   - 平台通用配置
      4. ("default_model", "default_cpu_machine") - 全局兜底
    
    Args:
        model_name: 模型名称(如 'qwen3-1.7b')
        platform: 平台 ID(如 'local_mps')
        verbose: 是否打印配置来源提示
    
    Returns:
        合并后的配置字典 {"max_token_length": int, "chunk_size": int}
    
    Raises:
        ValueError: 配置不完整时抛出
    """
    # 准备四层配置(从低优先级到高优先级)
    layers = [
        {
            "name": "default_model.default_cpu_machine",
            "config": RUNTIME_CONFIGS.get("default_model", {}).get("default_cpu_machine", {})
        },
        {
            "name": f"default_model.{platform}",
            "config": RUNTIME_CONFIGS.get("default_model", {}).get(platform, {})
        },
        {
            "name": f"{model_name}.default_cpu_machine",
            "config": RUNTIME_CONFIGS.get(model_name, {}).get("default_cpu_machine", {})
        },
        {
            "name": f"{model_name}.{platform}",
            "config": RUNTIME_CONFIGS.get(model_name, {}).get(platform, {})
        }
    ]
    
    # 追踪每个配置项的来源
    config_sources = {}  # {"max_token_length": "层级名称", "chunk_size": "层级名称"}
    merged = {}
    
    # 依次合并(后面的覆盖前面的)
    for layer in layers:
        layer_config = layer["config"]
        for key, value in layer_config.items():
            merged[key] = value
            config_sources[key] = layer["name"]
    
    # 确保必需字段存在
    if "max_token_length" not in merged or "chunk_size" not in merged:
        raise ValueError(
            f"配置不完整: model={model_name}, platform={platform}, "
            f"merged={merged}. 缺少必需字段!"
        )
    
    # 打印当前使用的配置项的配置来源
    for key, source in config_sources.items():
        actual_value = merged[key]
        print(f"\t{key}={actual_value} ( {source})")
    
    return merged


_semantic_max_token_length_cache: Optional[int] = None


def get_semantic_max_token_length(verbose: bool = False) -> int:
    """
    获取语义分析的 max_token_length(从 SEMANTIC_RUNTIME_CONFIGS 按平台读取)
    平台检测结果会缓存,避免每次分析重复检测。
    """
    global _semantic_max_token_length_cache
    if _semantic_max_token_length_cache is not None:
        return _semantic_max_token_length_cache
    platform = detect_platform(verbose=verbose)
    config = SEMANTIC_RUNTIME_CONFIGS.get(platform, SEMANTIC_RUNTIME_CONFIGS["default_cpu_machine"])
    _semantic_max_token_length_cache = config["max_token_length"]
    return _semantic_max_token_length_cache


def validate_platform_config(platform: str, chunk_size: int, verbose: bool = True) -> None:
    """
    平台级安全校验(前置到初始化阶段)
    
    Args:
        platform: 平台 ID
        chunk_size: 配置的 chunk_size
        verbose: 是否打印校验信息
    
    Raises:
        ValueError: 配置不符合平台限制时抛出
    """
    # MPS 平台的特殊限制
    if "mps" in platform.lower():
        if chunk_size > MPS_TOPK_BUG_THRESHOLD:
            raise ValueError(
                f"❌ MPS 平台配置错误: chunk_size ({chunk_size}) "
                f"超过安全上限 ({MPS_TOPK_BUG_THRESHOLD})\n"
                f"   平台: {platform}\n"
                f"   建议: 调整 RUNTIME_CONFIGS 中 {platform} 的 chunk_size"
            )
        if verbose:
            print(f"✓ MPS 平台安全检查通过: chunk_size={chunk_size} (上限={MPS_TOPK_BUG_THRESHOLD})")


def _get_cpu_info() -> Optional[str]:
    """
    读取 CPU 型号信息(仅用于显示)
    
    Returns:
        model_name, if None, return "未知"
    """
    model_name = None
    
    try:
        if sys.platform == 'linux':
            with open('/proc/cpuinfo', 'r') as f:
                for line in f:
                    # 读取 model name
                    if model_name is None and 'model name' in line.lower():
                        model_name = line.split(':', 1)[1].strip()
                    
                    # 如果已经读取到所需信息,可以提前退出
                    if model_name:
                        break
    except Exception:
        pass
    
    return model_name


def _print_cpu_info() -> None:
    """
    打印 CPU 型号信息(所有平台都打印)
    """
    try:
        cpu_model = _get_cpu_info()
        model = cpu_model or "未知"
        
        print(f"💻 CPU 型号: {model}")
    except Exception as e:
        print(f"⚠️  CPU 信息获取失败: {e}")


def _print_cpu_thread_info() -> None:
    """打印 CPU 线程配置信息(PyTorch 默认配置)"""
    try:
        intra_threads = torch.get_num_threads()
        inter_threads = torch.get_num_interop_threads()
        print(f"🧵 PyTorch 线程配置: intra-op={intra_threads}, inter-op={inter_threads}")
    except Exception as e:
        print(f"⚠️  CPU 线程信息获取失败: {e}")


def load_runtime_config(model_name: str, verbose: bool = False) -> tuple[str, int, int]:
    """
    加载运行时配置的完整流程:检测平台 -> 合并配置 -> 校验 -> CPU调试信息
    
    这是配置加载的主入口函数,封装了完整的配置加载逻辑。
    
    Args:
        model_name: 模型标识符(如 'qwen3-1.7b')
        verbose: 是否打印详细的配置信息
    
    Returns:
        tuple[platform, max_token_length, chunk_size]
    
    Raises:
        ValueError: 配置不完整或不符合平台限制时抛出
    """
    # 1. 检测平台
    platform = detect_platform(verbose=verbose)
    
    # 2. 四层配置合并(支持部分覆盖,并追踪配置来源)
    config = merge_runtime_config(
        model_name=model_name or "default_model",
        platform=platform,
        verbose=verbose
    )
    
    # 3. 提取配置
    max_token_length = config["max_token_length"]
    chunk_size = config["chunk_size"]
    
    # 4. 平台级安全校验(MPS 限制等)
    validate_platform_config(platform, chunk_size, verbose=verbose)
    
    # 5. 打印 CPU 信息(所有平台都打印)
    _print_cpu_info()
    
    # 6. CPU 线程配置信息打印(仅针对 CPU 平台)
    if "cpu" in platform.lower():
        _print_cpu_thread_info()  # 打印调试信息
    
    # 7. 打印配置摘要
    print(
        f"⚙️  运行时配置已加载 [model={model_name}, platform={platform}]: "
        f"max_token_length={max_token_length}, chunk_size={chunk_size}"
    )
    
    return platform, max_token_length, chunk_size