defford commited on
Commit
638efb3
·
verified ·
1 Parent(s): b66e41a

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +33 -0
handler.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from transformers import AutoProcessor, AutoModelForVision2Seq
4
+ import io
5
+ import base64
6
+
7
+ class EndpointHandler():
8
+ def __init__(self, path=""):
9
+ # Use the specific GLM-OCR architecture
10
+ self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
11
+ self.model = AutoModelForVision2Seq.from_pretrained(
12
+ path,
13
+ trust_remote_code=True,
14
+ device_map="auto",
15
+ torch_dtype=torch.bfloat16
16
+ )
17
+ self.model.eval()
18
+
19
+ def __call__(self, data):
20
+ # Decode the image sent from Google Sheets
21
+ inputs = data.pop("inputs", data)
22
+ image_data = base64.b64decode(inputs)
23
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
24
+
25
+ # Format for GLM-OCR
26
+ prompt = "Identify Date, Vendor, and list every Item with description, qty, and price. Return as JSON."
27
+ inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.model.device)
28
+
29
+ with torch.no_grad():
30
+ generated_ids = self.model.generate(**inputs, max_new_tokens=1024)
31
+
32
+ result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
33
+ return [{"generated_text": result}]