Instructions to use zeyuren2002/EvalMDE with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use zeyuren2002/EvalMDE with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("zeyuren2002/EvalMDE", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import hydra | |
| import pytorch_lightning as pl | |
| import os | |
| from omegaconf import DictConfig | |
| from ppd.entrys.utils import get_data, get_model, get_callbacks, print_cfg, find_last_ckpt_path | |
| from typing import Tuple | |
| from tqdm.auto import tqdm | |
| from ppd.utils.logger import Log | |
| def setup_trainer(cfg: DictConfig) -> Tuple[pl.Trainer, pl.LightningModule, pl.LightningDataModule]: | |
| """ | |
| Set up the PyTorch Lightning trainer, model, and data module. | |
| """ | |
| if cfg.print_cfg: print_cfg(cfg, use_rich=True) | |
| pl.seed_everything(cfg.seed) | |
| # preparation | |
| datamodule = get_data(cfg, wo_train=True) | |
| model = get_model(cfg) | |
| if cfg.pretrained_model: | |
| model.load_pretrained_model_eval(cfg.pretrained_model) | |
| else: | |
| raise FileNotFoundError("Pretrained model not found. Please specify 'pretrained_model' path in config.") | |
| # PL callbacks and logger | |
| callbacks = get_callbacks(cfg) | |
| cfg_logger = DictConfig.copy(cfg.logger) | |
| cfg_logger.update({'version': 'val_metrics'}) | |
| logger = hydra.utils.instantiate(cfg_logger, _recursive_=False) | |
| # PL-Trainer | |
| trainer = pl.Trainer( | |
| accelerator="gpu", | |
| logger=logger if logger is not None else False, | |
| callbacks=callbacks, | |
| **cfg.pl_trainer, | |
| ) | |
| return trainer, model, datamodule | |
| def val(cfg: DictConfig) -> None: | |
| """ | |
| Validate the model. | |
| """ | |
| trainer, model, datamodule = setup_trainer(cfg) | |
| trainer.validate(model, datamodule.val_dataloader()) | |
| def predict(cfg: DictConfig) -> None: | |
| """ | |
| Predict using the model. | |
| """ | |
| trainer, model, datamodule = setup_trainer(cfg) | |
| trainer.predict(model, datamodule.val_dataloader()) | |