rishi38 commited on
Commit
51876e1
·
verified ·
1 Parent(s): b777c1a

Remove handler.py

Browse files
Files changed (1) hide show
  1. handler.py +0 -102
handler.py DELETED
@@ -1,102 +0,0 @@
1
- import json
2
- from typing import Any
3
-
4
- import torch
5
- from peft import PeftModel
6
- from transformers import AutoModelForCausalLM, AutoTokenizer
7
-
8
-
9
- class EndpointHandler:
10
- """
11
- Custom Inference Endpoint handler for adapter-only LoRA repos.
12
- """
13
-
14
- def __init__(self, model_dir: str, **kwargs: Any):
15
- adapter_cfg_path = f"{model_dir}/adapter_config.json"
16
- with open(adapter_cfg_path, "r", encoding="utf-8") as f:
17
- adapter_cfg = json.load(f)
18
-
19
- base_model_id = adapter_cfg.get("base_model_name_or_path", "Qwen/Qwen3-4B")
20
-
21
- # Endpoints are usually more stable with the canonical base model id.
22
- if "unsloth" in base_model_id and "bnb-4bit" in base_model_id:
23
- # Try to infer the base model if it's an unsloth bnb-4bit one
24
- # For Qwen3-4B-unsloth-bnb-4bit, the base is likely Qwen/Qwen3-4B
25
- if "Qwen3-4B" in base_model_id:
26
- base_model_id = "Qwen/Qwen3-4B"
27
- elif "Qwen2.5" in base_model_id:
28
- base_model_id = "Qwen/Qwen2.5-7B" # Or whatever the base is
29
-
30
- self.tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
31
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
32
-
33
- base_model = AutoModelForCausalLM.from_pretrained(
34
- base_model_id,
35
- torch_dtype=dtype,
36
- device_map="auto" if torch.cuda.is_available() else None,
37
- low_cpu_mem_usage=True,
38
- )
39
- self.model = PeftModel.from_pretrained(base_model, model_dir)
40
- self.model.eval()
41
- if not torch.cuda.is_available():
42
- self.model.to("cpu")
43
-
44
- def _format_prompt(self, inputs: Any) -> str:
45
- if isinstance(inputs, str):
46
- return inputs
47
-
48
- # Support chat-style inputs:
49
- # [{"role":"system","content":"..."},{"role":"user","content":"..."}]
50
- if isinstance(inputs, list) and inputs and isinstance(inputs[0], dict) and "role" in inputs[0]:
51
- try:
52
- return self.tokenizer.apply_chat_template(
53
- inputs,
54
- add_generation_prompt=True,
55
- tokenize=False,
56
- enable_thinking=False,
57
- )
58
- except TypeError:
59
- return self.tokenizer.apply_chat_template(
60
- inputs,
61
- add_generation_prompt=True,
62
- tokenize=False,
63
- )
64
-
65
- if isinstance(inputs, dict):
66
- return inputs.get("prompt") or inputs.get("text") or json.dumps(inputs)
67
-
68
- return str(inputs)
69
-
70
- def __call__(self, data: Any) -> dict[str, str]:
71
- payload = data if isinstance(data, dict) else {"inputs": data}
72
- params = payload.get("parameters", {}) or {}
73
-
74
- prompt = self._format_prompt(payload.get("inputs", ""))
75
- max_new_tokens = int(params.get("max_new_tokens", 128))
76
- temperature = float(params.get("temperature", 0.2))
77
- top_p = float(params.get("top_p", 0.95))
78
- top_k = int(params.get("top_k", 0))
79
- if top_k < 0:
80
- top_k = 0
81
-
82
- enc = self.tokenizer([prompt], return_tensors="pt")
83
- device = next(self.model.parameters()).device
84
- enc = {k: v.to(device) for k, v in enc.items()}
85
-
86
- with torch.no_grad():
87
- out = self.model.generate(
88
- **enc,
89
- max_new_tokens=max_new_tokens,
90
- do_sample=temperature > 0,
91
- temperature=max(temperature, 1e-5),
92
- top_p=top_p,
93
- top_k=top_k,
94
- eos_token_id=self.tokenizer.eos_token_id,
95
- pad_token_id=self.tokenizer.pad_token_id,
96
- )
97
-
98
- generated_text = self.tokenizer.decode(
99
- out[0][enc["input_ids"].shape[1]:],
100
- skip_special_tokens=True,
101
- )
102
- return {"generated_text": generated_text}