Saravutw commited on
Commit
fc71c2c
·
verified ·
1 Parent(s): 254bb68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -90
app.py CHANGED
@@ -1,94 +1,46 @@
1
- import torch, random, gc, numpy as np
2
- import gradio as gr
3
- from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
4
- from PIL import Image
5
- import os
6
 
7
- # --- 1. Load Model (Free Tier HF) ---
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
- model_id = "stablediffusionapi/pony-diffusion-v6-xl"
10
 
11
- pipe = StableDiffusionXLPipeline.from_pretrained(
 
 
 
12
  model_id,
13
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
14
- variant="fp16",
15
  use_safetensors=True
16
- ).to(device)
17
-
18
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
19
- pipe.enable_vae_tiling()
20
- pipe.enable_vae_slicing()
21
-
22
- # --- 2. Generate Function ---
23
- def generate_single(prompt, neg_prompt, steps, cfg, seed, size_preset, style_preset, file_format):
24
- torch.cuda.empty_cache()
25
- gc.collect()
26
-
27
- style_map = {
28
- "None (ค่าดั้งเดิม)": "raw photo.cute.high resolution",
29
- "Realistic Photo (ดิบๆ)": "raw photo, (photorealistic:1.3), high fidelity, skin pores, film grain, Fujifilm, ",
30
- "Cinematic Real (แสงสวย)": "cinematic film still, shallow depth of field, dramatic lighting, highly detailed skin, 8k uhd, ",
31
- "3D Realistic (เนียนกริบ)": "super-realistic.realistic render,real human, octane render, soft global illumination, ",
32
- "Semi-Real (นวลตา)": "semi-realistic, digital concept art, smooth skin, detailed lighting, ",
33
- "Portrait Real (เน้นหน้า)": "close up portrait, highly detailed eyes, skin texture, dslr, soft natural light, ",
34
- "สไตล์บ้านๆ (Homemade)": "amateur phone photography, (casual lighting:1.2), flash photography, messy room background, (real life:1.3), grainy, ",
35
- "แนวบิวตี้ (ฟรุ้งฟริ้ง)": "soft focus, (ethereal lighting:1.3), dreamy atmosphere, glowing skin, (pastel tones:1.1), high key lighting, beauty filter, "
36
- }
37
-
38
- selected_style = style_map.get(style_preset, "")
39
- full_prompt = f"score_9, score_8_up, masterpiece, best quality, {selected_style}{prompt}"
40
-
41
- size_map = {
42
- "Square (512x512)": (512, 512),
43
- "Square (1024x1024)": (1024, 1024),
44
- "Portrait (832x1216)": (720, 1260),
45
- "Landscape (1216x832)": (1280, 720)
46
- }
47
-
48
- width, height = size_map.get(size_preset, (512, 512))
49
- base_seed = int(seed) if int(seed) != -1 else random.randint(0, 10_000_000)
50
-
51
- generator = torch.manual_seed(base_seed)
52
- with torch.inference_mode():
53
- image = pipe(
54
- prompt=full_prompt,
55
- negative_prompt=neg_prompt,
56
- num_inference_steps=int(steps),
57
- guidance_scale=cfg,
58
- generator=generator,
59
- width=width,
60
- height=height
61
- ).images[0]
62
-
63
- filename = f"output_{base_seed}.{file_format.lower()}"
64
- image.save(filename, quality=95 if file_format == "JPG" else 98)
65
- return image, base_seed, filename
66
-
67
- # --- 3. Gradio UI ---
68
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
69
- gr.Markdown("## 🐴 Pony V6 XL (Free Tier HF)")
70
-
71
- with gr.Row():
72
- with gr.Column(scale=1):
73
- btn = gr.Button("🚀 GENERATE NOW", variant="primary")
74
- p_in = gr.Textbox(label="Prompt", value="a beautiful cinematic portrait of a woman", lines=4)
75
- n_in = gr.Textbox(label="Negative", value="low quality, blurry, bad anatomy, watermark", lines=2)
76
-
77
- style_drop = gr.Dropdown(list(style_map.keys()), value="Realistic Photo (ดิบๆ)", label="เลือกสไตล์ภาพ")
78
- size_drop = gr.Dropdown(list(size_map.keys()), value="Square (1024x1024)", label="ขนาดภาพ")
79
- format_drop = gr.Dropdown(["PNG", "JPG", "WebP"], value="PNG", label="นามสกุลไฟล์")
80
-
81
- c_sld = gr.Slider(1, 15, value=3.0, step=0.5, label="ความคมชัดคำสั่ง (CFG)")
82
- s_sld = gr.Slider(2, 100, value=30, step=1, label="รอบการวาด (Steps)")
83
- sd_in = gr.Number(value=-1, label="Seed")
84
-
85
- with gr.Column(scale=1):
86
- img_out = gr.Image(label="Result")
87
- file_out = gr.File(label="Download File")
88
- seed_out = gr.Number(label="Seed ที่ใช้")
89
-
90
- btn.click(fn=generate_single,
91
- inputs=[p_in, n_in, s_sld, c_sld, sd_in, size_drop, style_drop, format_drop],
92
- outputs=[img_out, seed_out, file_out])
93
-
94
- demo.launch(share=True)
 
