import gradio as gr import torch from PIL import Image from transformers import AutoProcessor, LlavaForConditionalGeneration from io import BytesIO import requests import json import time # Load processor and model processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") model = LlavaForConditionalGeneration.from_pretrained( "llava-hf/llava-1.5-7b-hf", torch_dtype=torch.float16, device_map="auto" ) # Core inference function def generate_response(user_message, system_prompt=None, image=None, max_tokens=1024, temperature=0.7): if system_prompt: prompt = f"\n{system_prompt}\n{user_message}" else: prompt = f"\n{user_message}" inputs = processor(prompt, image, return_tensors="pt").to(model.device) with torch.inference_mode(): output = model.generate( **inputs, max_new_tokens=max_tokens, do_sample=True, temperature=temperature, ) response_text = processor.decode(output[0], skip_special_tokens=True) return response_text # API-style function for programmatic access def api_endpoint(request: gr.Request): try: data = request.json user_message = data.get("user_message", "") system_prompt = data.get("system_prompt", None) image_url = data.get("image_url", None) max_tokens = data.get("max_tokens", 1024) temperature = data.get("temperature", 0.7) image_data = None if image_url: image_response = requests.get(image_url) image_data = Image.open(BytesIO(image_response.content)).convert("RGB") response_text = generate_response( user_message=user_message, system_prompt=system_prompt, image=image_data, max_tokens=max_tokens, temperature=temperature ) return gr.Response(json.dumps({ "id": f"chatcmpl-{int(time.time())}", "object": "chat.completion", "created": int(time.time()), "model": "llava-1.5-7b", "choices": [{ "message": { "role": "assistant", "content": response_text }, "index": 0, "finish_reason": "stop" }] }), media_type="application/json") except Exception as e: return gr.Response(json.dumps({"error": str(e)}), media_type="application/json") # Gradio UI with gr.Blocks() as demo: gr.Markdown("# 🔍 LLaVA API Demo") with gr.Tab("Test UI"): with gr.Row(): with gr.Column(): user_message = gr.Textbox(label="User Message", lines=3) system_prompt = gr.Textbox(label="System Prompt (Optional)", lines=2) image_input = gr.Image(label="Image (Optional)", type="pil") max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=2048, value=1024, step=1) temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=0.7, step=0.1) submit_btn = gr.Button("Generate Response") with gr.Column(): output = gr.Textbox(label="Response", lines=10) def on_submit(message, system, image, tokens, temp): return generate_response(message, system, image, tokens, temp) submit_btn.click( fn=on_submit, inputs=[user_message, system_prompt, image_input, max_tokens, temperature], outputs=output ) # API endpoint demo.api("/api")(api_endpoint) # Launch demo.launch()