Instructions to use zeyuren2002/EvalMDE with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use zeyuren2002/EvalMDE with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("zeyuren2002/EvalMDE", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import torch, os, json | |
| from torch.utils.data import ConcatDataset | |
| import numpy as np | |
| from models.utils import load_state_dict,DiffusionTrainingModule, ModelLogger, launch_training_task, flux_parser, parse_flux_model_configs, find_latest_checkpoint | |
| from pipelines.flux_image_new import FluxImagePipeline, ModelConfig | |
| from lora.flux_lora import FluxLoRAConverter | |
| from models.unified_dataset import UnifiedDataset | |
| from utils.mixed_sampler import MixedBatchSampler | |
| from utils.visualize import visualize_sample,prepare_image | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # torch.autograd.set_detect_anomaly(True) | |
| class FluxTrainingModule(DiffusionTrainingModule): | |
| def __init__( | |
| self, | |
| model_paths=None, model_id_with_origin_paths=None, | |
| trainable_models=None, | |
| lora_base_model=None, lora_target_modules="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp", lora_rank=32, lora_checkpoint=None, | |
| use_gradient_checkpointing=True, | |
| use_gradient_checkpointing_offload=False, | |
| extra_inputs=None, | |
| multi_res_noise=False, | |
| deterministic_flow=False, | |
| extra_loss=None, | |
| depth_normalization="log", | |
| matting_prompt=None, | |
| ): | |
| super().__init__() | |
| # Load models | |
| # model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=False) | |
| model_configs = parse_flux_model_configs(root_path=model_paths) | |
| self.pipe = FluxImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cuda", model_configs=model_configs) | |
| # Update matting_prompt if provided | |
| if matting_prompt is not None and hasattr(self.pipe.dit, 'coord_encoder'): | |
| self.pipe.dit.coord_encoder.matting_prompt = matting_prompt | |
| print(f"Updated coord_encoder.matting_prompt to: {matting_prompt}") | |
| # Training mode | |
| self.switch_pipe_to_training_mode( | |
| self.pipe, trainable_models, | |
| lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, | |
| enable_fp8_training=False, | |
| ) | |
| # Store other configs | |
| self.use_gradient_checkpointing = use_gradient_checkpointing | |
| self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload | |
| self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] | |
| self.multi_res_noise = multi_res_noise | |
| self.deterministic_flow = deterministic_flow | |
| self.extra_loss = extra_loss | |
| self.depth_normalization = depth_normalization | |
| def forward_preprocess(self, data): | |
| # CFG-sensitive parameters | |
| inputs_posi = {"prompt": data["prompt"]} | |
| inputs_nega = {"negative_prompt": ""} | |
| # CFG-unsensitive parameters | |
| if data["image"].ndim == 3: | |
| data["image"] = data["image"].unsqueeze(0) | |
| if self.deterministic_flow: | |
| timestep = self.pipe.scheduler.timesteps[0].repeat(data["image"].shape[0],).to(self.pipe.device,dtype=self.pipe.torch_dtype) | |
| else: | |
| timestep = self.pipe.scheduler.timesteps[torch.randint(0, self.pipe.scheduler.num_train_timesteps, size=(data["image"].shape[0],))].to(self.pipe.device,dtype=self.pipe.torch_dtype) | |
| inputs_shared = { | |
| # Assume you are using this pipeline for inference, | |
| # please fill in the input parameters. | |
| "input_image": data["image"], | |
| "mask": data.get("mask", None), | |
| "height": data["image"].shape[2], | |
| "width": data["image"].shape[3], | |
| # Please do not modify the following parameters | |
| # unless you clearly know what this will cause. | |
| "seed": 42, | |
| "cfg_scale": 1, | |
| "embedded_guidance": 1, | |
| "t5_sequence_length": 512, | |
| "tiled": False, | |
| "rand_device": self.pipe.device, | |
| "use_gradient_checkpointing": self.use_gradient_checkpointing, | |
| "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, | |
| "multi_res_noise": self.multi_res_noise, | |
| "timestep": timestep, | |
| "deterministic_flow": self.deterministic_flow, | |
| "extra_loss": self.extra_loss, | |
| "depth_normalization": self.depth_normalization | |
| } | |
| # Extra inputs | |
| controlnet_input = {} | |
| for extra_input in self.extra_inputs: | |
| if extra_input.startswith("controlnet_"): | |
| controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input] | |
| else: | |
| inputs_shared[extra_input] = data[extra_input] | |
| # Pipeline units will automatically process the input parameters. | |
| for unit in self.pipe.units: | |
| inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) | |
| return {**inputs_shared, **inputs_posi} | |
| def forward(self, data, inputs=None): | |
| if inputs is None: inputs = self.forward_preprocess(data) | |
| models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} | |
| loss = self.pipe.training_loss(**models, **inputs) | |
| return loss | |
| if __name__ == "__main__": | |
| parser = flux_parser() | |
| args = parser.parse_args() | |
| datasets_ls = [] | |
| metadata_paths = args.dataset_metadata_path.split(",") | |
| dataset_base_paths = args.dataset_base_path.split(",") | |
| heights = [args.height] | |
| widths = [args.width] | |
| if len(metadata_paths) == 2: | |
| # heights = [768,352] | |
| # widths = [1024,1216] | |
| heights = [args.height, args.height] | |
| widths = [args.width, args.width] | |
| elif len(metadata_paths) == 3: | |
| heights = [args.height, args.height, args.height] | |
| widths = [args.width, args.width, args.width] | |
| elif len(metadata_paths) == 4: | |
| heights = [args.height, args.height, args.height, args.height] | |
| widths = [args.width, args.width, args.width, args.width] | |
| if "depth" in metadata_paths[0]: | |
| args.task = "depth" | |
| print("!!! Doing depth task !!!") | |
| elif "normal" in metadata_paths[0]: | |
| args.task = "normal" | |
| print("!!! Doing normal task !!!") | |
| elif "matting" in metadata_paths[0]: | |
| args.task = "matting" | |
| print("!!! Doing matting task !!!") | |
| else: | |
| raise ValueError("Cannot infer task from metadata path; please include 'depth' or 'normal' in the path.") | |
| if args.using_pdf: | |
| print("!!! Using PDF operator for image loading !!!") | |
| pdf = np.load("depth_mapping_lookup_table.npz") | |
| else: | |
| pdf = None | |
| for dataset_base_path,metadata_path,height,width in zip(dataset_base_paths,metadata_paths,heights,widths): | |
| dataset = UnifiedDataset( | |
| base_path=dataset_base_path, | |
| metadata_path=metadata_path, | |
| repeat=args.dataset_repeat, | |
| data_file_keys=args.data_file_keys.split(","), | |
| main_data_operator=UnifiedDataset.default_image_operator( | |
| base_path=dataset_base_path, | |
| max_pixels=args.max_pixels, | |
| height=height, | |
| width=width, | |
| height_division_factor=32, | |
| width_division_factor=32, | |
| using_log=args.using_log, | |
| using_sqrt=args.using_sqrt, | |
| using_sqrt_disp=args.using_sqrt_disp, | |
| using_pdf=args.using_pdf, | |
| pdf=pdf, | |
| with_mask=args.with_mask, | |
| ), | |
| special_operator_map=["mask","prompt"] if args.with_mask else ["prompt"], | |
| default_caption = args.default_caption, | |
| matting_prompt=args.matting_prompt if args.task=="matting" else None, | |
| use_coor_input=args.use_coor_input if args.task=="matting" else False, | |
| use_camera_intrinsics=args.use_camera_intrinsics, | |
| # use_attn_mask=args.use_attn_mask if args.task=="matting" else False, | |
| ) | |
| print(f"Loading {metadata_path} with {len(dataset)} items of size Height {height} x Width {width}") | |
| # Example data item logging (handle optional keys safely) | |
| try: | |
| example = dataset[0] | |
| img_shape = example['image'].shape if 'image' in example else None | |
| if isinstance(example.get('kontext_images', None), list): | |
| ktx_shape = [ktx.shape for ktx in example['kontext_images']] | |
| else: | |
| ktx_shape = example.get('kontext_images', None).shape if example.get('kontext_images', None) is not None else None | |
| # ktx_shape = example.get('kontext_images', None).shape if example.get('kontext_images', None) is not None else None | |
| mask_obj = example.get('mask', None) | |
| mask_shape = mask_obj.shape if mask_obj is not None else None | |
| prompt_val = example.get('prompt', '') | |
| print(f"Example data item: image={img_shape}, kontext_images={ktx_shape}, mask={mask_shape}, prompt={prompt_val[:80]}") | |
| except Exception as e: | |
| print(f"Failed to print example data item: {e}") | |
| datasets_ls.append(dataset) | |
| print(f"Total datasets loaded: {len(datasets_ls)}") | |
| if len(datasets_ls) > 1: | |
| dataset = ConcatDataset(datasets_ls) | |
| if args.task=="depth": | |
| prob=[0.9, 0.1] | |
| elif args.task=="normal": | |
| prob=[0.5,0.45,0.05] | |
| elif args.task=="matting": | |
| if len(datasets_ls)==2: | |
| prob=[0.5,0.5] | |
| elif len(datasets_ls)==3: | |
| prob=[0.02,0.65,0.33] | |
| elif len(datasets_ls)==4: | |
| prob=[0.22,0.22,0.3,0.26] | |
| mixed_sampler = MixedBatchSampler(datasets_ls, shuffle=True, batch_size=args.batch_size, drop_last=True, prob=prob) | |
| print(f"using {len(datasets_ls)} datasets, total length: {len(dataset)} with PROB:{prob}") | |
| else: | |
| dataset = datasets_ls[0] | |
| mixed_sampler = None | |
| print(args.eval_file_list) | |
| if args.eval_file_list: | |
| with open(args.eval_file_list, "r") as f: | |
| if args.task == "depth" or args.task == "normal": | |
| eval_file_list = [line.strip().split()[0] for line in f] | |
| base_dir = f"/mnt/nfs/workspace/syq/dataset/Eval/{args.task}/nyuv2" | |
| elif args.task == "matting": | |
| eval_file_list = [line.strip().split()[0] for line in f] | |
| base_dir = "/mnt/nfs/workspace/syq/dataset/matting/P3M-10k" | |
| else: | |
| raise ValueError(f"Unknown task {args.task}") | |
| print(f"Loaded {len(eval_file_list)} evaluation files.") | |
| eval_file_list = [os.path.join(base_dir, x) if not os.path.isabs(x) else x for x in eval_file_list] | |
| args.eval_file_list = eval_file_list | |
| print(f"top 5 evaluation files: {eval_file_list[:5]}") | |
| model = FluxTrainingModule( | |
| model_paths=args.model_paths, | |
| model_id_with_origin_paths=args.model_id_with_origin_paths, | |
| trainable_models=args.trainable_models, | |
| lora_base_model=args.lora_base_model, | |
| lora_target_modules=args.lora_target_modules, | |
| lora_rank=args.lora_rank, | |
| lora_checkpoint=find_latest_checkpoint(args.output_path) if (args.resume and args.lora_base_model is not None) else None, | |
| use_gradient_checkpointing=args.use_gradient_checkpointing, | |
| use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, | |
| extra_inputs=args.extra_inputs, | |
| multi_res_noise=args.multi_res_noise, | |
| deterministic_flow=args.deterministic_flow, | |
| extra_loss=args.extra_loss, | |
| depth_normalization=args.depth_normalization, | |
| matting_prompt=args.matting_prompt if args.task == "matting" else None, | |
| ) | |
| if args.resume and os.path.isdir(args.output_path): | |
| latest_ckpt = find_latest_checkpoint(args.output_path) | |
| if latest_ckpt is not None: | |
| if args.lora_base_model is None: | |
| state_dict = load_state_dict(latest_ckpt) | |
| model.pipe.dit.load_state_dict(state_dict,strict=False) | |
| del state_dict | |
| args.resume_steps = int(latest_ckpt.split("step-")[-1].split(".")[0]) | |
| print(f"Resumed training from step {args.resume_steps}") | |
| torch.cuda.empty_cache() | |
| else: | |
| print(f"No checkpoint found in {args.output_path}, starting fresh training.") | |
| args.resume_steps = 0 | |
| else: | |
| args.resume_steps = 0 | |
| model_logger = ModelLogger( | |
| args.output_path, | |
| remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, | |
| state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x, | |
| args=args | |
| ) | |
| launch_training_task(dataset, model, model_logger, dataset_sampler=mixed_sampler, args=args) | |