Saravutw commited on
Commit
e0ed766
·
verified ·
1 Parent(s): fd73d22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -31
app.py CHANGED
@@ -1,45 +1,71 @@
1
  import torch
2
- from diffusers import AutoPipelineForText2Image
 
 
3
 
4
- # การั้งคบืต้
5
- model_id = "stabilityai/sdxl-turbo"
 
 
 
6
 
7
- # 1. โหลดโมเดล
8
- # ใช้ torch.float32 สำหรับ CPU เพื่อความเสถียรสูงสุด (CPU บางตัวไม่รองรับ half-precision ได้ดี)
9
- # low_cpu_mem_usage=True จะช่วยลดการกระชากของ RAM ตอนโหลดโมเดล (ต้องมี accelerate library)
10
  pipe = AutoPipelineForText2Image.from_pretrained(
11
- model_id,
12
- torch_dtype=torch.float32,
13
- variant="fp16", # โหลด weight แบบ fp16 เพื่อลดขนาดไฟล์ดาวน์โหลด แต่รันจริงบน float32
14
- use_safetensors=True
15
  )
16
 
17
- # 2. ย้ายไปที่ CPU (ระบุชัดเจน)
18
  pipe.to("cpu")
19
 
20
- # การับแต่ง Memory เพิ่มเติม (ถ้ RAM 18GB ต็มจริงๆ อาจต้องเปิใช้ attention slicing แต่ันจะทำให้เจนภพช้าลง)
21
- # pipe.enable_attention_slicing()
 
 
 
 
 
 
 
 
 
22
 
23
- def generate_image(prompt_text):
24
- # SDXL Turbo ปกติเทรนมาที่ 512x512 การดันไป 1024x1024 บน CPU จะใช้เวลานานขึ้นและกิน RAM สูง
25
- # แต่สามารถทำได้โดยการกำหนด height/width
26
 
27
- image = pipe(
28
- prompt=prompt_text,
29
- num_inference_steps=2, # SDXL Turbo ต้องการแค่ 1-4 step (2 คือจุดสมดุลที่ดีสำหรับ 1024px)
30
- guidance_scale=0.0, # สำคัญ: Turbo ไม่ใช้ CFG (Guidance Scale) ต้องตั้งเป็น 0.0 เพื่อให้ภาพไม่เละและอิสระตามโมเดล
31
- width=1024,
32
- height=1024,
33
- ).images[0]
34
 
 
 
 
 
 
 
 
 
 
35
  return image
36
 
37
- # ส่วนของการทดสอบรัน (ตัวอย่าง)
38
- if __name__ == "__main__":
39
- # ใส่ Prompt ที่ต้องการทดสอบตรงนี้
40
- user_prompt = "A cinematic shot of a cyberpunk street in rain, neon lights, highly detailed, 8k"
41
 
42
- print("Starting generation... (CPU may take time)")
43
- result = generate_image(user_prompt)
44
- result.save("output_1024.png")
45
- print("Image saved as output_1024.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import gradio as gr
3
+ import os
4
+ from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler
5
 
6
+ # เปลี่ยนเป็นโมเดลระกูล XL ที่เนความสมจริและรัไว (Lightning/Turbo)
7
+ # ตัวนี้ยังอยู่ในตระกูลเดียวกับที่ทำไว้ แต่ให้งานผิวและแสงที่ต่างออกไป
8
+ MODEL_ID = "SG_161222/RealVisXL_V4.0_Lightning"
9
+
10
+ print(f"Loading {MODEL_ID} using existing CPU-optimized structure...")
11
 
 
 
 
12
  pipe = AutoPipelineForText2Image.from_pretrained(
13
+ MODEL_ID,
14
+ torch_dtype=torch.float32,
15
+ low_cpu_mem_usage=True
 
16
  )
17
 
 
18
  pipe.to("cpu")
19
 
20
+ # ใช้บรรดฐเดมที่ทำให้เ Never OOM
21
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
22
+ pipe.enable_attention_slicing("max")
23
+ pipe.enable_vae_tiling()
24
+ torch.set_num_threads(os.cpu_count())
25
+
26
+ STYLE_MAP = {
27
+ "สมจริง (Photo)": "cinematic photo, highly detailed, 8k wallpaper, realistic skin texture",
28
+ "ศิลปะ (Artistic)": "digital art, masterpiece, intricate details, vibrant",
29
+ "ไม่เน้นสไตล์": ""
30
+ }
31
 
32
+ def gen(prompt, style_name, negative_prompt, steps, cfg, width, height):
33
+ if not prompt: return None
 
34
 
35
+ style_prompt = STYLE_MAP.get(style_name, "")
36
+ full_prompt = f"{prompt}, {style_prompt}"
 
 
 
 
 
37
 
38
+ with torch.no_grad():
39
+ image = pipe(
40
+ prompt=full_prompt,
41
+ negative_prompt=negative_prompt,
42
+ num_inference_steps=int(steps),
43
+ guidance_scale=float(cfg),
44
+ width=int(width),
45
+ height=int(height)
46
+ ).images[0]
47
  return image
48
 
49
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
50
+ gr.Markdown(f"### 🚀 CPU Optimized: {MODEL_ID}")
 
 
51
 
52
+ with gr.Row():
53
+ with gr.Column():
54
+ prompt = gr.Textbox(label="Prompt", lines=2)
55
+ style_name = gr.Radio(choices=list(STYLE_MAP.keys()), value="สมจริง (Photo)", label="Style")
56
+
57
+ with gr.Accordion("Advanced Settings", open=False):
58
+ negative = gr.Textbox(label="Negative", value="low quality, blurry, deformed")
59
+ steps = gr.Slider(1, 10, 4, step=1, label="Steps (Lightning/Turbo use 4-8)")
60
+ cfg = gr.Slider(0.0, 2.0, 1.0, step=0.1, label="CFG (Lightning use 1.0-2.0)")
61
+ width = gr.Slider(256, 512, 384, step=64, label="Width")
62
+ height = gr.Slider(256, 512, 512, step=64, label="Height")
63
+
64
+ btn = gr.Button("Generate", variant="primary")
65
+
66
+ with gr.Column():
67
+ output_img = gr.Image(label="Result")
68
+
69
+ btn.click(fn=gen, inputs=[prompt, style_name, negative, steps, cfg, width, height], outputs=[output_img])
70
+
71
+ demo.launch()