| import gradio as gr |
| import torch |
| from PIL import Image |
| from transformers import AutoProcessor, LlavaForConditionalGeneration |
| from io import BytesIO |
| import requests |
| import json |
| import time |
|
|
| |
| processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") |
|
|
| model = LlavaForConditionalGeneration.from_pretrained( |
| "llava-hf/llava-1.5-7b-hf", |
| torch_dtype=torch.float16, |
| device_map="auto" |
| ) |
|
|
| |
| def generate_response(user_message, system_prompt=None, image=None, max_tokens=1024, temperature=0.7): |
| if system_prompt: |
| prompt = f"<image>\n{system_prompt}\n{user_message}" |
| else: |
| prompt = f"<image>\n{user_message}" |
|
|
| inputs = processor(prompt, image, return_tensors="pt").to(model.device) |
|
|
| with torch.inference_mode(): |
| output = model.generate( |
| **inputs, |
| max_new_tokens=max_tokens, |
| do_sample=True, |
| temperature=temperature, |
| ) |
|
|
| response_text = processor.decode(output[0], skip_special_tokens=True) |
| return response_text |
|
|
| |
| def api_endpoint(request: gr.Request): |
| try: |
| data = request.json |
| user_message = data.get("user_message", "") |
| system_prompt = data.get("system_prompt", None) |
| image_url = data.get("image_url", None) |
| max_tokens = data.get("max_tokens", 1024) |
| temperature = data.get("temperature", 0.7) |
|
|
| image_data = None |
| if image_url: |
| image_response = requests.get(image_url) |
| image_data = Image.open(BytesIO(image_response.content)).convert("RGB") |
|
|
| response_text = generate_response( |
| user_message=user_message, |
| system_prompt=system_prompt, |
| image=image_data, |
| max_tokens=max_tokens, |
| temperature=temperature |
| ) |
|
|
| return gr.Response(json.dumps({ |
| "id": f"chatcmpl-{int(time.time())}", |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": "llava-1.5-7b", |
| "choices": [{ |
| "message": { |
| "role": "assistant", |
| "content": response_text |
| }, |
| "index": 0, |
| "finish_reason": "stop" |
| }] |
| }), media_type="application/json") |
|
|
| except Exception as e: |
| return gr.Response(json.dumps({"error": str(e)}), media_type="application/json") |
|
|
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# π LLaVA API Demo") |
|
|
| with gr.Tab("Test UI"): |
| with gr.Row(): |
| with gr.Column(): |
| user_message = gr.Textbox(label="User Message", lines=3) |
| system_prompt = gr.Textbox(label="System Prompt (Optional)", lines=2) |
| image_input = gr.Image(label="Image (Optional)", type="pil") |
| max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=2048, value=1024, step=1) |
| temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=0.7, step=0.1) |
| submit_btn = gr.Button("Generate Response") |
| with gr.Column(): |
| output = gr.Textbox(label="Response", lines=10) |
|
|
| def on_submit(message, system, image, tokens, temp): |
| return generate_response(message, system, image, tokens, temp) |
|
|
| submit_btn.click( |
| fn=on_submit, |
| inputs=[user_message, system_prompt, image_input, max_tokens, temperature], |
| outputs=output |
| ) |
|
|
| |
| demo.api("/api")(api_endpoint) |
|
|
| |
| demo.launch() |