crbo commited on
Commit
11c2f83
·
verified ·
1 Parent(s): b5c1098

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. handler.py +93 -0
  2. requirements.txt +5 -0
handler.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom handler for LightOnOCR-2-1B on HuggingFace Inference Endpoints.
3
+ Requires transformers >= 5.0.0
4
+
5
+ Deployment options:
6
+ A) Fork lightonai/LightOnOCR-2-1B and add this file → uses model_dir
7
+ B) New repo with just handler.py + requirements.txt → loads from Hub
8
+ """
9
+ import base64
10
+ import io
11
+ import os
12
+ from typing import Any, Dict
13
+
14
+ import torch
15
+ from PIL import Image
16
+ from transformers import LightOnOcrForConditionalGeneration, LightOnOcrProcessor
17
+
18
+ MODEL_ID = "lightonai/LightOnOCR-2-1B"
19
+
20
+
21
+ class EndpointHandler:
22
+ def __init__(self, model_dir: str, **kwargs: Any):
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
25
+
26
+ self.device = device
27
+ self.dtype = dtype
28
+
29
+ # Use model_dir if it contains model weights (fork), otherwise load from Hub
30
+ config_path = os.path.join(model_dir, "config.json")
31
+ source = model_dir if os.path.exists(config_path) else MODEL_ID
32
+
33
+ self.model = LightOnOcrForConditionalGeneration.from_pretrained(
34
+ source, torch_dtype=dtype
35
+ ).to(device)
36
+ self.processor = LightOnOcrProcessor.from_pretrained(source)
37
+
38
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
39
+ inputs_data = data.get("inputs", data)
40
+
41
+ # --- Handle image input ---
42
+ image = None
43
+ image_url = None
44
+
45
+ if isinstance(inputs_data, str):
46
+ # Direct base64 string
47
+ image = Image.open(io.BytesIO(base64.b64decode(inputs_data))).convert("RGB")
48
+ elif isinstance(inputs_data, dict):
49
+ if "image" in inputs_data:
50
+ img_input = inputs_data["image"]
51
+ if img_input.startswith(("http://", "https://")):
52
+ image_url = img_input
53
+ else:
54
+ image = Image.open(io.BytesIO(base64.b64decode(img_input))).convert("RGB")
55
+ elif "url" in inputs_data:
56
+ image_url = inputs_data["url"]
57
+
58
+ if image is None and image_url is None:
59
+ return {"error": "No image provided. Send 'image' (base64 or URL) or 'url' in inputs."}
60
+
61
+ # --- Build conversation ---
62
+ prompt = inputs_data.get("prompt", None) if isinstance(inputs_data, dict) else None
63
+ content = []
64
+ if image_url:
65
+ content.append({"type": "image", "url": image_url})
66
+ elif image:
67
+ content.append({"type": "image", "image": image})
68
+
69
+ if prompt:
70
+ content.append({"type": "text", "text": prompt})
71
+
72
+ conversation = [{"role": "user", "content": content}]
73
+
74
+ # --- Process & generate ---
75
+ max_tokens = int(inputs_data.get("max_new_tokens", 4096)) if isinstance(inputs_data, dict) else 4096
76
+
77
+ inputs = self.processor.apply_chat_template(
78
+ conversation,
79
+ add_generation_prompt=True,
80
+ tokenize=True,
81
+ return_dict=True,
82
+ return_tensors="pt",
83
+ )
84
+ inputs = {
85
+ k: v.to(device=self.device, dtype=self.dtype) if v.is_floating_point() else v.to(self.device)
86
+ for k, v in inputs.items()
87
+ }
88
+
89
+ output_ids = self.model.generate(**inputs, max_new_tokens=max_tokens)
90
+ generated_ids = output_ids[0, inputs["input_ids"].shape[1]:]
91
+ output_text = self.processor.decode(generated_ids, skip_special_tokens=True)
92
+
93
+ return {"generated_text": output_text}
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers>=5.0.0
2
+ pillow
3
+ pypdfium2
4
+ torch
5
+ accelerate