alexander00001's picture
Update app.py
9e9fcb3 verified
raw
history blame
25.6 kB
try:
import spaces
SPACES_AVAILABLE = True
print("✅ Spaces available - ZeroGPU mode")
except ImportError:
SPACES_AVAILABLE = False
print("⚠️ Spaces not available - running in regular mode")
import gradio as gr
import torch
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
from PIL import Image
import datetime
import io
import json
import os
import re
from typing import Optional, List, Dict
import numpy as np
# ======================
# Configuration Section - 灵活模型配置
# ======================
# 1. 模型配置字典 - 支持多种模型类型
MODEL_CONFIGS = {
"wai_nsfw_illustrious_v80": {
"repo_id": "John6666/wai-nsfw-illustrious-v80-sdxl",
"type": "sdxl", # SDXL架构
"requires_safety_checker": False,
"default_negative": "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
"optimal_settings": {
"steps": 28,
"cfg": 7.0,
"sampler": "DPM++ 2M Karras"
},
"description": "WAI NSFW Illustrious v8.0 - 高质量插画风格模型"
},
"wai_nsfw_illustrious_v90": {
"repo_id": "John6666/wai-nsfw-illustrious-v90-sdxl",
"type": "sdxl",
"requires_safety_checker": False,
"default_negative": "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
"optimal_settings": {
"steps": 28,
"cfg": 7.0,
"sampler": "DPM++ 2M Karras"
},
"description": "WAI NSFW Illustrious v9.0 - 最新版本"
},
"wai_nsfw_illustrious_v110": {
"repo_id": "John6666/wai-nsfw-illustrious-v110-sdxl",
"type": "sdxl",
"requires_safety_checker": False,
"default_negative": "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
"optimal_settings": {
"steps": 30,
"cfg": 7.5,
"sampler": "DPM++ 2M Karras"
},
"description": "WAI NSFW Illustrious v11.0 - 增强版本"
},
"sdxl_base": {
"repo_id": "stabilityai/stable-diffusion-xl-base-1.0",
"type": "sdxl",
"requires_safety_checker": True,
"default_negative": "blurry, low quality, deformed, cartoon, anime, text, watermark, signature, username, worst quality, low res, bad anatomy, bad hands",
"optimal_settings": {
"steps": 30,
"cfg": 7.5,
"sampler": "Default"
},
"description": "Stable Diffusion XL Base 1.0 - 官方基础模型"
},
"realistic_vision": {
"repo_id": "SG161222/RealVisXL_V4.0",
"type": "sdxl",
"requires_safety_checker": False,
"default_negative": "blurry, low quality, deformed, text, watermark, signature, worst quality, bad anatomy",
"optimal_settings": {
"steps": 30,
"cfg": 7.5,
"sampler": "Default"
},
"description": "RealVisXL V4.0 - 高质量写实风格"
},
"anime_xl": {
"repo_id": "Linaqruf/animagine-xl-3.1",
"type": "sdxl",
"requires_safety_checker": False,
"default_negative": "lowres, bad anatomy, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated",
"optimal_settings": {
"steps": 28,
"cfg": 7.0,
"sampler": "Default"
},
"description": "Animagine XL 3.1 - 动漫风格"
},
"juggernaut_xl": {
"repo_id": "RunDiffusion/Juggernaut-XL-v9",
"type": "sdxl",
"requires_safety_checker": False,
"default_negative": "blurry, low quality, text, watermark, signature, worst quality",
"optimal_settings": {
"steps": 30,
"cfg": 7.5,
"sampler": "Default"
},
"description": "Juggernaut XL v9 - 通用高质量模型"
}
}
# 默认使用的模型 - 可以通过UI切换
DEFAULT_MODEL_KEY = "wai_nsfw_illustrious_v80"
# 2. 固定LoRA配置 - 自动加载
FIXED_LORAS = {
"detail_enhancer": {
"repo_id": "ostris/ikea-instructions-lora-sdxl",
"filename": None,
"weight": 0.5, # 降低权重避免过度影响
"trigger_words": "high quality, detailed",
"enabled": True # 可以禁用
},
"quality_boost": {
"repo_id": "stabilityai/stable-diffusion-xl-offset-example-lora",
"filename": None,
"weight": 0.4,
"trigger_words": "masterpiece, best quality",
"enabled": True
}
}
# 3. 风格模板 - 根据不同模型优化
STYLE_PROMPTS = {
"None": "",
"Realistic Photo": "photorealistic, ultra-detailed, natural lighting, 8k uhd, professional photography, DSLR, high quality, masterpiece, ",
"Anime/Illustration": "anime style, high quality illustration, vibrant colors, detailed, masterpiece, best quality, ",
"Artistic Illustration": "artistic illustration, painterly, detailed artwork, high quality, professional illustration, ",
"Comic Book": "comic book style, bold lines, dynamic composition, pop art, high quality, ",
"Watercolor": "watercolor painting, soft brush strokes, artistic, traditional art, masterpiece, ",
"Cinematic": "cinematic lighting, dramatic atmosphere, film grain, professional color grading, high quality, ",
}
# 4. 可选LoRA配置 - 用户可选择
OPTIONAL_LORAS = {
"None": {
"repo_id": None,
"weight": 0.0,
"trigger_words": "",
"description": "不使用额外LoRA"
},
"Offset Noise": {
"repo_id": "stabilityai/stable-diffusion-xl-offset-example-lora",
"weight": 0.7,
"trigger_words": "high contrast, dramatic lighting",
"description": "增强对比度和光照效果"
},
"LCM LoRA": {
"repo_id": "latent-consistency/lcm-lora-sdxl",
"weight": 0.8,
"trigger_words": "high quality",
"description": "快速生成模式"
},
"Pixel Art": {
"repo_id": "nerijs/pixel-art-xl",
"weight": 0.9,
"trigger_words": "pixel art style, 8bit, retro",
"description": "像素艺术风格"
},
"Watercolor": {
"repo_id": "ostris/watercolor-style-lora-sdxl",
"weight": 0.8,
"trigger_words": "watercolor painting, soft colors",
"description": "水彩画风格"
},
"Sketch": {
"repo_id": "ostris/crayon-style-lora-sdxl",
"weight": 0.7,
"trigger_words": "sketch style, pencil drawing",
"description": "素描风格"
},
"Portrait": {
"repo_id": "ostris/face-helper-sdxl-lora",
"weight": 0.8,
"trigger_words": "portrait, beautiful face, detailed eyes",
"description": "肖像和面部增强"
}
}
# 默认参数
DEFAULT_SEED = -1
DEFAULT_WIDTH = 1024
DEFAULT_HEIGHT = 1024
DEFAULT_LORA_SCALE = 0.8
DEFAULT_STEPS = 28
DEFAULT_CFG = 7.0
# 支持的语言
SUPPORTED_LANGUAGES = {
"en": "English",
"zh": "中文",
"ja": "日本語",
"ko": "한국어"
}
# ======================
# 全局变量: 懒加载
# ======================
pipe = None
current_model_key = None
current_loras = {}
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_pipeline(model_key: str = None):
"""灵活加载pipeline,支持不同模型"""
global pipe, current_model_key
if model_key is None:
model_key = DEFAULT_MODEL_KEY
# 如果模型已加载且是同一个,直接返回
if pipe is not None and current_model_key == model_key:
return pipe
# 卸载旧模型
if pipe is not None:
unload_pipeline()
model_config = MODEL_CONFIGS.get(model_key)
if not model_config:
raise ValueError(f"未知的模型配置: {model_key}")
print(f"🚀 加载模型: {model_config['description']} ({model_config['repo_id']})")
try:
# 加载SDXL类型的模型
if model_config["type"] == "sdxl":
pipe = StableDiffusionXLPipeline.from_pretrained(
model_config["repo_id"],
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
safety_checker=None if not model_config["requires_safety_checker"] else "default"
).to(device)
# 内存优化
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
if hasattr(pipe, 'enable_model_cpu_offload'):
pipe.enable_model_cpu_offload()
if hasattr(pipe, 'enable_xformers_memory_efficient_attention'):
try:
pipe.enable_xformers_memory_efficient_attention()
except:
print("⚠️ xformers不可用,跳过")
current_model_key = model_key
print(f"✅ 成功加载模型: {model_config['description']}")
return pipe
else:
raise ValueError(f"不支持的模型类型: {model_config['type']}")
except Exception as e:
print(f"❌ 加载模型失败: {e}")
# 尝试加载备用模型
if model_key != "sdxl_base":
print("🔄 尝试加载备用模型...")
return load_pipeline("sdxl_base")
else:
raise Exception("无法加载任何模型")
def unload_pipeline():
"""卸载pipeline释放内存"""
global pipe, current_loras, current_model_key
if pipe is not None:
try:
pipe.unload_lora_weights()
except:
pass
del pipe
torch.cuda.empty_cache()
pipe = None
current_loras = {}
current_model_key = None
print("🗑️ Pipeline已卸载")
def load_lora_weights(lora_configs: List[Dict]):
"""加载多个LoRA权重,带错误处理"""
global pipe, current_loras
if not lora_configs:
return
# 卸载现有LoRA
new_lora_ids = [config['repo_id'] for config in lora_configs if config['repo_id']]
if set(current_loras.keys()) != set(new_lora_ids):
try:
pipe.unload_lora_weights()
current_loras = {}
except:
pass
# 加载新LoRA
adapter_names = []
adapter_weights = []
for config in lora_configs:
if config['repo_id'] and config['repo_id'] not in current_loras:
try:
adapter_name = config['name'].replace(' ', '_').lower()
pipe.load_lora_weights(
config['repo_id'],
adapter_name=adapter_name
)
current_loras[config['repo_id']] = adapter_name
print(f"✅ 加载LoRA: {config['name']}")
except Exception as e:
print(f"⚠️ LoRA加载失败 {config['name']}: {e}")
continue
if config['repo_id'] in current_loras:
adapter_names.append(current_loras[config['repo_id']])
adapter_weights.append(config['weight'])
# 设置adapter权重
if adapter_names:
try:
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
print(f"✅ 激活了 {len(adapter_names)} 个LoRA")
except Exception as e:
print(f"⚠️ 设置adapter权重警告: {e}")
try:
pipe.set_adapters(adapter_names)
except:
print("❌ 无法设置任何adapter")
def process_long_prompt(prompt: str, max_length: int = 77) -> str:
"""处理长提示词"""
if len(prompt.split()) <= max_length:
return prompt
sentences = re.split(r'[.!?]+', prompt)
sentences = [s.strip() for s in sentences if s.strip()]
if sentences:
result = sentences[0]
remaining = max_length - len(result.split())
for sentence in sentences[1:]:
words = sentence.split()
if len(words) <= remaining:
result += ". " + sentence
remaining -= len(words)
else:
important_words = [w for w in words if len(w) > 3][:remaining]
if important_words:
result += ". " + " ".join(important_words)
break
return result
return " ".join(prompt.split()[:max_length])
# ======================
# 主生成函数
# ======================
@spaces.GPU(duration=60) if SPACES_AVAILABLE else lambda x: x
def generate_image(
model_key: str,
prompt: str,
negative_prompt: str,
style: str,
seed: int,
width: int,
height: int,
selected_loras: List[str],
lora_scale: float,
steps: int,
cfg_scale: float,
use_fixed_loras: bool,
language: str = "en"
):
"""主图像生成函数,支持ZeroGPU优化"""
global pipe
try:
# 加载指定模型
pipe = load_pipeline(model_key)
model_config = MODEL_CONFIGS[model_key]
# 处理种子
if seed == -1:
seed = torch.randint(0, 2**32, (1,)).item()
generator = torch.Generator(device=device).manual_seed(seed)
# 处理提示词
style_prefix = STYLE_PROMPTS.get(style, "")
processed_prompt = process_long_prompt(style_prefix + prompt, max_length=150)
# 使用模型默认负面提示词(如果用户未提供)
if not negative_prompt.strip():
negative_prompt = model_config["default_negative"]
processed_negative = process_long_prompt(negative_prompt, max_length=100)
# 准备LoRA配置
lora_configs = []
active_trigger_words = []
# 添加固定LoRA(如果启用)
if use_fixed_loras:
for name, config in FIXED_LORAS.items():
if config["repo_id"] and config["enabled"]:
lora_configs.append({
'name': name,
'repo_id': config["repo_id"],
'weight': config["weight"]
})
if config["trigger_words"]:
active_trigger_words.append(config["trigger_words"])
# 添加用户选择的LoRA
for lora_name in selected_loras:
if lora_name != "None" and lora_name in OPTIONAL_LORAS:
config = OPTIONAL_LORAS[lora_name]
if config["repo_id"]:
lora_configs.append({
'name': lora_name,
'repo_id': config["repo_id"],
'weight': config["weight"] * lora_scale
})
if config["trigger_words"]:
active_trigger_words.append(config["trigger_words"])
# 加载LoRA
load_lora_weights(lora_configs)
# 组合触发词
if active_trigger_words:
trigger_text = ", ".join(active_trigger_words)
final_prompt = f"{processed_prompt}, {trigger_text}"
else:
final_prompt = processed_prompt
# 生成图像
with torch.autocast(device):
image = pipe(
prompt=final_prompt,
negative_prompt=processed_negative,
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
).images[0]
# 生成元数据
timestamp = datetime.datetime.now()
metadata = {
"model": model_config["description"],
"model_repo": model_config["repo_id"],
"prompt": final_prompt,
"original_prompt": prompt,
"negative_prompt": processed_negative,
"style": style,
"fixed_loras_enabled": use_fixed_loras,
"fixed_loras": [name for name, config in FIXED_LORAS.items() if config["enabled"]] if use_fixed_loras else [],
"selected_loras": [name for name in selected_loras if name != "None"],
"lora_scale": lora_scale,
"seed": seed,
"steps": steps,
"cfg_scale": cfg_scale,
"width": width,
"height": height,
"language": language,
"timestamp": timestamp.isoformat(),
"trigger_words": active_trigger_words
}
metadata_str = json.dumps(metadata, indent=2, ensure_ascii=False)
return (
image,
metadata_str,
f"✅ 生成成功! 种子: {seed}"
)
except Exception as e:
error_msg = f"生成失败: {str(e)}"
print(f"❌ {error_msg}")
return None, error_msg, error_msg
# ======================
# Gradio界面
# ======================
def create_interface():
"""创建Gradio界面"""
with gr.Blocks(
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="purple",
neutral_hue="slate",
),
css="""
.model-card {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
padding: 20px;
border-radius: 12px;
color: white;
margin-bottom: 20px;
}
.control-section {
background: rgba(255,255,255,0.05);
border-radius: 12px;
padding: 15px;
margin: 10px 0;
}
""",
title="AI图像生成器 - Illustrious XL多模型版"
) as demo:
gr.Markdown("""
# 🎨 AI图像生成器 - Illustrious XL多模型版
### 支持多种SDXL模型自由切换 | 灵活的LoRA组合 | 优化的参数配置
""", elem_classes=["model-card"])
with gr.Row():
# 左侧 - 控制面板
with gr.Column(scale=3):
# 模型选择
with gr.Group(elem_classes=["control-section"]):
gr.Markdown("### 📦 模型选择")
model_dropdown = gr.Dropdown(
choices=[(config["description"], key) for key, config in MODEL_CONFIGS.items()],
value=DEFAULT_MODEL_KEY,
label="基础模型",
info="选择不同的模型以获得不同的风格"
)
model_info = gr.Markdown(MODEL_CONFIGS[DEFAULT_MODEL_KEY]["description"])
# 提示词输入
with gr.Group(elem_classes=["control-section"]):
gr.Markdown("### ✍️ 提示词")
prompt_input = gr.Textbox(
label="正面提示词",
placeholder="描述你想要生成的图像...",
lines=4,
max_lines=20
)
negative_prompt_input = gr.Textbox(
label="负面提示词(留空使用模型默认)",
placeholder="将自动使用所选模型的推荐负面提示词...",
lines=3,
max_lines=15
)
style_radio = gr.Radio(
choices=list(STYLE_PROMPTS.keys()),
label="风格模板",
value="None",
info="将自动添加到提示词前"
)
# 基础参数
with gr.Group(elem_classes=["control-section"]):
gr.Markdown("### ⚙️ 基础参数")
with gr.Row():
seed_input = gr.Slider(
minimum=-1,
maximum=99999999,
step=1,
value=DEFAULT_SEED,
label="种子 (-1=随机)"
)
with gr.Row():
width_input = gr.Slider(
minimum=512,
maximum=1536,
step=64,
value=DEFAULT_WIDTH,
label="宽度"
)
height_input = gr.Slider(
minimum=512,
maximum=1536,
step=64,
value=DEFAULT_HEIGHT,
label="高度"
)
with gr.Row():
steps_slider = gr.Slider(
minimum=10,
maximum=100,
step=1,
value=DEFAULT_STEPS,
label="采样步数"
)
cfg_slider = gr.Slider(
minimum=1.0,
maximum=20.0,
step=0.5,
value=DEFAULT_CFG,
label="CFG Scale"
)
# LoRA配置
with gr.Group(elem_classes=["control-section"]):
gr.Markdown("### 🎭 LoRA配置")
use_fixed_loras = gr.Checkbox(
label="启用固定LoRA增强(质量+细节)",
value=True,
info="自动加载质量和细节增强LoRA"
)
lora_dropdown = gr.Dropdown(
choices=list(OPTIONAL_LORAS.keys()),
label="额外LoRA(可多选)",
value=["None"],
multiselect=True,
info="选择额外的风格LoRA"
)
lora_scale_slider = gr.Slider(
minimum=0.0,
maximum=1.5,
step=0.05,
value=DEFAULT_LORA_SCALE,
label="LoRA强度"
)
# 生成按钮
generate_btn = gr.Button(
"✨ 生成图像",
variant="primary",
size="lg"
)
status_text = gr.Textbox(
label="状态",
value="准备就绪",
interactive=False
)
# 右侧 - 输出
with gr.Column(scale=2):
image_output = gr.Image(
label="生成的图像",
height=600,
format="webp"
)
gr.Markdown("**右键点击图像下载**")
metadata_output = gr.Textbox(
label="生成元数据 (JSON)",
lines=15,
max_lines=25
)
# ======================
# 事件处理
# ======================
# 模型切换时更新信息
def update_model_info(model_key):
config = MODEL_CONFIGS[model_key]
info = f"""
**模型:** {config['description']}
**仓库:** `{config['repo_id']}`
**推荐设置:** 步数={config['optimal_settings']['steps']}, CFG={config['optimal_settings']['cfg']}
"""
return (
info,
config['optimal_settings']['steps'],
config['optimal_settings']['cfg'],
config['default_negative']
)
model_dropdown.change(
fn=update_model_info,
inputs=[model_dropdown],
outputs=[model_info, steps_slider, cfg_slider, negative_prompt_input]
)
# 生成按钮
generate_btn.click(
fn=generate_image,
inputs=[
model_dropdown, prompt_input, negative_prompt_input, style_radio,
seed_input, width_input, height_input,
lora_dropdown, lora_scale_slider,
steps_slider, cfg_slider, use_fixed_loras,
gr.Textbox(value="zh", visible=False)
],
outputs=[
image_output, metadata_output, status_text
]
)
return demo
# ======================
# 启动应用
# ======================
if __name__ == "__main__":
demo = create_interface()
demo.queue(max_size=20)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)