Diffusers
Safetensors
zeyuren2002's picture
Add files using upload-large-folder tool
7f921f4 verified
import torch, os, json
from torch.utils.data import ConcatDataset
import numpy as np
from models.utils import load_state_dict,DiffusionTrainingModule, ModelLogger, launch_training_task, flux_parser, parse_flux_model_configs, find_latest_checkpoint
from pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from lora.flux_lora import FluxLoRAConverter
from models.unified_dataset import UnifiedDataset
from utils.mixed_sampler import MixedBatchSampler
from utils.visualize import visualize_sample,prepare_image
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# torch.autograd.set_detect_anomaly(True)
class FluxTrainingModule(DiffusionTrainingModule):
def __init__(
self,
model_paths=None, model_id_with_origin_paths=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", lora_rank=32, lora_checkpoint=None,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
multi_res_noise=False,
deterministic_flow=False,
extra_loss=None,
depth_normalization="log",
matting_prompt=None,
):
super().__init__()
# Load models
# model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False)
model_configs = parse_flux_model_configs(root_path=model_paths)
self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cuda", model_configs=model_configs)
# Update matting_prompt if provided
if matting_prompt is not None and hasattr(self.pipe.dit, 'coord_encoder'):
self.pipe.dit.coord_encoder.matting_prompt = matting_prompt
print(f"Updated coord_encoder.matting_prompt to: {matting_prompt}")
# Training mode
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
enable_fp8_training=False,
)
# Store other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.multi_res_noise = multi_res_noise
self.deterministic_flow = deterministic_flow
self.extra_loss = extra_loss
self.depth_normalization = depth_normalization
def forward_preprocess(self, data):
# CFG-sensitive parameters
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
# CFG-unsensitive parameters
if data["image"].ndim == 3:
data["image"] = data["image"].unsqueeze(0)
if self.deterministic_flow:
timestep = self.pipe.scheduler.timesteps[0].repeat(data["image"].shape[0],).to(self.pipe.device,dtype=self.pipe.torch_dtype)
else:
timestep = self.pipe.scheduler.timesteps[torch.randint(0, self.pipe.scheduler.num_train_timesteps, size=(data["image"].shape[0],))].to(self.pipe.device,dtype=self.pipe.torch_dtype)
inputs_shared = {
# Assume you are using this pipeline for inference,
# please fill in the input parameters.
"input_image": data["image"],
"mask": data.get("mask", None),
"height": data["image"].shape[2],
"width": data["image"].shape[3],
# Please do not modify the following parameters
# unless you clearly know what this will cause.
"seed": 42,
"cfg_scale": 1,
"embedded_guidance": 1,
"t5_sequence_length": 512,
"tiled": False,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
"multi_res_noise": self.multi_res_noise,
"timestep": timestep,
"deterministic_flow": self.deterministic_flow,
"extra_loss": self.extra_loss,
"depth_normalization": self.depth_normalization
}
# Extra inputs
controlnet_input = {}
for extra_input in self.extra_inputs:
if extra_input.startswith("controlnet_"):
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
else:
inputs_shared[extra_input] = data[extra_input]
# Pipeline units will automatically process the input parameters.
for unit in self.pipe.units:
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
return {**inputs_shared, **inputs_posi}
def forward(self, data, inputs=None):
if inputs is None: inputs = self.forward_preprocess(data)
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
loss = self.pipe.training_loss(**models, **inputs)
return loss
if __name__ == "__main__":
parser = flux_parser()
args = parser.parse_args()
datasets_ls = []
metadata_paths = args.dataset_metadata_path.split(",")
dataset_base_paths = args.dataset_base_path.split(",")
heights = [args.height]
widths = [args.width]
if len(metadata_paths) == 2:
# heights = [768,352]
# widths = [1024,1216]
heights = [args.height, args.height]
widths = [args.width, args.width]
elif len(metadata_paths) == 3:
heights = [args.height, args.height, args.height]
widths = [args.width, args.width, args.width]
elif len(metadata_paths) == 4:
heights = [args.height, args.height, args.height, args.height]
widths = [args.width, args.width, args.width, args.width]
if "depth" in metadata_paths[0]:
args.task = "depth"
print("!!! Doing depth task !!!")
elif "normal" in metadata_paths[0]:
args.task = "normal"
print("!!! Doing normal task !!!")
elif "matting" in metadata_paths[0]:
args.task = "matting"
print("!!! Doing matting task !!!")
else:
raise ValueError("Cannot infer task from metadata path; please include 'depth' or 'normal' in the path.")
if args.using_pdf:
print("!!! Using PDF operator for image loading !!!")
pdf = np.load("depth_mapping_lookup_table.npz")
else:
pdf = None
for dataset_base_path,metadata_path,height,width in zip(dataset_base_paths,metadata_paths,heights,widths):
dataset = UnifiedDataset(
base_path=dataset_base_path,
metadata_path=metadata_path,
repeat=args.dataset_repeat,
data_file_keys=args.data_file_keys.split(","),
main_data_operator=UnifiedDataset.default_image_operator(
base_path=dataset_base_path,
max_pixels=args.max_pixels,
height=height,
width=width,
height_division_factor=32,
width_division_factor=32,
using_log=args.using_log,
using_sqrt=args.using_sqrt,
using_sqrt_disp=args.using_sqrt_disp,
using_pdf=args.using_pdf,
pdf=pdf,
with_mask=args.with_mask,
),
special_operator_map=["mask","prompt"] if args.with_mask else ["prompt"],
default_caption = args.default_caption,
matting_prompt=args.matting_prompt if args.task=="matting" else None,
use_coor_input=args.use_coor_input if args.task=="matting" else False,
use_camera_intrinsics=args.use_camera_intrinsics,
# use_attn_mask=args.use_attn_mask if args.task=="matting" else False,
)
print(f"Loading {metadata_path} with {len(dataset)} items of size Height {height} x Width {width}")
# Example data item logging (handle optional keys safely)
try:
example = dataset[0]
img_shape = example['image'].shape if 'image' in example else None
if isinstance(example.get('kontext_images', None), list):
ktx_shape = [ktx.shape for ktx in example['kontext_images']]
else:
ktx_shape = example.get('kontext_images', None).shape if example.get('kontext_images', None) is not None else None
# ktx_shape = example.get('kontext_images', None).shape if example.get('kontext_images', None) is not None else None
mask_obj = example.get('mask', None)
mask_shape = mask_obj.shape if mask_obj is not None else None
prompt_val = example.get('prompt', '')
print(f"Example data item: image={img_shape}, kontext_images={ktx_shape}, mask={mask_shape}, prompt={prompt_val[:80]}")
except Exception as e:
print(f"Failed to print example data item: {e}")
datasets_ls.append(dataset)
print(f"Total datasets loaded: {len(datasets_ls)}")
if len(datasets_ls) > 1:
dataset = ConcatDataset(datasets_ls)
if args.task=="depth":
prob=[0.9, 0.1]
elif args.task=="normal":
prob=[0.5,0.45,0.05]
elif args.task=="matting":
if len(datasets_ls)==2:
prob=[0.5,0.5]
elif len(datasets_ls)==3:
prob=[0.02,0.65,0.33]
elif len(datasets_ls)==4:
prob=[0.22,0.22,0.3,0.26]
mixed_sampler = MixedBatchSampler(datasets_ls, shuffle=True, batch_size=args.batch_size, drop_last=True, prob=prob)
print(f"using {len(datasets_ls)} datasets, total length: {len(dataset)} with PROB:{prob}")
else:
dataset = datasets_ls[0]
mixed_sampler = None
print(args.eval_file_list)
if args.eval_file_list:
with open(args.eval_file_list, "r") as f:
if args.task == "depth" or args.task == "normal":
eval_file_list = [line.strip().split()[0] for line in f]
base_dir = f"/mnt/nfs/workspace/syq/dataset/Eval/{args.task}/nyuv2"
elif args.task == "matting":
eval_file_list = [line.strip().split()[0] for line in f]
base_dir = "/mnt/nfs/workspace/syq/dataset/matting/P3M-10k"
else:
raise ValueError(f"Unknown task {args.task}")
print(f"Loaded {len(eval_file_list)} evaluation files.")
eval_file_list = [os.path.join(base_dir, x) if not os.path.isabs(x) else x for x in eval_file_list]
args.eval_file_list = eval_file_list
print(f"top 5 evaluation files: {eval_file_list[:5]}")
model = FluxTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
trainable_models=args.trainable_models,
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
lora_checkpoint=find_latest_checkpoint(args.output_path) if (args.resume and args.lora_base_model is not None) else None,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
multi_res_noise=args.multi_res_noise,
deterministic_flow=args.deterministic_flow,
extra_loss=args.extra_loss,
depth_normalization=args.depth_normalization,
matting_prompt=args.matting_prompt if args.task == "matting" else None,
)
if args.resume and os.path.isdir(args.output_path):
latest_ckpt = find_latest_checkpoint(args.output_path)
if latest_ckpt is not None:
if args.lora_base_model is None:
state_dict = load_state_dict(latest_ckpt)
model.pipe.dit.load_state_dict(state_dict,strict=False)
del state_dict
args.resume_steps = int(latest_ckpt.split("step-")[-1].split(".")[0])
print(f"Resumed training from step {args.resume_steps}")
torch.cuda.empty_cache()
else:
print(f"No checkpoint found in {args.output_path}, starting fresh training.")
args.resume_steps = 0
else:
args.resume_steps = 0
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
args=args
)
launch_training_task(dataset, model, model_logger, dataset_sampler=mixed_sampler, args=args)