Commit ·
873b6ec
0
Parent(s):
init
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- .gitignore +46 -0
- .pre-commit-config.yaml +53 -0
- README.md +191 -0
- example/assets/image.png +3 -0
- example/assets/prompt.txt +7 -0
- example/base/config.json +16 -0
- example/base/run.sh +29 -0
- example/distill/config.json +17 -0
- example/distill/run.sh +29 -0
- example/sr_1080p/config.json +20 -0
- example/sr_1080p/run.sh +33 -0
- example/sr_540p/config.json +20 -0
- example/sr_540p/run.sh +32 -0
- inference/common/__init__.py +39 -0
- inference/common/arch.py +35 -0
- inference/common/config.py +283 -0
- inference/common/cpu_offload_wrapper.py +186 -0
- inference/common/sequence_schema.py +33 -0
- inference/infra/__init__.py +37 -0
- inference/infra/checkpoint/__init__.py +20 -0
- inference/infra/checkpoint/load_model_checkpoint.py +99 -0
- inference/infra/distributed/__init__.py +28 -0
- inference/infra/distributed/init_dist_env.py +62 -0
- inference/infra/distributed/parallel_state.py +659 -0
- inference/infra/distributed/utils.py +47 -0
- inference/infra/parallelism/__init__.py +20 -0
- inference/infra/parallelism/all_to_all_primitive.py +142 -0
- inference/infra/parallelism/gather_scatter_primitive.py +217 -0
- inference/infra/parallelism/ulysses_scheduler.py +143 -0
- inference/model/dit/__init__.py +18 -0
- inference/model/dit/dit_model.py +42 -0
- inference/model/dit/dit_module.py +950 -0
- inference/model/sa_audio/__init__.py +25 -0
- inference/model/sa_audio/sa_audio_model.py +116 -0
- inference/model/sa_audio/sa_audio_module.py +478 -0
- inference/model/t5_gemma/__init__.py +3 -0
- inference/model/t5_gemma/t5_gemma_model.py +43 -0
- inference/model/turbo_vaed/__init__.py +4 -0
- inference/model/turbo_vaed/turbo_vaed_model.py +33 -0
- inference/model/turbo_vaed/turbo_vaed_module.py +1039 -0
- inference/model/vae2_2/__init__.py +3 -0
- inference/model/vae2_2/vae2_2_model.py +17 -0
- inference/model/vae2_2/vae2_2_module.py +1086 -0
- inference/pipeline/__init__.py +20 -0
- inference/pipeline/data_proxy.py +390 -0
- inference/pipeline/entry.py +96 -0
- inference/pipeline/pipeline.py +108 -0
- inference/pipeline/prompt_process.py +60 -0
- inference/pipeline/scheduler_unipc.py +832 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tmp*
|
| 2 |
+
depyf
|
| 3 |
+
torch_compile_cache
|
| 4 |
+
|
| 5 |
+
__pycache__
|
| 6 |
+
*.so
|
| 7 |
+
build
|
| 8 |
+
.coverage_*
|
| 9 |
+
*.egg-info
|
| 10 |
+
*~
|
| 11 |
+
slurm*
|
| 12 |
+
logs
|
| 13 |
+
.vscode
|
| 14 |
+
nsys*
|
| 15 |
+
tmp/*
|
| 16 |
+
.mypy_cache
|
| 17 |
+
output
|
| 18 |
+
*.pyc
|
| 19 |
+
*.log
|
| 20 |
+
.idea
|
| 21 |
+
*.pt
|
| 22 |
+
*.png
|
| 23 |
+
*.jpg
|
| 24 |
+
*.jpeg
|
| 25 |
+
*.gif
|
| 26 |
+
*.mp3
|
| 27 |
+
*.mp4
|
| 28 |
+
*.pickle
|
| 29 |
+
*.nsys-rep
|
| 30 |
+
*.html
|
| 31 |
+
*.mov
|
| 32 |
+
*.safetensors
|
| 33 |
+
*.json
|
| 34 |
+
|
| 35 |
+
# Keep example media assets tracked.
|
| 36 |
+
!example/assets/*.png
|
| 37 |
+
!example/assets/*.mp4
|
| 38 |
+
!example/**/*.json
|
| 39 |
+
|
| 40 |
+
proj*
|
| 41 |
+
.venv
|
| 42 |
+
var
|
| 43 |
+
tags
|
| 44 |
+
fx_graph*.pdf
|
| 45 |
+
/clean_repo.py
|
| 46 |
+
/rm_caches.sh
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
exclude: \.patch$
|
| 2 |
+
repos:
|
| 3 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 4 |
+
rev: v4.4.0
|
| 5 |
+
hooks:
|
| 6 |
+
- id: check-added-large-files
|
| 7 |
+
args:
|
| 8 |
+
- --maxkb=30720
|
| 9 |
+
- id: check-merge-conflict
|
| 10 |
+
- id: check-symlinks
|
| 11 |
+
- id: detect-private-key
|
| 12 |
+
files: (?!.*third_party)^.*$ | (?!.*book)^.*$
|
| 13 |
+
- id: end-of-file-fixer
|
| 14 |
+
- id: trailing-whitespace
|
| 15 |
+
- id: requirements-txt-fixer
|
| 16 |
+
- id: sort-simple-yaml
|
| 17 |
+
- repo: https://github.com/Lucas-C/pre-commit-hooks.git
|
| 18 |
+
rev: v1.5.1
|
| 19 |
+
hooks:
|
| 20 |
+
- id: remove-crlf
|
| 21 |
+
files: (?!.*third_party)^.*$ | (?!.*book)^.*$
|
| 22 |
+
- id: remove-tabs
|
| 23 |
+
name: Tabs remover (C++)
|
| 24 |
+
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps)$
|
| 25 |
+
args: [--whitespaces-count, '2']
|
| 26 |
+
- id: remove-tabs
|
| 27 |
+
name: Tabs remover (Python)
|
| 28 |
+
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
|
| 29 |
+
args: [--whitespaces-count, '4']
|
| 30 |
+
- repo: https://github.com/psf/black.git
|
| 31 |
+
rev: 23.3.0
|
| 32 |
+
hooks:
|
| 33 |
+
- id: black
|
| 34 |
+
args: [--line-length=127, --skip-string-normalization, --skip-magic-trailing-comma]
|
| 35 |
+
files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
|
| 36 |
+
- repo: https://github.com/pre-commit/mirrors-isort
|
| 37 |
+
rev: v5.10.1
|
| 38 |
+
hooks:
|
| 39 |
+
- id: isort
|
| 40 |
+
args: [--profile=black, --line-length=127, --multi-line=3, --force-grid-wrap=0, --src-path=infra, --src-path=pipeline, --src-path=model]
|
| 41 |
+
files: \.py$
|
| 42 |
+
- repo: https://github.com/PyCQA/autoflake
|
| 43 |
+
rev: v2.3.1
|
| 44 |
+
hooks:
|
| 45 |
+
- id: autoflake
|
| 46 |
+
args: [--remove-all-unused-imports, --remove-unused-variables, --in-place, --ignore-init-module-imports, --ignore-pass-after-docstring]
|
| 47 |
+
files: \.py$
|
| 48 |
+
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks.git
|
| 49 |
+
rev: v2.9.0
|
| 50 |
+
hooks:
|
| 51 |
+
- id: pretty-format-yaml
|
| 52 |
+
args: [--autofix, --indent, '4']
|
| 53 |
+
additional_dependencies: [setuptools]
|
README.md
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
# daVinci-MagiHuman
|
| 4 |
+
|
| 5 |
+
### Speed by Simplicity: A Single-Stream Architecture for Fast Audio-Video Generative Foundation Model
|
| 6 |
+
|
| 7 |
+
<p align="center">
|
| 8 |
+
<a href="https://www.sjtu.edu.cn/">SII-GAIR</a> & <a href="https://sand.ai">Sand.ai</a>
|
| 9 |
+
</p>
|
| 10 |
+
|
| 11 |
+
[](https://github.com/GAIR-NLP/daVinci-MagiHuman/blob/main/assets/daVinci_MagiHuman.pdf)
|
| 12 |
+
[](https://huggingface.co/spaces/SII-GAIR/daVinci-MagiHuman)
|
| 13 |
+
[](https://huggingface.co/GAIR-NLP/daVinci-MagiHuman)
|
| 14 |
+
[](https://opensource.org/licenses/Apache-2.0)
|
| 15 |
+
[](https://www.python.org/)
|
| 16 |
+
[](https://pytorch.org/)
|
| 17 |
+
|
| 18 |
+
</div>
|
| 19 |
+
|
| 20 |
+
## Highlights
|
| 21 |
+
|
| 22 |
+
- **Single-Stream Transformer** — A unified 15B-parameter, 40-layer Transformer that jointly processes text, video, and audio via self-attention only. No cross-attention, no multi-stream complexity.
|
| 23 |
+
- **Exceptional Human-Centric Quality** — Expressive facial performance, natural speech-expression coordination, realistic body motion, and accurate audio-video synchronization.
|
| 24 |
+
- **Multilingual** — Supports Chinese (Mandarin & Cantonese), English, Japanese, Korean, German, and French.
|
| 25 |
+
- **Blazing Fast Inference** — Generates a 5-second 256p video in **2 seconds** and a 5-second 1080p video in **38 seconds** on a single H100 GPU.
|
| 26 |
+
- **State-of-the-Art Results** — Achieves **80.0%** win rate vs Ovi 1.1 and **60.9%** vs LTX 2.3 in pairwise human evaluation over 2,000 comparisons.
|
| 27 |
+
- **Fully Open Source** — We release the complete model stack: base model, distilled model, super-resolution model, and inference code.
|
| 28 |
+
|
| 29 |
+
## Demo
|
| 30 |
+
|
| 31 |
+
<!--
|
| 32 |
+
To add demo videos:
|
| 33 |
+
1. Open a GitHub issue on this repo
|
| 34 |
+
2. Drag & drop your .mp4 files into the issue comment box
|
| 35 |
+
3. Copy the generated URLs and paste them below
|
| 36 |
+
|
| 37 |
+
Example:
|
| 38 |
+
https://github.com/user-attachments/assets/xxxx-xxxx
|
| 39 |
+
-->
|
| 40 |
+
|
| 41 |
+
https://github.com/user-attachments/assets/PLACEHOLDER_VIDEO_1
|
| 42 |
+
|
| 43 |
+
https://github.com/user-attachments/assets/PLACEHOLDER_VIDEO_2
|
| 44 |
+
|
| 45 |
+
https://github.com/user-attachments/assets/PLACEHOLDER_VIDEO_3
|
| 46 |
+
|
| 47 |
+
## Architecture
|
| 48 |
+
|
| 49 |
+
<div align="center">
|
| 50 |
+
<img src="assets/architecture.png" width="90%">
|
| 51 |
+
</div>
|
| 52 |
+
|
| 53 |
+
daVinci-MagiHuman uses a single-stream Transformer that takes text tokens, a reference image latent, and noisy video and audio tokens as input, and jointly denoises the video and audio within a unified token sequence.
|
| 54 |
+
|
| 55 |
+
Key design choices:
|
| 56 |
+
|
| 57 |
+
| Component | Description |
|
| 58 |
+
|---|---|
|
| 59 |
+
| **Sandwich Architecture** | First and last 4 layers use modality-specific projections; middle 32 layers share parameters across modalities |
|
| 60 |
+
| **Timestep-Free Denoising** | No explicit timestep embeddings — the model infers the denoising state directly from input latents |
|
| 61 |
+
| **Per-Head Gating** | Learned scalar gates with sigmoid activation on each attention head for training stability |
|
| 62 |
+
| **Unified Conditioning** | Denoising and reference signals handled through a minimal unified interface — no dedicated conditioning branches |
|
| 63 |
+
|
| 64 |
+
## Performance
|
| 65 |
+
|
| 66 |
+
### Quantitative Quality Benchmark
|
| 67 |
+
|
| 68 |
+
| Model | Visual Quality ↑ | Text Alignment ↑ | Physical Consistency ↑ | WER ↓ |
|
| 69 |
+
|---|:---:|:---:|:---:|:---:|
|
| 70 |
+
| OVI 1.1 | 4.73 | 4.10 | 4.41 | 40.45% |
|
| 71 |
+
| LTX 2.3 | 4.76 | 4.12 | **4.56** | 19.23% |
|
| 72 |
+
| **daVinci-MagiHuman** | **4.80** | **4.18** | 4.52 | **14.60%** |
|
| 73 |
+
|
| 74 |
+
### Human Evaluation (2,000 Pairwise Comparisons)
|
| 75 |
+
|
| 76 |
+
| Matchup | daVinci-MagiHuman Win | Tie | Opponent Win |
|
| 77 |
+
|---|:---:|:---:|:---:|
|
| 78 |
+
| vs Ovi 1.1 | **80.0%** | 8.2% | 11.8% |
|
| 79 |
+
| vs LTX 2.3 | **60.9%** | 17.2% | 21.9% |
|
| 80 |
+
|
| 81 |
+
### Inference Speed (Single H100 GPU, 5-second video)
|
| 82 |
+
|
| 83 |
+
| Resolution | Base (s) | Super-Res (s) | Decode (s) | **Total (s)** |
|
| 84 |
+
|---|:---:|:---:|:---:|:---:|
|
| 85 |
+
| 256p | 1.6 | — | 0.4 | **2.0** |
|
| 86 |
+
| 540p | 1.6 | 5.1 | 1.3 | **8.0** |
|
| 87 |
+
| 1080p | 1.6 | 31.0 | 5.8 | **38.4** |
|
| 88 |
+
|
| 89 |
+
## Efficient Inference Techniques
|
| 90 |
+
|
| 91 |
+
- **Latent-Space Super-Resolution** — Two-stage pipeline: generate at low resolution, then refine in latent space (not pixel space), avoiding an extra VAE decode-encode round trip.
|
| 92 |
+
- **Turbo VAE Decoder** — A lightweight re-trained decoder that substantially reduces decoding overhead.
|
| 93 |
+
- **Full-Graph Compilation** — [MagiCompiler](https://github.com/sandai/MagiCompiler) fuses operators across Transformer layers for ~1.2x speedup.
|
| 94 |
+
- **Distillation** — DMD-2 distillation enables generation with only 8 denoising steps (no CFG), without sacrificing quality.
|
| 95 |
+
|
| 96 |
+
## Getting Started
|
| 97 |
+
|
| 98 |
+
### Option 1: Docker (Recommended)
|
| 99 |
+
|
| 100 |
+
```bash
|
| 101 |
+
# Pull the MagiCompiler Docker image
|
| 102 |
+
docker pull sandai/magi-compiler:latest
|
| 103 |
+
|
| 104 |
+
# Launch container
|
| 105 |
+
docker run -it --gpus all \
|
| 106 |
+
-v /path/to/models:/models \
|
| 107 |
+
sandai/magi-compiler:latest bash
|
| 108 |
+
|
| 109 |
+
# Install MagiCompiler
|
| 110 |
+
git clone https://github.com/sandai/MagiCompiler
|
| 111 |
+
cd MagiCompiler
|
| 112 |
+
pip install -e . --no-build-isolation --config-settings editable_mode=compat
|
| 113 |
+
cd ..
|
| 114 |
+
|
| 115 |
+
# Clone daVinci-MagiHuman
|
| 116 |
+
git clone https://github.com/GAIR-NLP/daVinci-MagiHuman
|
| 117 |
+
cd daVinci-MagiHuman
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
### Option 2: Conda
|
| 121 |
+
|
| 122 |
+
```bash
|
| 123 |
+
# Create environment
|
| 124 |
+
conda create -n davinci python=3.12
|
| 125 |
+
conda activate davinci
|
| 126 |
+
|
| 127 |
+
# Install PyTorch
|
| 128 |
+
pip install torch==2.9.0 torchvision==0.24.0 torchaudio==2.9.0
|
| 129 |
+
|
| 130 |
+
# Install Flash Attention (Hopper)
|
| 131 |
+
git clone https://github.com/Dao-AILab/flash-attention
|
| 132 |
+
cd flash-attention/hopper && python setup.py install && cd ../..
|
| 133 |
+
|
| 134 |
+
# Install MagiCompiler
|
| 135 |
+
git clone https://github.com/sandai/MagiCompiler
|
| 136 |
+
cd MagiCompiler
|
| 137 |
+
pip install -e . --no-build-isolation --config-settings editable_mode=compat
|
| 138 |
+
cd ..
|
| 139 |
+
|
| 140 |
+
# Clone and install daVinci-MagiHuman
|
| 141 |
+
git clone https://github.com/GAIR-NLP/daVinci-MagiHuman
|
| 142 |
+
cd daVinci-MagiHuman
|
| 143 |
+
pip install -r requirements.txt
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
### Download Model Checkpoints
|
| 147 |
+
|
| 148 |
+
Download the complete model stack from [HuggingFace](https://huggingface.co/GAIR-NLP/daVinci-MagiHuman) and update the paths in the config files under `example/`.
|
| 149 |
+
|
| 150 |
+
## Usage
|
| 151 |
+
|
| 152 |
+
Before running, update the checkpoint paths in the config files (`example/*/config.json`) to point to your local model directory.
|
| 153 |
+
|
| 154 |
+
**Base Model (256p)**
|
| 155 |
+
```bash
|
| 156 |
+
bash example/base/run.sh
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
**Distilled Model (256p, 8 steps, no CFG)**
|
| 160 |
+
```bash
|
| 161 |
+
bash example/distill/run.sh
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
**Super-Resolution to 540p**
|
| 165 |
+
```bash
|
| 166 |
+
bash example/sr_540p/run.sh
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
**Super-Resolution to 1080p**
|
| 170 |
+
```bash
|
| 171 |
+
bash example/sr_1080p/run.sh
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
## Citation
|
| 175 |
+
|
| 176 |
+
```bibtex
|
| 177 |
+
@misc{davinci-magihuman-2025,
|
| 178 |
+
title = {Speed by Simplicity: A Single-Stream Architecture for Fast Audio-Video Generative Foundation Model},
|
| 179 |
+
author = {SII-GAIR and Sand.ai},
|
| 180 |
+
year = {2025},
|
| 181 |
+
url = {https://github.com/GAIR-NLP/daVinci-MagiHuman}
|
| 182 |
+
}
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
## Acknowledgements
|
| 186 |
+
|
| 187 |
+
daVinci-MagiHuman builds upon several outstanding open-source projects, including [Wan2.2](https://github.com/Wan-Video/Wan2.2), [Flash Attention](https://github.com/Dao-AILab/flash-attention), and [Turbo-VAED](https://github.com/zou-group/turbo-vaed). We thank the broader open-source community for making this work possible.
|
| 188 |
+
|
| 189 |
+
## License
|
| 190 |
+
|
| 191 |
+
This project is released under the [Apache License 2.0](https://opensource.org/licenses/Apache-2.0).
|
example/assets/image.png
ADDED
|
Git LFS Details
|
example/assets/prompt.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"A man with dark hair and glasses, wearing a green button-up shirt and black gloves, stands behind a counter with a pan, gesturing with his left hand, while a blonde woman with her hair in a bun, dressed in a white shirt, holds a microphone to her mouth, looking intently at the pan. The scene is set outdoors under a bright, overcast sky, with decorative palm tree cutouts and a light green vintage Volkswagen bus visible in the background, suggesting a relaxed, possibly tropical, cooking demonstration. The overall emotional disposition is one of focused engagement and professional presentation. The camera maintains a static medium shot, capturing both individuals from the waist up, with a shallow depth of field that keeps them sharp while blurring the background elements. The lighting is bright and even, typical of outdoor daylight, with soft shadows. The color grading is natural and vibrant, reflecting the outdoor setting. The man, with a slight smile, explains in a clear, steady, and informative tone, ""Pulver mit dran gemacht, gibt's ja auch als Paste, aber als Pulver ist das hier ein bisschen..."" as he gestures towards the pan with his left hand, his right hand resting on the counter. The woman listens attentively, her eyebrows slightly raised, her mouth slightly open in an expression of curiosity and concentration, her gaze fixed on the pan.
|
| 2 |
+
|
| 3 |
+
Dialogue:
|
| 4 |
+
<Man in green shirt, German>: ""Pulver mit dran gemacht, gibt's ja auch als Paste, aber als Pulver ist das hier ein bisschen...""
|
| 5 |
+
|
| 6 |
+
Background Sound:
|
| 7 |
+
<No prominent background sound effects>"
|
example/base/config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"engine_config": {
|
| 3 |
+
"load": "/home/niubility2/hongyu/open_source_ckpt/base",
|
| 4 |
+
"cp_size": 1
|
| 5 |
+
},
|
| 6 |
+
"evaluation_config": {
|
| 7 |
+
"cfg_number": 2,
|
| 8 |
+
"num_inference_steps": 32,
|
| 9 |
+
"audio_model_path": "/home/niubility2/hongyu/open_source_ckpt/audio",
|
| 10 |
+
"txt_model_path": "/home/niubility2/hongyu/open_source_ckpt/t5/t5gemma-9b-9b-ul2",
|
| 11 |
+
"vae_model_path": "/home/niubility2/hongyu/open_source_ckpt/wan_vae/Wan2.2-TI2V-5B",
|
| 12 |
+
"use_turbo_vae": true,
|
| 13 |
+
"student_config_path": "/home/niubility2/hongyu/open_source_ckpt/turbo_vae/TurboV3-Wan22-TinyShallow_7_7.json",
|
| 14 |
+
"student_ckpt_path": "/home/niubility2/hongyu/open_source_ckpt/turbo_vae/checkpoint-340000.ckpt"
|
| 15 |
+
}
|
| 16 |
+
}
|
example/base/run.sh
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
set -euo pipefail
|
| 4 |
+
|
| 5 |
+
PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
|
| 6 |
+
cd "$PROJECT_ROOT"
|
| 7 |
+
|
| 8 |
+
export MASTER_ADDR="${MASTER_ADDR:-localhost}"
|
| 9 |
+
export MASTER_PORT="${MASTER_PORT:-6009}"
|
| 10 |
+
export NNODES="${NNODES:-1}"
|
| 11 |
+
export NODE_RANK="${NODE_RANK:-0}"
|
| 12 |
+
export GPUS_PER_NODE="${GPUS_PER_NODE:-1}"
|
| 13 |
+
export WORLD_SIZE="$((GPUS_PER_NODE * NNODES))"
|
| 14 |
+
|
| 15 |
+
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
| 16 |
+
export NCCL_ALGO="${NCCL_ALGO:-^NVLS}"
|
| 17 |
+
export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}"
|
| 18 |
+
|
| 19 |
+
DISTRIBUTED_ARGS="--nnodes=${NNODES} --node_rank=${NODE_RANK} --nproc_per_node=${GPUS_PER_NODE} --rdzv-backend=c10d --rdzv-endpoint=${MASTER_ADDR}:${MASTER_PORT}"
|
| 20 |
+
|
| 21 |
+
torchrun ${DISTRIBUTED_ARGS} inference/pipeline/entry.py \
|
| 22 |
+
--config-load-path example/base/config.json \
|
| 23 |
+
--prompt "$(<example/assets/prompt.txt)" \
|
| 24 |
+
--image_path example/assets/image.png \
|
| 25 |
+
--seconds 10 \
|
| 26 |
+
--br_width 448 \
|
| 27 |
+
--br_height 256 \
|
| 28 |
+
--output_path "output_example_base_$(date '+%Y%m%d_%H%M%S')" \
|
| 29 |
+
2>&1 | tee "log_example_base_$(date '+%Y%m%d_%H%M%S').log"
|
example/distill/config.json
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"engine_config": {
|
| 3 |
+
"load": "/home/niubility2/hongyu/open_source_ckpt/distill",
|
| 4 |
+
"distill": true,
|
| 5 |
+
"cp_size": 1
|
| 6 |
+
},
|
| 7 |
+
"evaluation_config": {
|
| 8 |
+
"cfg_number": 1,
|
| 9 |
+
"num_inference_steps": 8,
|
| 10 |
+
"audio_model_path": "/home/niubility2/hongyu/open_source_ckpt/audio",
|
| 11 |
+
"txt_model_path": "/home/niubility2/hongyu/open_source_ckpt/t5/t5gemma-9b-9b-ul2",
|
| 12 |
+
"vae_model_path": "/home/niubility2/hongyu/open_source_ckpt/wan_vae/Wan2.2-TI2V-5B",
|
| 13 |
+
"use_turbo_vae": true,
|
| 14 |
+
"student_config_path": "/home/niubility2/hongyu/open_source_ckpt/turbo_vae/TurboV3-Wan22-TinyShallow_7_7.json",
|
| 15 |
+
"student_ckpt_path": "/home/niubility2/hongyu/open_source_ckpt/turbo_vae/checkpoint-340000.ckpt"
|
| 16 |
+
}
|
| 17 |
+
}
|
example/distill/run.sh
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
set -euo pipefail
|
| 4 |
+
|
| 5 |
+
PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
|
| 6 |
+
cd "$PROJECT_ROOT"
|
| 7 |
+
|
| 8 |
+
export MASTER_ADDR="${MASTER_ADDR:-localhost}"
|
| 9 |
+
export MASTER_PORT="${MASTER_PORT:-6010}"
|
| 10 |
+
export NNODES="${NNODES:-1}"
|
| 11 |
+
export NODE_RANK="${NODE_RANK:-0}"
|
| 12 |
+
export GPUS_PER_NODE="${GPUS_PER_NODE:-1}"
|
| 13 |
+
export WORLD_SIZE="$((GPUS_PER_NODE * NNODES))"
|
| 14 |
+
|
| 15 |
+
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
| 16 |
+
export NCCL_ALGO="${NCCL_ALGO:-^NVLS}"
|
| 17 |
+
export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}"
|
| 18 |
+
|
| 19 |
+
DISTRIBUTED_ARGS="--nnodes=${NNODES} --node_rank=${NODE_RANK} --nproc_per_node=${GPUS_PER_NODE} --rdzv-backend=c10d --rdzv-endpoint=${MASTER_ADDR}:${MASTER_PORT}"
|
| 20 |
+
|
| 21 |
+
torchrun ${DISTRIBUTED_ARGS} inference/pipeline/entry.py \
|
| 22 |
+
--config-load-path example/distill/config.json \
|
| 23 |
+
--prompt "$(<example/assets/prompt.txt)" \
|
| 24 |
+
--image_path example/assets/image.png \
|
| 25 |
+
--seconds 10 \
|
| 26 |
+
--br_width 448 \
|
| 27 |
+
--br_height 256 \
|
| 28 |
+
--output_path "output_example_distill_$(date '+%Y%m%d_%H%M%S')" \
|
| 29 |
+
2>&1 | tee "log_example_distill_$(date '+%Y%m%d_%H%M%S').log"
|
example/sr_1080p/config.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"engine_config": {
|
| 3 |
+
"load": "/home/niubility2/hongyu/open_source_ckpt/base",
|
| 4 |
+
"cp_size": 1
|
| 5 |
+
},
|
| 6 |
+
"evaluation_config": {
|
| 7 |
+
"cfg_number": 2,
|
| 8 |
+
"num_inference_steps": 32,
|
| 9 |
+
"audio_model_path": "/home/niubility2/hongyu/open_source_ckpt/audio",
|
| 10 |
+
"txt_model_path": "/home/niubility2/hongyu/open_source_ckpt/t5/t5gemma-9b-9b-ul2",
|
| 11 |
+
"vae_model_path": "/home/niubility2/hongyu/open_source_ckpt/wan_vae/Wan2.2-TI2V-5B",
|
| 12 |
+
"use_sr_model": true,
|
| 13 |
+
"sr_model_path": "/home/niubility2/hongyu/open_source_ckpt/1080p_sr",
|
| 14 |
+
"sr_num_inference_steps": 5,
|
| 15 |
+
"sr_cfg_number": 1,
|
| 16 |
+
"use_turbo_vae": true,
|
| 17 |
+
"student_config_path": "/home/niubility2/hongyu/open_source_ckpt/turbo_vae/TurboV3-Wan22-TinyShallow_7_7.json",
|
| 18 |
+
"student_ckpt_path": "/home/niubility2/hongyu/open_source_ckpt/turbo_vae/checkpoint-340000.ckpt"
|
| 19 |
+
}
|
| 20 |
+
}
|
example/sr_1080p/run.sh
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
set -euo pipefail
|
| 4 |
+
|
| 5 |
+
PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
|
| 6 |
+
cd "$PROJECT_ROOT"
|
| 7 |
+
|
| 8 |
+
export MASTER_ADDR="${MASTER_ADDR:-localhost}"
|
| 9 |
+
export MASTER_PORT="${MASTER_PORT:-6012}"
|
| 10 |
+
export NNODES="${NNODES:-1}"
|
| 11 |
+
export NODE_RANK="${NODE_RANK:-0}"
|
| 12 |
+
export GPUS_PER_NODE="${GPUS_PER_NODE:-1}"
|
| 13 |
+
export WORLD_SIZE="$((GPUS_PER_NODE * NNODES))"
|
| 14 |
+
|
| 15 |
+
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
| 16 |
+
export NCCL_ALGO="${NCCL_ALGO:-^NVLS}"
|
| 17 |
+
export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}"
|
| 18 |
+
export SR2_1080="${SR2_1080:-true}"
|
| 19 |
+
export CPU_OFFLOAD="${CPU_OFFLOAD:-true}"
|
| 20 |
+
|
| 21 |
+
DISTRIBUTED_ARGS="--nnodes=${NNODES} --node_rank=${NODE_RANK} --nproc_per_node=${GPUS_PER_NODE} --rdzv-backend=c10d --rdzv-endpoint=${MASTER_ADDR}:${MASTER_PORT}"
|
| 22 |
+
|
| 23 |
+
torchrun ${DISTRIBUTED_ARGS} inference/pipeline/entry.py \
|
| 24 |
+
--config-load-path example/sr_1080p/config.json \
|
| 25 |
+
--prompt "$(<example/assets/prompt.txt)" \
|
| 26 |
+
--image_path example/assets/image.png \
|
| 27 |
+
--seconds 10 \
|
| 28 |
+
--br_width 448 \
|
| 29 |
+
--br_height 256 \
|
| 30 |
+
--output_path "output_example_sr_1080p_$(date '+%Y%m%d_%H%M%S')" \
|
| 31 |
+
--sr_width 1920 \
|
| 32 |
+
--sr_height 1088 \
|
| 33 |
+
2>&1 | tee "log_example_sr_1080p_$(date '+%Y%m%d_%H%M%S').log"
|
example/sr_540p/config.json
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"engine_config": {
|
| 3 |
+
"load": "/home/niubility2/hongyu/open_source_ckpt/base",
|
| 4 |
+
"cp_size": 1
|
| 5 |
+
},
|
| 6 |
+
"evaluation_config": {
|
| 7 |
+
"cfg_number": 2,
|
| 8 |
+
"num_inference_steps": 32,
|
| 9 |
+
"audio_model_path": "/home/niubility2/hongyu/open_source_ckpt/audio",
|
| 10 |
+
"txt_model_path": "/home/niubility2/hongyu/open_source_ckpt/t5/t5gemma-9b-9b-ul2",
|
| 11 |
+
"vae_model_path": "/home/niubility2/hongyu/open_source_ckpt/wan_vae/Wan2.2-TI2V-5B",
|
| 12 |
+
"use_sr_model": true,
|
| 13 |
+
"sr_model_path": "/home/niubility2/hongyu/open_source_ckpt/540p_sr",
|
| 14 |
+
"sr_num_inference_steps": 5,
|
| 15 |
+
"sr_cfg_number": 1,
|
| 16 |
+
"use_turbo_vae": true,
|
| 17 |
+
"student_config_path": "/home/niubility2/hongyu/open_source_ckpt/turbo_vae/TurboV3-Wan22-TinyShallow_7_7.json",
|
| 18 |
+
"student_ckpt_path": "/home/niubility2/hongyu/open_source_ckpt/turbo_vae/checkpoint-340000.ckpt"
|
| 19 |
+
}
|
| 20 |
+
}
|
example/sr_540p/run.sh
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
set -euo pipefail
|
| 4 |
+
|
| 5 |
+
PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
|
| 6 |
+
cd "$PROJECT_ROOT"
|
| 7 |
+
|
| 8 |
+
export MASTER_ADDR="${MASTER_ADDR:-localhost}"
|
| 9 |
+
export MASTER_PORT="${MASTER_PORT:-6011}"
|
| 10 |
+
export NNODES="${NNODES:-1}"
|
| 11 |
+
export NODE_RANK="${NODE_RANK:-0}"
|
| 12 |
+
export GPUS_PER_NODE="${GPUS_PER_NODE:-1}"
|
| 13 |
+
export WORLD_SIZE="$((GPUS_PER_NODE * NNODES))"
|
| 14 |
+
|
| 15 |
+
export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
|
| 16 |
+
export NCCL_ALGO="${NCCL_ALGO:-^NVLS}"
|
| 17 |
+
export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}"
|
| 18 |
+
export CPU_OFFLOAD=true
|
| 19 |
+
|
| 20 |
+
DISTRIBUTED_ARGS="--nnodes=${NNODES} --node_rank=${NODE_RANK} --nproc_per_node=${GPUS_PER_NODE} --rdzv-backend=c10d --rdzv-endpoint=${MASTER_ADDR}:${MASTER_PORT}"
|
| 21 |
+
|
| 22 |
+
torchrun ${DISTRIBUTED_ARGS} inference/pipeline/entry.py \
|
| 23 |
+
--config-load-path example/sr_540p/config.json \
|
| 24 |
+
--prompt "$(<example/assets/prompt.txt)" \
|
| 25 |
+
--image_path example/assets/image.png \
|
| 26 |
+
--seconds 10 \
|
| 27 |
+
--br_width 448 \
|
| 28 |
+
--br_height 256 \
|
| 29 |
+
--output_path "output_example_sr_540p_$(date '+%Y%m%d_%H%M%S')" \
|
| 30 |
+
--sr_width 896 \
|
| 31 |
+
--sr_height 512 \
|
| 32 |
+
2>&1 | tee "log_example_sr_540p_$(date '+%Y%m%d_%H%M%S').log"
|
inference/common/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .arch import get_arch_memory, is_hopper_arch
|
| 16 |
+
from .config import (
|
| 17 |
+
DataProxyConfig,
|
| 18 |
+
EngineConfig,
|
| 19 |
+
EvaluationConfig,
|
| 20 |
+
parse_config,
|
| 21 |
+
)
|
| 22 |
+
from .cpu_offload_wrapper import CPUOffloadWrapper
|
| 23 |
+
from .sequence_schema import Modality, VarlenHandler
|
| 24 |
+
|
| 25 |
+
__all__ = [
|
| 26 |
+
# arch
|
| 27 |
+
"get_arch_memory",
|
| 28 |
+
"is_hopper_arch",
|
| 29 |
+
# config
|
| 30 |
+
"EngineConfig",
|
| 31 |
+
"DataProxyConfig",
|
| 32 |
+
"EvaluationConfig",
|
| 33 |
+
"parse_config",
|
| 34 |
+
# cpu offload wrapper
|
| 35 |
+
"CPUOffloadWrapper",
|
| 36 |
+
# sequence schema
|
| 37 |
+
"Modality",
|
| 38 |
+
"VarlenHandler",
|
| 39 |
+
]
|
inference/common/arch.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def is_hopper_arch():
|
| 19 |
+
return torch.cuda.get_device_capability()[0] == 9
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_arch_memory(unit: str = "GB"):
|
| 23 |
+
if not torch.cuda.is_available():
|
| 24 |
+
return 0
|
| 25 |
+
total_bytes = torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory
|
| 26 |
+
if unit == "B":
|
| 27 |
+
return float(total_bytes)
|
| 28 |
+
elif unit == "KB":
|
| 29 |
+
return total_bytes / 1024
|
| 30 |
+
elif unit == "MB":
|
| 31 |
+
return total_bytes / 1024 / 1024
|
| 32 |
+
elif unit == "GB":
|
| 33 |
+
return total_bytes / 1024 / 1024 / 1024
|
| 34 |
+
else:
|
| 35 |
+
raise ValueError(f"Invalid unit: {unit}")
|
inference/common/config.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import copy
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Literal, Tuple
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from inference.utils import env_is_true, print_rank_0
|
| 25 |
+
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_validator
|
| 26 |
+
from pydantic_settings import (
|
| 27 |
+
BaseSettings,
|
| 28 |
+
CliSettingsSource,
|
| 29 |
+
JsonConfigSettingsSource,
|
| 30 |
+
PydanticBaseSettingsSource,
|
| 31 |
+
SettingsConfigDict,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class EngineConfig(BaseModel):
|
| 36 |
+
# Basic settings
|
| 37 |
+
seed: int = Field(1234, description="Random seed used for python, numpy, pytorch, and cuda.")
|
| 38 |
+
load: str | None = Field(None, description="Directory containing a model checkpoint.")
|
| 39 |
+
|
| 40 |
+
# Parallelism strategy
|
| 41 |
+
distributed_backend: Literal["nccl", "gloo"] = Field("nccl", description="Distributed backend. Choices: ['nccl', 'gloo'].")
|
| 42 |
+
distributed_timeout_minutes: int = Field(10, description="Timeout minutes for torch.distributed.")
|
| 43 |
+
sequence_parallel: bool = Field(False, description="Enable sequence parallel optimization.")
|
| 44 |
+
tp_size: int = Field(1, description="Degree of tensor model parallelism.")
|
| 45 |
+
pp_size: int = Field(1, description="Degree of pipeline model parallelism.")
|
| 46 |
+
cp_size: int = Field(1, description="Degree of context parallelism.")
|
| 47 |
+
dp_size: int = Field(1, description="Degree of data parallelism.")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ModelConfig(BaseModel):
|
| 51 |
+
"""Model configuration class defining various parameters for video generation model"""
|
| 52 |
+
|
| 53 |
+
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|
| 54 |
+
|
| 55 |
+
num_layers: int = Field(default=40, description="Number of Transformer layers")
|
| 56 |
+
hidden_size: int = Field(default=5120, description="Hidden size of the Transformer model")
|
| 57 |
+
head_dim: int = Field(default=128, description="Dimension per attention head")
|
| 58 |
+
num_query_groups: int = Field(default=8, description="Number of query groups for grouped-query attention")
|
| 59 |
+
video_in_channels: int = Field(default=48 * 4, description="Number of video input channels after patch embedding")
|
| 60 |
+
audio_in_channels: int = Field(default=64, description="Number of audio input channels")
|
| 61 |
+
text_in_channels: int = Field(default=3584, description="Number of text input channels")
|
| 62 |
+
checkpoint_qk_layernorm_rope: bool = Field(default=False, description="Enable checkpointing for QK layernorm + RoPE")
|
| 63 |
+
params_dtype: torch.dtype | str = Field(default=torch.float32, description="Parameter dtype")
|
| 64 |
+
tread_config: dict = Field(
|
| 65 |
+
default=dict(
|
| 66 |
+
selection_rate=0.5, start_layer_idx=2, end_layer_idx=25 # after forward of 0, 1 # before forward of 26 27 28 29
|
| 67 |
+
),
|
| 68 |
+
description="TReAD (Token Routing and Early Drop) configuration",
|
| 69 |
+
)
|
| 70 |
+
mm_layers: list[int] = Field(default=[0, 1, 2, 3, 36, 37, 38, 39], description="Indices of multimodal fusion layers")
|
| 71 |
+
local_attn_layers: list[int] = Field(default=[], description="Indices of local attention layers")
|
| 72 |
+
enable_attn_gating: bool = Field(default=True, description="Enable attention gating")
|
| 73 |
+
activation_type: str = Field(default="swiglu7", description="Activation type")
|
| 74 |
+
gelu7_layers: list[int] = Field(default=[0, 1, 2, 3], description="Indices of gelu7 layers")
|
| 75 |
+
|
| 76 |
+
# Add computed fields
|
| 77 |
+
num_heads_q: int = Field(default=0, description="Number of query heads (calculated from hidden_size // head_dim)")
|
| 78 |
+
num_heads_kv: int = Field(default=0, description="Number of key-value heads (calculated from num_query_groups)")
|
| 79 |
+
post_norm_layers: list[int] = Field(default=[], description="Indices of post norm layers")
|
| 80 |
+
|
| 81 |
+
@field_serializer("params_dtype")
|
| 82 |
+
def serialize_dtype(self, value: torch.dtype | str) -> str:
|
| 83 |
+
return str(value)
|
| 84 |
+
|
| 85 |
+
@field_validator("params_dtype", mode="before")
|
| 86 |
+
@classmethod
|
| 87 |
+
def validate_dtype(cls, value):
|
| 88 |
+
if isinstance(value, torch.dtype):
|
| 89 |
+
return value
|
| 90 |
+
if isinstance(value, str):
|
| 91 |
+
if value == "torch.float32" or value == "float32":
|
| 92 |
+
return torch.float32
|
| 93 |
+
elif value == "torch.float16" or value == "float16":
|
| 94 |
+
return torch.float16
|
| 95 |
+
elif value == "torch.bfloat16" or value == "bfloat16":
|
| 96 |
+
return torch.bfloat16
|
| 97 |
+
raise ValueError(f"Unknown torch.dtype string: '{value}'")
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class DataProxyConfig(BaseModel):
|
| 101 |
+
t_patch_size: int = Field(default=1, description="Patch size for time dimension")
|
| 102 |
+
patch_size: int = Field(default=2, description="Patch size for spatial dimensions")
|
| 103 |
+
frame_receptive_field: int = Field(default=11, description="Frame receptive field")
|
| 104 |
+
spatial_rope_interpolation: Literal["inter", "extra"] = Field(
|
| 105 |
+
default="extra", description="Spatial rope interpolation method."
|
| 106 |
+
)
|
| 107 |
+
ref_audio_offset: int = Field(default=1000, description="Offset for reference audio.")
|
| 108 |
+
text_offset: int = Field(default=0, description="Offset for text.")
|
| 109 |
+
coords_style: Literal["v1", "v2"] = Field(default="v2", description="Coords style.")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class EvaluationConfig(BaseModel):
|
| 113 |
+
"""Evaluation configuration class defining parameters for model evaluation and inference"""
|
| 114 |
+
|
| 115 |
+
model_config = ConfigDict(protected_namespaces=())
|
| 116 |
+
|
| 117 |
+
data_proxy_config: DataProxyConfig = Field(default=DataProxyConfig(), description="Data proxy configuration")
|
| 118 |
+
|
| 119 |
+
fps: int = Field(default=25, description="Frames per second for video generation")
|
| 120 |
+
num_inference_steps: int = Field(default=32, description="Number of denoising steps during inference")
|
| 121 |
+
video_txt_guidance_scale: float = Field(default=5.0, description="Video text guidance scale for text conditioning")
|
| 122 |
+
audio_txt_guidance_scale: float = Field(default=5.0, description="Audio text guidance scale for text conditioning")
|
| 123 |
+
txt_encoder_type: Literal["t5_gemma"] = Field(default="t5_gemma", description="Text encoder type.")
|
| 124 |
+
t5_gemma_target_length: int = Field(default=640, description="Target length for T5-Gemma encoder.")
|
| 125 |
+
support_ref_audio: bool = Field(default=True, description="Whether to support the ref_audio feature")
|
| 126 |
+
shift: float = Field(default=5.0, description="Temporal shift parameter for video generation")
|
| 127 |
+
exp_name: str = Field(default="exp_debug", description="Experiment name with evaluation suffix")
|
| 128 |
+
audio_model_path: str = Field(default="", description="Path to the pretrained audio model")
|
| 129 |
+
txt_model_path: str = Field(default="", description="Path to the pretrained txt model")
|
| 130 |
+
vae_model_path: str = Field(default="", description="Path to the pretrained vae model")
|
| 131 |
+
vae_stride: Tuple[int, int, int] = Field(default=(4, 16, 16), description="VAE stride in format (time, height, width)")
|
| 132 |
+
z_dim: int = Field(default=48, description="Dimension of z space.")
|
| 133 |
+
patch_size: Tuple[int, int, int] = Field(default=(1, 2, 2), description="Patch size in format (time, height, width)")
|
| 134 |
+
cfg_number: int = Field(default=2, description="Classifier-free guidance number")
|
| 135 |
+
sr_cfg_number: int = Field(default=2, description="SR Classifier-free guidance number")
|
| 136 |
+
|
| 137 |
+
# flops recording
|
| 138 |
+
enable_flops_recording: bool = Field(default=False, description="Whether to enable flops recording")
|
| 139 |
+
|
| 140 |
+
# super resolution model configuration
|
| 141 |
+
use_sr_model: bool = Field(default=False, description="Whether to use the super resolution model")
|
| 142 |
+
sr_model_path: str = Field(default="", description="Path to the pretrained super resolution model")
|
| 143 |
+
sr_num_inference_steps: int = Field(default=5, description="Number of denoising steps during super resolution inference")
|
| 144 |
+
noise_value: int = Field(default=220, description="Noise value for the super resolution model")
|
| 145 |
+
sr_video_txt_guidance_scale: float = Field(
|
| 146 |
+
default=3.5, description="Super resolution video text guidance scale for text conditioning"
|
| 147 |
+
)
|
| 148 |
+
use_cfg_trick: bool = Field(default=True, description="Whether to use the cfg trick")
|
| 149 |
+
cfg_trick_start_frame: int = Field(default=13, description="Start frame for the cfg trick")
|
| 150 |
+
cfg_trick_value: float = Field(default=2.0, description="Value for the cfg trick")
|
| 151 |
+
using_sde_flag: bool = Field(default=False, description="Whether to use the sde flag")
|
| 152 |
+
sr_audio_noise_scale: float = Field(default=0.7, description="Noise scale for the super resolution audio")
|
| 153 |
+
|
| 154 |
+
# turbo-vae config
|
| 155 |
+
use_turbo_vae: bool = Field(default=True, description="Whether to use the turbo-vae")
|
| 156 |
+
student_config_path: str = Field(default="", description="Path to the student config")
|
| 157 |
+
student_ckpt_path: str = Field(default="", description="Path to the student checkpoint")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class MagiPipelineConfig(BaseSettings):
|
| 161 |
+
engine_config: EngineConfig = Field(description="Engine configuration.", default_factory=EngineConfig)
|
| 162 |
+
arch_config: ModelConfig = Field(default=ModelConfig(), description="Model configuration.")
|
| 163 |
+
evaluation_config: EvaluationConfig = Field(default=EvaluationConfig(), description="Evaluation configuration.")
|
| 164 |
+
sr_arch_config: ModelConfig = Field(default=ModelConfig(), description="Super resolution model configuration.")
|
| 165 |
+
model_config = SettingsConfigDict(cli_parse_args=True, cli_ignore_unknown_args=True, cli_implicit_flags=True)
|
| 166 |
+
|
| 167 |
+
@classmethod
|
| 168 |
+
def settings_customise_sources(
|
| 169 |
+
cls,
|
| 170 |
+
settings_cls: type[BaseSettings],
|
| 171 |
+
init_settings: PydanticBaseSettingsSource,
|
| 172 |
+
env_settings: PydanticBaseSettingsSource,
|
| 173 |
+
dotenv_settings: PydanticBaseSettingsSource,
|
| 174 |
+
file_secret_settings: PydanticBaseSettingsSource,
|
| 175 |
+
) -> tuple[PydanticBaseSettingsSource, ...]:
|
| 176 |
+
parser = argparse.ArgumentParser(allow_abbrev=False)
|
| 177 |
+
parser.add_argument("--config-load-path", type=str, default=None, help="Path to load the config.json from")
|
| 178 |
+
args, _ = parser.parse_known_args()
|
| 179 |
+
config_load_path = args.config_load_path
|
| 180 |
+
sources = [env_settings, CliSettingsSource(settings_cls, cli_parse_args=True, cli_ignore_unknown_args=True)]
|
| 181 |
+
if config_load_path:
|
| 182 |
+
sources.append(JsonConfigSettingsSource(settings_cls, json_file=config_load_path))
|
| 183 |
+
|
| 184 |
+
sources.extend([init_settings, dotenv_settings, file_secret_settings])
|
| 185 |
+
return tuple(sources)
|
| 186 |
+
|
| 187 |
+
def save_to_json(self, json_path: str, indent: int = 4):
|
| 188 |
+
path = Path(json_path)
|
| 189 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 190 |
+
path.write_text(self.__str__(indent=indent))
|
| 191 |
+
|
| 192 |
+
def __str__(self, indent: int = 4):
|
| 193 |
+
data = self.model_dump(mode="json")
|
| 194 |
+
formatted = json.dumps(data, indent=indent, ensure_ascii=False, sort_keys=False)
|
| 195 |
+
class_name = self.__class__.__name__
|
| 196 |
+
return f"{class_name}:\n{formatted}".replace('"', "")
|
| 197 |
+
|
| 198 |
+
def __repr__(self, indent: int = 4):
|
| 199 |
+
return self.__str__(indent=indent)
|
| 200 |
+
|
| 201 |
+
@model_validator(mode="after")
|
| 202 |
+
def validate_engine_config(self):
|
| 203 |
+
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
| 204 |
+
self.engine_config.dp_size = world_size // (
|
| 205 |
+
self.engine_config.tp_size * self.engine_config.pp_size * self.engine_config.cp_size
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
assert world_size % self.engine_config.tp_size == 0
|
| 209 |
+
tp_pp_size = self.engine_config.tp_size * self.engine_config.pp_size
|
| 210 |
+
assert world_size % tp_pp_size == 0
|
| 211 |
+
tp_pp_cp_size = tp_pp_size * self.engine_config.cp_size
|
| 212 |
+
assert world_size % tp_pp_cp_size == 0
|
| 213 |
+
assert world_size == self.engine_config.dp_size * tp_pp_cp_size
|
| 214 |
+
|
| 215 |
+
if self.engine_config.tp_size == 1:
|
| 216 |
+
self.engine_config.sequence_parallel = False
|
| 217 |
+
|
| 218 |
+
return self
|
| 219 |
+
|
| 220 |
+
@model_validator(mode="after")
|
| 221 |
+
def post_override_config(self):
|
| 222 |
+
self.arch_config.num_heads_q = self.arch_config.hidden_size // self.arch_config.head_dim
|
| 223 |
+
self.arch_config.num_heads_kv = self.arch_config.num_query_groups
|
| 224 |
+
|
| 225 |
+
self.sr_arch_config = copy.deepcopy(self.arch_config)
|
| 226 |
+
if env_is_true("SR2_1080"):
|
| 227 |
+
self.sr_arch_config = copy.deepcopy(self.arch_config)
|
| 228 |
+
# fmt: off
|
| 229 |
+
self.sr_arch_config.local_attn_layers = [
|
| 230 |
+
0, 1, 2,
|
| 231 |
+
4, 5, 6,
|
| 232 |
+
8, 9, 10,
|
| 233 |
+
12, 13, 14,
|
| 234 |
+
16, 17, 18,
|
| 235 |
+
20, 21, 22,
|
| 236 |
+
24, 25, 26,
|
| 237 |
+
28, 29, 30,
|
| 238 |
+
32, 33, 34,
|
| 239 |
+
35, 36, 37,
|
| 240 |
+
38, 39,
|
| 241 |
+
]
|
| 242 |
+
# fmt: on
|
| 243 |
+
self.evaluation_config.sr_video_txt_guidance_scale = 3.5
|
| 244 |
+
|
| 245 |
+
return self
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def prevent_unsupported_list_syntax():
|
| 249 |
+
"""
|
| 250 |
+
Check sys.argv before Pydantic parsing to prevent using unsupported list syntax.
|
| 251 |
+
"""
|
| 252 |
+
args = sys.argv[1:]
|
| 253 |
+
for i, arg in enumerate(args):
|
| 254 |
+
if i + 2 < len(args):
|
| 255 |
+
value1, value2 = args[i + 1], args[i + 2]
|
| 256 |
+
if not value1.startswith("-") and not value2.startswith("-"):
|
| 257 |
+
error_msg = (
|
| 258 |
+
f"\n\nError: Detected list parameter '{arg}' using unsupported command line syntax.\n"
|
| 259 |
+
f"Error pattern: '{arg} {value1} {value2} ...'\n\n"
|
| 260 |
+
"Pydantic (or related libraries) do not support passing lists with space-separated multiple values.\n"
|
| 261 |
+
"Please use one of the following supported formats:\n\n"
|
| 262 |
+
f"1. JSON style: {arg} '[{value1},{value2},...]'\n"
|
| 263 |
+
f"2. Argparse style: {arg} {value1} {arg} {value2}\n"
|
| 264 |
+
f"3. Lazy style: {arg} {value1},{value2}\n"
|
| 265 |
+
)
|
| 266 |
+
raise ValueError(error_msg)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def parse_config(verbose: bool = False) -> MagiPipelineConfig:
|
| 270 |
+
parser = argparse.ArgumentParser(description="Load and optionally save config", allow_abbrev=False)
|
| 271 |
+
parser.add_argument("--config-save-path", type=str, default=None, help="Path to save the config.json to")
|
| 272 |
+
args, _ = parser.parse_known_args()
|
| 273 |
+
|
| 274 |
+
prevent_unsupported_list_syntax()
|
| 275 |
+
config = MagiPipelineConfig()
|
| 276 |
+
|
| 277 |
+
if args.config_save_path is not None:
|
| 278 |
+
config.save_to_json(args.config_save_path)
|
| 279 |
+
|
| 280 |
+
if verbose:
|
| 281 |
+
print_rank_0(config)
|
| 282 |
+
|
| 283 |
+
return config
|
inference/common/cpu_offload_wrapper.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Any, Callable, Dict, Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class CPUOffloadWrapper:
|
| 21 |
+
def __init__(self, model: Any, is_cpu_offload: bool = False, is_running_on_gpu: bool = True):
|
| 22 |
+
object.__setattr__(self, "model", model)
|
| 23 |
+
object.__setattr__(self, "is_cpu_offload", is_cpu_offload)
|
| 24 |
+
object.__setattr__(self, "is_running_on_gpu", is_running_on_gpu)
|
| 25 |
+
|
| 26 |
+
cpu_device = torch.device("cpu")
|
| 27 |
+
cuda_device = torch.device("cuda")
|
| 28 |
+
object.__setattr__(self, "cpu_device", cpu_device)
|
| 29 |
+
object.__setattr__(self, "cuda_device", cuda_device)
|
| 30 |
+
|
| 31 |
+
# Initialize placement location
|
| 32 |
+
if is_cpu_offload:
|
| 33 |
+
self.model.to(cpu_device)
|
| 34 |
+
else:
|
| 35 |
+
self.model.to(cuda_device)
|
| 36 |
+
|
| 37 |
+
# Whitelist non-compute methods that shouldn't trigger device hops (pass-through only; no device switch)
|
| 38 |
+
object.__setattr__(
|
| 39 |
+
self,
|
| 40 |
+
"_non_compute_methods",
|
| 41 |
+
{
|
| 42 |
+
"to",
|
| 43 |
+
"cpu",
|
| 44 |
+
"cuda",
|
| 45 |
+
"eval",
|
| 46 |
+
"train",
|
| 47 |
+
"state_dict",
|
| 48 |
+
"load_state_dict",
|
| 49 |
+
"parameters",
|
| 50 |
+
"named_parameters",
|
| 51 |
+
"buffers",
|
| 52 |
+
"named_buffers",
|
| 53 |
+
"modules",
|
| 54 |
+
"named_modules",
|
| 55 |
+
"children",
|
| 56 |
+
"named_children",
|
| 57 |
+
"register_forward_hook",
|
| 58 |
+
"register_forward_pre_hook",
|
| 59 |
+
"register_full_backward_hook",
|
| 60 |
+
"zero_grad",
|
| 61 |
+
"share_memory",
|
| 62 |
+
"half",
|
| 63 |
+
"float",
|
| 64 |
+
"bfloat16",
|
| 65 |
+
},
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# Get current primary device (for external reads)
|
| 69 |
+
@property
|
| 70 |
+
def device(self) -> torch.device:
|
| 71 |
+
if isinstance(self.model, torch.nn.Module):
|
| 72 |
+
return next(self.model.parameters()).device
|
| 73 |
+
else:
|
| 74 |
+
for k, v in self.model.__dict__.items():
|
| 75 |
+
if isinstance(v, torch.Tensor):
|
| 76 |
+
return v.device
|
| 77 |
+
elif isinstance(v, torch.nn.Module):
|
| 78 |
+
return next(v.parameters()).device
|
| 79 |
+
return self.cuda_device
|
| 80 |
+
|
| 81 |
+
def _backup_cpu_state(self) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, Any]]:
|
| 82 |
+
# Backup module parameters and buffers
|
| 83 |
+
module_param_backup = {}
|
| 84 |
+
module_buffer_backup = {}
|
| 85 |
+
other_backup = {}
|
| 86 |
+
|
| 87 |
+
def save_module_state(mod: torch.nn.Module, prefix: str):
|
| 88 |
+
for name, param in mod.named_parameters():
|
| 89 |
+
if param is not None:
|
| 90 |
+
full_key = prefix + name
|
| 91 |
+
module_param_backup[full_key] = param.data
|
| 92 |
+
for name, buffer in mod.named_buffers():
|
| 93 |
+
if buffer is not None:
|
| 94 |
+
full_key = prefix + name
|
| 95 |
+
module_buffer_backup[full_key] = buffer.data
|
| 96 |
+
|
| 97 |
+
if isinstance(self.model, torch.nn.Module):
|
| 98 |
+
save_module_state(self.model, "")
|
| 99 |
+
else:
|
| 100 |
+
for name, attr_val in self.model.__dict__.items():
|
| 101 |
+
if isinstance(attr_val, torch.nn.Module):
|
| 102 |
+
save_module_state(attr_val, name + ".")
|
| 103 |
+
elif isinstance(attr_val, torch.Tensor):
|
| 104 |
+
other_backup[name] = attr_val
|
| 105 |
+
|
| 106 |
+
return module_param_backup, module_buffer_backup, other_backup
|
| 107 |
+
|
| 108 |
+
def _restore_cpu_state(self, backups: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, Any]]):
|
| 109 |
+
# Restore module parameters and buffers
|
| 110 |
+
module_param_backup, module_buffer_backup, other_backup = backups
|
| 111 |
+
|
| 112 |
+
def restore_module_state(mod: torch.nn.Module, prefix: str):
|
| 113 |
+
for name, param in mod.named_parameters():
|
| 114 |
+
full_key = prefix + name
|
| 115 |
+
if full_key in module_param_backup:
|
| 116 |
+
param.data = module_param_backup[full_key]
|
| 117 |
+
|
| 118 |
+
for name, buffer in mod.named_buffers():
|
| 119 |
+
full_key = prefix + name
|
| 120 |
+
if full_key in module_buffer_backup:
|
| 121 |
+
buffer.data = module_buffer_backup[full_key]
|
| 122 |
+
|
| 123 |
+
if isinstance(self.model, torch.nn.Module):
|
| 124 |
+
restore_module_state(self.model, "")
|
| 125 |
+
else:
|
| 126 |
+
for name, attr_val in self.model.__dict__.items():
|
| 127 |
+
if isinstance(attr_val, torch.nn.Module):
|
| 128 |
+
restore_module_state(attr_val, name + ".")
|
| 129 |
+
|
| 130 |
+
if not isinstance(self.model, torch.nn.Module):
|
| 131 |
+
for name, val in other_backup.items():
|
| 132 |
+
setattr(self.model, name, val)
|
| 133 |
+
|
| 134 |
+
# Unified on/offload executor
|
| 135 |
+
def _run_with_optional_offload(self, func: Callable[..., Any], *args, **kwargs):
|
| 136 |
+
if self.is_cpu_offload and self.is_running_on_gpu:
|
| 137 |
+
backups = self._backup_cpu_state()
|
| 138 |
+
self.model.to(self.cuda_device)
|
| 139 |
+
try:
|
| 140 |
+
return func(*args, **kwargs)
|
| 141 |
+
finally:
|
| 142 |
+
if torch.cuda.is_available():
|
| 143 |
+
torch.cuda.synchronize()
|
| 144 |
+
self._restore_cpu_state(backups)
|
| 145 |
+
else:
|
| 146 |
+
# Make sure model and args are on the same device
|
| 147 |
+
args = [
|
| 148 |
+
arg.to(self.device) if isinstance(arg, torch.Tensor) and arg.device != self.device else arg for arg in args
|
| 149 |
+
]
|
| 150 |
+
kwargs = {
|
| 151 |
+
k: v.to(self.device) if isinstance(v, torch.Tensor) and v.device != self.device else v
|
| 152 |
+
for k, v in kwargs.items()
|
| 153 |
+
}
|
| 154 |
+
return func(*args, **kwargs)
|
| 155 |
+
|
| 156 |
+
# Direct call (equivalent to forward)
|
| 157 |
+
def __call__(self, *args, **kwargs):
|
| 158 |
+
return self._run_with_optional_offload(self.model.__call__, *args, **kwargs)
|
| 159 |
+
|
| 160 |
+
# Explicit forward; some code calls model.forward(...)
|
| 161 |
+
def forward(self, *args, **kwargs):
|
| 162 |
+
return self._run_with_optional_offload(self.model.forward, *args, **kwargs)
|
| 163 |
+
|
| 164 |
+
# Key: passthrough all attrs/methods. For callables, wrap with on/offload; for non-compute methods, pass-through only with no device switch.
|
| 165 |
+
def __getattr__(self, name: str):
|
| 166 |
+
# Fetch attribute from the wrapped model first
|
| 167 |
+
attr = getattr(self.model, name)
|
| 168 |
+
|
| 169 |
+
# Wrap methods (except in whitelist)
|
| 170 |
+
if callable(attr) and name not in self._non_compute_methods:
|
| 171 |
+
|
| 172 |
+
def _wrapped(*args, **kwargs):
|
| 173 |
+
return self._run_with_optional_offload(attr, *args, **kwargs)
|
| 174 |
+
|
| 175 |
+
return _wrapped
|
| 176 |
+
|
| 177 |
+
return attr
|
| 178 |
+
|
| 179 |
+
def __dir__(self):
|
| 180 |
+
return sorted(set(list(super().__dir__()) + dir(self.model)))
|
| 181 |
+
|
| 182 |
+
def __setattr__(self, name: str, value: Any):
|
| 183 |
+
raise AttributeError("CPUOffloadWrapper is immutable")
|
| 184 |
+
|
| 185 |
+
def __repr__(self) -> str:
|
| 186 |
+
return f"CPUOffloadWrapper(is_cpu_offload={self.is_cpu_offload}, is_running_on_gpu={self.is_running_on_gpu}, model={repr(self.model)})"
|
inference/common/sequence_schema.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from enum import IntEnum
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Modality(IntEnum):
|
| 22 |
+
VIDEO = 0
|
| 23 |
+
AUDIO = 1
|
| 24 |
+
TEXT = 2
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class VarlenHandler:
|
| 29 |
+
cu_seqlens_q: torch.Tensor
|
| 30 |
+
cu_seqlens_k: torch.Tensor
|
| 31 |
+
max_seqlen_q: int
|
| 32 |
+
max_seqlen_k: int
|
| 33 |
+
|
inference/infra/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
from inference.common import parse_config
|
| 18 |
+
from inference.infra.distributed import get_dp_rank, initialize_distributed
|
| 19 |
+
from inference.utils import print_rank_0, set_random_seed
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def initialize_infra():
|
| 23 |
+
assert torch.cuda.is_available(), "Infra requires CUDA environment."
|
| 24 |
+
|
| 25 |
+
# Initialize distributed environment
|
| 26 |
+
initialize_distributed()
|
| 27 |
+
|
| 28 |
+
# Initialize config
|
| 29 |
+
config = parse_config(verbose=True)
|
| 30 |
+
|
| 31 |
+
# Initialize random seed
|
| 32 |
+
set_random_seed(config.engine_config.seed + 10 * get_dp_rank())
|
| 33 |
+
|
| 34 |
+
print_rank_0("Infra successfully initialized")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
__all__ = ["initialize_infra"]
|
inference/infra/checkpoint/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .load_model_checkpoint import load_model_checkpoint
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
# checkpoint loader
|
| 19 |
+
"load_model_checkpoint",
|
| 20 |
+
]
|
inference/infra/checkpoint/load_model_checkpoint.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import io
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
import subprocess
|
| 19 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 20 |
+
|
| 21 |
+
from inference.common import EngineConfig
|
| 22 |
+
from inference.utils import print_rank_0
|
| 23 |
+
from safetensors.torch import load as load_from_bytes
|
| 24 |
+
from safetensors.torch import load_file
|
| 25 |
+
from tqdm.auto import tqdm
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _load_shard(shard_path, param_names, num_threads=None):
|
| 29 |
+
zstd_path = shard_path + ".zst"
|
| 30 |
+
if os.path.exists(zstd_path):
|
| 31 |
+
cmd = ["zstd", "-d"]
|
| 32 |
+
if num_threads:
|
| 33 |
+
cmd.extend(["-T", str(num_threads)]) # set parallelism
|
| 34 |
+
|
| 35 |
+
process = subprocess.Popen(cmd + ["-c", zstd_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=-1)
|
| 36 |
+
|
| 37 |
+
decompressed_data = process.stdout.read()
|
| 38 |
+
while True:
|
| 39 |
+
new_data = process.stdout.read()
|
| 40 |
+
if not new_data:
|
| 41 |
+
break
|
| 42 |
+
decompressed_data += new_data
|
| 43 |
+
process.stdout.close()
|
| 44 |
+
|
| 45 |
+
retcode = process.wait()
|
| 46 |
+
if retcode != 0:
|
| 47 |
+
raise RuntimeError(f"Decompression failed: {process.stderr.read().decode()}")
|
| 48 |
+
|
| 49 |
+
buffer = io.BytesIO(decompressed_data)
|
| 50 |
+
weights = load_from_bytes(buffer.getvalue())
|
| 51 |
+
buffer.close()
|
| 52 |
+
else:
|
| 53 |
+
weights = load_file(shard_path)
|
| 54 |
+
|
| 55 |
+
return {name: weights[name] for name in param_names}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def load_sharded_safetensors_parallel_with_progress(checkpoint_dir):
|
| 59 |
+
index_path = os.path.join(checkpoint_dir, "model.safetensors.index.json")
|
| 60 |
+
if not os.path.exists(index_path):
|
| 61 |
+
model_file_path = os.path.join(checkpoint_dir, "model.safetensors")
|
| 62 |
+
state_dict = load_file(model_file_path)
|
| 63 |
+
return state_dict
|
| 64 |
+
|
| 65 |
+
with open(index_path, "r") as f:
|
| 66 |
+
index = json.load(f)
|
| 67 |
+
|
| 68 |
+
state_dict = {}
|
| 69 |
+
shard_map = {}
|
| 70 |
+
|
| 71 |
+
# Group parameters by shard file
|
| 72 |
+
for param_name, shard_file in index["weight_map"].items():
|
| 73 |
+
shard_path = os.path.join(checkpoint_dir, shard_file)
|
| 74 |
+
if shard_path not in shard_map:
|
| 75 |
+
shard_map[shard_path] = []
|
| 76 |
+
shard_map[shard_path].append(param_name)
|
| 77 |
+
|
| 78 |
+
# Load shards in parallel with a progress bar
|
| 79 |
+
with ThreadPoolExecutor() as executor:
|
| 80 |
+
futures = {
|
| 81 |
+
executor.submit(_load_shard, shard_path, param_names): shard_path for shard_path, param_names in shard_map.items()
|
| 82 |
+
}
|
| 83 |
+
pbar = tqdm(futures, desc="Loading shards", total=len(futures))
|
| 84 |
+
for future in pbar:
|
| 85 |
+
result = future.result()
|
| 86 |
+
state_dict.update(result)
|
| 87 |
+
|
| 88 |
+
return state_dict
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def load_model_checkpoint(model, engine_config: EngineConfig):
|
| 92 |
+
print_rank_0("Loading checkpoint with safetensors format from pretrained_folder")
|
| 93 |
+
state_dict = load_sharded_safetensors_parallel_with_progress(engine_config.load)
|
| 94 |
+
|
| 95 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
| 96 |
+
print_rank_0(f"Load Weight Missing Keys: {missing_keys}")
|
| 97 |
+
print_rank_0(f"Load Weight Unexpected Keys: {unexpected_keys}")
|
| 98 |
+
print_rank_0("Load checkpoint successfully")
|
| 99 |
+
return model
|
inference/infra/distributed/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .parallel_state import get_cp_group, get_cp_rank, get_cp_world_size, get_dp_rank, get_pp_rank, get_tp_rank
|
| 16 |
+
from .init_dist_env import initialize_distributed
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
# distributed init
|
| 20 |
+
"initialize_distributed",
|
| 21 |
+
# parallel state
|
| 22 |
+
"get_cp_group",
|
| 23 |
+
"get_cp_world_size",
|
| 24 |
+
"get_tp_rank",
|
| 25 |
+
"get_pp_rank",
|
| 26 |
+
"get_dp_rank",
|
| 27 |
+
"get_cp_rank",
|
| 28 |
+
]
|
inference/infra/distributed/init_dist_env.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from datetime import timedelta
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from inference.common import parse_config
|
| 21 |
+
|
| 22 |
+
from .parallel_state import initialize_model_parallel, model_parallel_is_initialized
|
| 23 |
+
from inference.utils import print_rank_0
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def initialize_distributed():
|
| 27 |
+
"""Initialize torch.distributed and core model parallel."""
|
| 28 |
+
config = parse_config()
|
| 29 |
+
|
| 30 |
+
device_count = torch.cuda.device_count()
|
| 31 |
+
if torch.distributed.is_initialized():
|
| 32 |
+
if torch.distributed.get_rank() == 0:
|
| 33 |
+
print_rank_0("> torch distributed already initialized, skipping initialization ...")
|
| 34 |
+
else:
|
| 35 |
+
rank = int(os.getenv("RANK", "0"))
|
| 36 |
+
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
| 37 |
+
if rank == 0:
|
| 38 |
+
print_rank_0("> initializing torch distributed ...")
|
| 39 |
+
# Manually set the device ids.
|
| 40 |
+
if device_count > 0:
|
| 41 |
+
device = rank % device_count
|
| 42 |
+
torch.cuda.set_device(device)
|
| 43 |
+
# Call the init process
|
| 44 |
+
torch.distributed.init_process_group(
|
| 45 |
+
backend=config.engine_config.distributed_backend,
|
| 46 |
+
world_size=world_size,
|
| 47 |
+
rank=rank,
|
| 48 |
+
timeout=timedelta(minutes=config.engine_config.distributed_timeout_minutes),
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Set the tp, pp and dp communicators.
|
| 52 |
+
if device_count > 0:
|
| 53 |
+
if model_parallel_is_initialized():
|
| 54 |
+
return
|
| 55 |
+
initialize_model_parallel(
|
| 56 |
+
tp_size=config.engine_config.tp_size,
|
| 57 |
+
pp_size=config.engine_config.pp_size,
|
| 58 |
+
cp_size=config.engine_config.cp_size,
|
| 59 |
+
nccl_communicator_config_path=None,
|
| 60 |
+
distributed_timeout_minutes=config.engine_config.distributed_timeout_minutes,
|
| 61 |
+
order="tp-cp-pp-dp",
|
| 62 |
+
)
|
inference/infra/distributed/parallel_state.py
ADDED
|
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Model and data parallel groups."""
|
| 17 |
+
|
| 18 |
+
import warnings
|
| 19 |
+
from datetime import timedelta
|
| 20 |
+
from typing import List, Optional
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
# Intra-layer model parallel group that the current rank belongs to.
|
| 25 |
+
_TENSOR_MODEL_PARALLEL_GROUP = None
|
| 26 |
+
# Tensor parallel group information with context parallel combined.
|
| 27 |
+
_TENSOR_MODEL_PARALLEL_GROUP_WITH_CP = None
|
| 28 |
+
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP = None
|
| 29 |
+
# Inter-layer model parallel group that the current rank belongs to.
|
| 30 |
+
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
| 31 |
+
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
|
| 32 |
+
_MODEL_PARALLEL_GROUP = None
|
| 33 |
+
# Data parallel group that the current rank belongs to.
|
| 34 |
+
_DATA_PARALLEL_GROUP = None
|
| 35 |
+
# tensor model parallel group and data parallel group combined
|
| 36 |
+
# used for fp8 and moe training
|
| 37 |
+
_TENSOR_AND_DATA_PARALLEL_GROUP = None
|
| 38 |
+
|
| 39 |
+
# A list of global ranks for each pipeline group to ease calculation of the source
|
| 40 |
+
# rank when broadcasting from the first or last pipeline stage.
|
| 41 |
+
_PIPELINE_GLOBAL_RANKS = None
|
| 42 |
+
|
| 43 |
+
# A list of global ranks for each data parallel group to ease calculation of the source
|
| 44 |
+
# rank when broadcasting weights from src to all other data parallel ranks
|
| 45 |
+
_DATA_PARALLEL_GLOBAL_RANKS = None
|
| 46 |
+
|
| 47 |
+
# A list of global ranks for each tensor model parallel group to ease calculation of
|
| 48 |
+
# the first local rank in the tensor model parallel group
|
| 49 |
+
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = None
|
| 50 |
+
|
| 51 |
+
# Context parallel group that the current rank belongs to
|
| 52 |
+
_CONTEXT_PARALLEL_GROUP = None
|
| 53 |
+
# A list of global ranks for each context parallel group to ease calculation of the
|
| 54 |
+
# destination rank when exchanging KV/dKV between context parallel_ranks
|
| 55 |
+
_CONTEXT_PARALLEL_GLOBAL_RANKS = None
|
| 56 |
+
|
| 57 |
+
_CONTEXT_PARALLEL_EXTRA_GROUP = None
|
| 58 |
+
|
| 59 |
+
# Data parallel group information with context parallel combined.
|
| 60 |
+
_DATA_PARALLEL_GROUP_WITH_CP = None
|
| 61 |
+
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None
|
| 62 |
+
|
| 63 |
+
# combined parallel group of TP, DP, and CP used for fp8
|
| 64 |
+
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _get_nccl_options(pg_name, nccl_comm_cfgs):
|
| 68 |
+
"""Set the NCCL process group options.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
pg_name (str): process group name
|
| 72 |
+
nccl_comm_cfgs (dict): nccl communicator configurations
|
| 73 |
+
|
| 74 |
+
When an option (e.g., max_ctas) is not found in the config, use the NCCL default setting.
|
| 75 |
+
"""
|
| 76 |
+
if pg_name in nccl_comm_cfgs:
|
| 77 |
+
nccl_options = torch.distributed.ProcessGroupNCCL.Options()
|
| 78 |
+
nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get("cga_cluster_size", 4)
|
| 79 |
+
nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get("max_ctas", 32)
|
| 80 |
+
nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get("min_ctas", 1)
|
| 81 |
+
return nccl_options
|
| 82 |
+
else:
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def generate_masked_orthogonal_rank_groups(world_size: int, parallel_size: List[int], mask: List[bool]) -> List[List[int]]:
|
| 87 |
+
r"""Generate orthogonal parallel groups based on the parallel size and mask.
|
| 88 |
+
|
| 89 |
+
Arguments:
|
| 90 |
+
world_size (int): world size
|
| 91 |
+
|
| 92 |
+
parallel_size (List[int]):
|
| 93 |
+
The parallel size of each orthogonal parallel type. For example, if
|
| 94 |
+
tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4,
|
| 95 |
+
and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4].
|
| 96 |
+
|
| 97 |
+
mask (List[bool]):
|
| 98 |
+
The mask controls which parallel methods the generated groups represent. If mask[i] is
|
| 99 |
+
True, it means the generated group contains the i-th parallelism method. For example,
|
| 100 |
+
if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then
|
| 101 |
+
the generated group is the `tp-dp` group, if the mask = [False, True, False], then the
|
| 102 |
+
generated group is the `pp` group.
|
| 103 |
+
|
| 104 |
+
Algorithm:
|
| 105 |
+
For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and
|
| 106 |
+
local_rank satisfy the following equation:
|
| 107 |
+
global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1)
|
| 108 |
+
tp_rank \in [0, tp_size)
|
| 109 |
+
dp_rank \in [0, dp_size)
|
| 110 |
+
pp_rank \in [0, pp_size)
|
| 111 |
+
|
| 112 |
+
If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each.
|
| 113 |
+
For example, if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the
|
| 114 |
+
dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].)
|
| 115 |
+
The tp_rank and pp_rank will be combined to form the `dp_group_index`.
|
| 116 |
+
dp_group_index = tp_rank + pp_rank * tp_size (2)
|
| 117 |
+
|
| 118 |
+
So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in
|
| 119 |
+
range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the
|
| 120 |
+
equation (1).
|
| 121 |
+
|
| 122 |
+
This function solve this math problem.
|
| 123 |
+
|
| 124 |
+
For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4],
|
| 125 |
+
and the mask = [False, True, False]. Then,
|
| 126 |
+
dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2
|
| 127 |
+
dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2
|
| 128 |
+
...
|
| 129 |
+
dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2
|
| 130 |
+
|
| 131 |
+
dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4]
|
| 132 |
+
dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5]
|
| 133 |
+
...
|
| 134 |
+
dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23]
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def prefix_product(a: List[int], init=1) -> List[int]:
|
| 138 |
+
r = [init]
|
| 139 |
+
for v in a:
|
| 140 |
+
init = init * v
|
| 141 |
+
r.append(init)
|
| 142 |
+
return r
|
| 143 |
+
|
| 144 |
+
def inner_product(a: List[int], b: List[int]) -> int:
|
| 145 |
+
return sum([x * y for x, y in zip(a, b)])
|
| 146 |
+
|
| 147 |
+
def decompose(index, shape, stride=None):
|
| 148 |
+
"""
|
| 149 |
+
This function solve the math problem below:
|
| 150 |
+
There is an equation:
|
| 151 |
+
index = sum(idx[i] * stride[i])
|
| 152 |
+
And given the value of index, stride.
|
| 153 |
+
Return the idx.
|
| 154 |
+
This function will used to get the pp/dp/pp_rank
|
| 155 |
+
from group_index and rank_in_group.
|
| 156 |
+
"""
|
| 157 |
+
if stride is None:
|
| 158 |
+
stride = prefix_product(shape)
|
| 159 |
+
idx = [(index // d) % s for s, d in zip(shape, stride)]
|
| 160 |
+
# stride is a prefix_product result. And the value of stride[-1]
|
| 161 |
+
# is not used.
|
| 162 |
+
assert (
|
| 163 |
+
sum([x * y for x, y in zip(idx, stride[:-1])]) == index
|
| 164 |
+
), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx)
|
| 165 |
+
return idx
|
| 166 |
+
|
| 167 |
+
masked_shape = [s for s, m in zip(parallel_size, mask) if m]
|
| 168 |
+
unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m]
|
| 169 |
+
|
| 170 |
+
global_stride = prefix_product(parallel_size)
|
| 171 |
+
masked_stride = [d for d, m in zip(global_stride, mask) if m]
|
| 172 |
+
unmasked_stride = [d for d, m in zip(global_stride, mask) if not m]
|
| 173 |
+
|
| 174 |
+
group_size = prefix_product(masked_shape)[-1]
|
| 175 |
+
num_of_group = world_size // group_size
|
| 176 |
+
|
| 177 |
+
ranks = []
|
| 178 |
+
for group_index in range(num_of_group):
|
| 179 |
+
# get indices from unmaksed for group_index.
|
| 180 |
+
decomposed_group_idx = decompose(group_index, unmasked_shape)
|
| 181 |
+
rank = []
|
| 182 |
+
for rank_in_group in range(group_size):
|
| 183 |
+
# get indices from masked for rank_in_group.
|
| 184 |
+
decomposed_rank_idx = decompose(rank_in_group, masked_shape)
|
| 185 |
+
rank.append(
|
| 186 |
+
inner_product(decomposed_rank_idx, masked_stride) + inner_product(decomposed_group_idx, unmasked_stride)
|
| 187 |
+
)
|
| 188 |
+
ranks.append(rank)
|
| 189 |
+
return ranks
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class RankGenerator(object):
|
| 193 |
+
def __init__(self, tp: int, dp: int, pp: int, cp: int, order: str) -> None:
|
| 194 |
+
self.tp = tp
|
| 195 |
+
self.dp = dp
|
| 196 |
+
self.pp = pp
|
| 197 |
+
self.cp = cp
|
| 198 |
+
self.world_size = tp * dp * pp * cp
|
| 199 |
+
|
| 200 |
+
self.name_to_size = {"tp": self.tp, "pp": self.pp, "dp": self.dp, "cp": self.cp}
|
| 201 |
+
order = order.lower()
|
| 202 |
+
for name in self.name_to_size.keys():
|
| 203 |
+
if name not in order and self.name_to_size[name] != 1:
|
| 204 |
+
raise RuntimeError(
|
| 205 |
+
f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({order})."
|
| 206 |
+
)
|
| 207 |
+
elif name not in order:
|
| 208 |
+
order = order + "-" + name
|
| 209 |
+
|
| 210 |
+
self.order = order
|
| 211 |
+
self.ordered_size = [self.name_to_size[token] for token in order.split("-")]
|
| 212 |
+
|
| 213 |
+
def get_mask(self, order: str, token: str):
|
| 214 |
+
ordered_token = order.split("-")
|
| 215 |
+
token = token.split("-")
|
| 216 |
+
mask = [False] * len(ordered_token)
|
| 217 |
+
for t in token:
|
| 218 |
+
mask[ordered_token.index(t)] = True
|
| 219 |
+
return mask
|
| 220 |
+
|
| 221 |
+
def get_ranks(self, token):
|
| 222 |
+
"""Get rank group by input token.
|
| 223 |
+
|
| 224 |
+
Arguments:
|
| 225 |
+
token (str):
|
| 226 |
+
Specify the ranks type that want to get. If we want
|
| 227 |
+
to obtain multiple parallel types, we can use a hyphen
|
| 228 |
+
'-' to separate them. For example, if we want to obtain
|
| 229 |
+
the TP_DP group, the token should be 'tp-dp'.
|
| 230 |
+
"""
|
| 231 |
+
mask = self.get_mask(self.order, token)
|
| 232 |
+
ranks = generate_masked_orthogonal_rank_groups(self.world_size, self.ordered_size, mask)
|
| 233 |
+
return ranks
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def initialize_model_parallel(
|
| 237 |
+
tp_size: int = 1,
|
| 238 |
+
pp_size: int = 1,
|
| 239 |
+
cp_size: int = 1,
|
| 240 |
+
nccl_communicator_config_path: Optional[str] = None,
|
| 241 |
+
distributed_timeout_minutes: int = 30,
|
| 242 |
+
order: str = "tp-cp-pp-dp",
|
| 243 |
+
) -> None:
|
| 244 |
+
"""Initialize model data parallel groups.
|
| 245 |
+
Borrow from: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
tp_size (int, default = 1):
|
| 249 |
+
The number of GPUs to split individual tensors across.
|
| 250 |
+
|
| 251 |
+
pp_size (int, default = 1):
|
| 252 |
+
The number of tensor parallel GPU groups to split the
|
| 253 |
+
Transformer layers across. For example, if tp_size is 4 and
|
| 254 |
+
pp_size is 2, the model will be split into 2 groups of 4 GPUs.
|
| 255 |
+
|
| 256 |
+
cp_size (int, default = 1):
|
| 257 |
+
The number of tensor parallel GPU groups to split the
|
| 258 |
+
network input sequence length across. Compute of attention
|
| 259 |
+
module requires tokens of full sequence length, so GPUs
|
| 260 |
+
in a context parallel group need to communicate with each
|
| 261 |
+
other to exchange information of other sequence chunks.
|
| 262 |
+
Each GPU and its counterparts in other tensor parallel
|
| 263 |
+
groups compose a context parallel group.
|
| 264 |
+
|
| 265 |
+
For example, assume we have 8 GPUs, if tensor model parallel
|
| 266 |
+
size is 4 and context parallel size is 2, the network input
|
| 267 |
+
will be split into two sequence chunks, which are processed
|
| 268 |
+
by 2 different groups of 4 GPUs. One chunk is processed by
|
| 269 |
+
GPU0-3, the other chunk is processed by GPU4-7. Four groups
|
| 270 |
+
are build to do context parallel communications: [GPU0, GPU4],
|
| 271 |
+
[GPU1, GPU5], [GPU2, GPU6], and [GPU3, GPU7].
|
| 272 |
+
|
| 273 |
+
Context parallelism partitions sequence length, so it has no
|
| 274 |
+
impact on weights, which means weights are duplicated among
|
| 275 |
+
GPUs in a context parallel group. Hence, weight gradients
|
| 276 |
+
all-reduce is required in backward. For simplicity, we piggyback
|
| 277 |
+
GPUs of context parallelism on data parallel group for
|
| 278 |
+
weight gradient all-reduce.
|
| 279 |
+
|
| 280 |
+
nccl_communicator_config_path (str, default = None):
|
| 281 |
+
Path to the yaml file of NCCL communicator configurations.
|
| 282 |
+
`min_ctas`, `max_ctas`, and `cga_cluster_size` can be set
|
| 283 |
+
for each communicator.
|
| 284 |
+
|
| 285 |
+
distributed_timeout_minutes (int, default = 30): Timeout, in
|
| 286 |
+
minutes,for operations executed against distributed
|
| 287 |
+
process groups. See PyTorch documentation at
|
| 288 |
+
https://pytorch.org/docs/stable/distributed.html for
|
| 289 |
+
caveats.
|
| 290 |
+
|
| 291 |
+
order (str, default=tp-dp-pp):
|
| 292 |
+
The rank initialization order of parallelism. Now we support
|
| 293 |
+
tp-dp-pp and tp-pp-dp orders.
|
| 294 |
+
|
| 295 |
+
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
|
| 296 |
+
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
|
| 297 |
+
the model pipeline. The present function will
|
| 298 |
+
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
|
| 299 |
+
and 8 data-parallel groups as:
|
| 300 |
+
8 data_parallel groups:
|
| 301 |
+
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
|
| 302 |
+
8 tensor model-parallel groups:
|
| 303 |
+
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
|
| 304 |
+
4 pipeline model-parallel groups:
|
| 305 |
+
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
|
| 306 |
+
Note that for efficiency, the caller should make sure adjacent ranks
|
| 307 |
+
are on the same DGX box. For example if we are using 2 DGX-1 boxes
|
| 308 |
+
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
|
| 309 |
+
ranks 8 to 15 belong to the second box.
|
| 310 |
+
|
| 311 |
+
"""
|
| 312 |
+
# Get world size and rank. Ensure some consistencies.
|
| 313 |
+
assert torch.distributed.is_initialized()
|
| 314 |
+
world_size: int = torch.distributed.get_world_size()
|
| 315 |
+
if world_size % (tp_size * pp_size * cp_size) != 0:
|
| 316 |
+
raise RuntimeError(
|
| 317 |
+
f"world_size ({world_size}) is not divisible by tp_size "
|
| 318 |
+
f"({tp_size}) x pp_size ({pp_size}) "
|
| 319 |
+
f"x cp_size ({cp_size})"
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
nccl_comm_cfgs = {}
|
| 323 |
+
if nccl_communicator_config_path is not None:
|
| 324 |
+
try:
|
| 325 |
+
import yaml
|
| 326 |
+
except ImportError:
|
| 327 |
+
raise RuntimeError("Cannot import `yaml`. Setting custom nccl communicator configs " "requires the yaml package.")
|
| 328 |
+
|
| 329 |
+
with open(nccl_communicator_config_path, "r") as stream:
|
| 330 |
+
nccl_comm_cfgs = yaml.safe_load(stream)
|
| 331 |
+
|
| 332 |
+
dp_size: int = world_size // (tp_size * pp_size * cp_size)
|
| 333 |
+
rank = torch.distributed.get_rank()
|
| 334 |
+
rank_generator = RankGenerator(tp=tp_size, dp=dp_size, pp=pp_size, cp=cp_size, order=order)
|
| 335 |
+
timeout = timedelta(minutes=distributed_timeout_minutes)
|
| 336 |
+
|
| 337 |
+
# Build the data-parallel groups.
|
| 338 |
+
global _DATA_PARALLEL_GROUP
|
| 339 |
+
global _DATA_PARALLEL_GLOBAL_RANKS
|
| 340 |
+
global _DATA_PARALLEL_GROUP_WITH_CP
|
| 341 |
+
global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
|
| 342 |
+
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
|
| 343 |
+
|
| 344 |
+
for ranks in rank_generator.get_ranks("dp"):
|
| 345 |
+
group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=_get_nccl_options("dp", nccl_comm_cfgs))
|
| 346 |
+
if rank in ranks:
|
| 347 |
+
_DATA_PARALLEL_GROUP = group
|
| 348 |
+
_DATA_PARALLEL_GLOBAL_RANKS = ranks
|
| 349 |
+
for ranks_with_cp in rank_generator.get_ranks("dp-cp"):
|
| 350 |
+
group_with_cp = torch.distributed.new_group(
|
| 351 |
+
ranks_with_cp, timeout=timeout, pg_options=_get_nccl_options("dp_cp", nccl_comm_cfgs)
|
| 352 |
+
)
|
| 353 |
+
if rank in ranks_with_cp:
|
| 354 |
+
_DATA_PARALLEL_GROUP_WITH_CP = group_with_cp
|
| 355 |
+
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks_with_cp
|
| 356 |
+
|
| 357 |
+
# Build the context-parallel groups.
|
| 358 |
+
global _CONTEXT_PARALLEL_GROUP
|
| 359 |
+
global _CONTEXT_PARALLEL_GLOBAL_RANKS
|
| 360 |
+
assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized"
|
| 361 |
+
for ranks in rank_generator.get_ranks("cp"):
|
| 362 |
+
group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=_get_nccl_options("cp", nccl_comm_cfgs))
|
| 363 |
+
if rank in ranks:
|
| 364 |
+
_CONTEXT_PARALLEL_GROUP = group
|
| 365 |
+
_CONTEXT_PARALLEL_GLOBAL_RANKS = ranks
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
# Build the model-parallel groups.
|
| 369 |
+
global _MODEL_PARALLEL_GROUP
|
| 370 |
+
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
|
| 371 |
+
for ranks in rank_generator.get_ranks("tp-pp"):
|
| 372 |
+
group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=_get_nccl_options("mp", nccl_comm_cfgs))
|
| 373 |
+
if rank in ranks:
|
| 374 |
+
_MODEL_PARALLEL_GROUP = group
|
| 375 |
+
|
| 376 |
+
# Build the tensor model-parallel groups.
|
| 377 |
+
global _TENSOR_MODEL_PARALLEL_GROUP
|
| 378 |
+
global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
|
| 379 |
+
assert _TENSOR_MODEL_PARALLEL_GROUP is None, "tensor model parallel group is already initialized"
|
| 380 |
+
for ranks in rank_generator.get_ranks("tp"):
|
| 381 |
+
group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=_get_nccl_options("tp", nccl_comm_cfgs))
|
| 382 |
+
if rank in ranks:
|
| 383 |
+
_TENSOR_MODEL_PARALLEL_GROUP = group
|
| 384 |
+
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = ranks
|
| 385 |
+
|
| 386 |
+
# Build the tensor + context parallel groups.
|
| 387 |
+
global _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP
|
| 388 |
+
global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP
|
| 389 |
+
assert (
|
| 390 |
+
_TENSOR_MODEL_PARALLEL_GROUP_WITH_CP is None
|
| 391 |
+
), "tensor model parallel group with context parallel is already initialized"
|
| 392 |
+
for ranks in rank_generator.get_ranks("tp-cp"):
|
| 393 |
+
group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=_get_nccl_options("tp_cp", nccl_comm_cfgs))
|
| 394 |
+
if rank in ranks:
|
| 395 |
+
_TENSOR_MODEL_PARALLEL_GROUP_WITH_CP = group
|
| 396 |
+
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks
|
| 397 |
+
|
| 398 |
+
# Build the pipeline model-parallel groups
|
| 399 |
+
global _PIPELINE_MODEL_PARALLEL_GROUP
|
| 400 |
+
global _PIPELINE_GLOBAL_RANKS
|
| 401 |
+
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, "pipeline model parallel group is already initialized"
|
| 402 |
+
for ranks in rank_generator.get_ranks("pp"):
|
| 403 |
+
group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=_get_nccl_options("pp", nccl_comm_cfgs))
|
| 404 |
+
if rank in ranks:
|
| 405 |
+
_PIPELINE_MODEL_PARALLEL_GROUP = group
|
| 406 |
+
_PIPELINE_GLOBAL_RANKS = ranks
|
| 407 |
+
|
| 408 |
+
# Build the tensor + data parallel groups.
|
| 409 |
+
global _TENSOR_AND_DATA_PARALLEL_GROUP
|
| 410 |
+
global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
|
| 411 |
+
assert _TENSOR_AND_DATA_PARALLEL_GROUP is None, "Tensor + data parallel group is already initialized"
|
| 412 |
+
for ranks in rank_generator.get_ranks("tp-cp-dp"):
|
| 413 |
+
group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=_get_nccl_options("tp_cp_dp", nccl_comm_cfgs))
|
| 414 |
+
if rank in ranks:
|
| 415 |
+
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = group
|
| 416 |
+
for ranks in rank_generator.get_ranks("tp-dp"):
|
| 417 |
+
group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=_get_nccl_options("tp_dp", nccl_comm_cfgs))
|
| 418 |
+
if rank in ranks:
|
| 419 |
+
_TENSOR_AND_DATA_PARALLEL_GROUP = group
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def is_initialized():
|
| 423 |
+
"""Useful for code segments that may be accessed with or without mpu initialization"""
|
| 424 |
+
return _DATA_PARALLEL_GROUP is not None
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def is_unitialized() -> bool:
|
| 428 |
+
"""Check if parallel state has been initialized
|
| 429 |
+
|
| 430 |
+
Deprecated. Use is_initialized instead.
|
| 431 |
+
|
| 432 |
+
"""
|
| 433 |
+
warnings.warn("is_unitialized is deprecated, use is_initialized instead", DeprecationWarning)
|
| 434 |
+
return not is_initialized()
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def model_parallel_is_initialized():
|
| 438 |
+
"""Check if model and data parallel groups are initialized."""
|
| 439 |
+
if _TENSOR_MODEL_PARALLEL_GROUP is None or _PIPELINE_MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
|
| 440 |
+
return False
|
| 441 |
+
return True
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def get_model_parallel_group():
|
| 445 |
+
"""Get the model parallel group the caller rank belongs to."""
|
| 446 |
+
assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
|
| 447 |
+
return _MODEL_PARALLEL_GROUP
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def get_tp_group(check_initialized=True, with_context_parallel=False):
|
| 451 |
+
"""Get the tensor model parallel group the caller rank belongs to."""
|
| 452 |
+
if check_initialized:
|
| 453 |
+
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, "tensor model parallel group is not initialized"
|
| 454 |
+
if with_context_parallel:
|
| 455 |
+
assert (
|
| 456 |
+
_TENSOR_MODEL_PARALLEL_GROUP_WITH_CP is not None
|
| 457 |
+
), "tensor model parallel group with context parallel combined is not initialized"
|
| 458 |
+
return _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP
|
| 459 |
+
else:
|
| 460 |
+
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, "tensor model parallel group is not initialized"
|
| 461 |
+
return _TENSOR_MODEL_PARALLEL_GROUP
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def get_pp_group():
|
| 465 |
+
"""Get the pipeline model parallel group the caller rank belongs to."""
|
| 466 |
+
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, "pipeline_model parallel group is not initialized"
|
| 467 |
+
return _PIPELINE_MODEL_PARALLEL_GROUP
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def get_dp_group(with_context_parallel=False):
|
| 471 |
+
"""Get the data parallel group the caller rank belongs to."""
|
| 472 |
+
if with_context_parallel:
|
| 473 |
+
assert (
|
| 474 |
+
_DATA_PARALLEL_GROUP_WITH_CP is not None
|
| 475 |
+
), "data parallel group with context parallel combined is not initialized"
|
| 476 |
+
return _DATA_PARALLEL_GROUP_WITH_CP
|
| 477 |
+
else:
|
| 478 |
+
assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
|
| 479 |
+
return _DATA_PARALLEL_GROUP
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def get_cp_group(check_initialized=True):
|
| 483 |
+
"""Get the context parallel group the caller rank belongs to."""
|
| 484 |
+
if check_initialized:
|
| 485 |
+
assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized"
|
| 486 |
+
return _CONTEXT_PARALLEL_GROUP
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def get_cp_extra_group(check_initialized=True):
|
| 490 |
+
if check_initialized:
|
| 491 |
+
assert _CONTEXT_PARALLEL_EXTRA_GROUP is not None, "context parallel extra group is not initialized"
|
| 492 |
+
return _CONTEXT_PARALLEL_EXTRA_GROUP
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def get_tp_world_size(with_context_parallel=False):
|
| 496 |
+
"""Return world size for the tensor model parallel group."""
|
| 497 |
+
return torch.distributed.get_world_size(group=get_tp_group(with_context_parallel=with_context_parallel))
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def get_pp_world_size():
|
| 501 |
+
"""Return world size for the pipeline model parallel group."""
|
| 502 |
+
return torch.distributed.get_world_size(group=get_pp_group())
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def get_tp_rank(with_context_parallel=False):
|
| 506 |
+
"""Return my rank for the tensor model parallel group."""
|
| 507 |
+
return torch.distributed.get_rank(group=get_tp_group(with_context_parallel=with_context_parallel))
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def get_pp_rank():
|
| 511 |
+
"""Return my rank for the pipeline model parallel group."""
|
| 512 |
+
return torch.distributed.get_rank(group=get_pp_group())
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def is_pipeline_first_stage():
|
| 516 |
+
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
|
| 517 |
+
return get_pp_rank() == 0
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def is_pipeline_last_stage():
|
| 521 |
+
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
|
| 522 |
+
return get_pp_rank() == (get_pp_world_size() - 1)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
def get_tensor_model_parallel_src_rank(with_context_parallel=False):
|
| 526 |
+
"""Calculate the global rank corresponding to the first local rank
|
| 527 |
+
in the tensor model parallel group."""
|
| 528 |
+
assert _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS is not None, "Tensor model parallel group is not initialized"
|
| 529 |
+
if with_context_parallel:
|
| 530 |
+
assert (
|
| 531 |
+
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP is not None
|
| 532 |
+
), "Tensor model parallel group with context parallel combined is not initialized"
|
| 533 |
+
return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP[0]
|
| 534 |
+
else:
|
| 535 |
+
return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS[0]
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def get_tensor_model_parallel_ranks(with_context_parallel=False):
|
| 539 |
+
"""Return all global ranks for the tensor model parallel group."""
|
| 540 |
+
if with_context_parallel:
|
| 541 |
+
assert (
|
| 542 |
+
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP is not None
|
| 543 |
+
), "Tensor model parallel group with context parallel combined is not initialized"
|
| 544 |
+
return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP
|
| 545 |
+
else:
|
| 546 |
+
assert _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS is not None, "Tensor model parallel group is not initialized"
|
| 547 |
+
return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def get_tensor_model_parallel_last_rank(with_context_parallel=False):
|
| 551 |
+
"""Calculate the global rank corresponding to the first local rank
|
| 552 |
+
in the tensor model parallel group."""
|
| 553 |
+
assert _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS is not None, "Tensor model parallel group is not initialized"
|
| 554 |
+
if with_context_parallel:
|
| 555 |
+
assert (
|
| 556 |
+
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP is not None
|
| 557 |
+
), "Tensor model parallel group with context parallel combined is not initialized"
|
| 558 |
+
return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP[-1]
|
| 559 |
+
else:
|
| 560 |
+
return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS[-1]
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def get_pipeline_model_parallel_first_rank():
|
| 564 |
+
"""Return the global rank of the first process in the pipeline for the
|
| 565 |
+
current tensor parallel group"""
|
| 566 |
+
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
|
| 567 |
+
return _PIPELINE_GLOBAL_RANKS[0]
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def get_pipeline_model_parallel_last_rank():
|
| 571 |
+
"""Return the global rank of the last process in the pipeline for the
|
| 572 |
+
current tensor parallel group"""
|
| 573 |
+
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
|
| 574 |
+
last_rank_local = get_pp_world_size() - 1
|
| 575 |
+
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def get_pipeline_model_parallel_next_rank():
|
| 579 |
+
"""Return the global rank that follows the caller in the pipeline"""
|
| 580 |
+
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
|
| 581 |
+
rank_in_pipeline = get_pp_rank()
|
| 582 |
+
world_size = get_pp_world_size()
|
| 583 |
+
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def get_pipeline_model_parallel_prev_rank():
|
| 587 |
+
"""Return the global rank that preceeds the caller in the pipeline"""
|
| 588 |
+
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
|
| 589 |
+
rank_in_pipeline = get_pp_rank()
|
| 590 |
+
world_size = get_pp_world_size()
|
| 591 |
+
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def get_dp_world_size(with_context_parallel=False):
|
| 595 |
+
"""Return world size for the data parallel group."""
|
| 596 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
| 597 |
+
return torch.distributed.get_world_size(group=get_dp_group(with_context_parallel=with_context_parallel))
|
| 598 |
+
else:
|
| 599 |
+
return 0
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
def get_dp_rank(with_context_parallel=False):
|
| 603 |
+
"""Return my rank for the data parallel group."""
|
| 604 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
| 605 |
+
return torch.distributed.get_rank(group=get_dp_group(with_context_parallel=with_context_parallel))
|
| 606 |
+
else:
|
| 607 |
+
return 0
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
def get_cp_world_size():
|
| 611 |
+
"""Return world size for the context parallel group."""
|
| 612 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
| 613 |
+
return torch.distributed.get_world_size(group=get_cp_group())
|
| 614 |
+
else:
|
| 615 |
+
return 0
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
def get_cp_rank():
|
| 619 |
+
"""Return my rank for the context parallel group."""
|
| 620 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
| 621 |
+
return torch.distributed.get_rank(group=get_cp_group())
|
| 622 |
+
else:
|
| 623 |
+
return 0
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
def destroy_model_parallel():
|
| 627 |
+
"""Set the groups to none."""
|
| 628 |
+
global _MODEL_PARALLEL_GROUP
|
| 629 |
+
_MODEL_PARALLEL_GROUP = None
|
| 630 |
+
global _TENSOR_MODEL_PARALLEL_GROUP
|
| 631 |
+
_TENSOR_MODEL_PARALLEL_GROUP = None
|
| 632 |
+
global _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP
|
| 633 |
+
_TENSOR_MODEL_PARALLEL_GROUP_WITH_CP = None
|
| 634 |
+
global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP
|
| 635 |
+
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP = None
|
| 636 |
+
global _PIPELINE_MODEL_PARALLEL_GROUP
|
| 637 |
+
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
| 638 |
+
global _DATA_PARALLEL_GROUP
|
| 639 |
+
_DATA_PARALLEL_GROUP = None
|
| 640 |
+
global _TENSOR_AND_DATA_PARALLEL_GROUP
|
| 641 |
+
_TENSOR_AND_DATA_PARALLEL_GROUP = None
|
| 642 |
+
global _PIPELINE_GLOBAL_RANKS
|
| 643 |
+
_PIPELINE_GLOBAL_RANKS = None
|
| 644 |
+
global _DATA_PARALLEL_GLOBAL_RANKS
|
| 645 |
+
_DATA_PARALLEL_GLOBAL_RANKS = None
|
| 646 |
+
global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
|
| 647 |
+
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = None
|
| 648 |
+
global _CONTEXT_PARALLEL_GROUP
|
| 649 |
+
_CONTEXT_PARALLEL_GROUP = None
|
| 650 |
+
global _CONTEXT_PARALLEL_GLOBAL_RANKS
|
| 651 |
+
_CONTEXT_PARALLEL_GLOBAL_RANKS = None
|
| 652 |
+
global _CONTEXT_PARALLEL_EXTRA_GROUP
|
| 653 |
+
_CONTEXT_PARALLEL_EXTRA_GROUP = None
|
| 654 |
+
global _DATA_PARALLEL_GROUP_WITH_CP
|
| 655 |
+
_DATA_PARALLEL_GROUP_WITH_CP = None
|
| 656 |
+
global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
|
| 657 |
+
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None
|
| 658 |
+
global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
|
| 659 |
+
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None
|
inference/infra/distributed/utils.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
from .parallel_state import get_tp_rank, get_tp_world_size
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def is_last_rank():
|
| 21 |
+
return torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def is_last_tp_cp_rank():
|
| 25 |
+
return get_tp_rank(with_context_parallel=True) == get_tp_world_size(with_context_parallel=True) - 1
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_world_size():
|
| 29 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
| 30 |
+
world_size = torch.distributed.get_world_size()
|
| 31 |
+
else:
|
| 32 |
+
world_size = 1
|
| 33 |
+
return world_size
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_device(local_rank=None):
|
| 37 |
+
backend = torch.distributed.get_backend()
|
| 38 |
+
if backend == "nccl":
|
| 39 |
+
if local_rank is None:
|
| 40 |
+
device = torch.device("cuda")
|
| 41 |
+
else:
|
| 42 |
+
device = torch.device(f"cuda:{local_rank}")
|
| 43 |
+
elif backend == "gloo":
|
| 44 |
+
device = torch.device("cpu")
|
| 45 |
+
else:
|
| 46 |
+
raise RuntimeError
|
| 47 |
+
return device
|
inference/infra/parallelism/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .ulysses_scheduler import ulysses_scheduler
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
# context parallel
|
| 19 |
+
"ulysses_scheduler",
|
| 20 |
+
]
|
inference/infra/parallelism/all_to_all_primitive.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import List, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
|
| 21 |
+
from inference.utils import divide
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class FakeHandle:
|
| 25 |
+
def __init__(self):
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
def wait(self):
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def scatter_head_gather_seqlen(
|
| 33 |
+
tensor: torch.Tensor, split_sizes: List[int] = None, group: dist.ProcessGroup = None, async_op: bool = True
|
| 34 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, Union[dist.Work, FakeHandle]]]:
|
| 35 |
+
"""
|
| 36 |
+
Scatter head_number and gather seq_len, for example:
|
| 37 |
+
input: (seq_len, cp * hn, hd)
|
| 38 |
+
output: (seq_len * cp, hn, hd)
|
| 39 |
+
NOTE: seq_len of input maybe not equal, which depends on split_sizes[rank]
|
| 40 |
+
"""
|
| 41 |
+
if group is None or dist.get_world_size(group) == 1:
|
| 42 |
+
return tensor, FakeHandle()
|
| 43 |
+
group_world_size = dist.get_world_size(group)
|
| 44 |
+
if split_sizes is None:
|
| 45 |
+
split_sizes = [tensor.shape[0]] * group_world_size
|
| 46 |
+
|
| 47 |
+
_, hn, _ = tensor.shape
|
| 48 |
+
if group_world_size % hn == 0 and group_world_size != hn:
|
| 49 |
+
tensor = torch.repeat_interleave(tensor, repeats=divide(group_world_size, hn), dim=1).contiguous()
|
| 50 |
+
assert tensor.is_contiguous()
|
| 51 |
+
input_split_sizes = [tensor.shape[0]] * group_world_size
|
| 52 |
+
input = rearrange(tensor, "seq (cp hn) hd -> (cp seq) hn hd", cp=group_world_size).contiguous()
|
| 53 |
+
output = torch.empty([sum(split_sizes), *input.shape[1:]], device=input.device, dtype=input.dtype)
|
| 54 |
+
if async_op:
|
| 55 |
+
handle = dist.all_to_all_single(
|
| 56 |
+
output, input, output_split_sizes=split_sizes, input_split_sizes=input_split_sizes, group=group, async_op=True
|
| 57 |
+
)
|
| 58 |
+
return output, handle
|
| 59 |
+
else:
|
| 60 |
+
dist.all_to_all_single(
|
| 61 |
+
output, input, output_split_sizes=split_sizes, input_split_sizes=input_split_sizes, group=group, async_op=False
|
| 62 |
+
)
|
| 63 |
+
return output
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def scatter_seqlen_gather_head(
|
| 67 |
+
tensor: torch.Tensor, split_sizes: List[int] = None, group: dist.ProcessGroup = None, async_op: bool = True
|
| 68 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, Union[dist.Work, FakeHandle]]]:
|
| 69 |
+
"""
|
| 70 |
+
Scatter seq_len and gather head_number, for example:
|
| 71 |
+
input: (seq_len * cp, hn, hd)
|
| 72 |
+
output: (seq_len, cp * hn, hd)
|
| 73 |
+
NOTE: seq_len of output maybe not equal, which depends on split_sizes[rank]
|
| 74 |
+
NOTE: rearrange the tensor after communication: (cp, seq, hn, hd) -> (seq, cp * hn, hd)
|
| 75 |
+
"""
|
| 76 |
+
if group is None or dist.get_world_size(group) == 1:
|
| 77 |
+
return tensor, FakeHandle() if async_op else tensor
|
| 78 |
+
group_world_size = dist.get_world_size(group)
|
| 79 |
+
if split_sizes is None:
|
| 80 |
+
assert (
|
| 81 |
+
tensor.shape[0] % group_world_size == 0
|
| 82 |
+
), f"tensor.shape[0] {tensor.shape[0]} % group_world_size {group_world_size} != 0"
|
| 83 |
+
split_sizes = [tensor.shape[0] // group_world_size] * group_world_size
|
| 84 |
+
assert tensor.is_contiguous()
|
| 85 |
+
assert tensor.dim() == 3, f"tensor must be 3D, but got {tensor.dim()}D"
|
| 86 |
+
output = torch.empty(
|
| 87 |
+
[group_world_size * split_sizes[dist.get_rank(group)], *tensor.shape[1:]], device=tensor.device, dtype=tensor.dtype
|
| 88 |
+
)
|
| 89 |
+
output_split_sizes = [split_sizes[dist.get_rank(group)]] * group_world_size
|
| 90 |
+
if async_op:
|
| 91 |
+
handle = dist.all_to_all_single(
|
| 92 |
+
output, tensor, output_split_sizes=output_split_sizes, input_split_sizes=split_sizes, group=group, async_op=True
|
| 93 |
+
)
|
| 94 |
+
return output, handle
|
| 95 |
+
else:
|
| 96 |
+
dist.all_to_all_single(
|
| 97 |
+
output, tensor, output_split_sizes=output_split_sizes, input_split_sizes=split_sizes, group=group, async_op=False
|
| 98 |
+
)
|
| 99 |
+
return output
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def batch_scatter_head_gather_seqlen(
|
| 103 |
+
inputs: List[torch.Tensor], split_sizes: List[int] = None, group: dist.ProcessGroup = None
|
| 104 |
+
) -> List[torch.Tensor]:
|
| 105 |
+
"""
|
| 106 |
+
Batch scatter head_number and gather seq_len, for example:
|
| 107 |
+
inputs[i] input: (seq_len_i, cp * hn_i, hd)
|
| 108 |
+
outputs[i] output: (seq_len_i * cp, hn_i, hd)
|
| 109 |
+
NOTE: seq_len of inputs maybe not equal across ranks, which depends on split_sizes[rank]
|
| 110 |
+
NOTE: fuse along head dim before communication, and split back after
|
| 111 |
+
"""
|
| 112 |
+
if group is None or dist.get_world_size(group) == 1:
|
| 113 |
+
return inputs
|
| 114 |
+
rank = dist.get_rank(group)
|
| 115 |
+
group_world_size = dist.get_world_size(group)
|
| 116 |
+
if split_sizes is None:
|
| 117 |
+
split_sizes = [inputs[0].shape[0]] * group_world_size
|
| 118 |
+
assert all(
|
| 119 |
+
input.shape[0] == split_sizes[rank] for input in inputs
|
| 120 |
+
), f"inputs[0].shape[0] {inputs[0].shape[0]} != split_sizes[rank] {split_sizes[rank]}"
|
| 121 |
+
assert all(input.dim() == 3 for input in inputs), f"inputs[0].dim() {inputs[0].dim()} != 3"
|
| 122 |
+
for idx in range(len(inputs)):
|
| 123 |
+
_, hn, _ = inputs[idx].shape
|
| 124 |
+
if group_world_size % hn == 0 and group_world_size != hn:
|
| 125 |
+
inputs[idx] = torch.repeat_interleave(inputs[idx], repeats=divide(group_world_size, hn), dim=1)
|
| 126 |
+
inputs[idx] = rearrange(inputs[idx], "seq (cp hn) hd -> (cp seq) hn hd", cp=group_world_size).contiguous()
|
| 127 |
+
|
| 128 |
+
head_split_number = [input.shape[1] for input in inputs]
|
| 129 |
+
fused_input = torch.cat(inputs, dim=1).contiguous()
|
| 130 |
+
input_split_sizes = [fused_input.shape[0] // group_world_size] * group_world_size
|
| 131 |
+
|
| 132 |
+
fused_output = torch.empty([sum(split_sizes), *fused_input.shape[1:]], device=fused_input.device, dtype=fused_input.dtype)
|
| 133 |
+
dist.all_to_all_single(
|
| 134 |
+
fused_output,
|
| 135 |
+
fused_input,
|
| 136 |
+
output_split_sizes=split_sizes,
|
| 137 |
+
input_split_sizes=input_split_sizes,
|
| 138 |
+
group=group,
|
| 139 |
+
async_op=False,
|
| 140 |
+
)
|
| 141 |
+
outputs = torch.split(fused_output, head_split_number, dim=1)
|
| 142 |
+
return outputs
|
inference/infra/parallelism/gather_scatter_primitive.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from functools import partial
|
| 16 |
+
from typing import List, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.distributed as dist
|
| 20 |
+
from torch.utils._pytree import tree_map
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Metadata:
|
| 24 |
+
def __init__(self, dtype: torch.dtype, numel: int, ndim: int, shape: List[int]):
|
| 25 |
+
self.dtype = dtype
|
| 26 |
+
self.numel = numel
|
| 27 |
+
self.ndim = ndim
|
| 28 |
+
self.shape = shape
|
| 29 |
+
|
| 30 |
+
def __repr__(self):
|
| 31 |
+
return f"Metadata(dtype={self.dtype}, numel={self.numel}, ndim={self.ndim}, shape={self.shape})"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _gather_metadata(tensor_list: List[torch.Tensor], group: dist.ProcessGroup) -> List[List[Metadata]]:
|
| 35 |
+
dist.get_rank(group)
|
| 36 |
+
world_size = dist.get_world_size(group)
|
| 37 |
+
|
| 38 |
+
local_rank = torch.distributed.get_rank() % torch.cuda.device_count()
|
| 39 |
+
assert (
|
| 40 |
+
local_rank == torch.cuda.current_device()
|
| 41 |
+
), f"local_rank {local_rank} != current_device {torch.cuda.current_device()}"
|
| 42 |
+
device = tensor_list[0].device if len(tensor_list) > 0 else torch.device("cuda")
|
| 43 |
+
|
| 44 |
+
# ========== Step 1: flatten local tensor list ==========
|
| 45 |
+
|
| 46 |
+
# Metadata: [dtype_code, numel, ndim, *shape]
|
| 47 |
+
local_metadata = []
|
| 48 |
+
|
| 49 |
+
dtype_map = {torch.float32: 0, torch.float16: 1, torch.bfloat16: 2, torch.int32: 3, torch.int64: 4, torch.uint8: 5}
|
| 50 |
+
reverse_dtype_map = {v: k for k, v in dtype_map.items()}
|
| 51 |
+
|
| 52 |
+
for t in tensor_list:
|
| 53 |
+
dtype_code = dtype_map[t.dtype]
|
| 54 |
+
shape = list(t.shape)
|
| 55 |
+
numel = t.numel()
|
| 56 |
+
local_metadata.append(torch.tensor([dtype_code, numel, len(shape)] + shape, dtype=torch.int32, device=device))
|
| 57 |
+
|
| 58 |
+
if local_metadata:
|
| 59 |
+
local_metadata_tensor = torch.cat(local_metadata)
|
| 60 |
+
else:
|
| 61 |
+
local_metadata_tensor = torch.empty(0, dtype=torch.int32, device=device)
|
| 62 |
+
local_metadata_tensor = local_metadata_tensor.contiguous()
|
| 63 |
+
local_metadata_len = torch.tensor([local_metadata_tensor.numel()], dtype=torch.int32, device=device)
|
| 64 |
+
|
| 65 |
+
# ========== Step 2: all_gather metadata lengths ==========
|
| 66 |
+
metadata_lens = [torch.empty_like(local_metadata_len) for _ in range(world_size)]
|
| 67 |
+
dist.all_gather(metadata_lens, local_metadata_len, group)
|
| 68 |
+
|
| 69 |
+
# ========== Step 3: all_gather metadata payloads (with cpu tensor) ==========
|
| 70 |
+
metadata_lists = [torch.empty(m.item(), dtype=torch.int32, device=device) for m in metadata_lens]
|
| 71 |
+
dist.all_gather(metadata_lists, local_metadata_tensor, group)
|
| 72 |
+
|
| 73 |
+
# ========== Step 4: decode metadata and reconstruct tensor list ==========
|
| 74 |
+
result = []
|
| 75 |
+
for metadata_list in metadata_lists:
|
| 76 |
+
offset = 0
|
| 77 |
+
local_metadata = []
|
| 78 |
+
while offset < metadata_list.numel():
|
| 79 |
+
dtype_code = metadata_list[offset].item()
|
| 80 |
+
numel = metadata_list[offset + 1].item()
|
| 81 |
+
ndim = metadata_list[offset + 2].item()
|
| 82 |
+
shape = metadata_list[offset + 3 : offset + 3 + ndim].tolist()
|
| 83 |
+
offset += 3 + ndim
|
| 84 |
+
|
| 85 |
+
local_metadata.append(Metadata(reverse_dtype_map[dtype_code], numel, ndim, shape))
|
| 86 |
+
result.append(local_metadata)
|
| 87 |
+
|
| 88 |
+
return result
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _get_dtype_and_assert_consistency(metadata_lists: List[List[Metadata]]):
|
| 92 |
+
dtype_set = set()
|
| 93 |
+
for metadata_list in metadata_lists:
|
| 94 |
+
for metadata in metadata_list:
|
| 95 |
+
dtype_set.add(metadata.dtype)
|
| 96 |
+
assert len(dtype_set) == 1, f"Metadata lists are not consistent: {dtype_set}"
|
| 97 |
+
return dtype_set.pop()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _get_numel_for_each_rank(metadata_lists: List[List[Metadata]]) -> List[int]:
|
| 101 |
+
return [sum(meta.numel for meta in metadata_list) for metadata_list in metadata_lists]
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def gather_arbitrary_tensor_list(tensor_list: List[torch.Tensor], group: dist.ProcessGroup) -> List[torch.Tensor]:
|
| 105 |
+
"""
|
| 106 |
+
Magic gather primitive. Provide the following features:
|
| 107 |
+
1. Support tensor list with different length for each rank.
|
| 108 |
+
2. Support arbitrary Tensor, which means the Tensor can have different shapes but same dtype.
|
| 109 |
+
3. Support empty tensor_list in some ranks without padding.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
tensor_list: A list of tensors to gather.
|
| 113 |
+
group: The process group to use.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
A list of tensors gathered from all ranks.
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
dist.get_rank(group)
|
| 120 |
+
world_size = dist.get_world_size(group)
|
| 121 |
+
|
| 122 |
+
local_rank = torch.distributed.get_rank() % torch.cuda.device_count()
|
| 123 |
+
assert (
|
| 124 |
+
local_rank == torch.cuda.current_device()
|
| 125 |
+
), f"local_rank {local_rank} != current_device {torch.cuda.current_device()}"
|
| 126 |
+
device = tensor_list[0].device if len(tensor_list) > 0 else torch.device("cuda")
|
| 127 |
+
|
| 128 |
+
# Step 1: Gather metadata
|
| 129 |
+
metadata_lists = _gather_metadata(tensor_list, group)
|
| 130 |
+
tensor_dtype = _get_dtype_and_assert_consistency(metadata_lists)
|
| 131 |
+
|
| 132 |
+
# Step 2: Flatten local tensors into a single 1D buffer
|
| 133 |
+
if tensor_list:
|
| 134 |
+
flat_tensor = torch.cat([t.flatten() for t in tensor_list], dim=0).contiguous()
|
| 135 |
+
else:
|
| 136 |
+
flat_tensor = torch.empty(0, dtype=tensor_dtype, device=device) # dummy, will be ignored
|
| 137 |
+
|
| 138 |
+
# Step 3: Gather lengths from metadata
|
| 139 |
+
all_numels_int = _get_numel_for_each_rank(metadata_lists)
|
| 140 |
+
|
| 141 |
+
# Step 4: Allocate buffers and gather flat tensor data
|
| 142 |
+
output_flat_tensors = []
|
| 143 |
+
for numel in all_numels_int:
|
| 144 |
+
output_flat_tensors.append(torch.empty(numel, dtype=tensor_dtype, device=device))
|
| 145 |
+
dist.all_gather(output_flat_tensors, flat_tensor, group)
|
| 146 |
+
|
| 147 |
+
# Step 5: Reconstruct individual tensors using metadata
|
| 148 |
+
gathered_tensor_lists = []
|
| 149 |
+
for i in range(world_size):
|
| 150 |
+
flat = output_flat_tensors[i]
|
| 151 |
+
if flat.numel() == 0:
|
| 152 |
+
continue
|
| 153 |
+
metadata_list = metadata_lists[i]
|
| 154 |
+
offset = 0
|
| 155 |
+
for meta in metadata_list:
|
| 156 |
+
numel = meta.numel
|
| 157 |
+
t = flat[offset : offset + numel].view(meta.shape).to(meta.dtype)
|
| 158 |
+
offset += numel
|
| 159 |
+
gathered_tensor_lists.append(t)
|
| 160 |
+
|
| 161 |
+
return gathered_tensor_lists
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _scatter_to_context_parallel_region(input: torch.Tensor, split_sizes: List[int], group: dist.ProcessGroup = None):
|
| 165 |
+
"""Split the tensor along its first dimension and keep the
|
| 166 |
+
corresponding slice."""
|
| 167 |
+
# Split along first dimension with padding.
|
| 168 |
+
rank = dist.get_rank(group)
|
| 169 |
+
dim_offset = sum(split_sizes[:rank])
|
| 170 |
+
output = input[dim_offset : dim_offset + split_sizes[rank]].contiguous()
|
| 171 |
+
return output
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def scatter_to_context_parallel_region(
|
| 175 |
+
inputs: Union[torch.Tensor, List[torch.Tensor]], split_sizes: List[int] = None, group: dist.ProcessGroup = None
|
| 176 |
+
):
|
| 177 |
+
"""Split the tensor along its first dimension and keep the
|
| 178 |
+
corresponding slice."""
|
| 179 |
+
if group is None or torch.distributed.get_world_size(group) == 1:
|
| 180 |
+
return inputs
|
| 181 |
+
|
| 182 |
+
if split_sizes is None:
|
| 183 |
+
assert (
|
| 184 |
+
inputs.shape[0] % dist.get_world_size(group) == 0
|
| 185 |
+
), f"inputs.shape[0] {inputs.shape[0]} % dist.get_world_size(group) {dist.get_world_size(group)} != 0"
|
| 186 |
+
split_sizes = [inputs.shape[0] // dist.get_world_size(group)] * dist.get_world_size(group)
|
| 187 |
+
|
| 188 |
+
partial_func = partial(_scatter_to_context_parallel_region, split_sizes=split_sizes, group=group)
|
| 189 |
+
return tree_map(partial_func, inputs)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _gather_from_context_parallel_region(
|
| 193 |
+
input: Union[torch.Tensor, List[torch.Tensor]], split_sizes: List[int], group: dist.ProcessGroup = None
|
| 194 |
+
):
|
| 195 |
+
input = input.contiguous()
|
| 196 |
+
dim_size = list(input.size())
|
| 197 |
+
dim_size[0] = sum(split_sizes)
|
| 198 |
+
|
| 199 |
+
output = torch.empty(dim_size, dtype=input.dtype, device=input.device)
|
| 200 |
+
outputs = list(torch.split(output, split_sizes, dim=0))
|
| 201 |
+
torch.distributed.all_gather(outputs, input, group=group)
|
| 202 |
+
output = torch.concat(outputs, dim=0)
|
| 203 |
+
|
| 204 |
+
return output
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def gather_from_context_parallel_region(
|
| 208 |
+
inputs: Union[torch.Tensor, List[torch.Tensor]], split_sizes: List[int] = None, group: dist.ProcessGroup = None
|
| 209 |
+
):
|
| 210 |
+
"""Gather tensors and concatinate along the first dimension."""
|
| 211 |
+
if group is None or torch.distributed.get_world_size(group) == 1:
|
| 212 |
+
return inputs
|
| 213 |
+
|
| 214 |
+
if split_sizes is None:
|
| 215 |
+
split_sizes = [inputs.shape[0] * dist.get_world_size(group)]
|
| 216 |
+
partial_func = partial(_gather_from_context_parallel_region, split_sizes=split_sizes, group=group)
|
| 217 |
+
return tree_map(partial_func, inputs)
|
inference/infra/parallelism/ulysses_scheduler.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Generic, List, Optional, TypeVar
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch.utils._pytree import tree_map
|
| 19 |
+
|
| 20 |
+
from inference.infra.distributed import get_cp_group, get_cp_world_size
|
| 21 |
+
|
| 22 |
+
from .gather_scatter_primitive import gather_from_context_parallel_region, scatter_to_context_parallel_region
|
| 23 |
+
|
| 24 |
+
T = TypeVar("T")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class UlyssesScheduler(Generic[T]):
|
| 28 |
+
"""
|
| 29 |
+
A naive implementation of Ulysses scheduler for context parallel processing.
|
| 30 |
+
|
| 31 |
+
This scheduler handles tensor dispatching and undispatching operations when tensors
|
| 32 |
+
enter and exit the context parallel region. It supports arbitrary nested data structures
|
| 33 |
+
containing tensors and automatically handles padding and splitting operations.
|
| 34 |
+
|
| 35 |
+
The scheduler splits input tensors along the sequence dimension across multiple GPUs
|
| 36 |
+
in the context parallel group, enabling parallel processing of long sequences.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self):
|
| 40 |
+
"""Initialize the Ulysses scheduler."""
|
| 41 |
+
self._cp_split_sizes: Optional[List[int]] = None
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def cp_split_sizes(self):
|
| 45 |
+
"""Get the current context parallel split sizes."""
|
| 46 |
+
return self._cp_split_sizes
|
| 47 |
+
|
| 48 |
+
def _dispatch(self, x: torch.Tensor) -> torch.Tensor:
|
| 49 |
+
"""
|
| 50 |
+
Dispatch a tensor to the context parallel region.
|
| 51 |
+
|
| 52 |
+
This method automatically handles padding and splits the tensor along the sequence
|
| 53 |
+
dimension across the context parallel group. The split sizes are calculated to
|
| 54 |
+
distribute the sequence length as evenly as possible across all ranks.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
x: Input tensor with shape [seq_len, ...] where seq_len is the sequence length.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Dispatched tensor that has been split and distributed across the context parallel group.
|
| 61 |
+
|
| 62 |
+
Raises:
|
| 63 |
+
AssertionError: If the split sizes change between calls, indicating inconsistent
|
| 64 |
+
sequence lengths or context parallel group size.
|
| 65 |
+
"""
|
| 66 |
+
seq_len = x.shape[0]
|
| 67 |
+
cp_world_size = get_cp_world_size()
|
| 68 |
+
if seq_len % cp_world_size == 0:
|
| 69 |
+
cp_split_sizes = [seq_len // cp_world_size] * cp_world_size
|
| 70 |
+
else:
|
| 71 |
+
num_ranks_with_one_extra = seq_len % cp_world_size
|
| 72 |
+
min_tokens_per_rank = (seq_len - num_ranks_with_one_extra) // cp_world_size
|
| 73 |
+
cp_split_sizes = [min_tokens_per_rank + 1] * num_ranks_with_one_extra + [min_tokens_per_rank] * (
|
| 74 |
+
cp_world_size - num_ranks_with_one_extra
|
| 75 |
+
)
|
| 76 |
+
if self._cp_split_sizes is not None:
|
| 77 |
+
assert (
|
| 78 |
+
self._cp_split_sizes == cp_split_sizes
|
| 79 |
+
), f"cp_split_sizes changed from {self._cp_split_sizes} to {cp_split_sizes}"
|
| 80 |
+
self._cp_split_sizes = cp_split_sizes
|
| 81 |
+
x = scatter_to_context_parallel_region(x, cp_split_sizes, group=get_cp_group())
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
def _undispatch(self, x: torch.Tensor) -> torch.Tensor:
|
| 85 |
+
"""
|
| 86 |
+
Undispatch a tensor from the context parallel region.
|
| 87 |
+
|
| 88 |
+
This method gathers the tensor parts from all ranks in the context parallel group
|
| 89 |
+
and concatenates them back into the original sequence. It automatically handles
|
| 90 |
+
unpadding if padding was applied during dispatch.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
x: Dispatched tensor from the context parallel region.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Reconstructed tensor with the original sequence length.
|
| 97 |
+
"""
|
| 98 |
+
x = gather_from_context_parallel_region(x, self._cp_split_sizes, group=get_cp_group())
|
| 99 |
+
return x
|
| 100 |
+
|
| 101 |
+
def dispatch(self, tensors: T) -> T:
|
| 102 |
+
"""
|
| 103 |
+
Apply dispatch operation to all tensor leaf nodes in a nested data structure.
|
| 104 |
+
|
| 105 |
+
This method recursively applies the _dispatch operation to all tensors in the
|
| 106 |
+
input data structure, preparing them for context parallel computation. The
|
| 107 |
+
structure of the input is preserved in the output.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
tensors: Arbitrary nested data structure containing tensors (single tensor,
|
| 111 |
+
tuple, list, dict, etc.). All tensors should have the same sequence
|
| 112 |
+
length in their first dimension.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
A new data structure with the same structure as input, where all tensors
|
| 116 |
+
have been dispatched to the context parallel region.
|
| 117 |
+
"""
|
| 118 |
+
return tree_map(self._dispatch, tensors)
|
| 119 |
+
|
| 120 |
+
def undispatch(self, tensors: T) -> T:
|
| 121 |
+
"""
|
| 122 |
+
Apply undispatch operation to all tensor leaf nodes in a nested data structure.
|
| 123 |
+
|
| 124 |
+
This method recursively applies the _undispatch operation to all tensors in the
|
| 125 |
+
input data structure, reconstructing them from the context parallel region. The
|
| 126 |
+
structure of the input is preserved in the output.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
tensors: Arbitrary nested data structure containing dispatched tensors.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
A new data structure with the same structure as input, where all tensors
|
| 133 |
+
have been reconstructed from the context parallel region.
|
| 134 |
+
"""
|
| 135 |
+
output = tree_map(self._undispatch, tensors)
|
| 136 |
+
self._cp_split_sizes = None
|
| 137 |
+
return output
|
| 138 |
+
|
| 139 |
+
_ULYSSES_SCHEDULER = UlyssesScheduler()
|
| 140 |
+
|
| 141 |
+
def ulysses_scheduler() -> UlyssesScheduler:
|
| 142 |
+
assert _ULYSSES_SCHEDULER is not None, "ulysses scheduler is not initialized"
|
| 143 |
+
return _ULYSSES_SCHEDULER
|
inference/model/dit/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .dit_model import get_dit
|
| 16 |
+
from .dit_module import DiTModel
|
| 17 |
+
|
| 18 |
+
__all__ = ["DiTModel", "get_dit"]
|
inference/model/dit/dit_model.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import gc
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from inference.infra.checkpoint import load_model_checkpoint
|
| 19 |
+
from inference.infra.distributed import get_cp_rank, get_pp_rank, get_tp_rank
|
| 20 |
+
from inference.utils import print_mem_info_rank_0, print_model_size, print_rank_0
|
| 21 |
+
|
| 22 |
+
from .dit_module import DiTModel
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_dit(model_config, engine_config):
|
| 26 |
+
"""Build and load DiT model."""
|
| 27 |
+
model = DiTModel(model_config=model_config)
|
| 28 |
+
|
| 29 |
+
print_rank_0("Build dit model successfully")
|
| 30 |
+
print_rank_0(model)
|
| 31 |
+
print_model_size(
|
| 32 |
+
model, prefix=f"(tp, cp, pp) rank ({get_tp_rank()}, {get_cp_rank()}, {get_pp_rank()}): ", print_func=print_rank_0
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
model = load_model_checkpoint(model, engine_config)
|
| 36 |
+
model.cuda(torch.cuda.current_device())
|
| 37 |
+
model.eval()
|
| 38 |
+
print_mem_info_rank_0("Load model successfully")
|
| 39 |
+
|
| 40 |
+
gc.collect()
|
| 41 |
+
torch.cuda.empty_cache()
|
| 42 |
+
return model
|
inference/model/dit/dit_module.py
ADDED
|
@@ -0,0 +1,950 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import importlib
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from enum import Enum
|
| 18 |
+
from typing import Any, Callable, List, Optional, Tuple
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from einops import rearrange, repeat
|
| 23 |
+
from inference.common import Modality, VarlenHandler, is_hopper_arch
|
| 24 |
+
from inference.infra.parallelism import ulysses_scheduler
|
| 25 |
+
from magi_compiler import magi_compile
|
| 26 |
+
from magi_compiler.api import magi_register_custom_op
|
| 27 |
+
from magi_compiler.config import CompileConfig
|
| 28 |
+
from torch import Tensor
|
| 29 |
+
from torch.nn import Parameter
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class FFAHandler:
|
| 34 |
+
q_ranges: torch.Tensor
|
| 35 |
+
k_ranges: torch.Tensor
|
| 36 |
+
max_seqlen_q: int
|
| 37 |
+
max_seqlen_k: int
|
| 38 |
+
attn_type_map: torch.Tensor
|
| 39 |
+
softmax_scale: float
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Define the MLP activation type
|
| 43 |
+
class MLPActivationType(Enum):
|
| 44 |
+
"""Enumeration of supported activation functions for MLP"""
|
| 45 |
+
|
| 46 |
+
SWIGLU7 = "swiglu7"
|
| 47 |
+
GELU7 = "gelu7"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def swiglu7(x, alpha: float = 1.702, limit: float = 7.0, out_dtype: Optional[torch.dtype] = None):
|
| 51 |
+
out_dtype = x.dtype if out_dtype is None else out_dtype
|
| 52 |
+
x = x.to(torch.float32)
|
| 53 |
+
x_glu, x_linear = x[..., ::2], x[..., 1::2]
|
| 54 |
+
# Clamp the input values
|
| 55 |
+
x_glu = x_glu.clamp(min=None, max=limit)
|
| 56 |
+
x_linear = x_linear.clamp(min=-limit, max=limit)
|
| 57 |
+
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
|
| 58 |
+
# Note we add an extra bias of 1 to the linear layer (from GPT-OSS)
|
| 59 |
+
return (out_glu * (x_linear + 1)).to(out_dtype)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def gelu7(x, alpha: float = 1.702, limit: float = 7.0, out_dtype: Optional[torch.dtype] = None):
|
| 63 |
+
out_dtype = x.dtype if out_dtype is None else out_dtype
|
| 64 |
+
x = x.to(torch.float32)
|
| 65 |
+
x_glu = x
|
| 66 |
+
# Clamp the input values
|
| 67 |
+
x_glu = x_glu.clamp(min=None, max=limit)
|
| 68 |
+
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
|
| 69 |
+
# Note we add an extra bias of 1 to the linear layer
|
| 70 |
+
return out_glu.to(out_dtype)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def create_activation_func(activation_type: MLPActivationType) -> Callable:
|
| 74 |
+
match activation_type:
|
| 75 |
+
case MLPActivationType.SWIGLU7:
|
| 76 |
+
return swiglu7
|
| 77 |
+
case MLPActivationType.GELU7:
|
| 78 |
+
return gelu7
|
| 79 |
+
case _:
|
| 80 |
+
raise ValueError(f"Unknown activation type: {activation_type}")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class ModalityDispatcher:
|
| 84 |
+
permuted_modality_mapping: torch.Tensor
|
| 85 |
+
group_size: torch.Tensor
|
| 86 |
+
group_size_cpu: list[int]
|
| 87 |
+
num_modalities: int
|
| 88 |
+
|
| 89 |
+
def __init__(self, modality_mapping: torch.Tensor, num_modalities: int):
|
| 90 |
+
"""
|
| 91 |
+
Initialize dispatcher.
|
| 92 |
+
This runs once during object construction and precomputes all mappings.
|
| 93 |
+
"""
|
| 94 |
+
self.modality_mapping = modality_mapping
|
| 95 |
+
self.num_modalities = num_modalities
|
| 96 |
+
|
| 97 |
+
self.permuted_modality_mapping = self._precompute_permute_mapping(modality_mapping)
|
| 98 |
+
|
| 99 |
+
self.group_size = torch.bincount(self.permuted_modality_mapping, minlength=num_modalities).to(torch.int32)
|
| 100 |
+
self.group_size_cpu: list[int] = [int(x) for x in self.group_size.to("cpu").tolist()]
|
| 101 |
+
|
| 102 |
+
def _precompute_permute_mapping(self, modality_mapping):
|
| 103 |
+
# 1. Compute forward and inverse permutation mappings.
|
| 104 |
+
# argsort is an efficient O(N log N) operation.
|
| 105 |
+
self.permute_mapping = torch.argsort(modality_mapping)
|
| 106 |
+
self.inv_permute_mapping = torch.argsort(self.permute_mapping)
|
| 107 |
+
|
| 108 |
+
# 2. Compute group size for each modality.
|
| 109 |
+
# bincount is highly efficient for counting.
|
| 110 |
+
permuted_modality_mapping = modality_mapping[self.permute_mapping]
|
| 111 |
+
|
| 112 |
+
return permuted_modality_mapping
|
| 113 |
+
|
| 114 |
+
def dispatch(self, x: torch.Tensor) -> list[torch.Tensor]:
|
| 115 |
+
grouped_tensors = torch.split(x, self.group_size_cpu, dim=0)
|
| 116 |
+
return list(grouped_tensors)
|
| 117 |
+
|
| 118 |
+
def undispatch(self, *processed_groups: list[torch.Tensor]) -> torch.Tensor:
|
| 119 |
+
return torch.cat(processed_groups, dim=0)
|
| 120 |
+
|
| 121 |
+
@staticmethod
|
| 122 |
+
def permute(x: torch.Tensor, permute_mapping: torch.Tensor) -> torch.Tensor:
|
| 123 |
+
"""Apply forward permutation to tensor."""
|
| 124 |
+
return x[permute_mapping]
|
| 125 |
+
|
| 126 |
+
@staticmethod
|
| 127 |
+
def inv_permute(x: torch.Tensor, inv_permute_mapping: torch.Tensor) -> torch.Tensor:
|
| 128 |
+
"""Apply inverse permutation to tensor."""
|
| 129 |
+
return x[inv_permute_mapping]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def freq_bands(
|
| 133 |
+
num_bands: int, temperature: float = 10000.0, step: int = 2, device: Optional[torch.device] = None
|
| 134 |
+
) -> torch.Tensor:
|
| 135 |
+
exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
|
| 136 |
+
bands = 1.0 / (temperature**exp)
|
| 137 |
+
return bands
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def rotate_half(x, interleaved=False):
|
| 141 |
+
if not interleaved:
|
| 142 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 143 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 144 |
+
else:
|
| 145 |
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
| 146 |
+
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
| 150 |
+
"""
|
| 151 |
+
x: (batch_size, seqlen, nheads, headdim)
|
| 152 |
+
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
| 153 |
+
"""
|
| 154 |
+
ro_dim = cos.shape[-1] * 2
|
| 155 |
+
assert ro_dim <= x.shape[-1]
|
| 156 |
+
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
| 157 |
+
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
| 158 |
+
return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], dim=-1)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class ElementWiseFourierEmbed(nn.Module):
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
dim: int,
|
| 165 |
+
max_res: int = 224,
|
| 166 |
+
temperature: float = 10000.0,
|
| 167 |
+
in_pixels: bool = True,
|
| 168 |
+
linear_bands: bool = False,
|
| 169 |
+
learnable: bool = False,
|
| 170 |
+
device: torch.device = torch.device("cpu"),
|
| 171 |
+
dtype: torch.dtype = torch.float32,
|
| 172 |
+
):
|
| 173 |
+
"""
|
| 174 |
+
Args:
|
| 175 |
+
dim: Output feature dimension, total channels, must be divisible by 6
|
| 176 |
+
max_res: Max pixel-frequency resolution for pixel-domain bands
|
| 177 |
+
temperature: Temperature in inverse-frequency mode
|
| 178 |
+
in_pixels: True -> pixel-frequency bands, False -> inverse-frequency bands
|
| 179 |
+
linear_bands: Whether pixel-frequency bands are linearly spaced
|
| 180 |
+
learnable: Whether frequency bands are trainable
|
| 181 |
+
"""
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.dim = dim
|
| 184 |
+
self.in_pixels = in_pixels
|
| 185 |
+
self.learnable = learnable
|
| 186 |
+
self.temperature = temperature
|
| 187 |
+
self.max_res = max_res
|
| 188 |
+
self.linear_bands = linear_bands
|
| 189 |
+
self.device = device
|
| 190 |
+
self.dtype = dtype
|
| 191 |
+
# Make frequency bands trainable or register as buffer
|
| 192 |
+
bands = self.get_default_bands()
|
| 193 |
+
if self.learnable:
|
| 194 |
+
self.bands = nn.Parameter(bands)
|
| 195 |
+
else:
|
| 196 |
+
self.register_buffer("bands", bands)
|
| 197 |
+
|
| 198 |
+
def forward(self, coords: torch.Tensor) -> torch.Tensor:
|
| 199 |
+
"""
|
| 200 |
+
Args:
|
| 201 |
+
coords: [L,9], column order (time, row, col, T, H, W, ref_T, ref_H, ref_W)
|
| 202 |
+
Returns:
|
| 203 |
+
emb: [L, dim] element-wise Fourier embedding
|
| 204 |
+
"""
|
| 205 |
+
# Use slicing instead of unbind + stack to reduce intermediates
|
| 206 |
+
coords_xyz = coords[:, :3] # [L,3] -> (t, h, w)
|
| 207 |
+
sizes = coords[:, 3:6] # [L,3] -> (T, H, W)
|
| 208 |
+
refs = coords[:, 6:9] # [L,3] -> (ref_T, ref_H, ref_W)
|
| 209 |
+
|
| 210 |
+
# Compute scale factors
|
| 211 |
+
scales = (refs - 1) / (sizes - 1) # [L,3]
|
| 212 |
+
|
| 213 |
+
# NOTE: if both ref and size are 1, scale is fixed to 1; otherwise invalid
|
| 214 |
+
scales[(refs == 1) & (sizes == 1)] = 1
|
| 215 |
+
assert not scales.isnan().any(), "scales has nan"
|
| 216 |
+
assert not scales.isinf().any(), "scales has inf"
|
| 217 |
+
|
| 218 |
+
# Center alignment: apply to h,w only (not time)
|
| 219 |
+
centers = (sizes - 1) / 2 # [L,3]
|
| 220 |
+
centers[:, 0] = 0 # Do not center the time dimension
|
| 221 |
+
coords_xyz = coords_xyz - centers # [L,3]
|
| 222 |
+
|
| 223 |
+
# Project to frequency bands in one shot: [L,3,B]
|
| 224 |
+
proj = coords_xyz.unsqueeze(-1) * scales.unsqueeze(-1) * self.bands
|
| 225 |
+
|
| 226 |
+
# Compute sin & cos and concatenate
|
| 227 |
+
sin_proj = proj.sin() # [L,3,B]
|
| 228 |
+
cos_proj = proj.cos()
|
| 229 |
+
|
| 230 |
+
return torch.cat((sin_proj, cos_proj), dim=1).flatten(1)
|
| 231 |
+
|
| 232 |
+
def reset_parameters(self):
|
| 233 |
+
bands = self.get_default_bands()
|
| 234 |
+
self.bands.copy_(bands)
|
| 235 |
+
|
| 236 |
+
def get_default_bands(self):
|
| 237 |
+
if self.in_pixels:
|
| 238 |
+
raise NotImplementedError("in_pixels are not implemented yet")
|
| 239 |
+
else:
|
| 240 |
+
bands = freq_bands(self.dim // 8, temperature=self.temperature, step=1, device=self.device).to(self.dtype)
|
| 241 |
+
return bands
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class MultiModalityRMSNorm(nn.Module):
|
| 245 |
+
__constants__ = ["dim", "eps", "num_modality"]
|
| 246 |
+
dim: int
|
| 247 |
+
eps: float
|
| 248 |
+
num_modality: int
|
| 249 |
+
|
| 250 |
+
def __init__(self, dim: int, eps: float = 1e-6, device: torch.device | None = None, num_modality: int = 1):
|
| 251 |
+
super().__init__()
|
| 252 |
+
self.dim = dim
|
| 253 |
+
self.eps = eps
|
| 254 |
+
self.num_modality = num_modality
|
| 255 |
+
|
| 256 |
+
self.weight = torch.nn.Parameter(torch.zeros(dim * num_modality, device=device, dtype=torch.float32))
|
| 257 |
+
if num_modality > 1:
|
| 258 |
+
self.forward = self.forward_multi_experts
|
| 259 |
+
else:
|
| 260 |
+
self.forward = self.forward_single_expert
|
| 261 |
+
|
| 262 |
+
self.reset_parameters()
|
| 263 |
+
|
| 264 |
+
def reset_parameters(self):
|
| 265 |
+
nn.init.zeros_(self.weight)
|
| 266 |
+
|
| 267 |
+
def rms(self, x: torch.Tensor) -> torch.Tensor:
|
| 268 |
+
t, original_dtype = x.float(), x.dtype
|
| 269 |
+
t = t * torch.rsqrt(torch.mean(t**2, dim=-1, keepdim=True) + self.eps)
|
| 270 |
+
return t
|
| 271 |
+
|
| 272 |
+
def forward_multi_experts(self, x: torch.Tensor, modality_dispatcher: ModalityDispatcher) -> torch.Tensor:
|
| 273 |
+
original_dtype = x.dtype
|
| 274 |
+
t = self.rms(x)
|
| 275 |
+
|
| 276 |
+
weight_chunked = self.weight.chunk(self.num_modality, dim=0)
|
| 277 |
+
t_list = modality_dispatcher.dispatch(t)
|
| 278 |
+
for i in range(self.num_modality):
|
| 279 |
+
t_list[i] = t_list[i] * (weight_chunked[i] + 1)
|
| 280 |
+
t = modality_dispatcher.undispatch(*t_list)
|
| 281 |
+
|
| 282 |
+
return t.to(original_dtype)
|
| 283 |
+
|
| 284 |
+
def forward_single_expert(self, x: torch.Tensor, modality_dispatcher: Optional[ModalityDispatcher] = None) -> torch.Tensor:
|
| 285 |
+
t, original_dtype = x.float(), x.dtype
|
| 286 |
+
t = t * torch.rsqrt(torch.mean(t**2, dim=-1, keepdim=True) + self.eps)
|
| 287 |
+
return (t * (self.weight + 1)).to(original_dtype)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class _BF16ComputeLinear(torch.autograd.Function):
|
| 291 |
+
@staticmethod
|
| 292 |
+
def forward(
|
| 293 |
+
ctx,
|
| 294 |
+
input: torch.Tensor,
|
| 295 |
+
weight: torch.Tensor,
|
| 296 |
+
bias: Optional[torch.Tensor],
|
| 297 |
+
output_dtype: Optional[torch.dtype],
|
| 298 |
+
compute_dtype: torch.dtype = torch.bfloat16,
|
| 299 |
+
):
|
| 300 |
+
# Convert input to specified input data type
|
| 301 |
+
input_cast = input.to(compute_dtype)
|
| 302 |
+
# Convert weight to computation data type
|
| 303 |
+
weight_cast = weight.to(compute_dtype)
|
| 304 |
+
# Perform linear operation
|
| 305 |
+
output = torch.matmul(input_cast, weight_cast.t())
|
| 306 |
+
|
| 307 |
+
# Add bias if present
|
| 308 |
+
if bias is not None:
|
| 309 |
+
bias_cast = bias.to(compute_dtype)
|
| 310 |
+
output = output + bias_cast
|
| 311 |
+
else:
|
| 312 |
+
bias_cast = None
|
| 313 |
+
|
| 314 |
+
# Convert output to specified output data type
|
| 315 |
+
return output.to(output_dtype)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class BaseLinear(nn.Module):
|
| 319 |
+
__constants__ = ["in_features", "out_features", "num_layers", "num_experts"]
|
| 320 |
+
in_features: int
|
| 321 |
+
out_features: int
|
| 322 |
+
num_layers_for_initialization: int
|
| 323 |
+
num_experts: int
|
| 324 |
+
weight: Tensor
|
| 325 |
+
|
| 326 |
+
def __init__(
|
| 327 |
+
self, in_features, out_features, num_layers_for_initialization, num_experts, bias=True, device=None, dtype=None
|
| 328 |
+
):
|
| 329 |
+
super().__init__()
|
| 330 |
+
factory_kwargs = {"device": device, "dtype": torch.bfloat16}
|
| 331 |
+
self.in_features = in_features
|
| 332 |
+
self.out_features = out_features
|
| 333 |
+
self.num_layers_for_initialization = num_layers_for_initialization
|
| 334 |
+
self.num_experts = num_experts
|
| 335 |
+
self.use_bias = bias
|
| 336 |
+
self.weight = Parameter(torch.empty((out_features * num_experts, in_features), **factory_kwargs))
|
| 337 |
+
if bias:
|
| 338 |
+
self.bias = Parameter(torch.empty(out_features * num_experts, **factory_kwargs))
|
| 339 |
+
else:
|
| 340 |
+
self.register_parameter("bias", None)
|
| 341 |
+
|
| 342 |
+
def forward(
|
| 343 |
+
self,
|
| 344 |
+
input: torch.Tensor,
|
| 345 |
+
output_dtype: Optional[torch.dtype] = None,
|
| 346 |
+
modality_dispatcher: Optional[ModalityDispatcher] = None,
|
| 347 |
+
) -> torch.Tensor:
|
| 348 |
+
output_dtype = input.dtype if output_dtype is None else output_dtype
|
| 349 |
+
return _BF16ComputeLinear.apply(input, self.weight, self.bias, output_dtype, torch.bfloat16)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class NativeMoELinear(BaseLinear):
|
| 353 |
+
def forward(
|
| 354 |
+
self,
|
| 355 |
+
input: torch.Tensor,
|
| 356 |
+
output_dtype: Optional[torch.dtype] = None,
|
| 357 |
+
modality_dispatcher: Optional[ModalityDispatcher] = None,
|
| 358 |
+
) -> torch.Tensor:
|
| 359 |
+
output_dtype = input.dtype if output_dtype is None else output_dtype
|
| 360 |
+
|
| 361 |
+
input_list = modality_dispatcher.dispatch(input) # type: ignore
|
| 362 |
+
weight_chunked = self.weight.chunk(self.num_experts, dim=0)
|
| 363 |
+
|
| 364 |
+
if self.bias is not None:
|
| 365 |
+
bias_chunked = self.bias.chunk(self.num_experts, dim=0)
|
| 366 |
+
|
| 367 |
+
for i in range(self.num_experts):
|
| 368 |
+
input_list[i] = _BF16ComputeLinear.apply(
|
| 369 |
+
input_list[i],
|
| 370 |
+
weight_chunked[i],
|
| 371 |
+
bias_chunked[i] if self.bias is not None else None,
|
| 372 |
+
output_dtype,
|
| 373 |
+
torch.bfloat16,
|
| 374 |
+
)
|
| 375 |
+
return modality_dispatcher.undispatch(*input_list) # type: ignore
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def create_linear(
|
| 379 |
+
in_features, out_features, num_layers=1, num_experts=1, bias=True, device=None, dtype=None
|
| 380 |
+
) -> BaseLinear | NativeMoELinear:
|
| 381 |
+
if num_experts == 1:
|
| 382 |
+
return BaseLinear(in_features, out_features, num_layers, num_experts, bias, device, dtype)
|
| 383 |
+
else:
|
| 384 |
+
return NativeMoELinear(in_features, out_features, num_layers, num_experts, bias, device, dtype)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
HAS_MAGI_ATTENTION = importlib.util.find_spec("magi_attention") is not None
|
| 388 |
+
HAS_FA3 = importlib.util.find_spec("flash_attn_interface") is not None
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
@magi_register_custom_op(name="infra::flash_attn_func", is_subgraph_boundary=True)
|
| 392 |
+
def flash_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
| 393 |
+
if HAS_FA3 and is_hopper_arch():
|
| 394 |
+
from flash_attn_interface import flash_attn_func as fa3_flash_attn_func
|
| 395 |
+
|
| 396 |
+
return fa3_flash_attn_func(query, key, value)
|
| 397 |
+
else:
|
| 398 |
+
from flash_attn.flash_attn_interface import flash_attn_func as fa2_flash_attn_func
|
| 399 |
+
|
| 400 |
+
return fa2_flash_attn_func(query, key, value)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def _split_q_range_with_no_overlap(
|
| 404 |
+
q_ranges: torch.Tensor, k_ranges: torch.Tensor
|
| 405 |
+
) -> Tuple[List[List[int]], List[List[List[int]]]]:
|
| 406 |
+
range_boundary = torch.unique(q_ranges, sorted=True).tolist()
|
| 407 |
+
candidates = [[start, end, []] for start, end in zip(range_boundary[:-1], range_boundary[1:])]
|
| 408 |
+
q_ranges = q_ranges.tolist()
|
| 409 |
+
k_ranges = k_ranges.tolist()
|
| 410 |
+
for q_range, k_range in zip(q_ranges, k_ranges):
|
| 411 |
+
q_start, q_end = q_range
|
| 412 |
+
for q_range_cand in candidates:
|
| 413 |
+
if q_start <= q_range_cand[0] and q_range_cand[1] <= q_end:
|
| 414 |
+
q_range_cand[2].append(k_range)
|
| 415 |
+
q_ranges_out = []
|
| 416 |
+
k_ranges_out = []
|
| 417 |
+
for q_range_cand in candidates:
|
| 418 |
+
if len(q_range_cand[2]) > 0:
|
| 419 |
+
q_ranges_out.append(q_range_cand[0:2])
|
| 420 |
+
k_ranges_out.append(q_range_cand[2])
|
| 421 |
+
return q_ranges_out, k_ranges_out
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def _flash_attn_with_correction(
|
| 425 |
+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, q_ranges: List[List[int]], k_range_list: List[List[List[int]]]
|
| 426 |
+
):
|
| 427 |
+
output = torch.zeros_like(query)
|
| 428 |
+
output_lse = torch.zeros((query.shape[0], query.shape[1]), dtype=torch.float32, device=query.device)
|
| 429 |
+
|
| 430 |
+
from flash_attn.flash_attn_interface import flash_attn_func
|
| 431 |
+
|
| 432 |
+
for q_range, k_ranges in zip(q_ranges, k_range_list):
|
| 433 |
+
q_start, q_end = q_range
|
| 434 |
+
qo_out, qo_lse = None, None
|
| 435 |
+
for k_range in k_ranges:
|
| 436 |
+
k_start, k_end = k_range
|
| 437 |
+
cur_qo_out, cur_qo_lse, _ = flash_attn_func(
|
| 438 |
+
query[q_start:q_end].unsqueeze(0),
|
| 439 |
+
key[k_start:k_end].unsqueeze(0),
|
| 440 |
+
value[k_start:k_end].unsqueeze(0),
|
| 441 |
+
return_attn_probs=True,
|
| 442 |
+
)
|
| 443 |
+
cur_qo_out, cur_qo_lse = cur_qo_out.squeeze(0), cur_qo_lse.squeeze(0)
|
| 444 |
+
|
| 445 |
+
if qo_out is None:
|
| 446 |
+
qo_out = cur_qo_out
|
| 447 |
+
qo_lse = cur_qo_lse
|
| 448 |
+
else:
|
| 449 |
+
qo_lse[qo_lse == torch.inf] = -torch.inf
|
| 450 |
+
cur_qo_lse[cur_qo_lse == torch.inf] = -torch.inf
|
| 451 |
+
max_lse = torch.max(qo_lse, cur_qo_lse)
|
| 452 |
+
qo_se, cur_qo_se = torch.exp(qo_lse - max_lse), torch.exp(cur_qo_lse - max_lse)
|
| 453 |
+
sum_se = qo_se + cur_qo_se
|
| 454 |
+
qo_scale, cur_qo_scale = qo_se / sum_se, cur_qo_se / sum_se
|
| 455 |
+
|
| 456 |
+
qo_out = qo_out * qo_scale.permute(1, 0).unsqueeze(-1) + cur_qo_out * cur_qo_scale.permute(1, 0).unsqueeze(-1)
|
| 457 |
+
qo_lse = torch.log(sum_se) + max_lse
|
| 458 |
+
|
| 459 |
+
output[q_start:q_end] = qo_out
|
| 460 |
+
output_lse[q_start:q_end, :] = qo_lse.permute(1, 0)
|
| 461 |
+
return output, output_lse
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def _custom_flex_flash_attn_func(
|
| 465 |
+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, q_ranges: torch.Tensor, k_ranges: torch.Tensor, **kwargs
|
| 466 |
+
):
|
| 467 |
+
q_ranges, k_range_list = _split_q_range_with_no_overlap(q_ranges, k_ranges)
|
| 468 |
+
return _flash_attn_with_correction(query, key, value, q_ranges, k_range_list)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def _flex_flash_attn_func_infer_output_meta(
|
| 472 |
+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, q_ranges: torch.Tensor, k_ranges: torch.Tensor
|
| 473 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 474 |
+
output = torch.empty_like(query)
|
| 475 |
+
output_lse = torch.empty((query.shape[0], query.shape[1]), dtype=torch.float32, device=query.device)
|
| 476 |
+
return output, output_lse
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
@magi_register_custom_op(
|
| 480 |
+
name="infra::flex_flash_attn_func",
|
| 481 |
+
mutates_args=(),
|
| 482 |
+
infer_output_meta_fn=_flex_flash_attn_func_infer_output_meta,
|
| 483 |
+
is_subgraph_boundary=True,
|
| 484 |
+
)
|
| 485 |
+
def flex_flash_attn_func(
|
| 486 |
+
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, q_ranges: torch.Tensor, k_ranges: torch.Tensor
|
| 487 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 488 |
+
if HAS_MAGI_ATTENTION and is_hopper_arch():
|
| 489 |
+
from magi_attention.api import flex_flash_attn_func as magi_flex_flash_attn_func
|
| 490 |
+
|
| 491 |
+
return magi_flex_flash_attn_func(query, key, value, q_ranges, k_ranges)
|
| 492 |
+
else:
|
| 493 |
+
return _custom_flex_flash_attn_func(query, key, value, q_ranges, k_ranges)
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def _attention_with_cp_infer_output_meta(q: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 497 |
+
return torch.empty_like(q, dtype=torch.bfloat16).squeeze(0)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
@magi_register_custom_op(
|
| 501 |
+
name="infra::flash_attn_with_cp",
|
| 502 |
+
mutates_args=(),
|
| 503 |
+
infer_output_meta_fn=_attention_with_cp_infer_output_meta,
|
| 504 |
+
is_subgraph_boundary=True,
|
| 505 |
+
)
|
| 506 |
+
def flash_attn_with_cp(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cp_split_sizes: List[int]) -> torch.Tensor:
|
| 507 |
+
q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16)
|
| 508 |
+
|
| 509 |
+
from inference.infra.distributed import get_cp_group, get_cp_world_size
|
| 510 |
+
from inference.infra.parallelism.all_to_all_primitive import batch_scatter_head_gather_seqlen, scatter_seqlen_gather_head
|
| 511 |
+
|
| 512 |
+
if get_cp_world_size() > 1:
|
| 513 |
+
q, k, v = batch_scatter_head_gather_seqlen([q.squeeze(0), k.squeeze(0), v.squeeze(0)], cp_split_sizes, get_cp_group())
|
| 514 |
+
q = q.unsqueeze(0)
|
| 515 |
+
k = k.unsqueeze(0)
|
| 516 |
+
v = v.unsqueeze(0)
|
| 517 |
+
|
| 518 |
+
self_attn_out = torch.ops.infra.flash_attn_func(q, k, v).squeeze(0)
|
| 519 |
+
|
| 520 |
+
if get_cp_world_size() > 1:
|
| 521 |
+
self_attn_out = scatter_seqlen_gather_head(self_attn_out, cp_split_sizes, get_cp_group(), async_op=False)
|
| 522 |
+
self_attn_out = rearrange(self_attn_out, "(cp sq) hn hd -> sq (cp hn) hd", cp=get_cp_world_size())
|
| 523 |
+
|
| 524 |
+
return self_attn_out
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
@magi_register_custom_op(
|
| 528 |
+
name="infra::flex_flash_attn_with_cp",
|
| 529 |
+
mutates_args=(),
|
| 530 |
+
infer_output_meta_fn=_attention_with_cp_infer_output_meta,
|
| 531 |
+
is_subgraph_boundary=True,
|
| 532 |
+
)
|
| 533 |
+
def flex_flash_attn_with_cp(
|
| 534 |
+
q: torch.Tensor,
|
| 535 |
+
k: torch.Tensor,
|
| 536 |
+
v: torch.Tensor,
|
| 537 |
+
q_ranges: torch.Tensor,
|
| 538 |
+
k_ranges: torch.Tensor,
|
| 539 |
+
cp_split_sizes: List[int],
|
| 540 |
+
) -> torch.Tensor:
|
| 541 |
+
q, k, v = q.to(torch.bfloat16).squeeze(0), k.to(torch.bfloat16).squeeze(0), v.to(torch.bfloat16).squeeze(0)
|
| 542 |
+
|
| 543 |
+
from inference.infra.distributed import get_cp_group, get_cp_world_size
|
| 544 |
+
from inference.infra.parallelism.all_to_all_primitive import batch_scatter_head_gather_seqlen, scatter_seqlen_gather_head
|
| 545 |
+
|
| 546 |
+
if get_cp_world_size() > 1:
|
| 547 |
+
q, k, v = batch_scatter_head_gather_seqlen([q, k, v], cp_split_sizes, get_cp_group())
|
| 548 |
+
|
| 549 |
+
out, _ = torch.ops.infra.flex_flash_attn_func(q, k, v, q_ranges=q_ranges, k_ranges=k_ranges)
|
| 550 |
+
|
| 551 |
+
if get_cp_world_size() > 1:
|
| 552 |
+
out = scatter_seqlen_gather_head(out, cp_split_sizes, get_cp_group(), async_op=False)
|
| 553 |
+
out = rearrange(out, "(cp sq) hn hd -> sq (cp hn) hd", cp=get_cp_world_size())
|
| 554 |
+
|
| 555 |
+
return out
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
@dataclass
|
| 559 |
+
class AttentionConfig:
|
| 560 |
+
hidden_size: int
|
| 561 |
+
num_heads_q: int
|
| 562 |
+
num_heads_kv: int
|
| 563 |
+
head_dim: int
|
| 564 |
+
params_dtype: torch.dtype
|
| 565 |
+
checkpoint_qk_layernorm_rope: bool
|
| 566 |
+
num_modality: int
|
| 567 |
+
num_layers: int
|
| 568 |
+
use_local_attn: bool = False
|
| 569 |
+
enable_attn_gating: bool = False
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
class Attention(torch.nn.Module):
|
| 573 |
+
config: AttentionConfig
|
| 574 |
+
|
| 575 |
+
def __init__(self, config: AttentionConfig):
|
| 576 |
+
super().__init__()
|
| 577 |
+
self.config = config
|
| 578 |
+
|
| 579 |
+
self.pre_norm = MultiModalityRMSNorm(config.hidden_size, eps=1e-6, num_modality=config.num_modality)
|
| 580 |
+
self.gating_size = config.num_heads_q if config.enable_attn_gating else 0
|
| 581 |
+
|
| 582 |
+
self.linear_qkv = create_linear(
|
| 583 |
+
config.hidden_size,
|
| 584 |
+
config.num_heads_q * config.head_dim + config.num_heads_kv * config.head_dim * 2 + self.gating_size,
|
| 585 |
+
num_experts=config.num_modality,
|
| 586 |
+
bias=False,
|
| 587 |
+
dtype=config.params_dtype,
|
| 588 |
+
num_layers=config.num_layers,
|
| 589 |
+
)
|
| 590 |
+
self.linear_proj = create_linear(
|
| 591 |
+
config.num_heads_q * config.head_dim,
|
| 592 |
+
config.hidden_size,
|
| 593 |
+
bias=False,
|
| 594 |
+
num_experts=config.num_modality,
|
| 595 |
+
dtype=config.params_dtype,
|
| 596 |
+
num_layers=config.num_layers,
|
| 597 |
+
)
|
| 598 |
+
self.q_norm = MultiModalityRMSNorm(config.head_dim, num_modality=config.num_modality)
|
| 599 |
+
self.k_norm = MultiModalityRMSNorm(config.head_dim, num_modality=config.num_modality)
|
| 600 |
+
|
| 601 |
+
self.q_size = config.num_heads_q * config.head_dim
|
| 602 |
+
self.kv_size = config.num_heads_kv * config.head_dim
|
| 603 |
+
|
| 604 |
+
def reset_parameters(self):
|
| 605 |
+
if hasattr(self.linear_proj, "reset_parameters_output_layer"):
|
| 606 |
+
self.linear_proj.reset_parameters_output_layer()
|
| 607 |
+
|
| 608 |
+
def forward(
|
| 609 |
+
self,
|
| 610 |
+
hidden_states: torch.Tensor,
|
| 611 |
+
rope: torch.Tensor,
|
| 612 |
+
permute_mapping: torch.Tensor,
|
| 613 |
+
inv_permute_mapping: torch.Tensor,
|
| 614 |
+
varlen_handler: VarlenHandler,
|
| 615 |
+
local_attn_handler: FFAHandler,
|
| 616 |
+
modality_dispatcher: ModalityDispatcher,
|
| 617 |
+
cp_split_sizes: List[int],
|
| 618 |
+
) -> torch.Tensor:
|
| 619 |
+
hidden_states = self.pre_norm(hidden_states, modality_dispatcher=modality_dispatcher).to(torch.bfloat16)
|
| 620 |
+
qkv: torch.Tensor = self.linear_qkv(hidden_states, modality_dispatcher=modality_dispatcher).to(torch.float32)
|
| 621 |
+
|
| 622 |
+
q, k, v, g = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size, self.gating_size], dim=1)
|
| 623 |
+
q = q.view(-1, self.config.num_heads_q, self.config.head_dim)
|
| 624 |
+
k = k.view(-1, self.config.num_heads_kv, self.config.head_dim)
|
| 625 |
+
v = v.view(-1, self.config.num_heads_kv, self.config.head_dim)
|
| 626 |
+
g = g.view(k.shape[0], self.config.num_heads_q, -1)
|
| 627 |
+
|
| 628 |
+
q = self.q_norm(q, modality_dispatcher=modality_dispatcher)
|
| 629 |
+
k = self.k_norm(k, modality_dispatcher=modality_dispatcher)
|
| 630 |
+
|
| 631 |
+
q = ModalityDispatcher.inv_permute(q, inv_permute_mapping).unsqueeze(0)
|
| 632 |
+
k = ModalityDispatcher.inv_permute(k, inv_permute_mapping).unsqueeze(0)
|
| 633 |
+
v = ModalityDispatcher.inv_permute(v, inv_permute_mapping).unsqueeze(0)
|
| 634 |
+
|
| 635 |
+
sin_emb, cos_emb = rope.tensor_split(2, -1)
|
| 636 |
+
q = apply_rotary_emb_torch(q, cos_emb, sin_emb)
|
| 637 |
+
k = apply_rotary_emb_torch(k, cos_emb, sin_emb)
|
| 638 |
+
|
| 639 |
+
if self.config.use_local_attn:
|
| 640 |
+
self_attn_out = flex_flash_attn_with_cp(
|
| 641 |
+
q, k, v, local_attn_handler.q_ranges, local_attn_handler.k_ranges, cp_split_sizes
|
| 642 |
+
)
|
| 643 |
+
else:
|
| 644 |
+
self_attn_out = flash_attn_with_cp(q, k, v, cp_split_sizes)
|
| 645 |
+
self_attn_out = ModalityDispatcher.permute(self_attn_out, permute_mapping)
|
| 646 |
+
|
| 647 |
+
if self.config.enable_attn_gating:
|
| 648 |
+
self_attn_out = self_attn_out * torch.sigmoid(g)
|
| 649 |
+
|
| 650 |
+
self_attn_out = self_attn_out.view(-1, self.config.num_heads_q * self.config.head_dim).to(torch.bfloat16)
|
| 651 |
+
out = self.linear_proj(self_attn_out, modality_dispatcher=modality_dispatcher)
|
| 652 |
+
return out
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
@dataclass
|
| 656 |
+
class MLPConfig:
|
| 657 |
+
hidden_size: int
|
| 658 |
+
intermediate_size: int
|
| 659 |
+
activation_type: MLPActivationType
|
| 660 |
+
params_dtype: torch.dtype
|
| 661 |
+
num_modality: int = 1
|
| 662 |
+
num_layers: int = 1
|
| 663 |
+
gated_act: bool = False
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
class MLP(torch.nn.Module):
|
| 667 |
+
config: MLPConfig
|
| 668 |
+
|
| 669 |
+
def __init__(self, config: MLPConfig):
|
| 670 |
+
super().__init__()
|
| 671 |
+
num_experts = config.num_modality
|
| 672 |
+
self.pre_norm = MultiModalityRMSNorm(config.hidden_size, num_modality=config.num_modality)
|
| 673 |
+
intermediate_size_up = config.intermediate_size * 2 if config.gated_act else config.intermediate_size
|
| 674 |
+
|
| 675 |
+
self.up_gate_proj = create_linear(
|
| 676 |
+
config.hidden_size,
|
| 677 |
+
intermediate_size_up,
|
| 678 |
+
bias=False,
|
| 679 |
+
dtype=config.params_dtype,
|
| 680 |
+
num_layers=config.num_layers,
|
| 681 |
+
num_experts=num_experts,
|
| 682 |
+
)
|
| 683 |
+
self.down_proj = create_linear(
|
| 684 |
+
config.intermediate_size,
|
| 685 |
+
config.hidden_size,
|
| 686 |
+
bias=False,
|
| 687 |
+
dtype=config.params_dtype,
|
| 688 |
+
num_layers=config.num_layers,
|
| 689 |
+
num_experts=num_experts,
|
| 690 |
+
)
|
| 691 |
+
self.activation_func = create_activation_func(config.activation_type)
|
| 692 |
+
|
| 693 |
+
def forward(self, x: torch.Tensor, modality_dispatcher: ModalityDispatcher) -> torch.Tensor:
|
| 694 |
+
x = self.pre_norm(x, modality_dispatcher=modality_dispatcher).to(torch.bfloat16)
|
| 695 |
+
x = self.up_gate_proj(x, modality_dispatcher=modality_dispatcher).to(torch.float32)
|
| 696 |
+
x = self.activation_func(x).to(torch.bfloat16)
|
| 697 |
+
x = self.down_proj(x, modality_dispatcher=modality_dispatcher).to(torch.float32)
|
| 698 |
+
return x
|
| 699 |
+
|
| 700 |
+
def extra_repr(self) -> str:
|
| 701 |
+
return f"{self.up_gate_proj.weight.shape=}, {self.down_proj.weight.shape=}"
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
@dataclass
|
| 705 |
+
class AdapterConfig:
|
| 706 |
+
hidden_size: int
|
| 707 |
+
num_attention_heads: int
|
| 708 |
+
text_in_channels: int
|
| 709 |
+
video_in_channels: int
|
| 710 |
+
audio_in_channels: int
|
| 711 |
+
params_dtype: torch.dtype
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
class Adapter(torch.nn.Module):
|
| 715 |
+
config: AdapterConfig
|
| 716 |
+
|
| 717 |
+
def __init__(self, config: AdapterConfig):
|
| 718 |
+
super().__init__()
|
| 719 |
+
self.config = config
|
| 720 |
+
self.video_embedder = nn.Linear(config.video_in_channels, config.hidden_size, bias=True, dtype=torch.float32)
|
| 721 |
+
self.text_embedder = nn.Linear(config.text_in_channels, config.hidden_size, bias=True, dtype=torch.float32)
|
| 722 |
+
self.audio_embedder = nn.Linear(config.audio_in_channels, config.hidden_size, bias=True, dtype=torch.float32)
|
| 723 |
+
self.rope = ElementWiseFourierEmbed(config.hidden_size // config.num_attention_heads, in_pixels=False, learnable=False)
|
| 724 |
+
|
| 725 |
+
def forward(
|
| 726 |
+
self,
|
| 727 |
+
x: torch.Tensor,
|
| 728 |
+
coords_mapping: torch.Tensor,
|
| 729 |
+
video_mask: torch.Tensor,
|
| 730 |
+
audio_mask: torch.Tensor,
|
| 731 |
+
text_mask: torch.Tensor,
|
| 732 |
+
):
|
| 733 |
+
rope = self.rope(coords_mapping)
|
| 734 |
+
output_x = torch.zeros(x.shape[0], self.config.hidden_size, device=x.device, dtype=x.dtype)
|
| 735 |
+
output_x[text_mask] = self.text_embedder(x[text_mask, : self.config.text_in_channels])
|
| 736 |
+
output_x[audio_mask] = self.audio_embedder(x[audio_mask, : self.config.audio_in_channels])
|
| 737 |
+
output_x[video_mask] = self.video_embedder(x[video_mask, : self.config.video_in_channels])
|
| 738 |
+
return output_x, rope
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
class TransFormerLayer(torch.nn.Module):
|
| 742 |
+
def __init__(self, config: Any, layer_idx: int):
|
| 743 |
+
super().__init__()
|
| 744 |
+
num_modality = 3 if layer_idx in config.mm_layers else 1
|
| 745 |
+
use_local_attn = layer_idx in config.local_attn_layers
|
| 746 |
+
self.post_norm = layer_idx in config.post_norm_layers
|
| 747 |
+
attention_config = AttentionConfig(
|
| 748 |
+
hidden_size=config.hidden_size,
|
| 749 |
+
num_heads_q=config.num_heads_q,
|
| 750 |
+
num_heads_kv=config.num_heads_kv,
|
| 751 |
+
head_dim=config.head_dim,
|
| 752 |
+
params_dtype=config.params_dtype,
|
| 753 |
+
checkpoint_qk_layernorm_rope=config.checkpoint_qk_layernorm_rope,
|
| 754 |
+
num_modality=num_modality,
|
| 755 |
+
num_layers=config.num_layers,
|
| 756 |
+
use_local_attn=use_local_attn,
|
| 757 |
+
enable_attn_gating=config.enable_attn_gating,
|
| 758 |
+
)
|
| 759 |
+
self.attention: Attention = Attention(attention_config)
|
| 760 |
+
|
| 761 |
+
activation_type = MLPActivationType.GELU7 if layer_idx in config.gelu7_layers else MLPActivationType.SWIGLU7
|
| 762 |
+
if activation_type == MLPActivationType.SWIGLU7:
|
| 763 |
+
gated_act = True
|
| 764 |
+
intermediate_size = int(config.hidden_size * 4 * 2 / 3) // 4 * 4
|
| 765 |
+
else:
|
| 766 |
+
gated_act = False
|
| 767 |
+
intermediate_size = config.hidden_size * 4
|
| 768 |
+
mlp_config = MLPConfig(
|
| 769 |
+
hidden_size=config.hidden_size,
|
| 770 |
+
intermediate_size=intermediate_size,
|
| 771 |
+
activation_type=activation_type,
|
| 772 |
+
params_dtype=config.params_dtype,
|
| 773 |
+
num_modality=num_modality,
|
| 774 |
+
num_layers=config.num_layers,
|
| 775 |
+
gated_act=gated_act,
|
| 776 |
+
)
|
| 777 |
+
self.mlp: MLP = MLP(mlp_config)
|
| 778 |
+
if self.post_norm:
|
| 779 |
+
self.attn_post_norm = MultiModalityRMSNorm(config.hidden_size, num_modality=num_modality)
|
| 780 |
+
self.mlp_post_norm = MultiModalityRMSNorm(config.hidden_size, num_modality=num_modality)
|
| 781 |
+
|
| 782 |
+
def forward(
|
| 783 |
+
self,
|
| 784 |
+
hidden_states: torch.Tensor,
|
| 785 |
+
rope: torch.Tensor,
|
| 786 |
+
permute_mapping: torch.Tensor,
|
| 787 |
+
inv_permute_mapping: torch.Tensor,
|
| 788 |
+
varlen_handler: VarlenHandler,
|
| 789 |
+
local_attn_handler: FFAHandler,
|
| 790 |
+
modality_dispatcher: ModalityDispatcher,
|
| 791 |
+
cp_split_sizes: List[int],
|
| 792 |
+
) -> torch.Tensor:
|
| 793 |
+
attn_out = self.attention(
|
| 794 |
+
hidden_states,
|
| 795 |
+
rope,
|
| 796 |
+
permute_mapping,
|
| 797 |
+
inv_permute_mapping,
|
| 798 |
+
varlen_handler,
|
| 799 |
+
local_attn_handler,
|
| 800 |
+
modality_dispatcher,
|
| 801 |
+
cp_split_sizes,
|
| 802 |
+
)
|
| 803 |
+
if self.post_norm:
|
| 804 |
+
attn_out = self.attn_post_norm(attn_out, modality_dispatcher=modality_dispatcher)
|
| 805 |
+
hidden_states = hidden_states + attn_out
|
| 806 |
+
|
| 807 |
+
mlp_out = self.mlp(hidden_states, modality_dispatcher)
|
| 808 |
+
if self.post_norm:
|
| 809 |
+
mlp_out = self.mlp_post_norm(mlp_out, modality_dispatcher=modality_dispatcher)
|
| 810 |
+
hidden_states = hidden_states + mlp_out
|
| 811 |
+
return hidden_states
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
is_base_model = True
|
| 815 |
+
|
| 816 |
+
|
| 817 |
+
def config_patch(compile_config: CompileConfig) -> CompileConfig:
|
| 818 |
+
global is_base_model
|
| 819 |
+
if is_base_model:
|
| 820 |
+
is_base_model = False
|
| 821 |
+
else:
|
| 822 |
+
# Fully offload SR model for memory-constrained GPU
|
| 823 |
+
compile_config.offload_config.gpu_resident_weight_ratio = 0.0
|
| 824 |
+
return compile_config
|
| 825 |
+
|
| 826 |
+
|
| 827 |
+
@magi_compile(config_patch=config_patch)
|
| 828 |
+
class TransformerBlock(torch.nn.Module):
|
| 829 |
+
def __init__(self, model_config: Any):
|
| 830 |
+
super().__init__()
|
| 831 |
+
self.layers: list[TransFormerLayer] = nn.ModuleList()
|
| 832 |
+
for layer_idx in range(model_config.num_layers):
|
| 833 |
+
self.layers.append(TransFormerLayer(model_config, layer_idx))
|
| 834 |
+
|
| 835 |
+
def forward(
|
| 836 |
+
self,
|
| 837 |
+
x: torch.Tensor,
|
| 838 |
+
rope: torch.Tensor,
|
| 839 |
+
permute_mapping: torch.Tensor,
|
| 840 |
+
inv_permute_mapping: torch.Tensor,
|
| 841 |
+
varlen_handler: VarlenHandler,
|
| 842 |
+
local_attn_handler: FFAHandler,
|
| 843 |
+
modality_dispatcher: ModalityDispatcher,
|
| 844 |
+
cp_split_sizes: List[int],
|
| 845 |
+
) -> torch.Tensor:
|
| 846 |
+
for _, layer in enumerate(self.layers):
|
| 847 |
+
x = layer(
|
| 848 |
+
x,
|
| 849 |
+
rope,
|
| 850 |
+
permute_mapping,
|
| 851 |
+
inv_permute_mapping,
|
| 852 |
+
varlen_handler,
|
| 853 |
+
local_attn_handler,
|
| 854 |
+
modality_dispatcher,
|
| 855 |
+
cp_split_sizes,
|
| 856 |
+
)
|
| 857 |
+
return x
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
@dataclass
|
| 861 |
+
class TransformerConfig:
|
| 862 |
+
hidden_size: int
|
| 863 |
+
video_in_channels: int
|
| 864 |
+
audio_in_channels: int
|
| 865 |
+
text_in_channels: int
|
| 866 |
+
params_dtype: torch.dtype
|
| 867 |
+
post_process_dtype: torch.dtype
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
class DiTModel(torch.nn.Module):
|
| 871 |
+
config: TransformerConfig
|
| 872 |
+
|
| 873 |
+
def __init__(self, model_config: Any):
|
| 874 |
+
super().__init__()
|
| 875 |
+
self.config = TransformerConfig(
|
| 876 |
+
hidden_size=model_config.hidden_size,
|
| 877 |
+
video_in_channels=model_config.video_in_channels,
|
| 878 |
+
audio_in_channels=model_config.audio_in_channels,
|
| 879 |
+
text_in_channels=model_config.text_in_channels,
|
| 880 |
+
params_dtype=model_config.params_dtype,
|
| 881 |
+
post_process_dtype=torch.float32,
|
| 882 |
+
)
|
| 883 |
+
adapter_config = AdapterConfig(
|
| 884 |
+
hidden_size=model_config.hidden_size,
|
| 885 |
+
num_attention_heads=model_config.num_heads_q,
|
| 886 |
+
text_in_channels=model_config.text_in_channels,
|
| 887 |
+
video_in_channels=model_config.video_in_channels,
|
| 888 |
+
audio_in_channels=model_config.audio_in_channels,
|
| 889 |
+
params_dtype=torch.float32,
|
| 890 |
+
)
|
| 891 |
+
self.adapter: Adapter = Adapter(adapter_config)
|
| 892 |
+
self.block: TransformerBlock = TransformerBlock(model_config=model_config)
|
| 893 |
+
self.final_norm_video = MultiModalityRMSNorm(self.config.hidden_size)
|
| 894 |
+
self.final_norm_audio = MultiModalityRMSNorm(self.config.hidden_size)
|
| 895 |
+
self.final_linear_video = nn.Linear(
|
| 896 |
+
self.config.hidden_size, self.config.video_in_channels, bias=False, dtype=torch.float32
|
| 897 |
+
)
|
| 898 |
+
self.final_linear_audio = nn.Linear(
|
| 899 |
+
self.config.hidden_size, self.config.audio_in_channels, bias=False, dtype=torch.float32
|
| 900 |
+
)
|
| 901 |
+
|
| 902 |
+
def forward(
|
| 903 |
+
self,
|
| 904 |
+
x: torch.Tensor,
|
| 905 |
+
coords_mapping: torch.Tensor,
|
| 906 |
+
modality_mapping: torch.Tensor,
|
| 907 |
+
varlen_handler: VarlenHandler,
|
| 908 |
+
local_attn_handler: FFAHandler,
|
| 909 |
+
):
|
| 910 |
+
x = ulysses_scheduler().dispatch(x)
|
| 911 |
+
coords_mapping = ulysses_scheduler().dispatch(coords_mapping)
|
| 912 |
+
modality_mapping = ulysses_scheduler().dispatch(modality_mapping)
|
| 913 |
+
cp_split_sizes = ulysses_scheduler().cp_split_sizes
|
| 914 |
+
|
| 915 |
+
modality_dispatcher = ModalityDispatcher(modality_mapping, 3)
|
| 916 |
+
permute_mapping, inv_permute_mapping = modality_dispatcher.permute_mapping, modality_dispatcher.inv_permute_mapping
|
| 917 |
+
video_mask = modality_mapping == Modality.VIDEO
|
| 918 |
+
audio_mask = modality_mapping == Modality.AUDIO
|
| 919 |
+
text_mask = modality_mapping == Modality.TEXT
|
| 920 |
+
|
| 921 |
+
x, rope = self.adapter(x, coords_mapping, video_mask, audio_mask, text_mask)
|
| 922 |
+
x = x.to(self.config.params_dtype)
|
| 923 |
+
x = ModalityDispatcher.permute(x, permute_mapping)
|
| 924 |
+
x = self.block(
|
| 925 |
+
x,
|
| 926 |
+
rope,
|
| 927 |
+
permute_mapping=permute_mapping,
|
| 928 |
+
inv_permute_mapping=inv_permute_mapping,
|
| 929 |
+
varlen_handler=varlen_handler,
|
| 930 |
+
local_attn_handler=local_attn_handler,
|
| 931 |
+
modality_dispatcher=modality_dispatcher,
|
| 932 |
+
cp_split_sizes=cp_split_sizes,
|
| 933 |
+
)
|
| 934 |
+
x = ModalityDispatcher.inv_permute(x, inv_permute_mapping)
|
| 935 |
+
|
| 936 |
+
x_video = x[video_mask].to(self.final_norm_video.weight.dtype)
|
| 937 |
+
x_video = self.final_norm_video(x_video)
|
| 938 |
+
x_video = self.final_linear_video(x_video)
|
| 939 |
+
|
| 940 |
+
x_audio = x[audio_mask].to(self.final_norm_audio.weight.dtype)
|
| 941 |
+
x_audio = self.final_norm_audio(x_audio)
|
| 942 |
+
x_audio = self.final_linear_audio(x_audio)
|
| 943 |
+
|
| 944 |
+
x_out = torch.zeros(
|
| 945 |
+
x.shape[0], max(self.config.video_in_channels, self.config.audio_in_channels), device=x.device, dtype=x.dtype
|
| 946 |
+
)
|
| 947 |
+
x_out[video_mask, : self.config.video_in_channels] = x_video
|
| 948 |
+
x_out[audio_mask, : self.config.audio_in_channels] = x_audio
|
| 949 |
+
x_out = ulysses_scheduler().undispatch(x_out)
|
| 950 |
+
return x_out
|
inference/model/sa_audio/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .sa_audio_model import SAAudioFeatureExtractor
|
| 2 |
+
from .sa_audio_module import (
|
| 3 |
+
AudioAutoencoder,
|
| 4 |
+
OobleckDecoder,
|
| 5 |
+
OobleckEncoder,
|
| 6 |
+
VAEBottleneck,
|
| 7 |
+
create_autoencoder_from_config,
|
| 8 |
+
create_bottleneck_from_config,
|
| 9 |
+
create_decoder_from_config,
|
| 10 |
+
create_encoder_from_config,
|
| 11 |
+
create_model_from_config,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"SAAudioFeatureExtractor",
|
| 16 |
+
"AudioAutoencoder",
|
| 17 |
+
"OobleckDecoder",
|
| 18 |
+
"OobleckEncoder",
|
| 19 |
+
"VAEBottleneck",
|
| 20 |
+
"create_autoencoder_from_config",
|
| 21 |
+
"create_bottleneck_from_config",
|
| 22 |
+
"create_decoder_from_config",
|
| 23 |
+
"create_encoder_from_config",
|
| 24 |
+
"create_model_from_config",
|
| 25 |
+
]
|
inference/model/sa_audio/sa_audio_model.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from safetensors.torch import load_file
|
| 21 |
+
|
| 22 |
+
# Set env vars for local T5 loading
|
| 23 |
+
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
| 24 |
+
os.environ["HF_HUB_OFFLINE"] = "1"
|
| 25 |
+
|
| 26 |
+
from .sa_audio_module import create_model_from_config
|
| 27 |
+
|
| 28 |
+
from inference.utils import print_rank_0
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class SAAudioFeatureExtractor:
|
| 32 |
+
"""Stable Audio Feature Extractor that loads model once and reuses it."""
|
| 33 |
+
|
| 34 |
+
def __init__(self, device, model_path):
|
| 35 |
+
"""Initialize the extractor with model loading."""
|
| 36 |
+
self.device = device
|
| 37 |
+
self.vae_model, self.sample_rate = self._get_vae_only(model_path)
|
| 38 |
+
# self.vae_model.to(self.device).to(torch.bfloat16)
|
| 39 |
+
self.resampler = None # Will be initialized when needed
|
| 40 |
+
|
| 41 |
+
def _get_vae_only(self, model_path):
|
| 42 |
+
"""Load VAE only, skip T5 and diffusion model."""
|
| 43 |
+
if isinstance(model_path, str) and Path(model_path).is_dir():
|
| 44 |
+
try:
|
| 45 |
+
# Read full config
|
| 46 |
+
model_config_path = os.path.join(model_path, "model_config.json")
|
| 47 |
+
with open(model_config_path) as f:
|
| 48 |
+
full_config = json.load(f)
|
| 49 |
+
|
| 50 |
+
vae_config = full_config["model"]["pretransform"]["config"]
|
| 51 |
+
sample_rate = full_config["sample_rate"]
|
| 52 |
+
|
| 53 |
+
# Rebuild config structure expected by create_autoencoder_from_config
|
| 54 |
+
autoencoder_config = {
|
| 55 |
+
"model_type": "autoencoder",
|
| 56 |
+
"sample_rate": sample_rate, # sample_rate is required
|
| 57 |
+
"model": vae_config, # create_autoencoder_from_config expects key "model"
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
vae_model = create_model_from_config(autoencoder_config)
|
| 61 |
+
# Load weights
|
| 62 |
+
weights_path = Path(model_path) / "model.safetensors"
|
| 63 |
+
|
| 64 |
+
if not weights_path.exists():
|
| 65 |
+
raise FileNotFoundError(f"Weight file does not exist: {weights_path}")
|
| 66 |
+
|
| 67 |
+
# Load full state dict
|
| 68 |
+
full_state_dict = load_file(weights_path, device=str(self.device))
|
| 69 |
+
|
| 70 |
+
# Filter VAE-related weights (prefix: pretransform.model)
|
| 71 |
+
vae_state_dict = {}
|
| 72 |
+
for key, value in full_state_dict.items():
|
| 73 |
+
if key.startswith("pretransform.model."):
|
| 74 |
+
vae_key = key[len("pretransform.model.") :]
|
| 75 |
+
vae_state_dict[vae_key] = value
|
| 76 |
+
|
| 77 |
+
# Check expected model keys
|
| 78 |
+
model_keys = set(vae_model.state_dict().keys())
|
| 79 |
+
vae_keys = set(vae_state_dict.keys())
|
| 80 |
+
|
| 81 |
+
missing_keys = model_keys - vae_keys
|
| 82 |
+
extra_keys = vae_keys - model_keys
|
| 83 |
+
|
| 84 |
+
if missing_keys:
|
| 85 |
+
print_rank_0(f"Missing keys ({len(missing_keys)}):")
|
| 86 |
+
for key in list(missing_keys)[:5]:
|
| 87 |
+
print_rank_0(f" - {key}")
|
| 88 |
+
|
| 89 |
+
if extra_keys:
|
| 90 |
+
print_rank_0(f"Unexpected keys ({len(extra_keys)}):")
|
| 91 |
+
for key in list(extra_keys)[:5]:
|
| 92 |
+
print_rank_0(f" + {key}")
|
| 93 |
+
|
| 94 |
+
# Load VAE weights
|
| 95 |
+
vae_model.load_state_dict(vae_state_dict)
|
| 96 |
+
vae_model.to(self.device)
|
| 97 |
+
|
| 98 |
+
return vae_model, sample_rate
|
| 99 |
+
|
| 100 |
+
except Exception as e:
|
| 101 |
+
print_rank_0(f"audio model loading failed: {e}")
|
| 102 |
+
raise RuntimeError(
|
| 103 |
+
"Failed to load VAE-only Stable Audio model from local path"
|
| 104 |
+
) from e
|
| 105 |
+
else:
|
| 106 |
+
print_rank_0("Non-local path is not supported in audio model loading")
|
| 107 |
+
|
| 108 |
+
def decode(self, latents):
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
waveform_out = self.vae_model.decode(latents)
|
| 111 |
+
return waveform_out
|
| 112 |
+
|
| 113 |
+
def encode(self, waveform):
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
latents = self.vae_model.encode(waveform)
|
| 116 |
+
return latents
|
inference/model/sa_audio/sa_audio_module.py
ADDED
|
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Any, Dict, Literal
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from torch.nn.utils import weight_norm
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def snake_beta(x, alpha, beta):
|
| 16 |
+
return x + (1.0 / (beta + 1e-9)) * torch.pow(torch.sin(x * alpha), 2)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SnakeBeta(nn.Module):
|
| 20 |
+
# Adapted from BigVGAN activation.
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
in_features: int,
|
| 24 |
+
alpha: float = 1.0,
|
| 25 |
+
alpha_trainable: bool = True,
|
| 26 |
+
alpha_logscale: bool = True,
|
| 27 |
+
):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.alpha_logscale = alpha_logscale
|
| 30 |
+
if self.alpha_logscale:
|
| 31 |
+
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
| 32 |
+
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
| 33 |
+
else:
|
| 34 |
+
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
| 35 |
+
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
| 36 |
+
|
| 37 |
+
self.alpha.requires_grad = alpha_trainable
|
| 38 |
+
self.beta.requires_grad = alpha_trainable
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
|
| 42 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
| 43 |
+
if self.alpha_logscale:
|
| 44 |
+
alpha = torch.exp(alpha)
|
| 45 |
+
beta = torch.exp(beta)
|
| 46 |
+
return snake_beta(x, alpha, beta)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def vae_sample(mean, scale):
|
| 50 |
+
stdev = F.softplus(scale) + 1e-4
|
| 51 |
+
var = stdev * stdev
|
| 52 |
+
logvar = torch.log(var)
|
| 53 |
+
latents = torch.randn_like(mean) * stdev + mean
|
| 54 |
+
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
| 55 |
+
return latents, kl
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class VAEBottleneck(nn.Module):
|
| 59 |
+
def __init__(self):
|
| 60 |
+
super().__init__()
|
| 61 |
+
|
| 62 |
+
def encode(self, x, return_info=False, **kwargs):
|
| 63 |
+
info = {}
|
| 64 |
+
mean, scale = x.chunk(2, dim=1)
|
| 65 |
+
x, kl = vae_sample(mean, scale)
|
| 66 |
+
info["kl"] = kl
|
| 67 |
+
if return_info:
|
| 68 |
+
return x, info
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
def decode(self, x):
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def WNConv1d(*args, **kwargs):
|
| 76 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 80 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def checkpoint(function, *args, **kwargs):
|
| 84 |
+
kwargs.setdefault("use_reentrant", False)
|
| 85 |
+
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_activation(
|
| 89 |
+
activation: Literal["elu", "snake", "none"], antialias: bool = False, channels=None
|
| 90 |
+
) -> nn.Module:
|
| 91 |
+
if antialias:
|
| 92 |
+
raise NotImplementedError("antialias activation is not supported in sa_audio")
|
| 93 |
+
|
| 94 |
+
if activation == "elu":
|
| 95 |
+
return nn.ELU()
|
| 96 |
+
if activation == "snake":
|
| 97 |
+
return SnakeBeta(channels)
|
| 98 |
+
if activation == "none":
|
| 99 |
+
return nn.Identity()
|
| 100 |
+
raise ValueError(f"Unknown activation {activation}")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class ResidualUnit(nn.Module):
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
in_channels: int,
|
| 107 |
+
out_channels: int,
|
| 108 |
+
dilation: int,
|
| 109 |
+
use_snake: bool = False,
|
| 110 |
+
antialias_activation: bool = False,
|
| 111 |
+
):
|
| 112 |
+
super().__init__()
|
| 113 |
+
padding = (dilation * (7 - 1)) // 2
|
| 114 |
+
self.layers = nn.Sequential(
|
| 115 |
+
get_activation(
|
| 116 |
+
"snake" if use_snake else "elu",
|
| 117 |
+
antialias=antialias_activation,
|
| 118 |
+
channels=out_channels,
|
| 119 |
+
),
|
| 120 |
+
WNConv1d(
|
| 121 |
+
in_channels=in_channels,
|
| 122 |
+
out_channels=out_channels,
|
| 123 |
+
kernel_size=7,
|
| 124 |
+
dilation=dilation,
|
| 125 |
+
padding=padding,
|
| 126 |
+
),
|
| 127 |
+
get_activation(
|
| 128 |
+
"snake" if use_snake else "elu",
|
| 129 |
+
antialias=antialias_activation,
|
| 130 |
+
channels=out_channels,
|
| 131 |
+
),
|
| 132 |
+
WNConv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=1),
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def forward(self, x):
|
| 136 |
+
if self.training:
|
| 137 |
+
y = checkpoint(self.layers, x)
|
| 138 |
+
else:
|
| 139 |
+
y = self.layers(x)
|
| 140 |
+
return y + x
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class EncoderBlock(nn.Module):
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
in_channels: int,
|
| 147 |
+
out_channels: int,
|
| 148 |
+
stride: int,
|
| 149 |
+
use_snake: bool = False,
|
| 150 |
+
antialias_activation: bool = False,
|
| 151 |
+
):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.layers = nn.Sequential(
|
| 154 |
+
ResidualUnit(in_channels, in_channels, 1, use_snake=use_snake),
|
| 155 |
+
ResidualUnit(in_channels, in_channels, 3, use_snake=use_snake),
|
| 156 |
+
ResidualUnit(in_channels, in_channels, 9, use_snake=use_snake),
|
| 157 |
+
get_activation(
|
| 158 |
+
"snake" if use_snake else "elu",
|
| 159 |
+
antialias=antialias_activation,
|
| 160 |
+
channels=in_channels,
|
| 161 |
+
),
|
| 162 |
+
WNConv1d(
|
| 163 |
+
in_channels=in_channels,
|
| 164 |
+
out_channels=out_channels,
|
| 165 |
+
kernel_size=2 * stride,
|
| 166 |
+
stride=stride,
|
| 167 |
+
padding=math.ceil(stride / 2),
|
| 168 |
+
),
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
def forward(self, x):
|
| 172 |
+
return self.layers(x)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class DecoderBlock(nn.Module):
|
| 176 |
+
def __init__(
|
| 177 |
+
self,
|
| 178 |
+
in_channels: int,
|
| 179 |
+
out_channels: int,
|
| 180 |
+
stride: int,
|
| 181 |
+
use_snake: bool = False,
|
| 182 |
+
antialias_activation: bool = False,
|
| 183 |
+
use_nearest_upsample: bool = False,
|
| 184 |
+
):
|
| 185 |
+
super().__init__()
|
| 186 |
+
|
| 187 |
+
if use_nearest_upsample:
|
| 188 |
+
upsample_layer = nn.Sequential(
|
| 189 |
+
nn.Upsample(scale_factor=stride, mode="nearest"),
|
| 190 |
+
WNConv1d(
|
| 191 |
+
in_channels=in_channels,
|
| 192 |
+
out_channels=out_channels,
|
| 193 |
+
kernel_size=2 * stride,
|
| 194 |
+
stride=1,
|
| 195 |
+
bias=False,
|
| 196 |
+
padding="same",
|
| 197 |
+
),
|
| 198 |
+
)
|
| 199 |
+
else:
|
| 200 |
+
upsample_layer = WNConvTranspose1d(
|
| 201 |
+
in_channels=in_channels,
|
| 202 |
+
out_channels=out_channels,
|
| 203 |
+
kernel_size=2 * stride,
|
| 204 |
+
stride=stride,
|
| 205 |
+
padding=math.ceil(stride / 2),
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
self.layers = nn.Sequential(
|
| 209 |
+
get_activation(
|
| 210 |
+
"snake" if use_snake else "elu",
|
| 211 |
+
antialias=antialias_activation,
|
| 212 |
+
channels=in_channels,
|
| 213 |
+
),
|
| 214 |
+
upsample_layer,
|
| 215 |
+
ResidualUnit(out_channels, out_channels, 1, use_snake=use_snake),
|
| 216 |
+
ResidualUnit(out_channels, out_channels, 3, use_snake=use_snake),
|
| 217 |
+
ResidualUnit(out_channels, out_channels, 9, use_snake=use_snake),
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
def forward(self, x):
|
| 221 |
+
return self.layers(x)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class OobleckEncoder(nn.Module):
|
| 225 |
+
def __init__(
|
| 226 |
+
self,
|
| 227 |
+
in_channels: int = 2,
|
| 228 |
+
channels: int = 128,
|
| 229 |
+
latent_dim: int = 32,
|
| 230 |
+
c_mults=[1, 2, 4, 8],
|
| 231 |
+
strides=[2, 4, 8, 8],
|
| 232 |
+
use_snake: bool = False,
|
| 233 |
+
antialias_activation: bool = False,
|
| 234 |
+
):
|
| 235 |
+
super().__init__()
|
| 236 |
+
|
| 237 |
+
c_mults = [1] + c_mults
|
| 238 |
+
depth = len(c_mults)
|
| 239 |
+
|
| 240 |
+
layers = [
|
| 241 |
+
WNConv1d(
|
| 242 |
+
in_channels=in_channels,
|
| 243 |
+
out_channels=c_mults[0] * channels,
|
| 244 |
+
kernel_size=7,
|
| 245 |
+
padding=3,
|
| 246 |
+
)
|
| 247 |
+
]
|
| 248 |
+
|
| 249 |
+
for i in range(depth - 1):
|
| 250 |
+
layers.append(
|
| 251 |
+
EncoderBlock(
|
| 252 |
+
in_channels=c_mults[i] * channels,
|
| 253 |
+
out_channels=c_mults[i + 1] * channels,
|
| 254 |
+
stride=strides[i],
|
| 255 |
+
use_snake=use_snake,
|
| 256 |
+
)
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
layers.extend(
|
| 260 |
+
[
|
| 261 |
+
get_activation(
|
| 262 |
+
"snake" if use_snake else "elu",
|
| 263 |
+
antialias=antialias_activation,
|
| 264 |
+
channels=c_mults[-1] * channels,
|
| 265 |
+
),
|
| 266 |
+
WNConv1d(
|
| 267 |
+
in_channels=c_mults[-1] * channels,
|
| 268 |
+
out_channels=latent_dim,
|
| 269 |
+
kernel_size=3,
|
| 270 |
+
padding=1,
|
| 271 |
+
),
|
| 272 |
+
]
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
self.layers = nn.Sequential(*layers)
|
| 276 |
+
|
| 277 |
+
def forward(self, x):
|
| 278 |
+
return self.layers(x)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class OobleckDecoder(nn.Module):
|
| 282 |
+
def __init__(
|
| 283 |
+
self,
|
| 284 |
+
out_channels: int = 2,
|
| 285 |
+
channels: int = 128,
|
| 286 |
+
latent_dim: int = 32,
|
| 287 |
+
c_mults=[1, 2, 4, 8],
|
| 288 |
+
strides=[2, 4, 8, 8],
|
| 289 |
+
use_snake: bool = False,
|
| 290 |
+
antialias_activation: bool = False,
|
| 291 |
+
use_nearest_upsample: bool = False,
|
| 292 |
+
final_tanh: bool = True,
|
| 293 |
+
):
|
| 294 |
+
super().__init__()
|
| 295 |
+
|
| 296 |
+
c_mults = [1] + c_mults
|
| 297 |
+
depth = len(c_mults)
|
| 298 |
+
|
| 299 |
+
layers = [
|
| 300 |
+
WNConv1d(
|
| 301 |
+
in_channels=latent_dim,
|
| 302 |
+
out_channels=c_mults[-1] * channels,
|
| 303 |
+
kernel_size=7,
|
| 304 |
+
padding=3,
|
| 305 |
+
)
|
| 306 |
+
]
|
| 307 |
+
|
| 308 |
+
for i in range(depth - 1, 0, -1):
|
| 309 |
+
layers.append(
|
| 310 |
+
DecoderBlock(
|
| 311 |
+
in_channels=c_mults[i] * channels,
|
| 312 |
+
out_channels=c_mults[i - 1] * channels,
|
| 313 |
+
stride=strides[i - 1],
|
| 314 |
+
use_snake=use_snake,
|
| 315 |
+
antialias_activation=antialias_activation,
|
| 316 |
+
use_nearest_upsample=use_nearest_upsample,
|
| 317 |
+
)
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
layers.extend(
|
| 321 |
+
[
|
| 322 |
+
get_activation(
|
| 323 |
+
"snake" if use_snake else "elu",
|
| 324 |
+
antialias=antialias_activation,
|
| 325 |
+
channels=c_mults[0] * channels,
|
| 326 |
+
),
|
| 327 |
+
WNConv1d(
|
| 328 |
+
in_channels=c_mults[0] * channels,
|
| 329 |
+
out_channels=out_channels,
|
| 330 |
+
kernel_size=7,
|
| 331 |
+
padding=3,
|
| 332 |
+
bias=False,
|
| 333 |
+
),
|
| 334 |
+
nn.Tanh() if final_tanh else nn.Identity(),
|
| 335 |
+
]
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
self.layers = nn.Sequential(*layers)
|
| 339 |
+
|
| 340 |
+
def forward(self, x):
|
| 341 |
+
return self.layers(x)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class AudioAutoencoder(nn.Module):
|
| 345 |
+
def __init__(
|
| 346 |
+
self,
|
| 347 |
+
encoder: nn.Module,
|
| 348 |
+
decoder: nn.Module,
|
| 349 |
+
latent_dim: int,
|
| 350 |
+
downsampling_ratio: int,
|
| 351 |
+
sample_rate: int,
|
| 352 |
+
io_channels: int = 2,
|
| 353 |
+
bottleneck: nn.Module | None = None,
|
| 354 |
+
in_channels: int | None = None,
|
| 355 |
+
out_channels: int | None = None,
|
| 356 |
+
soft_clip: bool = False,
|
| 357 |
+
):
|
| 358 |
+
super().__init__()
|
| 359 |
+
self.downsampling_ratio = downsampling_ratio
|
| 360 |
+
self.sample_rate = sample_rate
|
| 361 |
+
self.latent_dim = latent_dim
|
| 362 |
+
self.io_channels = io_channels
|
| 363 |
+
self.in_channels = in_channels if in_channels is not None else io_channels
|
| 364 |
+
self.out_channels = out_channels if out_channels is not None else io_channels
|
| 365 |
+
self.bottleneck = bottleneck
|
| 366 |
+
self.encoder = encoder
|
| 367 |
+
self.decoder = decoder
|
| 368 |
+
self.soft_clip = soft_clip
|
| 369 |
+
|
| 370 |
+
def encode(self, audio, skip_bottleneck: bool = False, return_info: bool = False, **kwargs):
|
| 371 |
+
info = {}
|
| 372 |
+
latents = self.encoder(audio)
|
| 373 |
+
info["pre_bottleneck_latents"] = latents
|
| 374 |
+
|
| 375 |
+
if self.bottleneck is not None and not skip_bottleneck:
|
| 376 |
+
latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
|
| 377 |
+
info.update(bottleneck_info)
|
| 378 |
+
|
| 379 |
+
if return_info:
|
| 380 |
+
return latents, info
|
| 381 |
+
return latents
|
| 382 |
+
|
| 383 |
+
def decode(self, latents, skip_bottleneck: bool = False, **kwargs):
|
| 384 |
+
if self.bottleneck is not None and not skip_bottleneck:
|
| 385 |
+
latents = self.bottleneck.decode(latents)
|
| 386 |
+
decoded = self.decoder(latents, **kwargs)
|
| 387 |
+
if self.soft_clip:
|
| 388 |
+
decoded = torch.tanh(decoded)
|
| 389 |
+
return decoded
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# AE factories
|
| 393 |
+
|
| 394 |
+
def create_encoder_from_config(encoder_config: Dict[str, Any]):
|
| 395 |
+
encoder_type = encoder_config.get("type", None)
|
| 396 |
+
assert encoder_type is not None, "Encoder type must be specified"
|
| 397 |
+
if encoder_type != "oobleck":
|
| 398 |
+
raise ValueError(f"Only encoder type 'oobleck' is supported, got: {encoder_type}")
|
| 399 |
+
|
| 400 |
+
encoder = OobleckEncoder(**encoder_config["config"])
|
| 401 |
+
if not encoder_config.get("requires_grad", True):
|
| 402 |
+
for param in encoder.parameters():
|
| 403 |
+
param.requires_grad = False
|
| 404 |
+
return encoder
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def create_decoder_from_config(decoder_config: Dict[str, Any]):
|
| 408 |
+
decoder_type = decoder_config.get("type", None)
|
| 409 |
+
assert decoder_type is not None, "Decoder type must be specified"
|
| 410 |
+
if decoder_type != "oobleck":
|
| 411 |
+
raise ValueError(f"Only decoder type 'oobleck' is supported, got: {decoder_type}")
|
| 412 |
+
|
| 413 |
+
decoder = OobleckDecoder(**decoder_config["config"])
|
| 414 |
+
if not decoder_config.get("requires_grad", True):
|
| 415 |
+
for param in decoder.parameters():
|
| 416 |
+
param.requires_grad = False
|
| 417 |
+
return decoder
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def create_bottleneck_from_config(bottleneck_config: Dict[str, Any]):
|
| 421 |
+
bottleneck_type = bottleneck_config.get("type", None)
|
| 422 |
+
assert bottleneck_type is not None, "type must be specified in bottleneck config"
|
| 423 |
+
|
| 424 |
+
if bottleneck_type != "vae":
|
| 425 |
+
raise NotImplementedError(
|
| 426 |
+
f"Only bottleneck type 'vae' is supported, got: {bottleneck_type}"
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
bottleneck = VAEBottleneck()
|
| 430 |
+
if not bottleneck_config.get("requires_grad", True):
|
| 431 |
+
for param in bottleneck.parameters():
|
| 432 |
+
param.requires_grad = False
|
| 433 |
+
return bottleneck
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def create_autoencoder_from_config(config: Dict[str, Any]):
|
| 437 |
+
ae_config = config["model"]
|
| 438 |
+
|
| 439 |
+
if ae_config.get("pretransform") is not None:
|
| 440 |
+
raise NotImplementedError("Nested pretransform is not supported in sa_audio")
|
| 441 |
+
|
| 442 |
+
encoder = create_encoder_from_config(ae_config["encoder"])
|
| 443 |
+
decoder = create_decoder_from_config(ae_config["decoder"])
|
| 444 |
+
|
| 445 |
+
bottleneck_cfg = ae_config.get("bottleneck")
|
| 446 |
+
bottleneck = create_bottleneck_from_config(bottleneck_cfg) if bottleneck_cfg else None
|
| 447 |
+
|
| 448 |
+
latent_dim = ae_config.get("latent_dim")
|
| 449 |
+
assert latent_dim is not None, "latent_dim must be specified in model config"
|
| 450 |
+
downsampling_ratio = ae_config.get("downsampling_ratio")
|
| 451 |
+
assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
|
| 452 |
+
io_channels = ae_config.get("io_channels")
|
| 453 |
+
assert io_channels is not None, "io_channels must be specified in model config"
|
| 454 |
+
sample_rate = config.get("sample_rate")
|
| 455 |
+
assert sample_rate is not None, "sample_rate must be specified in model config"
|
| 456 |
+
|
| 457 |
+
return AudioAutoencoder(
|
| 458 |
+
encoder=encoder,
|
| 459 |
+
decoder=decoder,
|
| 460 |
+
latent_dim=latent_dim,
|
| 461 |
+
downsampling_ratio=downsampling_ratio,
|
| 462 |
+
sample_rate=sample_rate,
|
| 463 |
+
io_channels=io_channels,
|
| 464 |
+
bottleneck=bottleneck,
|
| 465 |
+
in_channels=ae_config.get("in_channels"),
|
| 466 |
+
out_channels=ae_config.get("out_channels"),
|
| 467 |
+
soft_clip=ae_config["decoder"].get("soft_clip", False),
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def create_model_from_config(model_config: Dict[str, Any]):
|
| 472 |
+
model_type = model_config.get("model_type", None)
|
| 473 |
+
assert model_type is not None, "model_type must be specified in model config"
|
| 474 |
+
|
| 475 |
+
if model_type != "autoencoder":
|
| 476 |
+
raise NotImplementedError(f"Only 'autoencoder' is supported, got: {model_type}")
|
| 477 |
+
|
| 478 |
+
return create_autoencoder_from_config(model_config)
|
inference/model/t5_gemma/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .t5_gemma_model import get_t5_gemma_embedding
|
| 2 |
+
|
| 3 |
+
__all__ = ["get_t5_gemma_embedding"]
|
inference/model/t5_gemma/t5_gemma_model.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
from transformers.models.t5gemma import T5GemmaEncoderModel
|
| 8 |
+
|
| 9 |
+
from inference.common import CPUOffloadWrapper, get_arch_memory
|
| 10 |
+
from inference.utils import env_is_true
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class T5GemmaEncoder:
|
| 14 |
+
def __init__(self, model_path: str, device: str, weight_dtype: torch.dtype):
|
| 15 |
+
self.device = device
|
| 16 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 17 |
+
model = T5GemmaEncoderModel.from_pretrained(
|
| 18 |
+
model_path,
|
| 19 |
+
is_encoder_decoder=False,
|
| 20 |
+
dtype=weight_dtype,
|
| 21 |
+
).to(device)
|
| 22 |
+
self.model = CPUOffloadWrapper(model, is_cpu_offload=env_is_true("CPU_OFFLOAD") or get_arch_memory() <= 48)
|
| 23 |
+
|
| 24 |
+
def encode(self, prompt: str) -> torch.Tensor:
|
| 25 |
+
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
|
| 26 |
+
outputs = self.model(**inputs)
|
| 27 |
+
return outputs["last_hidden_state"].half()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
_t5_gemma_cache: Optional[T5GemmaEncoder] = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_t5_gemma_encoder(model_path: str, device: str, weight_dtype: torch.dtype) -> T5GemmaEncoder:
|
| 34 |
+
global _t5_gemma_cache
|
| 35 |
+
if _t5_gemma_cache is None:
|
| 36 |
+
_t5_gemma_cache = T5GemmaEncoder(model_path=model_path, device=device, weight_dtype=weight_dtype)
|
| 37 |
+
return _t5_gemma_cache
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@torch.inference_mode()
|
| 41 |
+
def get_t5_gemma_embedding(prompt: str, model_path: str, device: str, weight_dtype: torch.dtype) -> torch.Tensor:
|
| 42 |
+
encoder = get_t5_gemma_encoder(model_path=model_path, device=device, weight_dtype=weight_dtype)
|
| 43 |
+
return encoder.encode(prompt)
|
inference/model/turbo_vaed/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .turbo_vaed_module import TurboVAED
|
| 2 |
+
from .turbo_vaed_model import get_turbo_vaed
|
| 3 |
+
|
| 4 |
+
__all__ = ["TurboVAED", "get_turbo_vaed"]
|
inference/model/turbo_vaed/turbo_vaed_model.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from .turbo_vaed_module import TurboVAED
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_turbo_vaed(config_path, ckpt_path, device="cuda", weight_dtype=torch.float32) -> TurboVAED:
|
| 8 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 9 |
+
config = json.load(f)
|
| 10 |
+
student = TurboVAED.from_config(config)
|
| 11 |
+
|
| 12 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 13 |
+
assert "ema_state_dict" in ckpt, "ckpt must contain ema_state_dict or state_dict"
|
| 14 |
+
|
| 15 |
+
state_dict = ckpt["ema_state_dict"]
|
| 16 |
+
new_state_dict = {}
|
| 17 |
+
for key, value in state_dict.items():
|
| 18 |
+
if key.startswith("module."):
|
| 19 |
+
new_state_dict[key[7:]] = value
|
| 20 |
+
else:
|
| 21 |
+
new_state_dict[key] = value
|
| 22 |
+
state_dict = new_state_dict
|
| 23 |
+
|
| 24 |
+
missing, _ = student.load_state_dict(state_dict, strict=False)
|
| 25 |
+
if len(missing) > 0:
|
| 26 |
+
sample_key = next(iter(state_dict.keys()))
|
| 27 |
+
if not sample_key.startswith("decoder.") and not sample_key.startswith("encoder."):
|
| 28 |
+
student.decoder.load_state_dict(state_dict, strict=False)
|
| 29 |
+
|
| 30 |
+
student = student.to(device, dtype=weight_dtype)
|
| 31 |
+
student.eval()
|
| 32 |
+
student.requires_grad_(False)
|
| 33 |
+
return student
|
inference/model/turbo_vaed/turbo_vaed_module.py
ADDED
|
@@ -0,0 +1,1039 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import List, Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 20 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 21 |
+
from einops import rearrange
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
__all__ = ["TurboVAED"]
|
| 25 |
+
|
| 26 |
+
ACT2CLS = {"swish": nn.SiLU, "silu": nn.SiLU, "mish": nn.Mish, "gelu": nn.GELU, "relu": nn.ReLU}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_activation(act_fn: str) -> nn.Module:
|
| 30 |
+
act_fn = act_fn.lower()
|
| 31 |
+
if act_fn in ACT2CLS:
|
| 32 |
+
return ACT2CLS[act_fn]()
|
| 33 |
+
else:
|
| 34 |
+
raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def unpatchify(x, patch_size):
|
| 38 |
+
"""
|
| 39 |
+
Unpatchify operation: convert patched representation back to original spatial resolution.
|
| 40 |
+
Similar to Wan VAE's unpatchify.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
x: Input tensor with shape [batch_size, (channels * patch_size * patch_size), frame, height, width]
|
| 44 |
+
patch_size: The patch size used during patchification
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Tensor with shape [batch_size, channels, frame, height * patch_size, width * patch_size]
|
| 48 |
+
"""
|
| 49 |
+
if patch_size == 1:
|
| 50 |
+
return x
|
| 51 |
+
|
| 52 |
+
if x.dim() != 5:
|
| 53 |
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 54 |
+
|
| 55 |
+
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
|
| 56 |
+
batch_size, c_patches, frames, height, width = x.shape
|
| 57 |
+
channels = c_patches // (patch_size * patch_size)
|
| 58 |
+
|
| 59 |
+
# Reshape to [b, c, patch_size, patch_size, f, h, w]
|
| 60 |
+
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
|
| 61 |
+
|
| 62 |
+
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
|
| 63 |
+
x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous()
|
| 64 |
+
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
|
| 65 |
+
|
| 66 |
+
return x
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class RMSNorm(nn.Module):
|
| 70 |
+
r"""
|
| 71 |
+
RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
|
| 75 |
+
eps (`float`): Small value to use when calculating the reciprocal of the square-root.
|
| 76 |
+
elementwise_affine (`bool`, defaults to `True`):
|
| 77 |
+
Boolean flag to denote if affine transformation should be applied.
|
| 78 |
+
bias (`bool`, defaults to False): If also training the `bias` param.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
|
| 82 |
+
super().__init__()
|
| 83 |
+
|
| 84 |
+
self.eps = eps
|
| 85 |
+
self.elementwise_affine = elementwise_affine
|
| 86 |
+
|
| 87 |
+
if isinstance(dim, int):
|
| 88 |
+
dim = (dim,)
|
| 89 |
+
|
| 90 |
+
self.dim = torch.Size(dim)
|
| 91 |
+
|
| 92 |
+
self.weight = None
|
| 93 |
+
|
| 94 |
+
if elementwise_affine:
|
| 95 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 96 |
+
|
| 97 |
+
def forward(self, hidden_states):
|
| 98 |
+
input_dtype = hidden_states.dtype
|
| 99 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(1, keepdim=True)
|
| 100 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
| 101 |
+
|
| 102 |
+
if self.weight is not None:
|
| 103 |
+
# convert into half-precision if necessary
|
| 104 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
| 105 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
| 106 |
+
hidden_states = hidden_states * self.weight
|
| 107 |
+
else:
|
| 108 |
+
hidden_states = hidden_states.to(input_dtype)
|
| 109 |
+
|
| 110 |
+
return hidden_states
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class TurboVAEDConv2dSplitUpsampler(nn.Module):
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
in_channels: int,
|
| 117 |
+
kernel_size: Union[int, Tuple[int, int]] = 3,
|
| 118 |
+
stride: Union[int, Tuple[int, int]] = 1,
|
| 119 |
+
upscale_factor: int = 1,
|
| 120 |
+
padding_mode: str = "zeros",
|
| 121 |
+
):
|
| 122 |
+
super().__init__()
|
| 123 |
+
|
| 124 |
+
self.in_channels = in_channels
|
| 125 |
+
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
|
| 126 |
+
self.upscale_factor = upscale_factor
|
| 127 |
+
|
| 128 |
+
out_channels = in_channels
|
| 129 |
+
|
| 130 |
+
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
|
| 131 |
+
|
| 132 |
+
height_pad = self.kernel_size[0] // 2
|
| 133 |
+
width_pad = self.kernel_size[1] // 2
|
| 134 |
+
padding = (height_pad, width_pad)
|
| 135 |
+
|
| 136 |
+
self.conv = nn.Conv2d(
|
| 137 |
+
in_channels=in_channels,
|
| 138 |
+
out_channels=out_channels,
|
| 139 |
+
kernel_size=self.kernel_size,
|
| 140 |
+
stride=1,
|
| 141 |
+
padding=padding,
|
| 142 |
+
padding_mode=padding_mode,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
@torch.compile
|
| 146 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 147 |
+
hidden_states = self.conv(hidden_states)
|
| 148 |
+
hidden_states = torch.nn.functional.pixel_shuffle(hidden_states, self.stride[0])
|
| 149 |
+
|
| 150 |
+
return hidden_states
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class TurboVAEDCausalConv3d(nn.Module):
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
in_channels: int,
|
| 157 |
+
out_channels: int,
|
| 158 |
+
kernel_size: Union[int, Tuple[int, int, int]] = 3,
|
| 159 |
+
stride: Union[int, Tuple[int, int, int]] = 1,
|
| 160 |
+
dilation: Union[int, Tuple[int, int, int]] = 1,
|
| 161 |
+
groups: int = 1,
|
| 162 |
+
padding_mode: str = "zeros",
|
| 163 |
+
is_causal: bool = False,
|
| 164 |
+
):
|
| 165 |
+
super().__init__()
|
| 166 |
+
|
| 167 |
+
assert is_causal == False
|
| 168 |
+
self.in_channels = in_channels
|
| 169 |
+
self.out_channels = out_channels
|
| 170 |
+
self.is_causal = is_causal
|
| 171 |
+
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
|
| 172 |
+
|
| 173 |
+
dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1)
|
| 174 |
+
stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
| 175 |
+
height_pad = self.kernel_size[1] // 2
|
| 176 |
+
width_pad = self.kernel_size[2] // 2
|
| 177 |
+
padding = (0, height_pad, width_pad)
|
| 178 |
+
|
| 179 |
+
self.conv = nn.Conv3d(
|
| 180 |
+
in_channels,
|
| 181 |
+
out_channels,
|
| 182 |
+
self.kernel_size,
|
| 183 |
+
stride=stride,
|
| 184 |
+
dilation=dilation,
|
| 185 |
+
groups=groups,
|
| 186 |
+
padding=padding,
|
| 187 |
+
padding_mode=padding_mode,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
@torch.compile
|
| 191 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 192 |
+
time_kernel_size = self.kernel_size[0]
|
| 193 |
+
|
| 194 |
+
if time_kernel_size > 1:
|
| 195 |
+
pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1))
|
| 196 |
+
pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1))
|
| 197 |
+
hidden_states = torch.cat([pad_left, hidden_states, pad_right], dim=2)
|
| 198 |
+
|
| 199 |
+
hidden_states = self.conv(hidden_states)
|
| 200 |
+
return hidden_states
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class TurboVAEDCausalDepthwiseSeperableConv3d(nn.Module):
|
| 204 |
+
def __init__(
|
| 205 |
+
self,
|
| 206 |
+
in_channels: int,
|
| 207 |
+
out_channels: int,
|
| 208 |
+
kernel_size: Union[int, Tuple[int, int, int]] = 3,
|
| 209 |
+
stride: Union[int, Tuple[int, int, int]] = 1,
|
| 210 |
+
dilation: Union[int, Tuple[int, int, int]] = 1,
|
| 211 |
+
padding_mode: str = "zeros",
|
| 212 |
+
is_causal: bool = True,
|
| 213 |
+
):
|
| 214 |
+
super().__init__()
|
| 215 |
+
|
| 216 |
+
self.in_channels = in_channels
|
| 217 |
+
self.out_channels = out_channels
|
| 218 |
+
self.is_causal = is_causal
|
| 219 |
+
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
|
| 220 |
+
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
| 221 |
+
self.dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1)
|
| 222 |
+
|
| 223 |
+
# Calculate padding for height and width dimensions
|
| 224 |
+
height_pad = self.kernel_size[1] // 2
|
| 225 |
+
width_pad = self.kernel_size[2] // 2
|
| 226 |
+
self.padding = (0, height_pad, width_pad)
|
| 227 |
+
|
| 228 |
+
# Depthwise Convolution
|
| 229 |
+
self.depthwise_conv = nn.Conv3d(
|
| 230 |
+
in_channels,
|
| 231 |
+
in_channels,
|
| 232 |
+
self.kernel_size,
|
| 233 |
+
stride=self.stride,
|
| 234 |
+
dilation=self.dilation,
|
| 235 |
+
groups=in_channels, # Each input channel is convolved separately
|
| 236 |
+
padding=self.padding,
|
| 237 |
+
padding_mode=padding_mode,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Pointwise Convolution
|
| 241 |
+
self.pointwise_conv = nn.Conv3d(in_channels, out_channels, kernel_size=1) # 1x1x1 convolution to mix channels
|
| 242 |
+
|
| 243 |
+
@torch.compile
|
| 244 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 245 |
+
time_kernel_size = self.kernel_size[0]
|
| 246 |
+
if time_kernel_size > 1:
|
| 247 |
+
pad_count = (time_kernel_size - 1) // 2
|
| 248 |
+
pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, pad_count, 1, 1))
|
| 249 |
+
pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, pad_count, 1, 1))
|
| 250 |
+
hidden_states = torch.cat([pad_left, hidden_states, pad_right], dim=2)
|
| 251 |
+
|
| 252 |
+
# Apply depthwise convolution
|
| 253 |
+
hidden_states = self.depthwise_conv(hidden_states)
|
| 254 |
+
# Apply pointwise convolution
|
| 255 |
+
hidden_states = self.pointwise_conv(hidden_states)
|
| 256 |
+
|
| 257 |
+
return hidden_states
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class TurboVAEDResnetBlock3d(nn.Module):
|
| 261 |
+
r"""
|
| 262 |
+
A 3D ResNet block used in the TurboVAED model.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
in_channels (`int`):
|
| 266 |
+
Number of input channels.
|
| 267 |
+
out_channels (`int`, *optional*):
|
| 268 |
+
Number of output channels. If None, defaults to `in_channels`.
|
| 269 |
+
dropout (`float`, defaults to `0.0`):
|
| 270 |
+
Dropout rate.
|
| 271 |
+
eps (`float`, defaults to `1e-6`):
|
| 272 |
+
Epsilon value for normalization layers.
|
| 273 |
+
elementwise_affine (`bool`, defaults to `False`):
|
| 274 |
+
Whether to enable elementwise affinity in the normalization layers.
|
| 275 |
+
non_linearity (`str`, defaults to `"swish"`):
|
| 276 |
+
Activation function to use.
|
| 277 |
+
conv_shortcut (bool, defaults to `False`):
|
| 278 |
+
Whether or not to use a convolution shortcut.
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
def __init__(
|
| 282 |
+
self,
|
| 283 |
+
in_channels: int,
|
| 284 |
+
out_channels: Optional[int] = None,
|
| 285 |
+
dropout: float = 0.0,
|
| 286 |
+
eps: float = 1e-6,
|
| 287 |
+
elementwise_affine: bool = False,
|
| 288 |
+
non_linearity: str = "swish",
|
| 289 |
+
is_causal: bool = True,
|
| 290 |
+
is_upsampler_modified: bool = False,
|
| 291 |
+
is_dw_conv: bool = False,
|
| 292 |
+
dw_kernel_size: int = 3,
|
| 293 |
+
) -> None:
|
| 294 |
+
super().__init__()
|
| 295 |
+
|
| 296 |
+
out_channels = out_channels or in_channels
|
| 297 |
+
|
| 298 |
+
self.nonlinearity = get_activation(non_linearity)
|
| 299 |
+
|
| 300 |
+
self.conv_operation = TurboVAEDCausalConv3d if not is_dw_conv else TurboVAEDCausalDepthwiseSeperableConv3d
|
| 301 |
+
self.kernel_size = 3 if not is_dw_conv else dw_kernel_size
|
| 302 |
+
|
| 303 |
+
self.is_upsampler_modified = is_upsampler_modified
|
| 304 |
+
self.replace_nonlinearity = get_activation("relu")
|
| 305 |
+
|
| 306 |
+
self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
|
| 307 |
+
self.conv1 = self.conv_operation(
|
| 308 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=self.kernel_size, is_causal=is_causal
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
|
| 312 |
+
self.dropout = nn.Dropout(dropout)
|
| 313 |
+
self.conv2 = self.conv_operation(
|
| 314 |
+
in_channels=out_channels, out_channels=out_channels, kernel_size=self.kernel_size, is_causal=is_causal
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
self.norm3 = None
|
| 318 |
+
self.conv_shortcut = None
|
| 319 |
+
if in_channels != out_channels:
|
| 320 |
+
self.norm3 = RMSNorm(in_channels, eps=eps, elementwise_affine=elementwise_affine)
|
| 321 |
+
self.conv_shortcut = self.conv_operation(
|
| 322 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
@torch.compile
|
| 326 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
| 327 |
+
hidden_states = inputs
|
| 328 |
+
|
| 329 |
+
hidden_states = self.norm1(hidden_states)
|
| 330 |
+
|
| 331 |
+
if self.is_upsampler_modified:
|
| 332 |
+
hidden_states = self.replace_nonlinearity(hidden_states)
|
| 333 |
+
else:
|
| 334 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 335 |
+
|
| 336 |
+
hidden_states = self.conv1(hidden_states)
|
| 337 |
+
|
| 338 |
+
hidden_states = self.norm2(hidden_states)
|
| 339 |
+
|
| 340 |
+
hidden_states = self.nonlinearity(hidden_states)
|
| 341 |
+
hidden_states = self.dropout(hidden_states)
|
| 342 |
+
|
| 343 |
+
hidden_states = self.conv2(hidden_states)
|
| 344 |
+
|
| 345 |
+
if self.norm3 is not None:
|
| 346 |
+
inputs = self.norm3(inputs)
|
| 347 |
+
|
| 348 |
+
if self.conv_shortcut is not None:
|
| 349 |
+
inputs = self.conv_shortcut(inputs)
|
| 350 |
+
|
| 351 |
+
hidden_states = hidden_states + inputs
|
| 352 |
+
return hidden_states
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class TurboVAEDUpsampler3d(nn.Module):
|
| 356 |
+
def __init__(
|
| 357 |
+
self,
|
| 358 |
+
in_channels: int,
|
| 359 |
+
stride: Union[int, Tuple[int, int, int]] = 1,
|
| 360 |
+
is_causal: bool = True,
|
| 361 |
+
upscale_factor: int = 1,
|
| 362 |
+
padding_mode: str = "zeros",
|
| 363 |
+
) -> None:
|
| 364 |
+
super().__init__()
|
| 365 |
+
|
| 366 |
+
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
|
| 367 |
+
self.upscale_factor = upscale_factor
|
| 368 |
+
|
| 369 |
+
out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
|
| 370 |
+
|
| 371 |
+
self.conv = TurboVAEDCausalConv3d(
|
| 372 |
+
in_channels=in_channels,
|
| 373 |
+
out_channels=out_channels,
|
| 374 |
+
kernel_size=3,
|
| 375 |
+
stride=1,
|
| 376 |
+
is_causal=is_causal,
|
| 377 |
+
padding_mode=padding_mode,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
@torch.compile
|
| 381 |
+
def forward(self, hidden_states: torch.Tensor, is_first_chunk: bool = True) -> torch.Tensor:
|
| 382 |
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 383 |
+
|
| 384 |
+
hidden_states = self.conv(hidden_states)
|
| 385 |
+
|
| 386 |
+
# because the former has better performance on cuda kernels.
|
| 387 |
+
s_t, s_h, s_w = self.stride
|
| 388 |
+
hidden_states = hidden_states.reshape(batch_size, -1, s_t, s_h, s_w, num_frames, height, width)
|
| 389 |
+
hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
| 390 |
+
hidden_states = hidden_states.reshape(batch_size, -1, num_frames * s_t, height * s_h, width * s_w)
|
| 391 |
+
|
| 392 |
+
# slice the first chunk
|
| 393 |
+
if is_first_chunk:
|
| 394 |
+
hidden_states = hidden_states[:, :, self.stride[0] - 1 :] # NOTE: extra handling for the first frame
|
| 395 |
+
|
| 396 |
+
return hidden_states
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class WanUpsample(nn.Upsample):
|
| 400 |
+
r"""
|
| 401 |
+
Perform upsampling while ensuring the output tensor has the same data type as the input.
|
| 402 |
+
|
| 403 |
+
Args:
|
| 404 |
+
x (torch.Tensor): Input tensor to be upsampled.
|
| 405 |
+
|
| 406 |
+
Returns:
|
| 407 |
+
torch.Tensor: Upsampled tensor with the same data type as the input.
|
| 408 |
+
"""
|
| 409 |
+
|
| 410 |
+
def forward(self, x):
|
| 411 |
+
return super().forward(x.float()).type_as(x)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
class WanResample(nn.Module):
|
| 415 |
+
r"""
|
| 416 |
+
A custom resampling module for 2D and 3D data.
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
dim (int): The number of input/output channels.
|
| 420 |
+
mode (str): The resampling mode. Must be one of:
|
| 421 |
+
- 'none': No resampling (identity operation).
|
| 422 |
+
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
|
| 423 |
+
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
|
| 424 |
+
"""
|
| 425 |
+
|
| 426 |
+
def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
|
| 427 |
+
super().__init__()
|
| 428 |
+
self.dim = dim
|
| 429 |
+
self.mode = mode
|
| 430 |
+
|
| 431 |
+
# default to dim //2
|
| 432 |
+
if upsample_out_dim is None:
|
| 433 |
+
upsample_out_dim = dim // 2
|
| 434 |
+
|
| 435 |
+
# layers
|
| 436 |
+
if mode == "upsample2d":
|
| 437 |
+
self.resample = nn.Sequential(
|
| 438 |
+
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
|
| 439 |
+
)
|
| 440 |
+
elif mode == "upsample3d":
|
| 441 |
+
self.resample = nn.Sequential(
|
| 442 |
+
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
|
| 443 |
+
)
|
| 444 |
+
self.time_conv = TurboVAEDCausalConv3d(dim, dim * 2, (3, 1, 1))
|
| 445 |
+
else:
|
| 446 |
+
self.resample = nn.Identity()
|
| 447 |
+
|
| 448 |
+
def forward(self, x, is_first_chunk: bool = True):
|
| 449 |
+
b, c, t, h, w = x.shape
|
| 450 |
+
if self.mode == "upsample3d":
|
| 451 |
+
x = self.time_conv(x)
|
| 452 |
+
x = rearrange(x, 'b (n_split c) t h w -> b c (t n_split) h w', n_split=2)
|
| 453 |
+
assert x.shape == (b, c, t * 2, h, w), "x.shape: {}, expected: {}".format(x.shape, (b, c, t * 2, h, w))
|
| 454 |
+
if is_first_chunk:
|
| 455 |
+
x = x[:, :, 1:]
|
| 456 |
+
|
| 457 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
| 458 |
+
x = self.resample(x)
|
| 459 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", b=b)
|
| 460 |
+
|
| 461 |
+
return x
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
class TurboVAEDMidBlock3d(nn.Module):
|
| 465 |
+
r"""
|
| 466 |
+
A middle block used in the TurboVAED model.
|
| 467 |
+
|
| 468 |
+
Args:
|
| 469 |
+
in_channels (`int`):
|
| 470 |
+
Number of input channels.
|
| 471 |
+
num_layers (`int`, defaults to `1`):
|
| 472 |
+
Number of resnet layers.
|
| 473 |
+
dropout (`float`, defaults to `0.0`):
|
| 474 |
+
Dropout rate.
|
| 475 |
+
resnet_eps (`float`, defaults to `1e-6`):
|
| 476 |
+
Epsilon value for normalization layers.
|
| 477 |
+
resnet_act_fn (`str`, defaults to `"swish"`):
|
| 478 |
+
Activation function to use.
|
| 479 |
+
is_causal (`bool`, defaults to `True`):
|
| 480 |
+
Whether this layer behaves causally (future frames depend only on past frames) or not.
|
| 481 |
+
"""
|
| 482 |
+
|
| 483 |
+
_supports_gradient_checkpointing = True
|
| 484 |
+
|
| 485 |
+
def __init__(
|
| 486 |
+
self,
|
| 487 |
+
in_channels: int,
|
| 488 |
+
num_layers: int = 1,
|
| 489 |
+
dropout: float = 0.0,
|
| 490 |
+
resnet_eps: float = 1e-6,
|
| 491 |
+
resnet_act_fn: str = "swish",
|
| 492 |
+
is_causal: bool = True,
|
| 493 |
+
is_dw_conv: bool = False,
|
| 494 |
+
dw_kernel_size: int = 3,
|
| 495 |
+
) -> None:
|
| 496 |
+
super().__init__()
|
| 497 |
+
|
| 498 |
+
resnets = []
|
| 499 |
+
for _ in range(num_layers):
|
| 500 |
+
resnets.append(
|
| 501 |
+
TurboVAEDResnetBlock3d(
|
| 502 |
+
in_channels=in_channels,
|
| 503 |
+
out_channels=in_channels,
|
| 504 |
+
dropout=dropout,
|
| 505 |
+
eps=resnet_eps,
|
| 506 |
+
non_linearity=resnet_act_fn,
|
| 507 |
+
is_causal=is_causal,
|
| 508 |
+
is_dw_conv=is_dw_conv,
|
| 509 |
+
dw_kernel_size=dw_kernel_size,
|
| 510 |
+
)
|
| 511 |
+
)
|
| 512 |
+
self.resnets = nn.ModuleList(resnets)
|
| 513 |
+
|
| 514 |
+
self.gradient_checkpointing = False
|
| 515 |
+
|
| 516 |
+
@torch.compile
|
| 517 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 518 |
+
r"""Forward method of the `LTXMidBlock3D` class."""
|
| 519 |
+
|
| 520 |
+
for i, resnet in enumerate(self.resnets):
|
| 521 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 522 |
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
|
| 523 |
+
else:
|
| 524 |
+
hidden_states = resnet(hidden_states)
|
| 525 |
+
|
| 526 |
+
return hidden_states
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
class TurboVAEDUpBlock3d(nn.Module):
|
| 530 |
+
r"""
|
| 531 |
+
Up block used in the TurboVAED model.
|
| 532 |
+
|
| 533 |
+
Args:
|
| 534 |
+
in_channels (`int`):
|
| 535 |
+
Number of input channels.
|
| 536 |
+
out_channels (`int`, *optional*):
|
| 537 |
+
Number of output channels. If None, defaults to `in_channels`.
|
| 538 |
+
num_layers (`int`, defaults to `1`):
|
| 539 |
+
Number of resnet layers.
|
| 540 |
+
dropout (`float`, defaults to `0.0`):
|
| 541 |
+
Dropout rate.
|
| 542 |
+
resnet_eps (`float`, defaults to `1e-6`):
|
| 543 |
+
Epsilon value for normalization layers.
|
| 544 |
+
resnet_act_fn (`str`, defaults to `"swish"`):
|
| 545 |
+
Activation function to use.
|
| 546 |
+
spatio_temporal_scale (`bool`, defaults to `True`):
|
| 547 |
+
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
|
| 548 |
+
Whether or not to downsample across temporal dimension.
|
| 549 |
+
is_causal (`bool`, defaults to `True`):
|
| 550 |
+
Whether this layer behaves causally (future frames depend only on past frames) or not.
|
| 551 |
+
"""
|
| 552 |
+
|
| 553 |
+
_supports_gradient_checkpointing = True
|
| 554 |
+
|
| 555 |
+
def __init__(
|
| 556 |
+
self,
|
| 557 |
+
in_channels: int,
|
| 558 |
+
out_channels: Optional[int] = None,
|
| 559 |
+
num_layers: int = 1,
|
| 560 |
+
dropout: float = 0.0,
|
| 561 |
+
resnet_eps: float = 1e-6,
|
| 562 |
+
resnet_act_fn: str = "swish",
|
| 563 |
+
spatio_temporal_scale: bool = True,
|
| 564 |
+
is_causal: bool = True,
|
| 565 |
+
is_dw_conv: bool = False,
|
| 566 |
+
dw_kernel_size: int = 3,
|
| 567 |
+
spatio_only: bool = False,
|
| 568 |
+
):
|
| 569 |
+
super().__init__()
|
| 570 |
+
|
| 571 |
+
out_channels = out_channels or in_channels
|
| 572 |
+
|
| 573 |
+
self.conv_in = None
|
| 574 |
+
if in_channels != out_channels:
|
| 575 |
+
self.conv_in = TurboVAEDResnetBlock3d(
|
| 576 |
+
in_channels=in_channels,
|
| 577 |
+
out_channels=out_channels,
|
| 578 |
+
dropout=dropout,
|
| 579 |
+
eps=resnet_eps,
|
| 580 |
+
non_linearity=resnet_act_fn,
|
| 581 |
+
is_causal=is_causal,
|
| 582 |
+
is_dw_conv=is_dw_conv,
|
| 583 |
+
dw_kernel_size=dw_kernel_size,
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
self.upsamplers = None
|
| 587 |
+
if spatio_temporal_scale:
|
| 588 |
+
self.upsamplers = nn.ModuleList(
|
| 589 |
+
[
|
| 590 |
+
WanResample(
|
| 591 |
+
dim=out_channels, mode="upsample2d" if spatio_only else "upsample3d", upsample_out_dim=out_channels
|
| 592 |
+
)
|
| 593 |
+
]
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
resnets = []
|
| 597 |
+
for _ in range(num_layers):
|
| 598 |
+
resnets.append(
|
| 599 |
+
TurboVAEDResnetBlock3d(
|
| 600 |
+
in_channels=out_channels,
|
| 601 |
+
out_channels=out_channels,
|
| 602 |
+
dropout=dropout,
|
| 603 |
+
eps=resnet_eps,
|
| 604 |
+
non_linearity=resnet_act_fn,
|
| 605 |
+
is_causal=is_causal,
|
| 606 |
+
is_dw_conv=is_dw_conv,
|
| 607 |
+
dw_kernel_size=dw_kernel_size,
|
| 608 |
+
is_upsampler_modified=(spatio_temporal_scale),
|
| 609 |
+
)
|
| 610 |
+
)
|
| 611 |
+
self.resnets = nn.ModuleList(resnets)
|
| 612 |
+
|
| 613 |
+
self.gradient_checkpointing = False
|
| 614 |
+
|
| 615 |
+
@torch.compile
|
| 616 |
+
def forward(self, hidden_states: torch.Tensor, is_first_chunk: bool) -> torch.Tensor:
|
| 617 |
+
if self.conv_in is not None:
|
| 618 |
+
hidden_states = self.conv_in(hidden_states)
|
| 619 |
+
|
| 620 |
+
if self.upsamplers is not None:
|
| 621 |
+
for upsampler in self.upsamplers:
|
| 622 |
+
hidden_states = upsampler(hidden_states, is_first_chunk=is_first_chunk)
|
| 623 |
+
|
| 624 |
+
for i, resnet in enumerate(self.resnets):
|
| 625 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 626 |
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
|
| 627 |
+
else:
|
| 628 |
+
hidden_states = resnet(hidden_states)
|
| 629 |
+
|
| 630 |
+
return hidden_states
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
class TurboVAEDDecoder3d(nn.Module):
|
| 634 |
+
r"""
|
| 635 |
+
The `TurboVAEDDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
|
| 636 |
+
sample.
|
| 637 |
+
|
| 638 |
+
Args:
|
| 639 |
+
in_channels (`int`, defaults to 128):
|
| 640 |
+
Number of latent channels.
|
| 641 |
+
out_channels (`int`, defaults to 3):
|
| 642 |
+
Number of output channels.
|
| 643 |
+
block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
|
| 644 |
+
The number of output channels for each block.
|
| 645 |
+
spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
|
| 646 |
+
Whether a block should contain spatio-temporal upscaling layers or not.
|
| 647 |
+
layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
|
| 648 |
+
The number of layers per block.
|
| 649 |
+
patch_size (`int`, defaults to `4`):
|
| 650 |
+
The size of spatial patches.
|
| 651 |
+
patch_size_t (`int`, defaults to `1`):
|
| 652 |
+
The size of temporal patches.
|
| 653 |
+
resnet_norm_eps (`float`, defaults to `1e-6`):
|
| 654 |
+
Epsilon value for ResNet normalization layers.
|
| 655 |
+
is_causal (`bool`, defaults to `False`):
|
| 656 |
+
Whether this layer behaves causally (future frames depend only on past frames) or not.
|
| 657 |
+
"""
|
| 658 |
+
|
| 659 |
+
def __init__(
|
| 660 |
+
self,
|
| 661 |
+
in_channels: int = 128,
|
| 662 |
+
out_channels: int = 3,
|
| 663 |
+
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
| 664 |
+
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
| 665 |
+
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
| 666 |
+
patch_size: int = 4,
|
| 667 |
+
patch_size_t: int = 1,
|
| 668 |
+
resnet_norm_eps: float = 1e-6,
|
| 669 |
+
is_causal: bool = False,
|
| 670 |
+
decoder_is_dw_conv: Tuple[bool, ...] = (False, False, False, False, False),
|
| 671 |
+
decoder_dw_kernel_size: int = 3,
|
| 672 |
+
spatio_only: Tuple[bool, ...] = (False, False, False, False),
|
| 673 |
+
upsampling: bool = False,
|
| 674 |
+
use_unpatchify: bool = False,
|
| 675 |
+
) -> None:
|
| 676 |
+
super().__init__()
|
| 677 |
+
|
| 678 |
+
self.patch_size = patch_size
|
| 679 |
+
self.patch_size_t = patch_size_t
|
| 680 |
+
self.out_channels = out_channels
|
| 681 |
+
|
| 682 |
+
self.upsampling = upsampling
|
| 683 |
+
self.use_unpatchify = use_unpatchify
|
| 684 |
+
|
| 685 |
+
block_out_channels = tuple(reversed(block_out_channels))
|
| 686 |
+
spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
|
| 687 |
+
layers_per_block = tuple(reversed(layers_per_block))
|
| 688 |
+
decoder_is_dw_conv = tuple(reversed(decoder_is_dw_conv))
|
| 689 |
+
spatio_only = tuple(reversed(spatio_only))
|
| 690 |
+
output_channel = block_out_channels[0]
|
| 691 |
+
|
| 692 |
+
self.conv_in = TurboVAEDCausalConv3d(
|
| 693 |
+
in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
self.mid_block = TurboVAEDMidBlock3d(
|
| 697 |
+
in_channels=output_channel,
|
| 698 |
+
num_layers=layers_per_block[0],
|
| 699 |
+
resnet_eps=resnet_norm_eps,
|
| 700 |
+
is_causal=is_causal,
|
| 701 |
+
is_dw_conv=decoder_is_dw_conv[0],
|
| 702 |
+
dw_kernel_size=decoder_dw_kernel_size,
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
# up blocks
|
| 706 |
+
num_block_out_channels = len(block_out_channels)
|
| 707 |
+
self.up_blocks = nn.ModuleList([])
|
| 708 |
+
for i in range(num_block_out_channels):
|
| 709 |
+
input_channel = output_channel
|
| 710 |
+
output_channel = block_out_channels[i]
|
| 711 |
+
|
| 712 |
+
up_block = TurboVAEDUpBlock3d(
|
| 713 |
+
in_channels=input_channel,
|
| 714 |
+
out_channels=output_channel,
|
| 715 |
+
num_layers=layers_per_block[i + 1],
|
| 716 |
+
resnet_eps=resnet_norm_eps,
|
| 717 |
+
spatio_temporal_scale=spatio_temporal_scaling[i],
|
| 718 |
+
is_causal=is_causal,
|
| 719 |
+
is_dw_conv=decoder_is_dw_conv[i + 1],
|
| 720 |
+
dw_kernel_size=decoder_dw_kernel_size,
|
| 721 |
+
spatio_only=spatio_only[i],
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
self.up_blocks.append(up_block)
|
| 725 |
+
|
| 726 |
+
# out
|
| 727 |
+
assert self.patch_size == 2
|
| 728 |
+
if not self.use_unpatchify:
|
| 729 |
+
self.norm_up_1 = RMSNorm(output_channel, eps=1e-8, elementwise_affine=False)
|
| 730 |
+
self.upsampler2d_1 = TurboVAEDConv2dSplitUpsampler(in_channels=output_channel, kernel_size=3, stride=(2, 2))
|
| 731 |
+
output_channel = output_channel // (2 * 2)
|
| 732 |
+
|
| 733 |
+
self.conv_act = nn.SiLU()
|
| 734 |
+
|
| 735 |
+
# When use_unpatchify=True, conv_out outputs more channels (out_channels * patch_size^2)
|
| 736 |
+
# and unpatchify will recover the spatial resolution
|
| 737 |
+
conv_out_channels = self.out_channels
|
| 738 |
+
if self.use_unpatchify and self.patch_size >= 2:
|
| 739 |
+
conv_out_channels = self.out_channels * self.patch_size * self.patch_size
|
| 740 |
+
|
| 741 |
+
self.conv_out = TurboVAEDCausalConv3d(
|
| 742 |
+
in_channels=output_channel, out_channels=conv_out_channels, kernel_size=3, stride=1, is_causal=is_causal
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
self.gradient_checkpointing = False
|
| 746 |
+
|
| 747 |
+
@torch.compile
|
| 748 |
+
def forward(self, hidden_states: torch.Tensor, is_first_chunk: bool) -> torch.Tensor:
|
| 749 |
+
hidden_states = self.conv_in(hidden_states)
|
| 750 |
+
|
| 751 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 752 |
+
|
| 753 |
+
def create_custom_forward(module):
|
| 754 |
+
def create_forward(*inputs):
|
| 755 |
+
return module(*inputs)
|
| 756 |
+
|
| 757 |
+
return create_forward
|
| 758 |
+
|
| 759 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)
|
| 760 |
+
|
| 761 |
+
for up_block in self.up_blocks:
|
| 762 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 763 |
+
create_custom_forward(up_block), hidden_states, is_first_chunk
|
| 764 |
+
)
|
| 765 |
+
else:
|
| 766 |
+
hidden_states = self.mid_block(hidden_states)
|
| 767 |
+
|
| 768 |
+
for index, up_block in enumerate(self.up_blocks):
|
| 769 |
+
hidden_states = up_block(hidden_states, is_first_chunk=is_first_chunk)
|
| 770 |
+
|
| 771 |
+
if not self.use_unpatchify:
|
| 772 |
+
hidden_states = self.norm_up_1(hidden_states)
|
| 773 |
+
hidden_states = self.conv_act(hidden_states)
|
| 774 |
+
|
| 775 |
+
hidden_states_array = []
|
| 776 |
+
for t in range(hidden_states.shape[2]):
|
| 777 |
+
h = self.upsampler2d_1(hidden_states[:, :, t, :, :])
|
| 778 |
+
hidden_states_array.append(h)
|
| 779 |
+
hidden_states = torch.stack(hidden_states_array, dim=2)
|
| 780 |
+
|
| 781 |
+
# RMSNorm
|
| 782 |
+
input_dtype = hidden_states.dtype
|
| 783 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(1, keepdim=True)
|
| 784 |
+
hidden_states = hidden_states * torch.rsqrt(variance + 1e-8)
|
| 785 |
+
hidden_states = hidden_states.to(input_dtype)
|
| 786 |
+
|
| 787 |
+
hidden_states = self.conv_act(hidden_states)
|
| 788 |
+
|
| 789 |
+
hidden_states = self.conv_out(hidden_states)
|
| 790 |
+
|
| 791 |
+
if self.use_unpatchify:
|
| 792 |
+
hidden_states = unpatchify(hidden_states, self.patch_size)
|
| 793 |
+
|
| 794 |
+
return hidden_states
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
class TurboVAED(ModelMixin, ConfigMixin):
|
| 798 |
+
_supports_gradient_checkpointing = True
|
| 799 |
+
|
| 800 |
+
@register_to_config
|
| 801 |
+
def __init__(
|
| 802 |
+
self,
|
| 803 |
+
in_channels: int = 3, # useless arg for compatibility, we only use latent channels
|
| 804 |
+
out_channels: int = 3,
|
| 805 |
+
latent_channels: int = 128,
|
| 806 |
+
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
| 807 |
+
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
|
| 808 |
+
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
|
| 809 |
+
patch_size: int = 4,
|
| 810 |
+
patch_size_t: int = 1,
|
| 811 |
+
resnet_norm_eps: float = 1e-6,
|
| 812 |
+
scaling_factor: float = 1.0,
|
| 813 |
+
decoder_causal: bool = False,
|
| 814 |
+
decoder_is_dw_conv: Tuple[bool, ...] = (False, False, False, False, False),
|
| 815 |
+
decoder_dw_kernel_size: int = 3,
|
| 816 |
+
decoder_spatio_only: Tuple[bool, ...] = (False, False, False, False),
|
| 817 |
+
first_chunk_size: int = 3,
|
| 818 |
+
step_size: int = 5,
|
| 819 |
+
spatial_compression_ratio: int = 16,
|
| 820 |
+
temporal_compression_ratio: int = 4,
|
| 821 |
+
use_unpatchify: bool = False,
|
| 822 |
+
# below are for training, keep for compatibility
|
| 823 |
+
aligned_feature_projection_mode: Optional[str] = None,
|
| 824 |
+
aligned_feature_projection_dim: Optional[List[Tuple[int, int]]] = None,
|
| 825 |
+
aligned_blks_indices: Optional[List[int]] = None,
|
| 826 |
+
):
|
| 827 |
+
super().__init__()
|
| 828 |
+
|
| 829 |
+
self.decoder = TurboVAEDDecoder3d(
|
| 830 |
+
in_channels=latent_channels,
|
| 831 |
+
out_channels=out_channels,
|
| 832 |
+
block_out_channels=decoder_block_out_channels,
|
| 833 |
+
spatio_temporal_scaling=decoder_spatio_temporal_scaling,
|
| 834 |
+
layers_per_block=decoder_layers_per_block,
|
| 835 |
+
patch_size=patch_size,
|
| 836 |
+
patch_size_t=patch_size_t,
|
| 837 |
+
resnet_norm_eps=resnet_norm_eps,
|
| 838 |
+
is_causal=decoder_causal,
|
| 839 |
+
decoder_is_dw_conv=decoder_is_dw_conv,
|
| 840 |
+
decoder_dw_kernel_size=decoder_dw_kernel_size,
|
| 841 |
+
spatio_only=decoder_spatio_only,
|
| 842 |
+
use_unpatchify=use_unpatchify,
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
self.first_chunk_size = first_chunk_size
|
| 846 |
+
self.step_size = step_size
|
| 847 |
+
|
| 848 |
+
self.spatial_compression_ratio = spatial_compression_ratio
|
| 849 |
+
self.temporal_compression_ratio = temporal_compression_ratio
|
| 850 |
+
|
| 851 |
+
self.z_dim = latent_channels
|
| 852 |
+
self.mean = torch.tensor(
|
| 853 |
+
[
|
| 854 |
+
-0.2289,
|
| 855 |
+
-0.0052,
|
| 856 |
+
-0.1323,
|
| 857 |
+
-0.2339,
|
| 858 |
+
-0.2799,
|
| 859 |
+
0.0174,
|
| 860 |
+
0.1838,
|
| 861 |
+
0.1557,
|
| 862 |
+
-0.1382,
|
| 863 |
+
0.0542,
|
| 864 |
+
0.2813,
|
| 865 |
+
0.0891,
|
| 866 |
+
0.1570,
|
| 867 |
+
-0.0098,
|
| 868 |
+
0.0375,
|
| 869 |
+
-0.1825,
|
| 870 |
+
-0.2246,
|
| 871 |
+
-0.1207,
|
| 872 |
+
-0.0698,
|
| 873 |
+
0.5109,
|
| 874 |
+
0.2665,
|
| 875 |
+
-0.2108,
|
| 876 |
+
-0.2158,
|
| 877 |
+
0.2502,
|
| 878 |
+
-0.2055,
|
| 879 |
+
-0.0322,
|
| 880 |
+
0.1109,
|
| 881 |
+
0.1567,
|
| 882 |
+
-0.0729,
|
| 883 |
+
0.0899,
|
| 884 |
+
-0.2799,
|
| 885 |
+
-0.1230,
|
| 886 |
+
-0.0313,
|
| 887 |
+
-0.1649,
|
| 888 |
+
0.0117,
|
| 889 |
+
0.0723,
|
| 890 |
+
-0.2839,
|
| 891 |
+
-0.2083,
|
| 892 |
+
-0.0520,
|
| 893 |
+
0.3748,
|
| 894 |
+
0.0152,
|
| 895 |
+
0.1957,
|
| 896 |
+
0.1433,
|
| 897 |
+
-0.2944,
|
| 898 |
+
0.3573,
|
| 899 |
+
-0.0548,
|
| 900 |
+
-0.1681,
|
| 901 |
+
-0.0667,
|
| 902 |
+
],
|
| 903 |
+
dtype=torch.float32,
|
| 904 |
+
device="cuda",
|
| 905 |
+
)
|
| 906 |
+
self.std = torch.tensor(
|
| 907 |
+
[
|
| 908 |
+
0.4765,
|
| 909 |
+
1.0364,
|
| 910 |
+
0.4514,
|
| 911 |
+
1.1677,
|
| 912 |
+
0.5313,
|
| 913 |
+
0.4990,
|
| 914 |
+
0.4818,
|
| 915 |
+
0.5013,
|
| 916 |
+
0.8158,
|
| 917 |
+
1.0344,
|
| 918 |
+
0.5894,
|
| 919 |
+
1.0901,
|
| 920 |
+
0.6885,
|
| 921 |
+
0.6165,
|
| 922 |
+
0.8454,
|
| 923 |
+
0.4978,
|
| 924 |
+
0.5759,
|
| 925 |
+
0.3523,
|
| 926 |
+
0.7135,
|
| 927 |
+
0.6804,
|
| 928 |
+
0.5833,
|
| 929 |
+
1.4146,
|
| 930 |
+
0.8986,
|
| 931 |
+
0.5659,
|
| 932 |
+
0.7069,
|
| 933 |
+
0.5338,
|
| 934 |
+
0.4889,
|
| 935 |
+
0.4917,
|
| 936 |
+
0.4069,
|
| 937 |
+
0.4999,
|
| 938 |
+
0.6866,
|
| 939 |
+
0.4093,
|
| 940 |
+
0.5709,
|
| 941 |
+
0.6065,
|
| 942 |
+
0.6415,
|
| 943 |
+
0.4944,
|
| 944 |
+
0.5726,
|
| 945 |
+
1.2042,
|
| 946 |
+
0.5458,
|
| 947 |
+
1.6887,
|
| 948 |
+
0.3971,
|
| 949 |
+
1.0600,
|
| 950 |
+
0.3943,
|
| 951 |
+
0.5537,
|
| 952 |
+
0.5444,
|
| 953 |
+
0.4089,
|
| 954 |
+
0.7468,
|
| 955 |
+
0.7744,
|
| 956 |
+
],
|
| 957 |
+
dtype=torch.float32,
|
| 958 |
+
device="cuda",
|
| 959 |
+
)
|
| 960 |
+
self.scale = [self.mean, 1.0 / self.std]
|
| 961 |
+
|
| 962 |
+
def _sliding_window_decode(self, z, output_offload=False):
|
| 963 |
+
z_dtype = z.dtype
|
| 964 |
+
z_device = z.device
|
| 965 |
+
scale = self.scale
|
| 966 |
+
assert isinstance(scale[0], torch.Tensor), "scale[0] must be a tensor"
|
| 967 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
|
| 968 |
+
z = z.to(z_dtype)
|
| 969 |
+
|
| 970 |
+
first_chunk_size = self.first_chunk_size
|
| 971 |
+
step = self.step_size
|
| 972 |
+
|
| 973 |
+
# Context mapping: 1 latent frame context -> temporal_compression_ratio pixel frames overlap
|
| 974 |
+
num_overlap_pixel_frames = 1 * self.temporal_compression_ratio
|
| 975 |
+
|
| 976 |
+
_, _, num_frames, _, _ = z.shape
|
| 977 |
+
|
| 978 |
+
# 1. Pad frames to satisfy chunking requirements
|
| 979 |
+
# The total number of frames must follow the formula:
|
| 980 |
+
# num_frames = first_chunk_size + n * step_size
|
| 981 |
+
num_padding_frames = 0
|
| 982 |
+
|
| 983 |
+
if num_frames < first_chunk_size:
|
| 984 |
+
# if input is shorter than first_chunk_size
|
| 985 |
+
num_padding_frames = first_chunk_size - num_frames
|
| 986 |
+
elif (num_frames - first_chunk_size) % step != 0:
|
| 987 |
+
num_padding_frames = step - (num_frames - first_chunk_size) % step
|
| 988 |
+
|
| 989 |
+
if num_padding_frames > 0:
|
| 990 |
+
z = torch.cat([z, z[:, :, -1:].repeat(1, 1, num_padding_frames, 1, 1)], dim=2)
|
| 991 |
+
num_frames = num_frames + num_padding_frames
|
| 992 |
+
|
| 993 |
+
# 2. Decode with overlapping windows
|
| 994 |
+
# Collect chunks on CPU to avoid GPU OOM for high resolution (e.g., 1080P) when output_offload=True
|
| 995 |
+
out_chunks = []
|
| 996 |
+
|
| 997 |
+
if num_frames == first_chunk_size:
|
| 998 |
+
# if only one chunk, decode directly
|
| 999 |
+
out = self.decoder(z, is_first_chunk=True)
|
| 1000 |
+
out_chunks.append(out.cpu() if output_offload else out)
|
| 1001 |
+
del out
|
| 1002 |
+
else:
|
| 1003 |
+
# first chunk: attach the right frame
|
| 1004 |
+
out = self.decoder(z[:, :, 0 : first_chunk_size + 1, :, :], is_first_chunk=True)
|
| 1005 |
+
out = out[:, :, :-num_overlap_pixel_frames]
|
| 1006 |
+
out_chunks.append(out.cpu() if output_offload else out)
|
| 1007 |
+
del out
|
| 1008 |
+
|
| 1009 |
+
# middle chunk: attach the left and right frames
|
| 1010 |
+
# last chunk: attach the left frame
|
| 1011 |
+
for i in range(first_chunk_size, num_frames, step):
|
| 1012 |
+
is_last_chunk = i + step == num_frames
|
| 1013 |
+
left = i - 1
|
| 1014 |
+
right = i + step + 1 if not is_last_chunk else i + step
|
| 1015 |
+
|
| 1016 |
+
assert left >= 0 and right <= num_frames, f"left: {left}, right: {right}, num_frames: {num_frames}"
|
| 1017 |
+
|
| 1018 |
+
out_ = self.decoder(z[:, :, left:right, :, :], is_first_chunk=False)
|
| 1019 |
+
|
| 1020 |
+
if is_last_chunk:
|
| 1021 |
+
out_ = out_[:, :, num_overlap_pixel_frames:]
|
| 1022 |
+
else:
|
| 1023 |
+
out_ = out_[:, :, num_overlap_pixel_frames:-num_overlap_pixel_frames]
|
| 1024 |
+
|
| 1025 |
+
out_chunks.append(out_.cpu() if output_offload else out_)
|
| 1026 |
+
del out_
|
| 1027 |
+
|
| 1028 |
+
# Concatenate chunks (on CPU if output_offload, otherwise on GPU)
|
| 1029 |
+
out = torch.cat(out_chunks, dim=2)
|
| 1030 |
+
del out_chunks
|
| 1031 |
+
|
| 1032 |
+
# 3. Remove padded frames
|
| 1033 |
+
if num_padding_frames > 0:
|
| 1034 |
+
out = out[:, :, : -num_padding_frames * self.temporal_compression_ratio]
|
| 1035 |
+
|
| 1036 |
+
return out.to(z_device) if output_offload else out
|
| 1037 |
+
|
| 1038 |
+
def decode(self, z: torch.Tensor, output_offload: bool = False):
|
| 1039 |
+
return self._sliding_window_decode(z, output_offload=output_offload)
|
inference/model/vae2_2/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .vae2_2_model import Wan2_2_VAE, get_vae2_2
|
| 2 |
+
|
| 3 |
+
__all__ = ["Wan2_2_VAE", "get_vae2_2"]
|
inference/model/vae2_2/vae2_2_model.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .vae2_2_module import Wan2_2_VAE
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_vae2_2(model_path, device="cuda", weight_dtype=torch.float32) -> Wan2_2_VAE:
|
| 9 |
+
vae = Wan2_2_VAE(vae_pth=model_path).to(device).to(weight_dtype)
|
| 10 |
+
vae.vae.requires_grad_(False)
|
| 11 |
+
vae.vae.eval()
|
| 12 |
+
gc.collect()
|
| 13 |
+
torch.cuda.empty_cache()
|
| 14 |
+
return vae
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
__all__ = ["Wan2_2_VAE", "get_vae2_2"]
|
inference/model/vae2_2/vae2_2_module.py
ADDED
|
@@ -0,0 +1,1086 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# Copyright 2024-2026 The Alibaba Wan Team Authors. All rights reserved.
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from einops import rearrange
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
__all__ = ["Wan2_2_VAE"]
|
| 25 |
+
|
| 26 |
+
CACHE_T = 2
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ScatterFwdAllGatherBackwardOverlap(torch.autograd.Function):
|
| 30 |
+
@staticmethod
|
| 31 |
+
def forward(ctx, x, group, overlap_size):
|
| 32 |
+
"""
|
| 33 |
+
Forward pass: split input tensor along W; each rank processes its local
|
| 34 |
+
chunk including overlap regions.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
x: Input tensor, shape [B, C, T, H, W]
|
| 38 |
+
group: Distributed communication group
|
| 39 |
+
overlap_size: Width of overlap region
|
| 40 |
+
"""
|
| 41 |
+
W = x.shape[4]
|
| 42 |
+
world_size = torch.distributed.get_world_size(group)
|
| 43 |
+
rank = torch.distributed.get_rank(group)
|
| 44 |
+
|
| 45 |
+
# Compute base chunk size
|
| 46 |
+
base_chunk_size = (W + world_size - 1) // world_size
|
| 47 |
+
|
| 48 |
+
# Compute chunk range for current rank
|
| 49 |
+
chunk_start = rank * base_chunk_size
|
| 50 |
+
chunk_end = min((rank + 1) * base_chunk_size, W)
|
| 51 |
+
|
| 52 |
+
# Extend range with overlap
|
| 53 |
+
overlap_start = max(0, chunk_start - overlap_size)
|
| 54 |
+
overlap_end = min(W, chunk_end + overlap_size)
|
| 55 |
+
|
| 56 |
+
# Slice local chunk
|
| 57 |
+
x_chunk = x[:, :, :, :, overlap_start:overlap_end].contiguous()
|
| 58 |
+
|
| 59 |
+
# Save metadata needed by backward
|
| 60 |
+
ctx.save_for_backward(torch.tensor([overlap_start, overlap_end, W], dtype=torch.long, device=x.device))
|
| 61 |
+
ctx.group = group
|
| 62 |
+
ctx.overlap_size = overlap_size
|
| 63 |
+
ctx.world_size = world_size
|
| 64 |
+
ctx.rank = rank
|
| 65 |
+
ctx.base_chunk_size = base_chunk_size
|
| 66 |
+
return x_chunk
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def backward(ctx, grad_output):
|
| 70 |
+
"""
|
| 71 |
+
Backward pass: all-gather gradients from all ranks and trim overlap.
|
| 72 |
+
"""
|
| 73 |
+
# Restore saved forward metadata
|
| 74 |
+
overlap_start, overlap_end, W = ctx.saved_tensors[0]
|
| 75 |
+
overlap_start = overlap_start.item()
|
| 76 |
+
overlap_end = overlap_end.item()
|
| 77 |
+
W = W.item()
|
| 78 |
+
|
| 79 |
+
group = ctx.group
|
| 80 |
+
overlap_size = ctx.overlap_size
|
| 81 |
+
world_size = ctx.world_size
|
| 82 |
+
ctx.rank
|
| 83 |
+
base_chunk_size = ctx.base_chunk_size
|
| 84 |
+
|
| 85 |
+
# Collect gradients from all ranks via all_gather
|
| 86 |
+
grad_output = grad_output.contiguous()
|
| 87 |
+
B, C, T, H = grad_output.shape[:4]
|
| 88 |
+
grad_shapes = []
|
| 89 |
+
for r in range(world_size):
|
| 90 |
+
r_chunk_start = r * base_chunk_size
|
| 91 |
+
r_chunk_end = min((r + 1) * base_chunk_size, W)
|
| 92 |
+
|
| 93 |
+
r_overlap_start = max(0, r_chunk_start - overlap_size)
|
| 94 |
+
r_overlap_end = min(W, r_chunk_end + overlap_size)
|
| 95 |
+
|
| 96 |
+
# Compute gradient shape for each rank
|
| 97 |
+
chunk_width = r_overlap_end - r_overlap_start
|
| 98 |
+
grad_shapes.append((B, C, T, H, chunk_width))
|
| 99 |
+
grad_chunks = [
|
| 100 |
+
torch.zeros(grad_shape, device=grad_output.device, dtype=grad_output.dtype) for grad_shape in grad_shapes
|
| 101 |
+
]
|
| 102 |
+
torch.distributed.all_gather(grad_chunks, grad_output, group=group)
|
| 103 |
+
|
| 104 |
+
# Stitch gathered chunks into full gradient tensor
|
| 105 |
+
full_grad = torch.zeros(B, C, T, H, W, device=grad_output.device, dtype=grad_output.dtype)
|
| 106 |
+
|
| 107 |
+
# Place each rank's gradient chunk at the correct position
|
| 108 |
+
for r in range(world_size):
|
| 109 |
+
r_chunk_start = r * base_chunk_size
|
| 110 |
+
r_chunk_end = min((r + 1) * base_chunk_size, W)
|
| 111 |
+
|
| 112 |
+
r_overlap_start = max(0, r_chunk_start - overlap_size)
|
| 113 |
+
r_overlap_end = min(W, r_chunk_end + overlap_size)
|
| 114 |
+
|
| 115 |
+
# Position in full gradient
|
| 116 |
+
grad_start_in_full = r_overlap_start
|
| 117 |
+
grad_end_in_full = r_overlap_end
|
| 118 |
+
|
| 119 |
+
# Position inside gathered chunk
|
| 120 |
+
grad_start_in_chunk = 0
|
| 121 |
+
grad_end_in_chunk = r_overlap_end - r_overlap_start
|
| 122 |
+
|
| 123 |
+
# Handle left boundary for first rank
|
| 124 |
+
if r == 0:
|
| 125 |
+
grad_start_in_chunk = 0
|
| 126 |
+
grad_end_in_chunk = min(r_chunk_end + overlap_size, W) - r_overlap_start
|
| 127 |
+
# Handle right boundary for last rank
|
| 128 |
+
elif r == world_size - 1:
|
| 129 |
+
grad_start_in_chunk = max(0, r_chunk_start - overlap_size) - r_overlap_start
|
| 130 |
+
grad_end_in_chunk = r_overlap_end - r_overlap_start
|
| 131 |
+
|
| 132 |
+
# Accumulate into full gradient
|
| 133 |
+
full_grad[:, :, :, :, grad_start_in_full:grad_end_in_full] += grad_chunks[r][
|
| 134 |
+
:, :, :, :, grad_start_in_chunk:grad_end_in_chunk
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
return full_grad, None, None
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def scatter_fwd_all_gather_backward_with_overlap(x, group, overlap_size=0):
|
| 141 |
+
return ScatterFwdAllGatherBackwardOverlap.apply(x, group, overlap_size)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class AllGatherFwdScatterBackwardOverlap(torch.autograd.Function):
|
| 145 |
+
@staticmethod
|
| 146 |
+
def forward(ctx, x, group, overlap_size):
|
| 147 |
+
"""
|
| 148 |
+
Forward pass: each rank clips local input, then all-gathers clipped chunks.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
x: Input tensor, shape [B, C, T, H, W], already local overlapped chunk per rank
|
| 152 |
+
group: Distributed communication group
|
| 153 |
+
overlap_size: Width of overlap region
|
| 154 |
+
"""
|
| 155 |
+
world_size = torch.distributed.get_world_size(group)
|
| 156 |
+
rank = torch.distributed.get_rank(group)
|
| 157 |
+
|
| 158 |
+
# Clip local input first (remove overlap area)
|
| 159 |
+
if rank == 0:
|
| 160 |
+
valid_start = 0
|
| 161 |
+
valid_end = x.shape[-1] - overlap_size
|
| 162 |
+
elif rank == world_size - 1:
|
| 163 |
+
valid_start = overlap_size
|
| 164 |
+
valid_end = x.shape[-1]
|
| 165 |
+
else:
|
| 166 |
+
valid_start = overlap_size
|
| 167 |
+
valid_end = x.shape[-1] - overlap_size
|
| 168 |
+
|
| 169 |
+
x_clipped = x[..., valid_start:valid_end].contiguous()
|
| 170 |
+
clipped_width = x_clipped.shape[-1]
|
| 171 |
+
|
| 172 |
+
# First all_gather: collect clipped widths across ranks
|
| 173 |
+
width_tensor = torch.tensor([clipped_width], dtype=torch.long, device=x.device)
|
| 174 |
+
all_widths = [torch.zeros_like(width_tensor) for _ in range(world_size)]
|
| 175 |
+
torch.distributed.all_gather(all_widths, width_tensor, group=group)
|
| 176 |
+
clipped_widths = [w.item() for w in all_widths]
|
| 177 |
+
|
| 178 |
+
# Second all_gather: collect clipped data across ranks
|
| 179 |
+
B, C, T, H = x_clipped.shape[:4]
|
| 180 |
+
x_clipped_chunks = [torch.zeros(B, C, T, H, w, device=x.device, dtype=x.dtype) for w in clipped_widths]
|
| 181 |
+
torch.distributed.all_gather(x_clipped_chunks, x_clipped, group=group)
|
| 182 |
+
full_x = torch.cat(x_clipped_chunks, dim=-1)
|
| 183 |
+
|
| 184 |
+
# Save metadata needed by backward
|
| 185 |
+
ctx.save_for_backward(torch.tensor([valid_start, valid_end], dtype=torch.long, device=x.device))
|
| 186 |
+
ctx.clipped_widths = clipped_widths
|
| 187 |
+
ctx.group = group
|
| 188 |
+
ctx.overlap_size = overlap_size
|
| 189 |
+
ctx.world_size = world_size
|
| 190 |
+
ctx.rank = rank
|
| 191 |
+
|
| 192 |
+
return full_x
|
| 193 |
+
|
| 194 |
+
@staticmethod
|
| 195 |
+
def backward(ctx, grad_output):
|
| 196 |
+
"""
|
| 197 |
+
Backward pass: each rank restores gradients for its own partition only.
|
| 198 |
+
"""
|
| 199 |
+
# Restore saved forward metadata
|
| 200 |
+
valid_start, valid_end = ctx.saved_tensors[0]
|
| 201 |
+
valid_start = valid_start.item()
|
| 202 |
+
valid_end = valid_end.item()
|
| 203 |
+
|
| 204 |
+
clipped_widths = ctx.clipped_widths
|
| 205 |
+
ctx.group
|
| 206 |
+
overlap_size = ctx.overlap_size
|
| 207 |
+
world_size = ctx.world_size
|
| 208 |
+
rank = ctx.rank
|
| 209 |
+
|
| 210 |
+
# Compute current rank offset in full gradient
|
| 211 |
+
start_pos = sum(clipped_widths[:rank])
|
| 212 |
+
end_pos = start_pos + clipped_widths[rank]
|
| 213 |
+
|
| 214 |
+
# Extract only current rank gradient slice
|
| 215 |
+
grad_clipped = grad_output[:, :, :, :, start_pos:end_pos]
|
| 216 |
+
|
| 217 |
+
# Pad zeros to recover overlap area for current rank
|
| 218 |
+
if rank == 0:
|
| 219 |
+
# First rank: pad right
|
| 220 |
+
grad_full = F.pad(grad_clipped, (0, overlap_size))
|
| 221 |
+
elif rank == world_size - 1:
|
| 222 |
+
# Last rank: pad left
|
| 223 |
+
grad_full = F.pad(grad_clipped, (overlap_size, 0))
|
| 224 |
+
else:
|
| 225 |
+
# Middle rank: pad both sides
|
| 226 |
+
grad_full = F.pad(grad_clipped, (overlap_size, overlap_size))
|
| 227 |
+
|
| 228 |
+
return grad_full, None, None
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def all_gather_fwd_scatter_backward_with_overlap(x, group, overlap_size=0):
|
| 232 |
+
return AllGatherFwdScatterBackwardOverlap.apply(x, group, overlap_size)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def one_plus_world_size(group):
|
| 236 |
+
return group is not None and torch.distributed.get_world_size(group) > 1
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class CausalConv3d(nn.Conv3d):
|
| 240 |
+
"""
|
| 241 |
+
Causal 3d convolusion.
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
def __init__(self, *args, **kwargs):
|
| 245 |
+
super().__init__(*args, **kwargs)
|
| 246 |
+
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
|
| 247 |
+
self.padding = (0, 0, 0)
|
| 248 |
+
|
| 249 |
+
@torch.compile
|
| 250 |
+
def forward(self, x, cache_x=None, group: torch.distributed.ProcessGroup = None):
|
| 251 |
+
padding = list(self._padding)
|
| 252 |
+
if cache_x is not None and self._padding[4] > 0:
|
| 253 |
+
cache_x = cache_x.to(x.device)
|
| 254 |
+
x = torch.cat([cache_x, x], dim=2)
|
| 255 |
+
padding[4] -= cache_x.shape[2]
|
| 256 |
+
if one_plus_world_size(group):
|
| 257 |
+
overlap_size = self.kernel_size[-1] // 2 * self.stride[-1]
|
| 258 |
+
x = scatter_fwd_all_gather_backward_with_overlap(x, group, overlap_size=overlap_size)
|
| 259 |
+
x = F.pad(x, padding)
|
| 260 |
+
x = super().forward(x)
|
| 261 |
+
if one_plus_world_size(group):
|
| 262 |
+
x = all_gather_fwd_scatter_backward_with_overlap(x, group, overlap_size=overlap_size)
|
| 263 |
+
return x
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class RMS_norm(nn.Module):
|
| 267 |
+
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
| 268 |
+
super().__init__()
|
| 269 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 270 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 271 |
+
|
| 272 |
+
self.channel_first = channel_first
|
| 273 |
+
self.scale = dim**0.5
|
| 274 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
| 275 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
| 276 |
+
|
| 277 |
+
@torch.compile
|
| 278 |
+
def forward(self, x):
|
| 279 |
+
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class Upsample(nn.Upsample):
|
| 283 |
+
@torch.compile
|
| 284 |
+
def forward(self, x):
|
| 285 |
+
"""
|
| 286 |
+
Fix bfloat16 support for nearest neighbor interpolation.
|
| 287 |
+
"""
|
| 288 |
+
return super().forward(x.float()).type_as(x)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class Resample(nn.Module):
|
| 292 |
+
def __init__(self, dim, mode):
|
| 293 |
+
assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d")
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.dim = dim
|
| 296 |
+
self.mode = mode
|
| 297 |
+
|
| 298 |
+
# layers
|
| 299 |
+
if mode == "upsample2d":
|
| 300 |
+
self.resample = nn.Sequential(
|
| 301 |
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim, 3, padding=1)
|
| 302 |
+
)
|
| 303 |
+
elif mode == "upsample3d":
|
| 304 |
+
self.resample = nn.Sequential(
|
| 305 |
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 306 |
+
nn.Conv2d(dim, dim, 3, padding=1),
|
| 307 |
+
# nn.Conv2d(dim, dim//2, 3, padding=1)
|
| 308 |
+
)
|
| 309 |
+
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
| 310 |
+
elif mode == "downsample2d":
|
| 311 |
+
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 312 |
+
elif mode == "downsample3d":
|
| 313 |
+
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 314 |
+
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
| 315 |
+
else:
|
| 316 |
+
self.resample = nn.Identity()
|
| 317 |
+
|
| 318 |
+
@torch.compile
|
| 319 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], group: torch.distributed.ProcessGroup = None):
|
| 320 |
+
if one_plus_world_size(group):
|
| 321 |
+
if self.mode in ["upsample3d", "upsample2d"]:
|
| 322 |
+
overlap_size = 1
|
| 323 |
+
elif self.mode in ["downsample3d", "downsample2d"]:
|
| 324 |
+
overlap_size = 2
|
| 325 |
+
else:
|
| 326 |
+
overlap_size = 0
|
| 327 |
+
x = scatter_fwd_all_gather_backward_with_overlap(x, group, overlap_size=overlap_size)
|
| 328 |
+
|
| 329 |
+
b, c, t, h, w = x.size()
|
| 330 |
+
if self.mode == "upsample3d":
|
| 331 |
+
if feat_cache is not None:
|
| 332 |
+
idx = feat_idx[0]
|
| 333 |
+
if feat_cache[idx] is None:
|
| 334 |
+
feat_cache[idx] = "Rep"
|
| 335 |
+
feat_idx[0] += 1
|
| 336 |
+
else:
|
| 337 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 338 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
|
| 339 |
+
# cache last frame of last two chunk
|
| 340 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 341 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
|
| 342 |
+
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
|
| 343 |
+
if feat_cache[idx] == "Rep":
|
| 344 |
+
x = self.time_conv(x)
|
| 345 |
+
else:
|
| 346 |
+
x = self.time_conv(x, feat_cache[idx])
|
| 347 |
+
feat_cache[idx] = cache_x
|
| 348 |
+
feat_idx[0] += 1
|
| 349 |
+
x = x.reshape(b, 2, c, t, h, w)
|
| 350 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
|
| 351 |
+
x = x.reshape(b, c, t * 2, h, w)
|
| 352 |
+
t = x.shape[2]
|
| 353 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
| 354 |
+
x = self.resample(x)
|
| 355 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
| 356 |
+
|
| 357 |
+
if self.mode == "downsample3d":
|
| 358 |
+
if feat_cache is not None:
|
| 359 |
+
idx = feat_idx[0]
|
| 360 |
+
if feat_cache[idx] is None:
|
| 361 |
+
feat_cache[idx] = x.clone()
|
| 362 |
+
feat_idx[0] += 1
|
| 363 |
+
else:
|
| 364 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
| 365 |
+
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
| 366 |
+
feat_cache[idx] = cache_x
|
| 367 |
+
feat_idx[0] += 1
|
| 368 |
+
|
| 369 |
+
if one_plus_world_size(group):
|
| 370 |
+
if self.mode in ["upsample3d", "upsample2d"]:
|
| 371 |
+
overlap_size = overlap_size * 2
|
| 372 |
+
elif self.mode in ["downsample3d", "downsample2d"]:
|
| 373 |
+
overlap_size = overlap_size // 2
|
| 374 |
+
else:
|
| 375 |
+
overlap_size = overlap_size
|
| 376 |
+
x = all_gather_fwd_scatter_backward_with_overlap(x, group, overlap_size=overlap_size)
|
| 377 |
+
return x
|
| 378 |
+
|
| 379 |
+
def init_weight(self, conv):
|
| 380 |
+
conv_weight = conv.weight.detach().clone()
|
| 381 |
+
nn.init.zeros_(conv_weight)
|
| 382 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 383 |
+
one_matrix = torch.eye(c1, c2)
|
| 384 |
+
init_matrix = one_matrix
|
| 385 |
+
nn.init.zeros_(conv_weight)
|
| 386 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
|
| 387 |
+
conv.weight = nn.Parameter(conv_weight)
|
| 388 |
+
nn.init.zeros_(conv.bias.data)
|
| 389 |
+
|
| 390 |
+
def init_weight2(self, conv):
|
| 391 |
+
conv_weight = conv.weight.data.detach().clone()
|
| 392 |
+
nn.init.zeros_(conv_weight)
|
| 393 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 394 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
| 395 |
+
conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
|
| 396 |
+
conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
|
| 397 |
+
conv.weight = nn.Parameter(conv_weight)
|
| 398 |
+
nn.init.zeros_(conv.bias.data)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class ResidualBlock(nn.Module):
|
| 402 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
| 403 |
+
super().__init__()
|
| 404 |
+
self.in_dim = in_dim
|
| 405 |
+
self.out_dim = out_dim
|
| 406 |
+
|
| 407 |
+
# layers
|
| 408 |
+
self.residual = nn.Sequential(
|
| 409 |
+
RMS_norm(in_dim, images=False),
|
| 410 |
+
nn.SiLU(),
|
| 411 |
+
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
| 412 |
+
RMS_norm(out_dim, images=False),
|
| 413 |
+
nn.SiLU(),
|
| 414 |
+
nn.Dropout(dropout),
|
| 415 |
+
CausalConv3d(out_dim, out_dim, 3, padding=1),
|
| 416 |
+
)
|
| 417 |
+
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
@torch.compile
|
| 421 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], group: torch.distributed.ProcessGroup = None):
|
| 422 |
+
if one_plus_world_size(group):
|
| 423 |
+
overlap_size = 2
|
| 424 |
+
x = scatter_fwd_all_gather_backward_with_overlap(x, group, overlap_size=overlap_size)
|
| 425 |
+
h = self.shortcut(x)
|
| 426 |
+
for layer in self.residual:
|
| 427 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 428 |
+
idx = feat_idx[0]
|
| 429 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 430 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 431 |
+
# cache last frame of last two chunk
|
| 432 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 433 |
+
x = layer(x, feat_cache[idx])
|
| 434 |
+
feat_cache[idx] = cache_x
|
| 435 |
+
feat_idx[0] += 1
|
| 436 |
+
else:
|
| 437 |
+
x = layer(x)
|
| 438 |
+
x = x + h
|
| 439 |
+
if one_plus_world_size(group):
|
| 440 |
+
x = all_gather_fwd_scatter_backward_with_overlap(x, group, overlap_size=overlap_size)
|
| 441 |
+
return x
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
class AttentionBlock(nn.Module):
|
| 445 |
+
"""
|
| 446 |
+
Causal self-attention with a single head.
|
| 447 |
+
"""
|
| 448 |
+
|
| 449 |
+
def __init__(self, dim):
|
| 450 |
+
super().__init__()
|
| 451 |
+
self.dim = dim
|
| 452 |
+
|
| 453 |
+
# layers
|
| 454 |
+
self.norm = RMS_norm(dim)
|
| 455 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
| 456 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
| 457 |
+
|
| 458 |
+
# zero out the last layer params
|
| 459 |
+
nn.init.zeros_(self.proj.weight)
|
| 460 |
+
|
| 461 |
+
@torch.compile
|
| 462 |
+
def forward(self, x):
|
| 463 |
+
identity = x
|
| 464 |
+
b, c, t, h, w = x.size()
|
| 465 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
| 466 |
+
x = self.norm(x)
|
| 467 |
+
# compute query, key, value
|
| 468 |
+
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
|
| 469 |
+
|
| 470 |
+
# apply attention
|
| 471 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
| 472 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
| 473 |
+
|
| 474 |
+
# output
|
| 475 |
+
x = self.proj(x)
|
| 476 |
+
x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
|
| 477 |
+
x = x + identity
|
| 478 |
+
return x
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def patchify(x, patch_size):
|
| 482 |
+
if patch_size == 1:
|
| 483 |
+
return x
|
| 484 |
+
if x.dim() == 4:
|
| 485 |
+
x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
|
| 486 |
+
elif x.dim() == 5:
|
| 487 |
+
x = rearrange(x, "b c f (h q) (w r) -> b (c r q) f h w", q=patch_size, r=patch_size)
|
| 488 |
+
else:
|
| 489 |
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 490 |
+
|
| 491 |
+
return x
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def unpatchify(x, patch_size):
|
| 495 |
+
if patch_size == 1:
|
| 496 |
+
return x
|
| 497 |
+
|
| 498 |
+
if x.dim() == 4:
|
| 499 |
+
x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
|
| 500 |
+
elif x.dim() == 5:
|
| 501 |
+
x = rearrange(x, "b (c r q) f h w -> b c f (h q) (w r)", q=patch_size, r=patch_size)
|
| 502 |
+
return x
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
class AvgDown3D(nn.Module):
|
| 506 |
+
def __init__(self, in_channels, out_channels, factor_t, factor_s=1):
|
| 507 |
+
super().__init__()
|
| 508 |
+
self.in_channels = in_channels
|
| 509 |
+
self.out_channels = out_channels
|
| 510 |
+
self.factor_t = factor_t
|
| 511 |
+
self.factor_s = factor_s
|
| 512 |
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
| 513 |
+
|
| 514 |
+
assert in_channels * self.factor % out_channels == 0
|
| 515 |
+
self.group_size = in_channels * self.factor // out_channels
|
| 516 |
+
|
| 517 |
+
@torch.compile
|
| 518 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 519 |
+
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
| 520 |
+
pad = (0, 0, 0, 0, pad_t, 0)
|
| 521 |
+
x = F.pad(x, pad)
|
| 522 |
+
B, C, T, H, W = x.shape
|
| 523 |
+
x = x.view(
|
| 524 |
+
B, C, T // self.factor_t, self.factor_t, H // self.factor_s, self.factor_s, W // self.factor_s, self.factor_s
|
| 525 |
+
)
|
| 526 |
+
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
| 527 |
+
x = x.view(B, C * self.factor, T // self.factor_t, H // self.factor_s, W // self.factor_s)
|
| 528 |
+
x = x.view(B, self.out_channels, self.group_size, T // self.factor_t, H // self.factor_s, W // self.factor_s)
|
| 529 |
+
x = x.mean(dim=2)
|
| 530 |
+
return x
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
class DupUp3D(nn.Module):
|
| 534 |
+
def __init__(self, in_channels: int, out_channels: int, factor_t, factor_s=1):
|
| 535 |
+
super().__init__()
|
| 536 |
+
self.in_channels = in_channels
|
| 537 |
+
self.out_channels = out_channels
|
| 538 |
+
|
| 539 |
+
self.factor_t = factor_t
|
| 540 |
+
self.factor_s = factor_s
|
| 541 |
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
| 542 |
+
|
| 543 |
+
assert out_channels * self.factor % in_channels == 0
|
| 544 |
+
self.repeats = out_channels * self.factor // in_channels
|
| 545 |
+
|
| 546 |
+
@torch.compile
|
| 547 |
+
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
| 548 |
+
x = x.repeat_interleave(self.repeats, dim=1)
|
| 549 |
+
x = x.view(x.size(0), self.out_channels, self.factor_t, self.factor_s, self.factor_s, x.size(2), x.size(3), x.size(4))
|
| 550 |
+
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
| 551 |
+
x = x.view(
|
| 552 |
+
x.size(0), self.out_channels, x.size(2) * self.factor_t, x.size(4) * self.factor_s, x.size(6) * self.factor_s
|
| 553 |
+
)
|
| 554 |
+
if first_chunk:
|
| 555 |
+
x = x[:, :, self.factor_t - 1 :, :, :]
|
| 556 |
+
return x
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
class Down_ResidualBlock(nn.Module):
|
| 560 |
+
def __init__(self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False):
|
| 561 |
+
super().__init__()
|
| 562 |
+
|
| 563 |
+
# Shortcut path with downsample
|
| 564 |
+
self.avg_shortcut = AvgDown3D(
|
| 565 |
+
in_dim, out_dim, factor_t=2 if temperal_downsample else 1, factor_s=2 if down_flag else 1
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
# Main path with residual blocks and downsample
|
| 569 |
+
downsamples = []
|
| 570 |
+
for _ in range(mult):
|
| 571 |
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 572 |
+
in_dim = out_dim
|
| 573 |
+
|
| 574 |
+
# Add the final downsample block
|
| 575 |
+
if down_flag:
|
| 576 |
+
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
| 577 |
+
downsamples.append(Resample(out_dim, mode=mode))
|
| 578 |
+
|
| 579 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 580 |
+
|
| 581 |
+
@torch.compile
|
| 582 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 583 |
+
x_copy = x.clone()
|
| 584 |
+
for module in self.downsamples:
|
| 585 |
+
x = module(x, feat_cache, feat_idx)
|
| 586 |
+
|
| 587 |
+
return x + self.avg_shortcut(x_copy)
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
class Up_ResidualBlock(nn.Module):
|
| 591 |
+
def __init__(self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False):
|
| 592 |
+
super().__init__()
|
| 593 |
+
# Shortcut path with upsample
|
| 594 |
+
if up_flag:
|
| 595 |
+
self.avg_shortcut = DupUp3D(in_dim, out_dim, factor_t=2 if temperal_upsample else 1, factor_s=2 if up_flag else 1)
|
| 596 |
+
else:
|
| 597 |
+
self.avg_shortcut = None
|
| 598 |
+
|
| 599 |
+
# Main path with residual blocks and upsample
|
| 600 |
+
upsamples = []
|
| 601 |
+
for _ in range(mult):
|
| 602 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 603 |
+
in_dim = out_dim
|
| 604 |
+
|
| 605 |
+
# Add the final upsample block
|
| 606 |
+
if up_flag:
|
| 607 |
+
mode = "upsample3d" if temperal_upsample else "upsample2d"
|
| 608 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
| 609 |
+
|
| 610 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 611 |
+
|
| 612 |
+
@torch.compile
|
| 613 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False, group: torch.distributed.ProcessGroup = None):
|
| 614 |
+
x_main = x.clone()
|
| 615 |
+
for module in self.upsamples:
|
| 616 |
+
x_main = module(x_main, feat_cache, feat_idx, group=group)
|
| 617 |
+
if self.avg_shortcut is not None:
|
| 618 |
+
x_shortcut = self.avg_shortcut(x, first_chunk)
|
| 619 |
+
return x_main + x_shortcut
|
| 620 |
+
else:
|
| 621 |
+
return x_main
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
class Encoder3d(nn.Module):
|
| 625 |
+
def __init__(
|
| 626 |
+
self,
|
| 627 |
+
dim=128,
|
| 628 |
+
z_dim=4,
|
| 629 |
+
dim_mult=[1, 2, 4, 4],
|
| 630 |
+
num_res_blocks=2,
|
| 631 |
+
attn_scales=[],
|
| 632 |
+
temperal_downsample=[True, True, False],
|
| 633 |
+
dropout=0.0,
|
| 634 |
+
):
|
| 635 |
+
super().__init__()
|
| 636 |
+
self.dim = dim
|
| 637 |
+
self.z_dim = z_dim
|
| 638 |
+
self.dim_mult = dim_mult
|
| 639 |
+
self.num_res_blocks = num_res_blocks
|
| 640 |
+
self.attn_scales = attn_scales
|
| 641 |
+
self.temperal_downsample = temperal_downsample
|
| 642 |
+
|
| 643 |
+
# dimensions
|
| 644 |
+
dims = [dim * u for u in [1] + dim_mult]
|
| 645 |
+
scale = 1.0
|
| 646 |
+
|
| 647 |
+
# init block
|
| 648 |
+
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
|
| 649 |
+
|
| 650 |
+
# downsample blocks
|
| 651 |
+
downsamples = []
|
| 652 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 653 |
+
t_down_flag = temperal_downsample[i] if i < len(temperal_downsample) else False
|
| 654 |
+
downsamples.append(
|
| 655 |
+
Down_ResidualBlock(
|
| 656 |
+
in_dim=in_dim,
|
| 657 |
+
out_dim=out_dim,
|
| 658 |
+
dropout=dropout,
|
| 659 |
+
mult=num_res_blocks,
|
| 660 |
+
temperal_downsample=t_down_flag,
|
| 661 |
+
down_flag=i != len(dim_mult) - 1,
|
| 662 |
+
)
|
| 663 |
+
)
|
| 664 |
+
scale /= 2.0
|
| 665 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 666 |
+
|
| 667 |
+
# middle blocks
|
| 668 |
+
self.middle = nn.Sequential(
|
| 669 |
+
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout)
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
# # output blocks
|
| 673 |
+
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1))
|
| 674 |
+
|
| 675 |
+
@torch.compile
|
| 676 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 677 |
+
if feat_cache is not None:
|
| 678 |
+
idx = feat_idx[0]
|
| 679 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 680 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 681 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 682 |
+
x = self.conv1(x, feat_cache[idx])
|
| 683 |
+
feat_cache[idx] = cache_x
|
| 684 |
+
feat_idx[0] += 1
|
| 685 |
+
else:
|
| 686 |
+
x = self.conv1(x)
|
| 687 |
+
|
| 688 |
+
# downsamples
|
| 689 |
+
for layer in self.downsamples:
|
| 690 |
+
if feat_cache is not None:
|
| 691 |
+
x = layer(x, feat_cache, feat_idx)
|
| 692 |
+
else:
|
| 693 |
+
x = layer(x)
|
| 694 |
+
|
| 695 |
+
# middle
|
| 696 |
+
for layer in self.middle:
|
| 697 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 698 |
+
x = layer(x, feat_cache, feat_idx)
|
| 699 |
+
else:
|
| 700 |
+
x = layer(x)
|
| 701 |
+
|
| 702 |
+
# head
|
| 703 |
+
for layer in self.head:
|
| 704 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 705 |
+
idx = feat_idx[0]
|
| 706 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 707 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 708 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 709 |
+
x = layer(x, feat_cache[idx])
|
| 710 |
+
feat_cache[idx] = cache_x
|
| 711 |
+
feat_idx[0] += 1
|
| 712 |
+
else:
|
| 713 |
+
x = layer(x)
|
| 714 |
+
|
| 715 |
+
return x
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
class Decoder3d(nn.Module):
|
| 719 |
+
def __init__(
|
| 720 |
+
self,
|
| 721 |
+
dim=128,
|
| 722 |
+
z_dim=4,
|
| 723 |
+
dim_mult=[1, 2, 4, 4],
|
| 724 |
+
num_res_blocks=2,
|
| 725 |
+
attn_scales=[],
|
| 726 |
+
temperal_upsample=[False, True, True],
|
| 727 |
+
dropout=0.0,
|
| 728 |
+
):
|
| 729 |
+
super().__init__()
|
| 730 |
+
self.dim = dim
|
| 731 |
+
self.z_dim = z_dim
|
| 732 |
+
self.dim_mult = dim_mult
|
| 733 |
+
self.num_res_blocks = num_res_blocks
|
| 734 |
+
self.attn_scales = attn_scales
|
| 735 |
+
self.temperal_upsample = temperal_upsample
|
| 736 |
+
|
| 737 |
+
# dimensions
|
| 738 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 739 |
+
# scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
| 740 |
+
# init block
|
| 741 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 742 |
+
|
| 743 |
+
# middle blocks
|
| 744 |
+
self.middle = nn.Sequential(
|
| 745 |
+
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout)
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
# upsample blocks
|
| 749 |
+
upsamples = []
|
| 750 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 751 |
+
t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False
|
| 752 |
+
upsamples.append(
|
| 753 |
+
Up_ResidualBlock(
|
| 754 |
+
in_dim=in_dim,
|
| 755 |
+
out_dim=out_dim,
|
| 756 |
+
dropout=dropout,
|
| 757 |
+
mult=num_res_blocks + 1,
|
| 758 |
+
temperal_upsample=t_up_flag,
|
| 759 |
+
up_flag=i != len(dim_mult) - 1,
|
| 760 |
+
)
|
| 761 |
+
)
|
| 762 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 763 |
+
|
| 764 |
+
# output blocks
|
| 765 |
+
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 12, 3, padding=1))
|
| 766 |
+
|
| 767 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False, group: torch.distributed.ProcessGroup = None):
|
| 768 |
+
if feat_cache is not None:
|
| 769 |
+
idx = feat_idx[0]
|
| 770 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 771 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 772 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 773 |
+
x = self.conv1(x, feat_cache[idx], group=group)
|
| 774 |
+
feat_cache[idx] = cache_x
|
| 775 |
+
feat_idx[0] += 1
|
| 776 |
+
else:
|
| 777 |
+
x = self.conv1(x, group=group)
|
| 778 |
+
|
| 779 |
+
for layer in self.middle:
|
| 780 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 781 |
+
x = layer(x, feat_cache, feat_idx, group=group)
|
| 782 |
+
else:
|
| 783 |
+
x = layer(x)
|
| 784 |
+
|
| 785 |
+
# upsamples
|
| 786 |
+
for layer in self.upsamples:
|
| 787 |
+
if feat_cache is not None:
|
| 788 |
+
x = layer(x, feat_cache, feat_idx, first_chunk, group=group)
|
| 789 |
+
else:
|
| 790 |
+
x = layer(x, group=group)
|
| 791 |
+
|
| 792 |
+
# head
|
| 793 |
+
if one_plus_world_size(group):
|
| 794 |
+
overlap_size = self.head[2].kernel_size[-1] // 2 * self.head[2].stride[-1]
|
| 795 |
+
x = scatter_fwd_all_gather_backward_with_overlap(x, group, overlap_size=overlap_size)
|
| 796 |
+
for layer in self.head:
|
| 797 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 798 |
+
idx = feat_idx[0]
|
| 799 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 800 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 801 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 802 |
+
x = layer(x, feat_cache[idx])
|
| 803 |
+
feat_cache[idx] = cache_x
|
| 804 |
+
feat_idx[0] += 1
|
| 805 |
+
else:
|
| 806 |
+
x = layer(x)
|
| 807 |
+
if one_plus_world_size(group):
|
| 808 |
+
x = all_gather_fwd_scatter_backward_with_overlap(x, group, overlap_size=overlap_size)
|
| 809 |
+
return x
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
def count_conv3d(model):
|
| 813 |
+
count = 0
|
| 814 |
+
for m in model.modules():
|
| 815 |
+
if isinstance(m, CausalConv3d):
|
| 816 |
+
count += 1
|
| 817 |
+
return count
|
| 818 |
+
|
| 819 |
+
|
| 820 |
+
class WanVAE_(nn.Module):
|
| 821 |
+
def __init__(
|
| 822 |
+
self,
|
| 823 |
+
dim=160,
|
| 824 |
+
dec_dim=256,
|
| 825 |
+
z_dim=16,
|
| 826 |
+
dim_mult=[1, 2, 4, 4],
|
| 827 |
+
num_res_blocks=2,
|
| 828 |
+
attn_scales=[],
|
| 829 |
+
temperal_downsample=[True, True, False],
|
| 830 |
+
dropout=0.0,
|
| 831 |
+
):
|
| 832 |
+
super().__init__()
|
| 833 |
+
self.dim = dim
|
| 834 |
+
self.z_dim = z_dim
|
| 835 |
+
self.dim_mult = dim_mult
|
| 836 |
+
self.num_res_blocks = num_res_blocks
|
| 837 |
+
self.attn_scales = attn_scales
|
| 838 |
+
self.temperal_downsample = temperal_downsample
|
| 839 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 840 |
+
|
| 841 |
+
# modules
|
| 842 |
+
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout)
|
| 843 |
+
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
| 844 |
+
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
| 845 |
+
self.decoder = Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout)
|
| 846 |
+
|
| 847 |
+
def forward(self, x, scale=[0, 1]):
|
| 848 |
+
mu = self.encode(x, scale)
|
| 849 |
+
x_recon = self.decode(mu, scale)
|
| 850 |
+
return x_recon, mu
|
| 851 |
+
|
| 852 |
+
def encode(self, x, scale):
|
| 853 |
+
self.clear_cache()
|
| 854 |
+
x = patchify(x, patch_size=2)
|
| 855 |
+
t = x.shape[2]
|
| 856 |
+
iter_ = 1 + (t - 1) // 4
|
| 857 |
+
for i in range(iter_):
|
| 858 |
+
self._enc_conv_idx = [0]
|
| 859 |
+
if i == 0:
|
| 860 |
+
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
| 861 |
+
else:
|
| 862 |
+
out_ = self.encoder(
|
| 863 |
+
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx
|
| 864 |
+
)
|
| 865 |
+
out = torch.cat([out, out_], 2)
|
| 866 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
| 867 |
+
if isinstance(scale[0], torch.Tensor):
|
| 868 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
|
| 869 |
+
else:
|
| 870 |
+
mu = (mu - scale[0]) * scale[1]
|
| 871 |
+
self.clear_cache()
|
| 872 |
+
return mu
|
| 873 |
+
|
| 874 |
+
def decode(self, z, scale, group: torch.distributed.ProcessGroup = None):
|
| 875 |
+
self.clear_cache()
|
| 876 |
+
if isinstance(scale[0], torch.Tensor):
|
| 877 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
|
| 878 |
+
else:
|
| 879 |
+
z = z / scale[1] + scale[0]
|
| 880 |
+
iter_ = z.shape[2]
|
| 881 |
+
x = self.conv2(z, group=group)
|
| 882 |
+
for i in range(iter_):
|
| 883 |
+
self._conv_idx = [0]
|
| 884 |
+
if i == 0:
|
| 885 |
+
out = self.decoder(
|
| 886 |
+
x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True, group=group
|
| 887 |
+
)
|
| 888 |
+
else:
|
| 889 |
+
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, group=group)
|
| 890 |
+
out = torch.cat([out, out_], 2)
|
| 891 |
+
out = unpatchify(out, patch_size=2)
|
| 892 |
+
self.clear_cache()
|
| 893 |
+
return out
|
| 894 |
+
|
| 895 |
+
def reparameterize(self, mu, log_var):
|
| 896 |
+
std = torch.exp(0.5 * log_var)
|
| 897 |
+
eps = torch.randn_like(std)
|
| 898 |
+
return eps * std + mu
|
| 899 |
+
|
| 900 |
+
def sample(self, imgs, deterministic=False):
|
| 901 |
+
mu, log_var = self.encode(imgs)
|
| 902 |
+
if deterministic:
|
| 903 |
+
return mu
|
| 904 |
+
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
| 905 |
+
return mu + std * torch.randn_like(std)
|
| 906 |
+
|
| 907 |
+
def clear_cache(self):
|
| 908 |
+
self._conv_num = count_conv3d(self.decoder)
|
| 909 |
+
self._conv_idx = [0]
|
| 910 |
+
self._feat_map = [None] * self._conv_num
|
| 911 |
+
# cache encode
|
| 912 |
+
self._enc_conv_num = count_conv3d(self.encoder)
|
| 913 |
+
self._enc_conv_idx = [0]
|
| 914 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
|
| 918 |
+
# params
|
| 919 |
+
cfg = dict(
|
| 920 |
+
dim=dim,
|
| 921 |
+
z_dim=z_dim,
|
| 922 |
+
dim_mult=[1, 2, 4, 4],
|
| 923 |
+
num_res_blocks=2,
|
| 924 |
+
attn_scales=[],
|
| 925 |
+
temperal_downsample=[True, True, True],
|
| 926 |
+
dropout=0.0,
|
| 927 |
+
)
|
| 928 |
+
cfg.update(**kwargs)
|
| 929 |
+
|
| 930 |
+
# init model
|
| 931 |
+
with torch.device("meta"):
|
| 932 |
+
model = WanVAE_(**cfg)
|
| 933 |
+
|
| 934 |
+
# load checkpoint
|
| 935 |
+
logging.info(f"loading {pretrained_path}")
|
| 936 |
+
model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)
|
| 937 |
+
|
| 938 |
+
return model
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
class Wan2_2_VAE:
|
| 942 |
+
def __init__(
|
| 943 |
+
self,
|
| 944 |
+
z_dim=48,
|
| 945 |
+
c_dim=160,
|
| 946 |
+
vae_pth=None,
|
| 947 |
+
dim_mult=[1, 2, 4, 4],
|
| 948 |
+
temperal_downsample=[False, True, True],
|
| 949 |
+
dtype=torch.float,
|
| 950 |
+
device="cuda",
|
| 951 |
+
):
|
| 952 |
+
self.dtype = dtype
|
| 953 |
+
self.device = device
|
| 954 |
+
|
| 955 |
+
self.mean = torch.tensor(
|
| 956 |
+
[
|
| 957 |
+
-0.2289,
|
| 958 |
+
-0.0052,
|
| 959 |
+
-0.1323,
|
| 960 |
+
-0.2339,
|
| 961 |
+
-0.2799,
|
| 962 |
+
0.0174,
|
| 963 |
+
0.1838,
|
| 964 |
+
0.1557,
|
| 965 |
+
-0.1382,
|
| 966 |
+
0.0542,
|
| 967 |
+
0.2813,
|
| 968 |
+
0.0891,
|
| 969 |
+
0.1570,
|
| 970 |
+
-0.0098,
|
| 971 |
+
0.0375,
|
| 972 |
+
-0.1825,
|
| 973 |
+
-0.2246,
|
| 974 |
+
-0.1207,
|
| 975 |
+
-0.0698,
|
| 976 |
+
0.5109,
|
| 977 |
+
0.2665,
|
| 978 |
+
-0.2108,
|
| 979 |
+
-0.2158,
|
| 980 |
+
0.2502,
|
| 981 |
+
-0.2055,
|
| 982 |
+
-0.0322,
|
| 983 |
+
0.1109,
|
| 984 |
+
0.1567,
|
| 985 |
+
-0.0729,
|
| 986 |
+
0.0899,
|
| 987 |
+
-0.2799,
|
| 988 |
+
-0.1230,
|
| 989 |
+
-0.0313,
|
| 990 |
+
-0.1649,
|
| 991 |
+
0.0117,
|
| 992 |
+
0.0723,
|
| 993 |
+
-0.2839,
|
| 994 |
+
-0.2083,
|
| 995 |
+
-0.0520,
|
| 996 |
+
0.3748,
|
| 997 |
+
0.0152,
|
| 998 |
+
0.1957,
|
| 999 |
+
0.1433,
|
| 1000 |
+
-0.2944,
|
| 1001 |
+
0.3573,
|
| 1002 |
+
-0.0548,
|
| 1003 |
+
-0.1681,
|
| 1004 |
+
-0.0667,
|
| 1005 |
+
],
|
| 1006 |
+
dtype=dtype,
|
| 1007 |
+
device=device,
|
| 1008 |
+
)
|
| 1009 |
+
self.std = torch.tensor(
|
| 1010 |
+
[
|
| 1011 |
+
0.4765,
|
| 1012 |
+
1.0364,
|
| 1013 |
+
0.4514,
|
| 1014 |
+
1.1677,
|
| 1015 |
+
0.5313,
|
| 1016 |
+
0.4990,
|
| 1017 |
+
0.4818,
|
| 1018 |
+
0.5013,
|
| 1019 |
+
0.8158,
|
| 1020 |
+
1.0344,
|
| 1021 |
+
0.5894,
|
| 1022 |
+
1.0901,
|
| 1023 |
+
0.6885,
|
| 1024 |
+
0.6165,
|
| 1025 |
+
0.8454,
|
| 1026 |
+
0.4978,
|
| 1027 |
+
0.5759,
|
| 1028 |
+
0.3523,
|
| 1029 |
+
0.7135,
|
| 1030 |
+
0.6804,
|
| 1031 |
+
0.5833,
|
| 1032 |
+
1.4146,
|
| 1033 |
+
0.8986,
|
| 1034 |
+
0.5659,
|
| 1035 |
+
0.7069,
|
| 1036 |
+
0.5338,
|
| 1037 |
+
0.4889,
|
| 1038 |
+
0.4917,
|
| 1039 |
+
0.4069,
|
| 1040 |
+
0.4999,
|
| 1041 |
+
0.6866,
|
| 1042 |
+
0.4093,
|
| 1043 |
+
0.5709,
|
| 1044 |
+
0.6065,
|
| 1045 |
+
0.6415,
|
| 1046 |
+
0.4944,
|
| 1047 |
+
0.5726,
|
| 1048 |
+
1.2042,
|
| 1049 |
+
0.5458,
|
| 1050 |
+
1.6887,
|
| 1051 |
+
0.3971,
|
| 1052 |
+
1.0600,
|
| 1053 |
+
0.3943,
|
| 1054 |
+
0.5537,
|
| 1055 |
+
0.5444,
|
| 1056 |
+
0.4089,
|
| 1057 |
+
0.7468,
|
| 1058 |
+
0.7744,
|
| 1059 |
+
],
|
| 1060 |
+
dtype=dtype,
|
| 1061 |
+
device=device,
|
| 1062 |
+
)
|
| 1063 |
+
self.scale = [self.mean, 1.0 / self.std]
|
| 1064 |
+
|
| 1065 |
+
# init model
|
| 1066 |
+
self.vae = (
|
| 1067 |
+
_video_vae(
|
| 1068 |
+
pretrained_path=vae_pth, z_dim=z_dim, dim=c_dim, dim_mult=dim_mult, temperal_downsample=temperal_downsample
|
| 1069 |
+
)
|
| 1070 |
+
.eval()
|
| 1071 |
+
.requires_grad_(False)
|
| 1072 |
+
.to(device)
|
| 1073 |
+
)
|
| 1074 |
+
|
| 1075 |
+
def encode(self, video):
|
| 1076 |
+
return self.vae.encode(video, self.scale).float()
|
| 1077 |
+
|
| 1078 |
+
def to(self, *args, **kwargs):
|
| 1079 |
+
self.mean = self.mean.to(*args, **kwargs)
|
| 1080 |
+
self.std = self.std.to(*args, **kwargs)
|
| 1081 |
+
self.scale = [self.mean, 1.0 / self.std]
|
| 1082 |
+
self.vae = self.vae.to(*args, **kwargs)
|
| 1083 |
+
return self
|
| 1084 |
+
|
| 1085 |
+
def decode(self, z, group: torch.distributed.ProcessGroup = None):
|
| 1086 |
+
return self.vae.decode(z, self.scale, group=group).float().clamp_(-1, 1)
|
inference/pipeline/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .pipeline import MagiPipeline
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
# pipeline
|
| 19 |
+
"MagiPipeline",
|
| 20 |
+
]
|
inference/pipeline/data_proxy.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from enum import IntEnum
|
| 18 |
+
from typing import Any, Literal, Optional, TYPE_CHECKING
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from einops import rearrange
|
| 22 |
+
from inference.common import DataProxyConfig, Modality, VarlenHandler
|
| 23 |
+
from inference.model.dit.dit_module import FFAHandler
|
| 24 |
+
from torch.nn import functional as F
|
| 25 |
+
from unfoldNd import UnfoldNd
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING:
|
| 28 |
+
from inference.pipeline.video_generate import EvalInput
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def calc_local_qk_range(num_video_tokens, num_audio_and_txt_tokens, num_frames, frame_receptive_field):
|
| 32 |
+
token_per_frame = num_video_tokens // num_frames
|
| 33 |
+
total_tokens = num_video_tokens + num_audio_and_txt_tokens
|
| 34 |
+
|
| 35 |
+
q_range_list = []
|
| 36 |
+
k_range_list = []
|
| 37 |
+
|
| 38 |
+
for i in range(num_frames):
|
| 39 |
+
local_q_range = torch.tensor([i * token_per_frame, (i + 1) * token_per_frame])
|
| 40 |
+
local_k_range = torch.tensor(
|
| 41 |
+
[(i - frame_receptive_field) * token_per_frame, (i + frame_receptive_field + 1) * token_per_frame]
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
q_range_list.append(local_q_range)
|
| 45 |
+
k_range_list.append(local_k_range)
|
| 46 |
+
local_q_range = torch.stack(q_range_list, dim=0)
|
| 47 |
+
local_k_range = torch.stack(k_range_list, dim=0)
|
| 48 |
+
|
| 49 |
+
local_k_range[local_k_range < 0] = 0
|
| 50 |
+
local_k_range[local_k_range > num_video_tokens] = num_video_tokens
|
| 51 |
+
|
| 52 |
+
video_q_range = torch.tensor([[0, num_video_tokens]])
|
| 53 |
+
video_k_range = torch.tensor([[num_video_tokens, num_video_tokens + num_audio_and_txt_tokens]])
|
| 54 |
+
|
| 55 |
+
at_q_ranges = torch.tensor([[num_video_tokens, total_tokens]])
|
| 56 |
+
at_k_ranges = torch.tensor([[0, total_tokens]])
|
| 57 |
+
|
| 58 |
+
q_ranges = torch.cat([local_q_range, video_q_range, at_q_ranges], dim=0).to(torch.int32).to("cuda", non_blocking=True)
|
| 59 |
+
k_ranges = torch.cat([local_k_range, video_k_range, at_k_ranges], dim=0).to(torch.int32).to("cuda", non_blocking=True)
|
| 60 |
+
|
| 61 |
+
return (q_ranges, k_ranges)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def calc_local_attn_ffa_handler(num_video_tokens, num_audio_and_txt_tokens, num_frames, frame_receptive_field):
|
| 65 |
+
q_ranges, k_ranges = calc_local_qk_range(num_video_tokens, num_audio_and_txt_tokens, num_frames, frame_receptive_field)
|
| 66 |
+
max_seqlen_q = num_video_tokens + num_audio_and_txt_tokens
|
| 67 |
+
max_seqlen_k = num_video_tokens + num_audio_and_txt_tokens
|
| 68 |
+
attn_type_map = torch.zeros([q_ranges.shape[0]], device="cuda", dtype=torch.int32)
|
| 69 |
+
softmax_scale = None
|
| 70 |
+
|
| 71 |
+
ffa_handler = FFAHandler(
|
| 72 |
+
q_ranges=q_ranges,
|
| 73 |
+
k_ranges=k_ranges,
|
| 74 |
+
max_seqlen_q=max_seqlen_q,
|
| 75 |
+
max_seqlen_k=max_seqlen_k,
|
| 76 |
+
attn_type_map=attn_type_map,
|
| 77 |
+
softmax_scale=softmax_scale,
|
| 78 |
+
)
|
| 79 |
+
return ffa_handler
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_coords(
|
| 83 |
+
shape: list[int],
|
| 84 |
+
ref_feat_shape: list[int],
|
| 85 |
+
offset_thw: list[int] = [0, 0, 0],
|
| 86 |
+
device: torch.device = torch.device("cpu"),
|
| 87 |
+
dtype: torch.dtype = torch.float32,
|
| 88 |
+
):
|
| 89 |
+
"""
|
| 90 |
+
Generate feature-grid coordinates and corresponding original/reference size metadata.
|
| 91 |
+
Args:
|
| 92 |
+
feat_shape: [T, H, W] original feature-map shape
|
| 93 |
+
ref_feat_shape: [T_ref, H_ref, W_ref] reference feature-map shape
|
| 94 |
+
device: device for coordinate tensors
|
| 95 |
+
Returns:
|
| 96 |
+
coords: tensor shape (T*H*W, 9), containing (t, h, w, T, H, W, ref_T, ref_H, ref_W)
|
| 97 |
+
"""
|
| 98 |
+
ori_t, ori_h, ori_w = shape
|
| 99 |
+
ref_t, ref_h, ref_w = ref_feat_shape
|
| 100 |
+
|
| 101 |
+
# Generate index ranges
|
| 102 |
+
offset_t, offset_h, offset_w = offset_thw
|
| 103 |
+
time_rng = torch.arange(ori_t, device=device, dtype=dtype) + offset_t
|
| 104 |
+
height_rng = torch.arange(ori_h, device=device, dtype=dtype) + offset_h
|
| 105 |
+
width_rng = torch.arange(ori_w, device=device, dtype=dtype) + offset_w
|
| 106 |
+
|
| 107 |
+
# Use meshgrid to generate a 3D grid (T, H, W)
|
| 108 |
+
time_grid, height_grid, width_grid = torch.meshgrid(time_rng, height_rng, width_rng, indexing="ij")
|
| 109 |
+
|
| 110 |
+
# Stack and flatten
|
| 111 |
+
coords_grid = torch.stack([time_grid, height_grid, width_grid], dim=-1)
|
| 112 |
+
coords_flat = coords_grid.reshape(-1, 3)
|
| 113 |
+
|
| 114 |
+
# Build and expand size metadata
|
| 115 |
+
meta = torch.tensor([ori_t, ori_h, ori_w, ref_t, ref_h, ref_w], device=device, dtype=dtype)
|
| 116 |
+
meta_expanded = meta.expand(coords_flat.size(0), -1)
|
| 117 |
+
|
| 118 |
+
# Merge and return
|
| 119 |
+
return torch.cat([coords_flat, meta_expanded], dim=-1)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@dataclass
|
| 123 |
+
class SingleData:
|
| 124 |
+
video_x_t: torch.Tensor
|
| 125 |
+
audio_x_t: torch.Tensor
|
| 126 |
+
audio_feat_len: int
|
| 127 |
+
txt_feat: torch.Tensor
|
| 128 |
+
txt_feat_len: int
|
| 129 |
+
t: int
|
| 130 |
+
h: int
|
| 131 |
+
w: int
|
| 132 |
+
patch_size: int
|
| 133 |
+
t_patch_size: int
|
| 134 |
+
spatial_rope_interpolation: Literal["inter", "extra"]
|
| 135 |
+
ref_audio_offset: int
|
| 136 |
+
text_offset: int
|
| 137 |
+
coords_style: Literal["v1", "v2"] = "v1"
|
| 138 |
+
|
| 139 |
+
def __post_init__(self):
|
| 140 |
+
self.video_token_num = self.video_x_t.shape[0]
|
| 141 |
+
|
| 142 |
+
self.audio_x_t = self.audio_x_t[: self.audio_feat_len]
|
| 143 |
+
self.txt_feat = self.txt_feat[: self.txt_feat_len]
|
| 144 |
+
|
| 145 |
+
self.video_channel = self.video_x_t.shape[-1]
|
| 146 |
+
self.audio_channel = self.audio_x_t.shape[-1]
|
| 147 |
+
self.txt_channel = self.txt_feat.shape[-1]
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
def device(self):
|
| 151 |
+
return self.video_x_t.device
|
| 152 |
+
|
| 153 |
+
@property
|
| 154 |
+
def default_dtype(self):
|
| 155 |
+
return self.video_x_t.dtype
|
| 156 |
+
|
| 157 |
+
@property
|
| 158 |
+
def total_token_num(self):
|
| 159 |
+
return self.video_token_num + self.audio_feat_len + self.txt_feat_len
|
| 160 |
+
|
| 161 |
+
@property
|
| 162 |
+
def token_sequence(self):
|
| 163 |
+
tensors_to_concat = [self.video_x_t, self.audio_x_t, self.txt_feat]
|
| 164 |
+
max_channel = max(tensor.shape[-1] for tensor in tensors_to_concat)
|
| 165 |
+
|
| 166 |
+
padded_tensors = [F.pad(t, (0, max_channel - t.shape[-1])) for t in tensors_to_concat]
|
| 167 |
+
ret_val = torch.cat(padded_tensors, dim=0)
|
| 168 |
+
return ret_val
|
| 169 |
+
|
| 170 |
+
@property
|
| 171 |
+
def modality_mapping(self):
|
| 172 |
+
v_map = torch.full((self.video_token_num,), Modality.VIDEO, dtype=torch.int64, device=self.device)
|
| 173 |
+
a_map = torch.full((self.audio_feat_len,), Modality.AUDIO, dtype=torch.int64, device=self.device)
|
| 174 |
+
t_map = torch.full((self.txt_feat_len,), Modality.TEXT, dtype=torch.int64, device=self.device)
|
| 175 |
+
|
| 176 |
+
modality_mapping = torch.cat([v_map, a_map, t_map], dim=0)
|
| 177 |
+
return modality_mapping
|
| 178 |
+
|
| 179 |
+
def default_coords(self, shape, ref_feat_shape, offset_thw=[0, 0, 0]):
|
| 180 |
+
return get_coords(
|
| 181 |
+
shape=shape, ref_feat_shape=ref_feat_shape, offset_thw=offset_thw, device=self.device, dtype=self.default_dtype
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def coords_mapping(self):
|
| 186 |
+
if self.spatial_rope_interpolation == "inter":
|
| 187 |
+
video_ref_feat_shape = (self.t // self.t_patch_size, 32, 32)
|
| 188 |
+
else:
|
| 189 |
+
video_ref_feat_shape = (self.t // self.t_patch_size, self.h // self.patch_size, self.w // self.patch_size)
|
| 190 |
+
|
| 191 |
+
video_coords = self.default_coords(
|
| 192 |
+
shape=(self.t // self.t_patch_size, self.h // self.patch_size, self.w // self.patch_size),
|
| 193 |
+
ref_feat_shape=video_ref_feat_shape,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
if self.coords_style == "v1":
|
| 197 |
+
audio_coords = self.default_coords(
|
| 198 |
+
shape=(self.audio_feat_len, 1, 1), ref_feat_shape=(self.t // self.t_patch_size, 1, 1)
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
text_coords = self.default_coords(
|
| 202 |
+
shape=(self.txt_feat_len, 1, 1), ref_feat_shape=(2, 1, 1), offset_thw=[self.text_offset, 0, 0]
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
elif self.coords_style == "v2":
|
| 206 |
+
magic_audio_ref_t = (self.audio_feat_len - 1) // 4 + 1
|
| 207 |
+
audio_coords = self.default_coords(
|
| 208 |
+
shape=(self.audio_feat_len, 1, 1), ref_feat_shape=(magic_audio_ref_t // self.t_patch_size, 1, 1)
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
text_coords = self.default_coords(
|
| 212 |
+
shape=(self.txt_feat_len, 1, 1), ref_feat_shape=(1, 1, 1), offset_thw=[-self.txt_feat_len, 0, 0]
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
coords_mapping = torch.cat([video_coords, audio_coords, text_coords], dim=0)
|
| 216 |
+
return coords_mapping
|
| 217 |
+
|
| 218 |
+
def depack_token_sequence(self, token_sequence):
|
| 219 |
+
video_x_t = token_sequence[: self.video_token_num, : self.video_channel]
|
| 220 |
+
video_x_t = rearrange(
|
| 221 |
+
video_x_t,
|
| 222 |
+
"(T H W) (pT pH pW C) -> C (T pT) (H pH) (W pW)",
|
| 223 |
+
H=self.h // self.patch_size,
|
| 224 |
+
W=self.w // self.patch_size,
|
| 225 |
+
pT=self.t_patch_size,
|
| 226 |
+
pH=self.patch_size,
|
| 227 |
+
pW=self.patch_size,
|
| 228 |
+
).contiguous()
|
| 229 |
+
|
| 230 |
+
audio_x_t = token_sequence[self.video_token_num : self.video_token_num + self.audio_feat_len, : self.audio_channel]
|
| 231 |
+
return video_x_t, audio_x_t
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
@dataclass
|
| 235 |
+
class SimplePackedData:
|
| 236 |
+
items: list[SingleData]
|
| 237 |
+
|
| 238 |
+
@property
|
| 239 |
+
def token_sequence(self):
|
| 240 |
+
return torch.cat([item.token_sequence for item in self.items], dim=0)
|
| 241 |
+
|
| 242 |
+
@property
|
| 243 |
+
def modality_mapping(self):
|
| 244 |
+
return torch.cat([item.modality_mapping for item in self.items], dim=0)
|
| 245 |
+
|
| 246 |
+
@property
|
| 247 |
+
def coords_mapping(self):
|
| 248 |
+
return torch.cat([item.coords_mapping for item in self.items], dim=0)
|
| 249 |
+
|
| 250 |
+
@property
|
| 251 |
+
def total_token_num(self):
|
| 252 |
+
return sum([item.total_token_num for item in self.items])
|
| 253 |
+
|
| 254 |
+
def __getitem__(self, index):
|
| 255 |
+
return self.items[index]
|
| 256 |
+
|
| 257 |
+
@property
|
| 258 |
+
def cu_seqlen(self):
|
| 259 |
+
cu_seqlen = torch.cumsum(torch.tensor([item.total_token_num for item in self.items]), dim=0)
|
| 260 |
+
cu_seqlen = torch.nn.functional.pad(cu_seqlen, (1, 0))
|
| 261 |
+
return cu_seqlen
|
| 262 |
+
|
| 263 |
+
@property
|
| 264 |
+
def max_seqlen(self):
|
| 265 |
+
return torch.tensor(max([item.total_token_num for item in self.items]))
|
| 266 |
+
|
| 267 |
+
def depack_token_sequence(self, token_sequence):
|
| 268 |
+
video_x_t_list = []
|
| 269 |
+
audio_x_t_list = []
|
| 270 |
+
|
| 271 |
+
token_sequence_list = torch.split(token_sequence, [item.total_token_num for item in self.items], dim=0)
|
| 272 |
+
for item, token_sequence in zip(self.items, token_sequence_list):
|
| 273 |
+
video_x_t, audio_x_t = item.depack_token_sequence(token_sequence)
|
| 274 |
+
video_x_t_list.append(video_x_t)
|
| 275 |
+
audio_x_t_list.append(audio_x_t)
|
| 276 |
+
return torch.stack(video_x_t_list, dim=0), torch.stack(audio_x_t_list, dim=0)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class MagiDataProxy:
|
| 280 |
+
def __init__(self, config: DataProxyConfig):
|
| 281 |
+
self.patch_size = config.patch_size
|
| 282 |
+
self.t_patch_size = config.t_patch_size
|
| 283 |
+
self.frame_receptive_field = config.frame_receptive_field
|
| 284 |
+
self.spatial_rope_interpolation = 'extra'
|
| 285 |
+
self.ref_audio_offset = config.ref_audio_offset
|
| 286 |
+
self.text_offset = config.text_offset
|
| 287 |
+
self.unfold = UnfoldNd(
|
| 288 |
+
kernel_size=(self.t_patch_size, self.patch_size, self.patch_size),
|
| 289 |
+
stride=(self.t_patch_size, self.patch_size, self.patch_size),
|
| 290 |
+
)
|
| 291 |
+
self.coords_style = config.coords_style
|
| 292 |
+
|
| 293 |
+
self._saved_data: dict[str, Any] = {}
|
| 294 |
+
|
| 295 |
+
def saved_for_output(self, **kwargs):
|
| 296 |
+
"""
|
| 297 |
+
Store intermediate data used by process_output.
|
| 298 |
+
Supports keyword-argument style calls: saved_for_output(a=1, b=2)
|
| 299 |
+
Can be called multiple times to accumulate data
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
**kwargs: key-value pairs to store
|
| 303 |
+
"""
|
| 304 |
+
# Directly update dict; supports accumulation across calls
|
| 305 |
+
self._saved_data.update(kwargs)
|
| 306 |
+
|
| 307 |
+
def get_saved_data(self, key: str):
|
| 308 |
+
"""
|
| 309 |
+
Get stored data
|
| 310 |
+
"""
|
| 311 |
+
return self._saved_data[key]
|
| 312 |
+
|
| 313 |
+
def img2tokens(self, x_t: torch.Tensor):
|
| 314 |
+
x_t_unfolded = self.unfold(x_t)
|
| 315 |
+
# Transpose dimensions from (N, col_dim, num_tokens) -> (N, num_tokens, col_dim)
|
| 316 |
+
x_t = rearrange(x_t_unfolded, "N col_dim num_tokens -> N num_tokens col_dim").contiguous()
|
| 317 |
+
return x_t
|
| 318 |
+
|
| 319 |
+
def process_input(self, transported_data: "EvalInput"):
|
| 320 |
+
# init img2col module
|
| 321 |
+
|
| 322 |
+
batch_size, _, t, h, w = transported_data.x_t.shape
|
| 323 |
+
# 1. Process video features while keeping the batch dimension
|
| 324 |
+
x_t = self.img2tokens(transported_data.x_t)
|
| 325 |
+
|
| 326 |
+
# 2. Process audio features while keeping the batch dimension
|
| 327 |
+
# Assume transported_data.audio_x_t shape is already (N, num_tokens, col_dim)
|
| 328 |
+
audio_x_t = transported_data.audio_x_t.contiguous()
|
| 329 |
+
|
| 330 |
+
# Here we assume text_in shape is (N, num_tokens, col_dim)
|
| 331 |
+
text_in = transported_data.txt_feat.contiguous()
|
| 332 |
+
|
| 333 |
+
simple_packed_data = SimplePackedData(items=[])
|
| 334 |
+
for i in range(batch_size):
|
| 335 |
+
single_data = SingleData(
|
| 336 |
+
video_x_t=x_t[i],
|
| 337 |
+
audio_x_t=audio_x_t[i],
|
| 338 |
+
audio_feat_len=transported_data.audio_feat_len[i],
|
| 339 |
+
txt_feat=text_in[i],
|
| 340 |
+
txt_feat_len=transported_data.txt_feat_len[i],
|
| 341 |
+
t=t,
|
| 342 |
+
h=h,
|
| 343 |
+
w=w,
|
| 344 |
+
patch_size=self.patch_size,
|
| 345 |
+
t_patch_size=self.t_patch_size,
|
| 346 |
+
spatial_rope_interpolation=self.spatial_rope_interpolation,
|
| 347 |
+
ref_audio_offset=self.ref_audio_offset,
|
| 348 |
+
text_offset=self.text_offset,
|
| 349 |
+
coords_style=self.coords_style,
|
| 350 |
+
)
|
| 351 |
+
simple_packed_data.items.append(single_data)
|
| 352 |
+
|
| 353 |
+
if self.frame_receptive_field != -1:
|
| 354 |
+
assert batch_size == 1, "local attention only supports batch size 1"
|
| 355 |
+
|
| 356 |
+
local_attn_handler = calc_local_attn_ffa_handler(
|
| 357 |
+
num_video_tokens=simple_packed_data[0].video_token_num,
|
| 358 |
+
num_audio_and_txt_tokens=simple_packed_data[0].audio_feat_len + simple_packed_data[0].txt_feat_len,
|
| 359 |
+
num_frames=t,
|
| 360 |
+
frame_receptive_field=self.frame_receptive_field,
|
| 361 |
+
)
|
| 362 |
+
if isinstance(local_attn_handler.max_seqlen_k, torch.Tensor):
|
| 363 |
+
local_attn_handler.max_seqlen_k = local_attn_handler.max_seqlen_k.item()
|
| 364 |
+
if isinstance(local_attn_handler.max_seqlen_q, torch.Tensor):
|
| 365 |
+
local_attn_handler.max_seqlen_q = local_attn_handler.max_seqlen_q.item()
|
| 366 |
+
else:
|
| 367 |
+
local_attn_handler = None
|
| 368 |
+
|
| 369 |
+
varlen_handler = VarlenHandler(
|
| 370 |
+
cu_seqlens_q=simple_packed_data.cu_seqlen.to(torch.int32).cuda(),
|
| 371 |
+
cu_seqlens_k=simple_packed_data.cu_seqlen.to(torch.int32).cuda(),
|
| 372 |
+
max_seqlen_q=simple_packed_data.max_seqlen.to(torch.int32).cuda(),
|
| 373 |
+
max_seqlen_k=simple_packed_data.max_seqlen.to(torch.int32).cuda(),
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
self.saved_for_output(simple_packed_data=simple_packed_data)
|
| 377 |
+
|
| 378 |
+
x = simple_packed_data.token_sequence
|
| 379 |
+
coords_mapping = simple_packed_data.coords_mapping
|
| 380 |
+
modality_mapping = simple_packed_data.modality_mapping
|
| 381 |
+
|
| 382 |
+
return (x, coords_mapping, modality_mapping, varlen_handler, local_attn_handler)
|
| 383 |
+
|
| 384 |
+
def process_output(self, x: torch.Tensor):
|
| 385 |
+
# Inserting operations in between may corrupt parallel-runtime data and cause latent errors
|
| 386 |
+
|
| 387 |
+
simple_packed_data: SimplePackedData = self.get_saved_data("simple_packed_data")
|
| 388 |
+
x_video, x_audio = simple_packed_data.depack_token_sequence(x)
|
| 389 |
+
|
| 390 |
+
return (x_video, x_audio)
|
inference/pipeline/entry.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import sys
|
| 17 |
+
|
| 18 |
+
from inference.common import parse_config
|
| 19 |
+
from inference.infra import initialize_infra
|
| 20 |
+
from inference.model.dit import get_dit
|
| 21 |
+
from inference.utils import print_rank_0
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
from .pipeline import MagiPipeline
|
| 25 |
+
except ImportError:
|
| 26 |
+
# Keep compatibility when entry.py is executed as a script path.
|
| 27 |
+
from inference.pipeline import MagiPipeline
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def parse_arguments():
|
| 31 |
+
parser = argparse.ArgumentParser(description="Run DiT pipeline with unified offline entry.")
|
| 32 |
+
parser.add_argument("--prompt", type=str)
|
| 33 |
+
parser.add_argument("--save_path_prefix", type=str, help="Path prefix for saving outputs.")
|
| 34 |
+
parser.add_argument("--output_path", type=str, help="Alias of --save_path_prefix for MAGI-style CLI.")
|
| 35 |
+
|
| 36 |
+
parser.add_argument("--image_path", type=str, help="Path to image for i2v mode.")
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--audio_path", type=str, default=None, help="Path to optional audio for lipsync mode; omit to use i2v or t2v"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Optional runtime controls; forwarded to pipeline methods when provided.
|
| 42 |
+
parser.add_argument("--seed", type=int)
|
| 43 |
+
parser.add_argument("--seconds", type=int)
|
| 44 |
+
parser.add_argument("--br_width", type=int)
|
| 45 |
+
parser.add_argument("--br_height", type=int)
|
| 46 |
+
parser.add_argument("--sr_width", type=int)
|
| 47 |
+
parser.add_argument("--sr_height", type=int)
|
| 48 |
+
parser.add_argument("--output_width", type=int)
|
| 49 |
+
parser.add_argument("--output_height", type=int)
|
| 50 |
+
parser.add_argument("--upsample_mode", type=str)
|
| 51 |
+
args, _ = parser.parse_known_args()
|
| 52 |
+
return args
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def main():
|
| 56 |
+
args = parse_arguments()
|
| 57 |
+
config = parse_config()
|
| 58 |
+
model = get_dit(config.arch_config, config.engine_config)
|
| 59 |
+
pipeline = MagiPipeline(model, config.evaluation_config)
|
| 60 |
+
save_path_prefix = args.save_path_prefix or args.output_path
|
| 61 |
+
if not save_path_prefix:
|
| 62 |
+
print_rank_0("Error: --save_path_prefix (or --output_path) is required.")
|
| 63 |
+
sys.exit(1)
|
| 64 |
+
|
| 65 |
+
optional_kwargs = {
|
| 66 |
+
"seed": args.seed,
|
| 67 |
+
"seconds": args.seconds,
|
| 68 |
+
"br_width": args.br_width,
|
| 69 |
+
"br_height": args.br_height,
|
| 70 |
+
"sr_width": args.sr_width,
|
| 71 |
+
"sr_height": args.sr_height,
|
| 72 |
+
"output_width": args.output_width,
|
| 73 |
+
"output_height": args.output_height,
|
| 74 |
+
"upsample_mode": args.upsample_mode,
|
| 75 |
+
}
|
| 76 |
+
optional_kwargs = {k: v for k, v in optional_kwargs.items() if v is not None and v is not False}
|
| 77 |
+
|
| 78 |
+
prompt = args.prompt
|
| 79 |
+
image_path = args.image_path
|
| 80 |
+
audio_path = args.audio_path
|
| 81 |
+
|
| 82 |
+
if not prompt:
|
| 83 |
+
print_rank_0("Error: --prompt is required.")
|
| 84 |
+
sys.exit(1)
|
| 85 |
+
if not image_path:
|
| 86 |
+
print_rank_0("Error: --image_path is required.")
|
| 87 |
+
sys.exit(1)
|
| 88 |
+
|
| 89 |
+
pipeline.run_offline(
|
| 90 |
+
prompt=prompt, image=image_path, audio=audio_path, save_path_prefix=save_path_prefix, **optional_kwargs
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
initialize_infra()
|
| 96 |
+
main()
|
inference/pipeline/pipeline.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import random
|
| 17 |
+
from typing import Optional, Union
|
| 18 |
+
|
| 19 |
+
import imageio
|
| 20 |
+
import soundfile as sf
|
| 21 |
+
import torch
|
| 22 |
+
from PIL import Image
|
| 23 |
+
|
| 24 |
+
from inference.common import EvaluationConfig, parse_config
|
| 25 |
+
from inference.model.dit import get_dit
|
| 26 |
+
from inference.model.dit import DiTModel
|
| 27 |
+
from .video_generate import MagiEvaluator
|
| 28 |
+
from .video_process import merge_video_and_audio, upsample_video
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class MagiPipeline:
|
| 32 |
+
"""Pipeline facade for inference."""
|
| 33 |
+
|
| 34 |
+
def __init__(self, model: DiTModel, evaluation_config: EvaluationConfig, device: str = "cuda"):
|
| 35 |
+
self.model = model
|
| 36 |
+
self.evaluation_config = evaluation_config
|
| 37 |
+
config = parse_config()
|
| 38 |
+
if evaluation_config.use_sr_model:
|
| 39 |
+
config.engine_config.load = evaluation_config.sr_model_path
|
| 40 |
+
sr_model = get_dit(config.sr_arch_config, config.engine_config)
|
| 41 |
+
else:
|
| 42 |
+
sr_model = None
|
| 43 |
+
self.evaluator = MagiEvaluator(model, sr_model, evaluation_config, device)
|
| 44 |
+
|
| 45 |
+
def _validate_offline_request(
|
| 46 |
+
self,
|
| 47 |
+
prompt: str,
|
| 48 |
+
save_path_prefix: str,
|
| 49 |
+
):
|
| 50 |
+
if not prompt or not prompt.strip():
|
| 51 |
+
raise ValueError("`prompt` must be a non-empty string.")
|
| 52 |
+
if not save_path_prefix or not save_path_prefix.strip():
|
| 53 |
+
raise ValueError("`save_path_prefix` must be a non-empty string.")
|
| 54 |
+
|
| 55 |
+
def run_offline(
|
| 56 |
+
self,
|
| 57 |
+
prompt: str,
|
| 58 |
+
image: Union[str, Image.Image, None],
|
| 59 |
+
audio: Optional[str],
|
| 60 |
+
save_path_prefix: str,
|
| 61 |
+
seed: int = 42,
|
| 62 |
+
seconds: int = 4,
|
| 63 |
+
br_width: int = 480,
|
| 64 |
+
br_height: int = 272,
|
| 65 |
+
sr_width: Optional[int] = None,
|
| 66 |
+
sr_height: Optional[int] = None,
|
| 67 |
+
output_width: Optional[int] = None,
|
| 68 |
+
output_height: Optional[int] = None,
|
| 69 |
+
upsample_mode: Optional[str] = None,
|
| 70 |
+
):
|
| 71 |
+
self._validate_offline_request(prompt=prompt, save_path_prefix=save_path_prefix)
|
| 72 |
+
|
| 73 |
+
if self.evaluator.sr_model is not None:
|
| 74 |
+
save_path = f"{save_path_prefix}_{seconds}s_{br_width}x{br_height}_{sr_width}x{sr_height}.mp4"
|
| 75 |
+
else:
|
| 76 |
+
save_path = f"{save_path_prefix}_{seconds}s_{br_width}x{br_height}.mp4"
|
| 77 |
+
|
| 78 |
+
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
|
| 79 |
+
torch.random.manual_seed(seed)
|
| 80 |
+
video_np, audio_np = self.evaluator.evaluate(
|
| 81 |
+
prompt,
|
| 82 |
+
image,
|
| 83 |
+
audio,
|
| 84 |
+
seconds=seconds,
|
| 85 |
+
br_width=br_width,
|
| 86 |
+
br_height=br_height,
|
| 87 |
+
sr_width=sr_width,
|
| 88 |
+
sr_height=sr_height,
|
| 89 |
+
br_num_inference_steps=self.evaluation_config.num_inference_steps,
|
| 90 |
+
sr_num_inference_steps=self.evaluation_config.sr_num_inference_steps,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
if output_width is not None and output_height is not None:
|
| 94 |
+
video_np = upsample_video(video_np, output_width, output_height, upsample_mode)
|
| 95 |
+
|
| 96 |
+
if torch.distributed.get_rank() == torch.distributed.get_world_size() - 1:
|
| 97 |
+
saving_name = f"{prompt.replace(' ', '_')[:10]}"
|
| 98 |
+
audio_path = saving_name + str(random.randint(0, 1000000)) + ".wav"
|
| 99 |
+
video_path = saving_name + str(random.randint(0, 1000000)) + ".mp4"
|
| 100 |
+
sf.write(audio_path, audio_np, self.evaluator.audio_vae.sample_rate)
|
| 101 |
+
imageio.mimwrite(video_path, video_np, fps=self.evaluation_config.fps, quality=8, output_params=["-loglevel", "error"])
|
| 102 |
+
assert os.path.exists(video_path)
|
| 103 |
+
merge_video_and_audio(video_path, audio_path, save_path)
|
| 104 |
+
|
| 105 |
+
if torch.distributed.is_initialized():
|
| 106 |
+
torch.distributed.barrier()
|
| 107 |
+
return save_path
|
| 108 |
+
|
inference/pipeline/prompt_process.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Tuple
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch.nn import functional as F
|
| 19 |
+
|
| 20 |
+
from inference.model.t5_gemma import get_t5_gemma_embedding
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def pad_or_trim(tensor: torch.Tensor, target_size: int, dim: int, pad_value: float = 0.0) -> Tuple[torch.Tensor, int]:
|
| 24 |
+
"""
|
| 25 |
+
Pads or trims a tensor along a specified dimension to reach a target size.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
tensor (torch.Tensor): The input tensor to be processed.
|
| 29 |
+
target_size (int): The desired size for the specified dimension.
|
| 30 |
+
dim (int): The dimension along which to pad or trim.
|
| 31 |
+
pad_value (float, optional): The value used for padding. Defaults to 0.0.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
torch.Tensor: The resulting tensor with the target size in the specified dimension.
|
| 35 |
+
"""
|
| 36 |
+
current_size = tensor.size(dim)
|
| 37 |
+
if current_size < target_size:
|
| 38 |
+
padding_amount = target_size - current_size
|
| 39 |
+
padding_tuple = [0] * (2 * tensor.dim())
|
| 40 |
+
padding_dim_index = tensor.dim() - 1 - dim
|
| 41 |
+
padding_tuple[2 * padding_dim_index + 1] = padding_amount
|
| 42 |
+
return F.pad(tensor, tuple(padding_tuple), "constant", pad_value), current_size
|
| 43 |
+
|
| 44 |
+
slicing = [slice(None)] * tensor.dim()
|
| 45 |
+
slicing[dim] = slice(0, target_size)
|
| 46 |
+
return tensor[tuple(slicing)], target_size
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_padded_t5_gemma_embedding(
|
| 50 |
+
prompt: str,
|
| 51 |
+
model_path: str,
|
| 52 |
+
device: str,
|
| 53 |
+
weight_dtype: torch.dtype,
|
| 54 |
+
target_length: int,
|
| 55 |
+
) -> Tuple[torch.Tensor, int]:
|
| 56 |
+
txt_feat = get_t5_gemma_embedding(prompt, model_path, device, weight_dtype)
|
| 57 |
+
txt_feat, original_len = pad_or_trim(txt_feat, target_size=target_length, dim=1)
|
| 58 |
+
return txt_feat.to(torch.float32), original_len
|
| 59 |
+
|
| 60 |
+
|
inference/pipeline/scheduler_unipc.py
ADDED
|
@@ -0,0 +1,832 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2026 SandAI. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
|
| 16 |
+
# Convert unipc for flow matching
|
| 17 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 18 |
+
import math
|
| 19 |
+
from typing import Any, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 24 |
+
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
| 25 |
+
from diffusers.utils import deprecate
|
| 26 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
| 30 |
+
"""
|
| 31 |
+
`UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
|
| 32 |
+
|
| 33 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 34 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 38 |
+
The number of diffusion steps to train the model.
|
| 39 |
+
solver_order (`int`, default `2`):
|
| 40 |
+
The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
|
| 41 |
+
due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
|
| 42 |
+
unconditional sampling.
|
| 43 |
+
prediction_type (`str`, defaults to "flow_prediction"):
|
| 44 |
+
Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
|
| 45 |
+
the flow of the diffusion process.
|
| 46 |
+
thresholding (`bool`, defaults to `False`):
|
| 47 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
| 48 |
+
as Stable Diffusion.
|
| 49 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
| 50 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
| 51 |
+
sample_max_value (`float`, defaults to 1.0):
|
| 52 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
|
| 53 |
+
predict_x0 (`bool`, defaults to `True`):
|
| 54 |
+
Whether to use the updating algorithm on the predicted x0.
|
| 55 |
+
solver_type (`str`, default `bh2`):
|
| 56 |
+
Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
|
| 57 |
+
otherwise.
|
| 58 |
+
lower_order_final (`bool`, default `True`):
|
| 59 |
+
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
| 60 |
+
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
| 61 |
+
disable_corrector (`list`, default `[]`):
|
| 62 |
+
Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
|
| 63 |
+
and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
|
| 64 |
+
usually disabled during the first few steps.
|
| 65 |
+
solver_p (`SchedulerMixin`, default `None`):
|
| 66 |
+
Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
|
| 67 |
+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
| 68 |
+
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
| 69 |
+
the sigmas are determined according to a sequence of noise levels {σi}.
|
| 70 |
+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
| 71 |
+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
| 72 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
| 73 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
| 74 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
| 75 |
+
steps_offset (`int`, defaults to 0):
|
| 76 |
+
An offset added to the inference steps, as required by some model families.
|
| 77 |
+
final_sigmas_type (`str`, defaults to `"zero"`):
|
| 78 |
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
| 79 |
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
| 83 |
+
order = 1
|
| 84 |
+
|
| 85 |
+
@register_to_config
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
num_train_timesteps: int = 1000,
|
| 89 |
+
solver_order: int = 2,
|
| 90 |
+
prediction_type: str = "flow_prediction",
|
| 91 |
+
shift: float = 1.0,
|
| 92 |
+
use_dynamic_shifting=False,
|
| 93 |
+
thresholding: bool = False,
|
| 94 |
+
dynamic_thresholding_ratio: float = 0.995,
|
| 95 |
+
sample_max_value: float = 1.0,
|
| 96 |
+
predict_x0: bool = True,
|
| 97 |
+
solver_type: str = "bh2",
|
| 98 |
+
lower_order_final: bool = True,
|
| 99 |
+
disable_corrector: List[int] = [],
|
| 100 |
+
solver_p: SchedulerMixin = None,
|
| 101 |
+
timestep_spacing: str = "linspace",
|
| 102 |
+
steps_offset: int = 0,
|
| 103 |
+
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
| 104 |
+
):
|
| 105 |
+
if solver_type not in ["bh1", "bh2"]:
|
| 106 |
+
if solver_type in ["midpoint", "heun", "logrho"]:
|
| 107 |
+
self.register_to_config(solver_type="bh2")
|
| 108 |
+
else:
|
| 109 |
+
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
| 110 |
+
|
| 111 |
+
self.predict_x0 = predict_x0
|
| 112 |
+
# setable values
|
| 113 |
+
self.num_inference_steps = None
|
| 114 |
+
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
|
| 115 |
+
sigmas = 1.0 - alphas
|
| 116 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
|
| 117 |
+
|
| 118 |
+
if not use_dynamic_shifting:
|
| 119 |
+
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
| 120 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
|
| 121 |
+
|
| 122 |
+
self.sigmas = sigmas
|
| 123 |
+
self.timesteps = sigmas * num_train_timesteps
|
| 124 |
+
|
| 125 |
+
self.model_outputs = [None] * solver_order
|
| 126 |
+
self.timestep_list = [None] * solver_order
|
| 127 |
+
self.lower_order_nums = 0
|
| 128 |
+
self.disable_corrector = disable_corrector
|
| 129 |
+
self.solver_p = solver_p
|
| 130 |
+
self.last_sample = None
|
| 131 |
+
self._step_index: Optional[int] = None
|
| 132 |
+
self._begin_index: Optional[int] = None
|
| 133 |
+
|
| 134 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 135 |
+
self.sigma_min = self.sigmas[-1].item()
|
| 136 |
+
self.sigma_max = self.sigmas[0].item()
|
| 137 |
+
|
| 138 |
+
@property
|
| 139 |
+
def step_index(self):
|
| 140 |
+
"""
|
| 141 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 142 |
+
"""
|
| 143 |
+
return self._step_index
|
| 144 |
+
|
| 145 |
+
@property
|
| 146 |
+
def begin_index(self):
|
| 147 |
+
"""
|
| 148 |
+
The index for the first timestep. It should be set by `inference.pipeline` with `set_begin_index`.
|
| 149 |
+
"""
|
| 150 |
+
return self._begin_index
|
| 151 |
+
|
| 152 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 153 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 154 |
+
"""
|
| 155 |
+
Sets the begin index for the scheduler. This function should be run by `inference.pipeline` before inference.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
begin_index (`int`):
|
| 159 |
+
The begin index for the scheduler.
|
| 160 |
+
"""
|
| 161 |
+
self._begin_index = begin_index
|
| 162 |
+
|
| 163 |
+
# Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
|
| 164 |
+
def set_timesteps(
|
| 165 |
+
self,
|
| 166 |
+
num_inference_steps: Union[int, None] = None,
|
| 167 |
+
device: Union[str, torch.device] = None,
|
| 168 |
+
sigmas: Optional[List[float]] = None,
|
| 169 |
+
mu: Optional[Union[float, None]] = None,
|
| 170 |
+
shift: Optional[Union[float, None]] = None,
|
| 171 |
+
):
|
| 172 |
+
"""
|
| 173 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 174 |
+
Args:
|
| 175 |
+
num_inference_steps (`int`):
|
| 176 |
+
Total number of the spacing of the time steps.
|
| 177 |
+
device (`str` or `torch.device`, *optional*):
|
| 178 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
if self.config.use_dynamic_shifting and mu is None:
|
| 182 |
+
raise ValueError(" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
|
| 183 |
+
|
| 184 |
+
if sigmas is None:
|
| 185 |
+
sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # type: ignore
|
| 186 |
+
|
| 187 |
+
if self.config.use_dynamic_shifting:
|
| 188 |
+
sigmas = self.time_shift(mu, 1.0, sigmas) # type: ignore
|
| 189 |
+
else:
|
| 190 |
+
if shift is None:
|
| 191 |
+
shift = self.config.shift
|
| 192 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # type: ignore
|
| 193 |
+
|
| 194 |
+
if self.config.final_sigmas_type == "sigma_min":
|
| 195 |
+
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
| 196 |
+
elif self.config.final_sigmas_type == "zero":
|
| 197 |
+
sigma_last = 0
|
| 198 |
+
else:
|
| 199 |
+
raise ValueError(
|
| 200 |
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
timesteps = sigmas * self.config.num_train_timesteps
|
| 204 |
+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore
|
| 205 |
+
|
| 206 |
+
self.sigmas = torch.from_numpy(sigmas)
|
| 207 |
+
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
|
| 208 |
+
|
| 209 |
+
self.num_inference_steps = len(timesteps) # type: ignore
|
| 210 |
+
|
| 211 |
+
self.model_outputs = [None] * self.config.solver_order
|
| 212 |
+
self.lower_order_nums = 0
|
| 213 |
+
self.last_sample = None
|
| 214 |
+
if self.solver_p:
|
| 215 |
+
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
|
| 216 |
+
|
| 217 |
+
# add an index counter for schedulers that allow duplicated timesteps
|
| 218 |
+
self._step_index = None
|
| 219 |
+
self._begin_index = None
|
| 220 |
+
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
| 221 |
+
|
| 222 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
| 223 |
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
| 224 |
+
"""
|
| 225 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
| 226 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
| 227 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
| 228 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
| 229 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
| 230 |
+
|
| 231 |
+
https://arxiv.org/abs/2205.11487
|
| 232 |
+
"""
|
| 233 |
+
dtype = sample.dtype
|
| 234 |
+
batch_size, channels, *remaining_dims = sample.shape
|
| 235 |
+
|
| 236 |
+
if dtype not in (torch.float32, torch.float64):
|
| 237 |
+
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
| 238 |
+
|
| 239 |
+
# Flatten sample for doing quantile calculation along each image
|
| 240 |
+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
| 241 |
+
|
| 242 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
| 243 |
+
|
| 244 |
+
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
| 245 |
+
s = torch.clamp(
|
| 246 |
+
s, min=1, max=self.config.sample_max_value
|
| 247 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
| 248 |
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
| 249 |
+
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
| 250 |
+
|
| 251 |
+
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
| 252 |
+
sample = sample.to(dtype)
|
| 253 |
+
|
| 254 |
+
return sample
|
| 255 |
+
|
| 256 |
+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
|
| 257 |
+
def _sigma_to_t(self, sigma):
|
| 258 |
+
return sigma * self.config.num_train_timesteps
|
| 259 |
+
|
| 260 |
+
def _sigma_to_alpha_sigma_t(self, sigma):
|
| 261 |
+
return 1 - sigma, sigma
|
| 262 |
+
|
| 263 |
+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
|
| 264 |
+
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
| 265 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 266 |
+
|
| 267 |
+
def convert_model_output(self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, **kwargs) -> torch.Tensor:
|
| 268 |
+
r"""
|
| 269 |
+
Convert the model output to the corresponding type the UniPC algorithm needs.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
model_output (`torch.Tensor`):
|
| 273 |
+
The direct output from the learned diffusion model.
|
| 274 |
+
timestep (`int`):
|
| 275 |
+
The current discrete timestep in the diffusion chain.
|
| 276 |
+
sample (`torch.Tensor`):
|
| 277 |
+
A current instance of a sample created by the diffusion process.
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
`torch.Tensor`:
|
| 281 |
+
The converted model output.
|
| 282 |
+
"""
|
| 283 |
+
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
| 284 |
+
if sample is None:
|
| 285 |
+
if len(args) > 1:
|
| 286 |
+
sample = args[1]
|
| 287 |
+
else:
|
| 288 |
+
raise ValueError("missing `sample` as a required keyward argument")
|
| 289 |
+
if timestep is not None:
|
| 290 |
+
deprecate(
|
| 291 |
+
"timesteps",
|
| 292 |
+
"1.0.0",
|
| 293 |
+
"Passing `timesteps` is deprecated "
|
| 294 |
+
"and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
sigma = self.sigmas[self.step_index]
|
| 298 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 299 |
+
|
| 300 |
+
if self.predict_x0:
|
| 301 |
+
if self.config.prediction_type == "flow_prediction":
|
| 302 |
+
sigma_t = self.sigmas[self.step_index]
|
| 303 |
+
x0_pred = sample - sigma_t * model_output
|
| 304 |
+
else:
|
| 305 |
+
raise ValueError(
|
| 306 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
|
| 307 |
+
" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
if self.config.thresholding:
|
| 311 |
+
x0_pred = self._threshold_sample(x0_pred)
|
| 312 |
+
|
| 313 |
+
return x0_pred
|
| 314 |
+
else:
|
| 315 |
+
if self.config.prediction_type == "flow_prediction":
|
| 316 |
+
sigma_t = self.sigmas[self.step_index]
|
| 317 |
+
epsilon = sample - (1 - sigma_t) * model_output
|
| 318 |
+
else:
|
| 319 |
+
raise ValueError(
|
| 320 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
|
| 321 |
+
" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
if self.config.thresholding:
|
| 325 |
+
sigma_t = self.sigmas[self.step_index]
|
| 326 |
+
x0_pred = sample - sigma_t * model_output
|
| 327 |
+
x0_pred = self._threshold_sample(x0_pred)
|
| 328 |
+
epsilon = model_output + x0_pred
|
| 329 |
+
|
| 330 |
+
return epsilon
|
| 331 |
+
|
| 332 |
+
def multistep_uni_p_bh_update(
|
| 333 |
+
self,
|
| 334 |
+
model_output: torch.Tensor,
|
| 335 |
+
*args,
|
| 336 |
+
sample: Optional[torch.Tensor] = None,
|
| 337 |
+
order: Optional[int] = None, # pyright: ignore
|
| 338 |
+
**kwargs,
|
| 339 |
+
) -> torch.Tensor:
|
| 340 |
+
"""
|
| 341 |
+
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
model_output (`torch.Tensor`):
|
| 345 |
+
The direct output from the learned diffusion model at the current timestep.
|
| 346 |
+
prev_timestep (`int`):
|
| 347 |
+
The previous discrete timestep in the diffusion chain.
|
| 348 |
+
sample (`torch.Tensor`):
|
| 349 |
+
A current instance of a sample created by the diffusion process.
|
| 350 |
+
order (`int`):
|
| 351 |
+
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
|
| 352 |
+
|
| 353 |
+
Returns:
|
| 354 |
+
`torch.Tensor`:
|
| 355 |
+
The sample tensor at the previous timestep.
|
| 356 |
+
"""
|
| 357 |
+
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
|
| 358 |
+
if sample is None:
|
| 359 |
+
if len(args) > 1:
|
| 360 |
+
sample = args[1]
|
| 361 |
+
else:
|
| 362 |
+
raise ValueError(" missing `sample` as a required keyward argument")
|
| 363 |
+
if order is None:
|
| 364 |
+
if len(args) > 2:
|
| 365 |
+
order = args[2]
|
| 366 |
+
else:
|
| 367 |
+
raise ValueError(" missing `order` as a required keyward argument")
|
| 368 |
+
if prev_timestep is not None:
|
| 369 |
+
deprecate(
|
| 370 |
+
"prev_timestep",
|
| 371 |
+
"1.0.0",
|
| 372 |
+
"Passing `prev_timestep` is deprecated "
|
| 373 |
+
"and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 374 |
+
)
|
| 375 |
+
model_output_list = self.model_outputs
|
| 376 |
+
|
| 377 |
+
s0 = self.timestep_list[-1]
|
| 378 |
+
m0 = model_output_list[-1]
|
| 379 |
+
x = sample
|
| 380 |
+
|
| 381 |
+
if self.solver_p:
|
| 382 |
+
x_t = self.solver_p.step(model_output, s0, x).prev_sample
|
| 383 |
+
return x_t
|
| 384 |
+
|
| 385 |
+
sigma_t, sigma_s0 = (self.sigmas[self.step_index + 1], self.sigmas[self.step_index]) # pyright: ignore
|
| 386 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 387 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 388 |
+
|
| 389 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 390 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 391 |
+
|
| 392 |
+
h = lambda_t - lambda_s0
|
| 393 |
+
device = sample.device
|
| 394 |
+
|
| 395 |
+
rks = []
|
| 396 |
+
D1s: Optional[List[Any]] = []
|
| 397 |
+
for i in range(1, order):
|
| 398 |
+
si = self.step_index - i # pyright: ignore
|
| 399 |
+
mi = model_output_list[-(i + 1)]
|
| 400 |
+
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
| 401 |
+
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
| 402 |
+
rk = (lambda_si - lambda_s0) / h
|
| 403 |
+
rks.append(rk)
|
| 404 |
+
D1s.append((mi - m0) / rk) # type: ignore
|
| 405 |
+
|
| 406 |
+
rks.append(1.0)
|
| 407 |
+
rks = torch.tensor(rks, device=device)
|
| 408 |
+
|
| 409 |
+
R = []
|
| 410 |
+
b = []
|
| 411 |
+
|
| 412 |
+
hh = -h if self.predict_x0 else h
|
| 413 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
| 414 |
+
h_phi_k = h_phi_1 / hh - 1
|
| 415 |
+
|
| 416 |
+
factorial_i = 1
|
| 417 |
+
|
| 418 |
+
if self.config.solver_type == "bh1":
|
| 419 |
+
B_h = hh
|
| 420 |
+
elif self.config.solver_type == "bh2":
|
| 421 |
+
B_h = torch.expm1(hh)
|
| 422 |
+
else:
|
| 423 |
+
raise NotImplementedError()
|
| 424 |
+
|
| 425 |
+
for i in range(1, order + 1):
|
| 426 |
+
R.append(torch.pow(rks, i - 1))
|
| 427 |
+
b.append(h_phi_k * factorial_i / B_h)
|
| 428 |
+
factorial_i *= i + 1
|
| 429 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
| 430 |
+
|
| 431 |
+
R = torch.stack(R)
|
| 432 |
+
b = torch.tensor(b, device=device)
|
| 433 |
+
|
| 434 |
+
if len(D1s) > 0: # type: ignore
|
| 435 |
+
D1s = torch.stack(D1s, dim=1) # (B, K)
|
| 436 |
+
# for order 2, we use a simplified version
|
| 437 |
+
if order == 2:
|
| 438 |
+
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
|
| 439 |
+
else:
|
| 440 |
+
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) # type: ignore
|
| 441 |
+
else:
|
| 442 |
+
D1s = None
|
| 443 |
+
|
| 444 |
+
if self.predict_x0:
|
| 445 |
+
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
| 446 |
+
if D1s is not None:
|
| 447 |
+
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
|
| 448 |
+
else:
|
| 449 |
+
pred_res = 0
|
| 450 |
+
x_t = x_t_ - alpha_t * B_h * pred_res
|
| 451 |
+
else:
|
| 452 |
+
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
| 453 |
+
if D1s is not None:
|
| 454 |
+
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
|
| 455 |
+
else:
|
| 456 |
+
pred_res = 0
|
| 457 |
+
x_t = x_t_ - sigma_t * B_h * pred_res
|
| 458 |
+
|
| 459 |
+
x_t = x_t.to(x.dtype)
|
| 460 |
+
return x_t
|
| 461 |
+
|
| 462 |
+
def multistep_uni_c_bh_update(
|
| 463 |
+
self,
|
| 464 |
+
this_model_output: torch.Tensor,
|
| 465 |
+
*args,
|
| 466 |
+
last_sample: torch.Tensor = None,
|
| 467 |
+
this_sample: torch.Tensor = None,
|
| 468 |
+
order: Optional[int] = None, # pyright: ignore
|
| 469 |
+
**kwargs,
|
| 470 |
+
) -> torch.Tensor:
|
| 471 |
+
"""
|
| 472 |
+
One step for the UniC (B(h) version).
|
| 473 |
+
|
| 474 |
+
Args:
|
| 475 |
+
this_model_output (`torch.Tensor`):
|
| 476 |
+
The model outputs at `x_t`.
|
| 477 |
+
this_timestep (`int`):
|
| 478 |
+
The current timestep `t`.
|
| 479 |
+
last_sample (`torch.Tensor`):
|
| 480 |
+
The generated sample before the last predictor `x_{t-1}`.
|
| 481 |
+
this_sample (`torch.Tensor`):
|
| 482 |
+
The generated sample after the last predictor `x_{t}`.
|
| 483 |
+
order (`int`):
|
| 484 |
+
The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
|
| 485 |
+
|
| 486 |
+
Returns:
|
| 487 |
+
`torch.Tensor`:
|
| 488 |
+
The corrected sample tensor at the current timestep.
|
| 489 |
+
"""
|
| 490 |
+
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
|
| 491 |
+
if last_sample is None:
|
| 492 |
+
if len(args) > 1:
|
| 493 |
+
last_sample = args[1]
|
| 494 |
+
else:
|
| 495 |
+
raise ValueError(" missing`last_sample` as a required keyward argument")
|
| 496 |
+
if this_sample is None:
|
| 497 |
+
if len(args) > 2:
|
| 498 |
+
this_sample = args[2]
|
| 499 |
+
else:
|
| 500 |
+
raise ValueError(" missing`this_sample` as a required keyward argument")
|
| 501 |
+
if order is None:
|
| 502 |
+
if len(args) > 3:
|
| 503 |
+
order = args[3]
|
| 504 |
+
else:
|
| 505 |
+
raise ValueError(" missing`order` as a required keyward argument")
|
| 506 |
+
if this_timestep is not None:
|
| 507 |
+
deprecate(
|
| 508 |
+
"this_timestep",
|
| 509 |
+
"1.0.0",
|
| 510 |
+
"Passing `this_timestep` is deprecated "
|
| 511 |
+
"and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
model_output_list = self.model_outputs
|
| 515 |
+
|
| 516 |
+
m0 = model_output_list[-1]
|
| 517 |
+
x = last_sample
|
| 518 |
+
x_t = this_sample
|
| 519 |
+
model_t = this_model_output
|
| 520 |
+
|
| 521 |
+
sigma_t, sigma_s0 = (self.sigmas[self.step_index], self.sigmas[self.step_index - 1]) # pyright: ignore
|
| 522 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 523 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 524 |
+
|
| 525 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 526 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 527 |
+
|
| 528 |
+
h = lambda_t - lambda_s0
|
| 529 |
+
device = this_sample.device
|
| 530 |
+
|
| 531 |
+
rks = []
|
| 532 |
+
D1s: Optional[List[Any]] = []
|
| 533 |
+
for i in range(1, order):
|
| 534 |
+
si = self.step_index - (i + 1) # pyright: ignore
|
| 535 |
+
mi = model_output_list[-(i + 1)]
|
| 536 |
+
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
| 537 |
+
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
| 538 |
+
rk = (lambda_si - lambda_s0) / h
|
| 539 |
+
rks.append(rk)
|
| 540 |
+
D1s.append((mi - m0) / rk) # type: ignore
|
| 541 |
+
|
| 542 |
+
rks.append(1.0)
|
| 543 |
+
rks = torch.tensor(rks, device=device)
|
| 544 |
+
|
| 545 |
+
R = []
|
| 546 |
+
b = []
|
| 547 |
+
|
| 548 |
+
hh = -h if self.predict_x0 else h
|
| 549 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
| 550 |
+
h_phi_k = h_phi_1 / hh - 1
|
| 551 |
+
|
| 552 |
+
factorial_i = 1
|
| 553 |
+
|
| 554 |
+
if self.config.solver_type == "bh1":
|
| 555 |
+
B_h = hh
|
| 556 |
+
elif self.config.solver_type == "bh2":
|
| 557 |
+
B_h = torch.expm1(hh)
|
| 558 |
+
else:
|
| 559 |
+
raise NotImplementedError()
|
| 560 |
+
|
| 561 |
+
for i in range(1, order + 1):
|
| 562 |
+
R.append(torch.pow(rks, i - 1))
|
| 563 |
+
b.append(h_phi_k * factorial_i / B_h)
|
| 564 |
+
factorial_i *= i + 1
|
| 565 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
| 566 |
+
|
| 567 |
+
R = torch.stack(R)
|
| 568 |
+
b = torch.tensor(b, device=device)
|
| 569 |
+
|
| 570 |
+
if len(D1s) > 0: # type: ignore
|
| 571 |
+
D1s = torch.stack(D1s, dim=1)
|
| 572 |
+
else:
|
| 573 |
+
D1s = None
|
| 574 |
+
|
| 575 |
+
# for order 1, we use a simplified version
|
| 576 |
+
if order == 1:
|
| 577 |
+
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
|
| 578 |
+
else:
|
| 579 |
+
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
|
| 580 |
+
|
| 581 |
+
if self.predict_x0:
|
| 582 |
+
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
| 583 |
+
if D1s is not None:
|
| 584 |
+
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
|
| 585 |
+
else:
|
| 586 |
+
corr_res = 0
|
| 587 |
+
D1_t = model_t - m0
|
| 588 |
+
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
| 589 |
+
else:
|
| 590 |
+
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
| 591 |
+
if D1s is not None:
|
| 592 |
+
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
|
| 593 |
+
else:
|
| 594 |
+
corr_res = 0
|
| 595 |
+
D1_t = model_t - m0
|
| 596 |
+
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
| 597 |
+
x_t = x_t.to(x.dtype)
|
| 598 |
+
return x_t
|
| 599 |
+
|
| 600 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 601 |
+
if schedule_timesteps is None:
|
| 602 |
+
schedule_timesteps = self.timesteps
|
| 603 |
+
|
| 604 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
| 605 |
+
|
| 606 |
+
# The sigma index that is taken for the **very** first `step`
|
| 607 |
+
# is always the second index (or the last index if there is only 1)
|
| 608 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
| 609 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 610 |
+
pos = 1 if len(indices) > 1 else 0
|
| 611 |
+
|
| 612 |
+
return indices[pos].item()
|
| 613 |
+
|
| 614 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
| 615 |
+
def _init_step_index(self, timestep):
|
| 616 |
+
"""
|
| 617 |
+
Initialize the step_index counter for the scheduler.
|
| 618 |
+
"""
|
| 619 |
+
|
| 620 |
+
if self.begin_index is None:
|
| 621 |
+
if isinstance(timestep, torch.Tensor):
|
| 622 |
+
timestep = timestep.to(self.timesteps.device)
|
| 623 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 624 |
+
else:
|
| 625 |
+
self._step_index = self._begin_index
|
| 626 |
+
|
| 627 |
+
def step(
|
| 628 |
+
self,
|
| 629 |
+
model_output: torch.Tensor,
|
| 630 |
+
timestep: Union[int, torch.Tensor],
|
| 631 |
+
sample: torch.Tensor,
|
| 632 |
+
return_dict: bool = True,
|
| 633 |
+
generator=None,
|
| 634 |
+
) -> Union[SchedulerOutput, Tuple]:
|
| 635 |
+
"""
|
| 636 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
|
| 637 |
+
the multistep UniPC.
|
| 638 |
+
|
| 639 |
+
Args:
|
| 640 |
+
model_output (`torch.Tensor`):
|
| 641 |
+
The direct output from learned diffusion model.
|
| 642 |
+
timestep (`int`):
|
| 643 |
+
The current discrete timestep in the diffusion chain.
|
| 644 |
+
sample (`torch.Tensor`):
|
| 645 |
+
A current instance of a sample created by the diffusion process.
|
| 646 |
+
return_dict (`bool`):
|
| 647 |
+
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
| 648 |
+
|
| 649 |
+
Returns:
|
| 650 |
+
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
| 651 |
+
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
| 652 |
+
tuple is returned where the first element is the sample tensor.
|
| 653 |
+
|
| 654 |
+
"""
|
| 655 |
+
if self.num_inference_steps is None:
|
| 656 |
+
raise ValueError(
|
| 657 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
if self.step_index is None: # type: ignore
|
| 661 |
+
self._init_step_index(timestep)
|
| 662 |
+
|
| 663 |
+
use_corrector = (
|
| 664 |
+
self.step_index > 0
|
| 665 |
+
and self.step_index - 1 not in self.disable_corrector
|
| 666 |
+
and self.last_sample is not None # pyright: ignore
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
model_output_convert = self.convert_model_output(model_output, sample=sample)
|
| 670 |
+
if use_corrector:
|
| 671 |
+
sample = self.multistep_uni_c_bh_update(
|
| 672 |
+
this_model_output=model_output_convert, last_sample=self.last_sample, this_sample=sample, order=self.this_order
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
for i in range(self.config.solver_order - 1):
|
| 676 |
+
self.model_outputs[i] = self.model_outputs[i + 1]
|
| 677 |
+
self.timestep_list[i] = self.timestep_list[i + 1]
|
| 678 |
+
|
| 679 |
+
self.model_outputs[-1] = model_output_convert
|
| 680 |
+
self.timestep_list[-1] = timestep # pyright: ignore
|
| 681 |
+
|
| 682 |
+
if self.config.lower_order_final:
|
| 683 |
+
this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) # pyright: ignore
|
| 684 |
+
else:
|
| 685 |
+
this_order = self.config.solver_order
|
| 686 |
+
|
| 687 |
+
self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
|
| 688 |
+
assert self.this_order > 0
|
| 689 |
+
|
| 690 |
+
self.last_sample = sample
|
| 691 |
+
prev_sample = self.multistep_uni_p_bh_update(
|
| 692 |
+
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
|
| 693 |
+
sample=sample,
|
| 694 |
+
order=self.this_order,
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
if self.lower_order_nums < self.config.solver_order:
|
| 698 |
+
self.lower_order_nums += 1
|
| 699 |
+
|
| 700 |
+
# upon completion increase step index by one
|
| 701 |
+
self._step_index += 1 # pyright: ignore
|
| 702 |
+
|
| 703 |
+
if not return_dict:
|
| 704 |
+
return (prev_sample,)
|
| 705 |
+
|
| 706 |
+
return SchedulerOutput(prev_sample=prev_sample)
|
| 707 |
+
|
| 708 |
+
def step_ddim(
|
| 709 |
+
# https://github.com/yifan123/flow_grpo/blob/main/flow_grpo/diffusers_patch/sd3_sde_with_logprob.py
|
| 710 |
+
self,
|
| 711 |
+
velocity: torch.FloatTensor,
|
| 712 |
+
t: int,
|
| 713 |
+
curr_state: torch.FloatTensor,
|
| 714 |
+
prev_state: Optional[torch.FloatTensor] = None,
|
| 715 |
+
generator: Optional[torch.Generator] = None,
|
| 716 |
+
):
|
| 717 |
+
"""
|
| 718 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
|
| 719 |
+
the multistep UniPC.
|
| 720 |
+
|
| 721 |
+
Args:
|
| 722 |
+
model_output (`torch.Tensor`):
|
| 723 |
+
The direct output from learned diffusion model.
|
| 724 |
+
timestep (`int`):
|
| 725 |
+
The current discrete timestep in the diffusion chain.
|
| 726 |
+
sample (`torch.Tensor`):
|
| 727 |
+
A current instance of a sample created by the diffusion process.
|
| 728 |
+
return_dict (`bool`):
|
| 729 |
+
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
| 730 |
+
|
| 731 |
+
Returns:
|
| 732 |
+
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
| 733 |
+
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
| 734 |
+
tuple is returned where the first element is the sample tensor.
|
| 735 |
+
|
| 736 |
+
"""
|
| 737 |
+
device = curr_state.device
|
| 738 |
+
curr_t = self.sigmas[t]
|
| 739 |
+
prev_t = self.sigmas[t + 1]
|
| 740 |
+
variance_noise = randn_tensor(curr_state.shape, generator=generator, device=device, dtype=curr_state.dtype)
|
| 741 |
+
cur_clean_ = curr_state - curr_t * velocity
|
| 742 |
+
prev_state = prev_t * variance_noise + (1 - prev_t) * cur_clean_
|
| 743 |
+
|
| 744 |
+
return prev_state
|
| 745 |
+
|
| 746 |
+
def step_sde(
|
| 747 |
+
# https://github.com/yifan123/flow_grpo/blob/main/flow_grpo/diffusers_patch/sd3_sde_with_logprob.py
|
| 748 |
+
self,
|
| 749 |
+
velocity: torch.FloatTensor,
|
| 750 |
+
t: int,
|
| 751 |
+
curr_state: torch.FloatTensor,
|
| 752 |
+
noise_theta: float = 1.0,
|
| 753 |
+
prev_state: Optional[torch.FloatTensor] = None,
|
| 754 |
+
generator: Optional[torch.Generator] = None,
|
| 755 |
+
):
|
| 756 |
+
"""
|
| 757 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
|
| 758 |
+
process from the learned model outputs (most often the predicted velocity).
|
| 759 |
+
|
| 760 |
+
Args:
|
| 761 |
+
velocity (`torch.FloatTensor`): (B, C, T, H, W)
|
| 762 |
+
The direct output from learned flow model.
|
| 763 |
+
timestep (`float`): (B, )
|
| 764 |
+
The current discrete timestep in the diffusion chain.
|
| 765 |
+
curr_state (`torch.FloatTensor`): (B, C, T, H, W)
|
| 766 |
+
A current instance of a sample created by the diffusion process.
|
| 767 |
+
generator (`torch.Generator`, *optional*):
|
| 768 |
+
A random number generator.
|
| 769 |
+
"""
|
| 770 |
+
device = curr_state.device
|
| 771 |
+
curr_t = self.sigmas[t]
|
| 772 |
+
prev_t = self.sigmas[t + 1]
|
| 773 |
+
cos = torch.cos(torch.tensor(noise_theta) * torch.pi / 2).to(device) # if noise_theta is 0, it degenerates to standard flow matching
|
| 774 |
+
sin = torch.sin(torch.tensor(noise_theta) * torch.pi / 2).to(device)
|
| 775 |
+
prev_sample_mean = (1 - prev_t + prev_t * cos) * (curr_state - curr_t * velocity) + prev_t * cos * velocity
|
| 776 |
+
std_dev_t = prev_t * sin
|
| 777 |
+
std_dev_t = torch.ones((1, 1)).to(curr_state) * std_dev_t
|
| 778 |
+
if prev_state is None:
|
| 779 |
+
variance_noise = randn_tensor(curr_state.shape, generator=generator, device=device, dtype=curr_state.dtype)
|
| 780 |
+
prev_state = prev_sample_mean + std_dev_t * variance_noise
|
| 781 |
+
else:
|
| 782 |
+
prev_state = prev_sample_mean + (prev_state - prev_sample_mean.detach())
|
| 783 |
+
|
| 784 |
+
return prev_state
|
| 785 |
+
|
| 786 |
+
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 787 |
+
"""
|
| 788 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
| 789 |
+
current timestep.
|
| 790 |
+
|
| 791 |
+
Args:
|
| 792 |
+
sample (`torch.Tensor`):
|
| 793 |
+
The input sample.
|
| 794 |
+
|
| 795 |
+
Returns:
|
| 796 |
+
`torch.Tensor`:
|
| 797 |
+
A scaled input sample.
|
| 798 |
+
"""
|
| 799 |
+
return sample
|
| 800 |
+
|
| 801 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
|
| 802 |
+
def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
|
| 803 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
| 804 |
+
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 805 |
+
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
| 806 |
+
# mps does not support float64
|
| 807 |
+
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
| 808 |
+
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
| 809 |
+
else:
|
| 810 |
+
schedule_timesteps = self.timesteps.to(original_samples.device)
|
| 811 |
+
timesteps = timesteps.to(original_samples.device)
|
| 812 |
+
|
| 813 |
+
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
|
| 814 |
+
if self.begin_index is None:
|
| 815 |
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
| 816 |
+
elif self.step_index is not None:
|
| 817 |
+
# add_noise is called after first denoising step (for inpainting)
|
| 818 |
+
step_indices = [self.step_index] * timesteps.shape[0]
|
| 819 |
+
else:
|
| 820 |
+
# add noise is called before first denoising step to create initial latent(img2img)
|
| 821 |
+
step_indices = [self.begin_index] * timesteps.shape[0]
|
| 822 |
+
|
| 823 |
+
sigma = sigmas[step_indices].flatten()
|
| 824 |
+
while len(sigma.shape) < len(original_samples.shape):
|
| 825 |
+
sigma = sigma.unsqueeze(-1)
|
| 826 |
+
|
| 827 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 828 |
+
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
| 829 |
+
return noisy_samples
|
| 830 |
+
|
| 831 |
+
def __len__(self):
|
| 832 |
+
return self.config.num_train_timesteps
|