hutlim
/

hutlim commited on
Commit
43d8db7
·
verified ·
1 Parent(s): f775139

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +39 -5
handler.py CHANGED
@@ -1,4 +1,6 @@
 
1
  import os
 
2
  from pathlib import Path
3
  from typing import Any, Dict, List
4
 
@@ -8,7 +10,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
8
 
9
  class EndpointHandler:
10
  def __init__(self, path: str = ""):
11
- model_dir = Path(path or os.getenv("HF_MODEL_DIR", ".")).resolve()
12
 
13
  if not model_dir.exists():
14
  raise FileNotFoundError(f"Model directory does not exist: {model_dir}")
@@ -49,7 +51,7 @@ class EndpointHandler:
49
 
50
  self.model = AutoModelForCausalLM.from_pretrained(
51
  str(model_dir),
52
- torch_dtype=self.torch_dtype,
53
  trust_remote_code=True,
54
  ).to(self.device).eval()
55
 
@@ -65,7 +67,10 @@ class EndpointHandler:
65
  self.token_true_id = yes_ids[0]
66
  self.token_false_id = no_ids[0]
67
 
68
- self.max_length = 8192
 
 
 
69
 
70
  self.prefix = (
71
  "<|im_start|>system\n"
@@ -114,6 +119,17 @@ class EndpointHandler:
114
  probs = torch.nn.functional.softmax(pair_scores, dim=1)[:, 1]
115
  return probs.tolist()
116
 
 
 
 
 
 
 
 
 
 
 
 
117
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
118
  payload = data.get("inputs", data)
119
 
@@ -128,9 +144,27 @@ class EndpointHandler:
128
  if not isinstance(documents, list) or len(documents) == 0:
129
  raise ValueError("`documents` must be a non-empty list of strings.")
130
 
 
 
 
 
 
 
131
  pairs = [self._format_one(instruction, query, doc) for doc in documents]
132
- model_inputs = self._process_inputs(pairs)
133
- scores = self._score(model_inputs)
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  results = []
136
  for i, (doc, score) in enumerate(zip(documents, scores)):
 
1
+ import gc
2
  import os
3
+ import threading
4
  from pathlib import Path
5
  from typing import Any, Dict, List
6
 
 
10
 
11
  class EndpointHandler:
12
  def __init__(self, path: str = ""):
13
+ model_dir = Path(path or os.getenv("HF_MODEL_DIR", "")).resolve()
14
 
15
  if not model_dir.exists():
16
  raise FileNotFoundError(f"Model directory does not exist: {model_dir}")
 
51
 
52
  self.model = AutoModelForCausalLM.from_pretrained(
53
  str(model_dir),
54
+ dtype=self.torch_dtype,
55
  trust_remote_code=True,
56
  ).to(self.device).eval()
57
 
 
67
  self.token_true_id = yes_ids[0]
68
  self.token_false_id = no_ids[0]
69
 
70
+ self.max_length = int(os.getenv("HANDLER_MAX_LENGTH", "8192"))
71
+ self.batch_size = int(os.getenv("HANDLER_BATCH_SIZE", "8"))
72
+ self.max_documents = int(os.getenv("HANDLER_MAX_DOCUMENTS", "64"))
73
+ self._semaphore = threading.Semaphore(int(os.getenv("HANDLER_MAX_CONCURRENT", "5")))
74
 
75
  self.prefix = (
76
  "<|im_start|>system\n"
 
119
  probs = torch.nn.functional.softmax(pair_scores, dim=1)[:, 1]
120
  return probs.tolist()
121
 
122
+ def _score_in_batches(self, pairs: List[str]) -> List[float]:
123
+ all_scores = []
124
+ for i in range(0, len(pairs), self.batch_size):
125
+ batch = pairs[i : i + self.batch_size]
126
+ model_inputs = self._process_inputs(batch)
127
+ scores = self._score(model_inputs)
128
+ all_scores.extend(scores)
129
+ del model_inputs
130
+ gc.collect()
131
+ return all_scores
132
+
133
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
134
  payload = data.get("inputs", data)
135
 
 
144
  if not isinstance(documents, list) or len(documents) == 0:
145
  raise ValueError("`documents` must be a non-empty list of strings.")
146
 
147
+ if len(documents) > self.max_documents:
148
+ raise ValueError(
149
+ f"`documents` exceeds max allowed ({self.max_documents}). "
150
+ f"Got {len(documents)}."
151
+ )
152
+
153
  pairs = [self._format_one(instruction, query, doc) for doc in documents]
154
+ acquired = self._semaphore.acquire(timeout=int(os.getenv("HANDLER_QUEUE_TIMEOUT", "60")))
155
+ if not acquired:
156
+ raise RuntimeError(
157
+ "Server is busy. Another request is being processed. Please retry."
158
+ )
159
+ try:
160
+ scores = self._score_in_batches(pairs)
161
+ except MemoryError:
162
+ gc.collect()
163
+ raise RuntimeError(
164
+ "Out of memory while scoring. Try sending fewer or shorter documents."
165
+ )
166
+ finally:
167
+ self._semaphore.release()
168
 
169
  results = []
170
  for i, (doc, score) in enumerate(zip(documents, scores)):