upload Astra checkpoint
#1
by wjque - opened
This view is limited to 50 files because it contains too many changes. See the raw diff here.
- .DS_Store +0 -0
- .gitattributes +0 -1
- .gitignore +0 -3
- README.md +45 -72
- Try ReCamMaster with Your Own Videos Here.txt +1 -0
- diffsynth/pipelines/__init__.py +1 -1
- diffsynth/pipelines/wan_video_recammaster.py +2 -2
- examples/.DS_Store +0 -0
- models/Astra/checkpoints/diffusion_pytorch_model.ckpt → examples/output_videos/output_moe_framepack_sliding.mp4 +2 -2
- assets/images/pipeline.png → logo-text-2.png +2 -2
- models/.DS_Store +0 -0
- models/Astra/.DS_Store +0 -0
- models/Astra/checkpoints/{Put Astra ckpt file here.txt → Put ReCamMaster ckpt file here.txt} +0 -0
- pip-list.txt +197 -0
- scripts/add_text_emb.py +161 -0
- scripts/add_text_emb_rl.py +161 -0
- scripts/add_text_emb_spatialvid.py +173 -0
- scripts/analyze_openx.py +243 -0
- scripts/analyze_pose.py +188 -0
- scripts/batch_drone.py +44 -0
- scripts/batch_infer.py +186 -0
- scripts/batch_nus.py +42 -0
- scripts/batch_rt.py +41 -0
- scripts/batch_spa.py +43 -0
- scripts/batch_walk.py +42 -0
- scripts/check.py +263 -0
- scripts/decode_openx.py +428 -0
- scripts/download_recam.py +7 -0
- scripts/encode_dynamic_videos.py +141 -0
- scripts/encode_openx.py +466 -0
- scripts/encode_rlbench_video.py +170 -0
- scripts/encode_sekai_video.py +162 -0
- scripts/encode_sekai_walking.py +249 -0
- scripts/encode_spatialvid.py +409 -0
- scripts/encode_spatialvid_first_frame.py +285 -0
- scripts/hud_logo.py +1 -1
- scripts/infer_demo.py +318 -494
- scripts/infer_moe.py +1023 -0
- scripts/infer_moe_spatialvid.py +1008 -0
- scripts/infer_moe_test.py +976 -0
- scripts/infer_nus.py +500 -0
- scripts/infer_openx.py +614 -0
- scripts/infer_origin.py +1108 -0
- scripts/infer_recam.py +272 -0
- scripts/infer_rlbench.py +447 -0
- scripts/infer_sekai.py +497 -0
- scripts/infer_sekai_framepack.py +675 -0
- scripts/infer_sekai_framepack_4.py +682 -0
- scripts/infer_sekai_framepack_test.py +551 -0
- scripts/infer_spatialvid.py +608 -0
.DS_Store
DELETED
|
Binary file (6.15 kB)
|
|
|
.gitattributes
CHANGED
|
@@ -40,4 +40,3 @@ diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt filter=lfs diff=lfs merge
|
|
| 40 |
examples/output_videos/output_moe_framepack_sliding.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
logo-text-2.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
assets/images/logo-text-2.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
-
pipeline.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 40 |
examples/output_videos/output_moe_framepack_sliding.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
logo-text-2.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
assets/images/logo-text-2.png filter=lfs diff=lfs merge=lfs -text
|
|
|
.gitignore
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
# Ignore all checkpoint files
|
| 2 |
-
*.ckpt
|
| 3 |
-
*.ckpt.*
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -20,7 +20,7 @@ model-index:
|
|
| 20 |
|
| 21 |
<h3 style="margin-top: 0;">
|
| 22 |
📄
|
| 23 |
-
[<a href="https://arxiv.org/abs/
|
| 24 |
|
| 25 |
🏠
|
| 26 |
[<a href="https://eternalevan.github.io/Astra-project/" target="_blank">Project Page</a>]
|
|
@@ -48,23 +48,36 @@ model-index:
|
|
| 48 |
|
| 49 |
<div align="center">
|
| 50 |
|
| 51 |
-
**[Yixuan Zhu<sup>1</sup>](https://
|
| 52 |
<!-- <br> -->
|
| 53 |
-
(
|
|
|
|
| 54 |
|
| 55 |
<sup>1</sup>Tsinghua University, <sup>2</sup>Kuaishou Technology.
|
| 56 |
</div>
|
| 57 |
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
##
|
| 60 |
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
**Astra** is an **interactive**, action-driven world model that predicts long-horizon future videos across diverse real-world scenarios. Built on an autoregressive diffusion transformer with temporal causal attention, Astra supports **streaming prediction** while preserving strong temporal coherence. Astra introduces **noise-augmented history memory** to stabilize long rollouts, an **action-aware adapter** for precise control signals, and a **mixture of action experts** to route heterogeneous action modalities. Through these key innovations, Astra delivers consistent, controllable, and high-fidelity video futures for applications such as autonomous driving, robot manipulation, and camera motion.
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
## Gallery
|
| 70 |
|
|
@@ -119,27 +132,7 @@ model-index:
|
|
| 119 |
|
| 120 |
If you would like to use ReCamMaster as a baseline and need qualitative or quantitative comparisons, please feel free to drop an email to [jianhongbai@zju.edu.cn](mailto:jianhongbai@zju.edu.cn). We can assist you with batch inference of our model. -->
|
| 121 |
|
| 122 |
-
##
|
| 123 |
-
- __[2025.11.17]__: Release the [project page](https://eternalevan.github.io/Astra-project/).
|
| 124 |
-
- __[2025.12.09]__: Release the inference code, model checkpoint.
|
| 125 |
-
|
| 126 |
-
## 🎯 TODO List
|
| 127 |
-
|
| 128 |
-
- [ ] **Release full inference pipelines** for additional scenarios:
|
| 129 |
-
- [ ] 🚗 Autonomous driving
|
| 130 |
-
- [ ] 🤖 Robotic manipulation
|
| 131 |
-
- [ ] 🛸 Drone navigation / exploration
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
- [ ] **Open-source training scripts**:
|
| 135 |
-
- [ ] ⬆️ Action-conditioned autoregressive denoising training
|
| 136 |
-
- [ ] 🔄 Multi-scenario joint training pipeline
|
| 137 |
-
|
| 138 |
-
- [ ] **Release dataset preprocessing tools**
|
| 139 |
-
|
| 140 |
-
- [ ] **Provide unified evaluation toolkit**
|
| 141 |
-
|
| 142 |
-
## ⚙️ Run Astra (Inference)
|
| 143 |
Astra is built upon [Wan2.1-1.3B](https://github.com/Wan-Video/Wan2.1), a diffusion-based video generation model. We provide inference scripts to help you quickly generate videos from images and action inputs. Follow the steps below:
|
| 144 |
|
| 145 |
### Inference
|
|
@@ -167,57 +160,37 @@ python download_wan2.1.py
|
|
| 167 |
```
|
| 168 |
2. Download the pre-trained Astra checkpoint
|
| 169 |
|
| 170 |
-
Please download from [huggingface](https://huggingface.co/
|
| 171 |
|
| 172 |
-
Step 3: Test the example
|
| 173 |
```shell
|
| 174 |
-
python
|
| 175 |
-
--dit_path ../models/Astra/checkpoints/diffusion_pytorch_model.ckpt \
|
| 176 |
-
--wan_model_path ../models/Wan-AI/Wan2.1-T2V-1.3B \
|
| 177 |
-
--condition_image ../examples/condition_images/garden_1.png \
|
| 178 |
-
--cam_type 4 \
|
| 179 |
-
--prompt "A sunlit European street lined with historic buildings and vibrant greenery creates a warm, charming, and inviting atmosphere. The scene shows a picturesque open square paved with red bricks, surrounded by classic narrow townhouses featuring tall windows, gabled roofs, and dark-painted facades. On the right side, a lush arrangement of potted plants and blooming flowers adds rich color and texture to the foreground. A vintage-style streetlamp stands prominently near the center-right, contributing to the timeless character of the street. Mature trees frame the background, their leaves glowing in the warm afternoon sunlight. Bicycles are visible along the edges of the buildings, reinforcing the urban yet leisurely feel. The sky is bright blue with scattered clouds, and soft sun flares enter the frame from the left, enhancing the scene’s inviting, peaceful mood." \
|
| 180 |
-
--output_path ../examples/output_videos/output_moe_framepack_sliding.mp4 \
|
| 181 |
```
|
| 182 |
|
| 183 |
-
Step 4: Test your own
|
| 184 |
|
| 185 |
-
|
| 186 |
|
| 187 |
```shell
|
| 188 |
-
python
|
| 189 |
-
--dit_path path/to/your/dit_ckpt \
|
| 190 |
-
--wan_model_path path/to/your/Wan2.1-T2V-1.3B \
|
| 191 |
-
--condition_image path/to/your/image \
|
| 192 |
-
--cam_type your_cam_type \
|
| 193 |
-
--prompt your_prompt \
|
| 194 |
-
--output_path path/to/your/output_video
|
| 195 |
```
|
| 196 |
|
| 197 |
We provide several preset camera types, as shown in the table below. Additionally, you can generate new trajectories for testing.
|
| 198 |
|
| 199 |
-
| cam_type
|
| 200 |
-
|
|
| 201 |
-
| 1
|
| 202 |
-
| 2
|
| 203 |
-
| 3
|
| 204 |
-
| 4
|
| 205 |
-
| 5
|
| 206 |
-
| 6
|
| 207 |
-
| 7
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
Looking ahead, we plan to further enhance Astra in several directions:
|
| 213 |
-
|
| 214 |
-
- **Training with Wan-2.2:** Upgrade our model using the latest Wan-2.2 framework to release a more powerful version with improved generation quality.
|
| 215 |
-
- **3D Spatial Consistency:** Explore techniques to better preserve 3D consistency across frames for more coherent and realistic video generation.
|
| 216 |
-
- **Long-Term Memory:** Incorporate mechanisms for long-term memory, enabling the model to handle extended temporal dependencies and complex action sequences.
|
| 217 |
-
|
| 218 |
-
These directions aim to push Astra towards more robust and interactive video world modeling.
|
| 219 |
|
| 220 |
-
|
| 221 |
|
| 222 |
Step 1: Set up the environment
|
| 223 |
|
|
@@ -249,7 +222,7 @@ Step 4: Test the model
|
|
| 249 |
|
| 250 |
```shell
|
| 251 |
python inference_recammaster.py --cam_type 1 --ckpt_path path/to/the/checkpoint
|
| 252 |
-
```
|
| 253 |
|
| 254 |
<!-- ## 📷 Dataset: MultiCamVideo Dataset
|
| 255 |
### 1. Dataset Introduction
|
|
@@ -408,10 +381,10 @@ Feel free to explore these outstanding related works, including but not limited
|
|
| 408 |
|
| 409 |
Please leave us a star 🌟 and cite our paper if you find our work helpful.
|
| 410 |
```
|
| 411 |
-
@
|
| 412 |
title={Astra: General Interactive World Model with Autoregressive Denoising},
|
| 413 |
author={Zhu, Yixuan and Feng, Jiaqi and Zheng, Wenzhao and Gao, Yuan and Tao, Xin and Wan, Pengfei and Zhou, Jie and Lu, Jiwen},
|
| 414 |
-
|
| 415 |
year={2025}
|
| 416 |
}
|
| 417 |
```
|
|
|
|
| 20 |
|
| 21 |
<h3 style="margin-top: 0;">
|
| 22 |
📄
|
| 23 |
+
[<a href="https://arxiv.org/abs/2503.11647" target="_blank">arXiv</a>]
|
| 24 |
|
| 25 |
🏠
|
| 26 |
[<a href="https://eternalevan.github.io/Astra-project/" target="_blank">Project Page</a>]
|
|
|
|
| 48 |
|
| 49 |
<div align="center">
|
| 50 |
|
| 51 |
+
**[Yixuan Zhu<sup>1</sup>](https://jianhongbai.github.io/), [Jiaqi Feng<sup>1</sup>](https://menghanxia.github.io/), [Wenzhao Zheng<sup>1 †</sup>](https://fuxiao0719.github.io/), [Yuan Gao<sup>2</sup>](https://xinntao.github.io/), [Xin Tao<sup>2</sup>](https://scholar.google.com/citations?user=dCik-2YAAAAJ&hl=en), [Pengfei Wan<sup>2</sup>](https://openreview.net/profile?id=~Jinwen_Cao1), [Jie Zhou <sup>1</sup>](https://person.zju.edu.cn/en/lzz), [Jiwen Lu<sup>1</sup>](https://person.zju.edu.cn/en/huhaoji)**
|
| 52 |
<!-- <br> -->
|
| 53 |
+
(*Work done during an internship at Kuaishou Technology,
|
| 54 |
+
† Project leader)
|
| 55 |
|
| 56 |
<sup>1</sup>Tsinghua University, <sup>2</sup>Kuaishou Technology.
|
| 57 |
</div>
|
| 58 |
|
| 59 |
+
## 🔥 Updates
|
| 60 |
+
- __[2025.11.17]__: Release the [project page](https://eternalevan.github.io/Astra-project/).
|
| 61 |
+
- __[2025.12.09]__: Release the training and inference code, model checkpoint.
|
| 62 |
|
| 63 |
+
## 🎯 TODO List
|
| 64 |
|
| 65 |
+
- [ ] **Release full inference pipelines** for additional scenarios:
|
| 66 |
+
- [ ] 🚗 Autonomous driving
|
| 67 |
+
- [ ] 🤖 Robotic manipulation
|
| 68 |
+
- [ ] 🛸 Drone navigation / exploration
|
| 69 |
|
|
|
|
| 70 |
|
| 71 |
+
- [ ] **Open-source training scripts**:
|
| 72 |
+
- [ ] ⬆️ Action-conditioned autoregressive denoising training
|
| 73 |
+
- [ ] 🔄 Multi-scenario joint training pipeline
|
| 74 |
+
|
| 75 |
+
- [ ] **Release dataset preprocessing tools**
|
| 76 |
+
|
| 77 |
+
- [ ] **Provide unified evaluation toolkit**
|
| 78 |
+
## 📖 Introduction
|
| 79 |
+
|
| 80 |
+
**TL;DR:** Astra is an **interactive world model** that delivers realistic long-horizon video rollouts under a wide range of scenarios and action inputs.
|
| 81 |
|
| 82 |
## Gallery
|
| 83 |
|
|
|
|
| 132 |
|
| 133 |
If you would like to use ReCamMaster as a baseline and need qualitative or quantitative comparisons, please feel free to drop an email to [jianhongbai@zju.edu.cn](mailto:jianhongbai@zju.edu.cn). We can assist you with batch inference of our model. -->
|
| 134 |
|
| 135 |
+
## ⚙️ Code: Astra + Wan2.1 (Inference & Training)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
Astra is built upon [Wan2.1-1.3B](https://github.com/Wan-Video/Wan2.1), a diffusion-based video generation model. We provide inference scripts to help you quickly generate videos from images and action inputs. Follow the steps below:
|
| 137 |
|
| 138 |
### Inference
|
|
|
|
| 160 |
```
|
| 161 |
2. Download the pre-trained Astra checkpoint
|
| 162 |
|
| 163 |
+
Please download from [huggingface](https://huggingface.co/wjque/lyra/blob/main/diffusion_pytorch_model.ckpt) and place it in ```models/Astra/checkpoints```.
|
| 164 |
|
| 165 |
+
Step 3: Test the example videos
|
| 166 |
```shell
|
| 167 |
+
python inference_astra.py --cam_type 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
```
|
| 169 |
|
| 170 |
+
Step 4: Test your own videos
|
| 171 |
|
| 172 |
+
If you want to test your own videos, you need to prepare your test data following the structure of the ```example_test_data``` folder. This includes N mp4 videos, each with at least 81 frames, and a ```metadata.csv``` file that stores their paths and corresponding captions. You can refer to the [Prompt Extension section](https://github.com/Wan-Video/Wan2.1?tab=readme-ov-file#2-using-prompt-extension) in Wan2.1 for guidance on preparing video captions.
|
| 173 |
|
| 174 |
```shell
|
| 175 |
+
python inference_astra.py --cam_type 1 --dataset_path path/to/your/data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
```
|
| 177 |
|
| 178 |
We provide several preset camera types, as shown in the table below. Additionally, you can generate new trajectories for testing.
|
| 179 |
|
| 180 |
+
| cam_type | Trajectory |
|
| 181 |
+
|-------------------|-----------------------------|
|
| 182 |
+
| 1 | Pan Right |
|
| 183 |
+
| 2 | Pan Left |
|
| 184 |
+
| 3 | Tilt Up |
|
| 185 |
+
| 4 | Tilt Down |
|
| 186 |
+
| 5 | Zoom In |
|
| 187 |
+
| 6 | Zoom Out |
|
| 188 |
+
| 7 | Translate Up (with rotation) |
|
| 189 |
+
| 8 | Translate Down (with rotation) |
|
| 190 |
+
| 9 | Arc Left (with rotation) |
|
| 191 |
+
| 10 | Arc Right (with rotation) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
+
### Training
|
| 194 |
|
| 195 |
Step 1: Set up the environment
|
| 196 |
|
|
|
|
| 222 |
|
| 223 |
```shell
|
| 224 |
python inference_recammaster.py --cam_type 1 --ckpt_path path/to/the/checkpoint
|
| 225 |
+
```
|
| 226 |
|
| 227 |
<!-- ## 📷 Dataset: MultiCamVideo Dataset
|
| 228 |
### 1. Dataset Introduction
|
|
|
|
| 381 |
|
| 382 |
Please leave us a star 🌟 and cite our paper if you find our work helpful.
|
| 383 |
```
|
| 384 |
+
@inproceedings{zhu2025astra,
|
| 385 |
title={Astra: General Interactive World Model with Autoregressive Denoising},
|
| 386 |
author={Zhu, Yixuan and Feng, Jiaqi and Zheng, Wenzhao and Gao, Yuan and Tao, Xin and Wan, Pengfei and Zhou, Jie and Lu, Jiwen},
|
| 387 |
+
booktitle={arxiv},
|
| 388 |
year={2025}
|
| 389 |
}
|
| 390 |
```
|
Try ReCamMaster with Your Own Videos Here.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
https://docs.google.com/forms/d/e/1FAIpQLSezOzGPbm8JMXQDq6EINiDf6iXn7rV4ozj6KcbQCSAzE8Vsnw/viewform?usp=dialog
|
diffsynth/pipelines/__init__.py
CHANGED
|
@@ -12,5 +12,5 @@ from .pipeline_runner import SDVideoPipelineRunner
|
|
| 12 |
from .hunyuan_video import HunyuanVideoPipeline
|
| 13 |
from .step_video import StepVideoPipeline
|
| 14 |
from .wan_video import WanVideoPipeline
|
| 15 |
-
from .wan_video_recammaster import
|
| 16 |
KolorsImagePipeline = SDXLImagePipeline
|
|
|
|
| 12 |
from .hunyuan_video import HunyuanVideoPipeline
|
| 13 |
from .step_video import StepVideoPipeline
|
| 14 |
from .wan_video import WanVideoPipeline
|
| 15 |
+
from .wan_video_recammaster import WanVideoReCamMasterPipeline
|
| 16 |
KolorsImagePipeline = SDXLImagePipeline
|
diffsynth/pipelines/wan_video_recammaster.py
CHANGED
|
@@ -21,7 +21,7 @@ from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
|
| 21 |
|
| 22 |
|
| 23 |
|
| 24 |
-
class
|
| 25 |
|
| 26 |
def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None,condition_frames=None,target_frames=None):
|
| 27 |
super().__init__(device=device, torch_dtype=torch_dtype)
|
|
@@ -141,7 +141,7 @@ class WanVideoAstraPipeline(BasePipeline):
|
|
| 141 |
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
|
| 142 |
if device is None: device = model_manager.device
|
| 143 |
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
| 144 |
-
pipe =
|
| 145 |
pipe.fetch_models(model_manager)
|
| 146 |
return pipe
|
| 147 |
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
|
| 24 |
+
class WanVideoReCamMasterPipeline(BasePipeline):
|
| 25 |
|
| 26 |
def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None,condition_frames=None,target_frames=None):
|
| 27 |
super().__init__(device=device, torch_dtype=torch_dtype)
|
|
|
|
| 141 |
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
|
| 142 |
if device is None: device = model_manager.device
|
| 143 |
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
| 144 |
+
pipe = WanVideoReCamMasterPipeline(device=device, torch_dtype=torch_dtype)
|
| 145 |
pipe.fetch_models(model_manager)
|
| 146 |
return pipe
|
| 147 |
|
examples/.DS_Store
DELETED
|
Binary file (6.15 kB)
|
|
|
models/Astra/checkpoints/diffusion_pytorch_model.ckpt → examples/output_videos/output_moe_framepack_sliding.mp4
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:94e522d491ab43a0e597b13eefc4a6a48640c4490489005adef1bcb48c22e15c
|
| 3 |
+
size 2085028
|
assets/images/pipeline.png → logo-text-2.png
RENAMED
|
File without changes
|
models/.DS_Store
DELETED
|
Binary file (6.15 kB)
|
|
|
models/Astra/.DS_Store
DELETED
|
Binary file (6.15 kB)
|
|
|
models/Astra/checkpoints/{Put Astra ckpt file here.txt → Put ReCamMaster ckpt file here.txt}
RENAMED
|
File without changes
|
pip-list.txt
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
(ddrm) zyx@zwl:~$ pip list
|
| 2 |
+
Package Version Editable project location
|
| 3 |
+
------------------------ --------------- ---------------------------------------------------------
|
| 4 |
+
absl-py 2.0.0
|
| 5 |
+
accelerate 1.0.1
|
| 6 |
+
addict 2.4.0
|
| 7 |
+
aiohttp 3.8.5
|
| 8 |
+
aiosignal 1.3.1
|
| 9 |
+
albumentations 1.4.6
|
| 10 |
+
annotated-types 0.6.0
|
| 11 |
+
antlr4-python3-runtime 4.9.3
|
| 12 |
+
appdirs 1.4.4
|
| 13 |
+
asttokens 2.4.1
|
| 14 |
+
async-timeout 4.0.3
|
| 15 |
+
attrs 23.1.0
|
| 16 |
+
backcall 0.2.0
|
| 17 |
+
basicsr 1.2.0+1.4.2 /home/zyx/Retinexformer-master
|
| 18 |
+
beautifulsoup4 4.12.3
|
| 19 |
+
blessed 1.20.0
|
| 20 |
+
blobfile 2.0.2
|
| 21 |
+
cachetools 5.3.1
|
| 22 |
+
certifi 2023.5.7
|
| 23 |
+
cffi 1.15.1
|
| 24 |
+
charset-normalizer 3.2.0
|
| 25 |
+
click 8.1.7
|
| 26 |
+
cmake 3.26.4
|
| 27 |
+
contourpy 1.1.1
|
| 28 |
+
cycler 0.11.0
|
| 29 |
+
decorator 5.1.1
|
| 30 |
+
decord 0.6.0
|
| 31 |
+
diffusers 0.31.0
|
| 32 |
+
dlib 19.24.6
|
| 33 |
+
docker-pycreds 0.4.0
|
| 34 |
+
einops 0.6.1
|
| 35 |
+
executing 2.1.0
|
| 36 |
+
face-alignment 1.4.1
|
| 37 |
+
facexlib 0.3.0
|
| 38 |
+
filelock 3.12.2
|
| 39 |
+
filterpy 1.4.5
|
| 40 |
+
fire 0.5.0
|
| 41 |
+
flatbuffers 23.5.26
|
| 42 |
+
fonttools 4.42.1
|
| 43 |
+
frozenlist 1.4.0
|
| 44 |
+
fsspec 2023.9.1
|
| 45 |
+
ftfy 6.1.1
|
| 46 |
+
future 0.18.3
|
| 47 |
+
gdown 5.2.0
|
| 48 |
+
gfpgan 1.3.8 /home/zyx/anaconda3/envs/ddrm/lib/python3.8/site-packages
|
| 49 |
+
gitdb 4.0.10
|
| 50 |
+
GitPython 3.1.37
|
| 51 |
+
google-auth 2.23.0
|
| 52 |
+
google-auth-oauthlib 1.0.0
|
| 53 |
+
gpustat 1.1
|
| 54 |
+
grpcio 1.58.0
|
| 55 |
+
guided-diffusion 0.0.0 /home/zyx/GenerativeDiffusionPrior
|
| 56 |
+
huggingface-hub 0.30.2
|
| 57 |
+
idna 3.4
|
| 58 |
+
imageio 2.31.5
|
| 59 |
+
imgaug 0.4.0
|
| 60 |
+
importlib-metadata 6.8.0
|
| 61 |
+
importlib-resources 6.1.0
|
| 62 |
+
ip-adapter 0.1.0
|
| 63 |
+
ipython 8.12.3
|
| 64 |
+
jedi 0.19.2
|
| 65 |
+
Jinja2 3.1.2
|
| 66 |
+
joblib 1.3.2
|
| 67 |
+
kiwisolver 1.4.5
|
| 68 |
+
lazy_loader 0.3
|
| 69 |
+
lightning-utilities 0.9.0
|
| 70 |
+
lit 16.0.6
|
| 71 |
+
llvmlite 0.41.0
|
| 72 |
+
lmdb 1.4.1
|
| 73 |
+
loguru 0.7.2
|
| 74 |
+
lora-diffusion 0.1.7
|
| 75 |
+
loralib 0.1.2
|
| 76 |
+
lpips 0.1.4
|
| 77 |
+
lxml 4.9.3
|
| 78 |
+
Markdown 3.4.4
|
| 79 |
+
markdown-it-py 3.0.0
|
| 80 |
+
MarkupSafe 2.1.3
|
| 81 |
+
matplotlib 3.7.3
|
| 82 |
+
matplotlib-inline 0.1.7
|
| 83 |
+
mdurl 0.1.2
|
| 84 |
+
mediapipe 0.10.5
|
| 85 |
+
mmcv 1.7.0
|
| 86 |
+
mmengine 0.10.7
|
| 87 |
+
mpi4py 3.1.4
|
| 88 |
+
mpmath 1.3.0
|
| 89 |
+
multidict 6.0.4
|
| 90 |
+
mypy-extensions 1.0.0
|
| 91 |
+
natsort 8.4.0
|
| 92 |
+
networkx 3.1
|
| 93 |
+
ninja 1.11.1.1
|
| 94 |
+
numba 0.58.0
|
| 95 |
+
numpy 1.24.4
|
| 96 |
+
nvidia-cublas-cu11 11.10.3.66
|
| 97 |
+
nvidia-cuda-cupti-cu11 11.7.101
|
| 98 |
+
nvidia-cuda-nvrtc-cu11 11.7.99
|
| 99 |
+
nvidia-cuda-runtime-cu11 11.7.99
|
| 100 |
+
nvidia-cudnn-cu11 8.5.0.96
|
| 101 |
+
nvidia-cufft-cu11 10.9.0.58
|
| 102 |
+
nvidia-curand-cu11 10.2.10.91
|
| 103 |
+
nvidia-cusolver-cu11 11.4.0.1
|
| 104 |
+
nvidia-cusparse-cu11 11.7.4.91
|
| 105 |
+
nvidia-ml-py 12.535.77
|
| 106 |
+
nvidia-nccl-cu11 2.14.3
|
| 107 |
+
nvidia-nvtx-cu11 11.7.91
|
| 108 |
+
oauthlib 3.2.2
|
| 109 |
+
omegaconf 2.3.0
|
| 110 |
+
open-clip-torch 2.20.0
|
| 111 |
+
openai-clip 1.0.1
|
| 112 |
+
opencv-contrib-python 4.8.0.76
|
| 113 |
+
opencv-python 4.8.0.74
|
| 114 |
+
opencv-python-headless 4.9.0.80
|
| 115 |
+
packaging 23.1
|
| 116 |
+
pandas 2.0.3
|
| 117 |
+
parso 0.8.4
|
| 118 |
+
pathtools 0.1.2
|
| 119 |
+
peft 0.13.2
|
| 120 |
+
pexpect 4.9.0
|
| 121 |
+
pickleshare 0.7.5
|
| 122 |
+
Pillow 10.0.0
|
| 123 |
+
pip 23.1.2
|
| 124 |
+
platformdirs 3.11.0
|
| 125 |
+
prompt_toolkit 3.0.48
|
| 126 |
+
protobuf 3.20.3
|
| 127 |
+
psutil 5.9.5
|
| 128 |
+
ptyprocess 0.7.0
|
| 129 |
+
pure_eval 0.2.3
|
| 130 |
+
pyasn1 0.5.0
|
| 131 |
+
pyasn1-modules 0.3.0
|
| 132 |
+
pycparser 2.21
|
| 133 |
+
pycryptodomex 3.18.0
|
| 134 |
+
pydantic 2.7.1
|
| 135 |
+
pydantic_core 2.18.2
|
| 136 |
+
pyDeprecate 0.3.1
|
| 137 |
+
Pygments 2.18.0
|
| 138 |
+
pyiqa 0.1.8
|
| 139 |
+
pyparsing 3.1.1
|
| 140 |
+
pyre-extensions 0.0.23
|
| 141 |
+
PySocks 1.7.1
|
| 142 |
+
python-dateutil 2.8.2
|
| 143 |
+
pytorch-fid 0.3.0
|
| 144 |
+
pytorch-lightning 1.4.2
|
| 145 |
+
pytz 2023.3.post1
|
| 146 |
+
PyWavelets 1.4.1
|
| 147 |
+
PyYAML 6.0.1
|
| 148 |
+
realesrgan 0.3.0
|
| 149 |
+
regex 2023.8.8
|
| 150 |
+
requests 2.31.0
|
| 151 |
+
requests-oauthlib 1.3.1
|
| 152 |
+
rich 14.0.0
|
| 153 |
+
rsa 4.9
|
| 154 |
+
safetensors 0.4.5
|
| 155 |
+
scikit-image 0.21.0
|
| 156 |
+
scikit-learn 1.3.2
|
| 157 |
+
scipy 1.10.1
|
| 158 |
+
sentencepiece 0.1.99
|
| 159 |
+
sentry-sdk 1.31.0
|
| 160 |
+
setproctitle 1.3.2
|
| 161 |
+
setuptools 68.0.0
|
| 162 |
+
shapely 2.0.1
|
| 163 |
+
six 1.16.0
|
| 164 |
+
smmap 5.0.1
|
| 165 |
+
sounddevice 0.4.6
|
| 166 |
+
soupsieve 2.5
|
| 167 |
+
stack-data 0.6.3
|
| 168 |
+
sympy 1.12
|
| 169 |
+
tb-nightly 2.14.0a20230808
|
| 170 |
+
tensorboard 2.14.0
|
| 171 |
+
tensorboard-data-server 0.7.1
|
| 172 |
+
termcolor 2.3.0
|
| 173 |
+
threadpoolctl 3.2.0
|
| 174 |
+
tifffile 2023.7.10
|
| 175 |
+
timm 0.9.7
|
| 176 |
+
tokenizers 0.15.0
|
| 177 |
+
tomli 2.0.1
|
| 178 |
+
torch 1.13.1+cu116
|
| 179 |
+
torchmetrics 0.5.0
|
| 180 |
+
torchvision 0.14.1+cu116
|
| 181 |
+
tqdm 4.65.0
|
| 182 |
+
traitlets 5.14.3
|
| 183 |
+
transformers 4.36.2
|
| 184 |
+
triton 2.0.0
|
| 185 |
+
typing_extensions 4.11.0
|
| 186 |
+
typing-inspect 0.9.0
|
| 187 |
+
tzdata 2023.3
|
| 188 |
+
urllib3 1.26.16
|
| 189 |
+
vqfr 2.0.0 /home/zyx/VQFR
|
| 190 |
+
wandb 0.15.11
|
| 191 |
+
wcwidth 0.2.6
|
| 192 |
+
Werkzeug 2.3.7
|
| 193 |
+
wheel 0.40.0
|
| 194 |
+
xformers 0.0.16
|
| 195 |
+
yapf 0.40.2
|
| 196 |
+
yarl 1.9.2
|
| 197 |
+
zipp 3.17.0
|
scripts/add_text_emb.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import lightning as pl
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 6 |
+
import json
|
| 7 |
+
import imageio
|
| 8 |
+
from torchvision.transforms import v2
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pdb
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 16 |
+
|
| 17 |
+
class VideoEncoder(pl.LightningModule):
|
| 18 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 19 |
+
super().__init__()
|
| 20 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 21 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 22 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 23 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 24 |
+
|
| 25 |
+
self.frame_process = v2.Compose([
|
| 26 |
+
# v2.CenterCrop(size=(900, 1600)),
|
| 27 |
+
# v2.Resize(size=(900, 1600), antialias=True),
|
| 28 |
+
v2.ToTensor(),
|
| 29 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
def crop_and_resize(self, image):
|
| 33 |
+
width, height = image.size
|
| 34 |
+
# print(width,height)
|
| 35 |
+
width_ori, height_ori_ = 832 , 480
|
| 36 |
+
image = v2.functional.resize(
|
| 37 |
+
image,
|
| 38 |
+
(round(height_ori_), round(width_ori)),
|
| 39 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 40 |
+
)
|
| 41 |
+
return image
|
| 42 |
+
|
| 43 |
+
def load_video_frames(self, video_path):
|
| 44 |
+
"""加载完整视频"""
|
| 45 |
+
reader = imageio.get_reader(video_path)
|
| 46 |
+
frames = []
|
| 47 |
+
|
| 48 |
+
for frame_data in reader:
|
| 49 |
+
frame = Image.fromarray(frame_data)
|
| 50 |
+
frame = self.crop_and_resize(frame)
|
| 51 |
+
frame = self.frame_process(frame)
|
| 52 |
+
frames.append(frame)
|
| 53 |
+
|
| 54 |
+
reader.close()
|
| 55 |
+
|
| 56 |
+
if len(frames) == 0:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
frames = torch.stack(frames, dim=0)
|
| 60 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 61 |
+
return frames
|
| 62 |
+
|
| 63 |
+
def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
|
| 64 |
+
"""编码所有场景的视频"""
|
| 65 |
+
|
| 66 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 67 |
+
encoder = encoder.cuda()
|
| 68 |
+
encoder.pipe.device = "cuda"
|
| 69 |
+
|
| 70 |
+
processed_count = 0
|
| 71 |
+
prompt_emb = 0
|
| 72 |
+
|
| 73 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 74 |
+
|
| 75 |
+
required_keys = ["latents", "cam_emb", "prompt_emb"]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
|
| 79 |
+
|
| 80 |
+
scene_dir = os.path.join(scenes_path, scene_name)
|
| 81 |
+
save_dir = os.path.join(output_dir,scene_name.split('.')[0])
|
| 82 |
+
# print('in:',scene_dir)
|
| 83 |
+
# print('out:',save_dir)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# 检查是否已编码
|
| 87 |
+
encoded_path = os.path.join(save_dir, "encoded_video.pth")
|
| 88 |
+
# if os.path.exists(encoded_path):
|
| 89 |
+
print(f"Checking scene {scene_name}...")
|
| 90 |
+
# continue
|
| 91 |
+
|
| 92 |
+
# 加载场景信息
|
| 93 |
+
|
| 94 |
+
# print(encoded_path)
|
| 95 |
+
data = torch.load(encoded_path,weights_only=False)
|
| 96 |
+
missing_keys = [key for key in required_keys if key not in data]
|
| 97 |
+
|
| 98 |
+
if missing_keys:
|
| 99 |
+
print(f"警告: 文件中缺少以下必要元素: {missing_keys}")
|
| 100 |
+
else:
|
| 101 |
+
print("文件包含所有必要元素: latents 和 cam_emb 和 prompt_emb")
|
| 102 |
+
continue
|
| 103 |
+
# with np.load(scene_cam_path) as data:
|
| 104 |
+
# cam_data = data.files
|
| 105 |
+
# cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
|
| 106 |
+
# with open(scene_cam_path, 'rb') as f:
|
| 107 |
+
# cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# 加载和编码视频
|
| 112 |
+
# video_frames = encoder.load_video_frames(video_path)
|
| 113 |
+
# if video_frames is None:
|
| 114 |
+
# print(f"Failed to load video: {video_path}")
|
| 115 |
+
# continue
|
| 116 |
+
|
| 117 |
+
# video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 118 |
+
# print(video_frames.shape)
|
| 119 |
+
# 编码视频
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
# latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
|
| 122 |
+
|
| 123 |
+
# 编码文本
|
| 124 |
+
if processed_count == 0:
|
| 125 |
+
print('encode prompt!!!')
|
| 126 |
+
prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")#A video of a scene shot using a drone's front camera
|
| 127 |
+
del encoder.pipe.prompter
|
| 128 |
+
|
| 129 |
+
data["prompt_emb"] = prompt_emb
|
| 130 |
+
|
| 131 |
+
print("已添加/更新 prompt_emb 元素")
|
| 132 |
+
|
| 133 |
+
# 保存修改后的文件(可改为新路径避免覆盖原文件)
|
| 134 |
+
torch.save(data, encoded_path)
|
| 135 |
+
|
| 136 |
+
# pdb.set_trace()
|
| 137 |
+
# 保存编码结果
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
print(f"Saved encoded data: {encoded_path}")
|
| 141 |
+
processed_count += 1
|
| 142 |
+
|
| 143 |
+
# except Exception as e:
|
| 144 |
+
# print(f"Error encoding scene {scene_name}: {e}")
|
| 145 |
+
# continue
|
| 146 |
+
print(processed_count)
|
| 147 |
+
print(f"Encoding completed! Processed {processed_count} scenes.")
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
parser = argparse.ArgumentParser()
|
| 151 |
+
parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking")
|
| 152 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 153 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 154 |
+
parser.add_argument("--vae_path", type=str,
|
| 155 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 156 |
+
|
| 157 |
+
parser.add_argument("--output_dir",type=str,
|
| 158 |
+
default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking")
|
| 159 |
+
|
| 160 |
+
args = parser.parse_args()
|
| 161 |
+
encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
|
scripts/add_text_emb_rl.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import lightning as pl
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 6 |
+
import json
|
| 7 |
+
import imageio
|
| 8 |
+
from torchvision.transforms import v2
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pdb
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 16 |
+
|
| 17 |
+
class VideoEncoder(pl.LightningModule):
|
| 18 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 19 |
+
super().__init__()
|
| 20 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 21 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 22 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 23 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 24 |
+
|
| 25 |
+
self.frame_process = v2.Compose([
|
| 26 |
+
# v2.CenterCrop(size=(900, 1600)),
|
| 27 |
+
# v2.Resize(size=(900, 1600), antialias=True),
|
| 28 |
+
v2.ToTensor(),
|
| 29 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
def crop_and_resize(self, image):
|
| 33 |
+
width, height = image.size
|
| 34 |
+
# print(width,height)
|
| 35 |
+
width_ori, height_ori_ = 832 , 480
|
| 36 |
+
image = v2.functional.resize(
|
| 37 |
+
image,
|
| 38 |
+
(round(height_ori_), round(width_ori)),
|
| 39 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 40 |
+
)
|
| 41 |
+
return image
|
| 42 |
+
|
| 43 |
+
def load_video_frames(self, video_path):
|
| 44 |
+
"""加载完整视频"""
|
| 45 |
+
reader = imageio.get_reader(video_path)
|
| 46 |
+
frames = []
|
| 47 |
+
|
| 48 |
+
for frame_data in reader:
|
| 49 |
+
frame = Image.fromarray(frame_data)
|
| 50 |
+
frame = self.crop_and_resize(frame)
|
| 51 |
+
frame = self.frame_process(frame)
|
| 52 |
+
frames.append(frame)
|
| 53 |
+
|
| 54 |
+
reader.close()
|
| 55 |
+
|
| 56 |
+
if len(frames) == 0:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
frames = torch.stack(frames, dim=0)
|
| 60 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 61 |
+
return frames
|
| 62 |
+
|
| 63 |
+
def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
|
| 64 |
+
"""编码所有场景的视频"""
|
| 65 |
+
|
| 66 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 67 |
+
encoder = encoder.cuda()
|
| 68 |
+
encoder.pipe.device = "cuda"
|
| 69 |
+
|
| 70 |
+
processed_count = 0
|
| 71 |
+
prompt_emb = 0
|
| 72 |
+
|
| 73 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 74 |
+
|
| 75 |
+
required_keys = ["latents", "cam_emb", "prompt_emb"]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
|
| 79 |
+
|
| 80 |
+
scene_dir = os.path.join(scenes_path, scene_name)
|
| 81 |
+
save_dir = os.path.join(output_dir,scene_name.split('.')[0])
|
| 82 |
+
# print('in:',scene_dir)
|
| 83 |
+
# print('out:',save_dir)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# 检查是否已编码
|
| 87 |
+
encoded_path = os.path.join(save_dir, "encoded_video.pth")
|
| 88 |
+
# if os.path.exists(encoded_path):
|
| 89 |
+
print(f"Checking scene {scene_name}...")
|
| 90 |
+
# continue
|
| 91 |
+
|
| 92 |
+
# 加载场景信息
|
| 93 |
+
|
| 94 |
+
# print(encoded_path)
|
| 95 |
+
data = torch.load(encoded_path,weights_only=False)
|
| 96 |
+
missing_keys = [key for key in required_keys if key not in data]
|
| 97 |
+
|
| 98 |
+
if missing_keys:
|
| 99 |
+
print(f"警告: 文件中缺少以下必要元素: {missing_keys}")
|
| 100 |
+
else:
|
| 101 |
+
print("文件包含所有必要元素: latents 和 cam_emb 和 prompt_emb")
|
| 102 |
+
continue
|
| 103 |
+
# with np.load(scene_cam_path) as data:
|
| 104 |
+
# cam_data = data.files
|
| 105 |
+
# cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
|
| 106 |
+
# with open(scene_cam_path, 'rb') as f:
|
| 107 |
+
# cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# 加载和编码视频
|
| 112 |
+
# video_frames = encoder.load_video_frames(video_path)
|
| 113 |
+
# if video_frames is None:
|
| 114 |
+
# print(f"Failed to load video: {video_path}")
|
| 115 |
+
# continue
|
| 116 |
+
|
| 117 |
+
# video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 118 |
+
# print(video_frames.shape)
|
| 119 |
+
# 编码视频
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
# latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
|
| 122 |
+
|
| 123 |
+
# 编码文本
|
| 124 |
+
if processed_count == 0:
|
| 125 |
+
print('encode prompt!!!')
|
| 126 |
+
prompt_emb = encoder.pipe.encode_prompt("a robotic arm executing precise manipulation tasks on a clean, organized desk")#A video of a scene shot using a drone's front camera + “A video of a scene shot using a pedestrian's front camera while walking”
|
| 127 |
+
del encoder.pipe.prompter
|
| 128 |
+
|
| 129 |
+
data["prompt_emb"] = prompt_emb
|
| 130 |
+
|
| 131 |
+
print("已添加/更新 prompt_emb 元素")
|
| 132 |
+
|
| 133 |
+
# 保存修改后的文件(可改为新路径避免覆盖原文件)
|
| 134 |
+
torch.save(data, encoded_path)
|
| 135 |
+
|
| 136 |
+
# pdb.set_trace()
|
| 137 |
+
# 保存编码结果
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
print(f"Saved encoded data: {encoded_path}")
|
| 141 |
+
processed_count += 1
|
| 142 |
+
|
| 143 |
+
# except Exception as e:
|
| 144 |
+
# print(f"Error encoding scene {scene_name}: {e}")
|
| 145 |
+
# continue
|
| 146 |
+
print(processed_count)
|
| 147 |
+
print(f"Encoding completed! Processed {processed_count} scenes.")
|
| 148 |
+
|
| 149 |
+
if __name__ == "__main__":
|
| 150 |
+
parser = argparse.ArgumentParser()
|
| 151 |
+
parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/rlbench")
|
| 152 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 153 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 154 |
+
parser.add_argument("--vae_path", type=str,
|
| 155 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 156 |
+
|
| 157 |
+
parser.add_argument("--output_dir",type=str,
|
| 158 |
+
default="/share_zhuyixuan05/zhuyixuan05/rlbench")
|
| 159 |
+
|
| 160 |
+
args = parser.parse_args()
|
| 161 |
+
encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
|
scripts/add_text_emb_spatialvid.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import lightning as pl
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 6 |
+
import json
|
| 7 |
+
import imageio
|
| 8 |
+
from torchvision.transforms import v2
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pdb
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 16 |
+
|
| 17 |
+
class VideoEncoder(pl.LightningModule):
|
| 18 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 19 |
+
super().__init__()
|
| 20 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 21 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 22 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 23 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 24 |
+
|
| 25 |
+
self.frame_process = v2.Compose([
|
| 26 |
+
# v2.CenterCrop(size=(900, 1600)),
|
| 27 |
+
# v2.Resize(size=(900, 1600), antialias=True),
|
| 28 |
+
v2.ToTensor(),
|
| 29 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
def crop_and_resize(self, image):
|
| 33 |
+
width, height = image.size
|
| 34 |
+
# print(width,height)
|
| 35 |
+
width_ori, height_ori_ = 832 , 480
|
| 36 |
+
image = v2.functional.resize(
|
| 37 |
+
image,
|
| 38 |
+
(round(height_ori_), round(width_ori)),
|
| 39 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 40 |
+
)
|
| 41 |
+
return image
|
| 42 |
+
|
| 43 |
+
def load_video_frames(self, video_path):
|
| 44 |
+
"""加载完整视频"""
|
| 45 |
+
reader = imageio.get_reader(video_path)
|
| 46 |
+
frames = []
|
| 47 |
+
|
| 48 |
+
for frame_data in reader:
|
| 49 |
+
frame = Image.fromarray(frame_data)
|
| 50 |
+
frame = self.crop_and_resize(frame)
|
| 51 |
+
frame = self.frame_process(frame)
|
| 52 |
+
frames.append(frame)
|
| 53 |
+
|
| 54 |
+
reader.close()
|
| 55 |
+
|
| 56 |
+
if len(frames) == 0:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
frames = torch.stack(frames, dim=0)
|
| 60 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 61 |
+
return frames
|
| 62 |
+
|
| 63 |
+
def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
|
| 64 |
+
"""编码所有场景的视频"""
|
| 65 |
+
|
| 66 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 67 |
+
encoder = encoder.cuda()
|
| 68 |
+
encoder.pipe.device = "cuda"
|
| 69 |
+
|
| 70 |
+
processed_count = 0
|
| 71 |
+
prompt_emb = 0
|
| 72 |
+
|
| 73 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 74 |
+
|
| 75 |
+
required_keys = ["latents", "cam_emb", "prompt_emb"]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
|
| 79 |
+
|
| 80 |
+
scene_dir = os.path.join(scenes_path, scene_name)
|
| 81 |
+
save_dir = os.path.join(output_dir,scene_name.split('.')[0])
|
| 82 |
+
# print('in:',scene_dir)
|
| 83 |
+
# print('out:',save_dir)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# 检查是否已编码
|
| 87 |
+
encoded_path = os.path.join(save_dir, "encoded_video.pth")
|
| 88 |
+
# if os.path.exists(encoded_path):
|
| 89 |
+
# print(f"Checking scene {scene_name}...")
|
| 90 |
+
# continue
|
| 91 |
+
|
| 92 |
+
# 加载场景信息
|
| 93 |
+
|
| 94 |
+
# print(encoded_path)
|
| 95 |
+
data = torch.load(encoded_path,weights_only=False,
|
| 96 |
+
map_location="cpu")
|
| 97 |
+
missing_keys = [key for key in required_keys if key not in data]
|
| 98 |
+
|
| 99 |
+
if missing_keys:
|
| 100 |
+
print(f"警告: 文件 {encoded_path} 中缺少以下必要元素: {missing_keys}")
|
| 101 |
+
# else:
|
| 102 |
+
# # print("文件包含所有必要元素: latents 和 cam_emb 和 prompt_emb")
|
| 103 |
+
# continue
|
| 104 |
+
# pdb.set_trace()
|
| 105 |
+
if data['prompt_emb']['context'].requires_grad:
|
| 106 |
+
print(f"警告: 文件 {encoded_path} 中存在含梯度变量,已消除")
|
| 107 |
+
|
| 108 |
+
data['prompt_emb']['context'] = data['prompt_emb']['context'].detach().clone()
|
| 109 |
+
|
| 110 |
+
# 双重保险:显式关闭梯度
|
| 111 |
+
data['prompt_emb']['context'].requires_grad_(False)
|
| 112 |
+
|
| 113 |
+
# 验证是否成功(可选)
|
| 114 |
+
assert not data['prompt_emb']['context'].requires_grad, "梯度仍未消除!"
|
| 115 |
+
torch.save(data, encoded_path)
|
| 116 |
+
# with np.load(scene_cam_path) as data:
|
| 117 |
+
# cam_data = data.files
|
| 118 |
+
# cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
|
| 119 |
+
# with open(scene_cam_path, 'rb') as f:
|
| 120 |
+
# cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# 加载和编码视频
|
| 125 |
+
# video_frames = encoder.load_video_frames(video_path)
|
| 126 |
+
# if video_frames is None:
|
| 127 |
+
# print(f"Failed to load video: {video_path}")
|
| 128 |
+
# continue
|
| 129 |
+
|
| 130 |
+
# video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 131 |
+
# print(video_frames.shape)
|
| 132 |
+
# 编码视频
|
| 133 |
+
'''with torch.no_grad():
|
| 134 |
+
# latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
|
| 135 |
+
|
| 136 |
+
# 编码文本
|
| 137 |
+
if processed_count == 0:
|
| 138 |
+
print('encode prompt!!!')
|
| 139 |
+
prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")#A video of a scene shot using a drone's front camera
|
| 140 |
+
del encoder.pipe.prompter
|
| 141 |
+
|
| 142 |
+
data["prompt_emb"] = prompt_emb
|
| 143 |
+
|
| 144 |
+
print("已添加/更新 prompt_emb 元素")
|
| 145 |
+
|
| 146 |
+
# 保存修改后的文件(可改为新路径避免覆盖原文件)
|
| 147 |
+
torch.save(data, encoded_path)
|
| 148 |
+
|
| 149 |
+
# pdb.set_trace()
|
| 150 |
+
# 保存编码结果
|
| 151 |
+
|
| 152 |
+
print(f"Saved encoded data: {encoded_path}")'''
|
| 153 |
+
processed_count += 1
|
| 154 |
+
|
| 155 |
+
# except Exception as e:
|
| 156 |
+
# print(f"Error encoding scene {scene_name}: {e}")
|
| 157 |
+
# continue
|
| 158 |
+
print(processed_count)
|
| 159 |
+
print(f"Encoding completed! Processed {processed_count} scenes.")
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
parser = argparse.ArgumentParser()
|
| 163 |
+
parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/spatialvid")
|
| 164 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 165 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 166 |
+
parser.add_argument("--vae_path", type=str,
|
| 167 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 168 |
+
|
| 169 |
+
parser.add_argument("--output_dir",type=str,
|
| 170 |
+
default="/share_zhuyixuan05/zhuyixuan05/spatialvid")
|
| 171 |
+
|
| 172 |
+
args = parser.parse_args()
|
| 173 |
+
encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
|
scripts/analyze_openx.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
def analyze_openx_dataset_frame_counts(dataset_path):
|
| 6 |
+
"""分析OpenX数据集中的帧数分布"""
|
| 7 |
+
|
| 8 |
+
print(f"🔧 分析OpenX数据集: {dataset_path}")
|
| 9 |
+
|
| 10 |
+
if not os.path.exists(dataset_path):
|
| 11 |
+
print(f" ⚠️ 路径不存在: {dataset_path}")
|
| 12 |
+
return
|
| 13 |
+
|
| 14 |
+
episode_dirs = []
|
| 15 |
+
total_episodes = 0
|
| 16 |
+
valid_episodes = 0
|
| 17 |
+
|
| 18 |
+
# 收集所有episode目录
|
| 19 |
+
for item in os.listdir(dataset_path):
|
| 20 |
+
episode_dir = os.path.join(dataset_path, item)
|
| 21 |
+
if os.path.isdir(episode_dir):
|
| 22 |
+
total_episodes += 1
|
| 23 |
+
encoded_path = os.path.join(episode_dir, "encoded_video.pth")
|
| 24 |
+
if os.path.exists(encoded_path):
|
| 25 |
+
episode_dirs.append(episode_dir)
|
| 26 |
+
valid_episodes += 1
|
| 27 |
+
|
| 28 |
+
print(f"📊 总episode数: {total_episodes}")
|
| 29 |
+
print(f"📊 有效episode数: {valid_episodes}")
|
| 30 |
+
|
| 31 |
+
if len(episode_dirs) == 0:
|
| 32 |
+
print("❌ 没有找到有效的episode")
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
# 统计帧数分布
|
| 36 |
+
frame_counts = []
|
| 37 |
+
less_than_10 = 0
|
| 38 |
+
less_than_8 = 0
|
| 39 |
+
less_than_5 = 0
|
| 40 |
+
error_count = 0
|
| 41 |
+
|
| 42 |
+
print("🔧 开始分析帧数分布...")
|
| 43 |
+
|
| 44 |
+
for episode_dir in tqdm(episode_dirs, desc="分析episodes"):
|
| 45 |
+
try:
|
| 46 |
+
encoded_data = torch.load(
|
| 47 |
+
os.path.join(episode_dir, "encoded_video.pth"),
|
| 48 |
+
weights_only=False,
|
| 49 |
+
map_location="cpu"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
latents = encoded_data['latents'] # [C, T, H, W]
|
| 53 |
+
frame_count = latents.shape[1] # T维度
|
| 54 |
+
frame_counts.append(frame_count)
|
| 55 |
+
|
| 56 |
+
if frame_count < 10:
|
| 57 |
+
less_than_10 += 1
|
| 58 |
+
if frame_count < 8:
|
| 59 |
+
less_than_8 += 1
|
| 60 |
+
if frame_count < 5:
|
| 61 |
+
less_than_5 += 1
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
error_count += 1
|
| 65 |
+
if error_count <= 5: # 只打印前5个错误
|
| 66 |
+
print(f"❌ 加载episode {os.path.basename(episode_dir)} 时出错: {e}")
|
| 67 |
+
|
| 68 |
+
# 统计结果
|
| 69 |
+
total_valid = len(frame_counts)
|
| 70 |
+
print(f"\n📈 帧数分布统计:")
|
| 71 |
+
print(f" 总有效episodes: {total_valid}")
|
| 72 |
+
print(f" 错误episodes: {error_count}")
|
| 73 |
+
print(f" 最小帧数: {min(frame_counts) if frame_counts else 0}")
|
| 74 |
+
print(f" 最大帧数: {max(frame_counts) if frame_counts else 0}")
|
| 75 |
+
print(f" 平均帧数: {sum(frame_counts) / len(frame_counts):.2f}" if frame_counts else 0)
|
| 76 |
+
|
| 77 |
+
print(f"\n🎯 关键统计:")
|
| 78 |
+
print(f" 帧数 < 5: {less_than_5:6d} episodes ({less_than_5/total_valid*100:.2f}%)")
|
| 79 |
+
print(f" 帧数 < 8: {less_than_8:6d} episodes ({less_than_8/total_valid*100:.2f}%)")
|
| 80 |
+
print(f" 帧数 < 10: {less_than_10:6d} episodes ({less_than_10/total_valid*100:.2f}%)")
|
| 81 |
+
print(f" 帧数 >= 10: {total_valid-less_than_10:6d} episodes ({(total_valid-less_than_10)/total_valid*100:.2f}%)")
|
| 82 |
+
|
| 83 |
+
# 详细分布
|
| 84 |
+
frame_counts.sort()
|
| 85 |
+
print(f"\n📊 详细帧数分布:")
|
| 86 |
+
|
| 87 |
+
# 按范围统计
|
| 88 |
+
ranges = [
|
| 89 |
+
(1, 4, "1-4帧"),
|
| 90 |
+
(5, 7, "5-7帧"),
|
| 91 |
+
(8, 9, "8-9帧"),
|
| 92 |
+
(10, 19, "10-19帧"),
|
| 93 |
+
(20, 49, "20-49帧"),
|
| 94 |
+
(50, 99, "50-99帧"),
|
| 95 |
+
(100, float('inf'), "100+帧")
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
for min_f, max_f, label in ranges:
|
| 99 |
+
count = sum(1 for f in frame_counts if min_f <= f <= max_f)
|
| 100 |
+
percentage = count / total_valid * 100
|
| 101 |
+
print(f" {label:8s}: {count:6d} episodes ({percentage:5.2f}%)")
|
| 102 |
+
|
| 103 |
+
# 建议的训练配置
|
| 104 |
+
print(f"\n💡 训练配置建议:")
|
| 105 |
+
time_compression_ratio = 4
|
| 106 |
+
min_condition_compressed = 4 // time_compression_ratio # 1帧
|
| 107 |
+
target_frames_compressed = 32 // time_compression_ratio # 8帧
|
| 108 |
+
min_required_compressed = min_condition_compressed + target_frames_compressed # 9帧
|
| 109 |
+
|
| 110 |
+
usable_episodes = sum(1 for f in frame_counts if f >= min_required_compressed)
|
| 111 |
+
usable_percentage = usable_episodes / total_valid * 100
|
| 112 |
+
|
| 113 |
+
print(f" 最小条件帧数(压缩后): {min_condition_compressed}")
|
| 114 |
+
print(f" 目标帧数(压缩后): {target_frames_compressed}")
|
| 115 |
+
print(f" 最小所需帧数(压缩后): {min_required_compressed}")
|
| 116 |
+
print(f" 可用于训练的episodes: {usable_episodes} ({usable_percentage:.2f}%)")
|
| 117 |
+
|
| 118 |
+
# 保存详细统计到文件
|
| 119 |
+
output_file = os.path.join(dataset_path, "frame_count_analysis.txt")
|
| 120 |
+
with open(output_file, 'w') as f:
|
| 121 |
+
f.write(f"OpenX Dataset Frame Count Analysis\n")
|
| 122 |
+
f.write(f"Dataset Path: {dataset_path}\n")
|
| 123 |
+
f.write(f"Analysis Date: {__import__('datetime').datetime.now()}\n\n")
|
| 124 |
+
|
| 125 |
+
f.write(f"Total Episodes: {total_episodes}\n")
|
| 126 |
+
f.write(f"Valid Episodes: {total_valid}\n")
|
| 127 |
+
f.write(f"Error Episodes: {error_count}\n\n")
|
| 128 |
+
|
| 129 |
+
f.write(f"Frame Count Statistics:\n")
|
| 130 |
+
f.write(f" Min Frames: {min(frame_counts) if frame_counts else 0}\n")
|
| 131 |
+
f.write(f" Max Frames: {max(frame_counts) if frame_counts else 0}\n")
|
| 132 |
+
f.write(f" Avg Frames: {sum(frame_counts) / len(frame_counts):.2f}\n\n" if frame_counts else " Avg Frames: 0\n\n")
|
| 133 |
+
|
| 134 |
+
f.write(f"Key Statistics:\n")
|
| 135 |
+
f.write(f" < 5 frames: {less_than_5} ({less_than_5/total_valid*100:.2f}%)\n")
|
| 136 |
+
f.write(f" < 8 frames: {less_than_8} ({less_than_8/total_valid*100:.2f}%)\n")
|
| 137 |
+
f.write(f" < 10 frames: {less_than_10} ({less_than_10/total_valid*100:.2f}%)\n")
|
| 138 |
+
f.write(f" >= 10 frames: {total_valid-less_than_10} ({(total_valid-less_than_10)/total_valid*100:.2f}%)\n\n")
|
| 139 |
+
|
| 140 |
+
f.write(f"Detailed Distribution:\n")
|
| 141 |
+
for min_f, max_f, label in ranges:
|
| 142 |
+
count = sum(1 for f in frame_counts if min_f <= f <= max_f)
|
| 143 |
+
percentage = count / total_valid * 100
|
| 144 |
+
f.write(f" {label}: {count} ({percentage:.2f}%)\n")
|
| 145 |
+
|
| 146 |
+
f.write(f"\nTraining Configuration Recommendation:\n")
|
| 147 |
+
f.write(f" Usable Episodes (>= {min_required_compressed} compressed frames): {usable_episodes} ({usable_percentage:.2f}%)\n")
|
| 148 |
+
|
| 149 |
+
# 写入所有帧数
|
| 150 |
+
f.write(f"\nAll Frame Counts:\n")
|
| 151 |
+
for i, count in enumerate(frame_counts):
|
| 152 |
+
f.write(f"{count}")
|
| 153 |
+
if (i + 1) % 20 == 0:
|
| 154 |
+
f.write("\n")
|
| 155 |
+
else:
|
| 156 |
+
f.write(", ")
|
| 157 |
+
|
| 158 |
+
print(f"\n💾 详细统计已保存到: {output_file}")
|
| 159 |
+
|
| 160 |
+
return {
|
| 161 |
+
'total_valid': total_valid,
|
| 162 |
+
'less_than_10': less_than_10,
|
| 163 |
+
'less_than_8': less_than_8,
|
| 164 |
+
'less_than_5': less_than_5,
|
| 165 |
+
'frame_counts': frame_counts,
|
| 166 |
+
'usable_episodes': usable_episodes
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
def quick_sample_analysis(dataset_path, sample_size=1000):
|
| 170 |
+
"""快速采样分析,用于大数据集的初步估计"""
|
| 171 |
+
|
| 172 |
+
print(f"🚀 快速采样分析 (样本数: {sample_size})")
|
| 173 |
+
|
| 174 |
+
episode_dirs = []
|
| 175 |
+
for item in os.listdir(dataset_path):
|
| 176 |
+
episode_dir = os.path.join(dataset_path, item)
|
| 177 |
+
if os.path.isdir(episode_dir):
|
| 178 |
+
encoded_path = os.path.join(episode_dir, "encoded_video.pth")
|
| 179 |
+
if os.path.exists(encoded_path):
|
| 180 |
+
episode_dirs.append(episode_dir)
|
| 181 |
+
|
| 182 |
+
if len(episode_dirs) == 0:
|
| 183 |
+
print("❌ 没有找到有效的episode")
|
| 184 |
+
return
|
| 185 |
+
|
| 186 |
+
# 随机采样
|
| 187 |
+
import random
|
| 188 |
+
sample_dirs = random.sample(episode_dirs, min(sample_size, len(episode_dirs)))
|
| 189 |
+
|
| 190 |
+
frame_counts = []
|
| 191 |
+
less_than_10 = 0
|
| 192 |
+
|
| 193 |
+
for episode_dir in tqdm(sample_dirs, desc="采样分析"):
|
| 194 |
+
try:
|
| 195 |
+
encoded_data = torch.load(
|
| 196 |
+
os.path.join(episode_dir, "encoded_video.pth"),
|
| 197 |
+
weights_only=False,
|
| 198 |
+
map_location="cpu"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
frame_count = encoded_data['latents'].shape[1]
|
| 202 |
+
frame_counts.append(frame_count)
|
| 203 |
+
|
| 204 |
+
if frame_count < 10:
|
| 205 |
+
less_than_10 += 1
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
continue
|
| 209 |
+
|
| 210 |
+
total_sample = len(frame_counts)
|
| 211 |
+
percentage_less_than_10 = less_than_10 / total_sample * 100
|
| 212 |
+
|
| 213 |
+
print(f"📊 采样结果:")
|
| 214 |
+
print(f" 采样数量: {total_sample}")
|
| 215 |
+
print(f" < 10帧: {less_than_10} ({percentage_less_than_10:.2f}%)")
|
| 216 |
+
print(f" >= 10帧: {total_sample - less_than_10} ({100 - percentage_less_than_10:.2f}%)")
|
| 217 |
+
print(f" 平均帧数: {sum(frame_counts) / len(frame_counts):.2f}")
|
| 218 |
+
|
| 219 |
+
# 估算全数据集
|
| 220 |
+
total_episodes = len(episode_dirs)
|
| 221 |
+
estimated_less_than_10 = int(total_episodes * percentage_less_than_10 / 100)
|
| 222 |
+
|
| 223 |
+
print(f"\n🔮 全数据集估算:")
|
| 224 |
+
print(f" 总episodes: {total_episodes}")
|
| 225 |
+
print(f" 估算 < 10帧: {estimated_less_than_10} ({percentage_less_than_10:.2f}%)")
|
| 226 |
+
print(f" 估算 >= 10帧: {total_episodes - estimated_less_than_10} ({100 - percentage_less_than_10:.2f}%)")
|
| 227 |
+
|
| 228 |
+
if __name__ == "__main__":
|
| 229 |
+
import argparse
|
| 230 |
+
|
| 231 |
+
parser = argparse.ArgumentParser(description="分析OpenX数据集的帧数分布")
|
| 232 |
+
parser.add_argument("--dataset_path", type=str,
|
| 233 |
+
default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded",
|
| 234 |
+
help="OpenX编码数据集路径")
|
| 235 |
+
parser.add_argument("--quick", action="store_true", help="快速采样分析模式")
|
| 236 |
+
parser.add_argument("--sample_size", type=int, default=1000, help="快速模式的采样数量")
|
| 237 |
+
|
| 238 |
+
args = parser.parse_args()
|
| 239 |
+
|
| 240 |
+
if args.quick:
|
| 241 |
+
quick_sample_analysis(args.dataset_path, args.sample_size)
|
| 242 |
+
else:
|
| 243 |
+
analyze_openx_dataset_frame_counts(args.dataset_path)
|
scripts/analyze_pose.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
from pose_classifier import PoseClassifier
|
| 6 |
+
import torch
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
|
| 9 |
+
def analyze_turning_patterns_detailed(dataset_path, num_samples=50):
|
| 10 |
+
"""详细分析转弯模式,基于相对于reference的pose变化"""
|
| 11 |
+
classifier = PoseClassifier()
|
| 12 |
+
samples_path = os.path.join(dataset_path, "samples")
|
| 13 |
+
|
| 14 |
+
all_analyses = []
|
| 15 |
+
sample_count = 0
|
| 16 |
+
|
| 17 |
+
# 用于统计每个类别的样本
|
| 18 |
+
class_samples = defaultdict(list)
|
| 19 |
+
|
| 20 |
+
print("=== 开始分析样本(基于相对于reference的变化)===")
|
| 21 |
+
|
| 22 |
+
for item in sorted(os.listdir(samples_path)): # 排序以便有序输出
|
| 23 |
+
if sample_count >= num_samples:
|
| 24 |
+
break
|
| 25 |
+
|
| 26 |
+
sample_dir = os.path.join(samples_path, item)
|
| 27 |
+
if os.path.isdir(sample_dir):
|
| 28 |
+
poses_path = os.path.join(sample_dir, "poses.json")
|
| 29 |
+
if os.path.exists(poses_path):
|
| 30 |
+
try:
|
| 31 |
+
with open(poses_path, 'r') as f:
|
| 32 |
+
poses_data = json.load(f)
|
| 33 |
+
|
| 34 |
+
target_relative_poses = poses_data['target_relative_poses']
|
| 35 |
+
|
| 36 |
+
if len(target_relative_poses) > 0:
|
| 37 |
+
# 🔧 创建相对pose向量(已经是相对于reference的)
|
| 38 |
+
pose_vecs = []
|
| 39 |
+
for pose_data in target_relative_poses:
|
| 40 |
+
# 相对位移(已经是相对于reference计算的)
|
| 41 |
+
translation = torch.tensor(pose_data['relative_translation'], dtype=torch.float32)
|
| 42 |
+
|
| 43 |
+
# 🔧 相对旋转(需要从current和reference计算)
|
| 44 |
+
current_rotation = torch.tensor(pose_data['current_rotation'], dtype=torch.float32)
|
| 45 |
+
reference_rotation = torch.tensor(pose_data['reference_rotation'], dtype=torch.float32)
|
| 46 |
+
|
| 47 |
+
# 计算相对旋转:q_relative = q_ref^-1 * q_current
|
| 48 |
+
relative_rotation = calculate_relative_rotation(current_rotation, reference_rotation)
|
| 49 |
+
|
| 50 |
+
# 组合为7D向量:[relative_translation, relative_rotation]
|
| 51 |
+
pose_vec = torch.cat([translation, relative_rotation], dim=0)
|
| 52 |
+
pose_vecs.append(pose_vec)
|
| 53 |
+
|
| 54 |
+
if pose_vecs:
|
| 55 |
+
pose_sequence = torch.stack(pose_vecs, dim=0)
|
| 56 |
+
|
| 57 |
+
# 🔧 使用新的分析方法
|
| 58 |
+
analysis = classifier.analyze_pose_sequence(pose_sequence)
|
| 59 |
+
analysis['sample_name'] = item
|
| 60 |
+
all_analyses.append(analysis)
|
| 61 |
+
|
| 62 |
+
# 🔧 详细输出每个样本的分类信息
|
| 63 |
+
print(f"\n--- 样本 {sample_count + 1}: {item} ---")
|
| 64 |
+
print(f"总帧数: {analysis['total_frames']}")
|
| 65 |
+
print(f"总距离: {analysis['total_distance']:.4f}")
|
| 66 |
+
|
| 67 |
+
# 分类分布
|
| 68 |
+
class_dist = analysis['class_distribution']
|
| 69 |
+
print(f"分类分布:")
|
| 70 |
+
for class_name, count in class_dist.items():
|
| 71 |
+
percentage = count / analysis['total_frames'] * 100
|
| 72 |
+
print(f" {class_name}: {count} 帧 ({percentage:.1f}%)")
|
| 73 |
+
|
| 74 |
+
# 🔧 调试前几个pose的分类过程
|
| 75 |
+
print(f"前3帧的详细分类过程:")
|
| 76 |
+
for i in range(min(3, len(pose_vecs))):
|
| 77 |
+
debug_info = classifier.debug_single_pose(
|
| 78 |
+
pose_vecs[i][:3], pose_vecs[i][3:7]
|
| 79 |
+
)
|
| 80 |
+
print(f" 帧{i}: {debug_info['classification']} "
|
| 81 |
+
f"(yaw: {debug_info['yaw_angle_deg']:.2f}°, "
|
| 82 |
+
f"forward: {debug_info['forward_movement']:.3f})")
|
| 83 |
+
|
| 84 |
+
# 运动段落
|
| 85 |
+
print(f"运动段落:")
|
| 86 |
+
for i, segment in enumerate(analysis['motion_segments']):
|
| 87 |
+
print(f" 段落{i+1}: {segment['class']} (帧 {segment['start_frame']}-{segment['end_frame']}, 持续 {segment['duration']} 帧)")
|
| 88 |
+
|
| 89 |
+
# 🔧 确定主要运动类型
|
| 90 |
+
dominant_class = max(class_dist.items(), key=lambda x: x[1])
|
| 91 |
+
dominant_class_name = dominant_class[0]
|
| 92 |
+
dominant_percentage = dominant_class[1] / analysis['total_frames'] * 100
|
| 93 |
+
|
| 94 |
+
print(f"主要运动类型: {dominant_class_name} ({dominant_percentage:.1f}%)")
|
| 95 |
+
|
| 96 |
+
# 将样本添加到对应类别
|
| 97 |
+
class_samples[dominant_class_name].append({
|
| 98 |
+
'name': item,
|
| 99 |
+
'percentage': dominant_percentage,
|
| 100 |
+
'analysis': analysis
|
| 101 |
+
})
|
| 102 |
+
|
| 103 |
+
sample_count += 1
|
| 104 |
+
|
| 105 |
+
except Exception as e:
|
| 106 |
+
print(f"❌ 处理样本 {item} 时出错: {e}")
|
| 107 |
+
|
| 108 |
+
print("\n" + "="*60)
|
| 109 |
+
print("=== 按类别分组的样本统计(基于相对于reference的变化)===")
|
| 110 |
+
|
| 111 |
+
# 🔧 按类别输出样本列表
|
| 112 |
+
for class_name in ['forward', 'backward', 'left_turn', 'right_turn']:
|
| 113 |
+
samples = class_samples[class_name]
|
| 114 |
+
print(f"\n🔸 {class_name.upper()} 类样本 (共 {len(samples)} 个):")
|
| 115 |
+
|
| 116 |
+
if samples:
|
| 117 |
+
# 按主要类别占比排序
|
| 118 |
+
samples.sort(key=lambda x: x['percentage'], reverse=True)
|
| 119 |
+
|
| 120 |
+
for i, sample_info in enumerate(samples, 1):
|
| 121 |
+
print(f" {i:2d}. {sample_info['name']} ({sample_info['percentage']:.1f}%)")
|
| 122 |
+
|
| 123 |
+
# 显示详细的段落信息
|
| 124 |
+
segments = sample_info['analysis']['motion_segments']
|
| 125 |
+
segment_summary = []
|
| 126 |
+
for seg in segments:
|
| 127 |
+
if seg['duration'] >= 2: # 只显示持续时间>=2帧的段落
|
| 128 |
+
segment_summary.append(f"{seg['class']}({seg['duration']})")
|
| 129 |
+
|
| 130 |
+
if segment_summary:
|
| 131 |
+
print(f" 段落: {' -> '.join(segment_summary)}")
|
| 132 |
+
else:
|
| 133 |
+
print(" (无样本)")
|
| 134 |
+
|
| 135 |
+
# 🔧 统计总体模式
|
| 136 |
+
print(f"\n" + "="*60)
|
| 137 |
+
print("=== 总体统计 ===")
|
| 138 |
+
|
| 139 |
+
total_forward = sum(a['class_distribution']['forward'] for a in all_analyses)
|
| 140 |
+
total_backward = sum(a['class_distribution']['backward'] for a in all_analyses)
|
| 141 |
+
total_left_turn = sum(a['class_distribution']['left_turn'] for a in all_analyses)
|
| 142 |
+
total_right_turn = sum(a['class_distribution']['right_turn'] for a in all_analyses)
|
| 143 |
+
total_frames = total_forward + total_backward + total_left_turn + total_right_turn
|
| 144 |
+
|
| 145 |
+
print(f"总样本数: {len(all_analyses)}")
|
| 146 |
+
print(f"总帧数: {total_frames}")
|
| 147 |
+
print(f"Forward: {total_forward} 帧 ({total_forward/total_frames*100:.1f}%)")
|
| 148 |
+
print(f"Backward: {total_backward} 帧 ({total_backward/total_frames*100:.1f}%)")
|
| 149 |
+
print(f"Left Turn: {total_left_turn} 帧 ({total_left_turn/total_frames*100:.1f}%)")
|
| 150 |
+
print(f"Right Turn: {total_right_turn} 帧 ({total_right_turn/total_frames*100:.1f}%)")
|
| 151 |
+
|
| 152 |
+
# 🔧 样本分布统计
|
| 153 |
+
print(f"\n按主要类型的样本分布:")
|
| 154 |
+
for class_name in ['forward', 'backward', 'left_turn', 'right_turn']:
|
| 155 |
+
count = len(class_samples[class_name])
|
| 156 |
+
percentage = count / len(all_analyses) * 100 if all_analyses else 0
|
| 157 |
+
print(f" {class_name}: {count} 样本 ({percentage:.1f}%)")
|
| 158 |
+
|
| 159 |
+
return all_analyses, class_samples
|
| 160 |
+
|
| 161 |
+
def calculate_relative_rotation(current_rotation, reference_rotation):
|
| 162 |
+
"""计算相对旋转四元数"""
|
| 163 |
+
q_current = torch.tensor(current_rotation, dtype=torch.float32)
|
| 164 |
+
q_ref = torch.tensor(reference_rotation, dtype=torch.float32)
|
| 165 |
+
|
| 166 |
+
# 计算参考旋转的逆 (q_ref^-1)
|
| 167 |
+
q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]])
|
| 168 |
+
|
| 169 |
+
# 四元数乘法计算相对旋转: q_relative = q_ref^-1 * q_current
|
| 170 |
+
w1, x1, y1, z1 = q_ref_inv
|
| 171 |
+
w2, x2, y2, z2 = q_current
|
| 172 |
+
|
| 173 |
+
relative_rotation = torch.tensor([
|
| 174 |
+
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
|
| 175 |
+
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
|
| 176 |
+
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
|
| 177 |
+
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
|
| 178 |
+
])
|
| 179 |
+
|
| 180 |
+
return relative_rotation
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
dataset_path = "/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_2"
|
| 184 |
+
|
| 185 |
+
print("开始详细分析pose分类(基于相对于reference的变化)...")
|
| 186 |
+
all_analyses, class_samples = analyze_turning_patterns_detailed(dataset_path, num_samples=4000)
|
| 187 |
+
|
| 188 |
+
print(f"\n🎉 分析完成! 共处理 {len(all_analyses)} 个样本")
|
scripts/batch_drone.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import subprocess
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
src_root = "/share_zhuyixuan05/zhuyixuan05/spatialvid"
|
| 7 |
+
dst_root = "/share_zhuyixuan05/zhuyixuan05/New_spatialvid_drone_first"
|
| 8 |
+
infer_script = "/home/zhuyixuan05/ReCamMaster/infer_origin.py" # 修改为你的实际路径
|
| 9 |
+
|
| 10 |
+
while True:
|
| 11 |
+
# 随机选择一个子文件夹
|
| 12 |
+
subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
|
| 13 |
+
if not subdirs:
|
| 14 |
+
print("没有可用的子文件夹")
|
| 15 |
+
break
|
| 16 |
+
chosen = random.choice(subdirs)
|
| 17 |
+
chosen_dir = os.path.join(src_root, chosen)
|
| 18 |
+
pth_file = os.path.join(chosen_dir, "encoded_video.pth")
|
| 19 |
+
if not os.path.exists(pth_file):
|
| 20 |
+
print(f"{pth_file} 不存在,跳过")
|
| 21 |
+
continue
|
| 22 |
+
|
| 23 |
+
# 生成输出文件名
|
| 24 |
+
out_file = os.path.join(dst_root, f"{chosen}.mp4")
|
| 25 |
+
print(f"开始生成: {pth_file} -> {out_file}")
|
| 26 |
+
|
| 27 |
+
# 构造命令
|
| 28 |
+
cmd = [
|
| 29 |
+
"python", infer_script,
|
| 30 |
+
"--condition_pth", pth_file,
|
| 31 |
+
"--output_path", out_file,
|
| 32 |
+
"--prompt", "exploring the world",
|
| 33 |
+
"--modality_type", "sekai",
|
| 34 |
+
"--direction", "right",
|
| 35 |
+
"--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step25000_first.ckpt",
|
| 36 |
+
"--use_gt_prompt"
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
# 仅使用第二张 GPU
|
| 40 |
+
env = os.environ.copy()
|
| 41 |
+
env["CUDA_VISIBLE_DEVICES"] = "0"
|
| 42 |
+
|
| 43 |
+
# 执行推理
|
| 44 |
+
subprocess.run(cmd, env=env)
|
scripts/batch_infer.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import glob
|
| 6 |
+
|
| 7 |
+
def find_video_files(videos_dir):
|
| 8 |
+
"""查找视频目录下的所有视频文件"""
|
| 9 |
+
video_extensions = ['.mp4']
|
| 10 |
+
video_files = []
|
| 11 |
+
|
| 12 |
+
for ext in video_extensions:
|
| 13 |
+
pattern = os.path.join(videos_dir, f"*{ext}")
|
| 14 |
+
video_files.extend(glob.glob(pattern))
|
| 15 |
+
|
| 16 |
+
return sorted(video_files)
|
| 17 |
+
|
| 18 |
+
def run_inference(condition_video, direction, dit_path, output_dir):
|
| 19 |
+
"""运行单个推理任务"""
|
| 20 |
+
# 构建输出文件名
|
| 21 |
+
input_filename = os.path.basename(condition_video)
|
| 22 |
+
name_parts = os.path.splitext(input_filename)
|
| 23 |
+
output_filename = f"{name_parts[0]}_{direction}{name_parts[1]}"
|
| 24 |
+
output_path = os.path.join(output_dir, output_filename)
|
| 25 |
+
|
| 26 |
+
# 构建推理命令
|
| 27 |
+
cmd = [
|
| 28 |
+
"python", "infer_nus.py",
|
| 29 |
+
"--condition_video", condition_video,
|
| 30 |
+
"--direction", direction,
|
| 31 |
+
"--dit_path", dit_path,
|
| 32 |
+
"--output_path", output_path,
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
print(f"🎬 生成 {direction} 方向视频: {input_filename} -> {output_filename}")
|
| 36 |
+
print(f" 命令: {' '.join(cmd)}")
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
# 运行推理
|
| 40 |
+
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
| 41 |
+
print(f"✅ 成功生成: {output_path}")
|
| 42 |
+
return True
|
| 43 |
+
except subprocess.CalledProcessError as e:
|
| 44 |
+
print(f"❌ 生成失败: {e}")
|
| 45 |
+
print(f" 错误输出: {e.stderr}")
|
| 46 |
+
return False
|
| 47 |
+
|
| 48 |
+
def batch_inference(args):
|
| 49 |
+
"""批量推理主函数"""
|
| 50 |
+
videos_dir = args.videos_dir
|
| 51 |
+
output_dir = args.output_dir
|
| 52 |
+
directions = args.directions
|
| 53 |
+
dit_path = args.dit_path
|
| 54 |
+
|
| 55 |
+
# 检查输入目录
|
| 56 |
+
if not os.path.exists(videos_dir):
|
| 57 |
+
print(f"❌ 视频目录不存在: {videos_dir}")
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
# 创建输出目录
|
| 61 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 62 |
+
print(f"📁 输出目录: {output_dir}")
|
| 63 |
+
|
| 64 |
+
# 查找所有视频文件
|
| 65 |
+
video_files = find_video_files(videos_dir)
|
| 66 |
+
|
| 67 |
+
if not video_files:
|
| 68 |
+
print(f"❌ 在 {videos_dir} 中没有找到视频文件")
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
print(f"🎥 找到 {len(video_files)} 个视频文件:")
|
| 72 |
+
for video in video_files:
|
| 73 |
+
print(f" - {os.path.basename(video)}")
|
| 74 |
+
|
| 75 |
+
print(f"🎯 将为每个视频生成以下方向: {', '.join(directions)}")
|
| 76 |
+
print(f"📊 总共将生成 {len(video_files) * len(directions)} 个视频")
|
| 77 |
+
|
| 78 |
+
# 统计信息
|
| 79 |
+
total_tasks = len(video_files) * len(directions)
|
| 80 |
+
completed_tasks = 0
|
| 81 |
+
failed_tasks = 0
|
| 82 |
+
|
| 83 |
+
# 批量处理
|
| 84 |
+
for i, video_file in enumerate(video_files, 1):
|
| 85 |
+
print(f"\n{'='*60}")
|
| 86 |
+
print(f"处理视频 {i}/{len(video_files)}: {os.path.basename(video_file)}")
|
| 87 |
+
print(f"{'='*60}")
|
| 88 |
+
|
| 89 |
+
for j, direction in enumerate(directions, 1):
|
| 90 |
+
print(f"\n--- 方向 {j}/{len(directions)}: {direction} ---")
|
| 91 |
+
|
| 92 |
+
# 检查输出文件是否已存在
|
| 93 |
+
input_filename = os.path.basename(video_file)
|
| 94 |
+
name_parts = os.path.splitext(input_filename)
|
| 95 |
+
output_filename = f"{name_parts[0]}_{direction}{name_parts[1]}"
|
| 96 |
+
output_path = os.path.join(output_dir, output_filename)
|
| 97 |
+
|
| 98 |
+
if os.path.exists(output_path) and not args.overwrite:
|
| 99 |
+
print(f"⏭️ 文件已存在,跳过: {output_filename}")
|
| 100 |
+
completed_tasks += 1
|
| 101 |
+
continue
|
| 102 |
+
|
| 103 |
+
# 运行推理
|
| 104 |
+
success = run_inference(
|
| 105 |
+
condition_video=video_file,
|
| 106 |
+
direction=direction,
|
| 107 |
+
dit_path=dit_path,
|
| 108 |
+
output_dir=output_dir,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
if success:
|
| 112 |
+
completed_tasks += 1
|
| 113 |
+
else:
|
| 114 |
+
failed_tasks += 1
|
| 115 |
+
|
| 116 |
+
# 显示进度
|
| 117 |
+
current_progress = completed_tasks + failed_tasks
|
| 118 |
+
print(f"📈 进度: {current_progress}/{total_tasks} "
|
| 119 |
+
f"(成功: {completed_tasks}, 失败: {failed_tasks})")
|
| 120 |
+
|
| 121 |
+
# 最终统计
|
| 122 |
+
print(f"\n{'='*60}")
|
| 123 |
+
print(f"🎉 批量推理完成!")
|
| 124 |
+
print(f"📊 总任务数: {total_tasks}")
|
| 125 |
+
print(f"✅ 成功: {completed_tasks}")
|
| 126 |
+
print(f"❌ 失败: {failed_tasks}")
|
| 127 |
+
print(f"📁 输出目录: {output_dir}")
|
| 128 |
+
|
| 129 |
+
if failed_tasks > 0:
|
| 130 |
+
print(f"⚠️ 有 {failed_tasks} 个任务失败,请检查日志")
|
| 131 |
+
|
| 132 |
+
# 列出生成的文件
|
| 133 |
+
if completed_tasks > 0:
|
| 134 |
+
print(f"\n📋 生成的文件:")
|
| 135 |
+
generated_files = glob.glob(os.path.join(output_dir, "*.mp4"))
|
| 136 |
+
for file_path in sorted(generated_files):
|
| 137 |
+
print(f" - {os.path.basename(file_path)}")
|
| 138 |
+
|
| 139 |
+
def main():
|
| 140 |
+
parser = argparse.ArgumentParser(description="批量对nus/videos目录下的所有视频生成不同方向的输出")
|
| 141 |
+
|
| 142 |
+
parser.add_argument("--videos_dir", type=str, default="/home/zhuyixuan05/ReCamMaster/nus/videos/4032",
|
| 143 |
+
help="输入视频目录路径")
|
| 144 |
+
|
| 145 |
+
parser.add_argument("--output_dir", type=str, default="nus/infer_results/batch_dynamic_4032_noise",
|
| 146 |
+
help="输出视频目录路径")
|
| 147 |
+
|
| 148 |
+
parser.add_argument("--directions", nargs="+",
|
| 149 |
+
default=["left_turn", "right_turn"],
|
| 150 |
+
choices=["forward", "backward", "left_turn", "right_turn"],
|
| 151 |
+
help="要生成的方向列表")
|
| 152 |
+
|
| 153 |
+
parser.add_argument("--dit_path", type=str, default="/home/zhuyixuan05/ReCamMaster/nus_dynamic/step15000_dynamic.ckpt",
|
| 154 |
+
help="训练好的DiT模型路径")
|
| 155 |
+
|
| 156 |
+
parser.add_argument("--overwrite", action="store_true",
|
| 157 |
+
help="是否覆盖已存在的输出文件")
|
| 158 |
+
|
| 159 |
+
parser.add_argument("--dry_run", action="store_true",
|
| 160 |
+
help="只显示将要执行的任务,不实际运行")
|
| 161 |
+
|
| 162 |
+
args = parser.parse_args()
|
| 163 |
+
|
| 164 |
+
if args.dry_run:
|
| 165 |
+
print("🔍 预览模式 - 只显示任务,不执行")
|
| 166 |
+
videos_dir = args.videos_dir
|
| 167 |
+
video_files = find_video_files(videos_dir)
|
| 168 |
+
|
| 169 |
+
print(f"📁 输入目录: {videos_dir}")
|
| 170 |
+
print(f"📁 输出目录: {args.output_dir}")
|
| 171 |
+
print(f"🎥 找到视频: {len(video_files)} 个")
|
| 172 |
+
print(f"🎯 生成方向: {', '.join(args.directions)}")
|
| 173 |
+
print(f"📊 总任务数: {len(video_files) * len(args.directions)}")
|
| 174 |
+
|
| 175 |
+
print(f"\n将要执行的任务:")
|
| 176 |
+
for video in video_files:
|
| 177 |
+
for direction in args.directions:
|
| 178 |
+
input_name = os.path.basename(video)
|
| 179 |
+
name_parts = os.path.splitext(input_name)
|
| 180 |
+
output_name = f"{name_parts[0]}_{direction}{name_parts[1]}"
|
| 181 |
+
print(f" {input_name} -> {output_name} ({direction})")
|
| 182 |
+
else:
|
| 183 |
+
batch_inference(args)
|
| 184 |
+
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
main()
|
scripts/batch_nus.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import subprocess
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
src_root = "/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes"
|
| 7 |
+
dst_root = "/share_zhuyixuan05/zhuyixuan05/New_nus_right_2"
|
| 8 |
+
infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径
|
| 9 |
+
|
| 10 |
+
while True:
|
| 11 |
+
# 随机选择一个子文件夹
|
| 12 |
+
subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
|
| 13 |
+
if not subdirs:
|
| 14 |
+
print("没有可用的子文件夹")
|
| 15 |
+
break
|
| 16 |
+
chosen = random.choice(subdirs)
|
| 17 |
+
chosen_dir = os.path.join(src_root, chosen)
|
| 18 |
+
pth_file = os.path.join(chosen_dir, "encoded_video-480p.pth")
|
| 19 |
+
if not os.path.exists(pth_file):
|
| 20 |
+
print(f"{pth_file} 不存在,跳过")
|
| 21 |
+
continue
|
| 22 |
+
|
| 23 |
+
# 生成输出文件名
|
| 24 |
+
out_file = os.path.join(dst_root, f"{chosen}.mp4")
|
| 25 |
+
print(f"开始生成: {pth_file} -> {out_file}")
|
| 26 |
+
|
| 27 |
+
# 构造命令
|
| 28 |
+
cmd = [
|
| 29 |
+
"python", infer_script,
|
| 30 |
+
"--condition_pth", pth_file,
|
| 31 |
+
"--output_path", out_file,
|
| 32 |
+
"--prompt", "a car is driving",
|
| 33 |
+
"--modality_type", "nuscenes",
|
| 34 |
+
"--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt"
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
# 仅使用第二张 GPU
|
| 38 |
+
env = os.environ.copy()
|
| 39 |
+
env["CUDA_VISIBLE_DEVICES"] = "1"
|
| 40 |
+
|
| 41 |
+
# 执行推理
|
| 42 |
+
subprocess.run(cmd, env=env)
|
scripts/batch_rt.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import subprocess
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
src_root = "/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded"
|
| 7 |
+
dst_root = "/share_zhuyixuan05/zhuyixuan05/New_RT"
|
| 8 |
+
infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径
|
| 9 |
+
|
| 10 |
+
while True:
|
| 11 |
+
# 随机选择一个子文件夹
|
| 12 |
+
subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
|
| 13 |
+
if not subdirs:
|
| 14 |
+
print("没有可用的子文件夹")
|
| 15 |
+
break
|
| 16 |
+
chosen = random.choice(subdirs)
|
| 17 |
+
chosen_dir = os.path.join(src_root, chosen)
|
| 18 |
+
pth_file = os.path.join(chosen_dir, "encoded_video.pth")
|
| 19 |
+
if not os.path.exists(pth_file):
|
| 20 |
+
print(f"{pth_file} 不存在,跳过")
|
| 21 |
+
continue
|
| 22 |
+
|
| 23 |
+
# 生成输出文件名
|
| 24 |
+
out_file = os.path.join(dst_root, f"{chosen}.mp4")
|
| 25 |
+
print(f"开始生成: {pth_file} -> {out_file}")
|
| 26 |
+
|
| 27 |
+
# 构造命令
|
| 28 |
+
cmd = [
|
| 29 |
+
"python", infer_script,
|
| 30 |
+
"--condition_pth", pth_file,
|
| 31 |
+
"--output_path", out_file,
|
| 32 |
+
"--prompt", "A robotic arm is moving the object",
|
| 33 |
+
"--modality_type", "openx",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
# 仅使用第二张 GPU
|
| 37 |
+
env = os.environ.copy()
|
| 38 |
+
env["CUDA_VISIBLE_DEVICES"] = "1"
|
| 39 |
+
|
| 40 |
+
# 执行推理
|
| 41 |
+
subprocess.run(cmd, env=env)
|
scripts/batch_spa.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import subprocess
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
src_root = "/share_zhuyixuan05/zhuyixuan05/spatialvid"
|
| 7 |
+
dst_root = "/share_zhuyixuan05/zhuyixuan05/New_spatialvid_right"
|
| 8 |
+
infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径
|
| 9 |
+
|
| 10 |
+
while True:
|
| 11 |
+
# 随机选择一个子文件夹
|
| 12 |
+
subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
|
| 13 |
+
if not subdirs:
|
| 14 |
+
print("没有可用的子文件夹")
|
| 15 |
+
break
|
| 16 |
+
chosen = random.choice(subdirs)
|
| 17 |
+
chosen_dir = os.path.join(src_root, chosen)
|
| 18 |
+
pth_file = os.path.join(chosen_dir, "encoded_video.pth")
|
| 19 |
+
if not os.path.exists(pth_file):
|
| 20 |
+
print(f"{pth_file} 不存在,跳过")
|
| 21 |
+
continue
|
| 22 |
+
|
| 23 |
+
# 生成输出文件名
|
| 24 |
+
out_file = os.path.join(dst_root, f"{chosen}.mp4")
|
| 25 |
+
print(f"开始生成: {pth_file} -> {out_file}")
|
| 26 |
+
|
| 27 |
+
# 构造命令
|
| 28 |
+
cmd = [
|
| 29 |
+
"python", infer_script,
|
| 30 |
+
"--condition_pth", pth_file,
|
| 31 |
+
"--output_path", out_file,
|
| 32 |
+
"--prompt", "exploring the world",
|
| 33 |
+
"--modality_type", "sekai",
|
| 34 |
+
#"--direction", "left",
|
| 35 |
+
"--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt"
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
# 仅使用第二张 GPU
|
| 39 |
+
env = os.environ.copy()
|
| 40 |
+
env["CUDA_VISIBLE_DEVICES"] = "0"
|
| 41 |
+
|
| 42 |
+
# 执行推理
|
| 43 |
+
subprocess.run(cmd, env=env)
|
scripts/batch_walk.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import subprocess
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
src_root = "/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes"
|
| 7 |
+
dst_root = "/share_zhuyixuan05/zhuyixuan05/New_walk"
|
| 8 |
+
infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径
|
| 9 |
+
|
| 10 |
+
while True:
|
| 11 |
+
# 随机选择一个子文件夹
|
| 12 |
+
subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
|
| 13 |
+
if not subdirs:
|
| 14 |
+
print("没有可用的子文件夹")
|
| 15 |
+
break
|
| 16 |
+
chosen = random.choice(subdirs)
|
| 17 |
+
chosen_dir = os.path.join(src_root, chosen)
|
| 18 |
+
pth_file = os.path.join(chosen_dir, "encoded_video-480p.pth")
|
| 19 |
+
if not os.path.exists(pth_file):
|
| 20 |
+
print(f"{pth_file} 不存在,跳过")
|
| 21 |
+
continue
|
| 22 |
+
|
| 23 |
+
# 生成输出文件名
|
| 24 |
+
out_file = os.path.join(dst_root, f"{chosen}.mp4")
|
| 25 |
+
print(f"开始生成: {pth_file} -> {out_file}")
|
| 26 |
+
|
| 27 |
+
# 构造命令
|
| 28 |
+
cmd = [
|
| 29 |
+
"python", infer_script,
|
| 30 |
+
"--condition_pth", pth_file,
|
| 31 |
+
"--output_path", out_file,
|
| 32 |
+
"--prompt", "a car is driving",
|
| 33 |
+
"--modality_type", "nuscenes",
|
| 34 |
+
"--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt"
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
# 仅使用第二张 GPU
|
| 38 |
+
env = os.environ.copy()
|
| 39 |
+
env["CUDA_VISIBLE_DEVICES"] = "1"
|
| 40 |
+
|
| 41 |
+
# 执行推理
|
| 42 |
+
subprocess.run(cmd, env=env)
|
scripts/check.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
def load_checkpoint(ckpt_path):
|
| 8 |
+
"""加载检查点文件"""
|
| 9 |
+
if not os.path.exists(ckpt_path):
|
| 10 |
+
return None
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
state_dict = torch.load(ckpt_path, map_location='cpu')
|
| 14 |
+
return state_dict
|
| 15 |
+
except Exception as e:
|
| 16 |
+
print(f"❌ 加载检查点失败: {e}")
|
| 17 |
+
return None
|
| 18 |
+
|
| 19 |
+
def compare_parameters(state_dict1, state_dict2, threshold=1e-8):
|
| 20 |
+
"""比较两个状态字典的参数差异"""
|
| 21 |
+
if state_dict1 is None or state_dict2 is None:
|
| 22 |
+
return None
|
| 23 |
+
|
| 24 |
+
updated_params = {}
|
| 25 |
+
unchanged_params = {}
|
| 26 |
+
|
| 27 |
+
for name, param1 in state_dict1.items():
|
| 28 |
+
if name in state_dict2:
|
| 29 |
+
param2 = state_dict2[name]
|
| 30 |
+
|
| 31 |
+
# 计算参数差异
|
| 32 |
+
diff = torch.abs(param1 - param2)
|
| 33 |
+
max_diff = torch.max(diff).item()
|
| 34 |
+
mean_diff = torch.mean(diff).item()
|
| 35 |
+
|
| 36 |
+
if max_diff > threshold:
|
| 37 |
+
updated_params[name] = {
|
| 38 |
+
'max_diff': max_diff,
|
| 39 |
+
'mean_diff': mean_diff,
|
| 40 |
+
'shape': param1.shape
|
| 41 |
+
}
|
| 42 |
+
else:
|
| 43 |
+
unchanged_params[name] = {
|
| 44 |
+
'max_diff': max_diff,
|
| 45 |
+
'mean_diff': mean_diff,
|
| 46 |
+
'shape': param1.shape
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
return updated_params, unchanged_params
|
| 50 |
+
|
| 51 |
+
def categorize_parameters(param_dict):
|
| 52 |
+
"""将参数按类型分类"""
|
| 53 |
+
categories = {
|
| 54 |
+
'moe_related': {},
|
| 55 |
+
'camera_related': {},
|
| 56 |
+
'framepack_related': {},
|
| 57 |
+
'attention': {},
|
| 58 |
+
'other': {}
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
for name, info in param_dict.items():
|
| 62 |
+
if any(keyword in name.lower() for keyword in ['moe', 'gate', 'expert', 'processor']):
|
| 63 |
+
categories['moe_related'][name] = info
|
| 64 |
+
elif any(keyword in name.lower() for keyword in ['cam_encoder', 'projector', 'camera']):
|
| 65 |
+
categories['camera_related'][name] = info
|
| 66 |
+
elif any(keyword in name.lower() for keyword in ['clean_x_embedder', 'framepack']):
|
| 67 |
+
categories['framepack_related'][name] = info
|
| 68 |
+
elif any(keyword in name.lower() for keyword in ['attn', 'attention']):
|
| 69 |
+
categories['attention'][name] = info
|
| 70 |
+
else:
|
| 71 |
+
categories['other'][name] = info
|
| 72 |
+
|
| 73 |
+
return categories
|
| 74 |
+
|
| 75 |
+
def print_category_summary(category_name, params, color_code=''):
|
| 76 |
+
"""打印某类参数的摘要"""
|
| 77 |
+
if not params:
|
| 78 |
+
print(f"{color_code} {category_name}: 无参数")
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
total_params = len(params)
|
| 82 |
+
max_diffs = [info['max_diff'] for info in params.values()]
|
| 83 |
+
mean_diffs = [info['mean_diff'] for info in params.values()]
|
| 84 |
+
|
| 85 |
+
print(f"{color_code} {category_name} ({total_params} 个参数):")
|
| 86 |
+
print(f" 最大差异范围: {min(max_diffs):.2e} ~ {max(max_diffs):.2e}")
|
| 87 |
+
print(f" 平均差异范围: {min(mean_diffs):.2e} ~ {max(mean_diffs):.2e}")
|
| 88 |
+
|
| 89 |
+
# 显示前5个最大变化的参数
|
| 90 |
+
sorted_params = sorted(params.items(), key=lambda x: x[1]['max_diff'], reverse=True)
|
| 91 |
+
print(f" 变化最大的参数:")
|
| 92 |
+
for i, (name, info) in enumerate(sorted_params[:100]):
|
| 93 |
+
shape_str = 'x'.join(map(str, info['shape']))
|
| 94 |
+
print(f" {i+1}. {name} [{shape_str}]: max_diff={info['max_diff']:.2e}")
|
| 95 |
+
|
| 96 |
+
def monitor_training(checkpoint_dir, check_interval=60):
|
| 97 |
+
"""监控训练过程中的参数更新"""
|
| 98 |
+
print(f"🔍 开始监控训练进度...")
|
| 99 |
+
print(f"📁 检查点目录: {checkpoint_dir}")
|
| 100 |
+
print(f"⏰ 检查间隔: {check_interval}秒")
|
| 101 |
+
print("=" * 80)
|
| 102 |
+
|
| 103 |
+
previous_ckpt = None
|
| 104 |
+
previous_step = -1
|
| 105 |
+
|
| 106 |
+
while True:
|
| 107 |
+
try:
|
| 108 |
+
# 查找最新的检查点
|
| 109 |
+
if not os.path.exists(checkpoint_dir):
|
| 110 |
+
print(f"❌ 检查点目录不存在: {checkpoint_dir}")
|
| 111 |
+
time.sleep(check_interval)
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
ckpt_files = [f for f in os.listdir(checkpoint_dir) if f.startswith('step') and f.endswith('.ckpt')]
|
| 115 |
+
if not ckpt_files:
|
| 116 |
+
print("⏳ 未找到检查点文件,等待中...")
|
| 117 |
+
time.sleep(check_interval)
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
# 按步数排序,获取最新的
|
| 121 |
+
ckpt_files.sort(key=lambda x: int(x.replace('step', '').replace('.ckpt', '')))
|
| 122 |
+
latest_ckpt_file = ckpt_files[-1]
|
| 123 |
+
latest_ckpt_path = os.path.join(checkpoint_dir, latest_ckpt_file)
|
| 124 |
+
|
| 125 |
+
# 提取步数
|
| 126 |
+
current_step = int(latest_ckpt_file.replace('step', '').replace('.ckpt', ''))
|
| 127 |
+
|
| 128 |
+
if current_step <= previous_step:
|
| 129 |
+
print(f"⏳ 等待新的检查点... (当前: step{current_step})")
|
| 130 |
+
time.sleep(check_interval)
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
print(f"\n🔍 发现新检查点: {latest_ckpt_file}")
|
| 134 |
+
|
| 135 |
+
# 加载当前检查点
|
| 136 |
+
current_state_dict = load_checkpoint(latest_ckpt_path)
|
| 137 |
+
if current_state_dict is None:
|
| 138 |
+
print("❌ 无法加载当前检查点")
|
| 139 |
+
time.sleep(check_interval)
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
if previous_ckpt is not None:
|
| 143 |
+
print(f"📊 比较 step{previous_step} -> step{current_step}")
|
| 144 |
+
|
| 145 |
+
# 比较参数
|
| 146 |
+
updated_params, unchanged_params = compare_parameters(
|
| 147 |
+
previous_ckpt, current_state_dict, threshold=1e-8
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
if updated_params is None:
|
| 151 |
+
print("❌ 参数比较失败")
|
| 152 |
+
else:
|
| 153 |
+
# 分类显示结果
|
| 154 |
+
updated_categories = categorize_parameters(updated_params)
|
| 155 |
+
unchanged_categories = categorize_parameters(unchanged_params)
|
| 156 |
+
|
| 157 |
+
print(f"\n✅ 已更新的参数 (总共 {len(updated_params)} 个):")
|
| 158 |
+
print_category_summary("MoE相关", updated_categories['moe_related'], '🔥')
|
| 159 |
+
print_category_summary("Camera相关", updated_categories['camera_related'], '📷')
|
| 160 |
+
print_category_summary("FramePack相关", updated_categories['framepack_related'], '🎞️')
|
| 161 |
+
print_category_summary("注意力相关", updated_categories['attention'], '👁️')
|
| 162 |
+
print_category_summary("其他", updated_categories['other'], '📦')
|
| 163 |
+
|
| 164 |
+
print(f"\n⚠️ 未更新的参数 (总共 {len(unchanged_params)} 个):")
|
| 165 |
+
print_category_summary("MoE相关", unchanged_categories['moe_related'], '❄️')
|
| 166 |
+
print_category_summary("Camera相关", unchanged_categories['camera_related'], '❄️')
|
| 167 |
+
print_category_summary("FramePack相关", unchanged_categories['framepack_related'], '❄️')
|
| 168 |
+
print_category_summary("注意力相关", unchanged_categories['attention'], '❄️')
|
| 169 |
+
print_category_summary("其他", unchanged_categories['other'], '❄️')
|
| 170 |
+
|
| 171 |
+
# 检查关键组件是否在更新
|
| 172 |
+
critical_keywords = ['moe', 'cam_encoder', 'projector', 'clean_x_embedder']
|
| 173 |
+
critical_updated = any(
|
| 174 |
+
any(keyword in name.lower() for keyword in critical_keywords)
|
| 175 |
+
for name in updated_params.keys()
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
if critical_updated:
|
| 179 |
+
print("\n✅ 关键组件正在更新!")
|
| 180 |
+
else:
|
| 181 |
+
print("\n❌ 警告:关键组件可能未在更新!")
|
| 182 |
+
|
| 183 |
+
# 计算更新率
|
| 184 |
+
total_params = len(updated_params) + len(unchanged_params)
|
| 185 |
+
update_rate = len(updated_params) / total_params * 100
|
| 186 |
+
print(f"\n📈 参数更新率: {update_rate:.1f}% ({len(updated_params)}/{total_params})")
|
| 187 |
+
|
| 188 |
+
# 保存当前状态用于下次比较
|
| 189 |
+
previous_ckpt = current_state_dict
|
| 190 |
+
previous_step = current_step
|
| 191 |
+
|
| 192 |
+
print("=" * 80)
|
| 193 |
+
time.sleep(check_interval)
|
| 194 |
+
|
| 195 |
+
except KeyboardInterrupt:
|
| 196 |
+
print("\n👋 监控已停止")
|
| 197 |
+
break
|
| 198 |
+
except Exception as e:
|
| 199 |
+
print(f"❌ 监控过程中出错: {e}")
|
| 200 |
+
time.sleep(check_interval)
|
| 201 |
+
|
| 202 |
+
def compare_two_checkpoints(ckpt1_path, ckpt2_path):
|
| 203 |
+
"""比较两个特定的检查点"""
|
| 204 |
+
print(f"🔍 比较两个检查点:")
|
| 205 |
+
print(f" 检查点1: {ckpt1_path}")
|
| 206 |
+
print(f" 检查点2: {ckpt2_path}")
|
| 207 |
+
print("=" * 80)
|
| 208 |
+
|
| 209 |
+
# 加载检查点
|
| 210 |
+
state_dict1 = load_checkpoint(ckpt1_path)
|
| 211 |
+
state_dict2 = load_checkpoint(ckpt2_path)
|
| 212 |
+
|
| 213 |
+
if state_dict1 is None or state_dict2 is None:
|
| 214 |
+
print("❌ 无法加载检查点文件")
|
| 215 |
+
return
|
| 216 |
+
|
| 217 |
+
# 比较参数
|
| 218 |
+
updated_params, unchanged_params = compare_parameters(state_dict1, state_dict2)
|
| 219 |
+
|
| 220 |
+
if updated_params is None:
|
| 221 |
+
print("❌ 参数比较失败")
|
| 222 |
+
return
|
| 223 |
+
|
| 224 |
+
# 分类显示结果
|
| 225 |
+
updated_categories = categorize_parameters(updated_params)
|
| 226 |
+
unchanged_categories = categorize_parameters(unchanged_params)
|
| 227 |
+
|
| 228 |
+
print(f"\n✅ 已更新的参数 (总共 {len(updated_params)} 个):")
|
| 229 |
+
for category_name, params in updated_categories.items():
|
| 230 |
+
print_category_summary(category_name.replace('_', ' ').title(), params, '🔥')
|
| 231 |
+
|
| 232 |
+
print(f"\n⚠️ 未更新的参数 (总共 {len(unchanged_params)} 个):")
|
| 233 |
+
for category_name, params in unchanged_categories.items():
|
| 234 |
+
print_category_summary(category_name.replace('_', ' ').title(), params, '❄️')
|
| 235 |
+
|
| 236 |
+
# 计算更新率
|
| 237 |
+
total_params = len(updated_params) + len(unchanged_params)
|
| 238 |
+
update_rate = len(updated_params) / total_params * 100
|
| 239 |
+
print(f"\n📈 参数更新率: {update_rate:.1f}% ({len(updated_params)}/{total_params})")
|
| 240 |
+
|
| 241 |
+
if __name__ == '__main__':
|
| 242 |
+
parser = argparse.ArgumentParser(description="检查模型参数更新情况")
|
| 243 |
+
parser.add_argument("--checkpoint_dir", type=str,
|
| 244 |
+
default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe",
|
| 245 |
+
help="检查点目录路径")
|
| 246 |
+
parser.add_argument("--compare", default=True,
|
| 247 |
+
help="比较两个特定检查点,而不是监控")
|
| 248 |
+
parser.add_argument("--ckpt1", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step1500_origin_cam_4.ckpt")
|
| 249 |
+
parser.add_argument("--ckpt2", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step500_origin_cam_4.ckpt")
|
| 250 |
+
parser.add_argument("--interval", type=int, default=60,
|
| 251 |
+
help="监控检查间隔(秒)")
|
| 252 |
+
parser.add_argument("--threshold", type=float, default=1e-8,
|
| 253 |
+
help="参数变化阈值")
|
| 254 |
+
|
| 255 |
+
args = parser.parse_args()
|
| 256 |
+
|
| 257 |
+
if args.compare:
|
| 258 |
+
if not args.ckpt1 or not args.ckpt2:
|
| 259 |
+
print("❌ 比较模式需要指定 --ckpt1 和 --ckpt2")
|
| 260 |
+
else:
|
| 261 |
+
compare_two_checkpoints(args.ckpt1, args.ckpt2)
|
| 262 |
+
else:
|
| 263 |
+
monitor_training(args.checkpoint_dir, args.interval)
|
scripts/decode_openx.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import imageio
|
| 6 |
+
import argparse
|
| 7 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
class VideoDecoder:
|
| 12 |
+
def __init__(self, vae_path, device="cuda"):
|
| 13 |
+
"""初始化视频解码器"""
|
| 14 |
+
self.device = device
|
| 15 |
+
|
| 16 |
+
# 初始化模型管理器
|
| 17 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 18 |
+
model_manager.load_models([vae_path])
|
| 19 |
+
|
| 20 |
+
# 创建pipeline并只保留VAE
|
| 21 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 22 |
+
self.pipe = self.pipe.to(device)
|
| 23 |
+
|
| 24 |
+
# 🔧 关键修复:确保VAE及其所有组件都在正确设备上
|
| 25 |
+
self.pipe.vae = self.pipe.vae.to(device)
|
| 26 |
+
if hasattr(self.pipe.vae, 'model'):
|
| 27 |
+
self.pipe.vae.model = self.pipe.vae.model.to(device)
|
| 28 |
+
|
| 29 |
+
print(f"✅ VAE解码器初始化完成,设备: {device}")
|
| 30 |
+
|
| 31 |
+
def decode_latents_to_video(self, latents, output_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 32 |
+
"""
|
| 33 |
+
将latents解码为视频 - 修正版本,修复维度处理问题
|
| 34 |
+
"""
|
| 35 |
+
print(f"🔧 开始解码latents...")
|
| 36 |
+
print(f"输入latents形状: {latents.shape}")
|
| 37 |
+
print(f"输入latents设备: {latents.device}")
|
| 38 |
+
print(f"输入latents数据类型: {latents.dtype}")
|
| 39 |
+
|
| 40 |
+
# 确保latents有batch维度
|
| 41 |
+
if len(latents.shape) == 4: # [C, T, H, W]
|
| 42 |
+
latents = latents.unsqueeze(0) # -> [1, C, T, H, W]
|
| 43 |
+
|
| 44 |
+
# 🔧 关键修正:确保latents在正确的设备上且数据类型匹配
|
| 45 |
+
model_dtype = next(self.pipe.vae.parameters()).dtype
|
| 46 |
+
model_device = next(self.pipe.vae.parameters()).device
|
| 47 |
+
|
| 48 |
+
print(f"模型设备: {model_device}")
|
| 49 |
+
print(f"模型数据类型: {model_dtype}")
|
| 50 |
+
|
| 51 |
+
# 将latents移动到正确的设备和数据类型
|
| 52 |
+
latents = latents.to(device=model_device, dtype=model_dtype)
|
| 53 |
+
|
| 54 |
+
print(f"解码latents形状: {latents.shape}")
|
| 55 |
+
print(f"解码latents设备: {latents.device}")
|
| 56 |
+
print(f"解码latents数据类型: {latents.dtype}")
|
| 57 |
+
|
| 58 |
+
# 🔧 强制设置pipeline设备,确保所有操作在同一设备上
|
| 59 |
+
self.pipe.device = model_device
|
| 60 |
+
|
| 61 |
+
# 使用VAE解码
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
try:
|
| 64 |
+
if tiled:
|
| 65 |
+
print("🔧 尝试tiled解码...")
|
| 66 |
+
decoded_video = self.pipe.decode_video(
|
| 67 |
+
latents,
|
| 68 |
+
tiled=True,
|
| 69 |
+
tile_size=tile_size,
|
| 70 |
+
tile_stride=tile_stride
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
print("🔧 使用非tiled解码...")
|
| 74 |
+
decoded_video = self.pipe.decode_video(latents, tiled=False)
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"decode_video失败,错误: {e}")
|
| 78 |
+
import traceback
|
| 79 |
+
traceback.print_exc()
|
| 80 |
+
|
| 81 |
+
# 🔧 fallback: 尝试直接调用VAE
|
| 82 |
+
try:
|
| 83 |
+
print("🔧 尝试直接调用VAE解码...")
|
| 84 |
+
decoded_video = self.pipe.vae.decode(
|
| 85 |
+
latents.squeeze(0), # 移除batch维度 [C, T, H, W]
|
| 86 |
+
device=model_device,
|
| 87 |
+
tiled=False
|
| 88 |
+
)
|
| 89 |
+
# 手动调整维度: VAE输出 [T, H, W, C] -> [1, T, H, W, C]
|
| 90 |
+
if len(decoded_video.shape) == 4: # [T, H, W, C]
|
| 91 |
+
decoded_video = decoded_video.unsqueeze(0) # -> [1, T, H, W, C]
|
| 92 |
+
except Exception as e2:
|
| 93 |
+
print(f"直接VAE解码也失败: {e2}")
|
| 94 |
+
raise e2
|
| 95 |
+
|
| 96 |
+
print(f"解码后视频形状: {decoded_video.shape}")
|
| 97 |
+
|
| 98 |
+
# 🔧 关键修正:正确处理维度顺序
|
| 99 |
+
video_np = None
|
| 100 |
+
|
| 101 |
+
if len(decoded_video.shape) == 5:
|
| 102 |
+
# 检查不同的可能维度顺序
|
| 103 |
+
if decoded_video.shape == torch.Size([1, 3, 113, 480, 832]):
|
| 104 |
+
# 格式: [B, C, T, H, W] -> 需要转换为 [T, H, W, C]
|
| 105 |
+
print("🔧 检测到格式: [B, C, T, H, W]")
|
| 106 |
+
video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy() # [T, H, W, C]
|
| 107 |
+
elif decoded_video.shape[1] == 3:
|
| 108 |
+
# 如果第二个维度是3,可能是 [B, C, T, H, W]
|
| 109 |
+
print("🔧 检测到可能的格式: [B, C, T, H, W]")
|
| 110 |
+
video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy() # [T, H, W, C]
|
| 111 |
+
elif decoded_video.shape[-1] == 3:
|
| 112 |
+
# 如果最后一个维度是3,可能是 [B, T, H, W, C]
|
| 113 |
+
print("🔧 检测到格式: [B, T, H, W, C]")
|
| 114 |
+
video_np = decoded_video[0].to(torch.float32).cpu().numpy() # [T, H, W, C]
|
| 115 |
+
else:
|
| 116 |
+
# 尝试找到维度为3的位置
|
| 117 |
+
shape = list(decoded_video.shape)
|
| 118 |
+
if 3 in shape:
|
| 119 |
+
channel_dim = shape.index(3)
|
| 120 |
+
print(f"🔧 检测到通道维度在位置: {channel_dim}")
|
| 121 |
+
|
| 122 |
+
if channel_dim == 1: # [B, C, T, H, W]
|
| 123 |
+
video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy()
|
| 124 |
+
elif channel_dim == 4: # [B, T, H, W, C]
|
| 125 |
+
video_np = decoded_video[0].to(torch.float32).cpu().numpy()
|
| 126 |
+
else:
|
| 127 |
+
print(f"⚠️ 未知的通道维度位置: {channel_dim}")
|
| 128 |
+
raise ValueError(f"Cannot handle channel dimension at position {channel_dim}")
|
| 129 |
+
else:
|
| 130 |
+
print(f"⚠️ 未找到通道维度为3的位置,形状: {decoded_video.shape}")
|
| 131 |
+
raise ValueError(f"Cannot find channel dimension of size 3 in shape {decoded_video.shape}")
|
| 132 |
+
|
| 133 |
+
elif len(decoded_video.shape) == 4:
|
| 134 |
+
# 4维张量,检查可能的格式
|
| 135 |
+
if decoded_video.shape[-1] == 3: # [T, H, W, C]
|
| 136 |
+
video_np = decoded_video.to(torch.float32).cpu().numpy()
|
| 137 |
+
elif decoded_video.shape[0] == 3: # [C, T, H, W]
|
| 138 |
+
video_np = decoded_video.permute(1, 2, 3, 0).to(torch.float32).cpu().numpy()
|
| 139 |
+
else:
|
| 140 |
+
print(f"⚠️ 无法处理的4D视频形状: {decoded_video.shape}")
|
| 141 |
+
raise ValueError(f"Cannot handle 4D video tensor shape: {decoded_video.shape}")
|
| 142 |
+
else:
|
| 143 |
+
print(f"⚠️ 意外的视频维度数: {len(decoded_video.shape)}")
|
| 144 |
+
raise ValueError(f"Unexpected video tensor dimensions: {decoded_video.shape}")
|
| 145 |
+
|
| 146 |
+
if video_np is None:
|
| 147 |
+
raise ValueError("Failed to convert video tensor to numpy array")
|
| 148 |
+
|
| 149 |
+
print(f"转换后视频数组形状: {video_np.shape}")
|
| 150 |
+
|
| 151 |
+
# 🔧 验证最终形状
|
| 152 |
+
if len(video_np.shape) != 4:
|
| 153 |
+
raise ValueError(f"Expected 4D array [T, H, W, C], got {video_np.shape}")
|
| 154 |
+
|
| 155 |
+
if video_np.shape[-1] != 3:
|
| 156 |
+
print(f"⚠️ 通道数异常: 期望3,实际{video_np.shape[-1]}")
|
| 157 |
+
print(f"完整形状: {video_np.shape}")
|
| 158 |
+
# 尝试其他维度排列
|
| 159 |
+
if video_np.shape[0] == 3: # [C, T, H, W]
|
| 160 |
+
print("🔧 尝试重新排列: [C, T, H, W] -> [T, H, W, C]")
|
| 161 |
+
video_np = np.transpose(video_np, (1, 2, 3, 0))
|
| 162 |
+
elif video_np.shape[1] == 3: # [T, C, H, W]
|
| 163 |
+
print("🔧 尝试重新排列: [T, C, H, W] -> [T, H, W, C]")
|
| 164 |
+
video_np = np.transpose(video_np, (0, 2, 3, 1))
|
| 165 |
+
else:
|
| 166 |
+
raise ValueError(f"Expected 3 channels (RGB), got {video_np.shape[-1]} channels")
|
| 167 |
+
|
| 168 |
+
# 反归一化
|
| 169 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1) # 反归一化
|
| 170 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 171 |
+
|
| 172 |
+
print(f"最终视频数组形状: {video_np.shape}")
|
| 173 |
+
print(f"视频数组值范围: {video_np.min()} - {video_np.max()}")
|
| 174 |
+
|
| 175 |
+
# 保存视频
|
| 176 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
with imageio.get_writer(output_path, fps=10, quality=8) as writer:
|
| 180 |
+
for frame_idx, frame in enumerate(video_np):
|
| 181 |
+
# 🔧 验证每一帧的形状
|
| 182 |
+
if len(frame.shape) != 3 or frame.shape[-1] != 3:
|
| 183 |
+
print(f"⚠️ 帧 {frame_idx} 形状异常: {frame.shape}")
|
| 184 |
+
continue
|
| 185 |
+
|
| 186 |
+
writer.append_data(frame)
|
| 187 |
+
if frame_idx % 10 == 0:
|
| 188 |
+
print(f" 写入帧 {frame_idx}/{len(video_np)}")
|
| 189 |
+
except Exception as e:
|
| 190 |
+
print(f"保存视频失败: {e}")
|
| 191 |
+
# 🔧 尝试保存前几帧为图片进行调试
|
| 192 |
+
debug_dir = os.path.join(os.path.dirname(output_path), "debug_frames")
|
| 193 |
+
os.makedirs(debug_dir, exist_ok=True)
|
| 194 |
+
|
| 195 |
+
for i in range(min(5, len(video_np))):
|
| 196 |
+
frame = video_np[i]
|
| 197 |
+
debug_path = os.path.join(debug_dir, f"debug_frame_{i}.png")
|
| 198 |
+
try:
|
| 199 |
+
if len(frame.shape) == 3 and frame.shape[-1] == 3:
|
| 200 |
+
Image.fromarray(frame).save(debug_path)
|
| 201 |
+
print(f"调试: 保存帧 {i} 到 {debug_path}")
|
| 202 |
+
else:
|
| 203 |
+
print(f"调试: 帧 {i} 形状异常: {frame.shape}")
|
| 204 |
+
except Exception as e2:
|
| 205 |
+
print(f"调试: 保存帧 {i} 失败: {e2}")
|
| 206 |
+
raise e
|
| 207 |
+
|
| 208 |
+
print(f"✅ 视频保存到: {output_path}")
|
| 209 |
+
return video_np
|
| 210 |
+
|
| 211 |
+
def save_frames_as_images(self, video_np, output_dir, prefix="frame"):
|
| 212 |
+
"""将视频帧保存为单独的图像文件"""
|
| 213 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 214 |
+
|
| 215 |
+
for i, frame in enumerate(video_np):
|
| 216 |
+
frame_path = os.path.join(output_dir, f"{prefix}_{i:04d}.png")
|
| 217 |
+
# 🔧 验证帧形状
|
| 218 |
+
if len(frame.shape) == 3 and frame.shape[-1] == 3:
|
| 219 |
+
Image.fromarray(frame).save(frame_path)
|
| 220 |
+
else:
|
| 221 |
+
print(f"⚠️ 跳过形状异常的帧 {i}: {frame.shape}")
|
| 222 |
+
|
| 223 |
+
print(f"✅ 保存了 {len(video_np)} 帧到: {output_dir}")
|
| 224 |
+
|
| 225 |
+
def decode_single_episode(encoded_pth_path, vae_path, output_base_dir, device="cuda"):
|
| 226 |
+
"""解码单个episode的编码数据 - 修正版本"""
|
| 227 |
+
print(f"\n🔧 解码episode: {encoded_pth_path}")
|
| 228 |
+
|
| 229 |
+
# 加载编码数据
|
| 230 |
+
try:
|
| 231 |
+
encoded_data = torch.load(encoded_pth_path, weights_only=False, map_location="cpu")
|
| 232 |
+
print(f"✅ 成功加载编码数据")
|
| 233 |
+
except Exception as e:
|
| 234 |
+
print(f"❌ 加载编码数据失败: {e}")
|
| 235 |
+
return False
|
| 236 |
+
|
| 237 |
+
# 检查数据结构
|
| 238 |
+
print("🔍 编码数据结构:")
|
| 239 |
+
for key, value in encoded_data.items():
|
| 240 |
+
if isinstance(value, torch.Tensor):
|
| 241 |
+
print(f" - {key}: {value.shape}, dtype: {value.dtype}, device: {value.device}")
|
| 242 |
+
elif isinstance(value, dict):
|
| 243 |
+
print(f" - {key}: dict with keys {list(value.keys())}")
|
| 244 |
+
else:
|
| 245 |
+
print(f" - {key}: {type(value)}")
|
| 246 |
+
|
| 247 |
+
# 获取latents
|
| 248 |
+
latents = encoded_data.get('latents')
|
| 249 |
+
if latents is None:
|
| 250 |
+
print("❌ 未找到latents数据")
|
| 251 |
+
return False
|
| 252 |
+
|
| 253 |
+
# 🔧 确保latents在CPU上(加载时的默认状态)
|
| 254 |
+
if latents.device != torch.device('cpu'):
|
| 255 |
+
latents = latents.cpu()
|
| 256 |
+
print(f"🔧 将latents移动到CPU: {latents.device}")
|
| 257 |
+
|
| 258 |
+
episode_info = encoded_data.get('episode_info', {})
|
| 259 |
+
episode_idx = episode_info.get('episode_idx', 'unknown')
|
| 260 |
+
total_frames = episode_info.get('total_frames', latents.shape[1] * 4) # 估算原始帧数
|
| 261 |
+
|
| 262 |
+
print(f"Episode信息:")
|
| 263 |
+
print(f" - Episode索引: {episode_idx}")
|
| 264 |
+
print(f" - Latents形状: {latents.shape}")
|
| 265 |
+
print(f" - Latents设备: {latents.device}")
|
| 266 |
+
print(f" - Latents数据类型: {latents.dtype}")
|
| 267 |
+
print(f" - 原始总帧数: {total_frames}")
|
| 268 |
+
print(f" - 压缩后帧数: {latents.shape[1]}")
|
| 269 |
+
|
| 270 |
+
# 创建输出目录
|
| 271 |
+
episode_name = f"episode_{episode_idx:06d}" if isinstance(episode_idx, int) else f"episode_{episode_idx}"
|
| 272 |
+
output_dir = os.path.join(output_base_dir, episode_name)
|
| 273 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 274 |
+
|
| 275 |
+
# 初始化解码器
|
| 276 |
+
try:
|
| 277 |
+
decoder = VideoDecoder(vae_path, device)
|
| 278 |
+
except Exception as e:
|
| 279 |
+
print(f"❌ 初始化解码器失败: {e}")
|
| 280 |
+
return False
|
| 281 |
+
|
| 282 |
+
# 解码为视频
|
| 283 |
+
video_output_path = os.path.join(output_dir, "decoded_video.mp4")
|
| 284 |
+
try:
|
| 285 |
+
video_np = decoder.decode_latents_to_video(
|
| 286 |
+
latents,
|
| 287 |
+
video_output_path,
|
| 288 |
+
tiled=False, # 🔧 首先尝试非tiled解码,避免tiled的复杂性
|
| 289 |
+
tile_size=(34, 34),
|
| 290 |
+
tile_stride=(18, 16)
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# 保存前几帧为图像(用于快速检查)
|
| 294 |
+
frames_dir = os.path.join(output_dir, "frames")
|
| 295 |
+
sample_frames = video_np[:min(10, len(video_np))] # 只保存前10帧
|
| 296 |
+
decoder.save_frames_as_images(sample_frames, frames_dir, f"frame_{episode_idx}")
|
| 297 |
+
|
| 298 |
+
# 保存解码信息
|
| 299 |
+
decode_info = {
|
| 300 |
+
"source_pth": encoded_pth_path,
|
| 301 |
+
"decoded_video_path": video_output_path,
|
| 302 |
+
"latents_shape": list(latents.shape),
|
| 303 |
+
"decoded_video_shape": list(video_np.shape),
|
| 304 |
+
"original_total_frames": total_frames,
|
| 305 |
+
"decoded_frames": len(video_np),
|
| 306 |
+
"compression_ratio": total_frames / len(video_np) if len(video_np) > 0 else 0,
|
| 307 |
+
"latents_dtype": str(latents.dtype),
|
| 308 |
+
"latents_device": str(latents.device),
|
| 309 |
+
"vae_compression_ratio": total_frames / latents.shape[1] if latents.shape[1] > 0 else 0
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
info_path = os.path.join(output_dir, "decode_info.json")
|
| 313 |
+
with open(info_path, 'w') as f:
|
| 314 |
+
json.dump(decode_info, f, indent=2)
|
| 315 |
+
|
| 316 |
+
print(f"✅ Episode {episode_idx} 解码完成")
|
| 317 |
+
print(f" - 原始帧数: {total_frames}")
|
| 318 |
+
print(f" - 解码帧数: {len(video_np)}")
|
| 319 |
+
print(f" - 压缩比: {decode_info['compression_ratio']:.2f}")
|
| 320 |
+
print(f" - VAE时间压缩比: {decode_info['vae_compression_ratio']:.2f}")
|
| 321 |
+
return True
|
| 322 |
+
|
| 323 |
+
except Exception as e:
|
| 324 |
+
print(f"❌ 解码失败: {e}")
|
| 325 |
+
import traceback
|
| 326 |
+
traceback.print_exc()
|
| 327 |
+
return False
|
| 328 |
+
|
| 329 |
+
def batch_decode_episodes(encoded_base_dir, vae_path, output_base_dir, max_episodes=None, device="cuda"):
|
| 330 |
+
"""批量解码episodes"""
|
| 331 |
+
print(f"🔧 批量解码Open-X episodes")
|
| 332 |
+
print(f"源目录: {encoded_base_dir}")
|
| 333 |
+
print(f"输出目录: {output_base_dir}")
|
| 334 |
+
|
| 335 |
+
# 查找所有编码的episodes
|
| 336 |
+
episode_dirs = []
|
| 337 |
+
if os.path.exists(encoded_base_dir):
|
| 338 |
+
for item in sorted(os.listdir(encoded_base_dir)): # 排序确保一致性
|
| 339 |
+
episode_dir = os.path.join(encoded_base_dir, item)
|
| 340 |
+
if os.path.isdir(episode_dir):
|
| 341 |
+
encoded_path = os.path.join(episode_dir, "encoded_video.pth")
|
| 342 |
+
if os.path.exists(encoded_path):
|
| 343 |
+
episode_dirs.append(encoded_path)
|
| 344 |
+
|
| 345 |
+
print(f"找到 {len(episode_dirs)} 个编码的episodes")
|
| 346 |
+
|
| 347 |
+
if max_episodes and len(episode_dirs) > max_episodes:
|
| 348 |
+
episode_dirs = episode_dirs[:max_episodes]
|
| 349 |
+
print(f"限制处理前 {max_episodes} 个episodes")
|
| 350 |
+
|
| 351 |
+
# 批量解码
|
| 352 |
+
success_count = 0
|
| 353 |
+
for i, encoded_pth_path in enumerate(tqdm(episode_dirs, desc="解码episodes")):
|
| 354 |
+
print(f"\n{'='*60}")
|
| 355 |
+
print(f"处理 {i+1}/{len(episode_dirs)}: {os.path.basename(os.path.dirname(encoded_pth_path))}")
|
| 356 |
+
|
| 357 |
+
success = decode_single_episode(encoded_pth_path, vae_path, output_base_dir, device)
|
| 358 |
+
if success:
|
| 359 |
+
success_count += 1
|
| 360 |
+
|
| 361 |
+
print(f"当前成功率: {success_count}/{i+1} ({success_count/(i+1)*100:.1f}%)")
|
| 362 |
+
|
| 363 |
+
print(f"\n🎉 批量解码完成!")
|
| 364 |
+
print(f"总处理: {len(episode_dirs)} 个episodes")
|
| 365 |
+
print(f"成功解码: {success_count} 个episodes")
|
| 366 |
+
print(f"成功率: {success_count/len(episode_dirs)*100:.1f}%")
|
| 367 |
+
|
| 368 |
+
def main():
|
| 369 |
+
parser = argparse.ArgumentParser(description="解码Open-X编码的latents以验证正确性 - 修正版本")
|
| 370 |
+
parser.add_argument("--mode", type=str, choices=["single", "batch"], default="batch",
|
| 371 |
+
help="解码模式:single (单个episode) 或 batch (批量)")
|
| 372 |
+
parser.add_argument("--encoded_pth", type=str,
|
| 373 |
+
default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000000/encoded_video.pth",
|
| 374 |
+
help="单个编码文件路径(single模式)")
|
| 375 |
+
parser.add_argument("--encoded_base_dir", type=str,
|
| 376 |
+
default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded",
|
| 377 |
+
help="编码数据基础目录(batch模式)")
|
| 378 |
+
parser.add_argument("--vae_path", type=str,
|
| 379 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 380 |
+
help="VAE模型路径")
|
| 381 |
+
parser.add_argument("--output_dir", type=str,
|
| 382 |
+
default="./decoded_results_fixed",
|
| 383 |
+
help="解码输出目录")
|
| 384 |
+
parser.add_argument("--max_episodes", type=int, default=5,
|
| 385 |
+
help="最大解码episodes数量(batch模式,用于测试)")
|
| 386 |
+
parser.add_argument("--device", type=str, default="cuda",
|
| 387 |
+
help="计算设备")
|
| 388 |
+
|
| 389 |
+
args = parser.parse_args()
|
| 390 |
+
|
| 391 |
+
print("🔧 Open-X Latents 解码验证工具 (修正版本 - Fixed)")
|
| 392 |
+
print(f"模式: {args.mode}")
|
| 393 |
+
print(f"VAE路径: {args.vae_path}")
|
| 394 |
+
print(f"输出目录: {args.output_dir}")
|
| 395 |
+
print(f"设备: {args.device}")
|
| 396 |
+
|
| 397 |
+
# 🔧 检查CUDA可用性
|
| 398 |
+
if args.device == "cuda" and not torch.cuda.is_available():
|
| 399 |
+
print("⚠️ CUDA不可用,切换到CPU")
|
| 400 |
+
args.device = "cpu"
|
| 401 |
+
|
| 402 |
+
# 确保输出目录存在
|
| 403 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 404 |
+
|
| 405 |
+
if args.mode == "single":
|
| 406 |
+
print(f"输入文件: {args.encoded_pth}")
|
| 407 |
+
if not os.path.exists(args.encoded_pth):
|
| 408 |
+
print(f"❌ 输入文件不存在: {args.encoded_pth}")
|
| 409 |
+
return
|
| 410 |
+
|
| 411 |
+
success = decode_single_episode(args.encoded_pth, args.vae_path, args.output_dir, args.device)
|
| 412 |
+
if success:
|
| 413 |
+
print("✅ 单个episode解码成功")
|
| 414 |
+
else:
|
| 415 |
+
print("❌ 单个episode解码失败")
|
| 416 |
+
|
| 417 |
+
elif args.mode == "batch":
|
| 418 |
+
print(f"输入目录: {args.encoded_base_dir}")
|
| 419 |
+
print(f"最大episodes: {args.max_episodes}")
|
| 420 |
+
|
| 421 |
+
if not os.path.exists(args.encoded_base_dir):
|
| 422 |
+
print(f"❌ 输入目录不存在: {args.encoded_base_dir}")
|
| 423 |
+
return
|
| 424 |
+
|
| 425 |
+
batch_decode_episodes(args.encoded_base_dir, args.vae_path, args.output_dir, args.max_episodes, args.device)
|
| 426 |
+
|
| 427 |
+
if __name__ == "__main__":
|
| 428 |
+
main()
|
scripts/download_recam.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import snapshot_download
|
| 2 |
+
|
| 3 |
+
snapshot_download(
|
| 4 |
+
repo_id="KwaiVGI/ReCamMaster-Wan2.1",
|
| 5 |
+
local_dir="models/ReCamMaster/checkpoints",
|
| 6 |
+
resume_download=True # 支持断点续传
|
| 7 |
+
)
|
scripts/encode_dynamic_videos.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import lightning as pl
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 6 |
+
import json
|
| 7 |
+
import imageio
|
| 8 |
+
from torchvision.transforms import v2
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import argparse
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
class VideoEncoder(pl.LightningModule):
|
| 13 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 14 |
+
super().__init__()
|
| 15 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 16 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 17 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 18 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 19 |
+
|
| 20 |
+
self.frame_process = v2.Compose([
|
| 21 |
+
# v2.CenterCrop(size=(900, 1600)),
|
| 22 |
+
# v2.Resize(size=(900, 1600), antialias=True),
|
| 23 |
+
v2.ToTensor(),
|
| 24 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 25 |
+
])
|
| 26 |
+
|
| 27 |
+
def crop_and_resize(self, image):
|
| 28 |
+
width, height = image.size
|
| 29 |
+
width_ori, height_ori_ = 832 , 480
|
| 30 |
+
image = v2.functional.resize(
|
| 31 |
+
image,
|
| 32 |
+
(round(height_ori_), round(width_ori)),
|
| 33 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 34 |
+
)
|
| 35 |
+
return image
|
| 36 |
+
|
| 37 |
+
def load_video_frames(self, video_path):
|
| 38 |
+
"""加载完整视频"""
|
| 39 |
+
reader = imageio.get_reader(video_path)
|
| 40 |
+
frames = []
|
| 41 |
+
|
| 42 |
+
for frame_data in reader:
|
| 43 |
+
frame = Image.fromarray(frame_data)
|
| 44 |
+
frame = self.crop_and_resize(frame)
|
| 45 |
+
frame = self.frame_process(frame)
|
| 46 |
+
frames.append(frame)
|
| 47 |
+
|
| 48 |
+
reader.close()
|
| 49 |
+
|
| 50 |
+
if len(frames) == 0:
|
| 51 |
+
return None
|
| 52 |
+
|
| 53 |
+
frames = torch.stack(frames, dim=0)
|
| 54 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 55 |
+
return frames
|
| 56 |
+
|
| 57 |
+
def encode_scenes(scenes_path, text_encoder_path, vae_path):
|
| 58 |
+
"""编码所有场景的视频"""
|
| 59 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 60 |
+
encoder = encoder.cuda()
|
| 61 |
+
encoder.pipe.device = "cuda"
|
| 62 |
+
|
| 63 |
+
processed_count = 0
|
| 64 |
+
|
| 65 |
+
for idx, scene_name in enumerate(tqdm(os.listdir(scenes_path))):
|
| 66 |
+
if idx < 450:
|
| 67 |
+
continue
|
| 68 |
+
scene_dir = os.path.join(scenes_path, scene_name)
|
| 69 |
+
if not os.path.isdir(scene_dir):
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
# 检查是否已编码
|
| 73 |
+
encoded_path = os.path.join(scene_dir, "encoded_video-480p-1.pth")
|
| 74 |
+
if os.path.exists(encoded_path):
|
| 75 |
+
print(f"Scene {scene_name} already encoded, skipping...")
|
| 76 |
+
continue
|
| 77 |
+
|
| 78 |
+
# 加载场景信息
|
| 79 |
+
scene_info_path = os.path.join(scene_dir, "scene_info.json")
|
| 80 |
+
if not os.path.exists(scene_info_path):
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
with open(scene_info_path, 'r') as f:
|
| 84 |
+
scene_info = json.load(f)
|
| 85 |
+
|
| 86 |
+
# 加载视频
|
| 87 |
+
video_path = os.path.join(scene_dir, scene_info['video_path'])
|
| 88 |
+
if not os.path.exists(video_path):
|
| 89 |
+
print(f"Video not found: {video_path}")
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
print(f"Encoding scene {scene_name}...")
|
| 94 |
+
|
| 95 |
+
# 加载和编码视频
|
| 96 |
+
video_frames = encoder.load_video_frames(video_path)
|
| 97 |
+
if video_frames is None:
|
| 98 |
+
print(f"Failed to load video: {video_path}")
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 102 |
+
|
| 103 |
+
# 编码视频
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
|
| 106 |
+
# print(latents.shape)
|
| 107 |
+
# assert False
|
| 108 |
+
# 编码文本
|
| 109 |
+
# prompt_emb = encoder.pipe.encode_prompt("A car driving scene captured by front camera")
|
| 110 |
+
if processed_count == 0:
|
| 111 |
+
print('encode prompt!!!')
|
| 112 |
+
prompt_emb = encoder.pipe.encode_prompt("A car driving scene captured by front camera")
|
| 113 |
+
del encoder.pipe.prompter
|
| 114 |
+
|
| 115 |
+
# 保存编码结果
|
| 116 |
+
encoded_data = {
|
| 117 |
+
"latents": latents.cpu(),
|
| 118 |
+
"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
|
| 119 |
+
"image_emb": {}
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
torch.save(encoded_data, encoded_path)
|
| 123 |
+
print(f"Saved encoded data: {encoded_path}")
|
| 124 |
+
processed_count += 1
|
| 125 |
+
|
| 126 |
+
except Exception as e:
|
| 127 |
+
print(f"Error encoding scene {scene_name}: {e}")
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
print(f"Encoding completed! Processed {processed_count} scenes.")
|
| 131 |
+
|
| 132 |
+
if __name__ == "__main__":
|
| 133 |
+
parser = argparse.ArgumentParser()
|
| 134 |
+
parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes")
|
| 135 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 136 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 137 |
+
parser.add_argument("--vae_path", type=str,
|
| 138 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 139 |
+
|
| 140 |
+
args = parser.parse_args()
|
| 141 |
+
encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path)
|
scripts/encode_openx.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import lightning as pl
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 6 |
+
import json
|
| 7 |
+
import imageio
|
| 8 |
+
from torchvision.transforms import v2
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
# 🔧 关键修复:设置环境变量避免GCS连接
|
| 15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 16 |
+
os.environ["TFDS_DISABLE_GCS"] = "1"
|
| 17 |
+
|
| 18 |
+
import tensorflow_datasets as tfds
|
| 19 |
+
import tensorflow as tf
|
| 20 |
+
|
| 21 |
+
class VideoEncoder(pl.LightningModule):
|
| 22 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 23 |
+
super().__init__()
|
| 24 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 25 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 26 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 27 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 28 |
+
|
| 29 |
+
self.frame_process = v2.Compose([
|
| 30 |
+
v2.ToTensor(),
|
| 31 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 32 |
+
])
|
| 33 |
+
|
| 34 |
+
def crop_and_resize(self, image, target_width=832, target_height=480):
|
| 35 |
+
"""调整图像尺寸"""
|
| 36 |
+
image = v2.functional.resize(
|
| 37 |
+
image,
|
| 38 |
+
(target_height, target_width),
|
| 39 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 40 |
+
)
|
| 41 |
+
return image
|
| 42 |
+
|
| 43 |
+
def load_episode_frames(self, episode_data, max_frames=300):
|
| 44 |
+
"""🔧 从fractal数据集加载视频帧 - 基于实际observation字段优化"""
|
| 45 |
+
frames = []
|
| 46 |
+
|
| 47 |
+
steps = episode_data['steps']
|
| 48 |
+
frame_count = 0
|
| 49 |
+
|
| 50 |
+
print(f"开始提取帧,最多 {max_frames} 帧...")
|
| 51 |
+
|
| 52 |
+
for step_idx, step in enumerate(steps):
|
| 53 |
+
if frame_count >= max_frames:
|
| 54 |
+
break
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
obs = step['observation']
|
| 58 |
+
|
| 59 |
+
# 🔧 基于实际的observation字段,优先使用'image'
|
| 60 |
+
img_data = None
|
| 61 |
+
image_keys_to_try = [
|
| 62 |
+
'image', # ✅ 确认存在的主要图像字段
|
| 63 |
+
'rgb', # 备用RGB图像
|
| 64 |
+
'camera_image', # 备用相机图像
|
| 65 |
+
'exterior_image_1_left', # 可能的外部摄像头
|
| 66 |
+
'wrist_image', # 可能的手腕摄像头
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
for img_key in image_keys_to_try:
|
| 70 |
+
if img_key in obs:
|
| 71 |
+
try:
|
| 72 |
+
img_tensor = obs[img_key]
|
| 73 |
+
img_data = img_tensor.numpy()
|
| 74 |
+
if step_idx < 3: # 只为前几个步骤打印
|
| 75 |
+
print(f"✅ 找到图像字段: {img_key}, 形状: {img_data.shape}")
|
| 76 |
+
break
|
| 77 |
+
except Exception as e:
|
| 78 |
+
if step_idx < 3:
|
| 79 |
+
print(f"尝试字段 {img_key} 失败: {e}")
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
if img_data is not None:
|
| 83 |
+
# 确保图像数据格式正确
|
| 84 |
+
if len(img_data.shape) == 3: # [H, W, C]
|
| 85 |
+
if img_data.dtype == np.uint8:
|
| 86 |
+
frame = Image.fromarray(img_data)
|
| 87 |
+
else:
|
| 88 |
+
# 如果是归一化的浮点数,转换为uint8
|
| 89 |
+
if img_data.max() <= 1.0:
|
| 90 |
+
img_data = (img_data * 255).astype(np.uint8)
|
| 91 |
+
else:
|
| 92 |
+
img_data = img_data.astype(np.uint8)
|
| 93 |
+
frame = Image.fromarray(img_data)
|
| 94 |
+
|
| 95 |
+
# 转换为RGB如果需要
|
| 96 |
+
if frame.mode != 'RGB':
|
| 97 |
+
frame = frame.convert('RGB')
|
| 98 |
+
|
| 99 |
+
frame = self.crop_and_resize(frame)
|
| 100 |
+
frame = self.frame_process(frame)
|
| 101 |
+
frames.append(frame)
|
| 102 |
+
frame_count += 1
|
| 103 |
+
|
| 104 |
+
if frame_count % 50 == 0:
|
| 105 |
+
print(f"已处理 {frame_count} 帧")
|
| 106 |
+
else:
|
| 107 |
+
if step_idx < 5:
|
| 108 |
+
print(f"步骤 {step_idx}: 图像形状不正确 {img_data.shape}")
|
| 109 |
+
else:
|
| 110 |
+
# 如果找不到图像,打印可用的观测键
|
| 111 |
+
if step_idx < 5: # 只为前几个步骤打印
|
| 112 |
+
available_keys = list(obs.keys())
|
| 113 |
+
print(f"步骤 {step_idx}: 未找到图像,可用键: {available_keys}")
|
| 114 |
+
|
| 115 |
+
except Exception as e:
|
| 116 |
+
print(f"处理步骤 {step_idx} 时出错: {e}")
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
print(f"成功提取 {len(frames)} 帧")
|
| 120 |
+
|
| 121 |
+
if len(frames) == 0:
|
| 122 |
+
return None
|
| 123 |
+
|
| 124 |
+
frames = torch.stack(frames, dim=0)
|
| 125 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 126 |
+
return frames
|
| 127 |
+
|
| 128 |
+
def extract_camera_poses(self, episode_data, num_frames):
|
| 129 |
+
"""🔧 从fractal数据集提取相机位姿信息 - 基于实际observation和action字段优化"""
|
| 130 |
+
camera_poses = []
|
| 131 |
+
|
| 132 |
+
steps = episode_data['steps']
|
| 133 |
+
frame_count = 0
|
| 134 |
+
|
| 135 |
+
print("提取相机位姿信息...")
|
| 136 |
+
|
| 137 |
+
# 🔧 累积位姿信息
|
| 138 |
+
cumulative_translation = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
| 139 |
+
cumulative_rotation = np.array([0.0, 0.0, 0.0], dtype=np.float32) # 欧拉角
|
| 140 |
+
|
| 141 |
+
for step_idx, step in enumerate(steps):
|
| 142 |
+
if frame_count >= num_frames:
|
| 143 |
+
break
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
obs = step['observation']
|
| 147 |
+
action = step.get('action', {})
|
| 148 |
+
|
| 149 |
+
# 🔧 基于实际的字段提取位姿变化
|
| 150 |
+
pose_data = {}
|
| 151 |
+
found_pose = False
|
| 152 |
+
|
| 153 |
+
# 1. 优先使用action中的world_vector(世界坐标系中的位移)
|
| 154 |
+
if 'world_vector' in action:
|
| 155 |
+
try:
|
| 156 |
+
world_vector = action['world_vector'].numpy()
|
| 157 |
+
if len(world_vector) == 3:
|
| 158 |
+
# 累积世界坐标位移
|
| 159 |
+
cumulative_translation += world_vector
|
| 160 |
+
pose_data['translation'] = cumulative_translation.copy()
|
| 161 |
+
found_pose = True
|
| 162 |
+
|
| 163 |
+
if step_idx < 3:
|
| 164 |
+
print(f"使用action.world_vector: {world_vector}, 累积位移: {cumulative_translation}")
|
| 165 |
+
except Exception as e:
|
| 166 |
+
if step_idx < 3:
|
| 167 |
+
print(f"action.world_vector提取失败: {e}")
|
| 168 |
+
|
| 169 |
+
# 2. 使用action中的rotation_delta(旋转变化)
|
| 170 |
+
if 'rotation_delta' in action:
|
| 171 |
+
try:
|
| 172 |
+
rotation_delta = action['rotation_delta'].numpy()
|
| 173 |
+
if len(rotation_delta) == 3:
|
| 174 |
+
# 累积旋转变化
|
| 175 |
+
cumulative_rotation += rotation_delta
|
| 176 |
+
|
| 177 |
+
# 转换为四元数(简化版本)
|
| 178 |
+
euler_angles = cumulative_rotation
|
| 179 |
+
# 欧拉角转四元数(ZYX顺序)
|
| 180 |
+
roll, pitch, yaw = euler_angles[0], euler_angles[1], euler_angles[2]
|
| 181 |
+
|
| 182 |
+
# 简化的欧拉角到四元数转换
|
| 183 |
+
cy = np.cos(yaw * 0.5)
|
| 184 |
+
sy = np.sin(yaw * 0.5)
|
| 185 |
+
cp = np.cos(pitch * 0.5)
|
| 186 |
+
sp = np.sin(pitch * 0.5)
|
| 187 |
+
cr = np.cos(roll * 0.5)
|
| 188 |
+
sr = np.sin(roll * 0.5)
|
| 189 |
+
|
| 190 |
+
qw = cr * cp * cy + sr * sp * sy
|
| 191 |
+
qx = sr * cp * cy - cr * sp * sy
|
| 192 |
+
qy = cr * sp * cy + sr * cp * sy
|
| 193 |
+
qz = cr * cp * sy - sr * sp * cy
|
| 194 |
+
|
| 195 |
+
pose_data['rotation'] = np.array([qw, qx, qy, qz], dtype=np.float32)
|
| 196 |
+
found_pose = True
|
| 197 |
+
|
| 198 |
+
if step_idx < 3:
|
| 199 |
+
print(f"使用action.rotation_delta: {rotation_delta}, 累积旋转: {cumulative_rotation}")
|
| 200 |
+
except Exception as e:
|
| 201 |
+
if step_idx < 3:
|
| 202 |
+
print(f"action.rotation_delta提取失败: {e}")
|
| 203 |
+
|
| 204 |
+
# 确保rotation字段存在
|
| 205 |
+
if 'rotation' not in pose_data:
|
| 206 |
+
# 使用当前累积的旋转计算四元数
|
| 207 |
+
roll, pitch, yaw = cumulative_rotation[0], cumulative_rotation[1], cumulative_rotation[2]
|
| 208 |
+
|
| 209 |
+
cy = np.cos(yaw * 0.5)
|
| 210 |
+
sy = np.sin(yaw * 0.5)
|
| 211 |
+
cp = np.cos(pitch * 0.5)
|
| 212 |
+
sp = np.sin(pitch * 0.5)
|
| 213 |
+
cr = np.cos(roll * 0.5)
|
| 214 |
+
sr = np.sin(roll * 0.5)
|
| 215 |
+
|
| 216 |
+
qw = cr * cp * cy + sr * sp * sy
|
| 217 |
+
qx = sr * cp * cy - cr * sp * sy
|
| 218 |
+
qy = cr * sp * cy + sr * cp * sy
|
| 219 |
+
qz = cr * cp * sy - sr * sp * cy
|
| 220 |
+
|
| 221 |
+
pose_data['rotation'] = np.array([qw, qx, qy, qz], dtype=np.float32)
|
| 222 |
+
|
| 223 |
+
camera_poses.append(pose_data)
|
| 224 |
+
frame_count += 1
|
| 225 |
+
|
| 226 |
+
except Exception as e:
|
| 227 |
+
print(f"提取位姿步骤 {step_idx} 时出错: {e}")
|
| 228 |
+
# 添加默认位姿
|
| 229 |
+
pose_data = {
|
| 230 |
+
'translation': cumulative_translation.copy(),
|
| 231 |
+
'rotation': np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
|
| 232 |
+
}
|
| 233 |
+
camera_poses.append(pose_data)
|
| 234 |
+
frame_count += 1
|
| 235 |
+
|
| 236 |
+
print(f"提取了 {len(camera_poses)} 个位姿")
|
| 237 |
+
print(f"最终累积位移: {cumulative_translation}")
|
| 238 |
+
print(f"最终累积旋转: {cumulative_rotation}")
|
| 239 |
+
|
| 240 |
+
return camera_poses
|
| 241 |
+
|
| 242 |
+
def create_camera_matrices(self, camera_poses):
|
| 243 |
+
"""将位姿转换为4x4变换矩阵"""
|
| 244 |
+
matrices = []
|
| 245 |
+
|
| 246 |
+
for pose in camera_poses:
|
| 247 |
+
matrix = np.eye(4, dtype=np.float32)
|
| 248 |
+
|
| 249 |
+
# 设置平移
|
| 250 |
+
matrix[:3, 3] = pose['translation']
|
| 251 |
+
|
| 252 |
+
# 设置旋转 - 假设是四元数 [w, x, y, z]
|
| 253 |
+
if len(pose['rotation']) == 4:
|
| 254 |
+
# 四元数转旋转矩阵
|
| 255 |
+
q = pose['rotation']
|
| 256 |
+
w, x, y, z = q[0], q[1], q[2], q[3]
|
| 257 |
+
|
| 258 |
+
# 四元数到旋转矩阵的转换
|
| 259 |
+
matrix[0, 0] = 1 - 2*(y*y + z*z)
|
| 260 |
+
matrix[0, 1] = 2*(x*y - w*z)
|
| 261 |
+
matrix[0, 2] = 2*(x*z + w*y)
|
| 262 |
+
matrix[1, 0] = 2*(x*y + w*z)
|
| 263 |
+
matrix[1, 1] = 1 - 2*(x*x + z*z)
|
| 264 |
+
matrix[1, 2] = 2*(y*z - w*x)
|
| 265 |
+
matrix[2, 0] = 2*(x*z - w*y)
|
| 266 |
+
matrix[2, 1] = 2*(y*z + w*x)
|
| 267 |
+
matrix[2, 2] = 1 - 2*(x*x + y*y)
|
| 268 |
+
elif len(pose['rotation']) == 3:
|
| 269 |
+
# 欧拉角转换(如果需要)
|
| 270 |
+
pass
|
| 271 |
+
|
| 272 |
+
matrices.append(matrix)
|
| 273 |
+
|
| 274 |
+
return np.array(matrices)
|
| 275 |
+
|
| 276 |
+
def encode_fractal_dataset(dataset_path, text_encoder_path, vae_path, output_dir, max_episodes=None):
|
| 277 |
+
"""🔧 编码fractal20220817_data数据集 - 基于实际字段结构优化"""
|
| 278 |
+
|
| 279 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 280 |
+
encoder = encoder.cuda()
|
| 281 |
+
encoder.pipe.device = "cuda"
|
| 282 |
+
|
| 283 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 284 |
+
|
| 285 |
+
processed_count = 0
|
| 286 |
+
prompt_emb = None
|
| 287 |
+
|
| 288 |
+
try:
|
| 289 |
+
# 🔧 使用你提供的成功方法加载数据集
|
| 290 |
+
ds = tfds.load(
|
| 291 |
+
"fractal20220817_data",
|
| 292 |
+
split="train",
|
| 293 |
+
data_dir=dataset_path,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
print(f"✅ 成功加载fractal20220817_data数据集")
|
| 297 |
+
|
| 298 |
+
# 限制处理的episode数量
|
| 299 |
+
if max_episodes:
|
| 300 |
+
ds = ds.take(max_episodes)
|
| 301 |
+
print(f"限制处理episodes数量: {max_episodes}")
|
| 302 |
+
|
| 303 |
+
except Exception as e:
|
| 304 |
+
print(f"❌ 加载数据集失败: {e}")
|
| 305 |
+
return
|
| 306 |
+
|
| 307 |
+
for episode_idx, episode in enumerate(tqdm(ds, desc="处理episodes")):
|
| 308 |
+
try:
|
| 309 |
+
episode_name = f"episode_{episode_idx:06d}"
|
| 310 |
+
save_episode_dir = os.path.join(output_dir, episode_name)
|
| 311 |
+
|
| 312 |
+
# 检查是否已经处理过
|
| 313 |
+
encoded_path = os.path.join(save_episode_dir, "encoded_video.pth")
|
| 314 |
+
if os.path.exists(encoded_path):
|
| 315 |
+
print(f"Episode {episode_name} 已处理,跳过...")
|
| 316 |
+
processed_count += 1
|
| 317 |
+
continue
|
| 318 |
+
|
| 319 |
+
os.makedirs(save_episode_dir, exist_ok=True)
|
| 320 |
+
|
| 321 |
+
print(f"\n🔧 处理episode {episode_name}...")
|
| 322 |
+
|
| 323 |
+
# 🔧 分析episode结构(仅对前几个episode)
|
| 324 |
+
if episode_idx < 2:
|
| 325 |
+
print("Episode结构分析:")
|
| 326 |
+
for key in episode.keys():
|
| 327 |
+
print(f" - {key}: {type(episode[key])}")
|
| 328 |
+
|
| 329 |
+
# 分析第一个step的结构
|
| 330 |
+
steps = episode['steps']
|
| 331 |
+
for step in steps.take(1):
|
| 332 |
+
print("第一个step结构:")
|
| 333 |
+
for key in step.keys():
|
| 334 |
+
print(f" - {key}: {type(step[key])}")
|
| 335 |
+
|
| 336 |
+
if 'observation' in step:
|
| 337 |
+
obs = step['observation']
|
| 338 |
+
print(" observation键:")
|
| 339 |
+
print(f" 🔍 可用字段: {list(obs.keys())}")
|
| 340 |
+
|
| 341 |
+
# 重点检查图像和位姿相关字段
|
| 342 |
+
key_fields = ['image', 'vector_to_go', 'rotation_delta_to_go', 'base_pose_tool_reached']
|
| 343 |
+
for key in key_fields:
|
| 344 |
+
if key in obs:
|
| 345 |
+
try:
|
| 346 |
+
value = obs[key]
|
| 347 |
+
if hasattr(value, 'shape'):
|
| 348 |
+
print(f" ✅ {key}: {type(value)}, shape: {value.shape}")
|
| 349 |
+
else:
|
| 350 |
+
print(f" ✅ {key}: {type(value)}")
|
| 351 |
+
except Exception as e:
|
| 352 |
+
print(f" ❌ {key}: 无法访问 ({e})")
|
| 353 |
+
|
| 354 |
+
if 'action' in step:
|
| 355 |
+
action = step['action']
|
| 356 |
+
print(" action键:")
|
| 357 |
+
print(f" 🔍 可用字段: {list(action.keys())}")
|
| 358 |
+
|
| 359 |
+
# 重点检查位姿相关字段
|
| 360 |
+
key_fields = ['world_vector', 'rotation_delta', 'base_displacement_vector']
|
| 361 |
+
for key in key_fields:
|
| 362 |
+
if key in action:
|
| 363 |
+
try:
|
| 364 |
+
value = action[key]
|
| 365 |
+
if hasattr(value, 'shape'):
|
| 366 |
+
print(f" ✅ {key}: {type(value)}, shape: {value.shape}")
|
| 367 |
+
else:
|
| 368 |
+
print(f" ✅ {key}: {type(value)}")
|
| 369 |
+
except Exception as e:
|
| 370 |
+
print(f" ❌ {key}: 无法访问 ({e})")
|
| 371 |
+
|
| 372 |
+
# 加载视频帧
|
| 373 |
+
video_frames = encoder.load_episode_frames(episode)
|
| 374 |
+
if video_frames is None:
|
| 375 |
+
print(f"❌ 无法加载episode {episode_name}的视频帧")
|
| 376 |
+
continue
|
| 377 |
+
|
| 378 |
+
print(f"✅ Episode {episode_name} 视频形状: {video_frames.shape}")
|
| 379 |
+
|
| 380 |
+
# 提取相机位姿
|
| 381 |
+
num_frames = video_frames.shape[1]
|
| 382 |
+
camera_poses = encoder.extract_camera_poses(episode, num_frames)
|
| 383 |
+
camera_matrices = encoder.create_camera_matrices(camera_poses)
|
| 384 |
+
|
| 385 |
+
print(f"🔧 编码episode {episode_name}...")
|
| 386 |
+
|
| 387 |
+
# 准备相机数据
|
| 388 |
+
cam_emb = {
|
| 389 |
+
'extrinsic': camera_matrices,
|
| 390 |
+
'intrinsic': np.eye(3, dtype=np.float32)
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
# 编码视频
|
| 394 |
+
frames_batch = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 395 |
+
|
| 396 |
+
with torch.no_grad():
|
| 397 |
+
latents = encoder.pipe.encode_video(frames_batch, **encoder.tiler_kwargs)[0]
|
| 398 |
+
|
| 399 |
+
# 编码文本prompt(第一次)
|
| 400 |
+
if prompt_emb is None:
|
| 401 |
+
print('🔧 编码prompt...')
|
| 402 |
+
prompt_emb = encoder.pipe.encode_prompt(
|
| 403 |
+
"A video of robotic manipulation task with camera movement"
|
| 404 |
+
)
|
| 405 |
+
# 释放prompter以节省内存
|
| 406 |
+
del encoder.pipe.prompter
|
| 407 |
+
|
| 408 |
+
# 保存编码结果
|
| 409 |
+
encoded_data = {
|
| 410 |
+
"latents": latents.cpu(),
|
| 411 |
+
"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v
|
| 412 |
+
for k, v in prompt_emb.items()},
|
| 413 |
+
"cam_emb": cam_emb,
|
| 414 |
+
"episode_info": {
|
| 415 |
+
"episode_idx": episode_idx,
|
| 416 |
+
"total_frames": video_frames.shape[1],
|
| 417 |
+
"pose_extraction_method": "observation_action_based"
|
| 418 |
+
}
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
torch.save(encoded_data, encoded_path)
|
| 422 |
+
print(f"✅ 保存编码数据: {encoded_path}")
|
| 423 |
+
|
| 424 |
+
processed_count += 1
|
| 425 |
+
print(f"✅ 已处理 {processed_count} 个episodes")
|
| 426 |
+
|
| 427 |
+
except Exception as e:
|
| 428 |
+
print(f"❌ 处理episode {episode_idx}时出错: {e}")
|
| 429 |
+
import traceback
|
| 430 |
+
traceback.print_exc()
|
| 431 |
+
continue
|
| 432 |
+
|
| 433 |
+
print(f"🎉 编码完成! 总共处理了 {processed_count} 个episodes")
|
| 434 |
+
if __name__ == "__main__":
|
| 435 |
+
parser = argparse.ArgumentParser(description="Encode Open-X Fractal20220817 Dataset - Based on Real Structure")
|
| 436 |
+
parser.add_argument("--dataset_path", type=str,
|
| 437 |
+
default="/share_zhuyixuan05/public_datasets/open-x/0.1.0",
|
| 438 |
+
help="Path to tensorflow_datasets directory")
|
| 439 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 440 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 441 |
+
parser.add_argument("--vae_path", type=str,
|
| 442 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 443 |
+
parser.add_argument("--output_dir", type=str,
|
| 444 |
+
default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded")
|
| 445 |
+
parser.add_argument("--max_episodes", type=int, default=10000,
|
| 446 |
+
help="Maximum number of episodes to process (default: 10 for testing)")
|
| 447 |
+
|
| 448 |
+
args = parser.parse_args()
|
| 449 |
+
|
| 450 |
+
# 确保输出目录存在
|
| 451 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 452 |
+
|
| 453 |
+
print("🚀 开始编码Open-X Fractal数据集 (基于实际字段结构)...")
|
| 454 |
+
print(f"📁 数据集路径: {args.dataset_path}")
|
| 455 |
+
print(f"💾 输出目录: {args.output_dir}")
|
| 456 |
+
print(f"🔢 最大处理episodes: {args.max_episodes}")
|
| 457 |
+
print("🔧 基于实际observation和action字段的位姿提取方法")
|
| 458 |
+
print("✅ 优先使用 'image' 字段获取图像数据")
|
| 459 |
+
|
| 460 |
+
encode_fractal_dataset(
|
| 461 |
+
args.dataset_path,
|
| 462 |
+
args.text_encoder_path,
|
| 463 |
+
args.vae_path,
|
| 464 |
+
args.output_dir,
|
| 465 |
+
args.max_episodes
|
| 466 |
+
)
|
scripts/encode_rlbench_video.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import lightning as pl
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 6 |
+
import json
|
| 7 |
+
import imageio
|
| 8 |
+
from torchvision.transforms import v2
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pdb
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 16 |
+
|
| 17 |
+
class VideoEncoder(pl.LightningModule):
|
| 18 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 19 |
+
super().__init__()
|
| 20 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 21 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 22 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 23 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 24 |
+
|
| 25 |
+
self.frame_process = v2.Compose([
|
| 26 |
+
# v2.CenterCrop(size=(900, 1600)),
|
| 27 |
+
# v2.Resize(size=(900, 1600), antialias=True),
|
| 28 |
+
v2.ToTensor(),
|
| 29 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
def crop_and_resize(self, image):
|
| 33 |
+
width, height = image.size
|
| 34 |
+
# print(width,height)
|
| 35 |
+
width_ori, height_ori_ = 512 , 512
|
| 36 |
+
image = v2.functional.resize(
|
| 37 |
+
image,
|
| 38 |
+
(round(height_ori_), round(width_ori)),
|
| 39 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 40 |
+
)
|
| 41 |
+
return image
|
| 42 |
+
|
| 43 |
+
def load_video_frames(self, video_path):
|
| 44 |
+
"""加载完整视频"""
|
| 45 |
+
reader = imageio.get_reader(video_path)
|
| 46 |
+
frames = []
|
| 47 |
+
|
| 48 |
+
for frame_data in reader:
|
| 49 |
+
frame = Image.fromarray(frame_data)
|
| 50 |
+
frame = self.crop_and_resize(frame)
|
| 51 |
+
frame = self.frame_process(frame)
|
| 52 |
+
frames.append(frame)
|
| 53 |
+
|
| 54 |
+
reader.close()
|
| 55 |
+
|
| 56 |
+
if len(frames) == 0:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
frames = torch.stack(frames, dim=0)
|
| 60 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 61 |
+
return frames
|
| 62 |
+
|
| 63 |
+
def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
|
| 64 |
+
"""编码所有场景的视频"""
|
| 65 |
+
|
| 66 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 67 |
+
encoder = encoder.cuda()
|
| 68 |
+
encoder.pipe.device = "cuda"
|
| 69 |
+
|
| 70 |
+
processed_count = 0
|
| 71 |
+
prompt_emb = 0
|
| 72 |
+
|
| 73 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 74 |
+
|
| 75 |
+
for i, scene_name in enumerate(os.listdir(scenes_path)):
|
| 76 |
+
# if i < 1700:
|
| 77 |
+
# continue
|
| 78 |
+
scene_dir = os.path.join(scenes_path, scene_name)
|
| 79 |
+
for j, demo_name in tqdm(enumerate(os.listdir(scene_dir)),total=len(os.listdir(scene_dir))):
|
| 80 |
+
demo_dir = os.path.join(scene_dir, demo_name)
|
| 81 |
+
for filename in os.listdir(demo_dir):
|
| 82 |
+
# 检查文件是否以.mp4结尾(不区分大小写)
|
| 83 |
+
if filename.lower().endswith('.mp4'):
|
| 84 |
+
# 获取完整路径
|
| 85 |
+
full_path = os.path.join(demo_dir, filename)
|
| 86 |
+
print(full_path)
|
| 87 |
+
save_dir = os.path.join(output_dir,scene_name+'_'+demo_name)
|
| 88 |
+
# print('in:',scene_dir)
|
| 89 |
+
# print('out:',save_dir)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
os.makedirs(save_dir,exist_ok=True)
|
| 94 |
+
# 检查是否已编码
|
| 95 |
+
encoded_path = os.path.join(save_dir, "encoded_video.pth")
|
| 96 |
+
if os.path.exists(encoded_path):
|
| 97 |
+
print(f"Scene {scene_name} already encoded, skipping...")
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
# 加载场景信息
|
| 101 |
+
|
| 102 |
+
scene_cam_path = full_path.replace("side.mp4", "data.npy")
|
| 103 |
+
print(scene_cam_path)
|
| 104 |
+
if not os.path.exists(scene_cam_path):
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
# with np.load(scene_cam_path) as data:
|
| 108 |
+
cam_data = np.load(scene_cam_path)
|
| 109 |
+
cam_emb = cam_data
|
| 110 |
+
print(cam_data.shape)
|
| 111 |
+
# with open(scene_cam_path, 'rb') as f:
|
| 112 |
+
# cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
|
| 113 |
+
|
| 114 |
+
# 加载视频
|
| 115 |
+
video_path = full_path
|
| 116 |
+
if not os.path.exists(video_path):
|
| 117 |
+
print(f"Video not found: {video_path}")
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
# try:
|
| 121 |
+
print(f"Encoding scene {scene_name}...Demo {demo_name}")
|
| 122 |
+
|
| 123 |
+
# 加载和编码视频
|
| 124 |
+
video_frames = encoder.load_video_frames(video_path)
|
| 125 |
+
if video_frames is None:
|
| 126 |
+
print(f"Failed to load video: {video_path}")
|
| 127 |
+
continue
|
| 128 |
+
|
| 129 |
+
video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 130 |
+
print('video shape:',video_frames.shape)
|
| 131 |
+
# 编码视频
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
|
| 134 |
+
|
| 135 |
+
# 编码文本
|
| 136 |
+
# if processed_count == 0:
|
| 137 |
+
# print('encode prompt!!!')
|
| 138 |
+
# prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
|
| 139 |
+
# del encoder.pipe.prompter
|
| 140 |
+
# pdb.set_trace()
|
| 141 |
+
# 保存编码结果
|
| 142 |
+
encoded_data = {
|
| 143 |
+
"latents": latents.cpu(),
|
| 144 |
+
#"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
|
| 145 |
+
"cam_emb": cam_emb
|
| 146 |
+
}
|
| 147 |
+
# pdb.set_trace()
|
| 148 |
+
torch.save(encoded_data, encoded_path)
|
| 149 |
+
print(f"Saved encoded data: {encoded_path}")
|
| 150 |
+
processed_count += 1
|
| 151 |
+
|
| 152 |
+
# except Exception as e:
|
| 153 |
+
# print(f"Error encoding scene {scene_name}: {e}")
|
| 154 |
+
# continue
|
| 155 |
+
|
| 156 |
+
print(f"Encoding completed! Processed {processed_count} scenes.")
|
| 157 |
+
|
| 158 |
+
if __name__ == "__main__":
|
| 159 |
+
parser = argparse.ArgumentParser()
|
| 160 |
+
parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/RLBench")
|
| 161 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 162 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 163 |
+
parser.add_argument("--vae_path", type=str,
|
| 164 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 165 |
+
|
| 166 |
+
parser.add_argument("--output_dir",type=str,
|
| 167 |
+
default="/share_zhuyixuan05/zhuyixuan05/rlbench")
|
| 168 |
+
|
| 169 |
+
args = parser.parse_args()
|
| 170 |
+
encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
|
scripts/encode_sekai_video.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import lightning as pl
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 6 |
+
import json
|
| 7 |
+
import imageio
|
| 8 |
+
from torchvision.transforms import v2
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import argparse
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pdb
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 16 |
+
|
| 17 |
+
class VideoEncoder(pl.LightningModule):
|
| 18 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 19 |
+
super().__init__()
|
| 20 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 21 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 22 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 23 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 24 |
+
|
| 25 |
+
self.frame_process = v2.Compose([
|
| 26 |
+
# v2.CenterCrop(size=(900, 1600)),
|
| 27 |
+
# v2.Resize(size=(900, 1600), antialias=True),
|
| 28 |
+
v2.ToTensor(),
|
| 29 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
def crop_and_resize(self, image):
|
| 33 |
+
width, height = image.size
|
| 34 |
+
# print(width,height)
|
| 35 |
+
width_ori, height_ori_ = 832 , 480
|
| 36 |
+
image = v2.functional.resize(
|
| 37 |
+
image,
|
| 38 |
+
(round(height_ori_), round(width_ori)),
|
| 39 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 40 |
+
)
|
| 41 |
+
return image
|
| 42 |
+
|
| 43 |
+
def load_video_frames(self, video_path):
|
| 44 |
+
"""加载完整视频"""
|
| 45 |
+
reader = imageio.get_reader(video_path)
|
| 46 |
+
frames = []
|
| 47 |
+
|
| 48 |
+
for frame_data in reader:
|
| 49 |
+
frame = Image.fromarray(frame_data)
|
| 50 |
+
frame = self.crop_and_resize(frame)
|
| 51 |
+
frame = self.frame_process(frame)
|
| 52 |
+
frames.append(frame)
|
| 53 |
+
|
| 54 |
+
reader.close()
|
| 55 |
+
|
| 56 |
+
if len(frames) == 0:
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
frames = torch.stack(frames, dim=0)
|
| 60 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 61 |
+
return frames
|
| 62 |
+
|
| 63 |
+
def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
|
| 64 |
+
"""编码所有场景的视频"""
|
| 65 |
+
|
| 66 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 67 |
+
encoder = encoder.cuda()
|
| 68 |
+
encoder.pipe.device = "cuda"
|
| 69 |
+
|
| 70 |
+
processed_count = 0
|
| 71 |
+
prompt_emb = 0
|
| 72 |
+
|
| 73 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 74 |
+
|
| 75 |
+
for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
|
| 76 |
+
# if i < 1700:
|
| 77 |
+
# continue
|
| 78 |
+
scene_dir = os.path.join(scenes_path, scene_name)
|
| 79 |
+
save_dir = os.path.join(output_dir,scene_name.split('.')[0])
|
| 80 |
+
# print('in:',scene_dir)
|
| 81 |
+
# print('out:',save_dir)
|
| 82 |
+
|
| 83 |
+
if not scene_dir.endswith(".mp4"):# or os.path.isdir(output_dir):
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
os.makedirs(save_dir,exist_ok=True)
|
| 88 |
+
# 检查是否已编码
|
| 89 |
+
encoded_path = os.path.join(save_dir, "encoded_video.pth")
|
| 90 |
+
if os.path.exists(encoded_path):
|
| 91 |
+
print(f"Scene {scene_name} already encoded, skipping...")
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
# 加载场景信息
|
| 95 |
+
|
| 96 |
+
scene_cam_path = scene_dir.replace(".mp4", ".npz")
|
| 97 |
+
if not os.path.exists(scene_cam_path):
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
with np.load(scene_cam_path) as data:
|
| 101 |
+
cam_data = data.files
|
| 102 |
+
cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
|
| 103 |
+
# with open(scene_cam_path, 'rb') as f:
|
| 104 |
+
# cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
|
| 105 |
+
|
| 106 |
+
# 加载视频
|
| 107 |
+
video_path = scene_dir
|
| 108 |
+
if not os.path.exists(video_path):
|
| 109 |
+
print(f"Video not found: {video_path}")
|
| 110 |
+
continue
|
| 111 |
+
|
| 112 |
+
# try:
|
| 113 |
+
print(f"Encoding scene {scene_name}...")
|
| 114 |
+
|
| 115 |
+
# 加载和编码视频
|
| 116 |
+
video_frames = encoder.load_video_frames(video_path)
|
| 117 |
+
if video_frames is None:
|
| 118 |
+
print(f"Failed to load video: {video_path}")
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 122 |
+
print('video shape:',video_frames.shape)
|
| 123 |
+
# 编码视频
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
|
| 126 |
+
|
| 127 |
+
# 编码文本
|
| 128 |
+
if processed_count == 0:
|
| 129 |
+
print('encode prompt!!!')
|
| 130 |
+
prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
|
| 131 |
+
del encoder.pipe.prompter
|
| 132 |
+
# pdb.set_trace()
|
| 133 |
+
# 保存编码结果
|
| 134 |
+
encoded_data = {
|
| 135 |
+
"latents": latents.cpu(),
|
| 136 |
+
#"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
|
| 137 |
+
"cam_emb": cam_emb
|
| 138 |
+
}
|
| 139 |
+
# pdb.set_trace()
|
| 140 |
+
torch.save(encoded_data, encoded_path)
|
| 141 |
+
print(f"Saved encoded data: {encoded_path}")
|
| 142 |
+
processed_count += 1
|
| 143 |
+
|
| 144 |
+
# except Exception as e:
|
| 145 |
+
# print(f"Error encoding scene {scene_name}: {e}")
|
| 146 |
+
# continue
|
| 147 |
+
|
| 148 |
+
print(f"Encoding completed! Processed {processed_count} scenes.")
|
| 149 |
+
|
| 150 |
+
if __name__ == "__main__":
|
| 151 |
+
parser = argparse.ArgumentParser()
|
| 152 |
+
parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/sekai/Sekai-Project/sekai-game-walking")
|
| 153 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 154 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 155 |
+
parser.add_argument("--vae_path", type=str,
|
| 156 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 157 |
+
|
| 158 |
+
parser.add_argument("--output_dir",type=str,
|
| 159 |
+
default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking")
|
| 160 |
+
|
| 161 |
+
args = parser.parse_args()
|
| 162 |
+
encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
|
scripts/encode_sekai_walking.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import lightning as pl
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 7 |
+
import json
|
| 8 |
+
import imageio
|
| 9 |
+
from torchvision.transforms import v2
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
import argparse
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pdb
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 17 |
+
|
| 18 |
+
class VideoEncoder(pl.LightningModule):
|
| 19 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 20 |
+
super().__init__()
|
| 21 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 22 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 23 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 24 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 25 |
+
|
| 26 |
+
self.frame_process = v2.Compose([
|
| 27 |
+
# v2.CenterCrop(size=(900, 1600)),
|
| 28 |
+
# v2.Resize(size=(900, 1600), antialias=True),
|
| 29 |
+
v2.ToTensor(),
|
| 30 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 31 |
+
])
|
| 32 |
+
|
| 33 |
+
def crop_and_resize(self, image):
|
| 34 |
+
width, height = image.size
|
| 35 |
+
# print(width,height)
|
| 36 |
+
width_ori, height_ori_ = 832 , 480
|
| 37 |
+
image = v2.functional.resize(
|
| 38 |
+
image,
|
| 39 |
+
(round(height_ori_), round(width_ori)),
|
| 40 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 41 |
+
)
|
| 42 |
+
return image
|
| 43 |
+
|
| 44 |
+
def load_video_frames(self, video_path):
|
| 45 |
+
"""加载完整视频"""
|
| 46 |
+
reader = imageio.get_reader(video_path)
|
| 47 |
+
frames = []
|
| 48 |
+
|
| 49 |
+
for frame_data in reader:
|
| 50 |
+
frame = Image.fromarray(frame_data)
|
| 51 |
+
frame = self.crop_and_resize(frame)
|
| 52 |
+
frame = self.frame_process(frame)
|
| 53 |
+
frames.append(frame)
|
| 54 |
+
|
| 55 |
+
reader.close()
|
| 56 |
+
|
| 57 |
+
if len(frames) == 0:
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
frames = torch.stack(frames, dim=0)
|
| 61 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 62 |
+
return frames
|
| 63 |
+
|
| 64 |
+
def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
|
| 65 |
+
"""编码所有场景的视频"""
|
| 66 |
+
|
| 67 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 68 |
+
encoder = encoder.cuda()
|
| 69 |
+
encoder.pipe.device = "cuda"
|
| 70 |
+
|
| 71 |
+
processed_count = 0
|
| 72 |
+
|
| 73 |
+
processed_chunk_count = 0
|
| 74 |
+
|
| 75 |
+
prompt_emb = 0
|
| 76 |
+
|
| 77 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 78 |
+
chunk_size = 300
|
| 79 |
+
for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
|
| 80 |
+
# print('index-----:',type(i))
|
| 81 |
+
# if i < 3000 :#or i >=2000:
|
| 82 |
+
# # print('index-----:',i)
|
| 83 |
+
# continue
|
| 84 |
+
# print('index:',i)
|
| 85 |
+
print('index:',i)
|
| 86 |
+
scene_dir = os.path.join(scenes_path, scene_name)
|
| 87 |
+
|
| 88 |
+
# save_dir = os.path.join(output_dir,scene_name.split('.')[0])
|
| 89 |
+
# print('in:',scene_dir)
|
| 90 |
+
# print('out:',save_dir)
|
| 91 |
+
|
| 92 |
+
if not scene_dir.endswith(".mp4"):# or os.path.isdir(output_dir):
|
| 93 |
+
continue
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
scene_cam_path = scene_dir.replace(".mp4", ".npz")
|
| 97 |
+
if not os.path.exists(scene_cam_path):
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
with np.load(scene_cam_path) as data:
|
| 101 |
+
cam_data = data.files
|
| 102 |
+
cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
|
| 103 |
+
# with open(scene_cam_path, 'rb') as f:
|
| 104 |
+
# cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
|
| 105 |
+
|
| 106 |
+
video_name = scene_name[:-4].split('_')[0]
|
| 107 |
+
start_frame = int(scene_name[:-4].split('_')[1])
|
| 108 |
+
end_frame = int(scene_name[:-4].split('_')[2])
|
| 109 |
+
|
| 110 |
+
sampled_range = range(start_frame, end_frame , chunk_size)
|
| 111 |
+
sampled_frames = list(sampled_range)
|
| 112 |
+
|
| 113 |
+
sampled_chunk_end = sampled_frames[0] + 300
|
| 114 |
+
start_str = f"{sampled_frames[0]:07d}"
|
| 115 |
+
end_str = f"{sampled_chunk_end:07d}"
|
| 116 |
+
|
| 117 |
+
chunk_name = f"{video_name}_{start_str}_{end_str}"
|
| 118 |
+
save_chunk_path = os.path.join(output_dir,chunk_name,"encoded_video.pth")
|
| 119 |
+
|
| 120 |
+
if os.path.exists(save_chunk_path):
|
| 121 |
+
print(f"Video {video_name} already encoded, skipping...")
|
| 122 |
+
continue
|
| 123 |
+
|
| 124 |
+
# 加载视频
|
| 125 |
+
video_path = scene_dir
|
| 126 |
+
if not os.path.exists(video_path):
|
| 127 |
+
print(f"Video not found: {video_path}")
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
video_frames = encoder.load_video_frames(video_path)
|
| 131 |
+
if video_frames is None:
|
| 132 |
+
print(f"Failed to load video: {video_path}")
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 136 |
+
print('video shape:',video_frames.shape)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# print(sampled_frames)
|
| 141 |
+
|
| 142 |
+
print(f"Encoding scene {scene_name}...")
|
| 143 |
+
for sampled_chunk_start in sampled_frames:
|
| 144 |
+
sampled_chunk_end = sampled_chunk_start + 300
|
| 145 |
+
start_str = f"{sampled_chunk_start:07d}"
|
| 146 |
+
end_str = f"{sampled_chunk_end:07d}"
|
| 147 |
+
|
| 148 |
+
# 生成保存目录名(假设video_name已定义)
|
| 149 |
+
chunk_name = f"{video_name}_{start_str}_{end_str}"
|
| 150 |
+
save_chunk_dir = os.path.join(output_dir,chunk_name)
|
| 151 |
+
|
| 152 |
+
os.makedirs(save_chunk_dir,exist_ok=True)
|
| 153 |
+
print(f"Encoding chunk {chunk_name}...")
|
| 154 |
+
|
| 155 |
+
encoded_path = os.path.join(save_chunk_dir, "encoded_video.pth")
|
| 156 |
+
|
| 157 |
+
if os.path.exists(encoded_path):
|
| 158 |
+
print(f"Chunk {chunk_name} already encoded, skipping...")
|
| 159 |
+
continue
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
chunk_frames = video_frames[:,:, sampled_chunk_start - start_frame : sampled_chunk_end - start_frame,...]
|
| 163 |
+
# print('extrinsic:',cam_emb['extrinsic'].shape)
|
| 164 |
+
chunk_cam_emb ={'extrinsic':cam_emb['extrinsic'][sampled_chunk_start - start_frame : sampled_chunk_end - start_frame],
|
| 165 |
+
'intrinsic':cam_emb['intrinsic']}
|
| 166 |
+
|
| 167 |
+
# print('chunk shape:',chunk_frames.shape)
|
| 168 |
+
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
latents = encoder.pipe.encode_video(chunk_frames, **encoder.tiler_kwargs)[0]
|
| 171 |
+
|
| 172 |
+
# 编码文本
|
| 173 |
+
# if processed_count == 0:
|
| 174 |
+
# print('encode prompt!!!')
|
| 175 |
+
# prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
|
| 176 |
+
# del encoder.pipe.prompter
|
| 177 |
+
# pdb.set_trace()
|
| 178 |
+
# 保存编码结果
|
| 179 |
+
encoded_data = {
|
| 180 |
+
"latents": latents.cpu(),
|
| 181 |
+
# "prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
|
| 182 |
+
"cam_emb": chunk_cam_emb
|
| 183 |
+
}
|
| 184 |
+
# pdb.set_trace()
|
| 185 |
+
torch.save(encoded_data, encoded_path)
|
| 186 |
+
print(f"Saved encoded data: {encoded_path}")
|
| 187 |
+
processed_chunk_count += 1
|
| 188 |
+
|
| 189 |
+
processed_count += 1
|
| 190 |
+
|
| 191 |
+
print("Encoded scene numebr:",processed_count)
|
| 192 |
+
print("Encoded chunk numebr:",processed_chunk_count)
|
| 193 |
+
|
| 194 |
+
# os.makedirs(save_dir,exist_ok=True)
|
| 195 |
+
# # 检查是否已编码
|
| 196 |
+
# encoded_path = os.path.join(save_dir, "encoded_video.pth")
|
| 197 |
+
# if os.path.exists(encoded_path):
|
| 198 |
+
# print(f"Scene {scene_name} already encoded, skipping...")
|
| 199 |
+
# continue
|
| 200 |
+
|
| 201 |
+
# 加载场景信息
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# try:
|
| 206 |
+
# print(f"Encoding scene {scene_name}...")
|
| 207 |
+
|
| 208 |
+
# 加载和编码视频
|
| 209 |
+
|
| 210 |
+
# 编码视频
|
| 211 |
+
# with torch.no_grad():
|
| 212 |
+
# latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
|
| 213 |
+
|
| 214 |
+
# # 编码文本
|
| 215 |
+
# if processed_count == 0:
|
| 216 |
+
# print('encode prompt!!!')
|
| 217 |
+
# prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
|
| 218 |
+
# del encoder.pipe.prompter
|
| 219 |
+
# # pdb.set_trace()
|
| 220 |
+
# # 保存编码结果
|
| 221 |
+
# encoded_data = {
|
| 222 |
+
# "latents": latents.cpu(),
|
| 223 |
+
# #"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
|
| 224 |
+
# "cam_emb": cam_emb
|
| 225 |
+
# }
|
| 226 |
+
# # pdb.set_trace()
|
| 227 |
+
# torch.save(encoded_data, encoded_path)
|
| 228 |
+
# print(f"Saved encoded data: {encoded_path}")
|
| 229 |
+
# processed_count += 1
|
| 230 |
+
|
| 231 |
+
# except Exception as e:
|
| 232 |
+
# print(f"Error encoding scene {scene_name}: {e}")
|
| 233 |
+
# continue
|
| 234 |
+
|
| 235 |
+
print(f"Encoding completed! Processed {processed_count} scenes.")
|
| 236 |
+
|
| 237 |
+
if __name__ == "__main__":
|
| 238 |
+
parser = argparse.ArgumentParser()
|
| 239 |
+
parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/sekai/Sekai-Project/sekai-game-walking")
|
| 240 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 241 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 242 |
+
parser.add_argument("--vae_path", type=str,
|
| 243 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 244 |
+
|
| 245 |
+
parser.add_argument("--output_dir",type=str,
|
| 246 |
+
default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking")
|
| 247 |
+
|
| 248 |
+
args = parser.parse_args()
|
| 249 |
+
encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
|
scripts/encode_spatialvid.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import lightning as pl
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 7 |
+
import json
|
| 8 |
+
import imageio
|
| 9 |
+
from torchvision.transforms import v2
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
import argparse
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pdb
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import pandas as pd
|
| 16 |
+
|
| 17 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 18 |
+
|
| 19 |
+
from scipy.spatial.transform import Slerp
|
| 20 |
+
from scipy.spatial.transform import Rotation as R
|
| 21 |
+
|
| 22 |
+
def interpolate_camera_poses(original_frames, original_poses, target_frames):
|
| 23 |
+
"""
|
| 24 |
+
对相机姿态进行插值,生成目标帧对应的姿态参数
|
| 25 |
+
|
| 26 |
+
参数:
|
| 27 |
+
original_frames: 原始帧索引列表,如[0,6,12,...]
|
| 28 |
+
original_poses: 原始姿态数组,形状为(n,7),每行[tx, ty, tz, qx, qy, qz, qw]
|
| 29 |
+
target_frames: 目标帧索引列表,如[0,4,8,12,...]
|
| 30 |
+
|
| 31 |
+
返回:
|
| 32 |
+
target_poses: 插值后的姿态数组,形状为(m,7),m为目标帧数量
|
| 33 |
+
"""
|
| 34 |
+
# 确保输入有效
|
| 35 |
+
print('original_frames:',len(original_frames))
|
| 36 |
+
print('original_poses:',len(original_poses))
|
| 37 |
+
if len(original_frames) != len(original_poses):
|
| 38 |
+
raise ValueError("原始帧数量与姿态数量不匹配")
|
| 39 |
+
|
| 40 |
+
if original_poses.shape[1] != 7:
|
| 41 |
+
raise ValueError(f"原始姿态应为(n,7)格式,实际为{original_poses.shape}")
|
| 42 |
+
|
| 43 |
+
target_poses = []
|
| 44 |
+
|
| 45 |
+
# 提取旋转部分并转换为Rotation对象
|
| 46 |
+
rotations = R.from_quat(original_poses[:, 3:7]) # 提取四元数部分
|
| 47 |
+
|
| 48 |
+
for t in target_frames:
|
| 49 |
+
# 找到t前后的原始帧索引
|
| 50 |
+
idx = np.searchsorted(original_frames, t, side='left')
|
| 51 |
+
|
| 52 |
+
# 处理边界情况
|
| 53 |
+
if idx == 0:
|
| 54 |
+
# 使用第一个姿态
|
| 55 |
+
target_poses.append(original_poses[0])
|
| 56 |
+
continue
|
| 57 |
+
if idx >= len(original_frames):
|
| 58 |
+
# 使用最后一个姿态
|
| 59 |
+
target_poses.append(original_poses[-1])
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
# 获取前后帧的信息
|
| 63 |
+
t_prev, t_next = original_frames[idx-1], original_frames[idx]
|
| 64 |
+
pose_prev, pose_next = original_poses[idx-1], original_poses[idx]
|
| 65 |
+
|
| 66 |
+
# 计算插值权重
|
| 67 |
+
alpha = (t - t_prev) / (t_next - t_prev)
|
| 68 |
+
|
| 69 |
+
# 1. 平移向量的线性插值
|
| 70 |
+
translation_prev = pose_prev[:3]
|
| 71 |
+
translation_next = pose_next[:3]
|
| 72 |
+
interpolated_translation = translation_prev + alpha * (translation_next - translation_prev)
|
| 73 |
+
|
| 74 |
+
# 2. 旋转四元数的球面线性插值(SLERP)
|
| 75 |
+
# 创建Slerp对象
|
| 76 |
+
slerp = Slerp([t_prev, t_next], rotations[idx-1:idx+1])
|
| 77 |
+
interpolated_rotation = slerp(t)
|
| 78 |
+
|
| 79 |
+
# 组合平移和旋转
|
| 80 |
+
interpolated_pose = np.concatenate([
|
| 81 |
+
interpolated_translation,
|
| 82 |
+
interpolated_rotation.as_quat() # 转换回四元数
|
| 83 |
+
])
|
| 84 |
+
|
| 85 |
+
target_poses.append(interpolated_pose)
|
| 86 |
+
|
| 87 |
+
return np.array(target_poses)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class VideoEncoder(pl.LightningModule):
|
| 91 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 92 |
+
super().__init__()
|
| 93 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 94 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 95 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 96 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 97 |
+
|
| 98 |
+
self.frame_process = v2.Compose([
|
| 99 |
+
# v2.CenterCrop(size=(900, 1600)),
|
| 100 |
+
# v2.Resize(size=(900, 1600), antialias=True),
|
| 101 |
+
v2.ToTensor(),
|
| 102 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 103 |
+
])
|
| 104 |
+
|
| 105 |
+
def crop_and_resize(self, image):
|
| 106 |
+
width, height = image.size
|
| 107 |
+
# print(width,height)
|
| 108 |
+
width_ori, height_ori_ = 832 , 480
|
| 109 |
+
image = v2.functional.resize(
|
| 110 |
+
image,
|
| 111 |
+
(round(height_ori_), round(width_ori)),
|
| 112 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 113 |
+
)
|
| 114 |
+
return image
|
| 115 |
+
|
| 116 |
+
def load_video_frames(self, video_path):
|
| 117 |
+
"""加载完整视频"""
|
| 118 |
+
reader = imageio.get_reader(video_path)
|
| 119 |
+
frames = []
|
| 120 |
+
|
| 121 |
+
for frame_data in reader:
|
| 122 |
+
frame = Image.fromarray(frame_data)
|
| 123 |
+
frame = self.crop_and_resize(frame)
|
| 124 |
+
frame = self.frame_process(frame)
|
| 125 |
+
frames.append(frame)
|
| 126 |
+
|
| 127 |
+
reader.close()
|
| 128 |
+
|
| 129 |
+
if len(frames) == 0:
|
| 130 |
+
return None
|
| 131 |
+
|
| 132 |
+
frames = torch.stack(frames, dim=0)
|
| 133 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 134 |
+
return frames
|
| 135 |
+
|
| 136 |
+
def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
|
| 137 |
+
"""编码所有场景的视频"""
|
| 138 |
+
|
| 139 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 140 |
+
encoder = encoder.cuda()
|
| 141 |
+
encoder.pipe.device = "cuda"
|
| 142 |
+
|
| 143 |
+
processed_count = 0
|
| 144 |
+
|
| 145 |
+
processed_chunk_count = 0
|
| 146 |
+
|
| 147 |
+
prompt_emb = 0
|
| 148 |
+
|
| 149 |
+
metadata = pd.read_csv('/share_zhuyixuan05/public_datasets/SpatialVID-HQ/data/train/SpatialVID_HQ_metadata.csv')
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 153 |
+
chunk_size = 300
|
| 154 |
+
required_keys = ["latents", "cam_emb", "prompt_emb"]
|
| 155 |
+
|
| 156 |
+
for i, scene_name in enumerate(os.listdir(scenes_path)):
|
| 157 |
+
# print('index-----:',type(i))
|
| 158 |
+
if i < 3 :#or i >=2000:
|
| 159 |
+
# # print('index-----:',i)
|
| 160 |
+
continue
|
| 161 |
+
# print('index:',i)
|
| 162 |
+
print('group:',i)
|
| 163 |
+
scene_dir = os.path.join(scenes_path, scene_name)
|
| 164 |
+
|
| 165 |
+
# save_dir = os.path.join(output_dir,scene_name.split('.')[0])
|
| 166 |
+
print('in:',scene_dir)
|
| 167 |
+
# print('out:',save_dir)
|
| 168 |
+
for j, video_name in tqdm(enumerate(os.listdir(scene_dir)),total=len(os.listdir(scene_dir))):
|
| 169 |
+
|
| 170 |
+
# if j < 1000 :#or i >=2000:
|
| 171 |
+
# print('index:',j)
|
| 172 |
+
# continue
|
| 173 |
+
print(video_name)
|
| 174 |
+
video_path = os.path.join(scene_dir, video_name)
|
| 175 |
+
if not video_path.endswith(".mp4"):# or os.path.isdir(output_dir):
|
| 176 |
+
continue
|
| 177 |
+
|
| 178 |
+
video_info = metadata[metadata['id'] == video_name[:-4]]
|
| 179 |
+
num_frames = video_info['num frames'].iloc[0]
|
| 180 |
+
|
| 181 |
+
scene_cam_dir = video_path.replace( "videos","annotations")[:-4]
|
| 182 |
+
scene_cam_path = os.path.join(scene_cam_dir,'poses.npy')
|
| 183 |
+
|
| 184 |
+
scene_caption_path = os.path.join(scene_cam_dir,'caption.json')
|
| 185 |
+
|
| 186 |
+
with open(scene_caption_path, 'r', encoding='utf-8') as f:
|
| 187 |
+
caption_data = json.load(f)
|
| 188 |
+
caption = caption_data["SceneSummary"]
|
| 189 |
+
if not os.path.exists(scene_cam_path):
|
| 190 |
+
print(f"Pose not found: {scene_cam_path}")
|
| 191 |
+
continue
|
| 192 |
+
|
| 193 |
+
camera_poses = np.load(scene_cam_path)
|
| 194 |
+
cam_data_len = camera_poses.shape[0]
|
| 195 |
+
|
| 196 |
+
# cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
|
| 197 |
+
# with open(scene_cam_path, 'rb') as f:
|
| 198 |
+
# cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
|
| 199 |
+
|
| 200 |
+
# 加载视频
|
| 201 |
+
# video_path = scene_dir
|
| 202 |
+
if not os.path.exists(video_path):
|
| 203 |
+
print(f"Video not found: {video_path}")
|
| 204 |
+
continue
|
| 205 |
+
|
| 206 |
+
start_str = f"{0:07d}"
|
| 207 |
+
end_str = f"{chunk_size:07d}"
|
| 208 |
+
chunk_name = f"{video_name[:-4]}_{start_str}_{end_str}"
|
| 209 |
+
first_save_chunk_dir = os.path.join(output_dir,chunk_name)
|
| 210 |
+
|
| 211 |
+
first_chunk_encoded_path = os.path.join(first_save_chunk_dir, "encoded_video.pth")
|
| 212 |
+
# print(first_chunk_encoded_path)
|
| 213 |
+
if os.path.exists(first_chunk_encoded_path):
|
| 214 |
+
data = torch.load(first_chunk_encoded_path,weights_only=False)
|
| 215 |
+
if 'latents' in data:
|
| 216 |
+
video_frames = 1
|
| 217 |
+
else:
|
| 218 |
+
video_frames = encoder.load_video_frames(video_path)
|
| 219 |
+
if video_frames is None:
|
| 220 |
+
print(f"Failed to load video: {video_path}")
|
| 221 |
+
continue
|
| 222 |
+
print('video shape:',video_frames.shape)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 227 |
+
print('video shape:',video_frames.shape)
|
| 228 |
+
|
| 229 |
+
video_name = video_name[:-4].split('_')[0]
|
| 230 |
+
start_frame = 0
|
| 231 |
+
end_frame = num_frames
|
| 232 |
+
# print("num_frames:",num_frames)
|
| 233 |
+
|
| 234 |
+
cam_interval = end_frame // (cam_data_len - 1)
|
| 235 |
+
|
| 236 |
+
cam_frames = np.linspace(start_frame, end_frame, cam_data_len, endpoint=True)
|
| 237 |
+
cam_frames = np.round(cam_frames).astype(int)
|
| 238 |
+
cam_frames = cam_frames.tolist()
|
| 239 |
+
# list(range(0, end_frame + 1 , cam_interval))
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
sampled_range = range(start_frame, end_frame , chunk_size)
|
| 243 |
+
sampled_frames = list(sampled_range)
|
| 244 |
+
|
| 245 |
+
sampled_chunk_end = sampled_frames[0] + chunk_size
|
| 246 |
+
start_str = f"{sampled_frames[0]:07d}"
|
| 247 |
+
end_str = f"{sampled_chunk_end:07d}"
|
| 248 |
+
|
| 249 |
+
chunk_name = f"{video_name}_{start_str}_{end_str}"
|
| 250 |
+
# save_chunk_path = os.path.join(output_dir,chunk_name,"encoded_video.pth")
|
| 251 |
+
|
| 252 |
+
# if os.path.exists(save_chunk_path):
|
| 253 |
+
# print(f"Video {video_name} already encoded, skipping...")
|
| 254 |
+
# continue
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# print(sampled_frames)
|
| 261 |
+
|
| 262 |
+
print(f"Encoding scene {video_name}...")
|
| 263 |
+
chunk_count_in_one_video = 0
|
| 264 |
+
for sampled_chunk_start in sampled_frames:
|
| 265 |
+
if num_frames - sampled_chunk_start < 100:
|
| 266 |
+
continue
|
| 267 |
+
sampled_chunk_end = sampled_chunk_start + chunk_size
|
| 268 |
+
start_str = f"{sampled_chunk_start:07d}"
|
| 269 |
+
end_str = f"{sampled_chunk_end:07d}"
|
| 270 |
+
|
| 271 |
+
resample_cam_frame = list(range(sampled_chunk_start, sampled_chunk_end , 4))
|
| 272 |
+
|
| 273 |
+
# 生成保存目录名(假设video_name已定义)
|
| 274 |
+
chunk_name = f"{video_name}_{start_str}_{end_str}"
|
| 275 |
+
save_chunk_dir = os.path.join(output_dir,chunk_name)
|
| 276 |
+
|
| 277 |
+
os.makedirs(save_chunk_dir,exist_ok=True)
|
| 278 |
+
print(f"Encoding chunk {chunk_name}...")
|
| 279 |
+
|
| 280 |
+
encoded_path = os.path.join(save_chunk_dir, "encoded_video.pth")
|
| 281 |
+
|
| 282 |
+
missing_keys = required_keys
|
| 283 |
+
if os.path.exists(encoded_path):
|
| 284 |
+
print('error:',encoded_path)
|
| 285 |
+
data = torch.load(encoded_path,weights_only=False)
|
| 286 |
+
missing_keys = [key for key in required_keys if key not in data]
|
| 287 |
+
# print(missing_keys)
|
| 288 |
+
# print(f"Chunk {chunk_name} already encoded, skipping...")
|
| 289 |
+
if missing_keys:
|
| 290 |
+
print(f"警告: 文件中缺少以下必要元素: {missing_keys}")
|
| 291 |
+
if len(missing_keys) == 0 :
|
| 292 |
+
continue
|
| 293 |
+
else:
|
| 294 |
+
print(f"警告: 缺少pth文件: {encoded_path}")
|
| 295 |
+
if not isinstance(video_frames, torch.Tensor):
|
| 296 |
+
|
| 297 |
+
video_frames = encoder.load_video_frames(video_path)
|
| 298 |
+
if video_frames is None:
|
| 299 |
+
print(f"Failed to load video: {video_path}")
|
| 300 |
+
continue
|
| 301 |
+
|
| 302 |
+
video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
|
| 303 |
+
|
| 304 |
+
print('video shape:',video_frames.shape)
|
| 305 |
+
if "latents" in missing_keys:
|
| 306 |
+
chunk_frames = video_frames[:,:, sampled_chunk_start - start_frame : sampled_chunk_end - start_frame,...]
|
| 307 |
+
|
| 308 |
+
# print('extrinsic:',cam_emb['extrinsic'].shape)
|
| 309 |
+
|
| 310 |
+
# chunk_cam_emb ={'extrinsic':cam_emb['extrinsic'][sampled_chunk_start - start_frame : sampled_chunk_end - start_frame],
|
| 311 |
+
# 'intrinsic':cam_emb['intrinsic']}
|
| 312 |
+
|
| 313 |
+
# print('chunk shape:',chunk_frames.shape)
|
| 314 |
+
|
| 315 |
+
with torch.no_grad():
|
| 316 |
+
latents = encoder.pipe.encode_video(chunk_frames, **encoder.tiler_kwargs)[0]
|
| 317 |
+
else:
|
| 318 |
+
latents = data['latents']
|
| 319 |
+
if "cam_emb" in missing_keys:
|
| 320 |
+
cam_emb = interpolate_camera_poses(cam_frames, camera_poses,resample_cam_frame)
|
| 321 |
+
chunk_cam_emb ={'extrinsic':cam_emb}
|
| 322 |
+
print(f"视频长度:{chunk_size},重采样相机长度:{cam_emb.shape[0]}")
|
| 323 |
+
else:
|
| 324 |
+
chunk_cam_emb = data['cam_emb']
|
| 325 |
+
|
| 326 |
+
if "prompt_emb" in missing_keys:
|
| 327 |
+
# 编码文本
|
| 328 |
+
if chunk_count_in_one_video == 0:
|
| 329 |
+
print(caption)
|
| 330 |
+
with torch.no_grad():
|
| 331 |
+
prompt_emb = encoder.pipe.encode_prompt(caption)
|
| 332 |
+
else:
|
| 333 |
+
prompt_emb = data['prompt_emb']
|
| 334 |
+
|
| 335 |
+
# del encoder.pipe.prompter
|
| 336 |
+
# pdb.set_trace()
|
| 337 |
+
# 保存编码结果
|
| 338 |
+
encoded_data = {
|
| 339 |
+
"latents": latents.cpu(),
|
| 340 |
+
"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
|
| 341 |
+
"cam_emb": chunk_cam_emb
|
| 342 |
+
}
|
| 343 |
+
# pdb.set_trace()
|
| 344 |
+
torch.save(encoded_data, encoded_path)
|
| 345 |
+
print(f"Saved encoded data: {encoded_path}")
|
| 346 |
+
processed_chunk_count += 1
|
| 347 |
+
chunk_count_in_one_video += 1
|
| 348 |
+
|
| 349 |
+
processed_count += 1
|
| 350 |
+
|
| 351 |
+
print("Encoded scene numebr:",processed_count)
|
| 352 |
+
print("Encoded chunk numebr:",processed_chunk_count)
|
| 353 |
+
|
| 354 |
+
# os.makedirs(save_dir,exist_ok=True)
|
| 355 |
+
# # 检查是否已编码
|
| 356 |
+
# encoded_path = os.path.join(save_dir, "encoded_video.pth")
|
| 357 |
+
# if os.path.exists(encoded_path):
|
| 358 |
+
# print(f"Scene {scene_name} already encoded, skipping...")
|
| 359 |
+
# continue
|
| 360 |
+
|
| 361 |
+
# 加载场景信息
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# try:
|
| 366 |
+
# print(f"Encoding scene {scene_name}...")
|
| 367 |
+
|
| 368 |
+
# 加载和编码视频
|
| 369 |
+
|
| 370 |
+
# 编码视频
|
| 371 |
+
# with torch.no_grad():
|
| 372 |
+
# latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
|
| 373 |
+
|
| 374 |
+
# # 编码文本
|
| 375 |
+
# if processed_count == 0:
|
| 376 |
+
# print('encode prompt!!!')
|
| 377 |
+
# prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
|
| 378 |
+
# del encoder.pipe.prompter
|
| 379 |
+
# # pdb.set_trace()
|
| 380 |
+
# # 保存编码结果
|
| 381 |
+
# encoded_data = {
|
| 382 |
+
# "latents": latents.cpu(),
|
| 383 |
+
# #"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
|
| 384 |
+
# "cam_emb": cam_emb
|
| 385 |
+
# }
|
| 386 |
+
# # pdb.set_trace()
|
| 387 |
+
# torch.save(encoded_data, encoded_path)
|
| 388 |
+
# print(f"Saved encoded data: {encoded_path}")
|
| 389 |
+
# processed_count += 1
|
| 390 |
+
|
| 391 |
+
# except Exception as e:
|
| 392 |
+
# print(f"Error encoding scene {scene_name}: {e}")
|
| 393 |
+
# continue
|
| 394 |
+
|
| 395 |
+
print(f"Encoding completed! Processed {processed_count} scenes.")
|
| 396 |
+
|
| 397 |
+
if __name__ == "__main__":
|
| 398 |
+
parser = argparse.ArgumentParser()
|
| 399 |
+
parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/SpatialVID-HQ/SpatialVid/HQ/videos/")
|
| 400 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 401 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 402 |
+
parser.add_argument("--vae_path", type=str,
|
| 403 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 404 |
+
|
| 405 |
+
parser.add_argument("--output_dir",type=str,
|
| 406 |
+
default="/share_zhuyixuan05/zhuyixuan05/spatialvid")
|
| 407 |
+
|
| 408 |
+
args = parser.parse_args()
|
| 409 |
+
encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
|
scripts/encode_spatialvid_first_frame.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import lightning as pl
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 7 |
+
import json
|
| 8 |
+
import imageio
|
| 9 |
+
from torchvision.transforms import v2
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
import argparse
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pdb
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import pandas as pd
|
| 16 |
+
|
| 17 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 18 |
+
|
| 19 |
+
from scipy.spatial.transform import Slerp
|
| 20 |
+
from scipy.spatial.transform import Rotation as R
|
| 21 |
+
|
| 22 |
+
def interpolate_camera_poses(original_frames, original_poses, target_frames):
|
| 23 |
+
"""
|
| 24 |
+
对相机姿态进行插值,生成目标帧对应的姿态参数
|
| 25 |
+
|
| 26 |
+
参数:
|
| 27 |
+
original_frames: 原始帧索引列表,如[0,6,12,...]
|
| 28 |
+
original_poses: 原始姿态数组,形状为(n,7),每行[tx, ty, tz, qx, qy, qz, qw]
|
| 29 |
+
target_frames: 目标帧索引列表,如[0,4,8,12,...]
|
| 30 |
+
|
| 31 |
+
返回:
|
| 32 |
+
target_poses: 插值后的姿态数组,形状为(m,7),m为目标帧数量
|
| 33 |
+
"""
|
| 34 |
+
# 确保输入有效
|
| 35 |
+
print('original_frames:',len(original_frames))
|
| 36 |
+
print('original_poses:',len(original_poses))
|
| 37 |
+
if len(original_frames) != len(original_poses):
|
| 38 |
+
raise ValueError("原始帧数量与姿态数量不匹配")
|
| 39 |
+
|
| 40 |
+
if original_poses.shape[1] != 7:
|
| 41 |
+
raise ValueError(f"原始姿态应为(n,7)格式,实际为{original_poses.shape}")
|
| 42 |
+
|
| 43 |
+
target_poses = []
|
| 44 |
+
|
| 45 |
+
# 提取旋转部分并转换为Rotation对象
|
| 46 |
+
rotations = R.from_quat(original_poses[:, 3:7]) # 提取四元数部分
|
| 47 |
+
|
| 48 |
+
for t in target_frames:
|
| 49 |
+
# 找到t前后的原始帧索引
|
| 50 |
+
idx = np.searchsorted(original_frames, t, side='left')
|
| 51 |
+
|
| 52 |
+
# 处理边界情况
|
| 53 |
+
if idx == 0:
|
| 54 |
+
# 使用第一个姿态
|
| 55 |
+
target_poses.append(original_poses[0])
|
| 56 |
+
continue
|
| 57 |
+
if idx >= len(original_frames):
|
| 58 |
+
# 使用最后一个姿态
|
| 59 |
+
target_poses.append(original_poses[-1])
|
| 60 |
+
continue
|
| 61 |
+
|
| 62 |
+
# 获取前后帧的信息
|
| 63 |
+
t_prev, t_next = original_frames[idx-1], original_frames[idx]
|
| 64 |
+
pose_prev, pose_next = original_poses[idx-1], original_poses[idx]
|
| 65 |
+
|
| 66 |
+
# 计算插值权重
|
| 67 |
+
alpha = (t - t_prev) / (t_next - t_prev)
|
| 68 |
+
|
| 69 |
+
# 1. 平移向量的线性插值
|
| 70 |
+
translation_prev = pose_prev[:3]
|
| 71 |
+
translation_next = pose_next[:3]
|
| 72 |
+
interpolated_translation = translation_prev + alpha * (translation_next - translation_prev)
|
| 73 |
+
|
| 74 |
+
# 2. 旋转四元数的球面线性插值(SLERP)
|
| 75 |
+
# 创建Slerp对象
|
| 76 |
+
slerp = Slerp([t_prev, t_next], rotations[idx-1:idx+1])
|
| 77 |
+
interpolated_rotation = slerp(t)
|
| 78 |
+
|
| 79 |
+
# 组合平移和旋转
|
| 80 |
+
interpolated_pose = np.concatenate([
|
| 81 |
+
interpolated_translation,
|
| 82 |
+
interpolated_rotation.as_quat() # 转换回四元数
|
| 83 |
+
])
|
| 84 |
+
|
| 85 |
+
target_poses.append(interpolated_pose)
|
| 86 |
+
|
| 87 |
+
return np.array(target_poses)
|
| 88 |
+
|
| 89 |
+
class VideoEncoder(pl.LightningModule):
|
| 90 |
+
def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
| 91 |
+
super().__init__()
|
| 92 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 93 |
+
model_manager.load_models([text_encoder_path, vae_path])
|
| 94 |
+
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
|
| 95 |
+
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
| 96 |
+
|
| 97 |
+
self.frame_process = v2.Compose([
|
| 98 |
+
v2.ToTensor(),
|
| 99 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 100 |
+
])
|
| 101 |
+
|
| 102 |
+
def crop_and_resize(self, image):
|
| 103 |
+
width, height = image.size
|
| 104 |
+
width_ori, height_ori_ = 832 , 480
|
| 105 |
+
image = v2.functional.resize(
|
| 106 |
+
image,
|
| 107 |
+
(round(height_ori_), round(width_ori)),
|
| 108 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 109 |
+
)
|
| 110 |
+
return image
|
| 111 |
+
|
| 112 |
+
def load_single_frame(self, video_path, frame_idx):
|
| 113 |
+
"""只加载指定的单帧"""
|
| 114 |
+
reader = imageio.get_reader(video_path)
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
# 直接跳转到指定帧
|
| 118 |
+
frame_data = reader.get_data(frame_idx)
|
| 119 |
+
frame = Image.fromarray(frame_data)
|
| 120 |
+
frame = self.crop_and_resize(frame)
|
| 121 |
+
frame = self.frame_process(frame)
|
| 122 |
+
|
| 123 |
+
# 添加batch和time维度: [C, H, W] -> [1, C, 1, H, W]
|
| 124 |
+
frame = frame.unsqueeze(0).unsqueeze(2)
|
| 125 |
+
|
| 126 |
+
except Exception as e:
|
| 127 |
+
print(f"Error loading frame {frame_idx} from {video_path}: {e}")
|
| 128 |
+
return None
|
| 129 |
+
finally:
|
| 130 |
+
reader.close()
|
| 131 |
+
|
| 132 |
+
return frame
|
| 133 |
+
|
| 134 |
+
def load_video_frames(self, video_path):
|
| 135 |
+
"""加载完整视频(保留用于兼容性)"""
|
| 136 |
+
reader = imageio.get_reader(video_path)
|
| 137 |
+
frames = []
|
| 138 |
+
|
| 139 |
+
for frame_data in reader:
|
| 140 |
+
frame = Image.fromarray(frame_data)
|
| 141 |
+
frame = self.crop_and_resize(frame)
|
| 142 |
+
frame = self.frame_process(frame)
|
| 143 |
+
frames.append(frame)
|
| 144 |
+
|
| 145 |
+
reader.close()
|
| 146 |
+
|
| 147 |
+
if len(frames) == 0:
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
frames = torch.stack(frames, dim=0)
|
| 151 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 152 |
+
return frames
|
| 153 |
+
|
| 154 |
+
def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
|
| 155 |
+
"""编码所有场景的视频"""
|
| 156 |
+
|
| 157 |
+
encoder = VideoEncoder(text_encoder_path, vae_path)
|
| 158 |
+
encoder = encoder.cuda()
|
| 159 |
+
encoder.pipe.device = "cuda"
|
| 160 |
+
|
| 161 |
+
processed_count = 0
|
| 162 |
+
processed_chunk_count = 0
|
| 163 |
+
|
| 164 |
+
metadata = pd.read_csv('/share_zhuyixuan05/public_datasets/SpatialVID-HQ/data/train/SpatialVID_HQ_metadata.csv')
|
| 165 |
+
|
| 166 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 167 |
+
chunk_size = 300
|
| 168 |
+
|
| 169 |
+
for i, scene_name in enumerate(os.listdir(scenes_path)):
|
| 170 |
+
if i < 2:
|
| 171 |
+
continue
|
| 172 |
+
print('group:',i)
|
| 173 |
+
scene_dir = os.path.join(scenes_path, scene_name)
|
| 174 |
+
|
| 175 |
+
print('in:',scene_dir)
|
| 176 |
+
for j, video_name in tqdm(enumerate(os.listdir(scene_dir)),total=len(os.listdir(scene_dir))):
|
| 177 |
+
print(video_name)
|
| 178 |
+
video_path = os.path.join(scene_dir, video_name)
|
| 179 |
+
if not video_path.endswith(".mp4"):
|
| 180 |
+
continue
|
| 181 |
+
|
| 182 |
+
video_info = metadata[metadata['id'] == video_name[:-4]]
|
| 183 |
+
num_frames = video_info['num frames'].iloc[0]
|
| 184 |
+
|
| 185 |
+
scene_cam_dir = video_path.replace("videos","annotations")[:-4]
|
| 186 |
+
scene_cam_path = os.path.join(scene_cam_dir,'poses.npy')
|
| 187 |
+
scene_caption_path = os.path.join(scene_cam_dir,'caption.json')
|
| 188 |
+
|
| 189 |
+
with open(scene_caption_path, 'r', encoding='utf-8') as f:
|
| 190 |
+
caption_data = json.load(f)
|
| 191 |
+
caption = caption_data["SceneSummary"]
|
| 192 |
+
|
| 193 |
+
if not os.path.exists(scene_cam_path):
|
| 194 |
+
print(f"Pose not found: {scene_cam_path}")
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
camera_poses = np.load(scene_cam_path)
|
| 198 |
+
cam_data_len = camera_poses.shape[0]
|
| 199 |
+
|
| 200 |
+
if not os.path.exists(video_path):
|
| 201 |
+
print(f"Video not found: {video_path}")
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
video_name = video_name[:-4].split('_')[0]
|
| 205 |
+
start_frame = 0
|
| 206 |
+
end_frame = num_frames
|
| 207 |
+
|
| 208 |
+
cam_interval = end_frame // (cam_data_len - 1)
|
| 209 |
+
|
| 210 |
+
cam_frames = np.linspace(start_frame, end_frame, cam_data_len, endpoint=True)
|
| 211 |
+
cam_frames = np.round(cam_frames).astype(int)
|
| 212 |
+
cam_frames = cam_frames.tolist()
|
| 213 |
+
|
| 214 |
+
sampled_range = range(start_frame, end_frame, chunk_size)
|
| 215 |
+
sampled_frames = list(sampled_range)
|
| 216 |
+
|
| 217 |
+
print(f"Encoding scene {video_name}...")
|
| 218 |
+
chunk_count_in_one_video = 0
|
| 219 |
+
|
| 220 |
+
for sampled_chunk_start in sampled_frames:
|
| 221 |
+
if num_frames - sampled_chunk_start < 100:
|
| 222 |
+
continue
|
| 223 |
+
|
| 224 |
+
sampled_chunk_end = sampled_chunk_start + chunk_size
|
| 225 |
+
start_str = f"{sampled_chunk_start:07d}"
|
| 226 |
+
end_str = f"{sampled_chunk_end:07d}"
|
| 227 |
+
|
| 228 |
+
chunk_name = f"{video_name}_{start_str}_{end_str}"
|
| 229 |
+
save_chunk_dir = os.path.join(output_dir, chunk_name)
|
| 230 |
+
os.makedirs(save_chunk_dir, exist_ok=True)
|
| 231 |
+
|
| 232 |
+
print(f"Encoding chunk {chunk_name}...")
|
| 233 |
+
|
| 234 |
+
first_latent_path = os.path.join(save_chunk_dir, "first_latent.pth")
|
| 235 |
+
|
| 236 |
+
if os.path.exists(first_latent_path):
|
| 237 |
+
print(f"First latent for chunk {chunk_name} already exists, skipping...")
|
| 238 |
+
continue
|
| 239 |
+
|
| 240 |
+
# 只加载需要的那一帧
|
| 241 |
+
first_frame_idx = sampled_chunk_start
|
| 242 |
+
print(f"first_frame:{first_frame_idx}")
|
| 243 |
+
first_frame = encoder.load_single_frame(video_path, first_frame_idx)
|
| 244 |
+
|
| 245 |
+
if first_frame is None:
|
| 246 |
+
print(f"Failed to load frame {first_frame_idx} from: {video_path}")
|
| 247 |
+
continue
|
| 248 |
+
|
| 249 |
+
first_frame = first_frame.to("cuda", dtype=torch.bfloat16)
|
| 250 |
+
|
| 251 |
+
# 重复4次
|
| 252 |
+
repeated_first_frame = first_frame.repeat(1, 1, 4, 1, 1)
|
| 253 |
+
print(f"Repeated first frame shape: {repeated_first_frame.shape}")
|
| 254 |
+
|
| 255 |
+
with torch.no_grad():
|
| 256 |
+
first_latents = encoder.pipe.encode_video(repeated_first_frame, **encoder.tiler_kwargs)[0]
|
| 257 |
+
|
| 258 |
+
first_latent_data = {
|
| 259 |
+
"latents": first_latents.cpu(),
|
| 260 |
+
}
|
| 261 |
+
torch.save(first_latent_data, first_latent_path)
|
| 262 |
+
print(f"Saved first latent: {first_latent_path}")
|
| 263 |
+
|
| 264 |
+
processed_chunk_count += 1
|
| 265 |
+
chunk_count_in_one_video += 1
|
| 266 |
+
|
| 267 |
+
processed_count += 1
|
| 268 |
+
print("Encoded scene number:", processed_count)
|
| 269 |
+
print("Encoded chunk number:", processed_chunk_count)
|
| 270 |
+
|
| 271 |
+
print(f"Encoding completed! Processed {processed_count} scenes.")
|
| 272 |
+
|
| 273 |
+
if __name__ == "__main__":
|
| 274 |
+
parser = argparse.ArgumentParser()
|
| 275 |
+
parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/SpatialVID-HQ/SpatialVid/HQ/videos/")
|
| 276 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 277 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
|
| 278 |
+
parser.add_argument("--vae_path", type=str,
|
| 279 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
|
| 280 |
+
|
| 281 |
+
parser.add_argument("--output_dir",type=str,
|
| 282 |
+
default="/share_zhuyixuan05/zhuyixuan05/spatialvid")
|
| 283 |
+
|
| 284 |
+
args = parser.parse_args()
|
| 285 |
+
encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
|
scripts/hud_logo.py
CHANGED
|
@@ -7,7 +7,7 @@ os.makedirs("wasd_ui", exist_ok=True)
|
|
| 7 |
key_size = (48, 48)
|
| 8 |
corner = 10
|
| 9 |
bg_padding = 6
|
| 10 |
-
font = ImageFont.truetype("arial.ttf", 28) #
|
| 11 |
|
| 12 |
def rounded_rect(im, bbox, radius, fill):
|
| 13 |
draw = ImageDraw.Draw(im, "RGBA")
|
|
|
|
| 7 |
key_size = (48, 48)
|
| 8 |
corner = 10
|
| 9 |
bg_padding = 6
|
| 10 |
+
font = ImageFont.truetype("arial.ttf", 28) # 替换成本地支持的字体
|
| 11 |
|
| 12 |
def rounded_rect(im, bbox, radius, fill):
|
| 13 |
draw = ImageDraw.Draw(im, "RGBA")
|
scripts/infer_demo.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
from typing import Optional
|
| 5 |
|
| 6 |
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 7 |
sys.path.append(ROOT_DIR)
|
|
@@ -12,158 +10,48 @@ import numpy as np
|
|
| 12 |
from PIL import Image
|
| 13 |
import imageio
|
| 14 |
import json
|
| 15 |
-
from diffsynth import
|
| 16 |
import argparse
|
| 17 |
from torchvision.transforms import v2
|
| 18 |
from einops import rearrange
|
| 19 |
-
from scipy.spatial.transform import Rotation as R
|
| 20 |
import random
|
| 21 |
import copy
|
| 22 |
from datetime import datetime
|
| 23 |
|
| 24 |
-
VALID_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg"}
|
| 25 |
-
class InlineVideoEncoder:
|
| 26 |
-
|
| 27 |
-
def __init__(self, pipe: WanVideoAstraPipeline, device="cuda"):
|
| 28 |
-
self.device = getattr(pipe, "device", device)
|
| 29 |
-
self.tiler_kwargs = {"tiled": True, "tile_size": (34, 34), "tile_stride": (18, 16)}
|
| 30 |
-
self.frame_process = v2.Compose([
|
| 31 |
-
v2.ToTensor(),
|
| 32 |
-
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 33 |
-
])
|
| 34 |
-
|
| 35 |
-
self.pipe = pipe
|
| 36 |
-
|
| 37 |
-
@staticmethod
|
| 38 |
-
def _crop_and_resize(image: Image.Image) -> Image.Image:
|
| 39 |
-
target_w, target_h = 832, 480
|
| 40 |
-
return v2.functional.resize(
|
| 41 |
-
image,
|
| 42 |
-
(round(target_h), round(target_w)),
|
| 43 |
-
interpolation=v2.InterpolationMode.BILINEAR,
|
| 44 |
-
)
|
| 45 |
-
|
| 46 |
-
def preprocess_frame(self, image: Image.Image) -> torch.Tensor:
|
| 47 |
-
image = image.convert("RGB")
|
| 48 |
-
image = self._crop_and_resize(image)
|
| 49 |
-
return self.frame_process(image)
|
| 50 |
-
|
| 51 |
-
def load_video_frames(self, video_path: Path) -> Optional[torch.Tensor]:
|
| 52 |
-
reader = imageio.get_reader(str(video_path))
|
| 53 |
-
frames = []
|
| 54 |
-
for frame_data in reader:
|
| 55 |
-
frame = Image.fromarray(frame_data)
|
| 56 |
-
frames.append(self.preprocess_frame(frame))
|
| 57 |
-
reader.close()
|
| 58 |
-
|
| 59 |
-
if not frames:
|
| 60 |
-
return None
|
| 61 |
-
|
| 62 |
-
frames = torch.stack(frames, dim=0)
|
| 63 |
-
return rearrange(frames, "T C H W -> C T H W")
|
| 64 |
-
|
| 65 |
-
def encode_frames_to_latents(self, frames: torch.Tensor) -> torch.Tensor:
|
| 66 |
-
frames = frames.unsqueeze(0).to(self.device, dtype=torch.bfloat16)
|
| 67 |
-
with torch.no_grad():
|
| 68 |
-
latents = self.pipe.encode_video(frames, **self.tiler_kwargs)[0]
|
| 69 |
-
|
| 70 |
-
if latents.dim() == 5 and latents.shape[0] == 1:
|
| 71 |
-
latents = latents.squeeze(0)
|
| 72 |
-
return latents.cpu()
|
| 73 |
-
|
| 74 |
-
def image_to_frame_stack(
|
| 75 |
-
image_path: Path,
|
| 76 |
-
encoder: InlineVideoEncoder,
|
| 77 |
-
repeat_count: int = 10
|
| 78 |
-
) -> torch.Tensor:
|
| 79 |
-
"""Repeat a single image into a tensor with specified number of frames, shape [C, T, H, W]"""
|
| 80 |
-
if image_path.suffix.lower() not in VALID_IMAGE_EXTENSIONS:
|
| 81 |
-
raise ValueError(f"Unsupported image format: {image_path.suffix}")
|
| 82 |
-
|
| 83 |
-
image = Image.open(str(image_path))
|
| 84 |
-
frame = encoder.preprocess_frame(image)
|
| 85 |
-
frames = torch.stack([frame for _ in range(repeat_count)], dim=0)
|
| 86 |
-
return rearrange(frames, "T C H W -> C T H W")
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
def load_or_encode_condition(
|
| 90 |
-
condition_pth_path: Optional[str],
|
| 91 |
-
condition_video: Optional[str],
|
| 92 |
-
condition_image: Optional[str],
|
| 93 |
-
start_frame: int,
|
| 94 |
-
num_frames: int,
|
| 95 |
-
device: str,
|
| 96 |
-
pipe: WanVideoAstraPipeline,
|
| 97 |
-
) -> tuple[torch.Tensor, dict]:
|
| 98 |
-
if condition_pth_path:
|
| 99 |
-
return load_encoded_video_from_pth(condition_pth_path, start_frame, num_frames)
|
| 100 |
-
|
| 101 |
-
encoder = InlineVideoEncoder(pipe=pipe, device=device)
|
| 102 |
-
|
| 103 |
-
if condition_video:
|
| 104 |
-
video_path = Path(condition_video).expanduser().resolve()
|
| 105 |
-
if not video_path.exists():
|
| 106 |
-
raise FileNotFoundError(f"File not Found: {video_path}")
|
| 107 |
-
frames = encoder.load_video_frames(video_path)
|
| 108 |
-
if frames is None:
|
| 109 |
-
raise ValueError(f"no valid frames in {video_path}")
|
| 110 |
-
elif condition_image:
|
| 111 |
-
image_path = Path(condition_image).expanduser().resolve()
|
| 112 |
-
if not image_path.exists():
|
| 113 |
-
raise FileNotFoundError(f"File not Found: {image_path}")
|
| 114 |
-
frames = image_to_frame_stack(image_path, encoder, repeat_count=10)
|
| 115 |
-
else:
|
| 116 |
-
raise ValueError("condition video or image is needed for video generation.")
|
| 117 |
-
|
| 118 |
-
latents = encoder.encode_frames_to_latents(frames)
|
| 119 |
-
encoded_data = {"latents": latents}
|
| 120 |
-
|
| 121 |
-
if start_frame + num_frames > latents.shape[1]:
|
| 122 |
-
raise ValueError(
|
| 123 |
-
f"Not enough frames after encoding: requested {start_frame + num_frames}, available {latents.shape[1]}"
|
| 124 |
-
)
|
| 125 |
-
|
| 126 |
-
condition_latents = latents[:, start_frame:start_frame + num_frames, :, :]
|
| 127 |
-
return condition_latents, encoded_data
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
def compute_relative_pose_matrix(pose1, pose2):
|
| 132 |
"""
|
| 133 |
-
|
| 134 |
|
| 135 |
-
|
| 136 |
-
pose1:
|
| 137 |
-
pose2:
|
| 138 |
|
| 139 |
-
|
| 140 |
-
relative_matrix:
|
| 141 |
-
first 3 columns are rotation matrix R_rel,
|
| 142 |
-
last column is translation vector t_rel
|
| 143 |
"""
|
| 144 |
-
#
|
| 145 |
-
t1 = pose1[:3] #
|
| 146 |
-
q1 = pose1[3:] #
|
| 147 |
-
t2 = pose2[:3] #
|
| 148 |
-
q2 = pose2[3:] #
|
| 149 |
-
|
| 150 |
-
# 1.
|
| 151 |
-
rot1 = R.from_quat(q1) #
|
| 152 |
-
rot2 = R.from_quat(q2) #
|
| 153 |
-
rot_rel = rot2 * rot1.inv() #
|
| 154 |
-
R_rel = rot_rel.as_matrix() #
|
| 155 |
-
|
| 156 |
-
# 2.
|
| 157 |
-
R1_T = rot1.as_matrix().T #
|
| 158 |
-
t_rel = R1_T @ (t2 - t1) #
|
| 159 |
-
|
| 160 |
-
# 3.
|
| 161 |
relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
|
| 162 |
|
| 163 |
return relative_matrix
|
| 164 |
|
| 165 |
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 166 |
-
"""
|
| 167 |
print(f"Loading encoded video from {pth_path}")
|
| 168 |
|
| 169 |
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
|
@@ -180,10 +68,11 @@ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
|
| 180 |
|
| 181 |
return condition_latents, encoded_data
|
| 182 |
|
|
|
|
| 183 |
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 184 |
-
"""
|
| 185 |
-
assert pose_a.shape == (4, 4), f"
|
| 186 |
-
assert pose_b.shape == (4, 4), f"
|
| 187 |
|
| 188 |
if use_torch:
|
| 189 |
if not isinstance(pose_a, torch.Tensor):
|
|
@@ -206,7 +95,7 @@ def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
|
| 206 |
|
| 207 |
|
| 208 |
def replace_dit_model_in_manager():
|
| 209 |
-
"""
|
| 210 |
from diffsynth.models.wan_video_dit_moe import WanModelMoe
|
| 211 |
from diffsynth.configs.model_config import model_loader_configs
|
| 212 |
|
|
@@ -221,7 +110,7 @@ def replace_dit_model_in_manager():
|
|
| 221 |
if name == 'wan_video_dit':
|
| 222 |
new_model_names.append(name)
|
| 223 |
new_model_classes.append(WanModelMoe)
|
| 224 |
-
print(f"
|
| 225 |
else:
|
| 226 |
new_model_names.append(name)
|
| 227 |
new_model_classes.append(cls)
|
|
@@ -230,7 +119,7 @@ def replace_dit_model_in_manager():
|
|
| 230 |
|
| 231 |
|
| 232 |
def add_framepack_components(dit_model):
|
| 233 |
-
"""
|
| 234 |
if not hasattr(dit_model, 'clean_x_embedder'):
|
| 235 |
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 236 |
|
|
@@ -257,37 +146,37 @@ def add_framepack_components(dit_model):
|
|
| 257 |
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
|
| 258 |
model_dtype = next(dit_model.parameters()).dtype
|
| 259 |
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
|
| 260 |
-
print("
|
| 261 |
|
| 262 |
|
| 263 |
def add_moe_components(dit_model, moe_config):
|
| 264 |
-
"""
|
| 265 |
if not hasattr(dit_model, 'moe_config'):
|
| 266 |
dit_model.moe_config = moe_config
|
| 267 |
-
print("
|
| 268 |
dit_model.top_k = moe_config.get("top_k", 1)
|
| 269 |
|
| 270 |
-
#
|
| 271 |
dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 272 |
unified_dim = moe_config.get("unified_dim", 25)
|
| 273 |
num_experts = moe_config.get("num_experts", 4)
|
| 274 |
from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
|
| 275 |
dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
|
| 276 |
dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
|
| 277 |
-
dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX
|
| 278 |
dit_model.global_router = nn.Linear(unified_dim, num_experts)
|
| 279 |
|
| 280 |
|
| 281 |
for i, block in enumerate(dit_model.blocks):
|
| 282 |
-
# MoE
|
| 283 |
block.moe = MultiModalMoE(
|
| 284 |
unified_dim=unified_dim,
|
| 285 |
-
output_dim=dim, #
|
| 286 |
num_experts=moe_config.get("num_experts", 4),
|
| 287 |
top_k=moe_config.get("top_k", 2)
|
| 288 |
)
|
| 289 |
|
| 290 |
-
print(f"Block {i}
|
| 291 |
|
| 292 |
|
| 293 |
def generate_sekai_camera_embeddings_sliding(
|
|
@@ -299,45 +188,45 @@ def generate_sekai_camera_embeddings_sliding(
|
|
| 299 |
use_real_poses=True,
|
| 300 |
direction="left"):
|
| 301 |
"""
|
| 302 |
-
|
| 303 |
|
| 304 |
Args:
|
| 305 |
-
cam_data:
|
| 306 |
-
start_frame:
|
| 307 |
-
initial_condition_frames:
|
| 308 |
-
new_frames:
|
| 309 |
-
total_generated:
|
| 310 |
-
use_real_poses:
|
| 311 |
-
direction:
|
| 312 |
|
| 313 |
Returns:
|
| 314 |
-
camera_embedding:
|
| 315 |
"""
|
| 316 |
time_compression_ratio = 4
|
| 317 |
|
| 318 |
-
#
|
| 319 |
-
# 1
|
| 320 |
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 321 |
|
| 322 |
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
|
| 323 |
-
print("🔧
|
| 324 |
cam_extrinsic = cam_data['extrinsic']
|
| 325 |
|
| 326 |
-
#
|
| 327 |
max_needed_frames = max(
|
| 328 |
start_frame + initial_condition_frames + new_frames,
|
| 329 |
framepack_needed_frames,
|
| 330 |
30
|
| 331 |
)
|
| 332 |
|
| 333 |
-
print(f"🔧
|
| 334 |
-
print(f" -
|
| 335 |
-
print(f" - FramePack
|
| 336 |
-
print(f" -
|
| 337 |
|
| 338 |
relative_poses = []
|
| 339 |
for i in range(max_needed_frames):
|
| 340 |
-
#
|
| 341 |
frame_idx = i * time_compression_ratio
|
| 342 |
next_frame_idx = frame_idx + time_compression_ratio
|
| 343 |
|
|
@@ -347,72 +236,52 @@ def generate_sekai_camera_embeddings_sliding(
|
|
| 347 |
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 348 |
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 349 |
else:
|
| 350 |
-
#
|
| 351 |
-
print(f"⚠️
|
| 352 |
relative_poses.append(torch.zeros(3, 4))
|
| 353 |
|
| 354 |
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 355 |
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 356 |
|
| 357 |
-
#
|
| 358 |
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 359 |
-
#
|
| 360 |
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
|
| 361 |
mask[start_frame:condition_end] = 1.0
|
| 362 |
|
| 363 |
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 364 |
-
print(f"🔧 Sekai
|
| 365 |
return camera_embedding.to(torch.bfloat16)
|
| 366 |
|
| 367 |
else:
|
| 368 |
-
#
|
| 369 |
max_needed_frames = max(
|
| 370 |
start_frame + initial_condition_frames + new_frames,
|
| 371 |
framepack_needed_frames,
|
| 372 |
30)
|
| 373 |
|
| 374 |
-
print(f"🔧
|
| 375 |
|
| 376 |
CONDITION_FRAMES = initial_condition_frames
|
| 377 |
STAGE_1 = new_frames//2
|
| 378 |
STAGE_2 = new_frames - STAGE_1
|
| 379 |
|
| 380 |
-
if direction=="
|
| 381 |
-
print("--------------- FORWARD MODE ---------------")
|
| 382 |
-
relative_poses = []
|
| 383 |
-
for i in range(max_needed_frames):
|
| 384 |
-
if i < CONDITION_FRAMES:
|
| 385 |
-
# Input condition frames default to zero motion camera pose
|
| 386 |
-
pose = np.eye(4, dtype=np.float32)
|
| 387 |
-
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 388 |
-
# Forward
|
| 389 |
-
forward_speed = 0.03
|
| 390 |
-
|
| 391 |
-
pose = np.eye(4, dtype=np.float32)
|
| 392 |
-
pose[2, 3] = -forward_speed
|
| 393 |
-
else:
|
| 394 |
-
# The part beyond condition frames and target frames remains stationary
|
| 395 |
-
pose = np.eye(4, dtype=np.float32)
|
| 396 |
-
|
| 397 |
-
relative_pose = pose[:3, :]
|
| 398 |
-
relative_poses.append(torch.as_tensor(relative_pose))
|
| 399 |
-
|
| 400 |
-
elif direction=="left":
|
| 401 |
print("--------------- LEFT TURNING MODE ---------------")
|
| 402 |
relative_poses = []
|
| 403 |
for i in range(max_needed_frames):
|
| 404 |
if i < CONDITION_FRAMES:
|
| 405 |
-
#
|
| 406 |
pose = np.eye(4, dtype=np.float32)
|
| 407 |
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 408 |
-
#
|
| 409 |
yaw_per_frame = 0.03
|
| 410 |
|
| 411 |
-
#
|
| 412 |
cos_yaw = np.cos(yaw_per_frame)
|
| 413 |
sin_yaw = np.sin(yaw_per_frame)
|
| 414 |
|
| 415 |
-
#
|
| 416 |
forward_speed = 0.00
|
| 417 |
|
| 418 |
pose = np.eye(4, dtype=np.float32)
|
|
@@ -423,7 +292,7 @@ def generate_sekai_camera_embeddings_sliding(
|
|
| 423 |
pose[2, 2] = cos_yaw
|
| 424 |
pose[2, 3] = -forward_speed
|
| 425 |
else:
|
| 426 |
-
#
|
| 427 |
pose = np.eye(4, dtype=np.float32)
|
| 428 |
|
| 429 |
relative_pose = pose[:3, :]
|
|
@@ -434,17 +303,17 @@ def generate_sekai_camera_embeddings_sliding(
|
|
| 434 |
relative_poses = []
|
| 435 |
for i in range(max_needed_frames):
|
| 436 |
if i < CONDITION_FRAMES:
|
| 437 |
-
#
|
| 438 |
pose = np.eye(4, dtype=np.float32)
|
| 439 |
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 440 |
-
#
|
| 441 |
yaw_per_frame = -0.03
|
| 442 |
|
| 443 |
-
#
|
| 444 |
cos_yaw = np.cos(yaw_per_frame)
|
| 445 |
sin_yaw = np.sin(yaw_per_frame)
|
| 446 |
|
| 447 |
-
#
|
| 448 |
forward_speed = 0.00
|
| 449 |
|
| 450 |
pose = np.eye(4, dtype=np.float32)
|
|
@@ -455,7 +324,7 @@ def generate_sekai_camera_embeddings_sliding(
|
|
| 455 |
pose[2, 2] = cos_yaw
|
| 456 |
pose[2, 3] = -forward_speed
|
| 457 |
else:
|
| 458 |
-
#
|
| 459 |
pose = np.eye(4, dtype=np.float32)
|
| 460 |
|
| 461 |
relative_pose = pose[:3, :]
|
|
@@ -466,17 +335,17 @@ def generate_sekai_camera_embeddings_sliding(
|
|
| 466 |
relative_poses = []
|
| 467 |
for i in range(max_needed_frames):
|
| 468 |
if i < CONDITION_FRAMES:
|
| 469 |
-
#
|
| 470 |
pose = np.eye(4, dtype=np.float32)
|
| 471 |
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 472 |
-
#
|
| 473 |
yaw_per_frame = 0.03
|
| 474 |
|
| 475 |
-
#
|
| 476 |
cos_yaw = np.cos(yaw_per_frame)
|
| 477 |
sin_yaw = np.sin(yaw_per_frame)
|
| 478 |
|
| 479 |
-
#
|
| 480 |
forward_speed = 0.03
|
| 481 |
|
| 482 |
pose = np.eye(4, dtype=np.float32)
|
|
@@ -488,7 +357,7 @@ def generate_sekai_camera_embeddings_sliding(
|
|
| 488 |
pose[2, 3] = -forward_speed
|
| 489 |
|
| 490 |
else:
|
| 491 |
-
#
|
| 492 |
pose = np.eye(4, dtype=np.float32)
|
| 493 |
|
| 494 |
relative_pose = pose[:3, :]
|
|
@@ -499,17 +368,17 @@ def generate_sekai_camera_embeddings_sliding(
|
|
| 499 |
relative_poses = []
|
| 500 |
for i in range(max_needed_frames):
|
| 501 |
if i < CONDITION_FRAMES:
|
| 502 |
-
#
|
| 503 |
pose = np.eye(4, dtype=np.float32)
|
| 504 |
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 505 |
-
#
|
| 506 |
yaw_per_frame = -0.03
|
| 507 |
|
| 508 |
-
#
|
| 509 |
cos_yaw = np.cos(yaw_per_frame)
|
| 510 |
sin_yaw = np.sin(yaw_per_frame)
|
| 511 |
|
| 512 |
-
#
|
| 513 |
forward_speed = 0.03
|
| 514 |
|
| 515 |
pose = np.eye(4, dtype=np.float32)
|
|
@@ -521,7 +390,7 @@ def generate_sekai_camera_embeddings_sliding(
|
|
| 521 |
pose[2, 3] = -forward_speed
|
| 522 |
|
| 523 |
else:
|
| 524 |
-
#
|
| 525 |
pose = np.eye(4, dtype=np.float32)
|
| 526 |
|
| 527 |
relative_pose = pose[:3, :]
|
|
@@ -532,17 +401,17 @@ def generate_sekai_camera_embeddings_sliding(
|
|
| 532 |
relative_poses = []
|
| 533 |
for i in range(max_needed_frames):
|
| 534 |
if i < CONDITION_FRAMES:
|
| 535 |
-
#
|
| 536 |
pose = np.eye(4, dtype=np.float32)
|
| 537 |
elif i < CONDITION_FRAMES+STAGE_1:
|
| 538 |
-
#
|
| 539 |
yaw_per_frame = 0.03
|
| 540 |
|
| 541 |
-
#
|
| 542 |
cos_yaw = np.cos(yaw_per_frame)
|
| 543 |
sin_yaw = np.sin(yaw_per_frame)
|
| 544 |
|
| 545 |
-
#
|
| 546 |
forward_speed = 0.03
|
| 547 |
|
| 548 |
pose = np.eye(4, dtype=np.float32)
|
|
@@ -554,16 +423,16 @@ def generate_sekai_camera_embeddings_sliding(
|
|
| 554 |
pose[2, 3] = -forward_speed
|
| 555 |
|
| 556 |
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 557 |
-
#
|
| 558 |
yaw_per_frame = -0.03
|
| 559 |
|
| 560 |
-
#
|
| 561 |
cos_yaw = np.cos(yaw_per_frame)
|
| 562 |
sin_yaw = np.sin(yaw_per_frame)
|
| 563 |
|
| 564 |
-
#
|
| 565 |
forward_speed = 0.03
|
| 566 |
-
#
|
| 567 |
if i < CONDITION_FRAMES+STAGE_1+STAGE_2//3:
|
| 568 |
radius_shift = -0.01
|
| 569 |
else:
|
|
@@ -579,7 +448,7 @@ def generate_sekai_camera_embeddings_sliding(
|
|
| 579 |
pose[0, 3] = radius_shift
|
| 580 |
|
| 581 |
else:
|
| 582 |
-
#
|
| 583 |
pose = np.eye(4, dtype=np.float32)
|
| 584 |
|
| 585 |
relative_pose = pose[:3, :]
|
|
@@ -590,17 +459,17 @@ def generate_sekai_camera_embeddings_sliding(
|
|
| 590 |
relative_poses = []
|
| 591 |
for i in range(max_needed_frames):
|
| 592 |
if i < CONDITION_FRAMES:
|
| 593 |
-
#
|
| 594 |
pose = np.eye(4, dtype=np.float32)
|
| 595 |
elif i < CONDITION_FRAMES+STAGE_1:
|
| 596 |
-
#
|
| 597 |
yaw_per_frame = 0.03
|
| 598 |
|
| 599 |
-
#
|
| 600 |
cos_yaw = np.cos(yaw_per_frame)
|
| 601 |
sin_yaw = np.sin(yaw_per_frame)
|
| 602 |
|
| 603 |
-
#
|
| 604 |
forward_speed = 0.00
|
| 605 |
|
| 606 |
pose = np.eye(4, dtype=np.float32)
|
|
@@ -612,14 +481,14 @@ def generate_sekai_camera_embeddings_sliding(
|
|
| 612 |
pose[2, 3] = -forward_speed
|
| 613 |
|
| 614 |
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 615 |
-
#
|
| 616 |
yaw_per_frame = -0.03
|
| 617 |
|
| 618 |
-
#
|
| 619 |
cos_yaw = np.cos(yaw_per_frame)
|
| 620 |
sin_yaw = np.sin(yaw_per_frame)
|
| 621 |
|
| 622 |
-
#
|
| 623 |
forward_speed = 0.00
|
| 624 |
|
| 625 |
pose = np.eye(4, dtype=np.float32)
|
|
@@ -631,55 +500,55 @@ def generate_sekai_camera_embeddings_sliding(
|
|
| 631 |
pose[2, 3] = -forward_speed
|
| 632 |
|
| 633 |
else:
|
| 634 |
-
#
|
| 635 |
pose = np.eye(4, dtype=np.float32)
|
| 636 |
|
| 637 |
relative_pose = pose[:3, :]
|
| 638 |
relative_poses.append(torch.as_tensor(relative_pose))
|
| 639 |
|
| 640 |
else:
|
| 641 |
-
raise ValueError(f"
|
| 642 |
|
| 643 |
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 644 |
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 645 |
|
| 646 |
-
#
|
| 647 |
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 648 |
condition_end = min(start_frame + initial_condition_frames + 1, max_needed_frames)
|
| 649 |
mask[start_frame:condition_end] = 1.0
|
| 650 |
|
| 651 |
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 652 |
-
print(f"🔧 Sekai
|
| 653 |
return camera_embedding.to(torch.bfloat16)
|
| 654 |
|
| 655 |
|
| 656 |
def generate_openx_camera_embeddings_sliding(
|
| 657 |
encoded_data, start_frame, initial_condition_frames, new_frames, use_real_poses):
|
| 658 |
-
"""
|
| 659 |
time_compression_ratio = 4
|
| 660 |
|
| 661 |
-
#
|
| 662 |
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 663 |
|
| 664 |
if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
|
| 665 |
-
print("🔧
|
| 666 |
cam_extrinsic = encoded_data['cam_emb']['extrinsic']
|
| 667 |
|
| 668 |
-
#
|
| 669 |
max_needed_frames = max(
|
| 670 |
start_frame + initial_condition_frames + new_frames,
|
| 671 |
framepack_needed_frames,
|
| 672 |
30
|
| 673 |
)
|
| 674 |
|
| 675 |
-
print(f"🔧
|
| 676 |
-
print(f" -
|
| 677 |
-
print(f" - FramePack
|
| 678 |
-
print(f" -
|
| 679 |
|
| 680 |
relative_poses = []
|
| 681 |
for i in range(max_needed_frames):
|
| 682 |
-
# OpenX
|
| 683 |
frame_idx = i * time_compression_ratio
|
| 684 |
next_frame_idx = frame_idx + time_compression_ratio
|
| 685 |
|
|
@@ -689,25 +558,25 @@ def generate_openx_camera_embeddings_sliding(
|
|
| 689 |
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 690 |
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 691 |
else:
|
| 692 |
-
#
|
| 693 |
-
print(f"⚠️
|
| 694 |
relative_poses.append(torch.zeros(3, 4))
|
| 695 |
|
| 696 |
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 697 |
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 698 |
|
| 699 |
-
#
|
| 700 |
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 701 |
-
#
|
| 702 |
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
|
| 703 |
mask[start_frame:condition_end] = 1.0
|
| 704 |
|
| 705 |
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 706 |
-
print(f"🔧 OpenX
|
| 707 |
return camera_embedding.to(torch.bfloat16)
|
| 708 |
|
| 709 |
else:
|
| 710 |
-
print("🔧
|
| 711 |
|
| 712 |
max_needed_frames = max(
|
| 713 |
start_frame + initial_condition_frames + new_frames,
|
|
@@ -715,30 +584,30 @@ def generate_openx_camera_embeddings_sliding(
|
|
| 715 |
30
|
| 716 |
)
|
| 717 |
|
| 718 |
-
print(f"🔧
|
| 719 |
relative_poses = []
|
| 720 |
for i in range(max_needed_frames):
|
| 721 |
-
# OpenX
|
| 722 |
-
#
|
| 723 |
-
roll_per_frame = 0.02 #
|
| 724 |
-
pitch_per_frame = 0.01 #
|
| 725 |
-
yaw_per_frame = 0.015 #
|
| 726 |
-
forward_speed = 0.003 #
|
| 727 |
|
| 728 |
pose = np.eye(4, dtype=np.float32)
|
| 729 |
|
| 730 |
-
#
|
| 731 |
-
#
|
| 732 |
cos_roll = np.cos(roll_per_frame)
|
| 733 |
sin_roll = np.sin(roll_per_frame)
|
| 734 |
-
#
|
| 735 |
cos_pitch = np.cos(pitch_per_frame)
|
| 736 |
sin_pitch = np.sin(pitch_per_frame)
|
| 737 |
-
#
|
| 738 |
cos_yaw = np.cos(yaw_per_frame)
|
| 739 |
sin_yaw = np.sin(yaw_per_frame)
|
| 740 |
|
| 741 |
-
#
|
| 742 |
pose[0, 0] = cos_yaw * cos_pitch
|
| 743 |
pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
|
| 744 |
pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
|
|
@@ -749,10 +618,10 @@ def generate_openx_camera_embeddings_sliding(
|
|
| 749 |
pose[2, 1] = cos_pitch * sin_roll
|
| 750 |
pose[2, 2] = cos_pitch * cos_roll
|
| 751 |
|
| 752 |
-
#
|
| 753 |
-
pose[0, 3] = forward_speed * 0.5 #
|
| 754 |
-
pose[1, 3] = forward_speed * 0.3 #
|
| 755 |
-
pose[2, 3] = -forward_speed #
|
| 756 |
|
| 757 |
relative_pose = pose[:3, :]
|
| 758 |
relative_poses.append(torch.as_tensor(relative_pose))
|
|
@@ -760,34 +629,30 @@ def generate_openx_camera_embeddings_sliding(
|
|
| 760 |
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 761 |
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 762 |
|
| 763 |
-
#
|
| 764 |
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 765 |
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
|
| 766 |
mask[start_frame:condition_end] = 1.0
|
| 767 |
|
| 768 |
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 769 |
-
print(f"🔧 OpenX
|
| 770 |
return camera_embedding.to(torch.bfloat16)
|
| 771 |
|
| 772 |
|
| 773 |
def generate_nuscenes_camera_embeddings_sliding(
|
| 774 |
scene_info, start_frame, initial_condition_frames, new_frames):
|
| 775 |
-
"""
|
| 776 |
-
Generate camera embeddings for NuScenes dataset - sliding window version
|
| 777 |
-
|
| 778 |
-
corrected version, consistent with train_moe.py
|
| 779 |
-
"""
|
| 780 |
time_compression_ratio = 4
|
| 781 |
|
| 782 |
-
#
|
| 783 |
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 784 |
|
| 785 |
if scene_info is not None and 'keyframe_poses' in scene_info:
|
| 786 |
-
print("🔧
|
| 787 |
keyframe_poses = scene_info['keyframe_poses']
|
| 788 |
|
| 789 |
if len(keyframe_poses) == 0:
|
| 790 |
-
print("⚠️ NuScenes keyframe_poses
|
| 791 |
max_needed_frames = max(framepack_needed_frames, 30)
|
| 792 |
|
| 793 |
pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
|
|
@@ -797,10 +662,10 @@ def generate_nuscenes_camera_embeddings_sliding(
|
|
| 797 |
mask[start_frame:condition_end] = 1.0
|
| 798 |
|
| 799 |
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 800 |
-
print(f"🔧 NuScenes
|
| 801 |
return camera_embedding.to(torch.bfloat16)
|
| 802 |
|
| 803 |
-
#
|
| 804 |
reference_pose = keyframe_poses[0]
|
| 805 |
|
| 806 |
max_needed_frames = max(framepack_needed_frames, 30)
|
|
@@ -810,18 +675,18 @@ def generate_nuscenes_camera_embeddings_sliding(
|
|
| 810 |
if i < len(keyframe_poses):
|
| 811 |
current_pose = keyframe_poses[i]
|
| 812 |
|
| 813 |
-
#
|
| 814 |
translation = torch.tensor(
|
| 815 |
np.array(current_pose['translation']) - np.array(reference_pose['translation']),
|
| 816 |
dtype=torch.float32
|
| 817 |
)
|
| 818 |
|
| 819 |
-
#
|
| 820 |
rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
|
| 821 |
|
| 822 |
pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
|
| 823 |
else:
|
| 824 |
-
#
|
| 825 |
pose_vec = torch.cat([
|
| 826 |
torch.zeros(3, dtype=torch.float32),
|
| 827 |
torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
|
|
@@ -831,41 +696,41 @@ def generate_nuscenes_camera_embeddings_sliding(
|
|
| 831 |
|
| 832 |
pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
|
| 833 |
|
| 834 |
-
#
|
| 835 |
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 836 |
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
|
| 837 |
mask[start_frame:condition_end] = 1.0
|
| 838 |
|
| 839 |
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 840 |
-
print(f"🔧 NuScenes
|
| 841 |
return camera_embedding.to(torch.bfloat16)
|
| 842 |
|
| 843 |
else:
|
| 844 |
-
print("🔧
|
| 845 |
max_needed_frames = max(framepack_needed_frames, 30)
|
| 846 |
|
| 847 |
-
#
|
| 848 |
pose_vecs = []
|
| 849 |
for i in range(max_needed_frames):
|
| 850 |
-
#
|
| 851 |
-
angle = i * 0.04 #
|
| 852 |
-
radius = 15.0 #
|
| 853 |
|
| 854 |
-
#
|
| 855 |
x = radius * np.sin(angle)
|
| 856 |
-
y = 0.0 #
|
| 857 |
z = radius * (1 - np.cos(angle))
|
| 858 |
|
| 859 |
translation = torch.tensor([x, y, z], dtype=torch.float32)
|
| 860 |
|
| 861 |
-
#
|
| 862 |
-
yaw = angle + np.pi/2 #
|
| 863 |
-
#
|
| 864 |
rotation = torch.tensor([
|
| 865 |
-
np.cos(yaw/2), # w (
|
| 866 |
0.0, # x
|
| 867 |
0.0, # y
|
| 868 |
-
np.sin(yaw/2) # z (
|
| 869 |
], dtype=torch.float32)
|
| 870 |
|
| 871 |
pose_vec = torch.cat([translation, rotation], dim=0) # [7D: tx,ty,tz,qw,qx,qy,qz]
|
|
@@ -873,15 +738,15 @@ def generate_nuscenes_camera_embeddings_sliding(
|
|
| 873 |
|
| 874 |
pose_sequence = torch.stack(pose_vecs, dim=0)
|
| 875 |
|
| 876 |
-
#
|
| 877 |
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 878 |
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
|
| 879 |
mask[start_frame:condition_end] = 1.0
|
| 880 |
|
| 881 |
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 882 |
-
print(f"🔧 NuScenes
|
| 883 |
return camera_embedding.to(torch.bfloat16)
|
| 884 |
-
|
| 885 |
def prepare_framepack_sliding_window_with_camera_moe(
|
| 886 |
history_latents,
|
| 887 |
target_frames_to_generate,
|
|
@@ -889,12 +754,12 @@ def prepare_framepack_sliding_window_with_camera_moe(
|
|
| 889 |
start_frame,
|
| 890 |
modality_type,
|
| 891 |
max_history_frames=49):
|
| 892 |
-
"""FramePack
|
| 893 |
-
# history_latents: [C, T, H, W]
|
| 894 |
C, T, H, W = history_latents.shape
|
| 895 |
|
| 896 |
-
#
|
| 897 |
-
# 1
|
| 898 |
total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
|
| 899 |
indices = torch.arange(0, total_indices_length)
|
| 900 |
split_sizes = [1, 16, 2, 1, target_frames_to_generate]
|
|
@@ -902,44 +767,44 @@ def prepare_framepack_sliding_window_with_camera_moe(
|
|
| 902 |
indices.split(split_sizes, dim=0)
|
| 903 |
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
|
| 904 |
|
| 905 |
-
#
|
| 906 |
if camera_embedding_full.shape[0] < total_indices_length:
|
| 907 |
-
print(f"⚠️ camera_embedding
|
| 908 |
shortage = total_indices_length - camera_embedding_full.shape[0]
|
| 909 |
padding = torch.zeros(shortage, camera_embedding_full.shape[1],
|
| 910 |
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
|
| 911 |
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
|
| 912 |
|
| 913 |
-
#
|
| 914 |
combined_camera = torch.zeros(
|
| 915 |
total_indices_length,
|
| 916 |
camera_embedding_full.shape[1],
|
| 917 |
dtype=camera_embedding_full.dtype,
|
| 918 |
device=camera_embedding_full.device)
|
| 919 |
|
| 920 |
-
#
|
| 921 |
history_slice = camera_embedding_full[max(T - 19, 0):T, :].clone()
|
| 922 |
combined_camera[19 - history_slice.shape[0]:19, :] = history_slice
|
| 923 |
|
| 924 |
-
#
|
| 925 |
target_slice = camera_embedding_full[T:T + target_frames_to_generate, :].clone()
|
| 926 |
combined_camera[19:19 + target_slice.shape[0], :] = target_slice
|
| 927 |
|
| 928 |
-
#
|
| 929 |
-
combined_camera[:, -1] = 0.0 #
|
| 930 |
|
| 931 |
-
#
|
| 932 |
if T > 0:
|
| 933 |
available_frames = min(T, 19)
|
| 934 |
start_pos = 19 - available_frames
|
| 935 |
-
combined_camera[start_pos:19, -1] = 1.0 #
|
| 936 |
|
| 937 |
-
print(f"🔧 MoE Camera mask
|
| 938 |
-
print(f" -
|
| 939 |
-
print(f" -
|
| 940 |
-
print(f" -
|
| 941 |
|
| 942 |
-
#
|
| 943 |
clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 944 |
|
| 945 |
if T > 0:
|
|
@@ -967,54 +832,54 @@ def prepare_framepack_sliding_window_with_camera_moe(
|
|
| 967 |
'clean_latent_2x_indices': clean_latent_2x_indices,
|
| 968 |
'clean_latent_4x_indices': clean_latent_4x_indices,
|
| 969 |
'camera_embedding': combined_camera,
|
| 970 |
-
'modality_type': modality_type, #
|
| 971 |
'current_length': T,
|
| 972 |
'next_length': T + target_frames_to_generate
|
| 973 |
}
|
| 974 |
|
| 975 |
def overlay_controls(frame_img, pose_vec, icons):
|
| 976 |
"""
|
| 977 |
-
|
| 978 |
-
pose_vec: 12
|
| 979 |
"""
|
| 980 |
if pose_vec is None or np.all(pose_vec[:12] == 0):
|
| 981 |
return frame_img
|
| 982 |
|
| 983 |
-
#
|
| 984 |
# [r00, r01, r02, tx, r10, r11, r12, ty, r20, r21, r22, tz]
|
| 985 |
tx = pose_vec[3]
|
| 986 |
# ty = pose_vec[7]
|
| 987 |
tz = pose_vec[11]
|
| 988 |
|
| 989 |
-
#
|
| 990 |
-
#
|
| 991 |
r00 = pose_vec[0]
|
| 992 |
r02 = pose_vec[2]
|
| 993 |
yaw = np.arctan2(r02, r00)
|
| 994 |
|
| 995 |
-
#
|
| 996 |
r12 = pose_vec[6]
|
| 997 |
r22 = pose_vec[10]
|
| 998 |
pitch = np.arctan2(-r12, r22)
|
| 999 |
|
| 1000 |
-
#
|
| 1001 |
TRANS_THRESH = 0.01
|
| 1002 |
ROT_THRESH = 0.005
|
| 1003 |
|
| 1004 |
-
#
|
| 1005 |
-
#
|
| 1006 |
-
#
|
| 1007 |
is_forward = tz < -TRANS_THRESH
|
| 1008 |
is_backward = tz > TRANS_THRESH
|
| 1009 |
is_left = tx < -TRANS_THRESH
|
| 1010 |
is_right = tx > TRANS_THRESH
|
| 1011 |
|
| 1012 |
-
#
|
| 1013 |
-
#
|
| 1014 |
is_turn_left = yaw > ROT_THRESH
|
| 1015 |
is_turn_right = yaw < -ROT_THRESH
|
| 1016 |
|
| 1017 |
-
#
|
| 1018 |
is_turn_up = pitch < -ROT_THRESH
|
| 1019 |
is_turn_down = pitch > ROT_THRESH
|
| 1020 |
|
|
@@ -1025,10 +890,10 @@ def overlay_controls(frame_img, pose_vec, icons):
|
|
| 1025 |
name = name_active if is_active else name_inactive
|
| 1026 |
if name in icons:
|
| 1027 |
icon = icons[name]
|
| 1028 |
-
#
|
| 1029 |
frame_img.paste(icon, (int(x), int(y)), icon)
|
| 1030 |
|
| 1031 |
-
#
|
| 1032 |
base_x_right = 100
|
| 1033 |
base_y = H - 100
|
| 1034 |
|
|
@@ -1041,7 +906,7 @@ def overlay_controls(frame_img, pose_vec, icons):
|
|
| 1041 |
# D
|
| 1042 |
paste_icon('move_right.png', 'not_move_right.png', is_right, base_x_right + spacing, base_y)
|
| 1043 |
|
| 1044 |
-
#
|
| 1045 |
base_x_left = W - 150
|
| 1046 |
|
| 1047 |
# ↑
|
|
@@ -1057,11 +922,8 @@ def overlay_controls(frame_img, pose_vec, icons):
|
|
| 1057 |
|
| 1058 |
|
| 1059 |
def inference_moe_framepack_sliding_window(
|
| 1060 |
-
condition_pth_path
|
| 1061 |
-
|
| 1062 |
-
condition_image=None,
|
| 1063 |
-
dit_path=None,
|
| 1064 |
-
wan_model_path=None,
|
| 1065 |
output_path="../examples/output_videos/output_moe_framepack_sliding.mp4",
|
| 1066 |
start_frame=0,
|
| 1067 |
initial_condition_frames=8,
|
|
@@ -1070,14 +932,14 @@ def inference_moe_framepack_sliding_window(
|
|
| 1070 |
max_history_frames=49,
|
| 1071 |
device="cuda",
|
| 1072 |
prompt="A video of a scene shot using a pedestrian's front camera while walking",
|
| 1073 |
-
modality_type="sekai", # "sekai"
|
| 1074 |
use_real_poses=True,
|
| 1075 |
-
scene_info_path=None, #
|
| 1076 |
-
# CFG
|
| 1077 |
use_camera_cfg=True,
|
| 1078 |
camera_guidance_scale=2.0,
|
| 1079 |
text_guidance_scale=1.0,
|
| 1080 |
-
# MoE
|
| 1081 |
moe_num_experts=4,
|
| 1082 |
moe_top_k=2,
|
| 1083 |
moe_hidden_dim=None,
|
|
@@ -1086,30 +948,30 @@ def inference_moe_framepack_sliding_window(
|
|
| 1086 |
add_icons=False
|
| 1087 |
):
|
| 1088 |
"""
|
| 1089 |
-
MoE FramePack
|
| 1090 |
"""
|
| 1091 |
-
#
|
| 1092 |
dir_path = os.path.dirname(output_path)
|
| 1093 |
os.makedirs(dir_path, exist_ok=True)
|
| 1094 |
|
| 1095 |
-
print(f"🔧
|
| 1096 |
-
print(f"
|
| 1097 |
-
print(f"
|
| 1098 |
-
print(f"
|
| 1099 |
-
print(f"
|
| 1100 |
|
| 1101 |
-
# 1.
|
| 1102 |
replace_dit_model_in_manager()
|
| 1103 |
|
| 1104 |
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 1105 |
model_manager.load_models([
|
| 1106 |
-
|
| 1107 |
-
|
| 1108 |
-
|
| 1109 |
])
|
| 1110 |
-
pipe =
|
| 1111 |
|
| 1112 |
-
# 2.
|
| 1113 |
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 1114 |
for block in pipe.dit.blocks:
|
| 1115 |
block.cam_encoder = nn.Linear(13, dim)
|
|
@@ -1119,45 +981,41 @@ def inference_moe_framepack_sliding_window(
|
|
| 1119 |
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 1120 |
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 1121 |
|
| 1122 |
-
# 3.
|
| 1123 |
add_framepack_components(pipe.dit)
|
| 1124 |
|
| 1125 |
-
# 4.
|
| 1126 |
moe_config = {
|
| 1127 |
"num_experts": moe_num_experts,
|
| 1128 |
"top_k": moe_top_k,
|
| 1129 |
"hidden_dim": moe_hidden_dim or dim * 2,
|
| 1130 |
-
"sekai_input_dim": 13, # Sekai: 12
|
| 1131 |
-
"nuscenes_input_dim": 8, # NuScenes: 7
|
| 1132 |
-
"openx_input_dim": 13 # OpenX: 12
|
| 1133 |
}
|
| 1134 |
add_moe_components(pipe.dit, moe_config)
|
| 1135 |
|
| 1136 |
-
# 5.
|
| 1137 |
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 1138 |
-
pipe.dit.load_state_dict(dit_state_dict, strict=False) #
|
| 1139 |
pipe = pipe.to(device)
|
| 1140 |
model_dtype = next(pipe.dit.parameters()).dtype
|
| 1141 |
|
| 1142 |
if hasattr(pipe.dit, 'clean_x_embedder'):
|
| 1143 |
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
|
| 1144 |
|
| 1145 |
-
#
|
| 1146 |
pipe.scheduler.set_timesteps(50)
|
| 1147 |
|
| 1148 |
-
# 6.
|
| 1149 |
print("Loading initial condition frames...")
|
| 1150 |
-
initial_latents, encoded_data =
|
| 1151 |
-
condition_pth_path,
|
| 1152 |
-
|
| 1153 |
-
|
| 1154 |
-
start_frame,
|
| 1155 |
-
initial_condition_frames,
|
| 1156 |
-
device,
|
| 1157 |
-
pipe,
|
| 1158 |
)
|
| 1159 |
|
| 1160 |
-
#
|
| 1161 |
target_height, target_width = 60, 104
|
| 1162 |
C, T, H, W = initial_latents.shape
|
| 1163 |
|
|
@@ -1169,50 +1027,50 @@ def inference_moe_framepack_sliding_window(
|
|
| 1169 |
|
| 1170 |
history_latents = initial_latents.to(device, dtype=model_dtype)
|
| 1171 |
|
| 1172 |
-
print(f"
|
| 1173 |
|
| 1174 |
-
# 7.
|
| 1175 |
if use_gt_prompt and 'prompt_emb' in encoded_data:
|
| 1176 |
-
print("✅
|
| 1177 |
prompt_emb_pos = encoded_data['prompt_emb']
|
| 1178 |
-
#
|
| 1179 |
if 'context' in prompt_emb_pos:
|
| 1180 |
prompt_emb_pos['context'] = prompt_emb_pos['context'].to(device, dtype=model_dtype)
|
| 1181 |
if 'context_mask' in prompt_emb_pos:
|
| 1182 |
prompt_emb_pos['context_mask'] = prompt_emb_pos['context_mask'].to(device, dtype=model_dtype)
|
| 1183 |
|
| 1184 |
-
#
|
| 1185 |
if text_guidance_scale > 1.0:
|
| 1186 |
prompt_emb_neg = pipe.encode_prompt("")
|
| 1187 |
-
print(f"
|
| 1188 |
else:
|
| 1189 |
prompt_emb_neg = None
|
| 1190 |
-
print("
|
| 1191 |
|
| 1192 |
-
#
|
| 1193 |
if 'prompt' in encoded_data['prompt_emb']:
|
| 1194 |
gt_prompt_text = encoded_data['prompt_emb']['prompt']
|
| 1195 |
-
print(f"📝 GT Prompt
|
| 1196 |
else:
|
| 1197 |
-
#
|
| 1198 |
-
print(f"🔄
|
| 1199 |
if text_guidance_scale > 1.0:
|
| 1200 |
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 1201 |
prompt_emb_neg = pipe.encode_prompt("")
|
| 1202 |
-
print(f"
|
| 1203 |
else:
|
| 1204 |
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 1205 |
prompt_emb_neg = None
|
| 1206 |
-
print("
|
| 1207 |
|
| 1208 |
-
# 8.
|
| 1209 |
scene_info = None
|
| 1210 |
if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
|
| 1211 |
with open(scene_info_path, 'r') as f:
|
| 1212 |
scene_info = json.load(f)
|
| 1213 |
-
print(f"
|
| 1214 |
|
| 1215 |
-
# 9.
|
| 1216 |
if modality_type == "sekai":
|
| 1217 |
camera_embedding_full = generate_sekai_camera_embeddings_sliding(
|
| 1218 |
encoded_data.get('cam_emb', None),
|
|
@@ -1239,25 +1097,25 @@ def inference_moe_framepack_sliding_window(
|
|
| 1239 |
use_real_poses=use_real_poses
|
| 1240 |
).to(device, dtype=model_dtype)
|
| 1241 |
else:
|
| 1242 |
-
raise ValueError(f"
|
| 1243 |
|
| 1244 |
-
print(f"
|
| 1245 |
|
| 1246 |
-
# 10.
|
| 1247 |
if use_camera_cfg:
|
| 1248 |
camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
|
| 1249 |
-
print(f"
|
| 1250 |
|
| 1251 |
-
# 11.
|
| 1252 |
total_generated = 0
|
| 1253 |
all_generated_frames = []
|
| 1254 |
|
| 1255 |
while total_generated < total_frames_to_generate:
|
| 1256 |
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
|
| 1257 |
-
print(f"\
|
| 1258 |
-
print(f"
|
| 1259 |
|
| 1260 |
-
# FramePack
|
| 1261 |
framepack_data = prepare_framepack_sliding_window_with_camera_moe(
|
| 1262 |
history_latents,
|
| 1263 |
current_generation,
|
|
@@ -1267,27 +1125,27 @@ def inference_moe_framepack_sliding_window(
|
|
| 1267 |
max_history_frames
|
| 1268 |
)
|
| 1269 |
|
| 1270 |
-
#
|
| 1271 |
clean_latents = framepack_data['clean_latents'].unsqueeze(0)
|
| 1272 |
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
|
| 1273 |
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
|
| 1274 |
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
|
| 1275 |
|
| 1276 |
-
#
|
| 1277 |
modality_inputs = {modality_type: camera_embedding}
|
| 1278 |
|
| 1279 |
-
#
|
| 1280 |
if use_camera_cfg:
|
| 1281 |
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
|
| 1282 |
modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
|
| 1283 |
|
| 1284 |
-
#
|
| 1285 |
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
|
| 1286 |
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
|
| 1287 |
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
|
| 1288 |
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
|
| 1289 |
|
| 1290 |
-
#
|
| 1291 |
new_latents = torch.randn(
|
| 1292 |
1, C, current_generation, H, W,
|
| 1293 |
device=device, dtype=model_dtype
|
|
@@ -1296,26 +1154,26 @@ def inference_moe_framepack_sliding_window(
|
|
| 1296 |
extra_input = pipe.prepare_extra_input(new_latents)
|
| 1297 |
|
| 1298 |
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 1299 |
-
print(f"Camera mask
|
| 1300 |
|
| 1301 |
-
#
|
| 1302 |
timesteps = pipe.scheduler.timesteps
|
| 1303 |
|
| 1304 |
for i, timestep in enumerate(timesteps):
|
| 1305 |
if i % 10 == 0:
|
| 1306 |
-
print(f"
|
| 1307 |
|
| 1308 |
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
|
| 1309 |
|
| 1310 |
with torch.no_grad():
|
| 1311 |
-
# CFG
|
| 1312 |
if use_camera_cfg and camera_guidance_scale > 1.0:
|
| 1313 |
-
#
|
| 1314 |
noise_pred_cond, moe_loess = pipe.dit(
|
| 1315 |
new_latents,
|
| 1316 |
timestep=timestep_tensor,
|
| 1317 |
cam_emb=camera_embedding,
|
| 1318 |
-
modality_inputs=modality_inputs, # MoE
|
| 1319 |
latent_indices=latent_indices,
|
| 1320 |
clean_latents=clean_latents,
|
| 1321 |
clean_latent_indices=clean_latent_indices,
|
|
@@ -1327,12 +1185,12 @@ def inference_moe_framepack_sliding_window(
|
|
| 1327 |
**extra_input
|
| 1328 |
)
|
| 1329 |
|
| 1330 |
-
#
|
| 1331 |
noise_pred_uncond, moe_loess = pipe.dit(
|
| 1332 |
new_latents,
|
| 1333 |
timestep=timestep_tensor,
|
| 1334 |
cam_emb=camera_embedding_uncond_batch,
|
| 1335 |
-
modality_inputs=modality_inputs_uncond, # MoE
|
| 1336 |
latent_indices=latent_indices,
|
| 1337 |
clean_latents=clean_latents,
|
| 1338 |
clean_latent_indices=clean_latent_indices,
|
|
@@ -1347,7 +1205,7 @@ def inference_moe_framepack_sliding_window(
|
|
| 1347 |
# Camera CFG
|
| 1348 |
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 1349 |
|
| 1350 |
-
#
|
| 1351 |
if text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 1352 |
noise_pred_text_uncond, moe_loess = pipe.dit(
|
| 1353 |
new_latents,
|
|
@@ -1365,11 +1223,11 @@ def inference_moe_framepack_sliding_window(
|
|
| 1365 |
**extra_input
|
| 1366 |
)
|
| 1367 |
|
| 1368 |
-
#
|
| 1369 |
noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
|
| 1370 |
|
| 1371 |
elif text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 1372 |
-
#
|
| 1373 |
noise_pred_cond, moe_loess = pipe.dit(
|
| 1374 |
new_latents,
|
| 1375 |
timestep=timestep_tensor,
|
|
@@ -1405,12 +1263,12 @@ def inference_moe_framepack_sliding_window(
|
|
| 1405 |
noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 1406 |
|
| 1407 |
else:
|
| 1408 |
-
#
|
| 1409 |
noise_pred, moe_loess = pipe.dit(
|
| 1410 |
new_latents,
|
| 1411 |
timestep=timestep_tensor,
|
| 1412 |
cam_emb=camera_embedding,
|
| 1413 |
-
modality_inputs=modality_inputs, # MoE
|
| 1414 |
latent_indices=latent_indices,
|
| 1415 |
clean_latents=clean_latents,
|
| 1416 |
clean_latent_indices=clean_latent_indices,
|
|
@@ -1424,31 +1282,31 @@ def inference_moe_framepack_sliding_window(
|
|
| 1424 |
|
| 1425 |
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
|
| 1426 |
|
| 1427 |
-
#
|
| 1428 |
new_latents_squeezed = new_latents.squeeze(0)
|
| 1429 |
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
|
| 1430 |
|
| 1431 |
-
#
|
| 1432 |
if history_latents.shape[1] > max_history_frames:
|
| 1433 |
first_frame = history_latents[:, 0:1, :, :]
|
| 1434 |
recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
|
| 1435 |
history_latents = torch.cat([first_frame, recent_frames], dim=1)
|
| 1436 |
-
print(f"
|
| 1437 |
|
| 1438 |
-
print(f"
|
| 1439 |
|
| 1440 |
all_generated_frames.append(new_latents_squeezed)
|
| 1441 |
total_generated += current_generation
|
| 1442 |
|
| 1443 |
-
print(f"✅
|
| 1444 |
|
| 1445 |
-
# 12.
|
| 1446 |
-
print("\
|
| 1447 |
|
| 1448 |
all_generated = torch.cat(all_generated_frames, dim=1)
|
| 1449 |
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
|
| 1450 |
|
| 1451 |
-
print(f"
|
| 1452 |
|
| 1453 |
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 1454 |
|
|
@@ -1461,7 +1319,7 @@ def inference_moe_framepack_sliding_window(
|
|
| 1461 |
icons = {}
|
| 1462 |
video_camera_poses = None
|
| 1463 |
if add_icons:
|
| 1464 |
-
#
|
| 1465 |
icons_dir = os.path.join(ROOT_DIR, 'icons')
|
| 1466 |
icon_names = ['move_forward.png', 'not_move_forward.png',
|
| 1467 |
'move_backward.png', 'not_move_backward.png',
|
|
@@ -1476,15 +1334,15 @@ def inference_moe_framepack_sliding_window(
|
|
| 1476 |
if os.path.exists(path):
|
| 1477 |
try:
|
| 1478 |
icon = Image.open(path).convert("RGBA")
|
| 1479 |
-
#
|
| 1480 |
icon = icon.resize((50, 50), Image.Resampling.LANCZOS)
|
| 1481 |
icons[name] = icon
|
| 1482 |
except Exception as e:
|
| 1483 |
print(f"Error loading icon {name}: {e}")
|
| 1484 |
else:
|
| 1485 |
-
print(f"
|
| 1486 |
|
| 1487 |
-
#
|
| 1488 |
time_compression_ratio = 4
|
| 1489 |
camera_poses = camera_embedding_full.detach().float().cpu().numpy()
|
| 1490 |
video_camera_poses = [x for x in camera_poses for _ in range(time_compression_ratio)]
|
|
@@ -1503,108 +1361,74 @@ def inference_moe_framepack_sliding_window(
|
|
| 1503 |
|
| 1504 |
writer.append_data(np.array(img))
|
| 1505 |
|
| 1506 |
-
print(f"
|
| 1507 |
-
print(f"
|
| 1508 |
-
print(f"
|
| 1509 |
|
| 1510 |
|
| 1511 |
def main():
|
| 1512 |
-
parser = argparse.ArgumentParser(description="MoE FramePack
|
| 1513 |
-
|
| 1514 |
-
#
|
| 1515 |
-
parser.add_argument("--condition_pth",
|
| 1516 |
-
|
| 1517 |
-
default=None,
|
| 1518 |
-
help="Path to pre-encoded condition pth file")
|
| 1519 |
-
parser.add_argument("--condition_video",
|
| 1520 |
-
type=str,
|
| 1521 |
-
default=None,
|
| 1522 |
-
help="Input video for novel view synthesis.")
|
| 1523 |
-
parser.add_argument("--condition_image",
|
| 1524 |
-
type=str,
|
| 1525 |
-
default=None,
|
| 1526 |
-
required=True,
|
| 1527 |
-
help="Input image for novel view synthesis.")
|
| 1528 |
parser.add_argument("--start_frame", type=int, default=0)
|
| 1529 |
parser.add_argument("--initial_condition_frames", type=int, default=1)
|
| 1530 |
parser.add_argument("--frames_per_generation", type=int, default=8)
|
| 1531 |
parser.add_argument("--total_frames_to_generate", type=int, default=24)
|
| 1532 |
parser.add_argument("--max_history_frames", type=int, default=100)
|
| 1533 |
parser.add_argument("--use_real_poses", default=False)
|
| 1534 |
-
parser.add_argument("--dit_path", type=str,
|
| 1535 |
-
default="../models/Astra/checkpoints/diffusion_pytorch_model.ckpt",
|
| 1536 |
help="path to the pretrained DiT MoE model checkpoint")
|
| 1537 |
-
parser.add_argument("--wan_model_path",
|
| 1538 |
-
type=str,
|
| 1539 |
-
default="../models/Wan-AI/Wan2.1-T2V-1.3B",
|
| 1540 |
-
help="path to Wan2.1-T2V-1.3B")
|
| 1541 |
parser.add_argument("--output_path", type=str,
|
| 1542 |
-
default='.
|
| 1543 |
-
parser.add_argument("--prompt",
|
| 1544 |
-
type=str,
|
| 1545 |
-
default="",
|
| 1546 |
help="text prompt for video generation")
|
| 1547 |
parser.add_argument("--device", type=str, default="cuda")
|
| 1548 |
parser.add_argument("--add_icons", action="store_true", default=False,
|
| 1549 |
-
help="
|
| 1550 |
|
| 1551 |
-
#
|
| 1552 |
parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"],
|
| 1553 |
-
default="sekai", help="
|
| 1554 |
parser.add_argument("--scene_info_path", type=str, default=None,
|
| 1555 |
-
help="NuScenes
|
| 1556 |
|
| 1557 |
-
# CFG
|
| 1558 |
parser.add_argument("--use_camera_cfg", default=False,
|
| 1559 |
-
help="
|
| 1560 |
parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
|
| 1561 |
help="Camera guidance scale for CFG")
|
| 1562 |
parser.add_argument("--text_guidance_scale", type=float, default=1.0,
|
| 1563 |
help="Text guidance scale for CFG")
|
| 1564 |
|
| 1565 |
-
# MoE
|
| 1566 |
-
parser.add_argument("--moe_num_experts", type=int, default=3, help="
|
| 1567 |
-
parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K
|
| 1568 |
-
parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE
|
| 1569 |
-
parser.add_argument("--direction", type=str, default="left", help="
|
| 1570 |
parser.add_argument("--use_gt_prompt", action="store_true", default=False,
|
| 1571 |
-
help="
|
| 1572 |
|
| 1573 |
args = parser.parse_args()
|
| 1574 |
|
| 1575 |
-
print(f"MoE FramePack CFG
|
| 1576 |
-
print(f"
|
| 1577 |
print(f"Camera CFG: {args.use_camera_cfg}")
|
| 1578 |
if args.use_camera_cfg:
|
| 1579 |
print(f"Camera guidance scale: {args.camera_guidance_scale}")
|
| 1580 |
-
print(f"
|
| 1581 |
print(f"Text guidance scale: {args.text_guidance_scale}")
|
| 1582 |
-
print(f"MoE
|
| 1583 |
print(f"DiT{args.dit_path}")
|
| 1584 |
|
| 1585 |
-
#
|
| 1586 |
if args.modality_type == "nuscenes" and not args.scene_info_path:
|
| 1587 |
-
print("⚠️
|
| 1588 |
-
|
| 1589 |
-
if not args.use_gt_prompt and (args.prompt is None or args.prompt.strip() == ""):
|
| 1590 |
-
print("⚠️ Warning: No prompt provided, will use empty string as prompt")
|
| 1591 |
-
|
| 1592 |
-
if not any([args.condition_pth, args.condition_video, args.condition_image]):
|
| 1593 |
-
raise ValueError("Need to provide condition_pth, condition_video, or condition_image as condition input")
|
| 1594 |
-
|
| 1595 |
-
if args.condition_pth:
|
| 1596 |
-
print(f"Using pre-encoded pth: {args.condition_pth}")
|
| 1597 |
-
elif args.condition_video:
|
| 1598 |
-
print(f"Using condition video for online encoding: {args.condition_video}")
|
| 1599 |
-
elif args.condition_image:
|
| 1600 |
-
print(f"Using condition image for online encoding: {args.condition_image} (repeat 10 frames)")
|
| 1601 |
|
| 1602 |
inference_moe_framepack_sliding_window(
|
| 1603 |
condition_pth_path=args.condition_pth,
|
| 1604 |
-
condition_video=args.condition_video,
|
| 1605 |
-
condition_image=args.condition_image,
|
| 1606 |
dit_path=args.dit_path,
|
| 1607 |
-
wan_model_path=args.wan_model_path,
|
| 1608 |
output_path=args.output_path,
|
| 1609 |
start_frame=args.start_frame,
|
| 1610 |
initial_condition_frames=args.initial_condition_frames,
|
|
@@ -1616,11 +1440,11 @@ def main():
|
|
| 1616 |
modality_type=args.modality_type,
|
| 1617 |
use_real_poses=args.use_real_poses,
|
| 1618 |
scene_info_path=args.scene_info_path,
|
| 1619 |
-
# CFG
|
| 1620 |
use_camera_cfg=args.use_camera_cfg,
|
| 1621 |
camera_guidance_scale=args.camera_guidance_scale,
|
| 1622 |
text_guidance_scale=args.text_guidance_scale,
|
| 1623 |
-
# MoE
|
| 1624 |
moe_num_experts=args.moe_num_experts,
|
| 1625 |
moe_top_k=args.moe_top_k,
|
| 1626 |
moe_hidden_dim=args.moe_hidden_dim,
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
|
|
|
|
|
|
| 3 |
|
| 4 |
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 5 |
sys.path.append(ROOT_DIR)
|
|
|
|
| 10 |
from PIL import Image
|
| 11 |
import imageio
|
| 12 |
import json
|
| 13 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 14 |
import argparse
|
| 15 |
from torchvision.transforms import v2
|
| 16 |
from einops import rearrange
|
|
|
|
| 17 |
import random
|
| 18 |
import copy
|
| 19 |
from datetime import datetime
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def compute_relative_pose_matrix(pose1, pose2):
|
| 22 |
"""
|
| 23 |
+
计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel]
|
| 24 |
|
| 25 |
+
参数:
|
| 26 |
+
pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
|
| 27 |
+
pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
|
| 28 |
|
| 29 |
+
返回:
|
| 30 |
+
relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel
|
|
|
|
|
|
|
| 31 |
"""
|
| 32 |
+
# 分离平移向量和四元数
|
| 33 |
+
t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1]
|
| 34 |
+
q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1]
|
| 35 |
+
t2 = pose2[:3] # 第i+1帧平移
|
| 36 |
+
q2 = pose2[3:] # 第i+1帧四元数
|
| 37 |
+
|
| 38 |
+
# 1. 计算相对旋转矩阵 R_rel
|
| 39 |
+
rot1 = R.from_quat(q1) # 第i帧旋转
|
| 40 |
+
rot2 = R.from_quat(q2) # 第i+1帧旋转
|
| 41 |
+
rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆
|
| 42 |
+
R_rel = rot_rel.as_matrix() # 转换为3×3矩阵
|
| 43 |
+
|
| 44 |
+
# 2. 计算相对平移向量 t_rel
|
| 45 |
+
R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆)
|
| 46 |
+
t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1)
|
| 47 |
+
|
| 48 |
+
# 3. 组合为3×4矩阵 [R_rel | t_rel]
|
| 49 |
relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
|
| 50 |
|
| 51 |
return relative_matrix
|
| 52 |
|
| 53 |
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 54 |
+
"""从pth文件加载预编码的视频数据"""
|
| 55 |
print(f"Loading encoded video from {pth_path}")
|
| 56 |
|
| 57 |
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
|
|
|
| 68 |
|
| 69 |
return condition_latents, encoded_data
|
| 70 |
|
| 71 |
+
|
| 72 |
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 73 |
+
"""计算相机B相对于相机A的相对位姿矩阵"""
|
| 74 |
+
assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
|
| 75 |
+
assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
|
| 76 |
|
| 77 |
if use_torch:
|
| 78 |
if not isinstance(pose_a, torch.Tensor):
|
|
|
|
| 95 |
|
| 96 |
|
| 97 |
def replace_dit_model_in_manager():
|
| 98 |
+
"""替换DiT模型类为MoE版本"""
|
| 99 |
from diffsynth.models.wan_video_dit_moe import WanModelMoe
|
| 100 |
from diffsynth.configs.model_config import model_loader_configs
|
| 101 |
|
|
|
|
| 110 |
if name == 'wan_video_dit':
|
| 111 |
new_model_names.append(name)
|
| 112 |
new_model_classes.append(WanModelMoe)
|
| 113 |
+
print(f"✅ 替换了模型类: {name} -> WanModelMoe")
|
| 114 |
else:
|
| 115 |
new_model_names.append(name)
|
| 116 |
new_model_classes.append(cls)
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
def add_framepack_components(dit_model):
|
| 122 |
+
"""添加FramePack相关组件"""
|
| 123 |
if not hasattr(dit_model, 'clean_x_embedder'):
|
| 124 |
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 125 |
|
|
|
|
| 146 |
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
|
| 147 |
model_dtype = next(dit_model.parameters()).dtype
|
| 148 |
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
|
| 149 |
+
print("✅ 添加了FramePack的clean_x_embedder组件")
|
| 150 |
|
| 151 |
|
| 152 |
def add_moe_components(dit_model, moe_config):
|
| 153 |
+
"""🔧 添加MoE相关组件 - 修正版本"""
|
| 154 |
if not hasattr(dit_model, 'moe_config'):
|
| 155 |
dit_model.moe_config = moe_config
|
| 156 |
+
print("✅ 添加了MoE配置到模型")
|
| 157 |
dit_model.top_k = moe_config.get("top_k", 1)
|
| 158 |
|
| 159 |
+
# 为每个block动态添加MoE组件
|
| 160 |
dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 161 |
unified_dim = moe_config.get("unified_dim", 25)
|
| 162 |
num_experts = moe_config.get("num_experts", 4)
|
| 163 |
from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
|
| 164 |
dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
|
| 165 |
dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
|
| 166 |
+
dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理
|
| 167 |
dit_model.global_router = nn.Linear(unified_dim, num_experts)
|
| 168 |
|
| 169 |
|
| 170 |
for i, block in enumerate(dit_model.blocks):
|
| 171 |
+
# MoE网络 - 输入unified_dim,输出dim
|
| 172 |
block.moe = MultiModalMoE(
|
| 173 |
unified_dim=unified_dim,
|
| 174 |
+
output_dim=dim, # 输出维度匹配transformer block的dim
|
| 175 |
num_experts=moe_config.get("num_experts", 4),
|
| 176 |
top_k=moe_config.get("top_k", 2)
|
| 177 |
)
|
| 178 |
|
| 179 |
+
print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
|
| 180 |
|
| 181 |
|
| 182 |
def generate_sekai_camera_embeddings_sliding(
|
|
|
|
| 188 |
use_real_poses=True,
|
| 189 |
direction="left"):
|
| 190 |
"""
|
| 191 |
+
为Sekai数据集生成camera embeddings - 滑动窗口版本
|
| 192 |
|
| 193 |
Args:
|
| 194 |
+
cam_data: 包含Sekai相机外参的字典, 键'extrinsic'对应一个N*4*4的numpy数组
|
| 195 |
+
start_frame: 当前生成起始帧索引
|
| 196 |
+
initial_condition_frames: 初始条件帧数
|
| 197 |
+
new_frames: 本次生成的新帧数
|
| 198 |
+
total_generated: 已生成的总帧数
|
| 199 |
+
use_real_poses: 是否使用真实的Sekai相机位姿
|
| 200 |
+
direction: 相机运动方向,默认为"left"
|
| 201 |
|
| 202 |
Returns:
|
| 203 |
+
camera_embedding: 形状为(M, 3*4 + 1)的torch张量, M为生成的总帧数
|
| 204 |
"""
|
| 205 |
time_compression_ratio = 4
|
| 206 |
|
| 207 |
+
# 计算FramePack实际需要的camera帧数
|
| 208 |
+
# 1帧初始 + 16帧4x + 2帧2x + 1帧1x + new_frames
|
| 209 |
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 210 |
|
| 211 |
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
|
| 212 |
+
print("🔧 使用真实Sekai camera数据")
|
| 213 |
cam_extrinsic = cam_data['extrinsic']
|
| 214 |
|
| 215 |
+
# 确保生成足够长的camera序列
|
| 216 |
max_needed_frames = max(
|
| 217 |
start_frame + initial_condition_frames + new_frames,
|
| 218 |
framepack_needed_frames,
|
| 219 |
30
|
| 220 |
)
|
| 221 |
|
| 222 |
+
print(f"🔧 计算Sekai camera序列长度:")
|
| 223 |
+
print(f" - 基础需求: {start_frame + initial_condition_frames + new_frames}")
|
| 224 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 225 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 226 |
|
| 227 |
relative_poses = []
|
| 228 |
for i in range(max_needed_frames):
|
| 229 |
+
# 计算当前帧在原始序列中的位置
|
| 230 |
frame_idx = i * time_compression_ratio
|
| 231 |
next_frame_idx = frame_idx + time_compression_ratio
|
| 232 |
|
|
|
|
| 236 |
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 237 |
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 238 |
else:
|
| 239 |
+
# 超出范围,使用零运动
|
| 240 |
+
print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
|
| 241 |
relative_poses.append(torch.zeros(3, 4))
|
| 242 |
|
| 243 |
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 244 |
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 245 |
|
| 246 |
+
# 创建对应长度的mask序列
|
| 247 |
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 248 |
+
# 从start_frame到start_frame+initial_condition_frames标记为condition
|
| 249 |
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
|
| 250 |
mask[start_frame:condition_end] = 1.0
|
| 251 |
|
| 252 |
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 253 |
+
print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
|
| 254 |
return camera_embedding.to(torch.bfloat16)
|
| 255 |
|
| 256 |
else:
|
| 257 |
+
# 确保生成足够长的camera序列
|
| 258 |
max_needed_frames = max(
|
| 259 |
start_frame + initial_condition_frames + new_frames,
|
| 260 |
framepack_needed_frames,
|
| 261 |
30)
|
| 262 |
|
| 263 |
+
print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
|
| 264 |
|
| 265 |
CONDITION_FRAMES = initial_condition_frames
|
| 266 |
STAGE_1 = new_frames//2
|
| 267 |
STAGE_2 = new_frames - STAGE_1
|
| 268 |
|
| 269 |
+
if direction=="left":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
print("--------------- LEFT TURNING MODE ---------------")
|
| 271 |
relative_poses = []
|
| 272 |
for i in range(max_needed_frames):
|
| 273 |
if i < CONDITION_FRAMES:
|
| 274 |
+
# 输入的条件帧默认的相机位姿为零运动
|
| 275 |
pose = np.eye(4, dtype=np.float32)
|
| 276 |
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 277 |
+
# 左转
|
| 278 |
yaw_per_frame = 0.03
|
| 279 |
|
| 280 |
+
# 旋转矩阵
|
| 281 |
cos_yaw = np.cos(yaw_per_frame)
|
| 282 |
sin_yaw = np.sin(yaw_per_frame)
|
| 283 |
|
| 284 |
+
# 前进
|
| 285 |
forward_speed = 0.00
|
| 286 |
|
| 287 |
pose = np.eye(4, dtype=np.float32)
|
|
|
|
| 292 |
pose[2, 2] = cos_yaw
|
| 293 |
pose[2, 3] = -forward_speed
|
| 294 |
else:
|
| 295 |
+
# 超出条件帧与目标帧的部分,保持静止
|
| 296 |
pose = np.eye(4, dtype=np.float32)
|
| 297 |
|
| 298 |
relative_pose = pose[:3, :]
|
|
|
|
| 303 |
relative_poses = []
|
| 304 |
for i in range(max_needed_frames):
|
| 305 |
if i < CONDITION_FRAMES:
|
| 306 |
+
# 输入的条件帧默认的相机位姿为零运动
|
| 307 |
pose = np.eye(4, dtype=np.float32)
|
| 308 |
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 309 |
+
# 右转
|
| 310 |
yaw_per_frame = -0.03
|
| 311 |
|
| 312 |
+
# 旋转矩阵
|
| 313 |
cos_yaw = np.cos(yaw_per_frame)
|
| 314 |
sin_yaw = np.sin(yaw_per_frame)
|
| 315 |
|
| 316 |
+
# 前进
|
| 317 |
forward_speed = 0.00
|
| 318 |
|
| 319 |
pose = np.eye(4, dtype=np.float32)
|
|
|
|
| 324 |
pose[2, 2] = cos_yaw
|
| 325 |
pose[2, 3] = -forward_speed
|
| 326 |
else:
|
| 327 |
+
# 超出条件帧与目标帧的部分,保持静止
|
| 328 |
pose = np.eye(4, dtype=np.float32)
|
| 329 |
|
| 330 |
relative_pose = pose[:3, :]
|
|
|
|
| 335 |
relative_poses = []
|
| 336 |
for i in range(max_needed_frames):
|
| 337 |
if i < CONDITION_FRAMES:
|
| 338 |
+
# 输入的条件帧默认的相机位姿为零运动
|
| 339 |
pose = np.eye(4, dtype=np.float32)
|
| 340 |
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 341 |
+
# 左转
|
| 342 |
yaw_per_frame = 0.03
|
| 343 |
|
| 344 |
+
# 旋转矩阵
|
| 345 |
cos_yaw = np.cos(yaw_per_frame)
|
| 346 |
sin_yaw = np.sin(yaw_per_frame)
|
| 347 |
|
| 348 |
+
# 前进
|
| 349 |
forward_speed = 0.03
|
| 350 |
|
| 351 |
pose = np.eye(4, dtype=np.float32)
|
|
|
|
| 357 |
pose[2, 3] = -forward_speed
|
| 358 |
|
| 359 |
else:
|
| 360 |
+
# 超出条件帧与目标帧的部分,保持静止
|
| 361 |
pose = np.eye(4, dtype=np.float32)
|
| 362 |
|
| 363 |
relative_pose = pose[:3, :]
|
|
|
|
| 368 |
relative_poses = []
|
| 369 |
for i in range(max_needed_frames):
|
| 370 |
if i < CONDITION_FRAMES:
|
| 371 |
+
# 输入的条件帧默认的相机位姿为零运动
|
| 372 |
pose = np.eye(4, dtype=np.float32)
|
| 373 |
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 374 |
+
# 右转
|
| 375 |
yaw_per_frame = -0.03
|
| 376 |
|
| 377 |
+
# 旋转矩阵
|
| 378 |
cos_yaw = np.cos(yaw_per_frame)
|
| 379 |
sin_yaw = np.sin(yaw_per_frame)
|
| 380 |
|
| 381 |
+
# 前进
|
| 382 |
forward_speed = 0.03
|
| 383 |
|
| 384 |
pose = np.eye(4, dtype=np.float32)
|
|
|
|
| 390 |
pose[2, 3] = -forward_speed
|
| 391 |
|
| 392 |
else:
|
| 393 |
+
# 超出条件帧与目标帧的部分,保持静止
|
| 394 |
pose = np.eye(4, dtype=np.float32)
|
| 395 |
|
| 396 |
relative_pose = pose[:3, :]
|
|
|
|
| 401 |
relative_poses = []
|
| 402 |
for i in range(max_needed_frames):
|
| 403 |
if i < CONDITION_FRAMES:
|
| 404 |
+
# 输入的条件帧默认的相机位姿为零运动
|
| 405 |
pose = np.eye(4, dtype=np.float32)
|
| 406 |
elif i < CONDITION_FRAMES+STAGE_1:
|
| 407 |
+
# 左转
|
| 408 |
yaw_per_frame = 0.03
|
| 409 |
|
| 410 |
+
# 旋转矩阵
|
| 411 |
cos_yaw = np.cos(yaw_per_frame)
|
| 412 |
sin_yaw = np.sin(yaw_per_frame)
|
| 413 |
|
| 414 |
+
# 前进
|
| 415 |
forward_speed = 0.03
|
| 416 |
|
| 417 |
pose = np.eye(4, dtype=np.float32)
|
|
|
|
| 423 |
pose[2, 3] = -forward_speed
|
| 424 |
|
| 425 |
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 426 |
+
# 右转
|
| 427 |
yaw_per_frame = -0.03
|
| 428 |
|
| 429 |
+
# 旋转矩阵
|
| 430 |
cos_yaw = np.cos(yaw_per_frame)
|
| 431 |
sin_yaw = np.sin(yaw_per_frame)
|
| 432 |
|
| 433 |
+
# 前进
|
| 434 |
forward_speed = 0.03
|
| 435 |
+
# 轻微向左漂移,保持惯性
|
| 436 |
if i < CONDITION_FRAMES+STAGE_1+STAGE_2//3:
|
| 437 |
radius_shift = -0.01
|
| 438 |
else:
|
|
|
|
| 448 |
pose[0, 3] = radius_shift
|
| 449 |
|
| 450 |
else:
|
| 451 |
+
# 超出条件帧与目标帧的部分,保持静止
|
| 452 |
pose = np.eye(4, dtype=np.float32)
|
| 453 |
|
| 454 |
relative_pose = pose[:3, :]
|
|
|
|
| 459 |
relative_poses = []
|
| 460 |
for i in range(max_needed_frames):
|
| 461 |
if i < CONDITION_FRAMES:
|
| 462 |
+
# 输入的条件帧默认的相机位姿为零运动
|
| 463 |
pose = np.eye(4, dtype=np.float32)
|
| 464 |
elif i < CONDITION_FRAMES+STAGE_1:
|
| 465 |
+
# 左转
|
| 466 |
yaw_per_frame = 0.03
|
| 467 |
|
| 468 |
+
# 旋转矩阵
|
| 469 |
cos_yaw = np.cos(yaw_per_frame)
|
| 470 |
sin_yaw = np.sin(yaw_per_frame)
|
| 471 |
|
| 472 |
+
# 前进
|
| 473 |
forward_speed = 0.00
|
| 474 |
|
| 475 |
pose = np.eye(4, dtype=np.float32)
|
|
|
|
| 481 |
pose[2, 3] = -forward_speed
|
| 482 |
|
| 483 |
elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
|
| 484 |
+
# 右转
|
| 485 |
yaw_per_frame = -0.03
|
| 486 |
|
| 487 |
+
# 旋转矩阵
|
| 488 |
cos_yaw = np.cos(yaw_per_frame)
|
| 489 |
sin_yaw = np.sin(yaw_per_frame)
|
| 490 |
|
| 491 |
+
# 前进
|
| 492 |
forward_speed = 0.00
|
| 493 |
|
| 494 |
pose = np.eye(4, dtype=np.float32)
|
|
|
|
| 500 |
pose[2, 3] = -forward_speed
|
| 501 |
|
| 502 |
else:
|
| 503 |
+
# 超出条件帧与目标帧的部分,保持静止
|
| 504 |
pose = np.eye(4, dtype=np.float32)
|
| 505 |
|
| 506 |
relative_pose = pose[:3, :]
|
| 507 |
relative_poses.append(torch.as_tensor(relative_pose))
|
| 508 |
|
| 509 |
else:
|
| 510 |
+
raise ValueError(f"未定义的相机运动方向: {direction}")
|
| 511 |
|
| 512 |
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 513 |
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 514 |
|
| 515 |
+
# 创建对应长度的mask序列
|
| 516 |
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 517 |
condition_end = min(start_frame + initial_condition_frames + 1, max_needed_frames)
|
| 518 |
mask[start_frame:condition_end] = 1.0
|
| 519 |
|
| 520 |
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 521 |
+
print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
|
| 522 |
return camera_embedding.to(torch.bfloat16)
|
| 523 |
|
| 524 |
|
| 525 |
def generate_openx_camera_embeddings_sliding(
|
| 526 |
encoded_data, start_frame, initial_condition_frames, new_frames, use_real_poses):
|
| 527 |
+
"""为OpenX数据集生成camera embeddings - 滑动窗口版本"""
|
| 528 |
time_compression_ratio = 4
|
| 529 |
|
| 530 |
+
# 计算FramePack实际需要的camera帧数
|
| 531 |
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 532 |
|
| 533 |
if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
|
| 534 |
+
print("🔧 使用OpenX真实camera数据")
|
| 535 |
cam_extrinsic = encoded_data['cam_emb']['extrinsic']
|
| 536 |
|
| 537 |
+
# 确保生成足够长的camera序列
|
| 538 |
max_needed_frames = max(
|
| 539 |
start_frame + initial_condition_frames + new_frames,
|
| 540 |
framepack_needed_frames,
|
| 541 |
30
|
| 542 |
)
|
| 543 |
|
| 544 |
+
print(f"🔧 计算OpenX camera序列长度:")
|
| 545 |
+
print(f" - 基础需求: {start_frame + initial_condition_frames + new_frames}")
|
| 546 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 547 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 548 |
|
| 549 |
relative_poses = []
|
| 550 |
for i in range(max_needed_frames):
|
| 551 |
+
# OpenX使用4倍间隔,类似sekai但处理更短的序列
|
| 552 |
frame_idx = i * time_compression_ratio
|
| 553 |
next_frame_idx = frame_idx + time_compression_ratio
|
| 554 |
|
|
|
|
| 558 |
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 559 |
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 560 |
else:
|
| 561 |
+
# 超出范围,使用零运动
|
| 562 |
+
print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
|
| 563 |
relative_poses.append(torch.zeros(3, 4))
|
| 564 |
|
| 565 |
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 566 |
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 567 |
|
| 568 |
+
# 创建对应长度的mask序列
|
| 569 |
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 570 |
+
# 从start_frame到start_frame + initial_condition_frames标记为condition
|
| 571 |
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
|
| 572 |
mask[start_frame:condition_end] = 1.0
|
| 573 |
|
| 574 |
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 575 |
+
print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
|
| 576 |
return camera_embedding.to(torch.bfloat16)
|
| 577 |
|
| 578 |
else:
|
| 579 |
+
print("🔧 使用OpenX合成camera数据")
|
| 580 |
|
| 581 |
max_needed_frames = max(
|
| 582 |
start_frame + initial_condition_frames + new_frames,
|
|
|
|
| 584 |
30
|
| 585 |
)
|
| 586 |
|
| 587 |
+
print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
|
| 588 |
relative_poses = []
|
| 589 |
for i in range(max_needed_frames):
|
| 590 |
+
# OpenX机器人操作运动模式 - 较小的运动幅度
|
| 591 |
+
# 模拟机器人手臂的精细操作运动
|
| 592 |
+
roll_per_frame = 0.02 # 轻微翻滚
|
| 593 |
+
pitch_per_frame = 0.01 # 轻微俯仰
|
| 594 |
+
yaw_per_frame = 0.015 # 轻微偏航
|
| 595 |
+
forward_speed = 0.003 # 较慢的前进速度
|
| 596 |
|
| 597 |
pose = np.eye(4, dtype=np.float32)
|
| 598 |
|
| 599 |
+
# 复合旋转 - 模拟机器人手臂的复杂运动
|
| 600 |
+
# 绕X轴旋转(roll)
|
| 601 |
cos_roll = np.cos(roll_per_frame)
|
| 602 |
sin_roll = np.sin(roll_per_frame)
|
| 603 |
+
# 绕Y轴旋转(pitch)
|
| 604 |
cos_pitch = np.cos(pitch_per_frame)
|
| 605 |
sin_pitch = np.sin(pitch_per_frame)
|
| 606 |
+
# 绕Z轴旋转(yaw)
|
| 607 |
cos_yaw = np.cos(yaw_per_frame)
|
| 608 |
sin_yaw = np.sin(yaw_per_frame)
|
| 609 |
|
| 610 |
+
# 简化的复合旋转矩阵(ZYX顺序)
|
| 611 |
pose[0, 0] = cos_yaw * cos_pitch
|
| 612 |
pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
|
| 613 |
pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
|
|
|
|
| 618 |
pose[2, 1] = cos_pitch * sin_roll
|
| 619 |
pose[2, 2] = cos_pitch * cos_roll
|
| 620 |
|
| 621 |
+
# 平移 - ���拟机器人操作的精细移动
|
| 622 |
+
pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
|
| 623 |
+
pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动
|
| 624 |
+
pose[2, 3] = -forward_speed # Z方向(深度)主要移动
|
| 625 |
|
| 626 |
relative_pose = pose[:3, :]
|
| 627 |
relative_poses.append(torch.as_tensor(relative_pose))
|
|
|
|
| 629 |
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 630 |
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 631 |
|
| 632 |
+
# 创建对应长度的mask序列
|
| 633 |
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 634 |
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
|
| 635 |
mask[start_frame:condition_end] = 1.0
|
| 636 |
|
| 637 |
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 638 |
+
print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
|
| 639 |
return camera_embedding.to(torch.bfloat16)
|
| 640 |
|
| 641 |
|
| 642 |
def generate_nuscenes_camera_embeddings_sliding(
|
| 643 |
scene_info, start_frame, initial_condition_frames, new_frames):
|
| 644 |
+
"""为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 645 |
time_compression_ratio = 4
|
| 646 |
|
| 647 |
+
# 计算FramePack实际需要的camera帧数
|
| 648 |
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 649 |
|
| 650 |
if scene_info is not None and 'keyframe_poses' in scene_info:
|
| 651 |
+
print("🔧 使用NuScenes真实pose数据")
|
| 652 |
keyframe_poses = scene_info['keyframe_poses']
|
| 653 |
|
| 654 |
if len(keyframe_poses) == 0:
|
| 655 |
+
print("⚠️ NuScenes keyframe_poses为空,使用零pose")
|
| 656 |
max_needed_frames = max(framepack_needed_frames, 30)
|
| 657 |
|
| 658 |
pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
|
|
|
|
| 662 |
mask[start_frame:condition_end] = 1.0
|
| 663 |
|
| 664 |
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 665 |
+
print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}")
|
| 666 |
return camera_embedding.to(torch.bfloat16)
|
| 667 |
|
| 668 |
+
# 使用第一个pose作为参考
|
| 669 |
reference_pose = keyframe_poses[0]
|
| 670 |
|
| 671 |
max_needed_frames = max(framepack_needed_frames, 30)
|
|
|
|
| 675 |
if i < len(keyframe_poses):
|
| 676 |
current_pose = keyframe_poses[i]
|
| 677 |
|
| 678 |
+
# 计算相对位移
|
| 679 |
translation = torch.tensor(
|
| 680 |
np.array(current_pose['translation']) - np.array(reference_pose['translation']),
|
| 681 |
dtype=torch.float32
|
| 682 |
)
|
| 683 |
|
| 684 |
+
# 计算相对旋转(简化版本)
|
| 685 |
rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
|
| 686 |
|
| 687 |
pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
|
| 688 |
else:
|
| 689 |
+
# 超出范围,使用零pose
|
| 690 |
pose_vec = torch.cat([
|
| 691 |
torch.zeros(3, dtype=torch.float32),
|
| 692 |
torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
|
|
|
|
| 696 |
|
| 697 |
pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
|
| 698 |
|
| 699 |
+
# 创建mask
|
| 700 |
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 701 |
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
|
| 702 |
mask[start_frame:condition_end] = 1.0
|
| 703 |
|
| 704 |
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 705 |
+
print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
|
| 706 |
return camera_embedding.to(torch.bfloat16)
|
| 707 |
|
| 708 |
else:
|
| 709 |
+
print("🔧 使用NuScenes合成pose数据")
|
| 710 |
max_needed_frames = max(framepack_needed_frames, 30)
|
| 711 |
|
| 712 |
+
# 创建合成运动序列
|
| 713 |
pose_vecs = []
|
| 714 |
for i in range(max_needed_frames):
|
| 715 |
+
# 左转运动模式 - 类似城市驾驶中的左转弯
|
| 716 |
+
angle = i * 0.04 # 每帧转动0.08弧度(稍微慢一点的转弯)
|
| 717 |
+
radius = 15.0 # 较大的转弯半径,更符合汽车转弯
|
| 718 |
|
| 719 |
+
# 计算圆弧轨迹上的位置
|
| 720 |
x = radius * np.sin(angle)
|
| 721 |
+
y = 0.0 # 保持水平面运动
|
| 722 |
z = radius * (1 - np.cos(angle))
|
| 723 |
|
| 724 |
translation = torch.tensor([x, y, z], dtype=torch.float32)
|
| 725 |
|
| 726 |
+
# 车辆朝向 - 始终沿着轨迹切线方向
|
| 727 |
+
yaw = angle + np.pi/2 # 相对于初始前进方向的偏航角
|
| 728 |
+
# 四元数表示绕Y轴的旋转
|
| 729 |
rotation = torch.tensor([
|
| 730 |
+
np.cos(yaw/2), # w (实部)
|
| 731 |
0.0, # x
|
| 732 |
0.0, # y
|
| 733 |
+
np.sin(yaw/2) # z (虚部,绕Y轴)
|
| 734 |
], dtype=torch.float32)
|
| 735 |
|
| 736 |
pose_vec = torch.cat([translation, rotation], dim=0) # [7D: tx,ty,tz,qw,qx,qy,qz]
|
|
|
|
| 738 |
|
| 739 |
pose_sequence = torch.stack(pose_vecs, dim=0)
|
| 740 |
|
| 741 |
+
# 创建mask
|
| 742 |
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 743 |
condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
|
| 744 |
mask[start_frame:condition_end] = 1.0
|
| 745 |
|
| 746 |
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 747 |
+
print(f"🔧 NuScenes合成左转pose embedding shape: {camera_embedding.shape}")
|
| 748 |
return camera_embedding.to(torch.bfloat16)
|
| 749 |
+
|
| 750 |
def prepare_framepack_sliding_window_with_camera_moe(
|
| 751 |
history_latents,
|
| 752 |
target_frames_to_generate,
|
|
|
|
| 754 |
start_frame,
|
| 755 |
modality_type,
|
| 756 |
max_history_frames=49):
|
| 757 |
+
"""FramePack滑动窗口机制 - MoE版本"""
|
| 758 |
+
# history_latents: [C, T, H, W] 当前的历史latents
|
| 759 |
C, T, H, W = history_latents.shape
|
| 760 |
|
| 761 |
+
# 固定索引结构(这决定了需要的camera帧数)
|
| 762 |
+
# 1帧起始 + 16帧4x + 2帧2x + 1帧1x + target_frames_to_generate
|
| 763 |
total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
|
| 764 |
indices = torch.arange(0, total_indices_length)
|
| 765 |
split_sizes = [1, 16, 2, 1, target_frames_to_generate]
|
|
|
|
| 767 |
indices.split(split_sizes, dim=0)
|
| 768 |
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
|
| 769 |
|
| 770 |
+
# 检查camera长度是否足够
|
| 771 |
if camera_embedding_full.shape[0] < total_indices_length:
|
| 772 |
+
print(f"⚠️ camera_embedding长度不足,进行零补齐: 当前长度 {camera_embedding_full.shape[0]}, 需要长度 {total_indices_length}")
|
| 773 |
shortage = total_indices_length - camera_embedding_full.shape[0]
|
| 774 |
padding = torch.zeros(shortage, camera_embedding_full.shape[1],
|
| 775 |
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
|
| 776 |
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
|
| 777 |
|
| 778 |
+
# 从完整camera序列中选取对应部分
|
| 779 |
combined_camera = torch.zeros(
|
| 780 |
total_indices_length,
|
| 781 |
camera_embedding_full.shape[1],
|
| 782 |
dtype=camera_embedding_full.dtype,
|
| 783 |
device=camera_embedding_full.device)
|
| 784 |
|
| 785 |
+
# 历史条件帧的相机位姿
|
| 786 |
history_slice = camera_embedding_full[max(T - 19, 0):T, :].clone()
|
| 787 |
combined_camera[19 - history_slice.shape[0]:19, :] = history_slice
|
| 788 |
|
| 789 |
+
# 目标帧的相机位姿
|
| 790 |
target_slice = camera_embedding_full[T:T + target_frames_to_generate, :].clone()
|
| 791 |
combined_camera[19:19 + target_slice.shape[0], :] = target_slice
|
| 792 |
|
| 793 |
+
# 根据当前history length重新设置mask
|
| 794 |
+
combined_camera[:, -1] = 0.0 # 先全部设为target (0)
|
| 795 |
|
| 796 |
+
# 设置condition mask:前19帧根据实际历史长度决定
|
| 797 |
if T > 0:
|
| 798 |
available_frames = min(T, 19)
|
| 799 |
start_pos = 19 - available_frames
|
| 800 |
+
combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
|
| 801 |
|
| 802 |
+
print(f"🔧 MoE Camera mask更新:")
|
| 803 |
+
print(f" - 历史帧数: {T}")
|
| 804 |
+
print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
|
| 805 |
+
print(f" - 模态类型: {modality_type}")
|
| 806 |
|
| 807 |
+
# 处理latents
|
| 808 |
clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 809 |
|
| 810 |
if T > 0:
|
|
|
|
| 832 |
'clean_latent_2x_indices': clean_latent_2x_indices,
|
| 833 |
'clean_latent_4x_indices': clean_latent_4x_indices,
|
| 834 |
'camera_embedding': combined_camera,
|
| 835 |
+
'modality_type': modality_type, # 新增模态类型信息
|
| 836 |
'current_length': T,
|
| 837 |
'next_length': T + target_frames_to_generate
|
| 838 |
}
|
| 839 |
|
| 840 |
def overlay_controls(frame_img, pose_vec, icons):
|
| 841 |
"""
|
| 842 |
+
根据相机位姿在帧上叠加控制图标(WASD 和箭头)
|
| 843 |
+
pose_vec: 12 个元素(展平的 3x4 矩阵)+ mask
|
| 844 |
"""
|
| 845 |
if pose_vec is None or np.all(pose_vec[:12] == 0):
|
| 846 |
return frame_img
|
| 847 |
|
| 848 |
+
# 提取平移向量(基于展平的 3x4 矩阵的索引)
|
| 849 |
# [r00, r01, r02, tx, r10, r11, r12, ty, r20, r21, r22, tz]
|
| 850 |
tx = pose_vec[3]
|
| 851 |
# ty = pose_vec[7]
|
| 852 |
tz = pose_vec[11]
|
| 853 |
|
| 854 |
+
# 提取旋转(偏航和俯仰)
|
| 855 |
+
# 偏航:绕 Y 轴。sin(偏航) = r02, cos(偏航) = r00
|
| 856 |
r00 = pose_vec[0]
|
| 857 |
r02 = pose_vec[2]
|
| 858 |
yaw = np.arctan2(r02, r00)
|
| 859 |
|
| 860 |
+
# 俯仰:绕 X 轴。sin(俯仰) = -r12, cos(俯仰) = r22
|
| 861 |
r12 = pose_vec[6]
|
| 862 |
r22 = pose_vec[10]
|
| 863 |
pitch = np.arctan2(-r12, r22)
|
| 864 |
|
| 865 |
+
# 按键激活的阈值
|
| 866 |
TRANS_THRESH = 0.01
|
| 867 |
ROT_THRESH = 0.005
|
| 868 |
|
| 869 |
+
# 确定按键状态
|
| 870 |
+
# 平移(WASD)
|
| 871 |
+
# 假设 -Z 为前进,+X 为右
|
| 872 |
is_forward = tz < -TRANS_THRESH
|
| 873 |
is_backward = tz > TRANS_THRESH
|
| 874 |
is_left = tx < -TRANS_THRESH
|
| 875 |
is_right = tx > TRANS_THRESH
|
| 876 |
|
| 877 |
+
# 旋转(箭头)
|
| 878 |
+
# 偏航:+ 为左,- 为右
|
| 879 |
is_turn_left = yaw > ROT_THRESH
|
| 880 |
is_turn_right = yaw < -ROT_THRESH
|
| 881 |
|
| 882 |
+
# 俯仰:+ 为下,- 为上
|
| 883 |
is_turn_up = pitch < -ROT_THRESH
|
| 884 |
is_turn_down = pitch > ROT_THRESH
|
| 885 |
|
|
|
|
| 890 |
name = name_active if is_active else name_inactive
|
| 891 |
if name in icons:
|
| 892 |
icon = icons[name]
|
| 893 |
+
# 使用 alpha 通道粘贴
|
| 894 |
frame_img.paste(icon, (int(x), int(y)), icon)
|
| 895 |
|
| 896 |
+
# 叠加 WASD(左下角)
|
| 897 |
base_x_right = 100
|
| 898 |
base_y = H - 100
|
| 899 |
|
|
|
|
| 906 |
# D
|
| 907 |
paste_icon('move_right.png', 'not_move_right.png', is_right, base_x_right + spacing, base_y)
|
| 908 |
|
| 909 |
+
# 叠加 ↑↓←→(右下角)
|
| 910 |
base_x_left = W - 150
|
| 911 |
|
| 912 |
# ↑
|
|
|
|
| 922 |
|
| 923 |
|
| 924 |
def inference_moe_framepack_sliding_window(
|
| 925 |
+
condition_pth_path,
|
| 926 |
+
dit_path,
|
|
|
|
|
|
|
|
|
|
| 927 |
output_path="../examples/output_videos/output_moe_framepack_sliding.mp4",
|
| 928 |
start_frame=0,
|
| 929 |
initial_condition_frames=8,
|
|
|
|
| 932 |
max_history_frames=49,
|
| 933 |
device="cuda",
|
| 934 |
prompt="A video of a scene shot using a pedestrian's front camera while walking",
|
| 935 |
+
modality_type="sekai", # "sekai" 或 "nuscenes"
|
| 936 |
use_real_poses=True,
|
| 937 |
+
scene_info_path=None, # 对于NuScenes数据集
|
| 938 |
+
# CFG参数
|
| 939 |
use_camera_cfg=True,
|
| 940 |
camera_guidance_scale=2.0,
|
| 941 |
text_guidance_scale=1.0,
|
| 942 |
+
# MoE参数
|
| 943 |
moe_num_experts=4,
|
| 944 |
moe_top_k=2,
|
| 945 |
moe_hidden_dim=None,
|
|
|
|
| 948 |
add_icons=False
|
| 949 |
):
|
| 950 |
"""
|
| 951 |
+
MoE FramePack滑动窗口视频生成 - 支持多模态
|
| 952 |
"""
|
| 953 |
+
# 创建输出目录
|
| 954 |
dir_path = os.path.dirname(output_path)
|
| 955 |
os.makedirs(dir_path, exist_ok=True)
|
| 956 |
|
| 957 |
+
print(f"🔧 MoE FramePack滑动窗口生成开始...")
|
| 958 |
+
print(f"模态类型: {modality_type}")
|
| 959 |
+
print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
|
| 960 |
+
print(f"Text guidance scale: {text_guidance_scale}")
|
| 961 |
+
print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
|
| 962 |
|
| 963 |
+
# 1. 模型初始化
|
| 964 |
replace_dit_model_in_manager()
|
| 965 |
|
| 966 |
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 967 |
model_manager.load_models([
|
| 968 |
+
"/mnt/data/louis_crq/models/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 969 |
+
"/mnt/data/louis_crq/models/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 970 |
+
"/mnt/data/louis_crq/models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 971 |
])
|
| 972 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 973 |
|
| 974 |
+
# 2. 添加传统camera编码器(兼容性)
|
| 975 |
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 976 |
for block in pipe.dit.blocks:
|
| 977 |
block.cam_encoder = nn.Linear(13, dim)
|
|
|
|
| 981 |
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 982 |
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 983 |
|
| 984 |
+
# 3. 添加FramePack组件
|
| 985 |
add_framepack_components(pipe.dit)
|
| 986 |
|
| 987 |
+
# 4. 添加MoE组件
|
| 988 |
moe_config = {
|
| 989 |
"num_experts": moe_num_experts,
|
| 990 |
"top_k": moe_top_k,
|
| 991 |
"hidden_dim": moe_hidden_dim or dim * 2,
|
| 992 |
+
"sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
|
| 993 |
+
"nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
|
| 994 |
+
"openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
|
| 995 |
}
|
| 996 |
add_moe_components(pipe.dit, moe_config)
|
| 997 |
|
| 998 |
+
# 5. 加载训练好的权重
|
| 999 |
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 1000 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
|
| 1001 |
pipe = pipe.to(device)
|
| 1002 |
model_dtype = next(pipe.dit.parameters()).dtype
|
| 1003 |
|
| 1004 |
if hasattr(pipe.dit, 'clean_x_embedder'):
|
| 1005 |
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
|
| 1006 |
|
| 1007 |
+
# 设置去噪步数
|
| 1008 |
pipe.scheduler.set_timesteps(50)
|
| 1009 |
|
| 1010 |
+
# 6. 加载初始条件
|
| 1011 |
print("Loading initial condition frames...")
|
| 1012 |
+
initial_latents, encoded_data = load_encoded_video_from_pth(
|
| 1013 |
+
condition_pth_path,
|
| 1014 |
+
start_frame=start_frame,
|
| 1015 |
+
num_frames=initial_condition_frames
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1016 |
)
|
| 1017 |
|
| 1018 |
+
# 空间裁剪
|
| 1019 |
target_height, target_width = 60, 104
|
| 1020 |
C, T, H, W = initial_latents.shape
|
| 1021 |
|
|
|
|
| 1027 |
|
| 1028 |
history_latents = initial_latents.to(device, dtype=model_dtype)
|
| 1029 |
|
| 1030 |
+
print(f"初始history_latents shape: {history_latents.shape}")
|
| 1031 |
|
| 1032 |
+
# 7. 编码prompt - 支持CFG
|
| 1033 |
if use_gt_prompt and 'prompt_emb' in encoded_data:
|
| 1034 |
+
print("✅ 使用预编码的GT prompt embedding")
|
| 1035 |
prompt_emb_pos = encoded_data['prompt_emb']
|
| 1036 |
+
# 将prompt_emb移到正确的设备和数据类型
|
| 1037 |
if 'context' in prompt_emb_pos:
|
| 1038 |
prompt_emb_pos['context'] = prompt_emb_pos['context'].to(device, dtype=model_dtype)
|
| 1039 |
if 'context_mask' in prompt_emb_pos:
|
| 1040 |
prompt_emb_pos['context_mask'] = prompt_emb_pos['context_mask'].to(device, dtype=model_dtype)
|
| 1041 |
|
| 1042 |
+
# 如果使用Text CFG,生成负向prompt
|
| 1043 |
if text_guidance_scale > 1.0:
|
| 1044 |
prompt_emb_neg = pipe.encode_prompt("")
|
| 1045 |
+
print(f"使用Text CFG with GT prompt,guidance scale: {text_guidance_scale}")
|
| 1046 |
else:
|
| 1047 |
prompt_emb_neg = None
|
| 1048 |
+
print("不使用Text CFG")
|
| 1049 |
|
| 1050 |
+
# 🔧 打印GT prompt文本(如果有)
|
| 1051 |
if 'prompt' in encoded_data['prompt_emb']:
|
| 1052 |
gt_prompt_text = encoded_data['prompt_emb']['prompt']
|
| 1053 |
+
print(f"📝 GT Prompt文本: {gt_prompt_text}")
|
| 1054 |
else:
|
| 1055 |
+
# 使用传入的prompt参数重新编码
|
| 1056 |
+
print(f"🔄 重新编码prompt: {prompt}")
|
| 1057 |
if text_guidance_scale > 1.0:
|
| 1058 |
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 1059 |
prompt_emb_neg = pipe.encode_prompt("")
|
| 1060 |
+
print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
|
| 1061 |
else:
|
| 1062 |
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 1063 |
prompt_emb_neg = None
|
| 1064 |
+
print("不使用Text CFG")
|
| 1065 |
|
| 1066 |
+
# 8. 加载场景信息(对于NuScenes)
|
| 1067 |
scene_info = None
|
| 1068 |
if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
|
| 1069 |
with open(scene_info_path, 'r') as f:
|
| 1070 |
scene_info = json.load(f)
|
| 1071 |
+
print(f"加载NuScenes场景信息: {scene_info_path}")
|
| 1072 |
|
| 1073 |
+
# 9. 预生成完整的camera embedding序列
|
| 1074 |
if modality_type == "sekai":
|
| 1075 |
camera_embedding_full = generate_sekai_camera_embeddings_sliding(
|
| 1076 |
encoded_data.get('cam_emb', None),
|
|
|
|
| 1097 |
use_real_poses=use_real_poses
|
| 1098 |
).to(device, dtype=model_dtype)
|
| 1099 |
else:
|
| 1100 |
+
raise ValueError(f"不支持的模态类型: {modality_type}")
|
| 1101 |
|
| 1102 |
+
print(f"完整camera序列shape: {camera_embedding_full.shape}")
|
| 1103 |
|
| 1104 |
+
# 10. 为Camera CFG创建无条件的camera embedding
|
| 1105 |
if use_camera_cfg:
|
| 1106 |
camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
|
| 1107 |
+
print(f"创建无条件camera embedding用于CFG")
|
| 1108 |
|
| 1109 |
+
# 11. 滑动窗口生成循环
|
| 1110 |
total_generated = 0
|
| 1111 |
all_generated_frames = []
|
| 1112 |
|
| 1113 |
while total_generated < total_frames_to_generate:
|
| 1114 |
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
|
| 1115 |
+
print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
|
| 1116 |
+
print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
|
| 1117 |
|
| 1118 |
+
# FramePack数据准备 - MoE版本
|
| 1119 |
framepack_data = prepare_framepack_sliding_window_with_camera_moe(
|
| 1120 |
history_latents,
|
| 1121 |
current_generation,
|
|
|
|
| 1125 |
max_history_frames
|
| 1126 |
)
|
| 1127 |
|
| 1128 |
+
# 准备输入
|
| 1129 |
clean_latents = framepack_data['clean_latents'].unsqueeze(0)
|
| 1130 |
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
|
| 1131 |
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
|
| 1132 |
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
|
| 1133 |
|
| 1134 |
+
# 准备modality_inputs
|
| 1135 |
modality_inputs = {modality_type: camera_embedding}
|
| 1136 |
|
| 1137 |
+
# 为CFG准备无条件camera embedding
|
| 1138 |
if use_camera_cfg:
|
| 1139 |
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
|
| 1140 |
modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
|
| 1141 |
|
| 1142 |
+
# 索引处理
|
| 1143 |
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
|
| 1144 |
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
|
| 1145 |
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
|
| 1146 |
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
|
| 1147 |
|
| 1148 |
+
# 初始化要生成的latents
|
| 1149 |
new_latents = torch.randn(
|
| 1150 |
1, C, current_generation, H, W,
|
| 1151 |
device=device, dtype=model_dtype
|
|
|
|
| 1154 |
extra_input = pipe.prepare_extra_input(new_latents)
|
| 1155 |
|
| 1156 |
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 1157 |
+
print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
|
| 1158 |
|
| 1159 |
+
# 去噪循环 - 支持CFG
|
| 1160 |
timesteps = pipe.scheduler.timesteps
|
| 1161 |
|
| 1162 |
for i, timestep in enumerate(timesteps):
|
| 1163 |
if i % 10 == 0:
|
| 1164 |
+
print(f" 去噪步骤 {i+1}/{len(timesteps)}")
|
| 1165 |
|
| 1166 |
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
|
| 1167 |
|
| 1168 |
with torch.no_grad():
|
| 1169 |
+
# CFG推理
|
| 1170 |
if use_camera_cfg and camera_guidance_scale > 1.0:
|
| 1171 |
+
# 条件预测(有camera)
|
| 1172 |
noise_pred_cond, moe_loess = pipe.dit(
|
| 1173 |
new_latents,
|
| 1174 |
timestep=timestep_tensor,
|
| 1175 |
cam_emb=camera_embedding,
|
| 1176 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 1177 |
latent_indices=latent_indices,
|
| 1178 |
clean_latents=clean_latents,
|
| 1179 |
clean_latent_indices=clean_latent_indices,
|
|
|
|
| 1185 |
**extra_input
|
| 1186 |
)
|
| 1187 |
|
| 1188 |
+
# 无条件预测(无camera)
|
| 1189 |
noise_pred_uncond, moe_loess = pipe.dit(
|
| 1190 |
new_latents,
|
| 1191 |
timestep=timestep_tensor,
|
| 1192 |
cam_emb=camera_embedding_uncond_batch,
|
| 1193 |
+
modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
|
| 1194 |
latent_indices=latent_indices,
|
| 1195 |
clean_latents=clean_latents,
|
| 1196 |
clean_latent_indices=clean_latent_indices,
|
|
|
|
| 1205 |
# Camera CFG
|
| 1206 |
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 1207 |
|
| 1208 |
+
# 如果同时使用Text CFG
|
| 1209 |
if text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 1210 |
noise_pred_text_uncond, moe_loess = pipe.dit(
|
| 1211 |
new_latents,
|
|
|
|
| 1223 |
**extra_input
|
| 1224 |
)
|
| 1225 |
|
| 1226 |
+
# 应用Text CFG到已经应用Camera CFG的结果
|
| 1227 |
noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
|
| 1228 |
|
| 1229 |
elif text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 1230 |
+
# 只使用Text CFG
|
| 1231 |
noise_pred_cond, moe_loess = pipe.dit(
|
| 1232 |
new_latents,
|
| 1233 |
timestep=timestep_tensor,
|
|
|
|
| 1263 |
noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 1264 |
|
| 1265 |
else:
|
| 1266 |
+
# 标准推理(无CFG)
|
| 1267 |
noise_pred, moe_loess = pipe.dit(
|
| 1268 |
new_latents,
|
| 1269 |
timestep=timestep_tensor,
|
| 1270 |
cam_emb=camera_embedding,
|
| 1271 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 1272 |
latent_indices=latent_indices,
|
| 1273 |
clean_latents=clean_latents,
|
| 1274 |
clean_latent_indices=clean_latent_indices,
|
|
|
|
| 1282 |
|
| 1283 |
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
|
| 1284 |
|
| 1285 |
+
# 更新历史
|
| 1286 |
new_latents_squeezed = new_latents.squeeze(0)
|
| 1287 |
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
|
| 1288 |
|
| 1289 |
+
# 维护滑动窗口
|
| 1290 |
if history_latents.shape[1] > max_history_frames:
|
| 1291 |
first_frame = history_latents[:, 0:1, :, :]
|
| 1292 |
recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
|
| 1293 |
history_latents = torch.cat([first_frame, recent_frames], dim=1)
|
| 1294 |
+
print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
|
| 1295 |
|
| 1296 |
+
print(f"更新后history_latents shape: {history_latents.shape}")
|
| 1297 |
|
| 1298 |
all_generated_frames.append(new_latents_squeezed)
|
| 1299 |
total_generated += current_generation
|
| 1300 |
|
| 1301 |
+
print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
|
| 1302 |
|
| 1303 |
+
# 12. 解码和保存
|
| 1304 |
+
print("\n🔧 解码生成的视频...")
|
| 1305 |
|
| 1306 |
all_generated = torch.cat(all_generated_frames, dim=1)
|
| 1307 |
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
|
| 1308 |
|
| 1309 |
+
print(f"最终视频shape: {final_video.shape}")
|
| 1310 |
|
| 1311 |
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 1312 |
|
|
|
|
| 1319 |
icons = {}
|
| 1320 |
video_camera_poses = None
|
| 1321 |
if add_icons:
|
| 1322 |
+
# 加载用于叠加的图标资源
|
| 1323 |
icons_dir = os.path.join(ROOT_DIR, 'icons')
|
| 1324 |
icon_names = ['move_forward.png', 'not_move_forward.png',
|
| 1325 |
'move_backward.png', 'not_move_backward.png',
|
|
|
|
| 1334 |
if os.path.exists(path):
|
| 1335 |
try:
|
| 1336 |
icon = Image.open(path).convert("RGBA")
|
| 1337 |
+
# 调整图标尺寸
|
| 1338 |
icon = icon.resize((50, 50), Image.Resampling.LANCZOS)
|
| 1339 |
icons[name] = icon
|
| 1340 |
except Exception as e:
|
| 1341 |
print(f"Error loading icon {name}: {e}")
|
| 1342 |
else:
|
| 1343 |
+
print(f"Warning: Icon {name} not found at {path}")
|
| 1344 |
|
| 1345 |
+
# 获取与视频帧对应的相机姿态
|
| 1346 |
time_compression_ratio = 4
|
| 1347 |
camera_poses = camera_embedding_full.detach().float().cpu().numpy()
|
| 1348 |
video_camera_poses = [x for x in camera_poses for _ in range(time_compression_ratio)]
|
|
|
|
| 1361 |
|
| 1362 |
writer.append_data(np.array(img))
|
| 1363 |
|
| 1364 |
+
print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
|
| 1365 |
+
print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
|
| 1366 |
+
print(f"使用模态: {modality_type}")
|
| 1367 |
|
| 1368 |
|
| 1369 |
def main():
|
| 1370 |
+
parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
|
| 1371 |
+
|
| 1372 |
+
# 基础参数
|
| 1373 |
+
parser.add_argument("--condition_pth", type=str,
|
| 1374 |
+
default="../examples/condition_pth/garden_1.pth")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1375 |
parser.add_argument("--start_frame", type=int, default=0)
|
| 1376 |
parser.add_argument("--initial_condition_frames", type=int, default=1)
|
| 1377 |
parser.add_argument("--frames_per_generation", type=int, default=8)
|
| 1378 |
parser.add_argument("--total_frames_to_generate", type=int, default=24)
|
| 1379 |
parser.add_argument("--max_history_frames", type=int, default=100)
|
| 1380 |
parser.add_argument("--use_real_poses", default=False)
|
| 1381 |
+
parser.add_argument("--dit_path", type=str, default=None, required=True,
|
|
|
|
| 1382 |
help="path to the pretrained DiT MoE model checkpoint")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1383 |
parser.add_argument("--output_path", type=str,
|
| 1384 |
+
default='./examples/output_videos/output_moe_framepack_sliding.mp4')
|
| 1385 |
+
parser.add_argument("--prompt", type=str, default=None,
|
|
|
|
|
|
|
| 1386 |
help="text prompt for video generation")
|
| 1387 |
parser.add_argument("--device", type=str, default="cuda")
|
| 1388 |
parser.add_argument("--add_icons", action="store_true", default=False,
|
| 1389 |
+
help="在生成的视频上叠加控制图标")
|
| 1390 |
|
| 1391 |
+
# 模态类型参数
|
| 1392 |
parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"],
|
| 1393 |
+
default="sekai", help="模态类型:sekai 或 nuscenes 或 openx")
|
| 1394 |
parser.add_argument("--scene_info_path", type=str, default=None,
|
| 1395 |
+
help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
|
| 1396 |
|
| 1397 |
+
# CFG参数
|
| 1398 |
parser.add_argument("--use_camera_cfg", default=False,
|
| 1399 |
+
help="使用Camera CFG")
|
| 1400 |
parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
|
| 1401 |
help="Camera guidance scale for CFG")
|
| 1402 |
parser.add_argument("--text_guidance_scale", type=float, default=1.0,
|
| 1403 |
help="Text guidance scale for CFG")
|
| 1404 |
|
| 1405 |
+
# MoE参数
|
| 1406 |
+
parser.add_argument("--moe_num_experts", type=int, default=3, help="专家数量")
|
| 1407 |
+
parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
|
| 1408 |
+
parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
|
| 1409 |
+
parser.add_argument("--direction", type=str, default="left", help="生成视频的行进轨迹方向")
|
| 1410 |
parser.add_argument("--use_gt_prompt", action="store_true", default=False,
|
| 1411 |
+
help="使用数据集中的ground truth prompt embedding")
|
| 1412 |
|
| 1413 |
args = parser.parse_args()
|
| 1414 |
|
| 1415 |
+
print(f"🔧 MoE FramePack CFG生成设置:")
|
| 1416 |
+
print(f"模态类型: {args.modality_type}")
|
| 1417 |
print(f"Camera CFG: {args.use_camera_cfg}")
|
| 1418 |
if args.use_camera_cfg:
|
| 1419 |
print(f"Camera guidance scale: {args.camera_guidance_scale}")
|
| 1420 |
+
print(f"使用GT Prompt: {args.use_gt_prompt}")
|
| 1421 |
print(f"Text guidance scale: {args.text_guidance_scale}")
|
| 1422 |
+
print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
|
| 1423 |
print(f"DiT{args.dit_path}")
|
| 1424 |
|
| 1425 |
+
# 验证NuScenes参数
|
| 1426 |
if args.modality_type == "nuscenes" and not args.scene_info_path:
|
| 1427 |
+
print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1428 |
|
| 1429 |
inference_moe_framepack_sliding_window(
|
| 1430 |
condition_pth_path=args.condition_pth,
|
|
|
|
|
|
|
| 1431 |
dit_path=args.dit_path,
|
|
|
|
| 1432 |
output_path=args.output_path,
|
| 1433 |
start_frame=args.start_frame,
|
| 1434 |
initial_condition_frames=args.initial_condition_frames,
|
|
|
|
| 1440 |
modality_type=args.modality_type,
|
| 1441 |
use_real_poses=args.use_real_poses,
|
| 1442 |
scene_info_path=args.scene_info_path,
|
| 1443 |
+
# CFG参数
|
| 1444 |
use_camera_cfg=args.use_camera_cfg,
|
| 1445 |
camera_guidance_scale=args.camera_guidance_scale,
|
| 1446 |
text_guidance_scale=args.text_guidance_scale,
|
| 1447 |
+
# MoE参数
|
| 1448 |
moe_num_experts=args.moe_num_experts,
|
| 1449 |
moe_top_k=args.moe_top_k,
|
| 1450 |
moe_hidden_dim=args.moe_hidden_dim,
|
scripts/infer_moe.py
ADDED
|
@@ -0,0 +1,1023 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import imageio
|
| 7 |
+
import json
|
| 8 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 9 |
+
import argparse
|
| 10 |
+
from torchvision.transforms import v2
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
import copy
|
| 13 |
+
from scipy.spatial.transform import Rotation as R
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def compute_relative_pose_matrix(pose1, pose2):
|
| 17 |
+
"""
|
| 18 |
+
计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel]
|
| 19 |
+
|
| 20 |
+
参数:
|
| 21 |
+
pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
|
| 22 |
+
pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
|
| 23 |
+
|
| 24 |
+
返回:
|
| 25 |
+
relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel
|
| 26 |
+
"""
|
| 27 |
+
# 分离平移向量和四元数
|
| 28 |
+
t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1]
|
| 29 |
+
q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1]
|
| 30 |
+
t2 = pose2[:3] # 第i+1帧平移
|
| 31 |
+
q2 = pose2[3:] # 第i+1帧四元数
|
| 32 |
+
|
| 33 |
+
# 1. 计算相对旋转矩阵 R_rel
|
| 34 |
+
rot1 = R.from_quat(q1) # 第i帧旋转
|
| 35 |
+
rot2 = R.from_quat(q2) # 第i+1帧旋转
|
| 36 |
+
rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆
|
| 37 |
+
R_rel = rot_rel.as_matrix() # 转换为3×3矩阵
|
| 38 |
+
|
| 39 |
+
# 2. 计算相对平移向量 t_rel
|
| 40 |
+
R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆)
|
| 41 |
+
t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1)
|
| 42 |
+
|
| 43 |
+
# 3. 组合为3×4矩阵 [R_rel | t_rel]
|
| 44 |
+
relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
|
| 45 |
+
|
| 46 |
+
return relative_matrix
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def calculate_relative_rotation(current_rotation, reference_rotation):
|
| 50 |
+
"""计算相对旋转四元数 - NuScenes专用"""
|
| 51 |
+
q_current = torch.tensor(current_rotation, dtype=torch.float32)
|
| 52 |
+
q_ref = torch.tensor(reference_rotation, dtype=torch.float32)
|
| 53 |
+
q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]])
|
| 54 |
+
w1, x1, y1, z1 = q_ref_inv
|
| 55 |
+
w2, x2, y2, z2 = q_current
|
| 56 |
+
relative_rotation = torch.tensor([
|
| 57 |
+
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
|
| 58 |
+
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
|
| 59 |
+
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
|
| 60 |
+
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
|
| 61 |
+
])
|
| 62 |
+
return relative_rotation
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 66 |
+
"""从pth文件加载预编码的视频数据"""
|
| 67 |
+
print(f"Loading encoded video from {pth_path}")
|
| 68 |
+
|
| 69 |
+
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
| 70 |
+
full_latents = encoded_data['latents'] # [C, T, H, W]
|
| 71 |
+
|
| 72 |
+
print(f"Full latents shape: {full_latents.shape}")
|
| 73 |
+
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
|
| 74 |
+
|
| 75 |
+
if start_frame + num_frames > full_latents.shape[1]:
|
| 76 |
+
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
|
| 77 |
+
|
| 78 |
+
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
|
| 79 |
+
print(f"Extracted condition latents shape: {condition_latents.shape}")
|
| 80 |
+
|
| 81 |
+
return condition_latents, encoded_data
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 85 |
+
"""计算相机B相对于相机A的相对位姿矩阵"""
|
| 86 |
+
assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
|
| 87 |
+
assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
|
| 88 |
+
|
| 89 |
+
if use_torch:
|
| 90 |
+
if not isinstance(pose_a, torch.Tensor):
|
| 91 |
+
pose_a = torch.from_numpy(pose_a).float()
|
| 92 |
+
if not isinstance(pose_b, torch.Tensor):
|
| 93 |
+
pose_b = torch.from_numpy(pose_b).float()
|
| 94 |
+
|
| 95 |
+
pose_a_inv = torch.inverse(pose_a)
|
| 96 |
+
relative_pose = torch.matmul(pose_b, pose_a_inv)
|
| 97 |
+
else:
|
| 98 |
+
if not isinstance(pose_a, np.ndarray):
|
| 99 |
+
pose_a = np.array(pose_a, dtype=np.float32)
|
| 100 |
+
if not isinstance(pose_b, np.ndarray):
|
| 101 |
+
pose_b = np.array(pose_b, dtype=np.float32)
|
| 102 |
+
|
| 103 |
+
pose_a_inv = np.linalg.inv(pose_a)
|
| 104 |
+
relative_pose = np.matmul(pose_b, pose_a_inv)
|
| 105 |
+
|
| 106 |
+
return relative_pose
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def replace_dit_model_in_manager():
|
| 110 |
+
"""替换DiT模型类为MoE版本"""
|
| 111 |
+
from diffsynth.models.wan_video_dit_moe import WanModelMoe
|
| 112 |
+
from diffsynth.configs.model_config import model_loader_configs
|
| 113 |
+
|
| 114 |
+
for i, config in enumerate(model_loader_configs):
|
| 115 |
+
keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
|
| 116 |
+
|
| 117 |
+
if 'wan_video_dit' in model_names:
|
| 118 |
+
new_model_names = []
|
| 119 |
+
new_model_classes = []
|
| 120 |
+
|
| 121 |
+
for name, cls in zip(model_names, model_classes):
|
| 122 |
+
if name == 'wan_video_dit':
|
| 123 |
+
new_model_names.append(name)
|
| 124 |
+
new_model_classes.append(WanModelMoe)
|
| 125 |
+
print(f"✅ 替换了模型类: {name} -> WanModelMoe")
|
| 126 |
+
else:
|
| 127 |
+
new_model_names.append(name)
|
| 128 |
+
new_model_classes.append(cls)
|
| 129 |
+
|
| 130 |
+
model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def add_framepack_components(dit_model):
|
| 134 |
+
"""添加FramePack相关组件"""
|
| 135 |
+
if not hasattr(dit_model, 'clean_x_embedder'):
|
| 136 |
+
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 137 |
+
|
| 138 |
+
class CleanXEmbedder(nn.Module):
|
| 139 |
+
def __init__(self, inner_dim):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 142 |
+
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
| 143 |
+
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
| 144 |
+
|
| 145 |
+
def forward(self, x, scale="1x"):
|
| 146 |
+
if scale == "1x":
|
| 147 |
+
x = x.to(self.proj.weight.dtype)
|
| 148 |
+
return self.proj(x)
|
| 149 |
+
elif scale == "2x":
|
| 150 |
+
x = x.to(self.proj_2x.weight.dtype)
|
| 151 |
+
return self.proj_2x(x)
|
| 152 |
+
elif scale == "4x":
|
| 153 |
+
x = x.to(self.proj_4x.weight.dtype)
|
| 154 |
+
return self.proj_4x(x)
|
| 155 |
+
else:
|
| 156 |
+
raise ValueError(f"Unsupported scale: {scale}")
|
| 157 |
+
|
| 158 |
+
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
|
| 159 |
+
model_dtype = next(dit_model.parameters()).dtype
|
| 160 |
+
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
|
| 161 |
+
print("✅ 添加了FramePack的clean_x_embedder组件")
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def add_moe_components(dit_model, moe_config):
|
| 165 |
+
"""🔧 添加MoE相关组件 - 修正版本"""
|
| 166 |
+
if not hasattr(dit_model, 'moe_config'):
|
| 167 |
+
dit_model.moe_config = moe_config
|
| 168 |
+
print("✅ 添加了MoE配置到模型")
|
| 169 |
+
dit_model.top_k = moe_config.get("top_k", 1)
|
| 170 |
+
|
| 171 |
+
# 为每个block动态添加MoE组件
|
| 172 |
+
dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 173 |
+
unified_dim = moe_config.get("unified_dim", 25)
|
| 174 |
+
num_experts = moe_config.get("num_experts", 4)
|
| 175 |
+
from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
|
| 176 |
+
dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
|
| 177 |
+
dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
|
| 178 |
+
dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理
|
| 179 |
+
dit_model.global_router = nn.Linear(unified_dim, num_experts)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
for i, block in enumerate(dit_model.blocks):
|
| 183 |
+
# MoE网络 - 输入unified_dim,输出dim
|
| 184 |
+
block.moe = MultiModalMoE(
|
| 185 |
+
unified_dim=unified_dim,
|
| 186 |
+
output_dim=dim, # 输出维度匹配transformer block的dim
|
| 187 |
+
num_experts=moe_config.get("num_experts", 4),
|
| 188 |
+
top_k=moe_config.get("top_k", 2)
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
|
| 195 |
+
"""为Sekai数据集生成camera embeddings - 滑动窗口版本"""
|
| 196 |
+
time_compression_ratio = 4
|
| 197 |
+
|
| 198 |
+
# 计算FramePack实际需要的camera帧数
|
| 199 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 200 |
+
|
| 201 |
+
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
|
| 202 |
+
print("🔧 使用真实Sekai camera数据")
|
| 203 |
+
cam_extrinsic = cam_data['extrinsic']
|
| 204 |
+
|
| 205 |
+
# 确保生成足够长的camera序列
|
| 206 |
+
max_needed_frames = max(
|
| 207 |
+
start_frame + current_history_length + new_frames,
|
| 208 |
+
framepack_needed_frames,
|
| 209 |
+
30
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
print(f"🔧 计算Sekai camera序列长度:")
|
| 213 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 214 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 215 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 216 |
+
|
| 217 |
+
relative_poses = []
|
| 218 |
+
for i in range(max_needed_frames):
|
| 219 |
+
# 计算当前帧在原始序列中的位置
|
| 220 |
+
frame_idx = i * time_compression_ratio
|
| 221 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 222 |
+
|
| 223 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 224 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 225 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 226 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 227 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 228 |
+
else:
|
| 229 |
+
# 超出范围,使用零运动
|
| 230 |
+
print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
|
| 231 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 232 |
+
|
| 233 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 234 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 235 |
+
|
| 236 |
+
# 创建对应长度的mask序列
|
| 237 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 238 |
+
# 从start_frame到current_history_length标记为condition
|
| 239 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 240 |
+
mask[start_frame:condition_end] = 1.0
|
| 241 |
+
|
| 242 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 243 |
+
print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
|
| 244 |
+
return camera_embedding.to(torch.bfloat16)
|
| 245 |
+
|
| 246 |
+
else:
|
| 247 |
+
print("🔧 使用Sekai合成camera数据")
|
| 248 |
+
|
| 249 |
+
max_needed_frames = max(
|
| 250 |
+
start_frame + current_history_length + new_frames,
|
| 251 |
+
framepack_needed_frames,
|
| 252 |
+
30
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
|
| 256 |
+
relative_poses = []
|
| 257 |
+
for i in range(max_needed_frames):
|
| 258 |
+
# 持续左转运动模式
|
| 259 |
+
yaw_per_frame = -0.1 # 每帧左转(正角度表示左转)
|
| 260 |
+
forward_speed = 0.005 # 每帧前进距离
|
| 261 |
+
|
| 262 |
+
pose = np.eye(4, dtype=np.float32)
|
| 263 |
+
|
| 264 |
+
# 旋转矩阵(绕Y轴左转)
|
| 265 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 266 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 267 |
+
|
| 268 |
+
pose[0, 0] = cos_yaw
|
| 269 |
+
pose[0, 2] = sin_yaw
|
| 270 |
+
pose[2, 0] = -sin_yaw
|
| 271 |
+
pose[2, 2] = cos_yaw
|
| 272 |
+
|
| 273 |
+
# 平移(在旋转后的局部坐标系中前进)
|
| 274 |
+
pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
|
| 275 |
+
|
| 276 |
+
# 添加轻微的向心运动,模拟圆形轨迹
|
| 277 |
+
radius_drift = 0.002 # 向圆心的轻微漂移
|
| 278 |
+
pose[0, 3] = radius_drift # 局部X轴负方向(向左)
|
| 279 |
+
|
| 280 |
+
relative_pose = pose[:3, :]
|
| 281 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 282 |
+
|
| 283 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 284 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 285 |
+
|
| 286 |
+
# 创建对应长度的mask序列
|
| 287 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 288 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 289 |
+
mask[start_frame:condition_end] = 1.0
|
| 290 |
+
|
| 291 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 292 |
+
print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
|
| 293 |
+
return camera_embedding.to(torch.bfloat16)
|
| 294 |
+
|
| 295 |
+
def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses):
|
| 296 |
+
"""为OpenX数据集生成camera embeddings - 滑动窗口版本"""
|
| 297 |
+
time_compression_ratio = 4
|
| 298 |
+
|
| 299 |
+
# 计算FramePack实际需要的camera帧数
|
| 300 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 301 |
+
|
| 302 |
+
if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
|
| 303 |
+
print("🔧 使用OpenX真实camera数据")
|
| 304 |
+
cam_extrinsic = encoded_data['cam_emb']['extrinsic']
|
| 305 |
+
|
| 306 |
+
# 确保生成足够长的camera序列
|
| 307 |
+
max_needed_frames = max(
|
| 308 |
+
start_frame + current_history_length + new_frames,
|
| 309 |
+
framepack_needed_frames,
|
| 310 |
+
30
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
print(f"🔧 计算OpenX camera序列长度:")
|
| 314 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 315 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 316 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 317 |
+
|
| 318 |
+
relative_poses = []
|
| 319 |
+
for i in range(max_needed_frames):
|
| 320 |
+
# OpenX使用4倍间隔,类似sekai但处理更短的序列
|
| 321 |
+
frame_idx = i * time_compression_ratio
|
| 322 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 323 |
+
|
| 324 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 325 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 326 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 327 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 328 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 329 |
+
else:
|
| 330 |
+
# 超出范围,使用零运动
|
| 331 |
+
print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
|
| 332 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 333 |
+
|
| 334 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 335 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 336 |
+
|
| 337 |
+
# 创建对应长度的mask序列
|
| 338 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 339 |
+
# 从start_frame到current_history_length标记为condition
|
| 340 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 341 |
+
mask[start_frame:condition_end] = 1.0
|
| 342 |
+
|
| 343 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 344 |
+
print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
|
| 345 |
+
return camera_embedding.to(torch.bfloat16)
|
| 346 |
+
|
| 347 |
+
else:
|
| 348 |
+
print("🔧 使用OpenX合成camera数据")
|
| 349 |
+
|
| 350 |
+
max_needed_frames = max(
|
| 351 |
+
start_frame + current_history_length + new_frames,
|
| 352 |
+
framepack_needed_frames,
|
| 353 |
+
30
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
|
| 357 |
+
relative_poses = []
|
| 358 |
+
for i in range(max_needed_frames):
|
| 359 |
+
# OpenX机器人操作运动模式 - 较小的运动幅度
|
| 360 |
+
# 模拟机器人手臂的精细操作运动
|
| 361 |
+
roll_per_frame = 0.02 # 轻微翻滚
|
| 362 |
+
pitch_per_frame = 0.01 # 轻微俯仰
|
| 363 |
+
yaw_per_frame = 0.015 # 轻微偏航
|
| 364 |
+
forward_speed = 0.003 # 较慢的前进速度
|
| 365 |
+
|
| 366 |
+
pose = np.eye(4, dtype=np.float32)
|
| 367 |
+
|
| 368 |
+
# 复合旋转 - 模拟机器人手臂的复杂运动
|
| 369 |
+
# 绕X轴旋转(roll)
|
| 370 |
+
cos_roll = np.cos(roll_per_frame)
|
| 371 |
+
sin_roll = np.sin(roll_per_frame)
|
| 372 |
+
# 绕Y轴旋转(pitch)
|
| 373 |
+
cos_pitch = np.cos(pitch_per_frame)
|
| 374 |
+
sin_pitch = np.sin(pitch_per_frame)
|
| 375 |
+
# 绕Z轴旋转(yaw)
|
| 376 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 377 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 378 |
+
|
| 379 |
+
# 简化的复合旋转矩阵(ZYX顺序)
|
| 380 |
+
pose[0, 0] = cos_yaw * cos_pitch
|
| 381 |
+
pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
|
| 382 |
+
pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
|
| 383 |
+
pose[1, 0] = sin_yaw * cos_pitch
|
| 384 |
+
pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
|
| 385 |
+
pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
|
| 386 |
+
pose[2, 0] = -sin_pitch
|
| 387 |
+
pose[2, 1] = cos_pitch * sin_roll
|
| 388 |
+
pose[2, 2] = cos_pitch * cos_roll
|
| 389 |
+
|
| 390 |
+
# 平移 - 模拟机器人操作的精细移动
|
| 391 |
+
pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
|
| 392 |
+
pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动
|
| 393 |
+
pose[2, 3] = -forward_speed # Z方向(深度)主要移动
|
| 394 |
+
|
| 395 |
+
relative_pose = pose[:3, :]
|
| 396 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 397 |
+
|
| 398 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 399 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 400 |
+
|
| 401 |
+
# 创建对应长度的mask序列
|
| 402 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 403 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 404 |
+
mask[start_frame:condition_end] = 1.0
|
| 405 |
+
|
| 406 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 407 |
+
print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
|
| 408 |
+
return camera_embedding.to(torch.bfloat16)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames):
|
| 412 |
+
"""为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
|
| 413 |
+
time_compression_ratio = 4
|
| 414 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 415 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 416 |
+
|
| 417 |
+
if scene_info is not None and 'keyframe_poses' in scene_info:
|
| 418 |
+
print("🔧 使用NuScenes真实pose数据")
|
| 419 |
+
keyframe_poses = scene_info['keyframe_poses']
|
| 420 |
+
# 生成所有需要的关键帧索引
|
| 421 |
+
keyframe_indices = []
|
| 422 |
+
for i in range(max_needed_frames + 1): # +1是因为需要前后两帧
|
| 423 |
+
idx = (start_frame + i) * time_compression_ratio
|
| 424 |
+
keyframe_indices.append(idx)
|
| 425 |
+
keyframe_indices = [min(idx, len(keyframe_poses)-1) for idx in keyframe_indices]
|
| 426 |
+
|
| 427 |
+
pose_vecs = []
|
| 428 |
+
for i in range(max_needed_frames):
|
| 429 |
+
pose_prev = keyframe_poses[keyframe_indices[i]]
|
| 430 |
+
pose_next = keyframe_poses[keyframe_indices[i+1]]
|
| 431 |
+
# 计算相对位移
|
| 432 |
+
translation = torch.tensor(
|
| 433 |
+
np.array(pose_next['translation']) - np.array(pose_prev['translation']),
|
| 434 |
+
dtype=torch.float32
|
| 435 |
+
)
|
| 436 |
+
# 计算相对旋转
|
| 437 |
+
relative_rotation = calculate_relative_rotation(
|
| 438 |
+
pose_next['rotation'],
|
| 439 |
+
pose_prev['rotation']
|
| 440 |
+
)
|
| 441 |
+
pose_vec = torch.cat([translation, relative_rotation], dim=0) # [7D]
|
| 442 |
+
pose_vecs.append(pose_vec)
|
| 443 |
+
|
| 444 |
+
pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
|
| 445 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 446 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 447 |
+
mask[start_frame:condition_end] = 1.0
|
| 448 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1)
|
| 449 |
+
print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
|
| 450 |
+
return camera_embedding.to(torch.bfloat16)
|
| 451 |
+
|
| 452 |
+
else:
|
| 453 |
+
print("🔧 使用NuScenes合成pose数据")
|
| 454 |
+
# 先生成绝对轨迹
|
| 455 |
+
abs_translations = []
|
| 456 |
+
abs_rotations = []
|
| 457 |
+
for i in range(max_needed_frames + 1): # +1是为了后续做相对
|
| 458 |
+
angle = -i * 0.12
|
| 459 |
+
radius = 8.0
|
| 460 |
+
x = radius * np.sin(angle)
|
| 461 |
+
y = 0.0
|
| 462 |
+
z = radius * (1 - np.cos(angle))
|
| 463 |
+
abs_translations.append(np.array([x, y, z], dtype=np.float32))
|
| 464 |
+
yaw = angle + np.pi/2
|
| 465 |
+
abs_rotations.append(np.array([
|
| 466 |
+
np.cos(yaw/2), 0.0, 0.0, np.sin(yaw/2)
|
| 467 |
+
], dtype=np.float32))
|
| 468 |
+
|
| 469 |
+
# 计算每帧相对上一帧的运动
|
| 470 |
+
pose_vecs = []
|
| 471 |
+
for i in range(max_needed_frames):
|
| 472 |
+
translation = torch.tensor(abs_translations[i+1] - abs_translations[i], dtype=torch.float32)
|
| 473 |
+
# 计算相对旋转
|
| 474 |
+
q_next = abs_rotations[i+1]
|
| 475 |
+
q_prev = abs_rotations[i]
|
| 476 |
+
# 四元数相对旋转
|
| 477 |
+
q_prev_inv = np.array([q_prev[0], -q_prev[1], -q_prev[2], -q_prev[3]], dtype=np.float32)
|
| 478 |
+
w1, x1, y1, z1 = q_prev_inv
|
| 479 |
+
w2, x2, y2, z2 = q_next
|
| 480 |
+
relative_rotation = torch.tensor([
|
| 481 |
+
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
|
| 482 |
+
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
|
| 483 |
+
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
|
| 484 |
+
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
|
| 485 |
+
], dtype=torch.float32)
|
| 486 |
+
pose_vec = torch.cat([translation, relative_rotation], dim=0) # [7D]
|
| 487 |
+
pose_vecs.append(pose_vec)
|
| 488 |
+
|
| 489 |
+
pose_sequence = torch.stack(pose_vecs, dim=0)
|
| 490 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 491 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 492 |
+
mask[start_frame:condition_end] = 1.0
|
| 493 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1)
|
| 494 |
+
print(f"🔧 NuScenes合成相对pose embedding shape: {camera_embedding.shape}")
|
| 495 |
+
return camera_embedding.to(torch.bfloat16)
|
| 496 |
+
|
| 497 |
+
def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49):
|
| 498 |
+
"""FramePack滑动窗口机制 - MoE版本"""
|
| 499 |
+
# history_latents: [C, T, H, W] 当前的历史latents
|
| 500 |
+
C, T, H, W = history_latents.shape
|
| 501 |
+
|
| 502 |
+
# 固定索引结构(这决定了需要的camera帧数)
|
| 503 |
+
total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
|
| 504 |
+
indices = torch.arange(0, total_indices_length)
|
| 505 |
+
split_sizes = [1, 16, 2, 1, target_frames_to_generate]
|
| 506 |
+
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
|
| 507 |
+
indices.split(split_sizes, dim=0)
|
| 508 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
|
| 509 |
+
|
| 510 |
+
# 检查camera长度是否足够
|
| 511 |
+
if camera_embedding_full.shape[0] < total_indices_length:
|
| 512 |
+
shortage = total_indices_length - camera_embedding_full.shape[0]
|
| 513 |
+
padding = torch.zeros(shortage, camera_embedding_full.shape[1],
|
| 514 |
+
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
|
| 515 |
+
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
|
| 516 |
+
|
| 517 |
+
# 从完整camera序列中选取对应部分
|
| 518 |
+
combined_camera = camera_embedding_full[:total_indices_length, :].clone()
|
| 519 |
+
|
| 520 |
+
# 根据当前history length重新设置mask
|
| 521 |
+
combined_camera[:, -1] = 0.0 # 先全部设为target (0)
|
| 522 |
+
|
| 523 |
+
# 设置condition mask:前19帧根据实际历史长度决定
|
| 524 |
+
if T > 0:
|
| 525 |
+
available_frames = min(T, 19)
|
| 526 |
+
start_pos = 19 - available_frames
|
| 527 |
+
combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
|
| 528 |
+
|
| 529 |
+
print(f"🔧 MoE Camera mask更新:")
|
| 530 |
+
print(f" - 历史帧数: {T}")
|
| 531 |
+
print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
|
| 532 |
+
print(f" - 模态类型: {modality_type}")
|
| 533 |
+
|
| 534 |
+
# 处理latents
|
| 535 |
+
clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 536 |
+
|
| 537 |
+
if T > 0:
|
| 538 |
+
available_frames = min(T, 19)
|
| 539 |
+
start_pos = 19 - available_frames
|
| 540 |
+
clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
|
| 541 |
+
|
| 542 |
+
clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
|
| 543 |
+
clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
|
| 544 |
+
clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
|
| 545 |
+
|
| 546 |
+
if T > 0:
|
| 547 |
+
start_latent = history_latents[:, 0:1, :, :]
|
| 548 |
+
else:
|
| 549 |
+
start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 550 |
+
|
| 551 |
+
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
|
| 552 |
+
|
| 553 |
+
return {
|
| 554 |
+
'latent_indices': latent_indices,
|
| 555 |
+
'clean_latents': clean_latents,
|
| 556 |
+
'clean_latents_2x': clean_latents_2x,
|
| 557 |
+
'clean_latents_4x': clean_latents_4x,
|
| 558 |
+
'clean_latent_indices': clean_latent_indices,
|
| 559 |
+
'clean_latent_2x_indices': clean_latent_2x_indices,
|
| 560 |
+
'clean_latent_4x_indices': clean_latent_4x_indices,
|
| 561 |
+
'camera_embedding': combined_camera,
|
| 562 |
+
'modality_type': modality_type, # 新增模态类型信息
|
| 563 |
+
'current_length': T,
|
| 564 |
+
'next_length': T + target_frames_to_generate
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def inference_moe_framepack_sliding_window(
|
| 569 |
+
condition_pth_path,
|
| 570 |
+
dit_path,
|
| 571 |
+
output_path="moe/infer_results/output_moe_framepack_sliding.mp4",
|
| 572 |
+
start_frame=0,
|
| 573 |
+
initial_condition_frames=8,
|
| 574 |
+
frames_per_generation=4,
|
| 575 |
+
total_frames_to_generate=32,
|
| 576 |
+
max_history_frames=49,
|
| 577 |
+
device="cuda",
|
| 578 |
+
prompt="A video of a scene shot using a pedestrian's front camera while walking",
|
| 579 |
+
modality_type="sekai", # "sekai" 或 "nuscenes"
|
| 580 |
+
use_real_poses=True,
|
| 581 |
+
scene_info_path=None, # 对于NuScenes数据集
|
| 582 |
+
# CFG参数
|
| 583 |
+
use_camera_cfg=True,
|
| 584 |
+
camera_guidance_scale=2.0,
|
| 585 |
+
text_guidance_scale=1.0,
|
| 586 |
+
# MoE参数
|
| 587 |
+
moe_num_experts=4,
|
| 588 |
+
moe_top_k=2,
|
| 589 |
+
moe_hidden_dim=None
|
| 590 |
+
):
|
| 591 |
+
"""
|
| 592 |
+
MoE FramePack滑动窗口视频生成 - 支持多模态
|
| 593 |
+
"""
|
| 594 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 595 |
+
print(f"🔧 MoE FramePack滑动窗口生成开始...")
|
| 596 |
+
print(f"模态类型: {modality_type}")
|
| 597 |
+
print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
|
| 598 |
+
print(f"Text guidance scale: {text_guidance_scale}")
|
| 599 |
+
print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
|
| 600 |
+
|
| 601 |
+
# 1. 模型初始化
|
| 602 |
+
replace_dit_model_in_manager()
|
| 603 |
+
|
| 604 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 605 |
+
model_manager.load_models([
|
| 606 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 607 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 608 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 609 |
+
])
|
| 610 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 611 |
+
|
| 612 |
+
# 2. 添加传统camera编码器(兼容性)
|
| 613 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 614 |
+
for block in pipe.dit.blocks:
|
| 615 |
+
block.cam_encoder = nn.Linear(13, dim)
|
| 616 |
+
block.projector = nn.Linear(dim, dim)
|
| 617 |
+
block.cam_encoder.weight.data.zero_()
|
| 618 |
+
block.cam_encoder.bias.data.zero_()
|
| 619 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 620 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 621 |
+
|
| 622 |
+
# 3. 添加FramePack组件
|
| 623 |
+
add_framepack_components(pipe.dit)
|
| 624 |
+
|
| 625 |
+
# 4. 添加MoE组件
|
| 626 |
+
moe_config = {
|
| 627 |
+
"num_experts": moe_num_experts,
|
| 628 |
+
"top_k": moe_top_k,
|
| 629 |
+
"hidden_dim": moe_hidden_dim or dim * 2,
|
| 630 |
+
"sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
|
| 631 |
+
"nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
|
| 632 |
+
"openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
|
| 633 |
+
}
|
| 634 |
+
add_moe_components(pipe.dit, moe_config)
|
| 635 |
+
|
| 636 |
+
# 5. 加载训练好的权重
|
| 637 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 638 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
|
| 639 |
+
pipe = pipe.to(device)
|
| 640 |
+
model_dtype = next(pipe.dit.parameters()).dtype
|
| 641 |
+
|
| 642 |
+
if hasattr(pipe.dit, 'clean_x_embedder'):
|
| 643 |
+
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
|
| 644 |
+
|
| 645 |
+
pipe.scheduler.set_timesteps(50)
|
| 646 |
+
|
| 647 |
+
# 6. 加载初始条件
|
| 648 |
+
print("Loading initial condition frames...")
|
| 649 |
+
initial_latents, encoded_data = load_encoded_video_from_pth(
|
| 650 |
+
condition_pth_path,
|
| 651 |
+
start_frame=start_frame,
|
| 652 |
+
num_frames=initial_condition_frames
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
# 空间裁剪
|
| 656 |
+
target_height, target_width = 60, 104
|
| 657 |
+
C, T, H, W = initial_latents.shape
|
| 658 |
+
|
| 659 |
+
if H > target_height or W > target_width:
|
| 660 |
+
h_start = (H - target_height) // 2
|
| 661 |
+
w_start = (W - target_width) // 2
|
| 662 |
+
initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
|
| 663 |
+
H, W = target_height, target_width
|
| 664 |
+
|
| 665 |
+
history_latents = initial_latents.to(device, dtype=model_dtype)
|
| 666 |
+
|
| 667 |
+
print(f"初始history_latents shape: {history_latents.shape}")
|
| 668 |
+
|
| 669 |
+
# 7. 编码prompt - 支持CFG
|
| 670 |
+
if text_guidance_scale > 1.0:
|
| 671 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 672 |
+
prompt_emb_neg = pipe.encode_prompt("")
|
| 673 |
+
print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
|
| 674 |
+
else:
|
| 675 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 676 |
+
prompt_emb_neg = None
|
| 677 |
+
print("不使用Text CFG")
|
| 678 |
+
|
| 679 |
+
# 8. 加载场景信息(对于NuScenes)
|
| 680 |
+
scene_info = None
|
| 681 |
+
if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
|
| 682 |
+
with open(scene_info_path, 'r') as f:
|
| 683 |
+
scene_info = json.load(f)
|
| 684 |
+
print(f"加载NuScenes场景信息: {scene_info_path}")
|
| 685 |
+
|
| 686 |
+
# 9. 预生成完整的camera embedding序列
|
| 687 |
+
if modality_type == "sekai":
|
| 688 |
+
camera_embedding_full = generate_sekai_camera_embeddings_sliding(
|
| 689 |
+
encoded_data.get('cam_emb', None),
|
| 690 |
+
0,
|
| 691 |
+
max_history_frames,
|
| 692 |
+
0,
|
| 693 |
+
0,
|
| 694 |
+
use_real_poses=use_real_poses
|
| 695 |
+
).to(device, dtype=model_dtype)
|
| 696 |
+
elif modality_type == "nuscenes":
|
| 697 |
+
camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
|
| 698 |
+
scene_info,
|
| 699 |
+
0,
|
| 700 |
+
max_history_frames,
|
| 701 |
+
0
|
| 702 |
+
).to(device, dtype=model_dtype)
|
| 703 |
+
elif modality_type == "openx":
|
| 704 |
+
camera_embedding_full = generate_openx_camera_embeddings_sliding(
|
| 705 |
+
encoded_data,
|
| 706 |
+
0,
|
| 707 |
+
max_history_frames,
|
| 708 |
+
0,
|
| 709 |
+
use_real_poses=use_real_poses
|
| 710 |
+
).to(device, dtype=model_dtype)
|
| 711 |
+
else:
|
| 712 |
+
raise ValueError(f"不支持的模态类型: {modality_type}")
|
| 713 |
+
|
| 714 |
+
print(f"完整camera序列shape: {camera_embedding_full.shape}")
|
| 715 |
+
|
| 716 |
+
# 10. 为Camera CFG创建无条件的camera embedding
|
| 717 |
+
if use_camera_cfg:
|
| 718 |
+
camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
|
| 719 |
+
print(f"创建无条件camera embedding用于CFG")
|
| 720 |
+
|
| 721 |
+
# 11. 滑动窗口生成循环
|
| 722 |
+
total_generated = 0
|
| 723 |
+
all_generated_frames = []
|
| 724 |
+
|
| 725 |
+
while total_generated < total_frames_to_generate:
|
| 726 |
+
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
|
| 727 |
+
print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
|
| 728 |
+
print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
|
| 729 |
+
|
| 730 |
+
# FramePack数据准备 - MoE版本
|
| 731 |
+
framepack_data = prepare_framepack_sliding_window_with_camera_moe(
|
| 732 |
+
history_latents,
|
| 733 |
+
current_generation,
|
| 734 |
+
camera_embedding_full,
|
| 735 |
+
start_frame,
|
| 736 |
+
modality_type,
|
| 737 |
+
max_history_frames
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
# 准备输入
|
| 741 |
+
clean_latents = framepack_data['clean_latents'].unsqueeze(0)
|
| 742 |
+
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
|
| 743 |
+
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
|
| 744 |
+
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
|
| 745 |
+
|
| 746 |
+
# 准备modality_inputs
|
| 747 |
+
modality_inputs = {modality_type: camera_embedding}
|
| 748 |
+
|
| 749 |
+
# 为CFG准备无条件camera embedding
|
| 750 |
+
if use_camera_cfg:
|
| 751 |
+
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
|
| 752 |
+
modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
|
| 753 |
+
|
| 754 |
+
# 索引处理
|
| 755 |
+
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
|
| 756 |
+
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
|
| 757 |
+
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
|
| 758 |
+
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
|
| 759 |
+
|
| 760 |
+
# 初始化要生成的latents
|
| 761 |
+
new_latents = torch.randn(
|
| 762 |
+
1, C, current_generation, H, W,
|
| 763 |
+
device=device, dtype=model_dtype
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
extra_input = pipe.prepare_extra_input(new_latents)
|
| 767 |
+
|
| 768 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 769 |
+
print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
|
| 770 |
+
|
| 771 |
+
# 去噪循环 - 支持CFG
|
| 772 |
+
timesteps = pipe.scheduler.timesteps
|
| 773 |
+
|
| 774 |
+
for i, timestep in enumerate(timesteps):
|
| 775 |
+
if i % 10 == 0:
|
| 776 |
+
print(f" 去噪步骤 {i+1}/{len(timesteps)}")
|
| 777 |
+
|
| 778 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
|
| 779 |
+
|
| 780 |
+
with torch.no_grad():
|
| 781 |
+
# CFG推理
|
| 782 |
+
if use_camera_cfg and camera_guidance_scale > 1.0:
|
| 783 |
+
# 条件预测(有camera)
|
| 784 |
+
noise_pred_cond, moe_loess = pipe.dit(
|
| 785 |
+
new_latents,
|
| 786 |
+
timestep=timestep_tensor,
|
| 787 |
+
cam_emb=camera_embedding,
|
| 788 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 789 |
+
latent_indices=latent_indices,
|
| 790 |
+
clean_latents=clean_latents,
|
| 791 |
+
clean_latent_indices=clean_latent_indices,
|
| 792 |
+
clean_latents_2x=clean_latents_2x,
|
| 793 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 794 |
+
clean_latents_4x=clean_latents_4x,
|
| 795 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 796 |
+
**prompt_emb_pos,
|
| 797 |
+
**extra_input
|
| 798 |
+
)
|
| 799 |
+
|
| 800 |
+
# 无条件预测(无camera)
|
| 801 |
+
noise_pred_uncond, moe_loess = pipe.dit(
|
| 802 |
+
new_latents,
|
| 803 |
+
timestep=timestep_tensor,
|
| 804 |
+
cam_emb=camera_embedding_uncond_batch,
|
| 805 |
+
modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
|
| 806 |
+
latent_indices=latent_indices,
|
| 807 |
+
clean_latents=clean_latents,
|
| 808 |
+
clean_latent_indices=clean_latent_indices,
|
| 809 |
+
clean_latents_2x=clean_latents_2x,
|
| 810 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 811 |
+
clean_latents_4x=clean_latents_4x,
|
| 812 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 813 |
+
**(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
|
| 814 |
+
**extra_input
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
# Camera CFG
|
| 818 |
+
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 819 |
+
|
| 820 |
+
# 如果同时使用Text CFG
|
| 821 |
+
if text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 822 |
+
noise_pred_text_uncond, moe_loess = pipe.dit(
|
| 823 |
+
new_latents,
|
| 824 |
+
timestep=timestep_tensor,
|
| 825 |
+
cam_emb=camera_embedding,
|
| 826 |
+
modality_inputs=modality_inputs,
|
| 827 |
+
latent_indices=latent_indices,
|
| 828 |
+
clean_latents=clean_latents,
|
| 829 |
+
clean_latent_indices=clean_latent_indices,
|
| 830 |
+
clean_latents_2x=clean_latents_2x,
|
| 831 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 832 |
+
clean_latents_4x=clean_latents_4x,
|
| 833 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 834 |
+
**prompt_emb_neg,
|
| 835 |
+
**extra_input
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
# 应用Text CFG到已经应用Camera CFG的结果
|
| 839 |
+
noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
|
| 840 |
+
|
| 841 |
+
elif text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 842 |
+
# 只使用Text CFG
|
| 843 |
+
noise_pred_cond, moe_loess = pipe.dit(
|
| 844 |
+
new_latents,
|
| 845 |
+
timestep=timestep_tensor,
|
| 846 |
+
cam_emb=camera_embedding,
|
| 847 |
+
modality_inputs=modality_inputs,
|
| 848 |
+
latent_indices=latent_indices,
|
| 849 |
+
clean_latents=clean_latents,
|
| 850 |
+
clean_latent_indices=clean_latent_indices,
|
| 851 |
+
clean_latents_2x=clean_latents_2x,
|
| 852 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 853 |
+
clean_latents_4x=clean_latents_4x,
|
| 854 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 855 |
+
**prompt_emb_pos,
|
| 856 |
+
**extra_input
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
noise_pred_uncond, moe_loess= pipe.dit(
|
| 860 |
+
new_latents,
|
| 861 |
+
timestep=timestep_tensor,
|
| 862 |
+
cam_emb=camera_embedding,
|
| 863 |
+
modality_inputs=modality_inputs,
|
| 864 |
+
latent_indices=latent_indices,
|
| 865 |
+
clean_latents=clean_latents,
|
| 866 |
+
clean_latent_indices=clean_latent_indices,
|
| 867 |
+
clean_latents_2x=clean_latents_2x,
|
| 868 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 869 |
+
clean_latents_4x=clean_latents_4x,
|
| 870 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 871 |
+
**prompt_emb_neg,
|
| 872 |
+
**extra_input
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 876 |
+
|
| 877 |
+
else:
|
| 878 |
+
# 标准推理(无CFG)
|
| 879 |
+
noise_pred, moe_loess = pipe.dit(
|
| 880 |
+
new_latents,
|
| 881 |
+
timestep=timestep_tensor,
|
| 882 |
+
cam_emb=camera_embedding,
|
| 883 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 884 |
+
latent_indices=latent_indices,
|
| 885 |
+
clean_latents=clean_latents,
|
| 886 |
+
clean_latent_indices=clean_latent_indices,
|
| 887 |
+
clean_latents_2x=clean_latents_2x,
|
| 888 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 889 |
+
clean_latents_4x=clean_latents_4x,
|
| 890 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 891 |
+
**prompt_emb_pos,
|
| 892 |
+
**extra_input
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
|
| 896 |
+
|
| 897 |
+
# 更新历史
|
| 898 |
+
new_latents_squeezed = new_latents.squeeze(0)
|
| 899 |
+
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
|
| 900 |
+
|
| 901 |
+
# 维护滑动窗口
|
| 902 |
+
if history_latents.shape[1] > max_history_frames:
|
| 903 |
+
first_frame = history_latents[:, 0:1, :, :]
|
| 904 |
+
recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
|
| 905 |
+
history_latents = torch.cat([first_frame, recent_frames], dim=1)
|
| 906 |
+
print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
|
| 907 |
+
|
| 908 |
+
print(f"更新后history_latents shape: {history_latents.shape}")
|
| 909 |
+
|
| 910 |
+
all_generated_frames.append(new_latents_squeezed)
|
| 911 |
+
total_generated += current_generation
|
| 912 |
+
|
| 913 |
+
print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
|
| 914 |
+
|
| 915 |
+
# 12. 解码和保存
|
| 916 |
+
print("\n🔧 解码生成的视频...")
|
| 917 |
+
|
| 918 |
+
all_generated = torch.cat(all_generated_frames, dim=1)
|
| 919 |
+
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
|
| 920 |
+
|
| 921 |
+
print(f"最终视频shape: {final_video.shape}")
|
| 922 |
+
|
| 923 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 924 |
+
|
| 925 |
+
print(f"Saving video to {output_path}")
|
| 926 |
+
|
| 927 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
|
| 928 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1)
|
| 929 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 930 |
+
|
| 931 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 932 |
+
for frame in video_np:
|
| 933 |
+
writer.append_data(frame)
|
| 934 |
+
|
| 935 |
+
print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
|
| 936 |
+
print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
|
| 937 |
+
print(f"使用模态: {modality_type}")
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
def main():
|
| 941 |
+
parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
|
| 942 |
+
|
| 943 |
+
# 基础参数
|
| 944 |
+
parser.add_argument("--condition_pth", type=str,
|
| 945 |
+
default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
|
| 946 |
+
#default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth")
|
| 947 |
+
#default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth")
|
| 948 |
+
#default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth")
|
| 949 |
+
parser.add_argument("--start_frame", type=int, default=0)
|
| 950 |
+
parser.add_argument("--initial_condition_frames", type=int, default=16)
|
| 951 |
+
parser.add_argument("--frames_per_generation", type=int, default=8)
|
| 952 |
+
parser.add_argument("--total_frames_to_generate", type=int, default=24)
|
| 953 |
+
parser.add_argument("--max_history_frames", type=int, default=100)
|
| 954 |
+
parser.add_argument("--use_real_poses", default=True)
|
| 955 |
+
parser.add_argument("--dit_path", type=str,
|
| 956 |
+
default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step25000_first.ckpt")
|
| 957 |
+
parser.add_argument("--output_path", type=str,
|
| 958 |
+
default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4')
|
| 959 |
+
parser.add_argument("--prompt", type=str,
|
| 960 |
+
default="A drone flying scene in a game world ")
|
| 961 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 962 |
+
|
| 963 |
+
# 模态类型参数
|
| 964 |
+
parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="sekai",
|
| 965 |
+
help="模态类型:sekai 或 nuscenes 或 openx")
|
| 966 |
+
parser.add_argument("--scene_info_path", type=str, default=None,
|
| 967 |
+
help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
|
| 968 |
+
|
| 969 |
+
# CFG参数
|
| 970 |
+
parser.add_argument("--use_camera_cfg", default=False,
|
| 971 |
+
help="使用Camera CFG")
|
| 972 |
+
parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
|
| 973 |
+
help="Camera guidance scale for CFG")
|
| 974 |
+
parser.add_argument("--text_guidance_scale", type=float, default=1.0,
|
| 975 |
+
help="Text guidance scale for CFG")
|
| 976 |
+
|
| 977 |
+
# MoE参数
|
| 978 |
+
parser.add_argument("--moe_num_experts", type=int, default=3, help="专家数量")
|
| 979 |
+
parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
|
| 980 |
+
parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
|
| 981 |
+
|
| 982 |
+
args = parser.parse_args()
|
| 983 |
+
|
| 984 |
+
print(f"🔧 MoE FramePack CFG生成设置:")
|
| 985 |
+
print(f"模态类型: {args.modality_type}")
|
| 986 |
+
print(f"Camera CFG: {args.use_camera_cfg}")
|
| 987 |
+
if args.use_camera_cfg:
|
| 988 |
+
print(f"Camera guidance scale: {args.camera_guidance_scale}")
|
| 989 |
+
print(f"Text guidance scale: {args.text_guidance_scale}")
|
| 990 |
+
print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
|
| 991 |
+
print(f"DiT{args.dit_path}")
|
| 992 |
+
|
| 993 |
+
# 验证NuScenes参数
|
| 994 |
+
if args.modality_type == "nuscenes" and not args.scene_info_path:
|
| 995 |
+
print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
|
| 996 |
+
|
| 997 |
+
inference_moe_framepack_sliding_window(
|
| 998 |
+
condition_pth_path=args.condition_pth,
|
| 999 |
+
dit_path=args.dit_path,
|
| 1000 |
+
output_path=args.output_path,
|
| 1001 |
+
start_frame=args.start_frame,
|
| 1002 |
+
initial_condition_frames=args.initial_condition_frames,
|
| 1003 |
+
frames_per_generation=args.frames_per_generation,
|
| 1004 |
+
total_frames_to_generate=args.total_frames_to_generate,
|
| 1005 |
+
max_history_frames=args.max_history_frames,
|
| 1006 |
+
device=args.device,
|
| 1007 |
+
prompt=args.prompt,
|
| 1008 |
+
modality_type=args.modality_type,
|
| 1009 |
+
use_real_poses=args.use_real_poses,
|
| 1010 |
+
scene_info_path=args.scene_info_path,
|
| 1011 |
+
# CFG参数
|
| 1012 |
+
use_camera_cfg=args.use_camera_cfg,
|
| 1013 |
+
camera_guidance_scale=args.camera_guidance_scale,
|
| 1014 |
+
text_guidance_scale=args.text_guidance_scale,
|
| 1015 |
+
# MoE参数
|
| 1016 |
+
moe_num_experts=args.moe_num_experts,
|
| 1017 |
+
moe_top_k=args.moe_top_k,
|
| 1018 |
+
moe_hidden_dim=args.moe_hidden_dim
|
| 1019 |
+
)
|
| 1020 |
+
|
| 1021 |
+
|
| 1022 |
+
if __name__ == "__main__":
|
| 1023 |
+
main()
|
scripts/infer_moe_spatialvid.py
ADDED
|
@@ -0,0 +1,1008 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import imageio
|
| 7 |
+
import json
|
| 8 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 9 |
+
import argparse
|
| 10 |
+
from torchvision.transforms import v2
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
import copy
|
| 13 |
+
from scipy.spatial.transform import Rotation as R
|
| 14 |
+
|
| 15 |
+
def compute_relative_pose_matrix(pose1, pose2):
|
| 16 |
+
"""
|
| 17 |
+
计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel]
|
| 18 |
+
|
| 19 |
+
参数:
|
| 20 |
+
pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
|
| 21 |
+
pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
|
| 22 |
+
|
| 23 |
+
返回:
|
| 24 |
+
relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel
|
| 25 |
+
"""
|
| 26 |
+
# 分离平移向量和四元数
|
| 27 |
+
t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1]
|
| 28 |
+
q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1]
|
| 29 |
+
t2 = pose2[:3] # 第i+1帧平移
|
| 30 |
+
q2 = pose2[3:] # 第i+1帧四元数
|
| 31 |
+
|
| 32 |
+
# 1. 计算相对旋转矩阵 R_rel
|
| 33 |
+
rot1 = R.from_quat(q1) # 第i帧旋转
|
| 34 |
+
rot2 = R.from_quat(q2) # 第i+1帧旋转
|
| 35 |
+
rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆
|
| 36 |
+
R_rel = rot_rel.as_matrix() # 转换为3×3矩阵
|
| 37 |
+
|
| 38 |
+
# 2. 计算相对平移向量 t_rel
|
| 39 |
+
R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆)
|
| 40 |
+
t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1)
|
| 41 |
+
|
| 42 |
+
# 3. 组合为3×4矩阵 [R_rel | t_rel]
|
| 43 |
+
relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
|
| 44 |
+
|
| 45 |
+
return relative_matrix
|
| 46 |
+
|
| 47 |
+
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 48 |
+
"""从pth文件加载预编码的视频数据"""
|
| 49 |
+
print(f"Loading encoded video from {pth_path}")
|
| 50 |
+
|
| 51 |
+
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
| 52 |
+
full_latents = encoded_data['latents'] # [C, T, H, W]
|
| 53 |
+
|
| 54 |
+
print(f"Full latents shape: {full_latents.shape}")
|
| 55 |
+
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
|
| 56 |
+
|
| 57 |
+
if start_frame + num_frames > full_latents.shape[1]:
|
| 58 |
+
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
|
| 59 |
+
|
| 60 |
+
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
|
| 61 |
+
print(f"Extracted condition latents shape: {condition_latents.shape}")
|
| 62 |
+
|
| 63 |
+
return condition_latents, encoded_data
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 67 |
+
"""计算相机B相对于相机A的相对位姿矩阵"""
|
| 68 |
+
assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
|
| 69 |
+
assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
|
| 70 |
+
|
| 71 |
+
if use_torch:
|
| 72 |
+
if not isinstance(pose_a, torch.Tensor):
|
| 73 |
+
pose_a = torch.from_numpy(pose_a).float()
|
| 74 |
+
if not isinstance(pose_b, torch.Tensor):
|
| 75 |
+
pose_b = torch.from_numpy(pose_b).float()
|
| 76 |
+
|
| 77 |
+
pose_a_inv = torch.inverse(pose_a)
|
| 78 |
+
relative_pose = torch.matmul(pose_b, pose_a_inv)
|
| 79 |
+
else:
|
| 80 |
+
if not isinstance(pose_a, np.ndarray):
|
| 81 |
+
pose_a = np.array(pose_a, dtype=np.float32)
|
| 82 |
+
if not isinstance(pose_b, np.ndarray):
|
| 83 |
+
pose_b = np.array(pose_b, dtype=np.float32)
|
| 84 |
+
|
| 85 |
+
pose_a_inv = np.linalg.inv(pose_a)
|
| 86 |
+
relative_pose = np.matmul(pose_b, pose_a_inv)
|
| 87 |
+
|
| 88 |
+
return relative_pose
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def replace_dit_model_in_manager():
|
| 92 |
+
"""替换DiT模型类为MoE版本"""
|
| 93 |
+
from diffsynth.models.wan_video_dit_moe import WanModelMoe
|
| 94 |
+
from diffsynth.configs.model_config import model_loader_configs
|
| 95 |
+
|
| 96 |
+
for i, config in enumerate(model_loader_configs):
|
| 97 |
+
keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
|
| 98 |
+
|
| 99 |
+
if 'wan_video_dit' in model_names:
|
| 100 |
+
new_model_names = []
|
| 101 |
+
new_model_classes = []
|
| 102 |
+
|
| 103 |
+
for name, cls in zip(model_names, model_classes):
|
| 104 |
+
if name == 'wan_video_dit':
|
| 105 |
+
new_model_names.append(name)
|
| 106 |
+
new_model_classes.append(WanModelMoe)
|
| 107 |
+
print(f"✅ 替换了模型类: {name} -> WanModelMoe")
|
| 108 |
+
else:
|
| 109 |
+
new_model_names.append(name)
|
| 110 |
+
new_model_classes.append(cls)
|
| 111 |
+
|
| 112 |
+
model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def add_framepack_components(dit_model):
|
| 116 |
+
"""添加FramePack相关组件"""
|
| 117 |
+
if not hasattr(dit_model, 'clean_x_embedder'):
|
| 118 |
+
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 119 |
+
|
| 120 |
+
class CleanXEmbedder(nn.Module):
|
| 121 |
+
def __init__(self, inner_dim):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 124 |
+
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
| 125 |
+
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
| 126 |
+
|
| 127 |
+
def forward(self, x, scale="1x"):
|
| 128 |
+
if scale == "1x":
|
| 129 |
+
x = x.to(self.proj.weight.dtype)
|
| 130 |
+
return self.proj(x)
|
| 131 |
+
elif scale == "2x":
|
| 132 |
+
x = x.to(self.proj_2x.weight.dtype)
|
| 133 |
+
return self.proj_2x(x)
|
| 134 |
+
elif scale == "4x":
|
| 135 |
+
x = x.to(self.proj_4x.weight.dtype)
|
| 136 |
+
return self.proj_4x(x)
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError(f"Unsupported scale: {scale}")
|
| 139 |
+
|
| 140 |
+
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
|
| 141 |
+
model_dtype = next(dit_model.parameters()).dtype
|
| 142 |
+
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
|
| 143 |
+
print("✅ 添加了FramePack的clean_x_embedder组件")
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def add_moe_components(dit_model, moe_config):
|
| 147 |
+
"""🔧 添加MoE相关组件 - 修正版本"""
|
| 148 |
+
if not hasattr(dit_model, 'moe_config'):
|
| 149 |
+
dit_model.moe_config = moe_config
|
| 150 |
+
print("✅ 添加了MoE配置到模型")
|
| 151 |
+
|
| 152 |
+
# 为每个block动态添加MoE组件
|
| 153 |
+
dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 154 |
+
unified_dim = moe_config.get("unified_dim", 25)
|
| 155 |
+
|
| 156 |
+
for i, block in enumerate(dit_model.blocks):
|
| 157 |
+
from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
|
| 158 |
+
|
| 159 |
+
# Sekai模态处理器 - 输出unified_dim
|
| 160 |
+
block.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
|
| 161 |
+
|
| 162 |
+
# # NuScenes模态处理器 - 输出unified_dim
|
| 163 |
+
# block.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
|
| 164 |
+
|
| 165 |
+
# MoE网络 - 输入unified_dim,输出dim
|
| 166 |
+
block.moe = MultiModalMoE(
|
| 167 |
+
unified_dim=unified_dim,
|
| 168 |
+
output_dim=dim, # 输出维度匹配transformer block的dim
|
| 169 |
+
num_experts=moe_config.get("num_experts", 4),
|
| 170 |
+
top_k=moe_config.get("top_k", 2)
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
|
| 177 |
+
"""为Sekai数据集生成camera embeddings - 滑动窗口版本"""
|
| 178 |
+
time_compression_ratio = 4
|
| 179 |
+
|
| 180 |
+
# 计算FramePack实际需要的camera帧数
|
| 181 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 182 |
+
|
| 183 |
+
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
|
| 184 |
+
print("🔧 使用真实Sekai camera数据")
|
| 185 |
+
cam_extrinsic = cam_data['extrinsic']
|
| 186 |
+
|
| 187 |
+
# 确保生成足够长的camera序列
|
| 188 |
+
max_needed_frames = max(
|
| 189 |
+
start_frame + current_history_length + new_frames,
|
| 190 |
+
framepack_needed_frames,
|
| 191 |
+
30
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
print(f"🔧 计算Sekai camera序列长度:")
|
| 195 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 196 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 197 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 198 |
+
|
| 199 |
+
relative_poses = []
|
| 200 |
+
for i in range(max_needed_frames):
|
| 201 |
+
# 计算当前帧在原始序列中的位置
|
| 202 |
+
frame_idx = i * time_compression_ratio
|
| 203 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 204 |
+
|
| 205 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 206 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 207 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 208 |
+
relative_pose = compute_relative_pose_matrix(cam_prev, cam_next)
|
| 209 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 210 |
+
else:
|
| 211 |
+
# 超出范围,使用零运动
|
| 212 |
+
print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
|
| 213 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 214 |
+
|
| 215 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 216 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 217 |
+
|
| 218 |
+
# 创建对应长度的mask序列
|
| 219 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 220 |
+
# 从start_frame到current_history_length标记为condition
|
| 221 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 222 |
+
mask[start_frame:condition_end] = 1.0
|
| 223 |
+
|
| 224 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 225 |
+
print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
|
| 226 |
+
return camera_embedding.to(torch.bfloat16)
|
| 227 |
+
|
| 228 |
+
else:
|
| 229 |
+
print("🔧 使用Sekai合成camera数据")
|
| 230 |
+
|
| 231 |
+
max_needed_frames = max(
|
| 232 |
+
start_frame + current_history_length + new_frames,
|
| 233 |
+
framepack_needed_frames,
|
| 234 |
+
30
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
|
| 238 |
+
relative_poses = []
|
| 239 |
+
for i in range(max_needed_frames):
|
| 240 |
+
# 持续左转运动模式
|
| 241 |
+
yaw_per_frame = 0.05 # 每帧左转(正角度表示左转)
|
| 242 |
+
forward_speed = 0.005 # 每帧前进距离
|
| 243 |
+
|
| 244 |
+
pose = np.eye(4, dtype=np.float32)
|
| 245 |
+
|
| 246 |
+
# 旋转矩阵(绕Y轴左转)
|
| 247 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 248 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 249 |
+
|
| 250 |
+
pose[0, 0] = cos_yaw
|
| 251 |
+
pose[0, 2] = sin_yaw
|
| 252 |
+
pose[2, 0] = -sin_yaw
|
| 253 |
+
pose[2, 2] = cos_yaw
|
| 254 |
+
|
| 255 |
+
# 平移(在旋转后的局部坐标系中前进)
|
| 256 |
+
pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
|
| 257 |
+
|
| 258 |
+
# 添加轻微的向心运动,模拟圆形轨迹
|
| 259 |
+
radius_drift = 0.002 # 向圆心的轻微漂移
|
| 260 |
+
pose[0, 3] = -radius_drift # 局部X轴负方向(向左)
|
| 261 |
+
|
| 262 |
+
relative_pose = pose[:3, :]
|
| 263 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 264 |
+
|
| 265 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 266 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 267 |
+
|
| 268 |
+
# 创建对应长度的mask序列
|
| 269 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 270 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 271 |
+
mask[start_frame:condition_end] = 1.0
|
| 272 |
+
|
| 273 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 274 |
+
print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
|
| 275 |
+
return camera_embedding.to(torch.bfloat16)
|
| 276 |
+
|
| 277 |
+
def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses):
|
| 278 |
+
"""为OpenX数据集生成camera embeddings - 滑动窗口版本"""
|
| 279 |
+
time_compression_ratio = 4
|
| 280 |
+
|
| 281 |
+
# 计算FramePack实际需要的camera帧数
|
| 282 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 283 |
+
|
| 284 |
+
if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
|
| 285 |
+
print("🔧 使用OpenX真实camera数据")
|
| 286 |
+
cam_extrinsic = encoded_data['cam_emb']['extrinsic']
|
| 287 |
+
|
| 288 |
+
# 确保生成足够长的camera序列
|
| 289 |
+
max_needed_frames = max(
|
| 290 |
+
start_frame + current_history_length + new_frames,
|
| 291 |
+
framepack_needed_frames,
|
| 292 |
+
30
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
print(f"🔧 计算OpenX camera序列长度:")
|
| 296 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 297 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 298 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 299 |
+
|
| 300 |
+
relative_poses = []
|
| 301 |
+
for i in range(max_needed_frames):
|
| 302 |
+
# OpenX使用4倍间隔,类似sekai但处理更短的序列
|
| 303 |
+
frame_idx = i * time_compression_ratio
|
| 304 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 305 |
+
|
| 306 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 307 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 308 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 309 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 310 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 311 |
+
else:
|
| 312 |
+
# 超出范围,使用零运动
|
| 313 |
+
print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
|
| 314 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 315 |
+
|
| 316 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 317 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 318 |
+
|
| 319 |
+
# 创建对应长度的mask序列
|
| 320 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 321 |
+
# 从start_frame到current_history_length标记为condition
|
| 322 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 323 |
+
mask[start_frame:condition_end] = 1.0
|
| 324 |
+
|
| 325 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 326 |
+
print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
|
| 327 |
+
return camera_embedding.to(torch.bfloat16)
|
| 328 |
+
|
| 329 |
+
else:
|
| 330 |
+
print("🔧 使用OpenX合成camera数据")
|
| 331 |
+
|
| 332 |
+
max_needed_frames = max(
|
| 333 |
+
start_frame + current_history_length + new_frames,
|
| 334 |
+
framepack_needed_frames,
|
| 335 |
+
30
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
|
| 339 |
+
relative_poses = []
|
| 340 |
+
for i in range(max_needed_frames):
|
| 341 |
+
# OpenX机器人操作运动模式 - 较小的运动幅度
|
| 342 |
+
# 模拟机器人手臂的精细操作运动
|
| 343 |
+
roll_per_frame = 0.02 # 轻微翻滚
|
| 344 |
+
pitch_per_frame = 0.01 # 轻微俯仰
|
| 345 |
+
yaw_per_frame = 0.015 # 轻微偏航
|
| 346 |
+
forward_speed = 0.003 # 较慢的前进速度
|
| 347 |
+
|
| 348 |
+
pose = np.eye(4, dtype=np.float32)
|
| 349 |
+
|
| 350 |
+
# 复合旋转 - 模拟机器人手臂的复杂运动
|
| 351 |
+
# 绕X轴旋转(roll)
|
| 352 |
+
cos_roll = np.cos(roll_per_frame)
|
| 353 |
+
sin_roll = np.sin(roll_per_frame)
|
| 354 |
+
# 绕Y轴旋转(pitch)
|
| 355 |
+
cos_pitch = np.cos(pitch_per_frame)
|
| 356 |
+
sin_pitch = np.sin(pitch_per_frame)
|
| 357 |
+
# 绕Z轴旋转(yaw)
|
| 358 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 359 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 360 |
+
|
| 361 |
+
# 简化的复合旋转矩阵(ZYX顺序)
|
| 362 |
+
pose[0, 0] = cos_yaw * cos_pitch
|
| 363 |
+
pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
|
| 364 |
+
pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
|
| 365 |
+
pose[1, 0] = sin_yaw * cos_pitch
|
| 366 |
+
pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
|
| 367 |
+
pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
|
| 368 |
+
pose[2, 0] = -sin_pitch
|
| 369 |
+
pose[2, 1] = cos_pitch * sin_roll
|
| 370 |
+
pose[2, 2] = cos_pitch * cos_roll
|
| 371 |
+
|
| 372 |
+
# 平移 - 模拟机器人操作的精细移动
|
| 373 |
+
pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
|
| 374 |
+
pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动
|
| 375 |
+
pose[2, 3] = -forward_speed # Z方向(深度)主要移动
|
| 376 |
+
|
| 377 |
+
relative_pose = pose[:3, :]
|
| 378 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 379 |
+
|
| 380 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 381 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 382 |
+
|
| 383 |
+
# 创建对应长度的mask序列
|
| 384 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 385 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 386 |
+
mask[start_frame:condition_end] = 1.0
|
| 387 |
+
|
| 388 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 389 |
+
print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
|
| 390 |
+
return camera_embedding.to(torch.bfloat16)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames):
|
| 394 |
+
"""为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
|
| 395 |
+
time_compression_ratio = 4
|
| 396 |
+
|
| 397 |
+
# 计算FramePack实际需要的camera帧数
|
| 398 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 399 |
+
|
| 400 |
+
if scene_info is not None and 'keyframe_poses' in scene_info:
|
| 401 |
+
print("🔧 使用NuScenes真实pose数据")
|
| 402 |
+
keyframe_poses = scene_info['keyframe_poses']
|
| 403 |
+
|
| 404 |
+
if len(keyframe_poses) == 0:
|
| 405 |
+
print("⚠️ NuScenes keyframe_poses为空,使用零pose")
|
| 406 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 407 |
+
|
| 408 |
+
pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
|
| 409 |
+
|
| 410 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 411 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 412 |
+
mask[start_frame:condition_end] = 1.0
|
| 413 |
+
|
| 414 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 415 |
+
print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}")
|
| 416 |
+
return camera_embedding.to(torch.bfloat16)
|
| 417 |
+
|
| 418 |
+
# 使用第一个pose作为参考
|
| 419 |
+
reference_pose = keyframe_poses[0]
|
| 420 |
+
|
| 421 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 422 |
+
|
| 423 |
+
pose_vecs = []
|
| 424 |
+
for i in range(max_needed_frames):
|
| 425 |
+
if i < len(keyframe_poses):
|
| 426 |
+
current_pose = keyframe_poses[i]
|
| 427 |
+
|
| 428 |
+
# 计算相对位移
|
| 429 |
+
translation = torch.tensor(
|
| 430 |
+
np.array(current_pose['translation']) - np.array(reference_pose['translation']),
|
| 431 |
+
dtype=torch.float32
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# 计算相对旋转(简化版本)
|
| 435 |
+
rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
|
| 436 |
+
|
| 437 |
+
pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
|
| 438 |
+
else:
|
| 439 |
+
# 超出范围,使用零pose
|
| 440 |
+
pose_vec = torch.cat([
|
| 441 |
+
torch.zeros(3, dtype=torch.float32),
|
| 442 |
+
torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
|
| 443 |
+
], dim=0) # [7D]
|
| 444 |
+
|
| 445 |
+
pose_vecs.append(pose_vec)
|
| 446 |
+
|
| 447 |
+
pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
|
| 448 |
+
|
| 449 |
+
# 创建mask
|
| 450 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 451 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 452 |
+
mask[start_frame:condition_end] = 1.0
|
| 453 |
+
|
| 454 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 455 |
+
print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
|
| 456 |
+
return camera_embedding.to(torch.bfloat16)
|
| 457 |
+
|
| 458 |
+
else:
|
| 459 |
+
print("🔧 使用NuScenes合成pose数据")
|
| 460 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 461 |
+
|
| 462 |
+
# 创建合成运动序列
|
| 463 |
+
pose_vecs = []
|
| 464 |
+
for i in range(max_needed_frames):
|
| 465 |
+
# 简单的前进运动
|
| 466 |
+
translation = torch.tensor([0.0, 0.0, i * 0.1], dtype=torch.float32) # 沿Z轴前进
|
| 467 |
+
rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) # 无旋转
|
| 468 |
+
|
| 469 |
+
pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
|
| 470 |
+
pose_vecs.append(pose_vec)
|
| 471 |
+
|
| 472 |
+
pose_sequence = torch.stack(pose_vecs, dim=0)
|
| 473 |
+
|
| 474 |
+
# 创建mask
|
| 475 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 476 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 477 |
+
mask[start_frame:condition_end] = 1.0
|
| 478 |
+
|
| 479 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 480 |
+
print(f"🔧 NuScenes合成pose embedding shape: {camera_embedding.shape}")
|
| 481 |
+
return camera_embedding.to(torch.bfloat16)
|
| 482 |
+
|
| 483 |
+
def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49):
|
| 484 |
+
"""FramePack滑动窗口机制 - MoE版本"""
|
| 485 |
+
# history_latents: [C, T, H, W] 当前的历史latents
|
| 486 |
+
C, T, H, W = history_latents.shape
|
| 487 |
+
|
| 488 |
+
# 固定索引结构(这决定了需要的camera帧数)
|
| 489 |
+
total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
|
| 490 |
+
indices = torch.arange(0, total_indices_length)
|
| 491 |
+
split_sizes = [1, 16, 2, 1, target_frames_to_generate]
|
| 492 |
+
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
|
| 493 |
+
indices.split(split_sizes, dim=0)
|
| 494 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
|
| 495 |
+
|
| 496 |
+
# 检查camera长度是否足够
|
| 497 |
+
if camera_embedding_full.shape[0] < total_indices_length:
|
| 498 |
+
shortage = total_indices_length - camera_embedding_full.shape[0]
|
| 499 |
+
padding = torch.zeros(shortage, camera_embedding_full.shape[1],
|
| 500 |
+
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
|
| 501 |
+
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
|
| 502 |
+
|
| 503 |
+
# 从完整camera序列中选取对应部分
|
| 504 |
+
combined_camera = camera_embedding_full[:total_indices_length, :].clone()
|
| 505 |
+
|
| 506 |
+
# 根据当前history length重新设置mask
|
| 507 |
+
combined_camera[:, -1] = 0.0 # 先全部设为target (0)
|
| 508 |
+
|
| 509 |
+
# 设置condition mask:前19帧根据实际历史长度决定
|
| 510 |
+
if T > 0:
|
| 511 |
+
available_frames = min(T, 19)
|
| 512 |
+
start_pos = 19 - available_frames
|
| 513 |
+
combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
|
| 514 |
+
|
| 515 |
+
print(f"🔧 MoE Camera mask更新:")
|
| 516 |
+
print(f" - 历史帧数: {T}")
|
| 517 |
+
print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
|
| 518 |
+
print(f" - 模态类型: {modality_type}")
|
| 519 |
+
|
| 520 |
+
# 处理latents
|
| 521 |
+
clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 522 |
+
|
| 523 |
+
if T > 0:
|
| 524 |
+
available_frames = min(T, 19)
|
| 525 |
+
start_pos = 19 - available_frames
|
| 526 |
+
clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
|
| 527 |
+
|
| 528 |
+
clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
|
| 529 |
+
clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
|
| 530 |
+
clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
|
| 531 |
+
|
| 532 |
+
if T > 0:
|
| 533 |
+
start_latent = history_latents[:, 0:1, :, :]
|
| 534 |
+
else:
|
| 535 |
+
start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 536 |
+
|
| 537 |
+
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
|
| 538 |
+
|
| 539 |
+
return {
|
| 540 |
+
'latent_indices': latent_indices,
|
| 541 |
+
'clean_latents': clean_latents,
|
| 542 |
+
'clean_latents_2x': clean_latents_2x,
|
| 543 |
+
'clean_latents_4x': clean_latents_4x,
|
| 544 |
+
'clean_latent_indices': clean_latent_indices,
|
| 545 |
+
'clean_latent_2x_indices': clean_latent_2x_indices,
|
| 546 |
+
'clean_latent_4x_indices': clean_latent_4x_indices,
|
| 547 |
+
'camera_embedding': combined_camera,
|
| 548 |
+
'modality_type': modality_type, # 新增模态类型信息
|
| 549 |
+
'current_length': T,
|
| 550 |
+
'next_length': T + target_frames_to_generate
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def inference_moe_framepack_sliding_window(
|
| 555 |
+
condition_pth_path,
|
| 556 |
+
dit_path,
|
| 557 |
+
output_path="moe/infer_results/output_moe_framepack_sliding.mp4",
|
| 558 |
+
start_frame=0,
|
| 559 |
+
initial_condition_frames=8,
|
| 560 |
+
frames_per_generation=4,
|
| 561 |
+
total_frames_to_generate=32,
|
| 562 |
+
max_history_frames=49,
|
| 563 |
+
device="cuda",
|
| 564 |
+
prompt="A video of a scene shot using a pedestrian's front camera while walking",
|
| 565 |
+
modality_type="sekai", # "sekai" 或 "nuscenes"
|
| 566 |
+
use_real_poses=True,
|
| 567 |
+
scene_info_path=None, # 对于NuScenes数据集
|
| 568 |
+
# CFG参数
|
| 569 |
+
use_camera_cfg=True,
|
| 570 |
+
camera_guidance_scale=2.0,
|
| 571 |
+
text_guidance_scale=1.0,
|
| 572 |
+
# MoE参数
|
| 573 |
+
moe_num_experts=4,
|
| 574 |
+
moe_top_k=2,
|
| 575 |
+
moe_hidden_dim=None
|
| 576 |
+
):
|
| 577 |
+
"""
|
| 578 |
+
MoE FramePack滑动窗口视频生成 - 支持多模态
|
| 579 |
+
"""
|
| 580 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 581 |
+
print(f"🔧 MoE FramePack滑动窗口生成开始...")
|
| 582 |
+
print(f"模态类型: {modality_type}")
|
| 583 |
+
print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
|
| 584 |
+
print(f"Text guidance scale: {text_guidance_scale}")
|
| 585 |
+
print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
|
| 586 |
+
|
| 587 |
+
# 1. 模型初始化
|
| 588 |
+
replace_dit_model_in_manager()
|
| 589 |
+
|
| 590 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 591 |
+
model_manager.load_models([
|
| 592 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 593 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 594 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 595 |
+
])
|
| 596 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 597 |
+
|
| 598 |
+
# 2. 添加传统camera编码器(兼容性)
|
| 599 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 600 |
+
for block in pipe.dit.blocks:
|
| 601 |
+
block.cam_encoder = nn.Linear(13, dim)
|
| 602 |
+
block.projector = nn.Linear(dim, dim)
|
| 603 |
+
block.cam_encoder.weight.data.zero_()
|
| 604 |
+
block.cam_encoder.bias.data.zero_()
|
| 605 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 606 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 607 |
+
|
| 608 |
+
# 3. 添加FramePack组件
|
| 609 |
+
add_framepack_components(pipe.dit)
|
| 610 |
+
|
| 611 |
+
# 4. 添加MoE组件
|
| 612 |
+
moe_config = {
|
| 613 |
+
"num_experts": moe_num_experts,
|
| 614 |
+
"top_k": moe_top_k,
|
| 615 |
+
"hidden_dim": moe_hidden_dim or dim * 2,
|
| 616 |
+
"sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
|
| 617 |
+
"nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
|
| 618 |
+
"openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
|
| 619 |
+
}
|
| 620 |
+
add_moe_components(pipe.dit, moe_config)
|
| 621 |
+
|
| 622 |
+
# 5. 加载训练好的权重
|
| 623 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 624 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
|
| 625 |
+
pipe = pipe.to(device)
|
| 626 |
+
model_dtype = next(pipe.dit.parameters()).dtype
|
| 627 |
+
|
| 628 |
+
if hasattr(pipe.dit, 'clean_x_embedder'):
|
| 629 |
+
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
|
| 630 |
+
|
| 631 |
+
pipe.scheduler.set_timesteps(50)
|
| 632 |
+
|
| 633 |
+
# 6. 加载初始条件
|
| 634 |
+
print("Loading initial condition frames...")
|
| 635 |
+
initial_latents, encoded_data = load_encoded_video_from_pth(
|
| 636 |
+
condition_pth_path,
|
| 637 |
+
start_frame=start_frame,
|
| 638 |
+
num_frames=initial_condition_frames
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
# 空间裁剪
|
| 642 |
+
target_height, target_width = 60, 104
|
| 643 |
+
C, T, H, W = initial_latents.shape
|
| 644 |
+
|
| 645 |
+
if H > target_height or W > target_width:
|
| 646 |
+
h_start = (H - target_height) // 2
|
| 647 |
+
w_start = (W - target_width) // 2
|
| 648 |
+
initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
|
| 649 |
+
H, W = target_height, target_width
|
| 650 |
+
|
| 651 |
+
history_latents = initial_latents.to(device, dtype=model_dtype)
|
| 652 |
+
|
| 653 |
+
print(f"初始history_latents shape: {history_latents.shape}")
|
| 654 |
+
|
| 655 |
+
# 7. 编码prompt - 支持CFG
|
| 656 |
+
if text_guidance_scale > 1.0:
|
| 657 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 658 |
+
prompt_emb_neg = pipe.encode_prompt("")
|
| 659 |
+
print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
|
| 660 |
+
else:
|
| 661 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 662 |
+
prompt_emb_neg = None
|
| 663 |
+
print("不使用Text CFG")
|
| 664 |
+
|
| 665 |
+
# 8. 加载场景信息(对于NuScenes)
|
| 666 |
+
scene_info = None
|
| 667 |
+
if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
|
| 668 |
+
with open(scene_info_path, 'r') as f:
|
| 669 |
+
scene_info = json.load(f)
|
| 670 |
+
print(f"加载NuScenes场景信息: {scene_info_path}")
|
| 671 |
+
|
| 672 |
+
# 9. 预生成完整的camera embedding序列
|
| 673 |
+
if modality_type == "sekai":
|
| 674 |
+
camera_embedding_full = generate_sekai_camera_embeddings_sliding(
|
| 675 |
+
encoded_data.get('cam_emb', None),
|
| 676 |
+
0,
|
| 677 |
+
max_history_frames,
|
| 678 |
+
0,
|
| 679 |
+
0,
|
| 680 |
+
use_real_poses=use_real_poses
|
| 681 |
+
).to(device, dtype=model_dtype)
|
| 682 |
+
elif modality_type == "nuscenes":
|
| 683 |
+
camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
|
| 684 |
+
scene_info,
|
| 685 |
+
0,
|
| 686 |
+
max_history_frames,
|
| 687 |
+
0
|
| 688 |
+
).to(device, dtype=model_dtype)
|
| 689 |
+
elif modality_type == "openx":
|
| 690 |
+
camera_embedding_full = generate_openx_camera_embeddings_sliding(
|
| 691 |
+
encoded_data,
|
| 692 |
+
0,
|
| 693 |
+
max_history_frames,
|
| 694 |
+
0,
|
| 695 |
+
use_real_poses=use_real_poses
|
| 696 |
+
).to(device, dtype=model_dtype)
|
| 697 |
+
else:
|
| 698 |
+
raise ValueError(f"不支持的模态类型: {modality_type}")
|
| 699 |
+
|
| 700 |
+
print(f"完整camera序列shape: {camera_embedding_full.shape}")
|
| 701 |
+
|
| 702 |
+
# 10. 为Camera CFG创建无条件的camera embedding
|
| 703 |
+
if use_camera_cfg:
|
| 704 |
+
camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
|
| 705 |
+
print(f"创建无条件camera embedding用于CFG")
|
| 706 |
+
|
| 707 |
+
# 11. 滑动窗口生成循环
|
| 708 |
+
total_generated = 0
|
| 709 |
+
all_generated_frames = []
|
| 710 |
+
|
| 711 |
+
while total_generated < total_frames_to_generate:
|
| 712 |
+
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
|
| 713 |
+
print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
|
| 714 |
+
print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
|
| 715 |
+
|
| 716 |
+
# FramePack数据准备 - MoE版本
|
| 717 |
+
framepack_data = prepare_framepack_sliding_window_with_camera_moe(
|
| 718 |
+
history_latents,
|
| 719 |
+
current_generation,
|
| 720 |
+
camera_embedding_full,
|
| 721 |
+
start_frame,
|
| 722 |
+
modality_type,
|
| 723 |
+
max_history_frames
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
# 准备输入
|
| 727 |
+
clean_latents = framepack_data['clean_latents'].unsqueeze(0)
|
| 728 |
+
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
|
| 729 |
+
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
|
| 730 |
+
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
|
| 731 |
+
|
| 732 |
+
# 准备modality_inputs
|
| 733 |
+
modality_inputs = {modality_type: camera_embedding}
|
| 734 |
+
|
| 735 |
+
# 为CFG准备无条件camera embedding
|
| 736 |
+
if use_camera_cfg:
|
| 737 |
+
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
|
| 738 |
+
modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
|
| 739 |
+
|
| 740 |
+
# 索引处理
|
| 741 |
+
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
|
| 742 |
+
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
|
| 743 |
+
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
|
| 744 |
+
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
|
| 745 |
+
|
| 746 |
+
# 初始化要生成的latents
|
| 747 |
+
new_latents = torch.randn(
|
| 748 |
+
1, C, current_generation, H, W,
|
| 749 |
+
device=device, dtype=model_dtype
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
extra_input = pipe.prepare_extra_input(new_latents)
|
| 753 |
+
|
| 754 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 755 |
+
print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
|
| 756 |
+
|
| 757 |
+
# 去噪循环 - 支持CFG
|
| 758 |
+
timesteps = pipe.scheduler.timesteps
|
| 759 |
+
|
| 760 |
+
for i, timestep in enumerate(timesteps):
|
| 761 |
+
if i % 10 == 0:
|
| 762 |
+
print(f" 去噪步骤 {i+1}/{len(timesteps)}")
|
| 763 |
+
|
| 764 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
|
| 765 |
+
|
| 766 |
+
with torch.no_grad():
|
| 767 |
+
# CFG推理
|
| 768 |
+
if use_camera_cfg and camera_guidance_scale > 1.0:
|
| 769 |
+
# 条件预测(有camera)
|
| 770 |
+
noise_pred_cond, moe_loss = pipe.dit(
|
| 771 |
+
new_latents,
|
| 772 |
+
timestep=timestep_tensor,
|
| 773 |
+
cam_emb=camera_embedding,
|
| 774 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 775 |
+
latent_indices=latent_indices,
|
| 776 |
+
clean_latents=clean_latents,
|
| 777 |
+
clean_latent_indices=clean_latent_indices,
|
| 778 |
+
clean_latents_2x=clean_latents_2x,
|
| 779 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 780 |
+
clean_latents_4x=clean_latents_4x,
|
| 781 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 782 |
+
**prompt_emb_pos,
|
| 783 |
+
**extra_input
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
# 无条件预测(无camera)
|
| 787 |
+
noise_pred_uncond, moe_loss = pipe.dit(
|
| 788 |
+
new_latents,
|
| 789 |
+
timestep=timestep_tensor,
|
| 790 |
+
cam_emb=camera_embedding_uncond_batch,
|
| 791 |
+
modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
|
| 792 |
+
latent_indices=latent_indices,
|
| 793 |
+
clean_latents=clean_latents,
|
| 794 |
+
clean_latent_indices=clean_latent_indices,
|
| 795 |
+
clean_latents_2x=clean_latents_2x,
|
| 796 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 797 |
+
clean_latents_4x=clean_latents_4x,
|
| 798 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 799 |
+
**(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
|
| 800 |
+
**extra_input
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
# Camera CFG
|
| 804 |
+
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 805 |
+
|
| 806 |
+
# 如果同时使用Text CFG
|
| 807 |
+
if text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 808 |
+
noise_pred_text_uncond, moe_loss = pipe.dit(
|
| 809 |
+
new_latents,
|
| 810 |
+
timestep=timestep_tensor,
|
| 811 |
+
cam_emb=camera_embedding,
|
| 812 |
+
modality_inputs=modality_inputs,
|
| 813 |
+
latent_indices=latent_indices,
|
| 814 |
+
clean_latents=clean_latents,
|
| 815 |
+
clean_latent_indices=clean_latent_indices,
|
| 816 |
+
clean_latents_2x=clean_latents_2x,
|
| 817 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 818 |
+
clean_latents_4x=clean_latents_4x,
|
| 819 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 820 |
+
**prompt_emb_neg,
|
| 821 |
+
**extra_input
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
# 应用Text CFG到已经应用Camera CFG的结果
|
| 825 |
+
noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
|
| 826 |
+
|
| 827 |
+
elif text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 828 |
+
# 只使用Text CFG
|
| 829 |
+
noise_pred_cond, moe_loss = pipe.dit(
|
| 830 |
+
new_latents,
|
| 831 |
+
timestep=timestep_tensor,
|
| 832 |
+
cam_emb=camera_embedding,
|
| 833 |
+
modality_inputs=modality_inputs,
|
| 834 |
+
latent_indices=latent_indices,
|
| 835 |
+
clean_latents=clean_latents,
|
| 836 |
+
clean_latent_indices=clean_latent_indices,
|
| 837 |
+
clean_latents_2x=clean_latents_2x,
|
| 838 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 839 |
+
clean_latents_4x=clean_latents_4x,
|
| 840 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 841 |
+
**prompt_emb_pos,
|
| 842 |
+
**extra_input
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
noise_pred_uncond, moe_loss = pipe.dit(
|
| 846 |
+
new_latents,
|
| 847 |
+
timestep=timestep_tensor,
|
| 848 |
+
cam_emb=camera_embedding,
|
| 849 |
+
modality_inputs=modality_inputs,
|
| 850 |
+
latent_indices=latent_indices,
|
| 851 |
+
clean_latents=clean_latents,
|
| 852 |
+
clean_latent_indices=clean_latent_indices,
|
| 853 |
+
clean_latents_2x=clean_latents_2x,
|
| 854 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 855 |
+
clean_latents_4x=clean_latents_4x,
|
| 856 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 857 |
+
**prompt_emb_neg,
|
| 858 |
+
**extra_input
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 862 |
+
|
| 863 |
+
else:
|
| 864 |
+
# 标准推理(无CFG)
|
| 865 |
+
noise_pred, moe_loss = pipe.dit(
|
| 866 |
+
new_latents,
|
| 867 |
+
timestep=timestep_tensor,
|
| 868 |
+
cam_emb=camera_embedding,
|
| 869 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 870 |
+
latent_indices=latent_indices,
|
| 871 |
+
clean_latents=clean_latents,
|
| 872 |
+
clean_latent_indices=clean_latent_indices,
|
| 873 |
+
clean_latents_2x=clean_latents_2x,
|
| 874 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 875 |
+
clean_latents_4x=clean_latents_4x,
|
| 876 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 877 |
+
**prompt_emb_pos,
|
| 878 |
+
**extra_input
|
| 879 |
+
)
|
| 880 |
+
|
| 881 |
+
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
|
| 882 |
+
|
| 883 |
+
# 更新历史
|
| 884 |
+
new_latents_squeezed = new_latents.squeeze(0)
|
| 885 |
+
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
|
| 886 |
+
|
| 887 |
+
# 维护滑动窗口
|
| 888 |
+
if history_latents.shape[1] > max_history_frames:
|
| 889 |
+
first_frame = history_latents[:, 0:1, :, :]
|
| 890 |
+
recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
|
| 891 |
+
history_latents = torch.cat([first_frame, recent_frames], dim=1)
|
| 892 |
+
print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
|
| 893 |
+
|
| 894 |
+
print(f"更新后history_latents shape: {history_latents.shape}")
|
| 895 |
+
|
| 896 |
+
all_generated_frames.append(new_latents_squeezed)
|
| 897 |
+
total_generated += current_generation
|
| 898 |
+
|
| 899 |
+
print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
|
| 900 |
+
|
| 901 |
+
# 12. 解码和保存
|
| 902 |
+
print("\n🔧 解码生成的视频...")
|
| 903 |
+
|
| 904 |
+
all_generated = torch.cat(all_generated_frames, dim=1)
|
| 905 |
+
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
|
| 906 |
+
|
| 907 |
+
print(f"最终视频shape: {final_video.shape}")
|
| 908 |
+
|
| 909 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 910 |
+
|
| 911 |
+
print(f"Saving video to {output_path}")
|
| 912 |
+
|
| 913 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
|
| 914 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1)
|
| 915 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 916 |
+
|
| 917 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 918 |
+
for frame in video_np:
|
| 919 |
+
writer.append_data(frame)
|
| 920 |
+
|
| 921 |
+
print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
|
| 922 |
+
print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
|
| 923 |
+
print(f"使用模态: {modality_type}")
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
def main():
|
| 927 |
+
parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
|
| 928 |
+
|
| 929 |
+
# 基础参数
|
| 930 |
+
parser.add_argument("--condition_pth", type=str,
|
| 931 |
+
#default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
|
| 932 |
+
#default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth")
|
| 933 |
+
default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth")
|
| 934 |
+
#default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth")
|
| 935 |
+
parser.add_argument("--start_frame", type=int, default=0)
|
| 936 |
+
parser.add_argument("--initial_condition_frames", type=int, default=16)
|
| 937 |
+
parser.add_argument("--frames_per_generation", type=int, default=8)
|
| 938 |
+
parser.add_argument("--total_frames_to_generate", type=int, default=8)
|
| 939 |
+
parser.add_argument("--max_history_frames", type=int, default=100)
|
| 940 |
+
parser.add_argument("--use_real_poses", action="store_true", default=False)
|
| 941 |
+
parser.add_argument("--dit_path", type=str,
|
| 942 |
+
default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe_spatialvid/step250_moe.ckpt")
|
| 943 |
+
parser.add_argument("--output_path", type=str,
|
| 944 |
+
default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4')
|
| 945 |
+
parser.add_argument("--prompt", type=str,
|
| 946 |
+
default="A man enter the room")
|
| 947 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 948 |
+
|
| 949 |
+
# 模态类型参数
|
| 950 |
+
parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="sekai",
|
| 951 |
+
help="模态类型:sekai 或 nuscenes 或 openx")
|
| 952 |
+
parser.add_argument("--scene_info_path", type=str, default=None,
|
| 953 |
+
help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
|
| 954 |
+
|
| 955 |
+
# CFG参数
|
| 956 |
+
parser.add_argument("--use_camera_cfg", default=True,
|
| 957 |
+
help="使用Camera CFG")
|
| 958 |
+
parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
|
| 959 |
+
help="Camera guidance scale for CFG")
|
| 960 |
+
parser.add_argument("--text_guidance_scale", type=float, default=1.0,
|
| 961 |
+
help="Text guidance scale for CFG")
|
| 962 |
+
|
| 963 |
+
# MoE参数
|
| 964 |
+
parser.add_argument("--moe_num_experts", type=int, default=1, help="专家数量")
|
| 965 |
+
parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
|
| 966 |
+
parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
|
| 967 |
+
|
| 968 |
+
args = parser.parse_args()
|
| 969 |
+
|
| 970 |
+
print(f"🔧 MoE FramePack CFG生成设置:")
|
| 971 |
+
print(f"模态类型: {args.modality_type}")
|
| 972 |
+
print(f"Camera CFG: {args.use_camera_cfg}")
|
| 973 |
+
if args.use_camera_cfg:
|
| 974 |
+
print(f"Camera guidance scale: {args.camera_guidance_scale}")
|
| 975 |
+
print(f"Text guidance scale: {args.text_guidance_scale}")
|
| 976 |
+
print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
|
| 977 |
+
|
| 978 |
+
# 验证NuScenes参数
|
| 979 |
+
if args.modality_type == "nuscenes" and not args.scene_info_path:
|
| 980 |
+
print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
|
| 981 |
+
|
| 982 |
+
inference_moe_framepack_sliding_window(
|
| 983 |
+
condition_pth_path=args.condition_pth,
|
| 984 |
+
dit_path=args.dit_path,
|
| 985 |
+
output_path=args.output_path,
|
| 986 |
+
start_frame=args.start_frame,
|
| 987 |
+
initial_condition_frames=args.initial_condition_frames,
|
| 988 |
+
frames_per_generation=args.frames_per_generation,
|
| 989 |
+
total_frames_to_generate=args.total_frames_to_generate,
|
| 990 |
+
max_history_frames=args.max_history_frames,
|
| 991 |
+
device=args.device,
|
| 992 |
+
prompt=args.prompt,
|
| 993 |
+
modality_type=args.modality_type,
|
| 994 |
+
use_real_poses=args.use_real_poses,
|
| 995 |
+
scene_info_path=args.scene_info_path,
|
| 996 |
+
# CFG参数
|
| 997 |
+
use_camera_cfg=args.use_camera_cfg,
|
| 998 |
+
camera_guidance_scale=args.camera_guidance_scale,
|
| 999 |
+
text_guidance_scale=args.text_guidance_scale,
|
| 1000 |
+
# MoE参数
|
| 1001 |
+
moe_num_experts=args.moe_num_experts,
|
| 1002 |
+
moe_top_k=args.moe_top_k,
|
| 1003 |
+
moe_hidden_dim=args.moe_hidden_dim
|
| 1004 |
+
)
|
| 1005 |
+
|
| 1006 |
+
|
| 1007 |
+
if __name__ == "__main__":
|
| 1008 |
+
main()
|
scripts/infer_moe_test.py
ADDED
|
@@ -0,0 +1,976 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import imageio
|
| 7 |
+
import json
|
| 8 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 9 |
+
import argparse
|
| 10 |
+
from torchvision.transforms import v2
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
import copy
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 16 |
+
"""从pth文件加载预编码的视频数据"""
|
| 17 |
+
print(f"Loading encoded video from {pth_path}")
|
| 18 |
+
|
| 19 |
+
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
| 20 |
+
full_latents = encoded_data['latents'] # [C, T, H, W]
|
| 21 |
+
|
| 22 |
+
print(f"Full latents shape: {full_latents.shape}")
|
| 23 |
+
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
|
| 24 |
+
|
| 25 |
+
if start_frame + num_frames > full_latents.shape[1]:
|
| 26 |
+
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
|
| 27 |
+
|
| 28 |
+
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
|
| 29 |
+
print(f"Extracted condition latents shape: {condition_latents.shape}")
|
| 30 |
+
|
| 31 |
+
return condition_latents, encoded_data
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 35 |
+
"""计算相机B相对于相机A的相对位姿矩阵"""
|
| 36 |
+
assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
|
| 37 |
+
assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
|
| 38 |
+
|
| 39 |
+
if use_torch:
|
| 40 |
+
if not isinstance(pose_a, torch.Tensor):
|
| 41 |
+
pose_a = torch.from_numpy(pose_a).float()
|
| 42 |
+
if not isinstance(pose_b, torch.Tensor):
|
| 43 |
+
pose_b = torch.from_numpy(pose_b).float()
|
| 44 |
+
|
| 45 |
+
pose_a_inv = torch.inverse(pose_a)
|
| 46 |
+
relative_pose = torch.matmul(pose_b, pose_a_inv)
|
| 47 |
+
else:
|
| 48 |
+
if not isinstance(pose_a, np.ndarray):
|
| 49 |
+
pose_a = np.array(pose_a, dtype=np.float32)
|
| 50 |
+
if not isinstance(pose_b, np.ndarray):
|
| 51 |
+
pose_b = np.array(pose_b, dtype=np.float32)
|
| 52 |
+
|
| 53 |
+
pose_a_inv = np.linalg.inv(pose_a)
|
| 54 |
+
relative_pose = np.matmul(pose_b, pose_a_inv)
|
| 55 |
+
|
| 56 |
+
return relative_pose
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def replace_dit_model_in_manager():
|
| 60 |
+
"""替换DiT模型类为MoE版本"""
|
| 61 |
+
from diffsynth.models.wan_video_dit_moe import WanModelMoe
|
| 62 |
+
from diffsynth.configs.model_config import model_loader_configs
|
| 63 |
+
|
| 64 |
+
for i, config in enumerate(model_loader_configs):
|
| 65 |
+
keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
|
| 66 |
+
|
| 67 |
+
if 'wan_video_dit' in model_names:
|
| 68 |
+
new_model_names = []
|
| 69 |
+
new_model_classes = []
|
| 70 |
+
|
| 71 |
+
for name, cls in zip(model_names, model_classes):
|
| 72 |
+
if name == 'wan_video_dit':
|
| 73 |
+
new_model_names.append(name)
|
| 74 |
+
new_model_classes.append(WanModelMoe)
|
| 75 |
+
print(f"✅ 替换了模型类: {name} -> WanModelMoe")
|
| 76 |
+
else:
|
| 77 |
+
new_model_names.append(name)
|
| 78 |
+
new_model_classes.append(cls)
|
| 79 |
+
|
| 80 |
+
model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def add_framepack_components(dit_model):
|
| 84 |
+
"""添加FramePack相关组件"""
|
| 85 |
+
if not hasattr(dit_model, 'clean_x_embedder'):
|
| 86 |
+
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 87 |
+
|
| 88 |
+
class CleanXEmbedder(nn.Module):
|
| 89 |
+
def __init__(self, inner_dim):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 92 |
+
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
| 93 |
+
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
| 94 |
+
|
| 95 |
+
def forward(self, x, scale="1x"):
|
| 96 |
+
if scale == "1x":
|
| 97 |
+
x = x.to(self.proj.weight.dtype)
|
| 98 |
+
return self.proj(x)
|
| 99 |
+
elif scale == "2x":
|
| 100 |
+
x = x.to(self.proj_2x.weight.dtype)
|
| 101 |
+
return self.proj_2x(x)
|
| 102 |
+
elif scale == "4x":
|
| 103 |
+
x = x.to(self.proj_4x.weight.dtype)
|
| 104 |
+
return self.proj_4x(x)
|
| 105 |
+
else:
|
| 106 |
+
raise ValueError(f"Unsupported scale: {scale}")
|
| 107 |
+
|
| 108 |
+
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
|
| 109 |
+
model_dtype = next(dit_model.parameters()).dtype
|
| 110 |
+
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
|
| 111 |
+
print("✅ 添加了FramePack的clean_x_embedder组件")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def add_moe_components(dit_model, moe_config):
|
| 115 |
+
"""🔧 添加MoE相关组件 - 修正版本"""
|
| 116 |
+
if not hasattr(dit_model, 'moe_config'):
|
| 117 |
+
dit_model.moe_config = moe_config
|
| 118 |
+
print("✅ 添加了MoE配置到模型")
|
| 119 |
+
|
| 120 |
+
# 为每个block动态添加MoE组件
|
| 121 |
+
dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 122 |
+
unified_dim = moe_config.get("unified_dim", 25)
|
| 123 |
+
|
| 124 |
+
for i, block in enumerate(dit_model.blocks):
|
| 125 |
+
from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
|
| 126 |
+
|
| 127 |
+
# Sekai模态处理器 - 输出unified_dim
|
| 128 |
+
block.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
|
| 129 |
+
|
| 130 |
+
# # NuScenes模态处理器 - 输出unified_dim
|
| 131 |
+
# block.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
|
| 132 |
+
|
| 133 |
+
# MoE网络 - 输入unified_dim,输出dim
|
| 134 |
+
block.moe = MultiModalMoE(
|
| 135 |
+
unified_dim=unified_dim,
|
| 136 |
+
output_dim=dim, # 输出维度匹配transformer block的dim
|
| 137 |
+
num_experts=moe_config.get("num_experts", 4),
|
| 138 |
+
top_k=moe_config.get("top_k", 2)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
|
| 145 |
+
"""为Sekai数据集生成camera embeddings - 滑动窗口版本"""
|
| 146 |
+
time_compression_ratio = 4
|
| 147 |
+
|
| 148 |
+
# 计算FramePack实际需要的camera帧数
|
| 149 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 150 |
+
|
| 151 |
+
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
|
| 152 |
+
print("🔧 使用真实Sekai camera数据")
|
| 153 |
+
cam_extrinsic = cam_data['extrinsic']
|
| 154 |
+
|
| 155 |
+
# 确保生成足够长的camera序列
|
| 156 |
+
max_needed_frames = max(
|
| 157 |
+
start_frame + current_history_length + new_frames,
|
| 158 |
+
framepack_needed_frames,
|
| 159 |
+
30
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
print(f"🔧 计算Sekai camera序列长度:")
|
| 163 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 164 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 165 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 166 |
+
|
| 167 |
+
relative_poses = []
|
| 168 |
+
for i in range(max_needed_frames):
|
| 169 |
+
# 计算当前帧在原始序列中的位置
|
| 170 |
+
frame_idx = i * time_compression_ratio
|
| 171 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 172 |
+
|
| 173 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 174 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 175 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 176 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 177 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 178 |
+
else:
|
| 179 |
+
# 超出范围,使用零运动
|
| 180 |
+
print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
|
| 181 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 182 |
+
|
| 183 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 184 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 185 |
+
|
| 186 |
+
# 创建对应长度的mask序列
|
| 187 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 188 |
+
# 从start_frame到current_history_length标记为condition
|
| 189 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 190 |
+
mask[start_frame:condition_end] = 1.0
|
| 191 |
+
|
| 192 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 193 |
+
print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
|
| 194 |
+
return camera_embedding.to(torch.bfloat16)
|
| 195 |
+
|
| 196 |
+
else:
|
| 197 |
+
print("🔧 使用Sekai合成camera数据")
|
| 198 |
+
|
| 199 |
+
max_needed_frames = max(
|
| 200 |
+
start_frame + current_history_length + new_frames,
|
| 201 |
+
framepack_needed_frames,
|
| 202 |
+
30
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
|
| 206 |
+
relative_poses = []
|
| 207 |
+
for i in range(max_needed_frames):
|
| 208 |
+
# 持续左转运动模式
|
| 209 |
+
yaw_per_frame = 0.05 # 每帧左转(正角度表示左转)
|
| 210 |
+
forward_speed = 0.005 # 每帧前进距离
|
| 211 |
+
|
| 212 |
+
pose = np.eye(4, dtype=np.float32)
|
| 213 |
+
|
| 214 |
+
# 旋转矩阵(绕Y轴左转)
|
| 215 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 216 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 217 |
+
|
| 218 |
+
pose[0, 0] = cos_yaw
|
| 219 |
+
pose[0, 2] = sin_yaw
|
| 220 |
+
pose[2, 0] = -sin_yaw
|
| 221 |
+
pose[2, 2] = cos_yaw
|
| 222 |
+
|
| 223 |
+
# 平移(在旋转后的局部坐标系中前进)
|
| 224 |
+
pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
|
| 225 |
+
|
| 226 |
+
# 添加轻微的向心运动,模拟圆形轨迹
|
| 227 |
+
radius_drift = 0.002 # 向圆心的轻微漂移
|
| 228 |
+
pose[0, 3] = -radius_drift # 局部X轴负方向(向左)
|
| 229 |
+
|
| 230 |
+
relative_pose = pose[:3, :]
|
| 231 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 232 |
+
|
| 233 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 234 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 235 |
+
|
| 236 |
+
# 创建对应长度的mask序列
|
| 237 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 238 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 239 |
+
mask[start_frame:condition_end] = 1.0
|
| 240 |
+
|
| 241 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 242 |
+
print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
|
| 243 |
+
return camera_embedding.to(torch.bfloat16)
|
| 244 |
+
|
| 245 |
+
def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses):
|
| 246 |
+
"""为OpenX数据集生成camera embeddings - 滑动窗口版本"""
|
| 247 |
+
time_compression_ratio = 4
|
| 248 |
+
|
| 249 |
+
# 计算FramePack实际需要的camera帧数
|
| 250 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 251 |
+
|
| 252 |
+
if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
|
| 253 |
+
print("🔧 使用OpenX真实camera数据")
|
| 254 |
+
cam_extrinsic = encoded_data['cam_emb']['extrinsic']
|
| 255 |
+
|
| 256 |
+
# 确保生成足够长的camera序列
|
| 257 |
+
max_needed_frames = max(
|
| 258 |
+
start_frame + current_history_length + new_frames,
|
| 259 |
+
framepack_needed_frames,
|
| 260 |
+
30
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
print(f"🔧 计算OpenX camera序列长度:")
|
| 264 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 265 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 266 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 267 |
+
|
| 268 |
+
relative_poses = []
|
| 269 |
+
for i in range(max_needed_frames):
|
| 270 |
+
# OpenX使用4倍间隔,类似sekai但处理更短的序列
|
| 271 |
+
frame_idx = i * time_compression_ratio
|
| 272 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 273 |
+
|
| 274 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 275 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 276 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 277 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 278 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 279 |
+
else:
|
| 280 |
+
# 超出范围,使用零运动
|
| 281 |
+
print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
|
| 282 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 283 |
+
|
| 284 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 285 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 286 |
+
|
| 287 |
+
# 创建对应长度的mask序列
|
| 288 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 289 |
+
# 从start_frame到current_history_length标记为condition
|
| 290 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 291 |
+
mask[start_frame:condition_end] = 1.0
|
| 292 |
+
|
| 293 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 294 |
+
print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
|
| 295 |
+
return camera_embedding.to(torch.bfloat16)
|
| 296 |
+
|
| 297 |
+
else:
|
| 298 |
+
print("🔧 使用OpenX合成camera数据")
|
| 299 |
+
|
| 300 |
+
max_needed_frames = max(
|
| 301 |
+
start_frame + current_history_length + new_frames,
|
| 302 |
+
framepack_needed_frames,
|
| 303 |
+
30
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
|
| 307 |
+
relative_poses = []
|
| 308 |
+
for i in range(max_needed_frames):
|
| 309 |
+
# OpenX机器人操作运动模式 - 较小的运动幅度
|
| 310 |
+
# 模拟机器人手臂的精细操作运动
|
| 311 |
+
roll_per_frame = 0.02 # 轻微翻滚
|
| 312 |
+
pitch_per_frame = 0.01 # 轻微俯仰
|
| 313 |
+
yaw_per_frame = 0.015 # 轻微偏航
|
| 314 |
+
forward_speed = 0.003 # 较慢的前进速度
|
| 315 |
+
|
| 316 |
+
pose = np.eye(4, dtype=np.float32)
|
| 317 |
+
|
| 318 |
+
# 复合旋转 - 模拟机器人手臂的复杂运动
|
| 319 |
+
# 绕X轴旋转(roll)
|
| 320 |
+
cos_roll = np.cos(roll_per_frame)
|
| 321 |
+
sin_roll = np.sin(roll_per_frame)
|
| 322 |
+
# 绕Y轴旋转(pitch)
|
| 323 |
+
cos_pitch = np.cos(pitch_per_frame)
|
| 324 |
+
sin_pitch = np.sin(pitch_per_frame)
|
| 325 |
+
# 绕Z轴旋转(yaw)
|
| 326 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 327 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 328 |
+
|
| 329 |
+
# 简化的复合旋转矩阵(ZYX顺序)
|
| 330 |
+
pose[0, 0] = cos_yaw * cos_pitch
|
| 331 |
+
pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
|
| 332 |
+
pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
|
| 333 |
+
pose[1, 0] = sin_yaw * cos_pitch
|
| 334 |
+
pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
|
| 335 |
+
pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
|
| 336 |
+
pose[2, 0] = -sin_pitch
|
| 337 |
+
pose[2, 1] = cos_pitch * sin_roll
|
| 338 |
+
pose[2, 2] = cos_pitch * cos_roll
|
| 339 |
+
|
| 340 |
+
# 平移 - 模拟机器人操作的精细移动
|
| 341 |
+
pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
|
| 342 |
+
pose[1, 3] = forward_speed * 0.3 # Y��向轻微移动
|
| 343 |
+
pose[2, 3] = -forward_speed # Z方向(深度)主要移动
|
| 344 |
+
|
| 345 |
+
relative_pose = pose[:3, :]
|
| 346 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 347 |
+
|
| 348 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 349 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 350 |
+
|
| 351 |
+
# 创建对应长度的mask序列
|
| 352 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 353 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 354 |
+
mask[start_frame:condition_end] = 1.0
|
| 355 |
+
|
| 356 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 357 |
+
print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
|
| 358 |
+
return camera_embedding.to(torch.bfloat16)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames):
|
| 362 |
+
"""为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
|
| 363 |
+
time_compression_ratio = 4
|
| 364 |
+
|
| 365 |
+
# 计算FramePack实际需要的camera帧数
|
| 366 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 367 |
+
|
| 368 |
+
if scene_info is not None and 'keyframe_poses' in scene_info:
|
| 369 |
+
print("🔧 使用NuScenes真实pose数据")
|
| 370 |
+
keyframe_poses = scene_info['keyframe_poses']
|
| 371 |
+
|
| 372 |
+
if len(keyframe_poses) == 0:
|
| 373 |
+
print("⚠️ NuScenes keyframe_poses为空,使用零pose")
|
| 374 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 375 |
+
|
| 376 |
+
pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
|
| 377 |
+
|
| 378 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 379 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 380 |
+
mask[start_frame:condition_end] = 1.0
|
| 381 |
+
|
| 382 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 383 |
+
print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}")
|
| 384 |
+
return camera_embedding.to(torch.bfloat16)
|
| 385 |
+
|
| 386 |
+
# 使用第一个pose作为参考
|
| 387 |
+
reference_pose = keyframe_poses[0]
|
| 388 |
+
|
| 389 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 390 |
+
|
| 391 |
+
pose_vecs = []
|
| 392 |
+
for i in range(max_needed_frames):
|
| 393 |
+
if i < len(keyframe_poses):
|
| 394 |
+
current_pose = keyframe_poses[i]
|
| 395 |
+
|
| 396 |
+
# 计算相对位移
|
| 397 |
+
translation = torch.tensor(
|
| 398 |
+
np.array(current_pose['translation']) - np.array(reference_pose['translation']),
|
| 399 |
+
dtype=torch.float32
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# 计算相对旋转(简化版本)
|
| 403 |
+
rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
|
| 404 |
+
|
| 405 |
+
pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
|
| 406 |
+
else:
|
| 407 |
+
# 超出范围,使用零pose
|
| 408 |
+
pose_vec = torch.cat([
|
| 409 |
+
torch.zeros(3, dtype=torch.float32),
|
| 410 |
+
torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
|
| 411 |
+
], dim=0) # [7D]
|
| 412 |
+
|
| 413 |
+
pose_vecs.append(pose_vec)
|
| 414 |
+
|
| 415 |
+
pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
|
| 416 |
+
|
| 417 |
+
# 创建mask
|
| 418 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 419 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 420 |
+
mask[start_frame:condition_end] = 1.0
|
| 421 |
+
|
| 422 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 423 |
+
print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
|
| 424 |
+
return camera_embedding.to(torch.bfloat16)
|
| 425 |
+
|
| 426 |
+
else:
|
| 427 |
+
print("🔧 使用NuScenes合成pose数据")
|
| 428 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 429 |
+
|
| 430 |
+
# 创建合成运动序列
|
| 431 |
+
pose_vecs = []
|
| 432 |
+
for i in range(max_needed_frames):
|
| 433 |
+
# 简单的前进运动
|
| 434 |
+
translation = torch.tensor([0.0, 0.0, i * 0.1], dtype=torch.float32) # 沿Z轴前进
|
| 435 |
+
rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) # 无旋转
|
| 436 |
+
|
| 437 |
+
pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
|
| 438 |
+
pose_vecs.append(pose_vec)
|
| 439 |
+
|
| 440 |
+
pose_sequence = torch.stack(pose_vecs, dim=0)
|
| 441 |
+
|
| 442 |
+
# 创建mask
|
| 443 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 444 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 445 |
+
mask[start_frame:condition_end] = 1.0
|
| 446 |
+
|
| 447 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 448 |
+
print(f"🔧 NuScenes合成pose embedding shape: {camera_embedding.shape}")
|
| 449 |
+
return camera_embedding.to(torch.bfloat16)
|
| 450 |
+
|
| 451 |
+
def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49):
|
| 452 |
+
"""FramePack滑动窗口机制 - MoE版本"""
|
| 453 |
+
# history_latents: [C, T, H, W] 当前的历史latents
|
| 454 |
+
C, T, H, W = history_latents.shape
|
| 455 |
+
|
| 456 |
+
# 固定索引结构(这决定了需要的camera帧数)
|
| 457 |
+
total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
|
| 458 |
+
indices = torch.arange(0, total_indices_length)
|
| 459 |
+
split_sizes = [1, 16, 2, 1, target_frames_to_generate]
|
| 460 |
+
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
|
| 461 |
+
indices.split(split_sizes, dim=0)
|
| 462 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
|
| 463 |
+
|
| 464 |
+
# 检查camera长度是否足够
|
| 465 |
+
if camera_embedding_full.shape[0] < total_indices_length:
|
| 466 |
+
shortage = total_indices_length - camera_embedding_full.shape[0]
|
| 467 |
+
padding = torch.zeros(shortage, camera_embedding_full.shape[1],
|
| 468 |
+
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
|
| 469 |
+
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
|
| 470 |
+
|
| 471 |
+
# 从完整camera序列中选取对应部分
|
| 472 |
+
combined_camera = camera_embedding_full[:total_indices_length, :].clone()
|
| 473 |
+
|
| 474 |
+
# 根据当前history length重新设置mask
|
| 475 |
+
combined_camera[:, -1] = 0.0 # 先全部设为target (0)
|
| 476 |
+
|
| 477 |
+
# 设置condition mask:前19帧根据实际历史长度决定
|
| 478 |
+
if T > 0:
|
| 479 |
+
available_frames = min(T, 19)
|
| 480 |
+
start_pos = 19 - available_frames
|
| 481 |
+
combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
|
| 482 |
+
|
| 483 |
+
print(f"🔧 MoE Camera mask更新:")
|
| 484 |
+
print(f" - 历史帧数: {T}")
|
| 485 |
+
print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
|
| 486 |
+
print(f" - 模态类型: {modality_type}")
|
| 487 |
+
|
| 488 |
+
# 处理latents
|
| 489 |
+
clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 490 |
+
|
| 491 |
+
if T > 0:
|
| 492 |
+
available_frames = min(T, 19)
|
| 493 |
+
start_pos = 19 - available_frames
|
| 494 |
+
clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
|
| 495 |
+
|
| 496 |
+
clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
|
| 497 |
+
clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
|
| 498 |
+
clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
|
| 499 |
+
|
| 500 |
+
if T > 0:
|
| 501 |
+
start_latent = history_latents[:, 0:1, :, :]
|
| 502 |
+
else:
|
| 503 |
+
start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 504 |
+
|
| 505 |
+
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
|
| 506 |
+
|
| 507 |
+
return {
|
| 508 |
+
'latent_indices': latent_indices,
|
| 509 |
+
'clean_latents': clean_latents,
|
| 510 |
+
'clean_latents_2x': clean_latents_2x,
|
| 511 |
+
'clean_latents_4x': clean_latents_4x,
|
| 512 |
+
'clean_latent_indices': clean_latent_indices,
|
| 513 |
+
'clean_latent_2x_indices': clean_latent_2x_indices,
|
| 514 |
+
'clean_latent_4x_indices': clean_latent_4x_indices,
|
| 515 |
+
'camera_embedding': combined_camera,
|
| 516 |
+
'modality_type': modality_type, # 新增模态类型信息
|
| 517 |
+
'current_length': T,
|
| 518 |
+
'next_length': T + target_frames_to_generate
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def inference_moe_framepack_sliding_window(
|
| 523 |
+
condition_pth_path,
|
| 524 |
+
dit_path,
|
| 525 |
+
output_path="moe/infer_results/output_moe_framepack_sliding.mp4",
|
| 526 |
+
start_frame=0,
|
| 527 |
+
initial_condition_frames=8,
|
| 528 |
+
frames_per_generation=4,
|
| 529 |
+
total_frames_to_generate=32,
|
| 530 |
+
max_history_frames=49,
|
| 531 |
+
device="cuda",
|
| 532 |
+
prompt="A video of a scene shot using a pedestrian's front camera while walking",
|
| 533 |
+
modality_type="sekai", # "sekai" 或 "nuscenes"
|
| 534 |
+
use_real_poses=True,
|
| 535 |
+
scene_info_path=None, # 对于NuScenes数据集
|
| 536 |
+
# CFG参数
|
| 537 |
+
use_camera_cfg=True,
|
| 538 |
+
camera_guidance_scale=2.0,
|
| 539 |
+
text_guidance_scale=1.0,
|
| 540 |
+
# MoE参数
|
| 541 |
+
moe_num_experts=4,
|
| 542 |
+
moe_top_k=2,
|
| 543 |
+
moe_hidden_dim=None
|
| 544 |
+
):
|
| 545 |
+
"""
|
| 546 |
+
MoE FramePack滑动窗口视频生成 - 支持多模态
|
| 547 |
+
"""
|
| 548 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 549 |
+
print(f"🔧 MoE FramePack滑动窗口生成开始...")
|
| 550 |
+
print(f"模态类型: {modality_type}")
|
| 551 |
+
print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
|
| 552 |
+
print(f"Text guidance scale: {text_guidance_scale}")
|
| 553 |
+
print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
|
| 554 |
+
|
| 555 |
+
# 1. 模型初始化
|
| 556 |
+
replace_dit_model_in_manager()
|
| 557 |
+
|
| 558 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 559 |
+
model_manager.load_models([
|
| 560 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 561 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 562 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 563 |
+
])
|
| 564 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 565 |
+
|
| 566 |
+
# 2. 添加传统camera编码器(兼容性)
|
| 567 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 568 |
+
for block in pipe.dit.blocks:
|
| 569 |
+
block.cam_encoder = nn.Linear(13, dim)
|
| 570 |
+
block.projector = nn.Linear(dim, dim)
|
| 571 |
+
block.cam_encoder.weight.data.zero_()
|
| 572 |
+
block.cam_encoder.bias.data.zero_()
|
| 573 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 574 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 575 |
+
|
| 576 |
+
# 3. 添加FramePack组件
|
| 577 |
+
add_framepack_components(pipe.dit)
|
| 578 |
+
|
| 579 |
+
# 4. 添加MoE组件
|
| 580 |
+
moe_config = {
|
| 581 |
+
"num_experts": moe_num_experts,
|
| 582 |
+
"top_k": moe_top_k,
|
| 583 |
+
"hidden_dim": moe_hidden_dim or dim * 2,
|
| 584 |
+
"sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
|
| 585 |
+
"nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
|
| 586 |
+
"openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
|
| 587 |
+
}
|
| 588 |
+
add_moe_components(pipe.dit, moe_config)
|
| 589 |
+
|
| 590 |
+
# 5. 加载训练好的权重
|
| 591 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 592 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
|
| 593 |
+
pipe = pipe.to(device)
|
| 594 |
+
model_dtype = next(pipe.dit.parameters()).dtype
|
| 595 |
+
|
| 596 |
+
if hasattr(pipe.dit, 'clean_x_embedder'):
|
| 597 |
+
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
|
| 598 |
+
|
| 599 |
+
pipe.scheduler.set_timesteps(50)
|
| 600 |
+
|
| 601 |
+
# 6. 加载初始条件
|
| 602 |
+
print("Loading initial condition frames...")
|
| 603 |
+
initial_latents, encoded_data = load_encoded_video_from_pth(
|
| 604 |
+
condition_pth_path,
|
| 605 |
+
start_frame=start_frame,
|
| 606 |
+
num_frames=initial_condition_frames
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
# 空间裁剪
|
| 610 |
+
target_height, target_width = 60, 104
|
| 611 |
+
C, T, H, W = initial_latents.shape
|
| 612 |
+
|
| 613 |
+
if H > target_height or W > target_width:
|
| 614 |
+
h_start = (H - target_height) // 2
|
| 615 |
+
w_start = (W - target_width) // 2
|
| 616 |
+
initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
|
| 617 |
+
H, W = target_height, target_width
|
| 618 |
+
|
| 619 |
+
history_latents = initial_latents.to(device, dtype=model_dtype)
|
| 620 |
+
|
| 621 |
+
print(f"初始history_latents shape: {history_latents.shape}")
|
| 622 |
+
|
| 623 |
+
# 7. 编码prompt - 支持CFG
|
| 624 |
+
if text_guidance_scale > 1.0:
|
| 625 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 626 |
+
prompt_emb_neg = pipe.encode_prompt("")
|
| 627 |
+
print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
|
| 628 |
+
else:
|
| 629 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 630 |
+
prompt_emb_neg = None
|
| 631 |
+
print("不使用Text CFG")
|
| 632 |
+
|
| 633 |
+
# 8. 加载场景信息(对于NuScenes)
|
| 634 |
+
scene_info = None
|
| 635 |
+
if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
|
| 636 |
+
with open(scene_info_path, 'r') as f:
|
| 637 |
+
scene_info = json.load(f)
|
| 638 |
+
print(f"加载NuScenes场景信息: {scene_info_path}")
|
| 639 |
+
|
| 640 |
+
# 9. 预生成完整的camera embedding序列
|
| 641 |
+
if modality_type == "sekai":
|
| 642 |
+
camera_embedding_full = generate_sekai_camera_embeddings_sliding(
|
| 643 |
+
encoded_data.get('cam_emb', None),
|
| 644 |
+
0,
|
| 645 |
+
max_history_frames,
|
| 646 |
+
0,
|
| 647 |
+
0,
|
| 648 |
+
use_real_poses=use_real_poses
|
| 649 |
+
).to(device, dtype=model_dtype)
|
| 650 |
+
elif modality_type == "nuscenes":
|
| 651 |
+
camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
|
| 652 |
+
scene_info,
|
| 653 |
+
0,
|
| 654 |
+
max_history_frames,
|
| 655 |
+
0
|
| 656 |
+
).to(device, dtype=model_dtype)
|
| 657 |
+
elif modality_type == "openx":
|
| 658 |
+
camera_embedding_full = generate_openx_camera_embeddings_sliding(
|
| 659 |
+
encoded_data,
|
| 660 |
+
0,
|
| 661 |
+
max_history_frames,
|
| 662 |
+
0,
|
| 663 |
+
use_real_poses=use_real_poses
|
| 664 |
+
).to(device, dtype=model_dtype)
|
| 665 |
+
else:
|
| 666 |
+
raise ValueError(f"不支持的模态类型: {modality_type}")
|
| 667 |
+
|
| 668 |
+
print(f"完整camera序列shape: {camera_embedding_full.shape}")
|
| 669 |
+
|
| 670 |
+
# 10. 为Camera CFG创建无条件的camera embedding
|
| 671 |
+
if use_camera_cfg:
|
| 672 |
+
camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
|
| 673 |
+
print(f"创建无条件camera embedding用于CFG")
|
| 674 |
+
|
| 675 |
+
# 11. 滑动窗口生成循环
|
| 676 |
+
total_generated = 0
|
| 677 |
+
all_generated_frames = []
|
| 678 |
+
|
| 679 |
+
while total_generated < total_frames_to_generate:
|
| 680 |
+
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
|
| 681 |
+
print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
|
| 682 |
+
print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
|
| 683 |
+
|
| 684 |
+
# FramePack数据准备 - MoE版本
|
| 685 |
+
framepack_data = prepare_framepack_sliding_window_with_camera_moe(
|
| 686 |
+
history_latents,
|
| 687 |
+
current_generation,
|
| 688 |
+
camera_embedding_full,
|
| 689 |
+
start_frame,
|
| 690 |
+
modality_type,
|
| 691 |
+
max_history_frames
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
# 准备输入
|
| 695 |
+
clean_latents = framepack_data['clean_latents'].unsqueeze(0)
|
| 696 |
+
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
|
| 697 |
+
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
|
| 698 |
+
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
|
| 699 |
+
|
| 700 |
+
# 准备modality_inputs
|
| 701 |
+
modality_inputs = {modality_type: camera_embedding}
|
| 702 |
+
|
| 703 |
+
# 为CFG准备无条件camera embedding
|
| 704 |
+
if use_camera_cfg:
|
| 705 |
+
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
|
| 706 |
+
modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
|
| 707 |
+
|
| 708 |
+
# 索引处理
|
| 709 |
+
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
|
| 710 |
+
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
|
| 711 |
+
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
|
| 712 |
+
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
|
| 713 |
+
|
| 714 |
+
# 初始化要生成的latents
|
| 715 |
+
new_latents = torch.randn(
|
| 716 |
+
1, C, current_generation, H, W,
|
| 717 |
+
device=device, dtype=model_dtype
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
extra_input = pipe.prepare_extra_input(new_latents)
|
| 721 |
+
|
| 722 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 723 |
+
print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
|
| 724 |
+
|
| 725 |
+
# 去噪循环 - 支持CFG
|
| 726 |
+
timesteps = pipe.scheduler.timesteps
|
| 727 |
+
|
| 728 |
+
for i, timestep in enumerate(timesteps):
|
| 729 |
+
if i % 10 == 0:
|
| 730 |
+
print(f" 去噪步骤 {i+1}/{len(timesteps)}")
|
| 731 |
+
|
| 732 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
|
| 733 |
+
|
| 734 |
+
with torch.no_grad():
|
| 735 |
+
# CFG推理
|
| 736 |
+
if use_camera_cfg and camera_guidance_scale > 1.0:
|
| 737 |
+
# 条件预测(有camera)
|
| 738 |
+
noise_pred_cond, moe_loss = pipe.dit(
|
| 739 |
+
new_latents,
|
| 740 |
+
timestep=timestep_tensor,
|
| 741 |
+
cam_emb=camera_embedding,
|
| 742 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 743 |
+
latent_indices=latent_indices,
|
| 744 |
+
clean_latents=clean_latents,
|
| 745 |
+
clean_latent_indices=clean_latent_indices,
|
| 746 |
+
clean_latents_2x=clean_latents_2x,
|
| 747 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 748 |
+
clean_latents_4x=clean_latents_4x,
|
| 749 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 750 |
+
**prompt_emb_pos,
|
| 751 |
+
**extra_input
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
# 无条件预测(无camera)
|
| 755 |
+
noise_pred_uncond, moe_loss = pipe.dit(
|
| 756 |
+
new_latents,
|
| 757 |
+
timestep=timestep_tensor,
|
| 758 |
+
cam_emb=camera_embedding_uncond_batch,
|
| 759 |
+
modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
|
| 760 |
+
latent_indices=latent_indices,
|
| 761 |
+
clean_latents=clean_latents,
|
| 762 |
+
clean_latent_indices=clean_latent_indices,
|
| 763 |
+
clean_latents_2x=clean_latents_2x,
|
| 764 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 765 |
+
clean_latents_4x=clean_latents_4x,
|
| 766 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 767 |
+
**(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
|
| 768 |
+
**extra_input
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
# Camera CFG
|
| 772 |
+
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 773 |
+
|
| 774 |
+
# 如果同时使用Text CFG
|
| 775 |
+
if text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 776 |
+
noise_pred_text_uncond, moe_loss = pipe.dit(
|
| 777 |
+
new_latents,
|
| 778 |
+
timestep=timestep_tensor,
|
| 779 |
+
cam_emb=camera_embedding,
|
| 780 |
+
modality_inputs=modality_inputs,
|
| 781 |
+
latent_indices=latent_indices,
|
| 782 |
+
clean_latents=clean_latents,
|
| 783 |
+
clean_latent_indices=clean_latent_indices,
|
| 784 |
+
clean_latents_2x=clean_latents_2x,
|
| 785 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 786 |
+
clean_latents_4x=clean_latents_4x,
|
| 787 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 788 |
+
**prompt_emb_neg,
|
| 789 |
+
**extra_input
|
| 790 |
+
)
|
| 791 |
+
|
| 792 |
+
# 应用Text CFG到已经应用Camera CFG的结果
|
| 793 |
+
noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
|
| 794 |
+
|
| 795 |
+
elif text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 796 |
+
# 只使用Text CFG
|
| 797 |
+
noise_pred_cond, moe_loss = pipe.dit(
|
| 798 |
+
new_latents,
|
| 799 |
+
timestep=timestep_tensor,
|
| 800 |
+
cam_emb=camera_embedding,
|
| 801 |
+
modality_inputs=modality_inputs,
|
| 802 |
+
latent_indices=latent_indices,
|
| 803 |
+
clean_latents=clean_latents,
|
| 804 |
+
clean_latent_indices=clean_latent_indices,
|
| 805 |
+
clean_latents_2x=clean_latents_2x,
|
| 806 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 807 |
+
clean_latents_4x=clean_latents_4x,
|
| 808 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 809 |
+
**prompt_emb_pos,
|
| 810 |
+
**extra_input
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
noise_pred_uncond, moe_loss = pipe.dit(
|
| 814 |
+
new_latents,
|
| 815 |
+
timestep=timestep_tensor,
|
| 816 |
+
cam_emb=camera_embedding,
|
| 817 |
+
modality_inputs=modality_inputs,
|
| 818 |
+
latent_indices=latent_indices,
|
| 819 |
+
clean_latents=clean_latents,
|
| 820 |
+
clean_latent_indices=clean_latent_indices,
|
| 821 |
+
clean_latents_2x=clean_latents_2x,
|
| 822 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 823 |
+
clean_latents_4x=clean_latents_4x,
|
| 824 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 825 |
+
**prompt_emb_neg,
|
| 826 |
+
**extra_input
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 830 |
+
|
| 831 |
+
else:
|
| 832 |
+
# 标准推理(无CFG)
|
| 833 |
+
noise_pred, moe_loss = pipe.dit(
|
| 834 |
+
new_latents,
|
| 835 |
+
timestep=timestep_tensor,
|
| 836 |
+
cam_emb=camera_embedding,
|
| 837 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 838 |
+
latent_indices=latent_indices,
|
| 839 |
+
clean_latents=clean_latents,
|
| 840 |
+
clean_latent_indices=clean_latent_indices,
|
| 841 |
+
clean_latents_2x=clean_latents_2x,
|
| 842 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 843 |
+
clean_latents_4x=clean_latents_4x,
|
| 844 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 845 |
+
**prompt_emb_pos,
|
| 846 |
+
**extra_input
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
|
| 850 |
+
|
| 851 |
+
# 更新历史
|
| 852 |
+
new_latents_squeezed = new_latents.squeeze(0)
|
| 853 |
+
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
|
| 854 |
+
|
| 855 |
+
# 维护滑动窗口
|
| 856 |
+
if history_latents.shape[1] > max_history_frames:
|
| 857 |
+
first_frame = history_latents[:, 0:1, :, :]
|
| 858 |
+
recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
|
| 859 |
+
history_latents = torch.cat([first_frame, recent_frames], dim=1)
|
| 860 |
+
print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
|
| 861 |
+
|
| 862 |
+
print(f"更新后history_latents shape: {history_latents.shape}")
|
| 863 |
+
|
| 864 |
+
all_generated_frames.append(new_latents_squeezed)
|
| 865 |
+
total_generated += current_generation
|
| 866 |
+
|
| 867 |
+
print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
|
| 868 |
+
|
| 869 |
+
# 12. 解码和保存
|
| 870 |
+
print("\n🔧 解码生成的视频...")
|
| 871 |
+
|
| 872 |
+
all_generated = torch.cat(all_generated_frames, dim=1)
|
| 873 |
+
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
|
| 874 |
+
|
| 875 |
+
print(f"最终视频shape: {final_video.shape}")
|
| 876 |
+
|
| 877 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 878 |
+
|
| 879 |
+
print(f"Saving video to {output_path}")
|
| 880 |
+
|
| 881 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
|
| 882 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1)
|
| 883 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 884 |
+
|
| 885 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 886 |
+
for frame in video_np:
|
| 887 |
+
writer.append_data(frame)
|
| 888 |
+
|
| 889 |
+
print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
|
| 890 |
+
print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
|
| 891 |
+
print(f"使用模态: {modality_type}")
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
def main():
|
| 895 |
+
parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
|
| 896 |
+
|
| 897 |
+
# 基��参数
|
| 898 |
+
parser.add_argument("--condition_pth", type=str,
|
| 899 |
+
default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
|
| 900 |
+
#default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth")
|
| 901 |
+
#default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth")
|
| 902 |
+
#default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth")
|
| 903 |
+
parser.add_argument("--start_frame", type=int, default=0)
|
| 904 |
+
parser.add_argument("--initial_condition_frames", type=int, default=16)
|
| 905 |
+
parser.add_argument("--frames_per_generation", type=int, default=8)
|
| 906 |
+
parser.add_argument("--total_frames_to_generate", type=int, default=40)
|
| 907 |
+
parser.add_argument("--max_history_frames", type=int, default=100)
|
| 908 |
+
parser.add_argument("--use_real_poses", action="store_true", default=False)
|
| 909 |
+
parser.add_argument("--dit_path", type=str,
|
| 910 |
+
default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe_test/step1000_moe.ckpt")
|
| 911 |
+
parser.add_argument("--output_path", type=str,
|
| 912 |
+
default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4')
|
| 913 |
+
parser.add_argument("--prompt", type=str,
|
| 914 |
+
default="A drone flying scene in a game world")
|
| 915 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 916 |
+
|
| 917 |
+
# 模态类型参数
|
| 918 |
+
parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="sekai",
|
| 919 |
+
help="模态类型:sekai 或 nuscenes 或 openx")
|
| 920 |
+
parser.add_argument("--scene_info_path", type=str, default=None,
|
| 921 |
+
help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
|
| 922 |
+
|
| 923 |
+
# CFG参数
|
| 924 |
+
parser.add_argument("--use_camera_cfg", default=True,
|
| 925 |
+
help="使用Camera CFG")
|
| 926 |
+
parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
|
| 927 |
+
help="Camera guidance scale for CFG")
|
| 928 |
+
parser.add_argument("--text_guidance_scale", type=float, default=1.0,
|
| 929 |
+
help="Text guidance scale for CFG")
|
| 930 |
+
|
| 931 |
+
# MoE参数
|
| 932 |
+
parser.add_argument("--moe_num_experts", type=int, default=1, help="专家数量")
|
| 933 |
+
parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
|
| 934 |
+
parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
|
| 935 |
+
|
| 936 |
+
args = parser.parse_args()
|
| 937 |
+
|
| 938 |
+
print(f"🔧 MoE FramePack CFG生成设置:")
|
| 939 |
+
print(f"模态类型: {args.modality_type}")
|
| 940 |
+
print(f"Camera CFG: {args.use_camera_cfg}")
|
| 941 |
+
if args.use_camera_cfg:
|
| 942 |
+
print(f"Camera guidance scale: {args.camera_guidance_scale}")
|
| 943 |
+
print(f"Text guidance scale: {args.text_guidance_scale}")
|
| 944 |
+
print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
|
| 945 |
+
|
| 946 |
+
# 验证NuScenes参数
|
| 947 |
+
if args.modality_type == "nuscenes" and not args.scene_info_path:
|
| 948 |
+
print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
|
| 949 |
+
|
| 950 |
+
inference_moe_framepack_sliding_window(
|
| 951 |
+
condition_pth_path=args.condition_pth,
|
| 952 |
+
dit_path=args.dit_path,
|
| 953 |
+
output_path=args.output_path,
|
| 954 |
+
start_frame=args.start_frame,
|
| 955 |
+
initial_condition_frames=args.initial_condition_frames,
|
| 956 |
+
frames_per_generation=args.frames_per_generation,
|
| 957 |
+
total_frames_to_generate=args.total_frames_to_generate,
|
| 958 |
+
max_history_frames=args.max_history_frames,
|
| 959 |
+
device=args.device,
|
| 960 |
+
prompt=args.prompt,
|
| 961 |
+
modality_type=args.modality_type,
|
| 962 |
+
use_real_poses=args.use_real_poses,
|
| 963 |
+
scene_info_path=args.scene_info_path,
|
| 964 |
+
# CFG参数
|
| 965 |
+
use_camera_cfg=args.use_camera_cfg,
|
| 966 |
+
camera_guidance_scale=args.camera_guidance_scale,
|
| 967 |
+
text_guidance_scale=args.text_guidance_scale,
|
| 968 |
+
# MoE参数
|
| 969 |
+
moe_num_experts=args.moe_num_experts,
|
| 970 |
+
moe_top_k=args.moe_top_k,
|
| 971 |
+
moe_hidden_dim=args.moe_hidden_dim
|
| 972 |
+
)
|
| 973 |
+
|
| 974 |
+
|
| 975 |
+
if __name__ == "__main__":
|
| 976 |
+
main()
|
scripts/infer_nus.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import imageio
|
| 6 |
+
import json
|
| 7 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 8 |
+
import argparse
|
| 9 |
+
from torchvision.transforms import v2
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from pose_classifier import PoseClassifier
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_video_frames(video_path, num_frames=20, height=900, width=1600):
|
| 16 |
+
"""Load video frames and preprocess them"""
|
| 17 |
+
frame_process = v2.Compose([
|
| 18 |
+
# v2.CenterCrop(size=(height, width)),
|
| 19 |
+
# v2.Resize(size=(height, width), antialias=True),
|
| 20 |
+
v2.ToTensor(),
|
| 21 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 22 |
+
])
|
| 23 |
+
|
| 24 |
+
def crop_and_resize(image):
|
| 25 |
+
w, h = image.size
|
| 26 |
+
# scale = max(width / w, height / h)
|
| 27 |
+
image = v2.functional.resize(
|
| 28 |
+
image,
|
| 29 |
+
(round(480), round(832)),
|
| 30 |
+
interpolation=v2.InterpolationMode.BILINEAR
|
| 31 |
+
)
|
| 32 |
+
return image
|
| 33 |
+
|
| 34 |
+
reader = imageio.get_reader(video_path)
|
| 35 |
+
frames = []
|
| 36 |
+
|
| 37 |
+
for i, frame_data in enumerate(reader):
|
| 38 |
+
if i >= num_frames:
|
| 39 |
+
break
|
| 40 |
+
frame = Image.fromarray(frame_data)
|
| 41 |
+
frame = crop_and_resize(frame)
|
| 42 |
+
frame = frame_process(frame)
|
| 43 |
+
frames.append(frame)
|
| 44 |
+
|
| 45 |
+
reader.close()
|
| 46 |
+
|
| 47 |
+
if len(frames) == 0:
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
frames = torch.stack(frames, dim=0)
|
| 51 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 52 |
+
return frames
|
| 53 |
+
|
| 54 |
+
def calculate_relative_rotation(current_rotation, reference_rotation):
|
| 55 |
+
"""计算相对旋转四元数"""
|
| 56 |
+
q_current = torch.tensor(current_rotation, dtype=torch.float32)
|
| 57 |
+
q_ref = torch.tensor(reference_rotation, dtype=torch.float32)
|
| 58 |
+
|
| 59 |
+
# 计算参考旋转的逆 (q_ref^-1)
|
| 60 |
+
q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]])
|
| 61 |
+
|
| 62 |
+
# 四元数乘法计算相对旋转: q_relative = q_ref^-1 * q_current
|
| 63 |
+
w1, x1, y1, z1 = q_ref_inv
|
| 64 |
+
w2, x2, y2, z2 = q_current
|
| 65 |
+
|
| 66 |
+
relative_rotation = torch.tensor([
|
| 67 |
+
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
|
| 68 |
+
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
|
| 69 |
+
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
|
| 70 |
+
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
|
| 71 |
+
])
|
| 72 |
+
|
| 73 |
+
return relative_rotation
|
| 74 |
+
|
| 75 |
+
def generate_direction_poses(direction="left", target_frames=10, condition_frames=20):
|
| 76 |
+
"""
|
| 77 |
+
根据指定方向生成pose类别embedding,包含condition和target帧
|
| 78 |
+
Args:
|
| 79 |
+
direction: 'forward', 'backward', 'left_turn', 'right_turn'
|
| 80 |
+
target_frames: 目标帧数
|
| 81 |
+
condition_frames: 条件帧数
|
| 82 |
+
"""
|
| 83 |
+
classifier = PoseClassifier()
|
| 84 |
+
|
| 85 |
+
total_frames = condition_frames + target_frames
|
| 86 |
+
print(f"conditon{condition_frames}")
|
| 87 |
+
print(f"target{target_frames}")
|
| 88 |
+
poses = []
|
| 89 |
+
|
| 90 |
+
# 🔧 生成condition帧的pose(相对稳定的前向运动)
|
| 91 |
+
for i in range(condition_frames):
|
| 92 |
+
t = i / max(1, condition_frames - 1) # 0 to 1
|
| 93 |
+
|
| 94 |
+
# condition帧保持相对稳定的前向运动
|
| 95 |
+
translation = [-t * 0.5, 0.0, 0.0] # 缓慢前进
|
| 96 |
+
rotation = [1.0, 0.0, 0.0, 0.0] # 无旋转
|
| 97 |
+
frame_type = 0.0 # condition
|
| 98 |
+
|
| 99 |
+
pose_vec = translation + rotation + [frame_type] # 8D vector
|
| 100 |
+
poses.append(pose_vec)
|
| 101 |
+
|
| 102 |
+
# 🔧 生成target帧的pose(根据指定方向)
|
| 103 |
+
for i in range(target_frames):
|
| 104 |
+
t = i / max(1, target_frames - 1) # 0 to 1
|
| 105 |
+
|
| 106 |
+
if direction == "forward":
|
| 107 |
+
# 前进:x负方向移动,无旋转
|
| 108 |
+
translation = [-(condition_frames * 0.5 + t * 2.0), 0.0, 0.0]
|
| 109 |
+
rotation = [1.0, 0.0, 0.0, 0.0] # 单位四元数
|
| 110 |
+
|
| 111 |
+
elif direction == "backward":
|
| 112 |
+
# 后退:x正方向移动,无旋转
|
| 113 |
+
translation = [-(condition_frames * 0.5) + t * 2.0, 0.0, 0.0]
|
| 114 |
+
rotation = [1.0, 0.0, 0.0, 0.0]
|
| 115 |
+
|
| 116 |
+
elif direction == "left_turn":
|
| 117 |
+
# 左转:前进 + 绕z轴正向旋转
|
| 118 |
+
translation = [-(condition_frames * 0.5 + t * 1.5), t * 0.5, 0.0] # 前进并稍微左移
|
| 119 |
+
yaw = t * 0.3 # 左转角度(弧度)
|
| 120 |
+
rotation = [
|
| 121 |
+
np.cos(yaw/2), # w
|
| 122 |
+
0.0, # x
|
| 123 |
+
0.0, # y
|
| 124 |
+
np.sin(yaw/2) # z (左转为正)
|
| 125 |
+
]
|
| 126 |
+
|
| 127 |
+
elif direction == "right_turn":
|
| 128 |
+
# 右转:前进 + 绕z轴负向旋转
|
| 129 |
+
translation = [-(condition_frames * 0.5 + t * 1.5), -t * 0.5, 0.0] # 前进并稍微右移
|
| 130 |
+
yaw = -t * 0.3 # 右转角度(弧度)
|
| 131 |
+
rotation = [
|
| 132 |
+
np.cos(abs(yaw)/2), # w
|
| 133 |
+
0.0, # x
|
| 134 |
+
0.0, # y
|
| 135 |
+
np.sin(yaw/2) # z (右转为负)
|
| 136 |
+
]
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError(f"Unknown direction: {direction}")
|
| 139 |
+
|
| 140 |
+
frame_type = 1.0 # target
|
| 141 |
+
pose_vec = translation + rotation + [frame_type] # 8D vector
|
| 142 |
+
poses.append(pose_vec)
|
| 143 |
+
|
| 144 |
+
pose_sequence = torch.tensor(poses, dtype=torch.float32)
|
| 145 |
+
|
| 146 |
+
# 🔧 只对target部分进行分类(前7维,去掉frame type)
|
| 147 |
+
target_pose_sequence = pose_sequence[condition_frames:, :7]
|
| 148 |
+
|
| 149 |
+
# 🔧 使用增强的embedding生成方法
|
| 150 |
+
condition_classes = torch.full((condition_frames,), 0, dtype=torch.long) # condition都是forward
|
| 151 |
+
target_classes = classifier.classify_pose_sequence(target_pose_sequence)
|
| 152 |
+
full_classes = torch.cat([condition_classes, target_classes], dim=0)
|
| 153 |
+
|
| 154 |
+
# 创建增强的embedding
|
| 155 |
+
class_embeddings = create_enhanced_class_embedding_for_inference(
|
| 156 |
+
full_classes, pose_sequence, embed_dim=512
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
print(f"Generated {direction} poses:")
|
| 160 |
+
print(f" Total frames: {total_frames} (condition: {condition_frames}, target: {target_frames})")
|
| 161 |
+
analysis = classifier.analyze_pose_sequence(target_pose_sequence)
|
| 162 |
+
print(f" Target class distribution: {analysis['class_distribution']}")
|
| 163 |
+
print(f" Target motion segments: {len(analysis['motion_segments'])}")
|
| 164 |
+
|
| 165 |
+
return class_embeddings
|
| 166 |
+
|
| 167 |
+
def create_enhanced_class_embedding_for_inference(class_labels: torch.Tensor, pose_sequence: torch.Tensor, embed_dim: int = 512) -> torch.Tensor:
|
| 168 |
+
"""推理时创建增强的类别embedding"""
|
| 169 |
+
num_classes = 4
|
| 170 |
+
num_frames = len(class_labels)
|
| 171 |
+
|
| 172 |
+
# 基础的方向embedding
|
| 173 |
+
direction_vectors = torch.tensor([
|
| 174 |
+
[1.0, 0.0, 0.0, 0.0], # forward
|
| 175 |
+
[-1.0, 0.0, 0.0, 0.0], # backward
|
| 176 |
+
[0.0, 1.0, 0.0, 0.0], # left_turn
|
| 177 |
+
[0.0, -1.0, 0.0, 0.0], # right_turn
|
| 178 |
+
], dtype=torch.float32)
|
| 179 |
+
|
| 180 |
+
# One-hot编码
|
| 181 |
+
one_hot = torch.zeros(num_frames, num_classes)
|
| 182 |
+
one_hot.scatter_(1, class_labels.unsqueeze(1), 1)
|
| 183 |
+
|
| 184 |
+
# 基于方向向量的基础embedding
|
| 185 |
+
base_embeddings = one_hot @ direction_vectors # [num_frames, 4]
|
| 186 |
+
|
| 187 |
+
# 添加frame type信息
|
| 188 |
+
frame_types = pose_sequence[:, -1] # 最后一维是frame type
|
| 189 |
+
frame_type_embeddings = torch.zeros(num_frames, 2)
|
| 190 |
+
frame_type_embeddings[:, 0] = (frame_types == 0).float() # condition
|
| 191 |
+
frame_type_embeddings[:, 1] = (frame_types == 1).float() # target
|
| 192 |
+
|
| 193 |
+
# 添加pose的几何信息
|
| 194 |
+
translations = pose_sequence[:, :3] # [num_frames, 3]
|
| 195 |
+
rotations = pose_sequence[:, 3:7] # [num_frames, 4]
|
| 196 |
+
|
| 197 |
+
# 组合所有特征
|
| 198 |
+
combined_features = torch.cat([
|
| 199 |
+
base_embeddings, # [num_frames, 4]
|
| 200 |
+
frame_type_embeddings, # [num_frames, 2]
|
| 201 |
+
translations, # [num_frames, 3]
|
| 202 |
+
rotations, # [num_frames, 4]
|
| 203 |
+
], dim=1) # [num_frames, 13]
|
| 204 |
+
|
| 205 |
+
# 扩展到目标维度
|
| 206 |
+
if embed_dim > 13:
|
| 207 |
+
expand_matrix = torch.randn(13, embed_dim) * 0.1
|
| 208 |
+
expand_matrix[:13, :13] = torch.eye(13)
|
| 209 |
+
embeddings = combined_features @ expand_matrix
|
| 210 |
+
else:
|
| 211 |
+
embeddings = combined_features[:, :embed_dim]
|
| 212 |
+
|
| 213 |
+
return embeddings
|
| 214 |
+
|
| 215 |
+
def generate_poses_from_file(poses_path, target_frames=10):
|
| 216 |
+
"""从poses.json文件生成类别embedding"""
|
| 217 |
+
classifier = PoseClassifier()
|
| 218 |
+
|
| 219 |
+
with open(poses_path, 'r') as f:
|
| 220 |
+
poses_data = json.load(f)
|
| 221 |
+
|
| 222 |
+
target_relative_poses = poses_data['target_relative_poses']
|
| 223 |
+
|
| 224 |
+
if not target_relative_poses:
|
| 225 |
+
print("No poses found in file, using forward direction")
|
| 226 |
+
return generate_direction_poses("forward", target_frames)
|
| 227 |
+
|
| 228 |
+
# 创建pose序列
|
| 229 |
+
pose_vecs = []
|
| 230 |
+
for i in range(target_frames):
|
| 231 |
+
if len(target_relative_poses) == 1:
|
| 232 |
+
pose_data = target_relative_poses[0]
|
| 233 |
+
else:
|
| 234 |
+
pose_idx = min(i * len(target_relative_poses) // target_frames,
|
| 235 |
+
len(target_relative_poses) - 1)
|
| 236 |
+
pose_data = target_relative_poses[pose_idx]
|
| 237 |
+
|
| 238 |
+
# 提取相对位移和旋转
|
| 239 |
+
translation = torch.tensor(pose_data['relative_translation'], dtype=torch.float32)
|
| 240 |
+
current_rotation = torch.tensor(pose_data['current_rotation'], dtype=torch.float32)
|
| 241 |
+
reference_rotation = torch.tensor(pose_data['reference_rotation'], dtype=torch.float32)
|
| 242 |
+
|
| 243 |
+
# 计算相对旋转
|
| 244 |
+
relative_rotation = calculate_relative_rotation(current_rotation, reference_rotation)
|
| 245 |
+
|
| 246 |
+
# 组合为7D向量
|
| 247 |
+
pose_vec = torch.cat([translation, relative_rotation], dim=0)
|
| 248 |
+
pose_vecs.append(pose_vec)
|
| 249 |
+
|
| 250 |
+
pose_sequence = torch.stack(pose_vecs, dim=0)
|
| 251 |
+
|
| 252 |
+
# 使用分类器生成class embedding
|
| 253 |
+
class_embeddings = classifier.create_class_embedding(
|
| 254 |
+
classifier.classify_pose_sequence(pose_sequence),
|
| 255 |
+
embed_dim=512
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
print(f"Generated poses from file:")
|
| 259 |
+
analysis = classifier.analyze_pose_sequence(pose_sequence)
|
| 260 |
+
print(f" Class distribution: {analysis['class_distribution']}")
|
| 261 |
+
print(f" Motion segments: {len(analysis['motion_segments'])}")
|
| 262 |
+
|
| 263 |
+
return class_embeddings
|
| 264 |
+
|
| 265 |
+
def inference_nuscenes_video(
|
| 266 |
+
condition_video_path,
|
| 267 |
+
dit_path,
|
| 268 |
+
text_encoder_path,
|
| 269 |
+
vae_path,
|
| 270 |
+
output_path="nus/infer_results/output_nuscenes.mp4",
|
| 271 |
+
condition_frames=20,
|
| 272 |
+
target_frames=3,
|
| 273 |
+
height=900,
|
| 274 |
+
width=1600,
|
| 275 |
+
device="cuda",
|
| 276 |
+
prompt="A car driving scene captured by front camera",
|
| 277 |
+
poses_path=None,
|
| 278 |
+
direction="forward"
|
| 279 |
+
):
|
| 280 |
+
"""
|
| 281 |
+
使用方向类别控制的推理函数 - 支持condition和target pose区分
|
| 282 |
+
"""
|
| 283 |
+
os.makedirs(os.path.dirname(output_path),exist_ok=True)
|
| 284 |
+
|
| 285 |
+
print(f"Setting up models for {direction} movement...")
|
| 286 |
+
|
| 287 |
+
# 1. Load models (same as before)
|
| 288 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 289 |
+
model_manager.load_models([
|
| 290 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 291 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 292 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 293 |
+
])
|
| 294 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 295 |
+
|
| 296 |
+
# Add camera components to DiT
|
| 297 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 298 |
+
for block in pipe.dit.blocks:
|
| 299 |
+
block.cam_encoder = nn.Linear(512, dim) # 保持512维embedding
|
| 300 |
+
block.projector = nn.Linear(dim, dim)
|
| 301 |
+
block.cam_encoder.weight.data.zero_()
|
| 302 |
+
block.cam_encoder.bias.data.zero_()
|
| 303 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 304 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 305 |
+
|
| 306 |
+
# Load trained DiT weights
|
| 307 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 308 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=True)
|
| 309 |
+
pipe = pipe.to(device)
|
| 310 |
+
pipe.scheduler.set_timesteps(50)
|
| 311 |
+
|
| 312 |
+
print("Loading condition video...")
|
| 313 |
+
|
| 314 |
+
# Load condition video
|
| 315 |
+
condition_video = load_video_frames(
|
| 316 |
+
condition_video_path,
|
| 317 |
+
num_frames=condition_frames,
|
| 318 |
+
height=height,
|
| 319 |
+
width=width
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
if condition_video is None:
|
| 323 |
+
raise ValueError(f"Failed to load condition video from {condition_video_path}")
|
| 324 |
+
|
| 325 |
+
condition_video = condition_video.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
|
| 326 |
+
|
| 327 |
+
print("Processing poses...")
|
| 328 |
+
|
| 329 |
+
# 🔧 修改:生成包含condition和target的pose embedding
|
| 330 |
+
print(f"Generating {direction} movement poses...")
|
| 331 |
+
camera_embedding = generate_direction_poses(
|
| 332 |
+
direction=direction,
|
| 333 |
+
target_frames=target_frames,
|
| 334 |
+
condition_frames=int(condition_frames/4) # 压缩后的condition帧数
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
camera_embedding = camera_embedding.unsqueeze(0).to(device, dtype=torch.bfloat16)
|
| 338 |
+
|
| 339 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 340 |
+
print(f"Generated poses for direction: {direction}")
|
| 341 |
+
|
| 342 |
+
print("Encoding inputs...")
|
| 343 |
+
|
| 344 |
+
# Encode text prompt
|
| 345 |
+
prompt_emb = pipe.encode_prompt(prompt)
|
| 346 |
+
|
| 347 |
+
# Encode condition video
|
| 348 |
+
condition_latents = pipe.encode_video(condition_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))[0]
|
| 349 |
+
|
| 350 |
+
print("Generating video...")
|
| 351 |
+
|
| 352 |
+
# Generate target latents
|
| 353 |
+
batch_size = 1
|
| 354 |
+
channels = condition_latents.shape[0]
|
| 355 |
+
latent_height = condition_latents.shape[2]
|
| 356 |
+
latent_width = condition_latents.shape[3]
|
| 357 |
+
target_height, target_width = 60, 104 # 根据你的需求调整
|
| 358 |
+
|
| 359 |
+
if latent_height > target_height or latent_width > target_width:
|
| 360 |
+
# 中心裁剪
|
| 361 |
+
h_start = (latent_height - target_height) // 2
|
| 362 |
+
w_start = (latent_width - target_width) // 2
|
| 363 |
+
condition_latents = condition_latents[:, :,
|
| 364 |
+
h_start:h_start+target_height,
|
| 365 |
+
w_start:w_start+target_width]
|
| 366 |
+
latent_height = target_height
|
| 367 |
+
latent_width = target_width
|
| 368 |
+
condition_latents = condition_latents.to(device, dtype=pipe.torch_dtype)
|
| 369 |
+
condition_latents = condition_latents.unsqueeze(0)
|
| 370 |
+
condition_latents = condition_latents + 0.05 * torch.randn_like(condition_latents) # 添加少量噪声以增加多样性
|
| 371 |
+
|
| 372 |
+
# Initialize target latents with noise
|
| 373 |
+
target_latents = torch.randn(
|
| 374 |
+
batch_size, channels, target_frames, latent_height, latent_width,
|
| 375 |
+
device=device, dtype=pipe.torch_dtype
|
| 376 |
+
)
|
| 377 |
+
print(target_latents.shape)
|
| 378 |
+
print(camera_embedding.shape)
|
| 379 |
+
# Combine condition and target latents
|
| 380 |
+
combined_latents = torch.cat([condition_latents, target_latents], dim=2)
|
| 381 |
+
print(combined_latents.shape)
|
| 382 |
+
|
| 383 |
+
# Prepare extra inputs
|
| 384 |
+
extra_input = pipe.prepare_extra_input(combined_latents)
|
| 385 |
+
|
| 386 |
+
# Denoising loop
|
| 387 |
+
timesteps = pipe.scheduler.timesteps
|
| 388 |
+
|
| 389 |
+
for i, timestep in enumerate(timesteps):
|
| 390 |
+
print(f"Denoising step {i+1}/{len(timesteps)}")
|
| 391 |
+
|
| 392 |
+
# Prepare timestep
|
| 393 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
|
| 394 |
+
|
| 395 |
+
# Predict noise
|
| 396 |
+
with torch.no_grad():
|
| 397 |
+
noise_pred = pipe.dit(
|
| 398 |
+
combined_latents,
|
| 399 |
+
timestep=timestep_tensor,
|
| 400 |
+
cam_emb=camera_embedding,
|
| 401 |
+
**prompt_emb,
|
| 402 |
+
**extra_input
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# Update only target part
|
| 406 |
+
target_noise_pred = noise_pred[:, :, int(condition_frames/4):, :, :]
|
| 407 |
+
target_latents = pipe.scheduler.step(target_noise_pred, timestep, target_latents)
|
| 408 |
+
|
| 409 |
+
# Update combined latents
|
| 410 |
+
combined_latents[:, :, int(condition_frames/4):, :, :] = target_latents
|
| 411 |
+
|
| 412 |
+
print("Decoding video...")
|
| 413 |
+
|
| 414 |
+
# Decode final video
|
| 415 |
+
final_video = torch.cat([condition_latents, target_latents], dim=2)
|
| 416 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 417 |
+
|
| 418 |
+
# Save video
|
| 419 |
+
print(f"Saving video to {output_path}")
|
| 420 |
+
|
| 421 |
+
# Convert to numpy and save
|
| 422 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() # 转换为 Float32
|
| 423 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1) # Denormalize
|
| 424 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 425 |
+
|
| 426 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 427 |
+
for frame in video_np:
|
| 428 |
+
writer.append_data(frame)
|
| 429 |
+
|
| 430 |
+
print(f"Video generation completed! Saved to {output_path}")
|
| 431 |
+
|
| 432 |
+
def main():
|
| 433 |
+
parser = argparse.ArgumentParser(description="NuScenes Video Generation Inference with Direction Control")
|
| 434 |
+
parser.add_argument("--condition_video", type=str, default="/home/zhuyixuan05/ReCamMaster/nus/videos/4032/right.mp4",
|
| 435 |
+
help="Path to condition video")
|
| 436 |
+
parser.add_argument("--direction", type=str, default="left_turn",
|
| 437 |
+
choices=["forward", "backward", "left_turn", "right_turn"],
|
| 438 |
+
help="Direction of camera movement")
|
| 439 |
+
parser.add_argument("--dit_path", type=str, default="/home/zhuyixuan05/ReCamMaster/nus_dynamic/step15000_dynamic.ckpt",
|
| 440 |
+
help="Path to trained DiT checkpoint")
|
| 441 |
+
parser.add_argument("--text_encoder_path", type=str,
|
| 442 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 443 |
+
help="Path to text encoder")
|
| 444 |
+
parser.add_argument("--vae_path", type=str,
|
| 445 |
+
default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 446 |
+
help="Path to VAE")
|
| 447 |
+
parser.add_argument("--output_path", type=str, default="nus/infer_results-15000/right_left.mp4",
|
| 448 |
+
help="Output video path")
|
| 449 |
+
parser.add_argument("--poses_path", type=str, default=None,
|
| 450 |
+
help="Path to poses.json file (optional, will use direction if not provided)")
|
| 451 |
+
parser.add_argument("--prompt", type=str,
|
| 452 |
+
default="A car driving scene captured by front camera",
|
| 453 |
+
help="Text prompt for generation")
|
| 454 |
+
parser.add_argument("--condition_frames", type=int, default=40,
|
| 455 |
+
help="Number of condition frames")
|
| 456 |
+
# 这个是原始帧数
|
| 457 |
+
parser.add_argument("--target_frames", type=int, default=8,
|
| 458 |
+
help="Number of target frames to generate")
|
| 459 |
+
# 这个要除以4
|
| 460 |
+
parser.add_argument("--height", type=int, default=900,
|
| 461 |
+
help="Video height")
|
| 462 |
+
parser.add_argument("--width", type=int, default=1600,
|
| 463 |
+
help="Video width")
|
| 464 |
+
parser.add_argument("--device", type=str, default="cuda",
|
| 465 |
+
help="Device to run inference on")
|
| 466 |
+
|
| 467 |
+
args = parser.parse_args()
|
| 468 |
+
|
| 469 |
+
condition_video_path = args.condition_video
|
| 470 |
+
input_filename = os.path.basename(condition_video_path)
|
| 471 |
+
output_dir = "nus/infer_results"
|
| 472 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 473 |
+
|
| 474 |
+
# 🔧 修改:在输出文件名中包含方向信息
|
| 475 |
+
if args.output_path is None:
|
| 476 |
+
name_parts = os.path.splitext(input_filename)
|
| 477 |
+
output_filename = f"{name_parts[0]}_{args.direction}{name_parts[1]}"
|
| 478 |
+
output_path = os.path.join(output_dir, output_filename)
|
| 479 |
+
else:
|
| 480 |
+
output_path = args.output_path
|
| 481 |
+
|
| 482 |
+
print(f"Output video will be saved to: {output_path}")
|
| 483 |
+
inference_nuscenes_video(
|
| 484 |
+
condition_video_path=args.condition_video,
|
| 485 |
+
dit_path=args.dit_path,
|
| 486 |
+
text_encoder_path=args.text_encoder_path,
|
| 487 |
+
vae_path=args.vae_path,
|
| 488 |
+
output_path=output_path,
|
| 489 |
+
condition_frames=args.condition_frames,
|
| 490 |
+
target_frames=args.target_frames,
|
| 491 |
+
height=args.height,
|
| 492 |
+
width=args.width,
|
| 493 |
+
device=args.device,
|
| 494 |
+
prompt=args.prompt,
|
| 495 |
+
poses_path=args.poses_path,
|
| 496 |
+
direction=args.direction # 🔧 新增
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
if __name__ == "__main__":
|
| 500 |
+
main()
|
scripts/infer_openx.py
ADDED
|
@@ -0,0 +1,614 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 3 |
+
from torchvision.transforms import v2
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import argparse
|
| 9 |
+
import numpy as np
|
| 10 |
+
import imageio
|
| 11 |
+
import copy
|
| 12 |
+
import random
|
| 13 |
+
|
| 14 |
+
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 15 |
+
"""从pth文件加载预编码的视频数据"""
|
| 16 |
+
print(f"Loading encoded video from {pth_path}")
|
| 17 |
+
|
| 18 |
+
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
| 19 |
+
full_latents = encoded_data['latents'] # [C, T, H, W]
|
| 20 |
+
|
| 21 |
+
print(f"Full latents shape: {full_latents.shape}")
|
| 22 |
+
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
|
| 23 |
+
|
| 24 |
+
if start_frame + num_frames > full_latents.shape[1]:
|
| 25 |
+
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
|
| 26 |
+
|
| 27 |
+
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
|
| 28 |
+
print(f"Extracted condition latents shape: {condition_latents.shape}")
|
| 29 |
+
|
| 30 |
+
return condition_latents, encoded_data
|
| 31 |
+
|
| 32 |
+
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 33 |
+
"""计算相机B相对于相机A的相对位姿矩阵"""
|
| 34 |
+
assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
|
| 35 |
+
assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
|
| 36 |
+
|
| 37 |
+
if use_torch:
|
| 38 |
+
if not isinstance(pose_a, torch.Tensor):
|
| 39 |
+
pose_a = torch.from_numpy(pose_a).float()
|
| 40 |
+
if not isinstance(pose_b, torch.Tensor):
|
| 41 |
+
pose_b = torch.from_numpy(pose_b).float()
|
| 42 |
+
|
| 43 |
+
pose_a_inv = torch.inverse(pose_a)
|
| 44 |
+
relative_pose = torch.matmul(pose_b, pose_a_inv)
|
| 45 |
+
else:
|
| 46 |
+
if not isinstance(pose_a, np.ndarray):
|
| 47 |
+
pose_a = np.array(pose_a, dtype=np.float32)
|
| 48 |
+
if not isinstance(pose_b, np.ndarray):
|
| 49 |
+
pose_b = np.array(pose_b, dtype=np.float32)
|
| 50 |
+
|
| 51 |
+
pose_a_inv = np.linalg.inv(pose_a)
|
| 52 |
+
relative_pose = np.matmul(pose_b, pose_a_inv)
|
| 53 |
+
|
| 54 |
+
return relative_pose
|
| 55 |
+
|
| 56 |
+
def replace_dit_model_in_manager():
|
| 57 |
+
"""在模型加载前替换DiT模型类"""
|
| 58 |
+
from diffsynth.models.wan_video_dit_recam_future import WanModelFuture
|
| 59 |
+
from diffsynth.configs.model_config import model_loader_configs
|
| 60 |
+
|
| 61 |
+
# 修改model_loader_configs中的配置
|
| 62 |
+
for i, config in enumerate(model_loader_configs):
|
| 63 |
+
keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
|
| 64 |
+
|
| 65 |
+
# 检查是否包含wan_video_dit模型
|
| 66 |
+
if 'wan_video_dit' in model_names:
|
| 67 |
+
# 找到wan_video_dit的索引并替换为WanModelFuture
|
| 68 |
+
new_model_names = []
|
| 69 |
+
new_model_classes = []
|
| 70 |
+
|
| 71 |
+
for name, cls in zip(model_names, model_classes):
|
| 72 |
+
if name == 'wan_video_dit':
|
| 73 |
+
new_model_names.append(name) # 保持名称不变
|
| 74 |
+
new_model_classes.append(WanModelFuture) # 替换为新的类
|
| 75 |
+
print(f"✅ 替换了模型类: {name} -> WanModelFuture")
|
| 76 |
+
else:
|
| 77 |
+
new_model_names.append(name)
|
| 78 |
+
new_model_classes.append(cls)
|
| 79 |
+
|
| 80 |
+
# 更新配置
|
| 81 |
+
model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
|
| 82 |
+
|
| 83 |
+
def add_framepack_components(dit_model):
|
| 84 |
+
"""添加FramePack相关组件"""
|
| 85 |
+
if not hasattr(dit_model, 'clean_x_embedder'):
|
| 86 |
+
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 87 |
+
|
| 88 |
+
class CleanXEmbedder(nn.Module):
|
| 89 |
+
def __init__(self, inner_dim):
|
| 90 |
+
super().__init__()
|
| 91 |
+
# 参考hunyuan_video_packed.py的设计
|
| 92 |
+
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 93 |
+
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
| 94 |
+
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
| 95 |
+
|
| 96 |
+
def forward(self, x, scale="1x"):
|
| 97 |
+
if scale == "1x":
|
| 98 |
+
return self.proj(x)
|
| 99 |
+
elif scale == "2x":
|
| 100 |
+
return self.proj_2x(x)
|
| 101 |
+
elif scale == "4x":
|
| 102 |
+
return self.proj_4x(x)
|
| 103 |
+
else:
|
| 104 |
+
raise ValueError(f"Unsupported scale: {scale}")
|
| 105 |
+
|
| 106 |
+
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
|
| 107 |
+
model_dtype = next(dit_model.parameters()).dtype
|
| 108 |
+
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
|
| 109 |
+
print("✅ 添加了FramePack的clean_x_embedder组件")
|
| 110 |
+
|
| 111 |
+
def generate_openx_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
|
| 112 |
+
"""为OpenX数据集生成camera embeddings - 滑动窗口版本"""
|
| 113 |
+
time_compression_ratio = 4
|
| 114 |
+
|
| 115 |
+
# 计算FramePack实际需要的camera帧数
|
| 116 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 117 |
+
|
| 118 |
+
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
|
| 119 |
+
print("🔧 使用真实OpenX camera数据")
|
| 120 |
+
cam_extrinsic = cam_data['extrinsic']
|
| 121 |
+
|
| 122 |
+
# 确保生成足够长的camera序列
|
| 123 |
+
max_needed_frames = max(
|
| 124 |
+
start_frame + current_history_length + new_frames,
|
| 125 |
+
framepack_needed_frames,
|
| 126 |
+
30
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
print(f"🔧 计算OpenX camera序列长度:")
|
| 130 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 131 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 132 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 133 |
+
|
| 134 |
+
relative_poses = []
|
| 135 |
+
for i in range(max_needed_frames):
|
| 136 |
+
# OpenX特有:每隔4帧
|
| 137 |
+
frame_idx = i * time_compression_ratio
|
| 138 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 139 |
+
|
| 140 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 141 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 142 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 143 |
+
relative_cam = compute_relative_pose(cam_prev, cam_next)
|
| 144 |
+
relative_poses.append(torch.as_tensor(relative_cam[:3, :]))
|
| 145 |
+
else:
|
| 146 |
+
# 超出范围,使用零运动
|
| 147 |
+
print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
|
| 148 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 149 |
+
|
| 150 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 151 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 152 |
+
|
| 153 |
+
# 创建对应长度的mask序列
|
| 154 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 155 |
+
# 从start_frame到current_history_length标记为condition
|
| 156 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 157 |
+
mask[start_frame:condition_end] = 1.0
|
| 158 |
+
|
| 159 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 160 |
+
print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
|
| 161 |
+
return camera_embedding.to(torch.bfloat16)
|
| 162 |
+
|
| 163 |
+
else:
|
| 164 |
+
print("🔧 使用OpenX合成camera数据")
|
| 165 |
+
|
| 166 |
+
max_needed_frames = max(
|
| 167 |
+
start_frame + current_history_length + new_frames,
|
| 168 |
+
framepack_needed_frames,
|
| 169 |
+
30
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
|
| 173 |
+
relative_poses = []
|
| 174 |
+
for i in range(max_needed_frames):
|
| 175 |
+
# OpenX机器人操作模式 - 稳定的小幅度运动
|
| 176 |
+
# 模拟机器人手臂的精细操作
|
| 177 |
+
forward_speed = 0.001 # 每帧前进距离(很小,因为是精细操作)
|
| 178 |
+
lateral_motion = 0.0005 * np.sin(i * 0.05) # 轻微的左右移动
|
| 179 |
+
vertical_motion = 0.0003 * np.cos(i * 0.1) # 轻微的上下移动
|
| 180 |
+
|
| 181 |
+
# 旋转变化(模拟视角微调)
|
| 182 |
+
yaw_change = 0.01 * np.sin(i * 0.03) # 轻微的偏航
|
| 183 |
+
pitch_change = 0.008 * np.cos(i * 0.04) # 轻微的俯仰
|
| 184 |
+
|
| 185 |
+
pose = np.eye(4, dtype=np.float32)
|
| 186 |
+
|
| 187 |
+
# 旋转矩阵(绕Y轴和X轴的小角度旋转)
|
| 188 |
+
cos_yaw = np.cos(yaw_change)
|
| 189 |
+
sin_yaw = np.sin(yaw_change)
|
| 190 |
+
cos_pitch = np.cos(pitch_change)
|
| 191 |
+
sin_pitch = np.sin(pitch_change)
|
| 192 |
+
|
| 193 |
+
# 组合旋转(先pitch后yaw)
|
| 194 |
+
pose[0, 0] = cos_yaw
|
| 195 |
+
pose[0, 2] = sin_yaw
|
| 196 |
+
pose[1, 1] = cos_pitch
|
| 197 |
+
pose[1, 2] = -sin_pitch
|
| 198 |
+
pose[2, 0] = -sin_yaw
|
| 199 |
+
pose[2, 1] = sin_pitch
|
| 200 |
+
pose[2, 2] = cos_yaw * cos_pitch
|
| 201 |
+
|
| 202 |
+
# 平移(精细操作的小幅度移动)
|
| 203 |
+
pose[0, 3] = lateral_motion # X轴(左右)
|
| 204 |
+
pose[1, 3] = vertical_motion # Y轴(上下)
|
| 205 |
+
pose[2, 3] = -forward_speed # Z轴(前后,负值表示前进)
|
| 206 |
+
|
| 207 |
+
relative_pose = pose[:3, :]
|
| 208 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 209 |
+
|
| 210 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 211 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 212 |
+
|
| 213 |
+
# 创建对应长度的mask序列
|
| 214 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 215 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 216 |
+
mask[start_frame:condition_end] = 1.0
|
| 217 |
+
|
| 218 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 219 |
+
print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
|
| 220 |
+
return camera_embedding.to(torch.bfloat16)
|
| 221 |
+
|
| 222 |
+
def prepare_framepack_sliding_window_with_camera(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, max_history_frames=49):
|
| 223 |
+
"""FramePack滑动��口机制 - OpenX版本"""
|
| 224 |
+
# history_latents: [C, T, H, W] 当前的历史latents
|
| 225 |
+
C, T, H, W = history_latents.shape
|
| 226 |
+
|
| 227 |
+
# 固定索引结构(这决定了需要的camera帧数)
|
| 228 |
+
total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
|
| 229 |
+
indices = torch.arange(0, total_indices_length)
|
| 230 |
+
split_sizes = [1, 16, 2, 1, target_frames_to_generate]
|
| 231 |
+
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
|
| 232 |
+
indices.split(split_sizes, dim=0)
|
| 233 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
|
| 234 |
+
|
| 235 |
+
# 检查camera长度是否足够
|
| 236 |
+
if camera_embedding_full.shape[0] < total_indices_length:
|
| 237 |
+
shortage = total_indices_length - camera_embedding_full.shape[0]
|
| 238 |
+
padding = torch.zeros(shortage, camera_embedding_full.shape[1],
|
| 239 |
+
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
|
| 240 |
+
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
|
| 241 |
+
|
| 242 |
+
# 从完整camera序列中选取对应部分
|
| 243 |
+
combined_camera = camera_embedding_full[:total_indices_length, :].clone()
|
| 244 |
+
|
| 245 |
+
# 根据当前history length重新设置mask
|
| 246 |
+
combined_camera[:, -1] = 0.0 # 先全部设为target (0)
|
| 247 |
+
|
| 248 |
+
# 设置condition mask:前19帧根据实际历史长度决定
|
| 249 |
+
if T > 0:
|
| 250 |
+
available_frames = min(T, 19)
|
| 251 |
+
start_pos = 19 - available_frames
|
| 252 |
+
combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
|
| 253 |
+
|
| 254 |
+
print(f"🔧 OpenX Camera mask更新:")
|
| 255 |
+
print(f" - 历史帧数: {T}")
|
| 256 |
+
print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
|
| 257 |
+
|
| 258 |
+
# 处理latents
|
| 259 |
+
clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 260 |
+
|
| 261 |
+
if T > 0:
|
| 262 |
+
available_frames = min(T, 19)
|
| 263 |
+
start_pos = 19 - available_frames
|
| 264 |
+
clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
|
| 265 |
+
|
| 266 |
+
clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
|
| 267 |
+
clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
|
| 268 |
+
clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
|
| 269 |
+
|
| 270 |
+
if T > 0:
|
| 271 |
+
start_latent = history_latents[:, 0:1, :, :]
|
| 272 |
+
else:
|
| 273 |
+
start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 274 |
+
|
| 275 |
+
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
|
| 276 |
+
|
| 277 |
+
return {
|
| 278 |
+
'latent_indices': latent_indices,
|
| 279 |
+
'clean_latents': clean_latents,
|
| 280 |
+
'clean_latents_2x': clean_latents_2x,
|
| 281 |
+
'clean_latents_4x': clean_latents_4x,
|
| 282 |
+
'clean_latent_indices': clean_latent_indices,
|
| 283 |
+
'clean_latent_2x_indices': clean_latent_2x_indices,
|
| 284 |
+
'clean_latent_4x_indices': clean_latent_4x_indices,
|
| 285 |
+
'camera_embedding': combined_camera,
|
| 286 |
+
'current_length': T,
|
| 287 |
+
'next_length': T + target_frames_to_generate
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
def inference_openx_framepack_sliding_window(
|
| 291 |
+
condition_pth_path,
|
| 292 |
+
dit_path,
|
| 293 |
+
output_path="openx_results/output_openx_framepack_sliding.mp4",
|
| 294 |
+
start_frame=0,
|
| 295 |
+
initial_condition_frames=8,
|
| 296 |
+
frames_per_generation=4,
|
| 297 |
+
total_frames_to_generate=32,
|
| 298 |
+
max_history_frames=49,
|
| 299 |
+
device="cuda",
|
| 300 |
+
prompt="A video of robotic manipulation task with camera movement",
|
| 301 |
+
use_real_poses=True,
|
| 302 |
+
# CFG参数
|
| 303 |
+
use_camera_cfg=True,
|
| 304 |
+
camera_guidance_scale=2.0,
|
| 305 |
+
text_guidance_scale=1.0
|
| 306 |
+
):
|
| 307 |
+
"""
|
| 308 |
+
OpenX FramePack滑动窗口视频生成
|
| 309 |
+
"""
|
| 310 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 311 |
+
print(f"🔧 OpenX FramePack滑动窗口生成开始...")
|
| 312 |
+
print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
|
| 313 |
+
print(f"Text guidance scale: {text_guidance_scale}")
|
| 314 |
+
|
| 315 |
+
# 1. 模型初始化
|
| 316 |
+
replace_dit_model_in_manager()
|
| 317 |
+
|
| 318 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 319 |
+
model_manager.load_models([
|
| 320 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 321 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 322 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 323 |
+
])
|
| 324 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 325 |
+
|
| 326 |
+
# 2. 添加camera编码器
|
| 327 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 328 |
+
for block in pipe.dit.blocks:
|
| 329 |
+
block.cam_encoder = nn.Linear(13, dim)
|
| 330 |
+
block.projector = nn.Linear(dim, dim)
|
| 331 |
+
block.cam_encoder.weight.data.zero_()
|
| 332 |
+
block.cam_encoder.bias.data.zero_()
|
| 333 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 334 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 335 |
+
|
| 336 |
+
# 3. 添加FramePack组件
|
| 337 |
+
add_framepack_components(pipe.dit)
|
| 338 |
+
|
| 339 |
+
# 4. 加载训练好的权重
|
| 340 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 341 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=True)
|
| 342 |
+
pipe = pipe.to(device)
|
| 343 |
+
model_dtype = next(pipe.dit.parameters()).dtype
|
| 344 |
+
|
| 345 |
+
if hasattr(pipe.dit, 'clean_x_embedder'):
|
| 346 |
+
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
|
| 347 |
+
|
| 348 |
+
pipe.scheduler.set_timesteps(50)
|
| 349 |
+
|
| 350 |
+
# 5. 加载初始条件
|
| 351 |
+
print("Loading initial condition frames...")
|
| 352 |
+
initial_latents, encoded_data = load_encoded_video_from_pth(
|
| 353 |
+
condition_pth_path,
|
| 354 |
+
start_frame=start_frame,
|
| 355 |
+
num_frames=initial_condition_frames
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# 空间裁剪(适配OpenX数据尺寸)
|
| 359 |
+
target_height, target_width = 60, 104
|
| 360 |
+
C, T, H, W = initial_latents.shape
|
| 361 |
+
|
| 362 |
+
if H > target_height or W > target_width:
|
| 363 |
+
h_start = (H - target_height) // 2
|
| 364 |
+
w_start = (W - target_width) // 2
|
| 365 |
+
initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
|
| 366 |
+
H, W = target_height, target_width
|
| 367 |
+
|
| 368 |
+
history_latents = initial_latents.to(device, dtype=model_dtype)
|
| 369 |
+
|
| 370 |
+
print(f"初始history_latents shape: {history_latents.shape}")
|
| 371 |
+
|
| 372 |
+
# 6. 编码prompt - 支持CFG
|
| 373 |
+
if text_guidance_scale > 1.0:
|
| 374 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 375 |
+
prompt_emb_neg = pipe.encode_prompt("")
|
| 376 |
+
print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
|
| 377 |
+
else:
|
| 378 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 379 |
+
prompt_emb_neg = None
|
| 380 |
+
print("不使用Text CFG")
|
| 381 |
+
|
| 382 |
+
# 7. 预生成完整的camera embedding序列
|
| 383 |
+
camera_embedding_full = generate_openx_camera_embeddings_sliding(
|
| 384 |
+
encoded_data.get('cam_emb', None),
|
| 385 |
+
0,
|
| 386 |
+
max_history_frames,
|
| 387 |
+
0,
|
| 388 |
+
0,
|
| 389 |
+
use_real_poses=use_real_poses
|
| 390 |
+
).to(device, dtype=model_dtype)
|
| 391 |
+
|
| 392 |
+
print(f"完整camera序列shape: {camera_embedding_full.shape}")
|
| 393 |
+
|
| 394 |
+
# 8. 为Camera CFG创建无条件的camera embedding
|
| 395 |
+
if use_camera_cfg:
|
| 396 |
+
camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
|
| 397 |
+
print(f"创建无条件camera embedding用于CFG")
|
| 398 |
+
|
| 399 |
+
# 9. 滑动窗口生成循环
|
| 400 |
+
total_generated = 0
|
| 401 |
+
all_generated_frames = []
|
| 402 |
+
|
| 403 |
+
while total_generated < total_frames_to_generate:
|
| 404 |
+
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
|
| 405 |
+
print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
|
| 406 |
+
print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
|
| 407 |
+
|
| 408 |
+
# FramePack数据准备 - OpenX版本
|
| 409 |
+
framepack_data = prepare_framepack_sliding_window_with_camera(
|
| 410 |
+
history_latents,
|
| 411 |
+
current_generation,
|
| 412 |
+
camera_embedding_full,
|
| 413 |
+
start_frame,
|
| 414 |
+
max_history_frames
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
# 准备输入
|
| 418 |
+
clean_latents = framepack_data['clean_latents'].unsqueeze(0)
|
| 419 |
+
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
|
| 420 |
+
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
|
| 421 |
+
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
|
| 422 |
+
|
| 423 |
+
# 为CFG准备无条件camera embedding
|
| 424 |
+
if use_camera_cfg:
|
| 425 |
+
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
|
| 426 |
+
|
| 427 |
+
# 索引处理
|
| 428 |
+
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
|
| 429 |
+
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
|
| 430 |
+
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
|
| 431 |
+
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
|
| 432 |
+
|
| 433 |
+
# 初始化要生成的latents
|
| 434 |
+
new_latents = torch.randn(
|
| 435 |
+
1, C, current_generation, H, W,
|
| 436 |
+
device=device, dtype=model_dtype
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
extra_input = pipe.prepare_extra_input(new_latents)
|
| 440 |
+
|
| 441 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 442 |
+
print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
|
| 443 |
+
|
| 444 |
+
# 去噪循环 - 支持CFG
|
| 445 |
+
timesteps = pipe.scheduler.timesteps
|
| 446 |
+
|
| 447 |
+
for i, timestep in enumerate(timesteps):
|
| 448 |
+
if i % 10 == 0:
|
| 449 |
+
print(f" 去噪步骤 {i}/{len(timesteps)}")
|
| 450 |
+
|
| 451 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
|
| 452 |
+
|
| 453 |
+
with torch.no_grad():
|
| 454 |
+
# 正向预测(带条件)
|
| 455 |
+
noise_pred_pos = pipe.dit(
|
| 456 |
+
new_latents,
|
| 457 |
+
timestep=timestep_tensor,
|
| 458 |
+
cam_emb=camera_embedding,
|
| 459 |
+
latent_indices=latent_indices,
|
| 460 |
+
clean_latents=clean_latents,
|
| 461 |
+
clean_latent_indices=clean_latent_indices,
|
| 462 |
+
clean_latents_2x=clean_latents_2x,
|
| 463 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 464 |
+
clean_latents_4x=clean_latents_4x,
|
| 465 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 466 |
+
**prompt_emb_pos,
|
| 467 |
+
**extra_input
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
# CFG处理
|
| 471 |
+
if use_camera_cfg and camera_guidance_scale > 1.0:
|
| 472 |
+
# 无条件预测(无camera条件)
|
| 473 |
+
noise_pred_uncond = pipe.dit(
|
| 474 |
+
new_latents,
|
| 475 |
+
timestep=timestep_tensor,
|
| 476 |
+
cam_emb=camera_embedding_uncond_batch,
|
| 477 |
+
latent_indices=latent_indices,
|
| 478 |
+
clean_latents=clean_latents,
|
| 479 |
+
clean_latent_indices=clean_latent_indices,
|
| 480 |
+
clean_latents_2x=clean_latents_2x,
|
| 481 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 482 |
+
clean_latents_4x=clean_latents_4x,
|
| 483 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 484 |
+
**prompt_emb_pos,
|
| 485 |
+
**extra_input
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
# Camera CFG
|
| 489 |
+
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_pos - noise_pred_uncond)
|
| 490 |
+
else:
|
| 491 |
+
noise_pred = noise_pred_pos
|
| 492 |
+
|
| 493 |
+
# Text CFG
|
| 494 |
+
if prompt_emb_neg is not None and text_guidance_scale > 1.0:
|
| 495 |
+
noise_pred_text_uncond = pipe.dit(
|
| 496 |
+
new_latents,
|
| 497 |
+
timestep=timestep_tensor,
|
| 498 |
+
cam_emb=camera_embedding,
|
| 499 |
+
latent_indices=latent_indices,
|
| 500 |
+
clean_latents=clean_latents,
|
| 501 |
+
clean_latent_indices=clean_latent_indices,
|
| 502 |
+
clean_latents_2x=clean_latents_2x,
|
| 503 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 504 |
+
clean_latents_4x=clean_latents_4x,
|
| 505 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 506 |
+
**prompt_emb_neg,
|
| 507 |
+
**extra_input
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
# Text CFG
|
| 511 |
+
noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
|
| 512 |
+
|
| 513 |
+
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
|
| 514 |
+
|
| 515 |
+
# 更新历史
|
| 516 |
+
new_latents_squeezed = new_latents.squeeze(0)
|
| 517 |
+
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
|
| 518 |
+
|
| 519 |
+
# 维护滑动窗口
|
| 520 |
+
if history_latents.shape[1] > max_history_frames:
|
| 521 |
+
first_frame = history_latents[:, 0:1, :, :]
|
| 522 |
+
recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
|
| 523 |
+
history_latents = torch.cat([first_frame, recent_frames], dim=1)
|
| 524 |
+
print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
|
| 525 |
+
|
| 526 |
+
print(f"更新后history_latents shape: {history_latents.shape}")
|
| 527 |
+
|
| 528 |
+
all_generated_frames.append(new_latents_squeezed)
|
| 529 |
+
total_generated += current_generation
|
| 530 |
+
|
| 531 |
+
print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
|
| 532 |
+
|
| 533 |
+
# 10. 解码和保存
|
| 534 |
+
print("\n🔧 解码生成的视频...")
|
| 535 |
+
|
| 536 |
+
all_generated = torch.cat(all_generated_frames, dim=1)
|
| 537 |
+
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
|
| 538 |
+
|
| 539 |
+
print(f"最终视频shape: {final_video.shape}")
|
| 540 |
+
|
| 541 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 542 |
+
|
| 543 |
+
print(f"Saving video to {output_path}")
|
| 544 |
+
|
| 545 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
|
| 546 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1)
|
| 547 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 548 |
+
|
| 549 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 550 |
+
for frame in video_np:
|
| 551 |
+
writer.append_data(frame)
|
| 552 |
+
|
| 553 |
+
print(f"🔧 OpenX FramePack滑动窗口生成完成! 保存到: {output_path}")
|
| 554 |
+
print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
|
| 555 |
+
|
| 556 |
+
def main():
|
| 557 |
+
parser = argparse.ArgumentParser(description="OpenX FramePack滑动窗口视频生成")
|
| 558 |
+
|
| 559 |
+
# 基础参数
|
| 560 |
+
parser.add_argument("--condition_pth", type=str,
|
| 561 |
+
default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth",
|
| 562 |
+
help="输入编码视频路径")
|
| 563 |
+
parser.add_argument("--start_frame", type=int, default=0)
|
| 564 |
+
parser.add_argument("--initial_condition_frames", type=int, default=16)
|
| 565 |
+
parser.add_argument("--frames_per_generation", type=int, default=8)
|
| 566 |
+
parser.add_argument("--total_frames_to_generate", type=int, default=24)
|
| 567 |
+
parser.add_argument("--max_history_frames", type=int, default=100)
|
| 568 |
+
parser.add_argument("--use_real_poses", action="store_true", default=False)
|
| 569 |
+
parser.add_argument("--dit_path", type=str,
|
| 570 |
+
default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/openx/openx_framepack/step2000.ckpt",
|
| 571 |
+
help="训练好的模型权重路径")
|
| 572 |
+
parser.add_argument("--output_path", type=str,
|
| 573 |
+
default='openx_results/output_openx_framepack_sliding.mp4')
|
| 574 |
+
parser.add_argument("--prompt", type=str,
|
| 575 |
+
default="A video of robotic manipulation task with camera movement")
|
| 576 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 577 |
+
|
| 578 |
+
# CFG参数
|
| 579 |
+
parser.add_argument("--use_camera_cfg", action="store_true", default=True,
|
| 580 |
+
help="使用Camera CFG")
|
| 581 |
+
parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
|
| 582 |
+
help="Camera guidance scale for CFG")
|
| 583 |
+
parser.add_argument("--text_guidance_scale", type=float, default=1.0,
|
| 584 |
+
help="Text guidance scale for CFG")
|
| 585 |
+
|
| 586 |
+
args = parser.parse_args()
|
| 587 |
+
|
| 588 |
+
print(f"🔧 OpenX FramePack CFG生成设置:")
|
| 589 |
+
print(f"Camera CFG: {args.use_camera_cfg}")
|
| 590 |
+
if args.use_camera_cfg:
|
| 591 |
+
print(f"Camera guidance scale: {args.camera_guidance_scale}")
|
| 592 |
+
print(f"Text guidance scale: {args.text_guidance_scale}")
|
| 593 |
+
print(f"OpenX特有特性: camera间隔为4帧,适用于机器人操作任务")
|
| 594 |
+
|
| 595 |
+
inference_openx_framepack_sliding_window(
|
| 596 |
+
condition_pth_path=args.condition_pth,
|
| 597 |
+
dit_path=args.dit_path,
|
| 598 |
+
output_path=args.output_path,
|
| 599 |
+
start_frame=args.start_frame,
|
| 600 |
+
initial_condition_frames=args.initial_condition_frames,
|
| 601 |
+
frames_per_generation=args.frames_per_generation,
|
| 602 |
+
total_frames_to_generate=args.total_frames_to_generate,
|
| 603 |
+
max_history_frames=args.max_history_frames,
|
| 604 |
+
device=args.device,
|
| 605 |
+
prompt=args.prompt,
|
| 606 |
+
use_real_poses=args.use_real_poses,
|
| 607 |
+
# CFG参数
|
| 608 |
+
use_camera_cfg=args.use_camera_cfg,
|
| 609 |
+
camera_guidance_scale=args.camera_guidance_scale,
|
| 610 |
+
text_guidance_scale=args.text_guidance_scale
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
if __name__ == "__main__":
|
| 614 |
+
main()
|
scripts/infer_origin.py
ADDED
|
@@ -0,0 +1,1108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import imageio
|
| 7 |
+
import json
|
| 8 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 9 |
+
import argparse
|
| 10 |
+
from torchvision.transforms import v2
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
import copy
|
| 13 |
+
|
| 14 |
+
def compute_relative_pose_matrix(pose1, pose2):
|
| 15 |
+
"""
|
| 16 |
+
计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel]
|
| 17 |
+
|
| 18 |
+
参数:
|
| 19 |
+
pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
|
| 20 |
+
pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
|
| 21 |
+
|
| 22 |
+
返回:
|
| 23 |
+
relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel
|
| 24 |
+
"""
|
| 25 |
+
# 分离平移向量和四元数
|
| 26 |
+
t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1]
|
| 27 |
+
q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1]
|
| 28 |
+
t2 = pose2[:3] # 第i+1帧平移
|
| 29 |
+
q2 = pose2[3:] # 第i+1帧四元数
|
| 30 |
+
|
| 31 |
+
# 1. 计算相对旋转矩阵 R_rel
|
| 32 |
+
rot1 = R.from_quat(q1) # 第i帧旋转
|
| 33 |
+
rot2 = R.from_quat(q2) # 第i+1帧旋转
|
| 34 |
+
rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆
|
| 35 |
+
R_rel = rot_rel.as_matrix() # 转换为3×3矩阵
|
| 36 |
+
|
| 37 |
+
# 2. 计算相对平移向量 t_rel
|
| 38 |
+
R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆)
|
| 39 |
+
t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1)
|
| 40 |
+
|
| 41 |
+
# 3. 组合为3×4矩阵 [R_rel | t_rel]
|
| 42 |
+
relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
|
| 43 |
+
|
| 44 |
+
return relative_matrix
|
| 45 |
+
|
| 46 |
+
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 47 |
+
"""从pth文件加载预编码的视频数据"""
|
| 48 |
+
print(f"Loading encoded video from {pth_path}")
|
| 49 |
+
|
| 50 |
+
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
| 51 |
+
full_latents = encoded_data['latents'] # [C, T, H, W]
|
| 52 |
+
|
| 53 |
+
print(f"Full latents shape: {full_latents.shape}")
|
| 54 |
+
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
|
| 55 |
+
|
| 56 |
+
if start_frame + num_frames > full_latents.shape[1]:
|
| 57 |
+
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
|
| 58 |
+
|
| 59 |
+
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
|
| 60 |
+
print(f"Extracted condition latents shape: {condition_latents.shape}")
|
| 61 |
+
|
| 62 |
+
return condition_latents, encoded_data
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 66 |
+
"""计算相机B相对于相机A的相对位姿矩阵"""
|
| 67 |
+
assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
|
| 68 |
+
assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
|
| 69 |
+
|
| 70 |
+
if use_torch:
|
| 71 |
+
if not isinstance(pose_a, torch.Tensor):
|
| 72 |
+
pose_a = torch.from_numpy(pose_a).float()
|
| 73 |
+
if not isinstance(pose_b, torch.Tensor):
|
| 74 |
+
pose_b = torch.from_numpy(pose_b).float()
|
| 75 |
+
|
| 76 |
+
pose_a_inv = torch.inverse(pose_a)
|
| 77 |
+
relative_pose = torch.matmul(pose_b, pose_a_inv)
|
| 78 |
+
else:
|
| 79 |
+
if not isinstance(pose_a, np.ndarray):
|
| 80 |
+
pose_a = np.array(pose_a, dtype=np.float32)
|
| 81 |
+
if not isinstance(pose_b, np.ndarray):
|
| 82 |
+
pose_b = np.array(pose_b, dtype=np.float32)
|
| 83 |
+
|
| 84 |
+
pose_a_inv = np.linalg.inv(pose_a)
|
| 85 |
+
relative_pose = np.matmul(pose_b, pose_a_inv)
|
| 86 |
+
|
| 87 |
+
return relative_pose
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def replace_dit_model_in_manager():
|
| 91 |
+
"""替换DiT模型类为MoE版本"""
|
| 92 |
+
from diffsynth.models.wan_video_dit_moe import WanModelMoe
|
| 93 |
+
from diffsynth.configs.model_config import model_loader_configs
|
| 94 |
+
|
| 95 |
+
for i, config in enumerate(model_loader_configs):
|
| 96 |
+
keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
|
| 97 |
+
|
| 98 |
+
if 'wan_video_dit' in model_names:
|
| 99 |
+
new_model_names = []
|
| 100 |
+
new_model_classes = []
|
| 101 |
+
|
| 102 |
+
for name, cls in zip(model_names, model_classes):
|
| 103 |
+
if name == 'wan_video_dit':
|
| 104 |
+
new_model_names.append(name)
|
| 105 |
+
new_model_classes.append(WanModelMoe)
|
| 106 |
+
print(f"✅ 替换了模型类: {name} -> WanModelMoe")
|
| 107 |
+
else:
|
| 108 |
+
new_model_names.append(name)
|
| 109 |
+
new_model_classes.append(cls)
|
| 110 |
+
|
| 111 |
+
model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def add_framepack_components(dit_model):
|
| 115 |
+
"""添加FramePack相关组件"""
|
| 116 |
+
if not hasattr(dit_model, 'clean_x_embedder'):
|
| 117 |
+
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 118 |
+
|
| 119 |
+
class CleanXEmbedder(nn.Module):
|
| 120 |
+
def __init__(self, inner_dim):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 123 |
+
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
| 124 |
+
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
| 125 |
+
|
| 126 |
+
def forward(self, x, scale="1x"):
|
| 127 |
+
if scale == "1x":
|
| 128 |
+
x = x.to(self.proj.weight.dtype)
|
| 129 |
+
return self.proj(x)
|
| 130 |
+
elif scale == "2x":
|
| 131 |
+
x = x.to(self.proj_2x.weight.dtype)
|
| 132 |
+
return self.proj_2x(x)
|
| 133 |
+
elif scale == "4x":
|
| 134 |
+
x = x.to(self.proj_4x.weight.dtype)
|
| 135 |
+
return self.proj_4x(x)
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError(f"Unsupported scale: {scale}")
|
| 138 |
+
|
| 139 |
+
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
|
| 140 |
+
model_dtype = next(dit_model.parameters()).dtype
|
| 141 |
+
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
|
| 142 |
+
print("✅ 添加了FramePack的clean_x_embedder组件")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def add_moe_components(dit_model, moe_config):
|
| 146 |
+
"""🔧 添加MoE相关组件 - 修正版本"""
|
| 147 |
+
if not hasattr(dit_model, 'moe_config'):
|
| 148 |
+
dit_model.moe_config = moe_config
|
| 149 |
+
print("✅ 添加了MoE配置到模型")
|
| 150 |
+
dit_model.top_k = moe_config.get("top_k", 1)
|
| 151 |
+
|
| 152 |
+
# 为每个block动态添加MoE组件
|
| 153 |
+
dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 154 |
+
unified_dim = moe_config.get("unified_dim", 25)
|
| 155 |
+
num_experts = moe_config.get("num_experts", 4)
|
| 156 |
+
from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
|
| 157 |
+
dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
|
| 158 |
+
dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
|
| 159 |
+
dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理
|
| 160 |
+
dit_model.global_router = nn.Linear(unified_dim, num_experts)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
for i, block in enumerate(dit_model.blocks):
|
| 164 |
+
# MoE网络 - 输入unified_dim,输出dim
|
| 165 |
+
block.moe = MultiModalMoE(
|
| 166 |
+
unified_dim=unified_dim,
|
| 167 |
+
output_dim=dim, # 输出维度匹配transformer block的dim
|
| 168 |
+
num_experts=moe_config.get("num_experts", 4),
|
| 169 |
+
top_k=moe_config.get("top_k", 2)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True,direction="left"):
|
| 176 |
+
"""为Sekai数据集生成camera embeddings - 滑动窗口版本"""
|
| 177 |
+
time_compression_ratio = 4
|
| 178 |
+
|
| 179 |
+
# 计算FramePack实际需要的camera帧数
|
| 180 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 181 |
+
|
| 182 |
+
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
|
| 183 |
+
print("🔧 使用真实Sekai camera数据")
|
| 184 |
+
cam_extrinsic = cam_data['extrinsic']
|
| 185 |
+
|
| 186 |
+
# 确保生成足够长的camera序列
|
| 187 |
+
max_needed_frames = max(
|
| 188 |
+
start_frame + current_history_length + new_frames,
|
| 189 |
+
framepack_needed_frames,
|
| 190 |
+
30
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
print(f"🔧 计算Sekai camera序列长度:")
|
| 194 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 195 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 196 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 197 |
+
|
| 198 |
+
relative_poses = []
|
| 199 |
+
for i in range(max_needed_frames):
|
| 200 |
+
# 计算当前帧在原始序列中的位置
|
| 201 |
+
frame_idx = i * time_compression_ratio
|
| 202 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 203 |
+
|
| 204 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 205 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 206 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 207 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 208 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 209 |
+
else:
|
| 210 |
+
# 超出范围,使用零运动
|
| 211 |
+
print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
|
| 212 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 213 |
+
|
| 214 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 215 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 216 |
+
|
| 217 |
+
# 创建对应长度的mask序列
|
| 218 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 219 |
+
# 从start_frame到current_history_length标记为condition
|
| 220 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 221 |
+
mask[start_frame:condition_end] = 1.0
|
| 222 |
+
|
| 223 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 224 |
+
print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
|
| 225 |
+
return camera_embedding.to(torch.bfloat16)
|
| 226 |
+
|
| 227 |
+
else:
|
| 228 |
+
if direction=="left":
|
| 229 |
+
print("-----Left-------")
|
| 230 |
+
|
| 231 |
+
max_needed_frames = max(
|
| 232 |
+
start_frame + current_history_length + new_frames,
|
| 233 |
+
framepack_needed_frames,
|
| 234 |
+
30
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
|
| 238 |
+
relative_poses = []
|
| 239 |
+
for i in range(max_needed_frames):
|
| 240 |
+
# 持续左转运动模式
|
| 241 |
+
yaw_per_frame = 0.05 # 每帧左转(正角度表示左转)
|
| 242 |
+
forward_speed = 0.05 # 每帧前进距离
|
| 243 |
+
|
| 244 |
+
pose = np.eye(4, dtype=np.float32)
|
| 245 |
+
|
| 246 |
+
# 旋转矩阵(绕Y轴左转)
|
| 247 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 248 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 249 |
+
|
| 250 |
+
pose[0, 0] = cos_yaw
|
| 251 |
+
pose[0, 2] = sin_yaw
|
| 252 |
+
pose[2, 0] = -sin_yaw
|
| 253 |
+
pose[2, 2] = cos_yaw
|
| 254 |
+
|
| 255 |
+
# 平移(在旋转后的局部坐标系中前进)
|
| 256 |
+
pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
|
| 257 |
+
|
| 258 |
+
# 添加轻微的向心运动,模拟圆形轨迹
|
| 259 |
+
radius_drift = 0.002 # 向圆心的轻微漂移
|
| 260 |
+
pose[0, 3] = -radius_drift # 局部X轴负方向(向左)
|
| 261 |
+
|
| 262 |
+
relative_pose = pose[:3, :]
|
| 263 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 264 |
+
|
| 265 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 266 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 267 |
+
|
| 268 |
+
# 创建对应长度的mask序列
|
| 269 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 270 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 271 |
+
mask[start_frame:condition_end] = 1.0
|
| 272 |
+
|
| 273 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 274 |
+
print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
|
| 275 |
+
return camera_embedding.to(torch.bfloat16)
|
| 276 |
+
elif direction=="right":
|
| 277 |
+
print("------------Right----------")
|
| 278 |
+
|
| 279 |
+
max_needed_frames = max(
|
| 280 |
+
start_frame + current_history_length + new_frames,
|
| 281 |
+
framepack_needed_frames,
|
| 282 |
+
30
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
|
| 286 |
+
relative_poses = []
|
| 287 |
+
for i in range(max_needed_frames):
|
| 288 |
+
# 持续左转运动模式
|
| 289 |
+
yaw_per_frame = -0.00 # 每帧左转(正角度表示左转)
|
| 290 |
+
forward_speed = 0.1 # 每帧前进距离
|
| 291 |
+
|
| 292 |
+
pose = np.eye(4, dtype=np.float32)
|
| 293 |
+
|
| 294 |
+
# 旋转矩阵(绕Y轴左转)
|
| 295 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 296 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 297 |
+
|
| 298 |
+
pose[0, 0] = cos_yaw
|
| 299 |
+
pose[0, 2] = sin_yaw
|
| 300 |
+
pose[2, 0] = -sin_yaw
|
| 301 |
+
pose[2, 2] = cos_yaw
|
| 302 |
+
|
| 303 |
+
# 平移(在旋转后的局部坐标系中前进)
|
| 304 |
+
pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
|
| 305 |
+
|
| 306 |
+
# 添加轻微的向心运动,模拟圆形轨迹
|
| 307 |
+
radius_drift = 0.000 # 向圆心的轻微漂移
|
| 308 |
+
pose[0, 3] = radius_drift # 局部X轴负方向(向左)
|
| 309 |
+
|
| 310 |
+
relative_pose = pose[:3, :]
|
| 311 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 312 |
+
|
| 313 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 314 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 315 |
+
|
| 316 |
+
# 创建对应长度的mask序列
|
| 317 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 318 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 319 |
+
mask[start_frame:condition_end] = 1.0
|
| 320 |
+
|
| 321 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 322 |
+
print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
|
| 323 |
+
return camera_embedding.to(torch.bfloat16)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses):
|
| 327 |
+
"""为OpenX数据集生成camera embeddings - 滑动窗口版本"""
|
| 328 |
+
time_compression_ratio = 4
|
| 329 |
+
|
| 330 |
+
# 计算FramePack实际需要的camera帧数
|
| 331 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 332 |
+
|
| 333 |
+
if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
|
| 334 |
+
print("🔧 使用OpenX真实camera数据")
|
| 335 |
+
cam_extrinsic = encoded_data['cam_emb']['extrinsic']
|
| 336 |
+
|
| 337 |
+
# 确保生成足够长的camera序列
|
| 338 |
+
max_needed_frames = max(
|
| 339 |
+
start_frame + current_history_length + new_frames,
|
| 340 |
+
framepack_needed_frames,
|
| 341 |
+
30
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
print(f"🔧 计算OpenX camera序列长度:")
|
| 345 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 346 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 347 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 348 |
+
|
| 349 |
+
relative_poses = []
|
| 350 |
+
for i in range(max_needed_frames):
|
| 351 |
+
# OpenX使用4倍间隔,类似sekai但处理更短的序列
|
| 352 |
+
frame_idx = i * time_compression_ratio
|
| 353 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 354 |
+
|
| 355 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 356 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 357 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 358 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 359 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 360 |
+
else:
|
| 361 |
+
# 超出范围,使用零运动
|
| 362 |
+
print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
|
| 363 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 364 |
+
|
| 365 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 366 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 367 |
+
|
| 368 |
+
# 创建对应长度的mask序列
|
| 369 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 370 |
+
# 从start_frame到current_history_length标记为condition
|
| 371 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 372 |
+
mask[start_frame:condition_end] = 1.0
|
| 373 |
+
|
| 374 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 375 |
+
print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
|
| 376 |
+
return camera_embedding.to(torch.bfloat16)
|
| 377 |
+
|
| 378 |
+
else:
|
| 379 |
+
print("🔧 使用OpenX合成camera数据")
|
| 380 |
+
|
| 381 |
+
max_needed_frames = max(
|
| 382 |
+
start_frame + current_history_length + new_frames,
|
| 383 |
+
framepack_needed_frames,
|
| 384 |
+
30
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
|
| 388 |
+
relative_poses = []
|
| 389 |
+
for i in range(max_needed_frames):
|
| 390 |
+
# OpenX机器人操作运动模式 - 较小的运动幅度
|
| 391 |
+
# 模拟机器人手臂的精细操作运动
|
| 392 |
+
roll_per_frame = 0.02 # 轻微翻滚
|
| 393 |
+
pitch_per_frame = 0.01 # 轻微俯仰
|
| 394 |
+
yaw_per_frame = 0.015 # 轻微偏航
|
| 395 |
+
forward_speed = 0.003 # 较慢的前进速度
|
| 396 |
+
|
| 397 |
+
pose = np.eye(4, dtype=np.float32)
|
| 398 |
+
|
| 399 |
+
# 复合旋转 - 模拟机器人手臂的复杂运动
|
| 400 |
+
# 绕X轴旋转(roll)
|
| 401 |
+
cos_roll = np.cos(roll_per_frame)
|
| 402 |
+
sin_roll = np.sin(roll_per_frame)
|
| 403 |
+
# 绕Y轴旋转(pitch)
|
| 404 |
+
cos_pitch = np.cos(pitch_per_frame)
|
| 405 |
+
sin_pitch = np.sin(pitch_per_frame)
|
| 406 |
+
# 绕Z轴旋转(yaw)
|
| 407 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 408 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 409 |
+
|
| 410 |
+
# 简化的复合旋转矩阵(ZYX顺序)
|
| 411 |
+
pose[0, 0] = cos_yaw * cos_pitch
|
| 412 |
+
pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
|
| 413 |
+
pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
|
| 414 |
+
pose[1, 0] = sin_yaw * cos_pitch
|
| 415 |
+
pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
|
| 416 |
+
pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
|
| 417 |
+
pose[2, 0] = -sin_pitch
|
| 418 |
+
pose[2, 1] = cos_pitch * sin_roll
|
| 419 |
+
pose[2, 2] = cos_pitch * cos_roll
|
| 420 |
+
|
| 421 |
+
# 平移 - 模拟机器人操作的精细移动
|
| 422 |
+
pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
|
| 423 |
+
pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动
|
| 424 |
+
pose[2, 3] = -forward_speed # Z方向(深度)主要移动
|
| 425 |
+
|
| 426 |
+
relative_pose = pose[:3, :]
|
| 427 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 428 |
+
|
| 429 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 430 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 431 |
+
|
| 432 |
+
# 创建对应长度的mask序列
|
| 433 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 434 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 435 |
+
mask[start_frame:condition_end] = 1.0
|
| 436 |
+
|
| 437 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 438 |
+
print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
|
| 439 |
+
return camera_embedding.to(torch.bfloat16)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames):
|
| 443 |
+
"""为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
|
| 444 |
+
time_compression_ratio = 4
|
| 445 |
+
|
| 446 |
+
# 计算FramePack实际需要的camera��数
|
| 447 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 448 |
+
|
| 449 |
+
if scene_info is not None and 'keyframe_poses' in scene_info:
|
| 450 |
+
print("🔧 使用NuScenes真实pose数据")
|
| 451 |
+
keyframe_poses = scene_info['keyframe_poses']
|
| 452 |
+
|
| 453 |
+
if len(keyframe_poses) == 0:
|
| 454 |
+
print("⚠️ NuScenes keyframe_poses为空,使用零pose")
|
| 455 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 456 |
+
|
| 457 |
+
pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
|
| 458 |
+
|
| 459 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 460 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 461 |
+
mask[start_frame:condition_end] = 1.0
|
| 462 |
+
|
| 463 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 464 |
+
print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}")
|
| 465 |
+
return camera_embedding.to(torch.bfloat16)
|
| 466 |
+
|
| 467 |
+
# 使用第一个pose作为参考
|
| 468 |
+
reference_pose = keyframe_poses[0]
|
| 469 |
+
|
| 470 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 471 |
+
|
| 472 |
+
pose_vecs = []
|
| 473 |
+
for i in range(max_needed_frames):
|
| 474 |
+
if i < len(keyframe_poses):
|
| 475 |
+
current_pose = keyframe_poses[i]
|
| 476 |
+
|
| 477 |
+
# 计算相对位移
|
| 478 |
+
translation = torch.tensor(
|
| 479 |
+
np.array(current_pose['translation']) - np.array(reference_pose['translation']),
|
| 480 |
+
dtype=torch.float32
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
# 计算相对旋转(简化版本)
|
| 484 |
+
rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
|
| 485 |
+
|
| 486 |
+
pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
|
| 487 |
+
else:
|
| 488 |
+
# 超出范围,使用零pose
|
| 489 |
+
pose_vec = torch.cat([
|
| 490 |
+
torch.zeros(3, dtype=torch.float32),
|
| 491 |
+
torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
|
| 492 |
+
], dim=0) # [7D]
|
| 493 |
+
|
| 494 |
+
pose_vecs.append(pose_vec)
|
| 495 |
+
|
| 496 |
+
pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
|
| 497 |
+
|
| 498 |
+
# 创建mask
|
| 499 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 500 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 501 |
+
mask[start_frame:condition_end] = 1.0
|
| 502 |
+
|
| 503 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 504 |
+
print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
|
| 505 |
+
return camera_embedding.to(torch.bfloat16)
|
| 506 |
+
|
| 507 |
+
else:
|
| 508 |
+
print("🔧 使用NuScenes合成pose数据")
|
| 509 |
+
max_needed_frames = max(framepack_needed_frames, 30)
|
| 510 |
+
|
| 511 |
+
# 创建合成运动序列
|
| 512 |
+
pose_vecs = []
|
| 513 |
+
for i in range(max_needed_frames):
|
| 514 |
+
# 左转运动模式 - 类似城市驾驶中的左转弯
|
| 515 |
+
angle = i * 0.04 # 每帧转动0.08弧度(稍微慢一点的转弯)
|
| 516 |
+
radius = 15.0 # 较大的转弯半径,更符合汽车转弯
|
| 517 |
+
|
| 518 |
+
# 计算圆弧轨迹上的位置
|
| 519 |
+
x = radius * np.sin(angle)
|
| 520 |
+
y = 0.0 # 保持水平面运动
|
| 521 |
+
z = radius * (1 - np.cos(angle))
|
| 522 |
+
|
| 523 |
+
translation = torch.tensor([x, y, z], dtype=torch.float32)
|
| 524 |
+
|
| 525 |
+
# 车辆朝向 - 始终沿着轨迹切线方向
|
| 526 |
+
yaw = angle + np.pi/2 # 相对于初始前进方向的偏航角
|
| 527 |
+
# 四元数表示绕Y轴的旋转
|
| 528 |
+
rotation = torch.tensor([
|
| 529 |
+
np.cos(yaw/2), # w (实部)
|
| 530 |
+
0.0, # x
|
| 531 |
+
0.0, # y
|
| 532 |
+
np.sin(yaw/2) # z (虚部,绕Y轴)
|
| 533 |
+
], dtype=torch.float32)
|
| 534 |
+
|
| 535 |
+
pose_vec = torch.cat([translation, rotation], dim=0) # [7D: tx,ty,tz,qw,qx,qy,qz]
|
| 536 |
+
pose_vecs.append(pose_vec)
|
| 537 |
+
|
| 538 |
+
pose_sequence = torch.stack(pose_vecs, dim=0)
|
| 539 |
+
|
| 540 |
+
# 创建mask
|
| 541 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 542 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 543 |
+
mask[start_frame:condition_end] = 1.0
|
| 544 |
+
|
| 545 |
+
camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
|
| 546 |
+
print(f"🔧 NuScenes合成左转pose embedding shape: {camera_embedding.shape}")
|
| 547 |
+
return camera_embedding.to(torch.bfloat16)
|
| 548 |
+
|
| 549 |
+
def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49):
|
| 550 |
+
"""FramePack滑动窗口机制 - MoE版本"""
|
| 551 |
+
# history_latents: [C, T, H, W] 当前的历史latents
|
| 552 |
+
C, T, H, W = history_latents.shape
|
| 553 |
+
|
| 554 |
+
# 固定索引结构(这决定了需要的camera帧数)
|
| 555 |
+
total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
|
| 556 |
+
indices = torch.arange(0, total_indices_length)
|
| 557 |
+
split_sizes = [1, 16, 2, 1, target_frames_to_generate]
|
| 558 |
+
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
|
| 559 |
+
indices.split(split_sizes, dim=0)
|
| 560 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
|
| 561 |
+
|
| 562 |
+
# 检查camera长度是否足够
|
| 563 |
+
if camera_embedding_full.shape[0] < total_indices_length:
|
| 564 |
+
shortage = total_indices_length - camera_embedding_full.shape[0]
|
| 565 |
+
padding = torch.zeros(shortage, camera_embedding_full.shape[1],
|
| 566 |
+
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
|
| 567 |
+
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
|
| 568 |
+
|
| 569 |
+
# 从完整camera序列中选取对应部分
|
| 570 |
+
combined_camera = camera_embedding_full[:total_indices_length, :].clone()
|
| 571 |
+
|
| 572 |
+
# 根据当前history length重新设置mask
|
| 573 |
+
combined_camera[:, -1] = 0.0 # 先全部设为target (0)
|
| 574 |
+
|
| 575 |
+
# 设置condition mask:前19帧根据实际历史长度决定
|
| 576 |
+
if T > 0:
|
| 577 |
+
available_frames = min(T, 19)
|
| 578 |
+
start_pos = 19 - available_frames
|
| 579 |
+
combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
|
| 580 |
+
|
| 581 |
+
print(f"🔧 MoE Camera mask更新:")
|
| 582 |
+
print(f" - 历史帧数: {T}")
|
| 583 |
+
print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
|
| 584 |
+
print(f" - 模态类型: {modality_type}")
|
| 585 |
+
|
| 586 |
+
# 处理latents
|
| 587 |
+
clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 588 |
+
|
| 589 |
+
if T > 0:
|
| 590 |
+
available_frames = min(T, 19)
|
| 591 |
+
start_pos = 19 - available_frames
|
| 592 |
+
clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
|
| 593 |
+
|
| 594 |
+
clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
|
| 595 |
+
clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
|
| 596 |
+
clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
|
| 597 |
+
|
| 598 |
+
if T > 0:
|
| 599 |
+
start_latent = history_latents[:, 0:1, :, :]
|
| 600 |
+
else:
|
| 601 |
+
start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 602 |
+
|
| 603 |
+
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
|
| 604 |
+
|
| 605 |
+
return {
|
| 606 |
+
'latent_indices': latent_indices,
|
| 607 |
+
'clean_latents': clean_latents,
|
| 608 |
+
'clean_latents_2x': clean_latents_2x,
|
| 609 |
+
'clean_latents_4x': clean_latents_4x,
|
| 610 |
+
'clean_latent_indices': clean_latent_indices,
|
| 611 |
+
'clean_latent_2x_indices': clean_latent_2x_indices,
|
| 612 |
+
'clean_latent_4x_indices': clean_latent_4x_indices,
|
| 613 |
+
'camera_embedding': combined_camera,
|
| 614 |
+
'modality_type': modality_type, # 新增模态类型信息
|
| 615 |
+
'current_length': T,
|
| 616 |
+
'next_length': T + target_frames_to_generate
|
| 617 |
+
}
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
def inference_moe_framepack_sliding_window(
|
| 621 |
+
condition_pth_path,
|
| 622 |
+
dit_path,
|
| 623 |
+
output_path="moe/infer_results/output_moe_framepack_sliding.mp4",
|
| 624 |
+
start_frame=0,
|
| 625 |
+
initial_condition_frames=8,
|
| 626 |
+
frames_per_generation=4,
|
| 627 |
+
total_frames_to_generate=32,
|
| 628 |
+
max_history_frames=49,
|
| 629 |
+
device="cuda",
|
| 630 |
+
prompt="A video of a scene shot using a pedestrian's front camera while walking",
|
| 631 |
+
modality_type="sekai", # "sekai" 或 "nuscenes"
|
| 632 |
+
use_real_poses=True,
|
| 633 |
+
scene_info_path=None, # 对于NuScenes数据集
|
| 634 |
+
# CFG参数
|
| 635 |
+
use_camera_cfg=True,
|
| 636 |
+
camera_guidance_scale=2.0,
|
| 637 |
+
text_guidance_scale=1.0,
|
| 638 |
+
# MoE参数
|
| 639 |
+
moe_num_experts=4,
|
| 640 |
+
moe_top_k=2,
|
| 641 |
+
moe_hidden_dim=None,
|
| 642 |
+
direction="left",
|
| 643 |
+
use_gt_prompt=True
|
| 644 |
+
):
|
| 645 |
+
"""
|
| 646 |
+
MoE FramePack滑动窗口视频生成 - 支持多模态
|
| 647 |
+
"""
|
| 648 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 649 |
+
print(f"🔧 MoE FramePack滑动窗口生成开始...")
|
| 650 |
+
print(f"模态类型: {modality_type}")
|
| 651 |
+
print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
|
| 652 |
+
print(f"Text guidance scale: {text_guidance_scale}")
|
| 653 |
+
print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
|
| 654 |
+
|
| 655 |
+
# 1. 模型初始化
|
| 656 |
+
replace_dit_model_in_manager()
|
| 657 |
+
|
| 658 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 659 |
+
model_manager.load_models([
|
| 660 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 661 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 662 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 663 |
+
])
|
| 664 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 665 |
+
|
| 666 |
+
# 2. 添加传统camera编码器(兼容性)
|
| 667 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 668 |
+
for block in pipe.dit.blocks:
|
| 669 |
+
block.cam_encoder = nn.Linear(13, dim)
|
| 670 |
+
block.projector = nn.Linear(dim, dim)
|
| 671 |
+
block.cam_encoder.weight.data.zero_()
|
| 672 |
+
block.cam_encoder.bias.data.zero_()
|
| 673 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 674 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 675 |
+
|
| 676 |
+
# 3. 添加FramePack组件
|
| 677 |
+
add_framepack_components(pipe.dit)
|
| 678 |
+
|
| 679 |
+
# 4. 添加MoE组件
|
| 680 |
+
moe_config = {
|
| 681 |
+
"num_experts": moe_num_experts,
|
| 682 |
+
"top_k": moe_top_k,
|
| 683 |
+
"hidden_dim": moe_hidden_dim or dim * 2,
|
| 684 |
+
"sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
|
| 685 |
+
"nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
|
| 686 |
+
"openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
|
| 687 |
+
}
|
| 688 |
+
add_moe_components(pipe.dit, moe_config)
|
| 689 |
+
|
| 690 |
+
# 5. 加载训练好的权重
|
| 691 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 692 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
|
| 693 |
+
pipe = pipe.to(device)
|
| 694 |
+
model_dtype = next(pipe.dit.parameters()).dtype
|
| 695 |
+
|
| 696 |
+
if hasattr(pipe.dit, 'clean_x_embedder'):
|
| 697 |
+
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
|
| 698 |
+
|
| 699 |
+
pipe.scheduler.set_timesteps(50)
|
| 700 |
+
|
| 701 |
+
# 6. 加载初始条件
|
| 702 |
+
print("Loading initial condition frames...")
|
| 703 |
+
initial_latents, encoded_data = load_encoded_video_from_pth(
|
| 704 |
+
condition_pth_path,
|
| 705 |
+
start_frame=start_frame,
|
| 706 |
+
num_frames=initial_condition_frames
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
# 空间裁剪
|
| 710 |
+
target_height, target_width = 60, 104
|
| 711 |
+
C, T, H, W = initial_latents.shape
|
| 712 |
+
|
| 713 |
+
if H > target_height or W > target_width:
|
| 714 |
+
h_start = (H - target_height) // 2
|
| 715 |
+
w_start = (W - target_width) // 2
|
| 716 |
+
initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
|
| 717 |
+
H, W = target_height, target_width
|
| 718 |
+
|
| 719 |
+
history_latents = initial_latents.to(device, dtype=model_dtype)
|
| 720 |
+
|
| 721 |
+
print(f"初始history_latents shape: {history_latents.shape}")
|
| 722 |
+
|
| 723 |
+
# 7. 编码prompt - 支持CFG
|
| 724 |
+
if use_gt_prompt and 'prompt_emb' in encoded_data:
|
| 725 |
+
print("✅ 使用预编码的GT prompt embedding")
|
| 726 |
+
prompt_emb_pos = encoded_data['prompt_emb']
|
| 727 |
+
# 将prompt_emb移到正确的设备和数据类型
|
| 728 |
+
if 'context' in prompt_emb_pos:
|
| 729 |
+
prompt_emb_pos['context'] = prompt_emb_pos['context'].to(device, dtype=model_dtype)
|
| 730 |
+
if 'context_mask' in prompt_emb_pos:
|
| 731 |
+
prompt_emb_pos['context_mask'] = prompt_emb_pos['context_mask'].to(device, dtype=model_dtype)
|
| 732 |
+
|
| 733 |
+
# 如果使用Text CFG,生成负向prompt
|
| 734 |
+
if text_guidance_scale > 1.0:
|
| 735 |
+
prompt_emb_neg = pipe.encode_prompt("")
|
| 736 |
+
print(f"使用Text CFG with GT prompt,guidance scale: {text_guidance_scale}")
|
| 737 |
+
else:
|
| 738 |
+
prompt_emb_neg = None
|
| 739 |
+
print("不使用Text CFG")
|
| 740 |
+
|
| 741 |
+
# 🔧 打印GT prompt文本(如果有)
|
| 742 |
+
if 'prompt' in encoded_data['prompt_emb']:
|
| 743 |
+
gt_prompt_text = encoded_data['prompt_emb']['prompt']
|
| 744 |
+
print(f"📝 GT Prompt文本: {gt_prompt_text}")
|
| 745 |
+
else:
|
| 746 |
+
# 使用传入的prompt参数重新编码
|
| 747 |
+
print(f"🔄 重新编码prompt: {prompt}")
|
| 748 |
+
if text_guidance_scale > 1.0:
|
| 749 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 750 |
+
prompt_emb_neg = pipe.encode_prompt("")
|
| 751 |
+
print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
|
| 752 |
+
else:
|
| 753 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 754 |
+
prompt_emb_neg = None
|
| 755 |
+
print("不使用Text CFG")
|
| 756 |
+
|
| 757 |
+
# 8. 加载场景信息(对于NuScenes)
|
| 758 |
+
scene_info = None
|
| 759 |
+
if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
|
| 760 |
+
with open(scene_info_path, 'r') as f:
|
| 761 |
+
scene_info = json.load(f)
|
| 762 |
+
print(f"加载NuScenes场景信息: {scene_info_path}")
|
| 763 |
+
|
| 764 |
+
# 9. 预生成完整的camera embedding序列
|
| 765 |
+
if modality_type == "sekai":
|
| 766 |
+
camera_embedding_full = generate_sekai_camera_embeddings_sliding(
|
| 767 |
+
encoded_data.get('cam_emb', None),
|
| 768 |
+
0,
|
| 769 |
+
max_history_frames,
|
| 770 |
+
0,
|
| 771 |
+
0,
|
| 772 |
+
use_real_poses=use_real_poses,
|
| 773 |
+
direction=direction
|
| 774 |
+
).to(device, dtype=model_dtype)
|
| 775 |
+
elif modality_type == "nuscenes":
|
| 776 |
+
camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
|
| 777 |
+
scene_info,
|
| 778 |
+
0,
|
| 779 |
+
max_history_frames,
|
| 780 |
+
0
|
| 781 |
+
).to(device, dtype=model_dtype)
|
| 782 |
+
elif modality_type == "openx":
|
| 783 |
+
camera_embedding_full = generate_openx_camera_embeddings_sliding(
|
| 784 |
+
encoded_data,
|
| 785 |
+
0,
|
| 786 |
+
max_history_frames,
|
| 787 |
+
0,
|
| 788 |
+
use_real_poses=use_real_poses
|
| 789 |
+
).to(device, dtype=model_dtype)
|
| 790 |
+
else:
|
| 791 |
+
raise ValueError(f"不支持的模态类型: {modality_type}")
|
| 792 |
+
|
| 793 |
+
print(f"完整camera序列shape: {camera_embedding_full.shape}")
|
| 794 |
+
|
| 795 |
+
# 10. 为Camera CFG创建无条件的camera embedding
|
| 796 |
+
if use_camera_cfg:
|
| 797 |
+
camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
|
| 798 |
+
print(f"创建无条件camera embedding用于CFG")
|
| 799 |
+
|
| 800 |
+
# 11. 滑动窗口生成循环
|
| 801 |
+
total_generated = 0
|
| 802 |
+
all_generated_frames = []
|
| 803 |
+
|
| 804 |
+
while total_generated < total_frames_to_generate:
|
| 805 |
+
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
|
| 806 |
+
print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
|
| 807 |
+
print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
|
| 808 |
+
|
| 809 |
+
# FramePack数据准备 - MoE版本
|
| 810 |
+
framepack_data = prepare_framepack_sliding_window_with_camera_moe(
|
| 811 |
+
history_latents,
|
| 812 |
+
current_generation,
|
| 813 |
+
camera_embedding_full,
|
| 814 |
+
start_frame,
|
| 815 |
+
modality_type,
|
| 816 |
+
max_history_frames
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
# 准备输入
|
| 820 |
+
clean_latents = framepack_data['clean_latents'].unsqueeze(0)
|
| 821 |
+
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
|
| 822 |
+
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
|
| 823 |
+
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
|
| 824 |
+
|
| 825 |
+
# 准备modality_inputs
|
| 826 |
+
modality_inputs = {modality_type: camera_embedding}
|
| 827 |
+
|
| 828 |
+
# 为CFG准备无条件camera embedding
|
| 829 |
+
if use_camera_cfg:
|
| 830 |
+
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
|
| 831 |
+
modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
|
| 832 |
+
|
| 833 |
+
# 索引处理
|
| 834 |
+
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
|
| 835 |
+
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
|
| 836 |
+
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
|
| 837 |
+
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
|
| 838 |
+
|
| 839 |
+
# 初始化要生成的latents
|
| 840 |
+
new_latents = torch.randn(
|
| 841 |
+
1, C, current_generation, H, W,
|
| 842 |
+
device=device, dtype=model_dtype
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
extra_input = pipe.prepare_extra_input(new_latents)
|
| 846 |
+
|
| 847 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 848 |
+
print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
|
| 849 |
+
|
| 850 |
+
# 去噪循环 - 支持CFG
|
| 851 |
+
timesteps = pipe.scheduler.timesteps
|
| 852 |
+
|
| 853 |
+
for i, timestep in enumerate(timesteps):
|
| 854 |
+
if i % 10 == 0:
|
| 855 |
+
print(f" 去噪步骤 {i+1}/{len(timesteps)}")
|
| 856 |
+
|
| 857 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
|
| 858 |
+
|
| 859 |
+
with torch.no_grad():
|
| 860 |
+
# CFG推理
|
| 861 |
+
if use_camera_cfg and camera_guidance_scale > 1.0:
|
| 862 |
+
# 条件预测(有camera)
|
| 863 |
+
noise_pred_cond, moe_loess = pipe.dit(
|
| 864 |
+
new_latents,
|
| 865 |
+
timestep=timestep_tensor,
|
| 866 |
+
cam_emb=camera_embedding,
|
| 867 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 868 |
+
latent_indices=latent_indices,
|
| 869 |
+
clean_latents=clean_latents,
|
| 870 |
+
clean_latent_indices=clean_latent_indices,
|
| 871 |
+
clean_latents_2x=clean_latents_2x,
|
| 872 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 873 |
+
clean_latents_4x=clean_latents_4x,
|
| 874 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 875 |
+
**prompt_emb_pos,
|
| 876 |
+
**extra_input
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
# 无条件预测(无camera)
|
| 880 |
+
noise_pred_uncond, moe_loess = pipe.dit(
|
| 881 |
+
new_latents,
|
| 882 |
+
timestep=timestep_tensor,
|
| 883 |
+
cam_emb=camera_embedding_uncond_batch,
|
| 884 |
+
modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
|
| 885 |
+
latent_indices=latent_indices,
|
| 886 |
+
clean_latents=clean_latents,
|
| 887 |
+
clean_latent_indices=clean_latent_indices,
|
| 888 |
+
clean_latents_2x=clean_latents_2x,
|
| 889 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 890 |
+
clean_latents_4x=clean_latents_4x,
|
| 891 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 892 |
+
**(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
|
| 893 |
+
**extra_input
|
| 894 |
+
)
|
| 895 |
+
|
| 896 |
+
# Camera CFG
|
| 897 |
+
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 898 |
+
|
| 899 |
+
# 如果同时使用Text CFG
|
| 900 |
+
if text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 901 |
+
noise_pred_text_uncond, moe_loess = pipe.dit(
|
| 902 |
+
new_latents,
|
| 903 |
+
timestep=timestep_tensor,
|
| 904 |
+
cam_emb=camera_embedding,
|
| 905 |
+
modality_inputs=modality_inputs,
|
| 906 |
+
latent_indices=latent_indices,
|
| 907 |
+
clean_latents=clean_latents,
|
| 908 |
+
clean_latent_indices=clean_latent_indices,
|
| 909 |
+
clean_latents_2x=clean_latents_2x,
|
| 910 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 911 |
+
clean_latents_4x=clean_latents_4x,
|
| 912 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 913 |
+
**prompt_emb_neg,
|
| 914 |
+
**extra_input
|
| 915 |
+
)
|
| 916 |
+
|
| 917 |
+
# 应用Text CFG到已经应用Camera CFG的结果
|
| 918 |
+
noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
|
| 919 |
+
|
| 920 |
+
elif text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 921 |
+
# 只使用Text CFG
|
| 922 |
+
noise_pred_cond, moe_loess = pipe.dit(
|
| 923 |
+
new_latents,
|
| 924 |
+
timestep=timestep_tensor,
|
| 925 |
+
cam_emb=camera_embedding,
|
| 926 |
+
modality_inputs=modality_inputs,
|
| 927 |
+
latent_indices=latent_indices,
|
| 928 |
+
clean_latents=clean_latents,
|
| 929 |
+
clean_latent_indices=clean_latent_indices,
|
| 930 |
+
clean_latents_2x=clean_latents_2x,
|
| 931 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 932 |
+
clean_latents_4x=clean_latents_4x,
|
| 933 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 934 |
+
**prompt_emb_pos,
|
| 935 |
+
**extra_input
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
noise_pred_uncond, moe_loess= pipe.dit(
|
| 939 |
+
new_latents,
|
| 940 |
+
timestep=timestep_tensor,
|
| 941 |
+
cam_emb=camera_embedding,
|
| 942 |
+
modality_inputs=modality_inputs,
|
| 943 |
+
latent_indices=latent_indices,
|
| 944 |
+
clean_latents=clean_latents,
|
| 945 |
+
clean_latent_indices=clean_latent_indices,
|
| 946 |
+
clean_latents_2x=clean_latents_2x,
|
| 947 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 948 |
+
clean_latents_4x=clean_latents_4x,
|
| 949 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 950 |
+
**prompt_emb_neg,
|
| 951 |
+
**extra_input
|
| 952 |
+
)
|
| 953 |
+
|
| 954 |
+
noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 955 |
+
|
| 956 |
+
else:
|
| 957 |
+
# 标准推理(无CFG)
|
| 958 |
+
noise_pred, moe_loess = pipe.dit(
|
| 959 |
+
new_latents,
|
| 960 |
+
timestep=timestep_tensor,
|
| 961 |
+
cam_emb=camera_embedding,
|
| 962 |
+
modality_inputs=modality_inputs, # MoE模态输入
|
| 963 |
+
latent_indices=latent_indices,
|
| 964 |
+
clean_latents=clean_latents,
|
| 965 |
+
clean_latent_indices=clean_latent_indices,
|
| 966 |
+
clean_latents_2x=clean_latents_2x,
|
| 967 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 968 |
+
clean_latents_4x=clean_latents_4x,
|
| 969 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 970 |
+
**prompt_emb_pos,
|
| 971 |
+
**extra_input
|
| 972 |
+
)
|
| 973 |
+
|
| 974 |
+
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
|
| 975 |
+
|
| 976 |
+
# 更新历史
|
| 977 |
+
new_latents_squeezed = new_latents.squeeze(0)
|
| 978 |
+
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
|
| 979 |
+
|
| 980 |
+
# 维护滑动窗口
|
| 981 |
+
if history_latents.shape[1] > max_history_frames:
|
| 982 |
+
first_frame = history_latents[:, 0:1, :, :]
|
| 983 |
+
recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
|
| 984 |
+
history_latents = torch.cat([first_frame, recent_frames], dim=1)
|
| 985 |
+
print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
|
| 986 |
+
|
| 987 |
+
print(f"更新后history_latents shape: {history_latents.shape}")
|
| 988 |
+
|
| 989 |
+
all_generated_frames.append(new_latents_squeezed)
|
| 990 |
+
total_generated += current_generation
|
| 991 |
+
|
| 992 |
+
print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
|
| 993 |
+
|
| 994 |
+
# 12. 解码和保存
|
| 995 |
+
print("\n🔧 解码生成的视频...")
|
| 996 |
+
|
| 997 |
+
all_generated = torch.cat(all_generated_frames, dim=1)
|
| 998 |
+
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
|
| 999 |
+
|
| 1000 |
+
print(f"最终视频shape: {final_video.shape}")
|
| 1001 |
+
|
| 1002 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 1003 |
+
|
| 1004 |
+
print(f"Saving video to {output_path}")
|
| 1005 |
+
|
| 1006 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
|
| 1007 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1)
|
| 1008 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 1009 |
+
|
| 1010 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 1011 |
+
for frame in video_np:
|
| 1012 |
+
writer.append_data(frame)
|
| 1013 |
+
|
| 1014 |
+
print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
|
| 1015 |
+
print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
|
| 1016 |
+
print(f"使用模态: {modality_type}")
|
| 1017 |
+
|
| 1018 |
+
|
| 1019 |
+
def main():
|
| 1020 |
+
parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
|
| 1021 |
+
|
| 1022 |
+
# 基础参数
|
| 1023 |
+
parser.add_argument("--condition_pth", type=str,
|
| 1024 |
+
#default="/share_zhuyixuan05/zhuyixuan05/sekai-game-drone/00500210001_0012150_0012450/encoded_video.pth")
|
| 1025 |
+
default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth")
|
| 1026 |
+
#default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth")
|
| 1027 |
+
#default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth")
|
| 1028 |
+
parser.add_argument("--start_frame", type=int, default=0)
|
| 1029 |
+
parser.add_argument("--initial_condition_frames", type=int, default=16)
|
| 1030 |
+
parser.add_argument("--frames_per_generation", type=int, default=8)
|
| 1031 |
+
parser.add_argument("--total_frames_to_generate", type=int, default=24)
|
| 1032 |
+
parser.add_argument("--max_history_frames", type=int, default=100)
|
| 1033 |
+
parser.add_argument("--use_real_poses", default=False)
|
| 1034 |
+
parser.add_argument("--dit_path", type=str,
|
| 1035 |
+
default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt")
|
| 1036 |
+
parser.add_argument("--output_path", type=str,
|
| 1037 |
+
default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4')
|
| 1038 |
+
parser.add_argument("--prompt", type=str,
|
| 1039 |
+
default="A car is driving")
|
| 1040 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 1041 |
+
|
| 1042 |
+
# 模态类型参数
|
| 1043 |
+
parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="nuscenes",
|
| 1044 |
+
help="模态类型:sekai 或 nuscenes 或 openx")
|
| 1045 |
+
parser.add_argument("--scene_info_path", type=str, default=None,
|
| 1046 |
+
help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
|
| 1047 |
+
|
| 1048 |
+
# CFG参数
|
| 1049 |
+
parser.add_argument("--use_camera_cfg", default=False,
|
| 1050 |
+
help="使用Camera CFG")
|
| 1051 |
+
parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
|
| 1052 |
+
help="Camera guidance scale for CFG")
|
| 1053 |
+
parser.add_argument("--text_guidance_scale", type=float, default=1.0,
|
| 1054 |
+
help="Text guidance scale for CFG")
|
| 1055 |
+
|
| 1056 |
+
# MoE参数
|
| 1057 |
+
parser.add_argument("--moe_num_experts", type=int, default=3, help="专家数量")
|
| 1058 |
+
parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
|
| 1059 |
+
parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
|
| 1060 |
+
parser.add_argument("--direction", type=str, default="left")
|
| 1061 |
+
parser.add_argument("--use_gt_prompt", action="store_true", default=False,
|
| 1062 |
+
help="使用数据集中的ground truth prompt embedding")
|
| 1063 |
+
|
| 1064 |
+
args = parser.parse_args()
|
| 1065 |
+
|
| 1066 |
+
print(f"🔧 MoE FramePack CFG生成设置:")
|
| 1067 |
+
print(f"模态类型: {args.modality_type}")
|
| 1068 |
+
print(f"Camera CFG: {args.use_camera_cfg}")
|
| 1069 |
+
if args.use_camera_cfg:
|
| 1070 |
+
print(f"Camera guidance scale: {args.camera_guidance_scale}")
|
| 1071 |
+
print(f"使用GT Prompt: {args.use_gt_prompt}")
|
| 1072 |
+
print(f"Text guidance scale: {args.text_guidance_scale}")
|
| 1073 |
+
print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
|
| 1074 |
+
print(f"DiT{args.dit_path}")
|
| 1075 |
+
|
| 1076 |
+
# 验证NuScenes参数
|
| 1077 |
+
if args.modality_type == "nuscenes" and not args.scene_info_path:
|
| 1078 |
+
print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
|
| 1079 |
+
|
| 1080 |
+
inference_moe_framepack_sliding_window(
|
| 1081 |
+
condition_pth_path=args.condition_pth,
|
| 1082 |
+
dit_path=args.dit_path,
|
| 1083 |
+
output_path=args.output_path,
|
| 1084 |
+
start_frame=args.start_frame,
|
| 1085 |
+
initial_condition_frames=args.initial_condition_frames,
|
| 1086 |
+
frames_per_generation=args.frames_per_generation,
|
| 1087 |
+
total_frames_to_generate=args.total_frames_to_generate,
|
| 1088 |
+
max_history_frames=args.max_history_frames,
|
| 1089 |
+
device=args.device,
|
| 1090 |
+
prompt=args.prompt,
|
| 1091 |
+
modality_type=args.modality_type,
|
| 1092 |
+
use_real_poses=args.use_real_poses,
|
| 1093 |
+
scene_info_path=args.scene_info_path,
|
| 1094 |
+
# CFG参数
|
| 1095 |
+
use_camera_cfg=args.use_camera_cfg,
|
| 1096 |
+
camera_guidance_scale=args.camera_guidance_scale,
|
| 1097 |
+
text_guidance_scale=args.text_guidance_scale,
|
| 1098 |
+
# MoE参数
|
| 1099 |
+
moe_num_experts=args.moe_num_experts,
|
| 1100 |
+
moe_top_k=args.moe_top_k,
|
| 1101 |
+
moe_hidden_dim=args.moe_hidden_dim,
|
| 1102 |
+
direction=args.direction,
|
| 1103 |
+
use_gt_prompt=args.use_gt_prompt
|
| 1104 |
+
)
|
| 1105 |
+
|
| 1106 |
+
|
| 1107 |
+
if __name__ == "__main__":
|
| 1108 |
+
main()
|
scripts/infer_recam.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from diffsynth import ModelManager, WanVideoReCamMasterPipeline, save_video, VideoData
|
| 5 |
+
import torch, os, imageio, argparse
|
| 6 |
+
from torchvision.transforms import v2
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torchvision
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import numpy as np
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
class Camera(object):
|
| 15 |
+
def __init__(self, c2w):
|
| 16 |
+
c2w_mat = np.array(c2w).reshape(4, 4)
|
| 17 |
+
self.c2w_mat = c2w_mat
|
| 18 |
+
self.w2c_mat = np.linalg.inv(c2w_mat)
|
| 19 |
+
|
| 20 |
+
class TextVideoCameraDataset(torch.utils.data.Dataset):
|
| 21 |
+
def __init__(self, base_path, metadata_path, args, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, condition_frames=40, target_frames=20):
|
| 22 |
+
metadata = pd.read_csv(metadata_path)
|
| 23 |
+
self.path = [os.path.join(base_path, "videos", file_name) for file_name in metadata["file_name"]]
|
| 24 |
+
self.text = metadata["text"].to_list()
|
| 25 |
+
|
| 26 |
+
self.max_num_frames = max_num_frames
|
| 27 |
+
self.frame_interval = frame_interval
|
| 28 |
+
self.num_frames = num_frames
|
| 29 |
+
self.height = height
|
| 30 |
+
self.width = width
|
| 31 |
+
self.is_i2v = is_i2v
|
| 32 |
+
self.args = args
|
| 33 |
+
self.cam_type = self.args.cam_type
|
| 34 |
+
|
| 35 |
+
# 🔧 新增:保存帧数配置
|
| 36 |
+
self.condition_frames = condition_frames
|
| 37 |
+
self.target_frames = target_frames
|
| 38 |
+
|
| 39 |
+
self.frame_process = v2.Compose([
|
| 40 |
+
v2.CenterCrop(size=(height, width)),
|
| 41 |
+
v2.Resize(size=(height, width), antialias=True),
|
| 42 |
+
v2.ToTensor(),
|
| 43 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 44 |
+
])
|
| 45 |
+
|
| 46 |
+
def crop_and_resize(self, image):
|
| 47 |
+
width, height = image.size
|
| 48 |
+
scale = max(self.width / width, self.height / height)
|
| 49 |
+
image = torchvision.transforms.functional.resize(
|
| 50 |
+
image,
|
| 51 |
+
(round(height*scale), round(width*scale)),
|
| 52 |
+
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
| 53 |
+
)
|
| 54 |
+
return image
|
| 55 |
+
|
| 56 |
+
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
|
| 57 |
+
reader = imageio.get_reader(file_path)
|
| 58 |
+
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
|
| 59 |
+
reader.close()
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
frames = []
|
| 63 |
+
first_frame = None
|
| 64 |
+
for frame_id in range(num_frames):
|
| 65 |
+
frame = reader.get_data(start_frame_id + frame_id * interval)
|
| 66 |
+
frame = Image.fromarray(frame)
|
| 67 |
+
frame = self.crop_and_resize(frame)
|
| 68 |
+
if first_frame is None:
|
| 69 |
+
first_frame = np.array(frame)
|
| 70 |
+
frame = frame_process(frame)
|
| 71 |
+
frames.append(frame)
|
| 72 |
+
reader.close()
|
| 73 |
+
|
| 74 |
+
frames = torch.stack(frames, dim=0)
|
| 75 |
+
frames = rearrange(frames, "T C H W -> C T H W")
|
| 76 |
+
|
| 77 |
+
if self.is_i2v:
|
| 78 |
+
return frames, first_frame
|
| 79 |
+
else:
|
| 80 |
+
return frames
|
| 81 |
+
|
| 82 |
+
def is_image(self, file_path):
|
| 83 |
+
file_ext_name = file_path.split(".")[-1]
|
| 84 |
+
if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]:
|
| 85 |
+
return True
|
| 86 |
+
return False
|
| 87 |
+
|
| 88 |
+
def load_video(self, file_path):
|
| 89 |
+
start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0]
|
| 90 |
+
frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process)
|
| 91 |
+
return frames
|
| 92 |
+
|
| 93 |
+
def parse_matrix(self, matrix_str):
|
| 94 |
+
rows = matrix_str.strip().split('] [')
|
| 95 |
+
matrix = []
|
| 96 |
+
for row in rows:
|
| 97 |
+
row = row.replace('[', '').replace(']', '')
|
| 98 |
+
matrix.append(list(map(float, row.split())))
|
| 99 |
+
return np.array(matrix)
|
| 100 |
+
|
| 101 |
+
def get_relative_pose(self, cam_params):
|
| 102 |
+
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
| 103 |
+
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
| 104 |
+
|
| 105 |
+
cam_to_origin = 0
|
| 106 |
+
target_cam_c2w = np.array([
|
| 107 |
+
[1, 0, 0, 0],
|
| 108 |
+
[0, 1, 0, -cam_to_origin],
|
| 109 |
+
[0, 0, 1, 0],
|
| 110 |
+
[0, 0, 0, 1]
|
| 111 |
+
])
|
| 112 |
+
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
| 113 |
+
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
| 114 |
+
ret_poses = np.array(ret_poses, dtype=np.float32)
|
| 115 |
+
return ret_poses
|
| 116 |
+
|
| 117 |
+
def __getitem__(self, data_id):
|
| 118 |
+
text = self.text[data_id]
|
| 119 |
+
path = self.path[data_id]
|
| 120 |
+
video = self.load_video(path)
|
| 121 |
+
if video is None:
|
| 122 |
+
raise ValueError(f"{path} is not a valid video.")
|
| 123 |
+
num_frames = video.shape[1]
|
| 124 |
+
assert num_frames == 81
|
| 125 |
+
data = {"text": text, "video": video, "path": path}
|
| 126 |
+
|
| 127 |
+
# load camera
|
| 128 |
+
tgt_camera_path = "./example_test_data/cameras/camera_extrinsics.json"
|
| 129 |
+
with open(tgt_camera_path, 'r') as file:
|
| 130 |
+
cam_data = json.load(file)
|
| 131 |
+
|
| 132 |
+
# 🔧 修改:生成target_frames长度的相机轨迹
|
| 133 |
+
cam_idx = np.linspace(0, 80, self.target_frames, dtype=int).tolist() # 改为target_frames长度
|
| 134 |
+
traj = [self.parse_matrix(cam_data[f"frame{idx}"][f"cam{int(self.cam_type):02d}"]) for idx in cam_idx]
|
| 135 |
+
traj = np.stack(traj).transpose(0, 2, 1)
|
| 136 |
+
c2ws = []
|
| 137 |
+
for c2w in traj:
|
| 138 |
+
c2w = c2w[:, [1, 2, 0, 3]]
|
| 139 |
+
c2w[:3, 1] *= -1.
|
| 140 |
+
c2w[:3, 3] /= 100
|
| 141 |
+
c2ws.append(c2w)
|
| 142 |
+
tgt_cam_params = [Camera(cam_param) for cam_param in c2ws]
|
| 143 |
+
relative_poses = []
|
| 144 |
+
for i in range(len(tgt_cam_params)):
|
| 145 |
+
relative_pose = self.get_relative_pose([tgt_cam_params[0], tgt_cam_params[i]])
|
| 146 |
+
relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1])
|
| 147 |
+
pose_embedding = torch.stack(relative_poses, dim=0) # [target_frames, 3, 4]
|
| 148 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [target_frames, 12]
|
| 149 |
+
data['camera'] = pose_embedding.to(torch.bfloat16)
|
| 150 |
+
return data
|
| 151 |
+
|
| 152 |
+
def __len__(self):
|
| 153 |
+
return len(self.path)
|
| 154 |
+
|
| 155 |
+
def parse_args():
|
| 156 |
+
parser = argparse.ArgumentParser(description="ReCamMaster Inference")
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--dataset_path",
|
| 159 |
+
type=str,
|
| 160 |
+
default="./example_test_data",
|
| 161 |
+
help="The path of the Dataset.",
|
| 162 |
+
)
|
| 163 |
+
parser.add_argument(
|
| 164 |
+
"--ckpt_path",
|
| 165 |
+
type=str,
|
| 166 |
+
default="/share_zhuyixuan05/zhuyixuan05/recam_future_checkpoint/step1000.ckpt",
|
| 167 |
+
help="Path to save the model.",
|
| 168 |
+
)
|
| 169 |
+
parser.add_argument(
|
| 170 |
+
"--output_dir",
|
| 171 |
+
type=str,
|
| 172 |
+
default="./results",
|
| 173 |
+
help="Path to save the results.",
|
| 174 |
+
)
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--dataloader_num_workers",
|
| 177 |
+
type=int,
|
| 178 |
+
default=1,
|
| 179 |
+
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--cam_type",
|
| 183 |
+
type=str,
|
| 184 |
+
default=1,
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--cfg_scale",
|
| 188 |
+
type=float,
|
| 189 |
+
default=5.0,
|
| 190 |
+
)
|
| 191 |
+
# 🔧 新增:condition和target帧数参数
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--condition_frames",
|
| 194 |
+
type=int,
|
| 195 |
+
default=15,
|
| 196 |
+
help="Number of condition frames",
|
| 197 |
+
)
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--target_frames",
|
| 200 |
+
type=int,
|
| 201 |
+
default=15,
|
| 202 |
+
help="Number of target frames to generate",
|
| 203 |
+
)
|
| 204 |
+
args = parser.parse_args()
|
| 205 |
+
return args
|
| 206 |
+
|
| 207 |
+
if __name__ == '__main__':
|
| 208 |
+
args = parse_args()
|
| 209 |
+
|
| 210 |
+
# 1. Load Wan2.1 pre-trained models
|
| 211 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 212 |
+
model_manager.load_models([
|
| 213 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 214 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 215 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 216 |
+
])
|
| 217 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 218 |
+
|
| 219 |
+
# 2. Initialize additional modules introduced in ReCamMaster
|
| 220 |
+
dim=pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 221 |
+
for block in pipe.dit.blocks:
|
| 222 |
+
block.cam_encoder = nn.Linear(12, dim)
|
| 223 |
+
block.projector = nn.Linear(dim, dim)
|
| 224 |
+
block.cam_encoder.weight.data.zero_()
|
| 225 |
+
block.cam_encoder.bias.data.zero_()
|
| 226 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 227 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 228 |
+
|
| 229 |
+
# 3. Load ReCamMaster checkpoint
|
| 230 |
+
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
| 231 |
+
pipe.dit.load_state_dict(state_dict, strict=True)
|
| 232 |
+
pipe.to("cuda")
|
| 233 |
+
pipe.to(dtype=torch.bfloat16)
|
| 234 |
+
|
| 235 |
+
output_dir = os.path.join(args.output_dir, f"cam_type{args.cam_type}")
|
| 236 |
+
if not os.path.exists(output_dir):
|
| 237 |
+
os.makedirs(output_dir)
|
| 238 |
+
|
| 239 |
+
# 4. Prepare test data (source video, target camera, target trajectory)
|
| 240 |
+
dataset = TextVideoCameraDataset(
|
| 241 |
+
args.dataset_path,
|
| 242 |
+
os.path.join(args.dataset_path, "metadata.csv"),
|
| 243 |
+
args,
|
| 244 |
+
condition_frames=args.condition_frames, # 🔧 传递参数
|
| 245 |
+
target_frames=args.target_frames, # 🔧 传递参数
|
| 246 |
+
)
|
| 247 |
+
dataloader = torch.utils.data.DataLoader(
|
| 248 |
+
dataset,
|
| 249 |
+
shuffle=False,
|
| 250 |
+
batch_size=1,
|
| 251 |
+
num_workers=args.dataloader_num_workers
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# 5. Inference
|
| 255 |
+
for batch_idx, batch in enumerate(dataloader):
|
| 256 |
+
target_text = batch["text"]
|
| 257 |
+
source_video = batch["video"]
|
| 258 |
+
target_camera = batch["camera"]
|
| 259 |
+
|
| 260 |
+
video = pipe(
|
| 261 |
+
prompt=target_text,
|
| 262 |
+
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的��景,三条腿,背景人很多,倒着走",
|
| 263 |
+
source_video=source_video,
|
| 264 |
+
target_camera=target_camera,
|
| 265 |
+
cfg_scale=args.cfg_scale,
|
| 266 |
+
num_inference_steps=50,
|
| 267 |
+
seed=0,
|
| 268 |
+
tiled=True,
|
| 269 |
+
condition_frames=args.condition_frames,
|
| 270 |
+
target_frames=args.target_frames,
|
| 271 |
+
)
|
| 272 |
+
save_video(video, os.path.join(output_dir, f"video{batch_idx}.mp4"), fps=30, quality=5)
|
scripts/infer_rlbench.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import imageio
|
| 6 |
+
import json
|
| 7 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 8 |
+
import argparse
|
| 9 |
+
from torchvision.transforms import v2
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 15 |
+
"""
|
| 16 |
+
从pth文件加载预编码的视频数据
|
| 17 |
+
Args:
|
| 18 |
+
pth_path: pth文件路径
|
| 19 |
+
start_frame: 起始帧索引(基于压缩后的latent帧数)
|
| 20 |
+
num_frames: 需要的帧数(基于压缩后的latent帧数)
|
| 21 |
+
Returns:
|
| 22 |
+
condition_latents: [C, T, H, W] 格式的latent tensor
|
| 23 |
+
"""
|
| 24 |
+
print(f"Loading encoded video from {pth_path}")
|
| 25 |
+
|
| 26 |
+
# 加载编码数据
|
| 27 |
+
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
| 28 |
+
|
| 29 |
+
# 获取latent数据
|
| 30 |
+
full_latents = encoded_data['latents'] # [C, T, H, W]
|
| 31 |
+
|
| 32 |
+
print(f"Full latents shape: {full_latents.shape}")
|
| 33 |
+
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
|
| 34 |
+
|
| 35 |
+
# 检查帧数是否足够
|
| 36 |
+
if start_frame + num_frames > full_latents.shape[1]:
|
| 37 |
+
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
|
| 38 |
+
|
| 39 |
+
# 提取指定帧数
|
| 40 |
+
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
|
| 41 |
+
|
| 42 |
+
print(f"Extracted condition latents shape: {condition_latents.shape}")
|
| 43 |
+
|
| 44 |
+
return condition_latents, encoded_data
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 48 |
+
"""
|
| 49 |
+
计算相机B相对于相机A的相对位姿矩阵
|
| 50 |
+
"""
|
| 51 |
+
assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
|
| 52 |
+
assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
|
| 53 |
+
|
| 54 |
+
if use_torch:
|
| 55 |
+
if not isinstance(pose_a, torch.Tensor):
|
| 56 |
+
pose_a = torch.from_numpy(pose_a).float()
|
| 57 |
+
if not isinstance(pose_b, torch.Tensor):
|
| 58 |
+
pose_b = torch.from_numpy(pose_b).float()
|
| 59 |
+
|
| 60 |
+
pose_a_inv = torch.inverse(pose_a)
|
| 61 |
+
relative_pose = torch.matmul(pose_b, pose_a_inv)
|
| 62 |
+
else:
|
| 63 |
+
if not isinstance(pose_a, np.ndarray):
|
| 64 |
+
pose_a = np.array(pose_a, dtype=np.float32)
|
| 65 |
+
if not isinstance(pose_b, np.ndarray):
|
| 66 |
+
pose_b = np.array(pose_b, dtype=np.float32)
|
| 67 |
+
|
| 68 |
+
pose_a_inv = np.linalg.inv(pose_a)
|
| 69 |
+
relative_pose = np.matmul(pose_b, pose_a_inv)
|
| 70 |
+
|
| 71 |
+
return relative_pose
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def generate_camera_poses_from_data(cam_data, start_frame, condition_frames, target_frames):
|
| 75 |
+
"""
|
| 76 |
+
从实际相机数据生成pose embeddings
|
| 77 |
+
Args:
|
| 78 |
+
cam_data: 相机外参数据
|
| 79 |
+
start_frame: 起始帧(原始帧索引)
|
| 80 |
+
condition_frames: 条件帧数(压缩后)
|
| 81 |
+
target_frames: 目标帧数(压缩后)
|
| 82 |
+
"""
|
| 83 |
+
time_compression_ratio = 4
|
| 84 |
+
total_frames = condition_frames + target_frames
|
| 85 |
+
|
| 86 |
+
# 获取相机外参序列
|
| 87 |
+
cam_extrinsic = cam_data # [N, 4, 4]
|
| 88 |
+
|
| 89 |
+
# 计算原始帧索引
|
| 90 |
+
start_frame_original = start_frame * time_compression_ratio
|
| 91 |
+
end_frame_original = (start_frame + total_frames) * time_compression_ratio
|
| 92 |
+
|
| 93 |
+
print(f"Using camera data from frame {start_frame_original} to {end_frame_original}")
|
| 94 |
+
|
| 95 |
+
# 计算相对pose
|
| 96 |
+
relative_poses = []
|
| 97 |
+
for i in range(total_frames):
|
| 98 |
+
frame_idx = start_frame_original + i * time_compression_ratio
|
| 99 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
relative_poses.append(torch.as_tensor(cam_prev)) # 取前3行
|
| 107 |
+
|
| 108 |
+
print(cam_prev)
|
| 109 |
+
# 组装pose embedding
|
| 110 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 111 |
+
# print('pose_embedding init:',pose_embedding[0])
|
| 112 |
+
print('pose_embedding:',pose_embedding)
|
| 113 |
+
# assert False
|
| 114 |
+
|
| 115 |
+
# pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12]
|
| 116 |
+
|
| 117 |
+
# 添加mask信息
|
| 118 |
+
mask = torch.zeros(total_frames, dtype=torch.float32)
|
| 119 |
+
mask[:condition_frames] = 1.0 # condition frames
|
| 120 |
+
mask = mask.view(-1, 1)
|
| 121 |
+
|
| 122 |
+
# 组合pose和mask
|
| 123 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13]
|
| 124 |
+
|
| 125 |
+
print(f"Generated camera embedding shape: {camera_embedding.shape}")
|
| 126 |
+
|
| 127 |
+
return camera_embedding.to(torch.bfloat16)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def generate_camera_poses(direction="forward", target_frames=10, condition_frames=20):
|
| 131 |
+
"""
|
| 132 |
+
根据指定方向生成相机pose序列(合成数据)
|
| 133 |
+
"""
|
| 134 |
+
time_compression_ratio = 4
|
| 135 |
+
total_frames = condition_frames + target_frames
|
| 136 |
+
|
| 137 |
+
poses = []
|
| 138 |
+
|
| 139 |
+
for i in range(total_frames):
|
| 140 |
+
t = i / max(1, total_frames - 1) # 0 to 1
|
| 141 |
+
|
| 142 |
+
# 创建变换矩阵
|
| 143 |
+
pose = np.eye(4, dtype=np.float32)
|
| 144 |
+
|
| 145 |
+
if direction == "forward":
|
| 146 |
+
# 前进:沿z轴负方向移动
|
| 147 |
+
pose[2, 3] = -t * 0.04
|
| 148 |
+
print('forward!')
|
| 149 |
+
|
| 150 |
+
elif direction == "backward":
|
| 151 |
+
# 后退:沿z轴正方向移动
|
| 152 |
+
pose[2, 3] = t * 2.0
|
| 153 |
+
|
| 154 |
+
elif direction == "left_turn":
|
| 155 |
+
# 左转:前进 + 绕y轴旋转
|
| 156 |
+
pose[2, 3] = -t * 0.03 # 前进
|
| 157 |
+
pose[0, 3] = t * 0.02 # 左移
|
| 158 |
+
# 添加旋转
|
| 159 |
+
yaw = t * 1
|
| 160 |
+
pose[0, 0] = np.cos(yaw)
|
| 161 |
+
pose[0, 2] = np.sin(yaw)
|
| 162 |
+
pose[2, 0] = -np.sin(yaw)
|
| 163 |
+
pose[2, 2] = np.cos(yaw)
|
| 164 |
+
|
| 165 |
+
elif direction == "right_turn":
|
| 166 |
+
# 右转:前进 + 绕y轴反向旋转
|
| 167 |
+
pose[2, 3] = -t * 0.03 # 前进
|
| 168 |
+
pose[0, 3] = -t * 0.02 # 右移
|
| 169 |
+
# 添加旋转
|
| 170 |
+
yaw = - t * 1
|
| 171 |
+
pose[0, 0] = np.cos(yaw)
|
| 172 |
+
pose[0, 2] = np.sin(yaw)
|
| 173 |
+
pose[2, 0] = -np.sin(yaw)
|
| 174 |
+
pose[2, 2] = np.cos(yaw)
|
| 175 |
+
|
| 176 |
+
poses.append(pose)
|
| 177 |
+
|
| 178 |
+
# 计算相对pose
|
| 179 |
+
relative_poses = []
|
| 180 |
+
for i in range(len(poses) - 1):
|
| 181 |
+
relative_pose = compute_relative_pose(poses[i], poses[i + 1])
|
| 182 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :])) # 取前3行
|
| 183 |
+
|
| 184 |
+
# 为了匹配模型输入,需要确保帧数正确
|
| 185 |
+
if len(relative_poses) < total_frames:
|
| 186 |
+
# 补充最后一帧
|
| 187 |
+
relative_poses.append(relative_poses[-1])
|
| 188 |
+
|
| 189 |
+
pose_embedding = torch.stack(relative_poses[:total_frames], dim=0)
|
| 190 |
+
|
| 191 |
+
print('pose_embedding init:',pose_embedding[0])
|
| 192 |
+
|
| 193 |
+
print('pose_embedding:',pose_embedding[-5:])
|
| 194 |
+
|
| 195 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12]
|
| 196 |
+
|
| 197 |
+
# 添加mask信息
|
| 198 |
+
mask = torch.zeros(total_frames, dtype=torch.float32)
|
| 199 |
+
mask[:condition_frames] = 1.0 # condition frames
|
| 200 |
+
mask = mask.view(-1, 1)
|
| 201 |
+
|
| 202 |
+
# 组合pose和mask
|
| 203 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13]
|
| 204 |
+
|
| 205 |
+
print(f"Generated {direction} movement poses:")
|
| 206 |
+
print(f" Total frames: {total_frames}")
|
| 207 |
+
print(f" Camera embedding shape: {camera_embedding.shape}")
|
| 208 |
+
|
| 209 |
+
return camera_embedding.to(torch.bfloat16)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def inference_sekai_video_from_pth(
|
| 213 |
+
condition_pth_path,
|
| 214 |
+
dit_path,
|
| 215 |
+
output_path="sekai/infer_results/output_sekai.mp4",
|
| 216 |
+
start_frame=0,
|
| 217 |
+
condition_frames=10, # 压缩后的帧数
|
| 218 |
+
target_frames=2, # 压缩后的帧数
|
| 219 |
+
device="cuda",
|
| 220 |
+
prompt="a robotic arm executing precise manipulation tasks on a clean, organized desk",
|
| 221 |
+
direction="forward",
|
| 222 |
+
use_real_poses=True
|
| 223 |
+
):
|
| 224 |
+
"""
|
| 225 |
+
从pth文件进行Sekai视频推理
|
| 226 |
+
"""
|
| 227 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 228 |
+
|
| 229 |
+
print(f"Setting up models for {direction} movement...")
|
| 230 |
+
|
| 231 |
+
# 1. Load models
|
| 232 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 233 |
+
model_manager.load_models([
|
| 234 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 235 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 236 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 237 |
+
])
|
| 238 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 239 |
+
|
| 240 |
+
# Add camera components to DiT
|
| 241 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 242 |
+
for block in pipe.dit.blocks:
|
| 243 |
+
block.cam_encoder = nn.Linear(30, dim) # 13维embedding (12D pose + 1D mask)
|
| 244 |
+
block.projector = nn.Linear(dim, dim)
|
| 245 |
+
block.cam_encoder.weight.data.zero_()
|
| 246 |
+
block.cam_encoder.bias.data.zero_()
|
| 247 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 248 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 249 |
+
|
| 250 |
+
# Load trained DiT weights
|
| 251 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 252 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=True)
|
| 253 |
+
pipe = pipe.to(device)
|
| 254 |
+
pipe.scheduler.set_timesteps(50)
|
| 255 |
+
|
| 256 |
+
print("Loading condition video from pth...")
|
| 257 |
+
|
| 258 |
+
# Load condition video from pth
|
| 259 |
+
condition_latents, encoded_data = load_encoded_video_from_pth(
|
| 260 |
+
condition_pth_path,
|
| 261 |
+
start_frame=start_frame,
|
| 262 |
+
num_frames=condition_frames
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
condition_latents = condition_latents.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
|
| 266 |
+
|
| 267 |
+
print("Processing poses...")
|
| 268 |
+
|
| 269 |
+
# 生成相机pose embedding
|
| 270 |
+
if use_real_poses and 'cam_emb' in encoded_data:
|
| 271 |
+
print("Using real camera poses from data")
|
| 272 |
+
camera_embedding = generate_camera_poses_from_data(
|
| 273 |
+
encoded_data['cam_emb'],
|
| 274 |
+
start_frame=start_frame,
|
| 275 |
+
condition_frames=condition_frames,
|
| 276 |
+
target_frames=target_frames
|
| 277 |
+
)
|
| 278 |
+
else:
|
| 279 |
+
print(f"Using synthetic {direction} poses")
|
| 280 |
+
camera_embedding = generate_camera_poses(
|
| 281 |
+
direction=direction,
|
| 282 |
+
target_frames=target_frames,
|
| 283 |
+
condition_frames=condition_frames
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
camera_embedding = camera_embedding.unsqueeze(0).to(device, dtype=torch.bfloat16)
|
| 289 |
+
|
| 290 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 291 |
+
|
| 292 |
+
print("Encoding prompt...")
|
| 293 |
+
|
| 294 |
+
# Encode text prompt
|
| 295 |
+
prompt_emb = pipe.encode_prompt(prompt)
|
| 296 |
+
|
| 297 |
+
print("Generating video...")
|
| 298 |
+
|
| 299 |
+
# Generate target latents
|
| 300 |
+
batch_size = 1
|
| 301 |
+
channels = condition_latents.shape[1]
|
| 302 |
+
latent_height = condition_latents.shape[3]
|
| 303 |
+
latent_width = condition_latents.shape[4]
|
| 304 |
+
|
| 305 |
+
# 空间裁剪以节省内存(如果需要)
|
| 306 |
+
target_height, target_width = 64, 64
|
| 307 |
+
|
| 308 |
+
if latent_height > target_height or latent_width > target_width:
|
| 309 |
+
# 中心裁剪
|
| 310 |
+
h_start = (latent_height - target_height) // 2
|
| 311 |
+
w_start = (latent_width - target_width) // 2
|
| 312 |
+
condition_latents = condition_latents[:, :, :,
|
| 313 |
+
h_start:h_start+target_height,
|
| 314 |
+
w_start:w_start+target_width]
|
| 315 |
+
latent_height = target_height
|
| 316 |
+
latent_width = target_width
|
| 317 |
+
|
| 318 |
+
# Initialize target latents with noise
|
| 319 |
+
target_latents = torch.randn(
|
| 320 |
+
batch_size, channels, target_frames, latent_height, latent_width,
|
| 321 |
+
device=device, dtype=pipe.torch_dtype
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
print(f"Condition latents shape: {condition_latents.shape}")
|
| 325 |
+
print(f"Target latents shape: {target_latents.shape}")
|
| 326 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 327 |
+
|
| 328 |
+
# Combine condition and target latents
|
| 329 |
+
combined_latents = torch.cat([condition_latents, target_latents], dim=2)
|
| 330 |
+
print(f"Combined latents shape: {combined_latents.shape}")
|
| 331 |
+
|
| 332 |
+
# Prepare extra inputs
|
| 333 |
+
extra_input = pipe.prepare_extra_input(combined_latents)
|
| 334 |
+
|
| 335 |
+
# Denoising loop
|
| 336 |
+
timesteps = pipe.scheduler.timesteps
|
| 337 |
+
|
| 338 |
+
for i, timestep in enumerate(timesteps):
|
| 339 |
+
print(f"Denoising step {i+1}/{len(timesteps)}")
|
| 340 |
+
|
| 341 |
+
# Prepare timestep
|
| 342 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
|
| 343 |
+
|
| 344 |
+
# Predict noise
|
| 345 |
+
with torch.no_grad():
|
| 346 |
+
noise_pred = pipe.dit(
|
| 347 |
+
combined_latents,
|
| 348 |
+
timestep=timestep_tensor,
|
| 349 |
+
cam_emb=camera_embedding,
|
| 350 |
+
**prompt_emb,
|
| 351 |
+
**extra_input
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# Update only target part
|
| 355 |
+
target_noise_pred = noise_pred[:, :, condition_frames:, :, :]
|
| 356 |
+
target_latents = pipe.scheduler.step(target_noise_pred, timestep, target_latents)
|
| 357 |
+
|
| 358 |
+
# Update combined latents
|
| 359 |
+
combined_latents[:, :, condition_frames:, :, :] = target_latents
|
| 360 |
+
|
| 361 |
+
print("Decoding video...")
|
| 362 |
+
|
| 363 |
+
# Decode final video
|
| 364 |
+
final_video = torch.cat([condition_latents, target_latents], dim=2)
|
| 365 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 366 |
+
|
| 367 |
+
# Save video
|
| 368 |
+
print(f"Saving video to {output_path}")
|
| 369 |
+
|
| 370 |
+
# Convert to numpy and save
|
| 371 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
|
| 372 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1) # Denormalize
|
| 373 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 374 |
+
|
| 375 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 376 |
+
for frame in video_np:
|
| 377 |
+
writer.append_data(frame)
|
| 378 |
+
|
| 379 |
+
print(f"Video generation completed! Saved to {output_path}")
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def main():
|
| 383 |
+
parser = argparse.ArgumentParser(description="Sekai Video Generation Inference from PTH")
|
| 384 |
+
parser.add_argument("--condition_pth", type=str,
|
| 385 |
+
default="/share_zhuyixuan05/zhuyixuan05/rlbench/OpenBox_demo_49/encoded_video.pth")
|
| 386 |
+
parser.add_argument("--start_frame", type=int, default=0,
|
| 387 |
+
help="Starting frame index (compressed latent frames)")
|
| 388 |
+
parser.add_argument("--condition_frames", type=int, default=8,
|
| 389 |
+
help="Number of condition frames (compressed latent frames)")
|
| 390 |
+
parser.add_argument("--target_frames", type=int, default=8,
|
| 391 |
+
help="Number of target frames to generate (compressed latent frames)")
|
| 392 |
+
parser.add_argument("--direction", type=str, default="left_turn",
|
| 393 |
+
choices=["forward", "backward", "left_turn", "right_turn"],
|
| 394 |
+
help="Direction of camera movement (if not using real poses)")
|
| 395 |
+
parser.add_argument("--use_real_poses", default=False,
|
| 396 |
+
help="Use real camera poses from data")
|
| 397 |
+
parser.add_argument("--dit_path", type=str, default="/home/zhuyixuan05/ReCamMaster/RLBench-train/step2000_dynamic.ckpt",
|
| 398 |
+
help="Path to trained DiT checkpoint")
|
| 399 |
+
parser.add_argument("--output_path", type=str, default='/home/zhuyixuan05/ReCamMaster/rlbench/infer_results/output_rl_2.mp4',
|
| 400 |
+
help="Output video path")
|
| 401 |
+
parser.add_argument("--prompt", type=str,
|
| 402 |
+
default="a robotic arm executing precise manipulation tasks on a clean, organized desk",
|
| 403 |
+
help="Text prompt for generation")
|
| 404 |
+
parser.add_argument("--device", type=str, default="cuda",
|
| 405 |
+
help="Device to run inference on")
|
| 406 |
+
|
| 407 |
+
args = parser.parse_args()
|
| 408 |
+
|
| 409 |
+
# 生成输出路径
|
| 410 |
+
if args.output_path is None:
|
| 411 |
+
pth_filename = os.path.basename(args.condition_pth)
|
| 412 |
+
name_parts = os.path.splitext(pth_filename)
|
| 413 |
+
output_dir = "rlbench/infer_results"
|
| 414 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 415 |
+
|
| 416 |
+
if args.use_real_poses:
|
| 417 |
+
output_filename = f"{name_parts[0]}_real_poses_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4"
|
| 418 |
+
else:
|
| 419 |
+
output_filename = f"{name_parts[0]}_{args.direction}_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4"
|
| 420 |
+
|
| 421 |
+
output_path = os.path.join(output_dir, output_filename)
|
| 422 |
+
else:
|
| 423 |
+
output_path = args.output_path
|
| 424 |
+
|
| 425 |
+
print(f"Input pth: {args.condition_pth}")
|
| 426 |
+
print(f"Start frame: {args.start_frame} (compressed)")
|
| 427 |
+
print(f"Condition frames: {args.condition_frames} (compressed, original: {args.condition_frames * 4})")
|
| 428 |
+
print(f"Target frames: {args.target_frames} (compressed, original: {args.target_frames * 4})")
|
| 429 |
+
print(f"Use real poses: {args.use_real_poses}")
|
| 430 |
+
print(f"Output video will be saved to: {output_path}")
|
| 431 |
+
|
| 432 |
+
inference_sekai_video_from_pth(
|
| 433 |
+
condition_pth_path=args.condition_pth,
|
| 434 |
+
dit_path=args.dit_path,
|
| 435 |
+
output_path=output_path,
|
| 436 |
+
start_frame=args.start_frame,
|
| 437 |
+
condition_frames=args.condition_frames,
|
| 438 |
+
target_frames=args.target_frames,
|
| 439 |
+
device=args.device,
|
| 440 |
+
prompt=args.prompt,
|
| 441 |
+
direction=args.direction,
|
| 442 |
+
use_real_poses=args.use_real_poses
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
if __name__ == "__main__":
|
| 447 |
+
main()
|
scripts/infer_sekai.py
ADDED
|
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import imageio
|
| 6 |
+
import json
|
| 7 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 8 |
+
import argparse
|
| 9 |
+
from torchvision.transforms import v2
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 15 |
+
"""
|
| 16 |
+
从pth文件加载预编码的视频数据
|
| 17 |
+
Args:
|
| 18 |
+
pth_path: pth文件路径
|
| 19 |
+
start_frame: 起始帧索引(基于压缩后的latent帧数)
|
| 20 |
+
num_frames: 需要的帧数(基于压缩后的latent帧数)
|
| 21 |
+
Returns:
|
| 22 |
+
condition_latents: [C, T, H, W] 格式的latent tensor
|
| 23 |
+
"""
|
| 24 |
+
print(f"Loading encoded video from {pth_path}")
|
| 25 |
+
|
| 26 |
+
# 加载编码数据
|
| 27 |
+
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
| 28 |
+
|
| 29 |
+
# 获取latent数据
|
| 30 |
+
full_latents = encoded_data['latents'] # [C, T, H, W]
|
| 31 |
+
|
| 32 |
+
print(f"Full latents shape: {full_latents.shape}")
|
| 33 |
+
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
|
| 34 |
+
|
| 35 |
+
# 检查帧数是否足够
|
| 36 |
+
if start_frame + num_frames > full_latents.shape[1]:
|
| 37 |
+
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
|
| 38 |
+
|
| 39 |
+
# 提取指定帧数
|
| 40 |
+
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
|
| 41 |
+
|
| 42 |
+
print(f"Extracted condition latents shape: {condition_latents.shape}")
|
| 43 |
+
|
| 44 |
+
return condition_latents, encoded_data
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 48 |
+
"""
|
| 49 |
+
计算相机B相对于相机A的相对位姿矩阵
|
| 50 |
+
"""
|
| 51 |
+
assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
|
| 52 |
+
assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
|
| 53 |
+
|
| 54 |
+
if use_torch:
|
| 55 |
+
if not isinstance(pose_a, torch.Tensor):
|
| 56 |
+
pose_a = torch.from_numpy(pose_a).float()
|
| 57 |
+
if not isinstance(pose_b, torch.Tensor):
|
| 58 |
+
pose_b = torch.from_numpy(pose_b).float()
|
| 59 |
+
|
| 60 |
+
pose_a_inv = torch.inverse(pose_a)
|
| 61 |
+
relative_pose = torch.matmul(pose_b, pose_a_inv)
|
| 62 |
+
else:
|
| 63 |
+
if not isinstance(pose_a, np.ndarray):
|
| 64 |
+
pose_a = np.array(pose_a, dtype=np.float32)
|
| 65 |
+
if not isinstance(pose_b, np.ndarray):
|
| 66 |
+
pose_b = np.array(pose_b, dtype=np.float32)
|
| 67 |
+
|
| 68 |
+
pose_a_inv = np.linalg.inv(pose_a)
|
| 69 |
+
relative_pose = np.matmul(pose_b, pose_a_inv)
|
| 70 |
+
|
| 71 |
+
return relative_pose
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def generate_camera_poses_from_data(cam_data, start_frame, condition_frames, target_frames):
|
| 75 |
+
"""
|
| 76 |
+
从实际相机数据生成pose embeddings
|
| 77 |
+
Args:
|
| 78 |
+
cam_data: 相机外参数据
|
| 79 |
+
start_frame: 起始帧(原始帧索引)
|
| 80 |
+
condition_frames: 条件帧数(压缩后)
|
| 81 |
+
target_frames: 目标帧数(压缩后)
|
| 82 |
+
"""
|
| 83 |
+
time_compression_ratio = 4
|
| 84 |
+
total_frames = condition_frames + target_frames
|
| 85 |
+
|
| 86 |
+
# 获取相机外参序列
|
| 87 |
+
cam_extrinsic = cam_data['extrinsic'] # [N, 4, 4]
|
| 88 |
+
|
| 89 |
+
# 计算原始帧索引
|
| 90 |
+
start_frame_original = start_frame * time_compression_ratio
|
| 91 |
+
end_frame_original = (start_frame + total_frames) * time_compression_ratio
|
| 92 |
+
|
| 93 |
+
print(f"Using camera data from frame {start_frame_original} to {end_frame_original}")
|
| 94 |
+
|
| 95 |
+
# 计算相对pose
|
| 96 |
+
relative_poses = []
|
| 97 |
+
for i in range(total_frames):
|
| 98 |
+
frame_idx = start_frame_original + i * time_compression_ratio
|
| 99 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 100 |
+
|
| 101 |
+
if next_frame_idx >= len(cam_extrinsic):
|
| 102 |
+
print('out of temporal range!!!')
|
| 103 |
+
# 如果超出范围,使用最后一个可用的pose
|
| 104 |
+
relative_poses.append(relative_poses[-1] if relative_poses else torch.zeros(3, 4))
|
| 105 |
+
else:
|
| 106 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 107 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 108 |
+
|
| 109 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 110 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :])) # 取前3行
|
| 111 |
+
|
| 112 |
+
print(cam_prev)
|
| 113 |
+
# 组装pose embedding
|
| 114 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 115 |
+
# print('pose_embedding init:',pose_embedding[0])
|
| 116 |
+
print('pose_embedding:',pose_embedding)
|
| 117 |
+
assert False
|
| 118 |
+
|
| 119 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12]
|
| 120 |
+
|
| 121 |
+
# 添加mask信息
|
| 122 |
+
mask = torch.zeros(total_frames, dtype=torch.float32)
|
| 123 |
+
mask[:condition_frames] = 1.0 # condition frames
|
| 124 |
+
mask = mask.view(-1, 1)
|
| 125 |
+
|
| 126 |
+
# 组合pose和mask
|
| 127 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13]
|
| 128 |
+
|
| 129 |
+
print(f"Generated camera embedding shape: {camera_embedding.shape}")
|
| 130 |
+
|
| 131 |
+
return camera_embedding.to(torch.bfloat16)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def generate_camera_poses(direction="forward", target_frames=10, condition_frames=20):
|
| 135 |
+
"""
|
| 136 |
+
根据指定方向生成相机pose序列(合成数据)
|
| 137 |
+
"""
|
| 138 |
+
time_compression_ratio = 4
|
| 139 |
+
total_frames = condition_frames + target_frames
|
| 140 |
+
|
| 141 |
+
poses = []
|
| 142 |
+
|
| 143 |
+
for i in range(total_frames):
|
| 144 |
+
t = i / max(1, total_frames - 1) # 0 to 1
|
| 145 |
+
|
| 146 |
+
# 创建变换矩阵
|
| 147 |
+
pose = np.eye(4, dtype=np.float32)
|
| 148 |
+
|
| 149 |
+
if direction == "forward":
|
| 150 |
+
# 前进:沿z轴负方向移动
|
| 151 |
+
pose[2, 3] = -t * 0.04
|
| 152 |
+
print('forward!')
|
| 153 |
+
|
| 154 |
+
elif direction == "backward":
|
| 155 |
+
# 后退:沿z轴正方向移动
|
| 156 |
+
pose[2, 3] = t * 2.0
|
| 157 |
+
|
| 158 |
+
elif direction == "left_turn":
|
| 159 |
+
# 左转:前进 + 绕y轴旋转
|
| 160 |
+
pose[2, 3] = -t * 0.03 # 前进
|
| 161 |
+
pose[0, 3] = t * 0.02 # 左移
|
| 162 |
+
# 添加旋转
|
| 163 |
+
yaw = t * 1
|
| 164 |
+
pose[0, 0] = np.cos(yaw)
|
| 165 |
+
pose[0, 2] = np.sin(yaw)
|
| 166 |
+
pose[2, 0] = -np.sin(yaw)
|
| 167 |
+
pose[2, 2] = np.cos(yaw)
|
| 168 |
+
|
| 169 |
+
elif direction == "right_turn":
|
| 170 |
+
# 右转:前进 + 绕y轴反向旋转
|
| 171 |
+
pose[2, 3] = -t * 0.03 # 前进
|
| 172 |
+
pose[0, 3] = -t * 0.02 # 右移
|
| 173 |
+
# 添加旋转
|
| 174 |
+
yaw = - t * 1
|
| 175 |
+
pose[0, 0] = np.cos(yaw)
|
| 176 |
+
pose[0, 2] = np.sin(yaw)
|
| 177 |
+
pose[2, 0] = -np.sin(yaw)
|
| 178 |
+
pose[2, 2] = np.cos(yaw)
|
| 179 |
+
|
| 180 |
+
poses.append(pose)
|
| 181 |
+
|
| 182 |
+
# 计算相对pose
|
| 183 |
+
relative_poses = []
|
| 184 |
+
for i in range(len(poses) - 1):
|
| 185 |
+
relative_pose = compute_relative_pose(poses[i], poses[i + 1])
|
| 186 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :])) # 取前3行
|
| 187 |
+
|
| 188 |
+
# 为了匹配模型输入,需要确保帧数正确
|
| 189 |
+
if len(relative_poses) < total_frames:
|
| 190 |
+
# 补充最后一帧
|
| 191 |
+
relative_poses.append(relative_poses[-1])
|
| 192 |
+
|
| 193 |
+
pose_embedding = torch.stack(relative_poses[:total_frames], dim=0)
|
| 194 |
+
|
| 195 |
+
print('pose_embedding init:',pose_embedding[0])
|
| 196 |
+
|
| 197 |
+
print('pose_embedding:',pose_embedding[-5:])
|
| 198 |
+
|
| 199 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12]
|
| 200 |
+
|
| 201 |
+
# 添加mask信息
|
| 202 |
+
mask = torch.zeros(total_frames, dtype=torch.float32)
|
| 203 |
+
mask[:condition_frames] = 1.0 # condition frames
|
| 204 |
+
mask = mask.view(-1, 1)
|
| 205 |
+
|
| 206 |
+
# 组合pose和mask
|
| 207 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13]
|
| 208 |
+
|
| 209 |
+
print(f"Generated {direction} movement poses:")
|
| 210 |
+
print(f" Total frames: {total_frames}")
|
| 211 |
+
print(f" Camera embedding shape: {camera_embedding.shape}")
|
| 212 |
+
|
| 213 |
+
return camera_embedding.to(torch.bfloat16)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def inference_sekai_video_from_pth(
|
| 217 |
+
condition_pth_path,
|
| 218 |
+
dit_path,
|
| 219 |
+
output_path="sekai/infer_results/output_sekai.mp4",
|
| 220 |
+
start_frame=0,
|
| 221 |
+
condition_frames=10, # 压缩后的帧数
|
| 222 |
+
target_frames=2, # 压缩后的帧数
|
| 223 |
+
device="cuda",
|
| 224 |
+
prompt="A video of a scene shot using a pedestrian's front camera while walking",
|
| 225 |
+
direction="forward",
|
| 226 |
+
use_real_poses=True
|
| 227 |
+
):
|
| 228 |
+
"""
|
| 229 |
+
从pth文件进行Sekai视频推理
|
| 230 |
+
"""
|
| 231 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 232 |
+
|
| 233 |
+
print(f"Setting up models for {direction} movement...")
|
| 234 |
+
|
| 235 |
+
# 1. Load models
|
| 236 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 237 |
+
model_manager.load_models([
|
| 238 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 239 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 240 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 241 |
+
])
|
| 242 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 243 |
+
|
| 244 |
+
# Add camera components to DiT
|
| 245 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 246 |
+
for block in pipe.dit.blocks:
|
| 247 |
+
block.cam_encoder = nn.Linear(13, dim) # 13维embedding (12D pose + 1D mask)
|
| 248 |
+
block.projector = nn.Linear(dim, dim)
|
| 249 |
+
block.cam_encoder.weight.data.zero_()
|
| 250 |
+
block.cam_encoder.bias.data.zero_()
|
| 251 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 252 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 253 |
+
|
| 254 |
+
# Load trained DiT weights
|
| 255 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 256 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=True)
|
| 257 |
+
pipe = pipe.to(device)
|
| 258 |
+
pipe.scheduler.set_timesteps(50)
|
| 259 |
+
|
| 260 |
+
print("Loading condition video from pth...")
|
| 261 |
+
|
| 262 |
+
# Load condition video from pth
|
| 263 |
+
condition_latents, encoded_data = load_encoded_video_from_pth(
|
| 264 |
+
condition_pth_path,
|
| 265 |
+
start_frame=start_frame,
|
| 266 |
+
num_frames=condition_frames
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
condition_latents = condition_latents.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
|
| 270 |
+
|
| 271 |
+
print("Processing poses...")
|
| 272 |
+
|
| 273 |
+
# 生成相机pose embedding
|
| 274 |
+
if use_real_poses and 'cam_emb' in encoded_data:
|
| 275 |
+
print("Using real camera poses from data")
|
| 276 |
+
camera_embedding = generate_camera_poses_from_data(
|
| 277 |
+
encoded_data['cam_emb'],
|
| 278 |
+
start_frame=start_frame,
|
| 279 |
+
condition_frames=condition_frames,
|
| 280 |
+
target_frames=target_frames
|
| 281 |
+
)
|
| 282 |
+
else:
|
| 283 |
+
print(f"Using synthetic {direction} poses")
|
| 284 |
+
camera_embedding = generate_camera_poses(
|
| 285 |
+
direction=direction,
|
| 286 |
+
target_frames=target_frames,
|
| 287 |
+
condition_frames=condition_frames
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# camera_embedding = torch.tensor([
|
| 291 |
+
# [ 9.9992e-01, 5.7823e-04, -1.2807e-02, -6.4978e-03, -6.1466e-04,
|
| 292 |
+
# 1.0000e+00, -2.8406e-03, -7.1422e-04, 1.2806e-02, 2.8482e-03,
|
| 293 |
+
# 9.9991e-01, -1.4152e-02, 1.0000e+00],
|
| 294 |
+
# [ 9.9993e-01, 5.0678e-04, -1.1601e-02, -5.7938e-03, -5.3597e-04,
|
| 295 |
+
# 1.0000e+00, -2.5129e-03, -5.6941e-04, 1.1600e-02, 2.5189e-03,
|
| 296 |
+
# 9.9993e-01, -1.4078e-02, 1.0000e+00],
|
| 297 |
+
# [ 9.9992e-01, 4.4420e-04, -1.2374e-02, -6.2723e-03, -4.8356e-04,
|
| 298 |
+
# 9.9999e-01, -3.1780e-03, -1.0313e-03, 1.2372e-02, 3.1837e-03,
|
| 299 |
+
# 9.9992e-01, -1.4170e-02, 1.0000e+00],
|
| 300 |
+
# [ 9.9997e-01, 2.6684e-04, -7.1423e-03, -2.9546e-03, -2.7965e-04,
|
| 301 |
+
# 1.0000e+00, -1.7922e-03, -2.0437e-04, 7.1418e-03, 1.7942e-03,
|
| 302 |
+
# 9.9997e-01, -1.3811e-02, 1.0000e+00],
|
| 303 |
+
# [ 9.9999e-01, 1.5524e-04, -4.1128e-03, -9.7896e-04, -1.5948e-04,
|
| 304 |
+
# 1.0000e+00, -1.0322e-03, 2.5742e-04, 4.1126e-03, 1.0328e-03,
|
| 305 |
+
# 9.9999e-01, -1.3608e-02, 1.0000e+00],
|
| 306 |
+
# [ 1.0000e+00, 8.9919e-05, -2.3684e-03, 1.8947e-04, -9.1325e-05,
|
| 307 |
+
# 1.0000e+00, -5.9445e-04, 5.2862e-04, 2.3683e-03, 5.9466e-04,
|
| 308 |
+
# 1.0000e+00, -1.3490e-02, 1.0000e+00],
|
| 309 |
+
# [ 1.0000e+00, 5.1932e-05, -1.3635e-03, 8.8221e-04, -5.2401e-05,
|
| 310 |
+
# 1.0000e+00, -3.4229e-04, 6.8774e-04, 1.3635e-03, 3.4236e-04,
|
| 311 |
+
# 1.0000e+00, -1.3419e-02, 1.0000e+00],
|
| 312 |
+
# [ 1.0000e+00, 2.9971e-05, -7.8533e-04, 1.2923e-03, -3.0129e-05,
|
| 313 |
+
# 1.0000e+00, -1.9714e-04, 7.8124e-04, 7.8534e-04, 1.9716e-04,
|
| 314 |
+
# 1.0000e+00, -1.3378e-02, 1.0000e+00],
|
| 315 |
+
# [ 1.0000e+00, 1.7271e-05, -4.5211e-04, 1.5351e-03, -1.7318e-05,
|
| 316 |
+
# 1.0000e+00, -1.1352e-04, 8.3586e-04, 4.5211e-04, 1.1353e-04,
|
| 317 |
+
# 1.0000e+00, -1.3353e-02, 1.0000e+00],
|
| 318 |
+
# [ 1.0000e+00, 9.9305e-06, -2.5968e-04, 1.6798e-03, -9.9495e-06,
|
| 319 |
+
# 1.0000e+00, -6.5163e-05, 8.6798e-04, 2.5970e-04, 6.5163e-05,
|
| 320 |
+
# 1.0000e+00, -1.3338e-02, 1.0000e+00],
|
| 321 |
+
# [ 1.0000e+00, 1.4484e-05, -3.7806e-04, 1.5971e-03, -1.4521e-05,
|
| 322 |
+
# 1.0000e+00, -9.4604e-05, 8.4546e-04, 3.7804e-04, 9.4615e-05,
|
| 323 |
+
# 1.0000e+00, -1.3347e-02, 0.0000e+00],
|
| 324 |
+
# [ 1.0000e+00, 6.5319e-05, -9.4321e-04, 1.1732e-03, -6.5316e-05,
|
| 325 |
+
# 1.0000e+00, 5.4177e-06, 9.2146e-04, 9.4322e-04, -5.3641e-06,
|
| 326 |
+
# 1.0000e+00, -1.3372e-02, 0.0000e+00],
|
| 327 |
+
# [ 9.9999e-01, 2.5994e-04, -3.9389e-03, -1.0991e-03, -2.6020e-04,
|
| 328 |
+
# 1.0000e+00, -6.6082e-05, 8.7861e-04, 3.9388e-03, 6.7103e-05,
|
| 329 |
+
# 9.9999e-01, -1.3561e-02, 0.0000e+00],
|
| 330 |
+
# [ 9.9998e-01, 2.7008e-04, -6.8774e-03, -3.3641e-03, -2.7882e-04,
|
| 331 |
+
# 1.0000e+00, -1.2689e-03, -5.0134e-05, 6.8771e-03, 1.2708e-03,
|
| 332 |
+
# 9.9998e-01, -1.3853e-02, 0.0000e+00],
|
| 333 |
+
# [ 9.9996e-01, 4.6250e-04, -8.4143e-03, -4.5899e-03, -4.6835e-04,
|
| 334 |
+
# 1.0000e+00, -6.9268e-04, 3.9740e-04, 8.4139e-03, 6.9660e-04,
|
| 335 |
+
# 9.9996e-01, -1.3917e-02, 0.0000e+00]
|
| 336 |
+
#], dtype=torch.bfloat16, device=device)
|
| 337 |
+
|
| 338 |
+
camera_embedding = camera_embedding.unsqueeze(0).to(device, dtype=torch.bfloat16)
|
| 339 |
+
|
| 340 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 341 |
+
|
| 342 |
+
print("Encoding prompt...")
|
| 343 |
+
|
| 344 |
+
# Encode text prompt
|
| 345 |
+
prompt_emb = pipe.encode_prompt(prompt)
|
| 346 |
+
|
| 347 |
+
print("Generating video...")
|
| 348 |
+
|
| 349 |
+
# Generate target latents
|
| 350 |
+
batch_size = 1
|
| 351 |
+
channels = condition_latents.shape[1]
|
| 352 |
+
latent_height = condition_latents.shape[3]
|
| 353 |
+
latent_width = condition_latents.shape[4]
|
| 354 |
+
|
| 355 |
+
# 空间裁剪以节省内存(如果需要)
|
| 356 |
+
target_height, target_width = 60, 104
|
| 357 |
+
|
| 358 |
+
if latent_height > target_height or latent_width > target_width:
|
| 359 |
+
# 中心裁剪
|
| 360 |
+
h_start = (latent_height - target_height) // 2
|
| 361 |
+
w_start = (latent_width - target_width) // 2
|
| 362 |
+
condition_latents = condition_latents[:, :, :,
|
| 363 |
+
h_start:h_start+target_height,
|
| 364 |
+
w_start:w_start+target_width]
|
| 365 |
+
latent_height = target_height
|
| 366 |
+
latent_width = target_width
|
| 367 |
+
|
| 368 |
+
# Initialize target latents with noise
|
| 369 |
+
target_latents = torch.randn(
|
| 370 |
+
batch_size, channels, target_frames, latent_height, latent_width,
|
| 371 |
+
device=device, dtype=pipe.torch_dtype
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
print(f"Condition latents shape: {condition_latents.shape}")
|
| 375 |
+
print(f"Target latents shape: {target_latents.shape}")
|
| 376 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 377 |
+
|
| 378 |
+
# Combine condition and target latents
|
| 379 |
+
combined_latents = torch.cat([condition_latents, target_latents], dim=2)
|
| 380 |
+
print(f"Combined latents shape: {combined_latents.shape}")
|
| 381 |
+
|
| 382 |
+
# Prepare extra inputs
|
| 383 |
+
extra_input = pipe.prepare_extra_input(combined_latents)
|
| 384 |
+
|
| 385 |
+
# Denoising loop
|
| 386 |
+
timesteps = pipe.scheduler.timesteps
|
| 387 |
+
|
| 388 |
+
for i, timestep in enumerate(timesteps):
|
| 389 |
+
print(f"Denoising step {i+1}/{len(timesteps)}")
|
| 390 |
+
|
| 391 |
+
# Prepare timestep
|
| 392 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
|
| 393 |
+
|
| 394 |
+
# Predict noise
|
| 395 |
+
with torch.no_grad():
|
| 396 |
+
noise_pred = pipe.dit(
|
| 397 |
+
combined_latents,
|
| 398 |
+
timestep=timestep_tensor,
|
| 399 |
+
cam_emb=camera_embedding,
|
| 400 |
+
**prompt_emb,
|
| 401 |
+
**extra_input
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
# Update only target part
|
| 405 |
+
target_noise_pred = noise_pred[:, :, condition_frames:, :, :]
|
| 406 |
+
target_latents = pipe.scheduler.step(target_noise_pred, timestep, target_latents)
|
| 407 |
+
|
| 408 |
+
# Update combined latents
|
| 409 |
+
combined_latents[:, :, condition_frames:, :, :] = target_latents
|
| 410 |
+
|
| 411 |
+
print("Decoding video...")
|
| 412 |
+
|
| 413 |
+
# Decode final video
|
| 414 |
+
final_video = torch.cat([condition_latents, target_latents], dim=2)
|
| 415 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 416 |
+
|
| 417 |
+
# Save video
|
| 418 |
+
print(f"Saving video to {output_path}")
|
| 419 |
+
|
| 420 |
+
# Convert to numpy and save
|
| 421 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
|
| 422 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1) # Denormalize
|
| 423 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 424 |
+
|
| 425 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 426 |
+
for frame in video_np:
|
| 427 |
+
writer.append_data(frame)
|
| 428 |
+
|
| 429 |
+
print(f"Video generation completed! Saved to {output_path}")
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def main():
|
| 433 |
+
parser = argparse.ArgumentParser(description="Sekai Video Generation Inference from PTH")
|
| 434 |
+
parser.add_argument("--condition_pth", type=str,
|
| 435 |
+
default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
|
| 436 |
+
parser.add_argument("--start_frame", type=int, default=0,
|
| 437 |
+
help="Starting frame index (compressed latent frames)")
|
| 438 |
+
parser.add_argument("--condition_frames", type=int, default=8,
|
| 439 |
+
help="Number of condition frames (compressed latent frames)")
|
| 440 |
+
parser.add_argument("--target_frames", type=int, default=8,
|
| 441 |
+
help="Number of target frames to generate (compressed latent frames)")
|
| 442 |
+
parser.add_argument("--direction", type=str, default="left_turn",
|
| 443 |
+
choices=["forward", "backward", "left_turn", "right_turn"],
|
| 444 |
+
help="Direction of camera movement (if not using real poses)")
|
| 445 |
+
parser.add_argument("--use_real_poses", default=False,
|
| 446 |
+
help="Use real camera poses from data")
|
| 447 |
+
parser.add_argument("--dit_path", type=str, default="/home/zhuyixuan05/ReCamMaster/sekai_walking_noise/step14000_dynamic.ckpt",
|
| 448 |
+
help="Path to trained DiT checkpoint")
|
| 449 |
+
parser.add_argument("--output_path", type=str, default='/home/zhuyixuan05/ReCamMaster/sekai/infer_noise_results/output_sekai_right_turn.mp4',
|
| 450 |
+
help="Output video path")
|
| 451 |
+
parser.add_argument("--prompt", type=str,
|
| 452 |
+
default="A drone flying scene in a game world",
|
| 453 |
+
help="Text prompt for generation")
|
| 454 |
+
parser.add_argument("--device", type=str, default="cuda",
|
| 455 |
+
help="Device to run inference on")
|
| 456 |
+
|
| 457 |
+
args = parser.parse_args()
|
| 458 |
+
|
| 459 |
+
# 生成输出路径
|
| 460 |
+
if args.output_path is None:
|
| 461 |
+
pth_filename = os.path.basename(args.condition_pth)
|
| 462 |
+
name_parts = os.path.splitext(pth_filename)
|
| 463 |
+
output_dir = "sekai/infer_results"
|
| 464 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 465 |
+
|
| 466 |
+
if args.use_real_poses:
|
| 467 |
+
output_filename = f"{name_parts[0]}_real_poses_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4"
|
| 468 |
+
else:
|
| 469 |
+
output_filename = f"{name_parts[0]}_{args.direction}_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4"
|
| 470 |
+
|
| 471 |
+
output_path = os.path.join(output_dir, output_filename)
|
| 472 |
+
else:
|
| 473 |
+
output_path = args.output_path
|
| 474 |
+
|
| 475 |
+
print(f"Input pth: {args.condition_pth}")
|
| 476 |
+
print(f"Start frame: {args.start_frame} (compressed)")
|
| 477 |
+
print(f"Condition frames: {args.condition_frames} (compressed, original: {args.condition_frames * 4})")
|
| 478 |
+
print(f"Target frames: {args.target_frames} (compressed, original: {args.target_frames * 4})")
|
| 479 |
+
print(f"Use real poses: {args.use_real_poses}")
|
| 480 |
+
print(f"Output video will be saved to: {output_path}")
|
| 481 |
+
|
| 482 |
+
inference_sekai_video_from_pth(
|
| 483 |
+
condition_pth_path=args.condition_pth,
|
| 484 |
+
dit_path=args.dit_path,
|
| 485 |
+
output_path=output_path,
|
| 486 |
+
start_frame=args.start_frame,
|
| 487 |
+
condition_frames=args.condition_frames,
|
| 488 |
+
target_frames=args.target_frames,
|
| 489 |
+
device=args.device,
|
| 490 |
+
prompt=args.prompt,
|
| 491 |
+
direction=args.direction,
|
| 492 |
+
use_real_poses=args.use_real_poses
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
if __name__ == "__main__":
|
| 497 |
+
main()
|
scripts/infer_sekai_framepack.py
ADDED
|
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import imageio
|
| 7 |
+
import json
|
| 8 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 9 |
+
import argparse
|
| 10 |
+
from torchvision.transforms import v2
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
import copy
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 16 |
+
"""从pth文件加载预编码的视频数据"""
|
| 17 |
+
print(f"Loading encoded video from {pth_path}")
|
| 18 |
+
|
| 19 |
+
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
| 20 |
+
full_latents = encoded_data['latents'] # [C, T, H, W]
|
| 21 |
+
|
| 22 |
+
print(f"Full latents shape: {full_latents.shape}")
|
| 23 |
+
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
|
| 24 |
+
|
| 25 |
+
if start_frame + num_frames > full_latents.shape[1]:
|
| 26 |
+
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
|
| 27 |
+
|
| 28 |
+
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
|
| 29 |
+
print(f"Extracted condition latents shape: {condition_latents.shape}")
|
| 30 |
+
|
| 31 |
+
return condition_latents, encoded_data
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 35 |
+
"""计算相机B相对于相机A的相对位姿矩阵"""
|
| 36 |
+
assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
|
| 37 |
+
assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
|
| 38 |
+
|
| 39 |
+
if use_torch:
|
| 40 |
+
if not isinstance(pose_a, torch.Tensor):
|
| 41 |
+
pose_a = torch.from_numpy(pose_a).float()
|
| 42 |
+
if not isinstance(pose_b, torch.Tensor):
|
| 43 |
+
pose_b = torch.from_numpy(pose_b).float()
|
| 44 |
+
|
| 45 |
+
pose_a_inv = torch.inverse(pose_a)
|
| 46 |
+
relative_pose = torch.matmul(pose_b, pose_a_inv)
|
| 47 |
+
else:
|
| 48 |
+
if not isinstance(pose_a, np.ndarray):
|
| 49 |
+
pose_a = np.array(pose_a, dtype=np.float32)
|
| 50 |
+
if not isinstance(pose_b, np.ndarray):
|
| 51 |
+
pose_b = np.array(pose_b, dtype=np.float32)
|
| 52 |
+
|
| 53 |
+
pose_a_inv = np.linalg.inv(pose_a)
|
| 54 |
+
relative_pose = np.matmul(pose_b, pose_a_inv)
|
| 55 |
+
|
| 56 |
+
return relative_pose
|
| 57 |
+
|
| 58 |
+
def replace_dit_model_in_manager():
|
| 59 |
+
"""替换DiT模型类为FramePack版本"""
|
| 60 |
+
from diffsynth.models.wan_video_dit_recam_future import WanModelFuture
|
| 61 |
+
from diffsynth.configs.model_config import model_loader_configs
|
| 62 |
+
|
| 63 |
+
for i, config in enumerate(model_loader_configs):
|
| 64 |
+
keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
|
| 65 |
+
|
| 66 |
+
if 'wan_video_dit' in model_names:
|
| 67 |
+
new_model_names = []
|
| 68 |
+
new_model_classes = []
|
| 69 |
+
|
| 70 |
+
for name, cls in zip(model_names, model_classes):
|
| 71 |
+
if name == 'wan_video_dit':
|
| 72 |
+
new_model_names.append(name)
|
| 73 |
+
new_model_classes.append(WanModelFuture)
|
| 74 |
+
print(f"✅ 替换了模型类: {name} -> WanModelFuture")
|
| 75 |
+
else:
|
| 76 |
+
new_model_names.append(name)
|
| 77 |
+
new_model_classes.append(cls)
|
| 78 |
+
|
| 79 |
+
model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def add_framepack_components(dit_model):
|
| 83 |
+
"""添加FramePack相关组件"""
|
| 84 |
+
if not hasattr(dit_model, 'clean_x_embedder'):
|
| 85 |
+
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 86 |
+
|
| 87 |
+
class CleanXEmbedder(nn.Module):
|
| 88 |
+
def __init__(self, inner_dim):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 91 |
+
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
| 92 |
+
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
| 93 |
+
|
| 94 |
+
def forward(self, x, scale="1x"):
|
| 95 |
+
if scale == "1x":
|
| 96 |
+
x = x.to(self.proj.weight.dtype)
|
| 97 |
+
return self.proj(x)
|
| 98 |
+
elif scale == "2x":
|
| 99 |
+
x = x.to(self.proj_2x.weight.dtype)
|
| 100 |
+
return self.proj_2x(x)
|
| 101 |
+
elif scale == "4x":
|
| 102 |
+
x = x.to(self.proj_4x.weight.dtype)
|
| 103 |
+
return self.proj_4x(x)
|
| 104 |
+
else:
|
| 105 |
+
raise ValueError(f"Unsupported scale: {scale}")
|
| 106 |
+
|
| 107 |
+
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
|
| 108 |
+
model_dtype = next(dit_model.parameters()).dtype
|
| 109 |
+
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
|
| 110 |
+
print("✅ 添加了FramePack的clean_x_embedder组件")
|
| 111 |
+
|
| 112 |
+
def generate_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
|
| 113 |
+
"""🔧 为滑动窗口生成camera embeddings - 修正长度计算,确保包含start_latent帧"""
|
| 114 |
+
time_compression_ratio = 4
|
| 115 |
+
|
| 116 |
+
# 🔧 计算FramePack实际需要的camera帧���
|
| 117 |
+
# FramePack结构: 1(start) + 16(4x) + 2(2x) + 1(1x) + target_frames
|
| 118 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 119 |
+
|
| 120 |
+
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
|
| 121 |
+
print("🔧 使用真实camera数据")
|
| 122 |
+
cam_extrinsic = cam_data['extrinsic']
|
| 123 |
+
|
| 124 |
+
# 🔧 确保生成足够长的camera序列
|
| 125 |
+
# 需要考虑:当前历史位置 + FramePack所需的完整结构
|
| 126 |
+
max_needed_frames = max(
|
| 127 |
+
start_frame + current_history_length + new_frames, # 基础需求
|
| 128 |
+
framepack_needed_frames, # FramePack结构需求
|
| 129 |
+
30 # 最小保证长度
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
print(f"🔧 计算camera序列长度:")
|
| 133 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 134 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 135 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 136 |
+
|
| 137 |
+
relative_poses = []
|
| 138 |
+
for i in range(max_needed_frames):
|
| 139 |
+
# 计算当前帧在原始序列中的位置
|
| 140 |
+
frame_idx = i * time_compression_ratio
|
| 141 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 142 |
+
|
| 143 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 144 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 145 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 146 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 147 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 148 |
+
else:
|
| 149 |
+
# 超出范围,使用零运动
|
| 150 |
+
print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
|
| 151 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 152 |
+
|
| 153 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 154 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 155 |
+
|
| 156 |
+
# 🔧 创建对应长度的mask序列
|
| 157 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 158 |
+
# 从start_frame到current_history_length标记为condition
|
| 159 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 160 |
+
mask[start_frame:condition_end] = 1.0
|
| 161 |
+
|
| 162 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 163 |
+
print(f"🔧 真实camera embedding shape: {camera_embedding.shape} (总长度:{max_needed_frames})")
|
| 164 |
+
return camera_embedding.to(torch.bfloat16)
|
| 165 |
+
|
| 166 |
+
else:
|
| 167 |
+
print("🔧 使用合成camera数据")
|
| 168 |
+
|
| 169 |
+
# 🔧 确保合成数据也有足够长度
|
| 170 |
+
max_needed_frames = max(
|
| 171 |
+
start_frame + current_history_length + new_frames,
|
| 172 |
+
framepack_needed_frames,
|
| 173 |
+
30
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
print(f"🔧 生成合成camera帧数: {max_needed_frames}")
|
| 177 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 178 |
+
|
| 179 |
+
relative_poses = []
|
| 180 |
+
for i in range(max_needed_frames):
|
| 181 |
+
# 🔧 持续左转运动模式
|
| 182 |
+
# 每帧旋转一个固定角度,同时前进
|
| 183 |
+
yaw_per_frame = -0.05 # 每帧左转(正角度表示左转)
|
| 184 |
+
forward_speed = 0.005 # 每帧前进距离
|
| 185 |
+
|
| 186 |
+
# 计算当前累积角度
|
| 187 |
+
current_yaw = i * yaw_per_frame
|
| 188 |
+
|
| 189 |
+
# 创建相对变换矩阵(从第i帧到第i+1帧的变换)
|
| 190 |
+
pose = np.eye(4, dtype=np.float32)
|
| 191 |
+
|
| 192 |
+
# 旋转矩阵(绕Y轴左转)
|
| 193 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 194 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 195 |
+
|
| 196 |
+
pose[0, 0] = cos_yaw
|
| 197 |
+
pose[0, 2] = sin_yaw
|
| 198 |
+
pose[2, 0] = -sin_yaw
|
| 199 |
+
pose[2, 2] = cos_yaw
|
| 200 |
+
|
| 201 |
+
# 平移(在旋转后的局部坐标系中前进)
|
| 202 |
+
pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
|
| 203 |
+
|
| 204 |
+
# 可选:添加轻微的向心运动,模拟圆形轨迹
|
| 205 |
+
radius_drift = 0.002 # 向圆心的轻微漂移
|
| 206 |
+
pose[0, 3] = radius_drift # 局部X轴负方向(向左)
|
| 207 |
+
|
| 208 |
+
relative_pose = pose[:3, :]
|
| 209 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 210 |
+
|
| 211 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 212 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 213 |
+
|
| 214 |
+
# 创建对应长度的mask序列
|
| 215 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 216 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 217 |
+
mask[start_frame:condition_end] = 1.0
|
| 218 |
+
|
| 219 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 220 |
+
print(f"🔧 合成camera embedding shape: {camera_embedding.shape} (总长度:{max_needed_frames})")
|
| 221 |
+
return camera_embedding.to(torch.bfloat16)
|
| 222 |
+
|
| 223 |
+
def prepare_framepack_sliding_window_with_camera(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, max_history_frames=49):
|
| 224 |
+
"""🔧 FramePack滑动窗口机�� - 修正camera mask更新逻辑"""
|
| 225 |
+
# history_latents: [C, T, H, W] 当前的历史latents
|
| 226 |
+
C, T, H, W = history_latents.shape
|
| 227 |
+
|
| 228 |
+
# 🔧 固定索引结构(这决定了需要的camera帧数)
|
| 229 |
+
total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
|
| 230 |
+
indices = torch.arange(0, total_indices_length)
|
| 231 |
+
split_sizes = [1, 16, 2, 1, target_frames_to_generate]
|
| 232 |
+
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
|
| 233 |
+
indices.split(split_sizes, dim=0)
|
| 234 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
|
| 235 |
+
|
| 236 |
+
# 🔧 检查camera长度是否足够
|
| 237 |
+
if camera_embedding_full.shape[0] < total_indices_length:
|
| 238 |
+
shortage = total_indices_length - camera_embedding_full.shape[0]
|
| 239 |
+
padding = torch.zeros(shortage, camera_embedding_full.shape[1],
|
| 240 |
+
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
|
| 241 |
+
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
|
| 242 |
+
|
| 243 |
+
# 🔧 从完整camera序列中选取对应部分
|
| 244 |
+
combined_camera = camera_embedding_full[:total_indices_length, :].clone() # clone to avoid modifying original
|
| 245 |
+
|
| 246 |
+
# 🔧 关键修正:根据当前history length重新设置mask
|
| 247 |
+
# combined_camera的结构对应: [1(start) + 16(4x) + 2(2x) + 1(1x) + target_frames]
|
| 248 |
+
# 前19帧对应clean latents,后面对应target
|
| 249 |
+
|
| 250 |
+
# 清空所有mask,重新设置
|
| 251 |
+
combined_camera[:, -1] = 0.0 # 先全部设为target (0)
|
| 252 |
+
|
| 253 |
+
# 设置condition mask:前19帧根据实际历史长度决定
|
| 254 |
+
if T > 0:
|
| 255 |
+
# 根据clean_latents的填充逻辑,确定哪些位置应该是condition
|
| 256 |
+
available_frames = min(T, 19)
|
| 257 |
+
start_pos = 19 - available_frames
|
| 258 |
+
|
| 259 |
+
# 对应的camera位置也应该标记为condition
|
| 260 |
+
combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
|
| 261 |
+
|
| 262 |
+
# target部分保持为0(已经在上面设置)
|
| 263 |
+
|
| 264 |
+
print(f"🔧 Camera mask更新:")
|
| 265 |
+
print(f" - 历史帧数: {T}")
|
| 266 |
+
print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
|
| 267 |
+
print(f" - Condition mask (前19帧): {combined_camera[:19, -1].cpu().tolist()}")
|
| 268 |
+
print(f" - Target mask (后{target_frames_to_generate}帧): {combined_camera[19:, -1].cpu().tolist()}")
|
| 269 |
+
# 其余处理逻辑保持不变...
|
| 270 |
+
clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 271 |
+
|
| 272 |
+
if T > 0:
|
| 273 |
+
available_frames = min(T, 19)
|
| 274 |
+
start_pos = 19 - available_frames
|
| 275 |
+
clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
|
| 276 |
+
|
| 277 |
+
clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
|
| 278 |
+
clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
|
| 279 |
+
clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
|
| 280 |
+
|
| 281 |
+
if T > 0:
|
| 282 |
+
start_latent = history_latents[:, 0:1, :, :]
|
| 283 |
+
else:
|
| 284 |
+
start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 285 |
+
|
| 286 |
+
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
|
| 287 |
+
|
| 288 |
+
return {
|
| 289 |
+
'latent_indices': latent_indices,
|
| 290 |
+
'clean_latents': clean_latents,
|
| 291 |
+
'clean_latents_2x': clean_latents_2x,
|
| 292 |
+
'clean_latents_4x': clean_latents_4x,
|
| 293 |
+
'clean_latent_indices': clean_latent_indices,
|
| 294 |
+
'clean_latent_2x_indices': clean_latent_2x_indices,
|
| 295 |
+
'clean_latent_4x_indices': clean_latent_4x_indices,
|
| 296 |
+
'camera_embedding': combined_camera, # 🔧 现在包含正确更新的mask
|
| 297 |
+
'current_length': T,
|
| 298 |
+
'next_length': T + target_frames_to_generate
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
def inference_sekai_framepack_sliding_window(
|
| 302 |
+
condition_pth_path,
|
| 303 |
+
dit_path,
|
| 304 |
+
output_path="sekai/infer_results/output_sekai_framepack_sliding.mp4",
|
| 305 |
+
start_frame=0,
|
| 306 |
+
initial_condition_frames=8,
|
| 307 |
+
frames_per_generation=4,
|
| 308 |
+
total_frames_to_generate=32,
|
| 309 |
+
max_history_frames=49,
|
| 310 |
+
device="cuda",
|
| 311 |
+
prompt="A video of a scene shot using a pedestrian's front camera while walking",
|
| 312 |
+
use_real_poses=True,
|
| 313 |
+
synthetic_direction="forward",
|
| 314 |
+
# 🔧 新增CFG参数
|
| 315 |
+
use_camera_cfg=True,
|
| 316 |
+
camera_guidance_scale=2.0,
|
| 317 |
+
text_guidance_scale=7.5
|
| 318 |
+
):
|
| 319 |
+
"""
|
| 320 |
+
🔧 FramePack滑动窗口视频生成 - 支持Camera CFG
|
| 321 |
+
"""
|
| 322 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 323 |
+
print(f"🔧 FramePack滑动窗口生成开始...")
|
| 324 |
+
print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
|
| 325 |
+
print(f"Text guidance scale: {text_guidance_scale}")
|
| 326 |
+
print(f"初始条件帧: {initial_condition_frames}, 每次生成: {frames_per_generation}, 总生成: {total_frames_to_generate}")
|
| 327 |
+
print(f"使用真实姿态: {use_real_poses}")
|
| 328 |
+
if not use_real_poses:
|
| 329 |
+
print(f"合成camera方向: {synthetic_direction}")
|
| 330 |
+
|
| 331 |
+
# 1-3. 模型初始化和组件添加(保持不变)
|
| 332 |
+
replace_dit_model_in_manager()
|
| 333 |
+
|
| 334 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 335 |
+
model_manager.load_models([
|
| 336 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 337 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 338 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 339 |
+
])
|
| 340 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 341 |
+
|
| 342 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 343 |
+
for block in pipe.dit.blocks:
|
| 344 |
+
block.cam_encoder = nn.Linear(13, dim)
|
| 345 |
+
block.projector = nn.Linear(dim, dim)
|
| 346 |
+
block.cam_encoder.weight.data.zero_()
|
| 347 |
+
block.cam_encoder.bias.data.zero_()
|
| 348 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 349 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 350 |
+
|
| 351 |
+
add_framepack_components(pipe.dit)
|
| 352 |
+
|
| 353 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 354 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=True)
|
| 355 |
+
pipe = pipe.to(device)
|
| 356 |
+
model_dtype = next(pipe.dit.parameters()).dtype
|
| 357 |
+
|
| 358 |
+
if hasattr(pipe.dit, 'clean_x_embedder'):
|
| 359 |
+
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
|
| 360 |
+
|
| 361 |
+
pipe.scheduler.set_timesteps(50)
|
| 362 |
+
|
| 363 |
+
# 4. 加载初始条件
|
| 364 |
+
print("Loading initial condition frames...")
|
| 365 |
+
initial_latents, encoded_data = load_encoded_video_from_pth(
|
| 366 |
+
condition_pth_path,
|
| 367 |
+
start_frame=start_frame,
|
| 368 |
+
num_frames=initial_condition_frames
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
# 空间裁剪
|
| 372 |
+
target_height, target_width = 60, 104
|
| 373 |
+
C, T, H, W = initial_latents.shape
|
| 374 |
+
|
| 375 |
+
if H > target_height or W > target_width:
|
| 376 |
+
h_start = (H - target_height) // 2
|
| 377 |
+
w_start = (W - target_width) // 2
|
| 378 |
+
initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
|
| 379 |
+
H, W = target_height, target_width
|
| 380 |
+
|
| 381 |
+
history_latents = initial_latents.to(device, dtype=model_dtype)
|
| 382 |
+
|
| 383 |
+
print(f"初始history_latents shape: {history_latents.shape}")
|
| 384 |
+
|
| 385 |
+
# 编码prompt - 支持CFG
|
| 386 |
+
if text_guidance_scale > 1.0:
|
| 387 |
+
# 编码positive prompt
|
| 388 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 389 |
+
# 编码negative prompt (空字符串)
|
| 390 |
+
prompt_emb_neg = pipe.encode_prompt("")
|
| 391 |
+
print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
|
| 392 |
+
else:
|
| 393 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 394 |
+
prompt_emb_neg = None
|
| 395 |
+
print("不使用Text CFG")
|
| 396 |
+
|
| 397 |
+
# 预生成完整的camera embedding序列
|
| 398 |
+
camera_embedding_full = generate_camera_embeddings_sliding(
|
| 399 |
+
encoded_data.get('cam_emb', None),
|
| 400 |
+
0,
|
| 401 |
+
max_history_frames,
|
| 402 |
+
0,
|
| 403 |
+
0,
|
| 404 |
+
use_real_poses=use_real_poses
|
| 405 |
+
).to(device, dtype=model_dtype)
|
| 406 |
+
|
| 407 |
+
print(f"完整camera序列shape: {camera_embedding_full.shape}")
|
| 408 |
+
|
| 409 |
+
# 🔧 为Camera CFG创建无条件的camera embedding
|
| 410 |
+
if use_camera_cfg:
|
| 411 |
+
# 创建零camera embedding(无条件)
|
| 412 |
+
camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
|
| 413 |
+
print(f"创建无条件camera embedding用于CFG")
|
| 414 |
+
|
| 415 |
+
# 滑动窗口生成循环
|
| 416 |
+
total_generated = 0
|
| 417 |
+
all_generated_frames = []
|
| 418 |
+
|
| 419 |
+
while total_generated < total_frames_to_generate:
|
| 420 |
+
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
|
| 421 |
+
print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
|
| 422 |
+
print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
|
| 423 |
+
|
| 424 |
+
# FramePack数据准备
|
| 425 |
+
framepack_data = prepare_framepack_sliding_window_with_camera(
|
| 426 |
+
history_latents,
|
| 427 |
+
current_generation,
|
| 428 |
+
camera_embedding_full,
|
| 429 |
+
start_frame,
|
| 430 |
+
max_history_frames
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# 准备输入
|
| 434 |
+
clean_latents = framepack_data['clean_latents'].unsqueeze(0)
|
| 435 |
+
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
|
| 436 |
+
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
|
| 437 |
+
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
|
| 438 |
+
|
| 439 |
+
# 🔧 为CFG准备无条件camera embedding
|
| 440 |
+
if use_camera_cfg:
|
| 441 |
+
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
|
| 442 |
+
|
| 443 |
+
# 索引处理
|
| 444 |
+
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
|
| 445 |
+
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
|
| 446 |
+
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
|
| 447 |
+
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
|
| 448 |
+
|
| 449 |
+
# 初始化要生成的latents
|
| 450 |
+
new_latents = torch.randn(
|
| 451 |
+
1, C, current_generation, H, W,
|
| 452 |
+
device=device, dtype=model_dtype
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
extra_input = pipe.prepare_extra_input(new_latents)
|
| 456 |
+
|
| 457 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 458 |
+
print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
|
| 459 |
+
|
| 460 |
+
# 去噪循环 - 支持CFG
|
| 461 |
+
timesteps = pipe.scheduler.timesteps
|
| 462 |
+
|
| 463 |
+
for i, timestep in enumerate(timesteps):
|
| 464 |
+
if i % 10 == 0:
|
| 465 |
+
print(f" 去噪步骤 {i+1}/{len(timesteps)}")
|
| 466 |
+
|
| 467 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
|
| 468 |
+
|
| 469 |
+
with torch.no_grad():
|
| 470 |
+
# 🔧 CFG推理
|
| 471 |
+
if use_camera_cfg and camera_guidance_scale > 1.0:
|
| 472 |
+
# 条件预测(有camera)
|
| 473 |
+
noise_pred_cond = pipe.dit(
|
| 474 |
+
new_latents,
|
| 475 |
+
timestep=timestep_tensor,
|
| 476 |
+
cam_emb=camera_embedding,
|
| 477 |
+
latent_indices=latent_indices,
|
| 478 |
+
clean_latents=clean_latents,
|
| 479 |
+
clean_latent_indices=clean_latent_indices,
|
| 480 |
+
clean_latents_2x=clean_latents_2x,
|
| 481 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 482 |
+
clean_latents_4x=clean_latents_4x,
|
| 483 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 484 |
+
**prompt_emb_pos,
|
| 485 |
+
**extra_input
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
# 无条件预测(无camera)
|
| 489 |
+
noise_pred_uncond = pipe.dit(
|
| 490 |
+
new_latents,
|
| 491 |
+
timestep=timestep_tensor,
|
| 492 |
+
cam_emb=camera_embedding_uncond_batch,
|
| 493 |
+
latent_indices=latent_indices,
|
| 494 |
+
clean_latents=clean_latents,
|
| 495 |
+
clean_latent_indices=clean_latent_indices,
|
| 496 |
+
clean_latents_2x=clean_latents_2x,
|
| 497 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 498 |
+
clean_latents_4x=clean_latents_4x,
|
| 499 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 500 |
+
**(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
|
| 501 |
+
**extra_input
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
# Camera CFG
|
| 505 |
+
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 506 |
+
|
| 507 |
+
# 如果同时使用Text CFG
|
| 508 |
+
if text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 509 |
+
# 还需要计算text无条件预测
|
| 510 |
+
noise_pred_text_uncond = pipe.dit(
|
| 511 |
+
new_latents,
|
| 512 |
+
timestep=timestep_tensor,
|
| 513 |
+
cam_emb=camera_embedding,
|
| 514 |
+
latent_indices=latent_indices,
|
| 515 |
+
clean_latents=clean_latents,
|
| 516 |
+
clean_latent_indices=clean_latent_indices,
|
| 517 |
+
clean_latents_2x=clean_latents_2x,
|
| 518 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 519 |
+
clean_latents_4x=clean_latents_4x,
|
| 520 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 521 |
+
**prompt_emb_neg,
|
| 522 |
+
**extra_input
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
# 应用Text CFG到已经应用Camera CFG的结果
|
| 526 |
+
noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
|
| 527 |
+
|
| 528 |
+
elif text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 529 |
+
# 只使用Text CFG
|
| 530 |
+
noise_pred_cond = pipe.dit(
|
| 531 |
+
new_latents,
|
| 532 |
+
timestep=timestep_tensor,
|
| 533 |
+
cam_emb=camera_embedding,
|
| 534 |
+
latent_indices=latent_indices,
|
| 535 |
+
clean_latents=clean_latents,
|
| 536 |
+
clean_latent_indices=clean_latent_indices,
|
| 537 |
+
clean_latents_2x=clean_latents_2x,
|
| 538 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 539 |
+
clean_latents_4x=clean_latents_4x,
|
| 540 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 541 |
+
**prompt_emb_pos,
|
| 542 |
+
**extra_input
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
noise_pred_uncond = pipe.dit(
|
| 546 |
+
new_latents,
|
| 547 |
+
timestep=timestep_tensor,
|
| 548 |
+
cam_emb=camera_embedding,
|
| 549 |
+
latent_indices=latent_indices,
|
| 550 |
+
clean_latents=clean_latents,
|
| 551 |
+
clean_latent_indices=clean_latent_indices,
|
| 552 |
+
clean_latents_2x=clean_latents_2x,
|
| 553 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 554 |
+
clean_latents_4x=clean_latents_4x,
|
| 555 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 556 |
+
**prompt_emb_neg,
|
| 557 |
+
**extra_input
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 561 |
+
|
| 562 |
+
else:
|
| 563 |
+
# 标准推理(无CFG)
|
| 564 |
+
noise_pred = pipe.dit(
|
| 565 |
+
new_latents,
|
| 566 |
+
timestep=timestep_tensor,
|
| 567 |
+
cam_emb=camera_embedding,
|
| 568 |
+
latent_indices=latent_indices,
|
| 569 |
+
clean_latents=clean_latents,
|
| 570 |
+
clean_latent_indices=clean_latent_indices,
|
| 571 |
+
clean_latents_2x=clean_latents_2x,
|
| 572 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 573 |
+
clean_latents_4x=clean_latents_4x,
|
| 574 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 575 |
+
**prompt_emb_pos,
|
| 576 |
+
**extra_input
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
|
| 580 |
+
|
| 581 |
+
# 更新历史
|
| 582 |
+
new_latents_squeezed = new_latents.squeeze(0)
|
| 583 |
+
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
|
| 584 |
+
|
| 585 |
+
# 维护滑动窗口
|
| 586 |
+
if history_latents.shape[1] > max_history_frames:
|
| 587 |
+
first_frame = history_latents[:, 0:1, :, :]
|
| 588 |
+
recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
|
| 589 |
+
history_latents = torch.cat([first_frame, recent_frames], dim=1)
|
| 590 |
+
print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
|
| 591 |
+
|
| 592 |
+
print(f"更新后history_latents shape: {history_latents.shape}")
|
| 593 |
+
|
| 594 |
+
all_generated_frames.append(new_latents_squeezed)
|
| 595 |
+
total_generated += current_generation
|
| 596 |
+
|
| 597 |
+
print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
|
| 598 |
+
|
| 599 |
+
# 7. 解码和保存
|
| 600 |
+
print("\n🔧 解码生成的视频...")
|
| 601 |
+
|
| 602 |
+
all_generated = torch.cat(all_generated_frames, dim=1)
|
| 603 |
+
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
|
| 604 |
+
|
| 605 |
+
print(f"最终视频shape: {final_video.shape}")
|
| 606 |
+
|
| 607 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 608 |
+
|
| 609 |
+
print(f"Saving video to {output_path}")
|
| 610 |
+
|
| 611 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
|
| 612 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1)
|
| 613 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 614 |
+
|
| 615 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 616 |
+
for frame in video_np:
|
| 617 |
+
writer.append_data(frame)
|
| 618 |
+
|
| 619 |
+
print(f"🔧 FramePack滑动窗口生成完成! 保存到: {output_path}")
|
| 620 |
+
print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
|
| 621 |
+
|
| 622 |
+
def main():
|
| 623 |
+
parser = argparse.ArgumentParser(description="Sekai FramePack滑动窗口视频生成 - 支持CFG")
|
| 624 |
+
parser.add_argument("--condition_pth", type=str,
|
| 625 |
+
default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
|
| 626 |
+
parser.add_argument("--start_frame", type=int, default=0)
|
| 627 |
+
parser.add_argument("--initial_condition_frames", type=int, default=16)
|
| 628 |
+
parser.add_argument("--frames_per_generation", type=int, default=8)
|
| 629 |
+
parser.add_argument("--total_frames_to_generate", type=int, default=40)
|
| 630 |
+
parser.add_argument("--max_history_frames", type=int, default=100)
|
| 631 |
+
parser.add_argument("--use_real_poses", action="store_true", default=False)
|
| 632 |
+
parser.add_argument("--dit_path", type=str,
|
| 633 |
+
default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/sekai/sekai_walking_framepack/step1000_framepack.ckpt")
|
| 634 |
+
parser.add_argument("--output_path", type=str,
|
| 635 |
+
default='/home/zhuyixuan05/ReCamMaster/sekai/infer_framepack_results/output_sekai_framepack_sliding.mp4')
|
| 636 |
+
parser.add_argument("--prompt", type=str,
|
| 637 |
+
default="A drone flying scene in a game world")
|
| 638 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 639 |
+
|
| 640 |
+
# 🔧 新增CFG参数
|
| 641 |
+
parser.add_argument("--use_camera_cfg", default=True,
|
| 642 |
+
help="使用Camera CFG")
|
| 643 |
+
parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
|
| 644 |
+
help="Camera guidance scale for CFG")
|
| 645 |
+
parser.add_argument("--text_guidance_scale", type=float, default=1.0,
|
| 646 |
+
help="Text guidance scale for CFG")
|
| 647 |
+
|
| 648 |
+
args = parser.parse_args()
|
| 649 |
+
|
| 650 |
+
print(f"🔧 FramePack CFG生成设置:")
|
| 651 |
+
print(f"Camera CFG: {args.use_camera_cfg}")
|
| 652 |
+
if args.use_camera_cfg:
|
| 653 |
+
print(f"Camera guidance scale: {args.camera_guidance_scale}")
|
| 654 |
+
print(f"Text guidance scale: {args.text_guidance_scale}")
|
| 655 |
+
|
| 656 |
+
inference_sekai_framepack_sliding_window(
|
| 657 |
+
condition_pth_path=args.condition_pth,
|
| 658 |
+
dit_path=args.dit_path,
|
| 659 |
+
output_path=args.output_path,
|
| 660 |
+
start_frame=args.start_frame,
|
| 661 |
+
initial_condition_frames=args.initial_condition_frames,
|
| 662 |
+
frames_per_generation=args.frames_per_generation,
|
| 663 |
+
total_frames_to_generate=args.total_frames_to_generate,
|
| 664 |
+
max_history_frames=args.max_history_frames,
|
| 665 |
+
device=args.device,
|
| 666 |
+
prompt=args.prompt,
|
| 667 |
+
use_real_poses=args.use_real_poses,
|
| 668 |
+
# 🔧 CFG参数
|
| 669 |
+
use_camera_cfg=args.use_camera_cfg,
|
| 670 |
+
camera_guidance_scale=args.camera_guidance_scale,
|
| 671 |
+
text_guidance_scale=args.text_guidance_scale
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
if __name__ == "__main__":
|
| 675 |
+
main()
|
scripts/infer_sekai_framepack_4.py
ADDED
|
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import imageio
|
| 7 |
+
import json
|
| 8 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 9 |
+
import argparse
|
| 10 |
+
from torchvision.transforms import v2
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
import copy
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 16 |
+
"""从pth文件加载预编码的视频数据"""
|
| 17 |
+
print(f"Loading encoded video from {pth_path}")
|
| 18 |
+
|
| 19 |
+
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
| 20 |
+
full_latents = encoded_data['latents'] # [C, T, H, W]
|
| 21 |
+
|
| 22 |
+
print(f"Full latents shape: {full_latents.shape}")
|
| 23 |
+
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
|
| 24 |
+
|
| 25 |
+
if start_frame + num_frames > full_latents.shape[1]:
|
| 26 |
+
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
|
| 27 |
+
|
| 28 |
+
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
|
| 29 |
+
print(f"Extracted condition latents shape: {condition_latents.shape}")
|
| 30 |
+
|
| 31 |
+
return condition_latents, encoded_data
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 35 |
+
"""计算相机B相对于相机A的相对位姿矩阵"""
|
| 36 |
+
assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
|
| 37 |
+
assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
|
| 38 |
+
|
| 39 |
+
if use_torch:
|
| 40 |
+
if not isinstance(pose_a, torch.Tensor):
|
| 41 |
+
pose_a = torch.from_numpy(pose_a).float()
|
| 42 |
+
if not isinstance(pose_b, torch.Tensor):
|
| 43 |
+
pose_b = torch.from_numpy(pose_b).float()
|
| 44 |
+
|
| 45 |
+
pose_a_inv = torch.inverse(pose_a)
|
| 46 |
+
relative_pose = torch.matmul(pose_b, pose_a_inv)
|
| 47 |
+
else:
|
| 48 |
+
if not isinstance(pose_a, np.ndarray):
|
| 49 |
+
pose_a = np.array(pose_a, dtype=np.float32)
|
| 50 |
+
if not isinstance(pose_b, np.ndarray):
|
| 51 |
+
pose_b = np.array(pose_b, dtype=np.float32)
|
| 52 |
+
|
| 53 |
+
pose_a_inv = np.linalg.inv(pose_a)
|
| 54 |
+
relative_pose = np.matmul(pose_b, pose_a_inv)
|
| 55 |
+
|
| 56 |
+
return relative_pose
|
| 57 |
+
|
| 58 |
+
def replace_dit_model_in_manager():
|
| 59 |
+
"""替换DiT模型类为FramePack版本"""
|
| 60 |
+
from diffsynth.models.wan_video_dit_4 import WanModelFuture4
|
| 61 |
+
from diffsynth.configs.model_config import model_loader_configs
|
| 62 |
+
|
| 63 |
+
for i, config in enumerate(model_loader_configs):
|
| 64 |
+
keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
|
| 65 |
+
|
| 66 |
+
if 'wan_video_dit' in model_names:
|
| 67 |
+
new_model_names = []
|
| 68 |
+
new_model_classes = []
|
| 69 |
+
|
| 70 |
+
for name, cls in zip(model_names, model_classes):
|
| 71 |
+
if name == 'wan_video_dit':
|
| 72 |
+
new_model_names.append(name)
|
| 73 |
+
new_model_classes.append(WanModelFuture4)
|
| 74 |
+
print(f"✅ 替换了模型类: {name} -> WanModelFuture4")
|
| 75 |
+
else:
|
| 76 |
+
new_model_names.append(name)
|
| 77 |
+
new_model_classes.append(cls)
|
| 78 |
+
|
| 79 |
+
model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def add_framepack_components(dit_model):
|
| 83 |
+
"""添加FramePack相关组件"""
|
| 84 |
+
if not hasattr(dit_model, 'clean_x_embedder'):
|
| 85 |
+
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 86 |
+
|
| 87 |
+
class CleanXEmbedder(nn.Module):
|
| 88 |
+
def __init__(self, inner_dim):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 91 |
+
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
| 92 |
+
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
| 93 |
+
|
| 94 |
+
def forward(self, x, scale="1x"):
|
| 95 |
+
if scale == "1x":
|
| 96 |
+
x = x.to(self.proj.weight.dtype)
|
| 97 |
+
return self.proj(x)
|
| 98 |
+
elif scale == "2x":
|
| 99 |
+
x = x.to(self.proj_2x.weight.dtype)
|
| 100 |
+
return self.proj_2x(x)
|
| 101 |
+
elif scale == "4x":
|
| 102 |
+
x = x.to(self.proj_4x.weight.dtype)
|
| 103 |
+
return self.proj_4x(x)
|
| 104 |
+
else:
|
| 105 |
+
raise ValueError(f"Unsupported scale: {scale}")
|
| 106 |
+
|
| 107 |
+
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
|
| 108 |
+
model_dtype = next(dit_model.parameters()).dtype
|
| 109 |
+
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
|
| 110 |
+
print("✅ 添加了FramePack的clean_x_embedder组件")
|
| 111 |
+
|
| 112 |
+
def generate_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
|
| 113 |
+
"""🔧 为滑动窗口生成camera embeddings - 修正长度计算,确保包含start_latent帧"""
|
| 114 |
+
time_compression_ratio = 4
|
| 115 |
+
|
| 116 |
+
# 🔧 计算FramePack实际需要的camera帧数
|
| 117 |
+
# FramePack结构: 1(start) + 16(4x) + 2(2x) + 1(1x) + target_frames
|
| 118 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 119 |
+
|
| 120 |
+
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
|
| 121 |
+
print("🔧 使用真实camera数据")
|
| 122 |
+
cam_extrinsic = cam_data['extrinsic']
|
| 123 |
+
|
| 124 |
+
# 🔧 确保生成足够长的camera序列
|
| 125 |
+
# 需要考虑:当前历史位置 + FramePack所需的完整结构
|
| 126 |
+
max_needed_frames = max(
|
| 127 |
+
start_frame + current_history_length + new_frames, # 基础需求
|
| 128 |
+
framepack_needed_frames, # FramePack结构需求
|
| 129 |
+
30 # 最小保证长度
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
print(f"🔧 计算camera序列长度:")
|
| 133 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 134 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 135 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 136 |
+
|
| 137 |
+
relative_poses = []
|
| 138 |
+
for i in range(max_needed_frames):
|
| 139 |
+
# 计算当前帧在原始序列中的位置
|
| 140 |
+
frame_idx = i * time_compression_ratio
|
| 141 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 142 |
+
|
| 143 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 144 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 145 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 146 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 147 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 148 |
+
else:
|
| 149 |
+
# 超出范围,使用零运动
|
| 150 |
+
print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
|
| 151 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 152 |
+
|
| 153 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 154 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 155 |
+
|
| 156 |
+
# 🔧 创建对应长度的mask序列
|
| 157 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 158 |
+
# 从start_frame到current_history_length标记为condition
|
| 159 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 160 |
+
mask[start_frame:condition_end] = 1.0
|
| 161 |
+
|
| 162 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 163 |
+
print(f"🔧 真实camera embedding shape: {camera_embedding.shape} (总长度:{max_needed_frames})")
|
| 164 |
+
return camera_embedding.to(torch.bfloat16)
|
| 165 |
+
|
| 166 |
+
else:
|
| 167 |
+
print("🔧 使用合成camera数据")
|
| 168 |
+
|
| 169 |
+
# 🔧 确保合成数据也有足够长度
|
| 170 |
+
max_needed_frames = max(
|
| 171 |
+
start_frame + current_history_length + new_frames,
|
| 172 |
+
framepack_needed_frames,
|
| 173 |
+
30
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
print(f"🔧 生成合成camera帧数: {max_needed_frames}")
|
| 177 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 178 |
+
|
| 179 |
+
relative_poses = []
|
| 180 |
+
for i in range(max_needed_frames):
|
| 181 |
+
# 🔧 持续左转运动模式
|
| 182 |
+
# 每帧旋转一个固定角度,同时前进
|
| 183 |
+
yaw_per_frame = -0.05 # 每帧左转(正角度表示左转)
|
| 184 |
+
forward_speed = 0.005 # 每帧前进距离
|
| 185 |
+
|
| 186 |
+
# 计算当前累积角度
|
| 187 |
+
current_yaw = i * yaw_per_frame
|
| 188 |
+
|
| 189 |
+
# 创建相对变换矩阵(从第i帧到第i+1帧的变换)
|
| 190 |
+
pose = np.eye(4, dtype=np.float32)
|
| 191 |
+
|
| 192 |
+
# 旋转矩阵(绕Y轴左转)
|
| 193 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 194 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 195 |
+
|
| 196 |
+
pose[0, 0] = cos_yaw
|
| 197 |
+
pose[0, 2] = sin_yaw
|
| 198 |
+
pose[2, 0] = -sin_yaw
|
| 199 |
+
pose[2, 2] = cos_yaw
|
| 200 |
+
|
| 201 |
+
# 平移(在旋转后的局部坐标系中前进)
|
| 202 |
+
pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
|
| 203 |
+
|
| 204 |
+
# 可选:添加轻微的向心运动,模拟圆形轨迹
|
| 205 |
+
radius_drift = 0.002 # 向圆心的轻微漂移
|
| 206 |
+
pose[0, 3] = radius_drift # 局部X轴负方向(向左)
|
| 207 |
+
|
| 208 |
+
relative_pose = pose[:3, :]
|
| 209 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 210 |
+
|
| 211 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 212 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 213 |
+
|
| 214 |
+
# 创建对应长度的mask序列
|
| 215 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 216 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 217 |
+
mask[start_frame:condition_end] = 1.0
|
| 218 |
+
|
| 219 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 220 |
+
print(f"🔧 合成camera embedding shape: {camera_embedding.shape} (总长度:{max_needed_frames})")
|
| 221 |
+
return camera_embedding.to(torch.bfloat16)
|
| 222 |
+
|
| 223 |
+
def prepare_framepack_sliding_window_with_camera(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, max_history_frames=49):
|
| 224 |
+
"""🔧 FramePack滑动窗口机制 - 支���起始4帧+最后1帧的clean_latents"""
|
| 225 |
+
# history_latents: [C, T, H, W] 当前的历史latents
|
| 226 |
+
C, T, H, W = history_latents.shape
|
| 227 |
+
|
| 228 |
+
# 🔧 固定索引结构:起始4帧 + 最后1帧 = 5帧clean_latents
|
| 229 |
+
total_indices_length = 1 + 16 + 2 + 5 + target_frames_to_generate # 修改:clean_latents现在是5帧
|
| 230 |
+
indices = torch.arange(0, total_indices_length)
|
| 231 |
+
split_sizes = [1, 16, 2, 5, target_frames_to_generate] # 修改:clean_latents部分改为5
|
| 232 |
+
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
|
| 233 |
+
indices.split(split_sizes, dim=0)
|
| 234 |
+
|
| 235 |
+
# clean_latents结构:起始4帧 + 最后1帧
|
| 236 |
+
clean_latent_indices = clean_latent_1x_indices # 现在是5帧,包含起始4帧+最后1帧
|
| 237 |
+
|
| 238 |
+
# 🔧 检查camera长度是否足够
|
| 239 |
+
if camera_embedding_full.shape[0] < total_indices_length:
|
| 240 |
+
shortage = total_indices_length - camera_embedding_full.shape[0]
|
| 241 |
+
padding = torch.zeros(shortage, camera_embedding_full.shape[1],
|
| 242 |
+
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
|
| 243 |
+
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
|
| 244 |
+
|
| 245 |
+
# 🔧 从完整camera序列中选取对应部分
|
| 246 |
+
combined_camera = camera_embedding_full[:total_indices_length, :].clone()
|
| 247 |
+
|
| 248 |
+
# 🔧 根据当前history length重新设置mask
|
| 249 |
+
combined_camera[:, -1] = 0.0 # 先全部设为target (0)
|
| 250 |
+
|
| 251 |
+
# 设置condition mask:前24帧根据实际历史长度决定(1+16+2+5)
|
| 252 |
+
if T > 0:
|
| 253 |
+
# 根据clean_latents的填充逻辑,确定哪些位置应该是condition
|
| 254 |
+
available_frames = min(T, 24) # 修改:现在前24帧对应clean latents
|
| 255 |
+
start_pos = 24 - available_frames
|
| 256 |
+
|
| 257 |
+
# 对应的camera位置也应该标记为condition
|
| 258 |
+
combined_camera[start_pos:24, -1] = 1.0 # 修改:前24帧对应condition
|
| 259 |
+
|
| 260 |
+
# target部分保持为0(已经在上面设置)
|
| 261 |
+
|
| 262 |
+
print(f"🔧 Camera mask更新:")
|
| 263 |
+
print(f" - 历史帧数: {T}")
|
| 264 |
+
print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
|
| 265 |
+
print(f" - Condition mask (前24帧): {combined_camera[:24, -1].cpu().tolist()}") # 修改:24帧
|
| 266 |
+
print(f" - Target mask (后{target_frames_to_generate}帧): {combined_camera[24:, -1].cpu().tolist()}")
|
| 267 |
+
|
| 268 |
+
# 处理clean latents - 现在clean_latents是5帧:起始4帧+最后1帧
|
| 269 |
+
clean_latents_combined = torch.zeros(C, 24, H, W, dtype=history_latents.dtype, device=history_latents.device) # 修改:24帧
|
| 270 |
+
|
| 271 |
+
if T > 0:
|
| 272 |
+
available_frames = min(T, 24) # 修改:24帧
|
| 273 |
+
start_pos = 24 - available_frames
|
| 274 |
+
clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
|
| 275 |
+
|
| 276 |
+
clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
|
| 277 |
+
clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
|
| 278 |
+
clean_latents_1x = clean_latents_combined[:, 18:23, :, :] # 修改:5帧clean latents
|
| 279 |
+
|
| 280 |
+
# 构建clean_latents:起始4帧 + 最后1帧
|
| 281 |
+
if T >= 5:
|
| 282 |
+
# 如果历史足够,取起始4帧+最后1帧
|
| 283 |
+
start_latent = history_latents[:, 0:4, :, :] # 起始4帧
|
| 284 |
+
last_latent = history_latents[:, -1:, :, :] # 最后1帧
|
| 285 |
+
clean_latents = torch.cat([start_latent, last_latent], dim=1) # 5帧
|
| 286 |
+
elif T > 0:
|
| 287 |
+
# 如果历史不足5帧,用0填充+最后1帧
|
| 288 |
+
clean_latents = torch.zeros(C, 5, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 289 |
+
# 从后往前填充历史帧
|
| 290 |
+
clean_latents[:, -T:, :, :] = history_latents
|
| 291 |
+
else:
|
| 292 |
+
# 没有历史,全部用0
|
| 293 |
+
clean_latents = torch.zeros(C, 5, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 294 |
+
|
| 295 |
+
return {
|
| 296 |
+
'latent_indices': latent_indices,
|
| 297 |
+
'clean_latents': clean_latents, # 现在是5帧:起始4帧+最后1帧
|
| 298 |
+
'clean_latents_2x': clean_latents_2x,
|
| 299 |
+
'clean_latents_4x': clean_latents_4x,
|
| 300 |
+
'clean_latent_indices': clean_latent_indices,
|
| 301 |
+
'clean_latent_2x_indices': clean_latent_2x_indices,
|
| 302 |
+
'clean_latent_4x_indices': clean_latent_4x_indices,
|
| 303 |
+
'camera_embedding': combined_camera,
|
| 304 |
+
'current_length': T,
|
| 305 |
+
'next_length': T + target_frames_to_generate
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
def inference_sekai_framepack_sliding_window(
|
| 309 |
+
condition_pth_path,
|
| 310 |
+
dit_path,
|
| 311 |
+
output_path="sekai/infer_results/output_sekai_framepack_sliding.mp4",
|
| 312 |
+
start_frame=0,
|
| 313 |
+
initial_condition_frames=8,
|
| 314 |
+
frames_per_generation=4,
|
| 315 |
+
total_frames_to_generate=32,
|
| 316 |
+
max_history_frames=49,
|
| 317 |
+
device="cuda",
|
| 318 |
+
prompt="A video of a scene shot using a pedestrian's front camera while walking",
|
| 319 |
+
use_real_poses=True,
|
| 320 |
+
synthetic_direction="forward",
|
| 321 |
+
# 🔧 新增CFG参数
|
| 322 |
+
use_camera_cfg=True,
|
| 323 |
+
camera_guidance_scale=2.0,
|
| 324 |
+
text_guidance_scale=7.5
|
| 325 |
+
):
|
| 326 |
+
"""
|
| 327 |
+
🔧 FramePack滑动窗口视频生成 - 支持Camera CFG
|
| 328 |
+
"""
|
| 329 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 330 |
+
print(f"🔧 FramePack滑动窗口生成开始...")
|
| 331 |
+
print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
|
| 332 |
+
print(f"Text guidance scale: {text_guidance_scale}")
|
| 333 |
+
print(f"初始条件帧: {initial_condition_frames}, 每次生成: {frames_per_generation}, 总生成: {total_frames_to_generate}")
|
| 334 |
+
print(f"使用真实姿态: {use_real_poses}")
|
| 335 |
+
if not use_real_poses:
|
| 336 |
+
print(f"合成camera方向: {synthetic_direction}")
|
| 337 |
+
|
| 338 |
+
# 1-3. 模型初始化和组件添加(保持不变)
|
| 339 |
+
replace_dit_model_in_manager()
|
| 340 |
+
|
| 341 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 342 |
+
model_manager.load_models([
|
| 343 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 344 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 345 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 346 |
+
])
|
| 347 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 348 |
+
|
| 349 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 350 |
+
for block in pipe.dit.blocks:
|
| 351 |
+
block.cam_encoder = nn.Linear(13, dim)
|
| 352 |
+
block.projector = nn.Linear(dim, dim)
|
| 353 |
+
block.cam_encoder.weight.data.zero_()
|
| 354 |
+
block.cam_encoder.bias.data.zero_()
|
| 355 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 356 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 357 |
+
|
| 358 |
+
add_framepack_components(pipe.dit)
|
| 359 |
+
|
| 360 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 361 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=True)
|
| 362 |
+
pipe = pipe.to(device)
|
| 363 |
+
model_dtype = next(pipe.dit.parameters()).dtype
|
| 364 |
+
|
| 365 |
+
if hasattr(pipe.dit, 'clean_x_embedder'):
|
| 366 |
+
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
|
| 367 |
+
|
| 368 |
+
pipe.scheduler.set_timesteps(50)
|
| 369 |
+
|
| 370 |
+
# 4. 加载初始条件
|
| 371 |
+
print("Loading initial condition frames...")
|
| 372 |
+
initial_latents, encoded_data = load_encoded_video_from_pth(
|
| 373 |
+
condition_pth_path,
|
| 374 |
+
start_frame=start_frame,
|
| 375 |
+
num_frames=initial_condition_frames
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# 空间裁剪
|
| 379 |
+
target_height, target_width = 60, 104
|
| 380 |
+
C, T, H, W = initial_latents.shape
|
| 381 |
+
|
| 382 |
+
if H > target_height or W > target_width:
|
| 383 |
+
h_start = (H - target_height) // 2
|
| 384 |
+
w_start = (W - target_width) // 2
|
| 385 |
+
initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
|
| 386 |
+
H, W = target_height, target_width
|
| 387 |
+
|
| 388 |
+
history_latents = initial_latents.to(device, dtype=model_dtype)
|
| 389 |
+
|
| 390 |
+
print(f"初始history_latents shape: {history_latents.shape}")
|
| 391 |
+
|
| 392 |
+
# 编码prompt - 支持CFG
|
| 393 |
+
if text_guidance_scale > 1.0:
|
| 394 |
+
# 编码positive prompt
|
| 395 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 396 |
+
# 编码negative prompt (空字符串)
|
| 397 |
+
prompt_emb_neg = pipe.encode_prompt("")
|
| 398 |
+
print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
|
| 399 |
+
else:
|
| 400 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 401 |
+
prompt_emb_neg = None
|
| 402 |
+
print("不使用Text CFG")
|
| 403 |
+
|
| 404 |
+
# 预生成完整的camera embedding序列
|
| 405 |
+
camera_embedding_full = generate_camera_embeddings_sliding(
|
| 406 |
+
encoded_data.get('cam_emb', None),
|
| 407 |
+
0,
|
| 408 |
+
max_history_frames,
|
| 409 |
+
0,
|
| 410 |
+
0,
|
| 411 |
+
use_real_poses=use_real_poses
|
| 412 |
+
).to(device, dtype=model_dtype)
|
| 413 |
+
|
| 414 |
+
print(f"完整camera序列shape: {camera_embedding_full.shape}")
|
| 415 |
+
|
| 416 |
+
# 🔧 为Camera CFG创建无条件的camera embedding
|
| 417 |
+
if use_camera_cfg:
|
| 418 |
+
# 创建零camera embedding(无条件)
|
| 419 |
+
camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
|
| 420 |
+
print(f"创建无条件camera embedding用于CFG")
|
| 421 |
+
|
| 422 |
+
# 滑动窗口生成循环
|
| 423 |
+
total_generated = 0
|
| 424 |
+
all_generated_frames = []
|
| 425 |
+
|
| 426 |
+
while total_generated < total_frames_to_generate:
|
| 427 |
+
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
|
| 428 |
+
print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
|
| 429 |
+
print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
|
| 430 |
+
|
| 431 |
+
# FramePack数据准备
|
| 432 |
+
framepack_data = prepare_framepack_sliding_window_with_camera(
|
| 433 |
+
history_latents,
|
| 434 |
+
current_generation,
|
| 435 |
+
camera_embedding_full,
|
| 436 |
+
start_frame,
|
| 437 |
+
max_history_frames
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# 准备输入
|
| 441 |
+
clean_latents = framepack_data['clean_latents'].unsqueeze(0)
|
| 442 |
+
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
|
| 443 |
+
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
|
| 444 |
+
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
|
| 445 |
+
|
| 446 |
+
# 🔧 为CFG准备无条件camera embedding
|
| 447 |
+
if use_camera_cfg:
|
| 448 |
+
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
|
| 449 |
+
|
| 450 |
+
# 索引处理
|
| 451 |
+
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
|
| 452 |
+
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
|
| 453 |
+
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
|
| 454 |
+
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
|
| 455 |
+
|
| 456 |
+
# 初始化要生成的latents
|
| 457 |
+
new_latents = torch.randn(
|
| 458 |
+
1, C, current_generation, H, W,
|
| 459 |
+
device=device, dtype=model_dtype
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
extra_input = pipe.prepare_extra_input(new_latents)
|
| 463 |
+
|
| 464 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 465 |
+
print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
|
| 466 |
+
|
| 467 |
+
# 去噪循环 - 支持CFG
|
| 468 |
+
timesteps = pipe.scheduler.timesteps
|
| 469 |
+
|
| 470 |
+
for i, timestep in enumerate(timesteps):
|
| 471 |
+
if i % 10 == 0:
|
| 472 |
+
print(f" 去噪步骤 {i+1}/{len(timesteps)}")
|
| 473 |
+
|
| 474 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
|
| 475 |
+
|
| 476 |
+
with torch.no_grad():
|
| 477 |
+
# 🔧 CFG推理
|
| 478 |
+
if use_camera_cfg and camera_guidance_scale > 1.0:
|
| 479 |
+
# 条件预测(有camera)
|
| 480 |
+
noise_pred_cond = pipe.dit(
|
| 481 |
+
new_latents,
|
| 482 |
+
timestep=timestep_tensor,
|
| 483 |
+
cam_emb=camera_embedding,
|
| 484 |
+
latent_indices=latent_indices,
|
| 485 |
+
clean_latents=clean_latents,
|
| 486 |
+
clean_latent_indices=clean_latent_indices,
|
| 487 |
+
clean_latents_2x=clean_latents_2x,
|
| 488 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 489 |
+
clean_latents_4x=clean_latents_4x,
|
| 490 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 491 |
+
**prompt_emb_pos,
|
| 492 |
+
**extra_input
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
# 无条件预测(无camera)
|
| 496 |
+
noise_pred_uncond = pipe.dit(
|
| 497 |
+
new_latents,
|
| 498 |
+
timestep=timestep_tensor,
|
| 499 |
+
cam_emb=camera_embedding_uncond_batch,
|
| 500 |
+
latent_indices=latent_indices,
|
| 501 |
+
clean_latents=clean_latents,
|
| 502 |
+
clean_latent_indices=clean_latent_indices,
|
| 503 |
+
clean_latents_2x=clean_latents_2x,
|
| 504 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 505 |
+
clean_latents_4x=clean_latents_4x,
|
| 506 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 507 |
+
**(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
|
| 508 |
+
**extra_input
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# Camera CFG
|
| 512 |
+
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 513 |
+
|
| 514 |
+
# 如果同时使用Text CFG
|
| 515 |
+
if text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 516 |
+
# 还需要计算text无条件预测
|
| 517 |
+
noise_pred_text_uncond = pipe.dit(
|
| 518 |
+
new_latents,
|
| 519 |
+
timestep=timestep_tensor,
|
| 520 |
+
cam_emb=camera_embedding,
|
| 521 |
+
latent_indices=latent_indices,
|
| 522 |
+
clean_latents=clean_latents,
|
| 523 |
+
clean_latent_indices=clean_latent_indices,
|
| 524 |
+
clean_latents_2x=clean_latents_2x,
|
| 525 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 526 |
+
clean_latents_4x=clean_latents_4x,
|
| 527 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 528 |
+
**prompt_emb_neg,
|
| 529 |
+
**extra_input
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
# 应用Text CFG到已经应用Camera CFG的结果
|
| 533 |
+
noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
|
| 534 |
+
|
| 535 |
+
elif text_guidance_scale > 1.0 and prompt_emb_neg:
|
| 536 |
+
# 只使用Text CFG
|
| 537 |
+
noise_pred_cond = pipe.dit(
|
| 538 |
+
new_latents,
|
| 539 |
+
timestep=timestep_tensor,
|
| 540 |
+
cam_emb=camera_embedding,
|
| 541 |
+
latent_indices=latent_indices,
|
| 542 |
+
clean_latents=clean_latents,
|
| 543 |
+
clean_latent_indices=clean_latent_indices,
|
| 544 |
+
clean_latents_2x=clean_latents_2x,
|
| 545 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 546 |
+
clean_latents_4x=clean_latents_4x,
|
| 547 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 548 |
+
**prompt_emb_pos,
|
| 549 |
+
**extra_input
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
noise_pred_uncond = pipe.dit(
|
| 553 |
+
new_latents,
|
| 554 |
+
timestep=timestep_tensor,
|
| 555 |
+
cam_emb=camera_embedding,
|
| 556 |
+
latent_indices=latent_indices,
|
| 557 |
+
clean_latents=clean_latents,
|
| 558 |
+
clean_latent_indices=clean_latent_indices,
|
| 559 |
+
clean_latents_2x=clean_latents_2x,
|
| 560 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 561 |
+
clean_latents_4x=clean_latents_4x,
|
| 562 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 563 |
+
**prompt_emb_neg,
|
| 564 |
+
**extra_input
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 568 |
+
|
| 569 |
+
else:
|
| 570 |
+
# 标准推理(无CFG)
|
| 571 |
+
noise_pred = pipe.dit(
|
| 572 |
+
new_latents,
|
| 573 |
+
timestep=timestep_tensor,
|
| 574 |
+
cam_emb=camera_embedding,
|
| 575 |
+
latent_indices=latent_indices,
|
| 576 |
+
clean_latents=clean_latents,
|
| 577 |
+
clean_latent_indices=clean_latent_indices,
|
| 578 |
+
clean_latents_2x=clean_latents_2x,
|
| 579 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 580 |
+
clean_latents_4x=clean_latents_4x,
|
| 581 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 582 |
+
**prompt_emb_pos,
|
| 583 |
+
**extra_input
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
|
| 587 |
+
|
| 588 |
+
# 更新历史
|
| 589 |
+
new_latents_squeezed = new_latents.squeeze(0)
|
| 590 |
+
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
|
| 591 |
+
|
| 592 |
+
# 维护滑动窗口
|
| 593 |
+
if history_latents.shape[1] > max_history_frames:
|
| 594 |
+
first_frame = history_latents[:, 0:1, :, :]
|
| 595 |
+
recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
|
| 596 |
+
history_latents = torch.cat([first_frame, recent_frames], dim=1)
|
| 597 |
+
print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
|
| 598 |
+
|
| 599 |
+
print(f"更新后history_latents shape: {history_latents.shape}")
|
| 600 |
+
|
| 601 |
+
all_generated_frames.append(new_latents_squeezed)
|
| 602 |
+
total_generated += current_generation
|
| 603 |
+
|
| 604 |
+
print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
|
| 605 |
+
|
| 606 |
+
# 7. 解码和保存
|
| 607 |
+
print("\n🔧 解码生成的视频...")
|
| 608 |
+
|
| 609 |
+
all_generated = torch.cat(all_generated_frames, dim=1)
|
| 610 |
+
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
|
| 611 |
+
|
| 612 |
+
print(f"最终视频shape: {final_video.shape}")
|
| 613 |
+
|
| 614 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 615 |
+
|
| 616 |
+
print(f"Saving video to {output_path}")
|
| 617 |
+
|
| 618 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
|
| 619 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1)
|
| 620 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 621 |
+
|
| 622 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 623 |
+
for frame in video_np:
|
| 624 |
+
writer.append_data(frame)
|
| 625 |
+
|
| 626 |
+
print(f"🔧 FramePack滑动窗口生成完成! 保存到: {output_path}")
|
| 627 |
+
print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
|
| 628 |
+
|
| 629 |
+
def main():
|
| 630 |
+
parser = argparse.ArgumentParser(description="Sekai FramePack滑动窗口视频生成 - 支持CFG")
|
| 631 |
+
parser.add_argument("--condition_pth", type=str,
|
| 632 |
+
default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
|
| 633 |
+
parser.add_argument("--start_frame", type=int, default=0)
|
| 634 |
+
parser.add_argument("--initial_condition_frames", type=int, default=16)
|
| 635 |
+
parser.add_argument("--frames_per_generation", type=int, default=8)
|
| 636 |
+
parser.add_argument("--total_frames_to_generate", type=int, default=60)
|
| 637 |
+
parser.add_argument("--max_history_frames", type=int, default=100)
|
| 638 |
+
parser.add_argument("--use_real_poses", action="store_true", default=True)
|
| 639 |
+
parser.add_argument("--dit_path", type=str,
|
| 640 |
+
default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/sekai/sekai_walking_framepack_4/step34290_framepack.ckpt")
|
| 641 |
+
parser.add_argument("--output_path", type=str,
|
| 642 |
+
default='/home/zhuyixuan05/ReCamMaster/sekai/infer_framepack_results/output_sekai_framepack_sliding.mp4')
|
| 643 |
+
parser.add_argument("--prompt", type=str,
|
| 644 |
+
default="A drone flying scene in a game world")
|
| 645 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 646 |
+
|
| 647 |
+
# 🔧 新增CFG参数
|
| 648 |
+
parser.add_argument("--use_camera_cfg", default=False,
|
| 649 |
+
help="使用Camera CFG")
|
| 650 |
+
parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
|
| 651 |
+
help="Camera guidance scale for CFG")
|
| 652 |
+
parser.add_argument("--text_guidance_scale", type=float, default=1.0,
|
| 653 |
+
help="Text guidance scale for CFG")
|
| 654 |
+
|
| 655 |
+
args = parser.parse_args()
|
| 656 |
+
|
| 657 |
+
print(f"🔧 FramePack CFG生成设置:")
|
| 658 |
+
print(f"Camera CFG: {args.use_camera_cfg}")
|
| 659 |
+
if args.use_camera_cfg:
|
| 660 |
+
print(f"Camera guidance scale: {args.camera_guidance_scale}")
|
| 661 |
+
print(f"Text guidance scale: {args.text_guidance_scale}")
|
| 662 |
+
|
| 663 |
+
inference_sekai_framepack_sliding_window(
|
| 664 |
+
condition_pth_path=args.condition_pth,
|
| 665 |
+
dit_path=args.dit_path,
|
| 666 |
+
output_path=args.output_path,
|
| 667 |
+
start_frame=args.start_frame,
|
| 668 |
+
initial_condition_frames=args.initial_condition_frames,
|
| 669 |
+
frames_per_generation=args.frames_per_generation,
|
| 670 |
+
total_frames_to_generate=args.total_frames_to_generate,
|
| 671 |
+
max_history_frames=args.max_history_frames,
|
| 672 |
+
device=args.device,
|
| 673 |
+
prompt=args.prompt,
|
| 674 |
+
use_real_poses=args.use_real_poses,
|
| 675 |
+
# 🔧 CFG参数
|
| 676 |
+
use_camera_cfg=args.use_camera_cfg,
|
| 677 |
+
camera_guidance_scale=args.camera_guidance_scale,
|
| 678 |
+
text_guidance_scale=args.text_guidance_scale
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
if __name__ == "__main__":
|
| 682 |
+
main()
|
scripts/infer_sekai_framepack_test.py
ADDED
|
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import imageio
|
| 7 |
+
import json
|
| 8 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 9 |
+
import argparse
|
| 10 |
+
from torchvision.transforms import v2
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
import copy
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 16 |
+
"""
|
| 17 |
+
从pth文件加载预编码的视频数据
|
| 18 |
+
Args:
|
| 19 |
+
pth_path: pth文件路径
|
| 20 |
+
start_frame: 起始帧索引(基于压缩后的latent帧数)
|
| 21 |
+
num_frames: 需要的帧数(基于压缩后的latent帧数)
|
| 22 |
+
Returns:
|
| 23 |
+
condition_latents: [C, T, H, W] 格式的latent tensor
|
| 24 |
+
"""
|
| 25 |
+
print(f"Loading encoded video from {pth_path}")
|
| 26 |
+
|
| 27 |
+
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
| 28 |
+
full_latents = encoded_data['latents'] # [C, T, H, W]
|
| 29 |
+
|
| 30 |
+
print(f"Full latents shape: {full_latents.shape}")
|
| 31 |
+
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
|
| 32 |
+
|
| 33 |
+
if start_frame + num_frames > full_latents.shape[1]:
|
| 34 |
+
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
|
| 35 |
+
|
| 36 |
+
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
|
| 37 |
+
print(f"Extracted condition latents shape: {condition_latents.shape}")
|
| 38 |
+
|
| 39 |
+
return condition_latents, encoded_data
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def compute_relative_pose(pose_a, pose_b, use_torch=False):
|
| 43 |
+
"""计算相机B相对于相机A的相对位姿矩阵"""
|
| 44 |
+
assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
|
| 45 |
+
assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
|
| 46 |
+
|
| 47 |
+
if use_torch:
|
| 48 |
+
if not isinstance(pose_a, torch.Tensor):
|
| 49 |
+
pose_a = torch.from_numpy(pose_a).float()
|
| 50 |
+
if not isinstance(pose_b, torch.Tensor):
|
| 51 |
+
pose_b = torch.from_numpy(pose_b).float()
|
| 52 |
+
|
| 53 |
+
pose_a_inv = torch.inverse(pose_a)
|
| 54 |
+
relative_pose = torch.matmul(pose_b, pose_a_inv)
|
| 55 |
+
else:
|
| 56 |
+
if not isinstance(pose_a, np.ndarray):
|
| 57 |
+
pose_a = np.array(pose_a, dtype=np.float32)
|
| 58 |
+
if not isinstance(pose_b, np.ndarray):
|
| 59 |
+
pose_b = np.array(pose_b, dtype=np.float32)
|
| 60 |
+
|
| 61 |
+
pose_a_inv = np.linalg.inv(pose_a)
|
| 62 |
+
relative_pose = np.matmul(pose_b, pose_a_inv)
|
| 63 |
+
|
| 64 |
+
return relative_pose
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def prepare_framepack_inputs(full_latents, condition_frames, target_frames, start_frame=0):
|
| 68 |
+
"""🔧 准备FramePack风格的多尺度输入"""
|
| 69 |
+
# 确保有batch维度
|
| 70 |
+
if len(full_latents.shape) == 4: # [C, T, H, W]
|
| 71 |
+
full_latents = full_latents.unsqueeze(0) # -> [1, C, T, H, W]
|
| 72 |
+
squeeze_batch = True
|
| 73 |
+
else:
|
| 74 |
+
squeeze_batch = False
|
| 75 |
+
|
| 76 |
+
B, C, T, H, W = full_latents.shape
|
| 77 |
+
|
| 78 |
+
# 主要latents(用于去噪预测)
|
| 79 |
+
target_start = start_frame + condition_frames
|
| 80 |
+
target_end = target_start + target_frames
|
| 81 |
+
latent_indices = torch.arange(target_start, target_end)
|
| 82 |
+
main_latents = full_latents[:, :, latent_indices, :, :]
|
| 83 |
+
|
| 84 |
+
# 🔧 1x条件帧(起始帧 + 最后1帧)
|
| 85 |
+
clean_latent_indices = torch.tensor([start_frame, start_frame + condition_frames - 1])
|
| 86 |
+
clean_latents = full_latents[:, :, clean_latent_indices, :, :]
|
| 87 |
+
|
| 88 |
+
# 🔧 2x条件帧(最后2帧)
|
| 89 |
+
clean_latents_2x = torch.zeros(B, C, 2, H, W, dtype=full_latents.dtype)
|
| 90 |
+
clean_latent_2x_indices = torch.full((2,), -1, dtype=torch.long)
|
| 91 |
+
|
| 92 |
+
if condition_frames >= 2:
|
| 93 |
+
actual_indices = torch.arange(max(start_frame, start_frame + condition_frames - 2),
|
| 94 |
+
start_frame + condition_frames)
|
| 95 |
+
start_pos = 2 - len(actual_indices)
|
| 96 |
+
clean_latents_2x[:, :, start_pos:, :, :] = full_latents[:, :, actual_indices, :, :]
|
| 97 |
+
clean_latent_2x_indices[start_pos:] = actual_indices
|
| 98 |
+
|
| 99 |
+
# 🔧 4x条件帧(最多16帧)
|
| 100 |
+
clean_latents_4x = torch.zeros(B, C, 16, H, W, dtype=full_latents.dtype)
|
| 101 |
+
clean_latent_4x_indices = torch.full((16,), -1, dtype=torch.long)
|
| 102 |
+
|
| 103 |
+
if condition_frames >= 1:
|
| 104 |
+
actual_indices = torch.arange(max(start_frame, start_frame + condition_frames - 16),
|
| 105 |
+
start_frame + condition_frames)
|
| 106 |
+
start_pos = 16 - len(actual_indices)
|
| 107 |
+
clean_latents_4x[:, :, start_pos:, :, :] = full_latents[:, :, actual_indices, :, :]
|
| 108 |
+
clean_latent_4x_indices[start_pos:] = actual_indices
|
| 109 |
+
|
| 110 |
+
# 移除batch维度(如果原来没有)
|
| 111 |
+
if squeeze_batch:
|
| 112 |
+
main_latents = main_latents.squeeze(0)
|
| 113 |
+
clean_latents = clean_latents.squeeze(0)
|
| 114 |
+
clean_latents_2x = clean_latents_2x.squeeze(0)
|
| 115 |
+
clean_latents_4x = clean_latents_4x.squeeze(0)
|
| 116 |
+
|
| 117 |
+
return {
|
| 118 |
+
'latents': main_latents,
|
| 119 |
+
'clean_latents': clean_latents,
|
| 120 |
+
'clean_latents_2x': clean_latents_2x,
|
| 121 |
+
'clean_latents_4x': clean_latents_4x,
|
| 122 |
+
'latent_indices': latent_indices,
|
| 123 |
+
'clean_latent_indices': clean_latent_indices,
|
| 124 |
+
'clean_latent_2x_indices': clean_latent_2x_indices,
|
| 125 |
+
'clean_latent_4x_indices': clean_latent_4x_indices,
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def generate_camera_poses_from_data(cam_data, start_frame, condition_frames, target_frames):
|
| 130 |
+
"""从实际相机数据生成pose embeddings"""
|
| 131 |
+
time_compression_ratio = 4
|
| 132 |
+
total_frames = condition_frames + target_frames
|
| 133 |
+
|
| 134 |
+
cam_extrinsic = cam_data['extrinsic'] # [N, 4, 4]
|
| 135 |
+
start_frame_original = start_frame * time_compression_ratio
|
| 136 |
+
|
| 137 |
+
print(f"Using camera data from frame {start_frame_original}")
|
| 138 |
+
|
| 139 |
+
# 计算相对pose
|
| 140 |
+
relative_poses = []
|
| 141 |
+
for i in range(total_frames):
|
| 142 |
+
frame_idx = start_frame_original + i * time_compression_ratio
|
| 143 |
+
next_frame_idx = frame_idx + time_compression_ratio
|
| 144 |
+
|
| 145 |
+
if next_frame_idx >= len(cam_extrinsic):
|
| 146 |
+
print('Out of temporal range, using last available pose')
|
| 147 |
+
relative_poses.append(relative_poses[-1] if relative_poses else torch.zeros(3, 4))
|
| 148 |
+
else:
|
| 149 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 150 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 151 |
+
|
| 152 |
+
relative_pose = compute_relative_pose(cam_prev, cam_next)
|
| 153 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 154 |
+
|
| 155 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 156 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12]
|
| 157 |
+
|
| 158 |
+
# 添加mask信息
|
| 159 |
+
mask = torch.zeros(total_frames, dtype=torch.float32)
|
| 160 |
+
mask[:condition_frames] = 1.0 # condition frames
|
| 161 |
+
mask = mask.view(-1, 1)
|
| 162 |
+
|
| 163 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13]
|
| 164 |
+
print(f"Generated camera embedding shape: {camera_embedding.shape}")
|
| 165 |
+
|
| 166 |
+
return camera_embedding.to(torch.bfloat16)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def generate_synthetic_camera_poses(direction="forward", target_frames=10, condition_frames=20):
|
| 170 |
+
"""根据指定方向生成相机pose序列(合成数据)"""
|
| 171 |
+
total_frames = condition_frames + target_frames
|
| 172 |
+
poses = []
|
| 173 |
+
|
| 174 |
+
for i in range(total_frames):
|
| 175 |
+
t = i / max(1, total_frames - 1)
|
| 176 |
+
pose = np.eye(4, dtype=np.float32)
|
| 177 |
+
|
| 178 |
+
if direction == "forward":
|
| 179 |
+
pose[2, 3] = -t * 0.04
|
| 180 |
+
elif direction == "backward":
|
| 181 |
+
pose[2, 3] = t * 2.0
|
| 182 |
+
elif direction == "left_turn":
|
| 183 |
+
pose[2, 3] = -t * 0.03
|
| 184 |
+
pose[0, 3] = t * 0.02
|
| 185 |
+
yaw = t * 1
|
| 186 |
+
pose[0, 0] = np.cos(yaw)
|
| 187 |
+
pose[0, 2] = np.sin(yaw)
|
| 188 |
+
pose[2, 0] = -np.sin(yaw)
|
| 189 |
+
pose[2, 2] = np.cos(yaw)
|
| 190 |
+
elif direction == "right_turn":
|
| 191 |
+
pose[2, 3] = -t * 0.03
|
| 192 |
+
pose[0, 3] = -t * 0.02
|
| 193 |
+
yaw = -t * 1
|
| 194 |
+
pose[0, 0] = np.cos(yaw)
|
| 195 |
+
pose[0, 2] = np.sin(yaw)
|
| 196 |
+
pose[2, 0] = -np.sin(yaw)
|
| 197 |
+
pose[2, 2] = np.cos(yaw)
|
| 198 |
+
|
| 199 |
+
poses.append(pose)
|
| 200 |
+
|
| 201 |
+
# 计算相对pose
|
| 202 |
+
relative_poses = []
|
| 203 |
+
for i in range(len(poses) - 1):
|
| 204 |
+
relative_pose = compute_relative_pose(poses[i], poses[i + 1])
|
| 205 |
+
relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
|
| 206 |
+
|
| 207 |
+
if len(relative_poses) < total_frames:
|
| 208 |
+
relative_poses.append(relative_poses[-1])
|
| 209 |
+
|
| 210 |
+
pose_embedding = torch.stack(relative_poses[:total_frames], dim=0)
|
| 211 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12]
|
| 212 |
+
|
| 213 |
+
# 添加mask信息
|
| 214 |
+
mask = torch.zeros(total_frames, dtype=torch.float32)
|
| 215 |
+
mask[:condition_frames] = 1.0
|
| 216 |
+
mask = mask.view(-1, 1)
|
| 217 |
+
|
| 218 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13]
|
| 219 |
+
print(f"Generated {direction} movement poses: {camera_embedding.shape}")
|
| 220 |
+
|
| 221 |
+
return camera_embedding.to(torch.bfloat16)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def replace_dit_model_in_manager():
|
| 225 |
+
"""替换DiT模型类为FramePack版本"""
|
| 226 |
+
from diffsynth.models.wan_video_dit_recam_future import WanModelFuture
|
| 227 |
+
from diffsynth.configs.model_config import model_loader_configs
|
| 228 |
+
|
| 229 |
+
for i, config in enumerate(model_loader_configs):
|
| 230 |
+
keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
|
| 231 |
+
|
| 232 |
+
if 'wan_video_dit' in model_names:
|
| 233 |
+
new_model_names = []
|
| 234 |
+
new_model_classes = []
|
| 235 |
+
|
| 236 |
+
for name, cls in zip(model_names, model_classes):
|
| 237 |
+
if name == 'wan_video_dit':
|
| 238 |
+
new_model_names.append(name)
|
| 239 |
+
new_model_classes.append(WanModelFuture)
|
| 240 |
+
print(f"✅ 替换了模型类: {name} -> WanModelFuture")
|
| 241 |
+
else:
|
| 242 |
+
new_model_names.append(name)
|
| 243 |
+
new_model_classes.append(cls)
|
| 244 |
+
|
| 245 |
+
model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
|
| 246 |
+
|
| 247 |
+
def add_framepack_components(dit_model):
|
| 248 |
+
"""添加FramePack相关组件"""
|
| 249 |
+
if not hasattr(dit_model, 'clean_x_embedder'):
|
| 250 |
+
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 251 |
+
|
| 252 |
+
class CleanXEmbedder(nn.Module):
|
| 253 |
+
def __init__(self, inner_dim):
|
| 254 |
+
super().__init__()
|
| 255 |
+
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 256 |
+
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
| 257 |
+
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
| 258 |
+
|
| 259 |
+
def forward(self, x, scale="1x"):
|
| 260 |
+
# 🔧 确保输入和权重的数据类型匹配
|
| 261 |
+
if scale == "1x":
|
| 262 |
+
x = x.to(self.proj.weight.dtype)
|
| 263 |
+
return self.proj(x)
|
| 264 |
+
elif scale == "2x":
|
| 265 |
+
x = x.to(self.proj_2x.weight.dtype)
|
| 266 |
+
return self.proj_2x(x)
|
| 267 |
+
elif scale == "4x":
|
| 268 |
+
x = x.to(self.proj_4x.weight.dtype)
|
| 269 |
+
return self.proj_4x(x)
|
| 270 |
+
else:
|
| 271 |
+
raise ValueError(f"Unsupported scale: {scale}")
|
| 272 |
+
|
| 273 |
+
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
|
| 274 |
+
# 🔧 修复:使用模型参数的dtype而不是模型的dtype属性
|
| 275 |
+
model_dtype = next(dit_model.parameters()).dtype
|
| 276 |
+
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
|
| 277 |
+
print("✅ 添加了FramePack的clean_x_embedder组件")
|
| 278 |
+
|
| 279 |
+
def inference_sekai_framepack_from_pth(
|
| 280 |
+
condition_pth_path,
|
| 281 |
+
dit_path,
|
| 282 |
+
output_path="sekai/infer_results/output_sekai_framepack.mp4",
|
| 283 |
+
start_frame=0,
|
| 284 |
+
condition_frames=10,
|
| 285 |
+
target_frames=2,
|
| 286 |
+
device="cuda",
|
| 287 |
+
prompt="A video of a scene shot using a pedestrian's front camera while walking",
|
| 288 |
+
direction="forward",
|
| 289 |
+
use_real_poses=True
|
| 290 |
+
):
|
| 291 |
+
"""
|
| 292 |
+
FramePack风格的Sekai视频推理
|
| 293 |
+
"""
|
| 294 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 295 |
+
print(f"Setting up FramePack models for {direction} movement...")
|
| 296 |
+
|
| 297 |
+
# 1. 替换模型类并加载模型
|
| 298 |
+
replace_dit_model_in_manager()
|
| 299 |
+
|
| 300 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 301 |
+
model_manager.load_models([
|
| 302 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 303 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 304 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 305 |
+
])
|
| 306 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 307 |
+
|
| 308 |
+
# 2. 添加camera components和FramePack components
|
| 309 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 310 |
+
for block in pipe.dit.blocks:
|
| 311 |
+
block.cam_encoder = nn.Linear(13, dim)
|
| 312 |
+
block.projector = nn.Linear(dim, dim)
|
| 313 |
+
block.cam_encoder.weight.data.zero_()
|
| 314 |
+
block.cam_encoder.bias.data.zero_()
|
| 315 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 316 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 317 |
+
|
| 318 |
+
# 添加FramePack组件
|
| 319 |
+
add_framepack_components(pipe.dit)
|
| 320 |
+
|
| 321 |
+
# 3. 加载训练的权重
|
| 322 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 323 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=True)
|
| 324 |
+
|
| 325 |
+
pipe = pipe.to(device)
|
| 326 |
+
model_dtype = next(pipe.dit.parameters()).dtype
|
| 327 |
+
pipe.dit = pipe.dit.to(dtype=model_dtype)
|
| 328 |
+
if hasattr(pipe.dit, 'clean_x_embedder'):
|
| 329 |
+
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
|
| 330 |
+
|
| 331 |
+
pipe.scheduler.set_timesteps(50)
|
| 332 |
+
print("Loading condition video from pth...")
|
| 333 |
+
|
| 334 |
+
# 4. 加载条件视频数据
|
| 335 |
+
condition_latents, encoded_data = load_encoded_video_from_pth(
|
| 336 |
+
condition_pth_path,
|
| 337 |
+
start_frame=start_frame,
|
| 338 |
+
num_frames=condition_frames
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
print("Preparing FramePack inputs...")
|
| 342 |
+
|
| 343 |
+
# 5. 🔧 准备FramePack风格的多尺度输入
|
| 344 |
+
full_latents = encoded_data['latents']
|
| 345 |
+
framepack_inputs = prepare_framepack_inputs(
|
| 346 |
+
full_latents, condition_frames, target_frames, start_frame
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# 🔧 转换为正确的设备和数据类型,确保与DiT模型一致
|
| 350 |
+
for key in framepack_inputs:
|
| 351 |
+
if torch.is_tensor(framepack_inputs[key]):
|
| 352 |
+
framepack_inputs[key] = framepack_inputs[key].to(device, dtype=model_dtype)
|
| 353 |
+
|
| 354 |
+
print("Processing poses...")
|
| 355 |
+
|
| 356 |
+
# 6. 生成相机pose embedding
|
| 357 |
+
if use_real_poses and 'cam_emb' in encoded_data:
|
| 358 |
+
print("Using real camera poses from data")
|
| 359 |
+
camera_embedding = generate_camera_poses_from_data(
|
| 360 |
+
encoded_data['cam_emb'],
|
| 361 |
+
start_frame=start_frame,
|
| 362 |
+
condition_frames=condition_frames,
|
| 363 |
+
target_frames=target_frames
|
| 364 |
+
)
|
| 365 |
+
else:
|
| 366 |
+
print(f"Using synthetic {direction} poses")
|
| 367 |
+
camera_embedding = generate_synthetic_camera_poses(
|
| 368 |
+
direction=direction,
|
| 369 |
+
target_frames=target_frames,
|
| 370 |
+
condition_frames=condition_frames
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
camera_embedding = camera_embedding.unsqueeze(0).to(device, dtype=model_dtype)
|
| 374 |
+
print("Encoding prompt...")
|
| 375 |
+
|
| 376 |
+
# 7. 编码文本提示
|
| 377 |
+
prompt_emb = pipe.encode_prompt(prompt)
|
| 378 |
+
print("Generating video...")
|
| 379 |
+
|
| 380 |
+
# 8. 生成目标latents
|
| 381 |
+
batch_size = 1
|
| 382 |
+
channels = framepack_inputs['latents'].shape[0] # 现在latents没有batch维度
|
| 383 |
+
latent_height = framepack_inputs['latents'].shape[2]
|
| 384 |
+
latent_width = framepack_inputs['latents'].shape[3]
|
| 385 |
+
|
| 386 |
+
# 空间裁剪以节省内存
|
| 387 |
+
target_height, target_width = 60, 104
|
| 388 |
+
|
| 389 |
+
if latent_height > target_height or latent_width > target_width:
|
| 390 |
+
h_start = (latent_height - target_height) // 2
|
| 391 |
+
w_start = (latent_width - target_width) // 2
|
| 392 |
+
|
| 393 |
+
# 裁剪所有inputs
|
| 394 |
+
for key in ['latents', 'clean_latents', 'clean_latents_2x', 'clean_latents_4x']:
|
| 395 |
+
if key in framepack_inputs and torch.is_tensor(framepack_inputs[key]):
|
| 396 |
+
framepack_inputs[key] = framepack_inputs[key][:, :,
|
| 397 |
+
h_start:h_start+target_height,
|
| 398 |
+
w_start:w_start+target_width]
|
| 399 |
+
|
| 400 |
+
latent_height = target_height
|
| 401 |
+
latent_width = target_width
|
| 402 |
+
|
| 403 |
+
# 为推理添加batch维度
|
| 404 |
+
for key in ['latents', 'clean_latents', 'clean_latents_2x', 'clean_latents_4x']:
|
| 405 |
+
if key in framepack_inputs and torch.is_tensor(framepack_inputs[key]):
|
| 406 |
+
framepack_inputs[key] = framepack_inputs[key].unsqueeze(0)
|
| 407 |
+
|
| 408 |
+
# 🔧 修复:为索引张量添加batch维度并确保正确的数据类型
|
| 409 |
+
for key in ['latent_indices', 'clean_latent_indices', 'clean_latent_2x_indices', 'clean_latent_4x_indices']:
|
| 410 |
+
if key in framepack_inputs and torch.is_tensor(framepack_inputs[key]):
|
| 411 |
+
# 确保索引是long类型,并且在CPU上
|
| 412 |
+
framepack_inputs[key] = framepack_inputs[key].long().cpu().unsqueeze(0)
|
| 413 |
+
|
| 414 |
+
# 初始化target latents with noise
|
| 415 |
+
target_latents = torch.randn(
|
| 416 |
+
batch_size, channels, target_frames, latent_height, latent_width,
|
| 417 |
+
device=device, dtype=model_dtype # 🔧 使用模型的dtype
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
print(f"FramePack inputs:")
|
| 421 |
+
for key, value in framepack_inputs.items():
|
| 422 |
+
if torch.is_tensor(value):
|
| 423 |
+
print(f" {key}: {value.shape} {value.dtype}")
|
| 424 |
+
else:
|
| 425 |
+
print(f" {key}: {value}")
|
| 426 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 427 |
+
print(f"Target latents shape: {target_latents.shape}")
|
| 428 |
+
|
| 429 |
+
# 9. 准备额外输入
|
| 430 |
+
extra_input = pipe.prepare_extra_input(target_latents)
|
| 431 |
+
|
| 432 |
+
# 10. 🔧 FramePack风格的去噪循环
|
| 433 |
+
timesteps = pipe.scheduler.timesteps
|
| 434 |
+
|
| 435 |
+
for i, timestep in enumerate(timesteps):
|
| 436 |
+
print(f"Denoising step {i+1}/{len(timesteps)}")
|
| 437 |
+
|
| 438 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
|
| 439 |
+
|
| 440 |
+
# 🔧 使用FramePack风格的forward调用
|
| 441 |
+
with torch.no_grad():
|
| 442 |
+
noise_pred = pipe.dit(
|
| 443 |
+
target_latents,
|
| 444 |
+
timestep=timestep_tensor,
|
| 445 |
+
cam_emb=camera_embedding,
|
| 446 |
+
# FramePack参数
|
| 447 |
+
latent_indices=framepack_inputs['latent_indices'],
|
| 448 |
+
clean_latents=framepack_inputs['clean_latents'],
|
| 449 |
+
clean_latent_indices=framepack_inputs['clean_latent_indices'],
|
| 450 |
+
clean_latents_2x=framepack_inputs['clean_latents_2x'],
|
| 451 |
+
clean_latent_2x_indices=framepack_inputs['clean_latent_2x_indices'],
|
| 452 |
+
clean_latents_4x=framepack_inputs['clean_latents_4x'],
|
| 453 |
+
clean_latent_4x_indices=framepack_inputs['clean_latent_4x_indices'],
|
| 454 |
+
**prompt_emb,
|
| 455 |
+
**extra_input
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
# 更新target latents
|
| 459 |
+
target_latents = pipe.scheduler.step(noise_pred, timestep, target_latents)
|
| 460 |
+
|
| 461 |
+
print("Decoding video...")
|
| 462 |
+
|
| 463 |
+
# 11. 解码最终视频
|
| 464 |
+
# 拼接condition和target用于解码
|
| 465 |
+
condition_for_decode = framepack_inputs['clean_latents'][:, :, -1:, :, :] # 取最后一帧作为条件
|
| 466 |
+
final_video = torch.cat([condition_for_decode, target_latents], dim=2)
|
| 467 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 468 |
+
|
| 469 |
+
# 12. 保存视频
|
| 470 |
+
print(f"Saving video to {output_path}")
|
| 471 |
+
|
| 472 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
|
| 473 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1)
|
| 474 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 475 |
+
|
| 476 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 477 |
+
for frame in video_np:
|
| 478 |
+
writer.append_data(frame)
|
| 479 |
+
|
| 480 |
+
print(f"FramePack video generation completed! Saved to {output_path}")
|
| 481 |
+
|
| 482 |
+
def main():
|
| 483 |
+
parser = argparse.ArgumentParser(description="Sekai FramePack Video Generation Inference from PTH")
|
| 484 |
+
parser.add_argument("--condition_pth", type=str,
|
| 485 |
+
default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
|
| 486 |
+
parser.add_argument("--start_frame", type=int, default=0,
|
| 487 |
+
help="Starting frame index (compressed latent frames)")
|
| 488 |
+
parser.add_argument("--condition_frames", type=int, default=8,
|
| 489 |
+
help="Number of condition frames (compressed latent frames)")
|
| 490 |
+
parser.add_argument("--target_frames", type=int, default=8,
|
| 491 |
+
help="Number of target frames to generate (compressed latent frames)")
|
| 492 |
+
parser.add_argument("--direction", type=str, default="left_turn",
|
| 493 |
+
choices=["forward", "backward", "left_turn", "right_turn"],
|
| 494 |
+
help="Direction of camera movement (if not using real poses)")
|
| 495 |
+
parser.add_argument("--use_real_poses", action="store_true", default=False,
|
| 496 |
+
help="Use real camera poses from data")
|
| 497 |
+
parser.add_argument("--dit_path", type=str,
|
| 498 |
+
default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/sekai/sekai_walking_framepack/step24000_framepack.ckpt",
|
| 499 |
+
help="Path to trained FramePack DiT checkpoint")
|
| 500 |
+
parser.add_argument("--output_path", type=str,
|
| 501 |
+
default='/home/zhuyixuan05/ReCamMaster/sekai/infer_framepack_results/output_sekai_framepack.mp4',
|
| 502 |
+
help="Output video path")
|
| 503 |
+
parser.add_argument("--prompt", type=str,
|
| 504 |
+
default="A drone flying scene in a game world",
|
| 505 |
+
help="Text prompt for generation")
|
| 506 |
+
parser.add_argument("--device", type=str, default="cuda",
|
| 507 |
+
help="Device to run inference on")
|
| 508 |
+
|
| 509 |
+
args = parser.parse_args()
|
| 510 |
+
|
| 511 |
+
# 生成输出路径
|
| 512 |
+
if args.output_path is None:
|
| 513 |
+
pth_filename = os.path.basename(args.condition_pth)
|
| 514 |
+
name_parts = os.path.splitext(pth_filename)
|
| 515 |
+
output_dir = "sekai/infer_framepack_results"
|
| 516 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 517 |
+
|
| 518 |
+
if args.use_real_poses:
|
| 519 |
+
output_filename = f"{name_parts[0]}_framepack_real_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4"
|
| 520 |
+
else:
|
| 521 |
+
output_filename = f"{name_parts[0]}_framepack_{args.direction}_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4"
|
| 522 |
+
|
| 523 |
+
output_path = os.path.join(output_dir, output_filename)
|
| 524 |
+
else:
|
| 525 |
+
output_path = args.output_path
|
| 526 |
+
|
| 527 |
+
print(f"🔧 FramePack Inference Settings:")
|
| 528 |
+
print(f"Input pth: {args.condition_pth}")
|
| 529 |
+
print(f"Start frame: {args.start_frame} (compressed)")
|
| 530 |
+
print(f"Condition frames: {args.condition_frames} (compressed, original: {args.condition_frames * 4})")
|
| 531 |
+
print(f"Target frames: {args.target_frames} (compressed, original: {args.target_frames * 4})")
|
| 532 |
+
print(f"Use real poses: {args.use_real_poses}")
|
| 533 |
+
print(f"Direction: {args.direction}")
|
| 534 |
+
print(f"Output video will be saved to: {output_path}")
|
| 535 |
+
|
| 536 |
+
inference_sekai_framepack_from_pth(
|
| 537 |
+
condition_pth_path=args.condition_pth,
|
| 538 |
+
dit_path=args.dit_path,
|
| 539 |
+
output_path=output_path,
|
| 540 |
+
start_frame=args.start_frame,
|
| 541 |
+
condition_frames=args.condition_frames,
|
| 542 |
+
target_frames=args.target_frames,
|
| 543 |
+
device=args.device,
|
| 544 |
+
prompt=args.prompt,
|
| 545 |
+
direction=args.direction,
|
| 546 |
+
use_real_poses=args.use_real_poses
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
if __name__ == "__main__":
|
| 551 |
+
main()
|
scripts/infer_spatialvid.py
ADDED
|
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import imageio
|
| 7 |
+
import argparse
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from diffsynth import WanVideoReCamMasterPipeline, ModelManager
|
| 10 |
+
from torchvision.transforms import v2
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from scipy.spatial.transform import Rotation as R
|
| 13 |
+
|
| 14 |
+
def compute_relative_pose_matrix(pose1, pose2):
|
| 15 |
+
"""
|
| 16 |
+
计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel]
|
| 17 |
+
|
| 18 |
+
参数:
|
| 19 |
+
pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
|
| 20 |
+
pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
|
| 21 |
+
|
| 22 |
+
返回:
|
| 23 |
+
relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel
|
| 24 |
+
"""
|
| 25 |
+
# 分离平移向量和四元数
|
| 26 |
+
t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1]
|
| 27 |
+
q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1]
|
| 28 |
+
t2 = pose2[:3] # 第i+1帧平移
|
| 29 |
+
q2 = pose2[3:] # 第i+1帧四元数
|
| 30 |
+
|
| 31 |
+
# 1. 计算相对旋转矩阵 R_rel
|
| 32 |
+
rot1 = R.from_quat(q1) # 第i帧旋转
|
| 33 |
+
rot2 = R.from_quat(q2) # 第i+1帧旋转
|
| 34 |
+
rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆
|
| 35 |
+
R_rel = rot_rel.as_matrix() # 转换为3×3矩阵
|
| 36 |
+
|
| 37 |
+
# 2. 计算相对平移向量 t_rel
|
| 38 |
+
R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆)
|
| 39 |
+
t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1)
|
| 40 |
+
|
| 41 |
+
# 3. 组合为3×4矩阵 [R_rel | t_rel]
|
| 42 |
+
relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
|
| 43 |
+
|
| 44 |
+
return relative_matrix
|
| 45 |
+
|
| 46 |
+
def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
|
| 47 |
+
"""从pth文件加载预编码的视频数据"""
|
| 48 |
+
print(f"Loading encoded video from {pth_path}")
|
| 49 |
+
|
| 50 |
+
encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
|
| 51 |
+
full_latents = encoded_data['latents'] # [C, T, H, W]
|
| 52 |
+
|
| 53 |
+
print(f"Full latents shape: {full_latents.shape}")
|
| 54 |
+
print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
|
| 55 |
+
|
| 56 |
+
if start_frame + num_frames > full_latents.shape[1]:
|
| 57 |
+
raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
|
| 58 |
+
|
| 59 |
+
condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
|
| 60 |
+
print(f"Extracted condition latents shape: {condition_latents.shape}")
|
| 61 |
+
|
| 62 |
+
return condition_latents, encoded_data
|
| 63 |
+
|
| 64 |
+
def replace_dit_model_in_manager():
|
| 65 |
+
"""在模型加载前替换DiT模型类"""
|
| 66 |
+
from diffsynth.models.wan_video_dit_recam_future import WanModelFuture
|
| 67 |
+
from diffsynth.configs.model_config import model_loader_configs
|
| 68 |
+
|
| 69 |
+
# 修改model_loader_configs中的配置
|
| 70 |
+
for i, config in enumerate(model_loader_configs):
|
| 71 |
+
keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
|
| 72 |
+
|
| 73 |
+
# 检查是否包含wan_video_dit模型
|
| 74 |
+
if 'wan_video_dit' in model_names:
|
| 75 |
+
# 找到wan_video_dit的索引并替换为WanModelFuture
|
| 76 |
+
new_model_names = []
|
| 77 |
+
new_model_classes = []
|
| 78 |
+
|
| 79 |
+
for name, cls in zip(model_names, model_classes):
|
| 80 |
+
if name == 'wan_video_dit':
|
| 81 |
+
new_model_names.append(name) # 保持名称不变
|
| 82 |
+
new_model_classes.append(WanModelFuture) # 替换为新的类
|
| 83 |
+
print(f"✅ 替换了模型类: {name} -> WanModelFuture")
|
| 84 |
+
else:
|
| 85 |
+
new_model_names.append(name)
|
| 86 |
+
new_model_classes.append(cls)
|
| 87 |
+
|
| 88 |
+
# 更新配置
|
| 89 |
+
model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
|
| 90 |
+
|
| 91 |
+
def add_framepack_components(dit_model):
|
| 92 |
+
"""添加FramePack相关组件"""
|
| 93 |
+
if not hasattr(dit_model, 'clean_x_embedder'):
|
| 94 |
+
inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
|
| 95 |
+
|
| 96 |
+
class CleanXEmbedder(nn.Module):
|
| 97 |
+
def __init__(self, inner_dim):
|
| 98 |
+
super().__init__()
|
| 99 |
+
# 参考hunyuan_video_packed.py的设计
|
| 100 |
+
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 101 |
+
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
| 102 |
+
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
| 103 |
+
|
| 104 |
+
def forward(self, x, scale="1x"):
|
| 105 |
+
if scale == "1x":
|
| 106 |
+
return self.proj(x)
|
| 107 |
+
elif scale == "2x":
|
| 108 |
+
return self.proj_2x(x)
|
| 109 |
+
elif scale == "4x":
|
| 110 |
+
return self.proj_4x(x)
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f"Unsupported scale: {scale}")
|
| 113 |
+
|
| 114 |
+
dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
|
| 115 |
+
model_dtype = next(dit_model.parameters()).dtype
|
| 116 |
+
dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
|
| 117 |
+
print("✅ 添加了FramePack的clean_x_embedder组件")
|
| 118 |
+
|
| 119 |
+
def generate_spatialvid_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
|
| 120 |
+
"""为SpatialVid数据集生成camera embeddings - 滑动窗口版本"""
|
| 121 |
+
time_compression_ratio = 4
|
| 122 |
+
|
| 123 |
+
# 计算FramePack实际需要的camera帧数
|
| 124 |
+
framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
|
| 125 |
+
|
| 126 |
+
if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
|
| 127 |
+
print("🔧 使用真实SpatialVid camera数据")
|
| 128 |
+
cam_extrinsic = cam_data['extrinsic']
|
| 129 |
+
|
| 130 |
+
# 确保生成足够长的camera序列
|
| 131 |
+
max_needed_frames = max(
|
| 132 |
+
start_frame + current_history_length + new_frames,
|
| 133 |
+
framepack_needed_frames,
|
| 134 |
+
30
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
print(f"🔧 计算SpatialVid camera序列长度:")
|
| 138 |
+
print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
|
| 139 |
+
print(f" - FramePack需求: {framepack_needed_frames}")
|
| 140 |
+
print(f" - 最终生成: {max_needed_frames}")
|
| 141 |
+
|
| 142 |
+
relative_poses = []
|
| 143 |
+
for i in range(max_needed_frames):
|
| 144 |
+
# SpatialVid特有:每隔1帧而不是4帧
|
| 145 |
+
frame_idx = i
|
| 146 |
+
next_frame_idx = frame_idx + 1
|
| 147 |
+
|
| 148 |
+
if next_frame_idx < len(cam_extrinsic):
|
| 149 |
+
cam_prev = cam_extrinsic[frame_idx]
|
| 150 |
+
cam_next = cam_extrinsic[next_frame_idx]
|
| 151 |
+
relative_cam = compute_relative_pose_matrix(cam_prev, cam_next)
|
| 152 |
+
relative_poses.append(torch.as_tensor(relative_cam[:3, :]))
|
| 153 |
+
else:
|
| 154 |
+
# 超出范围,使用零运动
|
| 155 |
+
print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
|
| 156 |
+
relative_poses.append(torch.zeros(3, 4))
|
| 157 |
+
|
| 158 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 159 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 160 |
+
|
| 161 |
+
# 创建对应长度的mask序列
|
| 162 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 163 |
+
# 从start_frame到current_history_length标记为condition
|
| 164 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 165 |
+
mask[start_frame:condition_end] = 1.0
|
| 166 |
+
|
| 167 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 168 |
+
print(f"🔧 SpatialVid真实camera embedding shape: {camera_embedding.shape}")
|
| 169 |
+
return camera_embedding.to(torch.bfloat16)
|
| 170 |
+
|
| 171 |
+
else:
|
| 172 |
+
print("🔧 使用SpatialVid合成camera数据")
|
| 173 |
+
|
| 174 |
+
max_needed_frames = max(
|
| 175 |
+
start_frame + current_history_length + new_frames,
|
| 176 |
+
framepack_needed_frames,
|
| 177 |
+
30
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
print(f"🔧 生成SpatialVid合成camera帧数: {max_needed_frames}")
|
| 181 |
+
relative_poses = []
|
| 182 |
+
for i in range(max_needed_frames):
|
| 183 |
+
# SpatialVid室内行走模式 - 轻微的左右摆动 + 前进
|
| 184 |
+
yaw_per_frame = 0.03 * np.sin(i * 0.1) # 左右摆动
|
| 185 |
+
forward_speed = 0.008 # 每帧前进距离
|
| 186 |
+
|
| 187 |
+
pose = np.eye(4, dtype=np.float32)
|
| 188 |
+
|
| 189 |
+
# 旋转矩阵(绕Y轴摆动)
|
| 190 |
+
cos_yaw = np.cos(yaw_per_frame)
|
| 191 |
+
sin_yaw = np.sin(yaw_per_frame)
|
| 192 |
+
|
| 193 |
+
pose[0, 0] = cos_yaw
|
| 194 |
+
pose[0, 2] = sin_yaw
|
| 195 |
+
pose[2, 0] = -sin_yaw
|
| 196 |
+
pose[2, 2] = cos_yaw
|
| 197 |
+
|
| 198 |
+
# 平移(前进 + 轻微的上下晃动)
|
| 199 |
+
pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
|
| 200 |
+
pose[1, 3] = 0.002 * np.sin(i * 0.15) # 轻微的上下晃动
|
| 201 |
+
|
| 202 |
+
relative_pose = pose[:3, :]
|
| 203 |
+
relative_poses.append(torch.as_tensor(relative_pose))
|
| 204 |
+
|
| 205 |
+
pose_embedding = torch.stack(relative_poses, dim=0)
|
| 206 |
+
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
|
| 207 |
+
|
| 208 |
+
# 创建对应长度的mask序列
|
| 209 |
+
mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
|
| 210 |
+
condition_end = min(start_frame + current_history_length, max_needed_frames)
|
| 211 |
+
mask[start_frame:condition_end] = 1.0
|
| 212 |
+
|
| 213 |
+
camera_embedding = torch.cat([pose_embedding, mask], dim=1)
|
| 214 |
+
print(f"🔧 SpatialVid合成camera embedding shape: {camera_embedding.shape}")
|
| 215 |
+
return camera_embedding.to(torch.bfloat16)
|
| 216 |
+
|
| 217 |
+
def prepare_framepack_sliding_window_with_camera(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, max_history_frames=49):
|
| 218 |
+
"""FramePack滑动窗口机制 - SpatialVid版本"""
|
| 219 |
+
# history_latents: [C, T, H, W] 当前的历史latents
|
| 220 |
+
C, T, H, W = history_latents.shape
|
| 221 |
+
|
| 222 |
+
# 固定索引结构(这决定了需要的camera帧数)
|
| 223 |
+
total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
|
| 224 |
+
indices = torch.arange(0, total_indices_length)
|
| 225 |
+
split_sizes = [1, 16, 2, 1, target_frames_to_generate]
|
| 226 |
+
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
|
| 227 |
+
indices.split(split_sizes, dim=0)
|
| 228 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
|
| 229 |
+
|
| 230 |
+
# 检查camera长度是否足够
|
| 231 |
+
if camera_embedding_full.shape[0] < total_indices_length:
|
| 232 |
+
shortage = total_indices_length - camera_embedding_full.shape[0]
|
| 233 |
+
padding = torch.zeros(shortage, camera_embedding_full.shape[1],
|
| 234 |
+
dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
|
| 235 |
+
camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
|
| 236 |
+
|
| 237 |
+
# 从完整camera序列中选取对应部分
|
| 238 |
+
combined_camera = camera_embedding_full[:total_indices_length, :].clone()
|
| 239 |
+
|
| 240 |
+
# 根据当前history length重新设置mask
|
| 241 |
+
combined_camera[:, -1] = 0.0 # 先全部设为target (0)
|
| 242 |
+
|
| 243 |
+
# 设置condition mask:前19帧根据实际历史长度决定
|
| 244 |
+
if T > 0:
|
| 245 |
+
available_frames = min(T, 19)
|
| 246 |
+
start_pos = 19 - available_frames
|
| 247 |
+
combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
|
| 248 |
+
|
| 249 |
+
print(f"🔧 SpatialVid Camera mask更新:")
|
| 250 |
+
print(f" - 历史帧数: {T}")
|
| 251 |
+
print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
|
| 252 |
+
|
| 253 |
+
# 处理latents
|
| 254 |
+
clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 255 |
+
|
| 256 |
+
if T > 0:
|
| 257 |
+
available_frames = min(T, 19)
|
| 258 |
+
start_pos = 19 - available_frames
|
| 259 |
+
clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
|
| 260 |
+
|
| 261 |
+
clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
|
| 262 |
+
clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
|
| 263 |
+
clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
|
| 264 |
+
|
| 265 |
+
if T > 0:
|
| 266 |
+
start_latent = history_latents[:, 0:1, :, :]
|
| 267 |
+
else:
|
| 268 |
+
start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
|
| 269 |
+
|
| 270 |
+
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
|
| 271 |
+
|
| 272 |
+
return {
|
| 273 |
+
'latent_indices': latent_indices,
|
| 274 |
+
'clean_latents': clean_latents,
|
| 275 |
+
'clean_latents_2x': clean_latents_2x,
|
| 276 |
+
'clean_latents_4x': clean_latents_4x,
|
| 277 |
+
'clean_latent_indices': clean_latent_indices,
|
| 278 |
+
'clean_latent_2x_indices': clean_latent_2x_indices,
|
| 279 |
+
'clean_latent_4x_indices': clean_latent_4x_indices,
|
| 280 |
+
'camera_embedding': combined_camera,
|
| 281 |
+
'current_length': T,
|
| 282 |
+
'next_length': T + target_frames_to_generate
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
def inference_spatialvid_framepack_sliding_window(
|
| 286 |
+
condition_pth_path,
|
| 287 |
+
dit_path,
|
| 288 |
+
output_path="spatialvid_results/output_spatialvid_framepack_sliding.mp4",
|
| 289 |
+
start_frame=0,
|
| 290 |
+
initial_condition_frames=8,
|
| 291 |
+
frames_per_generation=4,
|
| 292 |
+
total_frames_to_generate=32,
|
| 293 |
+
max_history_frames=49,
|
| 294 |
+
device="cuda",
|
| 295 |
+
prompt="A man walking through indoor spaces with a first-person view",
|
| 296 |
+
use_real_poses=True,
|
| 297 |
+
# CFG参数
|
| 298 |
+
use_camera_cfg=True,
|
| 299 |
+
camera_guidance_scale=2.0,
|
| 300 |
+
text_guidance_scale=1.0
|
| 301 |
+
):
|
| 302 |
+
"""
|
| 303 |
+
SpatialVid FramePack滑动窗口视频生成
|
| 304 |
+
"""
|
| 305 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 306 |
+
print(f"🔧 SpatialVid FramePack滑动窗口生成开始...")
|
| 307 |
+
print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
|
| 308 |
+
print(f"Text guidance scale: {text_guidance_scale}")
|
| 309 |
+
|
| 310 |
+
# 1. 模型初始化
|
| 311 |
+
replace_dit_model_in_manager()
|
| 312 |
+
|
| 313 |
+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
| 314 |
+
model_manager.load_models([
|
| 315 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
| 316 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
| 317 |
+
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
| 318 |
+
])
|
| 319 |
+
pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
|
| 320 |
+
|
| 321 |
+
# 2. 添加camera编码器
|
| 322 |
+
dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
|
| 323 |
+
for block in pipe.dit.blocks:
|
| 324 |
+
block.cam_encoder = nn.Linear(13, dim)
|
| 325 |
+
block.projector = nn.Linear(dim, dim)
|
| 326 |
+
block.cam_encoder.weight.data.zero_()
|
| 327 |
+
block.cam_encoder.bias.data.zero_()
|
| 328 |
+
block.projector.weight = nn.Parameter(torch.eye(dim))
|
| 329 |
+
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 330 |
+
|
| 331 |
+
# 3. 添加FramePack组件
|
| 332 |
+
add_framepack_components(pipe.dit)
|
| 333 |
+
|
| 334 |
+
# 4. 加载训练好的权重
|
| 335 |
+
dit_state_dict = torch.load(dit_path, map_location="cpu")
|
| 336 |
+
pipe.dit.load_state_dict(dit_state_dict, strict=True)
|
| 337 |
+
pipe = pipe.to(device)
|
| 338 |
+
model_dtype = next(pipe.dit.parameters()).dtype
|
| 339 |
+
|
| 340 |
+
if hasattr(pipe.dit, 'clean_x_embedder'):
|
| 341 |
+
pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
|
| 342 |
+
|
| 343 |
+
pipe.scheduler.set_timesteps(50)
|
| 344 |
+
|
| 345 |
+
# 5. 加载初始条件
|
| 346 |
+
print("Loading initial condition frames...")
|
| 347 |
+
initial_latents, encoded_data = load_encoded_video_from_pth(
|
| 348 |
+
condition_pth_path,
|
| 349 |
+
start_frame=start_frame,
|
| 350 |
+
num_frames=initial_condition_frames
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# 空间裁剪
|
| 354 |
+
target_height, target_width = 60, 104
|
| 355 |
+
C, T, H, W = initial_latents.shape
|
| 356 |
+
|
| 357 |
+
if H > target_height or W > target_width:
|
| 358 |
+
h_start = (H - target_height) // 2
|
| 359 |
+
w_start = (W - target_width) // 2
|
| 360 |
+
initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
|
| 361 |
+
H, W = target_height, target_width
|
| 362 |
+
|
| 363 |
+
history_latents = initial_latents.to(device, dtype=model_dtype)
|
| 364 |
+
|
| 365 |
+
print(f"初始history_latents shape: {history_latents.shape}")
|
| 366 |
+
|
| 367 |
+
# 6. 编码prompt - 支持CFG
|
| 368 |
+
if text_guidance_scale > 1.0:
|
| 369 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 370 |
+
prompt_emb_neg = pipe.encode_prompt("")
|
| 371 |
+
print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
|
| 372 |
+
else:
|
| 373 |
+
prompt_emb_pos = pipe.encode_prompt(prompt)
|
| 374 |
+
prompt_emb_neg = None
|
| 375 |
+
print("不使用Text CFG")
|
| 376 |
+
|
| 377 |
+
# 7. 预生成完整的camera embedding序列
|
| 378 |
+
camera_embedding_full = generate_spatialvid_camera_embeddings_sliding(
|
| 379 |
+
encoded_data.get('cam_emb', None),
|
| 380 |
+
0,
|
| 381 |
+
max_history_frames,
|
| 382 |
+
0,
|
| 383 |
+
0,
|
| 384 |
+
use_real_poses=use_real_poses
|
| 385 |
+
).to(device, dtype=model_dtype)
|
| 386 |
+
|
| 387 |
+
print(f"完整camera序列shape: {camera_embedding_full.shape}")
|
| 388 |
+
|
| 389 |
+
# 8. 为Camera CFG创建无条件的camera embedding
|
| 390 |
+
if use_camera_cfg:
|
| 391 |
+
camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
|
| 392 |
+
print(f"创建无条件camera embedding用于CFG")
|
| 393 |
+
|
| 394 |
+
# 9. 滑动窗口生成循环
|
| 395 |
+
total_generated = 0
|
| 396 |
+
all_generated_frames = []
|
| 397 |
+
|
| 398 |
+
while total_generated < total_frames_to_generate:
|
| 399 |
+
current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
|
| 400 |
+
print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
|
| 401 |
+
print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
|
| 402 |
+
|
| 403 |
+
# FramePack数据准备 - SpatialVid版本
|
| 404 |
+
framepack_data = prepare_framepack_sliding_window_with_camera(
|
| 405 |
+
history_latents,
|
| 406 |
+
current_generation,
|
| 407 |
+
camera_embedding_full,
|
| 408 |
+
start_frame,
|
| 409 |
+
max_history_frames
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
# 准备输入
|
| 413 |
+
clean_latents = framepack_data['clean_latents'].unsqueeze(0)
|
| 414 |
+
clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
|
| 415 |
+
clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
|
| 416 |
+
camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
|
| 417 |
+
|
| 418 |
+
# 为CFG准备无条件camera embedding
|
| 419 |
+
if use_camera_cfg:
|
| 420 |
+
camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
|
| 421 |
+
|
| 422 |
+
# 索引处理
|
| 423 |
+
latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
|
| 424 |
+
clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
|
| 425 |
+
clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
|
| 426 |
+
clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
|
| 427 |
+
|
| 428 |
+
# 初始化要生成的latents
|
| 429 |
+
new_latents = torch.randn(
|
| 430 |
+
1, C, current_generation, H, W,
|
| 431 |
+
device=device, dtype=model_dtype
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
extra_input = pipe.prepare_extra_input(new_latents)
|
| 435 |
+
|
| 436 |
+
print(f"Camera embedding shape: {camera_embedding.shape}")
|
| 437 |
+
print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
|
| 438 |
+
|
| 439 |
+
# 去噪循环 - 支持CFG
|
| 440 |
+
timesteps = pipe.scheduler.timesteps
|
| 441 |
+
|
| 442 |
+
for i, timestep in enumerate(timesteps):
|
| 443 |
+
if i % 10 == 0:
|
| 444 |
+
print(f" 去噪步骤 {i}/{len(timesteps)}")
|
| 445 |
+
|
| 446 |
+
timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
|
| 447 |
+
|
| 448 |
+
with torch.no_grad():
|
| 449 |
+
# 正向预测(带条件)
|
| 450 |
+
noise_pred_pos = pipe.dit(
|
| 451 |
+
new_latents,
|
| 452 |
+
timestep=timestep_tensor,
|
| 453 |
+
cam_emb=camera_embedding,
|
| 454 |
+
latent_indices=latent_indices,
|
| 455 |
+
clean_latents=clean_latents,
|
| 456 |
+
clean_latent_indices=clean_latent_indices,
|
| 457 |
+
clean_latents_2x=clean_latents_2x,
|
| 458 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 459 |
+
clean_latents_4x=clean_latents_4x,
|
| 460 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 461 |
+
**prompt_emb_pos,
|
| 462 |
+
**extra_input
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
# CFG处理
|
| 466 |
+
if use_camera_cfg and camera_guidance_scale > 1.0:
|
| 467 |
+
# 无条件预测(无camera条件)
|
| 468 |
+
noise_pred_uncond = pipe.dit(
|
| 469 |
+
new_latents,
|
| 470 |
+
timestep=timestep_tensor,
|
| 471 |
+
cam_emb=camera_embedding_uncond_batch,
|
| 472 |
+
latent_indices=latent_indices,
|
| 473 |
+
clean_latents=clean_latents,
|
| 474 |
+
clean_latent_indices=clean_latent_indices,
|
| 475 |
+
clean_latents_2x=clean_latents_2x,
|
| 476 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 477 |
+
clean_latents_4x=clean_latents_4x,
|
| 478 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 479 |
+
**prompt_emb_pos,
|
| 480 |
+
**extra_input
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
# Camera CFG
|
| 484 |
+
noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_pos - noise_pred_uncond)
|
| 485 |
+
else:
|
| 486 |
+
noise_pred = noise_pred_pos
|
| 487 |
+
|
| 488 |
+
# Text CFG
|
| 489 |
+
if prompt_emb_neg is not None and text_guidance_scale > 1.0:
|
| 490 |
+
noise_pred_neg = pipe.dit(
|
| 491 |
+
new_latents,
|
| 492 |
+
timestep=timestep_tensor,
|
| 493 |
+
cam_emb=camera_embedding,
|
| 494 |
+
latent_indices=latent_indices,
|
| 495 |
+
clean_latents=clean_latents,
|
| 496 |
+
clean_latent_indices=clean_latent_indices,
|
| 497 |
+
clean_latents_2x=clean_latents_2x,
|
| 498 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 499 |
+
clean_latents_4x=clean_latents_4x,
|
| 500 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 501 |
+
**prompt_emb_neg,
|
| 502 |
+
**extra_input
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
noise_pred = noise_pred_neg + text_guidance_scale * (noise_pred - noise_pred_neg)
|
| 506 |
+
|
| 507 |
+
new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
|
| 508 |
+
|
| 509 |
+
# 更新历史
|
| 510 |
+
new_latents_squeezed = new_latents.squeeze(0)
|
| 511 |
+
history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
|
| 512 |
+
|
| 513 |
+
# 维护滑动窗口
|
| 514 |
+
if history_latents.shape[1] > max_history_frames:
|
| 515 |
+
first_frame = history_latents[:, 0:1, :, :]
|
| 516 |
+
recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
|
| 517 |
+
history_latents = torch.cat([first_frame, recent_frames], dim=1)
|
| 518 |
+
print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
|
| 519 |
+
|
| 520 |
+
print(f"更新后history_latents shape: {history_latents.shape}")
|
| 521 |
+
|
| 522 |
+
all_generated_frames.append(new_latents_squeezed)
|
| 523 |
+
total_generated += current_generation
|
| 524 |
+
|
| 525 |
+
print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
|
| 526 |
+
|
| 527 |
+
# 10. 解码和保存
|
| 528 |
+
print("\n🔧 解码生成的视频...")
|
| 529 |
+
|
| 530 |
+
all_generated = torch.cat(all_generated_frames, dim=1)
|
| 531 |
+
final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
|
| 532 |
+
|
| 533 |
+
print(f"最终视频shape: {final_video.shape}")
|
| 534 |
+
|
| 535 |
+
decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
|
| 536 |
+
|
| 537 |
+
print(f"Saving video to {output_path}")
|
| 538 |
+
|
| 539 |
+
video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
|
| 540 |
+
video_np = (video_np * 0.5 + 0.5).clip(0, 1)
|
| 541 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 542 |
+
|
| 543 |
+
with imageio.get_writer(output_path, fps=20) as writer:
|
| 544 |
+
for frame in video_np:
|
| 545 |
+
writer.append_data(frame)
|
| 546 |
+
|
| 547 |
+
print(f"🔧 SpatialVid FramePack滑动窗口生成完成! 保存到: {output_path}")
|
| 548 |
+
print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
|
| 549 |
+
|
| 550 |
+
def main():
|
| 551 |
+
parser = argparse.ArgumentParser(description="SpatialVid FramePack滑动窗口视频生成")
|
| 552 |
+
|
| 553 |
+
# 基础参数
|
| 554 |
+
parser.add_argument("--condition_pth", type=str,
|
| 555 |
+
default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth",
|
| 556 |
+
help="输入编码视频路径")
|
| 557 |
+
parser.add_argument("--start_frame", type=int, default=0)
|
| 558 |
+
parser.add_argument("--initial_condition_frames", type=int, default=16)
|
| 559 |
+
parser.add_argument("--frames_per_generation", type=int, default=8)
|
| 560 |
+
parser.add_argument("--total_frames_to_generate", type=int, default=16)
|
| 561 |
+
parser.add_argument("--max_history_frames", type=int, default=100)
|
| 562 |
+
parser.add_argument("--use_real_poses", action="store_true", default=True)
|
| 563 |
+
parser.add_argument("--dit_path", type=str,
|
| 564 |
+
default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/spatialvid/spatialvid_framepack_random/step50.ckpt",
|
| 565 |
+
help="训练好的模型权重路径")
|
| 566 |
+
parser.add_argument("--output_path", type=str,
|
| 567 |
+
default='spatialvid_results/output_spatialvid_framepack_sliding.mp4')
|
| 568 |
+
parser.add_argument("--prompt", type=str,
|
| 569 |
+
default="A man walking through indoor spaces with a first-person view")
|
| 570 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 571 |
+
|
| 572 |
+
# CFG参数
|
| 573 |
+
parser.add_argument("--use_camera_cfg", action="store_true", default=True,
|
| 574 |
+
help="使用Camera CFG")
|
| 575 |
+
parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
|
| 576 |
+
help="Camera guidance scale for CFG")
|
| 577 |
+
parser.add_argument("--text_guidance_scale", type=float, default=1.0,
|
| 578 |
+
help="Text guidance scale for CFG")
|
| 579 |
+
|
| 580 |
+
args = parser.parse_args()
|
| 581 |
+
|
| 582 |
+
print(f"🔧 SpatialVid FramePack CFG生成设置:")
|
| 583 |
+
print(f"Camera CFG: {args.use_camera_cfg}")
|
| 584 |
+
if args.use_camera_cfg:
|
| 585 |
+
print(f"Camera guidance scale: {args.camera_guidance_scale}")
|
| 586 |
+
print(f"Text guidance scale: {args.text_guidance_scale}")
|
| 587 |
+
print(f"SpatialVid特有特性: camera间隔为1帧")
|
| 588 |
+
|
| 589 |
+
inference_spatialvid_framepack_sliding_window(
|
| 590 |
+
condition_pth_path=args.condition_pth,
|
| 591 |
+
dit_path=args.dit_path,
|
| 592 |
+
output_path=args.output_path,
|
| 593 |
+
start_frame=args.start_frame,
|
| 594 |
+
initial_condition_frames=args.initial_condition_frames,
|
| 595 |
+
frames_per_generation=args.frames_per_generation,
|
| 596 |
+
total_frames_to_generate=args.total_frames_to_generate,
|
| 597 |
+
max_history_frames=args.max_history_frames,
|
| 598 |
+
device=args.device,
|
| 599 |
+
prompt=args.prompt,
|
| 600 |
+
use_real_poses=args.use_real_poses,
|
| 601 |
+
# CFG参数
|
| 602 |
+
use_camera_cfg=args.use_camera_cfg,
|
| 603 |
+
camera_guidance_scale=args.camera_guidance_scale,
|
| 604 |
+
text_guidance_scale=args.text_guidance_scale
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
if __name__ == "__main__":
|
| 608 |
+
main()
|