import gradio as gr import torch import numpy as np import trimesh import plotly.graph_objects as go from torch_geometric.utils import to_dense_batch from model import HierarchicalFPSCliffordNet # --- SAF PYTORCH C++ YEDEKLERİ --- def knn_pure(x, y, k=3): dist = torch.cdist(y, x) _, topk_idx = torch.topk(dist, k, dim=1, largest=False) idx_y = torch.arange(y.size(0), device=y.device).view(-1, 1).expand(-1, k).reshape(-1) idx_x = topk_idx.reshape(-1) return torch.stack([idx_y, idx_x], dim=0) def radius_graph_pure(pos, r, max_n=16): dist = torch.cdist(pos, pos) _, target = torch.topk(dist, min(max_n, pos.size(0)), dim=1, largest=False) source = torch.arange(pos.size(0)).view(-1, 1).expand(-1, min(max_n, pos.size(0))) mask = dist[source, target] < r return torch.stack([source[mask], target[mask]], dim=0) def fps_pure(pos, ratio=0.5): k = max(1, int(pos.size(0) * ratio)) idx = torch.zeros(k, dtype=torch.long) dist = torch.full((pos.size(0),), 1e10) farthest = 0 for i in range(k): idx[i] = farthest dist = torch.min(dist, torch.cdist(pos, pos[farthest].view(1, 3)).squeeze()) farthest = torch.argmax(dist).item() return idx # --- Ayarlar ve Veri Seti Sözlükleri --- CATEGORIES = { 'Airplane': 0, 'Bag': 1, 'Cap': 2, 'Car': 3, 'Chair': 4, 'Earphone': 5, 'Guitar': 6, 'Knife': 7, 'Lamp': 8, 'Laptop': 9, 'Motorbike': 10, 'Mug': 11, 'Pistol': 12, 'Rocket': 13, 'Skateboard': 14, 'Table': 15 } # Kategoriye göre geçerli parça ID'leri (Logit Masking için tam liste) SEG_CLASSES = { 'Airplane': [0, 1, 2, 3], 'Bag': [4, 5], 'Cap': [6, 7], 'Car': [8, 9, 10, 11], 'Chair': [12, 13, 14, 15], 'Earphone': [16, 17, 18], 'Guitar': [19, 20, 21], 'Knife': [22, 23], 'Lamp': [24, 25, 26, 27], 'Laptop': [28, 29], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Mug': [36, 37], 'Pistol': [38, 39, 40], 'Rocket': [41, 42, 43], 'Skateboard': [44, 45, 46], 'Table': [47, 48, 49] } # Hover (Fare ile üstüne gelince) gösterilecek isimler (Tam liste) PART_NAMES = { 0: 'Airplane Body', 1: 'Airplane Wing', 2: 'Airplane Tail', 3: 'Airplane Engine', 4: 'Bag Handle', 5: 'Bag Body', 6: 'Cap Peak', 7: 'Cap Panels', 8: 'Car Roof', 9: 'Car Hood', 10: 'Car Bumper', 11: 'Car Wheels', 12: 'Chair Back', 13: 'Chair Seat', 14: 'Chair Leg', 15: 'Chair Armrest', 16: 'Earphone Earphone', 17: 'Earphone Headband', 18: 'Earphone Wire', 19: 'Guitar Head', 20: 'Guitar Body', 21: 'Guitar Neck', 22: 'Knife Blade', 23: 'Knife Handle', 24: 'Lamp Base', 25: 'Lamp Shade', 26: 'Lamp Tube', 27: 'Lamp Wire', 28: 'Laptop Keyboard', 29: 'Laptop Screen', 30: 'Motorbike Gas Tank', 31: 'Motorbike Seat', 32: 'Motorbike Wheels', 33: 'Motorbike Handles', 34: 'Motorbike Light', 35: 'Motorbike Engine', 36: 'Mug Handle', 37: 'Mug Cup', 38: 'Pistol Barrel', 39: 'Pistol Handle', 40: 'Pistol Trigger', 41: 'Rocket Body', 42: 'Rocket Fin', 43: 'Rocket Nose', 44: 'Skateboard Wheels', 45: 'Skateboard Deck', 46: 'Skateboard Steering', 47: 'Table Top', 48: 'Table Leg', 49: 'Table Base' } device = torch.device("cpu") model = HierarchicalFPSCliffordNet(base_channels=12).to(device) try: checkpoint = torch.load("best_all_categories_clifford.pt", map_location=device, weights_only=True) model.load_state_dict({k.replace("_orig_mod.", ""): v for k, v in checkpoint.items()}, strict=False) model.eval() except: print("Ağırlıklar bulunamadı.") def predict(file, category_name): if not file: return None file_path = file if isinstance(file, str) else file.name mesh = trimesh.load(file_path, force='mesh') p = np.array(mesh.vertices) p = p[np.random.choice(len(p), 1024, replace=len(p)<1024)] pos = torch.tensor(p, dtype=torch.float32) pos = (pos - pos.mean(0)) / pos.norm(dim=1).max().clamp(1e-8) # İşleme f1 = fps_pure(pos, 0.5); p2 = pos[f1] f2 = fps_pure(p2, 0.25); p3 = p2[f2] e1 = radius_graph_pure(pos, 0.15); e2 = radius_graph_pure(p2, 0.30) a32 = knn_pure(p3, p2); a21 = knn_pure(p2, pos) _, m = to_dense_batch(torch.zeros(p3.size(0), 1), torch.zeros(p3.size(0), dtype=torch.long)) with torch.no_grad(): out = model(pos, torch.zeros(1024, dtype=torch.long), torch.tensor([CATEGORIES[category_name]]), f1, f2, p2, torch.zeros(p2.size(0), dtype=torch.long), p3, torch.zeros(p3.size(0), dtype=torch.long), e1, e2, a32, a21, m) # Kategori Filtreleme (Logit Masking) Eklenmiştir valid_parts = SEG_CLASSES[category_name] mask = torch.zeros_like(out, dtype=torch.bool) mask[:, valid_parts] = True out = out.masked_fill(~mask, -10000.0) res = out.argmax(-1).numpy() # Noktaların üzerine gelindiğinde görünecek metinler hazırlanıyor hover_texts = [f"Parça: {PART_NAMES.get(pr, f'Bilinmeyen_{pr}')}" for pr in res] # Plotly figürüne text ve hoverinfo parametreleri eklendi fig = go.Figure(data=[go.Scatter3d( x=p[:,0], y=p[:,1], z=p[:,2], mode='markers', text=hover_texts, hoverinfo='text', marker=dict(size=4, color=res, colorscale='Viridis', opacity=0.8) )]) fig.update_layout(margin=dict(l=0,r=0,b=0,t=0), scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False))) return fig with gr.Blocks() as demo: gr.Markdown("# 🚀 Clifford-Dirac 3D Segmentation (0.15M Params)") with gr.Row(): with gr.Column(): inp = gr.Model3D(label="Model Yükle (.glb, .obj)") cat = gr.Dropdown(choices=list(CATEGORIES.keys()), label="Kategori", value="Chair") btn = gr.Button("Tahmin Et", variant="primary") out = gr.Plot(label="Sonuç") btn.click(predict, [inp, cat], out) # İsteğe bağlı: Sayfanın en altına profesyonel bir not eklendi gr.Markdown("---") gr.Markdown(""" ### 📊 Train GitHub https://github.com/yusuftiryaki/clifford-shapenet-segmentation.git """) demo.launch()