Spaces:
Sleeping
Sleeping
Commit ·
fd1bd17
1
Parent(s): 3e83674
fix torch cluster
Browse files- app.py +42 -34
- model.py +0 -6
- requirements.txt +0 -3
app.py
CHANGED
|
@@ -3,12 +3,40 @@ import torch
|
|
| 3 |
import numpy as np
|
| 4 |
import trimesh
|
| 5 |
import plotly.graph_objects as go
|
| 6 |
-
from torch_geometric.data import Data
|
| 7 |
-
from torch_cluster import fps, radius_graph, knn
|
| 8 |
from torch_geometric.utils import to_dense_batch
|
| 9 |
-
|
| 10 |
from model import HierarchicalFPSCliffordNet
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
# --- Kategori Sözlükleri ---
|
| 13 |
CATEGORIES = {
|
| 14 |
'Airplane': 0, 'Bag': 1, 'Cap': 2, 'Car': 3, 'Chair': 4, 'Earphone': 5, 'Guitar': 6,
|
|
@@ -24,7 +52,7 @@ SEG_CLASSES = {
|
|
| 24 |
}
|
| 25 |
|
| 26 |
# --- Modeli Yükleme ---
|
| 27 |
-
device = torch.device("cpu")
|
| 28 |
model = HierarchicalFPSCliffordNet(base_channels=12).to(device)
|
| 29 |
|
| 30 |
try:
|
|
@@ -38,11 +66,9 @@ except Exception as e:
|
|
| 38 |
|
| 39 |
# --- Veri Hazırlama Fonksiyonu ---
|
| 40 |
def process_point_cloud(file_path, category_name):
|
| 41 |
-
# Dosyayı oku ve noktalara çevir
|
| 42 |
mesh = trimesh.load(file_path, force='mesh')
|
| 43 |
points = np.array(mesh.vertices)
|
| 44 |
|
| 45 |
-
# 1024 noktaya örnekle (Eğitimdeki gibi)
|
| 46 |
if len(points) > 1024:
|
| 47 |
idx = np.random.choice(len(points), 1024, replace=False)
|
| 48 |
points = points[idx]
|
|
@@ -51,8 +77,6 @@ def process_point_cloud(file_path, category_name):
|
|
| 51 |
points = points[idx]
|
| 52 |
|
| 53 |
pos = torch.tensor(points, dtype=torch.float32)
|
| 54 |
-
|
| 55 |
-
# Birim Küreye Normalize Et
|
| 56 |
pos = pos - pos.mean(dim=0, keepdim=True)
|
| 57 |
scale = pos.norm(dim=1).max().clamp(min=1e-8)
|
| 58 |
pos = pos / scale
|
|
@@ -60,68 +84,52 @@ def process_point_cloud(file_path, category_name):
|
|
| 60 |
batch = torch.zeros(pos.size(0), dtype=torch.long)
|
| 61 |
cat_idx = torch.tensor([CATEGORIES[category_name]], dtype=torch.long)
|
| 62 |
|
| 63 |
-
#
|
| 64 |
-
fps_idx_1 =
|
| 65 |
pos2 = pos[fps_idx_1]
|
| 66 |
batch2 = torch.zeros(pos2.size(0), dtype=torch.long)
|
| 67 |
|
| 68 |
-
fps_idx_2 =
|
| 69 |
pos3 = pos2[fps_idx_2]
|
| 70 |
batch3 = torch.zeros(pos3.size(0), dtype=torch.long)
|
| 71 |
|
| 72 |
-
edge_index_1 =
|
| 73 |
-
edge_index_2 =
|
| 74 |
|
| 75 |
-
assign_index_32 =
|
| 76 |
-
assign_index_21 =
|
| 77 |
|
| 78 |
dummy_x = torch.zeros(pos3.size(0), 1)
|
| 79 |
_, x_dense_mask = to_dense_batch(dummy_x, batch3, max_num_nodes=int(torch.bincount(batch3).max().item()))
|
| 80 |
|
| 81 |
return pos, batch, cat_idx, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3, edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask
|
| 82 |
|
| 83 |
-
# --- Çıkarım
|
| 84 |
def predict(file, category_name):
|
| 85 |
if file is None:
|
| 86 |
return None
|
| 87 |
|
| 88 |
-
# Veriyi modele uygun formata getir
|
| 89 |
(pos, batch, cat_idx, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3,
|
| 90 |
edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask) = process_point_cloud(file.name, category_name)
|
| 91 |
|
| 92 |
-
# Modelden Geçir
|
| 93 |
with torch.no_grad():
|
| 94 |
logits = model(pos, batch, cat_idx, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3,
|
| 95 |
edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask)
|
| 96 |
|
| 97 |
-
# O kategoriye ait olmayan parçaları maskele
|
| 98 |
valid_parts = SEG_CLASSES[category_name]
|
| 99 |
mask = torch.zeros_like(logits, dtype=torch.bool)
|
| 100 |
mask[:, valid_parts] = True
|
| 101 |
logits = logits.masked_fill(~mask, -10000.0)
|
| 102 |
-
|
| 103 |
-
# Sınıfı Tahmin Et
|
| 104 |
preds = logits.argmax(dim=-1).numpy()
|
| 105 |
|
| 106 |
-
# Plotly ile Görselleştir
|
| 107 |
pos_np = pos.numpy()
|
| 108 |
fig = go.Figure(data=[go.Scatter3d(
|
| 109 |
x=pos_np[:, 0], y=pos_np[:, 1], z=pos_np[:, 2],
|
| 110 |
mode='markers',
|
| 111 |
-
marker=dict(
|
| 112 |
-
size=4,
|
| 113 |
-
color=preds, # Her parçaya farklı renk atar
|
| 114 |
-
colorscale='Viridis',
|
| 115 |
-
opacity=0.9,
|
| 116 |
-
line=dict(width=0.5, color='black') # Noktalara derinlik katar
|
| 117 |
-
)
|
| 118 |
)])
|
| 119 |
|
| 120 |
-
fig.update_layout(
|
| 121 |
-
margin=dict(l=0, r=0, b=0, t=0),
|
| 122 |
-
scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False))
|
| 123 |
-
)
|
| 124 |
-
|
| 125 |
return fig
|
| 126 |
|
| 127 |
# --- Gradio Arayüzü ---
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import trimesh
|
| 5 |
import plotly.graph_objects as go
|
|
|
|
|
|
|
| 6 |
from torch_geometric.utils import to_dense_batch
|
|
|
|
| 7 |
from model import HierarchicalFPSCliffordNet
|
| 8 |
|
| 9 |
+
# --- SAF PYTORCH C++ YEDEKLERİ (Hugging Face için Özel Çözüm) ---
|
| 10 |
+
def knn_pure(x, y, k):
|
| 11 |
+
dist = torch.cdist(x, y)
|
| 12 |
+
_, topk_idx = torch.topk(dist, k, dim=1, largest=False)
|
| 13 |
+
col = torch.arange(x.size(0)).view(-1, 1).expand(-1, k).reshape(-1)
|
| 14 |
+
row = topk_idx.reshape(-1)
|
| 15 |
+
return torch.stack([row, col], dim=0)
|
| 16 |
+
|
| 17 |
+
def radius_graph_pure(pos, r, max_num_neighbors=32):
|
| 18 |
+
dist = torch.cdist(pos, pos)
|
| 19 |
+
n = pos.size(0)
|
| 20 |
+
k = min(max_num_neighbors, n)
|
| 21 |
+
_, target = torch.topk(dist, k, dim=1, largest=False)
|
| 22 |
+
source = torch.arange(n).view(-1, 1).expand(-1, k)
|
| 23 |
+
valid = dist[source, target] < r
|
| 24 |
+
return torch.stack([source[valid], target[valid]], dim=0)
|
| 25 |
+
|
| 26 |
+
def fps_pure(pos, ratio=0.5):
|
| 27 |
+
n = pos.size(0)
|
| 28 |
+
k = max(1, int(n * ratio))
|
| 29 |
+
idx = torch.zeros(k, dtype=torch.long, device=pos.device)
|
| 30 |
+
dist = torch.full((n,), 1e10, device=pos.device)
|
| 31 |
+
farthest = 0
|
| 32 |
+
for i in range(k):
|
| 33 |
+
idx[i] = farthest
|
| 34 |
+
farthest_point = pos[farthest].view(1, 3)
|
| 35 |
+
dist_to_farthest = torch.cdist(pos, farthest_point).squeeze(-1)
|
| 36 |
+
dist = torch.min(dist, dist_to_farthest)
|
| 37 |
+
farthest = torch.argmax(dist).item()
|
| 38 |
+
return idx
|
| 39 |
+
|
| 40 |
# --- Kategori Sözlükleri ---
|
| 41 |
CATEGORIES = {
|
| 42 |
'Airplane': 0, 'Bag': 1, 'Cap': 2, 'Car': 3, 'Chair': 4, 'Earphone': 5, 'Guitar': 6,
|
|
|
|
| 52 |
}
|
| 53 |
|
| 54 |
# --- Modeli Yükleme ---
|
| 55 |
+
device = torch.device("cpu")
|
| 56 |
model = HierarchicalFPSCliffordNet(base_channels=12).to(device)
|
| 57 |
|
| 58 |
try:
|
|
|
|
| 66 |
|
| 67 |
# --- Veri Hazırlama Fonksiyonu ---
|
| 68 |
def process_point_cloud(file_path, category_name):
|
|
|
|
| 69 |
mesh = trimesh.load(file_path, force='mesh')
|
| 70 |
points = np.array(mesh.vertices)
|
| 71 |
|
|
|
|
| 72 |
if len(points) > 1024:
|
| 73 |
idx = np.random.choice(len(points), 1024, replace=False)
|
| 74 |
points = points[idx]
|
|
|
|
| 77 |
points = points[idx]
|
| 78 |
|
| 79 |
pos = torch.tensor(points, dtype=torch.float32)
|
|
|
|
|
|
|
| 80 |
pos = pos - pos.mean(dim=0, keepdim=True)
|
| 81 |
scale = pos.norm(dim=1).max().clamp(min=1e-8)
|
| 82 |
pos = pos / scale
|
|
|
|
| 84 |
batch = torch.zeros(pos.size(0), dtype=torch.long)
|
| 85 |
cat_idx = torch.tensor([CATEGORIES[category_name]], dtype=torch.long)
|
| 86 |
|
| 87 |
+
# Saf PyTorch ile Hızlı Hesaplamalar (C++ Yok)
|
| 88 |
+
fps_idx_1 = fps_pure(pos, ratio=0.5)
|
| 89 |
pos2 = pos[fps_idx_1]
|
| 90 |
batch2 = torch.zeros(pos2.size(0), dtype=torch.long)
|
| 91 |
|
| 92 |
+
fps_idx_2 = fps_pure(pos2, ratio=0.25)
|
| 93 |
pos3 = pos2[fps_idx_2]
|
| 94 |
batch3 = torch.zeros(pos3.size(0), dtype=torch.long)
|
| 95 |
|
| 96 |
+
edge_index_1 = radius_graph_pure(pos, r=0.15, max_num_neighbors=16)
|
| 97 |
+
edge_index_2 = radius_graph_pure(pos2, r=0.30, max_num_neighbors=32)
|
| 98 |
|
| 99 |
+
assign_index_32 = knn_pure(pos3, pos2, k=3)
|
| 100 |
+
assign_index_21 = knn_pure(pos2, pos, k=3)
|
| 101 |
|
| 102 |
dummy_x = torch.zeros(pos3.size(0), 1)
|
| 103 |
_, x_dense_mask = to_dense_batch(dummy_x, batch3, max_num_nodes=int(torch.bincount(batch3).max().item()))
|
| 104 |
|
| 105 |
return pos, batch, cat_idx, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3, edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask
|
| 106 |
|
| 107 |
+
# --- Çıkarım Fonksiyonu ---
|
| 108 |
def predict(file, category_name):
|
| 109 |
if file is None:
|
| 110 |
return None
|
| 111 |
|
|
|
|
| 112 |
(pos, batch, cat_idx, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3,
|
| 113 |
edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask) = process_point_cloud(file.name, category_name)
|
| 114 |
|
|
|
|
| 115 |
with torch.no_grad():
|
| 116 |
logits = model(pos, batch, cat_idx, fps_idx_1, fps_idx_2, pos2, batch2, pos3, batch3,
|
| 117 |
edge_index_1, edge_index_2, assign_index_32, assign_index_21, x_dense_mask)
|
| 118 |
|
|
|
|
| 119 |
valid_parts = SEG_CLASSES[category_name]
|
| 120 |
mask = torch.zeros_like(logits, dtype=torch.bool)
|
| 121 |
mask[:, valid_parts] = True
|
| 122 |
logits = logits.masked_fill(~mask, -10000.0)
|
|
|
|
|
|
|
| 123 |
preds = logits.argmax(dim=-1).numpy()
|
| 124 |
|
|
|
|
| 125 |
pos_np = pos.numpy()
|
| 126 |
fig = go.Figure(data=[go.Scatter3d(
|
| 127 |
x=pos_np[:, 0], y=pos_np[:, 1], z=pos_np[:, 2],
|
| 128 |
mode='markers',
|
| 129 |
+
marker=dict(size=4, color=preds, colorscale='Viridis', opacity=0.9, line=dict(width=0.5, color='black'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
)])
|
| 131 |
|
| 132 |
+
fig.update_layout(margin=dict(l=0, r=0, b=0, t=0), scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
return fig
|
| 134 |
|
| 135 |
# --- Gradio Arayüzü ---
|
model.py
CHANGED
|
@@ -3,10 +3,8 @@ import torch.nn as nn
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from torch_geometric.nn import MessagePassing
|
| 5 |
from torch_geometric.utils import to_dense_batch
|
| 6 |
-
from torch_cluster import fps, radius_graph, knn
|
| 7 |
import math
|
| 8 |
|
| 9 |
-
# --- TRITON & CPU FALLBACK KONTROLÜ ---
|
| 10 |
try:
|
| 11 |
import triton
|
| 12 |
import triton.language as tl
|
|
@@ -15,7 +13,6 @@ except ImportError:
|
|
| 15 |
HAS_TRITON = False
|
| 16 |
|
| 17 |
def pytorch_clifford_product(a, b):
|
| 18 |
-
"""Hugging Face Free Tier (CPU) için saf PyTorch 8D Geometrik Çarpım"""
|
| 19 |
res = torch.zeros_like(a)
|
| 20 |
res[..., 0] = a[...,0]*b[...,0] + a[...,1]*b[...,1] + a[...,2]*b[...,2] + a[...,3]*b[...,3] - a[...,4]*b[...,4] - a[...,5]*b[...,5] - a[...,6]*b[...,6] - a[...,7]*b[...,7]
|
| 21 |
res[..., 1] = a[...,0]*b[...,1] + a[...,1]*b[...,0] - a[...,2]*b[...,4] + a[...,3]*b[...,6] + a[...,4]*b[...,2] - a[...,5]*b[...,7] - a[...,6]*b[...,3] - a[...,7]*b[...,5]
|
|
@@ -28,13 +25,10 @@ def pytorch_clifford_product(a, b):
|
|
| 28 |
return res
|
| 29 |
|
| 30 |
def smart_clifford_product(a, b):
|
| 31 |
-
# Eğer ortamda GPU ve Triton yoksa, CPU versiyonunu kullan.
|
| 32 |
if HAS_TRITON:
|
| 33 |
-
# Eğitilmiş Triton kodlarınız buraya gelir (Çıkarım için genellikle CPU yeterlidir)
|
| 34 |
pass
|
| 35 |
return pytorch_clifford_product(a, b)
|
| 36 |
|
| 37 |
-
|
| 38 |
class CliffordDiracLayer(MessagePassing):
|
| 39 |
def __init__(self, channels: int):
|
| 40 |
super().__init__(aggr="add", node_dim=0)
|
|
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from torch_geometric.nn import MessagePassing
|
| 5 |
from torch_geometric.utils import to_dense_batch
|
|
|
|
| 6 |
import math
|
| 7 |
|
|
|
|
| 8 |
try:
|
| 9 |
import triton
|
| 10 |
import triton.language as tl
|
|
|
|
| 13 |
HAS_TRITON = False
|
| 14 |
|
| 15 |
def pytorch_clifford_product(a, b):
|
|
|
|
| 16 |
res = torch.zeros_like(a)
|
| 17 |
res[..., 0] = a[...,0]*b[...,0] + a[...,1]*b[...,1] + a[...,2]*b[...,2] + a[...,3]*b[...,3] - a[...,4]*b[...,4] - a[...,5]*b[...,5] - a[...,6]*b[...,6] - a[...,7]*b[...,7]
|
| 18 |
res[..., 1] = a[...,0]*b[...,1] + a[...,1]*b[...,0] - a[...,2]*b[...,4] + a[...,3]*b[...,6] + a[...,4]*b[...,2] - a[...,5]*b[...,7] - a[...,6]*b[...,3] - a[...,7]*b[...,5]
|
|
|
|
| 25 |
return res
|
| 26 |
|
| 27 |
def smart_clifford_product(a, b):
|
|
|
|
| 28 |
if HAS_TRITON:
|
|
|
|
| 29 |
pass
|
| 30 |
return pytorch_clifford_product(a, b)
|
| 31 |
|
|
|
|
| 32 |
class CliffordDiracLayer(MessagePassing):
|
| 33 |
def __init__(self, channels: int):
|
| 34 |
super().__init__(aggr="add", node_dim=0)
|
requirements.txt
CHANGED
|
@@ -1,8 +1,5 @@
|
|
| 1 |
torch
|
| 2 |
torch-geometric
|
| 3 |
-
torch-scatter
|
| 4 |
-
torch-sparse
|
| 5 |
-
torch-cluster
|
| 6 |
gradio
|
| 7 |
plotly
|
| 8 |
numpy
|
|
|
|
| 1 |
torch
|
| 2 |
torch-geometric
|
|
|
|
|
|
|
|
|
|
| 3 |
gradio
|
| 4 |
plotly
|
| 5 |
numpy
|