File size: 12,077 Bytes
494c9e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
"""
Semantic analysis:基于 instruct 模型提取原文 token 与 query 的相关度

使用 logits_gradient 梯度归因策略(与预测更一致),子策略由 --logits_gradient_submode 指定:
- count:top-10 logits 梯度(排除 0),prompt 引导「数量」。0.6b下只适合用于判断文章整体是否有关联,1.7b下全能
- match_score:目标 token logit 梯度,prompt 引导「相关度打分」。0.6b/1.7b下都不太有竞争力。【已废弃】
- fill_blank:填空式,top-10 logits 梯度(排除 无),prompt 引导「最相关的一个词」。0.6b下只适合用于给token打分,1.7b下全能

count/fill_blank 按概率加权(Σ pᵢ·zᵢ)。

模型由 --semantic_model 参数指定,默认 qwen3-0.6b-instruct
"""

import gc
import math
from typing import Callable, Dict, List, Optional

import torch

from .api.utils import round_to_sig_figs
from .device import DeviceManager
from .model_manager import ensure_semantic_slot_ready, get_semantic_model_display_name
from .next_token_topk import decode_topk_ids_to_strings_and_rounded_probs, DEFAULT_NEXT_TOKEN_TOPK
from .runtime_config import get_semantic_max_token_length



def _get_logits_gradient_submode() -> str:
    """logits_gradient 子策略:count / match_score(已废弃) / fill_blank"""
    try:
        from backend.app_context import get_args
        return getattr(get_args(), "logits_gradient_submode", "fill_blank")
    except RuntimeError:
        return "fill_blank"


def _truncate_text_by_tokens(tokenizer, text: str, max_tokens: int) -> str:
    """将 text 截断至最多 max_tokens 个 token;超长时打印提示。"""
    text_ids = tokenizer.encode(text, add_special_tokens=False)
    if len(text_ids) > max_tokens:
        print(f"⚠️  原文过长,已截断至前 {max_tokens} token")
        return tokenizer.decode(text_ids[:max_tokens])
    return text


def _get_gradient_checkpointing() -> bool:
    """默认 True(run.py);``--no-gradient-checkpointing`` 关闭。"""
    try:
        from backend.app_context import get_args
        return getattr(get_args(), "gradient_checkpointing", True)
    except RuntimeError:
        return True


def _get_verbose() -> bool:
    """是否输出详细调试信息(由 --verbose 控制)"""
    from backend.app_context import get_verbose
    return get_verbose()


