hutlim
/

hutlim commited on
Commit
6157eb9
·
verified ·
1 Parent(s): e2f2cfb

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +97 -41
handler.py CHANGED
@@ -1,59 +1,115 @@
 
1
  from typing import Any, Dict, List
 
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path: str = ""):
 
7
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- self.tokenizer = AutoTokenizer.from_pretrained(path)
10
- self.model = AutoModelForSequenceClassification.from_pretrained(path).to(self.device)
11
- self.model.eval()
12
 
13
- if self.tokenizer.pad_token is None:
14
- if self.tokenizer.eos_token is not None:
15
- self.tokenizer.pad_token = self.tokenizer.eos_token
16
- else:
17
- self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
18
- self.model.resize_token_embeddings(len(self.tokenizer))
19
 
20
- self.model.config.pad_token_id = self.tokenizer.pad_token_id
 
 
 
 
 
 
 
21
 
22
- @torch.inference_mode()
23
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
24
- payload = data["inputs"]
25
- params = data.get("parameters", {})
26
 
 
 
 
 
27
  query = payload["query"]
28
  documents = payload["documents"]
29
- instruction = params.get("instruction")
30
 
31
- if instruction:
32
- query_text = f"Instruct: {instruction}\nQuery: {query}"
33
- else:
34
- query_text = query
35
 
36
- pairs = [[query_text, doc] for doc in documents]
 
 
37
 
38
- enc = self.tokenizer(
39
- pairs,
40
- padding=True,
41
- truncation=True,
42
- return_tensors="pt"
43
- ).to(self.device)
44
-
45
- outputs = self.model(**enc)
46
- logits = outputs.logits
47
-
48
- if logits.shape[-1] == 1:
49
- scores = logits.squeeze(-1).float().cpu().tolist()
50
- else:
51
- scores = logits[:, -1].float().cpu().tolist()
52
-
53
- ranked = sorted(
54
- [{"index": i, "score": s, "document": documents[i]} for i, s in enumerate(scores)],
55
- key=lambda x: x["score"],
56
- reverse=True,
57
- )
58
 
59
- return {"scores": scores, "ranked": ranked}
 
 
 
 
 
 
 
 
1
+ import os
2
  from typing import Any, Dict, List
3
+
4
  import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+
7
 
8
  class EndpointHandler:
9
  def __init__(self, path: str = ""):
10
+ model_dir = path if path else os.getenv("HF_MODEL_DIR", ".")
11
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ self.dtype = torch.float16 if self.device == "cuda" else torch.float32
13
+
14
+ self.tokenizer = AutoTokenizer.from_pretrained(
15
+ model_dir,
16
+ padding_side="left",
17
+ trust_remote_code=True,
18
+ )
19
+ self.model = AutoModelForCausalLM.from_pretrained(
20
+ model_dir,
21
+ torch_dtype=self.dtype,
22
+ trust_remote_code=True,
23
+ ).to(self.device).eval()
24
+
25
+ self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
26
+ self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
27
+
28
+ self.max_length = 8192
29
+
30
+ self.prefix = (
31
+ '<|im_start|>system\n'
32
+ 'Judge whether the Document meets the requirements based on the Query '
33
+ 'and the Instruct provided. Note that the answer can only be "yes" or "no".'
34
+ '<|im_end|>\n'
35
+ '<|im_start|>user\n'
36
+ )
37
+ self.suffix = '<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n'
38
+
39
+ self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False)
40
+ self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
41
+
42
+ def _format_one(self, instruction: str, query: str, document: str) -> str:
43
+ return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {document}"
44
+
45
+ def _process_inputs(self, pairs: List[str]) -> Dict[str, torch.Tensor]:
46
+ inputs = self.tokenizer(
47
+ pairs,
48
+ padding=False,
49
+ truncation="longest_first",
50
+ return_attention_mask=False,
51
+ max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens),
52
+ )
53
+
54
+ for i, ids in enumerate(inputs["input_ids"]):
55
+ inputs["input_ids"][i] = self.prefix_tokens + ids + self.suffix_tokens
56
+
57
+ padded = self.tokenizer.pad(
58
+ inputs,
59
+ padding=True,
60
+ return_tensors="pt",
61
+ max_length=self.max_length,
62
+ )
63
 
64
+ for k in padded:
65
+ padded[k] = padded[k].to(self.device)
 
66
 
67
+ return padded
 
 
 
 
 
68
 
69
+ @torch.no_grad()
70
+ def _score(self, model_inputs: Dict[str, torch.Tensor]) -> List[float]:
71
+ logits = self.model(**model_inputs).logits[:, -1, :]
72
+ false_scores = logits[:, self.token_false_id]
73
+ true_scores = logits[:, self.token_true_id]
74
+ pair_scores = torch.stack([false_scores, true_scores], dim=1)
75
+ probs = torch.nn.functional.softmax(pair_scores, dim=1)[:, 1]
76
+ return probs.tolist()
77
 
 
78
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
79
+ payload = data.get("inputs", data)
 
80
 
81
+ instruction = payload.get(
82
+ "instruction",
83
+ "Given a web search query, retrieve relevant passages that answer the query",
84
+ )
85
  query = payload["query"]
86
  documents = payload["documents"]
87
+ return_documents = payload.get("return_documents", True)
88
 
89
+ if not isinstance(documents, list) or len(documents) == 0:
90
+ raise ValueError("`documents` must be a non-empty list of strings.")
 
 
91
 
92
+ pairs = [self._format_one(instruction, query, doc) for doc in documents]
93
+ model_inputs = self._process_inputs(pairs)
94
+ scores = self._score(model_inputs)
95
 
96
+ results = []
97
+ for i, (doc, score) in enumerate(zip(documents, scores)):
98
+ item = {
99
+ "index": i,
100
+ "relevance_score": float(score),
101
+ }
102
+ if return_documents:
103
+ item["document"] = doc
104
+ results.append(item)
105
+
106
+ results.sort(key=lambda x: x["relevance_score"], reverse=True)
 
 
 
 
 
 
 
 
 
107
 
108
+ return {
109
+ "results": results,
110
+ "meta": {
111
+ "instruction": instruction,
112
+ "query": query,
113
+ "count": len(documents),
114
+ },
115
+ }