import gc
import os
import threading
from pathlib import Path
from typing import Any, Dict, List
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
class EndpointHandler:
def __init__(self, path: str = ""):
model_dir = Path(path or os.getenv("HF_MODEL_DIR", "")).resolve()
if not model_dir.exists():
raise FileNotFoundError(f"Model directory does not exist: {model_dir}")
# Helpful debug info in endpoint logs
print(f"[handler] loading model from: {model_dir}")
print(f"[handler] files: {[p.name for p in model_dir.iterdir()]}")
required_any = [
"config.json",
]
missing_required = [f for f in required_any if not (model_dir / f).exists()]
if missing_required:
raise FileNotFoundError(
f"Missing required model files in {model_dir}: {missing_required}"
)
has_weights = any([
(model_dir / "model.safetensors").exists(),
(model_dir / "pytorch_model.bin").exists(),
any(model_dir.glob("model-*.safetensors")),
any(model_dir.glob("pytorch_model-*.bin")),
])
if not has_weights:
raise FileNotFoundError(
f"No model weight file found in {model_dir}. "
f"Expected model.safetensors, pytorch_model.bin, or sharded weights."
)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(
str(model_dir),
padding_side="left",
trust_remote_code=True,
)
self.model = AutoModelForCausalLM.from_pretrained(
str(model_dir),
dtype=self.torch_dtype,
trust_remote_code=True,
).to(self.device).eval()
# Safer token lookup for decoder LMs: include leading space variants if needed
yes_ids = self.tokenizer.encode(" yes", add_special_tokens=False)
no_ids = self.tokenizer.encode(" no", add_special_tokens=False)
if len(yes_ids) != 1 or len(no_ids) != 1:
raise ValueError(
f'Expected single-token " yes"/" no", got yes={yes_ids}, no={no_ids}. '
"You may need a different scoring method for this tokenizer."
)
self.token_true_id = yes_ids[0]
self.token_false_id = no_ids[0]
self.max_length = int(os.getenv("HANDLER_MAX_LENGTH", "8192"))
self.batch_size = int(os.getenv("HANDLER_BATCH_SIZE", "8"))
self.max_documents = int(os.getenv("HANDLER_MAX_DOCUMENTS", "64"))
self._semaphore = threading.Semaphore(int(os.getenv("HANDLER_MAX_CONCURRENT", "5")))
self.prefix = (
"<|im_start|>system\n"
'Judge whether the Document meets the requirements based on the Query '
'and the Instruct provided. Note that the answer can only be "yes" or "no".'
"<|im_end|>\n"
"<|im_start|>user\n"
)
self.suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"
self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False)
self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
def _format_one(self, instruction: str, query: str, document: str) -> str:
return f": {instruction}\n: {query}\n: {document}"
def _process_inputs(self, pairs: List[str]) -> Dict[str, torch.Tensor]:
inputs = self.tokenizer(
pairs,
padding=False,
truncation=True,
return_attention_mask=False,
max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens),
)
for i, ids in enumerate(inputs["input_ids"]):
inputs["input_ids"][i] = self.prefix_tokens + ids + self.suffix_tokens
padded = self.tokenizer.pad(
inputs,
padding=True,
return_tensors="pt",
)
for k in padded:
padded[k] = padded[k].to(self.device)
return padded
@torch.no_grad()
def _score(self, model_inputs: Dict[str, torch.Tensor]) -> List[float]:
logits = self.model(**model_inputs).logits[:, -1, :]
false_scores = logits[:, self.token_false_id]
true_scores = logits[:, self.token_true_id]
pair_scores = torch.stack([false_scores, true_scores], dim=1)
probs = torch.nn.functional.softmax(pair_scores, dim=1)[:, 1]
return probs.tolist()
def _score_in_batches(self, pairs: List[str]) -> List[float]:
all_scores = []
for i in range(0, len(pairs), self.batch_size):
batch = pairs[i : i + self.batch_size]
model_inputs = self._process_inputs(batch)
scores = self._score(model_inputs)
all_scores.extend(scores)
del model_inputs
gc.collect()
return all_scores
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
payload = data.get("inputs", data)
instruction = payload.get(
"instruction",
"Given a web search query, retrieve relevant passages that answer the query",
)
query = payload["query"]
documents = payload["documents"]
return_documents = payload.get("return_documents", True)
if not isinstance(documents, list) or len(documents) == 0:
raise ValueError("`documents` must be a non-empty list of strings.")
if len(documents) > self.max_documents:
raise ValueError(
f"`documents` exceeds max allowed ({self.max_documents}). "
f"Got {len(documents)}."
)
pairs = [self._format_one(instruction, query, doc) for doc in documents]
acquired = self._semaphore.acquire(timeout=int(os.getenv("HANDLER_QUEUE_TIMEOUT", "60")))
if not acquired:
raise RuntimeError(
"Server is busy. Another request is being processed. Please retry."
)
try:
scores = self._score_in_batches(pairs)
except MemoryError:
gc.collect()
raise RuntimeError(
"Out of memory while scoring. Try sending fewer or shorter documents."
)
finally:
self._semaphore.release()
results = []
for i, (doc, score) in enumerate(zip(documents, scores)):
item = {
"index": i,
"relevance_score": float(score),
}
if return_documents:
item["document"] = doc
results.append(item)
results.sort(key=lambda x: x["relevance_score"], reverse=True)
return {
"results": results,
"meta": {
"instruction": instruction,
"query": query,
"count": len(documents),
},
}