1
+ import torch
2
+ from diffusers import AutoPipelineForText2Image
 
 
 
3
 
4
+ # การตั้งค่าเบื้องต้น
5
+ model_id = "stabilityai/sdxl-turbo"
 
6
 
7
+ # 1. โหลดโมเดล
8
+ # ใช้ torch.float32 สำหรับ CPU เพื่อความเสถียรสูงสุด (CPU บางตัวไม่รองรับ half-precision ได้ดี)
9
+ # low_cpu_mem_usage=True จะช่วยลดการกระชากของ RAM ตอนโหลดโมเดล (ต้องมี accelerate library)
10
+ pipe = AutoPipelineForText2Image.from_pretrained(
11
  model_id,
12
+ torch_dtype=torch.float32,
13
+ variant="fp16", # โหลด weight แบบ fp16 เพื่อลดขนาดไฟล์ดาวน์โหลด แต่รันจริงบน float32
14
  use_safetensors=True
15
+ )
16
+
17
+ # 2. ย้ายไปที่ CPU (ระบุชัดเจน)
18
+ pipe.to("cpu")
19
+
20
+ # การปรับแต่ง Memory เพิ่มเติม (ถ้า RAM 18GB เต็มจริงๆ อาจต้องเปิดใช้ attention slicing แต่มันจะทำให้เจนภาพช้าลง)
21
+ # pipe.enable_attention_slicing()
22
+
23
+ def generate_image(prompt_text):
24
+ # SDXL Turbo ปกติเทรนมาที่ 512x512 การดันไป 1024x1024 บน CPU จะใช้เวลานานขึ้นและกิน RAM สูง
25
+ # แต่สามารถทำได้โดยการกำหนด height/width
26
+
27
+ image = pipe(
28
+ prompt=prompt_text,
29
+ num_inference_steps=2, # SDXL Turbo ต้องการแค่ 1-4 step (2 คือจุดสมดุลที่ดีสำหรับ 1024px)
30
+ guidance_scale=0.0, # สำคัญ: Turbo ไม่ใช้ CFG (Guidance Scale) ต้องตั้งเป็น 0.0 เพื่อให้ภาพไม่เละและอิสระตามโมเดล
31
+ width=1024,
32
+ height=1024,
33
+ ).images[0]
34
+
35
+ return image
36
+
37
+ # ส่วนของการทดสอบรัน (ตัวอย่าง)
38
+ if __name__ == "__main__":
39
+ # ใส่ Prompt ที่ต้องการทดสอบตรงนี้
40
+ user_prompt = "A cinematic shot of a cyberpunk street in rain, neon lights, highly detailed, 8k"
41
+
42
+ print("Starting generation... (CPU may take time)")
43
+ result = generate_image(user_prompt)
44
+ result.save("output_1024.png")
45
+ print("Image saved as output_1024.png")
46
+ (share=True)