Spaces:
Sleeping
Sleeping
Commit ·
99c5bcb
1
Parent(s): fd1bd17
glb support
Browse files
app.py
CHANGED
|
@@ -6,146 +6,72 @@ import plotly.graph_objects as go
|
|
| 6 |
from torch_geometric.utils import to_dense_batch
|
| 7 |
from model import HierarchicalFPSCliffordNet
|
| 8 |
|
| 9 |
-
# --- SAF PYTORCH C++ YEDEKLERİ
|
| 10 |
-
def knn_pure(x, y, k):
|
| 11 |
dist = torch.cdist(x, y)
|
| 12 |
_, topk_idx = torch.topk(dist, k, dim=1, largest=False)
|
| 13 |
-
|
| 14 |
-
row = topk_idx.reshape(-1)
|
| 15 |
-
return torch.stack([row, col], dim=0)
|
| 16 |
|
| 17 |
-
def radius_graph_pure(pos, r,
|
| 18 |
dist = torch.cdist(pos, pos)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
valid = dist[source, target] < r
|
| 24 |
-
return torch.stack([source[valid], target[valid]], dim=0)
|
| 25 |
|
| 26 |
def fps_pure(pos, ratio=0.5):
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
dist = torch.full((n,), 1e10, device=pos.device)
|
| 31 |
farthest = 0
|
| 32 |
for i in range(k):
|
| 33 |
idx[i] = farthest
|
| 34 |
-
|
| 35 |
-
dist_to_farthest = torch.cdist(pos, farthest_point).squeeze(-1)
|
| 36 |
-
dist = torch.min(dist, dist_to_farthest)
|
| 37 |
farthest = torch.argmax(dist).item()
|
| 38 |
return idx
|
| 39 |
|
| 40 |
-
# ---
|
| 41 |
-
CATEGORIES = {
|
| 42 |
-
|
| 43 |
-
'Knife': 7, 'Lamp': 8, 'Laptop': 9, 'Motorbike': 10, 'Mug': 11, 'Pistol': 12,
|
| 44 |
-
'Rocket': 13, 'Skateboard': 14, 'Table': 15
|
| 45 |
-
}
|
| 46 |
-
SEG_CLASSES = {
|
| 47 |
-
'Airplane': [0, 1, 2, 3], 'Bag': [4, 5], 'Cap': [6, 7], 'Car': [8, 9, 10, 11],
|
| 48 |
-
'Chair': [12, 13, 14, 15], 'Earphone': [16, 17, 18], 'Guitar': [19, 20, 21],
|
| 49 |
-
'Knife': [22, 23], 'Lamp': [24, 25, 26, 27], 'Laptop': [28, 29],
|
| 50 |
-
'Motorbike': [30, 31, 32, 33, 34, 35], 'Mug': [36, 37], 'Pistol': [38, 39, 40],
|
| 51 |
-
'Rocket': [41, 42, 43], 'Skateboard': [44, 45, 46], 'Table': [47, 48, 49]
|
| 52 |
-
}
|
| 53 |
|
| 54 |
-
# --- Modeli Yükleme ---
|
| 55 |
device = torch.device("cpu")
|
| 56 |
model = HierarchicalFPSCliffordNet(base_channels=12).to(device)
|
| 57 |
-
|
| 58 |
try:
|
| 59 |
checkpoint = torch.load("best_all_categories_clifford.pt", map_location=device, weights_only=True)
|
| 60 |
-
|
| 61 |
-
model.load_state_dict(clean_sd, strict=False)
|
| 62 |
model.eval()
|
| 63 |
-
|
| 64 |
-
except Exception as e:
|
| 65 |
-
print("Ağırlık yükleme hatası (Demo rastgele ağırlıklarla çalışacak):", e)
|
| 66 |
-
|
| 67 |
-
# --- Veri Hazırlama Fonksiyonu ---
|
| 68 |
-
def process_point_cloud(file_path, category_name):
|
| 69 |
-
mesh = trimesh.load(file_path, force='mesh')
|
| 70 |
-
points = np.array(mesh.vertices)
|
| 71 |
-
|
| 72 |
-
if len(points) > 1024:
|
| 73 |
-
idx = np.random.choice(len(points), 1024, replace=False)
|
| 74 |
-
points = points[idx]
|
| 75 |
-
elif len(points) < 1024:
|
| 76 |
-
idx = np.random.choice(len(points), 1024, replace=True)
|
| 77 |
-
points = points[idx]
|
| 78 |
-
|
| 79 |
-
pos = torch.tensor(points, dtype=torch.float32)
|
| 80 |
-
pos = pos - pos.mean(dim=0, keepdim=True)
|
| 81 |
-
scale = pos.norm(dim=1).max().clamp(min=1e-8)
|
| 82 |
-
pos = pos / scale
|
| 83 |
-
|
| 84 |
-
batch = torch.zeros(pos.size(0), dtype=torch.long)
|
| 85 |
-
cat_idx = torch.tensor([CATEGORIES[category_name]], dtype=torch.long)
|
| 86 |
-
|
| 87 |
-
# Saf PyTorch ile Hızlı Hesaplamalar (C++ Yok)
|
| 88 |
-
fps_idx_1 = fps_pure(pos, ratio=0.5)
|
| 89 |
-
pos2 = pos[fps_idx_1]
|
| 90 |
-
batch2 = torch.zeros(pos2.size(0), dtype=torch.long)
|
| 91 |
-
|
| 92 |
-
fps_idx_2 = fps_pure(pos2, ratio=0.25)
|
| 93 |
-
pos3 = pos2[fps_idx_2]
|
| 94 |
-
batch3 = torch.zeros(pos3.size(0), dtype=torch.long)
|
| 95 |
-
|
| 96 |
-
edge_index_1 = radius_graph_pure(pos, r=0.15, max_num_neighbors=16)
|
| 97 |
-
edge_index_2 = radius_graph_pure(pos2, r=0.30, max_num_neighbors=32)
|
| 98 |
-
|
| 99 |
-
assign_index_32 = knn_pure(pos3, pos2, k=3)
|
| 100 |
-
assign_index_21 = knn_pure(pos2, pos, k=3)
|
| 101 |
|
| 102 |
-
dummy_x = torch.zeros(pos3.size(0), 1)
|
| 103 |
-
_, x_dense_mask = to_dense_batch(dummy_x, batch3, max_num_nodes=int(torch.bincount(batch3).max().item()))
|
| 104 |
-
|
| 105 |
-
return pos, batch, cat_idx, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3, edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask
|
| 106 |
-
|
| 107 |
-
# --- Çıkarım Fonksiyonu ---
|
| 108 |
def predict(file, category_name):
|
| 109 |
-
if file
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
with torch.no_grad():
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
mask[:, valid_parts] = True
|
| 122 |
-
logits = logits.masked_fill(~mask, -10000.0)
|
| 123 |
-
preds = logits.argmax(dim=-1).numpy()
|
| 124 |
-
|
| 125 |
-
pos_np = pos.numpy()
|
| 126 |
-
fig = go.Figure(data=[go.Scatter3d(
|
| 127 |
-
x=pos_np[:, 0], y=pos_np[:, 1], z=pos_np[:, 2],
|
| 128 |
-
mode='markers',
|
| 129 |
-
marker=dict(size=4, color=preds, colorscale='Viridis', opacity=0.9, line=dict(width=0.5, color='black'))
|
| 130 |
-
)])
|
| 131 |
-
|
| 132 |
-
fig.update_layout(margin=dict(l=0, r=0, b=0, t=0), scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)))
|
| 133 |
return fig
|
| 134 |
|
| 135 |
-
# --- Gradio Arayüzü ---
|
| 136 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 137 |
-
gr.Markdown("# 🚀 Clifford-Dirac 3D
|
| 138 |
-
gr.Markdown("Bu model, Geometrik Cebir kullanarak 3D objeleri yüksek hassasiyetle parçalara ayırır. Bir .obj veya .ply dosyası yükleyin ve sonucu keşfedin!")
|
| 139 |
-
|
| 140 |
with gr.Row():
|
| 141 |
-
with gr.Column(
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
output_plot = gr.Plot(label="İnteraktif Segmentasyon Sonucu")
|
| 148 |
-
|
| 149 |
-
run_btn.click(predict, inputs=[input_file, cat_dropdown], outputs=output_plot)
|
| 150 |
-
|
| 151 |
demo.launch()
|
|
|
|
| 6 |
from torch_geometric.utils import to_dense_batch
|
| 7 |
from model import HierarchicalFPSCliffordNet
|
| 8 |
|
| 9 |
+
# --- SAF PYTORCH C++ YEDEKLERİ ---
|
| 10 |
+
def knn_pure(x, y, k=3):
|
| 11 |
dist = torch.cdist(x, y)
|
| 12 |
_, topk_idx = torch.topk(dist, k, dim=1, largest=False)
|
| 13 |
+
return torch.stack([torch.arange(x.size(0)).view(-1, 1).expand(-1, k).reshape(-1), topk_idx.reshape(-1)], dim=0)
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
def radius_graph_pure(pos, r, max_n=16):
|
| 16 |
dist = torch.cdist(pos, pos)
|
| 17 |
+
_, target = torch.topk(dist, min(max_n, pos.size(0)), dim=1, largest=False)
|
| 18 |
+
source = torch.arange(pos.size(0)).view(-1, 1).expand(-1, min(max_n, pos.size(0)))
|
| 19 |
+
mask = dist[source, target] < r
|
| 20 |
+
return torch.stack([source[mask], target[mask]], dim=0)
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def fps_pure(pos, ratio=0.5):
|
| 23 |
+
k = max(1, int(pos.size(0) * ratio))
|
| 24 |
+
idx = torch.zeros(k, dtype=torch.long)
|
| 25 |
+
dist = torch.full((pos.size(0),), 1e10)
|
|
|
|
| 26 |
farthest = 0
|
| 27 |
for i in range(k):
|
| 28 |
idx[i] = farthest
|
| 29 |
+
dist = torch.min(dist, torch.cdist(pos, pos[farthest].view(1, 3)).squeeze())
|
|
|
|
|
|
|
| 30 |
farthest = torch.argmax(dist).item()
|
| 31 |
return idx
|
| 32 |
|
| 33 |
+
# --- Ayarlar ---
|
| 34 |
+
CATEGORIES = {'Airplane': 0, 'Bag': 1, 'Cap': 2, 'Car': 3, 'Chair': 4, 'Earphone': 5, 'Guitar': 6, 'Knife': 7, 'Lamp': 8, 'Laptop': 9, 'Motorbike': 10, 'Mug': 11, 'Pistol': 12, 'Rocket': 13, 'Skateboard': 14, 'Table': 15}
|
| 35 |
+
SEG_CLASSES = {'Airplane': list(range(0,4)), 'Chair': list(range(12,16)), 'Guitar': list(range(19,22)), 'Laptop': list(range(28,30))} # Örnek kısıtlı liste
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
|
|
|
| 37 |
device = torch.device("cpu")
|
| 38 |
model = HierarchicalFPSCliffordNet(base_channels=12).to(device)
|
|
|
|
| 39 |
try:
|
| 40 |
checkpoint = torch.load("best_all_categories_clifford.pt", map_location=device, weights_only=True)
|
| 41 |
+
model.load_state_dict({k.replace("_orig_mod.", ""): v for k, v in checkpoint.items()}, strict=False)
|
|
|
|
| 42 |
model.eval()
|
| 43 |
+
except: print("Ağırlıklar bulunamadı.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
def predict(file, category_name):
|
| 46 |
+
if not file: return None
|
| 47 |
+
mesh = trimesh.load(file.name, force='mesh')
|
| 48 |
+
p = np.array(mesh.vertices)
|
| 49 |
+
p = p[np.random.choice(len(p), 1024, replace=len(p)<1024)]
|
| 50 |
+
pos = torch.tensor(p, dtype=torch.float32)
|
| 51 |
+
pos = (pos - pos.mean(0)) / pos.norm(dim=1).max().clamp(1e-8)
|
| 52 |
|
| 53 |
+
# İşleme
|
| 54 |
+
f1 = fps_pure(pos, 0.5); p2 = pos[f1]
|
| 55 |
+
f2 = fps_pure(p2, 0.25); p3 = p2[f2]
|
| 56 |
+
e1 = radius_graph_pure(pos, 0.15); e2 = radius_graph_pure(p2, 0.30)
|
| 57 |
+
a32 = knn_pure(p3, p2); a21 = knn_pure(p2, pos)
|
| 58 |
+
_, m = to_dense_batch(torch.zeros(p3.size(0), 1), torch.zeros(p3.size(0), dtype=torch.long))
|
| 59 |
+
|
| 60 |
with torch.no_grad():
|
| 61 |
+
out = model(pos, torch.zeros(1024, dtype=torch.long), torch.tensor([CATEGORIES[category_name]]), f1, f2, p2, torch.zeros(p2.size(0), dtype=torch.long), p3, torch.zeros(p3.size(0), dtype=torch.long), e1, e2, a32, a21, m)
|
| 62 |
+
res = out.argmax(-1).numpy()
|
| 63 |
+
|
| 64 |
+
fig = go.Figure(data=[go.Scatter3d(x=p[:,0], y=p[:,1], z=p[:,2], mode='markers', marker=dict(size=4, color=res, colorscale='Viridis', opacity=0.8))])
|
| 65 |
+
fig.update_layout(margin=dict(l=0,r=0,b=0,t=0), scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
return fig
|
| 67 |
|
|
|
|
| 68 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 69 |
+
gr.Markdown("# 🚀 Clifford-Dirac 3D Segmentation (0.15M Params)")
|
|
|
|
|
|
|
| 70 |
with gr.Row():
|
| 71 |
+
with gr.Column():
|
| 72 |
+
inp = gr.Model3D(label="Model Yükle (.glb, .obj)")
|
| 73 |
+
cat = gr.Dropdown(choices=list(CATEGORIES.keys()), label="Kategori", value="Chair")
|
| 74 |
+
btn = gr.Button("Tahmin Et", variant="primary")
|
| 75 |
+
out = gr.Plot(label="Sonuç")
|
| 76 |
+
btn.click(predict, [inp, cat], out)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
demo.launch()
|