| |
| """ |
| Zero123++ 工程六视图生成器 |
| 通过旋转输入图片来模拟不同视角 |
| """ |
|
|
| import gradio as gr |
| import torch |
| from PIL import Image |
| import numpy as np |
| from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler |
|
|
| |
| pipeline = None |
|
|
| |
| |
| |
| |
| ENGINEERING_VIEWS = { |
| "主视图": {"rotate_input": 0, "select_index": 0, "position": (0, 0)}, |
| "右视图": {"rotate_input": 0, "select_index": 1, "position": (1, 0)}, |
| "后视图": {"rotate_input": 0, "select_index": 3, "position": (2, 0)}, |
| "左视图": {"rotate_input": 0, "select_index": 4, "position": (0, 1)}, |
| "俯视图": {"rotate_input": -90, "select_index": 0, "position": (1, 1)}, |
| "底视图": {"rotate_input": 90, "select_index": 0, "position": (2, 1)}, |
| } |
|
|
| def load_model(): |
| """加载 Zero123++ 模型""" |
| global pipeline |
|
|
| if pipeline is not None: |
| return |
|
|
| print("正在加载 Zero123++ 模型...") |
|
|
| |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
| try: |
| |
| pipeline = DiffusionPipeline.from_pretrained( |
| "sudo-ai/zero123plus-v1.2", |
| custom_pipeline="sudo-ai/zero123plus-pipeline", |
| torch_dtype=dtype |
| ) |
|
|
| |
| pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( |
| pipeline.scheduler.config, |
| timestep_spacing='trailing' |
| ) |
|
|
| pipeline.to(device) |
|
|
| |
| if torch.cuda.is_available(): |
| try: |
| pipeline.enable_attention_slicing() |
| pipeline.enable_vae_slicing() |
| except: |
| pass |
|
|
| print(f"✓ 模型加载完成 (设备: {device})") |
| except Exception as e: |
| print(f"错误: 无法加载 Zero123++ - {e}") |
| raise |
|
|
| def rotate_image(image, angle): |
| """旋转图像""" |
| if angle == 0: |
| return image |
| elif angle == 90: |
| return image.rotate(-90, expand=True) |
| elif angle == -90: |
| return image.rotate(90, expand=True) |
| elif angle == 180: |
| return image.rotate(180, expand=True) |
| return image |
|
|
| def generate_multiview(input_image): |
| """ |
| 生成多视图输出 |
| |
| 输入: |
| input_image: PIL Image |
| |
| 输出: |
| PIL Image with 6 views (2x3 grid) |
| """ |
| global pipeline |
|
|
| if pipeline is None: |
| load_model() |
|
|
| |
| img = input_image.resize((320, 320), Image.LANCZOS) |
|
|
| |
| result = pipeline( |
| img, |
| num_inference_steps=75 |
| ).images[0] |
|
|
| |
| |
| result_w, result_h = result.size |
| view_w = result_w // 3 |
| view_h = result_h // 2 |
|
|
| |
| views = [] |
| for row in range(2): |
| for col in range(3): |
| x = col * view_w |
| y = row * view_h |
| view = result.crop((x, y, x + view_w, y + view_h)) |
| views.append(view) |
|
|
| return views |
|
|
| def process_image(input_image, progress=gr.Progress()): |
| """ |
| 处理输入图像,生成工程六视图 |
| |
| 输入: |
| input_image: PIL Image |
| progress: Gradio Progress 跟踪器 |
| |
| 输出: |
| result_image: PIL Image (六视图合成图) |
| """ |
| if input_image is None: |
| return None |
|
|
| try: |
| |
| load_model() |
|
|
| |
| img = input_image |
| if img.size[0] != img.size[1]: |
| size = min(img.size) |
| img = img.crop(( |
| (img.size[0] - size) // 2, |
| (img.size[1] - size) // 2, |
| (img.size[0] + size) // 2, |
| (img.size[1] + size) // 2 |
| )) |
|
|
| progress(0.1, desc="生成水平视图...") |
| |
| horizontal_views = generate_multiview(img) |
|
|
| progress(0.5, desc="生成俯视图...") |
| |
| img_rotated_up = rotate_image(img, -90) |
| top_views = generate_multiview(img_rotated_up) |
|
|
| progress(0.8, desc="生成底视图...") |
| |
| img_rotated_down = rotate_image(img, 90) |
| bottom_views = generate_multiview(img_rotated_down) |
|
|
| |
| |
| |
| view_size = 320 |
| combined = Image.new('RGB', (view_size * 3, view_size * 2)) |
|
|
| |
| combined.paste(horizontal_views[0], (0 * view_size, 0)) |
| combined.paste(horizontal_views[1], (1 * view_size, 0)) |
| combined.paste(horizontal_views[3], (2 * view_size, 0)) |
|
|
| |
| combined.paste(horizontal_views[4], (0 * view_size, view_size)) |
| combined.paste(top_views[0], (1 * view_size, view_size)) |
| combined.paste(bottom_views[0], (2 * view_size, view_size)) |
|
|
| progress(1.0, desc="完成!") |
| print("✓ 所有视图生成完成") |
| return combined |
|
|
| except Exception as e: |
| print(f"错误: {e}") |
| import traceback |
| traceback.print_exc() |
| raise gr.Error(f"处理失败: {str(e)}") |
|
|
| |
| def create_demo(): |
| with gr.Blocks(title="Zero123++ 工程六视图生成器") as demo: |
| gr.Markdown(""" |
| # Zero123++ 工程六视图生成器 |
| |
| 将单张主视图转换为工程六视图 |
| |
| **输入要求:** |
| - 建议使用正方形图片 |
| - 推荐分辨率 >= 320x320 |
| - 脚本会自动裁剪非正方形图片 |
| |
| **输出说明:** |
| 生成工程六视图,排列为 2 行 3 列: |
| |
| | 主视图 | 右视图 | 后视图 | |
| |-------|-------|-------| |
| | 左视图 | 俯视图 | 底视图 | |
| |
| **技术原理:** |
| - 使用 Zero123++ v1.2 模型 |
| - 通过 3 次推理 + 旋转输入实现工程六视图 |
| - 每次推理约 30-60 秒,总计约 2-3 分钟 |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| input_image = gr.Image( |
| label="输入主视图", |
| type="pil", |
| height=400 |
| ) |
|
|
| generate_btn = gr.Button("生成工程六视图", variant="primary", size="lg") |
|
|
| gr.Markdown(""" |
| **注意:** |
| - 需要运行 3 次推理(水平+俯视+底视) |
| - 总耗时约 2-3 分钟 |
| - 请耐心等待 |
| """) |
|
|
| with gr.Column(): |
| output_image = gr.Image( |
| label="工程六视图输出", |
| type="pil", |
| height=400 |
| ) |
|
|
| gr.Markdown(""" |
| ### 视角说明 |
| |
| | 视图 | 方法 | 说明 | |
| |-----|------|------| |
| | 主视图 | Zero123++ 30° 视角 | 正面 | |
| | 右视图 | Zero123++ 90° 视角 | 右侧 | |
| | 后视图 | Zero123++ 210° 视角 | 背面 | |
| | 左视图 | Zero123++ 270° 视角 | 左侧 | |
| | 俯视图 | 输入旋转-90° → Zero123++ | 从上往下 | |
| | 底视图 | 输入旋转+90° → Zero123++ | 从下往上 | |
| |
| ### 技术说明 |
| - 模型: [Zero123++ v1.2](https://huggingface.co/sudo-ai/zero123plus-v1.2) |
| - Zero123++ 固定输出 6 个环绕视图 |
| - 通过选择合适的视角 + 旋转输入实现工程视图 |
| - v1.2 改进: 更稳定的视角,FOV 30° |
| |
| ### 引用 |
| ```bibtex |
| @misc{shi2023zero123plus, |
| title={Zero123++: a Single Image to Consistent Multi-view Diffusion Base Model}, |
| author={Ruoxi Shi and Hansheng Chen and others}, |
| year={2023}, |
| eprint={2310.15110}, |
| archivePrefix={arXiv}, |
| primaryClass={cs.CV} |
| } |
| ``` |
| """) |
|
|
| |
| generate_btn.click( |
| fn=process_image, |
| inputs=[input_image], |
| outputs=output_image |
| ) |
|
|
| return demo |
|
|
| if __name__ == "__main__": |
| |
| print("=" * 50) |
| print("Zero123++ 工程六视图生成器") |
| print("=" * 50) |
| load_model() |
|
|
| |
| demo = create_demo() |
| demo.queue(max_size=5) |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False |
| ) |
|
|