yusuf-tiryaki commited on
Commit
99c5bcb
·
1 Parent(s): fd1bd17

glb support

Browse files
Files changed (1) hide show
  1. app.py +42 -116
app.py CHANGED
@@ -6,146 +6,72 @@ import plotly.graph_objects as go
6
  from torch_geometric.utils import to_dense_batch
7
  from model import HierarchicalFPSCliffordNet
8
 
9
- # --- SAF PYTORCH C++ YEDEKLERİ (Hugging Face için Özel Çözüm) ---
10
- def knn_pure(x, y, k):
11
  dist = torch.cdist(x, y)
12
  _, topk_idx = torch.topk(dist, k, dim=1, largest=False)
13
- col = torch.arange(x.size(0)).view(-1, 1).expand(-1, k).reshape(-1)
14
- row = topk_idx.reshape(-1)
15
- return torch.stack([row, col], dim=0)
16
 
17
- def radius_graph_pure(pos, r, max_num_neighbors=32):
18
  dist = torch.cdist(pos, pos)
19
- n = pos.size(0)
20
- k = min(max_num_neighbors, n)
21
- _, target = torch.topk(dist, k, dim=1, largest=False)
22
- source = torch.arange(n).view(-1, 1).expand(-1, k)
23
- valid = dist[source, target] < r
24
- return torch.stack([source[valid], target[valid]], dim=0)
25
 
26
  def fps_pure(pos, ratio=0.5):
27
- n = pos.size(0)
28
- k = max(1, int(n * ratio))
29
- idx = torch.zeros(k, dtype=torch.long, device=pos.device)
30
- dist = torch.full((n,), 1e10, device=pos.device)
31
  farthest = 0
32
  for i in range(k):
33
  idx[i] = farthest
34
- farthest_point = pos[farthest].view(1, 3)
35
- dist_to_farthest = torch.cdist(pos, farthest_point).squeeze(-1)
36
- dist = torch.min(dist, dist_to_farthest)
37
  farthest = torch.argmax(dist).item()
38
  return idx
39
 
40
- # --- Kategori Sözlükleri ---
41
- CATEGORIES = {
42
- 'Airplane': 0, 'Bag': 1, 'Cap': 2, 'Car': 3, 'Chair': 4, 'Earphone': 5, 'Guitar': 6,
43
- 'Knife': 7, 'Lamp': 8, 'Laptop': 9, 'Motorbike': 10, 'Mug': 11, 'Pistol': 12,
44
- 'Rocket': 13, 'Skateboard': 14, 'Table': 15
45
- }
46
- SEG_CLASSES = {
47
- 'Airplane': [0, 1, 2, 3], 'Bag': [4, 5], 'Cap': [6, 7], 'Car': [8, 9, 10, 11],
48
- 'Chair': [12, 13, 14, 15], 'Earphone': [16, 17, 18], 'Guitar': [19, 20, 21],
49
- 'Knife': [22, 23], 'Lamp': [24, 25, 26, 27], 'Laptop': [28, 29],
50
- 'Motorbike': [30, 31, 32, 33, 34, 35], 'Mug': [36, 37], 'Pistol': [38, 39, 40],
51
- 'Rocket': [41, 42, 43], 'Skateboard': [44, 45, 46], 'Table': [47, 48, 49]
52
- }
53
 
54
- # --- Modeli Yükleme ---
55
  device = torch.device("cpu")
56
  model = HierarchicalFPSCliffordNet(base_channels=12).to(device)
57
-
58
  try:
59
  checkpoint = torch.load("best_all_categories_clifford.pt", map_location=device, weights_only=True)
60
- clean_sd = {k.replace("_orig_mod.", ""): v for k, v in checkpoint.items()}
61
- model.load_state_dict(clean_sd, strict=False)
62
  model.eval()
