import argparse import logging from pathlib import Path from huggingface_hub import snapshot_download logger = logging.getLogger(__name__) DEFAULT_SYNLAYERS_REPO = "thuteam/" + "C" + "LD" def parse_args(): parser = argparse.ArgumentParser(description="Download SynLayers pretrained checkpoints.") parser.add_argument( "--project-root", default=str(Path(__file__).resolve().parents[1]), help="SynLayers project root directory.", ) parser.add_argument( "--flux-dir", default=None, help="Output directory for FLUX.1-dev weights.", ) parser.add_argument( "--adapter-dir", default=None, help="Output directory for FLUX.1-dev ControlNet adapter weights.", ) parser.add_argument( "--download-synlayers", dest="download_synlayers", action="store_true", help="Download SynLayers ckpt folder into project root.", ) parser.add_argument( "--synlayers-repo", dest="synlayers_repo", default=DEFAULT_SYNLAYERS_REPO, help="Hugging Face repo for SynLayers-compatible checkpoints.", ) parser.add_argument( "--skip-flux", action="store_true", help="Skip downloading FLUX.1-dev weights.", ) parser.add_argument( "--skip-adapter", action="store_true", help="Skip downloading FLUX.1-dev ControlNet adapter weights.", ) return parser.parse_args() def download_flux(target_dir): logger.info("Downloading FLUX.1-dev to %s", target_dir) snapshot_download( repo_id="black-forest-labs/FLUX.1-dev", local_dir=str(target_dir), local_dir_use_symlinks=False, ) def download_adapter(target_dir): logger.info("Downloading FLUX.1-dev-Controlnet-Inpainting-Alpha to %s", target_dir) snapshot_download( repo_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", local_dir=str(target_dir), local_dir_use_symlinks=False, ) def download_synlayers_ckpt(project_root, repo_id): ckpt_dir = project_root / "ckpt" logger.info("Downloading SynLayers ckpt files into %s", ckpt_dir) snapshot_download( repo_id=repo_id, local_dir=str(ckpt_dir), allow_patterns=[ "decouple_LoRA/**", "pre_trained_LoRA/**", "prism_ft_LoRA/**", "trans_vae/**", "README.md", ], local_dir_use_symlinks=False, ) def main(): logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") args = parse_args() logger.info("Note: dataset JSONL sizes should be multiples of 8 for inference.") project_root = Path(args.project_root).resolve() default_ckpt_root = project_root / "ckpt" flux_dir = Path(args.flux_dir) if args.flux_dir else default_ckpt_root / "FLUX.1-dev" adapter_dir = ( Path(args.adapter_dir) if args.adapter_dir else default_ckpt_root / "FLUX.1-dev-Controlnet-Inpainting-Alpha" ) if not args.skip_flux: flux_dir.mkdir(parents=True, exist_ok=True) download_flux(flux_dir) if not args.skip_adapter: adapter_dir.mkdir(parents=True, exist_ok=True) download_adapter(adapter_dir) if args.download_synlayers: download_synlayers_ckpt(project_root, args.synlayers_repo) if __name__ == "__main__": main()