import os
import io
import torch
from flask import Flask, request, jsonify, send_file, render_template_string
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
from PIL import Image, ImageFilter
from deep_translator import GoogleTranslator
from datetime import datetime
print("STARTING IMAGE PRO AI")
OUTPUT_DIR = "outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)
progress_value = 0
# ======================
# TRANSLATE & PROMPT ENHANCE
# ======================
def translate(text):
try:
return GoogleTranslator(source="auto", target="en").translate(text)
except:
return text
STYLE_PRESETS = {
"cinematic": "cinematic lighting, ultra realistic, 8k, film still, depth of field",
"anime": "anime style, clean lines, vibrant colors, detailed illustration",
"realistic": "photo realistic, natural lighting, high detail, 8k",
"neon": "neon lights, cyberpunk, glowing colors, night city",
"cartoon": "cartoon style, bold outlines, flat colors, playful",
}
def enhance(text, style: str | None = None):
base = "high detail, professional, sharp focus"
extra = base
if style and style in STYLE_PRESETS:
extra = STYLE_PRESETS[style] + ", " + base
return text + ", " + extra
# ======================
# MODEL (SZYBSZY: SD-TURBO)
# ======================
print("Loading Stable Diffusion Turbo...")
MODEL_ID = os.getenv("MODEL_ID", "stabilityai/sd-turbo")
pipe = StableDiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32
)
pipe = pipe.to("cpu")
pipe.enable_attention_slicing()
img2img_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32
)
img2img_pipe = img2img_pipe.to("cpu")
img2img_pipe.enable_attention_slicing()
print("MODEL READY")
# ======================
# API
# ======================
api = Flask(__name__)
# ======================
# HTML UI (ODŚWIEŻONY)
# ======================
HTML = """
IMAGE PRO AI STUDIO
"""
@api.route("/")
def home():
return render_template_string(HTML)
# ======================
# PROGRESS
# ======================
@api.route("/progress")
def progress():
global progress_value
return jsonify({"progress": progress_value})
# ======================
# GENERATE IMAGE
# ======================
def generate_image(prompt, style: str | None = None):
global progress_value
steps = 8 # szybciej
prompt = translate(prompt)
prompt = enhance(prompt, style)
def callback(step, timestep, latents):
global progress_value
progress_value = int((step/steps)*100)
image = pipe(
prompt,
num_inference_steps=steps,
guidance_scale=1.5,
callback=callback,
callback_steps=1
).images[0]
progress_value = 100
filename = datetime.now().strftime("%Y%m%d%H%M%S") + ".png"
path = os.path.join(OUTPUT_DIR, filename)
image.save(path)
return path
# ======================
# GENERATE
# ======================
@api.route("/generate", methods=["POST"])
def generate():
try:
data = request.get_json(force=True)
prompt = data.get("prompt", "")
style = data.get("style")
path = generate_image(prompt, style)
return send_file(path, mimetype="image/png")
except Exception as e:
return jsonify({"error": str(e)})
# ======================
# PRODUCT
# ======================
@api.route("/product", methods=["POST"])
def product():
data = request.get_json(force=True)
prompt = data.get("prompt", "")
style = data.get("style") or "realistic"
prompt = translate(prompt) + ", product photography, studio lighting, white background"
path = generate_image(prompt, style)
return send_file(path, mimetype="image/png")
# ======================
# LOGO
# ======================
@api.route("/logo", methods=["POST"])
def logo():
data = request.get_json(force=True)
prompt = data.get("prompt", "")
style = data.get("style") or "cartoon"
prompt = translate(prompt) + ", minimalist vector logo"
path = generate_image(prompt, style)
return send_file(path, mimetype="image/png")
# ======================
# BANNER
# ======================
@api.route("/banner", methods=["POST"])
def banner():
data = request.get_json(force=True)
prompt = data.get("prompt", "")
style = data.get("style") or "cinematic"
prompt = translate(prompt) + ", modern website banner design"
path = generate_image(prompt, style)
return send_file(path, mimetype="image/png")
# ======================
# SOCIAL
# ======================
@api.route("/social", methods=["POST"])
def social():
data = request.get_json(force=True)
prompt = data.get("prompt", "")
style = data.get("style") or "neon"
prompt = translate(prompt) + ", instagram social media post"
path = generate_image(prompt, style)
return send_file(path, mimetype="image/png")
# ======================
# RESTORE
# ======================
@api.route("/restore", methods=["POST"])
def restore():
if "image" not in request.files:
return "no image", 400
file = request.files["image"]
img = Image.open(file.stream)
img = img.filter(ImageFilter.SHARPEN)
path = os.path.join(OUTPUT_DIR, "restore.png")
img.save(path)
return send_file(path, mimetype="image/png")
# ======================
# UPSCALE
# ======================
@api.route("/upscale", methods=["POST"])
def upscale():
if "image" not in request.files:
return "no image", 400
file = request.files["image"]
img = Image.open(file.stream)
w, h = img.size
img = img.resize((w * 2, h * 2), Image.LANCZOS)
path = os.path.join(OUTPUT_DIR, "upscale.png")
img.save(path)
return send_file(path, mimetype="image/png")
# ======================
# COLORIZE
# ======================
@api.route("/colorize", methods=["POST"])
def colorize():
if "image" not in request.files:
return "no image", 400
file = request.files["image"]
img = Image.open(file.stream)
img = img.convert("RGB")
path = os.path.join(OUTPUT_DIR, "color.png")
img.save(path)
return send_file(path, mimetype="image/png")
# ======================
# BLEND (KLASYCZNY)
# ======================
def _blend_simple(file_a, file_b, mix: float, style: str | None = None):
img_a = Image.open(file_a.stream).convert("RGBA")
img_b = Image.open(file_b.stream).convert("RGBA")
img_b = img_b.resize(img_a.size)
mix = max(0.0, min(1.0, mix))
blended = Image.blend(img_a, img_b, mix)
path = os.path.join(OUTPUT_DIR, "blend.png")
blended.save(path)
return path
# ======================
# BLEND PRO (AI IMG2IMG)
# ======================
def _blend_ai(file_a, file_b, mix: float, style: str | None = None):
img_a = Image.open(file_a.stream).convert("RGBA")
img_b = Image.open(file_b.stream).convert("RGBA")
img_b = img_b.resize(img_a.size)
mix = max(0.0, min(1.0, mix))
base = Image.blend(img_a, img_b, mix).convert("RGB")
prompt = "high quality artistic blend of two images"
if style and style in STYLE_PRESETS:
prompt += ", " + STYLE_PRESETS[style]
images = img2img_pipe(
prompt=prompt,
image=base,
strength=0.6,
num_inference_steps=8,
guidance_scale=1.5,
).images
result = images[0]
path = os.path.join(OUTPUT_DIR, "blend_pro.png")
result.save(path)
return path
@api.route("/blend", methods=["POST"])
def blend():
if "image_a" not in request.files or "image_b" not in request.files:
return "need image_a and image_b", 400
file_a = request.files["image_a"]
file_b = request.files["image_b"]
try:
mix = float(request.form.get("mix", 0.5))
except:
mix = 0.5
style = request.form.get("style")
path = _blend_simple(file_a, file_b, mix, style)
return send_file(path, mimetype="image/png")
@api.route("/blend-pro", methods=["POST"])
def blend_pro():
if "image_a" not in request.files or "image_b" not in request.files:
return "need image_a and image_b", 400
file_a = request.files["image_a"]
file_b = request.files["image_b"]
try:
mix = float(request.form.get("mix", 0.5))
except:
mix = 0.5
style = request.form.get("style")
path = _blend_ai(file_a, file_b, mix, style)
return send_file(path, mimetype="image/png")
# ======================
if __name__ == "__main__":
api.run(host="0.0.0.0", port=7860)