infgrad commited on
Commit
5dc8079
·
verified ·
1 Parent(s): ed248a1

Add CrossEncoder integration

Browse files
1_LogitScore/config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "true_token_id": 9405,
3
+ "false_token_id": 2083
4
+ }
README.md CHANGED
@@ -58,6 +58,22 @@ If the document is not relevant, the model outputs `no` and stops. No contributi
58
 
59
  ## Quickstart
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ```python
62
  import torch
63
  from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -80,7 +96,7 @@ INSTRUCTION = (
80
  "- Concise: drop query-irrelevant background.\n"
81
  "- Verbatim (no translation): proper nouns, terms, abbreviations, "
82
  "numbers, dates, code, URLs.\n"
83
- "- Output language: multilingual doc -> query's language; else doc's language."
84
  "</evidence>"
85
  )
86
 
@@ -138,27 +154,62 @@ def rerank(query: str, doc: str, max_new_tokens: int = 512):
138
  return {"score": score, "text": text}
139
 
140
 
141
- example = rerank(
142
- query="What is the boiling point of water at sea level?",
143
- doc=(
144
- "Water boils at 100 C (212 F) at standard atmospheric pressure (1 atm), "
145
- "which corresponds to sea-level conditions."
146
- ),
147
- )
148
- print(example)
149
  ```
150
 
151
- Expected shape of the output:
152
 
153
  ```text
154
- {
155
- "score": 0.98,
156
- "text": "yes\n<contribution>...</contribution>\n<evidence>...</evidence>"
157
- }
158
  ```
159
 
160
  For irrelevant pairs the score is close to 0 and `text` is just `"no"`.
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  ## Notes on usage
164
 
@@ -170,4 +221,4 @@ For irrelevant pairs the score is close to 0 and `text` is just `"no"`.
170
 
171
  ## Contact
172
 
173
- Dun Zhang — `dunnzhang0@gmail.com` (independent researcher).
 
58
 
59
  ## Quickstart
60
 
61
+ Two ways to call the model. Both produce the **same** relevance score `s(q, d) = σ(ℓ_yes − ℓ_no)`. Use **A** when you also want `<contribution>` / `<evidence>`. Use **B** when you only need a score and want a drop-in replacement for any other CrossEncoder reranker.
62
+
63
+ We use one shared example throughout so you can compare the outputs side by side:
64
+
65
+ ```python
66
+ QUERY = "What is the boiling point of water at sea level?"
67
+ DOCUMENTS = [
68
+ "Water boils at 100 C (212 F) at standard atmospheric pressure (1 atm), "
69
+ "which corresponds to sea-level conditions.",
70
+ "Mount Everest is the highest mountain on Earth, with a peak elevation "
71
+ "of 8,848 meters above sea level.",
72
+ ]
73
+ ```
74
+
75
+ ### A. Transformers (full output: score + contribution + evidence)
76
+
77
  ```python
78
  import torch
79
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
96
  "- Concise: drop query-irrelevant background.\n"
97
  "- Verbatim (no translation): proper nouns, terms, abbreviations, "
98
  "numbers, dates, code, URLs.\n"
99
+ "- Output language: multilingual doc query's language; else doc's language."
100
  "</evidence>"
101
  )
102
 
 
154
  return {"score": score, "text": text}
155
 
156
 
157
+ for doc in DOCUMENTS:
158
+ print(rerank(QUERY, doc))
 
 
 
 
 
 
159
  ```
160
 
161
+ Expected output (one dict per document):
162
 
163
  ```text
164
+ {"score": 0.98, "text": "yes\n<contribution>...</contribution>\n<evidence>...</evidence>"}
165
+ {"score": 0.01, "text": "no"}
 
 
166
  ```
167
 
168
  For irrelevant pairs the score is close to 0 and `text` is just `"no"`.
169
 
