File size: 3,423 Bytes
3a8ebe3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse
import logging
from pathlib import Path

from huggingface_hub import snapshot_download


logger = logging.getLogger(__name__)
DEFAULT_SYNLAYERS_REPO = "thuteam/" + "C" + "LD"


def parse_args():
    parser = argparse.ArgumentParser(description="Download SynLayers pretrained checkpoints.")
    parser.add_argument(
        "--project-root",
        default=str(Path(__file__).resolve().parents[1]),
        help="SynLayers project root directory.",
    )
    parser.add_argument(
        "--flux-dir",
        default=None,
        help="Output directory for FLUX.1-dev weights.",
    )
    parser.add_argument(
        "--adapter-dir",
        default=None,
        help="Output directory for FLUX.1-dev ControlNet adapter weights.",
    )
    parser.add_argument(
        "--download-synlayers",
        dest="download_synlayers",
        action="store_true",
        help="Download SynLayers ckpt folder into project root.",
    )
    parser.add_argument(
        "--synlayers-repo",
        dest="synlayers_repo",
        default=DEFAULT_SYNLAYERS_REPO,
        help="Hugging Face repo for SynLayers-compatible checkpoints.",
    )
    parser.add_argument(
        "--skip-flux",
        action="store_true",
        help="Skip downloading FLUX.1-dev weights.",
    )
    parser.add_argument(
        "--skip-adapter",
        action="store_true",
        help="Skip downloading FLUX.1-dev ControlNet adapter weights.",
    )
    return parser.parse_args()


def download_flux(target_dir):
    logger.info("Downloading FLUX.1-dev to %s", target_dir)
    snapshot_download(
        repo_id="black-forest-labs/FLUX.1-dev",
        local_dir=str(target_dir),
        local_dir_use_symlinks=False,
    )


def download_adapter(target_dir):
    logger.info("Downloading FLUX.1-dev-Controlnet-Inpainting-Alpha to %s", target_dir)
    snapshot_download(
        repo_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha",
        local_dir=str(target_dir),
        local_dir_use_symlinks=False,
    )


def download_synlayers_ckpt(project_root, repo_id):
    ckpt_dir = project_root / "ckpt"
    logger.info("Downloading SynLayers ckpt files into %s", ckpt_dir)
    snapshot_download(
        repo_id=repo_id,
        local_dir=str(ckpt_dir),
        allow_patterns=[
            "decouple_LoRA/**",
            "pre_trained_LoRA/**",
            "prism_ft_LoRA/**",
            "trans_vae/**",
            "README.md",
        ],
        local_dir_use_symlinks=False,
    )


def main():
    logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
    args = parse_args()
    logger.info("Note: dataset JSONL sizes should be multiples of 8 for inference.")
    project_root = Path(args.project_root).resolve()
    default_ckpt_root = project_root / "ckpt"
    flux_dir = Path(args.flux_dir) if args.flux_dir else default_ckpt_root / "FLUX.1-dev"
    adapter_dir = (
        Path(args.adapter_dir)
        if args.adapter_dir
        else default_ckpt_root / "FLUX.1-dev-Controlnet-Inpainting-Alpha"
    )

    if not args.skip_flux:
        flux_dir.mkdir(parents=True, exist_ok=True)
        download_flux(flux_dir)
    if not args.skip_adapter:
        adapter_dir.mkdir(parents=True, exist_ok=True)
        download_adapter(adapter_dir)
    if args.download_synlayers:
        download_synlayers_ckpt(project_root, args.synlayers_repo)


if __name__ == "__main__":
    main()