| from flask import Flask, request, jsonify |
| from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation |
| from PIL import Image |
| import torch |
| import numpy as np |
| import io |
| import base64 |
| import threading |
| import time |
|
|
| app = Flask(__name__) |
|
|
| |
| processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") |
| model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
|
| @app.route('/') |
| def hello_world(): |
| return 'Hello, World!' |
|
|
| |
| def process_image(image, prompt): |
| inputs = processor( |
| text=prompt, images=image, padding="max_length", return_tensors="pt" |
| ) |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| preds = outputs.logits |
|
|
| pred = torch.sigmoid(preds) |
| mat = pred.cpu().numpy() |
| mask = Image.fromarray(np.uint8(mat * 255), "L") |
| mask = mask.convert("RGB") |
| mask = mask.resize(image.size) |
| mask = np.array(mask)[:, :, 0] |
|
|
| mask_min = mask.min() |
| mask_max = mask.max() |
| mask = (mask - mask_min) / (mask_max - mask_min) |
|
|
| return mask |
|
|
| |
| def get_masks(prompts, img, threshold): |
| prompts = prompts.split(",") |
| masks = [] |
| for prompt in prompts: |
| mask = process_image(img, prompt) |
| mask = mask > threshold |
| masks.append(mask) |
|
|
| return masks |
|
|
| |
| @app.route('/api', methods=['POST']) |
| def process_request(): |
| data = request.json |
|
|
| |
| base64_image = data.get('image') |
| image_data = base64.b64decode(base64_image.split(',')[1]) |
| img = Image.open(io.BytesIO(image_data)) |
|
|
| |
| pos_prompts = data.get('positive_prompts', '') |
| neg_prompts = data.get('negative_prompts', '') |
| threshold = float(data.get('threshold', 0.4)) |
|
|
| |
| positive_masks = get_masks(pos_prompts, img, 0.5) |
| negative_masks = get_masks(neg_prompts, img, 0.5) |
|
|
| pos_mask = np.any(np.stack(positive_masks), axis=0) |
| neg_mask = np.any(np.stack(negative_masks), axis=0) |
| final_mask = pos_mask & ~neg_mask |
|
|
| final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L") |
|
|
| |
| buffered = io.BytesIO() |
| final_mask.save(buffered, format="PNG") |
| final_mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
| return jsonify({'final_mask_base64': final_mask_base64}) |
|
|
| |
| def keep_alive(): |
| while True: |
| time.sleep(300) |
| requests.get('http://127.0.0.1:7860/') |
|
|
| if __name__ == '__main__': |
| print("Server starting. Verify it is running by visiting http://0.0.0.0:7860/") |
|
|
| |
| keep_alive_thread = threading.Thread(target=keep_alive) |
| keep_alive_thread.start() |
|
|
| app.run(host='0.0.0.0', port=7860, debug=True) |
|
|