File size: 1,842 Bytes
859d9ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
train.py
========

Train LrgNet on staged H5 data generated from labeled point clouds.

Example:
    python train.py --data_dir staged_h5/ --epochs 50 --batch_size 16 \
                    --lr 1e-3 --save_dir checkpoints/ --device cuda
"""

import argparse
import glob
from pathlib import Path
from learn_region_grow.train import train_lrgnet


def main():
    parser = argparse.ArgumentParser(description="Train LrgNet on staged H5 data")
    parser.add_argument("--data_dir", required=True, help="Directory with *.h5 staged training files")
    parser.add_argument("--val_split", type=float, default=0.1, help="Fraction of files for validation")
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--lite", type=int, default=0, choices=[0,1,2])
    parser.add_argument("--save_dir", default="checkpoints")
    parser.add_argument("--resume", default=None)
    args = parser.parse_args()

    h5_files = sorted(glob.glob(str(Path(args.data_dir) / "*.h5")))
    if not h5_files:
        raise FileNotFoundError(f"No H5 files found in {args.data_dir}")

    split = int(len(h5_files) * (1 - args.val_split))
    train_files = h5_files[:split]
    val_files = h5_files[split:] if args.val_split > 0 else None

    print(f"Train files: {len(train_files)}, Val files: {len(val_files) if val_files else 0}")

    model = train_lrgnet(
        train_files=train_files,
        val_files=val_files,
        epochs=args.epochs,
        batch_size=args.batch_size,
        lr=args.lr,
        device=args.device,
        lite=args.lite,
        save_dir=args.save_dir,
        resume=args.resume,
    )


if __name__ == "__main__":
    main()