170
+ ### B. Sentence Transformers CrossEncoder (score only)
171
+
172
+ If you only need the score and want a drop-in CrossEncoder, the same model works directly with `sentence-transformers >= 5.4.0`. **Note:** in this mode `<contribution>` and `<evidence>` are not produced — only the calibrated relevance score.
173
+
174
+ The system prompt and instruction are baked into the model's `chat_template.jinja` and are **not configurable** — the model was trained with one fixed prompt and only that prompt produces calibrated scores. You only pass `(query, document)`; the rest is hardcoded.
175
+
176
+ ```python
177
+ import torch
178
+ from sentence_transformers import CrossEncoder
179
+
180
+ MODEL_PATH = "infgrad/Prism-Qwen3.5-Reranker-4B" # or any sibling repo above
181
+
182
+ ce = CrossEncoder(MODEL_PATH, model_kwargs={"torch_dtype": torch.bfloat16})
183
+
184
+ # 1) Score (q, d) pairs. The default activation is Sigmoid, so scores are in (0, 1)
185
+ # and equal to s(q, d) = sigmoid(logit_yes - logit_no) — identical to path A above.
186
+ pairs = [(QUERY, doc) for doc in DOCUMENTS]
187
+ scores = ce.predict(pairs)
188
+ print(scores)
189
+ # array([0.98, 0.01], dtype=float32)
190
+
191
+ # 2) Rank documents directly.
192
+ ranked = ce.rank(QUERY, DOCUMENTS, return_documents=True)
193
+ for r in ranked:
194
+ print(f"{r['score']:.3f}\t{r['corpus_id']}\t{r['text'][:80]}")
195
+ ```
196
+
197
+ To get raw logit differences instead of [0, 1] probabilities, pass `activation_fn=torch.nn.Identity()` to `ce.predict(...)`.
198
+
199
+ #### A note on numerical parity with path A
200
+
201
+ In **fp32**, paths A and B produce the same score to within ~1e-6 (verified across all five checkpoints).
202
+
203
+ In **bf16** with the default batched call (`batch_size > 1`), CE scores can drift from path A by **~1–3%** for individual pairs. The cause is bf16 SDPA: when CrossEncoder pads shorter sequences to the longest in the batch, the bf16 attention numerics differ by a few ULPs vs running each pair alone, and the difference accumulates across layers before the final sigmoid. **Ranking order is unaffected.** If you need bit-for-bit parity with path A:
204
+
205
+ ```python
206
+ # Option 1: keep bf16, disable batching
207
+ ce.predict(pairs, batch_size=1)
208
+
209
+ # Option 2: use fp32 (slower, larger memory)
210
+ ce = CrossEncoder(MODEL_PATH, model_kwargs={"torch_dtype": torch.float32})
211
+ ```
212
+
213
 
214
  ## Notes on usage
215
 
 
221
 
222
  ## Contact
223
 
