| """ |
| TinyFlux-Deep Weight Converter: v3 โ v4 |
| |
| Converts v3 checkpoints to v4.1 architecture without destroying pretrained weights. |
| |
| Changes from v3 โ v4: |
| - expert_predictor โ lune_predictor (rename) |
| - expert_gate: raw value โ logit space (sigmoid(0)=0.5 preserved) |
| - NEW: sol_prior (attention statistics predictor, 70% geometric prior) |
| - NEW: t5_pool + text_balance (T5 vec pathway, 50/50 init) |
| - NEW: spatial_to_mod per attention layer (zero-init = identity) |
| |
| All new modules initialize to zero-effect, so converted model behaves |
| identically to v3 on first forward pass. |
| |
| Colab: |
| from convert_v3_to_v4 import run |
| run(401434) |
| |
| API: |
| from convert_v3_to_v4 import convert_checkpoint, load_config |
| config = load_config("path/to/config.json") |
| result = convert_checkpoint(step=401434, config=config) |
| |
| CLI: |
| python convert_v3_to_v4.py --step 401434 |
| python convert_v3_to_v4.py --step 401434 --config my_config.json |
| """ |
|
|
| __version__ = "4.1.0" |
|
|
| import torch |
| import torch.nn as nn |
| import math |
| import os |
| import re |
| import json |
| from typing import Dict, Tuple, Optional, Union, List |
| from dataclasses import dataclass, field, asdict |
| from pathlib import Path |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class TinyFluxConfig: |
| """ |
| TinyFlux-Deep v4.1 model configuration. |
| |
| This config fully defines the model architecture and can be used to: |
| 1. Initialize a new model |
| 2. Convert checkpoints between versions |
| 3. Validate checkpoint compatibility |
| |
| All dimension constraints are validated on creation. |
| """ |
| |
| hidden_size: int = 512 |
| num_attention_heads: int = 4 |
| attention_head_dim: int = 128 |
| in_channels: int = 16 |
| patch_size: int = 1 |
| joint_attention_dim: int = 768 |
| pooled_projection_dim: int = 768 |
| num_double_layers: int = 15 |
| num_single_layers: int = 25 |
| mlp_ratio: float = 4.0 |
| axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) |
| |
| |
| use_lune_expert: bool = True |
| lune_expert_dim: int = 1280 |
| lune_hidden_dim: int = 512 |
| lune_dropout: float = 0.1 |
| |
| |
| use_sol_prior: bool = True |
| sol_spatial_size: int = 8 |
| sol_hidden_dim: int = 256 |
| sol_geometric_weight: float = 0.7 |
| |
| |
| use_t5_vec: bool = True |
| t5_pool_mode: str = "attention" |
| |
| |
| lune_distill_mode: str = "cosine" |
| use_huber_loss: bool = True |
| huber_delta: float = 0.1 |
| |
| |
| guidance_embeds: bool = False |
| |
| def __post_init__(self): |
| """Validate configuration constraints.""" |
| |
| expected_hidden = self.num_attention_heads * self.attention_head_dim |
| if self.hidden_size != expected_hidden: |
| raise ValueError( |
| f"hidden_size ({self.hidden_size}) must equal " |
| f"num_attention_heads * attention_head_dim ({expected_hidden})" |
| ) |
| |
| |
| if isinstance(self.axes_dims_rope, list): |
| self.axes_dims_rope = tuple(self.axes_dims_rope) |
| |
| rope_sum = sum(self.axes_dims_rope) |
| if rope_sum != self.attention_head_dim: |
| raise ValueError( |
| f"sum(axes_dims_rope) ({rope_sum}) must equal " |
| f"attention_head_dim ({self.attention_head_dim})" |
| ) |
| |
| |
| if not 0.0 <= self.sol_geometric_weight <= 1.0: |
| raise ValueError(f"sol_geometric_weight must be in [0, 1], got {self.sol_geometric_weight}") |
| |
| |
| @property |
| def time_dim(self) -> int: |
| return self.hidden_size |
| |
| @property |
| def clip_dim(self) -> int: |
| return self.pooled_projection_dim |
| |
| @property |
| def num_heads(self) -> int: |
| return self.num_attention_heads |
| |
| @property |
| def num_double_blocks(self) -> int: |
| return self.num_double_layers |
| |
| @property |
| def num_single_blocks(self) -> int: |
| return self.num_single_layers |
| |
| def to_dict(self) -> Dict: |
| """Convert to JSON-serializable dict.""" |
| d = asdict(self) |
| d["axes_dims_rope"] = list(d["axes_dims_rope"]) |
| return d |
| |
| @classmethod |
| def from_dict(cls, d: Dict) -> "TinyFluxConfig": |
| """Create from dict, ignoring unknown keys.""" |
| |
| known_fields = {f.name for f in cls.__dataclass_fields__.values()} |
| filtered = {k: v for k, v in d.items() if k in known_fields and not k.startswith("_")} |
| return cls(**filtered) |
| |
| def validate_checkpoint(self, state_dict: Dict[str, torch.Tensor]) -> List[str]: |
| """ |
| Validate that a checkpoint matches this config. |
| |
| Returns list of warnings (empty if perfect match). |
| """ |
| warnings = [] |
| |
| |
| max_double = 0 |
| for key in state_dict: |
| if key.startswith("double_blocks."): |
| idx = int(key.split(".")[1]) |
| max_double = max(max_double, idx + 1) |
| if max_double != self.num_double_layers: |
| warnings.append(f"double_blocks: checkpoint has {max_double}, config expects {self.num_double_layers}") |
| |
| |
| max_single = 0 |
| for key in state_dict: |
| if key.startswith("single_blocks."): |
| idx = int(key.split(".")[1]) |
| max_single = max(max_single, idx + 1) |
| if max_single != self.num_single_layers: |
| warnings.append(f"single_blocks: checkpoint has {max_single}, config expects {self.num_single_layers}") |
| |
| |
| if "img_embed.proj.weight" in state_dict: |
| w = state_dict["img_embed.proj.weight"] |
| if w.shape[0] != self.hidden_size: |
| warnings.append(f"hidden_size: checkpoint has {w.shape[0]}, config expects {self.hidden_size}") |
| |
| return warnings |
|
|
|
|
| def load_config(path: Union[str, Path]) -> TinyFluxConfig: |
| """ |
| Load config from JSON file. |
| |
| Args: |
| path: Path to config JSON file |
| |
| Returns: |
| TinyFluxConfig instance |
| """ |
| with open(path) as f: |
| d = json.load(f) |
| return TinyFluxConfig.from_dict(d) |
|
|
|
|
| def save_config(config: TinyFluxConfig, path: Union[str, Path], conversion_info: Optional[Dict] = None): |
| """ |
| Save config to JSON file. |
| |
| Args: |
| config: TinyFluxConfig instance |
| path: Output path |
| conversion_info: Optional metadata about conversion |
| """ |
| d = config.to_dict() |
| if conversion_info: |
| d["_conversion_info"] = conversion_info |
| |
| with open(path, "w") as f: |
| json.dump(d, f, indent=2) |
|
|
|
|
| |
| DEFAULT_CONFIG = TinyFluxConfig() |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class CheckpointInfo: |
| """Analysis results for a checkpoint.""" |
| version: str = "unknown" |
| has_expert_predictor: bool = False |
| has_lune_predictor: bool = False |
| has_sol_prior: bool = False |
| has_t5_pool: bool = False |
| has_spatial_to_mod: bool = False |
| num_double_blocks: int = 0 |
| num_single_blocks: int = 0 |
| total_params: int = 0 |
| dtype: str = "float32" |
|
|
|
|
| def analyze_checkpoint(state_dict: Dict[str, torch.Tensor]) -> CheckpointInfo: |
| """ |
| Analyze a checkpoint to determine version and contents. |
| |
| Args: |
| state_dict: Model state dictionary |
| |
| Returns: |
| CheckpointInfo with analysis results |
| """ |
| info = CheckpointInfo() |
| info.total_params = sum(p.numel() for p in state_dict.values()) |
| |
| |
| for v in state_dict.values(): |
| info.dtype = str(v.dtype).replace("torch.", "") |
| break |
| |
| for key in state_dict.keys(): |
| if key.startswith("expert_predictor."): |
| info.has_expert_predictor = True |
| if key.startswith("lune_predictor."): |
| info.has_lune_predictor = True |
| if key.startswith("sol_prior."): |
| info.has_sol_prior = True |
| if key.startswith("t5_pool."): |
| info.has_t5_pool = True |
| if "spatial_to_mod" in key: |
| info.has_spatial_to_mod = True |
| if key.startswith("double_blocks."): |
| idx = int(key.split(".")[1]) |
| info.num_double_blocks = max(info.num_double_blocks, idx + 1) |
| if key.startswith("single_blocks."): |
| idx = int(key.split(".")[1]) |
| info.num_single_blocks = max(info.num_single_blocks, idx + 1) |
| |
| |
| if info.has_lune_predictor and info.has_sol_prior and info.has_t5_pool: |
| info.version = "v4.1" |
| elif info.has_lune_predictor and info.has_sol_prior: |
| info.version = "v4.0" |
| elif info.has_expert_predictor: |
| info.version = "v3" |
| elif info.has_lune_predictor: |
| info.version = "v3.5" |
| else: |
| info.version = "v2_or_earlier" |
| |
| return info |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class ConversionResult: |
| """Results from a conversion operation.""" |
| success: bool |
| model_path: Optional[str] = None |
| ema_path: Optional[str] = None |
| ema_secondary_path: Optional[str] = None |
| config_path: Optional[str] = None |
| source_version: str = "unknown" |
| target_version: str = "v4.1" |
| source_params: int = 0 |
| target_params: int = 0 |
| params_added: int = 0 |
| error: Optional[str] = None |
|
|
|
|
| |
| |
| |
|
|
| def run( |
| step: int = 401434, |
| name: str = "lailah", |
| output_dir: str = "checkpoint_runs/v4_init", |
| repo_id: str = "AbstractPhil/tiny-flux-deep", |
| upload_repo: str = "AbstractPhil/tiny-flux-deep", |
| upload_subdir: str = "checkpoint_runs/v4_init", |
| config: Optional[Union[TinyFluxConfig, Dict, str]] = None, |
| ): |
| """ |
| One-liner for Colab. Downloads, converts, saves locally, uploads to HF. |
| |
| Args: |
| step: Checkpoint step number to download |
| name: Model name prefix for output files |
| output_dir: Local output directory |
| repo_id: HuggingFace repo to download from |
| upload_repo: HuggingFace repo to upload to |
| upload_subdir: Subdirectory in upload repo |
| config: Model config - can be: |
| - None (use default) |
| - TinyFluxConfig instance |
| - Dict with config values |
| - Path to config JSON file |
| |
| Usage: |
| from convert_v3_to_v4 import run |
| run(401434) |
| |
| # With custom config |
| run(401434, config={"hidden_size": 768, ...}) |
| run(401434, config="path/to/config.json") |
| """ |
| |
| if config is None: |
| cfg = DEFAULT_CONFIG |
| elif isinstance(config, TinyFluxConfig): |
| cfg = config |
| elif isinstance(config, dict): |
| cfg = TinyFluxConfig.from_dict(config) |
| elif isinstance(config, (str, Path)): |
| cfg = load_config(config) |
| else: |
| raise TypeError(f"config must be TinyFluxConfig, dict, path, or None, got {type(config)}") |
| |
| print(f"TinyFlux-Deep v3 โ v4.1 Converter") |
| print(f"=" * 50) |
| print(f"Config: hidden_size={cfg.hidden_size}, heads={cfg.num_attention_heads}") |
| print(f" double_layers={cfg.num_double_layers}, single_layers={cfg.num_single_layers}") |
| |
| result = convert_checkpoint( |
| step=step, |
| model_name=name, |
| output_dir=output_dir, |
| repo_id=repo_id, |
| checkpoint_dir="checkpoints", |
| config=cfg, |
| verbose=True, |
| ) |
| |
| if not result.success: |
| print(f"\nโ Conversion failed: {result.error}") |
| return result |
| |
| print(f"\nโ
Conversion complete!") |
| print(f" Source: {result.source_version} ({result.source_params:,} params)") |
| print(f" Target: {result.target_version} ({result.target_params:,} params)") |
| print(f" Added: {result.params_added:,} params") |
| |
| |
| config_path = os.path.join(output_dir, f"{name}_{step}_v4_config.json") |
| conversion_info = { |
| "source_step": step, |
| "source_repo": repo_id, |
| "source_version": result.source_version, |
| "target_version": result.target_version, |
| "source_params": result.source_params, |
| "target_params": result.target_params, |
| "params_added": result.params_added, |
| "converter_version": __version__, |
| "files": { |
| "model": os.path.basename(result.model_path) if result.model_path else None, |
| "ema": os.path.basename(result.ema_path) if result.ema_path else None, |
| "ema_secondary": os.path.basename(result.ema_secondary_path) if result.ema_secondary_path else None, |
| }, |
| } |
| save_config(cfg, config_path, conversion_info) |
| result.config_path = config_path |
| print(f"๐พ Config: {config_path}") |
| |
| |
| from huggingface_hub import HfApi |
| api = HfApi() |
| |
| print(f"\n๐ค Uploading to {upload_repo}/{upload_subdir}/...") |
| |
| files_to_upload = [ |
| result.model_path, |
| result.ema_path, |
| result.ema_secondary_path, |
| config_path, |
| ] |
| |
| for local_path in files_to_upload: |
| if local_path and os.path.exists(local_path): |
| filename = os.path.basename(local_path) |
| remote_path = f"{upload_subdir}/{filename}" |
| |
| api.upload_file( |
| path_or_fileobj=local_path, |
| path_in_repo=remote_path, |
| repo_id=upload_repo, |
| ) |
| print(f" โ {remote_path}") |
| |
| print(f"\nโ
Uploaded to {upload_repo}/{upload_subdir}/") |
| |
| return result |
|
|
|
|
| |
| |
| |
|
|
| def to_logit(p: float) -> float: |
| """Convert probability to logit for sigmoid init.""" |
| p = max(1e-4, min(p, 1 - 1e-4)) |
| return math.log(p / (1 - p)) |
|
|
|
|
| def create_sol_prior_init( |
| config: TinyFluxConfig, |
| dtype: torch.dtype = torch.float32, |
| ) -> Dict[str, torch.Tensor]: |
| """Create zero-effect initialization for SolAttentionPrior.""" |
| init = {} |
| hidden_dim = config.sol_hidden_dim |
| time_dim = config.time_dim |
| clip_dim = config.clip_dim |
| num_heads = config.num_heads |
| spatial_size = config.sol_spatial_size |
| |
| |
| w0 = torch.empty(hidden_dim, time_dim + clip_dim, dtype=dtype) |
| nn.init.xavier_uniform_(w0, gain=0.1) |
| init['sol_prior.stat_predictor.0.weight'] = w0 |
| init['sol_prior.stat_predictor.0.bias'] = torch.zeros(hidden_dim, dtype=dtype) |
| |
| w1 = torch.empty(hidden_dim, hidden_dim, dtype=dtype) |
| nn.init.xavier_uniform_(w1, gain=0.1) |
| init['sol_prior.stat_predictor.2.weight'] = w1 |
| init['sol_prior.stat_predictor.2.bias'] = torch.zeros(hidden_dim, dtype=dtype) |
| |
| w2 = torch.empty(3, hidden_dim, dtype=dtype) |
| nn.init.xavier_uniform_(w2, gain=0.1) |
| init['sol_prior.stat_predictor.4.weight'] = w2 |
| init['sol_prior.stat_predictor.4.bias'] = torch.zeros(3, dtype=dtype) |
| |
| |
| w0 = torch.empty(hidden_dim, time_dim + clip_dim, dtype=dtype) |
| nn.init.xavier_uniform_(w0, gain=0.1) |
| init['sol_prior.spatial_predictor.0.weight'] = w0 |
| init['sol_prior.spatial_predictor.0.bias'] = torch.zeros(hidden_dim, dtype=dtype) |
| |
| w1 = torch.empty(hidden_dim, hidden_dim, dtype=dtype) |
| nn.init.xavier_uniform_(w1, gain=0.1) |
| init['sol_prior.spatial_predictor.2.weight'] = w1 |
| init['sol_prior.spatial_predictor.2.bias'] = torch.zeros(hidden_dim, dtype=dtype) |
| |
| w2 = torch.empty(spatial_size * spatial_size, hidden_dim, dtype=dtype) |
| nn.init.xavier_uniform_(w2, gain=0.1) |
| init['sol_prior.spatial_predictor.4.weight'] = w2 |
| init['sol_prior.spatial_predictor.4.bias'] = torch.zeros(spatial_size * spatial_size, dtype=dtype) |
| |
| |
| w0 = torch.empty(hidden_dim // 2, 3, dtype=dtype) |
| nn.init.xavier_uniform_(w0, gain=0.1) |
| init['sol_prior.stat_to_temperature.0.weight'] = w0 |
| init['sol_prior.stat_to_temperature.0.bias'] = torch.zeros(hidden_dim // 2, dtype=dtype) |
| |
| w1 = torch.empty(num_heads, hidden_dim // 2, dtype=dtype) |
| nn.init.xavier_uniform_(w1, gain=0.1) |
| init['sol_prior.stat_to_temperature.2.weight'] = w1 |
| init['sol_prior.stat_to_temperature.2.bias'] = torch.full((num_heads,), 0.54, dtype=dtype) |
| |
| |
| init['sol_prior.spatial_to_qk_scale.weight'] = torch.zeros(num_heads, 1, dtype=dtype) |
| init['sol_prior.spatial_to_qk_scale.bias'] = torch.ones(num_heads, dtype=dtype) |
| |
| |
| init['sol_prior.blend_gate'] = torch.tensor(to_logit(config.sol_geometric_weight), dtype=dtype) |
| |
| return init |
|
|
|
|
| def create_t5_pool_init( |
| config: TinyFluxConfig, |
| dtype: torch.dtype = torch.float32, |
| ) -> Dict[str, torch.Tensor]: |
| """Create initialization for T5 pool pathway.""" |
| init = {} |
| hidden_size = config.hidden_size |
| joint_attention_dim = config.joint_attention_dim |
| |
| w1 = torch.empty(hidden_size, joint_attention_dim, dtype=dtype) |
| nn.init.xavier_uniform_(w1) |
| init['t5_pool.0.weight'] = w1 |
| init['t5_pool.0.bias'] = torch.zeros(hidden_size, dtype=dtype) |
| |
| w2 = torch.empty(hidden_size, hidden_size, dtype=dtype) |
| nn.init.xavier_uniform_(w2) |
| init['t5_pool.2.weight'] = w2 |
| init['t5_pool.2.bias'] = torch.zeros(hidden_size, dtype=dtype) |
| |
| init['text_balance'] = torch.tensor(0.0, dtype=dtype) |
| |
| return init |
|
|
|
|
| def create_spatial_to_mod_init( |
| num_heads: int = 4, |
| dtype: torch.dtype = torch.float32, |
| ) -> Dict[str, torch.Tensor]: |
| """Create zero-init for spatial_to_mod Conv2d layers.""" |
| return { |
| 'weight': torch.zeros(num_heads, 1, 1, 1, dtype=dtype), |
| 'bias': torch.zeros(num_heads, dtype=dtype), |
| } |
|
|
|
|
| def convert_state_dict( |
| v3_state: Dict[str, torch.Tensor], |
| config: Optional[TinyFluxConfig] = None, |
| ) -> Tuple[Dict[str, torch.Tensor], Dict[str, any]]: |
| """ |
| Convert v3 state dict to v4.1 format. |
| |
| Args: |
| v3_state: v3 state dictionary |
| config: TinyFluxConfig (uses DEFAULT_CONFIG if None) |
| |
| Returns: |
| Tuple of (v4_state_dict, report_dict) |
| """ |
| cfg = config or DEFAULT_CONFIG |
| v3_info = analyze_checkpoint(v3_state) |
| |
| if v3_info.version in ("v4.0", "v4.1"): |
| return v3_state, {'status': 'already_v4', 'source_version': v3_info.version} |
| |
| |
| warnings = cfg.validate_checkpoint(v3_state) |
| if warnings: |
| print(f"โ ๏ธ Config validation warnings:") |
| for w in warnings: |
| print(f" - {w}") |
| |
| sample_key = list(v3_state.keys())[0] |
| dtype = v3_state[sample_key].dtype |
| |
| report = { |
| 'status': 'converted', |
| 'source_version': v3_info.version, |
| 'source_params': v3_info.total_params, |
| 'renamed': [], |
| 'initialized': [], |
| 'modified': [], |
| 'warnings': warnings, |
| } |
| |
| v4_state = {} |
| |
| |
| for key, value in v3_state.items(): |
| if key.startswith('expert_predictor.'): |
| new_key = key.replace('expert_predictor.', 'lune_predictor.') |
| v4_state[new_key] = value |
| report['renamed'].append((key, new_key)) |
| else: |
| v4_state[key] = value |
| |
| |
| gate_key = 'lune_predictor.expert_gate' |
| if gate_key in v4_state: |
| old_val = v4_state[gate_key].item() |
| if abs(old_val - 0.5) < 0.3: |
| new_val = to_logit(old_val) |
| v4_state[gate_key] = torch.tensor(new_val, dtype=dtype) |
| report['modified'].append((gate_key, f'{old_val:.4f} โ {new_val:.4f}')) |
| |
| |
| if not v3_info.has_sol_prior and cfg.use_sol_prior: |
| sol_init = create_sol_prior_init(cfg, dtype) |
| v4_state.update(sol_init) |
| report['initialized'].extend(list(sol_init.keys())) |
| |
| |
| if not v3_info.has_t5_pool and cfg.use_t5_vec: |
| t5_init = create_t5_pool_init(cfg, dtype) |
| v4_state.update(t5_init) |
| report['initialized'].extend(list(t5_init.keys())) |
| |
| |
| if not v3_info.has_spatial_to_mod and cfg.use_sol_prior: |
| spatial_init = create_spatial_to_mod_init(cfg.num_heads, dtype) |
| |
| for i in range(cfg.num_double_blocks): |
| prefix = f'double_blocks.{i}.attn.spatial_to_mod.' |
| v4_state[prefix + 'weight'] = spatial_init['weight'].clone() |
| v4_state[prefix + 'bias'] = spatial_init['bias'].clone() |
| report['initialized'].extend([prefix + 'weight', prefix + 'bias']) |
| |
| for i in range(cfg.num_single_blocks): |
| prefix = f'single_blocks.{i}.attn.spatial_to_mod.' |
| v4_state[prefix + 'weight'] = spatial_init['weight'].clone() |
| v4_state[prefix + 'bias'] = spatial_init['bias'].clone() |
| report['initialized'].extend([prefix + 'weight', prefix + 'bias']) |
| |
| report['target_params'] = sum(p.numel() for p in v4_state.values()) |
| report['params_added'] = report['target_params'] - report['source_params'] |
| |
| return v4_state, report |
|
|
|
|
| |
| |
| |
|
|
| def download_from_hf( |
| step: int, |
| repo_id: str = "AbstractPhil/tiny-flux-deep", |
| checkpoint_dir: str = "checkpoints", |
| local_dir: str = "./downloads", |
| include_ema: bool = True, |
| ) -> Tuple[str, Optional[str]]: |
| """ |
| Download checkpoint from HuggingFace. |
| |
| Args: |
| step: Step number to download |
| repo_id: HuggingFace repository ID |
| checkpoint_dir: Subdirectory in repo containing checkpoints |
| local_dir: Local directory to download to |
| include_ema: Whether to also download EMA weights |
| |
| Returns: |
| Tuple of (model_path, ema_path). ema_path may be None. |
| """ |
| from huggingface_hub import hf_hub_download |
| |
| model_filename = f"{checkpoint_dir}/step_{step}.safetensors" |
| model_path = hf_hub_download( |
| repo_id=repo_id, |
| filename=model_filename, |
| local_dir=local_dir, |
| ) |
| |
| ema_path = None |
| if include_ema: |
| ema_filename = f"{checkpoint_dir}/step_{step}_ema.safetensors" |
| try: |
| ema_path = hf_hub_download( |
| repo_id=repo_id, |
| filename=ema_filename, |
| local_dir=local_dir, |
| ) |
| except Exception: |
| pass |
| |
| return model_path, ema_path |
|
|
|
|
| def convert_checkpoint( |
| step: Optional[int] = None, |
| input_path: Optional[str] = None, |
| ema_input_path: Optional[str] = None, |
| output_dir: str = "checkpoint_runs/v4_init", |
| model_name: str = "lailah", |
| repo_id: str = "AbstractPhil/tiny-flux-deep", |
| checkpoint_dir: str = "checkpoints", |
| create_fresh_ema: bool = True, |
| preserve_secondary_ema: bool = True, |
| config: Optional[TinyFluxConfig] = None, |
| verbose: bool = True, |
| ) -> ConversionResult: |
| """ |
| Convert a v3 checkpoint to v4.1 format. |
| |
| Either `step` (to download from HF) or `input_path` (for local file) must be provided. |
| |
| Args: |
| step: Step number to download from HuggingFace |
| input_path: Path to local v3 checkpoint |
| ema_input_path: Path to local v3 EMA checkpoint |
| output_dir: Directory to save converted checkpoints |
| model_name: Prefix for output filenames |
| repo_id: HuggingFace repository ID (if using step) |
| checkpoint_dir: Subdirectory in repo (if using step) |
| create_fresh_ema: Create a fresh EMA from converted weights |
| preserve_secondary_ema: Convert and preserve old EMA as secondary |
| config: TinyFluxConfig for model architecture |
| verbose: Print progress messages |
| |
| Returns: |
| ConversionResult with paths and statistics |
| """ |
| from safetensors.torch import load_file, save_file |
| |
| cfg = config or DEFAULT_CONFIG |
| result = ConversionResult(success=False) |
| |
| try: |
| |
| if step is not None: |
| if verbose: |
| print(f"๐ฅ Downloading step_{step} from {repo_id}...") |
| model_path, ema_path = download_from_hf( |
| step=step, |
| repo_id=repo_id, |
| checkpoint_dir=checkpoint_dir, |
| ) |
| if verbose: |
| print(f" โ Model: {model_path}") |
| if ema_path: |
| print(f" โ EMA: {ema_path}") |
| elif input_path is not None: |
| model_path = input_path |
| ema_path = ema_input_path |
| match = re.search(r'step_(\d+)', model_path) |
| step = int(match.group(1)) if match else 0 |
| else: |
| result.error = "Must provide either step or input_path" |
| return result |
| |
| |
| if verbose: |
| print(f"\n๐ Converting to v4.1...") |
| |
| v3_state = load_file(model_path) |
| v4_state, report = convert_state_dict(v3_state, cfg) |
| |
| result.source_version = report['source_version'] |
| result.target_version = "v4.1" |
| result.source_params = report.get('source_params', 0) |
| result.target_params = report.get('target_params', 0) |
| result.params_added = report.get('params_added', 0) |
| |
| if verbose: |
| print(f" Source: {result.source_version} ({result.source_params:,} params)") |
| print(f" Target: {result.target_version} ({result.target_params:,} params)") |
| print(f" Added: {result.params_added:,} params") |
| |
| |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| model_out = os.path.join(output_dir, f"{model_name}_{step}_v4_init.safetensors") |
| save_file(v4_state, model_out) |
| result.model_path = model_out |
| if verbose: |
| print(f"\n๐พ Model: {model_out}") |
| |
| |
| if create_fresh_ema: |
| ema_out = os.path.join(output_dir, f"{model_name}_{step}_v4_init_ema.safetensors") |
| save_file(v4_state, ema_out) |
| result.ema_path = ema_out |
| if verbose: |
| print(f"๐พ EMA (fresh): {ema_out}") |
| |
| |
| if preserve_secondary_ema and ema_path and os.path.exists(ema_path): |
| if verbose: |
| print(f"\n๐ Converting old EMA...") |
| try: |
| old_ema_state = load_file(ema_path) |
| old_ema_v4, _ = convert_state_dict(old_ema_state, cfg) |
| ema_secondary_out = os.path.join(output_dir, f"{model_name}_{step}_v4_init_ema_secondary.safetensors") |
| save_file(old_ema_v4, ema_secondary_out) |
| result.ema_secondary_path = ema_secondary_out |
| if verbose: |
| print(f"๐พ EMA (secondary): {ema_secondary_out}") |
| except Exception as e: |
| if verbose: |
| print(f"โ Failed to convert old EMA: {e}") |
| |
| result.success = True |
| |
| except Exception as e: |
| result.error = str(e) |
| if verbose: |
| print(f"โ Error: {e}") |
| |
| return result |
|
|
|
|
| |
| |
| |
|
|
| def create_parser(): |
| """Create argument parser for CLI.""" |
| import argparse |
| |
| parser = argparse.ArgumentParser( |
| description='Convert TinyFlux-Deep v3 checkpoints to v4 format', |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| Examples: |
| python convert_v3_to_v4.py --step 401434 |
| python convert_v3_to_v4.py --input model_v3.safetensors |
| python convert_v3_to_v4.py --step 401434 --analyze-only |
| python convert_v3_to_v4.py --step 401434 --output-dir my_converted --name mymodel |
| """ |
| ) |
| |
| |
| input_group = parser.add_argument_group('Input (one required)') |
| input_group.add_argument('--step', type=int, help='Step number to download from HuggingFace') |
| input_group.add_argument('--input', '-i', dest='input_path', help='Path to local v3 checkpoint') |
| input_group.add_argument('--ema-input', dest='ema_input_path', help='Path to local v3 EMA checkpoint') |
| |
| |
| hf_group = parser.add_argument_group('HuggingFace options') |
| hf_group.add_argument('--repo', default='AbstractPhil/tiny-flux-deep', help='HuggingFace repo ID') |
| hf_group.add_argument('--checkpoint-dir', default='checkpoints', help='Subdirectory in repo') |
| |
| |
| output_group = parser.add_argument_group('Output options') |
| output_group.add_argument('--output-dir', '-o', default='checkpoint_runs/v4_init', help='Output directory') |
| output_group.add_argument('--name', default='lailah', help='Model name prefix') |
| |
| |
| conv_group = parser.add_argument_group('Conversion options') |
| conv_group.add_argument('--no-fresh-ema', action='store_true', help='Do not create fresh EMA') |
| conv_group.add_argument('--no-secondary-ema', action='store_true', help='Do not preserve old EMA') |
| conv_group.add_argument('--analyze-only', action='store_true', help='Only analyze, do not convert') |
| conv_group.add_argument('--quiet', '-q', action='store_true', help='Suppress progress messages') |
| |
| return parser |
|
|
|
|
| def cli_main(): |
| """CLI entry point.""" |
| parser = create_parser() |
| args = parser.parse_args() |
| |
| if not args.step and not args.input_path: |
| parser.error("Must specify either --step or --input") |
| |
| |
| if args.analyze_only: |
| from safetensors.torch import load_file |
| |
| if args.step: |
| model_path, _ = download_from_hf( |
| step=args.step, |
| repo_id=args.repo, |
| checkpoint_dir=args.checkpoint_dir, |
| ) |
| else: |
| model_path = args.input_path |
| |
| state = load_file(model_path) |
| info = analyze_checkpoint(state) |
| |
| print(f"\nCheckpoint: {model_path}") |
| print(f" Version: {info.version}") |
| print(f" Total params: {info.total_params:,}") |
| print(f" Double blocks: {info.num_double_blocks}") |
| print(f" Single blocks: {info.num_single_blocks}") |
| print(f" Has expert_predictor: {info.has_expert_predictor}") |
| print(f" Has lune_predictor: {info.has_lune_predictor}") |
| print(f" Has sol_prior: {info.has_sol_prior}") |
| print(f" Has t5_pool: {info.has_t5_pool}") |
| print(f" Has spatial_to_mod: {info.has_spatial_to_mod}") |
| return |
| |
| |
| result = convert_checkpoint( |
| step=args.step, |
| input_path=args.input_path, |
| ema_input_path=args.ema_input_path, |
| output_dir=args.output_dir, |
| model_name=args.name, |
| repo_id=args.repo, |
| checkpoint_dir=args.checkpoint_dir, |
| create_fresh_ema=not args.no_fresh_ema, |
| preserve_secondary_ema=not args.no_secondary_ema, |
| verbose=not args.quiet, |
| ) |
| |
| if result.success: |
| if not args.quiet: |
| print("\n" + "=" * 60) |
| print("โ
Conversion complete!") |
| print("=" * 60) |
| print(f"\nOutput files:") |
| if result.model_path: |
| print(f" Model: {result.model_path}") |
| if result.ema_path: |
| print(f" EMA: {result.ema_path}") |
| if result.ema_secondary_path: |
| print(f" EMA (secondary): {result.ema_secondary_path}") |
| else: |
| print(f"\nโ Conversion failed: {result.error}") |
| exit(1) |
|
|
|
|
| if __name__ == '__main__': |
| cli_main() |