zakerytclarke commited on
Commit
21235f2
·
verified ·
1 Parent(s): 2dca1ac

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +106 -0
handler.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ from typing import Any, Dict, List, Union
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+
6
+
7
+ MAX_INPUT_TOKENS = 512
8
+
9
+
10
+ class EndpointHandler:
11
+ """
12
+ HF Inference Endpoints custom handler that reproduces the exact style of
13
+ your shared Colab code:
14
+ - slow tokenizer (use_fast=False)
15
+ - Seq2Seq model
16
+ - deterministic generation by default (do_sample=False)
17
+ - decode skip_special_tokens=True
18
+ - if input > 512 tokens, keep only the MOST RECENT tokens (left-truncate)
19
+ """
20
+
21
+ def __init__(self, path: str = ""):
22
+ # Match your working code path and avoid fast tokenizer init issues on HF endpoints.
23
+ self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)
24
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
25
+
26
+ self.model.eval()
27
+ self.device = torch.device("cpu")
28
+ self.model.to(self.device)
29
+
30
+ @torch.inference_mode()
31
+ def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, str], List[Dict[str, str]]]:
32
+ """
33
+ Request schema:
34
+ {
35
+ "inputs": "<full prompt string>" OR ["<prompt1>", "<prompt2>", ...],
36
+ "parameters": { ... optional generate kwargs ... }
37
+ }
38
+
39
+ Response schema (kept simple):
40
+ - single input -> {"generated_text": "..."}
41
+ - list inputs -> [{"generated_text": "..."}, ...]
42
+ """
43
+ if "inputs" not in data:
44
+ raise ValueError("Missing required field 'inputs'.")
45
+
46
+ inputs = data["inputs"]
47
+ params = data.get("parameters") or {}
48
+
49
+ # Normalize to a batch of prompts
50
+ if isinstance(inputs, str):
51
+ prompts = [inputs]
52
+ single = True
53
+ else:
54
+ prompts = list(inputs)
55
+ single = False
56
+
57
+ # --- Tokenize WITHOUT truncation first so we can left-truncate manually ---
58
+ enc = self.tokenizer(
59
+ prompts,
60
+ return_tensors="pt",
61
+ padding=True,
62
+ truncation=False,
63
+ )
64
+
65
+ input_ids = enc["input_ids"]
66
+ attention_mask = enc["attention_mask"]
67
+
68
+ # Left-truncate to keep the most recent tokens (right side)
69
+ if input_ids.shape[1] > MAX_INPUT_TOKENS:
70
+ input_ids = input_ids[:, -MAX_INPUT_TOKENS:]
71
+ attention_mask = attention_mask[:, -MAX_INPUT_TOKENS:]
72
+
73
+ input_ids = input_ids.to(self.device)
74
+ attention_mask = attention_mask.to(self.device)
75
+
76
+ # Defaults that match your code: model.generate(**inputs, do_sample=False)
77
+ # Keep them overrideable via "parameters".
78
+ gen_kwargs = {
79
+ "do_sample": params.pop("do_sample", False),
80
+ }
81
+
82
+ # Optional knobs (only applied if provided)
83
+ if "max_new_tokens" in params:
84
+ gen_kwargs["max_new_tokens"] = params.pop("max_new_tokens")
85
+ if "num_beams" in params:
86
+ gen_kwargs["num_beams"] = params.pop("num_beams")
87
+ if "temperature" in params:
88
+ gen_kwargs["temperature"] = params.pop("temperature")
89
+ if "top_p" in params:
90
+ gen_kwargs["top_p"] = params.pop("top_p")
91
+ if "top_k" in params:
92
+ gen_kwargs["top_k"] = params.pop("top_k")
93
+
94
+ # Allow any remaining generate() kwargs through, in case you pass them
95
+ gen_kwargs.update(params)
96
+
97
+ outputs = self.model.generate(
98
+ input_ids=input_ids,
99
+ attention_mask=attention_mask,
100
+ **gen_kwargs,
101
+ )
102
+
103
+ texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
104
+
105
+ result = [{"generated_text": t} for t in texts]
106
+ return result[0] if single else result