bdck commited on
Commit
859d9ba
·
verified ·
1 Parent(s): 191e774

Upload scripts/train.py

Browse files
Files changed (1) hide show
  1. scripts/train.py +55 -0
scripts/train.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train.py
3
+ ========
4
+
5
+ Train LrgNet on staged H5 data generated from labeled point clouds.
6
+
7
+ Example:
8
+ python train.py --data_dir staged_h5/ --epochs 50 --batch_size 16 \
9
+ --lr 1e-3 --save_dir checkpoints/ --device cuda
10
+ """
11
+
12
+ import argparse
13
+ import glob
14
+ from pathlib import Path
15
+ from learn_region_grow.train import train_lrgnet
16
+
17
+
18
+ def main():
19
+ parser = argparse.ArgumentParser(description="Train LrgNet on staged H5 data")
20
+ parser.add_argument("--data_dir", required=True, help="Directory with *.h5 staged training files")
21
+ parser.add_argument("--val_split", type=float, default=0.1, help="Fraction of files for validation")
22
+ parser.add_argument("--epochs", type=int, default=50)
23
+ parser.add_argument("--batch_size", type=int, default=16)
24
+ parser.add_argument("--lr", type=float, default=1e-3)
25
+ parser.add_argument("--device", default="cuda")
26
+ parser.add_argument("--lite", type=int, default=0, choices=[0,1,2])
27
+ parser.add_argument("--save_dir", default="checkpoints")
28
+ parser.add_argument("--resume", default=None)
29
+ args = parser.parse_args()
30
+
31
+ h5_files = sorted(glob.glob(str(Path(args.data_dir) / "*.h5")))
32
+ if not h5_files:
33
+ raise FileNotFoundError(f"No H5 files found in {args.data_dir}")
34
+
35
+ split = int(len(h5_files) * (1 - args.val_split))
36
+ train_files = h5_files[:split]
37
+ val_files = h5_files[split:] if args.val_split > 0 else None
38
+
39
+ print(f"Train files: {len(train_files)}, Val files: {len(val_files) if val_files else 0}")
40
+
41
+ model = train_lrgnet(
42
+ train_files=train_files,
43
+ val_files=val_files,
44
+ epochs=args.epochs,
45
+ batch_size=args.batch_size,
46
+ lr=args.lr,
47
+ device=args.device,
48
+ lite=args.lite,
49
+ save_dir=args.save_dir,
50
+ resume=args.resume,
51
+ )
52
+
53
+
54
+ if __name__ == "__main__":
55
+ main()