Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import trimesh | |
| import plotly.graph_objects as go | |
| from torch_geometric.utils import to_dense_batch | |
| from model import HierarchicalFPSCliffordNet | |
| # --- SAF PYTORCH C++ YEDEKLERİ --- | |
| def knn_pure(x, y, k=3): | |
| dist = torch.cdist(y, x) | |
| _, topk_idx = torch.topk(dist, k, dim=1, largest=False) | |
| idx_y = torch.arange(y.size(0), device=y.device).view(-1, 1).expand(-1, k).reshape(-1) | |
| idx_x = topk_idx.reshape(-1) | |
| return torch.stack([idx_y, idx_x], dim=0) | |
| def radius_graph_pure(pos, r, max_n=16): | |
| dist = torch.cdist(pos, pos) | |
| _, target = torch.topk(dist, min(max_n, pos.size(0)), dim=1, largest=False) | |
| source = torch.arange(pos.size(0)).view(-1, 1).expand(-1, min(max_n, pos.size(0))) | |
| mask = dist[source, target] < r | |
| return torch.stack([source[mask], target[mask]], dim=0) | |
| def fps_pure(pos, ratio=0.5): | |
| k = max(1, int(pos.size(0) * ratio)) | |
| idx = torch.zeros(k, dtype=torch.long) | |
| dist = torch.full((pos.size(0),), 1e10) | |
| farthest = 0 | |
| for i in range(k): | |
| idx[i] = farthest | |
| dist = torch.min(dist, torch.cdist(pos, pos[farthest].view(1, 3)).squeeze()) | |
| farthest = torch.argmax(dist).item() | |
| return idx | |
| # --- Ayarlar ve Veri Seti Sözlükleri --- | |
| CATEGORIES = { | |
| 'Airplane': 0, 'Bag': 1, 'Cap': 2, 'Car': 3, 'Chair': 4, 'Earphone': 5, 'Guitar': 6, | |
| 'Knife': 7, 'Lamp': 8, 'Laptop': 9, 'Motorbike': 10, 'Mug': 11, 'Pistol': 12, | |
| 'Rocket': 13, 'Skateboard': 14, 'Table': 15 | |
| } | |
| # Kategoriye göre geçerli parça ID'leri (Logit Masking için tam liste) | |
| SEG_CLASSES = { | |
| 'Airplane': [0, 1, 2, 3], 'Bag': [4, 5], 'Cap': [6, 7], 'Car': [8, 9, 10, 11], | |
| 'Chair': [12, 13, 14, 15], 'Earphone': [16, 17, 18], 'Guitar': [19, 20, 21], | |
| 'Knife': [22, 23], 'Lamp': [24, 25, 26, 27], 'Laptop': [28, 29], | |
| 'Motorbike': [30, 31, 32, 33, 34, 35], 'Mug': [36, 37], 'Pistol': [38, 39, 40], | |
| 'Rocket': [41, 42, 43], 'Skateboard': [44, 45, 46], 'Table': [47, 48, 49] | |
| } | |
| # Hover (Fare ile üstüne gelince) gösterilecek isimler (Tam liste) | |
| PART_NAMES = { | |
| 0: 'Airplane Body', 1: 'Airplane Wing', 2: 'Airplane Tail', 3: 'Airplane Engine', | |
| 4: 'Bag Handle', 5: 'Bag Body', 6: 'Cap Peak', 7: 'Cap Panels', | |
| 8: 'Car Roof', 9: 'Car Hood', 10: 'Car Bumper', 11: 'Car Wheels', | |
| 12: 'Chair Back', 13: 'Chair Seat', 14: 'Chair Leg', 15: 'Chair Armrest', | |
| 16: 'Earphone Earphone', 17: 'Earphone Headband', 18: 'Earphone Wire', | |
| 19: 'Guitar Head', 20: 'Guitar Body', 21: 'Guitar Neck', | |
| 22: 'Knife Blade', 23: 'Knife Handle', | |
| 24: 'Lamp Base', 25: 'Lamp Shade', 26: 'Lamp Tube', 27: 'Lamp Wire', | |
| 28: 'Laptop Keyboard', 29: 'Laptop Screen', | |
| 30: 'Motorbike Gas Tank', 31: 'Motorbike Seat', 32: 'Motorbike Wheels', 33: 'Motorbike Handles', 34: 'Motorbike Light', 35: 'Motorbike Engine', | |
| 36: 'Mug Handle', 37: 'Mug Cup', | |
| 38: 'Pistol Barrel', 39: 'Pistol Handle', 40: 'Pistol Trigger', | |
| 41: 'Rocket Body', 42: 'Rocket Fin', 43: 'Rocket Nose', | |
| 44: 'Skateboard Wheels', 45: 'Skateboard Deck', 46: 'Skateboard Steering', | |
| 47: 'Table Top', 48: 'Table Leg', 49: 'Table Base' | |
| } | |
| device = torch.device("cpu") | |
| model = HierarchicalFPSCliffordNet(base_channels=12).to(device) | |
| try: | |
| checkpoint = torch.load("best_all_categories_clifford.pt", map_location=device, weights_only=True) | |
| model.load_state_dict({k.replace("_orig_mod.", ""): v for k, v in checkpoint.items()}, strict=False) | |
| model.eval() | |
| except: | |
| print("Ağırlıklar bulunamadı.") | |
| def predict(file, category_name): | |
| if not file: return None | |
| file_path = file if isinstance(file, str) else file.name | |
| mesh = trimesh.load(file_path, force='mesh') | |
| p = np.array(mesh.vertices) | |
| p = p[np.random.choice(len(p), 1024, replace=len(p)<1024)] | |
| pos = torch.tensor(p, dtype=torch.float32) | |
| pos = (pos - pos.mean(0)) / pos.norm(dim=1).max().clamp(1e-8) | |
| # İşleme | |
| f1 = fps_pure(pos, 0.5); p2 = pos[f1] | |
| f2 = fps_pure(p2, 0.25); p3 = p2[f2] | |
| e1 = radius_graph_pure(pos, 0.15); e2 = radius_graph_pure(p2, 0.30) | |
| a32 = knn_pure(p3, p2); a21 = knn_pure(p2, pos) | |
| _, m = to_dense_batch(torch.zeros(p3.size(0), 1), torch.zeros(p3.size(0), dtype=torch.long)) | |
| with torch.no_grad(): | |
| out = model(pos, torch.zeros(1024, dtype=torch.long), torch.tensor([CATEGORIES[category_name]]), f1, f2, p2, torch.zeros(p2.size(0), dtype=torch.long), p3, torch.zeros(p3.size(0), dtype=torch.long), e1, e2, a32, a21, m) | |
| # Kategori Filtreleme (Logit Masking) Eklenmiştir | |
| valid_parts = SEG_CLASSES[category_name] | |
| mask = torch.zeros_like(out, dtype=torch.bool) | |
| mask[:, valid_parts] = True | |
| out = out.masked_fill(~mask, -10000.0) | |
| res = out.argmax(-1).numpy() | |
| # Noktaların üzerine gelindiğinde görünecek metinler hazırlanıyor | |
| hover_texts = [f"Parça: {PART_NAMES.get(pr, f'Bilinmeyen_{pr}')}" for pr in res] | |
| # Plotly figürüne text ve hoverinfo parametreleri eklendi | |
| fig = go.Figure(data=[go.Scatter3d( | |
| x=p[:,0], y=p[:,1], z=p[:,2], | |
| mode='markers', | |
| text=hover_texts, | |
| hoverinfo='text', | |
| marker=dict(size=4, color=res, colorscale='Viridis', opacity=0.8) | |
| )]) | |
| fig.update_layout(margin=dict(l=0,r=0,b=0,t=0), scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False))) | |
| return fig | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🚀 Clifford-Dirac 3D Segmentation (0.15M Params)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.Model3D(label="Model Yükle (.glb, .obj)") | |
| cat = gr.Dropdown(choices=list(CATEGORIES.keys()), label="Kategori", value="Chair") | |
| btn = gr.Button("Tahmin Et", variant="primary") | |
| out = gr.Plot(label="Sonuç") | |
| btn.click(predict, [inp, cat], out) | |
| # İsteğe bağlı: Sayfanın en altına profesyonel bir not eklendi | |
| gr.Markdown("---") | |
| gr.Markdown(""" | |
| ### 📊 Train GitHub | |
| https://github.com/yusuftiryaki/clifford-shapenet-segmentation.git | |
| """) | |
| demo.launch() |