Text-to-Image
English
PAINE / paine_inference.py
joonghk's picture
first commit
03de09d
import argparse
from pathlib import Path
import torch
from diffusers import StableDiffusionXLPipeline, HunyuanDiTPipeline
from predictor.inference.loader import load_predictor, denormalize_prediction
from predictor.inference.noise_selection import generate_noise_candidates, select_top_k_noise
from predictor.configs.model_dims import get_dims
GENERATION_DEFAULTS = {
'sdxl': {'steps': 50, 'guidance_scale': 5.5},
'hunyuan_dit': {'steps': 50, 'guidance_scale': 5.0},
}
def encode_prompt(pipe, prompt, model_type, device):
if model_type == 'sdxl':
embeds, neg_embeds, pooled, neg_pooled = pipe.encode_prompt(
prompt=prompt,
prompt_2=prompt,
device=device,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
)
pred_embeds = embeds
pred_mask = torch.ones(embeds.shape[:2], device=device, dtype=torch.long)
gen_kwargs = {
'prompt_embeds': embeds,
'negative_prompt_embeds': neg_embeds,
'pooled_prompt_embeds': pooled,
'negative_pooled_prompt_embeds': neg_pooled,
}
elif model_type == 'hunyuan_dit':
max_seq_len = get_dims(model_type)['seq_len']
tokens = pipe.tokenizer_2(
prompt,
max_length=max_seq_len,
padding='max_length',
truncation=True,
return_tensors='pt',
).to(device)
with torch.no_grad():
t5_output = pipe.text_encoder_2(
tokens.input_ids,
attention_mask=tokens.attention_mask,
)
pred_embeds = t5_output[0].to(dtype=torch.float16)
pred_mask = tokens.attention_mask
gen_kwargs = {}
else:
raise ValueError(f"Unsupported model_type: {model_type}")
return pred_embeds, pred_mask, gen_kwargs
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--pipeline', default='sdxl',
choices=['sdxl', 'hunyuan_dit'], type=str)
parser.add_argument('--prompt', type=str, required=True,
help='Text prompt for image generation')
parser.add_argument('--pretrained-path', type=str, required=True)
parser.add_argument('--inference-step', default=None, type=int)
parser.add_argument('--cfg', default=None, type=float)
parser.add_argument('--N', default=100, type=int)
parser.add_argument('--B', default=1, type=int)
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--output-dir', default='output', type=str)
return parser.parse_args()
def run_single_prompt(pipe, predictor, norm_info, prompt, args, dims, device, output_dir):
pred_embeds, pred_mask, gen_kwargs = encode_prompt(
pipe, prompt, args.pipeline, device,
)
generator = torch.Generator(device=device).manual_seed(args.seed)
denoiser = pipe.unet if hasattr(pipe, 'unet') else pipe.transformer
noises = generate_noise_candidates(
num_candidates=args.N,
latent_shape=dims['latent_shape'],
device=device,
dtype=denoiser.dtype,
generator=generator,
)
selected_noises, scores = select_top_k_noise(
predictor=predictor,
noises=noises,
prompt_embeds=pred_embeds,
prompt_mask=pred_mask,
num_select=args.B,
)
raw_scores = denormalize_prediction(scores, norm_info)
target_name = norm_info.get('target', 'score')
print(f'Prompt: {prompt}')
print(f' PAINE predicted {target_name}: {[f"{s:.4f}" for s in raw_scores.tolist()]}')
B = selected_noises.shape[0]
expanded_kwargs = {}
for k, v in gen_kwargs.items():
if isinstance(v, torch.Tensor) and v.dim() >= 2:
expanded_kwargs[k] = v.expand(B, *[-1] * (v.dim() - 1))
else:
expanded_kwargs[k] = v
paine_result = pipe(
prompt=None if gen_kwargs else prompt,
**expanded_kwargs,
latents=selected_noises,
height=1024,
width=1024,
num_images_per_prompt=1,
num_inference_steps=args.inference_step,
guidance_scale=args.cfg,
)
for i, img in enumerate(paine_result.images):
path = output_dir / f'paine_{i:02d}.jpg'
img.save(path)
def main(args):
defaults = GENERATION_DEFAULTS.get(args.pipeline, {'steps': 50, 'guidance_scale': 5.5})
if args.inference_step is None:
args.inference_step = defaults['steps']
if args.cfg is None:
args.cfg = defaults['guidance_scale']
dtype = torch.float16
device = torch.device('cuda')
if args.pipeline == 'sdxl':
pipe = StableDiffusionXLPipeline.from_pretrained(
'stabilityai/stable-diffusion-xl-base-1.0',
variant='fp16', use_safetensors=True,
torch_dtype=dtype).to(device)
elif args.pipeline == 'hunyuan_dit':
pipe = HunyuanDiTPipeline.from_pretrained(
'Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers',
torch_dtype=dtype).to(device)
dims = get_dims(args.pipeline)
predictor, norm_info = load_predictor(args.pretrained_path, device=device)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
run_single_prompt(pipe, predictor, norm_info, args.prompt, args, dims, device, output_dir)
print(f'Done. Output: {output_dir}/')
if __name__ == '__main__':
args = get_args()
main(args)