|
|
|
|
|
|
| from typing import List, Union |
| import numpy as np |
| import axengine |
| import torch |
| from PIL import Image |
| from transformers import CLIPTokenizer, PreTrainedTokenizer |
| import time |
| import argparse |
| import uuid |
| import os |
| import traceback |
| from diffusers import DPMSolverMultistepScheduler |
|
|
| |
| DEBUG_MODE = True |
| LOG_TIMESTAMP = True |
|
|
| def debug_log(msg): |
| if DEBUG_MODE: |
| timestamp = f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] " if LOG_TIMESTAMP else "" |
| print(f"{timestamp}[DEBUG] {msg}") |
|
|
| def get_args(): |
| try: |
| parser = argparse.ArgumentParser( |
| prog="StableDiffusion", |
| description="Generate picture with the input prompt using DPM++ sampler" |
| ) |
| parser.add_argument("--prompt", type=str, required=False, |
| default="masterpiece, best quality, 1girl, (colorful),(delicate eyes and face), volumatic light, ray tracing, bust shot ,extremely detailed CG unity 8k wallpaper,solo,smile,intricate skirt,((flying petal)),(Flowery meadow) sky, cloudy_sky, moonlight, moon, night, (dark theme:1.3), light, fantasy, windy, magic sparks, dark castle,white hair", |
| help="the input text prompt") |
| parser.add_argument("--text_model_dir", type=str, required=False, default="./models/", |
| help="Path to text encoder and tokenizer files") |
| parser.add_argument("--unet_model", type=str, required=False, default="./models/unet.axmodel", |
| help="Path to unet axmodel model") |
| parser.add_argument("--vae_decoder_model", type=str, required=False, default="./models/vae_decoder.axmodel", |
| help="Path to vae decoder axmodel model") |
| parser.add_argument("--time_input", type=str, required=False, |
| default="./models/time_input_dpmpp_20steps.npy", |
| help="Path to time input file") |
| parser.add_argument("--save_dir", type=str, required=False, default="./txt2img_output_axe", |
| help="Path to the output image file") |
| parser.add_argument("--num_inference_steps", type=int, default=20, |
| help="Number of inference steps for DPM++ sampler") |
| parser.add_argument("--guidance_scale", type=float, default=7.5, help="Guidance scale for CFG") |
| parser.add_argument("--seed", type=int, default=None, help="Random seed") |
| return parser.parse_args() |
| except Exception as e: |
| print(f"参数解析失败: {str(e)}") |
| traceback.print_exc() |
| exit(1) |
|
|
| def get_embeds(prompt, negative_prompt, tokenizer_dir, text_encoder_dir): |
| """获取正负提示词的嵌入(带形状验证)""" |
| try: |
| debug_log(f"开始处理提示词: {prompt[:50]}...") |
| start_time = time.time() |
| |
| tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir) |
| |
| def process_prompt(prompt_text): |
| inputs = tokenizer( |
| prompt_text, |
| padding="max_length", |
| max_length=77, |
| truncation=True, |
| return_tensors="pt" |
| ) |
| debug_log(f"Tokenizer输出形状: {inputs.input_ids.shape}") |
| |
| model_path = os.path.join(text_encoder_dir, "sd15_text_encoder_sim.axmodel") |
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"文本编码器模型不存在: {model_path}") |
| |
| session = axengine.InferenceSession(model_path) |
| outputs = session.run(None, {"input_ids": inputs.input_ids.numpy().astype(np.int32)})[0] |
| debug_log(f"文本编码器输出形状: {outputs.shape} | dtype: {outputs.dtype}") |
| return outputs |
| |
| neg_start = time.time() |
| neg_embeds = process_prompt(negative_prompt) |
| pos_embeds = process_prompt(prompt) |
| debug_log(f"文本编码完成 | 耗时: {(time.time()-start_time):.2f}s") |
| |
| |
| if neg_embeds.shape != (1, 77, 768) or pos_embeds.shape != (1, 77, 768): |
| raise ValueError(f"嵌入形状异常: 负面{neg_embeds.shape}, 正面{pos_embeds.shape}") |
| |
| return neg_embeds, pos_embeds |
| except Exception as e: |
| print(f"获取嵌入失败: {str(e)}") |
| traceback.print_exc() |
| exit(1) |
|
|
| def main(): |
| try: |
| debug_log("程序启动") |
| args = get_args() |
| debug_log(f"参数解析完成 | 随机种子: {args.seed} | 推理步数: {args.num_inference_steps}") |
| |
| |
| seed = args.seed if args.seed else int(time.time()) |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| debug_log(f"随机种子设置完成: {seed}") |
| |
| |
| model_paths = [ |
| args.unet_model, |
| args.vae_decoder_model, |
| os.path.join(args.text_model_dir, 'tokenizer'), |
| os.path.join(args.text_model_dir, 'text_encoder') |
| ] |
| for path in model_paths: |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"模型路径不存在: {path}") |
| |
| |
| debug_log("初始化调度器...") |
| scheduler_start = time.time() |
| scheduler = DPMSolverMultistepScheduler( |
| num_train_timesteps=1000, |
| beta_start=0.00085, |
| beta_end=0.012, |
| beta_schedule="scaled_linear", |
| algorithm_type="dpmsolver++", |
| use_karras_sigmas=True |
| ) |
| scheduler.set_timesteps(args.num_inference_steps) |
| debug_log(f"调度器初始化完成 | 耗时: {(time.time()-scheduler_start):.2f}s") |
| |
| |
| debug_log("加载NPU模型...") |
| model_load_start = time.time() |
| unet_session_main = axengine.InferenceSession(args.unet_model) |
| vae_decoder = axengine.InferenceSession(args.vae_decoder_model) |
| debug_log(f"模型加载完成 | 总耗时: {(time.time()-model_load_start):.2f}s") |
| debug_log(f"UNET输入信息: {[str(inp) for inp in unet_session_main.get_inputs()]}") |
| debug_log(f"VAE输入信息: {[str(inp) for inp in vae_decoder.get_inputs()]}") |
| |
| |
| embed_start = time.time() |
| neg_embeds, pos_embeds = get_embeds( |
| args.prompt, |
| "sketch, duplicate, ugly...", |
| os.path.join(args.text_model_dir, 'tokenizer'), |
| os.path.join(args.text_model_dir, 'text_encoder') |
| ) |
| debug_log(f"提示词处理完成 | 总耗时: {(time.time()-embed_start):.2f}s") |
| |
| |
| latent_start = time.time() |
| latents_shape = [1, 4, 60, 40] |
| generator = torch.Generator().manual_seed(seed) |
| latent = torch.randn(latents_shape, generator=generator) |
| init_scale = scheduler.init_noise_sigma |
| latent = latent * init_scale |
| debug_log(f"潜在变量初始化 | 形状: {latent.shape} | 缩放系数: {init_scale}") |
| latent = latent.numpy().astype(np.float32) |
| debug_log(f"潜在变量转换完成 | dtype: {latent.dtype}") |
| |
| |
| debug_log(f"加载时间嵌入: {args.time_input}") |
| time_data = np.load(args.time_input) |
| if len(time_data) < args.num_inference_steps: |
| raise ValueError(f"时间嵌入不足: 需要{args.num_inference_steps}, 实际{len(time_data)}") |
| time_data = time_data[:args.num_inference_steps] |
| debug_log(f"时间嵌入验证通过 | 形状: {time_data.shape}") |
| |
| |
| debug_log("开始采样循环...") |
| total_unet_time = 0 |
| for step_idx, timestep in enumerate(scheduler.timesteps.numpy().astype(np.int64)): |
| step_start = time.time() |
| debug_log(f"\n--- 步骤 {step_idx+1}/{args.num_inference_steps} [ts={timestep}] ---") |
| |
| try: |
| |
| if np.isnan(latent).any(): |
| raise ValueError("潜在变量包含NaN值!") |
| |
| |
| time_emb = np.expand_dims(time_data[step_idx], axis=0) |
| debug_log(f"时间嵌入形状: {time_emb.shape}") |
| |
| |
| debug_log("运行UNET(负面提示)...") |
| unet_neg_start = time.time() |
| noise_pred_neg = unet_session_main.run(None, { |
| "sample": latent, |
| "/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb, |
| "encoder_hidden_states": neg_embeds |
| })[0] |
| debug_log(f"UNET(负面)完成 | 形状: {noise_pred_neg.shape} | 耗时: {(time.time()-unet_neg_start):.2f}s") |
| |
| debug_log("运行UNET(正面提示)...") |
| unet_pos_start = time.time() |
| noise_pred_pos = unet_session_main.run(None, { |
| "sample": latent, |
| "/down_blocks.0/resnets.0/act_1/Mul_output_0": time_emb, |
| "encoder_hidden_states": pos_embeds |
| })[0] |
| debug_log(f"UNET(正面)完成 | 耗时: {(time.time()-unet_pos_start):.2f}s") |
| |
| |
| debug_log(f"应用CFG指导(scale={args.guidance_scale})...") |
| noise_pred = noise_pred_neg + args.guidance_scale * (noise_pred_pos - noise_pred_neg) |
| debug_log(f"噪声预测范围: [{noise_pred.min():.3f}, {noise_pred.max():.3f}]") |
| |
| |
| latent_tensor = torch.from_numpy(latent) |
| noise_pred_tensor = torch.from_numpy(noise_pred) |
| |
| |
| debug_log("更新潜在变量...") |
| scheduler_start = time.time() |
| latent_tensor = scheduler.step( |
| model_output=noise_pred_tensor, |
| timestep=timestep, |
| sample=latent_tensor |
| ).prev_sample |
| debug_log(f"调度器更新完成 | 耗时: {(time.time()-scheduler_start):.2f}s") |
| |
| |
| latent = latent_tensor.numpy().astype(np.float32) |
| debug_log(f"更新后潜在变量范围: [{latent.min():.3f}, {latent.max():.3f}]") |
| |
| step_time = time.time() - step_start |
| total_unet_time += step_time |
| debug_log(f"步骤完成 | 单步耗时: {step_time:.2f}s | 累计耗时: {total_unet_time:.2f}s") |
| |
| except Exception as e: |
| print(f"步骤 {step_idx+1} 执行失败: {str(e)}") |
| traceback.print_exc() |
| exit(1) |
| |
| |
| debug_log("\n开始VAE解码...") |
| vae_start = time.time() |
| try: |
| latent = latent / 0.18215 |
| debug_log(f"VAE输入范围: [{latent.min():.3f}, {latent.max():.3f}]") |
| image = vae_decoder.run(None, {"latent": latent})[0] |
| debug_log(f"VAE输出形状: {image.shape} | 耗时: {(time.time()-vae_start):.2f}s") |
| except Exception as e: |
| print(f"VAE解码失败: {str(e)}") |
| traceback.print_exc() |
| exit(1) |
| |
| |
| debug_log("保存结果...") |
| try: |
| image = np.transpose(image, (0, 2, 3, 1)).squeeze(axis=0) |
| image_denorm = np.clip(image / 2 + 0.5, 0, 1) |
| image = (image_denorm * 255).round().astype("uint8") |
| debug_log(f"图像形状: {image.shape} | dtype: {image.dtype}") |
| |
| pil_image = Image.fromarray(image[:, :, :3]) |
| save_path = os.path.join(args.save_dir, f"{uuid.uuid4()}.png") |
| pil_image.save(save_path) |
| debug_log(f"图像保存成功: {save_path}") |
| except Exception as e: |
| print(f"保存失败: {str(e)}") |
| traceback.print_exc() |
| exit(1) |
| |
| except Exception as e: |
| print(f"主流程执行失败: {str(e)}") |
| traceback.print_exc() |
| exit(1) |
|
|
| if __name__ == '__main__': |
| main() |
|
|