x444 commited on
Commit
6483239
·
1 Parent(s): 4bdd346

yiyang 722

Browse files
Files changed (1) hide show
  1. data_real_world/segment.py +193 -0
data_real_world/segment.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image, ImageDraw
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from segment_anything import sam_model_registry, SamPredictor
6
+ import cv2
7
+ import os
8
+
9
+ from collections import defaultdict
10
+
11
+ def concat_image_variations_with_base(
12
+ base_folder: str,
13
+ variation_folder: str,
14
+ output_folder: str,
15
+ image_size: int = 512,
16
+ stroke_width: int = 6
17
+ ):
18
+ """
19
+ Includes the base image followed by variations in a row.
20
+ Outlines:
21
+ - base image: black stroke
22
+ - _0 -> green, _1 -> blue, _2 -> red
23
+ """
24
+
25
+ os.makedirs(output_folder, exist_ok=True)
26
+
27
+ suffix_to_color = {
28
+ '0': 'green',
29
+ '1': 'blue',
30
+ '2': 'red'
31
+ }
32
+
33
+ # Group variation images by ID
34
+ groups = defaultdict(list)
35
+ for fname in sorted(os.listdir(variation_folder)):
36
+ if fname.endswith('.png'):
37
+ match = re.match(r"(\d+)_\d+_(\d)\.png", fname)
38
+ if match:
39
+ base_id = match.group(1)
40
+ groups[base_id].append(fname)
41
+
42
+ for base_id, variations in groups.items():
43
+ images = []
44
+
45
+ # Load base image
46
+ base_candidates = [f for f in os.listdir(base_folder) if f.startswith(base_id)]
47
+ if base_candidates:
48
+ base_img_path = os.path.join(base_folder, base_candidates[0])
49
+ base_img = Image.open(base_img_path).convert("RGBA").resize((image_size, image_size))
50
+ draw = ImageDraw.Draw(base_img)
51
+ draw.rectangle([0, 0, image_size - 1, image_size - 1], outline="black", width=stroke_width)
52
+ images.append(base_img)
53
+ else:
54
+ print(f"Base image not found for ID {base_id}")
55
+ continue
56
+
57
+ # Add variation images
58
+ for var in sorted(variations, key=lambda x: int(x.split('_')[1])):
59
+ path = os.path.join(variation_folder, var)
60
+ img = Image.open(path).convert("RGBA").resize((image_size, image_size))
61
+ draw = ImageDraw.Draw(img)
62
+ suffix = var.split('_')[-1].split('.')[0]
63
+ color = suffix_to_color.get(suffix, "black")
64
+ draw.rectangle([0, 0, image_size - 1, image_size - 1], outline=color, width=stroke_width)
65
+ images.append(img)
66
+
67
+ # Concatenate all
68
+ total_width = image_size * len(images)
69
+ concat_img = Image.new("RGBA", (total_width, image_size))
70
+ for i, img in enumerate(images):
71
+ concat_img.paste(img, (i * image_size, 0))
72
+
73
+ output_path = os.path.join(output_folder, f"{base_id}_concat.png")
74
+ concat_img.save(output_path)
75
+ print(f"Saved: {output_path}")
76
+
77
+
78
+
79
+ # Load the SAM model
80
+ def load_sam_model(model_type="vit_h", checkpoint_path="sam_vit_h_4b8939.pth"):
81
+ sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
82
+ sam.to("cuda" if torch.cuda.is_available() else "cpu")
83
+ predictor = SamPredictor(sam)
84
+ return predictor
85
+
86
+ # Draw bounding box and label
87
+ def draw_box(img, box, label=None, color="green", output_path=None):
88
+ draw = ImageDraw.Draw(img)
89
+ draw.rectangle(box, outline=color, width=3)
90
+ if label:
91
+ draw.text((box[0] + 5, box[1] + 5), label, fill=color)
92
+ if output_path:
93
+ img.save(output_path)
94
+ return img
95
+
96
+ def yolo_to_xyxy(boxes, image_width, image_height):
97
+ """
98
+ Convert YOLO format boxes (label cx cy w h) to absolute xyxy format.
99
+
100
+ Parameters:
101
+ boxes (list of list): Each item is [label, cx, cy, w, h] in relative coords.
102
+ image_width (int): Width of the image in pixels.
103
+ image_height (int): Height of the image in pixels.
104
+
105
+ Returns:
106
+ List of [label, x1, y1, x2, y2] in pixel coords.
107
+ """
108
+ xyxy_boxes = []
109
+ for box in boxes:
110
+ label, cx, cy, w, h = box
111
+ cx *= image_width
112
+ cy *= image_height
113
+ w *= image_width
114
+ h *= image_height
115
+ x1 = int(cx - w / 2)
116
+ y1 = int(cy - h / 2)
117
+ x2 = int(cx + w / 2)
118
+ y2 = int(cy + h / 2)
119
+ xyxy_boxes.append([int(label), x1, y1, x2, y2])
120
+ return xyxy_boxes
121
+
122
+
123
+ # Main logic
124
+ def segment(image_np, box_coords, predictor):
125
+
126
+ # SAM expects box as numpy array in [x1, y1, x2, y2] format
127
+ input_box = np.array([box_coords])
128
+
129
+ # Get mask
130
+ masks, scores, logits = predictor.predict(box=input_box, multimask_output=False)
131
+ mask = masks[0]
132
+
133
+ # Apply mask to image
134
+ masked_image = image_np.copy()
135
+ masked_image[~mask] = [255, 255, 255] # white background where mask is off
136
+
137
+ # Convert back to PIL for saving
138
+ result_img = Image.fromarray(masked_image)
139
+ return result_img
140
+ # result_img = draw_box(result_img, box_coords, label="object", color="green", output_path="annotated_sam.jpg")
141
+
142
+ # print("✅ Image saved as 'annotated_sam.jpg'")
143
+
144
+ # ============================ single image ============================
145
+ # image_path = "A_images_resized/0010.png" # Replace with your image
146
+ # checkpoint_path = "sam_vit_h_4b8939.pth" # Replace with your model checkpoint
147
+ # box_coords = (100, 150, 300, 350) # Replace with your target box (x1, y1, x2, y2)
148
+
149
+ # # Load model
150
+ # predictor = load_sam_model(checkpoint_path=checkpoint_path)
151
+
152
+ # # load image
153
+ # image_pil = Image.open(image_path).convert("RGB")
154
+ # image_np = np.array(image_pil)
155
+ # predictor.set_image(image_np)
156
+
157
+ # result_img = segment(image_np, box_coords, predictor)
158
+
159
+
160
+
161
+ # ============================ multiple image ============================
162
+ image_folder_path = "A_images_resized"
163
+ checkpoint_path = "sam_vit_h_4b8939.pth" # Replace with your model checkpoint
164
+ predictor = load_sam_model(checkpoint_path=checkpoint_path)
165
+ print("okkkkk")
166
+
167
+ for img_path in os.listdir(image_folder_path):
168
+ # load image
169
+ image_pil = Image.open(os.path.join(image_folder_path, img_path)).convert("RGB")
170
+ image_np = np.array(image_pil)
171
+ predictor.set_image(image_np)
172
+ print("12345")
173
+
174
+ # load txt
175
+ with open(f"A_labels_resized/{img_path.removesuffix('.png')}.txt", "r") as f:
176
+ lines = f.readlines()
177
+ boxes = [list(map(float, line.strip().split())) for line in lines]
178
+ box_coords = yolo_to_xyxy(boxes, 1024, 1024)
179
+ for idx, box_coord in enumerate(box_coords):
180
+ label, x, y, x1, y1 = box_coord[0], box_coord[1], box_coord[2], box_coord[3], box_coord[4]
181
+ box_coord = (x, y, x1, y1)
182
+ result_img = segment(image_np, box_coord, predictor)
183
+ result_img.save(f"layer_image/{img_path.removesuffix('.png')}_{idx}_{label}.png")
184
+
185
+
186
+ # === view both original and layered data ===
187
+ # concat_image_variations_with_base(
188
+ # base_folder="A_images_resized",
189
+ # variation_folder="layer_image",
190
+ # output_folder="view_image",
191
+ # image_size= 512,
192
+ # stroke_width= 6
193
+ # )