Daankular commited on
Commit
c1af5fa
·
1 Parent(s): 0a1745e

Integrate PSHuman locally — multi-view diffusion in-process, no remote calls

Browse files

- pipeline/pshuman_local.py: clones pengHTYX/PSHuman, downloads
pengHTYX/PSHuman_Unclip_768_6views, runs StableUnCLIPImg2ImgPipeline
directly (bypasses inference.py to avoid pytorch3d/kaolin imports).
Uses our own preprocessing (no mediapipe): insightface face-crop or
top-centre heuristic fallback. Returns 6 colour + 6 normal PIL views.
- app.py: gradio_pshuman_face now calls pshuman_local.run_pshuman_diffusion
with @spaces.GPU(duration=180). PSHuman tab shows colour/normal galleries.
Removes remote service URL input entirely.
- requirements.txt: add rembg (bg removal for outputs), icecream (PSHuman dep)

Files changed (3) hide show
  1. app.py +41 -85
  2. pipeline/pshuman_local.py +316 -0
  3. requirements.txt +2 -0
app.py CHANGED
@@ -1525,69 +1525,40 @@ def gradio_animate(
1525
  return None, f"Error:\n{traceback.format_exc()}", None
1526
 
1527
 
1528
- # ── PSHuman Face Transplant ────────────────────────────────────────────────────
1529
 
 
1530
  def gradio_pshuman_face(
1531
  input_image,
1532
- rigged_glb_path,
1533
- weight_threshold: float,
1534
- retract_mm: float,
1535
- pshuman_url: str,
1536
  progress=gr.Progress(),
1537
  ):
1538
  """
1539
- Full PSHuman face transplant pipeline:
1540
- 1. Run PSHuman on input_image colored OBJ face mesh
1541
- 2. Run face_transplant.py → stitch face into rigged GLB
1542
- 3. Return the combined GLB
1543
-
1544
- PSHuman runs as a remote service (pshuman_url). On ZeroGPU the service_url
1545
- must point to an externally-deployed PSHuman endpoint (PSHUMAN_URL env var
1546
- or user-provided URL in the UI). Local localhost will not work on ZeroGPU.
1547
  """
1548
  try:
1549
  if input_image is None:
1550
- return None, "Upload a portrait image first.", None
1551
- rigged = rigged_glb_path
1552
- if not rigged or not os.path.exists(str(rigged)):
1553
- return None, "No rigged GLB found — run the Rig step first.", None
1554
-
1555
- work_dir = tempfile.mkdtemp(prefix="pshuman_transplant_")
1556
- img_path = os.path.join(work_dir, "portrait.png")
1557
- if isinstance(input_image, np.ndarray):
1558
- Image.fromarray(input_image).save(img_path)
1559
- else:
1560
- input_image.save(img_path)
1561
-
1562
- # pipeline/ is already in sys.path via PIPELINE_DIR insertion at startup
1563
- # ── Step 1: PSHuman inference ──────────────────────────────────────────
1564
- progress(0.05, desc="Step 1/2: Running PSHuman (generates multi-view face)...")
1565
- from pipeline.pshuman_client import generate_pshuman_mesh
1566
- face_obj = os.path.join(work_dir, "pshuman_face.obj")
1567
- generate_pshuman_mesh(
1568
- image_path = img_path,
1569
- output_path = face_obj,
1570
- service_url = pshuman_url.strip() or "http://localhost:7862",
1571
- )
1572
 
1573
- # ── Step 2: Face transplant ────────────────────────────────────────────
1574
- progress(0.7, desc="Step 2/2: Stitching PSHuman face into rigged GLB...")
1575
- out_glb = os.path.join(work_dir, "rigged_pshuman_face.glb")
1576
-
1577
- from pipeline.face_transplant import transplant_face
1578
- transplant_face(
1579
- body_glb_path = str(rigged),
1580
- pshuman_mesh_path = face_obj,
1581
- output_path = out_glb,
1582
- weight_threshold = float(weight_threshold),
1583
- retract_amount = float(retract_mm) / 1000.0, # mm → metres
1584
- )
1585
 
1586
  progress(1.0, desc="Done!")
1587
- return out_glb, "PSHuman face transplant complete.", out_glb
1588
 
1589
  except Exception:
1590
- return None, f"Error:\n{traceback.format_exc()}", None
1591
 
1592
 
1593
  # ── Full pipeline ─────────────────────────────────────────────────────────────
@@ -1867,53 +1838,38 @@ with gr.Blocks(title="Image2Model", theme=gr.themes.Soft()) as demo:
1867
  # ════════════════════════════════════════════════════════════════════
1868
  with gr.Tab("PSHuman Face"):
1869
  gr.Markdown(
1870
- "### PSHuman Face Transplant\n"
1871
- "Generates a high-detail face mesh via PSHuman (multi-view diffusion), "
1872
- "then transplants it into the rigged GLB.\n\n"
1873
- "**Pipeline:** portrait → PSHuman (remote service) colored OBJ → face_transplant → rigged GLB with HD face\n\n"
1874
- "**Note:** On ZeroGPU, PSHuman must run as a remote service. "
1875
- "Set `PSHUMAN_URL` environment variable or enter the URL below."
 
1876
  )
1877
  with gr.Row():
1878
  with gr.Column(scale=1):
1879
  pshuman_img_input = gr.Image(
1880
- label="Portrait image (same as used for Generate)",
1881
  type="pil",
1882
  )
1883
- with gr.Accordion("Advanced settings", open=False):
1884
- pshuman_weight_thresh = gr.Slider(
1885
- minimum=0.1, maximum=0.9, value=0.35, step=0.05,
1886
- label="Head bone weight threshold",
1887
- info="Vertices with head-bone weight above this get replaced",
1888
- )
1889
- pshuman_retract_mm = gr.Slider(
1890
- minimum=0.0, maximum=20.0, value=4.0, step=0.5,
1891
- label="Face retract (mm)",
1892
- info="How far to push original face verts inward to avoid z-fighting",
1893
- )
1894
- pshuman_service_url = gr.Textbox(
1895
- label="PSHuman service URL",
1896
- value=os.environ.get("PSHUMAN_URL", "http://localhost:7862"),
1897
- info="pshuman_app.py Gradio endpoint (deployed separately)",
1898
- )
1899
- pshuman_btn = gr.Button("Generate HD Face", variant="primary")
1900
 
1901
  with gr.Column(scale=2):
1902
- pshuman_status = gr.Textbox(label="Status", lines=4, interactive=False)
1903
- pshuman_model_3d = gr.Model3D(
1904
- label="Preview", clear_color=[0.9, 0.9, 0.9, 1.0])
1905
- pshuman_glb_dl = gr.File(label="Download GLB (with PSHuman face)")
 
 
 
 
 
 
1906
 
1907
  pshuman_btn.click(
1908
  fn=gradio_pshuman_face,
1909
- inputs=[
1910
- pshuman_img_input,
1911
- rigged_glb_state,
1912
- pshuman_weight_thresh,
1913
- pshuman_retract_mm,
1914
- pshuman_service_url,
1915
- ],
1916
- outputs=[pshuman_glb_dl, pshuman_status, pshuman_model_3d],
1917
  )
1918
 
1919
  # ════════════════════════════════════════════════════════════════════
 
1525
  return None, f"Error:\n{traceback.format_exc()}", None
1526
 
1527
 
1528
+ # ── PSHuman Multi-View ────────────────────────────────────────────────────────
1529
 
1530
+ @spaces.GPU(duration=180)
1531
  def gradio_pshuman_face(
1532
  input_image,
 
 
 
 
1533
  progress=gr.Progress(),
1534
  ):
1535
  """
1536
+ Run PSHuman multi-view diffusion locally (in-process).
1537
+ Returns 6 colour views + 6 normal-map views of the person.
1538
+
1539
+ Full mesh reconstruction (pytorch3d / kaolin / torch_scatter) is skipped —
1540
+ those packages have no Python 3.13 wheels. The generated views can be used
1541
+ directly for inspection or fed into the face-swap step.
 
 
1542
  """
1543
  try:
1544
  if input_image is None:
1545
+ return None, None, "Upload a portrait image first."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1546
 
1547
+ img = (Image.fromarray(input_image) if isinstance(input_image, np.ndarray)
1548
+ else input_image.convert("RGBA") if input_image.mode != "RGBA"
1549
+ else input_image)
1550
+
1551
+ progress(0.1, desc="Loading PSHuman pipeline…")
1552
+ from pipeline.pshuman_local import run_pshuman_diffusion
1553
+
1554
+ progress(0.2, desc="Running multi-view diffusion (40 steps × 7 views)…")
1555
+ colour_views, normal_views = run_pshuman_diffusion(img, device="cuda")
 
 
 
1556
 
1557
  progress(1.0, desc="Done!")
1558
+ return colour_views, normal_views, "Multi-view generation complete."
1559
 
1560
  except Exception:
1561
+ return None, None, f"Error:\n{traceback.format_exc()}"
1562
 
1563
 
1564
  # ── Full pipeline ─────────────────────────────────────────────────────────────
 
1838
  # ════════════════════════════════════════════════════════════════════
1839
  with gr.Tab("PSHuman Face"):
1840
  gr.Markdown(
1841
+ "### PSHuman Multi-View (local)\n"
1842
+ "Generates 6 colour + 6 normal-map views of a person using "
1843
+ "[PSHuman](https://github.com/pengHTYX/PSHuman) "
1844
+ "(StableUnCLIP fine-tuned on multi-view human images).\n\n"
1845
+ "**Pipeline:** portrait multi-view diffusion (in-process) "
1846
+ "6 × colour + 6 × normal views\n\n"
1847
+ "**Views:** front · front-right · right · back · left · front-left"
1848
  )
1849
  with gr.Row():
1850
  with gr.Column(scale=1):
1851
  pshuman_img_input = gr.Image(
1852
+ label="Portrait image",
1853
  type="pil",
1854
  )
1855
+ pshuman_btn = gr.Button("Generate Views", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1856
 
1857
  with gr.Column(scale=2):
1858
+ pshuman_status = gr.Textbox(
1859
+ label="Status", lines=2, interactive=False)
1860
+ pshuman_colour_gallery = gr.Gallery(
1861
+ label="Colour views (front front-right → right → back → left → front-left)",
1862
+ columns=3, rows=2, height=420,
1863
+ )
1864
+ pshuman_normal_gallery = gr.Gallery(
1865
+ label="Normal maps",
1866
+ columns=3, rows=2, height=420,
1867
+ )
1868
 
1869
  pshuman_btn.click(
1870
  fn=gradio_pshuman_face,
1871
+ inputs=[pshuman_img_input],
1872
+ outputs=[pshuman_colour_gallery, pshuman_normal_gallery, pshuman_status],
 
 
 
 
 
 
1873
  )
1874
 
1875
  # ════════════════════════════════════════════════════════════════════
pipeline/pshuman_local.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ pshuman_local.py
3
+ ================
4
+ Run PSHuman multi-view diffusion in-process on ZeroGPU.
5
+ Generates 6 colour views + 6 normal views from a single portrait image.
6
+
7
+ Full mesh reconstruction (SMPL fitting, pytorch3d/kaolin texture projection)
8
+ is intentionally skipped — those deps have no Python 3.13 wheels.
9
+
10
+ Model : pengHTYX/PSHuman_Unclip_768_6views
11
+ Repo : https://github.com/pengHTYX/PSHuman (cloned to /tmp/pshuman-src)
12
+
13
+ The preprocessing replaces PSHuman's mediapipe-based face-crop helper with a
14
+ self-contained approach (insightface if available, else top-centre crop) so no
15
+ additional exotic deps are needed.
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import os
20
+ import sys
21
+ import subprocess
22
+ import tempfile
23
+ from pathlib import Path
24
+ from typing import List, Optional, Tuple
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn.functional as F
29
+ from PIL import Image
30
+
31
+ # ── Constants ─────────────────────────────────────────────────────────────────
32
+ _PSHUMAN_SRC = Path("/tmp/pshuman-src")
33
+ _PSHUMAN_CKPT = Path("/tmp/pshuman-ckpts")
34
+ _PSHUMAN_REPO = "https://github.com/pengHTYX/PSHuman.git"
35
+ _PSHUMAN_HF_ID = "pengHTYX/PSHuman_Unclip_768_6views"
36
+
37
+ _CANVAS_SIZE = 768 # model input resolution
38
+ _CROP_SIZE = 740 # subject bounding-box target size within canvas
39
+ _NUM_VIEWS = 7 # 6 body views + 1 face-crop view (model expectation)
40
+
41
+ # Cached pipeline (survives across ZeroGPU calls when the Space is warm)
42
+ _pipeline = None
43
+
44
+ # ── Setup helpers ─────────────────────────────────────────────────────────────
45
+
46
+ def _ensure_repo() -> None:
47
+ if _PSHUMAN_SRC.exists():
48
+ return
49
+ print("[pshuman] Cloning PSHuman repo…")
50
+ subprocess.run(
51
+ ["git", "clone", "--depth=1", _PSHUMAN_REPO, str(_PSHUMAN_SRC)],
52
+ check=True,
53
+ )
54
+ print("[pshuman] Repo cloned.")
55
+
56
+
57
+ def _ensure_sys_path() -> None:
58
+ _ensure_repo()
59
+ src = str(_PSHUMAN_SRC)
60
+ if src not in sys.path:
61
+ sys.path.insert(0, src)
62
+
63
+
64
+ def _ensure_ckpt() -> None:
65
+ if _PSHUMAN_CKPT.exists():
66
+ return
67
+ from huggingface_hub import snapshot_download
68
+ print("[pshuman] Downloading model weights…")
69
+ snapshot_download(repo_id=_PSHUMAN_HF_ID, local_dir=str(_PSHUMAN_CKPT))
70
+ print("[pshuman] Weights downloaded.")
71
+
72
+
73
+ def load_pipeline(device: str = "cuda"):
74
+ """Load (and cache) the PSHuman StableUnCLIPImg2ImgPipeline."""
75
+ global _pipeline
76
+ if _pipeline is not None:
77
+ _pipeline.to(device)
78
+ return _pipeline
79
+
80
+ _ensure_sys_path()
81
+ _ensure_ckpt()
82
+
83
+ from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import (
84
+ StableUnCLIPImg2ImgPipeline,
85
+ )
86
+
87
+ print("[pshuman] Loading pipeline…")
88
+ pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
89
+ str(_PSHUMAN_CKPT),
90
+ torch_dtype=torch.float16,
91
+ )
92
+ try:
93
+ pipe.unet.enable_xformers_memory_efficient_attention()
94
+ except Exception:
95
+ pass
96
+ pipe.to(device)
97
+ _pipeline = pipe
98
+ print("[pshuman] Pipeline ready.")
99
+ return pipe
100
+
101
+
102
+ # ── Image preprocessing ───────────────────────────────────────────────────────
103
+
104
+ def _preprocess_subject(image: Image.Image) -> Image.Image:
105
+ """
106
+ Replicate PSHuman's load_image preprocessing:
107
+ 1. Ensure RGBA (background = white if no alpha).
108
+ 2. Crop to the bounding box of the alpha channel.
109
+ 3. Scale the longest side to _CROP_SIZE.
110
+ 4. Paste centred on a white _CANVAS_SIZE × _CANVAS_SIZE canvas.
111
+
112
+ Returns an RGB image (768 × 768).
113
+ """
114
+ if image.mode != "RGBA":
115
+ # No alpha → treat entire image as foreground
116
+ rgba = image.convert("RGBA")
117
+ alpha = np.ones((rgba.size[1], rgba.size[0]), dtype=np.uint8) * 255
118
+ rgba.putalpha(Image.fromarray(alpha))
119
+ image = rgba
120
+ else:
121
+ image = image.copy()
122
+
123
+ arr = np.array(image) # H W 4
124
+ a = arr[:, :, 3]
125
+
126
+ # Bounding box of non-transparent pixels
127
+ rows = np.any(a > 0, axis=1)
128
+ cols = np.any(a > 0, axis=0)
129
+ if rows.any():
130
+ r0, r1 = np.where(rows)[0][[0, -1]]
131
+ c0, c1 = np.where(cols)[0][[0, -1]]
132
+ crop = image.crop((c0, r0, c1 + 1, r1 + 1))
133
+ else:
134
+ crop = image
135
+
136
+ # Scale longest side → _CROP_SIZE
137
+ cw, ch = crop.size
138
+ scale = _CROP_SIZE / max(cw, ch)
139
+ new_w = round(cw * scale)
140
+ new_h = round(ch * scale)
141
+ crop = crop.resize((new_w, new_h), Image.LANCZOS)
142
+
143
+ # Paste on white canvas
144
+ canvas = Image.new("RGB", (_CANVAS_SIZE, _CANVAS_SIZE), (255, 255, 255))
145
+ x_off = (_CANVAS_SIZE - new_w) // 2
146
+ y_off = (_CANVAS_SIZE - new_h) // 2
147
+ if crop.mode == "RGBA":
148
+ canvas.paste(crop.convert("RGB"), (x_off, y_off), crop.split()[3])
149
+ else:
150
+ canvas.paste(crop.convert("RGB"), (x_off, y_off))
151
+
152
+ return canvas
153
+
154
+
155
+ def _get_face_crop(image: Image.Image) -> Image.Image:
156
+ """
157
+ Return a 256×256 face crop from *image* (768×768 preprocessed subject).
158
+ Tries insightface first; falls back to top-centre heuristic.
159
+ """
160
+ size = 256
161
+ # ── insightface ────────────────────────────────────────────────────────
162
+ try:
163
+ import insightface
164
+ from insightface.app import FaceAnalysis
165
+ _fa = FaceAnalysis(allowed_modules=["detection"])
166
+ _fa.prepare(ctx_id=0 if torch.cuda.is_available() else -1, det_size=(320, 320))
167
+ faces = _fa.get(np.array(image))
168
+ if faces:
169
+ b = faces[0].bbox.astype(int)
170
+ x0, y0, x1, y1 = max(0, b[0]), max(0, b[1]), min(image.width, b[2]), min(image.height, b[3])
171
+ face_img = image.crop((x0, y0, x1, y1)).resize((size, size), Image.LANCZOS)
172
+ return face_img
173
+ except Exception:
174
+ pass
175
+
176
+ # ── heuristic: top-centre 40 % of image ────────────────────────────────
177
+ w, h = image.size
178
+ face_h = int(h * 0.40)
179
+ margin = int(w * 0.15)
180
+ face_img = image.crop((margin, 0, w - margin, face_h)).resize((size, size), Image.LANCZOS)
181
+ return face_img
182
+
183
+
184
+ def _to_tensor(img: Image.Image) -> torch.Tensor:
185
+ """PIL RGB → float32 tensor (3, H, W) in [0, 1]."""
186
+ arr = np.array(img.convert("RGB"), dtype=np.float32) / 255.0
187
+ return torch.from_numpy(arr).permute(2, 0, 1) # (3, H, W)
188
+
189
+
190
+ def _to_pil(t: torch.Tensor) -> Image.Image:
191
+ """Float tensor (3, H, W) in [0, 1] → PIL RGB."""
192
+ arr = (t.float().clamp(0.0, 1.0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
193
+ return Image.fromarray(arr)
194
+
195
+
196
+ # ── Main inference ─────────────────────────────────────────────────────────────
197
+
198
+ def run_pshuman_diffusion(
199
+ image: Image.Image,
200
+ device: str = "cuda",
201
+ seed: int = 42,
202
+ guidance_scale: float = 3.0,
203
+ num_inference_steps: int = 40,
204
+ remove_bg_output: bool = True,
205
+ ) -> Tuple[List[Image.Image], List[Image.Image]]:
206
+ """
207
+ Run PSHuman multi-view diffusion on a single person image.
208
+
209
+ Parameters
210
+ ----------
211
+ image : Input PIL image (RGBA with bg removed, or RGB)
212
+ device : 'cuda' or 'cpu'
213
+ seed : RNG seed for reproducibility
214
+ guidance_scale : CFG scale (default 3.0, per PSHuman paper)
215
+ num_inference_steps: Diffusion steps (default 40)
216
+ remove_bg_output : Strip background from output colour views via rembg
217
+
218
+ Returns
219
+ -------
220
+ colour_views : List[PIL.Image] — 6 colour images
221
+ [front, front-right, right, back, left, front-left]
222
+ normal_views : List[PIL.Image] — 6 matching normal maps
223
+ """
224
+ from einops import rearrange
225
+
226
+ _ensure_sys_path()
227
+ _ensure_ckpt()
228
+
229
+ # ── 1. Preprocessing ──────────────────────────────────────────────────────
230
+ # PSHuman was trained on right-facing subjects → flip input left-right
231
+ flipped = image.transpose(Image.FLIP_LEFT_RIGHT)
232
+ subject = _preprocess_subject(flipped) # 768×768 RGB
233
+ face = _get_face_crop(subject) # 256×256 RGB
234
+
235
+ body_t = _to_tensor(subject) # (3, 768, 768)
236
+ face_t = _to_tensor(face.resize((_CANVAS_SIZE, _CANVAS_SIZE), Image.LANCZOS)) # (3, 768, 768)
237
+
238
+ # Stack: first (Nv-1) slots = body image, last slot = face crop
239
+ imgs_in = torch.stack(
240
+ [body_t] * (_NUM_VIEWS - 1) + [face_t], dim=0
241
+ ).float().unsqueeze(0) # (1, 7, 3, 768, 768)
242
+
243
+ # ── 2. Prompt embeddings ─────────────────────────────────────────────────
244
+ embeds_dir = _PSHUMAN_SRC / "mvdiffusion" / "data" / "fixed_prompt_embeds_7view"
245
+ normal_embeds = torch.load(embeds_dir / "normal_embeds.pt", map_location="cpu")
246
+ color_embeds = torch.load(embeds_dir / "clr_embeds.pt", map_location="cpu")
247
+
248
+ # Shapes from repo: (1, Nv, N, C) or (Nv, N, C) — normalise to (1, Nv, N, C)
249
+ if normal_embeds.dim() == 3:
250
+ normal_embeds = normal_embeds.unsqueeze(0)
251
+ if color_embeds.dim() == 3:
252
+ color_embeds = color_embeds.unsqueeze(0)
253
+
254
+ # ── 3. Batch construction (duplicate for CFG) ─────────────────────────────
255
+ # imgs_in: (2, 7, 3, 768, 768) → flat (14, 3, 768, 768)
256
+ imgs_2x = torch.cat([imgs_in, imgs_in], dim=0).to(device, dtype=torch.float16)
257
+ imgs_flat = rearrange(imgs_2x, "B Nv C H W -> (B Nv) C H W")
258
+
259
+ # prompt_embeds: (2, 7, N, C) → flat (14, N, C)
260
+ p_emb = torch.cat([normal_embeds, color_embeds], dim=0).to(device, dtype=torch.float16)
261
+ p_emb = rearrange(p_emb, "B Nv N C -> (B Nv) N C")
262
+
263
+ # ── 4. Diffusion ──────────────────────────────────────────────────────────
264
+ pipe = load_pipeline(device)
265
+ gen = torch.Generator(device=device).manual_seed(seed)
266
+
267
+ with torch.autocast("cuda"):
268
+ out_images = pipe(
269
+ imgs_flat,
270
+ None, # image_embeds slot
271
+ prompt_embeds=p_emb,
272
+ generator=gen,
273
+ guidance_scale=guidance_scale,
274
+ num_inference_steps=num_inference_steps,
275
+ eta=1.0,
276
+ output_type="pt",
277
+ num_images_per_prompt=1,
278
+ ).images # (14, 3, H, W) float [0,1]
279
+
280
+ # ── 5. Split normals / colours ────────────────────────────────────────────
281
+ bsz = out_images.shape[0] // 2 # = 7
282
+ normals_pt = out_images[:bsz].clone() # (7, 3, H, W)
283
+ colors_pt = out_images[bsz:].clone() # (7, 3, H, W)
284
+
285
+ # View 0 colour = input (PSHuman convention — not generated)
286
+ colors_pt[0] = imgs_flat[0].to(out_images.device)
287
+
288
+ # Views 3 & 4 are generated horizontally mirrored → flip back
289
+ for j in (3, 4):
290
+ normals_pt[j] = torch.flip(normals_pt[j], dims=[2])
291
+ colors_pt[j] = torch.flip(colors_pt[j], dims=[2])
292
+
293
+ # Paste the face-normal crop (view 6, 256px) into the top-right of normals[0]
294
+ face_nrm = F.interpolate(
295
+ normals_pt[6].unsqueeze(0), size=(256, 256),
296
+ mode="bilinear", align_corners=False,
297
+ ).squeeze(0)
298
+ normals_pt[0][:, :256, 256:512] = face_nrm
299
+
300
+ # ── 6. Convert to PIL (body views 0–5 only, skip face-crop view 6) ───────
301
+ colour_views = [_to_pil(colors_pt[j]) for j in range(6)]
302
+ normal_views = [_to_pil(normals_pt[j]) for j in range(6)]
303
+
304
+ # All outputs are in PSHuman's flipped (right-facing) space.
305
+ # Flip back so they match the user's original image orientation.
306
+ colour_views = [v.transpose(Image.FLIP_LEFT_RIGHT) for v in colour_views]
307
+ normal_views = [v.transpose(Image.FLIP_LEFT_RIGHT) for v in normal_views]
308
+
309
+ if remove_bg_output:
310
+ try:
311
+ from rembg import remove as _rembg
312
+ colour_views = [_rembg(v) for v in colour_views]
313
+ except Exception:
314
+ pass # rembg optional — return raw outputs if unavailable
315
+
316
+ return colour_views, normal_views
requirements.txt CHANGED
@@ -68,6 +68,8 @@ scikit-learn
68
  pandas
69
 
70
  # Utils
 
 
71
  easydict
72
  omegaconf
73
  yacs
 
68
  pandas
69
 
70
  # Utils
71
+ rembg
72
+ icecream
73
  easydict
74
  omegaconf
75
  yacs