| from transformers import AutoProcessor, AutoModelForCausalLM |
| import gradio as gr |
| import torch |
|
|
| |
| processor = AutoProcessor.from_pretrained("microsoft/git-base") |
| model = AutoModelForCausalLM.from_pretrained("./") |
|
|
| def predict(image): |
| try: |
| |
| inputs = processor(images=image, return_tensors="pt") |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| inputs = {key: value.to(device) for key, value in inputs.items()} |
| model.to(device) |
|
|
| |
| outputs = model.generate(**inputs) |
|
|
| |
| caption = processor.batch_decode(outputs, skip_special_tokens=True)[0] |
|
|
| return caption |
|
|
| except Exception as e: |
| print("Error during prediction:", str(e)) |
| return "Error: " + str(e) |
|
|
| |
| with gr.Blocks() as demo: |
| image = gr.Image(type="pil") |
| predict_btn = gr.Button("Predict", variant="primary") |
| output = gr.Label(label="Generated Caption") |
|
|
| inputs = [image] |
| outputs = [output] |
|
|
| predict_btn.click(predict, inputs=inputs, outputs=outputs) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
| |
| |
|
|