hutlim
/

hutlim commited on
Commit
e2f2cfb
·
verified ·
1 Parent(s): 336007e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +12 -30
handler.py CHANGED
@@ -2,17 +2,14 @@ from typing import Any, Dict, List
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
-
6
  class EndpointHandler:
7
- def __init__(self, model_dir: str, **kwargs: Any) -> None:
8
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
- self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
11
- self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
12
- self.model.to(self.device)
13
  self.model.eval()
14
 
15
- # Fix batching when the tokenizer has no pad token
16
  if self.tokenizer.pad_token is None:
17
  if self.tokenizer.eos_token is not None:
18
  self.tokenizer.pad_token = self.tokenizer.eos_token
@@ -24,21 +21,12 @@ class EndpointHandler:
24
 
25
  @torch.inference_mode()
26
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
27
- """
28
- Expected request body:
29
- {
30
- "query": "What is the capital of China?",
31
- "documents": [
32
- "The capital of China is Beijing.",
33
- "Gravity is a force..."
34
- ],
35
- "instruction": "Given a web search query, retrieve relevant passages that answer the query"
36
- }
37
- """
38
 
39
- query = data["query"]
40
- documents = data["documents"]
41
- instruction = data.get("instruction")
42
 
43
  if instruction:
44
  query_text = f"Instruct: {instruction}\nQuery: {query}"
@@ -47,22 +35,19 @@ class EndpointHandler:
47
 
48
  pairs = [[query_text, doc] for doc in documents]
49
 
50
- inputs = self.tokenizer(
51
  pairs,
52
  padding=True,
53
  truncation=True,
54
  return_tensors="pt"
55
  ).to(self.device)
56
 
57
- outputs = self.model(**inputs)
58
-
59
- # Common reranker head shape: [batch, 1] or [batch, 2]
60
  logits = outputs.logits
61
 
62
  if logits.shape[-1] == 1:
63
  scores = logits.squeeze(-1).float().cpu().tolist()
64
  else:
65
- # If binary classification style, use the positive class
66
  scores = logits[:, -1].float().cpu().tolist()
67
 
68
  ranked = sorted(
@@ -71,7 +56,4 @@ class EndpointHandler:
71
  reverse=True,
72
  )
73
 
74
- return {
75
- "scores": scores,
76
- "ranked": ranked,
77
- }
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
 
5
  class EndpointHandler:
6
+ def __init__(self, path: str = ""):
7
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
10
+ self.model = AutoModelForSequenceClassification.from_pretrained(path).to(self.device)
 
11
  self.model.eval()
12
 
 
13
  if self.tokenizer.pad_token is None:
14
  if self.tokenizer.eos_token is not None:
15
  self.tokenizer.pad_token = self.tokenizer.eos_token
 
21
 
22
  @torch.inference_mode()
23
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
24
+ payload = data["inputs"]
25
+ params = data.get("parameters", {})
 
 
 
 
 
 
 
 
 
26
 
27
+ query = payload["query"]
28
+ documents = payload["documents"]
29
+ instruction = params.get("instruction")
30
 
31
  if instruction:
32
  query_text = f"Instruct: {instruction}\nQuery: {query}"
 
35
 
36
  pairs = [[query_text, doc] for doc in documents]
37
 
38
+ enc = self.tokenizer(
39
  pairs,
40
  padding=True,
41
  truncation=True,
42
  return_tensors="pt"
43
  ).to(self.device)
44
 
45
+ outputs = self.model(**enc)
 
 
46
  logits = outputs.logits
47
 
48
  if logits.shape[-1] == 1:
49
  scores = logits.squeeze(-1).float().cpu().tolist()
50
  else:
 
51
  scores = logits[:, -1].float().cpu().tolist()
52
 
53
  ranked = sorted(
 
56
  reverse=True,
57
  )
58
 
59
+ return {"scores": scores, "ranked": ranked}