Instructions to use zeyuren2002/EvalMDE with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use zeyuren2002/EvalMDE with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("zeyuren2002/EvalMDE", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # copy from EasyVolumetricVideo | |
| # Author: Zhen Xu https://github.com/dendenxu | |
| from typing import Callable, List, Dict | |
| from multiprocessing.pool import ThreadPool | |
| from tqdm import tqdm | |
| from threading import Thread | |
| import torch | |
| from torch import Tensor | |
| from typing import List | |
| import imageio | |
| import os | |
| import asyncio | |
| from functools import wraps | |
| def async_call_func(func): | |
| async def wrapper(*args, **kwargs): | |
| loop = asyncio.get_event_loop() | |
| # Use run_in_executor to run the blocking function in a separate thread | |
| return await loop.run_in_executor(None, func, *args, **kwargs) | |
| return wrapper | |
| def _save_image_impl(save_img, save_path): | |
| """Common implementation for saving images synchronously or asynchronously""" | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| imageio.imwrite(save_path, save_img) | |
| def cat_dict(dict_list, dim=0, reserved_keys=['select_indice', 'dens_volume_fine', 'z_vals_fine', 'z_vals_mid_fine', 'points_fine', 'dens_volume', 'prob_volume', 't_vals', 't2z_func', 'img_feat', 't_mean_std_min_max']): | |
| return {k: torch.cat([item[k] for item in dict_list], dim) for k in dict_list[0] if isinstance(dict_list[0][k], Tensor) and k not in reserved_keys} | |
| def cat_list(list_list: List[List[Tensor]], dim: int = 0): | |
| return [torch.cat([item[i] for item in list_list], dim=dim) for i in range(len(list_list[0]))] | |
| def cat_tensor(tensor_list: List[Tensor], dim=0): | |
| return torch.cat(tensor_list, dim=dim) | |
| slice_func = lambda chunk_index, chunk_dim, chunk_size: [slice(None)] * chunk_dim + [slice(chunk_index, chunk_index+chunk_size)] | |
| def chunkify(func, cat_func, chunk_tensors: List[Tensor], chunk_dim: int, chunk_size: int, **kwargs): | |
| ''' | |
| func: function to be chunkified | |
| cat: function to concatenate the results | |
| chunk_tensors: list of tensors to be chunkified | |
| chunk_dim: dimension to be chunkified | |
| chunk_size: size of each chunk | |
| ''' | |
| total_chunk_size = chunk_tensors[0].shape[chunk_dim] | |
| assert all([total_chunk_size == chunk_tensors[i].shape[chunk_dim] for i in range(1, len(chunk_tensors))]) | |
| return cat_func([func(*[chunk_tensor[slice_func(i, chunk_dim, chunk_size)] for chunk_tensor in chunk_tensors], **kwargs) for i in range(0, total_chunk_size, chunk_size)], chunk_dim) | |
| all_ret = {} | |
| for i in range(0, chunk_size, chunk_num): | |
| ret = func(*[chunk_tensor[:, i:i + chunk_num] for chunk_tensor in chunk_tensors], **kwargs) | |
| if isinstance(ret, dict): | |
| for k in ret: | |
| if ret[k] is None: | |
| continue | |
| if k not in all_ret: | |
| all_ret[k] = [] | |
| all_ret[k].append(ret[k]) | |
| elif isinstance(ret, list) or isinstance(ret, tuple): | |
| for k in range(len(ret)): | |
| if ret[k] is None: | |
| continue | |
| if k not in all_ret: | |
| all_ret[k] = [] | |
| all_ret[k].append(ret[k]) | |
| elif isinstance(ret, torch.Tensor): | |
| if 0 not in all_ret: | |
| all_ret[0] = [] | |
| all_ret[0].append(ret) | |
| if isinstance(ret, dict): | |
| return {k: torch.cat(all_ret[k], dim=chunk_dim) for k in all_ret} | |
| elif isinstance(ret, list) or isinstance(ret, tuple): | |
| return [torch.cat(all_ret[k], dim=chunk_dim) for k in all_ret] | |
| elif isinstance(ret, torch.Tensor): | |
| return torch.cat(all_ret[0], dim=chunk_dim) | |
| def async_call(fn): | |
| def wrapper(*args, **kwargs): | |
| Thread(target=fn, args=args, kwargs=kwargs).start() | |
| return wrapper | |
| def save_image_async(save_img, save_path): | |
| """Save image asynchronously""" | |
| _save_image_impl(save_img, save_path) | |
| def save_image(save_img, save_path): | |
| """Save image synchronously""" | |
| _save_image_impl(save_img, save_path) | |
| def parallel_execution(*args, action: Callable, num_processes=32, print_progress=False, sequential=False, async_return=False, desc=None, **kwargs): | |
| # NOTE: we expect first arg / or kwargs to be distributed | |
| # NOTE: print_progress arg is reserved | |
| def get_length(args: List, kwargs: Dict): | |
| for a in args: | |
| if isinstance(a, list): | |
| return len(a) | |
| for v in kwargs.values(): | |
| if isinstance(v, list): | |
| return len(v) | |
| raise NotImplementedError | |
| def get_action_args(length: int, args: List, kwargs: Dict, i: int): | |
| action_args = [(arg[i] if isinstance(arg, list) and len(arg) == length else arg) for arg in args] | |
| # TODO: Support all types of iterable | |
| action_kwargs = {key: (kwargs[key][i] if isinstance(kwargs[key], list) and len(kwargs[key]) == length else kwargs[key]) for key in kwargs} | |
| return action_args, action_kwargs | |
| if not sequential: | |
| # Create ThreadPool | |
| pool = ThreadPool(processes=num_processes) | |
| # Spawn threads | |
| results = [] | |
| asyncs = [] | |
| length = get_length(args, kwargs) | |
| for i in range(length): | |
| action_args, action_kwargs = get_action_args(length, args, kwargs, i) | |
| async_result = pool.apply_async(action, action_args, action_kwargs) | |
| asyncs.append(async_result) | |
| # Join threads and get return values | |
| if not async_return: | |
| for async_result in tqdm(asyncs, desc=desc, disable=not print_progress): | |
| results.append(async_result.get()) # will sync the corresponding thread | |
| pool.close() | |
| pool.join() | |
| return results | |
| else: | |
| return pool | |
| else: | |
| results = [] | |
| length = get_length(args, kwargs) | |
| for i in tqdm(range(length), desc=desc, disable=not print_progress): | |
| action_args, action_kwargs = get_action_args(length, args, kwargs, i) | |
| async_result = action(*action_args, **action_kwargs) | |
| results.append(async_result) | |
| return results | |