Experiments / inference.py
credent007's picture
Update inference.py
f0f5725 verified
import torch
from model_loader import model, processor, device
from processor_utils import load_input
from prompt import get_prompt
import json
def process_document(image):
# images = load_input(file_path)
# image = images[0]
# print("Checking input type and no of pages in pdf")
# print(type(image))
# print(type(images))
# print(len(images))
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": get_prompt()}
]
}
]
text = processor.apply_chat_template(
messages,
tokenize=False, # so that this can return string output
add_generation_prompt=True # if true it will add extra on start and end
)
# print(f"The text of inference is {text}")
inputs = processor(
text=[text],
images=[image],
return_tensors="pt"
).to(device)
# print(f"The inputs of inference is {inputs}")
output = model.generate(
**inputs,
max_new_tokens=1500,
do_sample=False, # if it is true there will be extra text with output
# temperature=0.1 # temp is not required
)
# print(f"The output of inference is {output}")
generated_ids = output[0][inputs.input_ids.shape[-1]:]
# print(f"The generated_ids of inference is {generated_ids}")
# response = processor.decode( # past code
# generated_ids,
# skip_special_tokens=True
# )
# return response.strip()
response = processor.decode(
generated_ids,
skip_special_tokens=True
).strip()
# print(f"The response of inference is {response}")
# 🔥 FORCE JSON CLEANING
start = response.find("{")
end = response.rfind("}") + 1
if start != -1 and end != -1:
response = response[start:end]
print(f"The type of response is before{response}")
try:
parsed = json.loads(response)
except:
parsed = {
"error":[
response
]
# "Invalid JSON",
# "raw": response
}
print(f"The type of response is after{response}")
return response