Instructions to use zeyuren2002/EvalMDE with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use zeyuren2002/EvalMDE with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("zeyuren2002/EvalMDE", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
File size: 13,041 Bytes
7f921f4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 | import torch, os, json
from torch.utils.data import ConcatDataset
import numpy as np
from models.utils import load_state_dict,DiffusionTrainingModule, ModelLogger, launch_training_task, flux_parser, parse_flux_model_configs, find_latest_checkpoint
from pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from lora.flux_lora import FluxLoRAConverter
from models.unified_dataset import UnifiedDataset
from utils.mixed_sampler import MixedBatchSampler
from utils.visualize import visualize_sample,prepare_image
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# torch.autograd.set_detect_anomaly(True)
class FluxTrainingModule(DiffusionTrainingModule):
def __init__(
self,
model_paths=None, model_id_with_origin_paths=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", lora_rank=32, lora_checkpoint=None,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
multi_res_noise=False,
deterministic_flow=False,
extra_loss=None,
depth_normalization="log",
matting_prompt=None,
):
super().__init__()
# Load models
# model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False)
model_configs = parse_flux_model_configs(root_path=model_paths)
self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cuda", model_configs=model_configs)
# Update matting_prompt if provided
if matting_prompt is not None and hasattr(self.pipe.dit, 'coord_encoder'):
self.pipe.dit.coord_encoder.matting_prompt = matting_prompt
print(f"Updated coord_encoder.matting_prompt to: {matting_prompt}")
# Training mode
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
enable_fp8_training=False,
)
# Store other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.multi_res_noise = multi_res_noise
self.deterministic_flow = deterministic_flow
self.extra_loss = extra_loss
self.depth_normalization = depth_normalization
def forward_preprocess(self, data):
# CFG-sensitive parameters
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
# CFG-unsensitive parameters
if data["image"].ndim == 3:
data["image"] = data["image"].unsqueeze(0)
if self.deterministic_flow:
timestep = self.pipe.scheduler.timesteps[0].repeat(data["image"].shape[0],).to(self.pipe.device,dtype=self.pipe.torch_dtype)
else:
timestep = self.pipe.scheduler.timesteps[torch.randint(0, self.pipe.scheduler.num_train_timesteps, size=(data["image"].shape[0],))].to(self.pipe.device,dtype=self.pipe.torch_dtype)
inputs_shared = {
# Assume you are using this pipeline for inference,
# please fill in the input parameters.
"input_image": data["image"],
"mask": data.get("mask", None),
"height": data["image"].shape[2],
"width": data["image"].shape[3],
# Please do not modify the following parameters
# unless you clearly know what this will cause.
"seed": 42,
"cfg_scale": 1,
"embedded_guidance": 1,
"t5_sequence_length": 512,
"tiled": False,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
"multi_res_noise": self.multi_res_noise,
"timestep": timestep,
"deterministic_flow": self.deterministic_flow,
"extra_loss": self.extra_loss,
"depth_normalization": self.depth_normalization
}
# Extra inputs
controlnet_input = {}
for extra_input in self.extra_inputs:
if extra_input.startswith("controlnet_"):
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
else:
inputs_shared[extra_input] = data[extra_input]
# Pipeline units will automatically process the input parameters.
for unit in self.pipe.units:
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
return {**inputs_shared, **inputs_posi}
def forward(self, data, inputs=None):
if inputs is None: inputs = self.forward_preprocess(data)
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
loss = self.pipe.training_loss(**models, **inputs)
return loss
if __name__ == "__main__":
parser = flux_parser()
args = parser.parse_args()
datasets_ls = []
metadata_paths = args.dataset_metadata_path.split(",")
dataset_base_paths = args.dataset_base_path.split(",")
heights = [args.height]
widths = [args.width]
if len(metadata_paths) == 2:
# heights = [768,352]
# widths = [1024,1216]
heights = [args.height, args.height]
widths = [args.width, args.width]
elif len(metadata_paths) == 3:
heights = [args.height, args.height, args.height]
widths = [args.width, args.width, args.width]
elif len(metadata_paths) == 4:
heights = [args.height, args.height, args.height, args.height]
widths = [args.width, args.width, args.width, args.width]
if "depth" in metadata_paths[0]:
args.task = "depth"
print("!!! Doing depth task !!!")
elif "normal" in metadata_paths[0]:
args.task = "normal"
print("!!! Doing normal task !!!")
elif "matting" in metadata_paths[0]:
args.task = "matting"
print("!!! Doing matting task !!!")
else:
raise ValueError("Cannot infer task from metadata path; please include 'depth' or 'normal' in the path.")
if args.using_pdf:
print("!!! Using PDF operator for image loading !!!")
pdf = np.load("depth_mapping_lookup_table.npz")
else:
pdf = None
for dataset_base_path,metadata_path,height,width in zip(dataset_base_paths,metadata_paths,heights,widths):
dataset = UnifiedDataset(
base_path=dataset_base_path,
metadata_path=metadata_path,
repeat=args.dataset_repeat,
data_file_keys=args.data_file_keys.split(","),
main_data_operator=UnifiedDataset.default_image_operator(
base_path=dataset_base_path,
max_pixels=args.max_pixels,
height=height,
width=width,
height_division_factor=32,
width_division_factor=32,
using_log=args.using_log,
using_sqrt=args.using_sqrt,
using_sqrt_disp=args.using_sqrt_disp,
using_pdf=args.using_pdf,
pdf=pdf,
with_mask=args.with_mask,
),
special_operator_map=["mask","prompt"] if args.with_mask else ["prompt"],
default_caption = args.default_caption,
matting_prompt=args.matting_prompt if args.task=="matting" else None,
use_coor_input=args.use_coor_input if args.task=="matting" else False,
use_camera_intrinsics=args.use_camera_intrinsics,
# use_attn_mask=args.use_attn_mask if args.task=="matting" else False,
)
print(f"Loading {metadata_path} with {len(dataset)} items of size Height {height} x Width {width}")
# Example data item logging (handle optional keys safely)
try:
example = dataset[0]
img_shape = example['image'].shape if 'image' in example else None
if isinstance(example.get('kontext_images', None), list):
ktx_shape = [ktx.shape for ktx in example['kontext_images']]
else:
ktx_shape = example.get('kontext_images', None).shape if example.get('kontext_images', None) is not None else None
# ktx_shape = example.get('kontext_images', None).shape if example.get('kontext_images', None) is not None else None
mask_obj = example.get('mask', None)
mask_shape = mask_obj.shape if mask_obj is not None else None
prompt_val = example.get('prompt', '')
print(f"Example data item: image={img_shape}, kontext_images={ktx_shape}, mask={mask_shape}, prompt={prompt_val[:80]}")
except Exception as e:
print(f"Failed to print example data item: {e}")
datasets_ls.append(dataset)
print(f"Total datasets loaded: {len(datasets_ls)}")
if len(datasets_ls) > 1:
dataset = ConcatDataset(datasets_ls)
if args.task=="depth":
prob=[0.9, 0.1]
elif args.task=="normal":
prob=[0.5,0.45,0.05]
elif args.task=="matting":
if len(datasets_ls)==2:
prob=[0.5,0.5]
elif len(datasets_ls)==3:
prob=[0.02,0.65,0.33]
elif len(datasets_ls)==4:
prob=[0.22,0.22,0.3,0.26]
mixed_sampler = MixedBatchSampler(datasets_ls, shuffle=True, batch_size=args.batch_size, drop_last=True, prob=prob)
print(f"using {len(datasets_ls)} datasets, total length: {len(dataset)} with PROB:{prob}")
else:
dataset = datasets_ls[0]
mixed_sampler = None
print(args.eval_file_list)
if args.eval_file_list:
with open(args.eval_file_list, "r") as f:
if args.task == "depth" or args.task == "normal":
eval_file_list = [line.strip().split()[0] for line in f]
base_dir = f"/mnt/nfs/workspace/syq/dataset/Eval/{args.task}/nyuv2"
elif args.task == "matting":
eval_file_list = [line.strip().split()[0] for line in f]
base_dir = "/mnt/nfs/workspace/syq/dataset/matting/P3M-10k"
else:
raise ValueError(f"Unknown task {args.task}")
print(f"Loaded {len(eval_file_list)} evaluation files.")
eval_file_list = [os.path.join(base_dir, x) if not os.path.isabs(x) else x for x in eval_file_list]
args.eval_file_list = eval_file_list
print(f"top 5 evaluation files: {eval_file_list[:5]}")
model = FluxTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
trainable_models=args.trainable_models,
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
lora_checkpoint=find_latest_checkpoint(args.output_path) if (args.resume and args.lora_base_model is not None) else None,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
multi_res_noise=args.multi_res_noise,
deterministic_flow=args.deterministic_flow,
extra_loss=args.extra_loss,
depth_normalization=args.depth_normalization,
matting_prompt=args.matting_prompt if args.task == "matting" else None,
)
if args.resume and os.path.isdir(args.output_path):
latest_ckpt = find_latest_checkpoint(args.output_path)
if latest_ckpt is not None:
if args.lora_base_model is None:
state_dict = load_state_dict(latest_ckpt)
model.pipe.dit.load_state_dict(state_dict,strict=False)
del state_dict
args.resume_steps = int(latest_ckpt.split("step-")[-1].split(".")[0])
print(f"Resumed training from step {args.resume_steps}")
torch.cuda.empty_cache()
else:
print(f"No checkpoint found in {args.output_path}, starting fresh training.")
args.resume_steps = 0
else:
args.resume_steps = 0
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
args=args
)
launch_training_task(dataset, model, model_logger, dataset_sampler=mixed_sampler, args=args)
|