224
+ Dun Zhang — `dunnzhang0@gmail.com` (independent researcher).
chat_template.jinja CHANGED
@@ -1,154 +1,22 @@
1
- {%- set image_count = namespace(value=0) %}
2
- {%- set video_count = namespace(value=0) %}
3
- {%- macro render_content(content, do_vision_count, is_system_content=false) %}
4
- {%- if content is string %}
5
- {{- content }}
6
- {%- elif content is iterable and content is not mapping %}
7
- {%- for item in content %}
8
- {%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
9
- {%- if is_system_content %}
10
- {{- raise_exception('System message cannot contain images.') }}
11
- {%- endif %}
12
- {%- if do_vision_count %}
13
- {%- set image_count.value = image_count.value + 1 %}
14
- {%- endif %}
15
- {%- if add_vision_id %}
16
- {{- 'Picture ' ~ image_count.value ~ ': ' }}
17
- {%- endif %}
18
- {{- '<|vision_start|><|image_pad|><|vision_end|>' }}
19
- {%- elif 'video' in item or item.type == 'video' %}
20
- {%- if is_system_content %}
21
- {{- raise_exception('System message cannot contain videos.') }}
22
- {%- endif %}
23
- {%- if do_vision_count %}
24
- {%- set video_count.value = video_count.value + 1 %}
25
- {%- endif %}
26
- {%- if add_vision_id %}
27
- {{- 'Video ' ~ video_count.value ~ ': ' }}
28
- {%- endif %}
29
- {{- '<|vision_start|><|video_pad|><|vision_end|>' }}
30
- {%- elif 'text' in item %}
31
- {{- item.text }}
32
- {%- else %}
33
- {{- raise_exception('Unexpected item type in content.') }}
34
- {%- endif %}
35
- {%- endfor %}
36
- {%- elif content is none or content is undefined %}
37
- {{- '' }}
38
- {%- else %}
39
- {{- raise_exception('Unexpected content type.') }}
40
- {%- endif %}
41
- {%- endmacro %}
42
- {%- if not messages %}
43
- {{- raise_exception('No messages provided.') }}
44
- {%- endif %}
45
- {%- if tools and tools is iterable and tools is not mapping %}
46
- {{- '<|im_start|>system\n' }}
47
- {{- "# Tools\n\nYou have access to the following functions:\n\n<tools>" }}
48
- {%- for tool in tools %}
49
- {{- "\n" }}
50
- {{- tool | tojson }}
51
- {%- endfor %}
52
- {{- "\n</tools>" }}
53
- {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
54
- {%- if messages[0].role == 'system' %}
55
- {%- set content = render_content(messages[0].content, false, true)|trim %}
56
- {%- if content %}
57
- {{- '\n\n' + content }}
58
- {%- endif %}
59
- {%- endif %}
60
- {{- '<|im_end|>\n' }}
61
- {%- else %}
62
- {%- if messages[0].role == 'system' %}
63
- {%- set content = render_content(messages[0].content, false, true)|trim %}
64
- {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }}
65
- {%- endif %}
66
- {%- endif %}
67
- {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
68
- {%- for message in messages[::-1] %}
69
- {%- set index = (messages|length - 1) - loop.index0 %}
70
- {%- if ns.multi_step_tool and message.role == "user" %}
71
- {%- set content = render_content(message.content, false)|trim %}
72
- {%- if not(content.startswith('<tool_response>') and content.endswith('</tool_response>')) %}
73
- {%- set ns.multi_step_tool = false %}
74
- {%- set ns.last_query_index = index %}
75
- {%- endif %}
76
- {%- endif %}
77
- {%- endfor %}
78
- {%- if ns.multi_step_tool %}
79
- {{- raise_exception('No user query found in messages.') }}
80
- {%- endif %}
81
- {%- for message in messages %}
82
- {%- set content = render_content(message.content, true)|trim %}
83
- {%- if message.role == "system" %}
84
- {%- if not loop.first %}
85
- {{- raise_exception('System message must be at the beginning.') }}
86
- {%- endif %}
87
- {%- elif message.role == "user" %}
88
- {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
89
- {%- elif message.role == "assistant" %}
90
- {%- set reasoning_content = '' %}
91
- {%- if message.reasoning_content is string %}
92
- {%- set reasoning_content = message.reasoning_content %}
93
- {%- else %}
94
- {%- if '</think>' in content %}
95
- {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
96
- {%- set content = content.split('</think>')[-1].lstrip('\n') %}
97
- {%- endif %}
98
- {%- endif %}
99
- {%- set reasoning_content = reasoning_content|trim %}
100
- {%- if loop.index0 > ns.last_query_index %}
101
- {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n\n' + content }}
102
- {%- else %}
103
- {{- '<|im_start|>' + message.role + '\n' + content }}
104
- {%- endif %}
105
- {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %}
106
- {%- for tool_call in message.tool_calls %}
107
- {%- if tool_call.function is defined %}
108
- {%- set tool_call = tool_call.function %}
109
- {%- endif %}
110
- {%- if loop.first %}
111
- {%- if content|trim %}
112
- {{- '\n\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
113
- {%- else %}
114
- {{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
115
- {%- endif %}
116
- {%- else %}
117
- {{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
118
- {%- endif %}
119
- {%- if tool_call.arguments is defined %}
120
- {%- for args_name, args_value in tool_call.arguments|items %}
121
- {{- '<parameter=' + args_name + '>\n' }}
122
- {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
123
- {{- args_value }}
124
- {{- '\n</parameter>\n' }}
125
- {%- endfor %}
126
- {%- endif %}
127
- {{- '</function>\n</tool_call>' }}
128
- {%- endfor %}
129
- {%- endif %}
130
- {{- '<|im_end|>\n' }}
131
- {%- elif message.role == "tool" %}
132
- {%- if loop.previtem and loop.previtem.role != "tool" %}
133
- {{- '<|im_start|>user' }}
134
- {%- endif %}
135
- {{- '\n<tool_response>\n' }}
136
- {{- content }}
137
- {{- '\n</tool_response>' }}
138
- {%- if not loop.last and loop.nextitem.role != "tool" %}
139
- {{- '<|im_end|>\n' }}
140
- {%- elif loop.last %}
141
- {{- '<|im_end|>\n' }}
142
- {%- endif %}
143
- {%- else %}
144
- {{- raise_exception('Unexpected message role.') }}
145
- {%- endif %}
146
- {%- endfor %}
147
- {%- if add_generation_prompt %}
148
- {{- '<|im_start|>assistant\n' }}
149
- {%- if enable_thinking is defined and enable_thinking is false %}
150
- {{- '<think>\n\n</think>\n\n' }}
151
- {%- else %}
152
- {{- '<think>\n' }}
153
- {%- endif %}
154
- {%- endif %}
 
1
+ {%- set query_text = messages | selectattr("role", "eq", "query") | map(attribute="content") | first -%}
2
+ {%- set document_text = messages | selectattr("role", "eq", "document") | map(attribute="content") | first -%}
3
+ <|im_start|>system
4
+ Judge whether the Document meets the requirements based on the Query and the Instruct provided. <|im_end|>
5
+ <|im_start|>user
6
+ <Instruct>: Judge if the document is relevant to the query. Reply "yes" or "no".
7
+ On "yes", also emit:
8
+ <contribution>One sentence covering every core point the document contributes to the query, without elaboration.</contribution>
9
+ <evidence>Self-contained rewrite of the query-relevant content. Rules:
10
+ - Faithful: rephrase only; add or infer nothing.
11
+ - Self-contained: evidence alone must fully answer the query.
12
+ - Concise: drop query-irrelevant background.
13
+ - Verbatim (no translation): proper nouns, terms, abbreviations, numbers, dates, code, URLs.
14
+ - Output language: multilingual doc → query's language; else doc's language.</evidence>
15
+ <Query>: {{ query_text }}
16
+ <Document>: {{ document_text }}<|im_end|>
17
+ <|im_start|>assistant
18
+ <think>
19
+
20
+ </think>
21
+
22
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config_sentence_transformers.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "5.4.1"
4
+ },
5
+ "activation_fn": "torch.nn.modules.activation.Sigmoid",
6
+ "model_type": "CrossEncoder"
7
+ }
modules.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.base.modules.transformer.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_LogitScore",
12
+ "type": "sentence_transformers.cross_encoder.modules.logit_score.LogitScore"
13
+ }
14
+ ]
sentence_bert_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "transformer_task": "text-generation",
3
+ "modality_config": {
4
+ "text": {
5
+ "method": "forward",
6
+ "method_output_name": "logits"
7
+ },
8
+ "message": {
9
+ "method": "forward",
10
+ "method_output_name": "logits",
11
+ "format": "flat"
12
+ }
13
+ },
14
+ "module_output_name": "causal_logits"
15
+ }