| """ |
| train.py |
| ======== |
| |
| Train LrgNet on staged H5 data generated from labeled point clouds. |
| |
| Example: |
| python train.py --data_dir staged_h5/ --epochs 50 --batch_size 16 \ |
| --lr 1e-3 --save_dir checkpoints/ --device cuda |
| """ |
|
|
| import argparse |
| import glob |
| from pathlib import Path |
| from learn_region_grow.train import train_lrgnet |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Train LrgNet on staged H5 data") |
| parser.add_argument("--data_dir", required=True, help="Directory with *.h5 staged training files") |
| parser.add_argument("--val_split", type=float, default=0.1, help="Fraction of files for validation") |
| parser.add_argument("--epochs", type=int, default=50) |
| parser.add_argument("--batch_size", type=int, default=16) |
| parser.add_argument("--lr", type=float, default=1e-3) |
| parser.add_argument("--device", default="cuda") |
| parser.add_argument("--lite", type=int, default=0, choices=[0,1,2]) |
| parser.add_argument("--save_dir", default="checkpoints") |
| parser.add_argument("--resume", default=None) |
| args = parser.parse_args() |
|
|
| h5_files = sorted(glob.glob(str(Path(args.data_dir) / "*.h5"))) |
| if not h5_files: |
| raise FileNotFoundError(f"No H5 files found in {args.data_dir}") |
|
|
| split = int(len(h5_files) * (1 - args.val_split)) |
| train_files = h5_files[:split] |
| val_files = h5_files[split:] if args.val_split > 0 else None |
|
|
| print(f"Train files: {len(train_files)}, Val files: {len(val_files) if val_files else 0}") |
|
|
| model = train_lrgnet( |
| train_files=train_files, |
| val_files=val_files, |
| epochs=args.epochs, |
| batch_size=args.batch_size, |
| lr=args.lr, |
| device=args.device, |
| lite=args.lite, |
| save_dir=args.save_dir, |
| resume=args.resume, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|