| import gradio as gr |
| import torch |
| from PIL import Image |
| import requests |
| from io import BytesIO |
| import json |
| import time |
| import os |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig |
| from transformers import CLIPVisionModel, CLIPImageProcessor |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| print("π Starting LLaVA deployment...") |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"π» Using device: {device}") |
|
|
| |
| tokenizer = None |
| model = None |
| image_processor = None |
| vision_tower = None |
|
|
| def load_model(): |
| """Load LLaVA model components""" |
| global tokenizer, model, image_processor, vision_tower |
| |
| try: |
| print("π¦ Loading tokenizer...") |
| |
| model_path = "liuhaotian/llava-v1.5-7b" |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| |
| print("π§ Loading language model...") |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
| low_cpu_mem_usage=True, |
| device_map="auto" if device == "cuda" else None |
| ) |
| |
| print("ποΈ Loading vision components...") |
| |
| vision_tower = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14-336") |
| image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336") |
| |
| if device == "cuda": |
| vision_tower = vision_tower.to(device) |
| |
| print("β
Model loaded successfully!") |
| return True |
| |
| except Exception as e: |
| print(f"β Error loading model: {str(e)}") |
| return False |
|
|
| def process_image(image): |
| """Process image for the model""" |
| if image is None: |
| return None |
| |
| try: |
| |
| if image.mode != 'RGB': |
| image = image.convert('RGB') |
| |
| |
| image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'] |
| |
| if device == "cuda": |
| image_tensor = image_tensor.to(device) |
| |
| |
| with torch.no_grad(): |
| image_features = vision_tower(image_tensor).last_hidden_state |
| |
| return image_features |
| |
| except Exception as e: |
| print(f"Error processing image: {str(e)}") |
| return None |
|
|
| def generate_response(message, image=None, system_prompt="", max_tokens=1024, temperature=0.7): |
| """Generate response using LLaVA""" |
| global tokenizer, model, image_processor, vision_tower |
| |
| if model is None: |
| return "β Model not loaded. Please wait for initialization." |
| |
| try: |
| |
| image_features = None |
| if image is not None: |
| image_features = process_image(image) |
| if image_features is None: |
| return "β Error processing image." |
| |
| |
| if system_prompt: |
| full_prompt = f"System: {system_prompt}\n\nUser: {message}\n\nAssistant:" |
| else: |
| if image is not None: |
| full_prompt = f"USER: <image>\n{message}\nASSISTANT:" |
| else: |
| full_prompt = f"USER: {message}\nASSISTANT:" |
| |
| |
| inputs = tokenizer(full_prompt, return_tensors="pt") |
| |
| if device == "cuda": |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
| |
| with torch.no_grad(): |
| if image_features is not None: |
| |
| |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| do_sample=True, |
| pad_token_id=tokenizer.eos_token_id |
| ) |
| else: |
| |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| do_sample=True, |
| pad_token_id=tokenizer.eos_token_id |
| ) |
| |
| |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| |
| response = response[len(full_prompt):].strip() |
| |
| return response |
| |
| except Exception as e: |
| return f"β Error generating response: {str(e)}" |
|
|
| def api_endpoint(request_json): |
| """API endpoint for programmatic access""" |
| try: |
| data = json.loads(request_json) |
| |
| message = data.get("message", "") |
| system_prompt = data.get("system_prompt", "") |
| image_url = data.get("image_url", None) |
| max_tokens = int(data.get("max_tokens", 1024)) |
| temperature = float(data.get("temperature", 0.7)) |
| |
| |
| image = None |
| if image_url: |
| try: |
| response = requests.get(image_url, timeout=10) |
| if response.status_code == 200: |
| image = Image.open(BytesIO(response.content)) |
| except Exception as e: |
| return json.dumps({"error": f"Failed to load image: {str(e)}"}) |
| |
| |
| response_text = generate_response( |
| message=message, |
| image=image, |
| system_prompt=system_prompt, |
| max_tokens=max_tokens, |
| temperature=temperature |
| ) |
| |
| |
| return json.dumps({ |
| "id": f"chatcmpl-{int(time.time())}", |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": "llava-v1.5-7b", |
| "choices": [{ |
| "message": { |
| "role": "assistant", |
| "content": response_text |
| }, |
| "index": 0, |
| "finish_reason": "stop" |
| }], |
| "usage": { |
| "prompt_tokens": 0, |
| "completion_tokens": 0, |
| "total_tokens": 0 |
| } |
| }) |
| |
| except Exception as e: |
| return json.dumps({"error": str(e)}) |
|
|
| |
| print("π Initializing model...") |
| model_loaded = load_model() |
|
|
| |
| with gr.Blocks(title="LLaVA - Large Language and Vision Assistant", theme=gr.themes.Soft()) as demo: |
| gr.Markdown(""" |
| # π¦ LLaVA - Large Language and Vision Assistant |
| |
| An open-source chatbot trained by fine-tuning LLaMA/Vicuna on GPT-generated multimodal instruction-following data. |
| |
| **Features:** |
| - π¬ Text-based conversation |
| - πΌοΈ Image understanding and description |
| - π§ API endpoint for integration |
| """) |
| |
| with gr.Tab("π¬ Chat Interface"): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| image_input = gr.Image( |
| type="pil", |
| label="πΈ Upload Image (Optional)", |
| height=300 |
| ) |
| system_prompt = gr.Textbox( |
| label="π― System Prompt (Optional)", |
| placeholder="You are a helpful assistant that can analyze images...", |
| lines=2 |
| ) |
| |
| with gr.Column(scale=2): |
| chatbot = gr.Chatbot( |
| label="π Conversation", |
| height=400 |
| ) |
| |
| msg = gr.Textbox( |
| label="βοΈ Your Message", |
| placeholder="Type your message here... You can ask about the uploaded image!", |
| lines=2 |
| ) |
| |
| with gr.Row(): |
| submit_btn = gr.Button("π Send", variant="primary") |
| clear_btn = gr.Button("ποΈ Clear", variant="secondary") |
| |
| with gr.Accordion("βοΈ Advanced Settings", open=False): |
| max_tokens = gr.Slider( |
| minimum=1, |
| maximum=2048, |
| value=1024, |
| step=1, |
| label="π Max Tokens" |
| ) |
| temperature = gr.Slider( |
| minimum=0.1, |
| maximum=2.0, |
| value=0.7, |
| step=0.1, |
| label="π‘οΈ Temperature" |
| ) |
| |
| with gr.Tab("π API Documentation"): |
| gr.Markdown(""" |
| ## API Endpoint Usage |
| |
| **Endpoint**: `https://your-space-name.hf.space/api/predict` |
| |
| **Method**: POST |
| |
| ### Request Format: |
| ```json |
| { |
| "data": [ |
| "{ |
| \"message\": \"Describe this image in detail\", |
| \"system_prompt\": \"You are a helpful assistant\", |
| \"image_url\": \"https://example.com/image.jpg\", |
| \"max_tokens\": 1024, |
| \"temperature\": 0.7 |
| }" |
| ] |
| } |
| ``` |
| |
| ### Response Format: |
| ```json |
| { |
| "data": [ |
| "{ |
| \"id\": \"chatcmpl-123456789\", |
| \"object\": \"chat.completion\", |
| \"created\": 1683123456, |
| \"model\": \"llava-v1.5-7b\", |
| \"choices\": [ |
| { |
| \"message\": { |
| \"role\": \"assistant\", |
| \"content\": \"This image shows...\" |
| }, |
| \"index\": 0, |
| \"finish_reason\": \"stop\" |
| } |
| ] |
| }" |
| ] |
| } |
| ``` |
| |
| ### Python Client Example: |
| ```python |
| import requests |
| import json |
| |
| def query_llava(message, image_url=None, system_prompt=""): |
| payload = { |
| "data": [json.dumps({ |
| "message": message, |
| "image_url": image_url, |
| "system_prompt": system_prompt, |
| "max_tokens": 1024, |
| "temperature": 0.7 |
| })] |
| } |
| |
| response = requests.post( |
| "https://your-space-name.hf.space/api/predict", |
| json=payload |
| ) |
| |
| if response.status_code == 200: |
| result = response.json() |
| api_response = json.loads(result["data"][0]) |
| return api_response["choices"][0]["message"]["content"] |
| else: |
| return f"Error: {response.status_code}" |
| |
| # Example usage |
| result = query_llava( |
| "What do you see in this image?", |
| image_url="https://example.com/image.jpg" |
| ) |
| print(result) |
| ``` |
| """) |
| |
| |
| gr.Markdown("### π§ͺ Test API") |
| api_input = gr.Textbox( |
| label="π API Request (JSON)", |
| placeholder='{"message": "Hello!", "max_tokens": 1024}', |
| lines=4 |
| ) |
| api_output = gr.Textbox( |
| label="π€ API Response", |
| lines=8 |
| ) |
| api_test_btn = gr.Button("π§ͺ Test API", variant="primary") |
| |
| with gr.Tab("βΉοΈ About"): |
| gr.Markdown(""" |
| ## About LLaVA |
| |
| **LLaVA (Large Language and Vision Assistant)** is an open-source multimodal AI assistant that combines: |
| |
| - π§ **Language Understanding**: Based on Vicuna/LLaMA architecture |
| - ποΈ **Vision Capabilities**: Uses CLIP vision encoder |
| - π **Multimodal Integration**: Connects vision and language seamlessly |
| |
| ### Key Features: |
| - **Visual Question Answering**: Ask questions about images |
| - **Image Description**: Get detailed descriptions of uploaded images |
| - **General Conversation**: Chat about any topic |
| - **API Integration**: Easy integration with your applications |
| |
| ### Model Information: |
| - **Base Model**: LLaVA-v1.5-7B |
| - **Vision Encoder**: CLIP ViT-L/14@336px |
| - **Language Model**: Vicuna-7B |
| - **Training Data**: LLaVA-Instruct-150K |
| |
| ### Citation: |
| ``` |
| @misc{liu2023llava, |
| title={Visual Instruction Tuning}, |
| author={Haotian Liu and Chunyuan Li and Qingyang Wu and Yong Jae Lee}, |
| year={2023}, |
| eprint={2304.08485}, |
| archivePrefix={arXiv}, |
| primaryClass={cs.CV} |
| } |
| ``` |
| |
| **GitHub**: [https://github.com/haotian-liu/LLaVA](https://github.com/haotian-liu/LLaVA) |
| """) |
| |
| |
| def respond(message, chat_history, image, system_prompt, max_tokens, temperature): |
| if not message.strip(): |
| return "", chat_history |
| |
| |
| chat_history.append([message, None]) |
| |
| |
| response = generate_response( |
| message=message, |
| image=image, |
| system_prompt=system_prompt if system_prompt.strip() else "", |
| max_tokens=int(max_tokens), |
| temperature=temperature |
| ) |
| |
| |
| chat_history[-1][1] = response |
| |
| return "", chat_history |
| |
| def clear_chat(): |
| return None, [] |
| |
| |
| submit_btn.click( |
| respond, |
| [msg, chatbot, image_input, system_prompt, max_tokens, temperature], |
| [msg, chatbot] |
| ) |
| |
| msg.submit( |
| respond, |
| [msg, chatbot, image_input, system_prompt, max_tokens, temperature], |
| [msg, chatbot] |
| ) |
| |
| clear_btn.click(clear_chat, outputs=[chatbot, msg]) |
| |
| api_test_btn.click(api_endpoint, inputs=api_input, outputs=api_output) |
| |
| |
| api_interface = gr.Interface( |
| fn=api_endpoint, |
| inputs=gr.Textbox(), |
| outputs=gr.Textbox(), |
| api_name="predict" |
| ) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False |
| ) |