ethanchern commited on
Commit
873b6ec
·
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +46 -0
  3. .pre-commit-config.yaml +53 -0
  4. README.md +191 -0
  5. example/assets/image.png +3 -0
  6. example/assets/prompt.txt +7 -0
  7. example/base/config.json +16 -0
  8. example/base/run.sh +29 -0
  9. example/distill/config.json +17 -0
  10. example/distill/run.sh +29 -0
  11. example/sr_1080p/config.json +20 -0
  12. example/sr_1080p/run.sh +33 -0
  13. example/sr_540p/config.json +20 -0
  14. example/sr_540p/run.sh +32 -0
  15. inference/common/__init__.py +39 -0
  16. inference/common/arch.py +35 -0
  17. inference/common/config.py +283 -0
  18. inference/common/cpu_offload_wrapper.py +186 -0
  19. inference/common/sequence_schema.py +33 -0
  20. inference/infra/__init__.py +37 -0
  21. inference/infra/checkpoint/__init__.py +20 -0
  22. inference/infra/checkpoint/load_model_checkpoint.py +99 -0
  23. inference/infra/distributed/__init__.py +28 -0
  24. inference/infra/distributed/init_dist_env.py +62 -0
  25. inference/infra/distributed/parallel_state.py +659 -0
  26. inference/infra/distributed/utils.py +47 -0
  27. inference/infra/parallelism/__init__.py +20 -0
  28. inference/infra/parallelism/all_to_all_primitive.py +142 -0
  29. inference/infra/parallelism/gather_scatter_primitive.py +217 -0
  30. inference/infra/parallelism/ulysses_scheduler.py +143 -0
  31. inference/model/dit/__init__.py +18 -0
  32. inference/model/dit/dit_model.py +42 -0
  33. inference/model/dit/dit_module.py +950 -0
  34. inference/model/sa_audio/__init__.py +25 -0
  35. inference/model/sa_audio/sa_audio_model.py +116 -0
  36. inference/model/sa_audio/sa_audio_module.py +478 -0
  37. inference/model/t5_gemma/__init__.py +3 -0
  38. inference/model/t5_gemma/t5_gemma_model.py +43 -0
  39. inference/model/turbo_vaed/__init__.py +4 -0
  40. inference/model/turbo_vaed/turbo_vaed_model.py +33 -0
  41. inference/model/turbo_vaed/turbo_vaed_module.py +1039 -0
  42. inference/model/vae2_2/__init__.py +3 -0
  43. inference/model/vae2_2/vae2_2_model.py +17 -0
  44. inference/model/vae2_2/vae2_2_module.py +1086 -0
  45. inference/pipeline/__init__.py +20 -0
  46. inference/pipeline/data_proxy.py +390 -0
  47. inference/pipeline/entry.py +96 -0
  48. inference/pipeline/pipeline.py +108 -0
  49. inference/pipeline/prompt_process.py +60 -0
  50. inference/pipeline/scheduler_unipc.py +832 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tmp*
