yusuf-tiryaki's picture
github url
b7d5f4a
import gradio as gr
import torch
import numpy as np
import trimesh
import plotly.graph_objects as go
from torch_geometric.utils import to_dense_batch
from model import HierarchicalFPSCliffordNet
# --- SAF PYTORCH C++ YEDEKLERİ ---
def knn_pure(x, y, k=3):
dist = torch.cdist(y, x)
_, topk_idx = torch.topk(dist, k, dim=1, largest=False)
idx_y = torch.arange(y.size(0), device=y.device).view(-1, 1).expand(-1, k).reshape(-1)
idx_x = topk_idx.reshape(-1)
return torch.stack([idx_y, idx_x], dim=0)
def radius_graph_pure(pos, r, max_n=16):
dist = torch.cdist(pos, pos)
_, target = torch.topk(dist, min(max_n, pos.size(0)), dim=1, largest=False)
source = torch.arange(pos.size(0)).view(-1, 1).expand(-1, min(max_n, pos.size(0)))
mask = dist[source, target] < r
return torch.stack([source[mask], target[mask]], dim=0)
def fps_pure(pos, ratio=0.5):
k = max(1, int(pos.size(0) * ratio))
idx = torch.zeros(k, dtype=torch.long)
dist = torch.full((pos.size(0),), 1e10)
farthest = 0
for i in range(k):
idx[i] = farthest
dist = torch.min(dist, torch.cdist(pos, pos[farthest].view(1, 3)).squeeze())
farthest = torch.argmax(dist).item()
return idx
# --- Ayarlar ve Veri Seti Sözlükleri ---
CATEGORIES = {
'Airplane': 0, 'Bag': 1, 'Cap': 2, 'Car': 3, 'Chair': 4, 'Earphone': 5, 'Guitar': 6,
'Knife': 7, 'Lamp': 8, 'Laptop': 9, 'Motorbike': 10, 'Mug': 11, 'Pistol': 12,
'Rocket': 13, 'Skateboard': 14, 'Table': 15
}
# Kategoriye göre geçerli parça ID'leri (Logit Masking için tam liste)
SEG_CLASSES = {
'Airplane': [0, 1, 2, 3], 'Bag': [4, 5], 'Cap': [6, 7], 'Car': [8, 9, 10, 11],
'Chair': [12, 13, 14, 15], 'Earphone': [16, 17, 18], 'Guitar': [19, 20, 21],
'Knife': [22, 23], 'Lamp': [24, 25, 26, 27], 'Laptop': [28, 29],
'Motorbike': [30, 31, 32, 33, 34, 35], 'Mug': [36, 37], 'Pistol': [38, 39, 40],
'Rocket': [41, 42, 43], 'Skateboard': [44, 45, 46], 'Table': [47, 48, 49]
}
# Hover (Fare ile üstüne gelince) gösterilecek isimler (Tam liste)
PART_NAMES = {
0: 'Airplane Body', 1: 'Airplane Wing', 2: 'Airplane Tail', 3: 'Airplane Engine',
4: 'Bag Handle', 5: 'Bag Body', 6: 'Cap Peak', 7: 'Cap Panels',
8: 'Car Roof', 9: 'Car Hood', 10: 'Car Bumper', 11: 'Car Wheels',
12: 'Chair Back', 13: 'Chair Seat', 14: 'Chair Leg', 15: 'Chair Armrest',
16: 'Earphone Earphone', 17: 'Earphone Headband', 18: 'Earphone Wire',
19: 'Guitar Head', 20: 'Guitar Body', 21: 'Guitar Neck',
22: 'Knife Blade', 23: 'Knife Handle',
24: 'Lamp Base', 25: 'Lamp Shade', 26: 'Lamp Tube', 27: 'Lamp Wire',
28: 'Laptop Keyboard', 29: 'Laptop Screen',
30: 'Motorbike Gas Tank', 31: 'Motorbike Seat', 32: 'Motorbike Wheels', 33: 'Motorbike Handles', 34: 'Motorbike Light', 35: 'Motorbike Engine',
36: 'Mug Handle', 37: 'Mug Cup',
38: 'Pistol Barrel', 39: 'Pistol Handle', 40: 'Pistol Trigger',
41: 'Rocket Body', 42: 'Rocket Fin', 43: 'Rocket Nose',
44: 'Skateboard Wheels', 45: 'Skateboard Deck', 46: 'Skateboard Steering',
47: 'Table Top', 48: 'Table Leg', 49: 'Table Base'
}
device = torch.device("cpu")
model = HierarchicalFPSCliffordNet(base_channels=12).to(device)
try:
checkpoint = torch.load("best_all_categories_clifford.pt", map_location=device, weights_only=True)
model.load_state_dict({k.replace("_orig_mod.", ""): v for k, v in checkpoint.items()}, strict=False)
model.eval()
except:
print("Ağırlıklar bulunamadı.")
def predict(file, category_name):
if not file: return None
file_path = file if isinstance(file, str) else file.name
mesh = trimesh.load(file_path, force='mesh')
p = np.array(mesh.vertices)
p = p[np.random.choice(len(p), 1024, replace=len(p)<1024)]
pos = torch.tensor(p, dtype=torch.float32)
pos = (pos - pos.mean(0)) / pos.norm(dim=1).max().clamp(1e-8)
# İşleme
f1 = fps_pure(pos, 0.5); p2 = pos[f1]
f2 = fps_pure(p2, 0.25); p3 = p2[f2]
e1 = radius_graph_pure(pos, 0.15); e2 = radius_graph_pure(p2, 0.30)
a32 = knn_pure(p3, p2); a21 = knn_pure(p2, pos)
_, m = to_dense_batch(torch.zeros(p3.size(0), 1), torch.zeros(p3.size(0), dtype=torch.long))
with torch.no_grad():
out = model(pos, torch.zeros(1024, dtype=torch.long), torch.tensor([CATEGORIES[category_name]]), f1, f2, p2, torch.zeros(p2.size(0), dtype=torch.long), p3, torch.zeros(p3.size(0), dtype=torch.long), e1, e2, a32, a21, m)
# Kategori Filtreleme (Logit Masking) Eklenmiştir
valid_parts = SEG_CLASSES[category_name]
mask = torch.zeros_like(out, dtype=torch.bool)
mask[:, valid_parts] = True
out = out.masked_fill(~mask, -10000.0)
res = out.argmax(-1).numpy()
# Noktaların üzerine gelindiğinde görünecek metinler hazırlanıyor
hover_texts = [f"Parça: {PART_NAMES.get(pr, f'Bilinmeyen_{pr}')}" for pr in res]
# Plotly figürüne text ve hoverinfo parametreleri eklendi
fig = go.Figure(data=[go.Scatter3d(
x=p[:,0], y=p[:,1], z=p[:,2],
mode='markers',
text=hover_texts,
hoverinfo='text',
marker=dict(size=4, color=res, colorscale='Viridis', opacity=0.8)
)])
fig.update_layout(margin=dict(l=0,r=0,b=0,t=0), scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)))
return fig
with gr.Blocks() as demo:
gr.Markdown("# 🚀 Clifford-Dirac 3D Segmentation (0.15M Params)")
with gr.Row():
with gr.Column():
inp = gr.Model3D(label="Model Yükle (.glb, .obj)")
cat = gr.Dropdown(choices=list(CATEGORIES.keys()), label="Kategori", value="Chair")
btn = gr.Button("Tahmin Et", variant="primary")
out = gr.Plot(label="Sonuç")
btn.click(predict, [inp, cat], out)
# İsteğe bağlı: Sayfanın en altına profesyonel bir not eklendi
gr.Markdown("---")
gr.Markdown("""
### 📊 Train GitHub
https://github.com/yusuftiryaki/clifford-shapenet-segmentation.git
""")
demo.launch()