Instructions to use zeroentropy/zerank-2-reranker with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use zeroentropy/zerank-2-reranker with sentence-transformers:
from sentence_transformers import CrossEncoder model = CrossEncoder("zeroentropy/zerank-2-reranker") query = "Which planet is known as the Red Planet?" passages = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet." ] scores = model.predict([(query, passage) for passage in passages]) print(scores) - Notebooks
- Google Colab
- Kaggle
Commit ·
f899c80
1
Parent(s): 456ffeb
Load model during to_device call for eager loading
Browse filesMove model loading to to_device() since __init__ patching doesn't work
due to timing (CrossEncoder instance is created before this module
is loaded from HuggingFace). to_device() is called during CrossEncoder
initialization, making this effectively eager loading.
- modeling_zeranker.py +12 -0
modeling_zeranker.py
CHANGED
|
@@ -234,6 +234,18 @@ def to_device(self: _CE, new_device: torch.device) -> None:
|
|
| 234 |
logger.info(f"Changing device from {global_device} to {new_device}")
|
| 235 |
global_device = new_device
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
_CE.__init__ = __init__
|
| 239 |
_CE.predict = predict
|
|
|
|
| 234 |
logger.info(f"Changing device from {global_device} to {new_device}")
|
| 235 |
global_device = new_device
|
| 236 |
|
| 237 |
+
# Load the model now since __init__ patching doesn't work due to timing
|
| 238 |
+
# (CrossEncoder instance is created before this module is loaded)
|
| 239 |
+
if not hasattr(self, "inner_model"):
|
| 240 |
+
logger.info("Loading model during device setup (eager loading)")
|
| 241 |
+
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 242 |
+
self.inner_model.eval()
|
| 243 |
+
self.inner_model.gradient_checkpointing_disable()
|
| 244 |
+
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 245 |
+
"Yes", add_special_tokens=False
|
| 246 |
+
)[0]
|
| 247 |
+
logger.info(f"Model loaded successfully. Yes token ID: {self.inner_yes_token_id}")
|
| 248 |
+
|
| 249 |
|
| 250 |
_CE.__init__ = __init__
|
| 251 |
_CE.predict = predict
|