bdck commited on
Commit
3fe5c47
·
verified ·
1 Parent(s): e3423d1

Upload learn_region_grow/train.py

Browse files
Files changed (1) hide show
  1. learn_region_grow/train.py +175 -0
learn_region_grow/train.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training script for LrgNet."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from torch.utils.data import Dataset, DataLoader
7
+ import numpy as np
8
+ import h5py
9
+ from pathlib import Path
10
+ from typing import List, Optional
11
+ from .lrg_net import LrgNet
12
+
13
+
14
+ class StagedDataset(Dataset):
15
+ """PyTorch Dataset wrapping H5 staged training files."""
16
+
17
+ def __init__(self, h5_paths: List[str]):
18
+ self.h5_paths = h5_paths
19
+ self.offsets = []
20
+ self.total = 0
21
+ for p in h5_paths:
22
+ with h5py.File(p, 'r') as f:
23
+ n = f['inliers'].shape[0]
24
+ self.offsets.append((self.total, self.total + n, p))
25
+ self.total += n
26
+
27
+ def __len__(self):
28
+ return self.total
29
+
30
+ def __getitem__(self, idx):
31
+ # Find which file
32
+ for start, end, path in self.offsets:
33
+ if start <= idx < end:
34
+ local_idx = idx - start
35
+ break
36
+ else:
37
+ raise IndexError(idx)
38
+
39
+ with h5py.File(path, 'r') as f:
40
+ inlier = f['inliers'][local_idx] # (Ni, 13)
41
+ neighbor = f['neighbors'][local_idx] # (Nn, 13)
42
+ add_lbl = f['add_labels'][local_idx] # (Nn,)
43
+ rmv_lbl = f['remove_labels'][local_idx] # (Ni,)
44
+
45
+ # Transpose to (C, N) for Conv1d
46
+ inlier = torch.from_numpy(inlier.T).float() # (13, Ni)
47
+ neighbor = torch.from_numpy(neighbor.T).float() # (13, Nn)
48
+ add_lbl = torch.from_numpy(add_lbl).long() # (Nn,)
49
+ rmv_lbl = torch.from_numpy(rmv_lbl).long() # (Ni,)
50
+
51
+ return inlier, neighbor, add_lbl, rmv_lbl
52
+
53
+
54
+ class AddRemoveLoss(nn.Module):
55
+ """Joint cross-entropy over add + remove logits."""
56
+
57
+ def __init__(self, add_weight: float = 1.0, remove_weight: float = 1.0):
58
+ super().__init__()
59
+ self.add_weight = add_weight
60
+ self.remove_weight = remove_weight
61
+ self.ce = nn.CrossEntropyLoss(reduction='none')
62
+
63
+ def forward(self, add_logits, add_targets, remove_logits, remove_targets):
64
+ # add_logits: (B, 1, Nn) -> treat as binary classification
65
+ # PyTorch cross_entropy expects (B, C, ...); here C=1 is tricky for sigmoid
66
+ # Simpler: use BCEWithLogitsLoss
67
+ pass # placeholder -- we use BCE in trainer below
68
+
69
+
70
+ def train_lrgnet(train_files: List[str],
71
+ val_files: Optional[List[str]] = None,
72
+ epochs: int = 50,
73
+ batch_size: int = 16,
74
+ lr: float = 1e-3,
75
+ device: str = 'cuda',
76
+ lite: int = 0,
77
+ save_dir: str = './checkpoints',
78
+ resume: Optional[str] = None):
79
+ """
80
+ Train LrgNet on staged H5 files.
81
+
82
+ Parameters
83
+ ----------
84
+ train_files : list of str
85
+ Paths to staged H5 files.
86
+ val_files : list of str, optional
87
+ Validation H5 files.
88
+ epochs : int
89
+ batch_size : int
90
+ lr : float
91
+ Adam learning rate (default 1e-3, matching the paper).
92
+ device : str
93
+ lite : int
94
+ 0 = full, 1 = half channels, 2 = quarter channels.
95
+ save_dir : str
96
+ Where to write checkpoints.
97
+ resume : str, optional
98
+ Path to a checkpoint to resume from.
99
+ """
100
+ device = torch.device(device if torch.cuda.is_available() else 'cpu')
101
+
102
+ train_ds = StagedDataset(train_files)
103
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
104
+ num_workers=4, pin_memory=True, drop_last=True)
105
+
106
+ val_loader = None
107
+ if val_files:
108
+ val_ds = StagedDataset(val_files)
109
+ val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
110
+ num_workers=4, pin_memory=True)
111
+
112
+ model = LrgNet(in_channels=13, lite=lite).to(device)
113
+ if resume:
114
+ model.load_state_dict(torch.load(resume, map_location=device))
115
+
116
+ optimizer = optim.Adam(model.parameters(), lr=lr)
117
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
118
+
119
+ bce_add = nn.BCEWithLogitsLoss()
120
+ bce_remove = nn.BCEWithLogitsLoss()
121
+
122
+ save_dir = Path(save_dir)
123
+ save_dir.mkdir(parents=True, exist_ok=True)
124
+
125
+ best_val_loss = float('inf')
126
+
127
+ for epoch in range(1, epochs + 1):
128
+ model.train()
129
+ total_loss = 0.0
130
+ n_batches = 0
131
+
132
+ for inliers, neighbors, add_lbl, rmv_lbl in train_loader:
133
+ inliers = inliers.to(device)
134
+ neighbors = neighbors.to(device)
135
+ add_lbl = add_lbl.to(device).float().unsqueeze(1) # (B, 1, Nn)
136
+ rmv_lbl = rmv_lbl.to(device).float().unsqueeze(1) # (B, 1, Ni)
137
+
138
+ optimizer.zero_grad()
139
+ add_logits, rmv_logits = model(inliers, neighbors)
140
+
141
+ # For binary BCE, logits shape is (B, 1, N). Targets same shape.
142
+ loss = bce_add(add_logits, add_lbl) + bce_remove(rmv_logits, rmv_lbl)
143
+ loss.backward()
144
+ optimizer.step()
145
+
146
+ total_loss += loss.item()
147
+ n_batches += 1
148
+
149
+ avg_train = total_loss / max(n_batches, 1)
150
+
151
+ val_str = ""
152
+ if val_loader:
153
+ model.eval()
154
+ val_loss = 0.0
155
+ with torch.no_grad():
156
+ for inliers, neighbors, add_lbl, rmv_lbl in val_loader:
157
+ inliers = inliers.to(device)
158
+ neighbors = neighbors.to(device)
159
+ add_lbl = add_lbl.to(device).float().unsqueeze(1)
160
+ rmv_lbl = rmv_lbl.to(device).float().unsqueeze(1)
161
+ add_logits, rmv_logits = model(inliers, neighbors)
162
+ vloss = bce_add(add_logits, add_lbl) + bce_remove(rmv_logits, rmv_lbl)
163
+ val_loss += vloss.item()
164
+ avg_val = val_loss / len(val_loader)
165
+ val_str = f" | val_loss={avg_val:.4f}"
166
+ if avg_val < best_val_loss:
167
+ best_val_loss = avg_val
168
+ torch.save(model.state_dict(), save_dir / 'best_model.pt')
169
+
170
+ scheduler.step()
171
+ print(f"Epoch {epoch}/{epochs} train_loss={avg_train:.4f}{val_str}")
172
+ torch.save(model.state_dict(), save_dir / f'epoch_{epoch:03d}.pt')
173
+
174
+ print(f"Training complete. Checkpoints saved to {save_dir}")
175
+ return model