Spaces:
Paused
Paused
| import torch | |
| from model_loader import model, processor, device | |
| from processor_utils import load_input | |
| from prompt import get_prompt | |
| import json | |
| def process_document(image): | |
| # images = load_input(file_path) | |
| # image = images[0] | |
| # print("Checking input type and no of pages in pdf") | |
| # print(type(image)) | |
| # print(type(images)) | |
| # print(len(images)) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": get_prompt()} | |
| ] | |
| } | |
| ] | |
| text = processor.apply_chat_template( | |
| messages, | |
| tokenize=False, # so that this can return string output | |
| add_generation_prompt=True # if true it will add extra on start and end | |
| ) | |
| # print(f"The text of inference is {text}") | |
| inputs = processor( | |
| text=[text], | |
| images=[image], | |
| return_tensors="pt" | |
| ).to(device) | |
| # print(f"The inputs of inference is {inputs}") | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=1500, | |
| do_sample=False, # if it is true there will be extra text with output | |
| # temperature=0.1 # temp is not required | |
| ) | |
| # print(f"The output of inference is {output}") | |
| generated_ids = output[0][inputs.input_ids.shape[-1]:] | |
| # print(f"The generated_ids of inference is {generated_ids}") | |
| # response = processor.decode( # past code | |
| # generated_ids, | |
| # skip_special_tokens=True | |
| # ) | |
| # return response.strip() | |
| response = processor.decode( | |
| generated_ids, | |
| skip_special_tokens=True | |
| ).strip() | |
| # print(f"The response of inference is {response}") | |
| # 🔥 FORCE JSON CLEANING | |
| start = response.find("{") | |
| end = response.rfind("}") + 1 | |
| if start != -1 and end != -1: | |
| response = response[start:end] | |
| print(f"The type of response is before{response}") | |
| try: | |
| parsed = json.loads(response) | |
| except: | |
| parsed = { | |
| "error":[ | |
| response | |
| ] | |
| # "Invalid JSON", | |
| # "raw": response | |
| } | |
| print(f"The type of response is after{response}") | |
| return response |