Spaces:
Running on Zero
Running on Zero
Integrate PSHuman locally — multi-view diffusion in-process, no remote calls
Browse files- pipeline/pshuman_local.py: clones pengHTYX/PSHuman, downloads
pengHTYX/PSHuman_Unclip_768_6views, runs StableUnCLIPImg2ImgPipeline
directly (bypasses inference.py to avoid pytorch3d/kaolin imports).
Uses our own preprocessing (no mediapipe): insightface face-crop or
top-centre heuristic fallback. Returns 6 colour + 6 normal PIL views.
- app.py: gradio_pshuman_face now calls pshuman_local.run_pshuman_diffusion
with @spaces.GPU(duration=180). PSHuman tab shows colour/normal galleries.
Removes remote service URL input entirely.
- requirements.txt: add rembg (bg removal for outputs), icecream (PSHuman dep)
- app.py +41 -85
- pipeline/pshuman_local.py +316 -0
- requirements.txt +2 -0
app.py
CHANGED
|
@@ -1525,69 +1525,40 @@ def gradio_animate(
|
|
| 1525 |
return None, f"Error:\n{traceback.format_exc()}", None
|
| 1526 |
|
| 1527 |
|
| 1528 |
-
# ── PSHuman
|
| 1529 |
|
|
|
|
| 1530 |
def gradio_pshuman_face(
|
| 1531 |
input_image,
|
| 1532 |
-
rigged_glb_path,
|
| 1533 |
-
weight_threshold: float,
|
| 1534 |
-
retract_mm: float,
|
| 1535 |
-
pshuman_url: str,
|
| 1536 |
progress=gr.Progress(),
|
| 1537 |
):
|
| 1538 |
"""
|
| 1539 |
-
|
| 1540 |
-
|
| 1541 |
-
|
| 1542 |
-
|
| 1543 |
-
|
| 1544 |
-
|
| 1545 |
-
must point to an externally-deployed PSHuman endpoint (PSHUMAN_URL env var
|
| 1546 |
-
or user-provided URL in the UI). Local localhost will not work on ZeroGPU.
|
| 1547 |
"""
|
| 1548 |
try:
|
| 1549 |
if input_image is None:
|
| 1550 |
-
return None, "Upload a portrait image first."
|
| 1551 |
-
rigged = rigged_glb_path
|
| 1552 |
-
if not rigged or not os.path.exists(str(rigged)):
|
| 1553 |
-
return None, "No rigged GLB found — run the Rig step first.", None
|
| 1554 |
-
|
| 1555 |
-
work_dir = tempfile.mkdtemp(prefix="pshuman_transplant_")
|
| 1556 |
-
img_path = os.path.join(work_dir, "portrait.png")
|
| 1557 |
-
if isinstance(input_image, np.ndarray):
|
| 1558 |
-
Image.fromarray(input_image).save(img_path)
|
| 1559 |
-
else:
|
| 1560 |
-
input_image.save(img_path)
|
| 1561 |
-
|
| 1562 |
-
# pipeline/ is already in sys.path via PIPELINE_DIR insertion at startup
|
| 1563 |
-
# ── Step 1: PSHuman inference ──────────────────────────────────────────
|
| 1564 |
-
progress(0.05, desc="Step 1/2: Running PSHuman (generates multi-view face)...")
|
| 1565 |
-
from pipeline.pshuman_client import generate_pshuman_mesh
|
| 1566 |
-
face_obj = os.path.join(work_dir, "pshuman_face.obj")
|
| 1567 |
-
generate_pshuman_mesh(
|
| 1568 |
-
image_path = img_path,
|
| 1569 |
-
output_path = face_obj,
|
| 1570 |
-
service_url = pshuman_url.strip() or "http://localhost:7862",
|
| 1571 |
-
)
|
| 1572 |
|
| 1573 |
-
|
| 1574 |
-
|
| 1575 |
-
|
| 1576 |
-
|
| 1577 |
-
|
| 1578 |
-
|
| 1579 |
-
|
| 1580 |
-
|
| 1581 |
-
|
| 1582 |
-
weight_threshold = float(weight_threshold),
|
| 1583 |
-
retract_amount = float(retract_mm) / 1000.0, # mm → metres
|
| 1584 |
-
)
|
| 1585 |
|
| 1586 |
progress(1.0, desc="Done!")
|
| 1587 |
-
return
|
| 1588 |
|
| 1589 |
except Exception:
|
| 1590 |
-
return None, f"Error:\n{traceback.format_exc()}"
|
| 1591 |
|
| 1592 |
|
| 1593 |
# ── Full pipeline ─────────────────────────────────────────────────────────────
|
|
@@ -1867,53 +1838,38 @@ with gr.Blocks(title="Image2Model", theme=gr.themes.Soft()) as demo:
|
|
| 1867 |
# ════════════════════════════════════════════════════════════════════
|
| 1868 |
with gr.Tab("PSHuman Face"):
|
| 1869 |
gr.Markdown(
|
| 1870 |
-
"### PSHuman
|
| 1871 |
-
"Generates
|
| 1872 |
-
"
|
| 1873 |
-
"
|
| 1874 |
-
"**
|
| 1875 |
-
"
|
|
|
|
| 1876 |
)
|
| 1877 |
with gr.Row():
|
| 1878 |
with gr.Column(scale=1):
|
| 1879 |
pshuman_img_input = gr.Image(
|
| 1880 |
-
label="Portrait image
|
| 1881 |
type="pil",
|
| 1882 |
)
|
| 1883 |
-
|
| 1884 |
-
pshuman_weight_thresh = gr.Slider(
|
| 1885 |
-
minimum=0.1, maximum=0.9, value=0.35, step=0.05,
|
| 1886 |
-
label="Head bone weight threshold",
|
| 1887 |
-
info="Vertices with head-bone weight above this get replaced",
|
| 1888 |
-
)
|
| 1889 |
-
pshuman_retract_mm = gr.Slider(
|
| 1890 |
-
minimum=0.0, maximum=20.0, value=4.0, step=0.5,
|
| 1891 |
-
label="Face retract (mm)",
|
| 1892 |
-
info="How far to push original face verts inward to avoid z-fighting",
|
| 1893 |
-
)
|
| 1894 |
-
pshuman_service_url = gr.Textbox(
|
| 1895 |
-
label="PSHuman service URL",
|
| 1896 |
-
value=os.environ.get("PSHUMAN_URL", "http://localhost:7862"),
|
| 1897 |
-
info="pshuman_app.py Gradio endpoint (deployed separately)",
|
| 1898 |
-
)
|
| 1899 |
-
pshuman_btn = gr.Button("Generate HD Face", variant="primary")
|
| 1900 |
|
| 1901 |
with gr.Column(scale=2):
|
| 1902 |
-
pshuman_status
|
| 1903 |
-
|
| 1904 |
-
|
| 1905 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1906 |
|
| 1907 |
pshuman_btn.click(
|
| 1908 |
fn=gradio_pshuman_face,
|
| 1909 |
-
inputs=[
|
| 1910 |
-
|
| 1911 |
-
rigged_glb_state,
|
| 1912 |
-
pshuman_weight_thresh,
|
| 1913 |
-
pshuman_retract_mm,
|
| 1914 |
-
pshuman_service_url,
|
| 1915 |
-
],
|
| 1916 |
-
outputs=[pshuman_glb_dl, pshuman_status, pshuman_model_3d],
|
| 1917 |
)
|
| 1918 |
|
| 1919 |
# ════════════════════════════════════════════════════════════════════
|
|
|
|
| 1525 |
return None, f"Error:\n{traceback.format_exc()}", None
|
| 1526 |
|
| 1527 |
|
| 1528 |
+
# ── PSHuman Multi-View ────────────────────────────────────────────────────────
|
| 1529 |
|
| 1530 |
+
@spaces.GPU(duration=180)
|
| 1531 |
def gradio_pshuman_face(
|
| 1532 |
input_image,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1533 |
progress=gr.Progress(),
|
| 1534 |
):
|
| 1535 |
"""
|
| 1536 |
+
Run PSHuman multi-view diffusion locally (in-process).
|
| 1537 |
+
Returns 6 colour views + 6 normal-map views of the person.
|
| 1538 |
+
|
| 1539 |
+
Full mesh reconstruction (pytorch3d / kaolin / torch_scatter) is skipped —
|
| 1540 |
+
those packages have no Python 3.13 wheels. The generated views can be used
|
| 1541 |
+
directly for inspection or fed into the face-swap step.
|
|
|
|
|
|
|
| 1542 |
"""
|
| 1543 |
try:
|
| 1544 |
if input_image is None:
|
| 1545 |
+
return None, None, "Upload a portrait image first."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1546 |
|
| 1547 |
+
img = (Image.fromarray(input_image) if isinstance(input_image, np.ndarray)
|
| 1548 |
+
else input_image.convert("RGBA") if input_image.mode != "RGBA"
|
| 1549 |
+
else input_image)
|
| 1550 |
+
|
| 1551 |
+
progress(0.1, desc="Loading PSHuman pipeline…")
|
| 1552 |
+
from pipeline.pshuman_local import run_pshuman_diffusion
|
| 1553 |
+
|
| 1554 |
+
progress(0.2, desc="Running multi-view diffusion (40 steps × 7 views)…")
|
| 1555 |
+
colour_views, normal_views = run_pshuman_diffusion(img, device="cuda")
|
|
|
|
|
|
|
|
|
|
| 1556 |
|
| 1557 |
progress(1.0, desc="Done!")
|
| 1558 |
+
return colour_views, normal_views, "Multi-view generation complete."
|
| 1559 |
|
| 1560 |
except Exception:
|
| 1561 |
+
return None, None, f"Error:\n{traceback.format_exc()}"
|
| 1562 |
|
| 1563 |
|
| 1564 |
# ── Full pipeline ─────────────────────────────────────────────────────────────
|
|
|
|
| 1838 |
# ════════════════════════════════════════════════════════════════════
|
| 1839 |
with gr.Tab("PSHuman Face"):
|
| 1840 |
gr.Markdown(
|
| 1841 |
+
"### PSHuman Multi-View (local)\n"
|
| 1842 |
+
"Generates 6 colour + 6 normal-map views of a person using "
|
| 1843 |
+
"[PSHuman](https://github.com/pengHTYX/PSHuman) "
|
| 1844 |
+
"(StableUnCLIP fine-tuned on multi-view human images).\n\n"
|
| 1845 |
+
"**Pipeline:** portrait → multi-view diffusion (in-process) → "
|
| 1846 |
+
"6 × colour + 6 × normal views\n\n"
|
| 1847 |
+
"**Views:** front · front-right · right · back · left · front-left"
|
| 1848 |
)
|
| 1849 |
with gr.Row():
|
| 1850 |
with gr.Column(scale=1):
|
| 1851 |
pshuman_img_input = gr.Image(
|
| 1852 |
+
label="Portrait image",
|
| 1853 |
type="pil",
|
| 1854 |
)
|
| 1855 |
+
pshuman_btn = gr.Button("Generate Views", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1856 |
|
| 1857 |
with gr.Column(scale=2):
|
| 1858 |
+
pshuman_status = gr.Textbox(
|
| 1859 |
+
label="Status", lines=2, interactive=False)
|
| 1860 |
+
pshuman_colour_gallery = gr.Gallery(
|
| 1861 |
+
label="Colour views (front → front-right → right → back → left → front-left)",
|
| 1862 |
+
columns=3, rows=2, height=420,
|
| 1863 |
+
)
|
| 1864 |
+
pshuman_normal_gallery = gr.Gallery(
|
| 1865 |
+
label="Normal maps",
|
| 1866 |
+
columns=3, rows=2, height=420,
|
| 1867 |
+
)
|
| 1868 |
|
| 1869 |
pshuman_btn.click(
|
| 1870 |
fn=gradio_pshuman_face,
|
| 1871 |
+
inputs=[pshuman_img_input],
|
| 1872 |
+
outputs=[pshuman_colour_gallery, pshuman_normal_gallery, pshuman_status],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1873 |
)
|
| 1874 |
|
| 1875 |
# ════════════════════════════════════════════════════════════════════
|
pipeline/pshuman_local.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
pshuman_local.py
|
| 3 |
+
================
|
| 4 |
+
Run PSHuman multi-view diffusion in-process on ZeroGPU.
|
| 5 |
+
Generates 6 colour views + 6 normal views from a single portrait image.
|
| 6 |
+
|
| 7 |
+
Full mesh reconstruction (SMPL fitting, pytorch3d/kaolin texture projection)
|
| 8 |
+
is intentionally skipped — those deps have no Python 3.13 wheels.
|
| 9 |
+
|
| 10 |
+
Model : pengHTYX/PSHuman_Unclip_768_6views
|
| 11 |
+
Repo : https://github.com/pengHTYX/PSHuman (cloned to /tmp/pshuman-src)
|
| 12 |
+
|
| 13 |
+
The preprocessing replaces PSHuman's mediapipe-based face-crop helper with a
|
| 14 |
+
self-contained approach (insightface if available, else top-centre crop) so no
|
| 15 |
+
additional exotic deps are needed.
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
import subprocess
|
| 22 |
+
import tempfile
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import List, Optional, Tuple
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
from PIL import Image
|
| 30 |
+
|
| 31 |
+
# ── Constants ─────────────────────────────────────────────────────────────────
|
| 32 |
+
_PSHUMAN_SRC = Path("/tmp/pshuman-src")
|
| 33 |
+
_PSHUMAN_CKPT = Path("/tmp/pshuman-ckpts")
|
| 34 |
+
_PSHUMAN_REPO = "https://github.com/pengHTYX/PSHuman.git"
|
| 35 |
+
_PSHUMAN_HF_ID = "pengHTYX/PSHuman_Unclip_768_6views"
|
| 36 |
+
|
| 37 |
+
_CANVAS_SIZE = 768 # model input resolution
|
| 38 |
+
_CROP_SIZE = 740 # subject bounding-box target size within canvas
|
| 39 |
+
_NUM_VIEWS = 7 # 6 body views + 1 face-crop view (model expectation)
|
| 40 |
+
|
| 41 |
+
# Cached pipeline (survives across ZeroGPU calls when the Space is warm)
|
| 42 |
+
_pipeline = None
|
| 43 |
+
|
| 44 |
+
# ── Setup helpers ─────────────────────────────────────────────────────────────
|
| 45 |
+
|
| 46 |
+
def _ensure_repo() -> None:
|
| 47 |
+
if _PSHUMAN_SRC.exists():
|
| 48 |
+
return
|
| 49 |
+
print("[pshuman] Cloning PSHuman repo…")
|
| 50 |
+
subprocess.run(
|
| 51 |
+
["git", "clone", "--depth=1", _PSHUMAN_REPO, str(_PSHUMAN_SRC)],
|
| 52 |
+
check=True,
|
| 53 |
+
)
|
| 54 |
+
print("[pshuman] Repo cloned.")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _ensure_sys_path() -> None:
|
| 58 |
+
_ensure_repo()
|
| 59 |
+
src = str(_PSHUMAN_SRC)
|
| 60 |
+
if src not in sys.path:
|
| 61 |
+
sys.path.insert(0, src)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _ensure_ckpt() -> None:
|
| 65 |
+
if _PSHUMAN_CKPT.exists():
|
| 66 |
+
return
|
| 67 |
+
from huggingface_hub import snapshot_download
|
| 68 |
+
print("[pshuman] Downloading model weights…")
|
| 69 |
+
snapshot_download(repo_id=_PSHUMAN_HF_ID, local_dir=str(_PSHUMAN_CKPT))
|
| 70 |
+
print("[pshuman] Weights downloaded.")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def load_pipeline(device: str = "cuda"):
|
| 74 |
+
"""Load (and cache) the PSHuman StableUnCLIPImg2ImgPipeline."""
|
| 75 |
+
global _pipeline
|
| 76 |
+
if _pipeline is not None:
|
| 77 |
+
_pipeline.to(device)
|
| 78 |
+
return _pipeline
|
| 79 |
+
|
| 80 |
+
_ensure_sys_path()
|
| 81 |
+
_ensure_ckpt()
|
| 82 |
+
|
| 83 |
+
from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import (
|
| 84 |
+
StableUnCLIPImg2ImgPipeline,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
print("[pshuman] Loading pipeline…")
|
| 88 |
+
pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
|
| 89 |
+
str(_PSHUMAN_CKPT),
|
| 90 |
+
torch_dtype=torch.float16,
|
| 91 |
+
)
|
| 92 |
+
try:
|
| 93 |
+
pipe.unet.enable_xformers_memory_efficient_attention()
|
| 94 |
+
except Exception:
|
| 95 |
+
pass
|
| 96 |
+
pipe.to(device)
|
| 97 |
+
_pipeline = pipe
|
| 98 |
+
print("[pshuman] Pipeline ready.")
|
| 99 |
+
return pipe
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ── Image preprocessing ───────────────────────────────────────────────────────
|
| 103 |
+
|
| 104 |
+
def _preprocess_subject(image: Image.Image) -> Image.Image:
|
| 105 |
+
"""
|
| 106 |
+
Replicate PSHuman's load_image preprocessing:
|
| 107 |
+
1. Ensure RGBA (background = white if no alpha).
|
| 108 |
+
2. Crop to the bounding box of the alpha channel.
|
| 109 |
+
3. Scale the longest side to _CROP_SIZE.
|
| 110 |
+
4. Paste centred on a white _CANVAS_SIZE × _CANVAS_SIZE canvas.
|
| 111 |
+
|
| 112 |
+
Returns an RGB image (768 × 768).
|
| 113 |
+
"""
|
| 114 |
+
if image.mode != "RGBA":
|
| 115 |
+
# No alpha → treat entire image as foreground
|
| 116 |
+
rgba = image.convert("RGBA")
|
| 117 |
+
alpha = np.ones((rgba.size[1], rgba.size[0]), dtype=np.uint8) * 255
|
| 118 |
+
rgba.putalpha(Image.fromarray(alpha))
|
| 119 |
+
image = rgba
|
| 120 |
+
else:
|
| 121 |
+
image = image.copy()
|
| 122 |
+
|
| 123 |
+
arr = np.array(image) # H W 4
|
| 124 |
+
a = arr[:, :, 3]
|
| 125 |
+
|
| 126 |
+
# Bounding box of non-transparent pixels
|
| 127 |
+
rows = np.any(a > 0, axis=1)
|
| 128 |
+
cols = np.any(a > 0, axis=0)
|
| 129 |
+
if rows.any():
|
| 130 |
+
r0, r1 = np.where(rows)[0][[0, -1]]
|
| 131 |
+
c0, c1 = np.where(cols)[0][[0, -1]]
|
| 132 |
+
crop = image.crop((c0, r0, c1 + 1, r1 + 1))
|
| 133 |
+
else:
|
| 134 |
+
crop = image
|
| 135 |
+
|
| 136 |
+
# Scale longest side → _CROP_SIZE
|
| 137 |
+
cw, ch = crop.size
|
| 138 |
+
scale = _CROP_SIZE / max(cw, ch)
|
| 139 |
+
new_w = round(cw * scale)
|
| 140 |
+
new_h = round(ch * scale)
|
| 141 |
+
crop = crop.resize((new_w, new_h), Image.LANCZOS)
|
| 142 |
+
|
| 143 |
+
# Paste on white canvas
|
| 144 |
+
canvas = Image.new("RGB", (_CANVAS_SIZE, _CANVAS_SIZE), (255, 255, 255))
|
| 145 |
+
x_off = (_CANVAS_SIZE - new_w) // 2
|
| 146 |
+
y_off = (_CANVAS_SIZE - new_h) // 2
|
| 147 |
+
if crop.mode == "RGBA":
|
| 148 |
+
canvas.paste(crop.convert("RGB"), (x_off, y_off), crop.split()[3])
|
| 149 |
+
else:
|
| 150 |
+
canvas.paste(crop.convert("RGB"), (x_off, y_off))
|
| 151 |
+
|
| 152 |
+
return canvas
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _get_face_crop(image: Image.Image) -> Image.Image:
|
| 156 |
+
"""
|
| 157 |
+
Return a 256×256 face crop from *image* (768×768 preprocessed subject).
|
| 158 |
+
Tries insightface first; falls back to top-centre heuristic.
|
| 159 |
+
"""
|
| 160 |
+
size = 256
|
| 161 |
+
# ── insightface ────────────────────────────────────────────────────────
|
| 162 |
+
try:
|
| 163 |
+
import insightface
|
| 164 |
+
from insightface.app import FaceAnalysis
|
| 165 |
+
_fa = FaceAnalysis(allowed_modules=["detection"])
|
| 166 |
+
_fa.prepare(ctx_id=0 if torch.cuda.is_available() else -1, det_size=(320, 320))
|
| 167 |
+
faces = _fa.get(np.array(image))
|
| 168 |
+
if faces:
|
| 169 |
+
b = faces[0].bbox.astype(int)
|
| 170 |
+
x0, y0, x1, y1 = max(0, b[0]), max(0, b[1]), min(image.width, b[2]), min(image.height, b[3])
|
| 171 |
+
face_img = image.crop((x0, y0, x1, y1)).resize((size, size), Image.LANCZOS)
|
| 172 |
+
return face_img
|
| 173 |
+
except Exception:
|
| 174 |
+
pass
|
| 175 |
+
|
| 176 |
+
# ── heuristic: top-centre 40 % of image ────────────────────────────────
|
| 177 |
+
w, h = image.size
|
| 178 |
+
face_h = int(h * 0.40)
|
| 179 |
+
margin = int(w * 0.15)
|
| 180 |
+
face_img = image.crop((margin, 0, w - margin, face_h)).resize((size, size), Image.LANCZOS)
|
| 181 |
+
return face_img
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _to_tensor(img: Image.Image) -> torch.Tensor:
|
| 185 |
+
"""PIL RGB → float32 tensor (3, H, W) in [0, 1]."""
|
| 186 |
+
arr = np.array(img.convert("RGB"), dtype=np.float32) / 255.0
|
| 187 |
+
return torch.from_numpy(arr).permute(2, 0, 1) # (3, H, W)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _to_pil(t: torch.Tensor) -> Image.Image:
|
| 191 |
+
"""Float tensor (3, H, W) in [0, 1] → PIL RGB."""
|
| 192 |
+
arr = (t.float().clamp(0.0, 1.0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
| 193 |
+
return Image.fromarray(arr)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# ── Main inference ─────────────────────────────────────────────────────────────
|
| 197 |
+
|
| 198 |
+
def run_pshuman_diffusion(
|
| 199 |
+
image: Image.Image,
|
| 200 |
+
device: str = "cuda",
|
| 201 |
+
seed: int = 42,
|
| 202 |
+
guidance_scale: float = 3.0,
|
| 203 |
+
num_inference_steps: int = 40,
|
| 204 |
+
remove_bg_output: bool = True,
|
| 205 |
+
) -> Tuple[List[Image.Image], List[Image.Image]]:
|
| 206 |
+
"""
|
| 207 |
+
Run PSHuman multi-view diffusion on a single person image.
|
| 208 |
+
|
| 209 |
+
Parameters
|
| 210 |
+
----------
|
| 211 |
+
image : Input PIL image (RGBA with bg removed, or RGB)
|
| 212 |
+
device : 'cuda' or 'cpu'
|
| 213 |
+
seed : RNG seed for reproducibility
|
| 214 |
+
guidance_scale : CFG scale (default 3.0, per PSHuman paper)
|
| 215 |
+
num_inference_steps: Diffusion steps (default 40)
|
| 216 |
+
remove_bg_output : Strip background from output colour views via rembg
|
| 217 |
+
|
| 218 |
+
Returns
|
| 219 |
+
-------
|
| 220 |
+
colour_views : List[PIL.Image] — 6 colour images
|
| 221 |
+
[front, front-right, right, back, left, front-left]
|
| 222 |
+
normal_views : List[PIL.Image] — 6 matching normal maps
|
| 223 |
+
"""
|
| 224 |
+
from einops import rearrange
|
| 225 |
+
|
| 226 |
+
_ensure_sys_path()
|
| 227 |
+
_ensure_ckpt()
|
| 228 |
+
|
| 229 |
+
# ── 1. Preprocessing ──────────────────────────────────────────────────────
|
| 230 |
+
# PSHuman was trained on right-facing subjects → flip input left-right
|
| 231 |
+
flipped = image.transpose(Image.FLIP_LEFT_RIGHT)
|
| 232 |
+
subject = _preprocess_subject(flipped) # 768×768 RGB
|
| 233 |
+
face = _get_face_crop(subject) # 256×256 RGB
|
| 234 |
+
|
| 235 |
+
body_t = _to_tensor(subject) # (3, 768, 768)
|
| 236 |
+
face_t = _to_tensor(face.resize((_CANVAS_SIZE, _CANVAS_SIZE), Image.LANCZOS)) # (3, 768, 768)
|
| 237 |
+
|
| 238 |
+
# Stack: first (Nv-1) slots = body image, last slot = face crop
|
| 239 |
+
imgs_in = torch.stack(
|
| 240 |
+
[body_t] * (_NUM_VIEWS - 1) + [face_t], dim=0
|
| 241 |
+
).float().unsqueeze(0) # (1, 7, 3, 768, 768)
|
| 242 |
+
|
| 243 |
+
# ── 2. Prompt embeddings ─────────────────────────────────────────────────
|
| 244 |
+
embeds_dir = _PSHUMAN_SRC / "mvdiffusion" / "data" / "fixed_prompt_embeds_7view"
|
| 245 |
+
normal_embeds = torch.load(embeds_dir / "normal_embeds.pt", map_location="cpu")
|
| 246 |
+
color_embeds = torch.load(embeds_dir / "clr_embeds.pt", map_location="cpu")
|
| 247 |
+
|
| 248 |
+
# Shapes from repo: (1, Nv, N, C) or (Nv, N, C) — normalise to (1, Nv, N, C)
|
| 249 |
+
if normal_embeds.dim() == 3:
|
| 250 |
+
normal_embeds = normal_embeds.unsqueeze(0)
|
| 251 |
+
if color_embeds.dim() == 3:
|
| 252 |
+
color_embeds = color_embeds.unsqueeze(0)
|
| 253 |
+
|
| 254 |
+
# ── 3. Batch construction (duplicate for CFG) ─────────────────────────────
|
| 255 |
+
# imgs_in: (2, 7, 3, 768, 768) → flat (14, 3, 768, 768)
|
| 256 |
+
imgs_2x = torch.cat([imgs_in, imgs_in], dim=0).to(device, dtype=torch.float16)
|
| 257 |
+
imgs_flat = rearrange(imgs_2x, "B Nv C H W -> (B Nv) C H W")
|
| 258 |
+
|
| 259 |
+
# prompt_embeds: (2, 7, N, C) → flat (14, N, C)
|
| 260 |
+
p_emb = torch.cat([normal_embeds, color_embeds], dim=0).to(device, dtype=torch.float16)
|
| 261 |
+
p_emb = rearrange(p_emb, "B Nv N C -> (B Nv) N C")
|
| 262 |
+
|
| 263 |
+
# ── 4. Diffusion ──────────────────────────────────────────────────────────
|
| 264 |
+
pipe = load_pipeline(device)
|
| 265 |
+
gen = torch.Generator(device=device).manual_seed(seed)
|
| 266 |
+
|
| 267 |
+
with torch.autocast("cuda"):
|
| 268 |
+
out_images = pipe(
|
| 269 |
+
imgs_flat,
|
| 270 |
+
None, # image_embeds slot
|
| 271 |
+
prompt_embeds=p_emb,
|
| 272 |
+
generator=gen,
|
| 273 |
+
guidance_scale=guidance_scale,
|
| 274 |
+
num_inference_steps=num_inference_steps,
|
| 275 |
+
eta=1.0,
|
| 276 |
+
output_type="pt",
|
| 277 |
+
num_images_per_prompt=1,
|
| 278 |
+
).images # (14, 3, H, W) float [0,1]
|
| 279 |
+
|
| 280 |
+
# ── 5. Split normals / colours ────────────────────────────────────────────
|
| 281 |
+
bsz = out_images.shape[0] // 2 # = 7
|
| 282 |
+
normals_pt = out_images[:bsz].clone() # (7, 3, H, W)
|
| 283 |
+
colors_pt = out_images[bsz:].clone() # (7, 3, H, W)
|
| 284 |
+
|
| 285 |
+
# View 0 colour = input (PSHuman convention — not generated)
|
| 286 |
+
colors_pt[0] = imgs_flat[0].to(out_images.device)
|
| 287 |
+
|
| 288 |
+
# Views 3 & 4 are generated horizontally mirrored → flip back
|
| 289 |
+
for j in (3, 4):
|
| 290 |
+
normals_pt[j] = torch.flip(normals_pt[j], dims=[2])
|
| 291 |
+
colors_pt[j] = torch.flip(colors_pt[j], dims=[2])
|
| 292 |
+
|
| 293 |
+
# Paste the face-normal crop (view 6, 256px) into the top-right of normals[0]
|
| 294 |
+
face_nrm = F.interpolate(
|
| 295 |
+
normals_pt[6].unsqueeze(0), size=(256, 256),
|
| 296 |
+
mode="bilinear", align_corners=False,
|
| 297 |
+
).squeeze(0)
|
| 298 |
+
normals_pt[0][:, :256, 256:512] = face_nrm
|
| 299 |
+
|
| 300 |
+
# ── 6. Convert to PIL (body views 0–5 only, skip face-crop view 6) ───────
|
| 301 |
+
colour_views = [_to_pil(colors_pt[j]) for j in range(6)]
|
| 302 |
+
normal_views = [_to_pil(normals_pt[j]) for j in range(6)]
|
| 303 |
+
|
| 304 |
+
# All outputs are in PSHuman's flipped (right-facing) space.
|
| 305 |
+
# Flip back so they match the user's original image orientation.
|
| 306 |
+
colour_views = [v.transpose(Image.FLIP_LEFT_RIGHT) for v in colour_views]
|
| 307 |
+
normal_views = [v.transpose(Image.FLIP_LEFT_RIGHT) for v in normal_views]
|
| 308 |
+
|
| 309 |
+
if remove_bg_output:
|
| 310 |
+
try:
|
| 311 |
+
from rembg import remove as _rembg
|
| 312 |
+
colour_views = [_rembg(v) for v in colour_views]
|
| 313 |
+
except Exception:
|
| 314 |
+
pass # rembg optional — return raw outputs if unavailable
|
| 315 |
+
|
| 316 |
+
return colour_views, normal_views
|
requirements.txt
CHANGED
|
@@ -68,6 +68,8 @@ scikit-learn
|
|
| 68 |
pandas
|
| 69 |
|
| 70 |
# Utils
|
|
|
|
|
|
|
| 71 |
easydict
|
| 72 |
omegaconf
|
| 73 |
yacs
|
|
|
|
| 68 |
pandas
|
| 69 |
|
| 70 |
# Utils
|
| 71 |
+
rembg
|
| 72 |
+
icecream
|
| 73 |
easydict
|
| 74 |
omegaconf
|
| 75 |
yacs
|