Instructions to use zeroentropy/zerank-2-reranker with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use zeroentropy/zerank-2-reranker with sentence-transformers:
from sentence_transformers import CrossEncoder model = CrossEncoder("zeroentropy/zerank-2-reranker") query = "Which planet is known as the Red Planet?" passages = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet." ] scores = model.predict([(query, passage) for passage in passages]) print(scores) - Notebooks
- Google Colab
- Kaggle
| from sentence_transformers import CrossEncoder as _CE | |
| import math | |
| import logging | |
| from typing import cast, Any | |
| import types | |
| import torch | |
| from transformers.configuration_utils import PretrainedConfig | |
| from transformers.models.auto.configuration_auto import AutoConfig | |
| from transformers.models.auto.modeling_auto import AutoModelForCausalLM | |
| from transformers.models.auto.tokenization_auto import AutoTokenizer | |
| from transformers.models.gemma3.modeling_gemma3 import ( | |
| Gemma3ForCausalLM, | |
| Gemma3ForConditionalGeneration, | |
| ) | |
| from transformers.models.llama.modeling_llama import LlamaForCausalLM | |
| from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM | |
| from transformers.tokenization_utils_base import BatchEncoding | |
| from transformers.tokenization_utils_fast import PreTrainedTokenizerFast | |
| # pyright: reportUnknownMemberType=false | |
| # pyright: reportUnknownVariableType=false | |
| logger = logging.getLogger(__name__) | |
| MODEL_PATH = "zeroentropy/zerank-2" | |
| PER_DEVICE_BATCH_SIZE_TOKENS = 10_000 | |
| global_device = ( | |
| torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| ) | |
| def format_pointwise_datapoints( | |
| tokenizer: PreTrainedTokenizerFast, | |
| query_documents: list[tuple[str, str]], | |
| ) -> BatchEncoding: | |
| input_texts: list[str] = [] | |
| for query, document in query_documents: | |
| system_prompt = f""" | |
| {query} | |
| """.strip() | |
| user_message = f""" | |
| {document} | |
| """.strip() | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_message}, | |
| ] | |
| input_text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| assert isinstance(input_text, str) | |
| input_texts.append(input_text) | |
| batch_inputs = tokenizer( | |
| input_texts, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| return batch_inputs | |
| def load_model( | |
| device: torch.device | None = None, | |
| ) -> tuple[ | |
| PreTrainedTokenizerFast, | |
| LlamaForCausalLM | |
| | Gemma3ForConditionalGeneration | |
| | Gemma3ForCausalLM | |
| | Qwen3ForCausalLM, | |
| ]: | |
| if device is None: | |
| device = global_device | |
| logger.info(f"Loading model from {MODEL_PATH} on device: {device}") | |
| config = AutoConfig.from_pretrained(MODEL_PATH) | |
| assert isinstance(config, PretrainedConfig) | |
| logger.info(f"Loading model with config type: {config.model_type}") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_PATH, | |
| torch_dtype="auto", | |
| quantization_config=None, | |
| device_map={"": device}, | |
| ) | |
| if config.model_type == "llama": | |
| model.config.attn_implementation = "flash_attention_2" | |
| assert isinstance( | |
| model, | |
| LlamaForCausalLM | |
| | Gemma3ForConditionalGeneration | |
| | Gemma3ForCausalLM | |
| | Qwen3ForCausalLM, | |
| ) | |
| logger.info("Loading tokenizer") | |
| tokenizer = cast( | |
| AutoTokenizer, | |
| AutoTokenizer.from_pretrained( | |
| MODEL_PATH, | |
| padding_side="right", | |
| ), | |
| ) | |
| assert isinstance(tokenizer, PreTrainedTokenizerFast) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| logger.info("Model and tokenizer loaded successfully") | |
| return tokenizer, model | |
| # Store the original __init__ method | |
| _original_init = _CE.__init__ | |
| def __init__(self, *args: Any, **kwargs: Any) -> None: | |
| logger.info("Initializing CrossEncoder with eager model loading") | |
| # Call the original CrossEncoder __init__ first | |
| _original_init(self, *args, **kwargs) | |
| # Load the model immediately on instantiation | |
| logger.info("Loading model on instantiation (no lazy loading)") | |
| self.inner_tokenizer, self.inner_model = load_model(global_device) | |
| self.inner_model.eval() | |
| self.inner_model.gradient_checkpointing_disable() | |
| self.inner_yes_token_id = self.inner_tokenizer.encode( | |
| "Yes", add_special_tokens=False | |
| )[0] | |
| logger.info(f"CrossEncoder initialization complete. Yes token ID: {self.inner_yes_token_id}") | |
| def predict( | |
| self, | |
| query_documents: list[tuple[str, str]] | None = None, | |
| *, | |
| sentences: Any = None, | |
| batch_size: Any = None, | |
| show_progress_bar: Any = None, | |
| activation_fn: Any = None, | |
| apply_softmax: Any = None, | |
| convert_to_numpy: Any = None, | |
| convert_to_tensor: Any = None, | |
| ) -> list[float]: | |
| if query_documents is None: | |
| if sentences is None: | |
| raise ValueError("query_documents or sentences must be provided") | |
| query_documents = [[sentence[0], sentence[1]] for sentence in sentences] | |
| logger.info(f"Starting prediction for {len(query_documents)} query-document pairs") | |
| model = self.inner_model | |
| tokenizer = self.inner_tokenizer | |
| query_documents = [ | |
| (query[:2_000], document[:10_000]) for query, document in query_documents | |
| ] | |
| # Sort | |
| permutation = list(range(len(query_documents))) | |
| permutation.sort( | |
| key=lambda i: -len(query_documents[i][0]) - len(query_documents[i][1]) | |
| ) | |
| query_documents = [query_documents[i] for i in permutation] | |
| # Extract document batches from this line of datapoints | |
| max_length = 0 | |
| batches: list[list[tuple[str, str]]] = [] | |
| for query, document in query_documents: | |
| if ( | |
| len(batches) == 0 | |
| or (len(batches[-1]) + 1) * max(max_length, len(query) + len(document)) | |
| > PER_DEVICE_BATCH_SIZE_TOKENS | |
| ): | |
| batches.append([]) | |
| max_length = 0 | |
| batches[-1].append((query, document)) | |
| max_length = max(max_length, 20 + len(query) + len(document)) | |
| logger.info(f"Created {len(batches)} batches for inference") | |
| # Inference all of the document batches | |
| all_logits: list[float] = [] | |
| for batch_idx, batch in enumerate(batches): | |
| logger.debug(f"Processing batch {batch_idx + 1}/{len(batches)} with {len(batch)} pairs") | |
| batch_inputs = format_pointwise_datapoints( | |
| tokenizer, | |
| batch, | |
| ) | |
| batch_inputs = batch_inputs.to(global_device) | |
| try: | |
| with torch.inference_mode(): | |
| outputs = model(**batch_inputs, use_cache=False) | |
| except torch.OutOfMemoryError: | |
| logger.warning(f"GPU OOM! Memory reserved: {torch.cuda.memory_reserved()}") | |
| torch.cuda.empty_cache() | |
| logger.info(f"GPU cache cleared. Memory reserved: {torch.cuda.memory_reserved()}") | |
| outputs = model(**batch_inputs, use_cache=False) | |
| # Extract the logits | |
| logits = cast(torch.Tensor, outputs.logits) | |
| attention_mask = cast(torch.Tensor, batch_inputs.attention_mask) | |
| last_positions = attention_mask.sum(dim=1) - 1 | |
| batch_size = logits.shape[0] | |
| batch_indices = torch.arange(batch_size, device=global_device) | |
| last_logits = logits[batch_indices, last_positions] | |
| yes_logits = last_logits[:, self.inner_yes_token_id] | |
| all_logits.extend([float(logit) / 5.0 for logit in yes_logits]) | |
| def sigmoid(x: float) -> float: | |
| return 1 / (1 + math.exp(-x)) | |
| scores = [sigmoid(logit) for logit in all_logits] | |
| # Unsort by indices | |
| scores = [score for _, score in sorted(zip(permutation, scores, strict=True))] | |
| logger.info(f"Prediction complete. Generated {len(scores)} scores") | |
| return scores | |
| def to_device(self: _CE, new_device: torch.device) -> None: | |
| global global_device | |
| logger.info(f"Changing device from {global_device} to {new_device}") | |
| global_device = new_device | |
| _CE.__init__ = __init__ | |
| _CE.predict = predict | |
| _CE.to = to_device | |
| from transformers import Qwen3Config | |
| ZEConfig = Qwen3Config | |