hutlim
/

Qwen3-Reranker-0.6B / handler.py
hutlim's picture
Update handler.py
e2f2cfb verified
raw
history blame
1.98 kB
from typing import Any, Dict, List
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
class EndpointHandler:
def __init__(self, path: str = ""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForSequenceClassification.from_pretrained(path).to(self.device)
self.model.eval()
if self.tokenizer.pad_token is None:
if self.tokenizer.eos_token is not None:
self.tokenizer.pad_token = self.tokenizer.eos_token
else:
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.model.resize_token_embeddings(len(self.tokenizer))
self.model.config.pad_token_id = self.tokenizer.pad_token_id
@torch.inference_mode()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
payload = data["inputs"]
params = data.get("parameters", {})
query = payload["query"]
documents = payload["documents"]
instruction = params.get("instruction")
if instruction:
query_text = f"Instruct: {instruction}\nQuery: {query}"
else:
query_text = query
pairs = [[query_text, doc] for doc in documents]
enc = self.tokenizer(
pairs,
padding=True,
truncation=True,
return_tensors="pt"
).to(self.device)
outputs = self.model(**enc)
logits = outputs.logits
if logits.shape[-1] == 1:
scores = logits.squeeze(-1).float().cpu().tolist()
else:
scores = logits[:, -1].float().cpu().tolist()
ranked = sorted(
[{"index": i, "score": s, "document": documents[i]} for i, s in enumerate(scores)],
key=lambda x: x["score"],
reverse=True,
)
return {"scores": scores, "ranked": ranked}