File size: 6,119 Bytes
3e83674
 
 
 
 
 
 
 
99c5bcb
25ba4d0
99c5bcb
25ba4d0
fd1bd17
25ba4d0
 
 
 
 
fd1bd17
99c5bcb
fd1bd17
99c5bcb
 
 
 
fd1bd17
 
99c5bcb
 
 
fd1bd17
 
 
99c5bcb
fd1bd17
 
 
794800b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e83674
fd1bd17
3e83674
 
 
99c5bcb
3e83674
794800b
 
3e83674
 
99c5bcb
f7fc68e
 
 
 
99c5bcb
 
 
 
3e83674
99c5bcb
 
 
 
 
 
 
3e83674
99c5bcb
794800b
 
 
 
 
 
 
99c5bcb
 
794800b
 
 
 
 
 
 
 
 
 
 
 
99c5bcb
3e83674
 
f7fc68e
99c5bcb
794800b
3e83674
99c5bcb
 
 
 
 
794800b
99c5bcb
794800b
 
 
 
b7d5f4a
 
794800b
f7fc68e
3e83674
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import torch
import numpy as np
import trimesh
import plotly.graph_objects as go
from torch_geometric.utils import to_dense_batch
from model import HierarchicalFPSCliffordNet

# --- SAF PYTORCH C++ YEDEKLERİ ---

def knn_pure(x, y, k=3):
    dist = torch.cdist(y, x) 
    _, topk_idx = torch.topk(dist, k, dim=1, largest=False)
    
    idx_y = torch.arange(y.size(0), device=y.device).view(-1, 1).expand(-1, k).reshape(-1)
    idx_x = topk_idx.reshape(-1)
    
    return torch.stack([idx_y, idx_x], dim=0)

def radius_graph_pure(pos, r, max_n=16):
    dist = torch.cdist(pos, pos)
    _, target = torch.topk(dist, min(max_n, pos.size(0)), dim=1, largest=False)
    source = torch.arange(pos.size(0)).view(-1, 1).expand(-1, min(max_n, pos.size(0)))
    mask = dist[source, target] < r
    return torch.stack([source[mask], target[mask]], dim=0)

def fps_pure(pos, ratio=0.5):
    k = max(1, int(pos.size(0) * ratio))
    idx = torch.zeros(k, dtype=torch.long)
    dist = torch.full((pos.size(0),), 1e10)
    farthest = 0
    for i in range(k):
        idx[i] = farthest
        dist = torch.min(dist, torch.cdist(pos, pos[farthest].view(1, 3)).squeeze())
        farthest = torch.argmax(dist).item()
    return idx

# --- Ayarlar ve Veri Seti Sözlükleri ---
CATEGORIES = {
    'Airplane': 0, 'Bag': 1, 'Cap': 2, 'Car': 3, 'Chair': 4, 'Earphone': 5, 'Guitar': 6,
    'Knife': 7, 'Lamp': 8, 'Laptop': 9, 'Motorbike': 10, 'Mug': 11, 'Pistol': 12,
    'Rocket': 13, 'Skateboard': 14, 'Table': 15
}

# Kategoriye göre geçerli parça ID'leri (Logit Masking için tam liste)
SEG_CLASSES = {
    'Airplane': [0, 1, 2, 3], 'Bag': [4, 5], 'Cap': [6, 7], 'Car': [8, 9, 10, 11],
    'Chair': [12, 13, 14, 15], 'Earphone': [16, 17, 18], 'Guitar': [19, 20, 21],
    'Knife': [22, 23], 'Lamp': [24, 25, 26, 27], 'Laptop': [28, 29],
    'Motorbike': [30, 31, 32, 33, 34, 35], 'Mug': [36, 37], 'Pistol': [38, 39, 40],
    'Rocket': [41, 42, 43], 'Skateboard': [44, 45, 46], 'Table': [47, 48, 49]
}

