SynLayers commited on
Commit
0e78a0f
·
verified ·
1 Parent(s): 5f5cc3f

Upload demo/hf_repo_assets.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. demo/hf_repo_assets.py +16 -5
demo/hf_repo_assets.py CHANGED
@@ -15,6 +15,14 @@ def get_cache_dir() -> str | None:
15
  return os.environ.get("SYNLAYERS_HF_CACHE")
16
 
17
 
 
 
 
 
 
 
 
 
18
  @lru_cache(maxsize=4)
19
  def ensure_repo_assets(repo_id: str | None = None) -> Path | None:
20
  """Download required runtime assets from the uploaded model repo when configured."""
@@ -23,13 +31,14 @@ def ensure_repo_assets(repo_id: str | None = None) -> Path | None:
23
  return None
24
 
25
  allow_patterns = [
26
- "SynLayers_checkpoints/FLUX.1-dev/**",
27
  "SynLayers_checkpoints/FLUX.1-dev-Controlnet-Inpainting-Alpha/**",
28
  "SynLayers_ckpt/step_120000/**",
29
  "ckpt/trans_vae/0008000.pt",
30
  "ckpt/pre_trained_LoRA/**",
31
  "ckpt/prism_ft_LoRA/**",
32
  ]
 
 
33
 
34
  local_root = snapshot_download(
35
  repo_id=resolved_repo_id,
@@ -46,12 +55,9 @@ def build_repo_asset_overrides(repo_id: str | None = None) -> dict[str, str]:
46
  if local_root is None:
47
  return {}
48
 
49
- return {
50
  "repo_root": str(local_root),
51
  "decomp_ckpt_root": str(local_root / "SynLayers_ckpt" / "step_120000"),
52
- "pretrained_model_name_or_path": str(
53
- local_root / "SynLayers_checkpoints" / "FLUX.1-dev"
54
- ),
55
  "pretrained_adapter_path": str(
56
  local_root
57
  / "SynLayers_checkpoints"
@@ -61,3 +67,8 @@ def build_repo_asset_overrides(repo_id: str | None = None) -> dict[str, str]:
61
  "pretrained_lora_dir": str(local_root / "ckpt" / "pre_trained_LoRA"),
62
  "artplus_lora_dir": str(local_root / "ckpt" / "prism_ft_LoRA"),
63
  }
 
 
 
 
 
 
15
  return os.environ.get("SYNLAYERS_HF_CACHE")
16
 
17
 
18
+ def should_download_repo_flux() -> bool:
19
+ """Download FLUX from the model repo only when no external base model is configured."""
20
+ base_model = os.environ.get("SYNLAYERS_BASE_MODEL", "").strip()
21
+ if not base_model:
22
+ return True
23
+ return base_model.startswith(get_model_repo_id() or "")
24
+
25
+
26
  @lru_cache(maxsize=4)
27
  def ensure_repo_assets(repo_id: str | None = None) -> Path | None:
28
  """Download required runtime assets from the uploaded model repo when configured."""
 
31
  return None
32
 
33
  allow_patterns = [
 
34
  "SynLayers_checkpoints/FLUX.1-dev-Controlnet-Inpainting-Alpha/**",
35
  "SynLayers_ckpt/step_120000/**",
36
  "ckpt/trans_vae/0008000.pt",
37
  "ckpt/pre_trained_LoRA/**",
38
  "ckpt/prism_ft_LoRA/**",
39
  ]
40
+ if should_download_repo_flux():
41
+ allow_patterns.insert(0, "SynLayers_checkpoints/FLUX.1-dev/**")
42
 
43
  local_root = snapshot_download(
44
  repo_id=resolved_repo_id,
 
55
  if local_root is None:
56
  return {}
57
 
58
+ overrides = {
59
  "repo_root": str(local_root),
60
  "decomp_ckpt_root": str(local_root / "SynLayers_ckpt" / "step_120000"),
 
 
 
61
  "pretrained_adapter_path": str(
62
  local_root
63
  / "SynLayers_checkpoints"
 
67
  "pretrained_lora_dir": str(local_root / "ckpt" / "pre_trained_LoRA"),
68
  "artplus_lora_dir": str(local_root / "ckpt" / "prism_ft_LoRA"),
69
  }
70
+ if should_download_repo_flux():
71
+ overrides["pretrained_model_name_or_path"] = str(
72
+ local_root / "SynLayers_checkpoints" / "FLUX.1-dev"
73
+ )
74
+ return overrides