63
- print("Model ağırlıkları başarıyla yüklendi!")
64
- except Exception as e:
65
- print("Ağırlık yükleme hatası (Demo rastgele ağırlıklarla çalışacak):", e)
66
-
67
- # --- Veri Hazırlama Fonksiyonu ---
68
- def process_point_cloud(file_path, category_name):
69
- mesh = trimesh.load(file_path, force='mesh')
70
- points = np.array(mesh.vertices)
71
-
72
- if len(points) > 1024:
73
- idx = np.random.choice(len(points), 1024, replace=False)
74
- points = points[idx]
75
- elif len(points) < 1024:
76
- idx = np.random.choice(len(points), 1024, replace=True)
77
- points = points[idx]
78
-
79
- pos = torch.tensor(points, dtype=torch.float32)
80
- pos = pos - pos.mean(dim=0, keepdim=True)
81
- scale = pos.norm(dim=1).max().clamp(min=1e-8)
82
- pos = pos / scale
83
-
84
- batch = torch.zeros(pos.size(0), dtype=torch.long)
85
- cat_idx = torch.tensor([CATEGORIES[category_name]], dtype=torch.long)
86
-
87
- # Saf PyTorch ile Hızlı Hesaplamalar (C++ Yok)
88
- fps_idx_1 = fps_pure(pos, ratio=0.5)
89
- pos2 = pos[fps_idx_1]
90
- batch2 = torch.zeros(pos2.size(0), dtype=torch.long)
91
-
92
- fps_idx_2 = fps_pure(pos2, ratio=0.25)
93
- pos3 = pos2[fps_idx_2]
94
- batch3 = torch.zeros(pos3.size(0), dtype=torch.long)
95
-
96
- edge_index_1 = radius_graph_pure(pos, r=0.15, max_num_neighbors=16)
97
- edge_index_2 = radius_graph_pure(pos2, r=0.30, max_num_neighbors=32)
98
-
99
- assign_index_32 = knn_pure(pos3, pos2, k=3)
100
- assign_index_21 = knn_pure(pos2, pos, k=3)
101
 
102
- dummy_x = torch.zeros(pos3.size(0), 1)
103
- _, x_dense_mask = to_dense_batch(dummy_x, batch3, max_num_nodes=int(torch.bincount(batch3).max().item()))
104
-
105
- return pos, batch, cat_idx, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3, edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask
106
-
107
- # --- Çıkarım Fonksiyonu ---
108
  def predict(file, category_name):
109
- if file is None:
110
- return None
111
-
112
- (pos, batch, cat_idx, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3,
113
- edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask) = process_point_cloud(file.name, category_name)
 
114
 
 
 
 
 
 
 
 
115
  with torch.no_grad():
116
- logits = model(pos, batch, cat_idx, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3,
117
- edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask)
118
-
119
- valid_parts = SEG_CLASSES[category_name]
120
- mask = torch.zeros_like(logits, dtype=torch.bool)
121
- mask[:, valid_parts] = True
122
- logits = logits.masked_fill(~mask, -10000.0)
123
- preds = logits.argmax(dim=-1).numpy()
124
-
125
- pos_np = pos.numpy()
126
- fig = go.Figure(data=[go.Scatter3d(
127
- x=pos_np[:, 0], y=pos_np[:, 1], z=pos_np[:, 2],
128
- mode='markers',
129
- marker=dict(size=4, color=preds, colorscale='Viridis', opacity=0.9, line=dict(width=0.5, color='black'))
130
- )])
131
-
132
- fig.update_layout(margin=dict(l=0, r=0, b=0, t=0), scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)))
133
  return fig
134
 
135
- # --- Gradio Arayüzü ---
136
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
137
- gr.Markdown("# 🚀 Clifford-Dirac 3D Part Segmentation (0.15M Params)")
138
- gr.Markdown("Bu model, Geometrik Cebir kullanarak 3D objeleri yüksek hassasiyetle parçalara ayırır. Bir .obj veya .ply dosyası yükleyin ve sonucu keşfedin!")
139
-
140
  with gr.Row():
141
- with gr.Column(scale=1):
142
- input_file = gr.File(label="3D Model Yükle (.obj veya .ply)")
143
- cat_dropdown = gr.Dropdown(choices=list(CATEGORIES.keys()), label="Obje Kategorisi", value="Chair")
144
- run_btn = gr.Button("Segmentasyonu Başlat", variant="primary")
145
-
146
- with gr.Column(scale=2):
147
- output_plot = gr.Plot(label="İnteraktif Segmentasyon Sonucu")
148
-
149
- run_btn.click(predict, inputs=[input_file, cat_dropdown], outputs=output_plot)
150
-
151
  demo.launch()
 
6
  from torch_geometric.utils import to_dense_batch
7
  from model import HierarchicalFPSCliffordNet
8
 
9
+ # --- SAF PYTORCH C++ YEDEKLERİ ---
10
+ def knn_pure(x, y, k=3):
11
  dist = torch.cdist(x, y)
12
  _, topk_idx = torch.topk(dist, k, dim=1, largest=False)
13
+ return torch.stack([torch.arange(x.size(0)).view(-1, 1).expand(-1, k).reshape(-1), topk_idx.reshape(-1)], dim=0)
 
 
14
 
15
+ def radius_graph_pure(pos, r, max_n=16):
16
  dist = torch.cdist(pos, pos)
