shanaym commited on
Commit
c46e66e
·
verified ·
1 Parent(s): dcc820d

Upload modeling_zeranker.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_zeranker.py +214 -0
modeling_zeranker.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import CrossEncoder as _CE
2
+
3
+ import math
4
+ from typing import cast, Any
5
+ import types
6
+
7
+ import torch
8
+ from transformers.configuration_utils import PretrainedConfig
9
+ from transformers.models.auto.configuration_auto import AutoConfig
10
+ from transformers.models.auto.modeling_auto import AutoModelForCausalLM
11
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
12
+ from transformers.models.gemma3.modeling_gemma3 import (
13
+ Gemma3ForCausalLM,
14
+ Gemma3ForConditionalGeneration,
15
+ )
16
+ from transformers.models.llama.modeling_llama import LlamaForCausalLM
17
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
18
+ from transformers.tokenization_utils_base import BatchEncoding
19
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
20
+
21
+ # pyright: reportUnknownMemberType=false
22
+ # pyright: reportUnknownVariableType=false
23
+
24
+ MODEL_PATH = "zeroentropy/zerank-1-small"
25
+ PER_DEVICE_BATCH_SIZE_TOKENS = 15_000
26
+ global_device = (
27
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
28
+ )
29
+
30
+
31
+ def format_pointwise_datapoints(
32
+ tokenizer: PreTrainedTokenizerFast,
33
+ query_documents: list[tuple[str, str]],
34
+ ) -> BatchEncoding:
35
+ input_texts: list[str] = []
36
+ for query, document in query_documents:
37
+ system_prompt = f"""
38
+ {query}
39
+ """.strip()
40
+ user_message = f"""
41
+ {document}
42
+ """.strip()
43
+ messages = [
44
+ {"role": "system", "content": system_prompt},
45
+ {"role": "user", "content": user_message},
46
+ ]
47
+ input_text = tokenizer.apply_chat_template(
48
+ messages,
49
+ tokenize=False,
50
+ add_generation_prompt=True,
51
+ )
52
+ assert isinstance(input_text, str)
53
+ input_texts.append(input_text)
54
+
55
+ batch_inputs = tokenizer(
56
+ input_texts,
57
+ padding=True,
58
+ return_tensors="pt",
59
+ )
60
+ return batch_inputs
61
+
62
+
63
+ def load_model(
64
+ device: torch.device | None = None,
65
+ ) -> tuple[
66
+ PreTrainedTokenizerFast,
67
+ LlamaForCausalLM
68
+ | Gemma3ForConditionalGeneration
69
+ | Gemma3ForCausalLM
70
+ | Qwen3ForCausalLM,
71
+ ]:
72
+ if device is None:
73
+ device = global_device
74
+
75
+ config = AutoConfig.from_pretrained(MODEL_PATH)
76
+ assert isinstance(config, PretrainedConfig)
77
+
78
+ model = AutoModelForCausalLM.from_pretrained(
79
+ MODEL_PATH,
80
+ torch_dtype="auto",
81
+ quantization_config=None,
82
+ device_map={"": device},
83
+ )
84
+ if config.model_type == "llama":
85
+ model.config.attn_implementation = "flash_attention_2"
86
+ assert isinstance(
87
+ model,
88
+ LlamaForCausalLM
89
+ | Gemma3ForConditionalGeneration
90
+ | Gemma3ForCausalLM
91
+ | Qwen3ForCausalLM,
92
+ )
93
+
94
+ tokenizer = cast(
95
+ AutoTokenizer,
96
+ AutoTokenizer.from_pretrained(
97
+ MODEL_PATH,
98
+ padding_side="right",
99
+ ),
100
+ )
101
+ assert isinstance(tokenizer, PreTrainedTokenizerFast)
102
+
103
+ if tokenizer.pad_token is None:
104
+ tokenizer.pad_token = tokenizer.eos_token
105
+
106
+ return tokenizer, model
107
+
108
+
109
+ def predict(
110
+ self,
111
+ query_documents: list[tuple[str, str]] | None = None,
112
+ *,
113
+ sentences: Any = None,
114
+ batch_size: Any = None,
115
+ show_progress_bar: Any = None,
116
+ activation_fn: Any = None,
117
+ apply_softmax: Any = None,
118
+ convert_to_numpy: Any = None,
119
+ convert_to_tensor: Any = None,
120
+ ) -> list[float]:
121
+ if query_documents is None:
122
+ if sentences is None:
123
+ raise ValueError("query_documents or sentences must be provided")
124
+ query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
125
+
126
+ if not hasattr(self, "inner_model"):
127
+ self.inner_tokenizer, self.inner_model = load_model(global_device)
128
+ self.inner_model.gradient_checkpointing_enable()
129
+ self.inner_model.eval()
130
+ self.inner_yes_token_id = self.inner_tokenizer.encode(
131
+ "Yes", add_special_tokens=False
132
+ )[0]
133
+
134
+ model = self.inner_model
135
+ tokenizer = self.inner_tokenizer
136
+
137
+ query_documents = [
138
+ (query[:2_000], document[:10_000]) for query, document in query_documents
139
+ ]
140
+ # Sort
141
+ permutation = list(range(len(query_documents)))
142
+ permutation.sort(
143
+ key=lambda i: -len(query_documents[i][0]) - len(query_documents[i][1])
144
+ )
145
+ query_documents = [query_documents[i] for i in permutation]
146
+
147
+ # Extract document batches from this line of datapoints
148
+ max_length = 0
149
+ batches: list[list[tuple[str, str]]] = []
150
+ for query, document in query_documents:
151
+ if (
152
+ len(batches) == 0
153
+ or (len(batches[-1]) + 1) * max(max_length, len(query) + len(document))
154
+ > PER_DEVICE_BATCH_SIZE_TOKENS
155
+ ):
156
+ batches.append([])
157
+ max_length = 0
158
+
159
+ batches[-1].append((query, document))
160
+ max_length = max(max_length, 20 + len(query) + len(document))
161
+
162
+ # Inference all of the document batches
163
+ all_logits: list[float] = []
164
+ for batch in batches:
165
+ batch_inputs = format_pointwise_datapoints(
166
+ tokenizer,
167
+ batch,
168
+ )
169
+
170
+ batch_inputs = batch_inputs.to(global_device)
171
+
172
+ try:
173
+ outputs = model(**batch_inputs, use_cache=False)
174
+ except torch.OutOfMemoryError:
175
+ print(f"GPU OOM! {torch.cuda.memory_reserved()}")
176
+ torch.cuda.empty_cache()
177
+ print(f"GPU After OOM Cache Clear: {torch.cuda.memory_reserved()}")
178
+ outputs = model(**batch_inputs, use_cache=False)
179
+
180
+ # Extract the logits
181
+ logits = cast(torch.Tensor, outputs.logits)
182
+ attention_mask = cast(torch.Tensor, batch_inputs.attention_mask)
183
+ last_positions = attention_mask.sum(dim=1) - 1
184
+
185
+ batch_size = logits.shape[0]
186
+ batch_indices = torch.arange(batch_size, device=global_device)
187
+ last_logits = logits[batch_indices, last_positions]
188
+
189
+ yes_logits = last_logits[:, self.inner_yes_token_id]
190
+ all_logits.extend([float(logit) / 5.0 for logit in yes_logits])
191
+
192
+ def sigmoid(x: float) -> float:
193
+ return 1 / (1 + math.exp(-x))
194
+
195
+ scores = [sigmoid(logit) for logit in all_logits]
196
+
197
+ # Unsort by indices
198
+ scores = [score for _, score in sorted(zip(permutation, scores, strict=True))]
199
+
200
+ return scores
201
+
202
+
203
+ def to_device(self: _CE, new_device: torch.device) -> None:
204
+ global global_device
205
+ global_device = new_device
206
+
207
+
208
+ _CE.predict = predict
209
+
210
+ from transformers import Qwen3Config
211
+
212
+ ZEConfig = Qwen3Config
213
+
214
+ _CE.to = to_device