| |
| |
|
|
| from cog import BasePredictor, Input, Path |
| import os |
| import time |
| import torch |
| import numpy as np |
| from typing import List |
| from transformers import CLIPImageProcessor |
| from diffusers import ( |
| StableDiffusionXLPipeline, |
| DPMSolverMultistepScheduler, |
| DDIMScheduler, |
| HeunDiscreteScheduler, |
| EulerAncestralDiscreteScheduler, |
| EulerDiscreteScheduler, |
| PNDMScheduler |
| ) |
| from diffusers.pipelines.stable_diffusion.safety_checker import ( |
| StableDiffusionSafetyChecker, |
| ) |
|
|
| class KarrasDPM: |
| def from_config(config): |
| return DPMSolverMultistepScheduler.from_config(config, use_karras_sigmas=True) |
|
|
| SCHEDULERS = { |
| "DDIM": DDIMScheduler, |
| "DPMSolverMultistep": DPMSolverMultistepScheduler, |
| "HeunDiscrete": HeunDiscreteScheduler, |
| "KarrasDPM": KarrasDPM, |
| "K_EULER_ANCESTRAL": EulerAncestralDiscreteScheduler, |
| "K_EULER": EulerDiscreteScheduler, |
| "PNDM": PNDMScheduler, |
| } |
|
|
| MODEL_NAME = "artificialguybr/NebulRedmond" |
| MODEL_CACHE = "model-cache" |
| SAFETY_CACHE = "safety-cache" |
| FEATURE_EXTRACTOR = "feature-extractor" |
|
|
| class Predictor(BasePredictor): |
| def setup(self) -> None: |
| """Load the model into memory to make running multiple predictions efficient""" |
| start = time.time() |
| print("Loading safety checker...") |
| self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( |
| SAFETY_CACHE, torch_dtype=torch.float16 |
| ).to("cuda") |
| self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR) |
| print("Loading txt2img model") |
| self.pipe = StableDiffusionXLPipeline.from_pretrained( |
| MODEL_NAME, |
| torch_dtype=torch.float16, |
| use_safetensors=True, |
| cache_dir=MODEL_CACHE |
| ).to('cuda') |
| print("setup took: ", time.time() - start) |
|
|
| def run_safety_checker(self, image): |
| safety_checker_input = self.feature_extractor(image, return_tensors="pt").to( |
| "cuda" |
| ) |
| np_image = [np.array(val) for val in image] |
| image, has_nsfw_concept = self.safety_checker( |
| images=np_image, |
| clip_input=safety_checker_input.pixel_values.to(torch.float16), |
| ) |
| return image, has_nsfw_concept |
| |
| @torch.inference_mode() |
| def predict( |
| self, |
| prompt: str = Input( |
| description="Input prompt", |
| default="An astronaut riding a rainbow unicorn", |
| ), |
| negative_prompt: str = Input( |
| description="Input Negative Prompt", |
| default="", |
| ), |
| width: int = Input( |
| description="Width of output image", |
| default=1024, |
| ), |
| height: int = Input( |
| description="Height of output image", |
| default=1024, |
| ), |
| num_outputs: int = Input( |
| description="Number of images to output.", |
| ge=1, |
| le=4, |
| default=1, |
| ), |
| scheduler: str = Input( |
| description="scheduler", |
| choices=SCHEDULERS.keys(), |
| default="K_EULER", |
| ), |
| num_inference_steps: int = Input( |
| description="Number of denoising steps", ge=1, le=100, default=40 |
| ), |
| guidance_scale: float = Input( |
| description="Scale for classifier-free guidance", ge=1, le=20, default=7.5 |
| ), |
| seed: int = Input( |
| description="Random seed. Leave blank to randomize the seed", default=None |
| ), |
| apply_watermark: bool = Input( |
| description="Applies a watermark to enable determining if an image is generated in downstream applications. If you have other provisions for generating or deploying images safely, you can use this to disable watermarking.", |
| default=True, |
| ), |
| disable_safety_checker: bool = Input( |
| description="Disable safety checker for generated images. This feature is only available through the API. See [https://replicate.com/docs/how-does-replicate-work#safety](https://replicate.com/docs/how-does-replicate-work#safety)", |
| default=False |
| ) |
| ) -> List[Path]: |
| """Run a single prediction on the model.""" |
| if seed is None: |
| seed = int.from_bytes(os.urandom(3), "big") |
| print(f"Using seed: {seed}") |
| generator = torch.Generator("cuda").manual_seed(seed) |
| |
| pipe = self.pipe |
| pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config) |
|
|
| |
| if not apply_watermark: |
| watermark_cache = pipe.watermark |
| pipe.watermark = None |
| |
| sdxl_kwargs = {} |
| sdxl_kwargs["width"] = width |
| sdxl_kwargs["height"] = height |
| |
| common_args = { |
| "prompt": [prompt] * num_outputs, |
| "negative_prompt": [negative_prompt] * num_outputs, |
| "guidance_scale": guidance_scale, |
| "generator": generator, |
| "num_inference_steps": num_inference_steps, |
| } |
|
|
| output = pipe(**common_args, **sdxl_kwargs) |
|
|
| if not apply_watermark: |
| pipe.watermark = watermark_cache |
|
|
| if not disable_safety_checker: |
| _, has_nsfw_content = self.run_safety_checker(output.images) |
| |
| output_paths = [] |
| for i, image in enumerate(output.images): |
| if not disable_safety_checker: |
| if has_nsfw_content[i]: |
| print(f"NSFW content detected in image {i}") |
| continue |
| output_path = f"/tmp/out-{i}.png" |
| image.save(output_path) |
| output_paths.append(Path(output_path)) |
|
|
| if len(output_paths) == 0: |
| raise Exception( |
| f"NSFW content detected. Try running it again, or try a different prompt." |
| ) |
|
|
| return output_paths |