17
+ _, target = torch.topk(dist, min(max_n, pos.size(0)), dim=1, largest=False)
18
+ source = torch.arange(pos.size(0)).view(-1, 1).expand(-1, min(max_n, pos.size(0)))
19
+ mask = dist[source, target] < r
20
+ return torch.stack([source[mask], target[mask]], dim=0)
 
 
21
 
22
  def fps_pure(pos, ratio=0.5):
23
+ k = max(1, int(pos.size(0) * ratio))
24
+ idx = torch.zeros(k, dtype=torch.long)
25
+ dist = torch.full((pos.size(0),), 1e10)
 
26
  farthest = 0
27
  for i in range(k):
28
  idx[i] = farthest
29
+ dist = torch.min(dist, torch.cdist(pos, pos[farthest].view(1, 3)).squeeze())
 
 
30
  farthest = torch.argmax(dist).item()
31
  return idx
32
 
33
+ # --- Ayarlar ---
34
+ CATEGORIES = {'Airplane': 0, 'Bag': 1, 'Cap': 2, 'Car': 3, 'Chair': 4, 'Earphone': 5, 'Guitar': 6, 'Knife': 7, 'Lamp': 8, 'Laptop': 9, 'Motorbike': 10, 'Mug': 11, 'Pistol': 12, 'Rocket': 13, 'Skateboard': 14, 'Table': 15}
35
+ SEG_CLASSES = {'Airplane': list(range(0,4)), 'Chair': list(range(12,16)), 'Guitar': list(range(19,22)), 'Laptop': list(range(28,30))} # Örnek kısıtlı liste
 
 
 
 
 
 
 
 
 
 
36
 
 
37
  device = torch.device("cpu")
38
  model = HierarchicalFPSCliffordNet(base_channels=12).to(device)
 
39
  try:
40
  checkpoint = torch.load("best_all_categories_clifford.pt", map_location=device, weights_only=True)
41
+ model.load_state_dict({k.replace("_orig_mod.", ""): v for k, v in checkpoint.items()}, strict=False)
 
42
  model.eval()
43
+ except: print("Ağırlıklar bulunamadı.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
 
 
 
 
 
 
45
  def predict(file, category_name):
46
+ if not file: return None
47
+ mesh = trimesh.load(file.name, force='mesh')
48
+ p = np.array(mesh.vertices)
49
+ p = p[np.random.choice(len(p), 1024, replace=len(p)<1024)]
50
+ pos = torch.tensor(p, dtype=torch.float32)
51
+ pos = (pos - pos.mean(0)) / pos.norm(dim=1).max().clamp(1e-8)
52
 
53
+ # İşleme
54
+ f1 = fps_pure(pos, 0.5); p2 = pos[f1]
55
+ f2 = fps_pure(p2, 0.25); p3 = p2[f2]
56
+ e1 = radius_graph_pure(pos, 0.15); e2 = radius_graph_pure(p2, 0.30)
57
+ a32 = knn_pure(p3, p2); a21 = knn_pure(p2, pos)
58
+ _, m = to_dense_batch(torch.zeros(p3.size(0), 1), torch.zeros(p3.size(0), dtype=torch.long))
59
+
60
  with torch.no_grad():
61
+ out = model(pos, torch.zeros(1024, dtype=torch.long), torch.tensor([CATEGORIES[category_name]]), f1, f2, p2, torch.zeros(p2.size(0), dtype=torch.long), p3, torch.zeros(p3.size(0), dtype=torch.long), e1, e2, a32, a21, m)
62
+ res = out.argmax(-1).numpy()
63
+
64
+ fig = go.Figure(data=[go.Scatter3d(x=p[:,0], y=p[:,1], z=p[:,2], mode='markers', marker=dict(size=4, color=res, colorscale='Viridis', opacity=0.8))])
65
+ fig.update_layout(margin=dict(l=0,r=0,b=0,t=0), scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)))
 
 
 
 
 
 
 
 
 
 
 
 
66
  return fig
67
 
 
68
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
69
+ gr.Markdown("# 🚀 Clifford-Dirac 3D Segmentation (0.15M Params)")
 
 
70
  with gr.Row():
71
+ with gr.Column():
72
+ inp = gr.Model3D(label="Model Yükle (.glb, .obj)")
73
+ cat = gr.Dropdown(choices=list(CATEGORIES.keys()), label="Kategori", value="Chair")
74
+ btn = gr.Button("Tahmin Et", variant="primary")
75
+ out = gr.Plot(label="Sonuç")
76
+ btn.click(predict, [inp, cat], out)
 
 
 
 
77
  demo.launch()