| from flask import Flask, request, jsonify, render_template |
| from PIL import Image |
| import base64 |
| from io import BytesIO |
| from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation |
| import torch |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import cv2 |
|
|
| app = Flask(__name__) |
|
|
| processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") |
| model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
|
| def process_image(image, prompt, threshold, alpha_value, draw_rectangles): |
| inputs = processor( |
| text=prompt, images=image, padding="max_length", return_tensors="pt" |
| ) |
|
|
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| preds = outputs.logits |
|
|
| pred = torch.sigmoid(preds) |
| mat = pred.cpu().numpy() |
|
|
| |
| mat = np.squeeze(mat, axis=0) |
| mask = Image.fromarray(np.uint8(mat * 255), "L") |
| mask = mask.convert("RGB") |
| mask = mask.resize(image.size) |
| mask = np.array(mask)[:, :, 0] |
|
|
| |
| mask_min = mask.min() |
| mask_max = mask.max() |
| mask = (mask - mask_min) / (mask_max - mask_min) |
|
|
| |
| bmask = mask > threshold |
| mask[mask < threshold] = 0 |
|
|
| fig, ax = plt.subplots() |
| ax.imshow(image) |
| ax.imshow(mask, alpha=alpha_value, cmap="jet") |
|
|
| if draw_rectangles: |
| contours, hierarchy = cv2.findContours( |
| bmask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE |
| ) |
| for contour in contours: |
| x, y, w, h = cv2.boundingRect(contour) |
| rect = plt.Rectangle( |
| (x, y), w, h, fill=False, edgecolor="yellow", linewidth=2 |
| ) |
| ax.add_patch(rect) |
|
|
| ax.axis("off") |
| plt.tight_layout() |
|
|
| bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L") |
| output_image = Image.new("RGBA", image.size, (0, 0, 0, 0)) |
| output_image.paste(image, mask=bmask) |
|
|
| |
| buffered_mask = BytesIO() |
| bmask.save(buffered_mask, format="PNG") |
| result_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8') |
|
|
| |
| buffered_output = BytesIO() |
| output_image.save(buffered_output, format="PNG") |
| result_output = base64.b64encode(buffered_output.getvalue()).decode('utf-8') |
|
|
| return fig, result_mask, result_output |
|
|
|
|
| |
| |
|
|
| @app.route('/') |
| def index(): |
| return render_template('index.html') |
|
|
| @app.route('/api/mask_image', methods=['POST']) |
| def mask_image_api(): |
| data = request.get_json() |
|
|
| base64_image = data.get('base64_image', '') |
| prompt = data.get('prompt', '') |
| threshold = data.get('threshold', 0.4) |
| alpha_value = data.get('alpha_value', 0.5) |
| draw_rectangles = data.get('draw_rectangles', False) |
|
|
| |
| image_data = base64.b64decode(base64_image.split(',')[1]) |
| image = Image.open(BytesIO(image_data)) |
|
|
| |
| _, result_mask, result_output = process_image(image, prompt, threshold, alpha_value, draw_rectangles) |
|
|
| return jsonify({'result_mask': result_mask, 'result_output': result_output}) |
|
|
| if __name__ == '__main__': |
| app.run(host='0.0.0.0', port=7860, debug=True) |