grayphite commited on
Commit
9f92e1f
Β·
verified Β·
1 Parent(s): 253d1e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -155
app.py CHANGED
@@ -2,57 +2,64 @@ import gradio as gr
2
  import torch
3
  from PIL import Image
4
  from transformers import AutoProcessor, LlavaForConditionalGeneration
5
- import time
6
- import os
7
  import requests
8
  import json
 
 
 
 
9
 
10
- # Load model
11
- processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") # Using 7B to fit in free tier
12
  model = LlavaForConditionalGeneration.from_pretrained(
13
  "llava-hf/llava-1.5-7b-hf",
14
  torch_dtype=torch.float16,
15
- low_cpu_mem_usage=True
16
- ).to("cuda")
17
 
18
- # API function
19
- def api_endpoint(request):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  try:
21
- # Parse request
22
- data = json.loads(request)
23
  user_message = data.get("user_message", "")
24
  system_prompt = data.get("system_prompt", None)
25
  image_url = data.get("image_url", None)
26
  max_tokens = data.get("max_tokens", 1024)
27
  temperature = data.get("temperature", 0.7)
28
-
29
- # Process image if provided
30
  image_data = None
31
  if image_url:
32
- response = requests.get(image_url)
33
- image_data = Image.open(BytesIO(response.content)).convert("RGB")
34
-
35
- # Prepare prompt
36
- if system_prompt:
37
- prompt = f"<s>[INST] {system_prompt} [/INST]\n{user_message}"
38
- else:
39
- prompt = user_message
40
-
41
- # Generate response
42
- inputs = processor(prompt, image_data, return_tensors="pt").to(model.device)
43
-
44
- with torch.inference_mode():
45
- output = model.generate(
46
- **inputs,
47
- max_new_tokens=max_tokens,
48
- do_sample=True,
49
- temperature=temperature,
50
- )
51
-
52
- response_text = processor.decode(output[0], skip_special_tokens=True)
53
-
54
- # Return response
55
- return json.dumps({
56
  "id": f"chatcmpl-{int(time.time())}",
57
  "object": "chat.completion",
58
  "created": int(time.time()),
@@ -65,16 +72,17 @@ def api_endpoint(request):
65
  "index": 0,
66
  "finish_reason": "stop"
67
  }]
68
- })
69
-
70
  except Exception as e:
71
- return json.dumps({"error": str(e)})
 
72
 
73
- # Gradio Interface
74
  with gr.Blocks() as demo:
75
- gr.Markdown("# LLaVA API Demo")
76
-
77
- with gr.Tab("API Test UI"):
78
  with gr.Row():
79
  with gr.Column():
80
  user_message = gr.Textbox(label="User Message", lines=3)
@@ -83,123 +91,20 @@ with gr.Blocks() as demo:
83
  max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=2048, value=1024, step=1)
84
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=0.7, step=0.1)
85
  submit_btn = gr.Button("Generate Response")
86
-
87
  with gr.Column():
88
  output = gr.Textbox(label="Response", lines=10)
89
-
90
- with gr.Tab("API Documentation"):
91
- gr.Markdown("""
92
- ## API Endpoint Documentation
93
-
94
- **URL**: `https://YOUR-USERNAME-llava-api.hf.space/api/`
95
-
96
- **Method**: POST
97
-
98
- **Request Body**:
99
- ```json
100
- {
101
- "user_message": "Describe this image",
102
- "system_prompt": "You are a helpful assistant",
103
- "image_url": "https://example.com/image.jpg",
104
- "max_tokens": 1024,
105
- "temperature": 0.7
106
- }
107
- ```
108
-
109
- **Response**:
110
- ```json
111
- {
112
- "id": "chatcmpl-1234567890",
113
- "object": "chat.completion",
114
- "created": 1683123456,
115
- "model": "llava-1.5-7b",
116
- "choices": [
117
- {
118
- "message": {
119
- "role": "assistant",
120
- "content": "Response text here"
121
- },
122
- "index": 0,
123
- "finish_reason": "stop"
124
- }
125
- ]
126
- }
127
- ```
128
-
129
- **Example Python Client**:
130
- ```python
131
- import requests
132
- import json
133
-
134
- def query_llava_api(api_url, user_message, system_prompt=None, image_url=None):
135
- payload = {
136
- "user_message": user_message,
137
- "max_tokens": 1024
138
- }
139
-
140
- if system_prompt:
141
- payload["system_prompt"] = system_prompt
142
-
143
- if image_url:
144
- payload["image_url"] = image_url
145
-
146
- response = requests.post(api_url, json=payload)
147
- return response.json()
148
-
149
- # Example usage
150
- result = query_llava_api(
151
- "https://YOUR-USERNAME-llava-api.hf.space/api/",
152
- "What's in this image?",
153
- image_url="https://example.com/image.jpg"
154
- )
155
- print(result["choices"][0]["message"]["content"])
156
- ```
157
- """)
158
-
159
- # API endpoint
160
- gr.Interface(
161
- fn=api_endpoint,
162
- inputs=gr.Textbox(),
163
- outputs=gr.Textbox(),
164
- api_name="api"
165
- )
166
-
167
- # Connect UI to function
168
- def process_inputs(message, system, img, tokens, temp):
169
- # Create payload
170
- payload = {
171
- "user_message": message,
172
- "max_tokens": tokens,
173
- "temperature": temp
174
- }
175
-
176
- if system:
177
- payload["system_prompt"] = system
178
-
179
- # Process image
180
- if img is not None:
181
- # For demo purposes, we use the image directly
182
- # In a real API, you'd need to handle image URLs
183
- inputs = processor(message, img, return_tensors="pt").to(model.device)
184
-
185
- with torch.inference_mode():
186
- output = model.generate(
187
- **inputs,
188
- max_new_tokens=tokens,
189
- do_sample=True,
190
- temperature=temp,
191
- )
192
-
193
- response_text = processor.decode(output[0], skip_special_tokens=True)
194
- return response_text
195
-
196
- # If no image, process text only
197
- return api_endpoint(json.dumps(payload))
198
-
199
  submit_btn.click(
200
- process_inputs,
201
  inputs=[user_message, system_prompt, image_input, max_tokens, temperature],
202
  outputs=output
203
  )
