Spaces:
Sleeping
Sleeping
File size: 6,119 Bytes
3e83674 99c5bcb 25ba4d0 99c5bcb 25ba4d0 fd1bd17 25ba4d0 fd1bd17 99c5bcb fd1bd17 99c5bcb fd1bd17 99c5bcb fd1bd17 99c5bcb fd1bd17 794800b 3e83674 fd1bd17 3e83674 99c5bcb 3e83674 794800b 3e83674 99c5bcb f7fc68e 99c5bcb 3e83674 99c5bcb 3e83674 99c5bcb 794800b 99c5bcb 794800b 99c5bcb 3e83674 f7fc68e 99c5bcb 794800b 3e83674 99c5bcb 794800b 99c5bcb 794800b b7d5f4a 794800b f7fc68e 3e83674 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | import gradio as gr
import torch
import numpy as np
import trimesh
import plotly.graph_objects as go
from torch_geometric.utils import to_dense_batch
from model import HierarchicalFPSCliffordNet
# --- SAF PYTORCH C++ YEDEKLERİ ---
def knn_pure(x, y, k=3):
dist = torch.cdist(y, x)
_, topk_idx = torch.topk(dist, k, dim=1, largest=False)
idx_y = torch.arange(y.size(0), device=y.device).view(-1, 1).expand(-1, k).reshape(-1)
idx_x = topk_idx.reshape(-1)
return torch.stack([idx_y, idx_x], dim=0)
def radius_graph_pure(pos, r, max_n=16):
dist = torch.cdist(pos, pos)
_, target = torch.topk(dist, min(max_n, pos.size(0)), dim=1, largest=False)
source = torch.arange(pos.size(0)).view(-1, 1).expand(-1, min(max_n, pos.size(0)))
mask = dist[source, target] < r
return torch.stack([source[mask], target[mask]], dim=0)
def fps_pure(pos, ratio=0.5):
k = max(1, int(pos.size(0) * ratio))
idx = torch.zeros(k, dtype=torch.long)
dist = torch.full((pos.size(0),), 1e10)
farthest = 0
for i in range(k):
idx[i] = farthest
dist = torch.min(dist, torch.cdist(pos, pos[farthest].view(1, 3)).squeeze())
farthest = torch.argmax(dist).item()
return idx
# --- Ayarlar ve Veri Seti Sözlükleri ---
CATEGORIES = {
'Airplane': 0, 'Bag': 1, 'Cap': 2, 'Car': 3, 'Chair': 4, 'Earphone': 5, 'Guitar': 6,
'Knife': 7, 'Lamp': 8, 'Laptop': 9, 'Motorbike': 10, 'Mug': 11, 'Pistol': 12,
'Rocket': 13, 'Skateboard': 14, 'Table': 15
}
# Kategoriye göre geçerli parça ID'leri (Logit Masking için tam liste)
SEG_CLASSES = {
'Airplane': [0, 1, 2, 3], 'Bag': [4, 5], 'Cap': [6, 7], 'Car': [8, 9, 10, 11],
'Chair': [12, 13, 14, 15], 'Earphone': [16, 17, 18], 'Guitar': [19, 20, 21],
'Knife': [22, 23], 'Lamp': [24, 25, 26, 27], 'Laptop': [28, 29],
'Motorbike': [30, 31, 32, 33, 34, 35], 'Mug': [36, 37], 'Pistol': [38, 39, 40],
'Rocket': [41, 42, 43], 'Skateboard': [44, 45, 46], 'Table': [47, 48, 49]
}
# Hover (Fare ile üstüne gelince) gösterilecek isimler (Tam liste)
PART_NAMES = {
0: 'Airplane Body', 1: 'Airplane Wing', 2: 'Airplane Tail', 3: 'Airplane Engine',
4: 'Bag Handle', 5: 'Bag Body', 6: 'Cap Peak', 7: 'Cap Panels',
8: 'Car Roof', 9: 'Car Hood', 10: 'Car Bumper', 11: 'Car Wheels',
12: 'Chair Back', 13: 'Chair Seat', 14: 'Chair Leg', 15: 'Chair Armrest',
16: 'Earphone Earphone', 17: 'Earphone Headband', 18: 'Earphone Wire',
19: 'Guitar Head', 20: 'Guitar Body', 21: 'Guitar Neck',
22: 'Knife Blade', 23: 'Knife Handle',
24: 'Lamp Base', 25: 'Lamp Shade', 26: 'Lamp Tube', 27: 'Lamp Wire',
28: 'Laptop Keyboard', 29: 'Laptop Screen',
30: 'Motorbike Gas Tank', 31: 'Motorbike Seat', 32: 'Motorbike Wheels', 33: 'Motorbike Handles', 34: 'Motorbike Light', 35: 'Motorbike Engine',
36: 'Mug Handle', 37: 'Mug Cup',
38: 'Pistol Barrel', 39: 'Pistol Handle', 40: 'Pistol Trigger',
41: 'Rocket Body', 42: 'Rocket Fin', 43: 'Rocket Nose',
44: 'Skateboard Wheels', 45: 'Skateboard Deck', 46: 'Skateboard Steering',
47: 'Table Top', 48: 'Table Leg', 49: 'Table Base'
}
device = torch.device("cpu")
model = HierarchicalFPSCliffordNet(base_channels=12).to(device)
try:
checkpoint = torch.load("best_all_categories_clifford.pt", map_location=device, weights_only=True)
model.load_state_dict({k.replace("_orig_mod.", ""): v for k, v in checkpoint.items()}, strict=False)
model.eval()
except:
print("Ağırlıklar bulunamadı.")
def predict(file, category_name):
if not file: return None
file_path = file if isinstance(file, str) else file.name
mesh = trimesh.load(file_path, force='mesh')
p = np.array(mesh.vertices)
p = p[np.random.choice(len(p), 1024, replace=len(p)<1024)]
pos = torch.tensor(p, dtype=torch.float32)
pos = (pos - pos.mean(0)) / pos.norm(dim=1).max().clamp(1e-8)
# İşleme
f1 = fps_pure(pos, 0.5); p2 = pos[f1]
f2 = fps_pure(p2, 0.25); p3 = p2[f2]
e1 = radius_graph_pure(pos, 0.15); e2 = radius_graph_pure(p2, 0.30)
a32 = knn_pure(p3, p2); a21 = knn_pure(p2, pos)
_, m = to_dense_batch(torch.zeros(p3.size(0), 1), torch.zeros(p3.size(0), dtype=torch.long))
with torch.no_grad():
out = model(pos, torch.zeros(1024, dtype=torch.long), torch.tensor([CATEGORIES[category_name]]), f1, f2, p2, torch.zeros(p2.size(0), dtype=torch.long), p3, torch.zeros(p3.size(0), dtype=torch.long), e1, e2, a32, a21, m)
# Kategori Filtreleme (Logit Masking) Eklenmiştir
valid_parts = SEG_CLASSES[category_name]
mask = torch.zeros_like(out, dtype=torch.bool)
mask[:, valid_parts] = True
out = out.masked_fill(~mask, -10000.0)
res = out.argmax(-1).numpy()
# Noktaların üzerine gelindiğinde görünecek metinler hazırlanıyor
hover_texts = [f"Parça: {PART_NAMES.get(pr, f'Bilinmeyen_{pr}')}" for pr in res]
# Plotly figürüne text ve hoverinfo parametreleri eklendi
fig = go.Figure(data=[go.Scatter3d(
x=p[:,0], y=p[:,1], z=p[:,2],
mode='markers',
text=hover_texts,
hoverinfo='text',
marker=dict(size=4, color=res, colorscale='Viridis', opacity=0.8)
)])
fig.update_layout(margin=dict(l=0,r=0,b=0,t=0), scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)))
return fig
with gr.Blocks() as demo:
gr.Markdown("# 🚀 Clifford-Dirac 3D Segmentation (0.15M Params)")
with gr.Row():
with gr.Column():
inp = gr.Model3D(label="Model Yükle (.glb, .obj)")
cat = gr.Dropdown(choices=list(CATEGORIES.keys()), label="Kategori", value="Chair")
btn = gr.Button("Tahmin Et", variant="primary")
out = gr.Plot(label="Sonuç")
btn.click(predict, [inp, cat], out)
# İsteğe bağlı: Sayfanın en altına profesyonel bir not eklendi
gr.Markdown("---")
gr.Markdown("""
### 📊 Train GitHub
https://github.com/yusuftiryaki/clifford-shapenet-segmentation.git
""")
demo.launch() |