| import spaces |
| import gradio as gr |
| import numpy as np |
| import random |
| import time |
| import json |
| import os |
| from loguru import logger |
| from decouple import config |
| import io |
| import torch |
| import numpy as np |
| import torch |
| import cv2 |
| from PIL import Image |
|
|
| from segment_anything import sam_model_registry, SamPredictor |
|
|
|
|
|
|
| print(f"Is CUDA available: {torch.cuda.is_available()}") |
| print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") |
| print(torch.version.cuda) |
| device = torch.cuda.get_device_name(torch.cuda.current_device()) |
| print(device) |
|
|
| sam_checkpoint = "sam_hq_vit_h.pth" |
| model_type = "vit_h" |
| device = "cuda" |
| sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) |
|
|
|
|
| @spaces.GPU(duration=10) |
| def generate_image(prompt, image): |
| sam.to(device=device) |
| predictor = SamPredictor(sam) |
| predictor.set_image(image) |
|
|
| prompt = json.loads(prompt) |
| input_points = np.array(prompt['input_points']) |
| input_labels = np.array(prompt['input_labels']) |
|
|
| mask, _, _ = predictor.predict( |
| point_coords=input_points, |
| point_labels=input_labels, |
| box=None, |
| multimask_output=False, |
| hq_token_only=True, |
| ) |
|
|
| rgb_array = np.zeros((mask.shape[1], mask.shape[2], 3), dtype=np.uint8) |
| rgb_array[mask[0]] = 255 |
| result = Image.fromarray(rgb_array) |
|
|
| return result |
|
|
|
|
| if __name__ == "__main__": |
| demo = gr.Interface(fn=generate_image, inputs=[ |
| "text", |
| gr.Image(image_mode='RGB', type="numpy") |
| ], |
| outputs=[ |
| gr.Image(type="numpy", image_mode='RGB') |
| ]) |
| demo.launch(debug=True) |
| logger.debug('demo.launch()') |
|
|