defford commited on
Commit
3877d00
·
verified ·
1 Parent(s): 57bc9be

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +11 -17
handler.py CHANGED
@@ -1,15 +1,15 @@
1
  import torch
2
  from PIL import Image
3
- from transformers import AutoProcessor, AutoModelForVision2Seq
4
  import io
5
  import base64
 
 
6
 
7
  class EndpointHandler():
8
  def __init__(self, path=""):
9
- # We MUST use trust_remote_code=True because the architecture
10
- # is defined in the files you downloaded, not in the library itself.
11
- self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
12
- self.model = AutoModelForVision2Seq.from_pretrained(
13
  path,
14
  trust_remote_code=True,
15
  device_map="auto",
@@ -18,27 +18,21 @@ class EndpointHandler():
18
  self.model.eval()
19
 
20
  def __call__(self, data):
21
- # Handle the input format from Google Sheets
22
  inputs_data = data.pop("inputs", data)
23
-
24
- # If the data comes in as a string (base64)
25
- if isinstance(inputs_data, str):
26
- image_bytes = base64.b64decode(inputs_data)
27
- else:
28
- # Handle direct bytes if necessary
29
- image_bytes = inputs_data
30
-
31
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
32
 
33
- # The prompt that tells the AI what to look for
34
- prompt = "Extract all line items. Return a JSON array of objects with: date, vendor, description, qty, price, total."
35
 
36
- # Process the image and text
37
  model_inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.model.device)
38
 
39
  with torch.no_grad():
40
  generated_ids = self.model.generate(**model_inputs, max_new_tokens=1024)
41
 
 
42
  result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
43
 
44
  return [{"generated_text": result}]
 
1
  import torch
2
  from PIL import Image
 
3
  import io
4
  import base64
5
+ # We use the explicit classes to avoid the 'Auto' detection errors
6
+ from transformers import GlmOcrProcessor, GlmOcrForConditionalGeneration
7
 
8
  class EndpointHandler():
9
  def __init__(self, path=""):
10
+ # Explicitly load the processor and model
11
+ self.processor = GlmOcrProcessor.from_pretrained(path, trust_remote_code=True)
12
+ self.model = GlmOcrForConditionalGeneration.from_pretrained(
 
13
  path,
14
  trust_remote_code=True,
15
  device_map="auto",
 
18
  self.model.eval()
19
 
20
  def __call__(self, data):
21
+ # Extract base64 from Google Apps Script
22
  inputs_data = data.pop("inputs", data)
23
+ image_bytes = base64.b64decode(inputs_data)
 
 
 
 
 
 
 
24
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
25
 
26
+ # Specific prompt for structured bookkeeping
27
+ prompt = "Identify Date, Vendor, and list every Item with description, qty, and price. Return as a JSON array."
28
 
29
+ # Process the input
30
  model_inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.model.device)
31
 
32
  with torch.no_grad():
33
  generated_ids = self.model.generate(**model_inputs, max_new_tokens=1024)
34
 
35
+ # Decode output
36
  result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
37
 
38
  return [{"generated_text": result}]