204
 
 
 
 
 
205
  demo.launch()
 
2
  import torch
3
  from PIL import Image
4
  from transformers import AutoProcessor, LlavaForConditionalGeneration
5
+ from io import BytesIO
 
6
  import requests
7
  import json
8
+ import time
9
+
10
+ # Load processor and model
11
+ processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
12
 
 
 
13
  model = LlavaForConditionalGeneration.from_pretrained(
14
  "llava-hf/llava-1.5-7b-hf",
15
  torch_dtype=torch.float16,
16
+ device_map="auto"
17
+ )
18
 
19
+ # Core inference function
20
+ def generate_response(user_message, system_prompt=None, image=None, max_tokens=1024, temperature=0.7):
21
+ if system_prompt:
22
+ prompt = f"<image>\n{system_prompt}\n{user_message}"
23
+ else:
24
+ prompt = f"<image>\n{user_message}"
25
+
26
+ inputs = processor(prompt, image, return_tensors="pt").to(model.device)
27
+
28
+ with torch.inference_mode():
29
+ output = model.generate(
30
+ **inputs,
31
+ max_new_tokens=max_tokens,
32
+ do_sample=True,
33
+ temperature=temperature,
34
+ )
35
+
36
+ response_text = processor.decode(output[0], skip_special_tokens=True)
37
+ return response_text
38
+
39
+ # API-style function for programmatic access
40
+ def api_endpoint(request: gr.Request):
41
  try:
42
+ data = request.json
 
43
  user_message = data.get("user_message", "")
44
  system_prompt = data.get("system_prompt", None)
45
  image_url = data.get("image_url", None)
46
  max_tokens = data.get("max_tokens", 1024)
47
  temperature = data.get("temperature", 0.7)
48
+
 
49
  image_data = None
50
  if image_url:
51
+ image_response = requests.get(image_url)
52
+ image_data = Image.open(BytesIO(image_response.content)).convert("RGB")
53
+
54
+ response_text = generate_response(
55
+ user_message=user_message,
56
+ system_prompt=system_prompt,
57
+ image=image_data,
58
+ max_tokens=max_tokens,
59
+ temperature=temperature
60
+ )
61
+
62
+ return gr.Response(json.dumps({
 
 
 
 
 
 
 
 
 
 
 
 
63
  "id": f"chatcmpl-{int(time.time())}",
64
  "object": "chat.completion",
65
  "created": int(time.time()),
 
72
  "index": 0,
73
  "finish_reason": "stop"
74
  }]
75
+ }), media_type="application/json")
76
+
77
  except Exception as e:
78
+ return gr.Response(json.dumps({"error": str(e)}), media_type="application/json")
79
+
80
 
81
+ # Gradio UI
82
  with gr.Blocks() as demo:
83
+ gr.Markdown("# πŸ” LLaVA API Demo")
84
+
85
+ with gr.Tab("Test UI"):
86
  with gr.Row():
87
  with gr.Column():
88
  user_message = gr.Textbox(label="User Message", lines=3)
 
91
  max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=2048, value=1024, step=1)
92
  temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=0.7, step=0.1)
93
  submit_btn = gr.Button("Generate Response")
 
94
  with gr.Column():
95
  output = gr.Textbox(label="Response", lines=10)
96
+
97
+ def on_submit(message, system, image, tokens, temp):
98
+ return generate_response(message, system, image, tokens, temp)
99
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  submit_btn.click(
101
+ fn=on_submit,
102
  inputs=[user_message, system_prompt, image_input, max_tokens, temperature],
103
  outputs=output
104
  )
105
 
106
+ # API endpoint
107
+ demo.api("/api")(api_endpoint)
108
+
109
+ # Launch
110
  demo.launch()