| import cv2
|
| import gradio as gr
|
| import os
|
| from edit_func import *
|
| from TransUnet import Trans_UNet
|
| import TransUnet_Config as config2
|
| from huggingface_hub import hf_hub_download
|
| from googletrans import Translator
|
| import random
|
| import torch.nn as nn
|
| import spaces
|
|
|
| @spaces.GPU
|
| class DTM(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| self.detect_text_model = Trans_UNet(
|
| config2.in_channels, config2.adp_channels, config2.out_channels,
|
| config2.trans_num_layers, config2.trans_num_attn_heads, config2.trans_ffw_channels,
|
| config2.dropout
|
| ).to(self.device)
|
| self.repo_name = 'SS3M/detect-text-model'
|
| files = ['detect-text-v3-0.pt', 'detect-text-v3-1.pt',
|
| 'detect-text-v3-2.pt', 'detect-text-v3-3.pt',
|
| 'detect-text-v3-4.pt', 'detect-text-v3-5.pt',
|
| 'detect-text-v3-6.pt', 'detect-text-v3-7.pt']
|
| self.files = []
|
| for file in files:
|
| self.files.append(hf_hub_download(repo_id=self.repo_name, filename=file))
|
|
|
| def forward(self, X):
|
| X = X.to(self.device)
|
| N, C, H, W = X.shape
|
| result = torch.zeros((N, 1, H, W))
|
| for file in self.files:
|
| model_path = file
|
| best_model_state = torch.load(
|
| model_path,
|
| weights_only=True,
|
| map_location=self.device
|
| )
|
| self.detect_text_model.load_state_dict(best_model_state)
|
| result += self.detect_text_model(X)
|
| result /= len(self.files)
|
| return result
|
|
|
| @spaces.GPU
|
| class DWBM(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| self.detect_wordball_model = Trans_UNet(
|
| config2.in_channels, config2.adp_channels, config2.out_channels,
|
| config2.trans_num_layers, config2.trans_num_attn_heads, config2.trans_ffw_channels,
|
| config2.dropout
|
| ).to(self.device)
|
| self.repo_name = 'SS3M/detect-wordball-model'
|
| files = ['detect-text-v3-0.pt', 'detect-text-v3-1.pt',
|
| 'detect-text-v3-2.pt', 'detect-text-v3-3.pt',
|
| 'detect-text-v3-4.pt', 'detect-text-v3-5.pt',
|
| 'detect-text-v3-6.pt', 'detect-text-v3-7.pt']
|
| self.files = []
|
| for file in files:
|
| self.files.append(hf_hub_download(repo_id=self.repo_name, filename=file))
|
|
|
| def forward(self, X):
|
| X = X.to(self.device)
|
| N, C, H, W = X.shape
|
| result = torch.zeros((N, 1, H, W))
|
| for file in self.files:
|
| model_path = file
|
| best_model_state = torch.load(
|
| model_path,
|
| weights_only=True,
|
| map_location=self.device
|
| )
|
| self.detect_wordball_model.load_state_dict(best_model_state)
|
| result += self.detect_wordball_model(X)
|
| result /= len(self.files)
|
| return result
|
|
|
| detect_text_model = DTM()
|
| detect_wordball_model = DWBM()
|
|
|
| translator = Translator()
|
|
|
| def down1(src_img):
|
| src_img = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
|
| text_msk = create_text_mask(src_img, detect_text_model)
|
| wordball_msk = create_wordball_mask(src_img, detect_wordball_model)
|
|
|
| text_positions, areas = get_text_positions(text_msk, text_value=0)
|
| rgbs = []
|
| for _ in range(len(areas)):
|
| rgbs.append((random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))
|
|
|
| idx = '; '.join(str(i) for i in range(len(areas)))
|
| text_positions = '; '.join([', '.join(str(i) for i in pos) for pos in text_positions])
|
| areas = '; '.join(str(i) for i in areas)
|
| rgbs = '; '.join([', '.join(str(i) for i in rgb) for rgb in rgbs])
|
| src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
|
| return text_msk*255, wordball_msk*255, idx, text_positions, areas, rgbs, 'Xong'
|
|
|
| def idx_txt_change(src_img, idx_txt, pos_txt, rgb_txt):
|
| try:
|
| src_img2 = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
|
| text_positions = pos_txt.split('; ')
|
| for idx in range(len(text_positions)):
|
| text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
|
| rgbs = rgb_txt.split('; ')
|
| for idx in range(len(rgbs)):
|
| rgbs[idx] = (int(i) for i in rgbs[idx].split(', '))
|
| idxes = [int(idx) for idx in idx_txt.split('; ')]
|
|
|
| for idx, ((min_x, min_y, max_x, max_y), (r, g, b)) in enumerate(zip(text_positions, rgbs)):
|
| if idx in idxes:
|
| cv2.rectangle(src_img2, (min_x, min_y), (max_x, max_y), (b, g, r), thickness=4)
|
| src_img2 = cv2.cvtColor(src_img2, cv2.COLOR_BGR2RGB)
|
| return src_img2
|
| except:
|
| return src_img
|
|
|
| def scale_area_change(min_area, max_area, area_txt):
|
| areas = [int(area) for area in area_txt.split('; ')]
|
| idxes = []
|
| for idx, area in enumerate(areas):
|
| if min_area <= area <= max_area:
|
| idxes.append(idx)
|
| idxes = '; '.join(str(i) for i in idxes)
|
| return idxes
|
|
|
| def position_block_change(X, Y, W, H, ID, pos_txt_value):
|
| text_positions = pos_txt_value.split('; ')
|
| for idx in range(len(text_positions)):
|
| text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
|
|
|
| text_positions2 = []
|
| for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
|
| if idx == ID:
|
| text_positions2.append((X, Y, X+W, Y+H))
|
| else:
|
| text_positions2.append((min_x, min_y, max_x, max_y))
|
| text_positions2 = '; '.join([', '.join(str(i) for i in pos) for pos in text_positions2])
|
| return text_positions2
|
|
|
| def ID_block_change(ID_value, checkbox_value, ID_txt_value):
|
| ID_txt_value = [int(i) for i in ID_txt_value.split('; ')]
|
| if checkbox_value and ID_value not in ID_txt_value:
|
| ID_txt_value.append(ID_value)
|
| if not checkbox_value and ID_value in ID_txt_value:
|
| ID_txt_value.remove(ID_value)
|
| ID_txt_value = sorted(ID_txt_value)
|
| ID_txt_value = '; '.join([str(i) for i in ID_txt_value])
|
| return ID_txt_value
|
|
|
| def down2(src_img_value, txt_mask_value, wordball_mask_value, idx_txt_value, pos_txt_value):
|
| src_img_value = cv2.cvtColor(src_img_value, cv2.COLOR_RGB2BGR)
|
| text_positions = pos_txt_value.split('; ')
|
| for idx in range(len(text_positions)):
|
| text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
|
| idxes = [int(i) for i in idx_txt_value.split('; ')]
|
|
|
| for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
|
| if idx not in idxes:
|
| txt_mask_value[min_y:max_y+1, min_x:max_x+1] = 255
|
| txt_mask_value = txt_mask_value[:, :, 0].astype(np.uint8)
|
| non_text_src_img = clear_text(src_img_value, txt_mask_value, wordball_mask_value, text_value=0, non_text_value=255, r=5)
|
|
|
| list_texts = get_list_texts(src_img_value, [tuple(map(int, pos.split(', '))) for idx, pos in enumerate(pos_txt_value.split('; ')) if idx in idxes])
|
| list_translated_texts = translate(list_texts, translator)
|
| list_fonts = '; '.join(['MTO Astro City.ttf' for _ in range(len(list_translated_texts))])
|
| list_sizes = '; '.join(['20' for _ in range(len(list_translated_texts))])
|
| list_strokes = '; '.join(['3' for _ in range(len(list_translated_texts))])
|
| list_pads = '; '.join(['5' for _ in range(len(list_translated_texts))])
|
| list_translated_texts = '; '.join(list_translated_texts)
|
| switch = str(random.random())
|
|
|
| return non_text_src_img, list_translated_texts, list_fonts, list_sizes, list_strokes, list_pads, switch, 'Xong'
|
|
|
| def text_info_change(non_txt_img_value, translated_txt_value, pos_txt_value, idx_txt_value, font_txt_value, size_txt_value, stroke_txt_value, pad_txt_value):
|
| non_txt_img_value = non_txt_img_value.copy()
|
| idxes = [int(i) for i in idx_txt_value.split('; ')]
|
|
|
| translated_text_src_img = insert_text(non_txt_img_value,
|
| translated_txt_value.split('; '),
|
| [tuple(map(int, pos.split(', '))) for idx, pos in enumerate(pos_txt_value.split('; ')) if idx in idxes],
|
| font=font_txt_value.split('; '),
|
| font_size=[int(i) for i in size_txt_value.split('; ')],
|
| pad=[int(i) for i in pad_txt_value.split('; ')],
|
| stroke=[int(i) for i in stroke_txt_value.split('; ')])
|
| return translated_text_src_img
|
|
|
| def value2_change(value, ID2_value, txt_value):
|
| txt_value = txt_value.split('; ')
|
|
|
| txt_value2 = []
|
| for idx, text in enumerate(txt_value):
|
| if idx == ID2_value:
|
| txt_value2.append(str(value))
|
| else:
|
| txt_value2.append(str(text))
|
| txt_value2 = '; '.join(txt_value2)
|
| return txt_value2
|
|
|
|
|
| with gr.Blocks() as demo:
|
|
|
| src_img = gr.Image(type="numpy", label="Upload Image")
|
|
|
| down_bttn_1 = gr.Button("↓", elem_classes="arrow-button")
|
|
|
| with gr.Row():
|
| txt_mask = gr.Image(type="numpy", label="Upload Image", visible=True)
|
| wordball_mask = gr.Image(type="numpy", label="Upload Image", visible=True)
|
| complete = gr.Textbox()
|
| with gr.Row():
|
| idx_txt = gr.Textbox(label='ID', interactive=False, visible=False)
|
| pos_txt = gr.Textbox(label='Pos', interactive=False, visible=False)
|
| area_txt = gr.Textbox(label='Area', interactive=False, visible=False)
|
| rgb_txt = gr.Textbox(label='rgb', interactive=False, visible=False)
|
| with gr.Row():
|
| boxed_txt_img = gr.Image(type="numpy", label="Upload Image")
|
| with gr.Column() as down_1_column:
|
| @gr.render(inputs=[pos_txt, rgb_txt], triggers=[rgb_txt.change])
|
| def create_box(pos_txt_value, rgb_txt_value):
|
| text_positions = pos_txt_value.split('; ')
|
| for idx in range(len(text_positions)):
|
| text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
|
| rgbs = rgb_txt_value.split('; ')
|
| for idx in range(len(rgbs)):
|
| rgbs[idx] = (int(i) for i in rgbs[idx].split(', '))
|
|
|
| elements = []
|
| for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
|
| with gr.Group() as box:
|
| r, g, b = rgbs[idx]
|
| with gr.Row():
|
| gr.Markdown(
|
| f"""
|
| <div style="margin-left: 20px; display: flex; align-items: center;">
|
| <div style="width: 10px; height: 10px; background-color: rgb({r}, {g}, {b}); margin-right: 5px;"></div>
|
| <span style="font-size: 20px;">Textbox {idx+1}</span>
|
| </div>
|
| """
|
| )
|
| checkbox = gr.Checkbox(value=True, label='', min_width=50, interactive=True)
|
| with gr.Row():
|
| X = gr.Number(label="X", value=min_x, interactive=True)
|
| Y = gr.Number(label="Y", value=min_y, interactive=True)
|
| W = gr.Number(label="W", value=max_x-min_x, interactive=True)
|
| H = gr.Number(label="H", value=max_y-min_y, interactive=True)
|
| ID = gr.Number(label="ID", value=idx, interactive=True, visible=False)
|
| elements.append((X, Y, W, H, ID))
|
|
|
| checkbox.change(
|
| fn=ID_block_change,
|
| inputs=[ID, checkbox, idx_txt],
|
| outputs=idx_txt,
|
| show_progress=False
|
| ).then(
|
| fn=idx_txt_change,
|
| inputs=[src_img, idx_txt, pos_txt, rgb_txt],
|
| outputs=boxed_txt_img,
|
| )
|
| X.change(
|
| fn=position_block_change,
|
| inputs=[X, Y, W, H, ID, pos_txt],
|
| outputs=pos_txt,
|
| show_progress=False
|
| ).then(
|
| fn=idx_txt_change,
|
| inputs=[src_img, idx_txt, pos_txt, rgb_txt],
|
| outputs=boxed_txt_img,
|
| show_progress=False
|
| )
|
| Y.change(
|
| fn=position_block_change,
|
| inputs=[X, Y, W, H, ID, pos_txt],
|
| outputs=pos_txt,
|
| show_progress=False
|
| ).then(
|
| fn=idx_txt_change,
|
| inputs=[src_img, idx_txt, pos_txt, rgb_txt],
|
| outputs=boxed_txt_img,
|
| show_progress=False
|
| )
|
| W.change(
|
| fn=position_block_change,
|
| inputs=[X, Y, W, H, ID, pos_txt],
|
| outputs=pos_txt,
|
| show_progress=False
|
| ).then(
|
| fn=idx_txt_change,
|
| inputs=[src_img, idx_txt, pos_txt, rgb_txt],
|
| outputs=boxed_txt_img,
|
| show_progress=False
|
| )
|
| H.change(
|
| fn=position_block_change,
|
| inputs=[X, Y, W, H, ID, pos_txt],
|
| outputs=pos_txt,
|
| show_progress=False
|
| ).then(
|
| fn=idx_txt_change,
|
| inputs=[src_img, idx_txt, pos_txt, rgb_txt],
|
| outputs=boxed_txt_img,
|
| show_progress=False
|
| )
|
| down_bttn_2 = gr.Button("↓", elem_classes="arrow-button")
|
|
|
| non_txt_img = gr.Image(type="numpy", label="Upload Image", visible=False)
|
| complete2 = gr.Textbox()
|
| with gr.Row():
|
| translated_txt = gr.Textbox(label='translated', interactive=False, visible=False)
|
| font_txt = gr.Textbox(label='font', interactive=False, visible=False)
|
| size_txt = gr.Textbox(label='size', interactive=False, visible=False)
|
| stroke_txt = gr.Textbox(label='stroke', interactive=False, visible=False)
|
| pad_txt = gr.Textbox(label='pad', interactive=False, visible=False)
|
| switch_txt = gr.Textbox(label='switch', value='1', interactive=False, visible=False)
|
| with gr.Row():
|
| boxed_inserted_non_txt_img = gr.Image(type="numpy", label="Upload Image")
|
| with gr.Column():
|
| @gr.render(inputs=[translated_txt, font_txt, size_txt, stroke_txt, pad_txt], triggers=[switch_txt.change])
|
| def create_box2(translated_txt_value, font_txt_value, size_txt_value, stroke_txt_value, pad_txt_value):
|
| translated_txt_value = translated_txt_value.split('; ')
|
| font_txt_value = font_txt_value.split('; ')
|
| size_txt_value = size_txt_value.split('; ')
|
| stroke_txt_value = stroke_txt_value.split('; ')
|
| pad_txt_value = pad_txt_value.split('; ')
|
|
|
| elements = []
|
| for idx in range(len(font_txt_value)):
|
| with gr.Group():
|
| gr.Markdown(
|
| f"""
|
| <div style="margin-left: 20px; display: flex; align-items: center;">
|
| <div style="width: 10px; height: 10px; background-color: rgb(255, 255, 255); margin-right: 5px;"></div>
|
| <span style="font-size: 20px;">Text box {idx}</span>
|
| </div>
|
| """
|
| )
|
| translated_text_box = gr.Textbox(label="Translate", value=translated_txt_value[idx], interactive=True)
|
| with gr.Row():
|
| font = gr.Dropdown(choices=os.listdir('MTO Font'), label="Phông chữ", value=font_txt_value[idx], interactive=True, scale=7)
|
| size = gr.Number(label="Size", value=int(size_txt_value[idx]), interactive=True, minimum=1)
|
| stroke = gr.Number(label="Stroke", value=int(stroke_txt_value[idx]), interactive=True, minimum=0, maximum=5)
|
| pad = gr.Number(label="Pad", value=int(pad_txt_value[idx]), interactive=True, minimum=1, maximum=10)
|
| ID2 = gr.Number(label="ID", value=int(idx), interactive=True, visible=False)
|
|
|
| translated_text_box.submit(
|
| fn=value2_change,
|
| inputs=[translated_text_box, ID2, translated_txt],
|
| outputs=translated_txt,
|
| show_progress=False
|
| ).then(
|
| fn=text_info_change,
|
| inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
|
| outputs=boxed_inserted_non_txt_img,
|
| )
|
| font.change(
|
| fn=value2_change,
|
| inputs=[font, ID2, font_txt],
|
| outputs=font_txt,
|
| show_progress=False
|
| ).then(
|
| fn=text_info_change,
|
| inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
|
| outputs=boxed_inserted_non_txt_img,
|
| )
|
| size.change(
|
| fn=value2_change,
|
| inputs=[size, ID2, size_txt],
|
| outputs=size_txt,
|
| show_progress=False
|
| ).then(
|
| fn=text_info_change,
|
| inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
|
| outputs=boxed_inserted_non_txt_img,
|
| )
|
| stroke.change(
|
| fn=value2_change,
|
| inputs=[stroke, ID2, stroke_txt],
|
| outputs=stroke_txt,
|
| show_progress=False
|
| ).then(
|
| fn=text_info_change,
|
| inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
|
| outputs=boxed_inserted_non_txt_img,
|
| )
|
| pad.change(
|
| fn=value2_change,
|
| inputs=[pad, ID2, pad_txt],
|
| outputs=pad_txt,
|
| show_progress=False
|
| ).then(
|
| fn=text_info_change,
|
| inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
|
| outputs=boxed_inserted_non_txt_img,
|
| )
|
|
|
|
|
| demo.css = """
|
| .arrow-button {
|
| font-size: 40px; /* Kích thước font */
|
| }
|
| .group-elem {
|
| height: 70px;
|
| }
|
| """
|
|
|
|
|
| down_bttn_1.click(
|
| fn=down1,
|
| inputs=src_img,
|
| outputs=[txt_mask, wordball_mask, idx_txt, pos_txt, area_txt, rgb_txt, complete],
|
| )
|
| down_bttn_2.click(
|
| fn=down2,
|
| inputs=[src_img, txt_mask, wordball_mask, idx_txt, pos_txt],
|
| outputs=[non_txt_img, translated_txt, font_txt, size_txt, stroke_txt, pad_txt, switch_txt, complete2],
|
| ).then(
|
| fn=text_info_change,
|
| inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
|
| outputs=boxed_inserted_non_txt_img,
|
| )
|
|
|
| demo.launch()
|
|
|