ysharma HF Staff commited on
Commit
17d51c2
·
verified ·
1 Parent(s): 291586f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -33
app.py CHANGED
@@ -1,42 +1,121 @@
1
  import gradio as gr
2
- import spaces
 
 
 
 
 
 
3
  import torch
4
- from threading import Thread
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  from fastapi.responses import HTMLResponse
7
- from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- app = gr.Server()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  HOME = Path(__file__).parent
11
 
12
- MODEL_ID = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
13
- tok = AutoTokenizer.from_pretrained(MODEL_ID)
14
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to("cuda")
15
-
16
- @spaces.GPU
17
- def _generate(text: str):
18
- inputs = tok.apply_chat_template(
19
- [{"role": "user", "content": f"Summarize in 3 bullets:\n\n{text}"}],
20
- return_tensors="pt", return_dict=True, add_generation_prompt=True,
21
- ).to("cuda")
22
- streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
23
- Thread(target=model.generate, kwargs=dict(
24
- **inputs, streamer=streamer, max_new_tokens=300, do_sample=False,
25
- )).start()
26
- return streamer
27
-
28
- @app.mcp.tool(name="summarize")
29
- @app.api(name="summarize", concurrency_limit=1, stream_every=0.2)
30
- def summarize(text: str) -> str:
31
- """Summarize the input text into 3 bullet points."""
32
- out = ""
33
- for chunk in _generate(text):
34
- out += chunk
35
- yield out
36
-
37
- @app.get("/", response_class=HTMLResponse)
38
- async def index():
39
  return (HOME / "index.html").read_text(encoding="utf-8")
40
 
 
41
  if __name__ == "__main__":
42
- app.launch(mcp_server=True)
 
1
  import gradio as gr
2
+ import gc
3
+ import os
4
+ import random
5
+ import tempfile
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
  import torch
10
+ from PIL import Image
 
11
  from fastapi.responses import HTMLResponse
12
+ from gradio.data_classes import FileData
13
+
14
+ # ZeroGPU. Degrade gracefully off-Spaces so `python app.py` works locally.
15
+ try:
16
+ import spaces
17
+ _HAS_SPACES = True
18
+ except ImportError:
19
+ _HAS_SPACES = False
20
+
21
+ # --- Model load ---------------------------------------------------------------
22
+ # Heavy startup is wrapped in `gr.NO_RELOAD` so `gradio app.py` hot reload
23
+ # does not redownload weights every time you save the HTML.
24
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
+ DTYPE = torch.bfloat16
26
+
27
+ if gr.NO_RELOAD:
28
+ from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
29
+ from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
30
+ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
31
+
32
+ PIPE = QwenImageEditPlusPipeline.from_pretrained(
33
+ "FireRedTeam/FireRed-Image-Edit-1.1",
34
+ transformer=QwenImageTransformer2DModel.from_pretrained(
35
+ "prithivMLmods/Qwen-Image-Edit-Rapid-AIO-V19",
36
+ torch_dtype=DTYPE,
37
+ device_map="cuda",
38
+ ),
39
+ torch_dtype=DTYPE,
40
+ ).to(DEVICE)
41
+
42
+ try:
43
+ PIPE.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
44
+ print("Flash Attention 3 processor set.")
45
+ except Exception as e:
46
+ print(f"FA3 processor not set: {e}")
47
+
48
+ NEGATIVE_PROMPT = (
49
+ "worst quality, low quality, bad anatomy, bad hands, text, error, "
50
+ "missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, "
51
+ "signature, watermark, username, blurry"
52
+ )
53
+ MAX_SEED = np.iinfo(np.int32).max
54
+
55
+
56
+ def _round_dims(image: Image.Image) -> tuple[int, int]:
57
+ w, h = image.size
58
+ if w > h:
59
+ new_w, new_h = 1024, int(1024 * h / w)
60
+ else:
61
+ new_h, new_w = 1024, int(1024 * w / h)
62
+ return (new_w // 8) * 8, (new_h // 8) * 8
63
 
64
+
65
+ # --- Inner GPU function -------------------------------------------------------
66
+ # Per the reference: @spaces.GPU goes on the *inner* function that runs the
67
+ # model. The outer @server.api route just plugs it into the queue.
68
+ if _HAS_SPACES:
69
+ @spaces.GPU
70
+ def _edit(image: Image.Image, prompt: str, seed: int, steps: int) -> Image.Image:
71
+ return _run_pipe(image, prompt, seed, steps)
72
+ else:
73
+ def _edit(image, prompt, seed, steps):
74
+ return _run_pipe(image, prompt, seed, steps)
75
+
76
+
77
+ def _run_pipe(image, prompt, seed, steps):
78
+ gc.collect()
79
+ if torch.cuda.is_available():
80
+ torch.cuda.empty_cache()
81
+ width, height = _round_dims(image)
82
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
83
+ return PIPE(
84
+ image=[image],
85
+ prompt=prompt,
86
+ negative_prompt=NEGATIVE_PROMPT,
87
+ width=width,
88
+ height=height,
89
+ num_inference_steps=steps,
90
+ true_cfg_scale=1.0,
91
+ generator=generator,
92
+ ).images[0]
93
+
94
+
95
+ # --- Server -------------------------------------------------------------------
96
+ server = gr.Server()
97
  HOME = Path(__file__).parent
98
 
99
+
100
+ @server.api(name="edit_image", concurrency_limit=1)
101
+ def edit_image(image: FileData, prompt: str) -> dict:
102
+ """Edit an image guided by a text prompt using FireRed-Image-Edit 1.1."""
103
+ if not prompt or not prompt.strip():
104
+ return {"error": "Please enter an edit prompt."}
105
+ src = Image.open(image["path"]).convert("RGB")
106
+ seed = random.randint(0, MAX_SEED)
107
+ result = _edit(src, prompt.strip(), seed, steps=4)
108
+
109
+ fd, out_path = tempfile.mkstemp(suffix=".png")
110
+ os.close(fd)
111
+ result.save(out_path)
112
+ return {"image": FileData(path=out_path), "seed": seed}
113
+
114
+
115
+ @server.get("/", response_class=HTMLResponse)
116
+ async def homepage():
 
 
 
 
 
 
 
 
 
117
  return (HOME / "index.html").read_text(encoding="utf-8")
118
 
119
+
120
  if __name__ == "__main__":
121
+ server.launch(mcp_server=True, show_error=True)