import gc import os import threading from pathlib import Path from typing import Any, Dict, List import torch from transformers import AutoTokenizer, AutoModelForCausalLM class EndpointHandler: def __init__(self, path: str = ""): model_dir = Path(path or os.getenv("HF_MODEL_DIR", "")).resolve() if not model_dir.exists(): raise FileNotFoundError(f"Model directory does not exist: {model_dir}") # Helpful debug info in endpoint logs print(f"[handler] loading model from: {model_dir}") print(f"[handler] files: {[p.name for p in model_dir.iterdir()]}") required_any = [ "config.json", ] missing_required = [f for f in required_any if not (model_dir / f).exists()] if missing_required: raise FileNotFoundError( f"Missing required model files in {model_dir}: {missing_required}" ) has_weights = any([ (model_dir / "model.safetensors").exists(), (model_dir / "pytorch_model.bin").exists(), any(model_dir.glob("model-*.safetensors")), any(model_dir.glob("pytorch_model-*.bin")), ]) if not has_weights: raise FileNotFoundError( f"No model weight file found in {model_dir}. " f"Expected model.safetensors, pytorch_model.bin, or sharded weights." ) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 self.tokenizer = AutoTokenizer.from_pretrained( str(model_dir), padding_side="left", trust_remote_code=True, ) self.model = AutoModelForCausalLM.from_pretrained( str(model_dir), dtype=self.torch_dtype, trust_remote_code=True, ).to(self.device).eval() # Safer token lookup for decoder LMs: include leading space variants if needed yes_ids = self.tokenizer.encode(" yes", add_special_tokens=False) no_ids = self.tokenizer.encode(" no", add_special_tokens=False) if len(yes_ids) != 1 or len(no_ids) != 1: raise ValueError( f'Expected single-token " yes"/" no", got yes={yes_ids}, no={no_ids}. ' "You may need a different scoring method for this tokenizer." ) self.token_true_id = yes_ids[0] self.token_false_id = no_ids[0] self.max_length = int(os.getenv("HANDLER_MAX_LENGTH", "8192")) self.batch_size = int(os.getenv("HANDLER_BATCH_SIZE", "8")) self.max_documents = int(os.getenv("HANDLER_MAX_DOCUMENTS", "64")) self._semaphore = threading.Semaphore(int(os.getenv("HANDLER_MAX_CONCURRENT", "5"))) self.prefix = ( "<|im_start|>system\n" 'Judge whether the Document meets the requirements based on the Query ' 'and the Instruct provided. Note that the answer can only be "yes" or "no".' "<|im_end|>\n" "<|im_start|>user\n" ) self.suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False) self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False) def _format_one(self, instruction: str, query: str, document: str) -> str: return f": {instruction}\n: {query}\n: {document}" def _process_inputs(self, pairs: List[str]) -> Dict[str, torch.Tensor]: inputs = self.tokenizer( pairs, padding=False, truncation=True, return_attention_mask=False, max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens), ) for i, ids in enumerate(inputs["input_ids"]): inputs["input_ids"][i] = self.prefix_tokens + ids + self.suffix_tokens padded = self.tokenizer.pad( inputs, padding=True, return_tensors="pt", ) for k in padded: padded[k] = padded[k].to(self.device) return padded @torch.no_grad() def _score(self, model_inputs: Dict[str, torch.Tensor]) -> List[float]: logits = self.model(**model_inputs).logits[:, -1, :] false_scores = logits[:, self.token_false_id] true_scores = logits[:, self.token_true_id] pair_scores = torch.stack([false_scores, true_scores], dim=1) probs = torch.nn.functional.softmax(pair_scores, dim=1)[:, 1] return probs.tolist() def _score_in_batches(self, pairs: List[str]) -> List[float]: all_scores = [] for i in range(0, len(pairs), self.batch_size): batch = pairs[i : i + self.batch_size] model_inputs = self._process_inputs(batch) scores = self._score(model_inputs) all_scores.extend(scores) del model_inputs gc.collect() return all_scores def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: payload = data.get("inputs", data) instruction = payload.get( "instruction", "Given a web search query, retrieve relevant passages that answer the query", ) query = payload["query"] documents = payload["documents"] return_documents = payload.get("return_documents", True) if not isinstance(documents, list) or len(documents) == 0: raise ValueError("`documents` must be a non-empty list of strings.") if len(documents) > self.max_documents: raise ValueError( f"`documents` exceeds max allowed ({self.max_documents}). " f"Got {len(documents)}." ) pairs = [self._format_one(instruction, query, doc) for doc in documents] acquired = self._semaphore.acquire(timeout=int(os.getenv("HANDLER_QUEUE_TIMEOUT", "60"))) if not acquired: raise RuntimeError( "Server is busy. Another request is being processed. Please retry." ) try: scores = self._score_in_batches(pairs) except MemoryError: gc.collect() raise RuntimeError( "Out of memory while scoring. Try sending fewer or shorter documents." ) finally: self._semaphore.release() results = [] for i, (doc, score) in enumerate(zip(documents, scores)): item = { "index": i, "relevance_score": float(score), } if return_documents: item["document"] = doc results.append(item) results.sort(key=lambda x: x["relevance_score"], reverse=True) return { "results": results, "meta": { "instruction": instruction, "query": query, "count": len(documents), }, }