| import torch |
| import torch.nn as nn |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.nn as nn |
| from torchvision.transforms.functional import to_pil_image |
| from torchvision.transforms import Resize |
| import cv2 |
| import numpy as np |
| from torchcam.utils import overlay_mask |
| import gradio as gr |
| import os |
|
|
| class GeoPrior(nn.Module): |
|
|
| def __init__(self, embed_dim=128, num_heads=4, initial_value=2, heads_range=6): |
| super().__init__() |
| angle = 1.0 / (10000 ** torch.linspace(0, 1, embed_dim // num_heads // 2)) |
| angle = angle.unsqueeze(-1).repeat(1, 2).flatten() |
| self.initial_value = initial_value |
| self.heads_range = heads_range |
| self.num_heads = num_heads |
| decay = torch.log(1 - 2 ** (-initial_value - heads_range * torch.arange(num_heads, dtype=torch.float) / num_heads)) |
| self.register_buffer('angle', angle) |
| self.register_buffer('decay', decay) |
| |
| def generate_pos_decay(self, H: int, W: int): |
| ''' |
| generate 2d decay mask, the result is (HW)*(HW) |
| ''' |
| index_h = torch.arange(H).to(self.decay) |
| index_w = torch.arange(W).to(self.decay) |
| grid = torch.meshgrid([index_h, index_w]) |
| grid = torch.stack(grid, dim=-1).reshape(H*W, 2) |
| mask = grid[:, None, :] - grid[None, :, :] |
| mask = (mask.abs()).sum(dim=-1) |
| mask = mask |
| return mask |
| |
| def generate_2d_depth_decay(self, H: int, W: int, depth_grid): |
| ''' |
| generate 2d decay mask, the result is (HW)*(HW) |
| ''' |
| |
| |
| |
| |
| |
| |
| B,_,H,W = depth_grid.shape |
| grid_d = depth_grid.reshape(B, H*W, 1) |
| print(grid_d.dtype,'aaaaaaaaaaaaaaaaaa') |
| |
| mask_d = grid_d[:, :, None, :] - grid_d[:, None,:, :] |
| |
| |
| mask_d = (mask_d.abs()).sum(dim=-1) |
| |
| |
| mask_d = mask_d.unsqueeze(1) |
| return mask_d |
| |
| |
| |
| def forward(self, slen, depth_map, activate_recurrent=False, chunkwise_recurrent=False): |
| ''' |
| slen: (h, w) |
| h * w == l |
| recurrent is not implemented |
| ''' |
| |
| depth_map = F.interpolate(depth_map, size=slen,mode='bilinear',align_corners=False) |
| |
| depth_map = depth_map.float() |
| |
| |
| index = torch.arange(slen[0]*slen[1]).to(self.decay) |
| sin = torch.sin(index[:, None] * self.angle[None, :]) |
| sin = sin.reshape(slen[0], slen[1], -1) |
| cos = torch.cos(index[:, None] * self.angle[None, :]) |
| cos = cos.reshape(slen[0], slen[1], -1) |
| mask_1 = self.generate_pos_decay(slen[0], slen[1]) |
| mask_d = self.generate_2d_depth_decay(slen[0], slen[1], depth_map) |
| print(torch.max(mask_d),torch.min(mask_d),'-2') |
| mask = mask_d |
| mask_sum = (0.85*mask_1.cpu()+0.15*mask) * self.decay[:, None, None].cpu() |
| retention_rel_pos = ((sin, cos), mask, mask_1, mask_sum) |
| print(mask.shape,mask_1.shape) |
| |
|
|
| return retention_rel_pos |
|
|
| def fangda(mask, in_size=(480//20,640//20), out_size=(480,640)): |
| new_mask = torch.zeros(out_size) |
| ratio_h, ratio_w = out_size[0]//in_size[0], out_size[1]//in_size[1] |
| for i in range(in_size[0]): |
| for j in range(in_size[1]): |
| new_mask[i*ratio_h:(i+1)*ratio_h,j*ratio_w:(j+1)*ratio_w]=mask[i,j] |
| return new_mask |
|
|
| def put_mask(image,mask,color_rgb=None,border_mask=False,color_temp='jet',num_c='',beta=2,fixed_num=None): |
| mask = mask.numpy() |
| image = cv2.resize(image,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) |
| mask = cv2.resize(mask,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) |
| color=np.zeros((1,1,3), dtype=np.uint8) |
| if color_rgb is not None: |
| color[0,0,2],color[0,0,1],color[0,0,0]=color_rgb |
| else: |
| color[0, 0, 2], color[0, 0, 1], color[0, 0, 0]=120,86,87 |
| if fixed_num is not None: |
| mask = ((1-mask/255)) |
| else: |
| mask=(1-mask/np.max(mask)) |
|
|
|
|
|
|
| |
| result = overlay_mask(to_pil_image(image.astype(np.uint8)), to_pil_image(mask), colormap = color_temp, alpha=0.4) |
|
|
| |
| return np.array(result) |
|
|
|
|
| def visualize_geometry_prior(RGB_path, Depth_path, index_list=[[584]], cmap_list = ['jet_r'],x=0,y=0): |
| |
| H = 480//20 |
| W = 640//20 |
| index_num = int(x//20)+int((y//20+1)*32) |
| index_list = [[index_num]] |
| print(index_num) |
| |
| grid_d = cv2.imread(Depth_path,0) |
| |
| grid_d = cv2.resize(grid_d,dsize=(W,H),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) |
|
|
| grid_d = torch.tensor(grid_d).reshape(1,1,H,W) |
| grid_d_copy=cv2.imread(Depth_path) |
| grid_d_copy = cv2.resize(grid_d_copy,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) |
| grid_d_copy_gray = cv2.imread(Depth_path,0) |
| grid_d_copy_gray = cv2.resize(grid_d_copy_gray,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) |
| print('min max', torch.max(grid_d), torch.min(grid_d)) |
| print(grid_d.shape) |
| grid_d=grid_d.cpu() |
|
|
| respos = GeoPrior() |
| ((sin,cos), depth_map, mask_1, mask_sum) = respos((H,W), grid_d) |
| print(depth_map.shape, mask_1.shape,'-1') |
| print(torch.max(depth_map),torch.min(depth_map)) |
|
|
| |
| img_path = RGB_path |
| img = cv2.imread(img_path) |
| img = cv2.resize(img,dsize=(640,480),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) |
|
|
| grid_d_old = cv2.imread(Depth_path,0) |
| grid_d_old = cv2.resize(grid_d_old,dsize=(W,H),fx=1,fy=1,interpolation=cv2.INTER_LINEAR) |
| grid_d_old = torch.tensor(grid_d_old).reshape(H*W,1) |
| grid_d=grid_d.cpu() |
| mask_d_old = grid_d_old[:, None, :] - grid_d_old[None, :, :] |
| mask_d_old = (mask_d_old.abs()).sum(dim=-1) |
| Color_N=255 |
| for i in index_list[0]: |
| for color_temp in cmap_list: |
|
|
| temp_mask_d = depth_map[0,0,i,:].reshape(H,W).cpu() |
|
|
| temp_mask = mask_1[i,:].reshape(H,W).cpu() |
| print(torch.max(temp_mask_d),torch.min(temp_mask_d)) |
| temp_mask_d_old = mask_d_old[i,:].reshape(H,W).cpu() |
| temp_mask_sum = mask_sum[0,0,i,:].reshape(H,W).cpu() |
| temp_mask_d=torch.nn.functional.normalize(temp_mask_d, p=2.0, dim=1, eps=1e-12, out=None) |
|
|
| temp_mask_d = 255*(temp_mask_d-torch.min(temp_mask_d))/(torch.max(temp_mask_d)-torch.min(temp_mask_d)) |
| |
| temp_mask = 255*((temp_mask-torch.min(temp_mask))/(torch.max(temp_mask)-torch.min(temp_mask))) |
|
|
| temp_mask_sum = 255*((temp_mask_sum-torch.min(temp_mask_sum))/(torch.max(temp_mask_sum)-torch.min(temp_mask_sum))) |
| gama =0.55 |
| temp_mask_d_old = 255*(temp_mask_d_old-torch.min(temp_mask_d_old))/(torch.max(temp_mask_d_old)-torch.min(temp_mask_d_old)) |
| a0=put_mask(img,fangda(temp_mask),color_temp=color_temp) |
| jiange = 255*torch.ones(img.shape[0],20) |
| temp_mask_fuse = torch.cat([fangda(temp_mask),jiange,fangda(temp_mask_d),jiange,fangda(gama*temp_mask+(1-gama)*temp_mask_d),jiange,torch.tensor(grid_d_copy_gray)],dim=1) |
| jiange = np.ones((img.shape[0],20, 3)) * 255 |
| |
| a2 = put_mask(img, fangda(temp_mask_d),color_temp=color_temp) |
| print(a2.shape) |
| a3 = put_mask(img,fangda(gama*temp_mask+(1-gama)*temp_mask_d),color_temp=color_temp) |
| |
| |
| |
| |
| return a3.astype(np.uint8) |
|
|
| |
| def process_images(rgb_image, depth_image): |
| """ |
| 处理上传的图像并返回可视化结果 |
| |
| Args: |
| rgb_image: gradio上传的RGB图像 |
| depth_image: gradio上传的深度图像 |
| Returns: |
| 可视化结果图像 |
| """ |
| |
| temp_rgb_path = "temp_rgb.jpg" |
| temp_depth_path = "temp_depth.png" |
| |
| |
| cv2.imwrite(temp_rgb_path, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)) |
| cv2.imwrite(temp_depth_path, depth_image) |
| |
| |
| try: |
| result = visualize_geometry_prior(temp_rgb_path, temp_depth_path,x=x,y=y) |
| |
| |
| os.remove(temp_rgb_path) |
| os.remove(temp_depth_path) |
| |
| |
| result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) |
| return result |
| except Exception as e: |
| print(f"Error during processing: {str(e)}") |
| return None |
| finally: |
| |
| if os.path.exists(temp_rgb_path): |
| os.remove(temp_rgb_path) |
| if os.path.exists(temp_depth_path): |
| os.remove(temp_depth_path) |
|
|
| def draw_star(image, x, y, size=20, color=(255, 0, 0), thickness=2): |
| """在图像上绘制五角星""" |
| |
| pts = np.array([[x, y - size], |
| [x + size * 0.588, y + size * 0.809], |
| [x - size * 0.951, y - size * 0.309], |
| [x + size * 0.951, y - size * 0.309], |
| [x - size * 0.588, y + size * 0.809]], np.int32) |
| |
| |
| cv2.polylines(image, [pts], True, color, thickness) |
| return image |
|
|
| |
| def create_demo(): |
| with gr.Blocks() as demo: |
| gr.Markdown("# Geometry Prior Visualization Demo") |
| gr.Markdown(""" |
| ### Instructions: |
| 1. Upload RGB and Depth images |
| 2. Enter X (0-640) and Y (0-480) coordinates |
| 3. A star marker will be shown on the images at the selected position |
| 4. Click "Generate Visualization" to create the visualization |
| """) |
| |
| with gr.Row(): |
| with gr.Column(): |
| rgb_input = gr.Image(label="Upload RGB Image") |
| depth_input = gr.Image(label="Upload Depth Image", image_mode="L") |
| with gr.Row(): |
| x_coord = gr.Number(label="X (0-640)", value=160, minimum=0, maximum=640) |
| y_coord = gr.Number(label="Y (0-480)", value=270, minimum=0, maximum=480) |
| coordinates_text = gr.Textbox(label="Grid Position and Index", interactive=False) |
| |
| with gr.Column(): |
| marked_rgb = gr.Image(label="Marked RGB Image") |
| marked_depth = gr.Image(label="Marked Depth Image") |
| output_image = gr.Image(label="Visualization Result") |
| status_text = gr.Textbox(label="Status", interactive=False) |
|
|
| def update_coordinates_and_images(rgb_image, depth_image, x, y): |
| |
| x = max(0, min(640, float(x))) |
| y = max(0, min(480, float(y))) |
| |
| |
| H, W = 480//20, 640//20 |
| scaled_x = int(x * W / 640) |
| scaled_y = int(y * H / 480) |
| |
| |
| grid_index = scaled_y * W + scaled_x |
| |
| |
| rgb_marked = rgb_image.copy() |
| if len(rgb_marked.shape) == 2: |
| rgb_marked = cv2.cvtColor(rgb_marked, cv2.COLOR_GRAY2BGR) |
| elif rgb_marked.shape[2] == 4: |
| rgb_marked = cv2.cvtColor(rgb_marked, cv2.COLOR_RGBA2BGR) |
| rgb_marked = draw_star(rgb_marked, int(x), int(y), size=20, color=(255, 0, 0)) |
| |
| |
| depth_marked = depth_image.copy() |
| if len(depth_marked.shape) == 2: |
| depth_marked = cv2.cvtColor(depth_marked, cv2.COLOR_GRAY2BGR) |
| depth_marked = draw_star(depth_marked, int(x), int(y), size=20, color=(0, 255, 0)) |
| |
| return (f"Grid position: ({scaled_x}, {scaled_y}), Index: {grid_index}", |
| rgb_marked, |
| depth_marked) |
|
|
| |
| coord_update_btn = gr.Button("Update Coordinates") |
| coord_update_btn.click( |
| fn=update_coordinates_and_images, |
| inputs=[rgb_input, depth_input, x_coord, y_coord], |
| outputs=[coordinates_text, marked_rgb, marked_depth] |
| ) |
| |
| def process_with_status(rgb_image, depth_image, coords_text, x, y): |
| try: |
| |
| temp_rgb_path = "temp_rgb.jpg" |
| temp_depth_path = "temp_depth.png" |
| |
| |
| cv2.imwrite(temp_rgb_path, cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)) |
| cv2.imwrite(temp_depth_path, depth_image) |
| |
| if coords_text: |
| index = int(coords_text.split("Index: ")[-1]) |
| index_list = [[index]] |
| else: |
| index_list = [[584]] |
| |
| result = visualize_geometry_prior(temp_rgb_path, temp_depth_path, index_list=index_list, x=x, y=y) |
| |
| |
| os.remove(temp_rgb_path) |
| os.remove(temp_depth_path) |
| |
| |
| result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) |
| return result, "Processing completed successfully!" |
| except Exception as e: |
| |
| if os.path.exists(temp_rgb_path): |
| os.remove(temp_rgb_path) |
| if os.path.exists(temp_depth_path): |
| os.remove(temp_depth_path) |
| return None, f"Error: {str(e)}" |
|
|
| process_btn = gr.Button("Generate Visualization") |
| process_btn.click( |
| fn=process_with_status, |
| inputs=[rgb_input, depth_input, coordinates_text, x_coord, y_coord], |
| outputs=[output_image, status_text] |
| ) |
| |
| |
| x_coord.change( |
| fn=update_coordinates_and_images, |
| inputs=[rgb_input, depth_input, x_coord, y_coord], |
| outputs=[coordinates_text, marked_rgb, marked_depth] |
| ) |
| y_coord.change( |
| fn=update_coordinates_and_images, |
| inputs=[rgb_input, depth_input, x_coord, y_coord], |
| outputs=[coordinates_text, marked_rgb, marked_depth] |
| ) |
| |
| gr.Examples( |
| examples=[ |
| ["assets/example_rgb.jpg", "assets/example_depth.png"] |
| ], |
| inputs=[rgb_input, depth_input] |
| ) |
| |
| return demo |
|
|
| |
| if __name__ == "__main__": |
| demo = create_demo() |
| demo.queue() |
| demo.launch( |
| server_name="0.0.0.0", |
| share=True, |
| debug=True |
| ) |
|
|