jsflow / REG /train.py
xiangzai's picture
Add files using upload-large-folder tool
b65e56d verified
import argparse
import copy
from copy import deepcopy
import logging
import os
from pathlib import Path
from collections import OrderedDict
import json
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from models.sit import SiT_models
from loss import SILoss
from utils import load_encoders
from dataset import CustomDataset
from diffusers.models import AutoencoderKL
# import wandb_utils
import wandb
import math
from torchvision.utils import make_grid
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from PIL import Image
logger = get_logger(__name__)
def semantic_dim_from_enc_type(enc_type):
"""DINOv2 等 enc_type 字符串推断 class token 维度(与预处理特征一致)。"""
if enc_type is None:
return 768
s = str(enc_type).lower()
if "vit-g" in s or "vitg" in s:
return 1536
if "vit-l" in s or "vitl" in s:
return 1024
if "vit-s" in s or "vits" in s:
return 384
return 768
CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)
def preprocess_raw_image(x, enc_type):
resolution = x.shape[-1]
if 'clip' in enc_type:
x = x / 255.
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x)
elif 'mocov3' in enc_type or 'mae' in enc_type:
x = x / 255.
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
elif 'dinov2' in enc_type:
x = x / 255.
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
elif 'dinov1' in enc_type:
x = x / 255.
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
elif 'jepa' in enc_type:
x = x / 255.
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
return x
def array2grid(x):
nrow = round(math.sqrt(x.size(0)))
x = make_grid(x.clamp(0, 1), nrow=nrow, value_range=(0, 1))
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
return x
@torch.no_grad()
def sample_posterior(moments, latents_scale=1., latents_bias=0.):
device = moments.device
mean, std = torch.chunk(moments, 2, dim=1)
z = mean + std * torch.randn_like(mean)
z = (z * latents_scale + latents_bias)
return z
@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
"""
Step the EMA model towards the current model.
"""
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
name = name.replace("module.", "")
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
def create_logger(logging_dir):
"""
Create a logger that writes to a log file and stdout.
"""
logging.basicConfig(
level=logging.INFO,
format='[\033[34m%(asctime)s\033[0m] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
)
logger = logging.getLogger(__name__)
return logger
def requires_grad(model, flag=True):
"""
Set requires_grad flag for all parameters in a model.
"""
for p in model.parameters():
p.requires_grad = flag
#################################################################################
# Training Loop #
#################################################################################
def main(args):
# set accelerator
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(
project_dir=args.output_dir, logging_dir=logging_dir
)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)]
)
if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
save_dir = os.path.join(args.output_dir, args.exp_name)
os.makedirs(save_dir, exist_ok=True)
args_dict = vars(args)
# Save to a JSON file
json_dir = os.path.join(save_dir, "args.json")
with open(json_dir, 'w') as f:
json.dump(args_dict, f, indent=4)
checkpoint_dir = f"{save_dir}/checkpoints" # Stores saved model checkpoints
os.makedirs(checkpoint_dir, exist_ok=True)
logger = create_logger(save_dir)
logger.info(f"Experiment directory created at {save_dir}")
device = accelerator.device
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.seed is not None:
set_seed(args.seed + accelerator.process_index)
# Create model:
assert args.resolution % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
latent_size = args.resolution // 8
train_dataset = CustomDataset(
args.data_dir, semantic_features_dir=args.semantic_features_dir
)
use_preprocessed_semantic = train_dataset.use_preprocessed_semantic
if use_preprocessed_semantic:
encoders, encoder_types, architectures = [], [], []
z_dims = [semantic_dim_from_enc_type(args.enc_type)]
if accelerator.is_main_process:
logger.info(
f"Preprocessed semantic features: skip loading online encoder, z_dims={z_dims}"
)
elif args.enc_type is not None:
encoders, encoder_types, architectures = load_encoders(
args.enc_type, device, args.resolution
)
z_dims = [encoder.embed_dim for encoder in encoders]
else:
raise NotImplementedError()
block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
model = SiT_models[args.model](
input_size=latent_size,
num_classes=args.num_classes,
use_cfg = (args.cfg_prob > 0),
z_dims = z_dims,
encoder_depth=args.encoder_depth,
**block_kwargs
)
model = model.to(device)
ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
requires_grad(ema, False)
latents_scale = torch.tensor(
[0.18215, 0.18215, 0.18215, 0.18215]
).view(1, 4, 1, 1).to(device)
latents_bias = torch.tensor(
[0., 0., 0., 0.]
).view(1, 4, 1, 1).to(device)
# VAE decoder:采样阶段将 latent 解码为图像(与根目录 train.py / 预处理一致:sd-vae-ft-mse)
try:
from preprocessing import dnnlib
cache_dir = dnnlib.make_cache_dir_path("diffusers")
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
os.environ["HF_HOME"] = cache_dir
try:
vae = AutoencoderKL.from_pretrained(
"stabilityai/sd-vae-ft-mse",
cache_dir=cache_dir,
local_files_only=True,
).to(device)
vae.eval()
if accelerator.is_main_process:
logger.info(
"Loaded VAE 'stabilityai/sd-vae-ft-mse' from local diffusers cache "
f"at '{cache_dir}' for intermediate sampling."
)
except Exception as e_main:
vae = None
candidate_dir = None
possible_roots = [
cache_dir,
os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"),
os.path.join(os.path.expanduser("~"), ".cache", "diffusers"),
os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"),
]
checked_roots = []
for root_dir in possible_roots:
if not os.path.isdir(root_dir):
continue
checked_roots.append(root_dir)
for root, dirs, files in os.walk(root_dir):
if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"):
candidate_dir = root
break
if candidate_dir is not None:
break
if candidate_dir is not None:
try:
vae = AutoencoderKL.from_pretrained(
candidate_dir,
local_files_only=True,
).to(device)
vae.eval()
if accelerator.is_main_process:
logger.info(
"Loaded VAE 'stabilityai/sd-vae-ft-mse' from discovered local path "
f"'{candidate_dir}'. Searched roots: {checked_roots}"
)
except Exception as e_fallback:
if accelerator.is_main_process:
logger.warning(
"Tried to load VAE from discovered local path "
f"'{candidate_dir}' but failed: {e_fallback}"
)
if vae is None and accelerator.is_main_process:
logger.warning(
"Could not load VAE 'stabilityai/sd-vae-ft-mse' via repo name or local search. "
f"Last repo-level error: {e_main}"
)
except Exception as e:
vae = None
if accelerator.is_main_process:
logger.warning(
f"Failed to initialize VAE loading logic (will skip image decoding): {e}"
)
# create loss function
loss_fn = SILoss(
prediction=args.prediction,
path_type=args.path_type,
encoders=encoders,
accelerator=accelerator,
latents_scale=latents_scale,
latents_bias=latents_bias,
weighting=args.weighting,
t_c=args.t_c,
ot_cls=args.ot_cls,
)
if accelerator.is_main_process:
logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Setup data(train_dataset 已在上方创建)
local_batch_size = int(args.batch_size // accelerator.num_processes)
train_dataloader = DataLoader(
train_dataset,
batch_size=local_batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True
)
if accelerator.is_main_process:
logger.info(f"Dataset contains {len(train_dataset):,} images ({args.data_dir})")
# Prepare models for training:
update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights
model.train() # important! This enables embedding dropout for classifier-free guidance
ema.eval() # EMA model should always be in eval mode
# resume:
global_step = 0
if args.resume_from_ckpt is not None:
ckpt = torch.load(args.resume_from_ckpt, map_location="cpu")
model.load_state_dict(ckpt["model"])
ema.load_state_dict(ckpt["ema"])
if "opt" in ckpt:
optimizer.load_state_dict(ckpt["opt"])
global_step = int(ckpt.get("steps", 0))
if accelerator.is_main_process:
logger.info(
f"Resumed from ckpt: {args.resume_from_ckpt} (global_step={global_step})"
)
elif args.resume_step > 0:
ckpt_name = str(args.resume_step).zfill(7) +'.pt'
ckpt = torch.load(
f'{os.path.join(args.output_dir, args.exp_name)}/checkpoints/{ckpt_name}',
map_location='cpu',
)
model.load_state_dict(ckpt['model'])
ema.load_state_dict(ckpt['ema'])
optimizer.load_state_dict(ckpt['opt'])
global_step = ckpt['steps']
model, optimizer, train_dataloader = accelerator.prepare(
model, optimizer, train_dataloader
)
if accelerator.is_main_process:
tracker_config = vars(copy.deepcopy(args))
accelerator.init_trackers(
project_name="REG",
config=tracker_config,
init_kwargs={
"wandb": {"name": f"{args.exp_name}"}
},
)
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
# Labels to condition the model with (feel free to change):
sample_batch_size = 64 // accelerator.num_processes
first_batch = next(iter(train_dataloader))
if len(first_batch) == 4:
gt_raw_images, gt_xs, _, _ = first_batch
else:
gt_raw_images, gt_xs, _ = first_batch
assert gt_raw_images.shape[-1] == args.resolution
gt_xs = gt_xs[:sample_batch_size]
gt_xs = sample_posterior(
gt_xs.to(device), latents_scale=latents_scale, latents_bias=latents_bias
)
ys = torch.randint(1000, size=(sample_batch_size,), device=device)
ys = ys.to(device)
# Create sampling noise:
n = ys.size(0)
xT = torch.randn((n, 4, latent_size, latent_size), device=device)
for epoch in range(args.epochs):
model.train()
for batch in train_dataloader:
if len(batch) == 4:
raw_image, x, r_preprocessed, y = batch
use_sem_file = True
else:
raw_image, x, y = batch
r_preprocessed = None
use_sem_file = False
raw_image = raw_image.to(device)
x = x.squeeze(dim=1).to(device).float()
y = y.to(device)
if args.legacy:
# In our early experiments, we accidentally apply label dropping twice:
# once in train.py and once in sit.py.
# We keep this option for exact reproducibility with previous runs.
drop_ids = torch.rand(y.shape[0], device=y.device) < args.cfg_prob
labels = torch.where(drop_ids, args.num_classes, y)
else:
labels = y
with torch.no_grad():
x = sample_posterior(x, latents_scale=latents_scale, latents_bias=latents_bias)
zs = []
if use_sem_file and r_preprocessed is not None:
cls_token = r_preprocessed.to(device).float()
if cls_token.dim() == 1:
cls_token = cls_token.unsqueeze(0)
while cls_token.dim() > 2:
cls_token = cls_token.squeeze(1)
base_m = model.module if hasattr(model, "module") else model
n_pad = base_m.x_embedder.num_patches
zs = [
torch.cat(
[
cls_token.unsqueeze(1),
cls_token.unsqueeze(1).expand(-1, n_pad, -1),
],
dim=1,
)
]
else:
with accelerator.autocast():
for encoder, encoder_type, arch in zip(
encoders, encoder_types, architectures
):
raw_image_ = preprocess_raw_image(raw_image, encoder_type)
z = encoder.forward_features(raw_image_)
if 'dinov2' in encoder_type:
dense_z = z['x_norm_patchtokens']
cls_token = z['x_norm_clstoken']
dense_z = torch.cat([cls_token.unsqueeze(1), dense_z], dim=1)
else:
exit()
zs.append(dense_z)
with accelerator.accumulate(model):
model_kwargs = dict(y=labels)
loss1, proj_loss1, time_input, noises, loss2 = loss_fn(model, x, model_kwargs, zs=zs,
cls_token=cls_token,
time_input=None, noises=None)
loss_mean = loss1.mean()
loss_mean_cls = loss2.mean() * args.cls
proj_loss_mean = proj_loss1.mean() * args.proj_coeff
tc_vel_loss = torch.tensor(0.0, device=device)
if args.tc_velocity_loss_coeff > 0:
tc_vel_loss = loss_fn.tc_velocity_loss(
model,
x,
model_kwargs=model_kwargs,
cls_token=cls_token,
noises=noises,
).mean()
loss = (
loss_mean
+ proj_loss_mean
+ loss_mean_cls
+ args.tc_velocity_loss_coeff * tc_vel_loss
)
## optimization
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = model.parameters()
grad_norm = accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
if accelerator.sync_gradients:
update_ema(ema, model) # change ema function
### enter
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if global_step % args.checkpointing_steps == 0 and global_step > 0:
if accelerator.is_main_process:
checkpoint = {
"model": model.module.state_dict(),
"ema": ema.state_dict(),
"opt": optimizer.state_dict(),
"args": args,
"steps": global_step,
}
checkpoint_path = f"{checkpoint_dir}/{global_step:07d}.pt"
torch.save(checkpoint, checkpoint_path)
logger.info(f"Saved checkpoint to {checkpoint_path}")
if (global_step == 1 or (global_step % args.sampling_steps == 0 and global_step > 0)):
t_mid_vis = float(args.t_c)
tc_tag = f"{t_mid_vis:.4f}".rstrip("0").rstrip(".").replace(".", "_")
logging.info(
f"Generating EMA samples (Euler-Maruyama; t≈{t_mid_vis:g} → t=0)..."
)
ema.eval()
with torch.no_grad():
latent_size = args.resolution // 8
n_samples = min(16, args.batch_size)
base_model = model.module if hasattr(model, "module") else model
cls_dim = base_model.z_dims[0]
shared_seed = torch.randint(0, 2**32, (1,), device=device).item()
torch.manual_seed(shared_seed)
z_init = torch.randn(n_samples, base_model.in_channels, latent_size, latent_size, device=device)
torch.manual_seed(shared_seed)
cls_init = torch.randn(n_samples, cls_dim, device=device)
y_samples = torch.randint(0, args.num_classes, (n_samples,), device=device)
from samplers import euler_maruyama_sampler
z_0, z_mid, _ = euler_maruyama_sampler(
ema,
z_init,
y_samples,
num_steps=50,
cfg_scale=1.0,
guidance_low=0.0,
guidance_high=1.0,
path_type=args.path_type,
cls_latents=cls_init,
args=args,
return_mid_state=True,
t_mid=t_mid_vis,
)
samples_root = os.path.join(args.output_dir, args.exp_name, "samples")
t0_dir = os.path.join(samples_root, "t0")
t_mid_dir = os.path.join(samples_root, f"t0_{tc_tag}")
os.makedirs(t0_dir, exist_ok=True)
os.makedirs(t_mid_dir, exist_ok=True)
if vae is not None:
z_f = z_0.to(dtype=torch.float32)
samples_final = vae.decode((z_f - latents_bias) / latents_scale).sample
samples_final = (samples_final + 1) / 2.0
samples_final = samples_final.clamp(0, 1)
grid_final = array2grid(samples_final)
Image.fromarray(grid_final).save(
os.path.join(t0_dir, f"step_{global_step:07d}_t0.png")
)
if z_mid is not None:
z_m = z_mid.to(dtype=torch.float32)
samples_mid = vae.decode((z_m - latents_bias) / latents_scale).sample
samples_mid = (samples_mid + 1) / 2.0
samples_mid = samples_mid.clamp(0, 1)
grid_mid = array2grid(samples_mid)
Image.fromarray(grid_mid).save(
os.path.join(t_mid_dir, f"step_{global_step:07d}_t0_{tc_tag}.png")
)
else:
logging.warning(
f"Sampling time grid did not bracket t_mid={t_mid_vis:g}; "
f"skip t0_{tc_tag} image this step."
)
del z_init, cls_init, y_samples, z_0
if z_mid is not None:
del z_mid
if vae is not None:
del samples_final, grid_final
if "samples_mid" in locals():
del samples_mid, grid_mid
torch.cuda.empty_cache()
logs = {
"loss_final": accelerator.gather(loss).mean().detach().item(),
"loss_mean": accelerator.gather(loss_mean).mean().detach().item(),
"proj_loss": accelerator.gather(proj_loss_mean).mean().detach().item(),
"loss_mean_cls": accelerator.gather(loss_mean_cls).mean().detach().item(),
"loss_tc_vel": accelerator.gather(tc_vel_loss).mean().detach().item(),
"grad_norm": accelerator.gather(grad_norm).mean().detach().item()
}
log_message = ", ".join(f"{key}: {value:.6f}" for key, value in logs.items())
logging.info(f"Step: {global_step}, Training Logs: {log_message}")
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if global_step >= args.max_train_steps:
break
model.eval() # important! This disables randomized embedding dropout
# do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
accelerator.wait_for_everyone()
if accelerator.is_main_process:
logger.info("Done!")
accelerator.end_training()
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Training")
# logging:
parser.add_argument("--output-dir", type=str, default="exps")
parser.add_argument("--exp-name", type=str, required=True)
parser.add_argument("--logging-dir", type=str, default="logs")
parser.add_argument("--report-to", type=str, default="wandb")
parser.add_argument("--sampling-steps", type=int, default=2000)
parser.add_argument("--resume-step", type=int, default=0)
parser.add_argument(
"--resume-from-ckpt",
type=str,
default=None,
help="直接从指定 checkpoint 路径续训(优先于 --resume-step)。",
)
# model
parser.add_argument("--model", type=str)
parser.add_argument("--num-classes", type=int, default=1000)
parser.add_argument("--encoder-depth", type=int, default=8)
parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--ops-head", type=int, default=16)
# dataset
parser.add_argument("--data-dir", type=str, default="../data/imagenet256")
parser.add_argument(
"--semantic-features-dir",
type=str,
default=None,
help="预处理 DINOv2 class token 等特征目录(含 dataset.json)。"
"默认 None 时若存在 data-dir/imagenet_256_features/dinov2-vit-b_tmp/gpu0 则自动使用。",
)
parser.add_argument("--resolution", type=int, choices=[256, 512], default=256)
parser.add_argument("--batch-size", type=int, default=256)#256
# precision
parser.add_argument("--allow-tf32", action="store_true")
parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
# optimization
parser.add_argument("--epochs", type=int, default=14000)
parser.add_argument("--max-train-steps", type=int, default=10000000)
parser.add_argument("--checkpointing-steps", type=int, default=10000)
parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
parser.add_argument("--learning-rate", type=float, default=1e-4)
parser.add_argument("--adam-beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam-beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam-weight-decay", type=float, default=0., help="Weight decay to use.")
parser.add_argument("--adam-epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.")
# seed
parser.add_argument("--seed", type=int, default=0)
# cpu
parser.add_argument("--num-workers", type=int, default=4)
# loss
parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
parser.add_argument("--prediction", type=str, default="v", choices=["v"]) # currently we only support v-prediction
parser.add_argument("--cfg-prob", type=float, default=0.1)
parser.add_argument("--enc-type", type=str, default='dinov2-vit-b')
parser.add_argument("--proj-coeff", type=float, default=0.5)
parser.add_argument("--weighting", default="uniform", type=str, help="Max gradient norm.")
parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--cls", type=float, default=0.03)
parser.add_argument(
"--t-c",
type=float,
default=0.5,
help="语义分界时刻(与脚本内 t 约定一致:t=1 噪声→t=0 数据)。"
"t∈(t_c,1]:cls 沿 OT 配对后的路径插值(CFM/OT-CFM 式 minibatch OT);"
"t∈[0,t_c]:cls 固定为真实 encoder cls,目标 cls 速度为 0。",
)
parser.add_argument(
"--ot-cls",
action=argparse.BooleanOptionalAction,
default=True,
help="在 t>t_c 段对 cls 噪声与 batch 内 cls_gt 做 minibatch 最优传输配对(需 scipy);关闭则退化为独立高斯噪声配对。",
)
parser.add_argument(
"--tc-velocity-loss-coeff",
type=float,
default=0.0,
help="额外 t=t_c 图像速度场监督项权重(>0 启用,用于增强单步性)。",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
main(args)