""" train.py ======== Train LrgNet on staged H5 data generated from labeled point clouds. Example: python train.py --data_dir staged_h5/ --epochs 50 --batch_size 16 \ --lr 1e-3 --save_dir checkpoints/ --device cuda """ import argparse import glob from pathlib import Path from learn_region_grow.train import train_lrgnet def main(): parser = argparse.ArgumentParser(description="Train LrgNet on staged H5 data") parser.add_argument("--data_dir", required=True, help="Directory with *.h5 staged training files") parser.add_argument("--val_split", type=float, default=0.1, help="Fraction of files for validation") parser.add_argument("--epochs", type=int, default=50) parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--device", default="cuda") parser.add_argument("--lite", type=int, default=0, choices=[0,1,2]) parser.add_argument("--save_dir", default="checkpoints") parser.add_argument("--resume", default=None) args = parser.parse_args() h5_files = sorted(glob.glob(str(Path(args.data_dir) / "*.h5"))) if not h5_files: raise FileNotFoundError(f"No H5 files found in {args.data_dir}") split = int(len(h5_files) * (1 - args.val_split)) train_files = h5_files[:split] val_files = h5_files[split:] if args.val_split > 0 else None print(f"Train files: {len(train_files)}, Val files: {len(val_files) if val_files else 0}") model = train_lrgnet( train_files=train_files, val_files=val_files, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, device=args.device, lite=args.lite, save_dir=args.save_dir, resume=args.resume, ) if __name__ == "__main__": main()