defford commited on
Commit
0175efa
·
verified ·
1 Parent(s): acbbc4b

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -29
handler.py CHANGED
@@ -2,54 +2,35 @@ import torch
2
  from PIL import Image
3
  import io
4
  import base64
5
- from transformers import AutoProcessor, AutoModelForImageTextToText
6
 
7
  class EndpointHandler():
8
  def __init__(self, path=""):
9
- # This tells the library to load the custom code in your repo
10
- self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
11
- self.model = AutoModelForImageTextToText.from_pretrained(
12
  path,
13
- trust_remote_code=True,
14
  device_map="auto",
15
  torch_dtype=torch.bfloat16
16
  )
17
  self.model.eval()
18
 
19
  def __call__(self, data):
20
- # Extract base64 from Google Apps Script payload
21
  inputs_data = data.pop("inputs", data)
22
  image_bytes = base64.b64decode(inputs_data)
23
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
24
 
25
- # Bookkeeping prompt
26
- prompt = "Identify Date, Vendor, and list every Item with description, qty, and price. Return as a JSON array."
27
 
28
- # Use the specific chat template for GLM-OCR
29
- messages = [
30
- {
31
- "role": "user",
32
- "content": [
33
- {"type": "image", "image": image},
34
- {"type": "text", "text": prompt},
35
- ],
36
- }
37
- ]
38
 
39
- # Prepare inputs
40
- inputs = self.processor.apply_chat_template(
41
- messages,
42
- add_generation_prompt=True,
43
- tokenize=True,
44
- return_dict=True,
45
- return_tensors="pt"
46
- ).to(self.model.device)
47
-
48
- # Generate the reading
49
  with torch.no_grad():
50
  generated_ids = self.model.generate(**inputs, max_new_tokens=1024)
51
 
52
- # Decode the output
53
  result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
54
 
55
  return [{"generated_text": result}]
 
2
  from PIL import Image
3
  import io
4
  import base64
5
+ from transformers import GlmOcrProcessor, GlmOcrForConditionalGeneration
6
 
7
  class EndpointHandler():
8
  def __init__(self, path=""):
9
+ # Native 5.1.0 classes specifically for GLM-OCR
10
+ self.processor = GlmOcrProcessor.from_pretrained(path)
11
+ self.model = GlmOcrForConditionalGeneration.from_pretrained(
12
  path,
 
13
  device_map="auto",
14
  torch_dtype=torch.bfloat16
15
  )
16
  self.model.eval()
17
 
18
  def __call__(self, data):
19
+ # Extract base64 image from the 'inputs' field sent by Google Sheets
20
  inputs_data = data.pop("inputs", data)
21
  image_bytes = base64.b64decode(inputs_data)
22
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
23
 
24
+ # Bookkeeping prompt - Native formatting
25
+ prompt = "Extract receipt items into JSON: [{date, vendor, description, qty, price, total}]"
26
 
27
+ # New 5.1.0 process workflow
28
+ inputs = self.processor(images=image, text=prompt, return_tensors="pt").to(self.model.device)
 
 
 
 
 
 
 
 
29
 
 
 
 
 
 
 
 
 
 
 
30
  with torch.no_grad():
31
  generated_ids = self.model.generate(**inputs, max_new_tokens=1024)
32
 
33
+ # Decode results
34
  result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
35
 
36
  return [{"generated_text": result}]