image_cut_rect / app.py
JH-BK's picture
fix: build failure
36b7e0f
import gradio as gr
import numpy as np
import math
import os
import shutil
import torch
from PIL import Image, ImageDraw
from rect_main import docscanner_rec, load_docscanner_model
from data_utils.image_utils import unwarp, mask2point, get_corner, _rotate_90_degrees
from config import Config
config = Config()
cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_docscanner = None
def _is_git_lfs_pointer_file(path: str) -> bool:
try:
with open(path, "rb") as f:
head = f.read(200)
return b"version https://git-lfs.github.com/spec/v1" in head
except FileNotFoundError:
return False
def _try_download_from_hf_hub(repo_id: str, filename: str, local_path: str) -> bool:
try:
from huggingface_hub import hf_hub_download # type: ignore
downloaded = hf_hub_download(repo_id=repo_id, filename=filename)
os.makedirs(os.path.dirname(local_path), exist_ok=True)
shutil.copyfile(downloaded, local_path)
return True
except Exception:
return False
def _ensure_weights_present() -> None:
seg_path = config.get_seg_model_path
rec_path = config.get_rec_model_path
needs_seg = (not os.path.exists(seg_path)) or _is_git_lfs_pointer_file(seg_path)
needs_rec = (not os.path.exists(rec_path)) or _is_git_lfs_pointer_file(rec_path)
if not (needs_seg or needs_rec):
return
repo_id = (
os.getenv("SPACE_ID")
or os.getenv("HF_SPACE_ID")
or os.getenv("HF_REPO_ID")
or os.getenv("REPO_ID")
)
if repo_id:
if needs_seg:
_try_download_from_hf_hub(repo_id, os.path.basename(seg_path), seg_path)
if needs_rec:
_try_download_from_hf_hub(repo_id, os.path.basename(rec_path), rec_path)
def get_docscanner():
global _docscanner
if _docscanner is not None:
return _docscanner
_ensure_weights_present()
seg_path = config.get_seg_model_path
rec_path = config.get_rec_model_path
if _is_git_lfs_pointer_file(seg_path) or _is_git_lfs_pointer_file(rec_path):
raise RuntimeError(
"Model weight files look like Git LFS pointers. "
"Make sure LFS objects are downloaded (e.g. `git lfs pull`) "
"or allow the Space to download them from the Hub at runtime."
)
_docscanner = load_docscanner_model(cuda, path_l=rec_path, path_m=seg_path)
return _docscanner
# ์ขŒํ‘œ๋ฅผ ์ดˆ๊ธฐํ™”ํ•˜๋Š” ํ•จ์ˆ˜
def reset_points(image, state):
state = []
return image, state
def cutting_image(image, state):
min_x = min(point[0] for point in state)
max_x = max(point[0] for point in state)
min_y = min(point[1] for point in state)
max_y = max(point[1] for point in state)
cutted_image = image[min_y:max_y, min_x:max_x]
state = []
return cutted_image, cutted_image, state
def rotate_image(image):
rotated_image = _rotate_90_degrees(image)
state = []
return rotated_image, state
def reset_image(image, state):
docscanner = get_docscanner()
out_image, msk_np = docscanner_rec(image, docscanner, cuda)
state = list(get_corner(mask2point(mask=msk_np)))
img = Image.fromarray(image)
area = image.shape[0]*image.shape[1]
radius=max(5, round(area**0.5 / 120))
# ์ขŒํ‘œ๊ฐ€ ์ตœ์†Œ 3๊ฐœ ์ด์ƒ์ผ ๋•Œ๋งŒ ํด๋ฆฌ๊ณค ๊ทธ๋ฆฌ๊ธฐ
draw = ImageDraw.Draw(img)
for pt in state:
left_up_point = (pt[0] - radius, pt[1] - radius)
right_down_point = (pt[0] + radius, pt[1] + radius)
draw.ellipse([left_up_point, right_down_point], outline="black", fill="red")
center = (sum(p[0] for p in state) / len(state), sum(p[1] for p in state) / len(state))
# ๊ฐ๋„์— ๋”ฐ๋ผ ์ ๋“ค์„ ์ •๋ ฌ
sorted_points = sorted(state, key=lambda p: calculate_angle(p, center))
draw.polygon(sorted_points, outline="red", fill=None, width=round(radius/2))
return img, state
def auto_point_detect(image):
docscanner = get_docscanner()
out_image, msk_np = docscanner_rec(image, docscanner, cuda)
state = list(get_corner(mask2point(mask=msk_np)))
img = Image.fromarray(image)
area = image.shape[0]*image.shape[1]
radius=max(5, round(area**0.5 / 120))
# ์ขŒํ‘œ๊ฐ€ ์ตœ์†Œ 3๊ฐœ ์ด์ƒ์ผ ๋•Œ๋งŒ ํด๋ฆฌ๊ณค ๊ทธ๋ฆฌ๊ธฐ
draw = ImageDraw.Draw(img)
for pt in state:
left_up_point = (pt[0] - radius, pt[1] - radius)
right_down_point = (pt[0] + radius, pt[1] + radius)
draw.ellipse([left_up_point, right_down_point], outline="black", fill="red")
center = (sum(p[0] for p in state) / len(state), sum(p[1] for p in state) / len(state))
# ๊ฐ๋„์— ๋”ฐ๋ผ ์ ๋“ค์„ ์ •๋ ฌ
sorted_points = sorted(state, key=lambda p: calculate_angle(p, center))
draw.polygon(sorted_points, outline="red", fill=None, width=round(radius/2))
return img, state
def calculate_angle(point, center):
return math.atan2(point[1] - center[1], point[0] - center[0])
# ์ขŒํ‘œ๋ฅผ ๋ฐ›์•„์„œ ํด๋ฆฌ๊ณค์„ ๊ทธ๋ฆฌ๋Š” ํ•จ์ˆ˜
def draw_polygon_on_image(image, evt: gr.SelectData, state):
img = Image.fromarray(image)
pt = (evt.index[0], evt.index[1])
state.append(pt)
# ํด๋ฆญํ•œ ์ขŒํ‘œ๋ฅผ ์ €์žฅ
area = image.shape[0]*image.shape[1]
radius=max(5, round(area**0.5 / 120))
draw = ImageDraw.Draw(img)
for pt in state:
left_up_point = (pt[0] - radius, pt[1] - radius)
right_down_point = (pt[0] + radius, pt[1] + radius)
draw.ellipse([left_up_point, right_down_point], outline="black", fill="red")
if len(state) == 2:
draw.line([state[0], state[1]], fill="red", width=round(radius/2))
if len(state) >= 3: # ์ขŒํ‘œ๊ฐ€ ์ตœ์†Œ 3๊ฐœ ์ด์ƒ์ผ ๋•Œ๋งŒ ํด๋ฆฌ๊ณค ๊ทธ๋ฆฌ๊ธฐ
center = (sum(p[0] for p in state) / len(state), sum(p[1] for p in state) / len(state))
# ๊ฐ๋„์— ๋”ฐ๋ผ ์ ๋“ค์„ ์ •๋ ฌ
sorted_points = sorted(state, key=lambda p: calculate_angle(p, center))
draw.polygon(sorted_points, outline="red", fill=None, width=round(radius/2))
return img, state
def sort_corners(corners):
# ๊ฐ ์ขŒํ‘œ๋ฅผ (x, y) ํ˜•ํƒœ๋กœ ๋ฐ›๋Š”๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค.
# corners = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]
if len(corners) != 4:
raise ValueError("Input should contain exactly four coordinates.")
# ์ขŒํ‘œ๋ฅผ y ๊ธฐ์ค€์œผ๋กœ ์ •๋ ฌํ•˜์—ฌ ๊ฐ ์ขŒํ‘œ๋ฅผ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
sorted_by_y = sorted(corners, key=lambda p: p[1]) # y ๊ธฐ์ค€์œผ๋กœ ์ •๋ ฌ
lt, rt = sorted(sorted_by_y[:2], key=lambda p: p[0])
lb, rb = sorted(sorted_by_y[2:], key=lambda p: p[0])
return lt, rt, rb, lb
def convert(image, state):
h,w = image.shape[:2]
if len(state) < 4:
docscanner = get_docscanner()
out_image, msk_np = docscanner_rec(image, docscanner, cuda)
out_image = out_image[:,:,::-1]
elif len(state) ==4:
state = list(sort_corners(state))
src = np.array(state).astype(np.float32)
dst = np.float32([
(0, 0),
(w - 1, 0),
(w - 1, h - 1),
(0, h - 1)
])
out_image, M = unwarp(image, src, dst)
return out_image
css = """
.image-container {
padding: 20px;
background-color: #f0f0f0;
}
"""
# Gradio Blocks ์ปจํ…์ŠคํŠธ์—์„œ ์ธํ„ฐํŽ˜์ด์Šค ๊ตฌ์„ฑ
with gr.Blocks(css=css) as demo:
state = gr.State([])
with gr.Row():
with gr.Column():
text = gr.Textbox("์ž…๋ ฅ ์ด๋ฏธ์ง€(์ฝ”๋„ˆ๋ฅผ ํด๋ฆญํ•˜์„ธ์š”)", show_label=False)
image_input = gr.Image(
show_label=False,
interactive=True,
elem_classes="image-container",
type="numpy",
)
clear_button = gr.Button("Clear Points")
cutting_button = gr.Button("Cutting Image(need more than 2 points)")
rotating_button = gr.Button("Rotate Image(clock wise 90 degree)")
auto_button = gr.Button("Auto Points detection")
convert_button = gr.Button("Convert Image")
with gr.Column():
text = gr.Textbox("๋ณ€ํ™˜๋  ์˜์—ญ", show_label=False)
image_output = gr.Image(show_label=False, type="pil")
# state_display = gr.Textbox(label="Current State")
# coordinates_text = gr.Textbox(label="Coordinates", placeholder="Enter coordinates (x, y) for each point")
# update_coords_button = gr.Button("Update Coordinates")
with gr.Column():
text = gr.Textbox("๊ฒฐ๊ณผ ์ด๋ฏธ์ง€", show_label=False)
result_image = gr.Image(show_label=False, format="png", type="numpy")
# # ์ด๋ฏธ์ง€ ์œ„์—์„œ ํด๋ฆญ ์ด๋ฒคํŠธ ์ฒ˜๋ฆฌ
image_input.select(draw_polygon_on_image, inputs=[image_input,state], outputs=[image_output,state])
# ์ขŒํ‘œ ์ดˆ๊ธฐํ™” ๋ฒ„ํŠผ ํด๋ฆญ ์‹œ ์ขŒํ‘œ ๋ฆฌ์…‹
clear_button.click(fn=reset_points, inputs=[image_input,state], outputs=[image_output,state])
# ์ด๋ฏธ์ง€ ์ž๋ฅด๊ธฐ ํŽธ์ง‘
cutting_button.click(fn=cutting_image, inputs=[image_input,state], outputs=[image_input, image_output, state])
# ์ด๋ฏธ์ง€ ํšŒ์ „
rotating_button.click(fn=rotate_image, inputs=[image_input], outputs=[image_input, state])
# ์ž๋™ ๊ฒ€์ถœ ๋ฒ„ํŠผ
auto_button.click(fn=auto_point_detect, inputs=image_input, outputs=[image_output,state])
# ๋ณ€ํ™˜ ๋ฒ„ํŠผ
convert_button.click(fn=convert, inputs=[image_input,state], outputs=result_image)
is_spaces = bool(
os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID") or os.getenv("SYSTEM") == "spaces"
)
demo.launch(share=not is_spaces and bool(os.getenv("GRADIO_SHARE")))