| import argparse |
| from pathlib import Path |
|
|
| import torch |
| from diffusers import StableDiffusionXLPipeline, HunyuanDiTPipeline |
|
|
| from predictor.inference.loader import load_predictor, denormalize_prediction |
| from predictor.inference.noise_selection import generate_noise_candidates, select_top_k_noise |
| from predictor.configs.model_dims import get_dims |
|
|
| GENERATION_DEFAULTS = { |
| 'sdxl': {'steps': 50, 'guidance_scale': 5.5}, |
| 'hunyuan_dit': {'steps': 50, 'guidance_scale': 5.0}, |
| } |
|
|
|
|
| def encode_prompt(pipe, prompt, model_type, device): |
| if model_type == 'sdxl': |
| embeds, neg_embeds, pooled, neg_pooled = pipe.encode_prompt( |
| prompt=prompt, |
| prompt_2=prompt, |
| device=device, |
| num_images_per_prompt=1, |
| do_classifier_free_guidance=True, |
| ) |
| pred_embeds = embeds |
| pred_mask = torch.ones(embeds.shape[:2], device=device, dtype=torch.long) |
| gen_kwargs = { |
| 'prompt_embeds': embeds, |
| 'negative_prompt_embeds': neg_embeds, |
| 'pooled_prompt_embeds': pooled, |
| 'negative_pooled_prompt_embeds': neg_pooled, |
| } |
|
|
| elif model_type == 'hunyuan_dit': |
| max_seq_len = get_dims(model_type)['seq_len'] |
| tokens = pipe.tokenizer_2( |
| prompt, |
| max_length=max_seq_len, |
| padding='max_length', |
| truncation=True, |
| return_tensors='pt', |
| ).to(device) |
| with torch.no_grad(): |
| t5_output = pipe.text_encoder_2( |
| tokens.input_ids, |
| attention_mask=tokens.attention_mask, |
| ) |
| pred_embeds = t5_output[0].to(dtype=torch.float16) |
| pred_mask = tokens.attention_mask |
|
|
| gen_kwargs = {} |
|
|
| else: |
| raise ValueError(f"Unsupported model_type: {model_type}") |
|
|
| return pred_embeds, pred_mask, gen_kwargs |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument('--pipeline', default='sdxl', |
| choices=['sdxl', 'hunyuan_dit'], type=str) |
| parser.add_argument('--prompt', type=str, required=True, |
| help='Text prompt for image generation') |
| parser.add_argument('--pretrained-path', type=str, required=True) |
| parser.add_argument('--inference-step', default=None, type=int) |
| parser.add_argument('--cfg', default=None, type=float) |
| parser.add_argument('--N', default=100, type=int) |
| parser.add_argument('--B', default=1, type=int) |
| parser.add_argument('--seed', default=0, type=int) |
| parser.add_argument('--output-dir', default='output', type=str) |
|
|
| return parser.parse_args() |
|
|
|
|
| def run_single_prompt(pipe, predictor, norm_info, prompt, args, dims, device, output_dir): |
| pred_embeds, pred_mask, gen_kwargs = encode_prompt( |
| pipe, prompt, args.pipeline, device, |
| ) |
|
|
| generator = torch.Generator(device=device).manual_seed(args.seed) |
| denoiser = pipe.unet if hasattr(pipe, 'unet') else pipe.transformer |
| noises = generate_noise_candidates( |
| num_candidates=args.N, |
| latent_shape=dims['latent_shape'], |
| device=device, |
| dtype=denoiser.dtype, |
| generator=generator, |
| ) |
|
|
| selected_noises, scores = select_top_k_noise( |
| predictor=predictor, |
| noises=noises, |
| prompt_embeds=pred_embeds, |
| prompt_mask=pred_mask, |
| num_select=args.B, |
| ) |
|
|
| raw_scores = denormalize_prediction(scores, norm_info) |
| target_name = norm_info.get('target', 'score') |
|
|
| print(f'Prompt: {prompt}') |
| print(f' PAINE predicted {target_name}: {[f"{s:.4f}" for s in raw_scores.tolist()]}') |
|
|
| B = selected_noises.shape[0] |
| expanded_kwargs = {} |
| for k, v in gen_kwargs.items(): |
| if isinstance(v, torch.Tensor) and v.dim() >= 2: |
| expanded_kwargs[k] = v.expand(B, *[-1] * (v.dim() - 1)) |
| else: |
| expanded_kwargs[k] = v |
|
|
| paine_result = pipe( |
| prompt=None if gen_kwargs else prompt, |
| **expanded_kwargs, |
| latents=selected_noises, |
| height=1024, |
| width=1024, |
| num_images_per_prompt=1, |
| num_inference_steps=args.inference_step, |
| guidance_scale=args.cfg, |
| ) |
|
|
| for i, img in enumerate(paine_result.images): |
| path = output_dir / f'paine_{i:02d}.jpg' |
| img.save(path) |
|
|
|
|
| def main(args): |
| defaults = GENERATION_DEFAULTS.get(args.pipeline, {'steps': 50, 'guidance_scale': 5.5}) |
| if args.inference_step is None: |
| args.inference_step = defaults['steps'] |
| if args.cfg is None: |
| args.cfg = defaults['guidance_scale'] |
|
|
| dtype = torch.float16 |
| device = torch.device('cuda') |
|
|
| if args.pipeline == 'sdxl': |
| pipe = StableDiffusionXLPipeline.from_pretrained( |
| 'stabilityai/stable-diffusion-xl-base-1.0', |
| variant='fp16', use_safetensors=True, |
| torch_dtype=dtype).to(device) |
|
|
| elif args.pipeline == 'hunyuan_dit': |
| pipe = HunyuanDiTPipeline.from_pretrained( |
| 'Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers', |
| torch_dtype=dtype).to(device) |
|
|
| dims = get_dims(args.pipeline) |
| predictor, norm_info = load_predictor(args.pretrained_path, device=device) |
|
|
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| run_single_prompt(pipe, predictor, norm_info, args.prompt, args, dims, device, output_dir) |
|
|
| print(f'Done. Output: {output_dir}/') |
|
|
|
|
| if __name__ == '__main__': |
| args = get_args() |
| main(args) |
|
|