2
+ depyf
3
+ torch_compile_cache
4
+
5
+ __pycache__
6
+ *.so
7
+ build
8
+ .coverage_*
9
+ *.egg-info
10
+ *~
11
+ slurm*
12
+ logs
13
+ .vscode
14
+ nsys*
15
+ tmp/*
16
+ .mypy_cache
17
+ output
18
+ *.pyc
19
+ *.log
20
+ .idea
21
+ *.pt
22
+ *.png
23
+ *.jpg
24
+ *.jpeg
25
+ *.gif
26
+ *.mp3
27
+ *.mp4
28
+ *.pickle
29
+ *.nsys-rep
30
+ *.html
31
+ *.mov
32
+ *.safetensors
33
+ *.json
34
+
35
+ # Keep example media assets tracked.
36
+ !example/assets/*.png
37
+ !example/assets/*.mp4
38
+ !example/**/*.json
39
+
40
+ proj*
41
+ .venv
42
+ var
43
+ tags
44
+ fx_graph*.pdf
45
+ /clean_repo.py
46
+ /rm_caches.sh
.pre-commit-config.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: \.patch$
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.4.0
5
+ hooks:
6
+ - id: check-added-large-files
7
+ args:
8
+ - --maxkb=30720
9
+ - id: check-merge-conflict
10
+ - id: check-symlinks
11
+ - id: detect-private-key
12
+ files: (?!.*third_party)^.*$ | (?!.*book)^.*$
13
+ - id: end-of-file-fixer
14
+ - id: trailing-whitespace
15
+ - id: requirements-txt-fixer
16
+ - id: sort-simple-yaml
17
+ - repo: https://github.com/Lucas-C/pre-commit-hooks.git
18
+ rev: v1.5.1
19
+ hooks:
20
+ - id: remove-crlf
21
+ files: (?!.*third_party)^.*$ | (?!.*book)^.*$
22
+ - id: remove-tabs
23
+ name: Tabs remover (C++)
24
+ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps)$
25
+ args: [--whitespaces-count, '2']
26
+ - id: remove-tabs
27
+ name: Tabs remover (Python)
28
+ files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
29
+ args: [--whitespaces-count, '4']
30
+ - repo: https://github.com/psf/black.git
31
+ rev: 23.3.0
32
+ hooks:
33
+ - id: black
34
+ args: [--line-length=127, --skip-string-normalization, --skip-magic-trailing-comma]
35
+ files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
36
+ - repo: https://github.com/pre-commit/mirrors-isort
37
+ rev: v5.10.1
38
+ hooks:
39
+ - id: isort
40
+ args: [--profile=black, --line-length=127, --multi-line=3, --force-grid-wrap=0, --src-path=infra, --src-path=pipeline, --src-path=model]
41
+ files: \.py$
42
+ - repo: https://github.com/PyCQA/autoflake
43
+ rev: v2.3.1
44
+ hooks:
45
+ - id: autoflake
46
+ args: [--remove-all-unused-imports, --remove-unused-variables, --in-place, --ignore-init-module-imports, --ignore-pass-after-docstring]
47
+ files: \.py$
48
+ - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks.git
49
+ rev: v2.9.0
50
+ hooks:
51
+ - id: pretty-format-yaml
52
+ args: [--autofix, --indent, '4']
53
+ additional_dependencies: [setuptools]
README.md ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # daVinci-MagiHuman
4
+
5
+ ### Speed by Simplicity: A Single-Stream Architecture for Fast Audio-Video Generative Foundation Model
6
+
7
+ <p align="center">
8
+ <a href="https://www.sjtu.edu.cn/">SII-GAIR</a> &nbsp;&amp;&nbsp; <a href="https://sand.ai">Sand.ai</a>
9
+ </p>
10
+
11
+ [![Paper](https://img.shields.io/badge/Paper-PDF-red)](https://github.com/GAIR-NLP/daVinci-MagiHuman/blob/main/assets/daVinci_MagiHuman.pdf)
12
+ [![Demo](https://img.shields.io/badge/%F0%9F%A4%97%20Demo-HuggingFace-orange)](https://huggingface.co/spaces/SII-GAIR/daVinci-MagiHuman)
13
+ [![Models](https://img.shields.io/badge/%F0%9F%A4%97%20Models-HuggingFace-yellow)](https://huggingface.co/GAIR-NLP/daVinci-MagiHuman)
14
+ [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
15
+ [![Python](https://img.shields.io/badge/Python-3.12%2B-blue.svg)](https://www.python.org/)
16
+ [![PyTorch](https://img.shields.io/badge/PyTorch-2.9%2B-ee4c2c.svg)](https://pytorch.org/)
17
+
18
+ </div>
19
+
20
+ ## Highlights
21
+
22
+ - **Single-Stream Transformer** — A unified 15B-parameter, 40-layer Transformer that jointly processes text, video, and audio via self-attention only. No cross-attention, no multi-stream complexity.
23
+ - **Exceptional Human-Centric Quality** — Expressive facial performance, natural speech-expression coordination, realistic body motion, and accurate audio-video synchronization.
24
+ - **Multilingual** — Supports Chinese (Mandarin & Cantonese), English, Japanese, Korean, German, and French.
25
+ - **Blazing Fast Inference** — Generates a 5-second 256p video in **2 seconds** and a 5-second 1080p video in **38 seconds** on a single H100 GPU.
26
+ - **State-of-the-Art Results** — Achieves **80.0%** win rate vs Ovi 1.1 and **60.9%** vs LTX 2.3 in pairwise human evaluation over 2,000 comparisons.
27
+ - **Fully Open Source** — We release the complete model stack: base model, distilled model, super-resolution model, and inference code.
28
+
29
+ ## Demo
30
+
31
+ <!--
32
+ To add demo videos:
33
+ 1. Open a GitHub issue on this repo
34
+ 2. Drag & drop your .mp4 files into the issue comment box
35
+ 3. Copy the generated URLs and paste them below
36
+
37
+ Example:
38
+ https://github.com/user-attachments/assets/xxxx-xxxx
39
+ -->
40
+
41
+ https://github.com/user-attachments/assets/PLACEHOLDER_VIDEO_1
42
+
43
+ https://github.com/user-attachments/assets/PLACEHOLDER_VIDEO_2
44
+
45
+ https://github.com/user-attachments/assets/PLACEHOLDER_VIDEO_3
46
+
47
+ ## Architecture
48
+
49
+ <div align="center">
50
+ <img src="assets/architecture.png" width="90%">
51
+ </div>
52
+
53
+ daVinci-MagiHuman uses a single-stream Transformer that takes text tokens, a reference image latent, and noisy video and audio tokens as input, and jointly denoises the video and audio within a unified token sequence.
54
+
55
+ Key design choices:
56
+
57
+ | Component | Description |
58
+ |---|---|
59
+ | **Sandwich Architecture** | First and last 4 layers use modality-specific projections; middle 32 layers share parameters across modalities |
60
+ | **Timestep-Free Denoising** | No explicit timestep embeddings — the model infers the denoising state directly from input latents |
61
+ | **Per-Head Gating** | Learned scalar gates with sigmoid activation on each attention head for training stability |
62
+ | **Unified Conditioning** | Denoising and reference signals handled through a minimal unified interface — no dedicated conditioning branches |
63
+
64
+ ## Performance
65
+
66
+ ### Quantitative Quality Benchmark
67
+
68
+ | Model | Visual Quality ↑ | Text Alignment ↑ | Physical Consistency ↑ | WER ↓ |
69
+ |---|:---:|:---:|:---:|:---:|
70
+ | OVI 1.1 | 4.73 | 4.10 | 4.41 | 40.45% |
71
+ | LTX 2.3 | 4.76 | 4.12 | **4.56** | 19.23% |
72
+ | **daVinci-MagiHuman** | **4.80** | **4.18** | 4.52 | **14.60%** |
73
+
74
+ ### Human Evaluation (2,000 Pairwise Comparisons)
75
+
76
+ | Matchup | daVinci-MagiHuman Win | Tie | Opponent Win |
77
+ |---|:---:|:---:|:---:|
78
+ | vs Ovi 1.1 | **80.0%** | 8.2% | 11.8% |
79
+ | vs LTX 2.3 | **60.9%** | 17.2% | 21.9% |
80
+
81
+ ### Inference Speed (Single H100 GPU, 5-second video)
82
+
83
+ | Resolution | Base (s) | Super-Res (s) | Decode (s) | **Total (s)** |
84
+ |---|:---:|:---:|:---:|:---:|
85
+ | 256p | 1.6 | — | 0.4 | **2.0** |
86
+ | 540p | 1.6 | 5.1 | 1.3 | **8.0** |
87
+ | 1080p | 1.6 | 31.0 | 5.8 | **38.4** |
88
+
89
+ ## Efficient Inference Techniques
90
+
91
+ - **Latent-Space Super-Resolution** — Two-stage pipeline: generate at low resolution, then refine in latent space (not pixel space), avoiding an extra VAE decode-encode round trip.
92
+ - **Turbo VAE Decoder** — A lightweight re-trained decoder that substantially reduces decoding overhead.
93
+ - **Full-Graph Compilation** — [MagiCompiler](https://github.com/sandai/MagiCompiler) fuses operators across Transformer layers for ~1.2x speedup.
94
+ - **Distillation** — DMD-2 distillation enables generation with only 8 denoising steps (no CFG), without sacrificing quality.
95
+
96
+ ## Getting Started
97
+
98
+ ### Option 1: Docker (Recommended)
99
+
100
+ ```bash
101
+ # Pull the MagiCompiler Docker image
102
+ docker pull sandai/magi-compiler:latest
103
+
104
+ # Launch container
105
+ docker run -it --gpus all \
106
+ -v /path/to/models:/models \
107
+ sandai/magi-compiler:latest bash
108
+
109
+ # Install MagiCompiler
110
+ git clone https://github.com/sandai/MagiCompiler
111
+ cd MagiCompiler
112
+ pip install -e . --no-build-isolation --config-settings editable_mode=compat
113
+ cd ..
114
+
115
+ # Clone daVinci-MagiHuman
116
+ git clone https://github.com/GAIR-NLP/daVinci-MagiHuman
117
+ cd daVinci-MagiHuman
118
+ ```
119
+
120
+ ### Option 2: Conda
121
+
122
+ ```bash
123
+ # Create environment
124
+ conda create -n davinci python=3.12
125
+ conda activate davinci
126
+
127
+ # Install PyTorch
128
+ pip install torch==2.9.0 torchvision==0.24.0 torchaudio==2.9.0
129
+
130
+ # Install Flash Attention (Hopper)
131
+ git clone https://github.com/Dao-AILab/flash-attention
132
+ cd flash-attention/hopper && python setup.py install && cd ../..
133
+
134
+ # Install MagiCompiler
135
+ git clone https://github.com/sandai/MagiCompiler
136
+ cd MagiCompiler
137
+ pip install -e . --no-build-isolation --config-settings editable_mode=compat
138
+ cd ..
139
+
140
+ # Clone and install daVinci-MagiHuman
141
+ git clone https://github.com/GAIR-NLP/daVinci-MagiHuman
142
+ cd daVinci-MagiHuman
143
+ pip install -r requirements.txt
144
+ ```
145
+
146
+ ### Download Model Checkpoints
147
+
148
+ Download the complete model stack from [HuggingFace](https://huggingface.co/GAIR-NLP/daVinci-MagiHuman) and update the paths in the config files under `example/`.
149
+
150
+ ## Usage
151
+
152
+ Before running, update the checkpoint paths in the config files (`example/*/config.json`) to point to your local model directory.
153
+
154
+ **Base Model (256p)**
155
+ ```bash
156
+ bash example/base/run.sh
157
+ ```
158
+
159
+ **Distilled Model (256p, 8 steps, no CFG)**
160
+ ```bash
161
+ bash example/distill/run.sh
162
+ ```
163
+
164
+ **Super-Resolution to 540p**
165
+ ```bash
166
+ bash example/sr_540p/run.sh
167
+ ```
168
+
169
+ **Super-Resolution to 1080p**
170
+ ```bash
171
+ bash example/sr_1080p/run.sh
172
+ ```
173
+
174
+ ## Citation
175
+
176
+ ```bibtex
177
+ @misc{davinci-magihuman-2025,
178
+ title = {Speed by Simplicity: A Single-Stream Architecture for Fast Audio-Video Generative Foundation Model},
179
+ author = {SII-GAIR and Sand.ai},
180
+ year = {2025},
181
+ url = {https://github.com/GAIR-NLP/daVinci-MagiHuman}
182
+ }
183
+ ```
184
+
185
+ ## Acknowledgements
186
+
187
+ daVinci-MagiHuman builds upon several outstanding open-source projects, including [Wan2.2](https://github.com/Wan-Video/Wan2.2), [Flash Attention](https://github.com/Dao-AILab/flash-attention), and [Turbo-VAED](https://github.com/zou-group/turbo-vaed). We thank the broader open-source community for making this work possible.
188
+
189
+ ## License
190
+
191
+ This project is released under the [Apache License 2.0](https://opensource.org/licenses/Apache-2.0).
example/assets/image.png ADDED

Git LFS Details

  • SHA256: 0659ddf2d52dea107c8437889d850400929901676916ba3c5fe5feab4b116f65
  • Pointer size: 132 Bytes
  • Size of remote file: 1.01 MB
example/assets/prompt.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ "A man with dark hair and glasses, wearing a green button-up shirt and black gloves, stands behind a counter with a pan, gesturing with his left hand, while a blonde woman with her hair in a bun, dressed in a white shirt, holds a microphone to her mouth, looking intently at the pan. The scene is set outdoors under a bright, overcast sky, with decorative palm tree cutouts and a light green vintage Volkswagen bus visible in the background, suggesting a relaxed, possibly tropical, cooking demonstration. The overall emotional disposition is one of focused engagement and professional presentation. The camera maintains a static medium shot, capturing both individuals from the waist up, with a shallow depth of field that keeps them sharp while blurring the background elements. The lighting is bright and even, typical of outdoor daylight, with soft shadows. The color grading is natural and vibrant, reflecting the outdoor setting. The man, with a slight smile, explains in a clear, steady, and informative tone, ""Pulver mit dran gemacht, gibt's ja auch als Paste, aber als Pulver ist das hier ein bisschen..."" as he gestures towards the pan with his left hand, his right hand resting on the counter. The woman listens attentively, her eyebrows slightly raised, her mouth slightly open in an expression of curiosity and concentration, her gaze fixed on the pan.
2
+
3
+ Dialogue:
4
+ <Man in green shirt, German>: ""Pulver mit dran gemacht, gibt's ja auch als Paste, aber als Pulver ist das hier ein bisschen...""
5
+
6
+ Background Sound:
7
+ <No prominent background sound effects>"
example/base/config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "engine_config": {
3
+ "load": "/home/niubility2/hongyu/open_source_ckpt/base",
4
+ "cp_size": 1
5
+ },
6
+ "evaluation_config": {
7
+ "cfg_number": 2,
8
+ "num_inference_steps": 32,
9
+ "audio_model_path": "/home/niubility2/hongyu/open_source_ckpt/audio",
10
+ "txt_model_path": "/home/niubility2/hongyu/open_source_ckpt/t5/t5gemma-9b-9b-ul2",
11
+ "vae_model_path": "/home/niubility2/hongyu/open_source_ckpt/wan_vae/Wan2.2-TI2V-5B",
12
+ "use_turbo_vae": true,
13
+ "student_config_path": "/home/niubility2/hongyu/open_source_ckpt/turbo_vae/TurboV3-Wan22-TinyShallow_7_7.json",
14
+ "student_ckpt_path": "/home/niubility2/hongyu/open_source_ckpt/turbo_vae/checkpoint-340000.ckpt"
15
+ }
16
+ }
example/base/run.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -euo pipefail
4
+
5
+ PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
6
+ cd "$PROJECT_ROOT"
7
+
8
+ export MASTER_ADDR="${MASTER_ADDR:-localhost}"
9
+ export MASTER_PORT="${MASTER_PORT:-6009}"
10
+ export NNODES="${NNODES:-1}"
11
+ export NODE_RANK="${NODE_RANK:-0}"
12
+ export GPUS_PER_NODE="${GPUS_PER_NODE:-1}"
13
+ export WORLD_SIZE="$((GPUS_PER_NODE * NNODES))"
14
+
15
+ export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
16
+ export NCCL_ALGO="${NCCL_ALGO:-^NVLS}"
17
+ export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}"
18
+
19
+ DISTRIBUTED_ARGS="--nnodes=${NNODES} --node_rank=${NODE_RANK} --nproc_per_node=${GPUS_PER_NODE} --rdzv-backend=c10d --rdzv-endpoint=${MASTER_ADDR}:${MASTER_PORT}"
20
+
21
+ torchrun ${DISTRIBUTED_ARGS} inference/pipeline/entry.py \
22
+ --config-load-path example/base/config.json \
23
+ --prompt "$(<example/assets/prompt.txt)" \
24
+ --image_path example/assets/image.png \
25
+ --seconds 10 \
26
+ --br_width 448 \
27
+ --br_height 256 \
28
+ --output_path "output_example_base_$(date '+%Y%m%d_%H%M%S')" \
29
+ 2>&1 | tee "log_example_base_$(date '+%Y%m%d_%H%M%S').log"
example/distill/config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "engine_config": {
3
+ "load": "/home/niubility2/hongyu/open_source_ckpt/distill",
4
+ "distill": true,
5
+ "cp_size": 1
6
+ },
7
+ "evaluation_config": {
8
+ "cfg_number": 1,
9
+ "num_inference_steps": 8,
10
+ "audio_model_path": "/home/niubility2/hongyu/open_source_ckpt/audio",
11
+ "txt_model_path": "/home/niubility2/hongyu/open_source_ckpt/t5/t5gemma-9b-9b-ul2",
12
+ "vae_model_path": "/home/niubility2/hongyu/open_source_ckpt/wan_vae/Wan2.2-TI2V-5B",
13
+ "use_turbo_vae": true,
14
+ "student_config_path": "/home/niubility2/hongyu/open_source_ckpt/turbo_vae/TurboV3-Wan22-TinyShallow_7_7.json",
15
+ "student_ckpt_path": "/home/niubility2/hongyu/open_source_ckpt/turbo_vae/checkpoint-340000.ckpt"
16
+ }
17
+ }
example/distill/run.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -euo pipefail
4
+
5
+ PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
6
+ cd "$PROJECT_ROOT"
7
+
8
+ export MASTER_ADDR="${MASTER_ADDR:-localhost}"
9
+ export MASTER_PORT="${MASTER_PORT:-6010}"
10
+ export NNODES="${NNODES:-1}"
11
+ export NODE_RANK="${NODE_RANK:-0}"
12
+ export GPUS_PER_NODE="${GPUS_PER_NODE:-1}"
13
+ export WORLD_SIZE="$((GPUS_PER_NODE * NNODES))"
14
+
15
+ export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
16
+ export NCCL_ALGO="${NCCL_ALGO:-^NVLS}"
17
+ export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}"
18
+
19
+ DISTRIBUTED_ARGS="--nnodes=${NNODES} --node_rank=${NODE_RANK} --nproc_per_node=${GPUS_PER_NODE} --rdzv-backend=c10d --rdzv-endpoint=${MASTER_ADDR}:${MASTER_PORT}"
20
+
21
+ torchrun ${DISTRIBUTED_ARGS} inference/pipeline/entry.py \
22
+ --config-load-path example/distill/config.json \
23
+ --prompt "$(<example/assets/prompt.txt)" \
24
+ --image_path example/assets/image.png \
25
+ --seconds 10 \
26
+ --br_width 448 \
27
+ --br_height 256 \
28
+ --output_path "output_example_distill_$(date '+%Y%m%d_%H%M%S')" \
29
+ 2>&1 | tee "log_example_distill_$(date '+%Y%m%d_%H%M%S').log"
example/sr_1080p/config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "engine_config": {
3
+ "load": "/home/niubility2/hongyu/open_source_ckpt/base",
4
+ "cp_size": 1
5
+ },
6
+ "evaluation_config": {
7
+ "cfg_number": 2,
8
+ "num_inference_steps": 32,
9
+ "audio_model_path": "/home/niubility2/hongyu/open_source_ckpt/audio",
10
+ "txt_model_path": "/home/niubility2/hongyu/open_source_ckpt/t5/t5gemma-9b-9b-ul2",
11
+ "vae_model_path": "/home/niubility2/hongyu/open_source_ckpt/wan_vae/Wan2.2-TI2V-5B",
12
+ "use_sr_model": true,
13
+ "sr_model_path": "/home/niubility2/hongyu/open_source_ckpt/1080p_sr",
14
+ "sr_num_inference_steps": 5,
15
+ "sr_cfg_number": 1,
16
+ "use_turbo_vae": true,
17
+ "student_config_path": "/home/niubility2/hongyu/open_source_ckpt/turbo_vae/TurboV3-Wan22-TinyShallow_7_7.json",
18
+ "student_ckpt_path": "/home/niubility2/hongyu/open_source_ckpt/turbo_vae/checkpoint-340000.ckpt"
19
+ }
20
+ }
example/sr_1080p/run.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -euo pipefail
4
+
5
+ PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
6
+ cd "$PROJECT_ROOT"
7
+
8
+ export MASTER_ADDR="${MASTER_ADDR:-localhost}"
9
+ export MASTER_PORT="${MASTER_PORT:-6012}"
10
+ export NNODES="${NNODES:-1}"
11
+ export NODE_RANK="${NODE_RANK:-0}"
12
+ export GPUS_PER_NODE="${GPUS_PER_NODE:-1}"
13
+ export WORLD_SIZE="$((GPUS_PER_NODE * NNODES))"
14
+
15
+ export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
16
+ export NCCL_ALGO="${NCCL_ALGO:-^NVLS}"
17
+ export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}"
18
+ export SR2_1080="${SR2_1080:-true}"
19
+ export CPU_OFFLOAD="${CPU_OFFLOAD:-true}"
20
+
21
+ DISTRIBUTED_ARGS="--nnodes=${NNODES} --node_rank=${NODE_RANK} --nproc_per_node=${GPUS_PER_NODE} --rdzv-backend=c10d --rdzv-endpoint=${MASTER_ADDR}:${MASTER_PORT}"
22
+
23
+ torchrun ${DISTRIBUTED_ARGS} inference/pipeline/entry.py \
24
+ --config-load-path example/sr_1080p/config.json \
25
+ --prompt "$(<example/assets/prompt.txt)" \
26
+ --image_path example/assets/image.png \
27
+ --seconds 10 \
28
+ --br_width 448 \
29
+ --br_height 256 \
30
+ --output_path "output_example_sr_1080p_$(date '+%Y%m%d_%H%M%S')" \
31
+ --sr_width 1920 \
32
+ --sr_height 1088 \
33
+ 2>&1 | tee "log_example_sr_1080p_$(date '+%Y%m%d_%H%M%S').log"
example/sr_540p/config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "engine_config": {
3
+ "load": "/home/niubility2/hongyu/open_source_ckpt/base",
4
+ "cp_size": 1
5
+ },
6
+ "evaluation_config": {
7
+ "cfg_number": 2,
8
+ "num_inference_steps": 32,
9
+ "audio_model_path": "/home/niubility2/hongyu/open_source_ckpt/audio",
10
+ "txt_model_path": "/home/niubility2/hongyu/open_source_ckpt/t5/t5gemma-9b-9b-ul2",
11
+ "vae_model_path": "/home/niubility2/hongyu/open_source_ckpt/wan_vae/Wan2.2-TI2V-5B",
12
+ "use_sr_model": true,
13
+ "sr_model_path": "/home/niubility2/hongyu/open_source_ckpt/540p_sr",
14
+ "sr_num_inference_steps": 5,
15
+ "sr_cfg_number": 1,
16
+ "use_turbo_vae": true,
17
+ "student_config_path": "/home/niubility2/hongyu/open_source_ckpt/turbo_vae/TurboV3-Wan22-TinyShallow_7_7.json",
18
+ "student_ckpt_path": "/home/niubility2/hongyu/open_source_ckpt/turbo_vae/checkpoint-340000.ckpt"
19
+ }
20
+ }
example/sr_540p/run.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ set -euo pipefail
4
+
5
+ PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
6
+ cd "$PROJECT_ROOT"
7
+
8
+ export MASTER_ADDR="${MASTER_ADDR:-localhost}"
9
+ export MASTER_PORT="${MASTER_PORT:-6011}"
10
+ export NNODES="${NNODES:-1}"
11
+ export NODE_RANK="${NODE_RANK:-0}"
12
+ export GPUS_PER_NODE="${GPUS_PER_NODE:-1}"
13
+ export WORLD_SIZE="$((GPUS_PER_NODE * NNODES))"
14
+
15
+ export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}"
16
+ export NCCL_ALGO="${NCCL_ALGO:-^NVLS}"
17
+ export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}"
18
+ export CPU_OFFLOAD=true
19
+
20
+ DISTRIBUTED_ARGS="--nnodes=${NNODES} --node_rank=${NODE_RANK} --nproc_per_node=${GPUS_PER_NODE} --rdzv-backend=c10d --rdzv-endpoint=${MASTER_ADDR}:${MASTER_PORT}"
21
+
22
+ torchrun ${DISTRIBUTED_ARGS} inference/pipeline/entry.py \
23
+ --config-load-path example/sr_540p/config.json \
24
+ --prompt "$(<example/assets/prompt.txt)" \
25
+ --image_path example/assets/image.png \
26
+ --seconds 10 \
27
+ --br_width 448 \
28
+ --br_height 256 \
29
+ --output_path "output_example_sr_540p_$(date '+%Y%m%d_%H%M%S')" \
30
+ --sr_width 896 \
31
+ --sr_height 512 \
32
+ 2>&1 | tee "log_example_sr_540p_$(date '+%Y%m%d_%H%M%S').log"
inference/common/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .arch import get_arch_memory, is_hopper_arch
16
+ from .config import (
17
+ DataProxyConfig,
18
+ EngineConfig,
19
+ EvaluationConfig,
20
+ parse_config,
21
+ )
22
+ from .cpu_offload_wrapper import CPUOffloadWrapper
23
+ from .sequence_schema import Modality, VarlenHandler
24
+
25
+ __all__ = [
26
+ # arch
27
+ "get_arch_memory",
28
+ "is_hopper_arch",
29
+ # config
30
+ "EngineConfig",
31
+ "DataProxyConfig",
32
+ "EvaluationConfig",
33
+ "parse_config",
34
+ # cpu offload wrapper
35
+ "CPUOffloadWrapper",
36
+ # sequence schema
37
+ "Modality",
38
+ "VarlenHandler",
39
+ ]
inference/common/arch.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+
18
+ def is_hopper_arch():
19
+ return torch.cuda.get_device_capability()[0] == 9
20
+
21
+
22
+ def get_arch_memory(unit: str = "GB"):
23
+ if not torch.cuda.is_available():
24
+ return 0
25
+ total_bytes = torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory
26
+ if unit == "B":
27
+ return float(total_bytes)
28
+ elif unit == "KB":
29
+ return total_bytes / 1024
30
+ elif unit == "MB":
31
+ return total_bytes / 1024 / 1024
32
+ elif unit == "GB":
33
+ return total_bytes / 1024 / 1024 / 1024
34
+ else:
35
+ raise ValueError(f"Invalid unit: {unit}")
inference/common/config.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import copy
17
+ import json
18
+ import os
19
+ import sys
20
+ from pathlib import Path
21
+ from typing import Literal, Tuple
22
+
23
+ import torch
24
+ from inference.utils import env_is_true, print_rank_0
25
+ from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator, model_validator
26
+ from pydantic_settings import (
27
+ BaseSettings,
28
+ CliSettingsSource,
29
+ JsonConfigSettingsSource,
30
+ PydanticBaseSettingsSource,
31
+ SettingsConfigDict,
32
+ )
33
+
34
+
35
+ class EngineConfig(BaseModel):
36
+ # Basic settings
37
+ seed: int = Field(1234, description="Random seed used for python, numpy, pytorch, and cuda.")
38
+ load: str | None = Field(None, description="Directory containing a model checkpoint.")
39
+
40
+ # Parallelism strategy
41
+ distributed_backend: Literal["nccl", "gloo"] = Field("nccl", description="Distributed backend. Choices: ['nccl', 'gloo'].")
42
+ distributed_timeout_minutes: int = Field(10, description="Timeout minutes for torch.distributed.")
43
+ sequence_parallel: bool = Field(False, description="Enable sequence parallel optimization.")
44
+ tp_size: int = Field(1, description="Degree of tensor model parallelism.")
45
+ pp_size: int = Field(1, description="Degree of pipeline model parallelism.")
46
+ cp_size: int = Field(1, description="Degree of context parallelism.")
47
+ dp_size: int = Field(1, description="Degree of data parallelism.")
48
+
49
+
50
+ class ModelConfig(BaseModel):
51
+ """Model configuration class defining various parameters for video generation model"""
52
+
53
+ model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
54
+
55
+ num_layers: int = Field(default=40, description="Number of Transformer layers")
56
+ hidden_size: int = Field(default=5120, description="Hidden size of the Transformer model")
57
+ head_dim: int = Field(default=128, description="Dimension per attention head")
58
+ num_query_groups: int = Field(default=8, description="Number of query groups for grouped-query attention")
59
+ video_in_channels: int = Field(default=48 * 4, description="Number of video input channels after patch embedding")
60
+ audio_in_channels: int = Field(default=64, description="Number of audio input channels")
61
+ text_in_channels: int = Field(default=3584, description="Number of text input channels")
62
+ checkpoint_qk_layernorm_rope: bool = Field(default=False, description="Enable checkpointing for QK layernorm + RoPE")
63
+ params_dtype: torch.dtype | str = Field(default=torch.float32, description="Parameter dtype")
64
+ tread_config: dict = Field(
65
+ default=dict(
66
+ selection_rate=0.5, start_layer_idx=2, end_layer_idx=25 # after forward of 0, 1 # before forward of 26 27 28 29
67
+ ),
68
+ description="TReAD (Token Routing and Early Drop) configuration",
69
+ )
70
+ mm_layers: list[int] = Field(default=[0, 1, 2, 3, 36, 37, 38, 39], description="Indices of multimodal fusion layers")
71
+ local_attn_layers: list[int] = Field(default=[], description="Indices of local attention layers")
72
+ enable_attn_gating: bool = Field(default=True, description="Enable attention gating")
73
+ activation_type: str = Field(default="swiglu7", description="Activation type")
74
+ gelu7_layers: list[int] = Field(default=[0, 1, 2, 3], description="Indices of gelu7 layers")
75
+
76
+ # Add computed fields
77
+ num_heads_q: int = Field(default=0, description="Number of query heads (calculated from hidden_size // head_dim)")
78
+ num_heads_kv: int = Field(default=0, description="Number of key-value heads (calculated from num_query_groups)")
79
+ post_norm_layers: list[int] = Field(default=[], description="Indices of post norm layers")
80
+
81
+ @field_serializer("params_dtype")
82
+ def serialize_dtype(self, value: torch.dtype | str) -> str:
83
+ return str(value)
84
+
85
+ @field_validator("params_dtype", mode="before")
86
+ @classmethod
87
+ def validate_dtype(cls, value):
88
+ if isinstance(value, torch.dtype):
89
+ return value
90
+ if isinstance(value, str):
91
+ if value == "torch.float32" or value == "float32":
92
+ return torch.float32
93
+ elif value == "torch.float16" or value == "float16":
94
+ return torch.float16
95
+ elif value == "torch.bfloat16" or value == "bfloat16":
96
+ return torch.bfloat16
97
+ raise ValueError(f"Unknown torch.dtype string: '{value}'")
98
+
99
+
100
+ class DataProxyConfig(BaseModel):
101
+ t_patch_size: int = Field(default=1, description="Patch size for time dimension")
102
+ patch_size: int = Field(default=2, description="Patch size for spatial dimensions")
103
+ frame_receptive_field: int = Field(default=11, description="Frame receptive field")
104
+ spatial_rope_interpolation: Literal["inter", "extra"] = Field(
105
+ default="extra", description="Spatial rope interpolation method."
106
+ )
107
+ ref_audio_offset: int = Field(default=1000, description="Offset for reference audio.")
108
+ text_offset: int = Field(default=0, description="Offset for text.")
109
+ coords_style: Literal["v1", "v2"] = Field(default="v2", description="Coords style.")
110
+
111
+
112
+ class EvaluationConfig(BaseModel):
113
+ """Evaluation configuration class defining parameters for model evaluation and inference"""
114
+
115
+ model_config = ConfigDict(protected_namespaces=())
116
+
117
+ data_proxy_config: DataProxyConfig = Field(default=DataProxyConfig(), description="Data proxy configuration")
118
+
119
+ fps: int = Field(default=25, description="Frames per second for video generation")
120
+ num_inference_steps: int = Field(default=32, description="Number of denoising steps during inference")
121
+ video_txt_guidance_scale: float = Field(default=5.0, description="Video text guidance scale for text conditioning")
122
+ audio_txt_guidance_scale: float = Field(default=5.0, description="Audio text guidance scale for text conditioning")
123
+ txt_encoder_type: Literal["t5_gemma"] = Field(default="t5_gemma", description="Text encoder type.")
124
+ t5_gemma_target_length: int = Field(default=640, description="Target length for T5-Gemma encoder.")
125
+ support_ref_audio: bool = Field(default=True, description="Whether to support the ref_audio feature")
126
+ shift: float = Field(default=5.0, description="Temporal shift parameter for video generation")
127
+ exp_name: str = Field(default="exp_debug", description="Experiment name with evaluation suffix")
128
+ audio_model_path: str = Field(default="", description="Path to the pretrained audio model")
129
+ txt_model_path: str = Field(default="", description="Path to the pretrained txt model")
130
+ vae_model_path: str = Field(default="", description="Path to the pretrained vae model")
131
+ vae_stride: Tuple[int, int, int] = Field(default=(4, 16, 16), description="VAE stride in format (time, height, width)")
132
+ z_dim: int = Field(default=48, description="Dimension of z space.")
133
+ patch_size: Tuple[int, int, int] = Field(default=(1, 2, 2), description="Patch size in format (time, height, width)")
134
+ cfg_number: int = Field(default=2, description="Classifier-free guidance number")
135
+ sr_cfg_number: int = Field(default=2, description="SR Classifier-free guidance number")
136
+
137
+ # flops recording
138
+ enable_flops_recording: bool = Field(default=False, description="Whether to enable flops recording")
139
+
140
+ # super resolution model configuration
141
+ use_sr_model: bool = Field(default=False, description="Whether to use the super resolution model")
142
+ sr_model_path: str = Field(default="", description="Path to the pretrained super resolution model")
143
+ sr_num_inference_steps: int = Field(default=5, description="Number of denoising steps during super resolution inference")
144
+ noise_value: int = Field(default=220, description="Noise value for the super resolution model")
145
+ sr_video_txt_guidance_scale: float = Field(
146
+ default=3.5, description="Super resolution video text guidance scale for text conditioning"
147
+ )
148
+ use_cfg_trick: bool = Field(default=True, description="Whether to use the cfg trick")
149
+ cfg_trick_start_frame: int = Field(default=13, description="Start frame for the cfg trick")
150
+ cfg_trick_value: float = Field(default=2.0, description="Value for the cfg trick")
151
+ using_sde_flag: bool = Field(default=False, description="Whether to use the sde flag")
152
+ sr_audio_noise_scale: float = Field(default=0.7, description="Noise scale for the super resolution audio")
153
+
154
+ # turbo-vae config
155
+ use_turbo_vae: bool = Field(default=True, description="Whether to use the turbo-vae")
156
+ student_config_path: str = Field(default="", description="Path to the student config")
157
+ student_ckpt_path: str = Field(default="", description="Path to the student checkpoint")
158
+
159
+
160
+ class MagiPipelineConfig(BaseSettings):
161
+ engine_config: EngineConfig = Field(description="Engine configuration.", default_factory=EngineConfig)
162
+ arch_config: ModelConfig = Field(default=ModelConfig(), description="Model configuration.")
163
+ evaluation_config: EvaluationConfig = Field(default=EvaluationConfig(), description="Evaluation configuration.")
164
+ sr_arch_config: ModelConfig = Field(default=ModelConfig(), description="Super resolution model configuration.")
165
+ model_config = SettingsConfigDict(cli_parse_args=True, cli_ignore_unknown_args=True, cli_implicit_flags=True)
166
+
167
+ @classmethod
168
+ def settings_customise_sources(
169
+ cls,
170
+ settings_cls: type[BaseSettings],
171
+ init_settings: PydanticBaseSettingsSource,
172
+ env_settings: PydanticBaseSettingsSource,
173
+ dotenv_settings: PydanticBaseSettingsSource,
174
+ file_secret_settings: PydanticBaseSettingsSource,
175
+ ) -> tuple[PydanticBaseSettingsSource, ...]:
176
+ parser = argparse.ArgumentParser(allow_abbrev=False)
177
+ parser.add_argument("--config-load-path", type=str, default=None, help="Path to load the config.json from")
178
+ args, _ = parser.parse_known_args()
179
+ config_load_path = args.config_load_path
180
+ sources = [env_settings, CliSettingsSource(settings_cls, cli_parse_args=True, cli_ignore_unknown_args=True)]
181
+ if config_load_path:
182
+ sources.append(JsonConfigSettingsSource(settings_cls, json_file=config_load_path))
183
+
184
+ sources.extend([init_settings, dotenv_settings, file_secret_settings])
185
+ return tuple(sources)
186
+
187
+ def save_to_json(self, json_path: str, indent: int = 4):
188
+ path = Path(json_path)
189
+ path.parent.mkdir(parents=True, exist_ok=True)
190
+ path.write_text(self.__str__(indent=indent))
191
+
192
+ def __str__(self, indent: int = 4):
193
+ data = self.model_dump(mode="json")
194
+ formatted = json.dumps(data, indent=indent, ensure_ascii=False, sort_keys=False)
195
+ class_name = self.__class__.__name__
196
+ return f"{class_name}:\n{formatted}".replace('"', "")
197
+
198
+ def __repr__(self, indent: int = 4):
199
+ return self.__str__(indent=indent)
200
+
201
+ @model_validator(mode="after")
202
+ def validate_engine_config(self):
203
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
204
+ self.engine_config.dp_size = world_size // (
205
+ self.engine_config.tp_size * self.engine_config.pp_size * self.engine_config.cp_size
206
+ )
207
+
208
+ assert world_size % self.engine_config.tp_size == 0
209
+ tp_pp_size = self.engine_config.tp_size * self.engine_config.pp_size
210
+ assert world_size % tp_pp_size == 0
211
+ tp_pp_cp_size = tp_pp_size * self.engine_config.cp_size
212
+ assert world_size % tp_pp_cp_size == 0
213
+ assert world_size == self.engine_config.dp_size * tp_pp_cp_size
214
+
215
+ if self.engine_config.tp_size == 1:
216
+ self.engine_config.sequence_parallel = False
217
+
218
+ return self
219
+
220
+ @model_validator(mode="after")
221
+ def post_override_config(self):
222
+ self.arch_config.num_heads_q = self.arch_config.hidden_size // self.arch_config.head_dim
223
+ self.arch_config.num_heads_kv = self.arch_config.num_query_groups
224
+
225
+ self.sr_arch_config = copy.deepcopy(self.arch_config)
226
+ if env_is_true("SR2_1080"):
227
+ self.sr_arch_config = copy.deepcopy(self.arch_config)
228
+ # fmt: off
229
+ self.sr_arch_config.local_attn_layers = [
230
+ 0, 1, 2,
231
+ 4, 5, 6,
232
+ 8, 9, 10,
233
+ 12, 13, 14,
234
+ 16, 17, 18,
235
+ 20, 21, 22,
236
+ 24, 25, 26,
237
+ 28, 29, 30,
238
+ 32, 33, 34,
239
+ 35, 36, 37,
240
+ 38, 39,
241
+ ]
242
+ # fmt: on
243
+ self.evaluation_config.sr_video_txt_guidance_scale = 3.5
244
+
245
+ return self
246
+
247
+
248
+ def prevent_unsupported_list_syntax():
249
+ """
250
+ Check sys.argv before Pydantic parsing to prevent using unsupported list syntax.
251
+ """
252
+ args = sys.argv[1:]
253
+ for i, arg in enumerate(args):
254
+ if i + 2 < len(args):
255
+ value1, value2 = args[i + 1], args[i + 2]
256
+ if not value1.startswith("-") and not value2.startswith("-"):
257
+ error_msg = (
258
+ f"\n\nError: Detected list parameter '{arg}' using unsupported command line syntax.\n"
259
+ f"Error pattern: '{arg} {value1} {value2} ...'\n\n"
260
+ "Pydantic (or related libraries) do not support passing lists with space-separated multiple values.\n"
261
+ "Please use one of the following supported formats:\n\n"
262
+ f"1. JSON style: {arg} '[{value1},{value2},...]'\n"
263
+ f"2. Argparse style: {arg} {value1} {arg} {value2}\n"
264
+ f"3. Lazy style: {arg} {value1},{value2}\n"
265
+ )
266
+ raise ValueError(error_msg)
267
+
268
+
269
+ def parse_config(verbose: bool = False) -> MagiPipelineConfig:
270
+ parser = argparse.ArgumentParser(description="Load and optionally save config", allow_abbrev=False)
271
+ parser.add_argument("--config-save-path", type=str, default=None, help="Path to save the config.json to")
272
+ args, _ = parser.parse_known_args()
273
+
274
+ prevent_unsupported_list_syntax()
275
+ config = MagiPipelineConfig()
276
+
277
+ if args.config_save_path is not None:
278
+ config.save_to_json(args.config_save_path)
279
+
280
+ if verbose:
281
+ print_rank_0(config)
282
+
283
+ return config
inference/common/cpu_offload_wrapper.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Callable, Dict, Tuple
16
+
17
+ import torch
18
+
19
+
20
+ class CPUOffloadWrapper:
21
+ def __init__(self, model: Any, is_cpu_offload: bool = False, is_running_on_gpu: bool = True):
22
+ object.__setattr__(self, "model", model)
23
+ object.__setattr__(self, "is_cpu_offload", is_cpu_offload)
24
+ object.__setattr__(self, "is_running_on_gpu", is_running_on_gpu)
25
+
26
+ cpu_device = torch.device("cpu")
27
+ cuda_device = torch.device("cuda")
28
+ object.__setattr__(self, "cpu_device", cpu_device)
29
+ object.__setattr__(self, "cuda_device", cuda_device)
30
+
31
+ # Initialize placement location
32
+ if is_cpu_offload:
33
+ self.model.to(cpu_device)
34
+ else:
35
+ self.model.to(cuda_device)
36
+
37
+ # Whitelist non-compute methods that shouldn't trigger device hops (pass-through only; no device switch)
38
+ object.__setattr__(
39
+ self,
40
+ "_non_compute_methods",
41
+ {
42
+ "to",
43
+ "cpu",
44
+ "cuda",
45
+ "eval",
46
+ "train",
47
+ "state_dict",
48
+ "load_state_dict",
49
+ "parameters",
50
+ "named_parameters",
51
+ "buffers",
52
+ "named_buffers",
53
+ "modules",
54
+ "named_modules",
55
+ "children",
56
+ "named_children",
57
+ "register_forward_hook",
58
+ "register_forward_pre_hook",
59
+ "register_full_backward_hook",
60
+ "zero_grad",
61
+ "share_memory",
62
+ "half",
63
+ "float",
64
+ "bfloat16",
65
+ },
66
+ )
67
+
68
+ # Get current primary device (for external reads)
69
+ @property
70
+ def device(self) -> torch.device:
71
+ if isinstance(self.model, torch.nn.Module):
72
+ return next(self.model.parameters()).device
73
+ else:
74
+ for k, v in self.model.__dict__.items():
75
+ if isinstance(v, torch.Tensor):
76
+ return v.device
77
+ elif isinstance(v, torch.nn.Module):
78
+ return next(v.parameters()).device
79
+ return self.cuda_device
80
+
81
+ def _backup_cpu_state(self) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, Any]]:
82
+ # Backup module parameters and buffers
83
+ module_param_backup = {}
84
+ module_buffer_backup = {}
85
+ other_backup = {}
86
+
87
+ def save_module_state(mod: torch.nn.Module, prefix: str):
88
+ for name, param in mod.named_parameters():
89
+ if param is not None:
90
+ full_key = prefix + name
91
+ module_param_backup[full_key] = param.data
92
+ for name, buffer in mod.named_buffers():
93
+ if buffer is not None:
94
+ full_key = prefix + name
95
+ module_buffer_backup[full_key] = buffer.data
96
+
97
+ if isinstance(self.model, torch.nn.Module):
98
+ save_module_state(self.model, "")
99
+ else:
100
+ for name, attr_val in self.model.__dict__.items():
101
+ if isinstance(attr_val, torch.nn.Module):
102
+ save_module_state(attr_val, name + ".")
103
+ elif isinstance(attr_val, torch.Tensor):
104
+ other_backup[name] = attr_val
105
+
106
+ return module_param_backup, module_buffer_backup, other_backup
107
+
108
+ def _restore_cpu_state(self, backups: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, Any]]):
109
+ # Restore module parameters and buffers
110
+ module_param_backup, module_buffer_backup, other_backup = backups
111
+
112
+ def restore_module_state(mod: torch.nn.Module, prefix: str):
113
+ for name, param in mod.named_parameters():
114
+ full_key = prefix + name
115
+ if full_key in module_param_backup:
116
+ param.data = module_param_backup[full_key]
117
+
118
+ for name, buffer in mod.named_buffers():
119
+ full_key = prefix + name
120
+ if full_key in module_buffer_backup:
121
+ buffer.data = module_buffer_backup[full_key]
122
+
123
+ if isinstance(self.model, torch.nn.Module):
124
+ restore_module_state(self.model, "")
125
+ else:
126
+ for name, attr_val in self.model.__dict__.items():
127
+ if isinstance(attr_val, torch.nn.Module):
128
+ restore_module_state(attr_val, name + ".")
129
+
130
+ if not isinstance(self.model, torch.nn.Module):
131
+ for name, val in other_backup.items():
132
+ setattr(self.model, name, val)
133
+
134
+ # Unified on/offload executor
135
+ def _run_with_optional_offload(self, func: Callable[..., Any], *args, **kwargs):
136
+ if self.is_cpu_offload and self.is_running_on_gpu:
137
+ backups = self._backup_cpu_state()
138
+ self.model.to(self.cuda_device)
139
+ try:
140
+ return func(*args, **kwargs)
141
+ finally:
142
+ if torch.cuda.is_available():
143
+ torch.cuda.synchronize()
144
+ self._restore_cpu_state(backups)
145
+ else:
146
+ # Make sure model and args are on the same device
147
+ args = [
148
+ arg.to(self.device) if isinstance(arg, torch.Tensor) and arg.device != self.device else arg for arg in args
149
+ ]
150
+ kwargs = {
151
+ k: v.to(self.device) if isinstance(v, torch.Tensor) and v.device != self.device else v
152
+ for k, v in kwargs.items()
153
+ }
154
+ return func(*args, **kwargs)
155
+
156
+ # Direct call (equivalent to forward)
157
+ def __call__(self, *args, **kwargs):
158
+ return self._run_with_optional_offload(self.model.__call__, *args, **kwargs)
159
+
160
+ # Explicit forward; some code calls model.forward(...)
161
+ def forward(self, *args, **kwargs):
162
+ return self._run_with_optional_offload(self.model.forward, *args, **kwargs)
163
+
164
+ # Key: passthrough all attrs/methods. For callables, wrap with on/offload; for non-compute methods, pass-through only with no device switch.
165
+ def __getattr__(self, name: str):
166
+ # Fetch attribute from the wrapped model first
167
+ attr = getattr(self.model, name)
168
+
169
+ # Wrap methods (except in whitelist)
170
+ if callable(attr) and name not in self._non_compute_methods:
171
+
172
+ def _wrapped(*args, **kwargs):
173
+ return self._run_with_optional_offload(attr, *args, **kwargs)
174
+
175
+ return _wrapped
176
+
177
+ return attr
178
+
179
+ def __dir__(self):
180
+ return sorted(set(list(super().__dir__()) + dir(self.model)))
181
+
182
+ def __setattr__(self, name: str, value: Any):
183
+ raise AttributeError("CPUOffloadWrapper is immutable")
184
+
185
+ def __repr__(self) -> str:
186
+ return f"CPUOffloadWrapper(is_cpu_offload={self.is_cpu_offload}, is_running_on_gpu={self.is_running_on_gpu}, model={repr(self.model)})"
inference/common/sequence_schema.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from enum import IntEnum
17
+
18
+ import torch
19
+
20
+
21
+ class Modality(IntEnum):
22
+ VIDEO = 0
23
+ AUDIO = 1
24
+ TEXT = 2
25
+
26
+
27
+ @dataclass
28
+ class VarlenHandler:
29
+ cu_seqlens_q: torch.Tensor
30
+ cu_seqlens_k: torch.Tensor
31
+ max_seqlen_q: int
32
+ max_seqlen_k: int
33
+
inference/infra/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+ from inference.common import parse_config
18
+ from inference.infra.distributed import get_dp_rank, initialize_distributed
19
+ from inference.utils import print_rank_0, set_random_seed
20
+
21
+
22
+ def initialize_infra():
23
+ assert torch.cuda.is_available(), "Infra requires CUDA environment."
24
+
25
+ # Initialize distributed environment
26
+ initialize_distributed()
27
+
28
+ # Initialize config
29
+ config = parse_config(verbose=True)
30
+
31
+ # Initialize random seed
32
+ set_random_seed(config.engine_config.seed + 10 * get_dp_rank())
33
+
34
+ print_rank_0("Infra successfully initialized")
35
+
36
+
37
+ __all__ = ["initialize_infra"]
inference/infra/checkpoint/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .load_model_checkpoint import load_model_checkpoint
16
+
17
+ __all__ = [
18
+ # checkpoint loader
19
+ "load_model_checkpoint",
20
+ ]
inference/infra/checkpoint/load_model_checkpoint.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import io
16
+ import json
17
+ import os
18
+ import subprocess
19
+ from concurrent.futures import ThreadPoolExecutor
20
+
21
+ from inference.common import EngineConfig
22
+ from inference.utils import print_rank_0
23
+ from safetensors.torch import load as load_from_bytes
24
+ from safetensors.torch import load_file
25
+ from tqdm.auto import tqdm
26
+
27
+
28
+ def _load_shard(shard_path, param_names, num_threads=None):
29
+ zstd_path = shard_path + ".zst"
30
+ if os.path.exists(zstd_path):
31
+ cmd = ["zstd", "-d"]
32
+ if num_threads:
33
+ cmd.extend(["-T", str(num_threads)]) # set parallelism
34
+
35
+ process = subprocess.Popen(cmd + ["-c", zstd_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=-1)
36
+
37
+ decompressed_data = process.stdout.read()
38
+ while True:
39
+ new_data = process.stdout.read()
40
+ if not new_data:
41
+ break
42
+ decompressed_data += new_data
43
+ process.stdout.close()
44
+
45
+ retcode = process.wait()
46
+ if retcode != 0:
47
+ raise RuntimeError(f"Decompression failed: {process.stderr.read().decode()}")
48
+
49
+ buffer = io.BytesIO(decompressed_data)
50
+ weights = load_from_bytes(buffer.getvalue())
51
+ buffer.close()
52
+ else:
53
+ weights = load_file(shard_path)
54
+
55
+ return {name: weights[name] for name in param_names}
56
+
57
+
58
+ def load_sharded_safetensors_parallel_with_progress(checkpoint_dir):
59
+ index_path = os.path.join(checkpoint_dir, "model.safetensors.index.json")
60
+ if not os.path.exists(index_path):
61
+ model_file_path = os.path.join(checkpoint_dir, "model.safetensors")
62
+ state_dict = load_file(model_file_path)
63
+ return state_dict
64
+
65
+ with open(index_path, "r") as f:
66
+ index = json.load(f)
67
+
68
+ state_dict = {}
69
+ shard_map = {}
70
+
71
+ # Group parameters by shard file
72
+ for param_name, shard_file in index["weight_map"].items():
73
+ shard_path = os.path.join(checkpoint_dir, shard_file)
74
+ if shard_path not in shard_map:
75
+ shard_map[shard_path] = []
76
+ shard_map[shard_path].append(param_name)
77
+
78
+ # Load shards in parallel with a progress bar
79
+ with ThreadPoolExecutor() as executor:
80
+ futures = {
81
+ executor.submit(_load_shard, shard_path, param_names): shard_path for shard_path, param_names in shard_map.items()
82
+ }
83
+ pbar = tqdm(futures, desc="Loading shards", total=len(futures))
84
+ for future in pbar:
85
+ result = future.result()
86
+ state_dict.update(result)
87
+
88
+ return state_dict
89
+
90
+
91
+ def load_model_checkpoint(model, engine_config: EngineConfig):
92
+ print_rank_0("Loading checkpoint with safetensors format from pretrained_folder")
93
+ state_dict = load_sharded_safetensors_parallel_with_progress(engine_config.load)
94
+
95
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
96
+ print_rank_0(f"Load Weight Missing Keys: {missing_keys}")
97
+ print_rank_0(f"Load Weight Unexpected Keys: {unexpected_keys}")
98
+ print_rank_0("Load checkpoint successfully")
99
+ return model
inference/infra/distributed/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .parallel_state import get_cp_group, get_cp_rank, get_cp_world_size, get_dp_rank, get_pp_rank, get_tp_rank
16
+ from .init_dist_env import initialize_distributed
17
+
18
+ __all__ = [
19
+ # distributed init
20
+ "initialize_distributed",
21
+ # parallel state
22
+ "get_cp_group",
23
+ "get_cp_world_size",
24
+ "get_tp_rank",
25
+ "get_pp_rank",
26
+ "get_dp_rank",
27
+ "get_cp_rank",
28
+ ]
inference/infra/distributed/init_dist_env.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from datetime import timedelta
17
+
18
+ import torch
19
+
20
+ from inference.common import parse_config
21
+
22
+ from .parallel_state import initialize_model_parallel, model_parallel_is_initialized
23
+ from inference.utils import print_rank_0
24
+
25
+
26
+ def initialize_distributed():
27
+ """Initialize torch.distributed and core model parallel."""
28
+ config = parse_config()
29
+
30
+ device_count = torch.cuda.device_count()
31
+ if torch.distributed.is_initialized():
32
+ if torch.distributed.get_rank() == 0:
33
+ print_rank_0("> torch distributed already initialized, skipping initialization ...")
34
+ else:
35
+ rank = int(os.getenv("RANK", "0"))
36
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
37
+ if rank == 0:
38
+ print_rank_0("> initializing torch distributed ...")
39
+ # Manually set the device ids.
40
+ if device_count > 0:
41
+ device = rank % device_count
42
+ torch.cuda.set_device(device)
43
+ # Call the init process
44
+ torch.distributed.init_process_group(
45
+ backend=config.engine_config.distributed_backend,
46
+ world_size=world_size,
47
+ rank=rank,
48
+ timeout=timedelta(minutes=config.engine_config.distributed_timeout_minutes),
49
+ )
50
+
51
+ # Set the tp, pp and dp communicators.
52
+ if device_count > 0:
53
+ if model_parallel_is_initialized():
54
+ return
55
+ initialize_model_parallel(
56
+ tp_size=config.engine_config.tp_size,
57
+ pp_size=config.engine_config.pp_size,
58
+ cp_size=config.engine_config.cp_size,
59
+ nccl_communicator_config_path=None,
60
+ distributed_timeout_minutes=config.engine_config.distributed_timeout_minutes,
61
+ order="tp-cp-pp-dp",
62
+ )
inference/infra/distributed/parallel_state.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Model and data parallel groups."""
17
+
18
+ import warnings
19
+ from datetime import timedelta
20
+ from typing import List, Optional
21
+
22
+ import torch
23
+
24
+ # Intra-layer model parallel group that the current rank belongs to.
25
+ _TENSOR_MODEL_PARALLEL_GROUP = None
26
+ # Tensor parallel group information with context parallel combined.
27
+ _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP = None
28
+ _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP = None
29
+ # Inter-layer model parallel group that the current rank belongs to.
30
+ _PIPELINE_MODEL_PARALLEL_GROUP = None
31
+ # Model parallel group (both intra- and pipeline) that the current rank belongs to.
32
+ _MODEL_PARALLEL_GROUP = None
33
+ # Data parallel group that the current rank belongs to.
34
+ _DATA_PARALLEL_GROUP = None
35
+ # tensor model parallel group and data parallel group combined
36
+ # used for fp8 and moe training
37
+ _TENSOR_AND_DATA_PARALLEL_GROUP = None
38
+
39
+ # A list of global ranks for each pipeline group to ease calculation of the source
40
+ # rank when broadcasting from the first or last pipeline stage.
41
+ _PIPELINE_GLOBAL_RANKS = None
42
+
43
+ # A list of global ranks for each data parallel group to ease calculation of the source
44
+ # rank when broadcasting weights from src to all other data parallel ranks
45
+ _DATA_PARALLEL_GLOBAL_RANKS = None
46
+
47
+ # A list of global ranks for each tensor model parallel group to ease calculation of
48
+ # the first local rank in the tensor model parallel group
49
+ _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = None
50
+
51
+ # Context parallel group that the current rank belongs to
52
+ _CONTEXT_PARALLEL_GROUP = None
53
+ # A list of global ranks for each context parallel group to ease calculation of the
54
+ # destination rank when exchanging KV/dKV between context parallel_ranks
55
+ _CONTEXT_PARALLEL_GLOBAL_RANKS = None
56
+
57
+ _CONTEXT_PARALLEL_EXTRA_GROUP = None
58
+
59
+ # Data parallel group information with context parallel combined.
60
+ _DATA_PARALLEL_GROUP_WITH_CP = None
61
+ _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None
62
+
63
+ # combined parallel group of TP, DP, and CP used for fp8
64
+ _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None
65
+
66
+
67
+ def _get_nccl_options(pg_name, nccl_comm_cfgs):
68
+ """Set the NCCL process group options.
69
+
70
+ Args:
71
+ pg_name (str): process group name
72
+ nccl_comm_cfgs (dict): nccl communicator configurations
73
+
74
+ When an option (e.g., max_ctas) is not found in the config, use the NCCL default setting.
75
+ """
76
+ if pg_name in nccl_comm_cfgs:
77
+ nccl_options = torch.distributed.ProcessGroupNCCL.Options()
78
+ nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get("cga_cluster_size", 4)
79
+ nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get("max_ctas", 32)
80
+ nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get("min_ctas", 1)
81
+ return nccl_options
82
+ else:
83
+ return None
84
+
85
+
86
+ def generate_masked_orthogonal_rank_groups(world_size: int, parallel_size: List[int], mask: List[bool]) -> List[List[int]]:
87
+ r"""Generate orthogonal parallel groups based on the parallel size and mask.
88
+
89
+ Arguments:
90
+ world_size (int): world size
91
+
92
+ parallel_size (List[int]):
93
+ The parallel size of each orthogonal parallel type. For example, if
94
+ tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4,
95
+ and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4].
96
+
97
+ mask (List[bool]):
98
+ The mask controls which parallel methods the generated groups represent. If mask[i] is
99
+ True, it means the generated group contains the i-th parallelism method. For example,
100
+ if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then
101
+ the generated group is the `tp-dp` group, if the mask = [False, True, False], then the
102
+ generated group is the `pp` group.
103
+
104
+ Algorithm:
105
+ For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and
106
+ local_rank satisfy the following equation:
107
+ global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1)
108
+ tp_rank \in [0, tp_size)
109
+ dp_rank \in [0, dp_size)
110
+ pp_rank \in [0, pp_size)
111
+
112
+ If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each.
113
+ For example, if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the
114
+ dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].)
115
+ The tp_rank and pp_rank will be combined to form the `dp_group_index`.
116
+ dp_group_index = tp_rank + pp_rank * tp_size (2)
117
+
118
+ So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in
119
+ range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the
120
+ equation (1).
121
+
122
+ This function solve this math problem.
123
+
124
+ For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4],
125
+ and the mask = [False, True, False]. Then,
126
+ dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2
127
+ dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2
128
+ ...
129
+ dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2
130
+
131
+ dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4]
132
+ dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5]
133
+ ...
134
+ dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23]
135
+ """
136
+
137
+ def prefix_product(a: List[int], init=1) -> List[int]:
138
+ r = [init]
139
+ for v in a:
140
+ init = init * v
141
+ r.append(init)
142
+ return r
143
+
144
+ def inner_product(a: List[int], b: List[int]) -> int:
145
+ return sum([x * y for x, y in zip(a, b)])
146
+
147
+ def decompose(index, shape, stride=None):
148
+ """
149
+ This function solve the math problem below:
150
+ There is an equation:
151
+ index = sum(idx[i] * stride[i])
152
+ And given the value of index, stride.
153
+ Return the idx.
154
+ This function will used to get the pp/dp/pp_rank
155
+ from group_index and rank_in_group.
156
+ """
157
+ if stride is None:
158
+ stride = prefix_product(shape)
159
+ idx = [(index // d) % s for s, d in zip(shape, stride)]
160
+ # stride is a prefix_product result. And the value of stride[-1]
161
+ # is not used.
162
+ assert (
163
+ sum([x * y for x, y in zip(idx, stride[:-1])]) == index
164
+ ), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx)
165
+ return idx
166
+
167
+ masked_shape = [s for s, m in zip(parallel_size, mask) if m]
168
+ unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m]
169
+
170
+ global_stride = prefix_product(parallel_size)
171
+ masked_stride = [d for d, m in zip(global_stride, mask) if m]
172
+ unmasked_stride = [d for d, m in zip(global_stride, mask) if not m]
173
+
174
+ group_size = prefix_product(masked_shape)[-1]
175
+ num_of_group = world_size // group_size
176
+
177
+ ranks = []
178
+ for group_index in range(num_of_group):
179
+ # get indices from unmaksed for group_index.
180
+ decomposed_group_idx = decompose(group_index, unmasked_shape)
181
+ rank = []
182
+ for rank_in_group in range(group_size):
183
+ # get indices from masked for rank_in_group.
184
+ decomposed_rank_idx = decompose(rank_in_group, masked_shape)
185
+ rank.append(
186
+ inner_product(decomposed_rank_idx, masked_stride) + inner_product(decomposed_group_idx, unmasked_stride)
187
+ )
188
+ ranks.append(rank)
189
+ return ranks
190
+
191
+
192
+ class RankGenerator(object):
193
+ def __init__(self, tp: int, dp: int, pp: int, cp: int, order: str) -> None:
194
+ self.tp = tp
195
+ self.dp = dp
196
+ self.pp = pp
197
+ self.cp = cp
198
+ self.world_size = tp * dp * pp * cp
199
+
200
+ self.name_to_size = {"tp": self.tp, "pp": self.pp, "dp": self.dp, "cp": self.cp}
201
+ order = order.lower()
202
+ for name in self.name_to_size.keys():
203
+ if name not in order and self.name_to_size[name] != 1:
204
+ raise RuntimeError(
205
+ f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({order})."
206
+ )
207
+ elif name not in order:
208
+ order = order + "-" + name
209
+
210
+ self.order = order
211
+ self.ordered_size = [self.name_to_size[token] for token in order.split("-")]
212
+
213
+ def get_mask(self, order: str, token: str):
214
+ ordered_token = order.split("-")
215
+ token = token.split("-")
216
+ mask = [False] * len(ordered_token)
217
+ for t in token:
218
+ mask[ordered_token.index(t)] = True
219
+ return mask
220
+
221
+ def get_ranks(self, token):
222
+ """Get rank group by input token.
223
+
224
+ Arguments:
225
+ token (str):
226
+ Specify the ranks type that want to get. If we want
227
+ to obtain multiple parallel types, we can use a hyphen
228
+ '-' to separate them. For example, if we want to obtain
229
+ the TP_DP group, the token should be 'tp-dp'.
230
+ """
231
+ mask = self.get_mask(self.order, token)
232
+ ranks = generate_masked_orthogonal_rank_groups(self.world_size, self.ordered_size, mask)
233
+ return ranks
234
+
235
+
236
+ def initialize_model_parallel(
237
+ tp_size: int = 1,
238
+ pp_size: int = 1,
239
+ cp_size: int = 1,
240
+ nccl_communicator_config_path: Optional[str] = None,
241
+ distributed_timeout_minutes: int = 30,
242
+ order: str = "tp-cp-pp-dp",
243
+ ) -> None:
244
+ """Initialize model data parallel groups.
245
+ Borrow from: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
246
+
247
+ Args:
248
+ tp_size (int, default = 1):
249
+ The number of GPUs to split individual tensors across.
250
+
251
+ pp_size (int, default = 1):
252
+ The number of tensor parallel GPU groups to split the
253
+ Transformer layers across. For example, if tp_size is 4 and
254
+ pp_size is 2, the model will be split into 2 groups of 4 GPUs.
255
+
256
+ cp_size (int, default = 1):
257
+ The number of tensor parallel GPU groups to split the
258
+ network input sequence length across. Compute of attention
259
+ module requires tokens of full sequence length, so GPUs
260
+ in a context parallel group need to communicate with each
261
+ other to exchange information of other sequence chunks.
262
+ Each GPU and its counterparts in other tensor parallel
263
+ groups compose a context parallel group.
264
+
265
+ For example, assume we have 8 GPUs, if tensor model parallel
266
+ size is 4 and context parallel size is 2, the network input
267
+ will be split into two sequence chunks, which are processed
268
+ by 2 different groups of 4 GPUs. One chunk is processed by
269
+ GPU0-3, the other chunk is processed by GPU4-7. Four groups
270
+ are build to do context parallel communications: [GPU0, GPU4],
271
+ [GPU1, GPU5], [GPU2, GPU6], and [GPU3, GPU7].
272
+
273
+ Context parallelism partitions sequence length, so it has no
274
+ impact on weights, which means weights are duplicated among
275
+ GPUs in a context parallel group. Hence, weight gradients
276
+ all-reduce is required in backward. For simplicity, we piggyback
277
+ GPUs of context parallelism on data parallel group for
278
+ weight gradient all-reduce.
279
+
280
+ nccl_communicator_config_path (str, default = None):
281
+ Path to the yaml file of NCCL communicator configurations.
282
+ `min_ctas`, `max_ctas`, and `cga_cluster_size` can be set
283
+ for each communicator.
284
+
285
+ distributed_timeout_minutes (int, default = 30): Timeout, in
286
+ minutes,for operations executed against distributed
287
+ process groups. See PyTorch documentation at
288
+ https://pytorch.org/docs/stable/distributed.html for
289
+ caveats.
290
+
291
+ order (str, default=tp-dp-pp):
292
+ The rank initialization order of parallelism. Now we support
293
+ tp-dp-pp and tp-pp-dp orders.
294
+
295
+ Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
296
+ use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
297
+ the model pipeline. The present function will
298
+ create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
299
+ and 8 data-parallel groups as:
300
+ 8 data_parallel groups:
301
+ [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
302
+ 8 tensor model-parallel groups:
303
+ [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
304
+ 4 pipeline model-parallel groups:
305
+ [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
306
+ Note that for efficiency, the caller should make sure adjacent ranks
307
+ are on the same DGX box. For example if we are using 2 DGX-1 boxes
308
+ with a total of 16 GPUs, rank 0 to 7 belong to the first box and
309
+ ranks 8 to 15 belong to the second box.
310
+
311
+ """
312
+ # Get world size and rank. Ensure some consistencies.
313
+ assert torch.distributed.is_initialized()
314
+ world_size: int = torch.distributed.get_world_size()
315
+ if world_size % (tp_size * pp_size * cp_size) != 0:
316
+ raise RuntimeError(
317
+ f"world_size ({world_size}) is not divisible by tp_size "
318
+ f"({tp_size}) x pp_size ({pp_size}) "
319
+ f"x cp_size ({cp_size})"
320
+ )
321
+
322
+ nccl_comm_cfgs = {}
323
+ if nccl_communicator_config_path is not None:
324
+ try:
325
+ import yaml
326
+ except ImportError:
327
+ raise RuntimeError("Cannot import `yaml`. Setting custom nccl communicator configs " "requires the yaml package.")
328
+
329
+ with open(nccl_communicator_config_path, "r") as stream:
330
+ nccl_comm_cfgs = yaml.safe_load(stream)
331
+
332
+ dp_size: int = world_size // (tp_size * pp_size * cp_size)
333
+ rank = torch.distributed.get_rank()
334
+ rank_generator = RankGenerator(tp=tp_size, dp=dp_size, pp=pp_size, cp=cp_size, order=order)
335
+ timeout = timedelta(minutes=distributed_timeout_minutes)
336
+
337
+ # Build the data-parallel groups.
338
+ global _DATA_PARALLEL_GROUP
339
+ global _DATA_PARALLEL_GLOBAL_RANKS
340
+ global _DATA_PARALLEL_GROUP_WITH_CP
341
+ global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
342
+ assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
343
+
344
+ for ranks in rank_generator.get_ranks("dp"):
345
+ group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=_get_nccl_options("dp", nccl_comm_cfgs))
346
+ if rank in ranks:
347
+ _DATA_PARALLEL_GROUP = group
348
+ _DATA_PARALLEL_GLOBAL_RANKS = ranks
349
+ for ranks_with_cp in rank_generator.get_ranks("dp-cp"):
350
+ group_with_cp = torch.distributed.new_group(
351
+ ranks_with_cp, timeout=timeout, pg_options=_get_nccl_options("dp_cp", nccl_comm_cfgs)
352
+ )
353
+ if rank in ranks_with_cp:
354
+ _DATA_PARALLEL_GROUP_WITH_CP = group_with_cp
355
+ _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks_with_cp
356
+
357
+ # Build the context-parallel groups.
358
+ global _CONTEXT_PARALLEL_GROUP
359
+ global _CONTEXT_PARALLEL_GLOBAL_RANKS
360
+ assert _CONTEXT_PARALLEL_GROUP is None, "context parallel group is already initialized"
361
+ for ranks in rank_generator.get_ranks("cp"):
362
+ group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=_get_nccl_options("cp", nccl_comm_cfgs))
363
+ if rank in ranks:
364
+ _CONTEXT_PARALLEL_GROUP = group
365
+ _CONTEXT_PARALLEL_GLOBAL_RANKS = ranks
366
+
367
+
368
+ # Build the model-parallel groups.
369
+ global _MODEL_PARALLEL_GROUP
370
+ assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
371
+ for ranks in rank_generator.get_ranks("tp-pp"):
372
+ group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=_get_nccl_options("mp", nccl_comm_cfgs))
373
+ if rank in ranks:
374
+ _MODEL_PARALLEL_GROUP = group
375
+
376
+ # Build the tensor model-parallel groups.
377
+ global _TENSOR_MODEL_PARALLEL_GROUP
378
+ global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
379
+ assert _TENSOR_MODEL_PARALLEL_GROUP is None, "tensor model parallel group is already initialized"
380
+ for ranks in rank_generator.get_ranks("tp"):
381
+ group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=_get_nccl_options("tp", nccl_comm_cfgs))
382
+ if rank in ranks:
383
+ _TENSOR_MODEL_PARALLEL_GROUP = group
384
+ _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = ranks
385
+
386
+ # Build the tensor + context parallel groups.
387
+ global _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP
388
+ global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP
389
+ assert (
390
+ _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP is None
391
+ ), "tensor model parallel group with context parallel is already initialized"
392
+ for ranks in rank_generator.get_ranks("tp-cp"):
393
+ group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=_get_nccl_options("tp_cp", nccl_comm_cfgs))
394
+ if rank in ranks:
395
+ _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP = group
396
+ _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks
397
+
398
+ # Build the pipeline model-parallel groups
399
+ global _PIPELINE_MODEL_PARALLEL_GROUP
400
+ global _PIPELINE_GLOBAL_RANKS
401
+ assert _PIPELINE_MODEL_PARALLEL_GROUP is None, "pipeline model parallel group is already initialized"
402
+ for ranks in rank_generator.get_ranks("pp"):
403
+ group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=_get_nccl_options("pp", nccl_comm_cfgs))
404
+ if rank in ranks:
405
+ _PIPELINE_MODEL_PARALLEL_GROUP = group
406
+ _PIPELINE_GLOBAL_RANKS = ranks
407
+
408
+ # Build the tensor + data parallel groups.
409
+ global _TENSOR_AND_DATA_PARALLEL_GROUP
410
+ global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
411
+ assert _TENSOR_AND_DATA_PARALLEL_GROUP is None, "Tensor + data parallel group is already initialized"
412
+ for ranks in rank_generator.get_ranks("tp-cp-dp"):
413
+ group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=_get_nccl_options("tp_cp_dp", nccl_comm_cfgs))
414
+ if rank in ranks:
415
+ _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = group
416
+ for ranks in rank_generator.get_ranks("tp-dp"):
417
+ group = torch.distributed.new_group(ranks, timeout=timeout, pg_options=_get_nccl_options("tp_dp", nccl_comm_cfgs))
418
+ if rank in ranks:
419
+ _TENSOR_AND_DATA_PARALLEL_GROUP = group
420
+
421
+
422
+ def is_initialized():
423
+ """Useful for code segments that may be accessed with or without mpu initialization"""
424
+ return _DATA_PARALLEL_GROUP is not None
425
+
426
+
427
+ def is_unitialized() -> bool:
428
+ """Check if parallel state has been initialized
429
+
430
+ Deprecated. Use is_initialized instead.
431
+
432
+ """
433
+ warnings.warn("is_unitialized is deprecated, use is_initialized instead", DeprecationWarning)
434
+ return not is_initialized()
435
+
436
+
437
+ def model_parallel_is_initialized():
438
+ """Check if model and data parallel groups are initialized."""
439
+ if _TENSOR_MODEL_PARALLEL_GROUP is None or _PIPELINE_MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
440
+ return False
441
+ return True
442
+
443
+
444
+ def get_model_parallel_group():
445
+ """Get the model parallel group the caller rank belongs to."""
446
+ assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
447
+ return _MODEL_PARALLEL_GROUP
448
+
449
+
450
+ def get_tp_group(check_initialized=True, with_context_parallel=False):
451
+ """Get the tensor model parallel group the caller rank belongs to."""
452
+ if check_initialized:
453
+ assert _TENSOR_MODEL_PARALLEL_GROUP is not None, "tensor model parallel group is not initialized"
454
+ if with_context_parallel:
455
+ assert (
456
+ _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP is not None
457
+ ), "tensor model parallel group with context parallel combined is not initialized"
458
+ return _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP
459
+ else:
460
+ assert _TENSOR_MODEL_PARALLEL_GROUP is not None, "tensor model parallel group is not initialized"
461
+ return _TENSOR_MODEL_PARALLEL_GROUP
462
+
463
+
464
+ def get_pp_group():
465
+ """Get the pipeline model parallel group the caller rank belongs to."""
466
+ assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, "pipeline_model parallel group is not initialized"
467
+ return _PIPELINE_MODEL_PARALLEL_GROUP
468
+
469
+
470
+ def get_dp_group(with_context_parallel=False):
471
+ """Get the data parallel group the caller rank belongs to."""
472
+ if with_context_parallel:
473
+ assert (
474
+ _DATA_PARALLEL_GROUP_WITH_CP is not None
475
+ ), "data parallel group with context parallel combined is not initialized"
476
+ return _DATA_PARALLEL_GROUP_WITH_CP
477
+ else:
478
+ assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
479
+ return _DATA_PARALLEL_GROUP
480
+
481
+
482
+ def get_cp_group(check_initialized=True):
483
+ """Get the context parallel group the caller rank belongs to."""
484
+ if check_initialized:
485
+ assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized"
486
+ return _CONTEXT_PARALLEL_GROUP
487
+
488
+
489
+ def get_cp_extra_group(check_initialized=True):
490
+ if check_initialized:
491
+ assert _CONTEXT_PARALLEL_EXTRA_GROUP is not None, "context parallel extra group is not initialized"
492
+ return _CONTEXT_PARALLEL_EXTRA_GROUP
493
+
494
+
495
+ def get_tp_world_size(with_context_parallel=False):
496
+ """Return world size for the tensor model parallel group."""
497
+ return torch.distributed.get_world_size(group=get_tp_group(with_context_parallel=with_context_parallel))
498
+
499
+
500
+ def get_pp_world_size():
501
+ """Return world size for the pipeline model parallel group."""
502
+ return torch.distributed.get_world_size(group=get_pp_group())
503
+
504
+
505
+ def get_tp_rank(with_context_parallel=False):
506
+ """Return my rank for the tensor model parallel group."""
507
+ return torch.distributed.get_rank(group=get_tp_group(with_context_parallel=with_context_parallel))
508
+
509
+
510
+ def get_pp_rank():
511
+ """Return my rank for the pipeline model parallel group."""
512
+ return torch.distributed.get_rank(group=get_pp_group())
513
+
514
+
515
+ def is_pipeline_first_stage():
516
+ """Return True if in the first pipeline model-parallel stage, False otherwise."""
517
+ return get_pp_rank() == 0
518
+
519
+
520
+ def is_pipeline_last_stage():
521
+ """Return True if in the last pipeline model-parallel stage, False otherwise."""
522
+ return get_pp_rank() == (get_pp_world_size() - 1)
523
+
524
+
525
+ def get_tensor_model_parallel_src_rank(with_context_parallel=False):
526
+ """Calculate the global rank corresponding to the first local rank
527
+ in the tensor model parallel group."""
528
+ assert _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS is not None, "Tensor model parallel group is not initialized"
529
+ if with_context_parallel:
530
+ assert (
531
+ _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP is not None
532
+ ), "Tensor model parallel group with context parallel combined is not initialized"
533
+ return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP[0]
534
+ else:
535
+ return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS[0]
536
+
537
+
538
+ def get_tensor_model_parallel_ranks(with_context_parallel=False):
539
+ """Return all global ranks for the tensor model parallel group."""
540
+ if with_context_parallel:
541
+ assert (
542
+ _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP is not None
543
+ ), "Tensor model parallel group with context parallel combined is not initialized"
544
+ return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP
545
+ else:
546
+ assert _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS is not None, "Tensor model parallel group is not initialized"
547
+ return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
548
+
549
+
550
+ def get_tensor_model_parallel_last_rank(with_context_parallel=False):
551
+ """Calculate the global rank corresponding to the first local rank
552
+ in the tensor model parallel group."""
553
+ assert _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS is not None, "Tensor model parallel group is not initialized"
554
+ if with_context_parallel:
555
+ assert (
556
+ _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP is not None
557
+ ), "Tensor model parallel group with context parallel combined is not initialized"
558
+ return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP[-1]
559
+ else:
560
+ return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS[-1]
561
+
562
+
563
+ def get_pipeline_model_parallel_first_rank():
564
+ """Return the global rank of the first process in the pipeline for the
565
+ current tensor parallel group"""
566
+ assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
567
+ return _PIPELINE_GLOBAL_RANKS[0]
568
+
569
+
570
+ def get_pipeline_model_parallel_last_rank():
571
+ """Return the global rank of the last process in the pipeline for the
572
+ current tensor parallel group"""
573
+ assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
574
+ last_rank_local = get_pp_world_size() - 1
575
+ return _PIPELINE_GLOBAL_RANKS[last_rank_local]
576
+
577
+
578
+ def get_pipeline_model_parallel_next_rank():
579
+ """Return the global rank that follows the caller in the pipeline"""
580
+ assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
581
+ rank_in_pipeline = get_pp_rank()
582
+ world_size = get_pp_world_size()
583
+ return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
584
+
585
+
586
+ def get_pipeline_model_parallel_prev_rank():
587
+ """Return the global rank that preceeds the caller in the pipeline"""
588
+ assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
589
+ rank_in_pipeline = get_pp_rank()
590
+ world_size = get_pp_world_size()
591
+ return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
592
+
593
+
594
+ def get_dp_world_size(with_context_parallel=False):
595
+ """Return world size for the data parallel group."""
596
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
597
+ return torch.distributed.get_world_size(group=get_dp_group(with_context_parallel=with_context_parallel))
598
+ else:
599
+ return 0
600
+
601
+
602
+ def get_dp_rank(with_context_parallel=False):
603
+ """Return my rank for the data parallel group."""
604
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
605
+ return torch.distributed.get_rank(group=get_dp_group(with_context_parallel=with_context_parallel))
606
+ else:
607
+ return 0
608
+
609
+
610
+ def get_cp_world_size():
611
+ """Return world size for the context parallel group."""
612
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
613
+ return torch.distributed.get_world_size(group=get_cp_group())
614
+ else:
615
+ return 0
616
+
617
+
618
+ def get_cp_rank():
619
+ """Return my rank for the context parallel group."""
620
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
621
+ return torch.distributed.get_rank(group=get_cp_group())
622
+ else:
623
+ return 0
624
+
625
+
626
+ def destroy_model_parallel():
627
+ """Set the groups to none."""
628
+ global _MODEL_PARALLEL_GROUP
629
+ _MODEL_PARALLEL_GROUP = None
630
+ global _TENSOR_MODEL_PARALLEL_GROUP
631
+ _TENSOR_MODEL_PARALLEL_GROUP = None
632
+ global _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP
633
+ _TENSOR_MODEL_PARALLEL_GROUP_WITH_CP = None
634
+ global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP
635
+ _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS_WITH_CP = None
636
+ global _PIPELINE_MODEL_PARALLEL_GROUP
637
+ _PIPELINE_MODEL_PARALLEL_GROUP = None
638
+ global _DATA_PARALLEL_GROUP
639
+ _DATA_PARALLEL_GROUP = None
640
+ global _TENSOR_AND_DATA_PARALLEL_GROUP
641
+ _TENSOR_AND_DATA_PARALLEL_GROUP = None
642
+ global _PIPELINE_GLOBAL_RANKS
643
+ _PIPELINE_GLOBAL_RANKS = None
644
+ global _DATA_PARALLEL_GLOBAL_RANKS
645
+ _DATA_PARALLEL_GLOBAL_RANKS = None
646
+ global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
647
+ _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = None
648
+ global _CONTEXT_PARALLEL_GROUP
649
+ _CONTEXT_PARALLEL_GROUP = None
650
+ global _CONTEXT_PARALLEL_GLOBAL_RANKS
651
+ _CONTEXT_PARALLEL_GLOBAL_RANKS = None
652
+ global _CONTEXT_PARALLEL_EXTRA_GROUP
653
+ _CONTEXT_PARALLEL_EXTRA_GROUP = None
654
+ global _DATA_PARALLEL_GROUP_WITH_CP
655
+ _DATA_PARALLEL_GROUP_WITH_CP = None
656
+ global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
657
+ _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None
658
+ global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
659
+ _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None
inference/infra/distributed/utils.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+
17
+ from .parallel_state import get_tp_rank, get_tp_world_size
18
+
19
+
20
+ def is_last_rank():
21
+ return torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1)
22
+
23
+
24
+ def is_last_tp_cp_rank():
25
+ return get_tp_rank(with_context_parallel=True) == get_tp_world_size(with_context_parallel=True) - 1
26
+
27
+
28
+ def get_world_size():
29
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
30
+ world_size = torch.distributed.get_world_size()
31
+ else:
32
+ world_size = 1
33
+ return world_size
34
+
35
+
36
+ def get_device(local_rank=None):
37
+ backend = torch.distributed.get_backend()
38
+ if backend == "nccl":
39
+ if local_rank is None:
40
+ device = torch.device("cuda")
41
+ else:
42
+ device = torch.device(f"cuda:{local_rank}")
43
+ elif backend == "gloo":
44
+ device = torch.device("cpu")
45
+ else:
46
+ raise RuntimeError
47
+ return device
inference/infra/parallelism/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .ulysses_scheduler import ulysses_scheduler
16
+
17
+ __all__ = [
18
+ # context parallel
19
+ "ulysses_scheduler",
20
+ ]
inference/infra/parallelism/all_to_all_primitive.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Tuple, Union
16
+
17
+ import torch
18
+ import torch.distributed as dist
19
+ from einops import rearrange
20
+
21
+ from inference.utils import divide
22
+
23
+
24
+ class FakeHandle:
25
+ def __init__(self):
26
+ pass
27
+
28
+ def wait(self):
29
+ pass
30
+
31
+
32
+ def scatter_head_gather_seqlen(
33
+ tensor: torch.Tensor, split_sizes: List[int] = None, group: dist.ProcessGroup = None, async_op: bool = True
34
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, Union[dist.Work, FakeHandle]]]:
35
+ """
36
+ Scatter head_number and gather seq_len, for example:
37
+ input: (seq_len, cp * hn, hd)
38
+ output: (seq_len * cp, hn, hd)
39
+ NOTE: seq_len of input maybe not equal, which depends on split_sizes[rank]
40
+ """
41
+ if group is None or dist.get_world_size(group) == 1:
42
+ return tensor, FakeHandle()
43
+ group_world_size = dist.get_world_size(group)
44
+ if split_sizes is None:
45
+ split_sizes = [tensor.shape[0]] * group_world_size
46
+
47
+ _, hn, _ = tensor.shape
48
+ if group_world_size % hn == 0 and group_world_size != hn:
49
+ tensor = torch.repeat_interleave(tensor, repeats=divide(group_world_size, hn), dim=1).contiguous()
50
+ assert tensor.is_contiguous()
51
+ input_split_sizes = [tensor.shape[0]] * group_world_size
52
+ input = rearrange(tensor, "seq (cp hn) hd -> (cp seq) hn hd", cp=group_world_size).contiguous()
53
+ output = torch.empty([sum(split_sizes), *input.shape[1:]], device=input.device, dtype=input.dtype)
54
+ if async_op:
55
+ handle = dist.all_to_all_single(
56
+ output, input, output_split_sizes=split_sizes, input_split_sizes=input_split_sizes, group=group, async_op=True
57
+ )
58
+ return output, handle
59
+ else:
60
+ dist.all_to_all_single(
61
+ output, input, output_split_sizes=split_sizes, input_split_sizes=input_split_sizes, group=group, async_op=False
62
+ )
63
+ return output
64
+
65
+
66
+ def scatter_seqlen_gather_head(
67
+ tensor: torch.Tensor, split_sizes: List[int] = None, group: dist.ProcessGroup = None, async_op: bool = True
68
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, Union[dist.Work, FakeHandle]]]:
69
+ """
70
+ Scatter seq_len and gather head_number, for example:
71
+ input: (seq_len * cp, hn, hd)
72
+ output: (seq_len, cp * hn, hd)
73
+ NOTE: seq_len of output maybe not equal, which depends on split_sizes[rank]
74
+ NOTE: rearrange the tensor after communication: (cp, seq, hn, hd) -> (seq, cp * hn, hd)
75
+ """
76
+ if group is None or dist.get_world_size(group) == 1:
77
+ return tensor, FakeHandle() if async_op else tensor
78
+ group_world_size = dist.get_world_size(group)
79
+ if split_sizes is None:
80
+ assert (
81
+ tensor.shape[0] % group_world_size == 0
82
+ ), f"tensor.shape[0] {tensor.shape[0]} % group_world_size {group_world_size} != 0"
83
+ split_sizes = [tensor.shape[0] // group_world_size] * group_world_size
84
+ assert tensor.is_contiguous()
85
+ assert tensor.dim() == 3, f"tensor must be 3D, but got {tensor.dim()}D"
86
+ output = torch.empty(
87
+ [group_world_size * split_sizes[dist.get_rank(group)], *tensor.shape[1:]], device=tensor.device, dtype=tensor.dtype
88
+ )
89
+ output_split_sizes = [split_sizes[dist.get_rank(group)]] * group_world_size
90
+ if async_op:
91
+ handle = dist.all_to_all_single(
92
+ output, tensor, output_split_sizes=output_split_sizes, input_split_sizes=split_sizes, group=group, async_op=True
93
+ )
94
+ return output, handle
95
+ else:
96
+ dist.all_to_all_single(
97
+ output, tensor, output_split_sizes=output_split_sizes, input_split_sizes=split_sizes, group=group, async_op=False
98
+ )
99
+ return output
100
+
101
+
102
+ def batch_scatter_head_gather_seqlen(
103
+ inputs: List[torch.Tensor], split_sizes: List[int] = None, group: dist.ProcessGroup = None
104
+ ) -> List[torch.Tensor]:
105
+ """
106
+ Batch scatter head_number and gather seq_len, for example:
107
+ inputs[i] input: (seq_len_i, cp * hn_i, hd)
108
+ outputs[i] output: (seq_len_i * cp, hn_i, hd)
109
+ NOTE: seq_len of inputs maybe not equal across ranks, which depends on split_sizes[rank]
110
+ NOTE: fuse along head dim before communication, and split back after
111
+ """
112
+ if group is None or dist.get_world_size(group) == 1:
113
+ return inputs
114
+ rank = dist.get_rank(group)
115
+ group_world_size = dist.get_world_size(group)
116
+ if split_sizes is None:
117
+ split_sizes = [inputs[0].shape[0]] * group_world_size
118
+ assert all(
119
+ input.shape[0] == split_sizes[rank] for input in inputs
120
+ ), f"inputs[0].shape[0] {inputs[0].shape[0]} != split_sizes[rank] {split_sizes[rank]}"
121
+ assert all(input.dim() == 3 for input in inputs), f"inputs[0].dim() {inputs[0].dim()} != 3"
122
+ for idx in range(len(inputs)):
123
+ _, hn, _ = inputs[idx].shape
124
+ if group_world_size % hn == 0 and group_world_size != hn:
125
+ inputs[idx] = torch.repeat_interleave(inputs[idx], repeats=divide(group_world_size, hn), dim=1)
126
+ inputs[idx] = rearrange(inputs[idx], "seq (cp hn) hd -> (cp seq) hn hd", cp=group_world_size).contiguous()
127
+
128
+ head_split_number = [input.shape[1] for input in inputs]
129
+ fused_input = torch.cat(inputs, dim=1).contiguous()
130
+ input_split_sizes = [fused_input.shape[0] // group_world_size] * group_world_size
131
+
132
+ fused_output = torch.empty([sum(split_sizes), *fused_input.shape[1:]], device=fused_input.device, dtype=fused_input.dtype)
133
+ dist.all_to_all_single(
134
+ fused_output,
135
+ fused_input,
136
+ output_split_sizes=split_sizes,
137
+ input_split_sizes=input_split_sizes,
138
+ group=group,
139
+ async_op=False,
140
+ )
141
+ outputs = torch.split(fused_output, head_split_number, dim=1)
142
+ return outputs
inference/infra/parallelism/gather_scatter_primitive.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from functools import partial
16
+ from typing import List, Union
17
+
18
+ import torch
19
+ import torch.distributed as dist
20
+ from torch.utils._pytree import tree_map
21
+
22
+
23
+ class Metadata:
24
+ def __init__(self, dtype: torch.dtype, numel: int, ndim: int, shape: List[int]):
25
+ self.dtype = dtype
26
+ self.numel = numel
27
+ self.ndim = ndim
28
+ self.shape = shape
29
+
30
+ def __repr__(self):
31
+ return f"Metadata(dtype={self.dtype}, numel={self.numel}, ndim={self.ndim}, shape={self.shape})"
32
+
33
+
34
+ def _gather_metadata(tensor_list: List[torch.Tensor], group: dist.ProcessGroup) -> List[List[Metadata]]:
35
+ dist.get_rank(group)
36
+ world_size = dist.get_world_size(group)
37
+
38
+ local_rank = torch.distributed.get_rank() % torch.cuda.device_count()
39
+ assert (
40
+ local_rank == torch.cuda.current_device()
41
+ ), f"local_rank {local_rank} != current_device {torch.cuda.current_device()}"
42
+ device = tensor_list[0].device if len(tensor_list) > 0 else torch.device("cuda")
43
+
44
+ # ========== Step 1: flatten local tensor list ==========
45
+
46
+ # Metadata: [dtype_code, numel, ndim, *shape]
47
+ local_metadata = []
48
+
49
+ dtype_map = {torch.float32: 0, torch.float16: 1, torch.bfloat16: 2, torch.int32: 3, torch.int64: 4, torch.uint8: 5}
50
+ reverse_dtype_map = {v: k for k, v in dtype_map.items()}
51
+
52
+ for t in tensor_list:
53
+ dtype_code = dtype_map[t.dtype]
54
+ shape = list(t.shape)
55
+ numel = t.numel()
56
+ local_metadata.append(torch.tensor([dtype_code, numel, len(shape)] + shape, dtype=torch.int32, device=device))
57
+
58
+ if local_metadata:
59
+ local_metadata_tensor = torch.cat(local_metadata)
60
+ else:
61
+ local_metadata_tensor = torch.empty(0, dtype=torch.int32, device=device)
62
+ local_metadata_tensor = local_metadata_tensor.contiguous()
63
+ local_metadata_len = torch.tensor([local_metadata_tensor.numel()], dtype=torch.int32, device=device)
64
+
65
+ # ========== Step 2: all_gather metadata lengths ==========
66
+ metadata_lens = [torch.empty_like(local_metadata_len) for _ in range(world_size)]
67
+ dist.all_gather(metadata_lens, local_metadata_len, group)
68
+
69
+ # ========== Step 3: all_gather metadata payloads (with cpu tensor) ==========
70
+ metadata_lists = [torch.empty(m.item(), dtype=torch.int32, device=device) for m in metadata_lens]
71
+ dist.all_gather(metadata_lists, local_metadata_tensor, group)
72
+
73
+ # ========== Step 4: decode metadata and reconstruct tensor list ==========
74
+ result = []
75
+ for metadata_list in metadata_lists:
76
+ offset = 0
77
+ local_metadata = []
78
+ while offset < metadata_list.numel():
79
+ dtype_code = metadata_list[offset].item()
80
+ numel = metadata_list[offset + 1].item()
81
+ ndim = metadata_list[offset + 2].item()
82
+ shape = metadata_list[offset + 3 : offset + 3 + ndim].tolist()
83
+ offset += 3 + ndim
84
+
85
+ local_metadata.append(Metadata(reverse_dtype_map[dtype_code], numel, ndim, shape))
86
+ result.append(local_metadata)
87
+
88
+ return result
89
+
90
+
91
+ def _get_dtype_and_assert_consistency(metadata_lists: List[List[Metadata]]):
92
+ dtype_set = set()
93
+ for metadata_list in metadata_lists:
94
+ for metadata in metadata_list:
95
+ dtype_set.add(metadata.dtype)
96
+ assert len(dtype_set) == 1, f"Metadata lists are not consistent: {dtype_set}"
97
+ return dtype_set.pop()
98
+
99
+
100
+ def _get_numel_for_each_rank(metadata_lists: List[List[Metadata]]) -> List[int]:
101
+ return [sum(meta.numel for meta in metadata_list) for metadata_list in metadata_lists]
102
+
103
+
104
+ def gather_arbitrary_tensor_list(tensor_list: List[torch.Tensor], group: dist.ProcessGroup) -> List[torch.Tensor]:
105
+ """
106
+ Magic gather primitive. Provide the following features:
107
+ 1. Support tensor list with different length for each rank.
108
+ 2. Support arbitrary Tensor, which means the Tensor can have different shapes but same dtype.
109
+ 3. Support empty tensor_list in some ranks without padding.
110
+
111
+ Args:
112
+ tensor_list: A list of tensors to gather.
113
+ group: The process group to use.
114
+
115
+ Returns:
116
+ A list of tensors gathered from all ranks.
117
+ """
118
+
119
+ dist.get_rank(group)
120
+ world_size = dist.get_world_size(group)
121
+
122
+ local_rank = torch.distributed.get_rank() % torch.cuda.device_count()
123
+ assert (
124
+ local_rank == torch.cuda.current_device()
125
+ ), f"local_rank {local_rank} != current_device {torch.cuda.current_device()}"
126
+ device = tensor_list[0].device if len(tensor_list) > 0 else torch.device("cuda")
127
+
128
+ # Step 1: Gather metadata
129
+ metadata_lists = _gather_metadata(tensor_list, group)
130
+ tensor_dtype = _get_dtype_and_assert_consistency(metadata_lists)
131
+
132
+ # Step 2: Flatten local tensors into a single 1D buffer
133
+ if tensor_list:
134
+ flat_tensor = torch.cat([t.flatten() for t in tensor_list], dim=0).contiguous()
135
+ else:
136
+ flat_tensor = torch.empty(0, dtype=tensor_dtype, device=device) # dummy, will be ignored
137
+
138
+ # Step 3: Gather lengths from metadata
139
+ all_numels_int = _get_numel_for_each_rank(metadata_lists)
140
+
141
+ # Step 4: Allocate buffers and gather flat tensor data
142
+ output_flat_tensors = []
143
+ for numel in all_numels_int:
144
+ output_flat_tensors.append(torch.empty(numel, dtype=tensor_dtype, device=device))
145
+ dist.all_gather(output_flat_tensors, flat_tensor, group)
146
+
147
+ # Step 5: Reconstruct individual tensors using metadata
148
+ gathered_tensor_lists = []
149
+ for i in range(world_size):
150
+ flat = output_flat_tensors[i]
151
+ if flat.numel() == 0:
152
+ continue
153
+ metadata_list = metadata_lists[i]
154
+ offset = 0
155
+ for meta in metadata_list:
156
+ numel = meta.numel
157
+ t = flat[offset : offset + numel].view(meta.shape).to(meta.dtype)
158
+ offset += numel
159
+ gathered_tensor_lists.append(t)
160
+
161
+ return gathered_tensor_lists
162
+
163
+
164
+ def _scatter_to_context_parallel_region(input: torch.Tensor, split_sizes: List[int], group: dist.ProcessGroup = None):
165
+ """Split the tensor along its first dimension and keep the
166
+ corresponding slice."""
167
+ # Split along first dimension with padding.
168
+ rank = dist.get_rank(group)
169
+ dim_offset = sum(split_sizes[:rank])
170
+ output = input[dim_offset : dim_offset + split_sizes[rank]].contiguous()
171
+ return output
172
+
173
+
174
+ def scatter_to_context_parallel_region(
175
+ inputs: Union[torch.Tensor, List[torch.Tensor]], split_sizes: List[int] = None, group: dist.ProcessGroup = None
176
+ ):
177
+ """Split the tensor along its first dimension and keep the
178
+ corresponding slice."""
179
+ if group is None or torch.distributed.get_world_size(group) == 1:
180
+ return inputs
181
+
182
+ if split_sizes is None:
183
+ assert (
184
+ inputs.shape[0] % dist.get_world_size(group) == 0
185
+ ), f"inputs.shape[0] {inputs.shape[0]} % dist.get_world_size(group) {dist.get_world_size(group)} != 0"
186
+ split_sizes = [inputs.shape[0] // dist.get_world_size(group)] * dist.get_world_size(group)
187
+
188
+ partial_func = partial(_scatter_to_context_parallel_region, split_sizes=split_sizes, group=group)
189
+ return tree_map(partial_func, inputs)
190
+
191
+
192
+ def _gather_from_context_parallel_region(
193
+ input: Union[torch.Tensor, List[torch.Tensor]], split_sizes: List[int], group: dist.ProcessGroup = None
194
+ ):
195
+ input = input.contiguous()
196
+ dim_size = list(input.size())
197
+ dim_size[0] = sum(split_sizes)
198
+
199
+ output = torch.empty(dim_size, dtype=input.dtype, device=input.device)
200
+ outputs = list(torch.split(output, split_sizes, dim=0))
201
+ torch.distributed.all_gather(outputs, input, group=group)
202
+ output = torch.concat(outputs, dim=0)
203
+
204
+ return output
205
+
206
+
207
+ def gather_from_context_parallel_region(
208
+ inputs: Union[torch.Tensor, List[torch.Tensor]], split_sizes: List[int] = None, group: dist.ProcessGroup = None
209
+ ):
210
+ """Gather tensors and concatinate along the first dimension."""
211
+ if group is None or torch.distributed.get_world_size(group) == 1:
212
+ return inputs
213
+
214
+ if split_sizes is None:
215
+ split_sizes = [inputs.shape[0] * dist.get_world_size(group)]
216
+ partial_func = partial(_gather_from_context_parallel_region, split_sizes=split_sizes, group=group)
217
+ return tree_map(partial_func, inputs)
inference/infra/parallelism/ulysses_scheduler.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Generic, List, Optional, TypeVar
16
+
17
+ import torch
18
+ from torch.utils._pytree import tree_map
19
+
20
+ from inference.infra.distributed import get_cp_group, get_cp_world_size
21
+
22
+ from .gather_scatter_primitive import gather_from_context_parallel_region, scatter_to_context_parallel_region
23
+
24
+ T = TypeVar("T")
25
+
26
+
27
+ class UlyssesScheduler(Generic[T]):
28
+ """
29
+ A naive implementation of Ulysses scheduler for context parallel processing.
30
+
31
+ This scheduler handles tensor dispatching and undispatching operations when tensors
32
+ enter and exit the context parallel region. It supports arbitrary nested data structures
33
+ containing tensors and automatically handles padding and splitting operations.
34
+
35
+ The scheduler splits input tensors along the sequence dimension across multiple GPUs
36
+ in the context parallel group, enabling parallel processing of long sequences.
37
+ """
38
+
39
+ def __init__(self):
40
+ """Initialize the Ulysses scheduler."""
41
+ self._cp_split_sizes: Optional[List[int]] = None
42
+
43
+ @property
44
+ def cp_split_sizes(self):
45
+ """Get the current context parallel split sizes."""
46
+ return self._cp_split_sizes
47
+
48
+ def _dispatch(self, x: torch.Tensor) -> torch.Tensor:
49
+ """
50
+ Dispatch a tensor to the context parallel region.
51
+
52
+ This method automatically handles padding and splits the tensor along the sequence
53
+ dimension across the context parallel group. The split sizes are calculated to
54
+ distribute the sequence length as evenly as possible across all ranks.
55
+
56
+ Args:
57
+ x: Input tensor with shape [seq_len, ...] where seq_len is the sequence length.
58
+
59
+ Returns:
60
+ Dispatched tensor that has been split and distributed across the context parallel group.
61
+
62
+ Raises:
63
+ AssertionError: If the split sizes change between calls, indicating inconsistent
64
+ sequence lengths or context parallel group size.
65
+ """
66
+ seq_len = x.shape[0]
67
+ cp_world_size = get_cp_world_size()
68
+ if seq_len % cp_world_size == 0:
69
+ cp_split_sizes = [seq_len // cp_world_size] * cp_world_size
70
+ else:
71
+ num_ranks_with_one_extra = seq_len % cp_world_size
72
+ min_tokens_per_rank = (seq_len - num_ranks_with_one_extra) // cp_world_size
73
+ cp_split_sizes = [min_tokens_per_rank + 1] * num_ranks_with_one_extra + [min_tokens_per_rank] * (
74
+ cp_world_size - num_ranks_with_one_extra
75
+ )
76
+ if self._cp_split_sizes is not None:
77
+ assert (
78
+ self._cp_split_sizes == cp_split_sizes
79
+ ), f"cp_split_sizes changed from {self._cp_split_sizes} to {cp_split_sizes}"
80
+ self._cp_split_sizes = cp_split_sizes
81
+ x = scatter_to_context_parallel_region(x, cp_split_sizes, group=get_cp_group())
82
+ return x
83
+
84
+ def _undispatch(self, x: torch.Tensor) -> torch.Tensor:
85
+ """
86
+ Undispatch a tensor from the context parallel region.
87
+
88
+ This method gathers the tensor parts from all ranks in the context parallel group
89
+ and concatenates them back into the original sequence. It automatically handles
90
+ unpadding if padding was applied during dispatch.
91
+
92
+ Args:
93
+ x: Dispatched tensor from the context parallel region.
94
+
95
+ Returns:
96
+ Reconstructed tensor with the original sequence length.
97
+ """
98
+ x = gather_from_context_parallel_region(x, self._cp_split_sizes, group=get_cp_group())
99
+ return x
100
+
101
+ def dispatch(self, tensors: T) -> T:
102
+ """
103
+ Apply dispatch operation to all tensor leaf nodes in a nested data structure.
104
+
105
+ This method recursively applies the _dispatch operation to all tensors in the
106
+ input data structure, preparing them for context parallel computation. The
107
+ structure of the input is preserved in the output.
108
+
109
+ Args:
110
+ tensors: Arbitrary nested data structure containing tensors (single tensor,
111
+ tuple, list, dict, etc.). All tensors should have the same sequence
112
+ length in their first dimension.
113
+
114
+ Returns:
115
+ A new data structure with the same structure as input, where all tensors
116
+ have been dispatched to the context parallel region.
117
+ """
118
+ return tree_map(self._dispatch, tensors)
119
+
120
+ def undispatch(self, tensors: T) -> T:
121
+ """
122
+ Apply undispatch operation to all tensor leaf nodes in a nested data structure.
123
+
124
+ This method recursively applies the _undispatch operation to all tensors in the
125
+ input data structure, reconstructing them from the context parallel region. The
126
+ structure of the input is preserved in the output.
127
+
128
+ Args:
129
+ tensors: Arbitrary nested data structure containing dispatched tensors.
130
+
131
+ Returns:
132
+ A new data structure with the same structure as input, where all tensors
133
+ have been reconstructed from the context parallel region.
134
+ """
135
+ output = tree_map(self._undispatch, tensors)
136
+ self._cp_split_sizes = None
137
+ return output
138
+
139
+ _ULYSSES_SCHEDULER = UlyssesScheduler()
140
+
141
+ def ulysses_scheduler() -> UlyssesScheduler:
142
+ assert _ULYSSES_SCHEDULER is not None, "ulysses scheduler is not initialized"
143
+ return _ULYSSES_SCHEDULER
inference/model/dit/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .dit_model import get_dit
16
+ from .dit_module import DiTModel
17
+
18
+ __all__ = ["DiTModel", "get_dit"]
inference/model/dit/dit_model.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import gc
16
+
17
+ import torch
18
+ from inference.infra.checkpoint import load_model_checkpoint
19
+ from inference.infra.distributed import get_cp_rank, get_pp_rank, get_tp_rank
20
+ from inference.utils import print_mem_info_rank_0, print_model_size, print_rank_0
21
+
22
+ from .dit_module import DiTModel
23
+
24
+
25
+ def get_dit(model_config, engine_config):
26
+ """Build and load DiT model."""
27
+ model = DiTModel(model_config=model_config)
28
+
29
+ print_rank_0("Build dit model successfully")
30
+ print_rank_0(model)
31
+ print_model_size(
32
+ model, prefix=f"(tp, cp, pp) rank ({get_tp_rank()}, {get_cp_rank()}, {get_pp_rank()}): ", print_func=print_rank_0
33
+ )
34
+
35
+ model = load_model_checkpoint(model, engine_config)
36
+ model.cuda(torch.cuda.current_device())
37
+ model.eval()
38
+ print_mem_info_rank_0("Load model successfully")
39
+
40
+ gc.collect()
41
+ torch.cuda.empty_cache()
42
+ return model
inference/model/dit/dit_module.py ADDED
@@ -0,0 +1,950 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ from dataclasses import dataclass
17
+ from enum import Enum
18
+ from typing import Any, Callable, List, Optional, Tuple
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from einops import rearrange, repeat
23
+ from inference.common import Modality, VarlenHandler, is_hopper_arch
24
+ from inference.infra.parallelism import ulysses_scheduler
25
+ from magi_compiler import magi_compile
26
+ from magi_compiler.api import magi_register_custom_op
27
+ from magi_compiler.config import CompileConfig
28
+ from torch import Tensor
29
+ from torch.nn import Parameter
30
+
31
+
32
+ @dataclass
33
+ class FFAHandler:
34
+ q_ranges: torch.Tensor
35
+ k_ranges: torch.Tensor
36
+ max_seqlen_q: int
37
+ max_seqlen_k: int
38
+ attn_type_map: torch.Tensor
39
+ softmax_scale: float
40
+
41
+
42
+ # Define the MLP activation type
43
+ class MLPActivationType(Enum):
44
+ """Enumeration of supported activation functions for MLP"""
45
+
46
+ SWIGLU7 = "swiglu7"
47
+ GELU7 = "gelu7"
48
+
49
+
50
+ def swiglu7(x, alpha: float = 1.702, limit: float = 7.0, out_dtype: Optional[torch.dtype] = None):
51
+ out_dtype = x.dtype if out_dtype is None else out_dtype
52
+ x = x.to(torch.float32)
53
+ x_glu, x_linear = x[..., ::2], x[..., 1::2]
54
+ # Clamp the input values
55
+ x_glu = x_glu.clamp(min=None, max=limit)
56
+ x_linear = x_linear.clamp(min=-limit, max=limit)
57
+ out_glu = x_glu * torch.sigmoid(alpha * x_glu)
58
+ # Note we add an extra bias of 1 to the linear layer (from GPT-OSS)
59
+ return (out_glu * (x_linear + 1)).to(out_dtype)
60
+
61
+
62
+ def gelu7(x, alpha: float = 1.702, limit: float = 7.0, out_dtype: Optional[torch.dtype] = None):
63
+ out_dtype = x.dtype if out_dtype is None else out_dtype
64
+ x = x.to(torch.float32)
65
+ x_glu = x
66
+ # Clamp the input values
67
+ x_glu = x_glu.clamp(min=None, max=limit)
68
+ out_glu = x_glu * torch.sigmoid(alpha * x_glu)
69
+ # Note we add an extra bias of 1 to the linear layer
70
+ return out_glu.to(out_dtype)
71
+
72
+
73
+ def create_activation_func(activation_type: MLPActivationType) -> Callable:
74
+ match activation_type:
75
+ case MLPActivationType.SWIGLU7:
76
+ return swiglu7
77
+ case MLPActivationType.GELU7:
78
+ return gelu7
79
+ case _:
80
+ raise ValueError(f"Unknown activation type: {activation_type}")
81
+
82
+
83
+ class ModalityDispatcher:
84
+ permuted_modality_mapping: torch.Tensor
85
+ group_size: torch.Tensor
86
+ group_size_cpu: list[int]
87
+ num_modalities: int
88
+
89
+ def __init__(self, modality_mapping: torch.Tensor, num_modalities: int):
90
+ """
91
+ Initialize dispatcher.
92
+ This runs once during object construction and precomputes all mappings.
93
+ """
94
+ self.modality_mapping = modality_mapping
95
+ self.num_modalities = num_modalities
96
+
97
+ self.permuted_modality_mapping = self._precompute_permute_mapping(modality_mapping)
98
+
99
+ self.group_size = torch.bincount(self.permuted_modality_mapping, minlength=num_modalities).to(torch.int32)
100
+ self.group_size_cpu: list[int] = [int(x) for x in self.group_size.to("cpu").tolist()]
101
+
102
+ def _precompute_permute_mapping(self, modality_mapping):
103
+ # 1. Compute forward and inverse permutation mappings.
104
+ # argsort is an efficient O(N log N) operation.
105
+ self.permute_mapping = torch.argsort(modality_mapping)
106
+ self.inv_permute_mapping = torch.argsort(self.permute_mapping)
107
+
108
+ # 2. Compute group size for each modality.
109
+ # bincount is highly efficient for counting.
110
+ permuted_modality_mapping = modality_mapping[self.permute_mapping]
111
+
112
+ return permuted_modality_mapping
113
+
114
+ def dispatch(self, x: torch.Tensor) -> list[torch.Tensor]:
115
+ grouped_tensors = torch.split(x, self.group_size_cpu, dim=0)
116
+ return list(grouped_tensors)
117
+
118
+ def undispatch(self, *processed_groups: list[torch.Tensor]) -> torch.Tensor:
119
+ return torch.cat(processed_groups, dim=0)
120
+
121
+ @staticmethod
122
+ def permute(x: torch.Tensor, permute_mapping: torch.Tensor) -> torch.Tensor:
123
+ """Apply forward permutation to tensor."""
124
+ return x[permute_mapping]
125
+
126
+ @staticmethod
127
+ def inv_permute(x: torch.Tensor, inv_permute_mapping: torch.Tensor) -> torch.Tensor:
128
+ """Apply inverse permutation to tensor."""
129
+ return x[inv_permute_mapping]
130
+
131
+
132
+ def freq_bands(
133
+ num_bands: int, temperature: float = 10000.0, step: int = 2, device: Optional[torch.device] = None
134
+ ) -> torch.Tensor:
135
+ exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
136
+ bands = 1.0 / (temperature**exp)
137
+ return bands
138
+
139
+
140
+ def rotate_half(x, interleaved=False):
141
+ if not interleaved:
142
+ x1, x2 = x.chunk(2, dim=-1)
143
+ return torch.cat((-x2, x1), dim=-1)
144
+ else:
145
+ x1, x2 = x[..., ::2], x[..., 1::2]
146
+ return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
147
+
148
+
149
+ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
150
+ """
151
+ x: (batch_size, seqlen, nheads, headdim)
152
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
153
+ """
154
+ ro_dim = cos.shape[-1] * 2
155
+ assert ro_dim <= x.shape[-1]
156
+ cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
157
+ sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
158
+ return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], dim=-1)
159
+
160
+
161
+ class ElementWiseFourierEmbed(nn.Module):
162
+ def __init__(
163
+ self,
164
+ dim: int,
165
+ max_res: int = 224,
166
+ temperature: float = 10000.0,
167
+ in_pixels: bool = True,
168
+ linear_bands: bool = False,
169
+ learnable: bool = False,
170
+ device: torch.device = torch.device("cpu"),
171
+ dtype: torch.dtype = torch.float32,
172
+ ):
173
+ """
174
+ Args:
175
+ dim: Output feature dimension, total channels, must be divisible by 6
176
+ max_res: Max pixel-frequency resolution for pixel-domain bands
177
+ temperature: Temperature in inverse-frequency mode
178
+ in_pixels: True -> pixel-frequency bands, False -> inverse-frequency bands
179
+ linear_bands: Whether pixel-frequency bands are linearly spaced
180
+ learnable: Whether frequency bands are trainable
181
+ """
182
+ super().__init__()
183
+ self.dim = dim
184
+ self.in_pixels = in_pixels
185
+ self.learnable = learnable
186
+ self.temperature = temperature
187
+ self.max_res = max_res
188
+ self.linear_bands = linear_bands
189
+ self.device = device
190
+ self.dtype = dtype
191
+ # Make frequency bands trainable or register as buffer
192
+ bands = self.get_default_bands()
193
+ if self.learnable:
194
+ self.bands = nn.Parameter(bands)
195
+ else:
196
+ self.register_buffer("bands", bands)
197
+
198
+ def forward(self, coords: torch.Tensor) -> torch.Tensor:
199
+ """
200
+ Args:
201
+ coords: [L,9], column order (time, row, col, T, H, W, ref_T, ref_H, ref_W)
202
+ Returns:
203
+ emb: [L, dim] element-wise Fourier embedding
204
+ """
205
+ # Use slicing instead of unbind + stack to reduce intermediates
206
+ coords_xyz = coords[:, :3] # [L,3] -> (t, h, w)
207
+ sizes = coords[:, 3:6] # [L,3] -> (T, H, W)
208
+ refs = coords[:, 6:9] # [L,3] -> (ref_T, ref_H, ref_W)
209
+
210
+ # Compute scale factors
211
+ scales = (refs - 1) / (sizes - 1) # [L,3]
212
+
213
+ # NOTE: if both ref and size are 1, scale is fixed to 1; otherwise invalid
214
+ scales[(refs == 1) & (sizes == 1)] = 1
215
+ assert not scales.isnan().any(), "scales has nan"
216
+ assert not scales.isinf().any(), "scales has inf"
217
+
218
+ # Center alignment: apply to h,w only (not time)
219
+ centers = (sizes - 1) / 2 # [L,3]
220
+ centers[:, 0] = 0 # Do not center the time dimension
221
+ coords_xyz = coords_xyz - centers # [L,3]
222
+
223
+ # Project to frequency bands in one shot: [L,3,B]
224
+ proj = coords_xyz.unsqueeze(-1) * scales.unsqueeze(-1) * self.bands
225
+
226
+ # Compute sin & cos and concatenate
227
+ sin_proj = proj.sin() # [L,3,B]
228
+ cos_proj = proj.cos()
229
+
230
+ return torch.cat((sin_proj, cos_proj), dim=1).flatten(1)
231
+
232
+ def reset_parameters(self):
233
+ bands = self.get_default_bands()
234
+ self.bands.copy_(bands)
235
+
236
+ def get_default_bands(self):
237
+ if self.in_pixels:
238
+ raise NotImplementedError("in_pixels are not implemented yet")
239
+ else:
240
+ bands = freq_bands(self.dim // 8, temperature=self.temperature, step=1, device=self.device).to(self.dtype)
241
+ return bands
242
+
243
+
244
+ class MultiModalityRMSNorm(nn.Module):
245
+ __constants__ = ["dim", "eps", "num_modality"]
246
+ dim: int
247
+ eps: float
248
+ num_modality: int
249
+
250
+ def __init__(self, dim: int, eps: float = 1e-6, device: torch.device | None = None, num_modality: int = 1):
251
+ super().__init__()
252
+ self.dim = dim
253
+ self.eps = eps
254
+ self.num_modality = num_modality
255
+
256
+ self.weight = torch.nn.Parameter(torch.zeros(dim * num_modality, device=device, dtype=torch.float32))
257
+ if num_modality > 1:
258
+ self.forward = self.forward_multi_experts
259
+ else:
260
+ self.forward = self.forward_single_expert
261
+
262
+ self.reset_parameters()
263
+
264
+ def reset_parameters(self):
265
+ nn.init.zeros_(self.weight)
266
+
267
+ def rms(self, x: torch.Tensor) -> torch.Tensor:
268
+ t, original_dtype = x.float(), x.dtype
269
+ t = t * torch.rsqrt(torch.mean(t**2, dim=-1, keepdim=True) + self.eps)
270
+ return t
271
+
272
+ def forward_multi_experts(self, x: torch.Tensor, modality_dispatcher: ModalityDispatcher) -> torch.Tensor:
273
+ original_dtype = x.dtype
274
+ t = self.rms(x)
275
+
276
+ weight_chunked = self.weight.chunk(self.num_modality, dim=0)
277
+ t_list = modality_dispatcher.dispatch(t)
278
+ for i in range(self.num_modality):
279
+ t_list[i] = t_list[i] * (weight_chunked[i] + 1)
280
+ t = modality_dispatcher.undispatch(*t_list)
281
+
282
+ return t.to(original_dtype)
283
+
284
+ def forward_single_expert(self, x: torch.Tensor, modality_dispatcher: Optional[ModalityDispatcher] = None) -> torch.Tensor:
285
+ t, original_dtype = x.float(), x.dtype
286
+ t = t * torch.rsqrt(torch.mean(t**2, dim=-1, keepdim=True) + self.eps)
287
+ return (t * (self.weight + 1)).to(original_dtype)
288
+
289
+
290
+ class _BF16ComputeLinear(torch.autograd.Function):
291
+ @staticmethod
292
+ def forward(
293
+ ctx,
294
+ input: torch.Tensor,
295
+ weight: torch.Tensor,
296
+ bias: Optional[torch.Tensor],
297
+ output_dtype: Optional[torch.dtype],
298
+ compute_dtype: torch.dtype = torch.bfloat16,
299
+ ):
300
+ # Convert input to specified input data type
301
+ input_cast = input.to(compute_dtype)
302
+ # Convert weight to computation data type
303
+ weight_cast = weight.to(compute_dtype)
304
+ # Perform linear operation
305
+ output = torch.matmul(input_cast, weight_cast.t())
306
+
307
+ # Add bias if present
308
+ if bias is not None:
309
+ bias_cast = bias.to(compute_dtype)
310
+ output = output + bias_cast
311
+ else:
312
+ bias_cast = None
313
+
314
+ # Convert output to specified output data type
315
+ return output.to(output_dtype)
316
+
317
+
318
+ class BaseLinear(nn.Module):
319
+ __constants__ = ["in_features", "out_features", "num_layers", "num_experts"]
320
+ in_features: int
321
+ out_features: int
322
+ num_layers_for_initialization: int
323
+ num_experts: int
324
+ weight: Tensor
325
+
326
+ def __init__(
327
+ self, in_features, out_features, num_layers_for_initialization, num_experts, bias=True, device=None, dtype=None
328
+ ):
329
+ super().__init__()
330
+ factory_kwargs = {"device": device, "dtype": torch.bfloat16}
331
+ self.in_features = in_features
332
+ self.out_features = out_features
333
+ self.num_layers_for_initialization = num_layers_for_initialization
334
+ self.num_experts = num_experts
335
+ self.use_bias = bias
336
+ self.weight = Parameter(torch.empty((out_features * num_experts, in_features), **factory_kwargs))
337
+ if bias:
338
+ self.bias = Parameter(torch.empty(out_features * num_experts, **factory_kwargs))
339
+ else:
340
+ self.register_parameter("bias", None)
341
+
342
+ def forward(
343
+ self,
344
+ input: torch.Tensor,
345
+ output_dtype: Optional[torch.dtype] = None,
346
+ modality_dispatcher: Optional[ModalityDispatcher] = None,
347
+ ) -> torch.Tensor:
348
+ output_dtype = input.dtype if output_dtype is None else output_dtype
349
+ return _BF16ComputeLinear.apply(input, self.weight, self.bias, output_dtype, torch.bfloat16)
350
+
351
+
352
+ class NativeMoELinear(BaseLinear):
353
+ def forward(
354
+ self,
355
+ input: torch.Tensor,
356
+ output_dtype: Optional[torch.dtype] = None,
357
+ modality_dispatcher: Optional[ModalityDispatcher] = None,
358
+ ) -> torch.Tensor:
359
+ output_dtype = input.dtype if output_dtype is None else output_dtype
360
+
361
+ input_list = modality_dispatcher.dispatch(input) # type: ignore
362
+ weight_chunked = self.weight.chunk(self.num_experts, dim=0)
363
+
364
+ if self.bias is not None:
365
+ bias_chunked = self.bias.chunk(self.num_experts, dim=0)
366
+
367
+ for i in range(self.num_experts):
368
+ input_list[i] = _BF16ComputeLinear.apply(
369
+ input_list[i],
370
+ weight_chunked[i],
371
+ bias_chunked[i] if self.bias is not None else None,
372
+ output_dtype,
373
+ torch.bfloat16,
374
+ )
375
+ return modality_dispatcher.undispatch(*input_list) # type: ignore
376
+
377
+
378
+ def create_linear(
379
+ in_features, out_features, num_layers=1, num_experts=1, bias=True, device=None, dtype=None
380
+ ) -> BaseLinear | NativeMoELinear:
381
+ if num_experts == 1:
382
+ return BaseLinear(in_features, out_features, num_layers, num_experts, bias, device, dtype)
383
+ else:
384
+ return NativeMoELinear(in_features, out_features, num_layers, num_experts, bias, device, dtype)
385
+
386
+
387
+ HAS_MAGI_ATTENTION = importlib.util.find_spec("magi_attention") is not None
388
+ HAS_FA3 = importlib.util.find_spec("flash_attn_interface") is not None
389
+
390
+
391
+ @magi_register_custom_op(name="infra::flash_attn_func", is_subgraph_boundary=True)
392
+ def flash_attn_func(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
393
+ if HAS_FA3 and is_hopper_arch():
394
+ from flash_attn_interface import flash_attn_func as fa3_flash_attn_func
395
+
396
+ return fa3_flash_attn_func(query, key, value)
397
+ else:
398
+ from flash_attn.flash_attn_interface import flash_attn_func as fa2_flash_attn_func
399
+
400
+ return fa2_flash_attn_func(query, key, value)
401
+
402
+
403
+ def _split_q_range_with_no_overlap(
404
+ q_ranges: torch.Tensor, k_ranges: torch.Tensor
405
+ ) -> Tuple[List[List[int]], List[List[List[int]]]]:
406
+ range_boundary = torch.unique(q_ranges, sorted=True).tolist()
407
+ candidates = [[start, end, []] for start, end in zip(range_boundary[:-1], range_boundary[1:])]
408
+ q_ranges = q_ranges.tolist()
409
+ k_ranges = k_ranges.tolist()
410
+ for q_range, k_range in zip(q_ranges, k_ranges):
411
+ q_start, q_end = q_range
412
+ for q_range_cand in candidates:
413
+ if q_start <= q_range_cand[0] and q_range_cand[1] <= q_end:
414
+ q_range_cand[2].append(k_range)
415
+ q_ranges_out = []
416
+ k_ranges_out = []
417
+ for q_range_cand in candidates:
418
+ if len(q_range_cand[2]) > 0:
419
+ q_ranges_out.append(q_range_cand[0:2])
420
+ k_ranges_out.append(q_range_cand[2])
421
+ return q_ranges_out, k_ranges_out
422
+
423
+
424
+ def _flash_attn_with_correction(
425
+ query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, q_ranges: List[List[int]], k_range_list: List[List[List[int]]]
426
+ ):
427
+ output = torch.zeros_like(query)
428
+ output_lse = torch.zeros((query.shape[0], query.shape[1]), dtype=torch.float32, device=query.device)
429
+
430
+ from flash_attn.flash_attn_interface import flash_attn_func
431
+
432
+ for q_range, k_ranges in zip(q_ranges, k_range_list):
433
+ q_start, q_end = q_range
434
+ qo_out, qo_lse = None, None
435
+ for k_range in k_ranges:
436
+ k_start, k_end = k_range
437
+ cur_qo_out, cur_qo_lse, _ = flash_attn_func(
438
+ query[q_start:q_end].unsqueeze(0),
439
+ key[k_start:k_end].unsqueeze(0),
440
+ value[k_start:k_end].unsqueeze(0),
441
+ return_attn_probs=True,
442
+ )
443
+ cur_qo_out, cur_qo_lse = cur_qo_out.squeeze(0), cur_qo_lse.squeeze(0)
444
+
445
+ if qo_out is None:
446
+ qo_out = cur_qo_out
447
+ qo_lse = cur_qo_lse
448
+ else:
449
+ qo_lse[qo_lse == torch.inf] = -torch.inf
450
+ cur_qo_lse[cur_qo_lse == torch.inf] = -torch.inf
451
+ max_lse = torch.max(qo_lse, cur_qo_lse)
452
+ qo_se, cur_qo_se = torch.exp(qo_lse - max_lse), torch.exp(cur_qo_lse - max_lse)
453
+ sum_se = qo_se + cur_qo_se
454
+ qo_scale, cur_qo_scale = qo_se / sum_se, cur_qo_se / sum_se
455
+
456
+ qo_out = qo_out * qo_scale.permute(1, 0).unsqueeze(-1) + cur_qo_out * cur_qo_scale.permute(1, 0).unsqueeze(-1)
457
+ qo_lse = torch.log(sum_se) + max_lse
458
+
459
+ output[q_start:q_end] = qo_out
460
+ output_lse[q_start:q_end, :] = qo_lse.permute(1, 0)
461
+ return output, output_lse
462
+
463
+
464
+ def _custom_flex_flash_attn_func(
465
+ query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, q_ranges: torch.Tensor, k_ranges: torch.Tensor, **kwargs
466
+ ):
467
+ q_ranges, k_range_list = _split_q_range_with_no_overlap(q_ranges, k_ranges)
468
+ return _flash_attn_with_correction(query, key, value, q_ranges, k_range_list)
469
+
470
+
471
+ def _flex_flash_attn_func_infer_output_meta(
472
+ query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, q_ranges: torch.Tensor, k_ranges: torch.Tensor
473
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
474
+ output = torch.empty_like(query)
475
+ output_lse = torch.empty((query.shape[0], query.shape[1]), dtype=torch.float32, device=query.device)
476
+ return output, output_lse
477
+
478
+
479
+ @magi_register_custom_op(
480
+ name="infra::flex_flash_attn_func",
481
+ mutates_args=(),
482
+ infer_output_meta_fn=_flex_flash_attn_func_infer_output_meta,
483
+ is_subgraph_boundary=True,
484
+ )
485
+ def flex_flash_attn_func(
486
+ query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, q_ranges: torch.Tensor, k_ranges: torch.Tensor
487
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
488
+ if HAS_MAGI_ATTENTION and is_hopper_arch():
489
+ from magi_attention.api import flex_flash_attn_func as magi_flex_flash_attn_func
490
+
491
+ return magi_flex_flash_attn_func(query, key, value, q_ranges, k_ranges)
492
+ else:
493
+ return _custom_flex_flash_attn_func(query, key, value, q_ranges, k_ranges)
494
+
495
+
496
+ def _attention_with_cp_infer_output_meta(q: torch.Tensor, *args, **kwargs) -> torch.Tensor:
497
+ return torch.empty_like(q, dtype=torch.bfloat16).squeeze(0)
498
+
499
+
500
+ @magi_register_custom_op(
501
+ name="infra::flash_attn_with_cp",
502
+ mutates_args=(),
503
+ infer_output_meta_fn=_attention_with_cp_infer_output_meta,
504
+ is_subgraph_boundary=True,
505
+ )
506
+ def flash_attn_with_cp(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cp_split_sizes: List[int]) -> torch.Tensor:
507
+ q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16)
508
+
509
+ from inference.infra.distributed import get_cp_group, get_cp_world_size
510
+ from inference.infra.parallelism.all_to_all_primitive import batch_scatter_head_gather_seqlen, scatter_seqlen_gather_head
511
+
512
+ if get_cp_world_size() > 1:
513
+ q, k, v = batch_scatter_head_gather_seqlen([q.squeeze(0), k.squeeze(0), v.squeeze(0)], cp_split_sizes, get_cp_group())
514
+ q = q.unsqueeze(0)
515
+ k = k.unsqueeze(0)
516
+ v = v.unsqueeze(0)
517
+
518
+ self_attn_out = torch.ops.infra.flash_attn_func(q, k, v).squeeze(0)
519
+
520
+ if get_cp_world_size() > 1:
521
+ self_attn_out = scatter_seqlen_gather_head(self_attn_out, cp_split_sizes, get_cp_group(), async_op=False)
522
+ self_attn_out = rearrange(self_attn_out, "(cp sq) hn hd -> sq (cp hn) hd", cp=get_cp_world_size())
523
+
524
+ return self_attn_out
525
+
526
+
527
+ @magi_register_custom_op(
528
+ name="infra::flex_flash_attn_with_cp",
529
+ mutates_args=(),
530
+ infer_output_meta_fn=_attention_with_cp_infer_output_meta,
531
+ is_subgraph_boundary=True,
532
+ )
533
+ def flex_flash_attn_with_cp(
534
+ q: torch.Tensor,
535
+ k: torch.Tensor,
536
+ v: torch.Tensor,
537
+ q_ranges: torch.Tensor,
538
+ k_ranges: torch.Tensor,
539
+ cp_split_sizes: List[int],
540
+ ) -> torch.Tensor:
541
+ q, k, v = q.to(torch.bfloat16).squeeze(0), k.to(torch.bfloat16).squeeze(0), v.to(torch.bfloat16).squeeze(0)
542
+
543
+ from inference.infra.distributed import get_cp_group, get_cp_world_size
544
+ from inference.infra.parallelism.all_to_all_primitive import batch_scatter_head_gather_seqlen, scatter_seqlen_gather_head
545
+
546
+ if get_cp_world_size() > 1:
547
+ q, k, v = batch_scatter_head_gather_seqlen([q, k, v], cp_split_sizes, get_cp_group())
548
+
549
+ out, _ = torch.ops.infra.flex_flash_attn_func(q, k, v, q_ranges=q_ranges, k_ranges=k_ranges)
550
+
551
+ if get_cp_world_size() > 1:
552
+ out = scatter_seqlen_gather_head(out, cp_split_sizes, get_cp_group(), async_op=False)
553
+ out = rearrange(out, "(cp sq) hn hd -> sq (cp hn) hd", cp=get_cp_world_size())
554
+
555
+ return out
556
+
557
+
558
+ @dataclass
559
+ class AttentionConfig:
560
+ hidden_size: int
561
+ num_heads_q: int
562
+ num_heads_kv: int
563
+ head_dim: int
564
+ params_dtype: torch.dtype
565
+ checkpoint_qk_layernorm_rope: bool
566
+ num_modality: int
567
+ num_layers: int
568
+ use_local_attn: bool = False
569
+ enable_attn_gating: bool = False
570
+
571
+
572
+ class Attention(torch.nn.Module):
573
+ config: AttentionConfig
574
+
575
+ def __init__(self, config: AttentionConfig):
576
+ super().__init__()
577
+ self.config = config
578
+
579
+ self.pre_norm = MultiModalityRMSNorm(config.hidden_size, eps=1e-6, num_modality=config.num_modality)
580
+ self.gating_size = config.num_heads_q if config.enable_attn_gating else 0
581
+
582
+ self.linear_qkv = create_linear(
583
+ config.hidden_size,
584
+ config.num_heads_q * config.head_dim + config.num_heads_kv * config.head_dim * 2 + self.gating_size,
585
+ num_experts=config.num_modality,
586
+ bias=False,
587
+ dtype=config.params_dtype,
588
+ num_layers=config.num_layers,
589
+ )
590
+ self.linear_proj = create_linear(
591
+ config.num_heads_q * config.head_dim,
592
+ config.hidden_size,
593
+ bias=False,
594
+ num_experts=config.num_modality,
595
+ dtype=config.params_dtype,
596
+ num_layers=config.num_layers,
597
+ )
598
+ self.q_norm = MultiModalityRMSNorm(config.head_dim, num_modality=config.num_modality)
599
+ self.k_norm = MultiModalityRMSNorm(config.head_dim, num_modality=config.num_modality)
600
+
601
+ self.q_size = config.num_heads_q * config.head_dim
602
+ self.kv_size = config.num_heads_kv * config.head_dim
603
+
604
+ def reset_parameters(self):
605
+ if hasattr(self.linear_proj, "reset_parameters_output_layer"):
606
+ self.linear_proj.reset_parameters_output_layer()
607
+
608
+ def forward(
609
+ self,
610
+ hidden_states: torch.Tensor,
611
+ rope: torch.Tensor,
612
+ permute_mapping: torch.Tensor,
613
+ inv_permute_mapping: torch.Tensor,
614
+ varlen_handler: VarlenHandler,
615
+ local_attn_handler: FFAHandler,
616
+ modality_dispatcher: ModalityDispatcher,
617
+ cp_split_sizes: List[int],
618
+ ) -> torch.Tensor:
619
+ hidden_states = self.pre_norm(hidden_states, modality_dispatcher=modality_dispatcher).to(torch.bfloat16)
620
+ qkv: torch.Tensor = self.linear_qkv(hidden_states, modality_dispatcher=modality_dispatcher).to(torch.float32)
621
+
622
+ q, k, v, g = torch.split(qkv, [self.q_size, self.kv_size, self.kv_size, self.gating_size], dim=1)
623
+ q = q.view(-1, self.config.num_heads_q, self.config.head_dim)
624
+ k = k.view(-1, self.config.num_heads_kv, self.config.head_dim)
625
+ v = v.view(-1, self.config.num_heads_kv, self.config.head_dim)
626
+ g = g.view(k.shape[0], self.config.num_heads_q, -1)
627
+
628
+ q = self.q_norm(q, modality_dispatcher=modality_dispatcher)
629
+ k = self.k_norm(k, modality_dispatcher=modality_dispatcher)
630
+
631
+ q = ModalityDispatcher.inv_permute(q, inv_permute_mapping).unsqueeze(0)
632
+ k = ModalityDispatcher.inv_permute(k, inv_permute_mapping).unsqueeze(0)
633
+ v = ModalityDispatcher.inv_permute(v, inv_permute_mapping).unsqueeze(0)
634
+
635
+ sin_emb, cos_emb = rope.tensor_split(2, -1)
636
+ q = apply_rotary_emb_torch(q, cos_emb, sin_emb)
637
+ k = apply_rotary_emb_torch(k, cos_emb, sin_emb)
638
+
639
+ if self.config.use_local_attn:
640
+ self_attn_out = flex_flash_attn_with_cp(
641
+ q, k, v, local_attn_handler.q_ranges, local_attn_handler.k_ranges, cp_split_sizes
642
+ )
643
+ else:
644
+ self_attn_out = flash_attn_with_cp(q, k, v, cp_split_sizes)
645
+ self_attn_out = ModalityDispatcher.permute(self_attn_out, permute_mapping)
646
+
647
+ if self.config.enable_attn_gating:
648
+ self_attn_out = self_attn_out * torch.sigmoid(g)
649
+
650
+ self_attn_out = self_attn_out.view(-1, self.config.num_heads_q * self.config.head_dim).to(torch.bfloat16)
651
+ out = self.linear_proj(self_attn_out, modality_dispatcher=modality_dispatcher)
652
+ return out
653
+
654
+
655
+ @dataclass
656
+ class MLPConfig:
657
+ hidden_size: int
658
+ intermediate_size: int
659
+ activation_type: MLPActivationType
660
+ params_dtype: torch.dtype
661
+ num_modality: int = 1
662
+ num_layers: int = 1
663
+ gated_act: bool = False
664
+
665
+
666
+ class MLP(torch.nn.Module):
667
+ config: MLPConfig
668
+
669
+ def __init__(self, config: MLPConfig):
670
+ super().__init__()
671
+ num_experts = config.num_modality
672
+ self.pre_norm = MultiModalityRMSNorm(config.hidden_size, num_modality=config.num_modality)
673
+ intermediate_size_up = config.intermediate_size * 2 if config.gated_act else config.intermediate_size
674
+
675
+ self.up_gate_proj = create_linear(
676
+ config.hidden_size,
677
+ intermediate_size_up,
678
+ bias=False,
679
+ dtype=config.params_dtype,
680
+ num_layers=config.num_layers,
681
+ num_experts=num_experts,
682
+ )
683
+ self.down_proj = create_linear(
684
+ config.intermediate_size,
685
+ config.hidden_size,
686
+ bias=False,
687
+ dtype=config.params_dtype,
688
+ num_layers=config.num_layers,
689
+ num_experts=num_experts,
690
+ )
691
+ self.activation_func = create_activation_func(config.activation_type)
692
+
693
+ def forward(self, x: torch.Tensor, modality_dispatcher: ModalityDispatcher) -> torch.Tensor:
694
+ x = self.pre_norm(x, modality_dispatcher=modality_dispatcher).to(torch.bfloat16)
695
+ x = self.up_gate_proj(x, modality_dispatcher=modality_dispatcher).to(torch.float32)
696
+ x = self.activation_func(x).to(torch.bfloat16)
697
+ x = self.down_proj(x, modality_dispatcher=modality_dispatcher).to(torch.float32)
698
+ return x
699
+
700
+ def extra_repr(self) -> str:
701
+ return f"{self.up_gate_proj.weight.shape=}, {self.down_proj.weight.shape=}"
702
+
703
+
704
+ @dataclass
705
+ class AdapterConfig:
706
+ hidden_size: int
707
+ num_attention_heads: int
708
+ text_in_channels: int
709
+ video_in_channels: int
710
+ audio_in_channels: int
711
+ params_dtype: torch.dtype
712
+
713
+
714
+ class Adapter(torch.nn.Module):
715
+ config: AdapterConfig
716
+
717
+ def __init__(self, config: AdapterConfig):
718
+ super().__init__()
719
+ self.config = config
720
+ self.video_embedder = nn.Linear(config.video_in_channels, config.hidden_size, bias=True, dtype=torch.float32)
721
+ self.text_embedder = nn.Linear(config.text_in_channels, config.hidden_size, bias=True, dtype=torch.float32)
722
+ self.audio_embedder = nn.Linear(config.audio_in_channels, config.hidden_size, bias=True, dtype=torch.float32)
723
+ self.rope = ElementWiseFourierEmbed(config.hidden_size // config.num_attention_heads, in_pixels=False, learnable=False)
724
+
725
+ def forward(
726
+ self,
727
+ x: torch.Tensor,
728
+ coords_mapping: torch.Tensor,
729
+ video_mask: torch.Tensor,
730
+ audio_mask: torch.Tensor,
731
+ text_mask: torch.Tensor,
732
+ ):
733
+ rope = self.rope(coords_mapping)
734
+ output_x = torch.zeros(x.shape[0], self.config.hidden_size, device=x.device, dtype=x.dtype)
735
+ output_x[text_mask] = self.text_embedder(x[text_mask, : self.config.text_in_channels])
736
+ output_x[audio_mask] = self.audio_embedder(x[audio_mask, : self.config.audio_in_channels])
737
+ output_x[video_mask] = self.video_embedder(x[video_mask, : self.config.video_in_channels])
738
+ return output_x, rope
739
+
740
+
741
+ class TransFormerLayer(torch.nn.Module):
742
+ def __init__(self, config: Any, layer_idx: int):
743
+ super().__init__()
744
+ num_modality = 3 if layer_idx in config.mm_layers else 1
745
+ use_local_attn = layer_idx in config.local_attn_layers
746
+ self.post_norm = layer_idx in config.post_norm_layers
747
+ attention_config = AttentionConfig(
748
+ hidden_size=config.hidden_size,
749
+ num_heads_q=config.num_heads_q,
750
+ num_heads_kv=config.num_heads_kv,
751
+ head_dim=config.head_dim,
752
+ params_dtype=config.params_dtype,
753
+ checkpoint_qk_layernorm_rope=config.checkpoint_qk_layernorm_rope,
754
+ num_modality=num_modality,
755
+ num_layers=config.num_layers,
756
+ use_local_attn=use_local_attn,
757
+ enable_attn_gating=config.enable_attn_gating,
758
+ )
759
+ self.attention: Attention = Attention(attention_config)
760
+
761
+ activation_type = MLPActivationType.GELU7 if layer_idx in config.gelu7_layers else MLPActivationType.SWIGLU7
762
+ if activation_type == MLPActivationType.SWIGLU7:
763
+ gated_act = True
764
+ intermediate_size = int(config.hidden_size * 4 * 2 / 3) // 4 * 4
765
+ else:
766
+ gated_act = False
767
+ intermediate_size = config.hidden_size * 4
768
+ mlp_config = MLPConfig(
769
+ hidden_size=config.hidden_size,
770
+ intermediate_size=intermediate_size,
771
+ activation_type=activation_type,
772
+ params_dtype=config.params_dtype,
773
+ num_modality=num_modality,
774
+ num_layers=config.num_layers,
775
+ gated_act=gated_act,
776
+ )
777
+ self.mlp: MLP = MLP(mlp_config)
778
+ if self.post_norm:
779
+ self.attn_post_norm = MultiModalityRMSNorm(config.hidden_size, num_modality=num_modality)
780
+ self.mlp_post_norm = MultiModalityRMSNorm(config.hidden_size, num_modality=num_modality)
781
+
782
+ def forward(
783
+ self,
784
+ hidden_states: torch.Tensor,
785
+ rope: torch.Tensor,
786
+ permute_mapping: torch.Tensor,
787
+ inv_permute_mapping: torch.Tensor,
788
+ varlen_handler: VarlenHandler,
789
+ local_attn_handler: FFAHandler,
790
+ modality_dispatcher: ModalityDispatcher,
791
+ cp_split_sizes: List[int],
792
+ ) -> torch.Tensor:
793
+ attn_out = self.attention(
794
+ hidden_states,
795
+ rope,
796
+ permute_mapping,
797
+ inv_permute_mapping,
798
+ varlen_handler,
799
+ local_attn_handler,
800
+ modality_dispatcher,
801
+ cp_split_sizes,
802
+ )
803
+ if self.post_norm:
804
+ attn_out = self.attn_post_norm(attn_out, modality_dispatcher=modality_dispatcher)
805
+ hidden_states = hidden_states + attn_out
806
+
807
+ mlp_out = self.mlp(hidden_states, modality_dispatcher)
808
+ if self.post_norm:
809
+ mlp_out = self.mlp_post_norm(mlp_out, modality_dispatcher=modality_dispatcher)
810
+ hidden_states = hidden_states + mlp_out
811
+ return hidden_states
812
+
813
+
814
+ is_base_model = True
815
+
816
+
817
+ def config_patch(compile_config: CompileConfig) -> CompileConfig:
818
+ global is_base_model
819
+ if is_base_model:
820
+ is_base_model = False
821
+ else:
822
+ # Fully offload SR model for memory-constrained GPU
823
+ compile_config.offload_config.gpu_resident_weight_ratio = 0.0
824
+ return compile_config
825
+
826
+
827
+ @magi_compile(config_patch=config_patch)
828
+ class TransformerBlock(torch.nn.Module):
829
+ def __init__(self, model_config: Any):
830
+ super().__init__()
831
+ self.layers: list[TransFormerLayer] = nn.ModuleList()
832
+ for layer_idx in range(model_config.num_layers):
833
+ self.layers.append(TransFormerLayer(model_config, layer_idx))
834
+
835
+ def forward(
836
+ self,
837
+ x: torch.Tensor,
838
+ rope: torch.Tensor,
839
+ permute_mapping: torch.Tensor,
840
+ inv_permute_mapping: torch.Tensor,
841
+ varlen_handler: VarlenHandler,
842
+ local_attn_handler: FFAHandler,
843
+ modality_dispatcher: ModalityDispatcher,
844
+ cp_split_sizes: List[int],
845
+ ) -> torch.Tensor:
846
+ for _, layer in enumerate(self.layers):
847
+ x = layer(
848
+ x,
849
+ rope,
850
+ permute_mapping,
851
+ inv_permute_mapping,
852
+ varlen_handler,
853
+ local_attn_handler,
854
+ modality_dispatcher,
855
+ cp_split_sizes,
856
+ )
857
+ return x
858
+
859
+
860
+ @dataclass
861
+ class TransformerConfig:
862
+ hidden_size: int
863
+ video_in_channels: int
864
+ audio_in_channels: int
865
+ text_in_channels: int
866
+ params_dtype: torch.dtype
867
+ post_process_dtype: torch.dtype
868
+
869
+
870
+ class DiTModel(torch.nn.Module):
871
+ config: TransformerConfig
872
+
873
+ def __init__(self, model_config: Any):
874
+ super().__init__()
875
+ self.config = TransformerConfig(
876
+ hidden_size=model_config.hidden_size,
877
+ video_in_channels=model_config.video_in_channels,
878
+ audio_in_channels=model_config.audio_in_channels,
879
+ text_in_channels=model_config.text_in_channels,
880
+ params_dtype=model_config.params_dtype,
881
+ post_process_dtype=torch.float32,
882
+ )
883
+ adapter_config = AdapterConfig(
884
+ hidden_size=model_config.hidden_size,
885
+ num_attention_heads=model_config.num_heads_q,
886
+ text_in_channels=model_config.text_in_channels,
887
+ video_in_channels=model_config.video_in_channels,
888
+ audio_in_channels=model_config.audio_in_channels,
889
+ params_dtype=torch.float32,
890
+ )
891
+ self.adapter: Adapter = Adapter(adapter_config)
892
+ self.block: TransformerBlock = TransformerBlock(model_config=model_config)
893
+ self.final_norm_video = MultiModalityRMSNorm(self.config.hidden_size)
894
+ self.final_norm_audio = MultiModalityRMSNorm(self.config.hidden_size)
895
+ self.final_linear_video = nn.Linear(
896
+ self.config.hidden_size, self.config.video_in_channels, bias=False, dtype=torch.float32
897
+ )
898
+ self.final_linear_audio = nn.Linear(
899
+ self.config.hidden_size, self.config.audio_in_channels, bias=False, dtype=torch.float32
900
+ )
901
+
902
+ def forward(
903
+ self,
904
+ x: torch.Tensor,
905
+ coords_mapping: torch.Tensor,
906
+ modality_mapping: torch.Tensor,
907
+ varlen_handler: VarlenHandler,
908
+ local_attn_handler: FFAHandler,
909
+ ):
910
+ x = ulysses_scheduler().dispatch(x)
911
+ coords_mapping = ulysses_scheduler().dispatch(coords_mapping)
912
+ modality_mapping = ulysses_scheduler().dispatch(modality_mapping)
913
+ cp_split_sizes = ulysses_scheduler().cp_split_sizes
914
+
915
+ modality_dispatcher = ModalityDispatcher(modality_mapping, 3)
916
+ permute_mapping, inv_permute_mapping = modality_dispatcher.permute_mapping, modality_dispatcher.inv_permute_mapping
917
+ video_mask = modality_mapping == Modality.VIDEO
918
+ audio_mask = modality_mapping == Modality.AUDIO
919
+ text_mask = modality_mapping == Modality.TEXT
920
+
921
+ x, rope = self.adapter(x, coords_mapping, video_mask, audio_mask, text_mask)
922
+ x = x.to(self.config.params_dtype)
923
+ x = ModalityDispatcher.permute(x, permute_mapping)
924
+ x = self.block(
925
+ x,
926
+ rope,
927
+ permute_mapping=permute_mapping,
928
+ inv_permute_mapping=inv_permute_mapping,
929
+ varlen_handler=varlen_handler,
930
+ local_attn_handler=local_attn_handler,
931
+ modality_dispatcher=modality_dispatcher,
932
+ cp_split_sizes=cp_split_sizes,
933
+ )
934
+ x = ModalityDispatcher.inv_permute(x, inv_permute_mapping)
935
+
936
+ x_video = x[video_mask].to(self.final_norm_video.weight.dtype)
937
+ x_video = self.final_norm_video(x_video)
938
+ x_video = self.final_linear_video(x_video)
939
+
940
+ x_audio = x[audio_mask].to(self.final_norm_audio.weight.dtype)
941
+ x_audio = self.final_norm_audio(x_audio)
942
+ x_audio = self.final_linear_audio(x_audio)
943
+
944
+ x_out = torch.zeros(
945
+ x.shape[0], max(self.config.video_in_channels, self.config.audio_in_channels), device=x.device, dtype=x.dtype
946
+ )
947
+ x_out[video_mask, : self.config.video_in_channels] = x_video
948
+ x_out[audio_mask, : self.config.audio_in_channels] = x_audio
949
+ x_out = ulysses_scheduler().undispatch(x_out)
950
+ return x_out
inference/model/sa_audio/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sa_audio_model import SAAudioFeatureExtractor
2
+ from .sa_audio_module import (
3
+ AudioAutoencoder,
4
+ OobleckDecoder,
5
+ OobleckEncoder,
6
+ VAEBottleneck,
7
+ create_autoencoder_from_config,
8
+ create_bottleneck_from_config,
9
+ create_decoder_from_config,
10
+ create_encoder_from_config,
11
+ create_model_from_config,
12
+ )
13
+
14
+ __all__ = [
15
+ "SAAudioFeatureExtractor",
16
+ "AudioAutoencoder",
17
+ "OobleckDecoder",
18
+ "OobleckEncoder",
19
+ "VAEBottleneck",
20
+ "create_autoencoder_from_config",
21
+ "create_bottleneck_from_config",
22
+ "create_decoder_from_config",
23
+ "create_encoder_from_config",
24
+ "create_model_from_config",
25
+ ]
inference/model/sa_audio/sa_audio_model.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ from pathlib import Path
18
+
19
+ import torch
20
+ from safetensors.torch import load_file
21
+
22
+ # Set env vars for local T5 loading
23
+ os.environ["TRANSFORMERS_OFFLINE"] = "1"
24
+ os.environ["HF_HUB_OFFLINE"] = "1"
25
+
26
+ from .sa_audio_module import create_model_from_config
27
+
28
+ from inference.utils import print_rank_0
29
+
30
+
31
+ class SAAudioFeatureExtractor:
32
+ """Stable Audio Feature Extractor that loads model once and reuses it."""
33
+
34
+ def __init__(self, device, model_path):
35
+ """Initialize the extractor with model loading."""
36
+ self.device = device
37
+ self.vae_model, self.sample_rate = self._get_vae_only(model_path)
38
+ # self.vae_model.to(self.device).to(torch.bfloat16)
39
+ self.resampler = None # Will be initialized when needed
40
+
41
+ def _get_vae_only(self, model_path):
42
+ """Load VAE only, skip T5 and diffusion model."""
43
+ if isinstance(model_path, str) and Path(model_path).is_dir():
44
+ try:
45
+ # Read full config
46
+ model_config_path = os.path.join(model_path, "model_config.json")
47
+ with open(model_config_path) as f:
48
+ full_config = json.load(f)
49
+
50
+ vae_config = full_config["model"]["pretransform"]["config"]
51
+ sample_rate = full_config["sample_rate"]
52
+
53
+ # Rebuild config structure expected by create_autoencoder_from_config
54
+ autoencoder_config = {
55
+ "model_type": "autoencoder",
56
+ "sample_rate": sample_rate, # sample_rate is required
57
+ "model": vae_config, # create_autoencoder_from_config expects key "model"
58
+ }
59
+
60
+ vae_model = create_model_from_config(autoencoder_config)
61
+ # Load weights
62
+ weights_path = Path(model_path) / "model.safetensors"
63
+
64
+ if not weights_path.exists():
65
+ raise FileNotFoundError(f"Weight file does not exist: {weights_path}")
66
+
67
+ # Load full state dict
68
+ full_state_dict = load_file(weights_path, device=str(self.device))
69
+
70
+ # Filter VAE-related weights (prefix: pretransform.model)
71
+ vae_state_dict = {}
72
+ for key, value in full_state_dict.items():
73
+ if key.startswith("pretransform.model."):
74
+ vae_key = key[len("pretransform.model.") :]
75
+ vae_state_dict[vae_key] = value
76
+
77
+ # Check expected model keys
78
+ model_keys = set(vae_model.state_dict().keys())
79
+ vae_keys = set(vae_state_dict.keys())
80
+
81
+ missing_keys = model_keys - vae_keys
82
+ extra_keys = vae_keys - model_keys
83
+
84
+ if missing_keys:
85
+ print_rank_0(f"Missing keys ({len(missing_keys)}):")
86
+ for key in list(missing_keys)[:5]:
87
+ print_rank_0(f" - {key}")
88
+
89
+ if extra_keys:
90
+ print_rank_0(f"Unexpected keys ({len(extra_keys)}):")
91
+ for key in list(extra_keys)[:5]:
92
+ print_rank_0(f" + {key}")
93
+
94
+ # Load VAE weights
95
+ vae_model.load_state_dict(vae_state_dict)
96
+ vae_model.to(self.device)
97
+
98
+ return vae_model, sample_rate
99
+
100
+ except Exception as e:
101
+ print_rank_0(f"audio model loading failed: {e}")
102
+ raise RuntimeError(
103
+ "Failed to load VAE-only Stable Audio model from local path"
104
+ ) from e
105
+ else:
106
+ print_rank_0("Non-local path is not supported in audio model loading")
107
+
108
+ def decode(self, latents):
109
+ with torch.no_grad():
110
+ waveform_out = self.vae_model.decode(latents)
111
+ return waveform_out
112
+
113
+ def encode(self, waveform):
114
+ with torch.no_grad():
115
+ latents = self.vae_model.encode(waveform)
116
+ return latents
inference/model/sa_audio/sa_audio_module.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+
6
+ import math
7
+ from typing import Any, Dict, Literal
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+ from torch.nn.utils import weight_norm
13
+
14
+
15
+ def snake_beta(x, alpha, beta):
16
+ return x + (1.0 / (beta + 1e-9)) * torch.pow(torch.sin(x * alpha), 2)
17
+
18
+
19
+ class SnakeBeta(nn.Module):
20
+ # Adapted from BigVGAN activation.
21
+ def __init__(
22
+ self,
23
+ in_features: int,
24
+ alpha: float = 1.0,
25
+ alpha_trainable: bool = True,
26
+ alpha_logscale: bool = True,
27
+ ):
28
+ super().__init__()
29
+ self.alpha_logscale = alpha_logscale
30
+ if self.alpha_logscale:
31
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
32
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
33
+ else:
34
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
35
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
36
+
37
+ self.alpha.requires_grad = alpha_trainable
38
+ self.beta.requires_grad = alpha_trainable
39
+
40
+ def forward(self, x):
41
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
42
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
43
+ if self.alpha_logscale:
44
+ alpha = torch.exp(alpha)
45
+ beta = torch.exp(beta)
46
+ return snake_beta(x, alpha, beta)
47
+
48
+
49
+ def vae_sample(mean, scale):
50
+ stdev = F.softplus(scale) + 1e-4
51
+ var = stdev * stdev
52
+ logvar = torch.log(var)
53
+ latents = torch.randn_like(mean) * stdev + mean
54
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
55
+ return latents, kl
56
+
57
+
58
+ class VAEBottleneck(nn.Module):
59
+ def __init__(self):
60
+ super().__init__()
61
+
62
+ def encode(self, x, return_info=False, **kwargs):
63
+ info = {}
64
+ mean, scale = x.chunk(2, dim=1)
65
+ x, kl = vae_sample(mean, scale)
66
+ info["kl"] = kl
67
+ if return_info:
68
+ return x, info
69
+ return x
70
+
71
+ def decode(self, x):
72
+ return x
73
+
74
+
75
+ def WNConv1d(*args, **kwargs):
76
+ return weight_norm(nn.Conv1d(*args, **kwargs))
77
+
78
+
79
+ def WNConvTranspose1d(*args, **kwargs):
80
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
81
+
82
+
83
+ def checkpoint(function, *args, **kwargs):
84
+ kwargs.setdefault("use_reentrant", False)
85
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
86
+
87
+
88
+ def get_activation(
89
+ activation: Literal["elu", "snake", "none"], antialias: bool = False, channels=None
90
+ ) -> nn.Module:
91
+ if antialias:
92
+ raise NotImplementedError("antialias activation is not supported in sa_audio")
93
+
94
+ if activation == "elu":
95
+ return nn.ELU()
96
+ if activation == "snake":
97
+ return SnakeBeta(channels)
98
+ if activation == "none":
99
+ return nn.Identity()
100
+ raise ValueError(f"Unknown activation {activation}")
101
+
102
+
103
+ class ResidualUnit(nn.Module):
104
+ def __init__(
105
+ self,
106
+ in_channels: int,
107
+ out_channels: int,
108
+ dilation: int,
109
+ use_snake: bool = False,
110
+ antialias_activation: bool = False,
111
+ ):
112
+ super().__init__()
113
+ padding = (dilation * (7 - 1)) // 2
114
+ self.layers = nn.Sequential(
115
+ get_activation(
116
+ "snake" if use_snake else "elu",
117
+ antialias=antialias_activation,
118
+ channels=out_channels,
119
+ ),
120
+ WNConv1d(
121
+ in_channels=in_channels,
122
+ out_channels=out_channels,
123
+ kernel_size=7,
124
+ dilation=dilation,
125
+ padding=padding,
126
+ ),
127
+ get_activation(
128
+ "snake" if use_snake else "elu",
129
+ antialias=antialias_activation,
130
+ channels=out_channels,
131
+ ),
132
+ WNConv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=1),
133
+ )
134
+
135
+ def forward(self, x):
136
+ if self.training:
137
+ y = checkpoint(self.layers, x)
138
+ else:
139
+ y = self.layers(x)
140
+ return y + x
141
+
142
+
143
+ class EncoderBlock(nn.Module):
144
+ def __init__(
145
+ self,
146
+ in_channels: int,
147
+ out_channels: int,
148
+ stride: int,
149
+ use_snake: bool = False,
150
+ antialias_activation: bool = False,
151
+ ):
152
+ super().__init__()
153
+ self.layers = nn.Sequential(
154
+ ResidualUnit(in_channels, in_channels, 1, use_snake=use_snake),
155
+ ResidualUnit(in_channels, in_channels, 3, use_snake=use_snake),
156
+ ResidualUnit(in_channels, in_channels, 9, use_snake=use_snake),
157
+ get_activation(
158
+ "snake" if use_snake else "elu",
159
+ antialias=antialias_activation,
160
+ channels=in_channels,
161
+ ),
162
+ WNConv1d(
163
+ in_channels=in_channels,
164
+ out_channels=out_channels,
165
+ kernel_size=2 * stride,
166
+ stride=stride,
167
+ padding=math.ceil(stride / 2),
168
+ ),
169
+ )
170
+
171
+ def forward(self, x):
172
+ return self.layers(x)
173
+
174
+
175
+ class DecoderBlock(nn.Module):
176
+ def __init__(
177
+ self,
178
+ in_channels: int,
179
+ out_channels: int,
180
+ stride: int,
181
+ use_snake: bool = False,
182
+ antialias_activation: bool = False,
183
+ use_nearest_upsample: bool = False,
184
+ ):
185
+ super().__init__()
186
+
187
+ if use_nearest_upsample:
188
+ upsample_layer = nn.Sequential(
189
+ nn.Upsample(scale_factor=stride, mode="nearest"),
190
+ WNConv1d(
191
+ in_channels=in_channels,
192
+ out_channels=out_channels,
193
+ kernel_size=2 * stride,
194
+ stride=1,
195
+ bias=False,
196
+ padding="same",
197
+ ),
198
+ )
199
+ else:
200
+ upsample_layer = WNConvTranspose1d(
201
+ in_channels=in_channels,
202
+ out_channels=out_channels,
203
+ kernel_size=2 * stride,
204
+ stride=stride,
205
+ padding=math.ceil(stride / 2),
206
+ )
207
+
208
+ self.layers = nn.Sequential(
209
+ get_activation(
210
+ "snake" if use_snake else "elu",
211
+ antialias=antialias_activation,
212
+ channels=in_channels,
213
+ ),
214
+ upsample_layer,
215
+ ResidualUnit(out_channels, out_channels, 1, use_snake=use_snake),
216
+ ResidualUnit(out_channels, out_channels, 3, use_snake=use_snake),
217
+ ResidualUnit(out_channels, out_channels, 9, use_snake=use_snake),
218
+ )
219
+
220
+ def forward(self, x):
221
+ return self.layers(x)
222
+
223
+
224
+ class OobleckEncoder(nn.Module):
225
+ def __init__(
226
+ self,
227
+ in_channels: int = 2,
228
+ channels: int = 128,
229
+ latent_dim: int = 32,
230
+ c_mults=[1, 2, 4, 8],
231
+ strides=[2, 4, 8, 8],
232
+ use_snake: bool = False,
233
+ antialias_activation: bool = False,
234
+ ):
235
+ super().__init__()
236
+
237
+ c_mults = [1] + c_mults
238
+ depth = len(c_mults)
239
+
240
+ layers = [
241
+ WNConv1d(
242
+ in_channels=in_channels,
243
+ out_channels=c_mults[0] * channels,
244
+ kernel_size=7,
245
+ padding=3,
246
+ )
247
+ ]
248
+
249
+ for i in range(depth - 1):
250
+ layers.append(
251
+ EncoderBlock(
252
+ in_channels=c_mults[i] * channels,
253
+ out_channels=c_mults[i + 1] * channels,
254
+ stride=strides[i],
255
+ use_snake=use_snake,
256
+ )
257
+ )
258
+
259
+ layers.extend(
260
+ [
261
+ get_activation(
262
+ "snake" if use_snake else "elu",
263
+ antialias=antialias_activation,
264
+ channels=c_mults[-1] * channels,
265
+ ),
266
+ WNConv1d(
267
+ in_channels=c_mults[-1] * channels,
268
+ out_channels=latent_dim,
269
+ kernel_size=3,
270
+ padding=1,
271
+ ),
272
+ ]
273
+ )
274
+
275
+ self.layers = nn.Sequential(*layers)
276
+
277
+ def forward(self, x):
278
+ return self.layers(x)
279
+
280
+
281
+ class OobleckDecoder(nn.Module):
282
+ def __init__(
283
+ self,
284
+ out_channels: int = 2,
285
+ channels: int = 128,
286
+ latent_dim: int = 32,
287
+ c_mults=[1, 2, 4, 8],
288
+ strides=[2, 4, 8, 8],
289
+ use_snake: bool = False,
290
+ antialias_activation: bool = False,
291
+ use_nearest_upsample: bool = False,
292
+ final_tanh: bool = True,
293
+ ):
294
+ super().__init__()
295
+
296
+ c_mults = [1] + c_mults
297
+ depth = len(c_mults)
298
+
299
+ layers = [
300
+ WNConv1d(
301
+ in_channels=latent_dim,
302
+ out_channels=c_mults[-1] * channels,
303
+ kernel_size=7,
304
+ padding=3,
305
+ )
306
+ ]
307
+
308
+ for i in range(depth - 1, 0, -1):
309
+ layers.append(
310
+ DecoderBlock(
311
+ in_channels=c_mults[i] * channels,
312
+ out_channels=c_mults[i - 1] * channels,
313
+ stride=strides[i - 1],
314
+ use_snake=use_snake,
315
+ antialias_activation=antialias_activation,
316
+ use_nearest_upsample=use_nearest_upsample,
317
+ )
318
+ )
319
+
320
+ layers.extend(
321
+ [
322
+ get_activation(
323
+ "snake" if use_snake else "elu",
324
+ antialias=antialias_activation,
325
+ channels=c_mults[0] * channels,
326
+ ),
327
+ WNConv1d(
328
+ in_channels=c_mults[0] * channels,
329
+ out_channels=out_channels,
330
+ kernel_size=7,
331
+ padding=3,
332
+ bias=False,
333
+ ),
334
+ nn.Tanh() if final_tanh else nn.Identity(),
335
+ ]
336
+ )
337
+
338
+ self.layers = nn.Sequential(*layers)
339
+
340
+ def forward(self, x):
341
+ return self.layers(x)
342
+
343
+
344
+ class AudioAutoencoder(nn.Module):
345
+ def __init__(
346
+ self,
347
+ encoder: nn.Module,
348
+ decoder: nn.Module,
349
+ latent_dim: int,
350
+ downsampling_ratio: int,
351
+ sample_rate: int,
352
+ io_channels: int = 2,
353
+ bottleneck: nn.Module | None = None,
354
+ in_channels: int | None = None,
355
+ out_channels: int | None = None,
356
+ soft_clip: bool = False,
357
+ ):
358
+ super().__init__()
359
+ self.downsampling_ratio = downsampling_ratio
360
+ self.sample_rate = sample_rate
361
+ self.latent_dim = latent_dim
362
+ self.io_channels = io_channels
363
+ self.in_channels = in_channels if in_channels is not None else io_channels
364
+ self.out_channels = out_channels if out_channels is not None else io_channels
365
+ self.bottleneck = bottleneck
366
+ self.encoder = encoder
367
+ self.decoder = decoder
368
+ self.soft_clip = soft_clip
369
+
370
+ def encode(self, audio, skip_bottleneck: bool = False, return_info: bool = False, **kwargs):
371
+ info = {}
372
+ latents = self.encoder(audio)
373
+ info["pre_bottleneck_latents"] = latents
374
+
375
+ if self.bottleneck is not None and not skip_bottleneck:
376
+ latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
377
+ info.update(bottleneck_info)
378
+
379
+ if return_info:
380
+ return latents, info
381
+ return latents
382
+
383
+ def decode(self, latents, skip_bottleneck: bool = False, **kwargs):
384
+ if self.bottleneck is not None and not skip_bottleneck:
385
+ latents = self.bottleneck.decode(latents)
386
+ decoded = self.decoder(latents, **kwargs)
387
+ if self.soft_clip:
388
+ decoded = torch.tanh(decoded)
389
+ return decoded
390
+
391
+
392
+ # AE factories
393
+
394
+ def create_encoder_from_config(encoder_config: Dict[str, Any]):
395
+ encoder_type = encoder_config.get("type", None)
396
+ assert encoder_type is not None, "Encoder type must be specified"
397
+ if encoder_type != "oobleck":
398
+ raise ValueError(f"Only encoder type 'oobleck' is supported, got: {encoder_type}")
399
+
400
+ encoder = OobleckEncoder(**encoder_config["config"])
401
+ if not encoder_config.get("requires_grad", True):
402
+ for param in encoder.parameters():
403
+ param.requires_grad = False
404
+ return encoder
405
+
406
+
407
+ def create_decoder_from_config(decoder_config: Dict[str, Any]):
408
+ decoder_type = decoder_config.get("type", None)
409
+ assert decoder_type is not None, "Decoder type must be specified"
410
+ if decoder_type != "oobleck":
411
+ raise ValueError(f"Only decoder type 'oobleck' is supported, got: {decoder_type}")
412
+
413
+ decoder = OobleckDecoder(**decoder_config["config"])
414
+ if not decoder_config.get("requires_grad", True):
415
+ for param in decoder.parameters():
416
+ param.requires_grad = False
417
+ return decoder
418
+
419
+
420
+ def create_bottleneck_from_config(bottleneck_config: Dict[str, Any]):
421
+ bottleneck_type = bottleneck_config.get("type", None)
422
+ assert bottleneck_type is not None, "type must be specified in bottleneck config"
423
+
424
+ if bottleneck_type != "vae":
425
+ raise NotImplementedError(
426
+ f"Only bottleneck type 'vae' is supported, got: {bottleneck_type}"
427
+ )
428
+
429
+ bottleneck = VAEBottleneck()
430
+ if not bottleneck_config.get("requires_grad", True):
431
+ for param in bottleneck.parameters():
432
+ param.requires_grad = False
433
+ return bottleneck
434
+
435
+
436
+ def create_autoencoder_from_config(config: Dict[str, Any]):
437
+ ae_config = config["model"]
438
+
439
+ if ae_config.get("pretransform") is not None:
440
+ raise NotImplementedError("Nested pretransform is not supported in sa_audio")
441
+
442
+ encoder = create_encoder_from_config(ae_config["encoder"])
443
+ decoder = create_decoder_from_config(ae_config["decoder"])
444
+
445
+ bottleneck_cfg = ae_config.get("bottleneck")
446
+ bottleneck = create_bottleneck_from_config(bottleneck_cfg) if bottleneck_cfg else None
447
+
448
+ latent_dim = ae_config.get("latent_dim")
449
+ assert latent_dim is not None, "latent_dim must be specified in model config"
450
+ downsampling_ratio = ae_config.get("downsampling_ratio")
451
+ assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
452
+ io_channels = ae_config.get("io_channels")
453
+ assert io_channels is not None, "io_channels must be specified in model config"
454
+ sample_rate = config.get("sample_rate")
455
+ assert sample_rate is not None, "sample_rate must be specified in model config"
456
+
457
+ return AudioAutoencoder(
458
+ encoder=encoder,
459
+ decoder=decoder,
460
+ latent_dim=latent_dim,
461
+ downsampling_ratio=downsampling_ratio,
462
+ sample_rate=sample_rate,
463
+ io_channels=io_channels,
464
+ bottleneck=bottleneck,
465
+ in_channels=ae_config.get("in_channels"),
466
+ out_channels=ae_config.get("out_channels"),
467
+ soft_clip=ae_config["decoder"].get("soft_clip", False),
468
+ )
469
+
470
+
471
+ def create_model_from_config(model_config: Dict[str, Any]):
472
+ model_type = model_config.get("model_type", None)
473
+ assert model_type is not None, "model_type must be specified in model config"
474
+
475
+ if model_type != "autoencoder":
476
+ raise NotImplementedError(f"Only 'autoencoder' is supported, got: {model_type}")
477
+
478
+ return create_autoencoder_from_config(model_config)
inference/model/t5_gemma/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .t5_gemma_model import get_t5_gemma_embedding
2
+
3
+ __all__ = ["get_t5_gemma_embedding"]
inference/model/t5_gemma/t5_gemma_model.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+ from transformers.models.t5gemma import T5GemmaEncoderModel
8
+
9
+ from inference.common import CPUOffloadWrapper, get_arch_memory
10
+ from inference.utils import env_is_true
11
+
12
+
13
+ class T5GemmaEncoder:
14
+ def __init__(self, model_path: str, device: str, weight_dtype: torch.dtype):
15
+ self.device = device
16
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
17
+ model = T5GemmaEncoderModel.from_pretrained(
18
+ model_path,
19
+ is_encoder_decoder=False,
20
+ dtype=weight_dtype,
21
+ ).to(device)
22
+ self.model = CPUOffloadWrapper(model, is_cpu_offload=env_is_true("CPU_OFFLOAD") or get_arch_memory() <= 48)
23
+
24
+ def encode(self, prompt: str) -> torch.Tensor:
25
+ inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
26
+ outputs = self.model(**inputs)
27
+ return outputs["last_hidden_state"].half()
28
+
29
+
30
+ _t5_gemma_cache: Optional[T5GemmaEncoder] = None
31
+
32
+
33
+ def get_t5_gemma_encoder(model_path: str, device: str, weight_dtype: torch.dtype) -> T5GemmaEncoder:
34
+ global _t5_gemma_cache
35
+ if _t5_gemma_cache is None:
36
+ _t5_gemma_cache = T5GemmaEncoder(model_path=model_path, device=device, weight_dtype=weight_dtype)
37
+ return _t5_gemma_cache
38
+
39
+
40
+ @torch.inference_mode()
41
+ def get_t5_gemma_embedding(prompt: str, model_path: str, device: str, weight_dtype: torch.dtype) -> torch.Tensor:
42
+ encoder = get_t5_gemma_encoder(model_path=model_path, device=device, weight_dtype=weight_dtype)
43
+ return encoder.encode(prompt)
inference/model/turbo_vaed/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .turbo_vaed_module import TurboVAED
2
+ from .turbo_vaed_model import get_turbo_vaed
3
+
4
+ __all__ = ["TurboVAED", "get_turbo_vaed"]
inference/model/turbo_vaed/turbo_vaed_model.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+
4
+ from .turbo_vaed_module import TurboVAED
5
+
6
+
7
+ def get_turbo_vaed(config_path, ckpt_path, device="cuda", weight_dtype=torch.float32) -> TurboVAED:
8
+ with open(config_path, "r", encoding="utf-8") as f:
9
+ config = json.load(f)
10
+ student = TurboVAED.from_config(config)
11
+
12
+ ckpt = torch.load(ckpt_path, map_location="cpu")
13
+ assert "ema_state_dict" in ckpt, "ckpt must contain ema_state_dict or state_dict"
14
+
15
+ state_dict = ckpt["ema_state_dict"]
16
+ new_state_dict = {}
17
+ for key, value in state_dict.items():
18
+ if key.startswith("module."):
19
+ new_state_dict[key[7:]] = value
20
+ else:
21
+ new_state_dict[key] = value
22
+ state_dict = new_state_dict
23
+
24
+ missing, _ = student.load_state_dict(state_dict, strict=False)
25
+ if len(missing) > 0:
26
+ sample_key = next(iter(state_dict.keys()))
27
+ if not sample_key.startswith("decoder.") and not sample_key.startswith("encoder."):
28
+ student.decoder.load_state_dict(state_dict, strict=False)
29
+
30
+ student = student.to(device, dtype=weight_dtype)
31
+ student.eval()
32
+ student.requires_grad_(False)
33
+ return student
inference/model/turbo_vaed/turbo_vaed_module.py ADDED
@@ -0,0 +1,1039 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ from diffusers.models.modeling_utils import ModelMixin
21
+ from einops import rearrange
22
+
23
+
24
+ __all__ = ["TurboVAED"]
25
+
26
+ ACT2CLS = {"swish": nn.SiLU, "silu": nn.SiLU, "mish": nn.Mish, "gelu": nn.GELU, "relu": nn.ReLU}
27
+
28
+
29
+ def get_activation(act_fn: str) -> nn.Module:
30
+ act_fn = act_fn.lower()
31
+ if act_fn in ACT2CLS:
32
+ return ACT2CLS[act_fn]()
33
+ else:
34
+ raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}")
35
+
36
+
37
+ def unpatchify(x, patch_size):
38
+ """
39
+ Unpatchify operation: convert patched representation back to original spatial resolution.
40
+ Similar to Wan VAE's unpatchify.
41
+
42
+ Args:
43
+ x: Input tensor with shape [batch_size, (channels * patch_size * patch_size), frame, height, width]
44
+ patch_size: The patch size used during patchification
45
+
46
+ Returns:
47
+ Tensor with shape [batch_size, channels, frame, height * patch_size, width * patch_size]
48
+ """
49
+ if patch_size == 1:
50
+ return x
51
+
52
+ if x.dim() != 5:
53
+ raise ValueError(f"Invalid input shape: {x.shape}")
54
+
55
+ # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
56
+ batch_size, c_patches, frames, height, width = x.shape
57
+ channels = c_patches // (patch_size * patch_size)
58
+
59
+ # Reshape to [b, c, patch_size, patch_size, f, h, w]
60
+ x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
61
+
62
+ # Rearrange to [b, c, f, h * patch_size, w * patch_size]
63
+ x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous()
64
+ x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
65
+
66
+ return x
67
+
68
+
69
+ class RMSNorm(nn.Module):
70
+ r"""
71
+ RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.
72
+
73
+ Args:
74
+ dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
75
+ eps (`float`): Small value to use when calculating the reciprocal of the square-root.
76
+ elementwise_affine (`bool`, defaults to `True`):
77
+ Boolean flag to denote if affine transformation should be applied.
78
+ bias (`bool`, defaults to False): If also training the `bias` param.
79
+ """
80
+
81
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
82
+ super().__init__()
83
+
84
+ self.eps = eps
85
+ self.elementwise_affine = elementwise_affine
86
+
87
+ if isinstance(dim, int):
88
+ dim = (dim,)
89
+
90
+ self.dim = torch.Size(dim)
91
+
92
+ self.weight = None
93
+
94
+ if elementwise_affine:
95
+ self.weight = nn.Parameter(torch.ones(dim))
96
+
97
+ def forward(self, hidden_states):
98
+ input_dtype = hidden_states.dtype
99
+ variance = hidden_states.to(torch.float32).pow(2).mean(1, keepdim=True)
100
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
101
+
102
+ if self.weight is not None:
103
+ # convert into half-precision if necessary
104
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
105
+ hidden_states = hidden_states.to(self.weight.dtype)
106
+ hidden_states = hidden_states * self.weight
107
+ else:
108
+ hidden_states = hidden_states.to(input_dtype)
109
+
110
+ return hidden_states
111
+
112
+
113
+ class TurboVAEDConv2dSplitUpsampler(nn.Module):
114
+ def __init__(
115
+ self,
116
+ in_channels: int,
117
+ kernel_size: Union[int, Tuple[int, int]] = 3,
118
+ stride: Union[int, Tuple[int, int]] = 1,
119
+ upscale_factor: int = 1,
120
+ padding_mode: str = "zeros",
121
+ ):
122
+ super().__init__()
123
+
124
+ self.in_channels = in_channels
125
+ self.stride = stride if isinstance(stride, tuple) else (stride, stride)
126
+ self.upscale_factor = upscale_factor
127
+
128
+ out_channels = in_channels
129
+
130
+ self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
131
+
132
+ height_pad = self.kernel_size[0] // 2
133
+ width_pad = self.kernel_size[1] // 2
134
+ padding = (height_pad, width_pad)
135
+
136
+ self.conv = nn.Conv2d(
137
+ in_channels=in_channels,
138
+ out_channels=out_channels,
139
+ kernel_size=self.kernel_size,
140
+ stride=1,
141
+ padding=padding,
142
+ padding_mode=padding_mode,
143
+ )
144
+
145
+ @torch.compile
146
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
147
+ hidden_states = self.conv(hidden_states)
148
+ hidden_states = torch.nn.functional.pixel_shuffle(hidden_states, self.stride[0])
149
+
150
+ return hidden_states
151
+
152
+
153
+ class TurboVAEDCausalConv3d(nn.Module):
154
+ def __init__(
155
+ self,
156
+ in_channels: int,
157
+ out_channels: int,
158
+ kernel_size: Union[int, Tuple[int, int, int]] = 3,
159
+ stride: Union[int, Tuple[int, int, int]] = 1,
160
+ dilation: Union[int, Tuple[int, int, int]] = 1,
161
+ groups: int = 1,
162
+ padding_mode: str = "zeros",
163
+ is_causal: bool = False,
164
+ ):
165
+ super().__init__()
166
+
167
+ assert is_causal == False
168
+ self.in_channels = in_channels
169
+ self.out_channels = out_channels
170
+ self.is_causal = is_causal
171
+ self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
172
+
173
+ dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1)
174
+ stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
175
+ height_pad = self.kernel_size[1] // 2
176
+ width_pad = self.kernel_size[2] // 2
177
+ padding = (0, height_pad, width_pad)
178
+
179
+ self.conv = nn.Conv3d(
180
+ in_channels,
181
+ out_channels,
182
+ self.kernel_size,
183
+ stride=stride,
184
+ dilation=dilation,
185
+ groups=groups,
186
+ padding=padding,
187
+ padding_mode=padding_mode,
188
+ )
189
+
190
+ @torch.compile
191
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
192
+ time_kernel_size = self.kernel_size[0]
193
+
194
+ if time_kernel_size > 1:
195
+ pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1))
196
+ pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1))
197
+ hidden_states = torch.cat([pad_left, hidden_states, pad_right], dim=2)
198
+
199
+ hidden_states = self.conv(hidden_states)
200
+ return hidden_states
201
+
202
+
203
+ class TurboVAEDCausalDepthwiseSeperableConv3d(nn.Module):
204
+ def __init__(
205
+ self,
206
+ in_channels: int,
207
+ out_channels: int,
208
+ kernel_size: Union[int, Tuple[int, int, int]] = 3,
209
+ stride: Union[int, Tuple[int, int, int]] = 1,
210
+ dilation: Union[int, Tuple[int, int, int]] = 1,
211
+ padding_mode: str = "zeros",
212
+ is_causal: bool = True,
213
+ ):
214
+ super().__init__()
215
+
216
+ self.in_channels = in_channels
217
+ self.out_channels = out_channels
218
+ self.is_causal = is_causal
219
+ self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
220
+ self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
221
+ self.dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1)
222
+
223
+ # Calculate padding for height and width dimensions
224
+ height_pad = self.kernel_size[1] // 2
225
+ width_pad = self.kernel_size[2] // 2
226
+ self.padding = (0, height_pad, width_pad)
227
+
228
+ # Depthwise Convolution
229
+ self.depthwise_conv = nn.Conv3d(
230
+ in_channels,
231
+ in_channels,
232
+ self.kernel_size,
233
+ stride=self.stride,
234
+ dilation=self.dilation,
235
+ groups=in_channels, # Each input channel is convolved separately
236
+ padding=self.padding,
237
+ padding_mode=padding_mode,
238
+ )
239
+
240
+ # Pointwise Convolution
241
+ self.pointwise_conv = nn.Conv3d(in_channels, out_channels, kernel_size=1) # 1x1x1 convolution to mix channels
242
+
243
+ @torch.compile
244
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
245
+ time_kernel_size = self.kernel_size[0]
246
+ if time_kernel_size > 1:
247
+ pad_count = (time_kernel_size - 1) // 2
248
+ pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, pad_count, 1, 1))
249
+ pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, pad_count, 1, 1))
250
+ hidden_states = torch.cat([pad_left, hidden_states, pad_right], dim=2)
251
+
252
+ # Apply depthwise convolution
253
+ hidden_states = self.depthwise_conv(hidden_states)
254
+ # Apply pointwise convolution
255
+ hidden_states = self.pointwise_conv(hidden_states)
256
+
257
+ return hidden_states
258
+
259
+
260
+ class TurboVAEDResnetBlock3d(nn.Module):
261
+ r"""
262
+ A 3D ResNet block used in the TurboVAED model.
263
+
264
+ Args:
265
+ in_channels (`int`):
266
+ Number of input channels.
267
+ out_channels (`int`, *optional*):
268
+ Number of output channels. If None, defaults to `in_channels`.
269
+ dropout (`float`, defaults to `0.0`):
270
+ Dropout rate.
271
+ eps (`float`, defaults to `1e-6`):
272
+ Epsilon value for normalization layers.
273
+ elementwise_affine (`bool`, defaults to `False`):
274
+ Whether to enable elementwise affinity in the normalization layers.
275
+ non_linearity (`str`, defaults to `"swish"`):
276
+ Activation function to use.
277
+ conv_shortcut (bool, defaults to `False`):
278
+ Whether or not to use a convolution shortcut.
279
+ """
280
+
281
+ def __init__(
282
+ self,
283
+ in_channels: int,
284
+ out_channels: Optional[int] = None,
285
+ dropout: float = 0.0,
286
+ eps: float = 1e-6,
287
+ elementwise_affine: bool = False,
288
+ non_linearity: str = "swish",
289
+ is_causal: bool = True,
290
+ is_upsampler_modified: bool = False,
291
+ is_dw_conv: bool = False,
292
+ dw_kernel_size: int = 3,
293
+ ) -> None:
294
+ super().__init__()
295
+
296
+ out_channels = out_channels or in_channels
297
+
298
+ self.nonlinearity = get_activation(non_linearity)
299
+
300
+ self.conv_operation = TurboVAEDCausalConv3d if not is_dw_conv else TurboVAEDCausalDepthwiseSeperableConv3d
301
+ self.kernel_size = 3 if not is_dw_conv else dw_kernel_size
302
+
303
+ self.is_upsampler_modified = is_upsampler_modified
304
+ self.replace_nonlinearity = get_activation("relu")
305
+
306
+ self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
307
+ self.conv1 = self.conv_operation(
308
+ in_channels=in_channels, out_channels=out_channels, kernel_size=self.kernel_size, is_causal=is_causal
309
+ )
310
+
311
+ self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
312
+ self.dropout = nn.Dropout(dropout)
313
+ self.conv2 = self.conv_operation(
314
+ in_channels=out_channels, out_channels=out_channels, kernel_size=self.kernel_size, is_causal=is_causal
315
+ )
316
+
317
+ self.norm3 = None
318
+ self.conv_shortcut = None
319
+ if in_channels != out_channels:
320
+ self.norm3 = RMSNorm(in_channels, eps=eps, elementwise_affine=elementwise_affine)
321
+ self.conv_shortcut = self.conv_operation(
322
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
323
+ )
324
+
325
+ @torch.compile
326
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
327
+ hidden_states = inputs
328
+
329
+ hidden_states = self.norm1(hidden_states)
330
+
331
+ if self.is_upsampler_modified:
332
+ hidden_states = self.replace_nonlinearity(hidden_states)
333
+ else:
334
+ hidden_states = self.nonlinearity(hidden_states)
335
+
336
+ hidden_states = self.conv1(hidden_states)
337
+
338
+ hidden_states = self.norm2(hidden_states)
339
+
340
+ hidden_states = self.nonlinearity(hidden_states)
341
+ hidden_states = self.dropout(hidden_states)
342
+
343
+ hidden_states = self.conv2(hidden_states)
344
+
345
+ if self.norm3 is not None:
346
+ inputs = self.norm3(inputs)
347
+
348
+ if self.conv_shortcut is not None:
349
+ inputs = self.conv_shortcut(inputs)
350
+
351
+ hidden_states = hidden_states + inputs
352
+ return hidden_states
353
+
354
+
355
+ class TurboVAEDUpsampler3d(nn.Module):
356
+ def __init__(
357
+ self,
358
+ in_channels: int,
359
+ stride: Union[int, Tuple[int, int, int]] = 1,
360
+ is_causal: bool = True,
361
+ upscale_factor: int = 1,
362
+ padding_mode: str = "zeros",
363
+ ) -> None:
364
+ super().__init__()
365
+
366
+ self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
367
+ self.upscale_factor = upscale_factor
368
+
369
+ out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
370
+
371
+ self.conv = TurboVAEDCausalConv3d(
372
+ in_channels=in_channels,
373
+ out_channels=out_channels,
374
+ kernel_size=3,
375
+ stride=1,
376
+ is_causal=is_causal,
377
+ padding_mode=padding_mode,
378
+ )
379
+
380
+ @torch.compile
381
+ def forward(self, hidden_states: torch.Tensor, is_first_chunk: bool = True) -> torch.Tensor:
382
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
383
+
384
+ hidden_states = self.conv(hidden_states)
385
+
386
+ # because the former has better performance on cuda kernels.
387
+ s_t, s_h, s_w = self.stride
388
+ hidden_states = hidden_states.reshape(batch_size, -1, s_t, s_h, s_w, num_frames, height, width)
389
+ hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
390
+ hidden_states = hidden_states.reshape(batch_size, -1, num_frames * s_t, height * s_h, width * s_w)
391
+
392
+ # slice the first chunk
393
+ if is_first_chunk:
394
+ hidden_states = hidden_states[:, :, self.stride[0] - 1 :] # NOTE: extra handling for the first frame
395
+
396
+ return hidden_states
397
+
398
+
399
+ class WanUpsample(nn.Upsample):
400
+ r"""
401
+ Perform upsampling while ensuring the output tensor has the same data type as the input.
402
+
403
+ Args:
404
+ x (torch.Tensor): Input tensor to be upsampled.
405
+
406
+ Returns:
407
+ torch.Tensor: Upsampled tensor with the same data type as the input.
408
+ """
409
+
410
+ def forward(self, x):
411
+ return super().forward(x.float()).type_as(x)
412
+
413
+
414
+ class WanResample(nn.Module):
415
+ r"""
416
+ A custom resampling module for 2D and 3D data.
417
+
418
+ Args:
419
+ dim (int): The number of input/output channels.
420
+ mode (str): The resampling mode. Must be one of:
421
+ - 'none': No resampling (identity operation).
422
+ - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
423
+ - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
424
+ """
425
+
426
+ def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
427
+ super().__init__()
428
+ self.dim = dim
429
+ self.mode = mode
430
+
431
+ # default to dim //2
432
+ if upsample_out_dim is None:
433
+ upsample_out_dim = dim // 2
434
+
435
+ # layers
436
+ if mode == "upsample2d":
437
+ self.resample = nn.Sequential(
438
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
439
+ )
440
+ elif mode == "upsample3d":
441
+ self.resample = nn.Sequential(
442
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1)
443
+ )
444
+ self.time_conv = TurboVAEDCausalConv3d(dim, dim * 2, (3, 1, 1))
445
+ else:
446
+ self.resample = nn.Identity()
447
+
448
+ def forward(self, x, is_first_chunk: bool = True):
449
+ b, c, t, h, w = x.shape
450
+ if self.mode == "upsample3d":
451
+ x = self.time_conv(x)
452
+ x = rearrange(x, 'b (n_split c) t h w -> b c (t n_split) h w', n_split=2)
453
+ assert x.shape == (b, c, t * 2, h, w), "x.shape: {}, expected: {}".format(x.shape, (b, c, t * 2, h, w))
454
+ if is_first_chunk:
455
+ x = x[:, :, 1:]
456
+
457
+ x = rearrange(x, "b c t h w -> (b t) c h w")
458
+ x = self.resample(x)
459
+ x = rearrange(x, "(b t) c h w -> b c t h w", b=b)
460
+
461
+ return x
462
+
463
+
464
+ class TurboVAEDMidBlock3d(nn.Module):
465
+ r"""
466
+ A middle block used in the TurboVAED model.
467
+
468
+ Args:
469
+ in_channels (`int`):
470
+ Number of input channels.
471
+ num_layers (`int`, defaults to `1`):
472
+ Number of resnet layers.
473
+ dropout (`float`, defaults to `0.0`):
474
+ Dropout rate.
475
+ resnet_eps (`float`, defaults to `1e-6`):
476
+ Epsilon value for normalization layers.
477
+ resnet_act_fn (`str`, defaults to `"swish"`):
478
+ Activation function to use.
479
+ is_causal (`bool`, defaults to `True`):
480
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
481
+ """
482
+
483
+ _supports_gradient_checkpointing = True
484
+
485
+ def __init__(
486
+ self,
487
+ in_channels: int,
488
+ num_layers: int = 1,
489
+ dropout: float = 0.0,
490
+ resnet_eps: float = 1e-6,
491
+ resnet_act_fn: str = "swish",
492
+ is_causal: bool = True,
493
+ is_dw_conv: bool = False,
494
+ dw_kernel_size: int = 3,
495
+ ) -> None:
496
+ super().__init__()
497
+
498
+ resnets = []
499
+ for _ in range(num_layers):
500
+ resnets.append(
501
+ TurboVAEDResnetBlock3d(
502
+ in_channels=in_channels,
503
+ out_channels=in_channels,
504
+ dropout=dropout,
505
+ eps=resnet_eps,
506
+ non_linearity=resnet_act_fn,
507
+ is_causal=is_causal,
508
+ is_dw_conv=is_dw_conv,
509
+ dw_kernel_size=dw_kernel_size,
510
+ )
511
+ )
512
+ self.resnets = nn.ModuleList(resnets)
513
+
514
+ self.gradient_checkpointing = False
515
+
516
+ @torch.compile
517
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
518
+ r"""Forward method of the `LTXMidBlock3D` class."""
519
+
520
+ for i, resnet in enumerate(self.resnets):
521
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
522
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
523
+ else:
524
+ hidden_states = resnet(hidden_states)
525
+
526
+ return hidden_states
527
+
528
+
529
+ class TurboVAEDUpBlock3d(nn.Module):
530
+ r"""
531
+ Up block used in the TurboVAED model.
532
+
533
+ Args:
534
+ in_channels (`int`):
535
+ Number of input channels.
536
+ out_channels (`int`, *optional*):
537
+ Number of output channels. If None, defaults to `in_channels`.
538
+ num_layers (`int`, defaults to `1`):
539
+ Number of resnet layers.
540
+ dropout (`float`, defaults to `0.0`):
541
+ Dropout rate.
542
+ resnet_eps (`float`, defaults to `1e-6`):
543
+ Epsilon value for normalization layers.
544
+ resnet_act_fn (`str`, defaults to `"swish"`):
545
+ Activation function to use.
546
+ spatio_temporal_scale (`bool`, defaults to `True`):
547
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
548
+ Whether or not to downsample across temporal dimension.
549
+ is_causal (`bool`, defaults to `True`):
550
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
551
+ """
552
+
553
+ _supports_gradient_checkpointing = True
554
+
555
+ def __init__(
556
+ self,
557
+ in_channels: int,
558
+ out_channels: Optional[int] = None,
559
+ num_layers: int = 1,
560
+ dropout: float = 0.0,
561
+ resnet_eps: float = 1e-6,
562
+ resnet_act_fn: str = "swish",
563
+ spatio_temporal_scale: bool = True,
564
+ is_causal: bool = True,
565
+ is_dw_conv: bool = False,
566
+ dw_kernel_size: int = 3,
567
+ spatio_only: bool = False,
568
+ ):
569
+ super().__init__()
570
+
571
+ out_channels = out_channels or in_channels
572
+
573
+ self.conv_in = None
574
+ if in_channels != out_channels:
575
+ self.conv_in = TurboVAEDResnetBlock3d(
576
+ in_channels=in_channels,
577
+ out_channels=out_channels,
578
+ dropout=dropout,
579
+ eps=resnet_eps,
580
+ non_linearity=resnet_act_fn,
581
+ is_causal=is_causal,
582
+ is_dw_conv=is_dw_conv,
583
+ dw_kernel_size=dw_kernel_size,
584
+ )
585
+
586
+ self.upsamplers = None
587
+ if spatio_temporal_scale:
588
+ self.upsamplers = nn.ModuleList(
589
+ [
590
+ WanResample(
591
+ dim=out_channels, mode="upsample2d" if spatio_only else "upsample3d", upsample_out_dim=out_channels
592
+ )
593
+ ]
594
+ )
595
+
596
+ resnets = []
597
+ for _ in range(num_layers):
598
+ resnets.append(
599
+ TurboVAEDResnetBlock3d(
600
+ in_channels=out_channels,
601
+ out_channels=out_channels,
602
+ dropout=dropout,
603
+ eps=resnet_eps,
604
+ non_linearity=resnet_act_fn,
605
+ is_causal=is_causal,
606
+ is_dw_conv=is_dw_conv,
607
+ dw_kernel_size=dw_kernel_size,
608
+ is_upsampler_modified=(spatio_temporal_scale),
609
+ )
610
+ )
611
+ self.resnets = nn.ModuleList(resnets)
612
+
613
+ self.gradient_checkpointing = False
614
+
615
+ @torch.compile
616
+ def forward(self, hidden_states: torch.Tensor, is_first_chunk: bool) -> torch.Tensor:
617
+ if self.conv_in is not None:
618
+ hidden_states = self.conv_in(hidden_states)
619
+
620
+ if self.upsamplers is not None:
621
+ for upsampler in self.upsamplers:
622
+ hidden_states = upsampler(hidden_states, is_first_chunk=is_first_chunk)
623
+
624
+ for i, resnet in enumerate(self.resnets):
625
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
626
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
627
+ else:
628
+ hidden_states = resnet(hidden_states)
629
+
630
+ return hidden_states
631
+
632
+
633
+ class TurboVAEDDecoder3d(nn.Module):
634
+ r"""
635
+ The `TurboVAEDDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
636
+ sample.
637
+
638
+ Args:
639
+ in_channels (`int`, defaults to 128):
640
+ Number of latent channels.
641
+ out_channels (`int`, defaults to 3):
642
+ Number of output channels.
643
+ block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
644
+ The number of output channels for each block.
645
+ spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
646
+ Whether a block should contain spatio-temporal upscaling layers or not.
647
+ layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
648
+ The number of layers per block.
649
+ patch_size (`int`, defaults to `4`):
650
+ The size of spatial patches.
651
+ patch_size_t (`int`, defaults to `1`):
652
+ The size of temporal patches.
653
+ resnet_norm_eps (`float`, defaults to `1e-6`):
654
+ Epsilon value for ResNet normalization layers.
655
+ is_causal (`bool`, defaults to `False`):
656
+ Whether this layer behaves causally (future frames depend only on past frames) or not.
657
+ """
658
+
659
+ def __init__(
660
+ self,
661
+ in_channels: int = 128,
662
+ out_channels: int = 3,
663
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
664
+ spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
665
+ layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
666
+ patch_size: int = 4,
667
+ patch_size_t: int = 1,
668
+ resnet_norm_eps: float = 1e-6,
669
+ is_causal: bool = False,
670
+ decoder_is_dw_conv: Tuple[bool, ...] = (False, False, False, False, False),
671
+ decoder_dw_kernel_size: int = 3,
672
+ spatio_only: Tuple[bool, ...] = (False, False, False, False),
673
+ upsampling: bool = False,
674
+ use_unpatchify: bool = False,
675
+ ) -> None:
676
+ super().__init__()
677
+
678
+ self.patch_size = patch_size
679
+ self.patch_size_t = patch_size_t
680
+ self.out_channels = out_channels
681
+
682
+ self.upsampling = upsampling
683
+ self.use_unpatchify = use_unpatchify
684
+
685
+ block_out_channels = tuple(reversed(block_out_channels))
686
+ spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
687
+ layers_per_block = tuple(reversed(layers_per_block))
688
+ decoder_is_dw_conv = tuple(reversed(decoder_is_dw_conv))
689
+ spatio_only = tuple(reversed(spatio_only))
690
+ output_channel = block_out_channels[0]
691
+
692
+ self.conv_in = TurboVAEDCausalConv3d(
693
+ in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal
694
+ )
695
+
696
+ self.mid_block = TurboVAEDMidBlock3d(
697
+ in_channels=output_channel,
698
+ num_layers=layers_per_block[0],
699
+ resnet_eps=resnet_norm_eps,
700
+ is_causal=is_causal,
701
+ is_dw_conv=decoder_is_dw_conv[0],
702
+ dw_kernel_size=decoder_dw_kernel_size,
703
+ )
704
+
705
+ # up blocks
706
+ num_block_out_channels = len(block_out_channels)
707
+ self.up_blocks = nn.ModuleList([])
708
+ for i in range(num_block_out_channels):
709
+ input_channel = output_channel
710
+ output_channel = block_out_channels[i]
711
+
712
+ up_block = TurboVAEDUpBlock3d(
713
+ in_channels=input_channel,
714
+ out_channels=output_channel,
715
+ num_layers=layers_per_block[i + 1],
716
+ resnet_eps=resnet_norm_eps,
717
+ spatio_temporal_scale=spatio_temporal_scaling[i],
718
+ is_causal=is_causal,
719
+ is_dw_conv=decoder_is_dw_conv[i + 1],
720
+ dw_kernel_size=decoder_dw_kernel_size,
721
+ spatio_only=spatio_only[i],
722
+ )
723
+
724
+ self.up_blocks.append(up_block)
725
+
726
+ # out
727
+ assert self.patch_size == 2
728
+ if not self.use_unpatchify:
729
+ self.norm_up_1 = RMSNorm(output_channel, eps=1e-8, elementwise_affine=False)
730
+ self.upsampler2d_1 = TurboVAEDConv2dSplitUpsampler(in_channels=output_channel, kernel_size=3, stride=(2, 2))
731
+ output_channel = output_channel // (2 * 2)
732
+
733
+ self.conv_act = nn.SiLU()
734
+
735
+ # When use_unpatchify=True, conv_out outputs more channels (out_channels * patch_size^2)
736
+ # and unpatchify will recover the spatial resolution
737
+ conv_out_channels = self.out_channels
738
+ if self.use_unpatchify and self.patch_size >= 2:
739
+ conv_out_channels = self.out_channels * self.patch_size * self.patch_size
740
+
741
+ self.conv_out = TurboVAEDCausalConv3d(
742
+ in_channels=output_channel, out_channels=conv_out_channels, kernel_size=3, stride=1, is_causal=is_causal
743
+ )
744
+
745
+ self.gradient_checkpointing = False
746
+
747
+ @torch.compile
748
+ def forward(self, hidden_states: torch.Tensor, is_first_chunk: bool) -> torch.Tensor:
749
+ hidden_states = self.conv_in(hidden_states)
750
+
751
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
752
+
753
+ def create_custom_forward(module):
754
+ def create_forward(*inputs):
755
+ return module(*inputs)
756
+
757
+ return create_forward
758
+
759
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)
760
+
761
+ for up_block in self.up_blocks:
762
+ hidden_states = torch.utils.checkpoint.checkpoint(
763
+ create_custom_forward(up_block), hidden_states, is_first_chunk
764
+ )
765
+ else:
766
+ hidden_states = self.mid_block(hidden_states)
767
+
768
+ for index, up_block in enumerate(self.up_blocks):
769
+ hidden_states = up_block(hidden_states, is_first_chunk=is_first_chunk)
770
+
771
+ if not self.use_unpatchify:
772
+ hidden_states = self.norm_up_1(hidden_states)
773
+ hidden_states = self.conv_act(hidden_states)
774
+
775
+ hidden_states_array = []
776
+ for t in range(hidden_states.shape[2]):
777
+ h = self.upsampler2d_1(hidden_states[:, :, t, :, :])
778
+ hidden_states_array.append(h)
779
+ hidden_states = torch.stack(hidden_states_array, dim=2)
780
+
781
+ # RMSNorm
782
+ input_dtype = hidden_states.dtype
783
+ variance = hidden_states.to(torch.float32).pow(2).mean(1, keepdim=True)
784
+ hidden_states = hidden_states * torch.rsqrt(variance + 1e-8)
785
+ hidden_states = hidden_states.to(input_dtype)
786
+
787
+ hidden_states = self.conv_act(hidden_states)
788
+
789
+ hidden_states = self.conv_out(hidden_states)
790
+
791
+ if self.use_unpatchify:
792
+ hidden_states = unpatchify(hidden_states, self.patch_size)
793
+
794
+ return hidden_states
795
+
796
+
797
+ class TurboVAED(ModelMixin, ConfigMixin):
798
+ _supports_gradient_checkpointing = True
799
+
800
+ @register_to_config
801
+ def __init__(
802
+ self,
803
+ in_channels: int = 3, # useless arg for compatibility, we only use latent channels
804
+ out_channels: int = 3,
805
+ latent_channels: int = 128,
806
+ decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
807
+ decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
808
+ decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
809
+ patch_size: int = 4,
810
+ patch_size_t: int = 1,
811
+ resnet_norm_eps: float = 1e-6,
812
+ scaling_factor: float = 1.0,
813
+ decoder_causal: bool = False,
814
+ decoder_is_dw_conv: Tuple[bool, ...] = (False, False, False, False, False),
815
+ decoder_dw_kernel_size: int = 3,
816
+ decoder_spatio_only: Tuple[bool, ...] = (False, False, False, False),
817
+ first_chunk_size: int = 3,
818
+ step_size: int = 5,
819
+ spatial_compression_ratio: int = 16,
820
+ temporal_compression_ratio: int = 4,
821
+ use_unpatchify: bool = False,
822
+ # below are for training, keep for compatibility
823
+ aligned_feature_projection_mode: Optional[str] = None,
824
+ aligned_feature_projection_dim: Optional[List[Tuple[int, int]]] = None,
825
+ aligned_blks_indices: Optional[List[int]] = None,
826
+ ):
827
+ super().__init__()
828
+
829
+ self.decoder = TurboVAEDDecoder3d(
830
+ in_channels=latent_channels,
831
+ out_channels=out_channels,
832
+ block_out_channels=decoder_block_out_channels,
833
+ spatio_temporal_scaling=decoder_spatio_temporal_scaling,
834
+ layers_per_block=decoder_layers_per_block,
835
+ patch_size=patch_size,
836
+ patch_size_t=patch_size_t,
837
+ resnet_norm_eps=resnet_norm_eps,
838
+ is_causal=decoder_causal,
839
+ decoder_is_dw_conv=decoder_is_dw_conv,
840
+ decoder_dw_kernel_size=decoder_dw_kernel_size,
841
+ spatio_only=decoder_spatio_only,
842
+ use_unpatchify=use_unpatchify,
843
+ )
844
+
845
+ self.first_chunk_size = first_chunk_size
846
+ self.step_size = step_size
847
+
848
+ self.spatial_compression_ratio = spatial_compression_ratio
849
+ self.temporal_compression_ratio = temporal_compression_ratio
850
+
851
+ self.z_dim = latent_channels
852
+ self.mean = torch.tensor(
853
+ [
854
+ -0.2289,
855
+ -0.0052,
856
+ -0.1323,
857
+ -0.2339,
858
+ -0.2799,
859
+ 0.0174,
860
+ 0.1838,
861
+ 0.1557,
862
+ -0.1382,
863
+ 0.0542,
864
+ 0.2813,
865
+ 0.0891,
866
+ 0.1570,
867
+ -0.0098,
868
+ 0.0375,
869
+ -0.1825,
870
+ -0.2246,
871
+ -0.1207,
872
+ -0.0698,
873
+ 0.5109,
874
+ 0.2665,
875
+ -0.2108,
876
+ -0.2158,
877
+ 0.2502,
878
+ -0.2055,
879
+ -0.0322,
880
+ 0.1109,
881
+ 0.1567,
882
+ -0.0729,
883
+ 0.0899,
884
+ -0.2799,
885
+ -0.1230,
886
+ -0.0313,
887
+ -0.1649,
888
+ 0.0117,
889
+ 0.0723,
890
+ -0.2839,
891
+ -0.2083,
892
+ -0.0520,
893
+ 0.3748,
894
+ 0.0152,
895
+ 0.1957,
896
+ 0.1433,
897
+ -0.2944,
898
+ 0.3573,
899
+ -0.0548,
900
+ -0.1681,
901
+ -0.0667,
902
+ ],
903
+ dtype=torch.float32,
904
+ device="cuda",
905
+ )
906
+ self.std = torch.tensor(
907
+ [
908
+ 0.4765,
909
+ 1.0364,
910
+ 0.4514,
911
+ 1.1677,
912
+ 0.5313,
913
+ 0.4990,
914
+ 0.4818,
915
+ 0.5013,
916
+ 0.8158,
917
+ 1.0344,
918
+ 0.5894,
919
+ 1.0901,
920
+ 0.6885,
921
+ 0.6165,
922
+ 0.8454,
923
+ 0.4978,
924
+ 0.5759,
925
+ 0.3523,
926
+ 0.7135,
927
+ 0.6804,
928
+ 0.5833,
929
+ 1.4146,
930
+ 0.8986,
931
+ 0.5659,
932
+ 0.7069,
933
+ 0.5338,
934
+ 0.4889,
935
+ 0.4917,
936
+ 0.4069,
937
+ 0.4999,
938
+ 0.6866,
939
+ 0.4093,
940
+ 0.5709,
941
+ 0.6065,
942
+ 0.6415,
943
+ 0.4944,
944
+ 0.5726,
945
+ 1.2042,
946
+ 0.5458,
947
+ 1.6887,
948
+ 0.3971,
949
+ 1.0600,
950
+ 0.3943,
951
+ 0.5537,
952
+ 0.5444,
953
+ 0.4089,
954
+ 0.7468,
955
+ 0.7744,
956
+ ],
957
+ dtype=torch.float32,
958
+ device="cuda",
959
+ )
960
+ self.scale = [self.mean, 1.0 / self.std]
961
+
962
+ def _sliding_window_decode(self, z, output_offload=False):
963
+ z_dtype = z.dtype
964
+ z_device = z.device
965
+ scale = self.scale
966
+ assert isinstance(scale[0], torch.Tensor), "scale[0] must be a tensor"
967
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
968
+ z = z.to(z_dtype)
969
+
970
+ first_chunk_size = self.first_chunk_size
971
+ step = self.step_size
972
+
973
+ # Context mapping: 1 latent frame context -> temporal_compression_ratio pixel frames overlap
974
+ num_overlap_pixel_frames = 1 * self.temporal_compression_ratio
975
+
976
+ _, _, num_frames, _, _ = z.shape
977
+
978
+ # 1. Pad frames to satisfy chunking requirements
979
+ # The total number of frames must follow the formula:
980
+ # num_frames = first_chunk_size + n * step_size
981
+ num_padding_frames = 0
982
+
983
+ if num_frames < first_chunk_size:
984
+ # if input is shorter than first_chunk_size
985
+ num_padding_frames = first_chunk_size - num_frames
986
+ elif (num_frames - first_chunk_size) % step != 0:
987
+ num_padding_frames = step - (num_frames - first_chunk_size) % step
988
+
989
+ if num_padding_frames > 0:
990
+ z = torch.cat([z, z[:, :, -1:].repeat(1, 1, num_padding_frames, 1, 1)], dim=2)
991
+ num_frames = num_frames + num_padding_frames
992
+
993
+ # 2. Decode with overlapping windows
994
+ # Collect chunks on CPU to avoid GPU OOM for high resolution (e.g., 1080P) when output_offload=True
995
+ out_chunks = []
996
+
997
+ if num_frames == first_chunk_size:
998
+ # if only one chunk, decode directly
999
+ out = self.decoder(z, is_first_chunk=True)
1000
+ out_chunks.append(out.cpu() if output_offload else out)
1001
+ del out
1002
+ else:
1003
+ # first chunk: attach the right frame
1004
+ out = self.decoder(z[:, :, 0 : first_chunk_size + 1, :, :], is_first_chunk=True)
1005
+ out = out[:, :, :-num_overlap_pixel_frames]
1006
+ out_chunks.append(out.cpu() if output_offload else out)
1007
+ del out
1008
+
1009
+ # middle chunk: attach the left and right frames
1010
+ # last chunk: attach the left frame
1011
+ for i in range(first_chunk_size, num_frames, step):
1012
+ is_last_chunk = i + step == num_frames
1013
+ left = i - 1
1014
+ right = i + step + 1 if not is_last_chunk else i + step
1015
+
1016
+ assert left >= 0 and right <= num_frames, f"left: {left}, right: {right}, num_frames: {num_frames}"
1017
+
1018
+ out_ = self.decoder(z[:, :, left:right, :, :], is_first_chunk=False)
1019
+
1020
+ if is_last_chunk:
1021
+ out_ = out_[:, :, num_overlap_pixel_frames:]
1022
+ else:
1023
+ out_ = out_[:, :, num_overlap_pixel_frames:-num_overlap_pixel_frames]
1024
+
1025
+ out_chunks.append(out_.cpu() if output_offload else out_)
1026
+ del out_
1027
+
1028
+ # Concatenate chunks (on CPU if output_offload, otherwise on GPU)
1029
+ out = torch.cat(out_chunks, dim=2)
1030
+ del out_chunks
1031
+
1032
+ # 3. Remove padded frames
1033
+ if num_padding_frames > 0:
1034
+ out = out[:, :, : -num_padding_frames * self.temporal_compression_ratio]
1035
+
1036
+ return out.to(z_device) if output_offload else out
1037
+
1038
+ def decode(self, z: torch.Tensor, output_offload: bool = False):
1039
+ return self._sliding_window_decode(z, output_offload=output_offload)
inference/model/vae2_2/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .vae2_2_model import Wan2_2_VAE, get_vae2_2
2
+
3
+ __all__ = ["Wan2_2_VAE", "get_vae2_2"]
inference/model/vae2_2/vae2_2_model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+
3
+ import torch
4
+
5
+ from .vae2_2_module import Wan2_2_VAE
6
+
7
+
8
+ def get_vae2_2(model_path, device="cuda", weight_dtype=torch.float32) -> Wan2_2_VAE:
9
+ vae = Wan2_2_VAE(vae_pth=model_path).to(device).to(weight_dtype)
10
+ vae.vae.requires_grad_(False)
11
+ vae.vae.eval()
12
+ gc.collect()
13
+ torch.cuda.empty_cache()
14
+ return vae
15
+
16
+
17
+ __all__ = ["Wan2_2_VAE", "get_vae2_2"]
inference/model/vae2_2/vae2_2_module.py ADDED
@@ -0,0 +1,1086 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copyright 2024-2026 The Alibaba Wan Team Authors. All rights reserved.
16
+ import logging
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from einops import rearrange
22
+
23
+
24
+ __all__ = ["Wan2_2_VAE"]
25
+
26
+ CACHE_T = 2
27
+
28
+
29
+ class ScatterFwdAllGatherBackwardOverlap(torch.autograd.Function):
30
+ @staticmethod
31
+ def forward(ctx, x, group, overlap_size):
32
+ """
33
+ Forward pass: split input tensor along W; each rank processes its local
34
+ chunk including overlap regions.
35
+
36
+ Args:
37
+ x: Input tensor, shape [B, C, T, H, W]
38
+ group: Distributed communication group
39
+ overlap_size: Width of overlap region
40
+ """
41
+ W = x.shape[4]
42
+ world_size = torch.distributed.get_world_size(group)
43
+ rank = torch.distributed.get_rank(group)
44
+
45
+ # Compute base chunk size
46
+ base_chunk_size = (W + world_size - 1) // world_size
47
+
48
+ # Compute chunk range for current rank
49
+ chunk_start = rank * base_chunk_size
50
+ chunk_end = min((rank + 1) * base_chunk_size, W)
51
+
52
+ # Extend range with overlap
53
+ overlap_start = max(0, chunk_start - overlap_size)
54
+ overlap_end = min(W, chunk_end + overlap_size)
55
+
56
+ # Slice local chunk
57
+ x_chunk = x[:, :, :, :, overlap_start:overlap_end].contiguous()
58
+
59
+ # Save metadata needed by backward
60
+ ctx.save_for_backward(torch.tensor([overlap_start, overlap_end, W], dtype=torch.long, device=x.device))
61
+ ctx.group = group
62
+ ctx.overlap_size = overlap_size
63
+ ctx.world_size = world_size
64
+ ctx.rank = rank
65
+ ctx.base_chunk_size = base_chunk_size
66
+ return x_chunk
67
+
68
+ @staticmethod
69
+ def backward(ctx, grad_output):
70
+ """
71
+ Backward pass: all-gather gradients from all ranks and trim overlap.
72
+ """
73
+ # Restore saved forward metadata
74
+ overlap_start, overlap_end, W = ctx.saved_tensors[0]
75
+ overlap_start = overlap_start.item()
76
+ overlap_end = overlap_end.item()
77
+ W = W.item()
78
+
79
+ group = ctx.group
80
+ overlap_size = ctx.overlap_size
81
+ world_size = ctx.world_size
82
+ ctx.rank
83
+ base_chunk_size = ctx.base_chunk_size
84
+
85
+ # Collect gradients from all ranks via all_gather
86
+ grad_output = grad_output.contiguous()
87
+ B, C, T, H = grad_output.shape[:4]
88
+ grad_shapes = []
89
+ for r in range(world_size):
90
+ r_chunk_start = r * base_chunk_size
91
+ r_chunk_end = min((r + 1) * base_chunk_size, W)
92
+
93
+ r_overlap_start = max(0, r_chunk_start - overlap_size)
94
+ r_overlap_end = min(W, r_chunk_end + overlap_size)
95
+
96
+ # Compute gradient shape for each rank
97
+ chunk_width = r_overlap_end - r_overlap_start
98
+ grad_shapes.append((B, C, T, H, chunk_width))
99
+ grad_chunks = [
100
+ torch.zeros(grad_shape, device=grad_output.device, dtype=grad_output.dtype) for grad_shape in grad_shapes
101
+ ]
102
+ torch.distributed.all_gather(grad_chunks, grad_output, group=group)
103
+
104
+ # Stitch gathered chunks into full gradient tensor
105
+ full_grad = torch.zeros(B, C, T, H, W, device=grad_output.device, dtype=grad_output.dtype)
106
+
107
+ # Place each rank's gradient chunk at the correct position
108
+ for r in range(world_size):
109
+ r_chunk_start = r * base_chunk_size
110
+ r_chunk_end = min((r + 1) * base_chunk_size, W)
111
+
112
+ r_overlap_start = max(0, r_chunk_start - overlap_size)
113
+ r_overlap_end = min(W, r_chunk_end + overlap_size)
114
+
115
+ # Position in full gradient
116
+ grad_start_in_full = r_overlap_start
117
+ grad_end_in_full = r_overlap_end
118
+
119
+ # Position inside gathered chunk
120
+ grad_start_in_chunk = 0
121
+ grad_end_in_chunk = r_overlap_end - r_overlap_start
122
+
123
+ # Handle left boundary for first rank
124
+ if r == 0:
125
+ grad_start_in_chunk = 0
126
+ grad_end_in_chunk = min(r_chunk_end + overlap_size, W) - r_overlap_start
127
+ # Handle right boundary for last rank
128
+ elif r == world_size - 1:
129
+ grad_start_in_chunk = max(0, r_chunk_start - overlap_size) - r_overlap_start
130
+ grad_end_in_chunk = r_overlap_end - r_overlap_start
131
+
132
+ # Accumulate into full gradient
133
+ full_grad[:, :, :, :, grad_start_in_full:grad_end_in_full] += grad_chunks[r][
134
+ :, :, :, :, grad_start_in_chunk:grad_end_in_chunk
135
+ ]
136
+
137
+ return full_grad, None, None
138
+
139
+
140
+ def scatter_fwd_all_gather_backward_with_overlap(x, group, overlap_size=0):
141
+ return ScatterFwdAllGatherBackwardOverlap.apply(x, group, overlap_size)
142
+
143
+
144
+ class AllGatherFwdScatterBackwardOverlap(torch.autograd.Function):
145
+ @staticmethod
146
+ def forward(ctx, x, group, overlap_size):
147
+ """
148
+ Forward pass: each rank clips local input, then all-gathers clipped chunks.
149
+
150
+ Args:
151
+ x: Input tensor, shape [B, C, T, H, W], already local overlapped chunk per rank
152
+ group: Distributed communication group
153
+ overlap_size: Width of overlap region
154
+ """
155
+ world_size = torch.distributed.get_world_size(group)
156
+ rank = torch.distributed.get_rank(group)
157
+
158
+ # Clip local input first (remove overlap area)
159
+ if rank == 0:
160
+ valid_start = 0
161
+ valid_end = x.shape[-1] - overlap_size
162
+ elif rank == world_size - 1:
163
+ valid_start = overlap_size
164
+ valid_end = x.shape[-1]
165
+ else:
166
+ valid_start = overlap_size
167
+ valid_end = x.shape[-1] - overlap_size
168
+
169
+ x_clipped = x[..., valid_start:valid_end].contiguous()
170
+ clipped_width = x_clipped.shape[-1]
171
+
172
+ # First all_gather: collect clipped widths across ranks
173
+ width_tensor = torch.tensor([clipped_width], dtype=torch.long, device=x.device)
174
+ all_widths = [torch.zeros_like(width_tensor) for _ in range(world_size)]
175
+ torch.distributed.all_gather(all_widths, width_tensor, group=group)
176
+ clipped_widths = [w.item() for w in all_widths]
177
+
178
+ # Second all_gather: collect clipped data across ranks
179
+ B, C, T, H = x_clipped.shape[:4]
180
+ x_clipped_chunks = [torch.zeros(B, C, T, H, w, device=x.device, dtype=x.dtype) for w in clipped_widths]
181
+ torch.distributed.all_gather(x_clipped_chunks, x_clipped, group=group)
182
+ full_x = torch.cat(x_clipped_chunks, dim=-1)
183
+
184
+ # Save metadata needed by backward
185
+ ctx.save_for_backward(torch.tensor([valid_start, valid_end], dtype=torch.long, device=x.device))
186
+ ctx.clipped_widths = clipped_widths
187
+ ctx.group = group
188
+ ctx.overlap_size = overlap_size
189
+ ctx.world_size = world_size
190
+ ctx.rank = rank
191
+
192
+ return full_x
193
+
194
+ @staticmethod
195
+ def backward(ctx, grad_output):
196
+ """
197
+ Backward pass: each rank restores gradients for its own partition only.
198
+ """
199
+ # Restore saved forward metadata
200
+ valid_start, valid_end = ctx.saved_tensors[0]
201
+ valid_start = valid_start.item()
202
+ valid_end = valid_end.item()
203
+
204
+ clipped_widths = ctx.clipped_widths
205
+ ctx.group
206
+ overlap_size = ctx.overlap_size
207
+ world_size = ctx.world_size
208
+ rank = ctx.rank
209
+
210
+ # Compute current rank offset in full gradient
211
+ start_pos = sum(clipped_widths[:rank])
212
+ end_pos = start_pos + clipped_widths[rank]
213
+
214
+ # Extract only current rank gradient slice
215
+ grad_clipped = grad_output[:, :, :, :, start_pos:end_pos]
216
+
217
+ # Pad zeros to recover overlap area for current rank
218
+ if rank == 0:
219
+ # First rank: pad right
220
+ grad_full = F.pad(grad_clipped, (0, overlap_size))
221
+ elif rank == world_size - 1:
222
+ # Last rank: pad left
223
+ grad_full = F.pad(grad_clipped, (overlap_size, 0))
224
+ else:
225
+ # Middle rank: pad both sides
226
+ grad_full = F.pad(grad_clipped, (overlap_size, overlap_size))
227
+
228
+ return grad_full, None, None
229
+
230
+
231
+ def all_gather_fwd_scatter_backward_with_overlap(x, group, overlap_size=0):
232
+ return AllGatherFwdScatterBackwardOverlap.apply(x, group, overlap_size)
233
+
234
+
235
+ def one_plus_world_size(group):
236
+ return group is not None and torch.distributed.get_world_size(group) > 1
237
+
238
+
239
+ class CausalConv3d(nn.Conv3d):
240
+ """
241
+ Causal 3d convolusion.
242
+ """
243
+
244
+ def __init__(self, *args, **kwargs):
245
+ super().__init__(*args, **kwargs)
246
+ self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
247
+ self.padding = (0, 0, 0)
248
+
249
+ @torch.compile
250
+ def forward(self, x, cache_x=None, group: torch.distributed.ProcessGroup = None):
251
+ padding = list(self._padding)
252
+ if cache_x is not None and self._padding[4] > 0:
253
+ cache_x = cache_x.to(x.device)
254
+ x = torch.cat([cache_x, x], dim=2)
255
+ padding[4] -= cache_x.shape[2]
256
+ if one_plus_world_size(group):
257
+ overlap_size = self.kernel_size[-1] // 2 * self.stride[-1]
258
+ x = scatter_fwd_all_gather_backward_with_overlap(x, group, overlap_size=overlap_size)
259
+ x = F.pad(x, padding)
260
+ x = super().forward(x)
261
+ if one_plus_world_size(group):
262
+ x = all_gather_fwd_scatter_backward_with_overlap(x, group, overlap_size=overlap_size)
263
+ return x
264
+
265
+
266
+ class RMS_norm(nn.Module):
267
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
268
+ super().__init__()
269
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
270
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
271
+
272
+ self.channel_first = channel_first
273
+ self.scale = dim**0.5
274
+ self.gamma = nn.Parameter(torch.ones(shape))
275
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
276
+
277
+ @torch.compile
278
+ def forward(self, x):
279
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
280
+
281
+
282
+ class Upsample(nn.Upsample):
283
+ @torch.compile
284
+ def forward(self, x):
285
+ """
286
+ Fix bfloat16 support for nearest neighbor interpolation.
287
+ """
288
+ return super().forward(x.float()).type_as(x)
289
+
290
+
291
+ class Resample(nn.Module):
292
+ def __init__(self, dim, mode):
293
+ assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d")
294
+ super().__init__()
295
+ self.dim = dim
296
+ self.mode = mode
297
+
298
+ # layers
299
+ if mode == "upsample2d":
300
+ self.resample = nn.Sequential(
301
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim, 3, padding=1)
302
+ )
303
+ elif mode == "upsample3d":
304
+ self.resample = nn.Sequential(
305
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
306
+ nn.Conv2d(dim, dim, 3, padding=1),
307
+ # nn.Conv2d(dim, dim//2, 3, padding=1)
308
+ )
309
+ self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
310
+ elif mode == "downsample2d":
311
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
312
+ elif mode == "downsample3d":
313
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
314
+ self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
315
+ else:
316
+ self.resample = nn.Identity()
317
+
318
+ @torch.compile
319
+ def forward(self, x, feat_cache=None, feat_idx=[0], group: torch.distributed.ProcessGroup = None):
320
+ if one_plus_world_size(group):
321
+ if self.mode in ["upsample3d", "upsample2d"]:
322
+ overlap_size = 1
323
+ elif self.mode in ["downsample3d", "downsample2d"]:
324
+ overlap_size = 2
325
+ else:
326
+ overlap_size = 0
327
+ x = scatter_fwd_all_gather_backward_with_overlap(x, group, overlap_size=overlap_size)
328
+
329
+ b, c, t, h, w = x.size()
330
+ if self.mode == "upsample3d":
331
+ if feat_cache is not None:
332
+ idx = feat_idx[0]
333
+ if feat_cache[idx] is None:
334
+ feat_cache[idx] = "Rep"
335
+ feat_idx[0] += 1
336
+ else:
337
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
338
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
339
+ # cache last frame of last two chunk
340
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
341
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
342
+ cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
343
+ if feat_cache[idx] == "Rep":
344
+ x = self.time_conv(x)
345
+ else:
346
+ x = self.time_conv(x, feat_cache[idx])
347
+ feat_cache[idx] = cache_x
348
+ feat_idx[0] += 1
349
+ x = x.reshape(b, 2, c, t, h, w)
350
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
351
+ x = x.reshape(b, c, t * 2, h, w)
352
+ t = x.shape[2]
353
+ x = rearrange(x, "b c t h w -> (b t) c h w")
354
+ x = self.resample(x)
355
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
356
+
357
+ if self.mode == "downsample3d":
358
+ if feat_cache is not None:
359
+ idx = feat_idx[0]
360
+ if feat_cache[idx] is None:
361
+ feat_cache[idx] = x.clone()
362
+ feat_idx[0] += 1
363
+ else:
364
+ cache_x = x[:, :, -1:, :, :].clone()
365
+ x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
366
+ feat_cache[idx] = cache_x
367
+ feat_idx[0] += 1
368
+
369
+ if one_plus_world_size(group):
370
+ if self.mode in ["upsample3d", "upsample2d"]:
371
+ overlap_size = overlap_size * 2
372
+ elif self.mode in ["downsample3d", "downsample2d"]:
373
+ overlap_size = overlap_size // 2
374
+ else:
375
+ overlap_size = overlap_size
376
+ x = all_gather_fwd_scatter_backward_with_overlap(x, group, overlap_size=overlap_size)
377
+ return x
378
+
379
+ def init_weight(self, conv):
380
+ conv_weight = conv.weight.detach().clone()
381
+ nn.init.zeros_(conv_weight)
382
+ c1, c2, t, h, w = conv_weight.size()
383
+ one_matrix = torch.eye(c1, c2)
384
+ init_matrix = one_matrix
385
+ nn.init.zeros_(conv_weight)
386
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
387
+ conv.weight = nn.Parameter(conv_weight)
388
+ nn.init.zeros_(conv.bias.data)
389
+
390
+ def init_weight2(self, conv):
391
+ conv_weight = conv.weight.data.detach().clone()
392
+ nn.init.zeros_(conv_weight)
393
+ c1, c2, t, h, w = conv_weight.size()
394
+ init_matrix = torch.eye(c1 // 2, c2)
395
+ conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
396
+ conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
397
+ conv.weight = nn.Parameter(conv_weight)
398
+ nn.init.zeros_(conv.bias.data)
399
+
400
+
401
+ class ResidualBlock(nn.Module):
402
+ def __init__(self, in_dim, out_dim, dropout=0.0):
403
+ super().__init__()
404
+ self.in_dim = in_dim
405
+ self.out_dim = out_dim
406
+
407
+ # layers
408
+ self.residual = nn.Sequential(
409
+ RMS_norm(in_dim, images=False),
410
+ nn.SiLU(),
411
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
412
+ RMS_norm(out_dim, images=False),
413
+ nn.SiLU(),
414
+ nn.Dropout(dropout),
415
+ CausalConv3d(out_dim, out_dim, 3, padding=1),
416
+ )
417
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
418
+
419
+
420
+ @torch.compile
421
+ def forward(self, x, feat_cache=None, feat_idx=[0], group: torch.distributed.ProcessGroup = None):
422
+ if one_plus_world_size(group):
423
+ overlap_size = 2
424
+ x = scatter_fwd_all_gather_backward_with_overlap(x, group, overlap_size=overlap_size)
425
+ h = self.shortcut(x)
426
+ for layer in self.residual:
427
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
428
+ idx = feat_idx[0]
429
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
430
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
431
+ # cache last frame of last two chunk
432
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
433
+ x = layer(x, feat_cache[idx])
434
+ feat_cache[idx] = cache_x
435
+ feat_idx[0] += 1
436
+ else:
437
+ x = layer(x)
438
+ x = x + h
439
+ if one_plus_world_size(group):
440
+ x = all_gather_fwd_scatter_backward_with_overlap(x, group, overlap_size=overlap_size)
441
+ return x
442
+
443
+
444
+ class AttentionBlock(nn.Module):
445
+ """
446
+ Causal self-attention with a single head.
447
+ """
448
+
449
+ def __init__(self, dim):
450
+ super().__init__()
451
+ self.dim = dim
452
+
453
+ # layers
454
+ self.norm = RMS_norm(dim)
455
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
456
+ self.proj = nn.Conv2d(dim, dim, 1)
457
+
458
+ # zero out the last layer params
459
+ nn.init.zeros_(self.proj.weight)
460
+
461
+ @torch.compile
462
+ def forward(self, x):
463
+ identity = x
464
+ b, c, t, h, w = x.size()
465
+ x = rearrange(x, "b c t h w -> (b t) c h w")
466
+ x = self.norm(x)
467
+ # compute query, key, value
468
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
469
+
470
+ # apply attention
471
+ x = F.scaled_dot_product_attention(q, k, v)
472
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
473
+
474
+ # output
475
+ x = self.proj(x)
476
+ x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
477
+ x = x + identity
478
+ return x
479
+
480
+
481
+ def patchify(x, patch_size):
482
+ if patch_size == 1:
483
+ return x
484
+ if x.dim() == 4:
485
+ x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
486
+ elif x.dim() == 5:
487
+ x = rearrange(x, "b c f (h q) (w r) -> b (c r q) f h w", q=patch_size, r=patch_size)
488
+ else:
489
+ raise ValueError(f"Invalid input shape: {x.shape}")
490
+
491
+ return x
492
+
493
+
494
+ def unpatchify(x, patch_size):
495
+ if patch_size == 1:
496
+ return x
497
+
498
+ if x.dim() == 4:
499
+ x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
500
+ elif x.dim() == 5:
501
+ x = rearrange(x, "b (c r q) f h w -> b c f (h q) (w r)", q=patch_size, r=patch_size)
502
+ return x
503
+
504
+
505
+ class AvgDown3D(nn.Module):
506
+ def __init__(self, in_channels, out_channels, factor_t, factor_s=1):
507
+ super().__init__()
508
+ self.in_channels = in_channels
509
+ self.out_channels = out_channels
510
+ self.factor_t = factor_t
511
+ self.factor_s = factor_s
512
+ self.factor = self.factor_t * self.factor_s * self.factor_s
513
+
514
+ assert in_channels * self.factor % out_channels == 0
515
+ self.group_size = in_channels * self.factor // out_channels
516
+
517
+ @torch.compile
518
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
519
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
520
+ pad = (0, 0, 0, 0, pad_t, 0)
521
+ x = F.pad(x, pad)
522
+ B, C, T, H, W = x.shape
523
+ x = x.view(
524
+ B, C, T // self.factor_t, self.factor_t, H // self.factor_s, self.factor_s, W // self.factor_s, self.factor_s
525
+ )
526
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
527
+ x = x.view(B, C * self.factor, T // self.factor_t, H // self.factor_s, W // self.factor_s)
528
+ x = x.view(B, self.out_channels, self.group_size, T // self.factor_t, H // self.factor_s, W // self.factor_s)
529
+ x = x.mean(dim=2)
530
+ return x
531
+
532
+
533
+ class DupUp3D(nn.Module):
534
+ def __init__(self, in_channels: int, out_channels: int, factor_t, factor_s=1):
535
+ super().__init__()
536
+ self.in_channels = in_channels
537
+ self.out_channels = out_channels
538
+
539
+ self.factor_t = factor_t
540
+ self.factor_s = factor_s
541
+ self.factor = self.factor_t * self.factor_s * self.factor_s
542
+
543
+ assert out_channels * self.factor % in_channels == 0
544
+ self.repeats = out_channels * self.factor // in_channels
545
+
546
+ @torch.compile
547
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
548
+ x = x.repeat_interleave(self.repeats, dim=1)
549
+ x = x.view(x.size(0), self.out_channels, self.factor_t, self.factor_s, self.factor_s, x.size(2), x.size(3), x.size(4))
550
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
551
+ x = x.view(
552
+ x.size(0), self.out_channels, x.size(2) * self.factor_t, x.size(4) * self.factor_s, x.size(6) * self.factor_s
553
+ )
554
+ if first_chunk:
555
+ x = x[:, :, self.factor_t - 1 :, :, :]
556
+ return x
557
+
558
+
559
+ class Down_ResidualBlock(nn.Module):
560
+ def __init__(self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False):
561
+ super().__init__()
562
+
563
+ # Shortcut path with downsample
564
+ self.avg_shortcut = AvgDown3D(
565
+ in_dim, out_dim, factor_t=2 if temperal_downsample else 1, factor_s=2 if down_flag else 1
566
+ )
567
+
568
+ # Main path with residual blocks and downsample
569
+ downsamples = []
570
+ for _ in range(mult):
571
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
572
+ in_dim = out_dim
573
+
574
+ # Add the final downsample block
575
+ if down_flag:
576
+ mode = "downsample3d" if temperal_downsample else "downsample2d"
577
+ downsamples.append(Resample(out_dim, mode=mode))
578
+
579
+ self.downsamples = nn.Sequential(*downsamples)
580
+
581
+ @torch.compile
582
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
583
+ x_copy = x.clone()
584
+ for module in self.downsamples:
585
+ x = module(x, feat_cache, feat_idx)
586
+
587
+ return x + self.avg_shortcut(x_copy)
588
+
589
+
590
+ class Up_ResidualBlock(nn.Module):
591
+ def __init__(self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False):
592
+ super().__init__()
593
+ # Shortcut path with upsample
594
+ if up_flag:
595
+ self.avg_shortcut = DupUp3D(in_dim, out_dim, factor_t=2 if temperal_upsample else 1, factor_s=2 if up_flag else 1)
596
+ else:
597
+ self.avg_shortcut = None
598
+
599
+ # Main path with residual blocks and upsample
600
+ upsamples = []
601
+ for _ in range(mult):
602
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
603
+ in_dim = out_dim
604
+
605
+ # Add the final upsample block
606
+ if up_flag:
607
+ mode = "upsample3d" if temperal_upsample else "upsample2d"
608
+ upsamples.append(Resample(out_dim, mode=mode))
609
+
610
+ self.upsamples = nn.Sequential(*upsamples)
611
+
612
+ @torch.compile
613
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False, group: torch.distributed.ProcessGroup = None):
614
+ x_main = x.clone()
615
+ for module in self.upsamples:
616
+ x_main = module(x_main, feat_cache, feat_idx, group=group)
617
+ if self.avg_shortcut is not None:
618
+ x_shortcut = self.avg_shortcut(x, first_chunk)
619
+ return x_main + x_shortcut
620
+ else:
621
+ return x_main
622
+
623
+
624
+ class Encoder3d(nn.Module):
625
+ def __init__(
626
+ self,
627
+ dim=128,
628
+ z_dim=4,
629
+ dim_mult=[1, 2, 4, 4],
630
+ num_res_blocks=2,
631
+ attn_scales=[],
632
+ temperal_downsample=[True, True, False],
633
+ dropout=0.0,
634
+ ):
635
+ super().__init__()
636
+ self.dim = dim
637
+ self.z_dim = z_dim
638
+ self.dim_mult = dim_mult
639
+ self.num_res_blocks = num_res_blocks
640
+ self.attn_scales = attn_scales
641
+ self.temperal_downsample = temperal_downsample
642
+
643
+ # dimensions
644
+ dims = [dim * u for u in [1] + dim_mult]
645
+ scale = 1.0
646
+
647
+ # init block
648
+ self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
649
+
650
+ # downsample blocks
651
+ downsamples = []
652
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
653
+ t_down_flag = temperal_downsample[i] if i < len(temperal_downsample) else False
654
+ downsamples.append(
655
+ Down_ResidualBlock(
656
+ in_dim=in_dim,
657
+ out_dim=out_dim,
658
+ dropout=dropout,
659
+ mult=num_res_blocks,
660
+ temperal_downsample=t_down_flag,
661
+ down_flag=i != len(dim_mult) - 1,
662
+ )
663
+ )
664
+ scale /= 2.0
665
+ self.downsamples = nn.Sequential(*downsamples)
666
+
667
+ # middle blocks
668
+ self.middle = nn.Sequential(
669
+ ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout)
670
+ )
671
+
672
+ # # output blocks
673
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1))
674
+
675
+ @torch.compile
676
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
677
+ if feat_cache is not None:
678
+ idx = feat_idx[0]
679
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
680
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
681
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
682
+ x = self.conv1(x, feat_cache[idx])
683
+ feat_cache[idx] = cache_x
684
+ feat_idx[0] += 1
685
+ else:
686
+ x = self.conv1(x)
687
+
688
+ # downsamples
689
+ for layer in self.downsamples:
690
+ if feat_cache is not None:
691
+ x = layer(x, feat_cache, feat_idx)
692
+ else:
693
+ x = layer(x)
694
+
695
+ # middle
696
+ for layer in self.middle:
697
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
698
+ x = layer(x, feat_cache, feat_idx)
699
+ else:
700
+ x = layer(x)
701
+
702
+ # head
703
+ for layer in self.head:
704
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
705
+ idx = feat_idx[0]
706
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
707
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
708
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
709
+ x = layer(x, feat_cache[idx])
710
+ feat_cache[idx] = cache_x
711
+ feat_idx[0] += 1
712
+ else:
713
+ x = layer(x)
714
+
715
+ return x
716
+
717
+
718
+ class Decoder3d(nn.Module):
719
+ def __init__(
720
+ self,
721
+ dim=128,
722
+ z_dim=4,
723
+ dim_mult=[1, 2, 4, 4],
724
+ num_res_blocks=2,
725
+ attn_scales=[],
726
+ temperal_upsample=[False, True, True],
727
+ dropout=0.0,
728
+ ):
729
+ super().__init__()
730
+ self.dim = dim
731
+ self.z_dim = z_dim
732
+ self.dim_mult = dim_mult
733
+ self.num_res_blocks = num_res_blocks
734
+ self.attn_scales = attn_scales
735
+ self.temperal_upsample = temperal_upsample
736
+
737
+ # dimensions
738
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
739
+ # scale = 1.0 / 2 ** (len(dim_mult) - 2)
740
+ # init block
741
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
742
+
743
+ # middle blocks
744
+ self.middle = nn.Sequential(
745
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout)
746
+ )
747
+
748
+ # upsample blocks
749
+ upsamples = []
750
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
751
+ t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False
752
+ upsamples.append(
753
+ Up_ResidualBlock(
754
+ in_dim=in_dim,
755
+ out_dim=out_dim,
756
+ dropout=dropout,
757
+ mult=num_res_blocks + 1,
758
+ temperal_upsample=t_up_flag,
759
+ up_flag=i != len(dim_mult) - 1,
760
+ )
761
+ )
762
+ self.upsamples = nn.Sequential(*upsamples)
763
+
764
+ # output blocks
765
+ self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 12, 3, padding=1))
766
+
767
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False, group: torch.distributed.ProcessGroup = None):
768
+ if feat_cache is not None:
769
+ idx = feat_idx[0]
770
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
771
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
772
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
773
+ x = self.conv1(x, feat_cache[idx], group=group)
774
+ feat_cache[idx] = cache_x
775
+ feat_idx[0] += 1
776
+ else:
777
+ x = self.conv1(x, group=group)
778
+
779
+ for layer in self.middle:
780
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
781
+ x = layer(x, feat_cache, feat_idx, group=group)
782
+ else:
783
+ x = layer(x)
784
+
785
+ # upsamples
786
+ for layer in self.upsamples:
787
+ if feat_cache is not None:
788
+ x = layer(x, feat_cache, feat_idx, first_chunk, group=group)
789
+ else:
790
+ x = layer(x, group=group)
791
+
792
+ # head
793
+ if one_plus_world_size(group):
794
+ overlap_size = self.head[2].kernel_size[-1] // 2 * self.head[2].stride[-1]
795
+ x = scatter_fwd_all_gather_backward_with_overlap(x, group, overlap_size=overlap_size)
796
+ for layer in self.head:
797
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
798
+ idx = feat_idx[0]
799
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
800
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
801
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
802
+ x = layer(x, feat_cache[idx])
803
+ feat_cache[idx] = cache_x
804
+ feat_idx[0] += 1
805
+ else:
806
+ x = layer(x)
807
+ if one_plus_world_size(group):
808
+ x = all_gather_fwd_scatter_backward_with_overlap(x, group, overlap_size=overlap_size)
809
+ return x
810
+
811
+
812
+ def count_conv3d(model):
813
+ count = 0
814
+ for m in model.modules():
815
+ if isinstance(m, CausalConv3d):
816
+ count += 1
817
+ return count
818
+
819
+
820
+ class WanVAE_(nn.Module):
821
+ def __init__(
822
+ self,
823
+ dim=160,
824
+ dec_dim=256,
825
+ z_dim=16,
826
+ dim_mult=[1, 2, 4, 4],
827
+ num_res_blocks=2,
828
+ attn_scales=[],
829
+ temperal_downsample=[True, True, False],
830
+ dropout=0.0,
831
+ ):
832
+ super().__init__()
833
+ self.dim = dim
834
+ self.z_dim = z_dim
835
+ self.dim_mult = dim_mult
836
+ self.num_res_blocks = num_res_blocks
837
+ self.attn_scales = attn_scales
838
+ self.temperal_downsample = temperal_downsample
839
+ self.temperal_upsample = temperal_downsample[::-1]
840
+
841
+ # modules
842
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout)
843
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
844
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
845
+ self.decoder = Decoder3d(dec_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout)
846
+
847
+ def forward(self, x, scale=[0, 1]):
848
+ mu = self.encode(x, scale)
849
+ x_recon = self.decode(mu, scale)
850
+ return x_recon, mu
851
+
852
+ def encode(self, x, scale):
853
+ self.clear_cache()
854
+ x = patchify(x, patch_size=2)
855
+ t = x.shape[2]
856
+ iter_ = 1 + (t - 1) // 4
857
+ for i in range(iter_):
858
+ self._enc_conv_idx = [0]
859
+ if i == 0:
860
+ out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
861
+ else:
862
+ out_ = self.encoder(
863
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx
864
+ )
865
+ out = torch.cat([out, out_], 2)
866
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
867
+ if isinstance(scale[0], torch.Tensor):
868
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
869
+ else:
870
+ mu = (mu - scale[0]) * scale[1]
871
+ self.clear_cache()
872
+ return mu
873
+
874
+ def decode(self, z, scale, group: torch.distributed.ProcessGroup = None):
875
+ self.clear_cache()
876
+ if isinstance(scale[0], torch.Tensor):
877
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
878
+ else:
879
+ z = z / scale[1] + scale[0]
880
+ iter_ = z.shape[2]
881
+ x = self.conv2(z, group=group)
882
+ for i in range(iter_):
883
+ self._conv_idx = [0]
884
+ if i == 0:
885
+ out = self.decoder(
886
+ x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True, group=group
887
+ )
888
+ else:
889
+ out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, group=group)
890
+ out = torch.cat([out, out_], 2)
891
+ out = unpatchify(out, patch_size=2)
892
+ self.clear_cache()
893
+ return out
894
+
895
+ def reparameterize(self, mu, log_var):
896
+ std = torch.exp(0.5 * log_var)
897
+ eps = torch.randn_like(std)
898
+ return eps * std + mu
899
+
900
+ def sample(self, imgs, deterministic=False):
901
+ mu, log_var = self.encode(imgs)
902
+ if deterministic:
903
+ return mu
904
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
905
+ return mu + std * torch.randn_like(std)
906
+
907
+ def clear_cache(self):
908
+ self._conv_num = count_conv3d(self.decoder)
909
+ self._conv_idx = [0]
910
+ self._feat_map = [None] * self._conv_num
911
+ # cache encode
912
+ self._enc_conv_num = count_conv3d(self.encoder)
913
+ self._enc_conv_idx = [0]
914
+ self._enc_feat_map = [None] * self._enc_conv_num
915
+
916
+
917
+ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
918
+ # params
919
+ cfg = dict(
920
+ dim=dim,
921
+ z_dim=z_dim,
922
+ dim_mult=[1, 2, 4, 4],
923
+ num_res_blocks=2,
924
+ attn_scales=[],
925
+ temperal_downsample=[True, True, True],
926
+ dropout=0.0,
927
+ )
928
+ cfg.update(**kwargs)
929
+
930
+ # init model
931
+ with torch.device("meta"):
932
+ model = WanVAE_(**cfg)
933
+
934
+ # load checkpoint
935
+ logging.info(f"loading {pretrained_path}")
936
+ model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)
937
+
938
+ return model
939
+
940
+
941
+ class Wan2_2_VAE:
942
+ def __init__(
943
+ self,
944
+ z_dim=48,
945
+ c_dim=160,
946
+ vae_pth=None,
947
+ dim_mult=[1, 2, 4, 4],
948
+ temperal_downsample=[False, True, True],
949
+ dtype=torch.float,
950
+ device="cuda",
951
+ ):
952
+ self.dtype = dtype
953
+ self.device = device
954
+
955
+ self.mean = torch.tensor(
956
+ [
957
+ -0.2289,
958
+ -0.0052,
959
+ -0.1323,
960
+ -0.2339,
961
+ -0.2799,
962
+ 0.0174,
963
+ 0.1838,
964
+ 0.1557,
965
+ -0.1382,
966
+ 0.0542,
967
+ 0.2813,
968
+ 0.0891,
969
+ 0.1570,
970
+ -0.0098,
971
+ 0.0375,
972
+ -0.1825,
973
+ -0.2246,
974
+ -0.1207,
975
+ -0.0698,
976
+ 0.5109,
977
+ 0.2665,
978
+ -0.2108,
979
+ -0.2158,
980
+ 0.2502,
981
+ -0.2055,
982
+ -0.0322,
983
+ 0.1109,
984
+ 0.1567,
985
+ -0.0729,
986
+ 0.0899,
987
+ -0.2799,
988
+ -0.1230,
989
+ -0.0313,
990
+ -0.1649,
991
+ 0.0117,
992
+ 0.0723,
993
+ -0.2839,
994
+ -0.2083,
995
+ -0.0520,
996
+ 0.3748,
997
+ 0.0152,
998
+ 0.1957,
999
+ 0.1433,
1000
+ -0.2944,
1001
+ 0.3573,
1002
+ -0.0548,
1003
+ -0.1681,
1004
+ -0.0667,
1005
+ ],
1006
+ dtype=dtype,
1007
+ device=device,
1008
+ )
1009
+ self.std = torch.tensor(
1010
+ [
1011
+ 0.4765,
1012
+ 1.0364,
1013
+ 0.4514,
1014
+ 1.1677,
1015
+ 0.5313,
1016
+ 0.4990,
1017
+ 0.4818,
1018
+ 0.5013,
1019
+ 0.8158,
1020
+ 1.0344,
1021
+ 0.5894,
1022
+ 1.0901,
1023
+ 0.6885,
1024
+ 0.6165,
1025
+ 0.8454,
1026
+ 0.4978,
1027
+ 0.5759,
1028
+ 0.3523,
1029
+ 0.7135,
1030
+ 0.6804,
1031
+ 0.5833,
1032
+ 1.4146,
1033
+ 0.8986,
1034
+ 0.5659,
1035
+ 0.7069,
1036
+ 0.5338,
1037
+ 0.4889,
1038
+ 0.4917,
1039
+ 0.4069,
1040
+ 0.4999,
1041
+ 0.6866,
1042
+ 0.4093,
1043
+ 0.5709,
1044
+ 0.6065,
1045
+ 0.6415,
1046
+ 0.4944,
1047
+ 0.5726,
1048
+ 1.2042,
1049
+ 0.5458,
1050
+ 1.6887,
1051
+ 0.3971,
1052
+ 1.0600,
1053
+ 0.3943,
1054
+ 0.5537,
1055
+ 0.5444,
1056
+ 0.4089,
1057
+ 0.7468,
1058
+ 0.7744,
1059
+ ],
1060
+ dtype=dtype,
1061
+ device=device,
1062
+ )
1063
+ self.scale = [self.mean, 1.0 / self.std]
1064
+
1065
+ # init model
1066
+ self.vae = (
1067
+ _video_vae(
1068
+ pretrained_path=vae_pth, z_dim=z_dim, dim=c_dim, dim_mult=dim_mult, temperal_downsample=temperal_downsample
1069
+ )
1070
+ .eval()
1071
+ .requires_grad_(False)
1072
+ .to(device)
1073
+ )
1074
+
1075
+ def encode(self, video):
1076
+ return self.vae.encode(video, self.scale).float()
1077
+
1078
+ def to(self, *args, **kwargs):
1079
+ self.mean = self.mean.to(*args, **kwargs)
1080
+ self.std = self.std.to(*args, **kwargs)
1081
+ self.scale = [self.mean, 1.0 / self.std]
1082
+ self.vae = self.vae.to(*args, **kwargs)
1083
+ return self
1084
+
1085
+ def decode(self, z, group: torch.distributed.ProcessGroup = None):
1086
+ return self.vae.decode(z, self.scale, group=group).float().clamp_(-1, 1)
inference/pipeline/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .pipeline import MagiPipeline
16
+
17
+ __all__ = [
18
+ # pipeline
19
+ "MagiPipeline",
20
+ ]
inference/pipeline/data_proxy.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from enum import IntEnum
18
+ from typing import Any, Literal, Optional, TYPE_CHECKING
19
+
20
+ import torch
21
+ from einops import rearrange
22
+ from inference.common import DataProxyConfig, Modality, VarlenHandler
23
+ from inference.model.dit.dit_module import FFAHandler
24
+ from torch.nn import functional as F
25
+ from unfoldNd import UnfoldNd
26
+
27
+ if TYPE_CHECKING:
28
+ from inference.pipeline.video_generate import EvalInput
29
+
30
+
31
+ def calc_local_qk_range(num_video_tokens, num_audio_and_txt_tokens, num_frames, frame_receptive_field):
32
+ token_per_frame = num_video_tokens // num_frames
33
+ total_tokens = num_video_tokens + num_audio_and_txt_tokens
34
+
35
+ q_range_list = []
36
+ k_range_list = []
37
+
38
+ for i in range(num_frames):
39
+ local_q_range = torch.tensor([i * token_per_frame, (i + 1) * token_per_frame])
40
+ local_k_range = torch.tensor(
41
+ [(i - frame_receptive_field) * token_per_frame, (i + frame_receptive_field + 1) * token_per_frame]
42
+ )
43
+
44
+ q_range_list.append(local_q_range)
45
+ k_range_list.append(local_k_range)
46
+ local_q_range = torch.stack(q_range_list, dim=0)
47
+ local_k_range = torch.stack(k_range_list, dim=0)
48
+
49
+ local_k_range[local_k_range < 0] = 0
50
+ local_k_range[local_k_range > num_video_tokens] = num_video_tokens
51
+
52
+ video_q_range = torch.tensor([[0, num_video_tokens]])
53
+ video_k_range = torch.tensor([[num_video_tokens, num_video_tokens + num_audio_and_txt_tokens]])
54
+
55
+ at_q_ranges = torch.tensor([[num_video_tokens, total_tokens]])
56
+ at_k_ranges = torch.tensor([[0, total_tokens]])
57
+
58
+ q_ranges = torch.cat([local_q_range, video_q_range, at_q_ranges], dim=0).to(torch.int32).to("cuda", non_blocking=True)
59
+ k_ranges = torch.cat([local_k_range, video_k_range, at_k_ranges], dim=0).to(torch.int32).to("cuda", non_blocking=True)
60
+
61
+ return (q_ranges, k_ranges)
62
+
63
+
64
+ def calc_local_attn_ffa_handler(num_video_tokens, num_audio_and_txt_tokens, num_frames, frame_receptive_field):
65
+ q_ranges, k_ranges = calc_local_qk_range(num_video_tokens, num_audio_and_txt_tokens, num_frames, frame_receptive_field)
66
+ max_seqlen_q = num_video_tokens + num_audio_and_txt_tokens
67
+ max_seqlen_k = num_video_tokens + num_audio_and_txt_tokens
68
+ attn_type_map = torch.zeros([q_ranges.shape[0]], device="cuda", dtype=torch.int32)
69
+ softmax_scale = None
70
+
71
+ ffa_handler = FFAHandler(
72
+ q_ranges=q_ranges,
73
+ k_ranges=k_ranges,
74
+ max_seqlen_q=max_seqlen_q,
75
+ max_seqlen_k=max_seqlen_k,
76
+ attn_type_map=attn_type_map,
77
+ softmax_scale=softmax_scale,
78
+ )
79
+ return ffa_handler
80
+
81
+
82
+ def get_coords(
83
+ shape: list[int],
84
+ ref_feat_shape: list[int],
85
+ offset_thw: list[int] = [0, 0, 0],
86
+ device: torch.device = torch.device("cpu"),
87
+ dtype: torch.dtype = torch.float32,
88
+ ):
89
+ """
90
+ Generate feature-grid coordinates and corresponding original/reference size metadata.
91
+ Args:
92
+ feat_shape: [T, H, W] original feature-map shape
93
+ ref_feat_shape: [T_ref, H_ref, W_ref] reference feature-map shape
94
+ device: device for coordinate tensors
95
+ Returns:
96
+ coords: tensor shape (T*H*W, 9), containing (t, h, w, T, H, W, ref_T, ref_H, ref_W)
97
+ """
98
+ ori_t, ori_h, ori_w = shape
99
+ ref_t, ref_h, ref_w = ref_feat_shape
100
+
101
+ # Generate index ranges
102
+ offset_t, offset_h, offset_w = offset_thw
103
+ time_rng = torch.arange(ori_t, device=device, dtype=dtype) + offset_t
104
+ height_rng = torch.arange(ori_h, device=device, dtype=dtype) + offset_h
105
+ width_rng = torch.arange(ori_w, device=device, dtype=dtype) + offset_w
106
+
107
+ # Use meshgrid to generate a 3D grid (T, H, W)
108
+ time_grid, height_grid, width_grid = torch.meshgrid(time_rng, height_rng, width_rng, indexing="ij")
109
+
110
+ # Stack and flatten
111
+ coords_grid = torch.stack([time_grid, height_grid, width_grid], dim=-1)
112
+ coords_flat = coords_grid.reshape(-1, 3)
113
+
114
+ # Build and expand size metadata
115
+ meta = torch.tensor([ori_t, ori_h, ori_w, ref_t, ref_h, ref_w], device=device, dtype=dtype)
116
+ meta_expanded = meta.expand(coords_flat.size(0), -1)
117
+
118
+ # Merge and return
119
+ return torch.cat([coords_flat, meta_expanded], dim=-1)
120
+
121
+
122
+ @dataclass
123
+ class SingleData:
124
+ video_x_t: torch.Tensor
125
+ audio_x_t: torch.Tensor
126
+ audio_feat_len: int
127
+ txt_feat: torch.Tensor
128
+ txt_feat_len: int
129
+ t: int
130
+ h: int
131
+ w: int
132
+ patch_size: int
133
+ t_patch_size: int
134
+ spatial_rope_interpolation: Literal["inter", "extra"]
135
+ ref_audio_offset: int
136
+ text_offset: int
137
+ coords_style: Literal["v1", "v2"] = "v1"
138
+
139
+ def __post_init__(self):
140
+ self.video_token_num = self.video_x_t.shape[0]
141
+
142
+ self.audio_x_t = self.audio_x_t[: self.audio_feat_len]
143
+ self.txt_feat = self.txt_feat[: self.txt_feat_len]
144
+
145
+ self.video_channel = self.video_x_t.shape[-1]
146
+ self.audio_channel = self.audio_x_t.shape[-1]
147
+ self.txt_channel = self.txt_feat.shape[-1]
148
+
149
+ @property
150
+ def device(self):
151
+ return self.video_x_t.device
152
+
153
+ @property
154
+ def default_dtype(self):
155
+ return self.video_x_t.dtype
156
+
157
+ @property
158
+ def total_token_num(self):
159
+ return self.video_token_num + self.audio_feat_len + self.txt_feat_len
160
+
161
+ @property
162
+ def token_sequence(self):
163
+ tensors_to_concat = [self.video_x_t, self.audio_x_t, self.txt_feat]
164
+ max_channel = max(tensor.shape[-1] for tensor in tensors_to_concat)
165
+
166
+ padded_tensors = [F.pad(t, (0, max_channel - t.shape[-1])) for t in tensors_to_concat]
167
+ ret_val = torch.cat(padded_tensors, dim=0)
168
+ return ret_val
169
+
170
+ @property
171
+ def modality_mapping(self):
172
+ v_map = torch.full((self.video_token_num,), Modality.VIDEO, dtype=torch.int64, device=self.device)
173
+ a_map = torch.full((self.audio_feat_len,), Modality.AUDIO, dtype=torch.int64, device=self.device)
174
+ t_map = torch.full((self.txt_feat_len,), Modality.TEXT, dtype=torch.int64, device=self.device)
175
+
176
+ modality_mapping = torch.cat([v_map, a_map, t_map], dim=0)
177
+ return modality_mapping
178
+
179
+ def default_coords(self, shape, ref_feat_shape, offset_thw=[0, 0, 0]):
180
+ return get_coords(
181
+ shape=shape, ref_feat_shape=ref_feat_shape, offset_thw=offset_thw, device=self.device, dtype=self.default_dtype
182
+ )
183
+
184
+ @property
185
+ def coords_mapping(self):
186
+ if self.spatial_rope_interpolation == "inter":
187
+ video_ref_feat_shape = (self.t // self.t_patch_size, 32, 32)
188
+ else:
189
+ video_ref_feat_shape = (self.t // self.t_patch_size, self.h // self.patch_size, self.w // self.patch_size)
190
+
191
+ video_coords = self.default_coords(
192
+ shape=(self.t // self.t_patch_size, self.h // self.patch_size, self.w // self.patch_size),
193
+ ref_feat_shape=video_ref_feat_shape,
194
+ )
195
+
196
+ if self.coords_style == "v1":
197
+ audio_coords = self.default_coords(
198
+ shape=(self.audio_feat_len, 1, 1), ref_feat_shape=(self.t // self.t_patch_size, 1, 1)
199
+ )
200
+
201
+ text_coords = self.default_coords(
202
+ shape=(self.txt_feat_len, 1, 1), ref_feat_shape=(2, 1, 1), offset_thw=[self.text_offset, 0, 0]
203
+ )
204
+
205
+ elif self.coords_style == "v2":
206
+ magic_audio_ref_t = (self.audio_feat_len - 1) // 4 + 1
207
+ audio_coords = self.default_coords(
208
+ shape=(self.audio_feat_len, 1, 1), ref_feat_shape=(magic_audio_ref_t // self.t_patch_size, 1, 1)
209
+ )
210
+
211
+ text_coords = self.default_coords(
212
+ shape=(self.txt_feat_len, 1, 1), ref_feat_shape=(1, 1, 1), offset_thw=[-self.txt_feat_len, 0, 0]
213
+ )
214
+
215
+ coords_mapping = torch.cat([video_coords, audio_coords, text_coords], dim=0)
216
+ return coords_mapping
217
+
218
+ def depack_token_sequence(self, token_sequence):
219
+ video_x_t = token_sequence[: self.video_token_num, : self.video_channel]
220
+ video_x_t = rearrange(
221
+ video_x_t,
222
+ "(T H W) (pT pH pW C) -> C (T pT) (H pH) (W pW)",
223
+ H=self.h // self.patch_size,
224
+ W=self.w // self.patch_size,
225
+ pT=self.t_patch_size,
226
+ pH=self.patch_size,
227
+ pW=self.patch_size,
228
+ ).contiguous()
229
+
230
+ audio_x_t = token_sequence[self.video_token_num : self.video_token_num + self.audio_feat_len, : self.audio_channel]
231
+ return video_x_t, audio_x_t
232
+
233
+
234
+ @dataclass
235
+ class SimplePackedData:
236
+ items: list[SingleData]
237
+
238
+ @property
239
+ def token_sequence(self):
240
+ return torch.cat([item.token_sequence for item in self.items], dim=0)
241
+
242
+ @property
243
+ def modality_mapping(self):
244
+ return torch.cat([item.modality_mapping for item in self.items], dim=0)
245
+
246
+ @property
247
+ def coords_mapping(self):
248
+ return torch.cat([item.coords_mapping for item in self.items], dim=0)
249
+
250
+ @property
251
+ def total_token_num(self):
252
+ return sum([item.total_token_num for item in self.items])
253
+
254
+ def __getitem__(self, index):
255
+ return self.items[index]
256
+
257
+ @property
258
+ def cu_seqlen(self):
259
+ cu_seqlen = torch.cumsum(torch.tensor([item.total_token_num for item in self.items]), dim=0)
260
+ cu_seqlen = torch.nn.functional.pad(cu_seqlen, (1, 0))
261
+ return cu_seqlen
262
+
263
+ @property
264
+ def max_seqlen(self):
265
+ return torch.tensor(max([item.total_token_num for item in self.items]))
266
+
267
+ def depack_token_sequence(self, token_sequence):
268
+ video_x_t_list = []
269
+ audio_x_t_list = []
270
+
271
+ token_sequence_list = torch.split(token_sequence, [item.total_token_num for item in self.items], dim=0)
272
+ for item, token_sequence in zip(self.items, token_sequence_list):
273
+ video_x_t, audio_x_t = item.depack_token_sequence(token_sequence)
274
+ video_x_t_list.append(video_x_t)
275
+ audio_x_t_list.append(audio_x_t)
276
+ return torch.stack(video_x_t_list, dim=0), torch.stack(audio_x_t_list, dim=0)
277
+
278
+
279
+ class MagiDataProxy:
280
+ def __init__(self, config: DataProxyConfig):
281
+ self.patch_size = config.patch_size
282
+ self.t_patch_size = config.t_patch_size
283
+ self.frame_receptive_field = config.frame_receptive_field
284
+ self.spatial_rope_interpolation = 'extra'
285
+ self.ref_audio_offset = config.ref_audio_offset
286
+ self.text_offset = config.text_offset
287
+ self.unfold = UnfoldNd(
288
+ kernel_size=(self.t_patch_size, self.patch_size, self.patch_size),
289
+ stride=(self.t_patch_size, self.patch_size, self.patch_size),
290
+ )
291
+ self.coords_style = config.coords_style
292
+
293
+ self._saved_data: dict[str, Any] = {}
294
+
295
+ def saved_for_output(self, **kwargs):
296
+ """
297
+ Store intermediate data used by process_output.
298
+ Supports keyword-argument style calls: saved_for_output(a=1, b=2)
299
+ Can be called multiple times to accumulate data
300
+
301
+ Args:
302
+ **kwargs: key-value pairs to store
303
+ """
304
+ # Directly update dict; supports accumulation across calls
305
+ self._saved_data.update(kwargs)
306
+
307
+ def get_saved_data(self, key: str):
308
+ """
309
+ Get stored data
310
+ """
311
+ return self._saved_data[key]
312
+
313
+ def img2tokens(self, x_t: torch.Tensor):
314
+ x_t_unfolded = self.unfold(x_t)
315
+ # Transpose dimensions from (N, col_dim, num_tokens) -> (N, num_tokens, col_dim)
316
+ x_t = rearrange(x_t_unfolded, "N col_dim num_tokens -> N num_tokens col_dim").contiguous()
317
+ return x_t
318
+
319
+ def process_input(self, transported_data: "EvalInput"):
320
+ # init img2col module
321
+
322
+ batch_size, _, t, h, w = transported_data.x_t.shape
323
+ # 1. Process video features while keeping the batch dimension
324
+ x_t = self.img2tokens(transported_data.x_t)
325
+
326
+ # 2. Process audio features while keeping the batch dimension
327
+ # Assume transported_data.audio_x_t shape is already (N, num_tokens, col_dim)
328
+ audio_x_t = transported_data.audio_x_t.contiguous()
329
+
330
+ # Here we assume text_in shape is (N, num_tokens, col_dim)
331
+ text_in = transported_data.txt_feat.contiguous()
332
+
333
+ simple_packed_data = SimplePackedData(items=[])
334
+ for i in range(batch_size):
335
+ single_data = SingleData(
336
+ video_x_t=x_t[i],
337
+ audio_x_t=audio_x_t[i],
338
+ audio_feat_len=transported_data.audio_feat_len[i],
339
+ txt_feat=text_in[i],
340
+ txt_feat_len=transported_data.txt_feat_len[i],
341
+ t=t,
342
+ h=h,
343
+ w=w,
344
+ patch_size=self.patch_size,
345
+ t_patch_size=self.t_patch_size,
346
+ spatial_rope_interpolation=self.spatial_rope_interpolation,
347
+ ref_audio_offset=self.ref_audio_offset,
348
+ text_offset=self.text_offset,
349
+ coords_style=self.coords_style,
350
+ )
351
+ simple_packed_data.items.append(single_data)
352
+
353
+ if self.frame_receptive_field != -1:
354
+ assert batch_size == 1, "local attention only supports batch size 1"
355
+
356
+ local_attn_handler = calc_local_attn_ffa_handler(
357
+ num_video_tokens=simple_packed_data[0].video_token_num,
358
+ num_audio_and_txt_tokens=simple_packed_data[0].audio_feat_len + simple_packed_data[0].txt_feat_len,
359
+ num_frames=t,
360
+ frame_receptive_field=self.frame_receptive_field,
361
+ )
362
+ if isinstance(local_attn_handler.max_seqlen_k, torch.Tensor):
363
+ local_attn_handler.max_seqlen_k = local_attn_handler.max_seqlen_k.item()
364
+ if isinstance(local_attn_handler.max_seqlen_q, torch.Tensor):
365
+ local_attn_handler.max_seqlen_q = local_attn_handler.max_seqlen_q.item()
366
+ else:
367
+ local_attn_handler = None
368
+
369
+ varlen_handler = VarlenHandler(
370
+ cu_seqlens_q=simple_packed_data.cu_seqlen.to(torch.int32).cuda(),
371
+ cu_seqlens_k=simple_packed_data.cu_seqlen.to(torch.int32).cuda(),
372
+ max_seqlen_q=simple_packed_data.max_seqlen.to(torch.int32).cuda(),
373
+ max_seqlen_k=simple_packed_data.max_seqlen.to(torch.int32).cuda(),
374
+ )
375
+
376
+ self.saved_for_output(simple_packed_data=simple_packed_data)
377
+
378
+ x = simple_packed_data.token_sequence
379
+ coords_mapping = simple_packed_data.coords_mapping
380
+ modality_mapping = simple_packed_data.modality_mapping
381
+
382
+ return (x, coords_mapping, modality_mapping, varlen_handler, local_attn_handler)
383
+
384
+ def process_output(self, x: torch.Tensor):
385
+ # Inserting operations in between may corrupt parallel-runtime data and cause latent errors
386
+
387
+ simple_packed_data: SimplePackedData = self.get_saved_data("simple_packed_data")
388
+ x_video, x_audio = simple_packed_data.depack_token_sequence(x)
389
+
390
+ return (x_video, x_audio)
inference/pipeline/entry.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import sys
17
+
18
+ from inference.common import parse_config
19
+ from inference.infra import initialize_infra
20
+ from inference.model.dit import get_dit
21
+ from inference.utils import print_rank_0
22
+
23
+ try:
24
+ from .pipeline import MagiPipeline
25
+ except ImportError:
26
+ # Keep compatibility when entry.py is executed as a script path.
27
+ from inference.pipeline import MagiPipeline
28
+
29
+
30
+ def parse_arguments():
31
+ parser = argparse.ArgumentParser(description="Run DiT pipeline with unified offline entry.")
32
+ parser.add_argument("--prompt", type=str)
33
+ parser.add_argument("--save_path_prefix", type=str, help="Path prefix for saving outputs.")
34
+ parser.add_argument("--output_path", type=str, help="Alias of --save_path_prefix for MAGI-style CLI.")
35
+
36
+ parser.add_argument("--image_path", type=str, help="Path to image for i2v mode.")
37
+ parser.add_argument(
38
+ "--audio_path", type=str, default=None, help="Path to optional audio for lipsync mode; omit to use i2v or t2v"
39
+ )
40
+
41
+ # Optional runtime controls; forwarded to pipeline methods when provided.
42
+ parser.add_argument("--seed", type=int)
43
+ parser.add_argument("--seconds", type=int)
44
+ parser.add_argument("--br_width", type=int)
45
+ parser.add_argument("--br_height", type=int)
46
+ parser.add_argument("--sr_width", type=int)
47
+ parser.add_argument("--sr_height", type=int)
48
+ parser.add_argument("--output_width", type=int)
49
+ parser.add_argument("--output_height", type=int)
50
+ parser.add_argument("--upsample_mode", type=str)
51
+ args, _ = parser.parse_known_args()
52
+ return args
53
+
54
+
55
+ def main():
56
+ args = parse_arguments()
57
+ config = parse_config()
58
+ model = get_dit(config.arch_config, config.engine_config)
59
+ pipeline = MagiPipeline(model, config.evaluation_config)
60
+ save_path_prefix = args.save_path_prefix or args.output_path
61
+ if not save_path_prefix:
62
+ print_rank_0("Error: --save_path_prefix (or --output_path) is required.")
63
+ sys.exit(1)
64
+
65
+ optional_kwargs = {
66
+ "seed": args.seed,
67
+ "seconds": args.seconds,
68
+ "br_width": args.br_width,
69
+ "br_height": args.br_height,
70
+ "sr_width": args.sr_width,
71
+ "sr_height": args.sr_height,
72
+ "output_width": args.output_width,
73
+ "output_height": args.output_height,
74
+ "upsample_mode": args.upsample_mode,
75
+ }
76
+ optional_kwargs = {k: v for k, v in optional_kwargs.items() if v is not None and v is not False}
77
+
78
+ prompt = args.prompt
79
+ image_path = args.image_path
80
+ audio_path = args.audio_path
81
+
82
+ if not prompt:
83
+ print_rank_0("Error: --prompt is required.")
84
+ sys.exit(1)
85
+ if not image_path:
86
+ print_rank_0("Error: --image_path is required.")
87
+ sys.exit(1)
88
+
89
+ pipeline.run_offline(
90
+ prompt=prompt, image=image_path, audio=audio_path, save_path_prefix=save_path_prefix, **optional_kwargs
91
+ )
92
+
93
+
94
+ if __name__ == "__main__":
95
+ initialize_infra()
96
+ main()
inference/pipeline/pipeline.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import random
17
+ from typing import Optional, Union
18
+
19
+ import imageio
20
+ import soundfile as sf
21
+ import torch
22
+ from PIL import Image
23
+
24
+ from inference.common import EvaluationConfig, parse_config
25
+ from inference.model.dit import get_dit
26
+ from inference.model.dit import DiTModel
27
+ from .video_generate import MagiEvaluator
28
+ from .video_process import merge_video_and_audio, upsample_video
29
+
30
+
31
+ class MagiPipeline:
32
+ """Pipeline facade for inference."""
33
+
34
+ def __init__(self, model: DiTModel, evaluation_config: EvaluationConfig, device: str = "cuda"):
35
+ self.model = model
36
+ self.evaluation_config = evaluation_config
37
+ config = parse_config()
38
+ if evaluation_config.use_sr_model:
39
+ config.engine_config.load = evaluation_config.sr_model_path
40
+ sr_model = get_dit(config.sr_arch_config, config.engine_config)
41
+ else:
42
+ sr_model = None
43
+ self.evaluator = MagiEvaluator(model, sr_model, evaluation_config, device)
44
+
45
+ def _validate_offline_request(
46
+ self,
47
+ prompt: str,
48
+ save_path_prefix: str,
49
+ ):
50
+ if not prompt or not prompt.strip():
51
+ raise ValueError("`prompt` must be a non-empty string.")
52
+ if not save_path_prefix or not save_path_prefix.strip():
53
+ raise ValueError("`save_path_prefix` must be a non-empty string.")
54
+
55
+ def run_offline(
56
+ self,
57
+ prompt: str,
58
+ image: Union[str, Image.Image, None],
59
+ audio: Optional[str],
60
+ save_path_prefix: str,
61
+ seed: int = 42,
62
+ seconds: int = 4,
63
+ br_width: int = 480,
64
+ br_height: int = 272,
65
+ sr_width: Optional[int] = None,
66
+ sr_height: Optional[int] = None,
67
+ output_width: Optional[int] = None,
68
+ output_height: Optional[int] = None,
69
+ upsample_mode: Optional[str] = None,
70
+ ):
71
+ self._validate_offline_request(prompt=prompt, save_path_prefix=save_path_prefix)
72
+
73
+ if self.evaluator.sr_model is not None:
74
+ save_path = f"{save_path_prefix}_{seconds}s_{br_width}x{br_height}_{sr_width}x{sr_height}.mp4"
75
+ else:
76
+ save_path = f"{save_path_prefix}_{seconds}s_{br_width}x{br_height}.mp4"
77
+
78
+ with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
79
+ torch.random.manual_seed(seed)
80
+ video_np, audio_np = self.evaluator.evaluate(
81
+ prompt,
82
+ image,
83
+ audio,
84
+ seconds=seconds,
85
+ br_width=br_width,
86
+ br_height=br_height,
87
+ sr_width=sr_width,
88
+ sr_height=sr_height,
89
+ br_num_inference_steps=self.evaluation_config.num_inference_steps,
90
+ sr_num_inference_steps=self.evaluation_config.sr_num_inference_steps,
91
+ )
92
+
93
+ if output_width is not None and output_height is not None:
94
+ video_np = upsample_video(video_np, output_width, output_height, upsample_mode)
95
+
96
+ if torch.distributed.get_rank() == torch.distributed.get_world_size() - 1:
97
+ saving_name = f"{prompt.replace(' ', '_')[:10]}"
98
+ audio_path = saving_name + str(random.randint(0, 1000000)) + ".wav"
99
+ video_path = saving_name + str(random.randint(0, 1000000)) + ".mp4"
100
+ sf.write(audio_path, audio_np, self.evaluator.audio_vae.sample_rate)
101
+ imageio.mimwrite(video_path, video_np, fps=self.evaluation_config.fps, quality=8, output_params=["-loglevel", "error"])
102
+ assert os.path.exists(video_path)
103
+ merge_video_and_audio(video_path, audio_path, save_path)
104
+
105
+ if torch.distributed.is_initialized():
106
+ torch.distributed.barrier()
107
+ return save_path
108
+
inference/pipeline/prompt_process.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple
16
+
17
+ import torch
18
+ from torch.nn import functional as F
19
+
20
+ from inference.model.t5_gemma import get_t5_gemma_embedding
21
+
22
+
23
+ def pad_or_trim(tensor: torch.Tensor, target_size: int, dim: int, pad_value: float = 0.0) -> Tuple[torch.Tensor, int]:
24
+ """
25
+ Pads or trims a tensor along a specified dimension to reach a target size.
26
+
27
+ Args:
28
+ tensor (torch.Tensor): The input tensor to be processed.
29
+ target_size (int): The desired size for the specified dimension.
30
+ dim (int): The dimension along which to pad or trim.
31
+ pad_value (float, optional): The value used for padding. Defaults to 0.0.
32
+
33
+ Returns:
34
+ torch.Tensor: The resulting tensor with the target size in the specified dimension.
35
+ """
36
+ current_size = tensor.size(dim)
37
+ if current_size < target_size:
38
+ padding_amount = target_size - current_size
39
+ padding_tuple = [0] * (2 * tensor.dim())
40
+ padding_dim_index = tensor.dim() - 1 - dim
41
+ padding_tuple[2 * padding_dim_index + 1] = padding_amount
42
+ return F.pad(tensor, tuple(padding_tuple), "constant", pad_value), current_size
43
+
44
+ slicing = [slice(None)] * tensor.dim()
45
+ slicing[dim] = slice(0, target_size)
46
+ return tensor[tuple(slicing)], target_size
47
+
48
+
49
+ def get_padded_t5_gemma_embedding(
50
+ prompt: str,
51
+ model_path: str,
52
+ device: str,
53
+ weight_dtype: torch.dtype,
54
+ target_length: int,
55
+ ) -> Tuple[torch.Tensor, int]:
56
+ txt_feat = get_t5_gemma_embedding(prompt, model_path, device, weight_dtype)
57
+ txt_feat, original_len = pad_or_trim(txt_feat, target_size=target_length, dim=1)
58
+ return txt_feat.to(torch.float32), original_len
59
+
60
+
inference/pipeline/scheduler_unipc.py ADDED
@@ -0,0 +1,832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 SandAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
16
+ # Convert unipc for flow matching
17
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
18
+ import math
19
+ from typing import Any, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
25
+ from diffusers.utils import deprecate
26
+ from diffusers.utils.torch_utils import randn_tensor
27
+
28
+
29
+ class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
30
+ """
31
+ `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
32
+
33
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
34
+ methods the library implements for all schedulers such as loading and saving.
35
+
36
+ Args:
37
+ num_train_timesteps (`int`, defaults to 1000):
38
+ The number of diffusion steps to train the model.
39
+ solver_order (`int`, default `2`):
40
+ The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
41
+ due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
42
+ unconditional sampling.
43
+ prediction_type (`str`, defaults to "flow_prediction"):
44
+ Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
45
+ the flow of the diffusion process.
46
+ thresholding (`bool`, defaults to `False`):
47
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
48
+ as Stable Diffusion.
49
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
50
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
51
+ sample_max_value (`float`, defaults to 1.0):
52
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
53
+ predict_x0 (`bool`, defaults to `True`):
54
+ Whether to use the updating algorithm on the predicted x0.
55
+ solver_type (`str`, default `bh2`):
56
+ Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
57
+ otherwise.
58
+ lower_order_final (`bool`, default `True`):
59
+ Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
60
+ stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
61
+ disable_corrector (`list`, default `[]`):
62
+ Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
63
+ and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
64
+ usually disabled during the first few steps.
65
+ solver_p (`SchedulerMixin`, default `None`):
66
+ Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
67
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
68
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
69
+ the sigmas are determined according to a sequence of noise levels {σi}.
70
+ use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
71
+ Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
72
+ timestep_spacing (`str`, defaults to `"linspace"`):
73
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
74
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
75
+ steps_offset (`int`, defaults to 0):
76
+ An offset added to the inference steps, as required by some model families.
77
+ final_sigmas_type (`str`, defaults to `"zero"`):
78
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
79
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
80
+ """
81
+
82
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
83
+ order = 1
84
+
85
+ @register_to_config
86
+ def __init__(
87
+ self,
88
+ num_train_timesteps: int = 1000,
89
+ solver_order: int = 2,
90
+ prediction_type: str = "flow_prediction",
91
+ shift: float = 1.0,
92
+ use_dynamic_shifting=False,
93
+ thresholding: bool = False,
94
+ dynamic_thresholding_ratio: float = 0.995,
95
+ sample_max_value: float = 1.0,
96
+ predict_x0: bool = True,
97
+ solver_type: str = "bh2",
98
+ lower_order_final: bool = True,
99
+ disable_corrector: List[int] = [],
100
+ solver_p: SchedulerMixin = None,
101
+ timestep_spacing: str = "linspace",
102
+ steps_offset: int = 0,
103
+ final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
104
+ ):
105
+ if solver_type not in ["bh1", "bh2"]:
106
+ if solver_type in ["midpoint", "heun", "logrho"]:
107
+ self.register_to_config(solver_type="bh2")
108
+ else:
109
+ raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
110
+
111
+ self.predict_x0 = predict_x0
112
+ # setable values
113
+ self.num_inference_steps = None
114
+ alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy()
115
+ sigmas = 1.0 - alphas
116
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
117
+
118
+ if not use_dynamic_shifting:
119
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
120
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
121
+
122
+ self.sigmas = sigmas
123
+ self.timesteps = sigmas * num_train_timesteps
124
+
125
+ self.model_outputs = [None] * solver_order
126
+ self.timestep_list = [None] * solver_order
127
+ self.lower_order_nums = 0
128
+ self.disable_corrector = disable_corrector
129
+ self.solver_p = solver_p
130
+ self.last_sample = None
131
+ self._step_index: Optional[int] = None
132
+ self._begin_index: Optional[int] = None
133
+
134
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
135
+ self.sigma_min = self.sigmas[-1].item()
136
+ self.sigma_max = self.sigmas[0].item()
137
+
138
+ @property
139
+ def step_index(self):
140
+ """
141
+ The index counter for current timestep. It will increase 1 after each scheduler step.
142
+ """
143
+ return self._step_index
144
+
145
+ @property
146
+ def begin_index(self):
147
+ """
148
+ The index for the first timestep. It should be set by `inference.pipeline` with `set_begin_index`.
149
+ """
150
+ return self._begin_index
151
+
152
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
153
+ def set_begin_index(self, begin_index: int = 0):
154
+ """
155
+ Sets the begin index for the scheduler. This function should be run by `inference.pipeline` before inference.
156
+
157
+ Args:
158
+ begin_index (`int`):
159
+ The begin index for the scheduler.
160
+ """
161
+ self._begin_index = begin_index
162
+
163
+ # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
164
+ def set_timesteps(
165
+ self,
166
+ num_inference_steps: Union[int, None] = None,
167
+ device: Union[str, torch.device] = None,
168
+ sigmas: Optional[List[float]] = None,
169
+ mu: Optional[Union[float, None]] = None,
170
+ shift: Optional[Union[float, None]] = None,
171
+ ):
172
+ """
173
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
174
+ Args:
175
+ num_inference_steps (`int`):
176
+ Total number of the spacing of the time steps.
177
+ device (`str` or `torch.device`, *optional*):
178
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
179
+ """
180
+
181
+ if self.config.use_dynamic_shifting and mu is None:
182
+ raise ValueError(" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
183
+
184
+ if sigmas is None:
185
+ sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # type: ignore
186
+
187
+ if self.config.use_dynamic_shifting:
188
+ sigmas = self.time_shift(mu, 1.0, sigmas) # type: ignore
189
+ else:
190
+ if shift is None:
191
+ shift = self.config.shift
192
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # type: ignore
193
+
194
+ if self.config.final_sigmas_type == "sigma_min":
195
+ sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
196
+ elif self.config.final_sigmas_type == "zero":
197
+ sigma_last = 0
198
+ else:
199
+ raise ValueError(
200
+ f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
201
+ )
202
+
203
+ timesteps = sigmas * self.config.num_train_timesteps
204
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore
205
+
206
+ self.sigmas = torch.from_numpy(sigmas)
207
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
208
+
209
+ self.num_inference_steps = len(timesteps) # type: ignore
210
+
211
+ self.model_outputs = [None] * self.config.solver_order
212
+ self.lower_order_nums = 0
213
+ self.last_sample = None
214
+ if self.solver_p:
215
+ self.solver_p.set_timesteps(self.num_inference_steps, device=device)
216
+
217
+ # add an index counter for schedulers that allow duplicated timesteps
218
+ self._step_index = None
219
+ self._begin_index = None
220
+ self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
221
+
222
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
223
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
224
+ """
225
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
226
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
227
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
228
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
229
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
230
+
231
+ https://arxiv.org/abs/2205.11487
232
+ """
233
+ dtype = sample.dtype
234
+ batch_size, channels, *remaining_dims = sample.shape
235
+
236
+ if dtype not in (torch.float32, torch.float64):
237
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
238
+
239
+ # Flatten sample for doing quantile calculation along each image
240
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
241
+
242
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
243
+
244
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
245
+ s = torch.clamp(
246
+ s, min=1, max=self.config.sample_max_value
247
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
248
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
249
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
250
+
251
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
252
+ sample = sample.to(dtype)
253
+
254
+ return sample
255
+
256
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
257
+ def _sigma_to_t(self, sigma):
258
+ return sigma * self.config.num_train_timesteps
259
+
260
+ def _sigma_to_alpha_sigma_t(self, sigma):
261
+ return 1 - sigma, sigma
262
+
263
+ # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
264
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
265
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
266
+
267
+ def convert_model_output(self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, **kwargs) -> torch.Tensor:
268
+ r"""
269
+ Convert the model output to the corresponding type the UniPC algorithm needs.
270
+
271
+ Args:
272
+ model_output (`torch.Tensor`):
273
+ The direct output from the learned diffusion model.
274
+ timestep (`int`):
275
+ The current discrete timestep in the diffusion chain.
276
+ sample (`torch.Tensor`):
277
+ A current instance of a sample created by the diffusion process.
278
+
279
+ Returns:
280
+ `torch.Tensor`:
281
+ The converted model output.
282
+ """
283
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
284
+ if sample is None:
285
+ if len(args) > 1:
286
+ sample = args[1]
287
+ else:
288
+ raise ValueError("missing `sample` as a required keyward argument")
289
+ if timestep is not None:
290
+ deprecate(
291
+ "timesteps",
292
+ "1.0.0",
293
+ "Passing `timesteps` is deprecated "
294
+ "and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
295
+ )
296
+
297
+ sigma = self.sigmas[self.step_index]
298
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
299
+
300
+ if self.predict_x0:
301
+ if self.config.prediction_type == "flow_prediction":
302
+ sigma_t = self.sigmas[self.step_index]
303
+ x0_pred = sample - sigma_t * model_output
304
+ else:
305
+ raise ValueError(
306
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
307
+ " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
308
+ )
309
+
310
+ if self.config.thresholding:
311
+ x0_pred = self._threshold_sample(x0_pred)
312
+
313
+ return x0_pred
314
+ else:
315
+ if self.config.prediction_type == "flow_prediction":
316
+ sigma_t = self.sigmas[self.step_index]
317
+ epsilon = sample - (1 - sigma_t) * model_output
318
+ else:
319
+ raise ValueError(
320
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
321
+ " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
322
+ )
323
+
324
+ if self.config.thresholding:
325
+ sigma_t = self.sigmas[self.step_index]
326
+ x0_pred = sample - sigma_t * model_output
327
+ x0_pred = self._threshold_sample(x0_pred)
328
+ epsilon = model_output + x0_pred
329
+
330
+ return epsilon
331
+
332
+ def multistep_uni_p_bh_update(
333
+ self,
334
+ model_output: torch.Tensor,
335
+ *args,
336
+ sample: Optional[torch.Tensor] = None,
337
+ order: Optional[int] = None, # pyright: ignore
338
+ **kwargs,
339
+ ) -> torch.Tensor:
340
+ """
341
+ One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
342
+
343
+ Args:
344
+ model_output (`torch.Tensor`):
345
+ The direct output from the learned diffusion model at the current timestep.
346
+ prev_timestep (`int`):
347
+ The previous discrete timestep in the diffusion chain.
348
+ sample (`torch.Tensor`):
349
+ A current instance of a sample created by the diffusion process.
350
+ order (`int`):
351
+ The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
352
+
353
+ Returns:
354
+ `torch.Tensor`:
355
+ The sample tensor at the previous timestep.
356
+ """
357
+ prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
358
+ if sample is None:
359
+ if len(args) > 1:
360
+ sample = args[1]
361
+ else:
362
+ raise ValueError(" missing `sample` as a required keyward argument")
363
+ if order is None:
364
+ if len(args) > 2:
365
+ order = args[2]
366
+ else:
367
+ raise ValueError(" missing `order` as a required keyward argument")
368
+ if prev_timestep is not None:
369
+ deprecate(
370
+ "prev_timestep",
371
+ "1.0.0",
372
+ "Passing `prev_timestep` is deprecated "
373
+ "and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
374
+ )
375
+ model_output_list = self.model_outputs
376
+
377
+ s0 = self.timestep_list[-1]
378
+ m0 = model_output_list[-1]
379
+ x = sample
380
+
381
+ if self.solver_p:
382
+ x_t = self.solver_p.step(model_output, s0, x).prev_sample
383
+ return x_t
384
+
385
+ sigma_t, sigma_s0 = (self.sigmas[self.step_index + 1], self.sigmas[self.step_index]) # pyright: ignore
386
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
387
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
388
+
389
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
390
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
391
+
392
+ h = lambda_t - lambda_s0
393
+ device = sample.device
394
+
395
+ rks = []
396
+ D1s: Optional[List[Any]] = []
397
+ for i in range(1, order):
398
+ si = self.step_index - i # pyright: ignore
399
+ mi = model_output_list[-(i + 1)]
400
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
401
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
402
+ rk = (lambda_si - lambda_s0) / h
403
+ rks.append(rk)
404
+ D1s.append((mi - m0) / rk) # type: ignore
405
+
406
+ rks.append(1.0)
407
+ rks = torch.tensor(rks, device=device)
408
+
409
+ R = []
410
+ b = []
411
+
412
+ hh = -h if self.predict_x0 else h
413
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
414
+ h_phi_k = h_phi_1 / hh - 1
415
+
416
+ factorial_i = 1
417
+
418
+ if self.config.solver_type == "bh1":
419
+ B_h = hh
420
+ elif self.config.solver_type == "bh2":
421
+ B_h = torch.expm1(hh)
422
+ else:
423
+ raise NotImplementedError()
424
+
425
+ for i in range(1, order + 1):
426
+ R.append(torch.pow(rks, i - 1))
427
+ b.append(h_phi_k * factorial_i / B_h)
428
+ factorial_i *= i + 1
429
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
430
+
431
+ R = torch.stack(R)
432
+ b = torch.tensor(b, device=device)
433
+
434
+ if len(D1s) > 0: # type: ignore
435
+ D1s = torch.stack(D1s, dim=1) # (B, K)
436
+ # for order 2, we use a simplified version
437
+ if order == 2:
438
+ rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
439
+ else:
440
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) # type: ignore
441
+ else:
442
+ D1s = None
443
+
444
+ if self.predict_x0:
445
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
446
+ if D1s is not None:
447
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
448
+ else:
449
+ pred_res = 0
450
+ x_t = x_t_ - alpha_t * B_h * pred_res
451
+ else:
452
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
453
+ if D1s is not None:
454
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore
455
+ else:
456
+ pred_res = 0
457
+ x_t = x_t_ - sigma_t * B_h * pred_res
458
+
459
+ x_t = x_t.to(x.dtype)
460
+ return x_t
461
+
462
+ def multistep_uni_c_bh_update(
463
+ self,
464
+ this_model_output: torch.Tensor,
465
+ *args,
466
+ last_sample: torch.Tensor = None,
467
+ this_sample: torch.Tensor = None,
468
+ order: Optional[int] = None, # pyright: ignore
469
+ **kwargs,
470
+ ) -> torch.Tensor:
471
+ """
472
+ One step for the UniC (B(h) version).
473
+
474
+ Args:
475
+ this_model_output (`torch.Tensor`):
476
+ The model outputs at `x_t`.
477
+ this_timestep (`int`):
478
+ The current timestep `t`.
479
+ last_sample (`torch.Tensor`):
480
+ The generated sample before the last predictor `x_{t-1}`.
481
+ this_sample (`torch.Tensor`):
482
+ The generated sample after the last predictor `x_{t}`.
483
+ order (`int`):
484
+ The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
485
+
486
+ Returns:
487
+ `torch.Tensor`:
488
+ The corrected sample tensor at the current timestep.
489
+ """
490
+ this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
491
+ if last_sample is None:
492
+ if len(args) > 1:
493
+ last_sample = args[1]
494
+ else:
495
+ raise ValueError(" missing`last_sample` as a required keyward argument")
496
+ if this_sample is None:
497
+ if len(args) > 2:
498
+ this_sample = args[2]
499
+ else:
500
+ raise ValueError(" missing`this_sample` as a required keyward argument")
501
+ if order is None:
502
+ if len(args) > 3:
503
+ order = args[3]
504
+ else:
505
+ raise ValueError(" missing`order` as a required keyward argument")
506
+ if this_timestep is not None:
507
+ deprecate(
508
+ "this_timestep",
509
+ "1.0.0",
510
+ "Passing `this_timestep` is deprecated "
511
+ "and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
512
+ )
513
+
514
+ model_output_list = self.model_outputs
515
+
516
+ m0 = model_output_list[-1]
517
+ x = last_sample
518
+ x_t = this_sample
519
+ model_t = this_model_output
520
+
521
+ sigma_t, sigma_s0 = (self.sigmas[self.step_index], self.sigmas[self.step_index - 1]) # pyright: ignore
522
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
523
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
524
+
525
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
526
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
527
+
528
+ h = lambda_t - lambda_s0
529
+ device = this_sample.device
530
+
531
+ rks = []
532
+ D1s: Optional[List[Any]] = []
533
+ for i in range(1, order):
534
+ si = self.step_index - (i + 1) # pyright: ignore
535
+ mi = model_output_list[-(i + 1)]
536
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
537
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
538
+ rk = (lambda_si - lambda_s0) / h
539
+ rks.append(rk)
540
+ D1s.append((mi - m0) / rk) # type: ignore
541
+
542
+ rks.append(1.0)
543
+ rks = torch.tensor(rks, device=device)
544
+
545
+ R = []
546
+ b = []
547
+
548
+ hh = -h if self.predict_x0 else h
549
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
550
+ h_phi_k = h_phi_1 / hh - 1
551
+
552
+ factorial_i = 1
553
+
554
+ if self.config.solver_type == "bh1":
555
+ B_h = hh
556
+ elif self.config.solver_type == "bh2":
557
+ B_h = torch.expm1(hh)
558
+ else:
559
+ raise NotImplementedError()
560
+
561
+ for i in range(1, order + 1):
562
+ R.append(torch.pow(rks, i - 1))
563
+ b.append(h_phi_k * factorial_i / B_h)
564
+ factorial_i *= i + 1
565
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
566
+
567
+ R = torch.stack(R)
568
+ b = torch.tensor(b, device=device)
569
+
570
+ if len(D1s) > 0: # type: ignore
571
+ D1s = torch.stack(D1s, dim=1)
572
+ else:
573
+ D1s = None
574
+
575
+ # for order 1, we use a simplified version
576
+ if order == 1:
577
+ rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
578
+ else:
579
+ rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
580
+
581
+ if self.predict_x0:
582
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
583
+ if D1s is not None:
584
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
585
+ else:
586
+ corr_res = 0
587
+ D1_t = model_t - m0
588
+ x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
589
+ else:
590
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
591
+ if D1s is not None:
592
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
593
+ else:
594
+ corr_res = 0
595
+ D1_t = model_t - m0
596
+ x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
597
+ x_t = x_t.to(x.dtype)
598
+ return x_t
599
+
600
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
601
+ if schedule_timesteps is None:
602
+ schedule_timesteps = self.timesteps
603
+
604
+ indices = (schedule_timesteps == timestep).nonzero()
605
+
606
+ # The sigma index that is taken for the **very** first `step`
607
+ # is always the second index (or the last index if there is only 1)
608
+ # This way we can ensure we don't accidentally skip a sigma in
609
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
610
+ pos = 1 if len(indices) > 1 else 0
611
+
612
+ return indices[pos].item()
613
+
614
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
615
+ def _init_step_index(self, timestep):
616
+ """
617
+ Initialize the step_index counter for the scheduler.
618
+ """
619
+
620
+ if self.begin_index is None:
621
+ if isinstance(timestep, torch.Tensor):
622
+ timestep = timestep.to(self.timesteps.device)
623
+ self._step_index = self.index_for_timestep(timestep)
624
+ else:
625
+ self._step_index = self._begin_index
626
+
627
+ def step(
628
+ self,
629
+ model_output: torch.Tensor,
630
+ timestep: Union[int, torch.Tensor],
631
+ sample: torch.Tensor,
632
+ return_dict: bool = True,
633
+ generator=None,
634
+ ) -> Union[SchedulerOutput, Tuple]:
635
+ """
636
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
637
+ the multistep UniPC.
638
+
639
+ Args:
640
+ model_output (`torch.Tensor`):
641
+ The direct output from learned diffusion model.
642
+ timestep (`int`):
643
+ The current discrete timestep in the diffusion chain.
644
+ sample (`torch.Tensor`):
645
+ A current instance of a sample created by the diffusion process.
646
+ return_dict (`bool`):
647
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
648
+
649
+ Returns:
650
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
651
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
652
+ tuple is returned where the first element is the sample tensor.
653
+
654
+ """
655
+ if self.num_inference_steps is None:
656
+ raise ValueError(
657
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
658
+ )
659
+
660
+ if self.step_index is None: # type: ignore
661
+ self._init_step_index(timestep)
662
+
663
+ use_corrector = (
664
+ self.step_index > 0
665
+ and self.step_index - 1 not in self.disable_corrector
666
+ and self.last_sample is not None # pyright: ignore
667
+ )
668
+
669
+ model_output_convert = self.convert_model_output(model_output, sample=sample)
670
+ if use_corrector:
671
+ sample = self.multistep_uni_c_bh_update(
672
+ this_model_output=model_output_convert, last_sample=self.last_sample, this_sample=sample, order=self.this_order
673
+ )
674
+
675
+ for i in range(self.config.solver_order - 1):
676
+ self.model_outputs[i] = self.model_outputs[i + 1]
677
+ self.timestep_list[i] = self.timestep_list[i + 1]
678
+
679
+ self.model_outputs[-1] = model_output_convert
680
+ self.timestep_list[-1] = timestep # pyright: ignore
681
+
682
+ if self.config.lower_order_final:
683
+ this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) # pyright: ignore
684
+ else:
685
+ this_order = self.config.solver_order
686
+
687
+ self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
688
+ assert self.this_order > 0
689
+
690
+ self.last_sample = sample
691
+ prev_sample = self.multistep_uni_p_bh_update(
692
+ model_output=model_output, # pass the original non-converted model output, in case solver-p is used
693
+ sample=sample,
694
+ order=self.this_order,
695
+ )
696
+
697
+ if self.lower_order_nums < self.config.solver_order:
698
+ self.lower_order_nums += 1
699
+
700
+ # upon completion increase step index by one
701
+ self._step_index += 1 # pyright: ignore
702
+
703
+ if not return_dict:
704
+ return (prev_sample,)
705
+
706
+ return SchedulerOutput(prev_sample=prev_sample)
707
+
708
+ def step_ddim(
709
+ # https://github.com/yifan123/flow_grpo/blob/main/flow_grpo/diffusers_patch/sd3_sde_with_logprob.py
710
+ self,
711
+ velocity: torch.FloatTensor,
712
+ t: int,
713
+ curr_state: torch.FloatTensor,
714
+ prev_state: Optional[torch.FloatTensor] = None,
715
+ generator: Optional[torch.Generator] = None,
716
+ ):
717
+ """
718
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
719
+ the multistep UniPC.
720
+
721
+ Args:
722
+ model_output (`torch.Tensor`):
723
+ The direct output from learned diffusion model.
724
+ timestep (`int`):
725
+ The current discrete timestep in the diffusion chain.
726
+ sample (`torch.Tensor`):
727
+ A current instance of a sample created by the diffusion process.
728
+ return_dict (`bool`):
729
+ Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
730
+
731
+ Returns:
732
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
733
+ If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
734
+ tuple is returned where the first element is the sample tensor.
735
+
736
+ """
737
+ device = curr_state.device
738
+ curr_t = self.sigmas[t]
739
+ prev_t = self.sigmas[t + 1]
740
+ variance_noise = randn_tensor(curr_state.shape, generator=generator, device=device, dtype=curr_state.dtype)
741
+ cur_clean_ = curr_state - curr_t * velocity
742
+ prev_state = prev_t * variance_noise + (1 - prev_t) * cur_clean_
743
+
744
+ return prev_state
745
+
746
+ def step_sde(
747
+ # https://github.com/yifan123/flow_grpo/blob/main/flow_grpo/diffusers_patch/sd3_sde_with_logprob.py
748
+ self,
749
+ velocity: torch.FloatTensor,
750
+ t: int,
751
+ curr_state: torch.FloatTensor,
752
+ noise_theta: float = 1.0,
753
+ prev_state: Optional[torch.FloatTensor] = None,
754
+ generator: Optional[torch.Generator] = None,
755
+ ):
756
+ """
757
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the flow
758
+ process from the learned model outputs (most often the predicted velocity).
759
+
760
+ Args:
761
+ velocity (`torch.FloatTensor`): (B, C, T, H, W)
762
+ The direct output from learned flow model.
763
+ timestep (`float`): (B, )
764
+ The current discrete timestep in the diffusion chain.
765
+ curr_state (`torch.FloatTensor`): (B, C, T, H, W)
766
+ A current instance of a sample created by the diffusion process.
767
+ generator (`torch.Generator`, *optional*):
768
+ A random number generator.
769
+ """
770
+ device = curr_state.device
771
+ curr_t = self.sigmas[t]
772
+ prev_t = self.sigmas[t + 1]
773
+ cos = torch.cos(torch.tensor(noise_theta) * torch.pi / 2).to(device) # if noise_theta is 0, it degenerates to standard flow matching
774
+ sin = torch.sin(torch.tensor(noise_theta) * torch.pi / 2).to(device)
775
+ prev_sample_mean = (1 - prev_t + prev_t * cos) * (curr_state - curr_t * velocity) + prev_t * cos * velocity
776
+ std_dev_t = prev_t * sin
777
+ std_dev_t = torch.ones((1, 1)).to(curr_state) * std_dev_t
778
+ if prev_state is None:
779
+ variance_noise = randn_tensor(curr_state.shape, generator=generator, device=device, dtype=curr_state.dtype)
780
+ prev_state = prev_sample_mean + std_dev_t * variance_noise
781
+ else:
782
+ prev_state = prev_sample_mean + (prev_state - prev_sample_mean.detach())
783
+
784
+ return prev_state
785
+
786
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
787
+ """
788
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
789
+ current timestep.
790
+
791
+ Args:
792
+ sample (`torch.Tensor`):
793
+ The input sample.
794
+
795
+ Returns:
796
+ `torch.Tensor`:
797
+ A scaled input sample.
798
+ """
799
+ return sample
800
+
801
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
802
+ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
803
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
804
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
805
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
806
+ # mps does not support float64
807
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
808
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
809
+ else:
810
+ schedule_timesteps = self.timesteps.to(original_samples.device)
811
+ timesteps = timesteps.to(original_samples.device)
812
+
813
+ # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
814
+ if self.begin_index is None:
815
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
816
+ elif self.step_index is not None:
817
+ # add_noise is called after first denoising step (for inpainting)
818
+ step_indices = [self.step_index] * timesteps.shape[0]
819
+ else:
820
+ # add noise is called before first denoising step to create initial latent(img2img)
821
+ step_indices = [self.begin_index] * timesteps.shape[0]
822
+
823
+ sigma = sigmas[step_indices].flatten()
824
+ while len(sigma.shape) < len(original_samples.shape):
825
+ sigma = sigma.unsqueeze(-1)
826
+
827
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
828
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
829
+ return noisy_samples
830
+
831
+ def __len__(self):
832
+ return self.config.num_train_timesteps