hutlim
/

hutlim commited on
Commit
a4af8e6
·
verified ·
1 Parent(s): 4be06c8

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +77 -0
handler.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, model_dir: str, **kwargs: Any) -> None:
8
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
11
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
12
+ self.model.to(self.device)
13
+ self.model.eval()
14
+
15
+ # Fix batching when the tokenizer has no pad token
16
+ if self.tokenizer.pad_token is None:
17
+ if self.tokenizer.eos_token is not None:
18
+ self.tokenizer.pad_token = self.tokenizer.eos_token
19
+ else:
20
+ self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
21
+ self.model.resize_token_embeddings(len(self.tokenizer))
22
+
23
+ self.model.config.pad_token_id = self.tokenizer.pad_token_id
24
+
25
+ @torch.inference_mode()
26
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
27
+ """
28
+ Expected request body:
29
+ {
30
+ "query": "What is the capital of China?",
31
+ "documents": [
32
+ "The capital of China is Beijing.",
33
+ "Gravity is a force..."
34
+ ],
35
+ "instruction": "Given a web search query, retrieve relevant passages that answer the query"
36
+ }
37
+ """
38
+
39
+ query = data["query"]
40
+ documents = data["documents"]
41
+ instruction = data.get("instruction")
42
+
43
+ if instruction:
44
+ query_text = f"Instruct: {instruction}\nQuery: {query}"
45
+ else:
46
+ query_text = query
47
+
48
+ pairs = [[query_text, doc] for doc in documents]
49
+
50
+ inputs = self.tokenizer(
51
+ pairs,
52
+ padding=True,
53
+ truncation=True,
54
+ return_tensors="pt"
55
+ ).to(self.device)
56
+
57
+ outputs = self.model(**inputs)
58
+
59
+ # Common reranker head shape: [batch, 1] or [batch, 2]
60
+ logits = outputs.logits
61
+
62
+ if logits.shape[-1] == 1:
63
+ scores = logits.squeeze(-1).float().cpu().tolist()
64
+ else:
65
+ # If binary classification style, use the positive class
66
+ scores = logits[:, -1].float().cpu().tolist()
67
+
68
+ ranked = sorted(
69
+ [{"index": i, "score": s, "document": documents[i]} for i, s in enumerate(scores)],
70
+ key=lambda x: x["score"],
71
+ reverse=True,
72
+ )
73
+
74
+ return {
75
+ "scores": scores,
76
+ "ranked": ranked,
77
+ }