def _analyze_logits_gradient(
    query: str,
    text: str,
    tokenizer,
    model,
    device,
    submode_override: Optional[str] = None,
    progress_callback: Optional[Callable[[int, int, str, Optional[int]], None]] = None,
    debug_info: bool = False,
    full_match_degree_only: bool = False,
) -> Dict:
    """
    梯度归因:logits 对输入 embedding 的梯度。
    子策略:count / match_score(已废弃) / fill_blank,由 --logits_gradient_submode 指定。
    submode_override: 评估时可选覆盖,用于同一进程内测试不同子模式。
    """
    TOTAL_STEPS = 4

    submode = submode_override if submode_override is not None else _get_logits_gradient_submode()
    max_length = get_semantic_max_token_length()

    if progress_callback:
        progress_callback(1, TOTAL_STEPS, "encoding", None)
    # 根据submodule来决定不同的instruction
    # 文档前用 \n\n 分隔,避免 tokenizer 将首字符与空格合并,导致 offset_mapping 计算错误
    if submode == "count":
        instruction = f"请问下面文字中有多少个词与查询主题({query})相关?文字内容:\n\n"
    elif submode == "match_score":  # 已废弃
        instruction = f"请问下面文字与查询主题({query})的相关程度是多少?请回答0/1/2(2为最高相关)。文字内容:\n\n"
    elif submode == "fill_blank":
        instruction = f"请问下面文字中哪个词与查询主题({query})最相关?如无相关词则回答“无”。文字内容:\n\n"
    else:
        raise ValueError(f"未知子模式: {submode}")

    # 截断 text 到 max_length token,再拼
    truncated_text = _truncate_text_by_tokens(tokenizer, text, max_length)
    
    messages = [{"role": "user", "content": instruction + truncated_text}]
    formatted = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True,
        enable_thinking=False
    )
    # 生成引导词:chat template 只支持完整消息,引导词需追加到 formatted
    if submode == "count":
        generation_guide = f"原文中与查询主题({query})相关的词的数量 = **"
    elif submode == "match_score":  # 已废弃
        generation_guide = f"文章和查询主题({query})的相关程度(0-2)打分为:**"
    elif submode == "fill_blank":
        # “引号是特意为了防止模型生成引号
        generation_guide = f"原文中与查询主题({query})最相关的一个词是:**“"
    else:
        raise ValueError(f"未知子模式: {submode}")
    formatted += generation_guide

    # logits_gradient count/fill_blank 的 top-k,影响梯度目标覆盖的候选词数量
    LOGITS_GRADIENT_TOPK = DEFAULT_NEXT_TOKEN_TOPK

    idx = formatted.find(instruction)
    instruction_start_char = idx if idx >= 0 else 0
    text_start_char = instruction_start_char + len(instruction)
    text_end_char = text_start_char + len(truncated_text)
    lines = truncated_text.splitlines()
    abbrev_text = truncated_text if len(lines) <= 2 else f"{lines[0]}\n...\n{lines[-1]}"
    abbrev = formatted[:text_start_char] + abbrev_text + formatted[text_end_char:]

    enc = tokenizer(
        formatted,
        return_tensors="pt",
        return_offsets_mapping=True,
    )

    input_ids = enc["input_ids"].to(device)
    offset_mapping = enc["offset_mapping"][0].tolist()

    prompt_end = len(offset_mapping)
    for i, (s, _) in enumerate(offset_mapping):
        if s >= text_start_char:
            prompt_end = i
            break

    embed_layer = model.get_input_embeddings()
    embeds = embed_layer(input_ids).detach().clone().requires_grad_(True)

    use_gc = _get_gradient_checkpointing()
    if _get_verbose():
        print(f"📌 logits_gradient: 推理原文 (tokens={len(offset_mapping)}):\n{abbrev}")
    if progress_callback:
        progress_callback(2, TOTAL_STEPS, "inference", None)
    model.eval()
    if use_gc:
        model.gradient_checkpointing_enable()
    try:
        with torch.set_grad_enabled(not full_match_degree_only):
            outputs = model(
                inputs_embeds=embeds,
                output_attentions=False,
            )
        # 显式同步,确保已完成,progress_callback 时机准确
        if device.type == "cuda":
            torch.cuda.synchronize(device)
        elif device.type == "mps":
            torch.mps.synchronize()

        logits = outputs.logits[:, -1, :]
        topk_vals, topk_ids = torch.topk(logits, LOGITS_GRADIENT_TOPK, dim=-1)
        probs = torch.softmax(logits, dim=-1)
        topk_tokens, topk_probs = decode_topk_ids_to_strings_and_rounded_probs(
            probs[0], tokenizer, topk_ids[0]
        )
        if _get_verbose():
            print(f"top{LOGITS_GRADIENT_TOPK}: {[f'{t}({p*100:.1f}%)' for t, p in zip(topk_tokens, topk_probs)]}")

        neg_token = "无" if submode == "fill_blank" else "0"
        neg_id = tokenizer.encode(neg_token, add_special_tokens=False)[0]
        # 全文匹配度:count/match_score(已废弃) 用 1-P("0"),fill_blank 用 1-P("无")
        p_neg = probs[0, neg_id].item()
        full_match_degree = round(1.0 - p_neg, 4)

        if full_match_degree_only:
            return {
                "model": get_semantic_model_display_name(),
                "token_attention": [],
                "full_match_degree": full_match_degree,
            }

        if progress_callback:
            progress_callback(3, TOTAL_STEPS, "backward", None)
        # 归因目标:raw logits(不经过 softmax backward),避免饱和与竞争污染。
        if submode == "count" or submode == "fill_blank":
            # count/fill_blank 均用 top-10、按概率加权 Σ pᵢ·zᵢ,并排除 neg_token(0/无)以保持梯度方向与「相关」一致。
            vals = topk_vals[0]
            w = probs[0, topk_ids[0]].detach().clone()
            # 排除 neg_token
            w[topk_ids[0] == neg_id] = 0  

            target_logit = (w * vals).sum()
        elif submode == "match_score":  # 已废弃
            target_ids = tokenizer.encode("2", add_special_tokens=False)
            if not target_ids:
                raise ValueError("tokenizer 无法编码 '2'")
            target_logit = logits[0, target_ids[0]]
        else:
            raise ValueError(f"未知 submode: {submode}")
        target_logit.backward()
        grad = embeds.grad
        if grad is None:
            raise RuntimeError("logits_gradient: 梯度未回传,可能模型不支持(如 int8 量化)")

        # 显式同步,确保已完成,progress_callback 时机准确
        if device.type == "cuda":
            torch.cuda.synchronize(device)
        elif device.type == "mps":
            torch.mps.synchronize()
        if progress_callback:
            progress_callback(4, TOTAL_STEPS, "processing", None)
            
        text_token_end = len(offset_mapping)
        # 在 GPU 上一次性计算所有 token 的 ‖∇f‖,避免循环内 .item() 导致 500 次 GPU→CPU 同步
        grad_slice = grad[0, prompt_end:text_token_end].float()
        norms = grad_slice.norm(dim=-1).cpu().tolist()
        token_attention: List[Dict] = []
        nan_count = 0
        for i in range(prompt_end, text_token_end):
            s, e = offset_mapping[i]
            if s >= text_start_char and e <= text_end_char:
                s_rel, e_rel = s - text_start_char, e - text_start_char
                score = norms[i - prompt_end]
                if not math.isfinite(score):
                    score = 0.0
                    nan_count += 1
                else:
                    score = round_to_sig_figs(score)
                token_attention.append({"offset": [s_rel, e_rel], "raw": truncated_text[s_rel:e_rel], "score": score})
        if nan_count > 0:
            print(f"⚠️ token_attention 中有 {nan_count} 个 score 为 NaN/Inf,已替换为 0。")

        out = {
            "model": get_semantic_model_display_name(),
            "token_attention": token_attention,
            "full_match_degree": full_match_degree,
        }
        if debug_info:
            out["debug_info"] = {"abbrev": abbrev, "topk_tokens": topk_tokens, "topk_probs": topk_probs}
        return out
    finally:
        if use_gc:
            model.gradient_checkpointing_disable()
        # 每次推理后清理:避免连续多次调用时 MPS/CUDA 内存累积导致卡死
        DeviceManager.clear_cache(device)


def analyze_semantic(
    query: str,
    text: str,
    submode_override: Optional[str] = None,
    progress_callback: Optional[Callable[[int, int, str, Optional[int]], None]] = None,
    debug_info: bool = False,
    full_match_degree_only: bool = False,
) -> Dict:
    """
    分析原文各 token 与 query 的相关度(使用 logits_gradient 梯度归因)。

    Args:
        query: 查询主题
        text: 原文
        submode_override: 评估时可选覆盖子模式(count/match_score已废弃/fill_blank)
        progress_callback: 可选进度回调 (step, total_steps, stage, percentage)
        debug_info: 为 True 时返回 debug_abbrev(推理原文缩写);topk_tokens、topk_probs 始终在结果中

    Returns:
        {"model", "token_attention", "full_match_degree"};debug_info=True 时包含 debug_info 对象
    """
    tokenizer, model, device = ensure_semantic_slot_ready()
    return _analyze_logits_gradient(
        query, text, tokenizer, model, device,
        submode_override=submode_override,
        progress_callback=progress_callback,
        debug_info=debug_info,
        full_match_degree_only=full_match_degree_only,
    )