# Hover (Fare ile üstüne gelince) gösterilecek isimler (Tam liste)
PART_NAMES = {
    0: 'Airplane Body', 1: 'Airplane Wing', 2: 'Airplane Tail', 3: 'Airplane Engine',
    4: 'Bag Handle', 5: 'Bag Body', 6: 'Cap Peak', 7: 'Cap Panels',
    8: 'Car Roof', 9: 'Car Hood', 10: 'Car Bumper', 11: 'Car Wheels',
    12: 'Chair Back', 13: 'Chair Seat', 14: 'Chair Leg', 15: 'Chair Armrest',
    16: 'Earphone Earphone', 17: 'Earphone Headband', 18: 'Earphone Wire',
    19: 'Guitar Head', 20: 'Guitar Body', 21: 'Guitar Neck',
    22: 'Knife Blade', 23: 'Knife Handle',
    24: 'Lamp Base', 25: 'Lamp Shade', 26: 'Lamp Tube', 27: 'Lamp Wire',
    28: 'Laptop Keyboard', 29: 'Laptop Screen',
    30: 'Motorbike Gas Tank', 31: 'Motorbike Seat', 32: 'Motorbike Wheels', 33: 'Motorbike Handles', 34: 'Motorbike Light', 35: 'Motorbike Engine',
    36: 'Mug Handle', 37: 'Mug Cup',
    38: 'Pistol Barrel', 39: 'Pistol Handle', 40: 'Pistol Trigger',
    41: 'Rocket Body', 42: 'Rocket Fin', 43: 'Rocket Nose',
    44: 'Skateboard Wheels', 45: 'Skateboard Deck', 46: 'Skateboard Steering',
    47: 'Table Top', 48: 'Table Leg', 49: 'Table Base'
}

device = torch.device("cpu")
model = HierarchicalFPSCliffordNet(base_channels=12).to(device)
try:
    checkpoint = torch.load("best_all_categories_clifford.pt", map_location=device, weights_only=True)
    model.load_state_dict({k.replace("_orig_mod.", ""): v for k, v in checkpoint.items()}, strict=False)
    model.eval()
except: 
    print("Ağırlıklar bulunamadı.")

def predict(file, category_name):
    if not file: return None
    
    file_path = file if isinstance(file, str) else file.name
    
    mesh = trimesh.load(file_path, force='mesh')
    p = np.array(mesh.vertices)
    p = p[np.random.choice(len(p), 1024, replace=len(p)<1024)]
    pos = torch.tensor(p, dtype=torch.float32)
    pos = (pos - pos.mean(0)) / pos.norm(dim=1).max().clamp(1e-8)
    
    # İşleme
    f1 = fps_pure(pos, 0.5); p2 = pos[f1]
    f2 = fps_pure(p2, 0.25); p3 = p2[f2]
    e1 = radius_graph_pure(pos, 0.15); e2 = radius_graph_pure(p2, 0.30)
    a32 = knn_pure(p3, p2); a21 = knn_pure(p2, pos)
    _, m = to_dense_batch(torch.zeros(p3.size(0), 1), torch.zeros(p3.size(0), dtype=torch.long))

    with torch.no_grad():
        out = model(pos, torch.zeros(1024, dtype=torch.long), torch.tensor([CATEGORIES[category_name]]), f1, f2, p2, torch.zeros(p2.size(0), dtype=torch.long), p3, torch.zeros(p3.size(0), dtype=torch.long), e1, e2, a32, a21, m)
        
        # Kategori Filtreleme (Logit Masking) Eklenmiştir
        valid_parts = SEG_CLASSES[category_name]
        mask = torch.zeros_like(out, dtype=torch.bool)
        mask[:, valid_parts] = True
        out = out.masked_fill(~mask, -10000.0)
        
        res = out.argmax(-1).numpy()

    # Noktaların üzerine gelindiğinde görünecek metinler hazırlanıyor
    hover_texts = [f"Parça: {PART_NAMES.get(pr, f'Bilinmeyen_{pr}')}" for pr in res]

    # Plotly figürüne text ve hoverinfo parametreleri eklendi
    fig = go.Figure(data=[go.Scatter3d(
        x=p[:,0], y=p[:,1], z=p[:,2], 
        mode='markers', 
        text=hover_texts, 
        hoverinfo='text',
        marker=dict(size=4, color=res, colorscale='Viridis', opacity=0.8)
    )])
    
    fig.update_layout(margin=dict(l=0,r=0,b=0,t=0), scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)))
    return fig

with gr.Blocks() as demo:
    gr.Markdown("# 🚀 Clifford-Dirac 3D Segmentation (0.15M Params)")
    
    with gr.Row():
        with gr.Column():
            inp = gr.Model3D(label="Model Yükle (.glb, .obj)")
            cat = gr.Dropdown(choices=list(CATEGORIES.keys()), label="Kategori", value="Chair")
            btn = gr.Button("Tahmin Et", variant="primary")
        out = gr.Plot(label="Sonuç")
        
    btn.click(predict, [inp, cat], out)
    
    # İsteğe bağlı: Sayfanın en altına profesyonel bir not eklendi
    gr.Markdown("---")
    gr.Markdown("""
    ### 📊 Train GitHub
    https://github.com/yusuftiryaki/clifford-shapenet-segmentation.git
    """)

demo.launch()