samtuckervegan commited on
Commit
d942f82
·
verified ·
1 Parent(s): 6d20ad5

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +25 -15
handler.py CHANGED
@@ -25,24 +25,34 @@ class EndpointHandler:
25
  prompt = data.get("inputs", "")
26
  if not prompt:
27
  return {"error": "Missing 'inputs' field."}
28
-
29
  defaults = {
30
  "max_new_tokens": 100,
31
  "do_sample": True,
32
  "temperature": 0.7,
33
- "top_p": 0.9
 
34
  }
35
-
36
  generation_args = {**defaults, **data.get("parameters", {})}
37
-
38
- outputs = self.generator(prompt, **generation_args)
39
-
40
- return {
41
- "choices": [{
42
- "message": {
43
- "role": "assistant",
44
- "content": outputs[0]["generated_text"].strip()
45
- },
46
- "finish_reason": "stop"
47
- }]
48
- }
 
 
 
 
 
 
 
 
 
 
25
  prompt = data.get("inputs", "")
26
  if not prompt:
27
  return {"error": "Missing 'inputs' field."}
28
+
29
  defaults = {
30
  "max_new_tokens": 100,
31
  "do_sample": True,
32
  "temperature": 0.7,
33
+ "top_p": 0.9,
34
+ "eos_token_id": self.tokenizer.eos_token_id # ✅ Stop at <|eot_id|>
35
  }
36
+
37
  generation_args = {**defaults, **data.get("parameters", {})}
38
+
39
+ try:
40
+ outputs = self.generator(prompt, **generation_args)
41
+ output_text = outputs[0]["generated_text"].strip()
42
+
43
+ finish_reason = "stop"
44
+ if len(self.tokenizer.encode(output_text)) >= generation_args["max_new_tokens"]:
45
+ finish_reason = "length"
46
+
47
+ return {
48
+ "choices": [{
49
+ "message": {
50
+ "role": "assistant",
51
+ "content": output_text
52
+ },
53
+ "finish_reason": finish_reason
54
+ }]
55
+ }
56
+
57
+ except Exception as e:
58
+ return {"error": str(e)}