upload Astra checkpoint

#1
by wjque - opened
This view is limited to 50 files because it contains too many changes. See the raw diff here.
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +0 -1
  3. .gitignore +0 -3
  4. README.md +45 -72
  5. Try ReCamMaster with Your Own Videos Here.txt +1 -0
  6. diffsynth/pipelines/__init__.py +1 -1
  7. diffsynth/pipelines/wan_video_recammaster.py +2 -2
  8. examples/.DS_Store +0 -0
  9. models/Astra/checkpoints/diffusion_pytorch_model.ckpt → examples/output_videos/output_moe_framepack_sliding.mp4 +2 -2
  10. assets/images/pipeline.png → logo-text-2.png +2 -2
  11. models/.DS_Store +0 -0
  12. models/Astra/.DS_Store +0 -0
  13. models/Astra/checkpoints/{Put Astra ckpt file here.txt → Put ReCamMaster ckpt file here.txt} +0 -0
  14. pip-list.txt +197 -0
  15. scripts/add_text_emb.py +161 -0
  16. scripts/add_text_emb_rl.py +161 -0
  17. scripts/add_text_emb_spatialvid.py +173 -0
  18. scripts/analyze_openx.py +243 -0
  19. scripts/analyze_pose.py +188 -0
  20. scripts/batch_drone.py +44 -0
  21. scripts/batch_infer.py +186 -0
  22. scripts/batch_nus.py +42 -0
  23. scripts/batch_rt.py +41 -0
  24. scripts/batch_spa.py +43 -0
  25. scripts/batch_walk.py +42 -0
  26. scripts/check.py +263 -0
  27. scripts/decode_openx.py +428 -0
  28. scripts/download_recam.py +7 -0
  29. scripts/encode_dynamic_videos.py +141 -0
  30. scripts/encode_openx.py +466 -0
  31. scripts/encode_rlbench_video.py +170 -0
  32. scripts/encode_sekai_video.py +162 -0
  33. scripts/encode_sekai_walking.py +249 -0
  34. scripts/encode_spatialvid.py +409 -0
  35. scripts/encode_spatialvid_first_frame.py +285 -0
  36. scripts/hud_logo.py +1 -1
  37. scripts/infer_demo.py +318 -494
  38. scripts/infer_moe.py +1023 -0
  39. scripts/infer_moe_spatialvid.py +1008 -0
  40. scripts/infer_moe_test.py +976 -0
  41. scripts/infer_nus.py +500 -0
  42. scripts/infer_openx.py +614 -0
  43. scripts/infer_origin.py +1108 -0
  44. scripts/infer_recam.py +272 -0
  45. scripts/infer_rlbench.py +447 -0
  46. scripts/infer_sekai.py +497 -0
  47. scripts/infer_sekai_framepack.py +675 -0
  48. scripts/infer_sekai_framepack_4.py +682 -0
  49. scripts/infer_sekai_framepack_test.py +551 -0
  50. scripts/infer_spatialvid.py +608 -0
.DS_Store DELETED
Binary file (6.15 kB)
 
.gitattributes CHANGED
@@ -40,4 +40,3 @@ diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt filter=lfs diff=lfs merge
40
  examples/output_videos/output_moe_framepack_sliding.mp4 filter=lfs diff=lfs merge=lfs -text
41
  logo-text-2.png filter=lfs diff=lfs merge=lfs -text
42
  assets/images/logo-text-2.png filter=lfs diff=lfs merge=lfs -text
43
- pipeline.png filter=lfs diff=lfs merge=lfs -text
 
40
  examples/output_videos/output_moe_framepack_sliding.mp4 filter=lfs diff=lfs merge=lfs -text
41
  logo-text-2.png filter=lfs diff=lfs merge=lfs -text
42
  assets/images/logo-text-2.png filter=lfs diff=lfs merge=lfs -text
 
.gitignore DELETED
@@ -1,3 +0,0 @@
1
- # Ignore all checkpoint files
2
- *.ckpt
3
- *.ckpt.*
 
 
 
 
README.md CHANGED
@@ -20,7 +20,7 @@ model-index:
20
 
21
  <h3 style="margin-top: 0;">
22
  📄
23
- [<a href="https://arxiv.org/abs/2512.08931" target="_blank">arXiv</a>]
24
  &nbsp;&nbsp;
25
  🏠
26
  [<a href="https://eternalevan.github.io/Astra-project/" target="_blank">Project Page</a>]
@@ -48,23 +48,36 @@ model-index:
48
 
49
  <div align="center">
50
 
51
- **[Yixuan Zhu<sup>1</sup>](https://eternalevan.github.io/), [Jiaqi Feng<sup>1</sup>](https://github.com/Aurora-edu/), [Wenzhao Zheng<sup>1 †</sup>](https://wzzheng.net), [Yuan Gao<sup>2</sup>](https://openreview.net/profile?id=~Yuan_Gao32), [Xin Tao<sup>2</sup>](https://www.xtao.website), [Pengfei Wan<sup>2</sup>](https://scholar.google.com/citations?user=P6MraaYAAAAJ&hl=en), [Jie Zhou <sup>1</sup>](https://scholar.google.com/citations?user=6a79aPwAAAAJ&hl=en&authuser=1), [Jiwen Lu<sup>1</sup>](https://ivg.au.tsinghua.edu.cn/Jiwen_Lu/)**
52
  <!-- <br> -->
53
- ( Project leader)
 
54
 
55
  <sup>1</sup>Tsinghua University, <sup>2</sup>Kuaishou Technology.
56
  </div>
57
 
 
 
 
58
 
59
- ## 📖 Introduction
60
 
61
- **TL;DR:** Astra is an **interactive world model** that delivers realistic long-horizon video rollouts under a wide range of scenarios and action inputs.
 
 
 
62
 
63
- **Astra** is an **interactive**, action-driven world model that predicts long-horizon future videos across diverse real-world scenarios. Built on an autoregressive diffusion transformer with temporal causal attention, Astra supports **streaming prediction** while preserving strong temporal coherence. Astra introduces **noise-augmented history memory** to stabilize long rollouts, an **action-aware adapter** for precise control signals, and a **mixture of action experts** to route heterogeneous action modalities. Through these key innovations, Astra delivers consistent, controllable, and high-fidelity video futures for applications such as autonomous driving, robot manipulation, and camera motion.
64
 
65
- <div align="center">
66
- <img src="./assets/images/pipeline.png" alt="Astra Pipeline" width="90%">
67
- </div>
 
 
 
 
 
 
 
68
 
69
  ## Gallery
70
 
@@ -119,27 +132,7 @@ model-index:
119
 
120
  If you would like to use ReCamMaster as a baseline and need qualitative or quantitative comparisons, please feel free to drop an email to [jianhongbai@zju.edu.cn](mailto:jianhongbai@zju.edu.cn). We can assist you with batch inference of our model. -->
121
 
122
- ## 🔥 Updates
123
- - __[2025.11.17]__: Release the [project page](https://eternalevan.github.io/Astra-project/).
124
- - __[2025.12.09]__: Release the inference code, model checkpoint.
125
-
126
- ## 🎯 TODO List
127
-
128
- - [ ] **Release full inference pipelines** for additional scenarios:
129
- - [ ] 🚗 Autonomous driving
130
- - [ ] 🤖 Robotic manipulation
131
- - [ ] 🛸 Drone navigation / exploration
132
-
133
-
134
- - [ ] **Open-source training scripts**:
135
- - [ ] ⬆️ Action-conditioned autoregressive denoising training
136
- - [ ] 🔄 Multi-scenario joint training pipeline
137
-
138
- - [ ] **Release dataset preprocessing tools**
139
-
140
- - [ ] **Provide unified evaluation toolkit**
141
-
142
- ## ⚙️ Run Astra (Inference)
143
  Astra is built upon [Wan2.1-1.3B](https://github.com/Wan-Video/Wan2.1), a diffusion-based video generation model. We provide inference scripts to help you quickly generate videos from images and action inputs. Follow the steps below:
144
 
145
  ### Inference
@@ -167,57 +160,37 @@ python download_wan2.1.py
167
  ```
168
  2. Download the pre-trained Astra checkpoint
169
 
170
- Please download from [huggingface](https://huggingface.co/EvanEternal/Astra/blob/main/models/Astra/checkpoints/diffusion_pytorch_model.ckpt) and place it in ```models/Astra/checkpoints```.
171
 
172
- Step 3: Test the example image
173
  ```shell
174
- python infer_demo.py \
175
- --dit_path ../models/Astra/checkpoints/diffusion_pytorch_model.ckpt \
176
- --wan_model_path ../models/Wan-AI/Wan2.1-T2V-1.3B \
177
- --condition_image ../examples/condition_images/garden_1.png \
178
- --cam_type 4 \
179
- --prompt "A sunlit European street lined with historic buildings and vibrant greenery creates a warm, charming, and inviting atmosphere. The scene shows a picturesque open square paved with red bricks, surrounded by classic narrow townhouses featuring tall windows, gabled roofs, and dark-painted facades. On the right side, a lush arrangement of potted plants and blooming flowers adds rich color and texture to the foreground. A vintage-style streetlamp stands prominently near the center-right, contributing to the timeless character of the street. Mature trees frame the background, their leaves glowing in the warm afternoon sunlight. Bicycles are visible along the edges of the buildings, reinforcing the urban yet leisurely feel. The sky is bright blue with scattered clouds, and soft sun flares enter the frame from the left, enhancing the scene’s inviting, peaceful mood." \
180
- --output_path ../examples/output_videos/output_moe_framepack_sliding.mp4 \
181
  ```
182
 
183
- Step 4: Test your own images
184
 
185
- To test with your own custom images, you need to prepare the target images and their corresponding text prompts. **We recommend that the size of the input images is close to 832×480 (width × height)**, which is consistent with the resolution of the generated video and can help achieve better video generation effects. For prompts generation, you can refer to the [Prompt Extension section](https://github.com/Wan-Video/Wan2.1?tab=readme-ov-file#2-using-prompt-extension) in Wan2.1 for guidance on crafting the captions.
186
 
187
  ```shell
188
- python infer_demo.py \
189
- --dit_path path/to/your/dit_ckpt \
190
- --wan_model_path path/to/your/Wan2.1-T2V-1.3B \
191
- --condition_image path/to/your/image \
192
- --cam_type your_cam_type \
193
- --prompt your_prompt \
194
- --output_path path/to/your/output_video
195
  ```
196
 
197
  We provide several preset camera types, as shown in the table below. Additionally, you can generate new trajectories for testing.
198
 
199
- | cam_type | Trajectory |
200
- |:-----------:|-----------------------------|
201
- | 1 | Move Forward (Straight) |
202
- | 2 | Rotate Left In Place |
203
- | 3 | Rotate Right In Place |
204
- | 4 | Move Forward + Rotate Left |
205
- | 5 | Move Forward + Rotate Right |
206
- | 6 | S-shaped Trajectory |
207
- | 7 | Rotate Left Rotate Right |
208
-
209
-
210
- ## Future Work 🚀
211
-
212
- Looking ahead, we plan to further enhance Astra in several directions:
213
-
214
- - **Training with Wan-2.2:** Upgrade our model using the latest Wan-2.2 framework to release a more powerful version with improved generation quality.
215
- - **3D Spatial Consistency:** Explore techniques to better preserve 3D consistency across frames for more coherent and realistic video generation.
216
- - **Long-Term Memory:** Incorporate mechanisms for long-term memory, enabling the model to handle extended temporal dependencies and complex action sequences.
217
-
218
- These directions aim to push Astra towards more robust and interactive video world modeling.
219
 
220
- <!-- ### Training
221
 
222
  Step 1: Set up the environment
223
 
@@ -249,7 +222,7 @@ Step 4: Test the model
249
 
250
  ```shell
251
  python inference_recammaster.py --cam_type 1 --ckpt_path path/to/the/checkpoint
252
- ``` -->
253
 
254
  <!-- ## 📷 Dataset: MultiCamVideo Dataset
255
  ### 1. Dataset Introduction
@@ -408,10 +381,10 @@ Feel free to explore these outstanding related works, including but not limited
408
 
409
  Please leave us a star 🌟 and cite our paper if you find our work helpful.
410
  ```
411
- @article{zhu2025astra,
412
  title={Astra: General Interactive World Model with Autoregressive Denoising},
413
  author={Zhu, Yixuan and Feng, Jiaqi and Zheng, Wenzhao and Gao, Yuan and Tao, Xin and Wan, Pengfei and Zhou, Jie and Lu, Jiwen},
414
- journal={arXiv preprint arXiv:2512.08931},
415
  year={2025}
416
  }
417
  ```
 
20
 
21
  <h3 style="margin-top: 0;">
22
  📄
23
+ [<a href="https://arxiv.org/abs/2503.11647" target="_blank">arXiv</a>]
24
  &nbsp;&nbsp;
25
  🏠
26
  [<a href="https://eternalevan.github.io/Astra-project/" target="_blank">Project Page</a>]
 
48
 
49
  <div align="center">
50
 
51
+ **[Yixuan Zhu<sup>1</sup>](https://jianhongbai.github.io/), [Jiaqi Feng<sup>1</sup>](https://menghanxia.github.io/), [Wenzhao Zheng<sup>1 †</sup>](https://fuxiao0719.github.io/), [Yuan Gao<sup>2</sup>](https://xinntao.github.io/), [Xin Tao<sup>2</sup>](https://scholar.google.com/citations?user=dCik-2YAAAAJ&hl=en), [Pengfei Wan<sup>2</sup>](https://openreview.net/profile?id=~Jinwen_Cao1), [Jie Zhou <sup>1</sup>](https://person.zju.edu.cn/en/lzz), [Jiwen Lu<sup>1</sup>](https://person.zju.edu.cn/en/huhaoji)**
52
  <!-- <br> -->
53
+ (*Work done during an internship at Kuaishou Technology,
54
+ † Project leader)
55
 
56
  <sup>1</sup>Tsinghua University, <sup>2</sup>Kuaishou Technology.
57
  </div>
58
 
59
+ ## 🔥 Updates
60
+ - __[2025.11.17]__: Release the [project page](https://eternalevan.github.io/Astra-project/).
61
+ - __[2025.12.09]__: Release the training and inference code, model checkpoint.
62
 
63
+ ## 🎯 TODO List
64
 
65
+ - [ ] **Release full inference pipelines** for additional scenarios:
66
+ - [ ] 🚗 Autonomous driving
67
+ - [ ] 🤖 Robotic manipulation
68
+ - [ ] 🛸 Drone navigation / exploration
69
 
 
70
 
71
+ - [ ] **Open-source training scripts**:
72
+ - [ ] ⬆️ Action-conditioned autoregressive denoising training
73
+ - [ ] 🔄 Multi-scenario joint training pipeline
74
+
75
+ - [ ] **Release dataset preprocessing tools**
76
+
77
+ - [ ] **Provide unified evaluation toolkit**
78
+ ## 📖 Introduction
79
+
80
+ **TL;DR:** Astra is an **interactive world model** that delivers realistic long-horizon video rollouts under a wide range of scenarios and action inputs.
81
 
82
  ## Gallery
83
 
 
132
 
133
  If you would like to use ReCamMaster as a baseline and need qualitative or quantitative comparisons, please feel free to drop an email to [jianhongbai@zju.edu.cn](mailto:jianhongbai@zju.edu.cn). We can assist you with batch inference of our model. -->
134
 
135
+ ## ⚙️ Code: Astra + Wan2.1 (Inference & Training)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  Astra is built upon [Wan2.1-1.3B](https://github.com/Wan-Video/Wan2.1), a diffusion-based video generation model. We provide inference scripts to help you quickly generate videos from images and action inputs. Follow the steps below:
137
 
138
  ### Inference
 
160
  ```
161
  2. Download the pre-trained Astra checkpoint
162
 
163
+ Please download from [huggingface](https://huggingface.co/wjque/lyra/blob/main/diffusion_pytorch_model.ckpt) and place it in ```models/Astra/checkpoints```.
164
 
165
+ Step 3: Test the example videos
166
  ```shell
167
+ python inference_astra.py --cam_type 1
 
 
 
 
 
 
168
  ```
169
 
170
+ Step 4: Test your own videos
171
 
172
+ If you want to test your own videos, you need to prepare your test data following the structure of the ```example_test_data``` folder. This includes N mp4 videos, each with at least 81 frames, and a ```metadata.csv``` file that stores their paths and corresponding captions. You can refer to the [Prompt Extension section](https://github.com/Wan-Video/Wan2.1?tab=readme-ov-file#2-using-prompt-extension) in Wan2.1 for guidance on preparing video captions.
173
 
174
  ```shell
175
+ python inference_astra.py --cam_type 1 --dataset_path path/to/your/data
 
 
 
 
 
 
176
  ```
177
 
178
  We provide several preset camera types, as shown in the table below. Additionally, you can generate new trajectories for testing.
179
 
180
+ | cam_type | Trajectory |
181
+ |-------------------|-----------------------------|
182
+ | 1 | Pan Right |
183
+ | 2 | Pan Left |
184
+ | 3 | Tilt Up |
185
+ | 4 | Tilt Down |
186
+ | 5 | Zoom In |
187
+ | 6 | Zoom Out |
188
+ | 7 | Translate Up (with rotation) |
189
+ | 8 | Translate Down (with rotation) |
190
+ | 9 | Arc Left (with rotation) |
191
+ | 10 | Arc Right (with rotation) |
 
 
 
 
 
 
 
 
192
 
193
+ ### Training
194
 
195
  Step 1: Set up the environment
196
 
 
222
 
223
  ```shell
224
  python inference_recammaster.py --cam_type 1 --ckpt_path path/to/the/checkpoint
225
+ ```
226
 
227
  <!-- ## 📷 Dataset: MultiCamVideo Dataset
228
  ### 1. Dataset Introduction
 
381
 
382
  Please leave us a star 🌟 and cite our paper if you find our work helpful.
383
  ```
384
+ @inproceedings{zhu2025astra,
385
  title={Astra: General Interactive World Model with Autoregressive Denoising},
386
  author={Zhu, Yixuan and Feng, Jiaqi and Zheng, Wenzhao and Gao, Yuan and Tao, Xin and Wan, Pengfei and Zhou, Jie and Lu, Jiwen},
387
+ booktitle={arxiv},
388
  year={2025}
389
  }
390
  ```
Try ReCamMaster with Your Own Videos Here.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ https://docs.google.com/forms/d/e/1FAIpQLSezOzGPbm8JMXQDq6EINiDf6iXn7rV4ozj6KcbQCSAzE8Vsnw/viewform?usp=dialog
diffsynth/pipelines/__init__.py CHANGED
@@ -12,5 +12,5 @@ from .pipeline_runner import SDVideoPipelineRunner
12
  from .hunyuan_video import HunyuanVideoPipeline
13
  from .step_video import StepVideoPipeline
14
  from .wan_video import WanVideoPipeline
15
- from .wan_video_recammaster import WanVideoAstraPipeline
16
  KolorsImagePipeline = SDXLImagePipeline
 
12
  from .hunyuan_video import HunyuanVideoPipeline
13
  from .step_video import StepVideoPipeline
14
  from .wan_video import WanVideoPipeline
15
+ from .wan_video_recammaster import WanVideoReCamMasterPipeline
16
  KolorsImagePipeline = SDXLImagePipeline
diffsynth/pipelines/wan_video_recammaster.py CHANGED
@@ -21,7 +21,7 @@ from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
21
 
22
 
23
 
24
- class WanVideoAstraPipeline(BasePipeline):
25
 
26
  def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None,condition_frames=None,target_frames=None):
27
  super().__init__(device=device, torch_dtype=torch_dtype)
@@ -141,7 +141,7 @@ class WanVideoAstraPipeline(BasePipeline):
141
  def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
142
  if device is None: device = model_manager.device
143
  if torch_dtype is None: torch_dtype = model_manager.torch_dtype
144
- pipe = WanVideoAstraPipeline(device=device, torch_dtype=torch_dtype)
145
  pipe.fetch_models(model_manager)
146
  return pipe
147
 
 
21
 
22
 
23
 
24
+ class WanVideoReCamMasterPipeline(BasePipeline):
25
 
26
  def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None,condition_frames=None,target_frames=None):
27
  super().__init__(device=device, torch_dtype=torch_dtype)
 
141
  def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
142
  if device is None: device = model_manager.device
143
  if torch_dtype is None: torch_dtype = model_manager.torch_dtype
144
+ pipe = WanVideoReCamMasterPipeline(device=device, torch_dtype=torch_dtype)
145
  pipe.fetch_models(model_manager)
146
  return pipe
147
 
examples/.DS_Store DELETED
Binary file (6.15 kB)
 
models/Astra/checkpoints/diffusion_pytorch_model.ckpt → examples/output_videos/output_moe_framepack_sliding.mp4 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5ca9c5e04e26fdbe37ea4819ea623d76c69747f9fcbacd5ca47f43e49cbfbd9f
3
- size 3153044423
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94e522d491ab43a0e597b13eefc4a6a48640c4490489005adef1bcb48c22e15c
3
+ size 2085028
assets/images/pipeline.png → logo-text-2.png RENAMED
File without changes
models/.DS_Store DELETED
Binary file (6.15 kB)
 
models/Astra/.DS_Store DELETED
Binary file (6.15 kB)
 
models/Astra/checkpoints/{Put Astra ckpt file here.txt → Put ReCamMaster ckpt file here.txt} RENAMED
File without changes
pip-list.txt ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ (ddrm) zyx@zwl:~$ pip list
2
+ Package Version Editable project location
3
+ ------------------------ --------------- ---------------------------------------------------------
4
+ absl-py 2.0.0
5
+ accelerate 1.0.1
6
+ addict 2.4.0
7
+ aiohttp 3.8.5
8
+ aiosignal 1.3.1
9
+ albumentations 1.4.6
10
+ annotated-types 0.6.0
11
+ antlr4-python3-runtime 4.9.3
12
+ appdirs 1.4.4
13
+ asttokens 2.4.1
14
+ async-timeout 4.0.3
15
+ attrs 23.1.0
16
+ backcall 0.2.0
17
+ basicsr 1.2.0+1.4.2 /home/zyx/Retinexformer-master
18
+ beautifulsoup4 4.12.3
19
+ blessed 1.20.0
20
+ blobfile 2.0.2
21
+ cachetools 5.3.1
22
+ certifi 2023.5.7
23
+ cffi 1.15.1
24
+ charset-normalizer 3.2.0
25
+ click 8.1.7
26
+ cmake 3.26.4
27
+ contourpy 1.1.1
28
+ cycler 0.11.0
29
+ decorator 5.1.1
30
+ decord 0.6.0
31
+ diffusers 0.31.0
32
+ dlib 19.24.6
33
+ docker-pycreds 0.4.0
34
+ einops 0.6.1
35
+ executing 2.1.0
36
+ face-alignment 1.4.1
37
+ facexlib 0.3.0
38
+ filelock 3.12.2
39
+ filterpy 1.4.5
40
+ fire 0.5.0
41
+ flatbuffers 23.5.26
42
+ fonttools 4.42.1
43
+ frozenlist 1.4.0
44
+ fsspec 2023.9.1
45
+ ftfy 6.1.1
46
+ future 0.18.3
47
+ gdown 5.2.0
48
+ gfpgan 1.3.8 /home/zyx/anaconda3/envs/ddrm/lib/python3.8/site-packages
49
+ gitdb 4.0.10
50
+ GitPython 3.1.37
51
+ google-auth 2.23.0
52
+ google-auth-oauthlib 1.0.0
53
+ gpustat 1.1
54
+ grpcio 1.58.0
55
+ guided-diffusion 0.0.0 /home/zyx/GenerativeDiffusionPrior
56
+ huggingface-hub 0.30.2
57
+ idna 3.4
58
+ imageio 2.31.5
59
+ imgaug 0.4.0
60
+ importlib-metadata 6.8.0
61
+ importlib-resources 6.1.0
62
+ ip-adapter 0.1.0
63
+ ipython 8.12.3
64
+ jedi 0.19.2
65
+ Jinja2 3.1.2
66
+ joblib 1.3.2
67
+ kiwisolver 1.4.5
68
+ lazy_loader 0.3
69
+ lightning-utilities 0.9.0
70
+ lit 16.0.6
71
+ llvmlite 0.41.0
72
+ lmdb 1.4.1
73
+ loguru 0.7.2
74
+ lora-diffusion 0.1.7
75
+ loralib 0.1.2
76
+ lpips 0.1.4
77
+ lxml 4.9.3
78
+ Markdown 3.4.4
79
+ markdown-it-py 3.0.0
80
+ MarkupSafe 2.1.3
81
+ matplotlib 3.7.3
82
+ matplotlib-inline 0.1.7
83
+ mdurl 0.1.2
84
+ mediapipe 0.10.5
85
+ mmcv 1.7.0
86
+ mmengine 0.10.7
87
+ mpi4py 3.1.4
88
+ mpmath 1.3.0
89
+ multidict 6.0.4
90
+ mypy-extensions 1.0.0
91
+ natsort 8.4.0
92
+ networkx 3.1
93
+ ninja 1.11.1.1
94
+ numba 0.58.0
95
+ numpy 1.24.4
96
+ nvidia-cublas-cu11 11.10.3.66
97
+ nvidia-cuda-cupti-cu11 11.7.101
98
+ nvidia-cuda-nvrtc-cu11 11.7.99
99
+ nvidia-cuda-runtime-cu11 11.7.99
100
+ nvidia-cudnn-cu11 8.5.0.96
101
+ nvidia-cufft-cu11 10.9.0.58
102
+ nvidia-curand-cu11 10.2.10.91
103
+ nvidia-cusolver-cu11 11.4.0.1
104
+ nvidia-cusparse-cu11 11.7.4.91
105
+ nvidia-ml-py 12.535.77
106
+ nvidia-nccl-cu11 2.14.3
107
+ nvidia-nvtx-cu11 11.7.91
108
+ oauthlib 3.2.2
109
+ omegaconf 2.3.0
110
+ open-clip-torch 2.20.0
111
+ openai-clip 1.0.1
112
+ opencv-contrib-python 4.8.0.76
113
+ opencv-python 4.8.0.74
114
+ opencv-python-headless 4.9.0.80
115
+ packaging 23.1
116
+ pandas 2.0.3
117
+ parso 0.8.4
118
+ pathtools 0.1.2
119
+ peft 0.13.2
120
+ pexpect 4.9.0
121
+ pickleshare 0.7.5
122
+ Pillow 10.0.0
123
+ pip 23.1.2
124
+ platformdirs 3.11.0
125
+ prompt_toolkit 3.0.48
126
+ protobuf 3.20.3
127
+ psutil 5.9.5
128
+ ptyprocess 0.7.0
129
+ pure_eval 0.2.3
130
+ pyasn1 0.5.0
131
+ pyasn1-modules 0.3.0
132
+ pycparser 2.21
133
+ pycryptodomex 3.18.0
134
+ pydantic 2.7.1
135
+ pydantic_core 2.18.2
136
+ pyDeprecate 0.3.1
137
+ Pygments 2.18.0
138
+ pyiqa 0.1.8
139
+ pyparsing 3.1.1
140
+ pyre-extensions 0.0.23
141
+ PySocks 1.7.1
142
+ python-dateutil 2.8.2
143
+ pytorch-fid 0.3.0
144
+ pytorch-lightning 1.4.2
145
+ pytz 2023.3.post1
146
+ PyWavelets 1.4.1
147
+ PyYAML 6.0.1
148
+ realesrgan 0.3.0
149
+ regex 2023.8.8
150
+ requests 2.31.0
151
+ requests-oauthlib 1.3.1
152
+ rich 14.0.0
153
+ rsa 4.9
154
+ safetensors 0.4.5
155
+ scikit-image 0.21.0
156
+ scikit-learn 1.3.2
157
+ scipy 1.10.1
158
+ sentencepiece 0.1.99
159
+ sentry-sdk 1.31.0
160
+ setproctitle 1.3.2
161
+ setuptools 68.0.0
162
+ shapely 2.0.1
163
+ six 1.16.0
164
+ smmap 5.0.1
165
+ sounddevice 0.4.6
166
+ soupsieve 2.5
167
+ stack-data 0.6.3
168
+ sympy 1.12
169
+ tb-nightly 2.14.0a20230808
170
+ tensorboard 2.14.0
171
+ tensorboard-data-server 0.7.1
172
+ termcolor 2.3.0
173
+ threadpoolctl 3.2.0
174
+ tifffile 2023.7.10
175
+ timm 0.9.7
176
+ tokenizers 0.15.0
177
+ tomli 2.0.1
178
+ torch 1.13.1+cu116
179
+ torchmetrics 0.5.0
180
+ torchvision 0.14.1+cu116
181
+ tqdm 4.65.0
182
+ traitlets 5.14.3
183
+ transformers 4.36.2
184
+ triton 2.0.0
185
+ typing_extensions 4.11.0
186
+ typing-inspect 0.9.0
187
+ tzdata 2023.3
188
+ urllib3 1.26.16
189
+ vqfr 2.0.0 /home/zyx/VQFR
190
+ wandb 0.15.11
191
+ wcwidth 0.2.6
192
+ Werkzeug 2.3.7
193
+ wheel 0.40.0
194
+ xformers 0.0.16
195
+ yapf 0.40.2
196
+ yarl 1.9.2
197
+ zipp 3.17.0
scripts/add_text_emb.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import lightning as pl
4
+ from PIL import Image
5
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
6
+ import json
7
+ import imageio
8
+ from torchvision.transforms import v2
9
+ from einops import rearrange
10
+ import argparse
11
+ import numpy as np
12
+ import pdb
13
+ from tqdm import tqdm
14
+
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
+
17
+ class VideoEncoder(pl.LightningModule):
18
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
19
+ super().__init__()
20
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
21
+ model_manager.load_models([text_encoder_path, vae_path])
22
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
23
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
24
+
25
+ self.frame_process = v2.Compose([
26
+ # v2.CenterCrop(size=(900, 1600)),
27
+ # v2.Resize(size=(900, 1600), antialias=True),
28
+ v2.ToTensor(),
29
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
30
+ ])
31
+
32
+ def crop_and_resize(self, image):
33
+ width, height = image.size
34
+ # print(width,height)
35
+ width_ori, height_ori_ = 832 , 480
36
+ image = v2.functional.resize(
37
+ image,
38
+ (round(height_ori_), round(width_ori)),
39
+ interpolation=v2.InterpolationMode.BILINEAR
40
+ )
41
+ return image
42
+
43
+ def load_video_frames(self, video_path):
44
+ """加载完整视频"""
45
+ reader = imageio.get_reader(video_path)
46
+ frames = []
47
+
48
+ for frame_data in reader:
49
+ frame = Image.fromarray(frame_data)
50
+ frame = self.crop_and_resize(frame)
51
+ frame = self.frame_process(frame)
52
+ frames.append(frame)
53
+
54
+ reader.close()
55
+
56
+ if len(frames) == 0:
57
+ return None
58
+
59
+ frames = torch.stack(frames, dim=0)
60
+ frames = rearrange(frames, "T C H W -> C T H W")
61
+ return frames
62
+
63
+ def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
64
+ """编码所有场景的视频"""
65
+
66
+ encoder = VideoEncoder(text_encoder_path, vae_path)
67
+ encoder = encoder.cuda()
68
+ encoder.pipe.device = "cuda"
69
+
70
+ processed_count = 0
71
+ prompt_emb = 0
72
+
73
+ os.makedirs(output_dir,exist_ok=True)
74
+
75
+ required_keys = ["latents", "cam_emb", "prompt_emb"]
76
+
77
+
78
+ for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
79
+
80
+ scene_dir = os.path.join(scenes_path, scene_name)
81
+ save_dir = os.path.join(output_dir,scene_name.split('.')[0])
82
+ # print('in:',scene_dir)
83
+ # print('out:',save_dir)
84
+
85
+
86
+ # 检查是否已编码
87
+ encoded_path = os.path.join(save_dir, "encoded_video.pth")
88
+ # if os.path.exists(encoded_path):
89
+ print(f"Checking scene {scene_name}...")
90
+ # continue
91
+
92
+ # 加载场景信息
93
+
94
+ # print(encoded_path)
95
+ data = torch.load(encoded_path,weights_only=False)
96
+ missing_keys = [key for key in required_keys if key not in data]
97
+
98
+ if missing_keys:
99
+ print(f"警告: 文件中缺少以下必要元素: {missing_keys}")
100
+ else:
101
+ print("文件包含所有必要元素: latents 和 cam_emb 和 prompt_emb")
102
+ continue
103
+ # with np.load(scene_cam_path) as data:
104
+ # cam_data = data.files
105
+ # cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
106
+ # with open(scene_cam_path, 'rb') as f:
107
+ # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
108
+
109
+
110
+
111
+ # 加载和编码视频
112
+ # video_frames = encoder.load_video_frames(video_path)
113
+ # if video_frames is None:
114
+ # print(f"Failed to load video: {video_path}")
115
+ # continue
116
+
117
+ # video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
118
+ # print(video_frames.shape)
119
+ # 编码视频
120
+ with torch.no_grad():
121
+ # latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
122
+
123
+ # 编码文本
124
+ if processed_count == 0:
125
+ print('encode prompt!!!')
126
+ prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")#A video of a scene shot using a drone's front camera
127
+ del encoder.pipe.prompter
128
+
129
+ data["prompt_emb"] = prompt_emb
130
+
131
+ print("已添加/更新 prompt_emb 元素")
132
+
133
+ # 保存修改后的文件(可改为新路径避免覆盖原文件)
134
+ torch.save(data, encoded_path)
135
+
136
+ # pdb.set_trace()
137
+ # 保存编码结果
138
+
139
+
140
+ print(f"Saved encoded data: {encoded_path}")
141
+ processed_count += 1
142
+
143
+ # except Exception as e:
144
+ # print(f"Error encoding scene {scene_name}: {e}")
145
+ # continue
146
+ print(processed_count)
147
+ print(f"Encoding completed! Processed {processed_count} scenes.")
148
+
149
+ if __name__ == "__main__":
150
+ parser = argparse.ArgumentParser()
151
+ parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking")
152
+ parser.add_argument("--text_encoder_path", type=str,
153
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
154
+ parser.add_argument("--vae_path", type=str,
155
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
156
+
157
+ parser.add_argument("--output_dir",type=str,
158
+ default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking")
159
+
160
+ args = parser.parse_args()
161
+ encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
scripts/add_text_emb_rl.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import lightning as pl
4
+ from PIL import Image
5
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
6
+ import json
7
+ import imageio
8
+ from torchvision.transforms import v2
9
+ from einops import rearrange
10
+ import argparse
11
+ import numpy as np
12
+ import pdb
13
+ from tqdm import tqdm
14
+
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
+
17
+ class VideoEncoder(pl.LightningModule):
18
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
19
+ super().__init__()
20
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
21
+ model_manager.load_models([text_encoder_path, vae_path])
22
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
23
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
24
+
25
+ self.frame_process = v2.Compose([
26
+ # v2.CenterCrop(size=(900, 1600)),
27
+ # v2.Resize(size=(900, 1600), antialias=True),
28
+ v2.ToTensor(),
29
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
30
+ ])
31
+
32
+ def crop_and_resize(self, image):
33
+ width, height = image.size
34
+ # print(width,height)
35
+ width_ori, height_ori_ = 832 , 480
36
+ image = v2.functional.resize(
37
+ image,
38
+ (round(height_ori_), round(width_ori)),
39
+ interpolation=v2.InterpolationMode.BILINEAR
40
+ )
41
+ return image
42
+
43
+ def load_video_frames(self, video_path):
44
+ """加载完整视频"""
45
+ reader = imageio.get_reader(video_path)
46
+ frames = []
47
+
48
+ for frame_data in reader:
49
+ frame = Image.fromarray(frame_data)
50
+ frame = self.crop_and_resize(frame)
51
+ frame = self.frame_process(frame)
52
+ frames.append(frame)
53
+
54
+ reader.close()
55
+
56
+ if len(frames) == 0:
57
+ return None
58
+
59
+ frames = torch.stack(frames, dim=0)
60
+ frames = rearrange(frames, "T C H W -> C T H W")
61
+ return frames
62
+
63
+ def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
64
+ """编码所有场景的视频"""
65
+
66
+ encoder = VideoEncoder(text_encoder_path, vae_path)
67
+ encoder = encoder.cuda()
68
+ encoder.pipe.device = "cuda"
69
+
70
+ processed_count = 0
71
+ prompt_emb = 0
72
+
73
+ os.makedirs(output_dir,exist_ok=True)
74
+
75
+ required_keys = ["latents", "cam_emb", "prompt_emb"]
76
+
77
+
78
+ for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
79
+
80
+ scene_dir = os.path.join(scenes_path, scene_name)
81
+ save_dir = os.path.join(output_dir,scene_name.split('.')[0])
82
+ # print('in:',scene_dir)
83
+ # print('out:',save_dir)
84
+
85
+
86
+ # 检查是否已编码
87
+ encoded_path = os.path.join(save_dir, "encoded_video.pth")
88
+ # if os.path.exists(encoded_path):
89
+ print(f"Checking scene {scene_name}...")
90
+ # continue
91
+
92
+ # 加载场景信息
93
+
94
+ # print(encoded_path)
95
+ data = torch.load(encoded_path,weights_only=False)
96
+ missing_keys = [key for key in required_keys if key not in data]
97
+
98
+ if missing_keys:
99
+ print(f"警告: 文件中缺少以下必要元素: {missing_keys}")
100
+ else:
101
+ print("文件包含所有必要元素: latents 和 cam_emb 和 prompt_emb")
102
+ continue
103
+ # with np.load(scene_cam_path) as data:
104
+ # cam_data = data.files
105
+ # cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
106
+ # with open(scene_cam_path, 'rb') as f:
107
+ # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
108
+
109
+
110
+
111
+ # 加载和编码视频
112
+ # video_frames = encoder.load_video_frames(video_path)
113
+ # if video_frames is None:
114
+ # print(f"Failed to load video: {video_path}")
115
+ # continue
116
+
117
+ # video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
118
+ # print(video_frames.shape)
119
+ # 编码视频
120
+ with torch.no_grad():
121
+ # latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
122
+
123
+ # 编码文本
124
+ if processed_count == 0:
125
+ print('encode prompt!!!')
126
+ prompt_emb = encoder.pipe.encode_prompt("a robotic arm executing precise manipulation tasks on a clean, organized desk")#A video of a scene shot using a drone's front camera + “A video of a scene shot using a pedestrian's front camera while walking”
127
+ del encoder.pipe.prompter
128
+
129
+ data["prompt_emb"] = prompt_emb
130
+
131
+ print("已添加/更新 prompt_emb 元素")
132
+
133
+ # 保存修改后的文件(可改为新路径避免覆盖原文件)
134
+ torch.save(data, encoded_path)
135
+
136
+ # pdb.set_trace()
137
+ # 保存编码结果
138
+
139
+
140
+ print(f"Saved encoded data: {encoded_path}")
141
+ processed_count += 1
142
+
143
+ # except Exception as e:
144
+ # print(f"Error encoding scene {scene_name}: {e}")
145
+ # continue
146
+ print(processed_count)
147
+ print(f"Encoding completed! Processed {processed_count} scenes.")
148
+
149
+ if __name__ == "__main__":
150
+ parser = argparse.ArgumentParser()
151
+ parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/rlbench")
152
+ parser.add_argument("--text_encoder_path", type=str,
153
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
154
+ parser.add_argument("--vae_path", type=str,
155
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
156
+
157
+ parser.add_argument("--output_dir",type=str,
158
+ default="/share_zhuyixuan05/zhuyixuan05/rlbench")
159
+
160
+ args = parser.parse_args()
161
+ encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
scripts/add_text_emb_spatialvid.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import lightning as pl
4
+ from PIL import Image
5
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
6
+ import json
7
+ import imageio
8
+ from torchvision.transforms import v2
9
+ from einops import rearrange
10
+ import argparse
11
+ import numpy as np
12
+ import pdb
13
+ from tqdm import tqdm
14
+
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
+
17
+ class VideoEncoder(pl.LightningModule):
18
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
19
+ super().__init__()
20
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
21
+ model_manager.load_models([text_encoder_path, vae_path])
22
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
23
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
24
+
25
+ self.frame_process = v2.Compose([
26
+ # v2.CenterCrop(size=(900, 1600)),
27
+ # v2.Resize(size=(900, 1600), antialias=True),
28
+ v2.ToTensor(),
29
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
30
+ ])
31
+
32
+ def crop_and_resize(self, image):
33
+ width, height = image.size
34
+ # print(width,height)
35
+ width_ori, height_ori_ = 832 , 480
36
+ image = v2.functional.resize(
37
+ image,
38
+ (round(height_ori_), round(width_ori)),
39
+ interpolation=v2.InterpolationMode.BILINEAR
40
+ )
41
+ return image
42
+
43
+ def load_video_frames(self, video_path):
44
+ """加载完整视频"""
45
+ reader = imageio.get_reader(video_path)
46
+ frames = []
47
+
48
+ for frame_data in reader:
49
+ frame = Image.fromarray(frame_data)
50
+ frame = self.crop_and_resize(frame)
51
+ frame = self.frame_process(frame)
52
+ frames.append(frame)
53
+
54
+ reader.close()
55
+
56
+ if len(frames) == 0:
57
+ return None
58
+
59
+ frames = torch.stack(frames, dim=0)
60
+ frames = rearrange(frames, "T C H W -> C T H W")
61
+ return frames
62
+
63
+ def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
64
+ """编码所有场景的视频"""
65
+
66
+ encoder = VideoEncoder(text_encoder_path, vae_path)
67
+ encoder = encoder.cuda()
68
+ encoder.pipe.device = "cuda"
69
+
70
+ processed_count = 0
71
+ prompt_emb = 0
72
+
73
+ os.makedirs(output_dir,exist_ok=True)
74
+
75
+ required_keys = ["latents", "cam_emb", "prompt_emb"]
76
+
77
+
78
+ for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
79
+
80
+ scene_dir = os.path.join(scenes_path, scene_name)
81
+ save_dir = os.path.join(output_dir,scene_name.split('.')[0])
82
+ # print('in:',scene_dir)
83
+ # print('out:',save_dir)
84
+
85
+
86
+ # 检查是否已编码
87
+ encoded_path = os.path.join(save_dir, "encoded_video.pth")
88
+ # if os.path.exists(encoded_path):
89
+ # print(f"Checking scene {scene_name}...")
90
+ # continue
91
+
92
+ # 加载场景信息
93
+
94
+ # print(encoded_path)
95
+ data = torch.load(encoded_path,weights_only=False,
96
+ map_location="cpu")
97
+ missing_keys = [key for key in required_keys if key not in data]
98
+
99
+ if missing_keys:
100
+ print(f"警告: 文件 {encoded_path} 中缺少以下必要元素: {missing_keys}")
101
+ # else:
102
+ # # print("文件包含所有必要元素: latents 和 cam_emb 和 prompt_emb")
103
+ # continue
104
+ # pdb.set_trace()
105
+ if data['prompt_emb']['context'].requires_grad:
106
+ print(f"警告: 文件 {encoded_path} 中存在含梯度变量,已消除")
107
+
108
+ data['prompt_emb']['context'] = data['prompt_emb']['context'].detach().clone()
109
+
110
+ # 双重保险:显式关闭梯度
111
+ data['prompt_emb']['context'].requires_grad_(False)
112
+
113
+ # 验证是否成功(可选)
114
+ assert not data['prompt_emb']['context'].requires_grad, "梯度仍未消除!"
115
+ torch.save(data, encoded_path)
116
+ # with np.load(scene_cam_path) as data:
117
+ # cam_data = data.files
118
+ # cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
119
+ # with open(scene_cam_path, 'rb') as f:
120
+ # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
121
+
122
+
123
+
124
+ # 加载和编码视频
125
+ # video_frames = encoder.load_video_frames(video_path)
126
+ # if video_frames is None:
127
+ # print(f"Failed to load video: {video_path}")
128
+ # continue
129
+
130
+ # video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
131
+ # print(video_frames.shape)
132
+ # 编码视频
133
+ '''with torch.no_grad():
134
+ # latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
135
+
136
+ # 编码文本
137
+ if processed_count == 0:
138
+ print('encode prompt!!!')
139
+ prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")#A video of a scene shot using a drone's front camera
140
+ del encoder.pipe.prompter
141
+
142
+ data["prompt_emb"] = prompt_emb
143
+
144
+ print("已添加/更新 prompt_emb 元素")
145
+
146
+ # 保存修改后的文件(可改为新路径避免覆盖原文件)
147
+ torch.save(data, encoded_path)
148
+
149
+ # pdb.set_trace()
150
+ # 保存编码结果
151
+
152
+ print(f"Saved encoded data: {encoded_path}")'''
153
+ processed_count += 1
154
+
155
+ # except Exception as e:
156
+ # print(f"Error encoding scene {scene_name}: {e}")
157
+ # continue
158
+ print(processed_count)
159
+ print(f"Encoding completed! Processed {processed_count} scenes.")
160
+
161
+ if __name__ == "__main__":
162
+ parser = argparse.ArgumentParser()
163
+ parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/spatialvid")
164
+ parser.add_argument("--text_encoder_path", type=str,
165
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
166
+ parser.add_argument("--vae_path", type=str,
167
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
168
+
169
+ parser.add_argument("--output_dir",type=str,
170
+ default="/share_zhuyixuan05/zhuyixuan05/spatialvid")
171
+
172
+ args = parser.parse_args()
173
+ encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
scripts/analyze_openx.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ def analyze_openx_dataset_frame_counts(dataset_path):
6
+ """分析OpenX数据集中的帧数分布"""
7
+
8
+ print(f"🔧 分析OpenX数据集: {dataset_path}")
9
+
10
+ if not os.path.exists(dataset_path):
11
+ print(f" ⚠️ 路径不存在: {dataset_path}")
12
+ return
13
+
14
+ episode_dirs = []
15
+ total_episodes = 0
16
+ valid_episodes = 0
17
+
18
+ # 收集所有episode目录
19
+ for item in os.listdir(dataset_path):
20
+ episode_dir = os.path.join(dataset_path, item)
21
+ if os.path.isdir(episode_dir):
22
+ total_episodes += 1
23
+ encoded_path = os.path.join(episode_dir, "encoded_video.pth")
24
+ if os.path.exists(encoded_path):
25
+ episode_dirs.append(episode_dir)
26
+ valid_episodes += 1
27
+
28
+ print(f"📊 总episode数: {total_episodes}")
29
+ print(f"📊 有效episode数: {valid_episodes}")
30
+
31
+ if len(episode_dirs) == 0:
32
+ print("❌ 没有找到有效的episode")
33
+ return
34
+
35
+ # 统计帧数分布
36
+ frame_counts = []
37
+ less_than_10 = 0
38
+ less_than_8 = 0
39
+ less_than_5 = 0
40
+ error_count = 0
41
+
42
+ print("🔧 开始分析帧数分布...")
43
+
44
+ for episode_dir in tqdm(episode_dirs, desc="分析episodes"):
45
+ try:
46
+ encoded_data = torch.load(
47
+ os.path.join(episode_dir, "encoded_video.pth"),
48
+ weights_only=False,
49
+ map_location="cpu"
50
+ )
51
+
52
+ latents = encoded_data['latents'] # [C, T, H, W]
53
+ frame_count = latents.shape[1] # T维度
54
+ frame_counts.append(frame_count)
55
+
56
+ if frame_count < 10:
57
+ less_than_10 += 1
58
+ if frame_count < 8:
59
+ less_than_8 += 1
60
+ if frame_count < 5:
61
+ less_than_5 += 1
62
+
63
+ except Exception as e:
64
+ error_count += 1
65
+ if error_count <= 5: # 只打印前5个错误
66
+ print(f"❌ 加载episode {os.path.basename(episode_dir)} 时出错: {e}")
67
+
68
+ # 统计结果
69
+ total_valid = len(frame_counts)
70
+ print(f"\n📈 帧数分布统计:")
71
+ print(f" 总有效episodes: {total_valid}")
72
+ print(f" 错误episodes: {error_count}")
73
+ print(f" 最小帧数: {min(frame_counts) if frame_counts else 0}")
74
+ print(f" 最大帧数: {max(frame_counts) if frame_counts else 0}")
75
+ print(f" 平均帧数: {sum(frame_counts) / len(frame_counts):.2f}" if frame_counts else 0)
76
+
77
+ print(f"\n🎯 关键统计:")
78
+ print(f" 帧数 < 5: {less_than_5:6d} episodes ({less_than_5/total_valid*100:.2f}%)")
79
+ print(f" 帧数 < 8: {less_than_8:6d} episodes ({less_than_8/total_valid*100:.2f}%)")
80
+ print(f" 帧数 < 10: {less_than_10:6d} episodes ({less_than_10/total_valid*100:.2f}%)")
81
+ print(f" 帧数 >= 10: {total_valid-less_than_10:6d} episodes ({(total_valid-less_than_10)/total_valid*100:.2f}%)")
82
+
83
+ # 详细分布
84
+ frame_counts.sort()
85
+ print(f"\n📊 详细帧数分布:")
86
+
87
+ # 按范围统计
88
+ ranges = [
89
+ (1, 4, "1-4帧"),
90
+ (5, 7, "5-7帧"),
91
+ (8, 9, "8-9帧"),
92
+ (10, 19, "10-19帧"),
93
+ (20, 49, "20-49帧"),
94
+ (50, 99, "50-99帧"),
95
+ (100, float('inf'), "100+帧")
96
+ ]
97
+
98
+ for min_f, max_f, label in ranges:
99
+ count = sum(1 for f in frame_counts if min_f <= f <= max_f)
100
+ percentage = count / total_valid * 100
101
+ print(f" {label:8s}: {count:6d} episodes ({percentage:5.2f}%)")
102
+
103
+ # 建议的训练配置
104
+ print(f"\n💡 训练配置建议:")
105
+ time_compression_ratio = 4
106
+ min_condition_compressed = 4 // time_compression_ratio # 1帧
107
+ target_frames_compressed = 32 // time_compression_ratio # 8帧
108
+ min_required_compressed = min_condition_compressed + target_frames_compressed # 9帧
109
+
110
+ usable_episodes = sum(1 for f in frame_counts if f >= min_required_compressed)
111
+ usable_percentage = usable_episodes / total_valid * 100
112
+
113
+ print(f" 最小条件帧数(压缩后): {min_condition_compressed}")
114
+ print(f" 目标帧数(压缩后): {target_frames_compressed}")
115
+ print(f" 最小所需帧数(压缩后): {min_required_compressed}")
116
+ print(f" 可用于训练的episodes: {usable_episodes} ({usable_percentage:.2f}%)")
117
+
118
+ # 保存详细统计到文件
119
+ output_file = os.path.join(dataset_path, "frame_count_analysis.txt")
120
+ with open(output_file, 'w') as f:
121
+ f.write(f"OpenX Dataset Frame Count Analysis\n")
122
+ f.write(f"Dataset Path: {dataset_path}\n")
123
+ f.write(f"Analysis Date: {__import__('datetime').datetime.now()}\n\n")
124
+
125
+ f.write(f"Total Episodes: {total_episodes}\n")
126
+ f.write(f"Valid Episodes: {total_valid}\n")
127
+ f.write(f"Error Episodes: {error_count}\n\n")
128
+
129
+ f.write(f"Frame Count Statistics:\n")
130
+ f.write(f" Min Frames: {min(frame_counts) if frame_counts else 0}\n")
131
+ f.write(f" Max Frames: {max(frame_counts) if frame_counts else 0}\n")
132
+ f.write(f" Avg Frames: {sum(frame_counts) / len(frame_counts):.2f}\n\n" if frame_counts else " Avg Frames: 0\n\n")
133
+
134
+ f.write(f"Key Statistics:\n")
135
+ f.write(f" < 5 frames: {less_than_5} ({less_than_5/total_valid*100:.2f}%)\n")
136
+ f.write(f" < 8 frames: {less_than_8} ({less_than_8/total_valid*100:.2f}%)\n")
137
+ f.write(f" < 10 frames: {less_than_10} ({less_than_10/total_valid*100:.2f}%)\n")
138
+ f.write(f" >= 10 frames: {total_valid-less_than_10} ({(total_valid-less_than_10)/total_valid*100:.2f}%)\n\n")
139
+
140
+ f.write(f"Detailed Distribution:\n")
141
+ for min_f, max_f, label in ranges:
142
+ count = sum(1 for f in frame_counts if min_f <= f <= max_f)
143
+ percentage = count / total_valid * 100
144
+ f.write(f" {label}: {count} ({percentage:.2f}%)\n")
145
+
146
+ f.write(f"\nTraining Configuration Recommendation:\n")
147
+ f.write(f" Usable Episodes (>= {min_required_compressed} compressed frames): {usable_episodes} ({usable_percentage:.2f}%)\n")
148
+
149
+ # 写入所有帧数
150
+ f.write(f"\nAll Frame Counts:\n")
151
+ for i, count in enumerate(frame_counts):
152
+ f.write(f"{count}")
153
+ if (i + 1) % 20 == 0:
154
+ f.write("\n")
155
+ else:
156
+ f.write(", ")
157
+
158
+ print(f"\n💾 详细统计已保存到: {output_file}")
159
+
160
+ return {
161
+ 'total_valid': total_valid,
162
+ 'less_than_10': less_than_10,
163
+ 'less_than_8': less_than_8,
164
+ 'less_than_5': less_than_5,
165
+ 'frame_counts': frame_counts,
166
+ 'usable_episodes': usable_episodes
167
+ }
168
+
169
+ def quick_sample_analysis(dataset_path, sample_size=1000):
170
+ """快速采样分析,用于大数据集的初步估计"""
171
+
172
+ print(f"🚀 快速采样分析 (样本数: {sample_size})")
173
+
174
+ episode_dirs = []
175
+ for item in os.listdir(dataset_path):
176
+ episode_dir = os.path.join(dataset_path, item)
177
+ if os.path.isdir(episode_dir):
178
+ encoded_path = os.path.join(episode_dir, "encoded_video.pth")
179
+ if os.path.exists(encoded_path):
180
+ episode_dirs.append(episode_dir)
181
+
182
+ if len(episode_dirs) == 0:
183
+ print("❌ 没有找到有效的episode")
184
+ return
185
+
186
+ # 随机采样
187
+ import random
188
+ sample_dirs = random.sample(episode_dirs, min(sample_size, len(episode_dirs)))
189
+
190
+ frame_counts = []
191
+ less_than_10 = 0
192
+
193
+ for episode_dir in tqdm(sample_dirs, desc="采样分析"):
194
+ try:
195
+ encoded_data = torch.load(
196
+ os.path.join(episode_dir, "encoded_video.pth"),
197
+ weights_only=False,
198
+ map_location="cpu"
199
+ )
200
+
201
+ frame_count = encoded_data['latents'].shape[1]
202
+ frame_counts.append(frame_count)
203
+
204
+ if frame_count < 10:
205
+ less_than_10 += 1
206
+
207
+ except Exception as e:
208
+ continue
209
+
210
+ total_sample = len(frame_counts)
211
+ percentage_less_than_10 = less_than_10 / total_sample * 100
212
+
213
+ print(f"📊 采样结果:")
214
+ print(f" 采样数量: {total_sample}")
215
+ print(f" < 10帧: {less_than_10} ({percentage_less_than_10:.2f}%)")
216
+ print(f" >= 10帧: {total_sample - less_than_10} ({100 - percentage_less_than_10:.2f}%)")
217
+ print(f" 平均帧数: {sum(frame_counts) / len(frame_counts):.2f}")
218
+
219
+ # 估算全数据集
220
+ total_episodes = len(episode_dirs)
221
+ estimated_less_than_10 = int(total_episodes * percentage_less_than_10 / 100)
222
+
223
+ print(f"\n🔮 全数据集估算:")
224
+ print(f" 总episodes: {total_episodes}")
225
+ print(f" 估算 < 10帧: {estimated_less_than_10} ({percentage_less_than_10:.2f}%)")
226
+ print(f" 估算 >= 10帧: {total_episodes - estimated_less_than_10} ({100 - percentage_less_than_10:.2f}%)")
227
+
228
+ if __name__ == "__main__":
229
+ import argparse
230
+
231
+ parser = argparse.ArgumentParser(description="分析OpenX数据集的帧数分布")
232
+ parser.add_argument("--dataset_path", type=str,
233
+ default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded",
234
+ help="OpenX编码数据集路径")
235
+ parser.add_argument("--quick", action="store_true", help="快速采样分析模式")
236
+ parser.add_argument("--sample_size", type=int, default=1000, help="快速模式的采样数量")
237
+
238
+ args = parser.parse_args()
239
+
240
+ if args.quick:
241
+ quick_sample_analysis(args.dataset_path, args.sample_size)
242
+ else:
243
+ analyze_openx_dataset_frame_counts(args.dataset_path)
scripts/analyze_pose.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from pose_classifier import PoseClassifier
6
+ import torch
7
+ from collections import defaultdict
8
+
9
+ def analyze_turning_patterns_detailed(dataset_path, num_samples=50):
10
+ """详细分析转弯模式,基于相对于reference的pose变化"""
11
+ classifier = PoseClassifier()
12
+ samples_path = os.path.join(dataset_path, "samples")
13
+
14
+ all_analyses = []
15
+ sample_count = 0
16
+
17
+ # 用于统计每个类别的样本
18
+ class_samples = defaultdict(list)
19
+
20
+ print("=== 开始分析样本(基于相对于reference的变化)===")
21
+
22
+ for item in sorted(os.listdir(samples_path)): # 排序以便有序输出
23
+ if sample_count >= num_samples:
24
+ break
25
+
26
+ sample_dir = os.path.join(samples_path, item)
27
+ if os.path.isdir(sample_dir):
28
+ poses_path = os.path.join(sample_dir, "poses.json")
29
+ if os.path.exists(poses_path):
30
+ try:
31
+ with open(poses_path, 'r') as f:
32
+ poses_data = json.load(f)
33
+
34
+ target_relative_poses = poses_data['target_relative_poses']
35
+
36
+ if len(target_relative_poses) > 0:
37
+ # 🔧 创建相对pose向量(已经是相对于reference的)
38
+ pose_vecs = []
39
+ for pose_data in target_relative_poses:
40
+ # 相对位移(已经是相对于reference计算的)
41
+ translation = torch.tensor(pose_data['relative_translation'], dtype=torch.float32)
42
+
43
+ # 🔧 相对旋转(需要从current和reference计算)
44
+ current_rotation = torch.tensor(pose_data['current_rotation'], dtype=torch.float32)
45
+ reference_rotation = torch.tensor(pose_data['reference_rotation'], dtype=torch.float32)
46
+
47
+ # 计算相对旋转:q_relative = q_ref^-1 * q_current
48
+ relative_rotation = calculate_relative_rotation(current_rotation, reference_rotation)
49
+
50
+ # 组合为7D向量:[relative_translation, relative_rotation]
51
+ pose_vec = torch.cat([translation, relative_rotation], dim=0)
52
+ pose_vecs.append(pose_vec)
53
+
54
+ if pose_vecs:
55
+ pose_sequence = torch.stack(pose_vecs, dim=0)
56
+
57
+ # 🔧 使用新的分析方法
58
+ analysis = classifier.analyze_pose_sequence(pose_sequence)
59
+ analysis['sample_name'] = item
60
+ all_analyses.append(analysis)
61
+
62
+ # 🔧 详细输出每个样本的分类信息
63
+ print(f"\n--- 样本 {sample_count + 1}: {item} ---")
64
+ print(f"总帧数: {analysis['total_frames']}")
65
+ print(f"总距离: {analysis['total_distance']:.4f}")
66
+
67
+ # 分类分布
68
+ class_dist = analysis['class_distribution']
69
+ print(f"分类分布:")
70
+ for class_name, count in class_dist.items():
71
+ percentage = count / analysis['total_frames'] * 100
72
+ print(f" {class_name}: {count} 帧 ({percentage:.1f}%)")
73
+
74
+ # 🔧 调试前几个pose的分类过程
75
+ print(f"前3帧的详细分类过程:")
76
+ for i in range(min(3, len(pose_vecs))):
77
+ debug_info = classifier.debug_single_pose(
78
+ pose_vecs[i][:3], pose_vecs[i][3:7]
79
+ )
80
+ print(f" 帧{i}: {debug_info['classification']} "
81
+ f"(yaw: {debug_info['yaw_angle_deg']:.2f}°, "
82
+ f"forward: {debug_info['forward_movement']:.3f})")
83
+
84
+ # 运动段落
85
+ print(f"运动段落:")
86
+ for i, segment in enumerate(analysis['motion_segments']):
87
+ print(f" 段落{i+1}: {segment['class']} (帧 {segment['start_frame']}-{segment['end_frame']}, 持续 {segment['duration']} 帧)")
88
+
89
+ # 🔧 确定主要运动类型
90
+ dominant_class = max(class_dist.items(), key=lambda x: x[1])
91
+ dominant_class_name = dominant_class[0]
92
+ dominant_percentage = dominant_class[1] / analysis['total_frames'] * 100
93
+
94
+ print(f"主要运动类型: {dominant_class_name} ({dominant_percentage:.1f}%)")
95
+
96
+ # 将样本添加到对应类别
97
+ class_samples[dominant_class_name].append({
98
+ 'name': item,
99
+ 'percentage': dominant_percentage,
100
+ 'analysis': analysis
101
+ })
102
+
103
+ sample_count += 1
104
+
105
+ except Exception as e:
106
+ print(f"❌ 处理样本 {item} 时出错: {e}")
107
+
108
+ print("\n" + "="*60)
109
+ print("=== 按类别分组的样本统计(基于相对于reference的变化)===")
110
+
111
+ # 🔧 按类别输出样本列表
112
+ for class_name in ['forward', 'backward', 'left_turn', 'right_turn']:
113
+ samples = class_samples[class_name]
114
+ print(f"\n🔸 {class_name.upper()} 类样本 (共 {len(samples)} 个):")
115
+
116
+ if samples:
117
+ # 按主要类别占比排序
118
+ samples.sort(key=lambda x: x['percentage'], reverse=True)
119
+
120
+ for i, sample_info in enumerate(samples, 1):
121
+ print(f" {i:2d}. {sample_info['name']} ({sample_info['percentage']:.1f}%)")
122
+
123
+ # 显示详细的段落信息
124
+ segments = sample_info['analysis']['motion_segments']
125
+ segment_summary = []
126
+ for seg in segments:
127
+ if seg['duration'] >= 2: # 只显示持续时间>=2帧的段落
128
+ segment_summary.append(f"{seg['class']}({seg['duration']})")
129
+
130
+ if segment_summary:
131
+ print(f" 段落: {' -> '.join(segment_summary)}")
132
+ else:
133
+ print(" (无样本)")
134
+
135
+ # 🔧 统计总体模式
136
+ print(f"\n" + "="*60)
137
+ print("=== 总体统计 ===")
138
+
139
+ total_forward = sum(a['class_distribution']['forward'] for a in all_analyses)
140
+ total_backward = sum(a['class_distribution']['backward'] for a in all_analyses)
141
+ total_left_turn = sum(a['class_distribution']['left_turn'] for a in all_analyses)
142
+ total_right_turn = sum(a['class_distribution']['right_turn'] for a in all_analyses)
143
+ total_frames = total_forward + total_backward + total_left_turn + total_right_turn
144
+
145
+ print(f"总样本数: {len(all_analyses)}")
146
+ print(f"总帧数: {total_frames}")
147
+ print(f"Forward: {total_forward} 帧 ({total_forward/total_frames*100:.1f}%)")
148
+ print(f"Backward: {total_backward} 帧 ({total_backward/total_frames*100:.1f}%)")
149
+ print(f"Left Turn: {total_left_turn} 帧 ({total_left_turn/total_frames*100:.1f}%)")
150
+ print(f"Right Turn: {total_right_turn} 帧 ({total_right_turn/total_frames*100:.1f}%)")
151
+
152
+ # 🔧 样本分布统计
153
+ print(f"\n按主要类型的样本分布:")
154
+ for class_name in ['forward', 'backward', 'left_turn', 'right_turn']:
155
+ count = len(class_samples[class_name])
156
+ percentage = count / len(all_analyses) * 100 if all_analyses else 0
157
+ print(f" {class_name}: {count} 样本 ({percentage:.1f}%)")
158
+
159
+ return all_analyses, class_samples
160
+
161
+ def calculate_relative_rotation(current_rotation, reference_rotation):
162
+ """计算相对旋转四元数"""
163
+ q_current = torch.tensor(current_rotation, dtype=torch.float32)
164
+ q_ref = torch.tensor(reference_rotation, dtype=torch.float32)
165
+
166
+ # 计算参考旋转的逆 (q_ref^-1)
167
+ q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]])
168
+
169
+ # 四元数乘法计算相对旋转: q_relative = q_ref^-1 * q_current
170
+ w1, x1, y1, z1 = q_ref_inv
171
+ w2, x2, y2, z2 = q_current
172
+
173
+ relative_rotation = torch.tensor([
174
+ w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
175
+ w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
176
+ w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
177
+ w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
178
+ ])
179
+
180
+ return relative_rotation
181
+
182
+ if __name__ == "__main__":
183
+ dataset_path = "/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_2"
184
+
185
+ print("开始详细分析pose分类(基于相对于reference的变化)...")
186
+ all_analyses, class_samples = analyze_turning_patterns_detailed(dataset_path, num_samples=4000)
187
+
188
+ print(f"\n🎉 分析完成! 共处理 {len(all_analyses)} 个样本")
scripts/batch_drone.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import subprocess
4
+ import time
5
+
6
+ src_root = "/share_zhuyixuan05/zhuyixuan05/spatialvid"
7
+ dst_root = "/share_zhuyixuan05/zhuyixuan05/New_spatialvid_drone_first"
8
+ infer_script = "/home/zhuyixuan05/ReCamMaster/infer_origin.py" # 修改为你的实际路径
9
+
10
+ while True:
11
+ # 随机选择一个子文件夹
12
+ subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
13
+ if not subdirs:
14
+ print("没有可用的子文件夹")
15
+ break
16
+ chosen = random.choice(subdirs)
17
+ chosen_dir = os.path.join(src_root, chosen)
18
+ pth_file = os.path.join(chosen_dir, "encoded_video.pth")
19
+ if not os.path.exists(pth_file):
20
+ print(f"{pth_file} 不存在,跳过")
21
+ continue
22
+
23
+ # 生成输出文件名
24
+ out_file = os.path.join(dst_root, f"{chosen}.mp4")
25
+ print(f"开始生成: {pth_file} -> {out_file}")
26
+
27
+ # 构造命令
28
+ cmd = [
29
+ "python", infer_script,
30
+ "--condition_pth", pth_file,
31
+ "--output_path", out_file,
32
+ "--prompt", "exploring the world",
33
+ "--modality_type", "sekai",
34
+ "--direction", "right",
35
+ "--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step25000_first.ckpt",
36
+ "--use_gt_prompt"
37
+ ]
38
+
39
+ # 仅使用第二张 GPU
40
+ env = os.environ.copy()
41
+ env["CUDA_VISIBLE_DEVICES"] = "0"
42
+
43
+ # 执行推理
44
+ subprocess.run(cmd, env=env)
scripts/batch_infer.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import argparse
4
+ from pathlib import Path
5
+ import glob
6
+
7
+ def find_video_files(videos_dir):
8
+ """查找视频目录下的所有视频文件"""
9
+ video_extensions = ['.mp4']
10
+ video_files = []
11
+
12
+ for ext in video_extensions:
13
+ pattern = os.path.join(videos_dir, f"*{ext}")
14
+ video_files.extend(glob.glob(pattern))
15
+
16
+ return sorted(video_files)
17
+
18
+ def run_inference(condition_video, direction, dit_path, output_dir):
19
+ """运行单个推理任务"""
20
+ # 构建输出文件名
21
+ input_filename = os.path.basename(condition_video)
22
+ name_parts = os.path.splitext(input_filename)
23
+ output_filename = f"{name_parts[0]}_{direction}{name_parts[1]}"
24
+ output_path = os.path.join(output_dir, output_filename)
25
+
26
+ # 构建推理命令
27
+ cmd = [
28
+ "python", "infer_nus.py",
29
+ "--condition_video", condition_video,
30
+ "--direction", direction,
31
+ "--dit_path", dit_path,
32
+ "--output_path", output_path,
33
+ ]
34
+
35
+ print(f"🎬 生成 {direction} 方向视频: {input_filename} -> {output_filename}")
36
+ print(f" 命令: {' '.join(cmd)}")
37
+
38
+ try:
39
+ # 运行推理
40
+ result = subprocess.run(cmd, capture_output=True, text=True, check=True)
41
+ print(f"✅ 成功生成: {output_path}")
42
+ return True
43
+ except subprocess.CalledProcessError as e:
44
+ print(f"❌ 生成失败: {e}")
45
+ print(f" 错误输出: {e.stderr}")
46
+ return False
47
+
48
+ def batch_inference(args):
49
+ """批量推理主函数"""
50
+ videos_dir = args.videos_dir
51
+ output_dir = args.output_dir
52
+ directions = args.directions
53
+ dit_path = args.dit_path
54
+
55
+ # 检查输入目录
56
+ if not os.path.exists(videos_dir):
57
+ print(f"❌ 视频目录不存在: {videos_dir}")
58
+ return
59
+
60
+ # 创建输出目录
61
+ os.makedirs(output_dir, exist_ok=True)
62
+ print(f"📁 输出目录: {output_dir}")
63
+
64
+ # 查找所有视频文件
65
+ video_files = find_video_files(videos_dir)
66
+
67
+ if not video_files:
68
+ print(f"❌ 在 {videos_dir} 中没有找到视频文件")
69
+ return
70
+
71
+ print(f"🎥 找到 {len(video_files)} 个视频文件:")
72
+ for video in video_files:
73
+ print(f" - {os.path.basename(video)}")
74
+
75
+ print(f"🎯 将为每个视频生成以下方向: {', '.join(directions)}")
76
+ print(f"📊 总共将生成 {len(video_files) * len(directions)} 个视频")
77
+
78
+ # 统计信息
79
+ total_tasks = len(video_files) * len(directions)
80
+ completed_tasks = 0
81
+ failed_tasks = 0
82
+
83
+ # 批量处理
84
+ for i, video_file in enumerate(video_files, 1):
85
+ print(f"\n{'='*60}")
86
+ print(f"处理视频 {i}/{len(video_files)}: {os.path.basename(video_file)}")
87
+ print(f"{'='*60}")
88
+
89
+ for j, direction in enumerate(directions, 1):
90
+ print(f"\n--- 方向 {j}/{len(directions)}: {direction} ---")
91
+
92
+ # 检查输出文件是否已存在
93
+ input_filename = os.path.basename(video_file)
94
+ name_parts = os.path.splitext(input_filename)
95
+ output_filename = f"{name_parts[0]}_{direction}{name_parts[1]}"
96
+ output_path = os.path.join(output_dir, output_filename)
97
+
98
+ if os.path.exists(output_path) and not args.overwrite:
99
+ print(f"⏭️ 文件已存在,跳过: {output_filename}")
100
+ completed_tasks += 1
101
+ continue
102
+
103
+ # 运行推理
104
+ success = run_inference(
105
+ condition_video=video_file,
106
+ direction=direction,
107
+ dit_path=dit_path,
108
+ output_dir=output_dir,
109
+ )
110
+
111
+ if success:
112
+ completed_tasks += 1
113
+ else:
114
+ failed_tasks += 1
115
+
116
+ # 显示进度
117
+ current_progress = completed_tasks + failed_tasks
118
+ print(f"📈 进度: {current_progress}/{total_tasks} "
119
+ f"(成功: {completed_tasks}, 失败: {failed_tasks})")
120
+
121
+ # 最终统计
122
+ print(f"\n{'='*60}")
123
+ print(f"🎉 批量推理完成!")
124
+ print(f"📊 总任务数: {total_tasks}")
125
+ print(f"✅ 成功: {completed_tasks}")
126
+ print(f"❌ 失败: {failed_tasks}")
127
+ print(f"📁 输出目录: {output_dir}")
128
+
129
+ if failed_tasks > 0:
130
+ print(f"⚠️ 有 {failed_tasks} 个任务失败,请检查日志")
131
+
132
+ # 列出生成的文件
133
+ if completed_tasks > 0:
134
+ print(f"\n📋 生成的文件:")
135
+ generated_files = glob.glob(os.path.join(output_dir, "*.mp4"))
136
+ for file_path in sorted(generated_files):
137
+ print(f" - {os.path.basename(file_path)}")
138
+
139
+ def main():
140
+ parser = argparse.ArgumentParser(description="批量对nus/videos目录下的所有视频生成不同方向的输出")
141
+
142
+ parser.add_argument("--videos_dir", type=str, default="/home/zhuyixuan05/ReCamMaster/nus/videos/4032",
143
+ help="输入视频目录路径")
144
+
145
+ parser.add_argument("--output_dir", type=str, default="nus/infer_results/batch_dynamic_4032_noise",
146
+ help="输出视频目录路径")
147
+
148
+ parser.add_argument("--directions", nargs="+",
149
+ default=["left_turn", "right_turn"],
150
+ choices=["forward", "backward", "left_turn", "right_turn"],
151
+ help="要生成的方向列表")
152
+
153
+ parser.add_argument("--dit_path", type=str, default="/home/zhuyixuan05/ReCamMaster/nus_dynamic/step15000_dynamic.ckpt",
154
+ help="训练好的DiT模型路径")
155
+
156
+ parser.add_argument("--overwrite", action="store_true",
157
+ help="是否覆盖已存在的输出文件")
158
+
159
+ parser.add_argument("--dry_run", action="store_true",
160
+ help="只显示将要执行的任务,不实际运行")
161
+
162
+ args = parser.parse_args()
163
+
164
+ if args.dry_run:
165
+ print("🔍 预览模式 - 只显示任务,不执行")
166
+ videos_dir = args.videos_dir
167
+ video_files = find_video_files(videos_dir)
168
+
169
+ print(f"📁 输入目录: {videos_dir}")
170
+ print(f"📁 输出目录: {args.output_dir}")
171
+ print(f"🎥 找到视频: {len(video_files)} 个")
172
+ print(f"🎯 生成方向: {', '.join(args.directions)}")
173
+ print(f"📊 总任务数: {len(video_files) * len(args.directions)}")
174
+
175
+ print(f"\n将要执行的任务:")
176
+ for video in video_files:
177
+ for direction in args.directions:
178
+ input_name = os.path.basename(video)
179
+ name_parts = os.path.splitext(input_name)
180
+ output_name = f"{name_parts[0]}_{direction}{name_parts[1]}"
181
+ print(f" {input_name} -> {output_name} ({direction})")
182
+ else:
183
+ batch_inference(args)
184
+
185
+ if __name__ == "__main__":
186
+ main()
scripts/batch_nus.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import subprocess
4
+ import time
5
+
6
+ src_root = "/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes"
7
+ dst_root = "/share_zhuyixuan05/zhuyixuan05/New_nus_right_2"
8
+ infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径
9
+
10
+ while True:
11
+ # 随机选择一个子文件夹
12
+ subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
13
+ if not subdirs:
14
+ print("没有可用的子文件夹")
15
+ break
16
+ chosen = random.choice(subdirs)
17
+ chosen_dir = os.path.join(src_root, chosen)
18
+ pth_file = os.path.join(chosen_dir, "encoded_video-480p.pth")
19
+ if not os.path.exists(pth_file):
20
+ print(f"{pth_file} 不存在,跳过")
21
+ continue
22
+
23
+ # 生成输出文件名
24
+ out_file = os.path.join(dst_root, f"{chosen}.mp4")
25
+ print(f"开始生成: {pth_file} -> {out_file}")
26
+
27
+ # 构造命令
28
+ cmd = [
29
+ "python", infer_script,
30
+ "--condition_pth", pth_file,
31
+ "--output_path", out_file,
32
+ "--prompt", "a car is driving",
33
+ "--modality_type", "nuscenes",
34
+ "--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt"
35
+ ]
36
+
37
+ # 仅使用第二张 GPU
38
+ env = os.environ.copy()
39
+ env["CUDA_VISIBLE_DEVICES"] = "1"
40
+
41
+ # 执行推理
42
+ subprocess.run(cmd, env=env)
scripts/batch_rt.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import subprocess
4
+ import time
5
+
6
+ src_root = "/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded"
7
+ dst_root = "/share_zhuyixuan05/zhuyixuan05/New_RT"
8
+ infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径
9
+
10
+ while True:
11
+ # 随机选择一个子文件夹
12
+ subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
13
+ if not subdirs:
14
+ print("没有可用的子文件夹")
15
+ break
16
+ chosen = random.choice(subdirs)
17
+ chosen_dir = os.path.join(src_root, chosen)
18
+ pth_file = os.path.join(chosen_dir, "encoded_video.pth")
19
+ if not os.path.exists(pth_file):
20
+ print(f"{pth_file} 不存在,跳过")
21
+ continue
22
+
23
+ # 生成输出文件名
24
+ out_file = os.path.join(dst_root, f"{chosen}.mp4")
25
+ print(f"开始生成: {pth_file} -> {out_file}")
26
+
27
+ # 构造命令
28
+ cmd = [
29
+ "python", infer_script,
30
+ "--condition_pth", pth_file,
31
+ "--output_path", out_file,
32
+ "--prompt", "A robotic arm is moving the object",
33
+ "--modality_type", "openx",
34
+ ]
35
+
36
+ # 仅使用第二张 GPU
37
+ env = os.environ.copy()
38
+ env["CUDA_VISIBLE_DEVICES"] = "1"
39
+
40
+ # 执行推理
41
+ subprocess.run(cmd, env=env)
scripts/batch_spa.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import subprocess
4
+ import time
5
+
6
+ src_root = "/share_zhuyixuan05/zhuyixuan05/spatialvid"
7
+ dst_root = "/share_zhuyixuan05/zhuyixuan05/New_spatialvid_right"
8
+ infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径
9
+
10
+ while True:
11
+ # 随机选择一个子文件夹
12
+ subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
13
+ if not subdirs:
14
+ print("没有可用的子文件夹")
15
+ break
16
+ chosen = random.choice(subdirs)
17
+ chosen_dir = os.path.join(src_root, chosen)
18
+ pth_file = os.path.join(chosen_dir, "encoded_video.pth")
19
+ if not os.path.exists(pth_file):
20
+ print(f"{pth_file} 不存在,跳过")
21
+ continue
22
+
23
+ # 生成输出文件名
24
+ out_file = os.path.join(dst_root, f"{chosen}.mp4")
25
+ print(f"开始生成: {pth_file} -> {out_file}")
26
+
27
+ # 构造命令
28
+ cmd = [
29
+ "python", infer_script,
30
+ "--condition_pth", pth_file,
31
+ "--output_path", out_file,
32
+ "--prompt", "exploring the world",
33
+ "--modality_type", "sekai",
34
+ #"--direction", "left",
35
+ "--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt"
36
+ ]
37
+
38
+ # 仅使用第二张 GPU
39
+ env = os.environ.copy()
40
+ env["CUDA_VISIBLE_DEVICES"] = "0"
41
+
42
+ # 执行推理
43
+ subprocess.run(cmd, env=env)
scripts/batch_walk.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import subprocess
4
+ import time
5
+
6
+ src_root = "/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes"
7
+ dst_root = "/share_zhuyixuan05/zhuyixuan05/New_walk"
8
+ infer_script = "/home/zhuyixuan05/ReCamMaster/infer_moe.py" # 修改为你的实际路径
9
+
10
+ while True:
11
+ # 随机选择一个子文件夹
12
+ subdirs = [d for d in os.listdir(src_root) if os.path.isdir(os.path.join(src_root, d))]
13
+ if not subdirs:
14
+ print("没有可用的子文件夹")
15
+ break
16
+ chosen = random.choice(subdirs)
17
+ chosen_dir = os.path.join(src_root, chosen)
18
+ pth_file = os.path.join(chosen_dir, "encoded_video-480p.pth")
19
+ if not os.path.exists(pth_file):
20
+ print(f"{pth_file} 不存在,跳过")
21
+ continue
22
+
23
+ # 生成输出文件名
24
+ out_file = os.path.join(dst_root, f"{chosen}.mp4")
25
+ print(f"开始生成: {pth_file} -> {out_file}")
26
+
27
+ # 构造命令
28
+ cmd = [
29
+ "python", infer_script,
30
+ "--condition_pth", pth_file,
31
+ "--output_path", out_file,
32
+ "--prompt", "a car is driving",
33
+ "--modality_type", "nuscenes",
34
+ "--dit_path", "/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt"
35
+ ]
36
+
37
+ # 仅使用第二张 GPU
38
+ env = os.environ.copy()
39
+ env["CUDA_VISIBLE_DEVICES"] = "1"
40
+
41
+ # 执行推理
42
+ subprocess.run(cmd, env=env)
scripts/check.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import argparse
4
+ from collections import defaultdict
5
+ import time
6
+
7
+ def load_checkpoint(ckpt_path):
8
+ """加载检查点文件"""
9
+ if not os.path.exists(ckpt_path):
10
+ return None
11
+
12
+ try:
13
+ state_dict = torch.load(ckpt_path, map_location='cpu')
14
+ return state_dict
15
+ except Exception as e:
16
+ print(f"❌ 加载检查点失败: {e}")
17
+ return None
18
+
19
+ def compare_parameters(state_dict1, state_dict2, threshold=1e-8):
20
+ """比较两个状态字典的参数差异"""
21
+ if state_dict1 is None or state_dict2 is None:
22
+ return None
23
+
24
+ updated_params = {}
25
+ unchanged_params = {}
26
+
27
+ for name, param1 in state_dict1.items():
28
+ if name in state_dict2:
29
+ param2 = state_dict2[name]
30
+
31
+ # 计算参数差异
32
+ diff = torch.abs(param1 - param2)
33
+ max_diff = torch.max(diff).item()
34
+ mean_diff = torch.mean(diff).item()
35
+
36
+ if max_diff > threshold:
37
+ updated_params[name] = {
38
+ 'max_diff': max_diff,
39
+ 'mean_diff': mean_diff,
40
+ 'shape': param1.shape
41
+ }
42
+ else:
43
+ unchanged_params[name] = {
44
+ 'max_diff': max_diff,
45
+ 'mean_diff': mean_diff,
46
+ 'shape': param1.shape
47
+ }
48
+
49
+ return updated_params, unchanged_params
50
+
51
+ def categorize_parameters(param_dict):
52
+ """将参数按类型分类"""
53
+ categories = {
54
+ 'moe_related': {},
55
+ 'camera_related': {},
56
+ 'framepack_related': {},
57
+ 'attention': {},
58
+ 'other': {}
59
+ }
60
+
61
+ for name, info in param_dict.items():
62
+ if any(keyword in name.lower() for keyword in ['moe', 'gate', 'expert', 'processor']):
63
+ categories['moe_related'][name] = info
64
+ elif any(keyword in name.lower() for keyword in ['cam_encoder', 'projector', 'camera']):
65
+ categories['camera_related'][name] = info
66
+ elif any(keyword in name.lower() for keyword in ['clean_x_embedder', 'framepack']):
67
+ categories['framepack_related'][name] = info
68
+ elif any(keyword in name.lower() for keyword in ['attn', 'attention']):
69
+ categories['attention'][name] = info
70
+ else:
71
+ categories['other'][name] = info
72
+
73
+ return categories
74
+
75
+ def print_category_summary(category_name, params, color_code=''):
76
+ """打印某类参数的摘要"""
77
+ if not params:
78
+ print(f"{color_code} {category_name}: 无参数")
79
+ return
80
+
81
+ total_params = len(params)
82
+ max_diffs = [info['max_diff'] for info in params.values()]
83
+ mean_diffs = [info['mean_diff'] for info in params.values()]
84
+
85
+ print(f"{color_code} {category_name} ({total_params} 个参数):")
86
+ print(f" 最大差异范围: {min(max_diffs):.2e} ~ {max(max_diffs):.2e}")
87
+ print(f" 平均差异范围: {min(mean_diffs):.2e} ~ {max(mean_diffs):.2e}")
88
+
89
+ # 显示前5个最大变化的参数
90
+ sorted_params = sorted(params.items(), key=lambda x: x[1]['max_diff'], reverse=True)
91
+ print(f" 变化最大的参数:")
92
+ for i, (name, info) in enumerate(sorted_params[:100]):
93
+ shape_str = 'x'.join(map(str, info['shape']))
94
+ print(f" {i+1}. {name} [{shape_str}]: max_diff={info['max_diff']:.2e}")
95
+
96
+ def monitor_training(checkpoint_dir, check_interval=60):
97
+ """监控训练过程中的参数更新"""
98
+ print(f"🔍 开始监控训练进度...")
99
+ print(f"📁 检查点目录: {checkpoint_dir}")
100
+ print(f"⏰ 检查间隔: {check_interval}秒")
101
+ print("=" * 80)
102
+
103
+ previous_ckpt = None
104
+ previous_step = -1
105
+
106
+ while True:
107
+ try:
108
+ # 查找最新的检查点
109
+ if not os.path.exists(checkpoint_dir):
110
+ print(f"❌ 检查点目录不存在: {checkpoint_dir}")
111
+ time.sleep(check_interval)
112
+ continue
113
+
114
+ ckpt_files = [f for f in os.listdir(checkpoint_dir) if f.startswith('step') and f.endswith('.ckpt')]
115
+ if not ckpt_files:
116
+ print("⏳ 未找到检查点文件,等待中...")
117
+ time.sleep(check_interval)
118
+ continue
119
+
120
+ # 按步数排序,获取最新的
121
+ ckpt_files.sort(key=lambda x: int(x.replace('step', '').replace('.ckpt', '')))
122
+ latest_ckpt_file = ckpt_files[-1]
123
+ latest_ckpt_path = os.path.join(checkpoint_dir, latest_ckpt_file)
124
+
125
+ # 提取步数
126
+ current_step = int(latest_ckpt_file.replace('step', '').replace('.ckpt', ''))
127
+
128
+ if current_step <= previous_step:
129
+ print(f"⏳ 等待新的检查点... (当前: step{current_step})")
130
+ time.sleep(check_interval)
131
+ continue
132
+
133
+ print(f"\n🔍 发现新检查点: {latest_ckpt_file}")
134
+
135
+ # 加载当前检查点
136
+ current_state_dict = load_checkpoint(latest_ckpt_path)
137
+ if current_state_dict is None:
138
+ print("❌ 无法加载当前检查点")
139
+ time.sleep(check_interval)
140
+ continue
141
+
142
+ if previous_ckpt is not None:
143
+ print(f"📊 比较 step{previous_step} -> step{current_step}")
144
+
145
+ # 比较参数
146
+ updated_params, unchanged_params = compare_parameters(
147
+ previous_ckpt, current_state_dict, threshold=1e-8
148
+ )
149
+
150
+ if updated_params is None:
151
+ print("❌ 参数比较失败")
152
+ else:
153
+ # 分类显示结果
154
+ updated_categories = categorize_parameters(updated_params)
155
+ unchanged_categories = categorize_parameters(unchanged_params)
156
+
157
+ print(f"\n✅ 已更新的参数 (总共 {len(updated_params)} 个):")
158
+ print_category_summary("MoE相关", updated_categories['moe_related'], '🔥')
159
+ print_category_summary("Camera相关", updated_categories['camera_related'], '📷')
160
+ print_category_summary("FramePack相关", updated_categories['framepack_related'], '🎞️')
161
+ print_category_summary("注意力相关", updated_categories['attention'], '👁️')
162
+ print_category_summary("其他", updated_categories['other'], '📦')
163
+
164
+ print(f"\n⚠️ 未更新的参数 (总共 {len(unchanged_params)} 个):")
165
+ print_category_summary("MoE相关", unchanged_categories['moe_related'], '❄️')
166
+ print_category_summary("Camera相关", unchanged_categories['camera_related'], '❄️')
167
+ print_category_summary("FramePack相关", unchanged_categories['framepack_related'], '❄️')
168
+ print_category_summary("注意力相关", unchanged_categories['attention'], '❄️')
169
+ print_category_summary("其他", unchanged_categories['other'], '❄️')
170
+
171
+ # 检查关键组件是否在更新
172
+ critical_keywords = ['moe', 'cam_encoder', 'projector', 'clean_x_embedder']
173
+ critical_updated = any(
174
+ any(keyword in name.lower() for keyword in critical_keywords)
175
+ for name in updated_params.keys()
176
+ )
177
+
178
+ if critical_updated:
179
+ print("\n✅ 关键组件正在更新!")
180
+ else:
181
+ print("\n❌ 警告:关键组件可能未在更新!")
182
+
183
+ # 计算更新率
184
+ total_params = len(updated_params) + len(unchanged_params)
185
+ update_rate = len(updated_params) / total_params * 100
186
+ print(f"\n📈 参数更新率: {update_rate:.1f}% ({len(updated_params)}/{total_params})")
187
+
188
+ # 保存当前状态用于下次比较
189
+ previous_ckpt = current_state_dict
190
+ previous_step = current_step
191
+
192
+ print("=" * 80)
193
+ time.sleep(check_interval)
194
+
195
+ except KeyboardInterrupt:
196
+ print("\n👋 监控已停止")
197
+ break
198
+ except Exception as e:
199
+ print(f"❌ 监控过程中出错: {e}")
200
+ time.sleep(check_interval)
201
+
202
+ def compare_two_checkpoints(ckpt1_path, ckpt2_path):
203
+ """比较两个特定的检查点"""
204
+ print(f"🔍 比较两个检查点:")
205
+ print(f" 检查点1: {ckpt1_path}")
206
+ print(f" 检查点2: {ckpt2_path}")
207
+ print("=" * 80)
208
+
209
+ # 加载检查点
210
+ state_dict1 = load_checkpoint(ckpt1_path)
211
+ state_dict2 = load_checkpoint(ckpt2_path)
212
+
213
+ if state_dict1 is None or state_dict2 is None:
214
+ print("❌ 无法加载检查点文件")
215
+ return
216
+
217
+ # 比较参数
218
+ updated_params, unchanged_params = compare_parameters(state_dict1, state_dict2)
219
+
220
+ if updated_params is None:
221
+ print("❌ 参数比较失败")
222
+ return
223
+
224
+ # 分类显示结果
225
+ updated_categories = categorize_parameters(updated_params)
226
+ unchanged_categories = categorize_parameters(unchanged_params)
227
+
228
+ print(f"\n✅ 已更新的参数 (总共 {len(updated_params)} 个):")
229
+ for category_name, params in updated_categories.items():
230
+ print_category_summary(category_name.replace('_', ' ').title(), params, '🔥')
231
+
232
+ print(f"\n⚠️ 未更新的参数 (总共 {len(unchanged_params)} 个):")
233
+ for category_name, params in unchanged_categories.items():
234
+ print_category_summary(category_name.replace('_', ' ').title(), params, '❄️')
235
+
236
+ # 计算更新率
237
+ total_params = len(updated_params) + len(unchanged_params)
238
+ update_rate = len(updated_params) / total_params * 100
239
+ print(f"\n📈 参数更新率: {update_rate:.1f}% ({len(updated_params)}/{total_params})")
240
+
241
+ if __name__ == '__main__':
242
+ parser = argparse.ArgumentParser(description="检查模型参数更新情况")
243
+ parser.add_argument("--checkpoint_dir", type=str,
244
+ default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe",
245
+ help="检查点目录路径")
246
+ parser.add_argument("--compare", default=True,
247
+ help="比较两个特定检查点,而不是监控")
248
+ parser.add_argument("--ckpt1", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step1500_origin_cam_4.ckpt")
249
+ parser.add_argument("--ckpt2", type=str, default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step500_origin_cam_4.ckpt")
250
+ parser.add_argument("--interval", type=int, default=60,
251
+ help="监控检查间隔(秒)")
252
+ parser.add_argument("--threshold", type=float, default=1e-8,
253
+ help="参数变化阈值")
254
+
255
+ args = parser.parse_args()
256
+
257
+ if args.compare:
258
+ if not args.ckpt1 or not args.ckpt2:
259
+ print("❌ 比较模式需要指定 --ckpt1 和 --ckpt2")
260
+ else:
261
+ compare_two_checkpoints(args.ckpt1, args.ckpt2)
262
+ else:
263
+ monitor_training(args.checkpoint_dir, args.interval)
scripts/decode_openx.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import imageio
6
+ import argparse
7
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
8
+ from tqdm import tqdm
9
+ import json
10
+
11
+ class VideoDecoder:
12
+ def __init__(self, vae_path, device="cuda"):
13
+ """初始化视频解码器"""
14
+ self.device = device
15
+
16
+ # 初始化模型管理器
17
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
18
+ model_manager.load_models([vae_path])
19
+
20
+ # 创建pipeline并只保留VAE
21
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
22
+ self.pipe = self.pipe.to(device)
23
+
24
+ # 🔧 关键修复:确保VAE及其所有组件都在正确设备上
25
+ self.pipe.vae = self.pipe.vae.to(device)
26
+ if hasattr(self.pipe.vae, 'model'):
27
+ self.pipe.vae.model = self.pipe.vae.model.to(device)
28
+
29
+ print(f"✅ VAE解码器初始化完成,设备: {device}")
30
+
31
+ def decode_latents_to_video(self, latents, output_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
32
+ """
33
+ 将latents解码为视频 - 修正版本,修复维度处理问题
34
+ """
35
+ print(f"🔧 开始解码latents...")
36
+ print(f"输入latents形状: {latents.shape}")
37
+ print(f"输入latents设备: {latents.device}")
38
+ print(f"输入latents数据类型: {latents.dtype}")
39
+
40
+ # 确保latents有batch维度
41
+ if len(latents.shape) == 4: # [C, T, H, W]
42
+ latents = latents.unsqueeze(0) # -> [1, C, T, H, W]
43
+
44
+ # 🔧 关键修正:确保latents在正确的设备上且数据类型匹配
45
+ model_dtype = next(self.pipe.vae.parameters()).dtype
46
+ model_device = next(self.pipe.vae.parameters()).device
47
+
48
+ print(f"模型设备: {model_device}")
49
+ print(f"模型数据类型: {model_dtype}")
50
+
51
+ # 将latents移动到正确的设备和数据类型
52
+ latents = latents.to(device=model_device, dtype=model_dtype)
53
+
54
+ print(f"解码latents形状: {latents.shape}")
55
+ print(f"解码latents设备: {latents.device}")
56
+ print(f"解码latents数据类型: {latents.dtype}")
57
+
58
+ # 🔧 强制设置pipeline设备,确保所有操作在同一设备上
59
+ self.pipe.device = model_device
60
+
61
+ # 使用VAE解码
62
+ with torch.no_grad():
63
+ try:
64
+ if tiled:
65
+ print("🔧 尝试tiled解码...")
66
+ decoded_video = self.pipe.decode_video(
67
+ latents,
68
+ tiled=True,
69
+ tile_size=tile_size,
70
+ tile_stride=tile_stride
71
+ )
72
+ else:
73
+ print("🔧 使用非tiled解码...")
74
+ decoded_video = self.pipe.decode_video(latents, tiled=False)
75
+
76
+ except Exception as e:
77
+ print(f"decode_video失败,错误: {e}")
78
+ import traceback
79
+ traceback.print_exc()
80
+
81
+ # 🔧 fallback: 尝试直接调用VAE
82
+ try:
83
+ print("🔧 尝试直接调用VAE解码...")
84
+ decoded_video = self.pipe.vae.decode(
85
+ latents.squeeze(0), # 移除batch维度 [C, T, H, W]
86
+ device=model_device,
87
+ tiled=False
88
+ )
89
+ # 手动调整维度: VAE输出 [T, H, W, C] -> [1, T, H, W, C]
90
+ if len(decoded_video.shape) == 4: # [T, H, W, C]
91
+ decoded_video = decoded_video.unsqueeze(0) # -> [1, T, H, W, C]
92
+ except Exception as e2:
93
+ print(f"直接VAE解码也失败: {e2}")
94
+ raise e2
95
+
96
+ print(f"解码后视频形状: {decoded_video.shape}")
97
+
98
+ # 🔧 关键修正:正确处理维度顺序
99
+ video_np = None
100
+
101
+ if len(decoded_video.shape) == 5:
102
+ # 检查不同的可能维度顺序
103
+ if decoded_video.shape == torch.Size([1, 3, 113, 480, 832]):
104
+ # 格式: [B, C, T, H, W] -> 需要转换为 [T, H, W, C]
105
+ print("🔧 检测到格式: [B, C, T, H, W]")
106
+ video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy() # [T, H, W, C]
107
+ elif decoded_video.shape[1] == 3:
108
+ # 如果第二个维度是3,可能是 [B, C, T, H, W]
109
+ print("🔧 检测到可能的格式: [B, C, T, H, W]")
110
+ video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy() # [T, H, W, C]
111
+ elif decoded_video.shape[-1] == 3:
112
+ # 如果最后一个维度是3,可能是 [B, T, H, W, C]
113
+ print("🔧 检测到格式: [B, T, H, W, C]")
114
+ video_np = decoded_video[0].to(torch.float32).cpu().numpy() # [T, H, W, C]
115
+ else:
116
+ # 尝试找到维度为3的位置
117
+ shape = list(decoded_video.shape)
118
+ if 3 in shape:
119
+ channel_dim = shape.index(3)
120
+ print(f"🔧 检测到通道维度在位置: {channel_dim}")
121
+
122
+ if channel_dim == 1: # [B, C, T, H, W]
123
+ video_np = decoded_video[0].permute(1, 2, 3, 0).to(torch.float32).cpu().numpy()
124
+ elif channel_dim == 4: # [B, T, H, W, C]
125
+ video_np = decoded_video[0].to(torch.float32).cpu().numpy()
126
+ else:
127
+ print(f"⚠️ 未知的通道维度位置: {channel_dim}")
128
+ raise ValueError(f"Cannot handle channel dimension at position {channel_dim}")
129
+ else:
130
+ print(f"⚠️ 未找到通道维度为3的位置,形状: {decoded_video.shape}")
131
+ raise ValueError(f"Cannot find channel dimension of size 3 in shape {decoded_video.shape}")
132
+
133
+ elif len(decoded_video.shape) == 4:
134
+ # 4维张量,检查可能的格式
135
+ if decoded_video.shape[-1] == 3: # [T, H, W, C]
136
+ video_np = decoded_video.to(torch.float32).cpu().numpy()
137
+ elif decoded_video.shape[0] == 3: # [C, T, H, W]
138
+ video_np = decoded_video.permute(1, 2, 3, 0).to(torch.float32).cpu().numpy()
139
+ else:
140
+ print(f"⚠️ 无法处理的4D视频形状: {decoded_video.shape}")
141
+ raise ValueError(f"Cannot handle 4D video tensor shape: {decoded_video.shape}")
142
+ else:
143
+ print(f"⚠️ 意外的视频维度数: {len(decoded_video.shape)}")
144
+ raise ValueError(f"Unexpected video tensor dimensions: {decoded_video.shape}")
145
+
146
+ if video_np is None:
147
+ raise ValueError("Failed to convert video tensor to numpy array")
148
+
149
+ print(f"转换后视频数组形状: {video_np.shape}")
150
+
151
+ # 🔧 验证最终形状
152
+ if len(video_np.shape) != 4:
153
+ raise ValueError(f"Expected 4D array [T, H, W, C], got {video_np.shape}")
154
+
155
+ if video_np.shape[-1] != 3:
156
+ print(f"⚠️ 通道数异常: 期望3,实际{video_np.shape[-1]}")
157
+ print(f"完整形状: {video_np.shape}")
158
+ # 尝试其他维度排列
159
+ if video_np.shape[0] == 3: # [C, T, H, W]
160
+ print("🔧 尝试重新排列: [C, T, H, W] -> [T, H, W, C]")
161
+ video_np = np.transpose(video_np, (1, 2, 3, 0))
162
+ elif video_np.shape[1] == 3: # [T, C, H, W]
163
+ print("🔧 尝试重新排列: [T, C, H, W] -> [T, H, W, C]")
164
+ video_np = np.transpose(video_np, (0, 2, 3, 1))
165
+ else:
166
+ raise ValueError(f"Expected 3 channels (RGB), got {video_np.shape[-1]} channels")
167
+
168
+ # 反归一化
169
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1) # 反归一化
170
+ video_np = (video_np * 255).astype(np.uint8)
171
+
172
+ print(f"最终视频数组形状: {video_np.shape}")
173
+ print(f"视频数组值范围: {video_np.min()} - {video_np.max()}")
174
+
175
+ # 保存视频
176
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
177
+
178
+ try:
179
+ with imageio.get_writer(output_path, fps=10, quality=8) as writer:
180
+ for frame_idx, frame in enumerate(video_np):
181
+ # 🔧 验证每一帧的形状
182
+ if len(frame.shape) != 3 or frame.shape[-1] != 3:
183
+ print(f"⚠️ 帧 {frame_idx} 形状异常: {frame.shape}")
184
+ continue
185
+
186
+ writer.append_data(frame)
187
+ if frame_idx % 10 == 0:
188
+ print(f" 写入帧 {frame_idx}/{len(video_np)}")
189
+ except Exception as e:
190
+ print(f"保存视频失败: {e}")
191
+ # 🔧 尝试保存前几帧为图片进行调试
192
+ debug_dir = os.path.join(os.path.dirname(output_path), "debug_frames")
193
+ os.makedirs(debug_dir, exist_ok=True)
194
+
195
+ for i in range(min(5, len(video_np))):
196
+ frame = video_np[i]
197
+ debug_path = os.path.join(debug_dir, f"debug_frame_{i}.png")
198
+ try:
199
+ if len(frame.shape) == 3 and frame.shape[-1] == 3:
200
+ Image.fromarray(frame).save(debug_path)
201
+ print(f"调试: 保存帧 {i} 到 {debug_path}")
202
+ else:
203
+ print(f"调试: 帧 {i} 形状异常: {frame.shape}")
204
+ except Exception as e2:
205
+ print(f"调试: 保存帧 {i} 失败: {e2}")
206
+ raise e
207
+
208
+ print(f"✅ 视频保存到: {output_path}")
209
+ return video_np
210
+
211
+ def save_frames_as_images(self, video_np, output_dir, prefix="frame"):
212
+ """将视频帧保存为单独的图像文件"""
213
+ os.makedirs(output_dir, exist_ok=True)
214
+
215
+ for i, frame in enumerate(video_np):
216
+ frame_path = os.path.join(output_dir, f"{prefix}_{i:04d}.png")
217
+ # 🔧 验证帧形状
218
+ if len(frame.shape) == 3 and frame.shape[-1] == 3:
219
+ Image.fromarray(frame).save(frame_path)
220
+ else:
221
+ print(f"⚠️ 跳过形状异常的帧 {i}: {frame.shape}")
222
+
223
+ print(f"✅ 保存了 {len(video_np)} 帧到: {output_dir}")
224
+
225
+ def decode_single_episode(encoded_pth_path, vae_path, output_base_dir, device="cuda"):
226
+ """解码单个episode的编码数据 - 修正版本"""
227
+ print(f"\n🔧 解码episode: {encoded_pth_path}")
228
+
229
+ # 加载编码数据
230
+ try:
231
+ encoded_data = torch.load(encoded_pth_path, weights_only=False, map_location="cpu")
232
+ print(f"✅ 成功加载编码数据")
233
+ except Exception as e:
234
+ print(f"❌ 加载编码数据失败: {e}")
235
+ return False
236
+
237
+ # 检查数据结构
238
+ print("🔍 编码数据结构:")
239
+ for key, value in encoded_data.items():
240
+ if isinstance(value, torch.Tensor):
241
+ print(f" - {key}: {value.shape}, dtype: {value.dtype}, device: {value.device}")
242
+ elif isinstance(value, dict):
243
+ print(f" - {key}: dict with keys {list(value.keys())}")
244
+ else:
245
+ print(f" - {key}: {type(value)}")
246
+
247
+ # 获取latents
248
+ latents = encoded_data.get('latents')
249
+ if latents is None:
250
+ print("❌ 未找到latents数据")
251
+ return False
252
+
253
+ # 🔧 确保latents在CPU上(加载时的默认状态)
254
+ if latents.device != torch.device('cpu'):
255
+ latents = latents.cpu()
256
+ print(f"🔧 将latents移动到CPU: {latents.device}")
257
+
258
+ episode_info = encoded_data.get('episode_info', {})
259
+ episode_idx = episode_info.get('episode_idx', 'unknown')
260
+ total_frames = episode_info.get('total_frames', latents.shape[1] * 4) # 估算原始帧数
261
+
262
+ print(f"Episode信息:")
263
+ print(f" - Episode索引: {episode_idx}")
264
+ print(f" - Latents形状: {latents.shape}")
265
+ print(f" - Latents设备: {latents.device}")
266
+ print(f" - Latents数据类型: {latents.dtype}")
267
+ print(f" - 原始总帧数: {total_frames}")
268
+ print(f" - 压缩后帧数: {latents.shape[1]}")
269
+
270
+ # 创建输出目录
271
+ episode_name = f"episode_{episode_idx:06d}" if isinstance(episode_idx, int) else f"episode_{episode_idx}"
272
+ output_dir = os.path.join(output_base_dir, episode_name)
273
+ os.makedirs(output_dir, exist_ok=True)
274
+
275
+ # 初始化解码器
276
+ try:
277
+ decoder = VideoDecoder(vae_path, device)
278
+ except Exception as e:
279
+ print(f"❌ 初始化解码器失败: {e}")
280
+ return False
281
+
282
+ # 解码为视频
283
+ video_output_path = os.path.join(output_dir, "decoded_video.mp4")
284
+ try:
285
+ video_np = decoder.decode_latents_to_video(
286
+ latents,
287
+ video_output_path,
288
+ tiled=False, # 🔧 首先尝试非tiled解码,避免tiled的复杂性
289
+ tile_size=(34, 34),
290
+ tile_stride=(18, 16)
291
+ )
292
+
293
+ # 保存前几帧为图像(用于快速检查)
294
+ frames_dir = os.path.join(output_dir, "frames")
295
+ sample_frames = video_np[:min(10, len(video_np))] # 只保存前10帧
296
+ decoder.save_frames_as_images(sample_frames, frames_dir, f"frame_{episode_idx}")
297
+
298
+ # 保存解码信息
299
+ decode_info = {
300
+ "source_pth": encoded_pth_path,
301
+ "decoded_video_path": video_output_path,
302
+ "latents_shape": list(latents.shape),
303
+ "decoded_video_shape": list(video_np.shape),
304
+ "original_total_frames": total_frames,
305
+ "decoded_frames": len(video_np),
306
+ "compression_ratio": total_frames / len(video_np) if len(video_np) > 0 else 0,
307
+ "latents_dtype": str(latents.dtype),
308
+ "latents_device": str(latents.device),
309
+ "vae_compression_ratio": total_frames / latents.shape[1] if latents.shape[1] > 0 else 0
310
+ }
311
+
312
+ info_path = os.path.join(output_dir, "decode_info.json")
313
+ with open(info_path, 'w') as f:
314
+ json.dump(decode_info, f, indent=2)
315
+
316
+ print(f"✅ Episode {episode_idx} 解码完成")
317
+ print(f" - 原始帧数: {total_frames}")
318
+ print(f" - 解码帧数: {len(video_np)}")
319
+ print(f" - 压缩比: {decode_info['compression_ratio']:.2f}")
320
+ print(f" - VAE时间压缩比: {decode_info['vae_compression_ratio']:.2f}")
321
+ return True
322
+
323
+ except Exception as e:
324
+ print(f"❌ 解码失败: {e}")
325
+ import traceback
326
+ traceback.print_exc()
327
+ return False
328
+
329
+ def batch_decode_episodes(encoded_base_dir, vae_path, output_base_dir, max_episodes=None, device="cuda"):
330
+ """批量解码episodes"""
331
+ print(f"🔧 批量解码Open-X episodes")
332
+ print(f"源目录: {encoded_base_dir}")
333
+ print(f"输出目录: {output_base_dir}")
334
+
335
+ # 查找所有编码的episodes
336
+ episode_dirs = []
337
+ if os.path.exists(encoded_base_dir):
338
+ for item in sorted(os.listdir(encoded_base_dir)): # 排序确保一致性
339
+ episode_dir = os.path.join(encoded_base_dir, item)
340
+ if os.path.isdir(episode_dir):
341
+ encoded_path = os.path.join(episode_dir, "encoded_video.pth")
342
+ if os.path.exists(encoded_path):
343
+ episode_dirs.append(encoded_path)
344
+
345
+ print(f"找到 {len(episode_dirs)} 个编码的episodes")
346
+
347
+ if max_episodes and len(episode_dirs) > max_episodes:
348
+ episode_dirs = episode_dirs[:max_episodes]
349
+ print(f"限制处理前 {max_episodes} 个episodes")
350
+
351
+ # 批量解码
352
+ success_count = 0
353
+ for i, encoded_pth_path in enumerate(tqdm(episode_dirs, desc="解码episodes")):
354
+ print(f"\n{'='*60}")
355
+ print(f"处理 {i+1}/{len(episode_dirs)}: {os.path.basename(os.path.dirname(encoded_pth_path))}")
356
+
357
+ success = decode_single_episode(encoded_pth_path, vae_path, output_base_dir, device)
358
+ if success:
359
+ success_count += 1
360
+
361
+ print(f"当前成功率: {success_count}/{i+1} ({success_count/(i+1)*100:.1f}%)")
362
+
363
+ print(f"\n🎉 批量解码完成!")
364
+ print(f"总处理: {len(episode_dirs)} 个episodes")
365
+ print(f"成功解码: {success_count} 个episodes")
366
+ print(f"成功率: {success_count/len(episode_dirs)*100:.1f}%")
367
+
368
+ def main():
369
+ parser = argparse.ArgumentParser(description="解码Open-X编码的latents以验证正确性 - 修正版本")
370
+ parser.add_argument("--mode", type=str, choices=["single", "batch"], default="batch",
371
+ help="解码模式:single (单个episode) 或 batch (批量)")
372
+ parser.add_argument("--encoded_pth", type=str,
373
+ default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000000/encoded_video.pth",
374
+ help="单个编码文件路径(single模式)")
375
+ parser.add_argument("--encoded_base_dir", type=str,
376
+ default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded",
377
+ help="编码数据基础目录(batch模式)")
378
+ parser.add_argument("--vae_path", type=str,
379
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
380
+ help="VAE模型路径")
381
+ parser.add_argument("--output_dir", type=str,
382
+ default="./decoded_results_fixed",
383
+ help="解码输出目录")
384
+ parser.add_argument("--max_episodes", type=int, default=5,
385
+ help="最大解码episodes数量(batch模式,用于测试)")
386
+ parser.add_argument("--device", type=str, default="cuda",
387
+ help="计算设备")
388
+
389
+ args = parser.parse_args()
390
+
391
+ print("🔧 Open-X Latents 解码验证工具 (修正版本 - Fixed)")
392
+ print(f"模式: {args.mode}")
393
+ print(f"VAE路径: {args.vae_path}")
394
+ print(f"输出目录: {args.output_dir}")
395
+ print(f"设备: {args.device}")
396
+
397
+ # 🔧 检查CUDA可用性
398
+ if args.device == "cuda" and not torch.cuda.is_available():
399
+ print("⚠️ CUDA不可用,切换到CPU")
400
+ args.device = "cpu"
401
+
402
+ # 确保输出目录存在
403
+ os.makedirs(args.output_dir, exist_ok=True)
404
+
405
+ if args.mode == "single":
406
+ print(f"输入文件: {args.encoded_pth}")
407
+ if not os.path.exists(args.encoded_pth):
408
+ print(f"❌ 输入文件不存在: {args.encoded_pth}")
409
+ return
410
+
411
+ success = decode_single_episode(args.encoded_pth, args.vae_path, args.output_dir, args.device)
412
+ if success:
413
+ print("✅ 单个episode解码成功")
414
+ else:
415
+ print("❌ 单个episode解码失败")
416
+
417
+ elif args.mode == "batch":
418
+ print(f"输入目录: {args.encoded_base_dir}")
419
+ print(f"最大episodes: {args.max_episodes}")
420
+
421
+ if not os.path.exists(args.encoded_base_dir):
422
+ print(f"❌ 输入目录不存在: {args.encoded_base_dir}")
423
+ return
424
+
425
+ batch_decode_episodes(args.encoded_base_dir, args.vae_path, args.output_dir, args.max_episodes, args.device)
426
+
427
+ if __name__ == "__main__":
428
+ main()
scripts/download_recam.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+
3
+ snapshot_download(
4
+ repo_id="KwaiVGI/ReCamMaster-Wan2.1",
5
+ local_dir="models/ReCamMaster/checkpoints",
6
+ resume_download=True # 支持断点续传
7
+ )
scripts/encode_dynamic_videos.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import lightning as pl
4
+ from PIL import Image
5
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
6
+ import json
7
+ import imageio
8
+ from torchvision.transforms import v2
9
+ from einops import rearrange
10
+ import argparse
11
+ from tqdm import tqdm
12
+ class VideoEncoder(pl.LightningModule):
13
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
14
+ super().__init__()
15
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
16
+ model_manager.load_models([text_encoder_path, vae_path])
17
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
18
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
19
+
20
+ self.frame_process = v2.Compose([
21
+ # v2.CenterCrop(size=(900, 1600)),
22
+ # v2.Resize(size=(900, 1600), antialias=True),
23
+ v2.ToTensor(),
24
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
25
+ ])
26
+
27
+ def crop_and_resize(self, image):
28
+ width, height = image.size
29
+ width_ori, height_ori_ = 832 , 480
30
+ image = v2.functional.resize(
31
+ image,
32
+ (round(height_ori_), round(width_ori)),
33
+ interpolation=v2.InterpolationMode.BILINEAR
34
+ )
35
+ return image
36
+
37
+ def load_video_frames(self, video_path):
38
+ """加载完整视频"""
39
+ reader = imageio.get_reader(video_path)
40
+ frames = []
41
+
42
+ for frame_data in reader:
43
+ frame = Image.fromarray(frame_data)
44
+ frame = self.crop_and_resize(frame)
45
+ frame = self.frame_process(frame)
46
+ frames.append(frame)
47
+
48
+ reader.close()
49
+
50
+ if len(frames) == 0:
51
+ return None
52
+
53
+ frames = torch.stack(frames, dim=0)
54
+ frames = rearrange(frames, "T C H W -> C T H W")
55
+ return frames
56
+
57
+ def encode_scenes(scenes_path, text_encoder_path, vae_path):
58
+ """编码所有场景的视频"""
59
+ encoder = VideoEncoder(text_encoder_path, vae_path)
60
+ encoder = encoder.cuda()
61
+ encoder.pipe.device = "cuda"
62
+
63
+ processed_count = 0
64
+
65
+ for idx, scene_name in enumerate(tqdm(os.listdir(scenes_path))):
66
+ if idx < 450:
67
+ continue
68
+ scene_dir = os.path.join(scenes_path, scene_name)
69
+ if not os.path.isdir(scene_dir):
70
+ continue
71
+
72
+ # 检查是否已编码
73
+ encoded_path = os.path.join(scene_dir, "encoded_video-480p-1.pth")
74
+ if os.path.exists(encoded_path):
75
+ print(f"Scene {scene_name} already encoded, skipping...")
76
+ continue
77
+
78
+ # 加载场景信息
79
+ scene_info_path = os.path.join(scene_dir, "scene_info.json")
80
+ if not os.path.exists(scene_info_path):
81
+ continue
82
+
83
+ with open(scene_info_path, 'r') as f:
84
+ scene_info = json.load(f)
85
+
86
+ # 加载视频
87
+ video_path = os.path.join(scene_dir, scene_info['video_path'])
88
+ if not os.path.exists(video_path):
89
+ print(f"Video not found: {video_path}")
90
+ continue
91
+
92
+ try:
93
+ print(f"Encoding scene {scene_name}...")
94
+
95
+ # 加载和编码视频
96
+ video_frames = encoder.load_video_frames(video_path)
97
+ if video_frames is None:
98
+ print(f"Failed to load video: {video_path}")
99
+ continue
100
+
101
+ video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
102
+
103
+ # 编码视频
104
+ with torch.no_grad():
105
+ latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
106
+ # print(latents.shape)
107
+ # assert False
108
+ # 编码文本
109
+ # prompt_emb = encoder.pipe.encode_prompt("A car driving scene captured by front camera")
110
+ if processed_count == 0:
111
+ print('encode prompt!!!')
112
+ prompt_emb = encoder.pipe.encode_prompt("A car driving scene captured by front camera")
113
+ del encoder.pipe.prompter
114
+
115
+ # 保存编码结果
116
+ encoded_data = {
117
+ "latents": latents.cpu(),
118
+ "prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
119
+ "image_emb": {}
120
+ }
121
+
122
+ torch.save(encoded_data, encoded_path)
123
+ print(f"Saved encoded data: {encoded_path}")
124
+ processed_count += 1
125
+
126
+ except Exception as e:
127
+ print(f"Error encoding scene {scene_name}: {e}")
128
+ continue
129
+
130
+ print(f"Encoding completed! Processed {processed_count} scenes.")
131
+
132
+ if __name__ == "__main__":
133
+ parser = argparse.ArgumentParser()
134
+ parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes")
135
+ parser.add_argument("--text_encoder_path", type=str,
136
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
137
+ parser.add_argument("--vae_path", type=str,
138
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
139
+
140
+ args = parser.parse_args()
141
+ encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path)
scripts/encode_openx.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import lightning as pl
4
+ from PIL import Image
5
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
6
+ import json
7
+ import imageio
8
+ from torchvision.transforms import v2
9
+ from einops import rearrange
10
+ import argparse
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+
14
+ # 🔧 关键修复:设置环境变量避免GCS连接
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
+ os.environ["TFDS_DISABLE_GCS"] = "1"
17
+
18
+ import tensorflow_datasets as tfds
19
+ import tensorflow as tf
20
+
21
+ class VideoEncoder(pl.LightningModule):
22
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
23
+ super().__init__()
24
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
25
+ model_manager.load_models([text_encoder_path, vae_path])
26
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
27
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
28
+
29
+ self.frame_process = v2.Compose([
30
+ v2.ToTensor(),
31
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
32
+ ])
33
+
34
+ def crop_and_resize(self, image, target_width=832, target_height=480):
35
+ """调整图像尺寸"""
36
+ image = v2.functional.resize(
37
+ image,
38
+ (target_height, target_width),
39
+ interpolation=v2.InterpolationMode.BILINEAR
40
+ )
41
+ return image
42
+
43
+ def load_episode_frames(self, episode_data, max_frames=300):
44
+ """🔧 从fractal数据集加载视频帧 - 基于实际observation字段优化"""
45
+ frames = []
46
+
47
+ steps = episode_data['steps']
48
+ frame_count = 0
49
+
50
+ print(f"开始提取帧,最多 {max_frames} 帧...")
51
+
52
+ for step_idx, step in enumerate(steps):
53
+ if frame_count >= max_frames:
54
+ break
55
+
56
+ try:
57
+ obs = step['observation']
58
+
59
+ # 🔧 基于实际的observation字段,优先使用'image'
60
+ img_data = None
61
+ image_keys_to_try = [
62
+ 'image', # ✅ 确认存在的主要图像字段
63
+ 'rgb', # 备用RGB图像
64
+ 'camera_image', # 备用相机图像
65
+ 'exterior_image_1_left', # 可能的外部摄像头
66
+ 'wrist_image', # 可能的手腕摄像头
67
+ ]
68
+
69
+ for img_key in image_keys_to_try:
70
+ if img_key in obs:
71
+ try:
72
+ img_tensor = obs[img_key]
73
+ img_data = img_tensor.numpy()
74
+ if step_idx < 3: # 只为前几个步骤打印
75
+ print(f"✅ 找到图像字段: {img_key}, 形状: {img_data.shape}")
76
+ break
77
+ except Exception as e:
78
+ if step_idx < 3:
79
+ print(f"尝试字段 {img_key} 失败: {e}")
80
+ continue
81
+
82
+ if img_data is not None:
83
+ # 确保图像数据格式正确
84
+ if len(img_data.shape) == 3: # [H, W, C]
85
+ if img_data.dtype == np.uint8:
86
+ frame = Image.fromarray(img_data)
87
+ else:
88
+ # 如果是归一化的浮点数,转换为uint8
89
+ if img_data.max() <= 1.0:
90
+ img_data = (img_data * 255).astype(np.uint8)
91
+ else:
92
+ img_data = img_data.astype(np.uint8)
93
+ frame = Image.fromarray(img_data)
94
+
95
+ # 转换为RGB如果需要
96
+ if frame.mode != 'RGB':
97
+ frame = frame.convert('RGB')
98
+
99
+ frame = self.crop_and_resize(frame)
100
+ frame = self.frame_process(frame)
101
+ frames.append(frame)
102
+ frame_count += 1
103
+
104
+ if frame_count % 50 == 0:
105
+ print(f"已处理 {frame_count} 帧")
106
+ else:
107
+ if step_idx < 5:
108
+ print(f"步骤 {step_idx}: 图像形状不正确 {img_data.shape}")
109
+ else:
110
+ # 如果找不到图像,打印可用的观测键
111
+ if step_idx < 5: # 只为前几个步骤打印
112
+ available_keys = list(obs.keys())
113
+ print(f"步骤 {step_idx}: 未找到图像,可用键: {available_keys}")
114
+
115
+ except Exception as e:
116
+ print(f"处理步骤 {step_idx} 时出错: {e}")
117
+ continue
118
+
119
+ print(f"成功提取 {len(frames)} 帧")
120
+
121
+ if len(frames) == 0:
122
+ return None
123
+
124
+ frames = torch.stack(frames, dim=0)
125
+ frames = rearrange(frames, "T C H W -> C T H W")
126
+ return frames
127
+
128
+ def extract_camera_poses(self, episode_data, num_frames):
129
+ """🔧 从fractal数据集提取相机位姿信息 - 基于实际observation和action字段优化"""
130
+ camera_poses = []
131
+
132
+ steps = episode_data['steps']
133
+ frame_count = 0
134
+
135
+ print("提取相机位姿信息...")
136
+
137
+ # 🔧 累积位姿信息
138
+ cumulative_translation = np.array([0.0, 0.0, 0.0], dtype=np.float32)
139
+ cumulative_rotation = np.array([0.0, 0.0, 0.0], dtype=np.float32) # 欧拉角
140
+
141
+ for step_idx, step in enumerate(steps):
142
+ if frame_count >= num_frames:
143
+ break
144
+
145
+ try:
146
+ obs = step['observation']
147
+ action = step.get('action', {})
148
+
149
+ # 🔧 基于实际的字段提取位姿变化
150
+ pose_data = {}
151
+ found_pose = False
152
+
153
+ # 1. 优先使用action中的world_vector(世界坐标系中的位移)
154
+ if 'world_vector' in action:
155
+ try:
156
+ world_vector = action['world_vector'].numpy()
157
+ if len(world_vector) == 3:
158
+ # 累积世界坐标位移
159
+ cumulative_translation += world_vector
160
+ pose_data['translation'] = cumulative_translation.copy()
161
+ found_pose = True
162
+
163
+ if step_idx < 3:
164
+ print(f"使用action.world_vector: {world_vector}, 累积位移: {cumulative_translation}")
165
+ except Exception as e:
166
+ if step_idx < 3:
167
+ print(f"action.world_vector提取失败: {e}")
168
+
169
+ # 2. 使用action中的rotation_delta(旋转变化)
170
+ if 'rotation_delta' in action:
171
+ try:
172
+ rotation_delta = action['rotation_delta'].numpy()
173
+ if len(rotation_delta) == 3:
174
+ # 累积旋转变化
175
+ cumulative_rotation += rotation_delta
176
+
177
+ # 转换为四元数(简化版本)
178
+ euler_angles = cumulative_rotation
179
+ # 欧拉角转四元数(ZYX顺序)
180
+ roll, pitch, yaw = euler_angles[0], euler_angles[1], euler_angles[2]
181
+
182
+ # 简化的欧拉角到四元数转换
183
+ cy = np.cos(yaw * 0.5)
184
+ sy = np.sin(yaw * 0.5)
185
+ cp = np.cos(pitch * 0.5)
186
+ sp = np.sin(pitch * 0.5)
187
+ cr = np.cos(roll * 0.5)
188
+ sr = np.sin(roll * 0.5)
189
+
190
+ qw = cr * cp * cy + sr * sp * sy
191
+ qx = sr * cp * cy - cr * sp * sy
192
+ qy = cr * sp * cy + sr * cp * sy
193
+ qz = cr * cp * sy - sr * sp * cy
194
+
195
+ pose_data['rotation'] = np.array([qw, qx, qy, qz], dtype=np.float32)
196
+ found_pose = True
197
+
198
+ if step_idx < 3:
199
+ print(f"使用action.rotation_delta: {rotation_delta}, 累积旋转: {cumulative_rotation}")
200
+ except Exception as e:
201
+ if step_idx < 3:
202
+ print(f"action.rotation_delta提取失败: {e}")
203
+
204
+ # 确保rotation字段存在
205
+ if 'rotation' not in pose_data:
206
+ # 使用当前累积的旋转计算四元数
207
+ roll, pitch, yaw = cumulative_rotation[0], cumulative_rotation[1], cumulative_rotation[2]
208
+
209
+ cy = np.cos(yaw * 0.5)
210
+ sy = np.sin(yaw * 0.5)
211
+ cp = np.cos(pitch * 0.5)
212
+ sp = np.sin(pitch * 0.5)
213
+ cr = np.cos(roll * 0.5)
214
+ sr = np.sin(roll * 0.5)
215
+
216
+ qw = cr * cp * cy + sr * sp * sy
217
+ qx = sr * cp * cy - cr * sp * sy
218
+ qy = cr * sp * cy + sr * cp * sy
219
+ qz = cr * cp * sy - sr * sp * cy
220
+
221
+ pose_data['rotation'] = np.array([qw, qx, qy, qz], dtype=np.float32)
222
+
223
+ camera_poses.append(pose_data)
224
+ frame_count += 1
225
+
226
+ except Exception as e:
227
+ print(f"提取位姿步骤 {step_idx} 时出错: {e}")
228
+ # 添加默认位姿
229
+ pose_data = {
230
+ 'translation': cumulative_translation.copy(),
231
+ 'rotation': np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
232
+ }
233
+ camera_poses.append(pose_data)
234
+ frame_count += 1
235
+
236
+ print(f"提取了 {len(camera_poses)} 个位姿")
237
+ print(f"最终累积位移: {cumulative_translation}")
238
+ print(f"最终累积旋转: {cumulative_rotation}")
239
+
240
+ return camera_poses
241
+
242
+ def create_camera_matrices(self, camera_poses):
243
+ """将位姿转换为4x4变换矩阵"""
244
+ matrices = []
245
+
246
+ for pose in camera_poses:
247
+ matrix = np.eye(4, dtype=np.float32)
248
+
249
+ # 设置平移
250
+ matrix[:3, 3] = pose['translation']
251
+
252
+ # 设置旋转 - 假设是四元数 [w, x, y, z]
253
+ if len(pose['rotation']) == 4:
254
+ # 四元数转旋转矩阵
255
+ q = pose['rotation']
256
+ w, x, y, z = q[0], q[1], q[2], q[3]
257
+
258
+ # 四元数到旋转矩阵的转换
259
+ matrix[0, 0] = 1 - 2*(y*y + z*z)
260
+ matrix[0, 1] = 2*(x*y - w*z)
261
+ matrix[0, 2] = 2*(x*z + w*y)
262
+ matrix[1, 0] = 2*(x*y + w*z)
263
+ matrix[1, 1] = 1 - 2*(x*x + z*z)
264
+ matrix[1, 2] = 2*(y*z - w*x)
265
+ matrix[2, 0] = 2*(x*z - w*y)
266
+ matrix[2, 1] = 2*(y*z + w*x)
267
+ matrix[2, 2] = 1 - 2*(x*x + y*y)
268
+ elif len(pose['rotation']) == 3:
269
+ # 欧拉角转换(如果需要)
270
+ pass
271
+
272
+ matrices.append(matrix)
273
+
274
+ return np.array(matrices)
275
+
276
+ def encode_fractal_dataset(dataset_path, text_encoder_path, vae_path, output_dir, max_episodes=None):
277
+ """🔧 编码fractal20220817_data数据集 - 基于实际字段结构优化"""
278
+
279
+ encoder = VideoEncoder(text_encoder_path, vae_path)
280
+ encoder = encoder.cuda()
281
+ encoder.pipe.device = "cuda"
282
+
283
+ os.makedirs(output_dir, exist_ok=True)
284
+
285
+ processed_count = 0
286
+ prompt_emb = None
287
+
288
+ try:
289
+ # 🔧 使用你提供的成功方法加载数据集
290
+ ds = tfds.load(
291
+ "fractal20220817_data",
292
+ split="train",
293
+ data_dir=dataset_path,
294
+ )
295
+
296
+ print(f"✅ 成功加载fractal20220817_data数据集")
297
+
298
+ # 限制处理的episode数量
299
+ if max_episodes:
300
+ ds = ds.take(max_episodes)
301
+ print(f"限制处理episodes数量: {max_episodes}")
302
+
303
+ except Exception as e:
304
+ print(f"❌ 加载数据集失败: {e}")
305
+ return
306
+
307
+ for episode_idx, episode in enumerate(tqdm(ds, desc="处理episodes")):
308
+ try:
309
+ episode_name = f"episode_{episode_idx:06d}"
310
+ save_episode_dir = os.path.join(output_dir, episode_name)
311
+
312
+ # 检查是否已经处理过
313
+ encoded_path = os.path.join(save_episode_dir, "encoded_video.pth")
314
+ if os.path.exists(encoded_path):
315
+ print(f"Episode {episode_name} 已处理,跳过...")
316
+ processed_count += 1
317
+ continue
318
+
319
+ os.makedirs(save_episode_dir, exist_ok=True)
320
+
321
+ print(f"\n🔧 处理episode {episode_name}...")
322
+
323
+ # 🔧 分析episode结构(仅对前几个episode)
324
+ if episode_idx < 2:
325
+ print("Episode结构分析:")
326
+ for key in episode.keys():
327
+ print(f" - {key}: {type(episode[key])}")
328
+
329
+ # 分析第一个step的结构
330
+ steps = episode['steps']
331
+ for step in steps.take(1):
332
+ print("第一个step结构:")
333
+ for key in step.keys():
334
+ print(f" - {key}: {type(step[key])}")
335
+
336
+ if 'observation' in step:
337
+ obs = step['observation']
338
+ print(" observation键:")
339
+ print(f" 🔍 可用字段: {list(obs.keys())}")
340
+
341
+ # 重点检查图像和位姿相关字段
342
+ key_fields = ['image', 'vector_to_go', 'rotation_delta_to_go', 'base_pose_tool_reached']
343
+ for key in key_fields:
344
+ if key in obs:
345
+ try:
346
+ value = obs[key]
347
+ if hasattr(value, 'shape'):
348
+ print(f" ✅ {key}: {type(value)}, shape: {value.shape}")
349
+ else:
350
+ print(f" ✅ {key}: {type(value)}")
351
+ except Exception as e:
352
+ print(f" ❌ {key}: 无法访问 ({e})")
353
+
354
+ if 'action' in step:
355
+ action = step['action']
356
+ print(" action键:")
357
+ print(f" 🔍 可用字段: {list(action.keys())}")
358
+
359
+ # 重点检查位姿相关字段
360
+ key_fields = ['world_vector', 'rotation_delta', 'base_displacement_vector']
361
+ for key in key_fields:
362
+ if key in action:
363
+ try:
364
+ value = action[key]
365
+ if hasattr(value, 'shape'):
366
+ print(f" ✅ {key}: {type(value)}, shape: {value.shape}")
367
+ else:
368
+ print(f" ✅ {key}: {type(value)}")
369
+ except Exception as e:
370
+ print(f" ❌ {key}: 无法访问 ({e})")
371
+
372
+ # 加载视频帧
373
+ video_frames = encoder.load_episode_frames(episode)
374
+ if video_frames is None:
375
+ print(f"❌ 无法加载episode {episode_name}的视频帧")
376
+ continue
377
+
378
+ print(f"✅ Episode {episode_name} 视频形状: {video_frames.shape}")
379
+
380
+ # 提取相机位姿
381
+ num_frames = video_frames.shape[1]
382
+ camera_poses = encoder.extract_camera_poses(episode, num_frames)
383
+ camera_matrices = encoder.create_camera_matrices(camera_poses)
384
+
385
+ print(f"🔧 编码episode {episode_name}...")
386
+
387
+ # 准备相机数据
388
+ cam_emb = {
389
+ 'extrinsic': camera_matrices,
390
+ 'intrinsic': np.eye(3, dtype=np.float32)
391
+ }
392
+
393
+ # 编码视频
394
+ frames_batch = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
395
+
396
+ with torch.no_grad():
397
+ latents = encoder.pipe.encode_video(frames_batch, **encoder.tiler_kwargs)[0]
398
+
399
+ # 编码文本prompt(第一次)
400
+ if prompt_emb is None:
401
+ print('🔧 编码prompt...')
402
+ prompt_emb = encoder.pipe.encode_prompt(
403
+ "A video of robotic manipulation task with camera movement"
404
+ )
405
+ # 释放prompter以节省内存
406
+ del encoder.pipe.prompter
407
+
408
+ # 保存编码结果
409
+ encoded_data = {
410
+ "latents": latents.cpu(),
411
+ "prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v
412
+ for k, v in prompt_emb.items()},
413
+ "cam_emb": cam_emb,
414
+ "episode_info": {
415
+ "episode_idx": episode_idx,
416
+ "total_frames": video_frames.shape[1],
417
+ "pose_extraction_method": "observation_action_based"
418
+ }
419
+ }
420
+
421
+ torch.save(encoded_data, encoded_path)
422
+ print(f"✅ 保存编码数据: {encoded_path}")
423
+
424
+ processed_count += 1
425
+ print(f"✅ 已处理 {processed_count} 个episodes")
426
+
427
+ except Exception as e:
428
+ print(f"❌ 处理episode {episode_idx}时出错: {e}")
429
+ import traceback
430
+ traceback.print_exc()
431
+ continue
432
+
433
+ print(f"🎉 编码完成! 总共处理了 {processed_count} 个episodes")
434
+ if __name__ == "__main__":
435
+ parser = argparse.ArgumentParser(description="Encode Open-X Fractal20220817 Dataset - Based on Real Structure")
436
+ parser.add_argument("--dataset_path", type=str,
437
+ default="/share_zhuyixuan05/public_datasets/open-x/0.1.0",
438
+ help="Path to tensorflow_datasets directory")
439
+ parser.add_argument("--text_encoder_path", type=str,
440
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
441
+ parser.add_argument("--vae_path", type=str,
442
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
443
+ parser.add_argument("--output_dir", type=str,
444
+ default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded")
445
+ parser.add_argument("--max_episodes", type=int, default=10000,
446
+ help="Maximum number of episodes to process (default: 10 for testing)")
447
+
448
+ args = parser.parse_args()
449
+
450
+ # 确保输出目录存在
451
+ os.makedirs(args.output_dir, exist_ok=True)
452
+
453
+ print("🚀 开始编码Open-X Fractal数据集 (基于实际字段结构)...")
454
+ print(f"📁 数据集路径: {args.dataset_path}")
455
+ print(f"💾 输出目录: {args.output_dir}")
456
+ print(f"🔢 最大处理episodes: {args.max_episodes}")
457
+ print("🔧 基于实际observation和action字段的位姿提取方法")
458
+ print("✅ 优先使用 'image' 字段获取图像数据")
459
+
460
+ encode_fractal_dataset(
461
+ args.dataset_path,
462
+ args.text_encoder_path,
463
+ args.vae_path,
464
+ args.output_dir,
465
+ args.max_episodes
466
+ )
scripts/encode_rlbench_video.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import lightning as pl
4
+ from PIL import Image
5
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
6
+ import json
7
+ import imageio
8
+ from torchvision.transforms import v2
9
+ from einops import rearrange
10
+ import argparse
11
+ import numpy as np
12
+ import pdb
13
+ from tqdm import tqdm
14
+
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
+
17
+ class VideoEncoder(pl.LightningModule):
18
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
19
+ super().__init__()
20
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
21
+ model_manager.load_models([text_encoder_path, vae_path])
22
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
23
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
24
+
25
+ self.frame_process = v2.Compose([
26
+ # v2.CenterCrop(size=(900, 1600)),
27
+ # v2.Resize(size=(900, 1600), antialias=True),
28
+ v2.ToTensor(),
29
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
30
+ ])
31
+
32
+ def crop_and_resize(self, image):
33
+ width, height = image.size
34
+ # print(width,height)
35
+ width_ori, height_ori_ = 512 , 512
36
+ image = v2.functional.resize(
37
+ image,
38
+ (round(height_ori_), round(width_ori)),
39
+ interpolation=v2.InterpolationMode.BILINEAR
40
+ )
41
+ return image
42
+
43
+ def load_video_frames(self, video_path):
44
+ """加载完整视频"""
45
+ reader = imageio.get_reader(video_path)
46
+ frames = []
47
+
48
+ for frame_data in reader:
49
+ frame = Image.fromarray(frame_data)
50
+ frame = self.crop_and_resize(frame)
51
+ frame = self.frame_process(frame)
52
+ frames.append(frame)
53
+
54
+ reader.close()
55
+
56
+ if len(frames) == 0:
57
+ return None
58
+
59
+ frames = torch.stack(frames, dim=0)
60
+ frames = rearrange(frames, "T C H W -> C T H W")
61
+ return frames
62
+
63
+ def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
64
+ """编码所有场景的视频"""
65
+
66
+ encoder = VideoEncoder(text_encoder_path, vae_path)
67
+ encoder = encoder.cuda()
68
+ encoder.pipe.device = "cuda"
69
+
70
+ processed_count = 0
71
+ prompt_emb = 0
72
+
73
+ os.makedirs(output_dir,exist_ok=True)
74
+
75
+ for i, scene_name in enumerate(os.listdir(scenes_path)):
76
+ # if i < 1700:
77
+ # continue
78
+ scene_dir = os.path.join(scenes_path, scene_name)
79
+ for j, demo_name in tqdm(enumerate(os.listdir(scene_dir)),total=len(os.listdir(scene_dir))):
80
+ demo_dir = os.path.join(scene_dir, demo_name)
81
+ for filename in os.listdir(demo_dir):
82
+ # 检查文件是否以.mp4结尾(不区分大小写)
83
+ if filename.lower().endswith('.mp4'):
84
+ # 获取完整路径
85
+ full_path = os.path.join(demo_dir, filename)
86
+ print(full_path)
87
+ save_dir = os.path.join(output_dir,scene_name+'_'+demo_name)
88
+ # print('in:',scene_dir)
89
+ # print('out:',save_dir)
90
+
91
+
92
+
93
+ os.makedirs(save_dir,exist_ok=True)
94
+ # 检查是否已编码
95
+ encoded_path = os.path.join(save_dir, "encoded_video.pth")
96
+ if os.path.exists(encoded_path):
97
+ print(f"Scene {scene_name} already encoded, skipping...")
98
+ continue
99
+
100
+ # 加载场景信息
101
+
102
+ scene_cam_path = full_path.replace("side.mp4", "data.npy")
103
+ print(scene_cam_path)
104
+ if not os.path.exists(scene_cam_path):
105
+ continue
106
+
107
+ # with np.load(scene_cam_path) as data:
108
+ cam_data = np.load(scene_cam_path)
109
+ cam_emb = cam_data
110
+ print(cam_data.shape)
111
+ # with open(scene_cam_path, 'rb') as f:
112
+ # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
113
+
114
+ # 加载视频
115
+ video_path = full_path
116
+ if not os.path.exists(video_path):
117
+ print(f"Video not found: {video_path}")
118
+ continue
119
+
120
+ # try:
121
+ print(f"Encoding scene {scene_name}...Demo {demo_name}")
122
+
123
+ # 加载和编码视频
124
+ video_frames = encoder.load_video_frames(video_path)
125
+ if video_frames is None:
126
+ print(f"Failed to load video: {video_path}")
127
+ continue
128
+
129
+ video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
130
+ print('video shape:',video_frames.shape)
131
+ # 编码视频
132
+ with torch.no_grad():
133
+ latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
134
+
135
+ # 编码文本
136
+ # if processed_count == 0:
137
+ # print('encode prompt!!!')
138
+ # prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
139
+ # del encoder.pipe.prompter
140
+ # pdb.set_trace()
141
+ # 保存编码结果
142
+ encoded_data = {
143
+ "latents": latents.cpu(),
144
+ #"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
145
+ "cam_emb": cam_emb
146
+ }
147
+ # pdb.set_trace()
148
+ torch.save(encoded_data, encoded_path)
149
+ print(f"Saved encoded data: {encoded_path}")
150
+ processed_count += 1
151
+
152
+ # except Exception as e:
153
+ # print(f"Error encoding scene {scene_name}: {e}")
154
+ # continue
155
+
156
+ print(f"Encoding completed! Processed {processed_count} scenes.")
157
+
158
+ if __name__ == "__main__":
159
+ parser = argparse.ArgumentParser()
160
+ parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/zhuyixuan05/RLBench")
161
+ parser.add_argument("--text_encoder_path", type=str,
162
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
163
+ parser.add_argument("--vae_path", type=str,
164
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
165
+
166
+ parser.add_argument("--output_dir",type=str,
167
+ default="/share_zhuyixuan05/zhuyixuan05/rlbench")
168
+
169
+ args = parser.parse_args()
170
+ encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
scripts/encode_sekai_video.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import lightning as pl
4
+ from PIL import Image
5
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
6
+ import json
7
+ import imageio
8
+ from torchvision.transforms import v2
9
+ from einops import rearrange
10
+ import argparse
11
+ import numpy as np
12
+ import pdb
13
+ from tqdm import tqdm
14
+
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
+
17
+ class VideoEncoder(pl.LightningModule):
18
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
19
+ super().__init__()
20
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
21
+ model_manager.load_models([text_encoder_path, vae_path])
22
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
23
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
24
+
25
+ self.frame_process = v2.Compose([
26
+ # v2.CenterCrop(size=(900, 1600)),
27
+ # v2.Resize(size=(900, 1600), antialias=True),
28
+ v2.ToTensor(),
29
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
30
+ ])
31
+
32
+ def crop_and_resize(self, image):
33
+ width, height = image.size
34
+ # print(width,height)
35
+ width_ori, height_ori_ = 832 , 480
36
+ image = v2.functional.resize(
37
+ image,
38
+ (round(height_ori_), round(width_ori)),
39
+ interpolation=v2.InterpolationMode.BILINEAR
40
+ )
41
+ return image
42
+
43
+ def load_video_frames(self, video_path):
44
+ """加载完整视频"""
45
+ reader = imageio.get_reader(video_path)
46
+ frames = []
47
+
48
+ for frame_data in reader:
49
+ frame = Image.fromarray(frame_data)
50
+ frame = self.crop_and_resize(frame)
51
+ frame = self.frame_process(frame)
52
+ frames.append(frame)
53
+
54
+ reader.close()
55
+
56
+ if len(frames) == 0:
57
+ return None
58
+
59
+ frames = torch.stack(frames, dim=0)
60
+ frames = rearrange(frames, "T C H W -> C T H W")
61
+ return frames
62
+
63
+ def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
64
+ """编码所有场景的视频"""
65
+
66
+ encoder = VideoEncoder(text_encoder_path, vae_path)
67
+ encoder = encoder.cuda()
68
+ encoder.pipe.device = "cuda"
69
+
70
+ processed_count = 0
71
+ prompt_emb = 0
72
+
73
+ os.makedirs(output_dir,exist_ok=True)
74
+
75
+ for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
76
+ # if i < 1700:
77
+ # continue
78
+ scene_dir = os.path.join(scenes_path, scene_name)
79
+ save_dir = os.path.join(output_dir,scene_name.split('.')[0])
80
+ # print('in:',scene_dir)
81
+ # print('out:',save_dir)
82
+
83
+ if not scene_dir.endswith(".mp4"):# or os.path.isdir(output_dir):
84
+ continue
85
+
86
+
87
+ os.makedirs(save_dir,exist_ok=True)
88
+ # 检查是否已编码
89
+ encoded_path = os.path.join(save_dir, "encoded_video.pth")
90
+ if os.path.exists(encoded_path):
91
+ print(f"Scene {scene_name} already encoded, skipping...")
92
+ continue
93
+
94
+ # 加载场景信息
95
+
96
+ scene_cam_path = scene_dir.replace(".mp4", ".npz")
97
+ if not os.path.exists(scene_cam_path):
98
+ continue
99
+
100
+ with np.load(scene_cam_path) as data:
101
+ cam_data = data.files
102
+ cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
103
+ # with open(scene_cam_path, 'rb') as f:
104
+ # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
105
+
106
+ # 加载视频
107
+ video_path = scene_dir
108
+ if not os.path.exists(video_path):
109
+ print(f"Video not found: {video_path}")
110
+ continue
111
+
112
+ # try:
113
+ print(f"Encoding scene {scene_name}...")
114
+
115
+ # 加载和编码视频
116
+ video_frames = encoder.load_video_frames(video_path)
117
+ if video_frames is None:
118
+ print(f"Failed to load video: {video_path}")
119
+ continue
120
+
121
+ video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
122
+ print('video shape:',video_frames.shape)
123
+ # 编码视频
124
+ with torch.no_grad():
125
+ latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
126
+
127
+ # 编码文本
128
+ if processed_count == 0:
129
+ print('encode prompt!!!')
130
+ prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
131
+ del encoder.pipe.prompter
132
+ # pdb.set_trace()
133
+ # 保存编码结果
134
+ encoded_data = {
135
+ "latents": latents.cpu(),
136
+ #"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
137
+ "cam_emb": cam_emb
138
+ }
139
+ # pdb.set_trace()
140
+ torch.save(encoded_data, encoded_path)
141
+ print(f"Saved encoded data: {encoded_path}")
142
+ processed_count += 1
143
+
144
+ # except Exception as e:
145
+ # print(f"Error encoding scene {scene_name}: {e}")
146
+ # continue
147
+
148
+ print(f"Encoding completed! Processed {processed_count} scenes.")
149
+
150
+ if __name__ == "__main__":
151
+ parser = argparse.ArgumentParser()
152
+ parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/sekai/Sekai-Project/sekai-game-walking")
153
+ parser.add_argument("--text_encoder_path", type=str,
154
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
155
+ parser.add_argument("--vae_path", type=str,
156
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
157
+
158
+ parser.add_argument("--output_dir",type=str,
159
+ default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking")
160
+
161
+ args = parser.parse_args()
162
+ encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
scripts/encode_sekai_walking.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ import lightning as pl
5
+ from PIL import Image
6
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
7
+ import json
8
+ import imageio
9
+ from torchvision.transforms import v2
10
+ from einops import rearrange
11
+ import argparse
12
+ import numpy as np
13
+ import pdb
14
+ from tqdm import tqdm
15
+
16
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
17
+
18
+ class VideoEncoder(pl.LightningModule):
19
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
20
+ super().__init__()
21
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
22
+ model_manager.load_models([text_encoder_path, vae_path])
23
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
24
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
25
+
26
+ self.frame_process = v2.Compose([
27
+ # v2.CenterCrop(size=(900, 1600)),
28
+ # v2.Resize(size=(900, 1600), antialias=True),
29
+ v2.ToTensor(),
30
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
31
+ ])
32
+
33
+ def crop_and_resize(self, image):
34
+ width, height = image.size
35
+ # print(width,height)
36
+ width_ori, height_ori_ = 832 , 480
37
+ image = v2.functional.resize(
38
+ image,
39
+ (round(height_ori_), round(width_ori)),
40
+ interpolation=v2.InterpolationMode.BILINEAR
41
+ )
42
+ return image
43
+
44
+ def load_video_frames(self, video_path):
45
+ """加载完整视频"""
46
+ reader = imageio.get_reader(video_path)
47
+ frames = []
48
+
49
+ for frame_data in reader:
50
+ frame = Image.fromarray(frame_data)
51
+ frame = self.crop_and_resize(frame)
52
+ frame = self.frame_process(frame)
53
+ frames.append(frame)
54
+
55
+ reader.close()
56
+
57
+ if len(frames) == 0:
58
+ return None
59
+
60
+ frames = torch.stack(frames, dim=0)
61
+ frames = rearrange(frames, "T C H W -> C T H W")
62
+ return frames
63
+
64
+ def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
65
+ """编码所有场景的视频"""
66
+
67
+ encoder = VideoEncoder(text_encoder_path, vae_path)
68
+ encoder = encoder.cuda()
69
+ encoder.pipe.device = "cuda"
70
+
71
+ processed_count = 0
72
+
73
+ processed_chunk_count = 0
74
+
75
+ prompt_emb = 0
76
+
77
+ os.makedirs(output_dir,exist_ok=True)
78
+ chunk_size = 300
79
+ for i, scene_name in tqdm(enumerate(os.listdir(scenes_path)),total=len(os.listdir(scenes_path))):
80
+ # print('index-----:',type(i))
81
+ # if i < 3000 :#or i >=2000:
82
+ # # print('index-----:',i)
83
+ # continue
84
+ # print('index:',i)
85
+ print('index:',i)
86
+ scene_dir = os.path.join(scenes_path, scene_name)
87
+
88
+ # save_dir = os.path.join(output_dir,scene_name.split('.')[0])
89
+ # print('in:',scene_dir)
90
+ # print('out:',save_dir)
91
+
92
+ if not scene_dir.endswith(".mp4"):# or os.path.isdir(output_dir):
93
+ continue
94
+
95
+
96
+ scene_cam_path = scene_dir.replace(".mp4", ".npz")
97
+ if not os.path.exists(scene_cam_path):
98
+ continue
99
+
100
+ with np.load(scene_cam_path) as data:
101
+ cam_data = data.files
102
+ cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
103
+ # with open(scene_cam_path, 'rb') as f:
104
+ # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
105
+
106
+ video_name = scene_name[:-4].split('_')[0]
107
+ start_frame = int(scene_name[:-4].split('_')[1])
108
+ end_frame = int(scene_name[:-4].split('_')[2])
109
+
110
+ sampled_range = range(start_frame, end_frame , chunk_size)
111
+ sampled_frames = list(sampled_range)
112
+
113
+ sampled_chunk_end = sampled_frames[0] + 300
114
+ start_str = f"{sampled_frames[0]:07d}"
115
+ end_str = f"{sampled_chunk_end:07d}"
116
+
117
+ chunk_name = f"{video_name}_{start_str}_{end_str}"
118
+ save_chunk_path = os.path.join(output_dir,chunk_name,"encoded_video.pth")
119
+
120
+ if os.path.exists(save_chunk_path):
121
+ print(f"Video {video_name} already encoded, skipping...")
122
+ continue
123
+
124
+ # 加载视频
125
+ video_path = scene_dir
126
+ if not os.path.exists(video_path):
127
+ print(f"Video not found: {video_path}")
128
+ continue
129
+
130
+ video_frames = encoder.load_video_frames(video_path)
131
+ if video_frames is None:
132
+ print(f"Failed to load video: {video_path}")
133
+ continue
134
+
135
+ video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
136
+ print('video shape:',video_frames.shape)
137
+
138
+
139
+
140
+ # print(sampled_frames)
141
+
142
+ print(f"Encoding scene {scene_name}...")
143
+ for sampled_chunk_start in sampled_frames:
144
+ sampled_chunk_end = sampled_chunk_start + 300
145
+ start_str = f"{sampled_chunk_start:07d}"
146
+ end_str = f"{sampled_chunk_end:07d}"
147
+
148
+ # 生成保存目录名(假设video_name已定义)
149
+ chunk_name = f"{video_name}_{start_str}_{end_str}"
150
+ save_chunk_dir = os.path.join(output_dir,chunk_name)
151
+
152
+ os.makedirs(save_chunk_dir,exist_ok=True)
153
+ print(f"Encoding chunk {chunk_name}...")
154
+
155
+ encoded_path = os.path.join(save_chunk_dir, "encoded_video.pth")
156
+
157
+ if os.path.exists(encoded_path):
158
+ print(f"Chunk {chunk_name} already encoded, skipping...")
159
+ continue
160
+
161
+
162
+ chunk_frames = video_frames[:,:, sampled_chunk_start - start_frame : sampled_chunk_end - start_frame,...]
163
+ # print('extrinsic:',cam_emb['extrinsic'].shape)
164
+ chunk_cam_emb ={'extrinsic':cam_emb['extrinsic'][sampled_chunk_start - start_frame : sampled_chunk_end - start_frame],
165
+ 'intrinsic':cam_emb['intrinsic']}
166
+
167
+ # print('chunk shape:',chunk_frames.shape)
168
+
169
+ with torch.no_grad():
170
+ latents = encoder.pipe.encode_video(chunk_frames, **encoder.tiler_kwargs)[0]
171
+
172
+ # 编码文本
173
+ # if processed_count == 0:
174
+ # print('encode prompt!!!')
175
+ # prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
176
+ # del encoder.pipe.prompter
177
+ # pdb.set_trace()
178
+ # 保存编码结果
179
+ encoded_data = {
180
+ "latents": latents.cpu(),
181
+ # "prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
182
+ "cam_emb": chunk_cam_emb
183
+ }
184
+ # pdb.set_trace()
185
+ torch.save(encoded_data, encoded_path)
186
+ print(f"Saved encoded data: {encoded_path}")
187
+ processed_chunk_count += 1
188
+
189
+ processed_count += 1
190
+
191
+ print("Encoded scene numebr:",processed_count)
192
+ print("Encoded chunk numebr:",processed_chunk_count)
193
+
194
+ # os.makedirs(save_dir,exist_ok=True)
195
+ # # 检查是否已编码
196
+ # encoded_path = os.path.join(save_dir, "encoded_video.pth")
197
+ # if os.path.exists(encoded_path):
198
+ # print(f"Scene {scene_name} already encoded, skipping...")
199
+ # continue
200
+
201
+ # 加载场景信息
202
+
203
+
204
+
205
+ # try:
206
+ # print(f"Encoding scene {scene_name}...")
207
+
208
+ # 加载和编码视频
209
+
210
+ # 编码视频
211
+ # with torch.no_grad():
212
+ # latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
213
+
214
+ # # 编码文本
215
+ # if processed_count == 0:
216
+ # print('encode prompt!!!')
217
+ # prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
218
+ # del encoder.pipe.prompter
219
+ # # pdb.set_trace()
220
+ # # 保存编码结果
221
+ # encoded_data = {
222
+ # "latents": latents.cpu(),
223
+ # #"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
224
+ # "cam_emb": cam_emb
225
+ # }
226
+ # # pdb.set_trace()
227
+ # torch.save(encoded_data, encoded_path)
228
+ # print(f"Saved encoded data: {encoded_path}")
229
+ # processed_count += 1
230
+
231
+ # except Exception as e:
232
+ # print(f"Error encoding scene {scene_name}: {e}")
233
+ # continue
234
+
235
+ print(f"Encoding completed! Processed {processed_count} scenes.")
236
+
237
+ if __name__ == "__main__":
238
+ parser = argparse.ArgumentParser()
239
+ parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/sekai/Sekai-Project/sekai-game-walking")
240
+ parser.add_argument("--text_encoder_path", type=str,
241
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
242
+ parser.add_argument("--vae_path", type=str,
243
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
244
+
245
+ parser.add_argument("--output_dir",type=str,
246
+ default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking")
247
+
248
+ args = parser.parse_args()
249
+ encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
scripts/encode_spatialvid.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ import lightning as pl
5
+ from PIL import Image
6
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
7
+ import json
8
+ import imageio
9
+ from torchvision.transforms import v2
10
+ from einops import rearrange
11
+ import argparse
12
+ import numpy as np
13
+ import pdb
14
+ from tqdm import tqdm
15
+ import pandas as pd
16
+
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+
19
+ from scipy.spatial.transform import Slerp
20
+ from scipy.spatial.transform import Rotation as R
21
+
22
+ def interpolate_camera_poses(original_frames, original_poses, target_frames):
23
+ """
24
+ 对相机姿态进行插值,生成目标帧对应的姿态参数
25
+
26
+ 参数:
27
+ original_frames: 原始帧索引列表,如[0,6,12,...]
28
+ original_poses: 原始姿态数组,形状为(n,7),每行[tx, ty, tz, qx, qy, qz, qw]
29
+ target_frames: 目标帧索引列表,如[0,4,8,12,...]
30
+
31
+ 返回:
32
+ target_poses: 插值后的姿态数组,形状为(m,7),m为目标帧数量
33
+ """
34
+ # 确保输入有效
35
+ print('original_frames:',len(original_frames))
36
+ print('original_poses:',len(original_poses))
37
+ if len(original_frames) != len(original_poses):
38
+ raise ValueError("原始帧数量与姿态数量不匹配")
39
+
40
+ if original_poses.shape[1] != 7:
41
+ raise ValueError(f"原始姿态应为(n,7)格式,实际为{original_poses.shape}")
42
+
43
+ target_poses = []
44
+
45
+ # 提取旋转部分并转换为Rotation对象
46
+ rotations = R.from_quat(original_poses[:, 3:7]) # 提取四元数部分
47
+
48
+ for t in target_frames:
49
+ # 找到t前后的原始帧索引
50
+ idx = np.searchsorted(original_frames, t, side='left')
51
+
52
+ # 处理边界情况
53
+ if idx == 0:
54
+ # 使用第一个姿态
55
+ target_poses.append(original_poses[0])
56
+ continue
57
+ if idx >= len(original_frames):
58
+ # 使用最后一个姿态
59
+ target_poses.append(original_poses[-1])
60
+ continue
61
+
62
+ # 获取前后帧的信息
63
+ t_prev, t_next = original_frames[idx-1], original_frames[idx]
64
+ pose_prev, pose_next = original_poses[idx-1], original_poses[idx]
65
+
66
+ # 计算插值权重
67
+ alpha = (t - t_prev) / (t_next - t_prev)
68
+
69
+ # 1. 平移向量的线性插值
70
+ translation_prev = pose_prev[:3]
71
+ translation_next = pose_next[:3]
72
+ interpolated_translation = translation_prev + alpha * (translation_next - translation_prev)
73
+
74
+ # 2. 旋转四元数的球面线性插值(SLERP)
75
+ # 创建Slerp对象
76
+ slerp = Slerp([t_prev, t_next], rotations[idx-1:idx+1])
77
+ interpolated_rotation = slerp(t)
78
+
79
+ # 组合平移和旋转
80
+ interpolated_pose = np.concatenate([
81
+ interpolated_translation,
82
+ interpolated_rotation.as_quat() # 转换回四元数
83
+ ])
84
+
85
+ target_poses.append(interpolated_pose)
86
+
87
+ return np.array(target_poses)
88
+
89
+
90
+ class VideoEncoder(pl.LightningModule):
91
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
92
+ super().__init__()
93
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
94
+ model_manager.load_models([text_encoder_path, vae_path])
95
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
96
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
97
+
98
+ self.frame_process = v2.Compose([
99
+ # v2.CenterCrop(size=(900, 1600)),
100
+ # v2.Resize(size=(900, 1600), antialias=True),
101
+ v2.ToTensor(),
102
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
103
+ ])
104
+
105
+ def crop_and_resize(self, image):
106
+ width, height = image.size
107
+ # print(width,height)
108
+ width_ori, height_ori_ = 832 , 480
109
+ image = v2.functional.resize(
110
+ image,
111
+ (round(height_ori_), round(width_ori)),
112
+ interpolation=v2.InterpolationMode.BILINEAR
113
+ )
114
+ return image
115
+
116
+ def load_video_frames(self, video_path):
117
+ """加载完整视频"""
118
+ reader = imageio.get_reader(video_path)
119
+ frames = []
120
+
121
+ for frame_data in reader:
122
+ frame = Image.fromarray(frame_data)
123
+ frame = self.crop_and_resize(frame)
124
+ frame = self.frame_process(frame)
125
+ frames.append(frame)
126
+
127
+ reader.close()
128
+
129
+ if len(frames) == 0:
130
+ return None
131
+
132
+ frames = torch.stack(frames, dim=0)
133
+ frames = rearrange(frames, "T C H W -> C T H W")
134
+ return frames
135
+
136
+ def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
137
+ """编码所有场景的视频"""
138
+
139
+ encoder = VideoEncoder(text_encoder_path, vae_path)
140
+ encoder = encoder.cuda()
141
+ encoder.pipe.device = "cuda"
142
+
143
+ processed_count = 0
144
+
145
+ processed_chunk_count = 0
146
+
147
+ prompt_emb = 0
148
+
149
+ metadata = pd.read_csv('/share_zhuyixuan05/public_datasets/SpatialVID-HQ/data/train/SpatialVID_HQ_metadata.csv')
150
+
151
+
152
+ os.makedirs(output_dir,exist_ok=True)
153
+ chunk_size = 300
154
+ required_keys = ["latents", "cam_emb", "prompt_emb"]
155
+
156
+ for i, scene_name in enumerate(os.listdir(scenes_path)):
157
+ # print('index-----:',type(i))
158
+ if i < 3 :#or i >=2000:
159
+ # # print('index-----:',i)
160
+ continue
161
+ # print('index:',i)
162
+ print('group:',i)
163
+ scene_dir = os.path.join(scenes_path, scene_name)
164
+
165
+ # save_dir = os.path.join(output_dir,scene_name.split('.')[0])
166
+ print('in:',scene_dir)
167
+ # print('out:',save_dir)
168
+ for j, video_name in tqdm(enumerate(os.listdir(scene_dir)),total=len(os.listdir(scene_dir))):
169
+
170
+ # if j < 1000 :#or i >=2000:
171
+ # print('index:',j)
172
+ # continue
173
+ print(video_name)
174
+ video_path = os.path.join(scene_dir, video_name)
175
+ if not video_path.endswith(".mp4"):# or os.path.isdir(output_dir):
176
+ continue
177
+
178
+ video_info = metadata[metadata['id'] == video_name[:-4]]
179
+ num_frames = video_info['num frames'].iloc[0]
180
+
181
+ scene_cam_dir = video_path.replace( "videos","annotations")[:-4]
182
+ scene_cam_path = os.path.join(scene_cam_dir,'poses.npy')
183
+
184
+ scene_caption_path = os.path.join(scene_cam_dir,'caption.json')
185
+
186
+ with open(scene_caption_path, 'r', encoding='utf-8') as f:
187
+ caption_data = json.load(f)
188
+ caption = caption_data["SceneSummary"]
189
+ if not os.path.exists(scene_cam_path):
190
+ print(f"Pose not found: {scene_cam_path}")
191
+ continue
192
+
193
+ camera_poses = np.load(scene_cam_path)
194
+ cam_data_len = camera_poses.shape[0]
195
+
196
+ # cam_emb = {k: data[k].cpu() if isinstance(data[k], torch.Tensor) else data[k] for k in cam_data}
197
+ # with open(scene_cam_path, 'rb') as f:
198
+ # cam_data = np.load(f) # 此时cam_data仅包含数据,无文件句柄引用
199
+
200
+ # 加载视频
201
+ # video_path = scene_dir
202
+ if not os.path.exists(video_path):
203
+ print(f"Video not found: {video_path}")
204
+ continue
205
+
206
+ start_str = f"{0:07d}"
207
+ end_str = f"{chunk_size:07d}"
208
+ chunk_name = f"{video_name[:-4]}_{start_str}_{end_str}"
209
+ first_save_chunk_dir = os.path.join(output_dir,chunk_name)
210
+
211
+ first_chunk_encoded_path = os.path.join(first_save_chunk_dir, "encoded_video.pth")
212
+ # print(first_chunk_encoded_path)
213
+ if os.path.exists(first_chunk_encoded_path):
214
+ data = torch.load(first_chunk_encoded_path,weights_only=False)
215
+ if 'latents' in data:
216
+ video_frames = 1
217
+ else:
218
+ video_frames = encoder.load_video_frames(video_path)
219
+ if video_frames is None:
220
+ print(f"Failed to load video: {video_path}")
221
+ continue
222
+ print('video shape:',video_frames.shape)
223
+
224
+
225
+
226
+ video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
227
+ print('video shape:',video_frames.shape)
228
+
229
+ video_name = video_name[:-4].split('_')[0]
230
+ start_frame = 0
231
+ end_frame = num_frames
232
+ # print("num_frames:",num_frames)
233
+
234
+ cam_interval = end_frame // (cam_data_len - 1)
235
+
236
+ cam_frames = np.linspace(start_frame, end_frame, cam_data_len, endpoint=True)
237
+ cam_frames = np.round(cam_frames).astype(int)
238
+ cam_frames = cam_frames.tolist()
239
+ # list(range(0, end_frame + 1 , cam_interval))
240
+
241
+
242
+ sampled_range = range(start_frame, end_frame , chunk_size)
243
+ sampled_frames = list(sampled_range)
244
+
245
+ sampled_chunk_end = sampled_frames[0] + chunk_size
246
+ start_str = f"{sampled_frames[0]:07d}"
247
+ end_str = f"{sampled_chunk_end:07d}"
248
+
249
+ chunk_name = f"{video_name}_{start_str}_{end_str}"
250
+ # save_chunk_path = os.path.join(output_dir,chunk_name,"encoded_video.pth")
251
+
252
+ # if os.path.exists(save_chunk_path):
253
+ # print(f"Video {video_name} already encoded, skipping...")
254
+ # continue
255
+
256
+
257
+
258
+
259
+
260
+ # print(sampled_frames)
261
+
262
+ print(f"Encoding scene {video_name}...")
263
+ chunk_count_in_one_video = 0
264
+ for sampled_chunk_start in sampled_frames:
265
+ if num_frames - sampled_chunk_start < 100:
266
+ continue
267
+ sampled_chunk_end = sampled_chunk_start + chunk_size
268
+ start_str = f"{sampled_chunk_start:07d}"
269
+ end_str = f"{sampled_chunk_end:07d}"
270
+
271
+ resample_cam_frame = list(range(sampled_chunk_start, sampled_chunk_end , 4))
272
+
273
+ # 生成保存目录名(假设video_name已定义)
274
+ chunk_name = f"{video_name}_{start_str}_{end_str}"
275
+ save_chunk_dir = os.path.join(output_dir,chunk_name)
276
+
277
+ os.makedirs(save_chunk_dir,exist_ok=True)
278
+ print(f"Encoding chunk {chunk_name}...")
279
+
280
+ encoded_path = os.path.join(save_chunk_dir, "encoded_video.pth")
281
+
282
+ missing_keys = required_keys
283
+ if os.path.exists(encoded_path):
284
+ print('error:',encoded_path)
285
+ data = torch.load(encoded_path,weights_only=False)
286
+ missing_keys = [key for key in required_keys if key not in data]
287
+ # print(missing_keys)
288
+ # print(f"Chunk {chunk_name} already encoded, skipping...")
289
+ if missing_keys:
290
+ print(f"警告: 文件中缺少以下必要元素: {missing_keys}")
291
+ if len(missing_keys) == 0 :
292
+ continue
293
+ else:
294
+ print(f"警告: 缺少pth文件: {encoded_path}")
295
+ if not isinstance(video_frames, torch.Tensor):
296
+
297
+ video_frames = encoder.load_video_frames(video_path)
298
+ if video_frames is None:
299
+ print(f"Failed to load video: {video_path}")
300
+ continue
301
+
302
+ video_frames = video_frames.unsqueeze(0).to("cuda", dtype=torch.bfloat16)
303
+
304
+ print('video shape:',video_frames.shape)
305
+ if "latents" in missing_keys:
306
+ chunk_frames = video_frames[:,:, sampled_chunk_start - start_frame : sampled_chunk_end - start_frame,...]
307
+
308
+ # print('extrinsic:',cam_emb['extrinsic'].shape)
309
+
310
+ # chunk_cam_emb ={'extrinsic':cam_emb['extrinsic'][sampled_chunk_start - start_frame : sampled_chunk_end - start_frame],
311
+ # 'intrinsic':cam_emb['intrinsic']}
312
+
313
+ # print('chunk shape:',chunk_frames.shape)
314
+
315
+ with torch.no_grad():
316
+ latents = encoder.pipe.encode_video(chunk_frames, **encoder.tiler_kwargs)[0]
317
+ else:
318
+ latents = data['latents']
319
+ if "cam_emb" in missing_keys:
320
+ cam_emb = interpolate_camera_poses(cam_frames, camera_poses,resample_cam_frame)
321
+ chunk_cam_emb ={'extrinsic':cam_emb}
322
+ print(f"视频长度:{chunk_size},重采样相机长度:{cam_emb.shape[0]}")
323
+ else:
324
+ chunk_cam_emb = data['cam_emb']
325
+
326
+ if "prompt_emb" in missing_keys:
327
+ # 编码文本
328
+ if chunk_count_in_one_video == 0:
329
+ print(caption)
330
+ with torch.no_grad():
331
+ prompt_emb = encoder.pipe.encode_prompt(caption)
332
+ else:
333
+ prompt_emb = data['prompt_emb']
334
+
335
+ # del encoder.pipe.prompter
336
+ # pdb.set_trace()
337
+ # 保存编码结果
338
+ encoded_data = {
339
+ "latents": latents.cpu(),
340
+ "prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
341
+ "cam_emb": chunk_cam_emb
342
+ }
343
+ # pdb.set_trace()
344
+ torch.save(encoded_data, encoded_path)
345
+ print(f"Saved encoded data: {encoded_path}")
346
+ processed_chunk_count += 1
347
+ chunk_count_in_one_video += 1
348
+
349
+ processed_count += 1
350
+
351
+ print("Encoded scene numebr:",processed_count)
352
+ print("Encoded chunk numebr:",processed_chunk_count)
353
+
354
+ # os.makedirs(save_dir,exist_ok=True)
355
+ # # 检查是否已编码
356
+ # encoded_path = os.path.join(save_dir, "encoded_video.pth")
357
+ # if os.path.exists(encoded_path):
358
+ # print(f"Scene {scene_name} already encoded, skipping...")
359
+ # continue
360
+
361
+ # 加载场景信息
362
+
363
+
364
+
365
+ # try:
366
+ # print(f"Encoding scene {scene_name}...")
367
+
368
+ # 加载和编码视频
369
+
370
+ # 编码视频
371
+ # with torch.no_grad():
372
+ # latents = encoder.pipe.encode_video(video_frames, **encoder.tiler_kwargs)[0]
373
+
374
+ # # 编码文本
375
+ # if processed_count == 0:
376
+ # print('encode prompt!!!')
377
+ # prompt_emb = encoder.pipe.encode_prompt("A video of a scene shot using a pedestrian's front camera while walking")
378
+ # del encoder.pipe.prompter
379
+ # # pdb.set_trace()
380
+ # # 保存编码结果
381
+ # encoded_data = {
382
+ # "latents": latents.cpu(),
383
+ # #"prompt_emb": {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in prompt_emb.items()},
384
+ # "cam_emb": cam_emb
385
+ # }
386
+ # # pdb.set_trace()
387
+ # torch.save(encoded_data, encoded_path)
388
+ # print(f"Saved encoded data: {encoded_path}")
389
+ # processed_count += 1
390
+
391
+ # except Exception as e:
392
+ # print(f"Error encoding scene {scene_name}: {e}")
393
+ # continue
394
+
395
+ print(f"Encoding completed! Processed {processed_count} scenes.")
396
+
397
+ if __name__ == "__main__":
398
+ parser = argparse.ArgumentParser()
399
+ parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/SpatialVID-HQ/SpatialVid/HQ/videos/")
400
+ parser.add_argument("--text_encoder_path", type=str,
401
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
402
+ parser.add_argument("--vae_path", type=str,
403
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
404
+
405
+ parser.add_argument("--output_dir",type=str,
406
+ default="/share_zhuyixuan05/zhuyixuan05/spatialvid")
407
+
408
+ args = parser.parse_args()
409
+ encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
scripts/encode_spatialvid_first_frame.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ import lightning as pl
5
+ from PIL import Image
6
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
7
+ import json
8
+ import imageio
9
+ from torchvision.transforms import v2
10
+ from einops import rearrange
11
+ import argparse
12
+ import numpy as np
13
+ import pdb
14
+ from tqdm import tqdm
15
+ import pandas as pd
16
+
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+
19
+ from scipy.spatial.transform import Slerp
20
+ from scipy.spatial.transform import Rotation as R
21
+
22
+ def interpolate_camera_poses(original_frames, original_poses, target_frames):
23
+ """
24
+ 对相机姿态进行插值,生成目标帧对应的姿态参数
25
+
26
+ 参数:
27
+ original_frames: 原始帧索引列表,如[0,6,12,...]
28
+ original_poses: 原始姿态数组,形状为(n,7),每行[tx, ty, tz, qx, qy, qz, qw]
29
+ target_frames: 目标帧索引列表,如[0,4,8,12,...]
30
+
31
+ 返回:
32
+ target_poses: 插值后的姿态数组,形状为(m,7),m为目标帧数量
33
+ """
34
+ # 确保输入有效
35
+ print('original_frames:',len(original_frames))
36
+ print('original_poses:',len(original_poses))
37
+ if len(original_frames) != len(original_poses):
38
+ raise ValueError("原始帧数量与姿态数量不匹配")
39
+
40
+ if original_poses.shape[1] != 7:
41
+ raise ValueError(f"原始姿态应为(n,7)格式,实际为{original_poses.shape}")
42
+
43
+ target_poses = []
44
+
45
+ # 提取旋转部分并转换为Rotation对象
46
+ rotations = R.from_quat(original_poses[:, 3:7]) # 提取四元数部分
47
+
48
+ for t in target_frames:
49
+ # 找到t前后的原始帧索引
50
+ idx = np.searchsorted(original_frames, t, side='left')
51
+
52
+ # 处理边界情况
53
+ if idx == 0:
54
+ # 使用第一个姿态
55
+ target_poses.append(original_poses[0])
56
+ continue
57
+ if idx >= len(original_frames):
58
+ # 使用最后一个姿态
59
+ target_poses.append(original_poses[-1])
60
+ continue
61
+
62
+ # 获取前后帧的信息
63
+ t_prev, t_next = original_frames[idx-1], original_frames[idx]
64
+ pose_prev, pose_next = original_poses[idx-1], original_poses[idx]
65
+
66
+ # 计算插值权重
67
+ alpha = (t - t_prev) / (t_next - t_prev)
68
+
69
+ # 1. 平移向量的线性插值
70
+ translation_prev = pose_prev[:3]
71
+ translation_next = pose_next[:3]
72
+ interpolated_translation = translation_prev + alpha * (translation_next - translation_prev)
73
+
74
+ # 2. 旋转四元数的球面线性插值(SLERP)
75
+ # 创建Slerp对象
76
+ slerp = Slerp([t_prev, t_next], rotations[idx-1:idx+1])
77
+ interpolated_rotation = slerp(t)
78
+
79
+ # 组合平移和旋转
80
+ interpolated_pose = np.concatenate([
81
+ interpolated_translation,
82
+ interpolated_rotation.as_quat() # 转换回四元数
83
+ ])
84
+
85
+ target_poses.append(interpolated_pose)
86
+
87
+ return np.array(target_poses)
88
+
89
+ class VideoEncoder(pl.LightningModule):
90
+ def __init__(self, text_encoder_path, vae_path, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
91
+ super().__init__()
92
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
93
+ model_manager.load_models([text_encoder_path, vae_path])
94
+ self.pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager)
95
+ self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
96
+
97
+ self.frame_process = v2.Compose([
98
+ v2.ToTensor(),
99
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
100
+ ])
101
+
102
+ def crop_and_resize(self, image):
103
+ width, height = image.size
104
+ width_ori, height_ori_ = 832 , 480
105
+ image = v2.functional.resize(
106
+ image,
107
+ (round(height_ori_), round(width_ori)),
108
+ interpolation=v2.InterpolationMode.BILINEAR
109
+ )
110
+ return image
111
+
112
+ def load_single_frame(self, video_path, frame_idx):
113
+ """只加载指定的单帧"""
114
+ reader = imageio.get_reader(video_path)
115
+
116
+ try:
117
+ # 直接跳转到指定帧
118
+ frame_data = reader.get_data(frame_idx)
119
+ frame = Image.fromarray(frame_data)
120
+ frame = self.crop_and_resize(frame)
121
+ frame = self.frame_process(frame)
122
+
123
+ # 添加batch和time维度: [C, H, W] -> [1, C, 1, H, W]
124
+ frame = frame.unsqueeze(0).unsqueeze(2)
125
+
126
+ except Exception as e:
127
+ print(f"Error loading frame {frame_idx} from {video_path}: {e}")
128
+ return None
129
+ finally:
130
+ reader.close()
131
+
132
+ return frame
133
+
134
+ def load_video_frames(self, video_path):
135
+ """加载完整视频(保留用于兼容性)"""
136
+ reader = imageio.get_reader(video_path)
137
+ frames = []
138
+
139
+ for frame_data in reader:
140
+ frame = Image.fromarray(frame_data)
141
+ frame = self.crop_and_resize(frame)
142
+ frame = self.frame_process(frame)
143
+ frames.append(frame)
144
+
145
+ reader.close()
146
+
147
+ if len(frames) == 0:
148
+ return None
149
+
150
+ frames = torch.stack(frames, dim=0)
151
+ frames = rearrange(frames, "T C H W -> C T H W")
152
+ return frames
153
+
154
+ def encode_scenes(scenes_path, text_encoder_path, vae_path,output_dir):
155
+ """编码所有场景的视频"""
156
+
157
+ encoder = VideoEncoder(text_encoder_path, vae_path)
158
+ encoder = encoder.cuda()
159
+ encoder.pipe.device = "cuda"
160
+
161
+ processed_count = 0
162
+ processed_chunk_count = 0
163
+
164
+ metadata = pd.read_csv('/share_zhuyixuan05/public_datasets/SpatialVID-HQ/data/train/SpatialVID_HQ_metadata.csv')
165
+
166
+ os.makedirs(output_dir,exist_ok=True)
167
+ chunk_size = 300
168
+
169
+ for i, scene_name in enumerate(os.listdir(scenes_path)):
170
+ if i < 2:
171
+ continue
172
+ print('group:',i)
173
+ scene_dir = os.path.join(scenes_path, scene_name)
174
+
175
+ print('in:',scene_dir)
176
+ for j, video_name in tqdm(enumerate(os.listdir(scene_dir)),total=len(os.listdir(scene_dir))):
177
+ print(video_name)
178
+ video_path = os.path.join(scene_dir, video_name)
179
+ if not video_path.endswith(".mp4"):
180
+ continue
181
+
182
+ video_info = metadata[metadata['id'] == video_name[:-4]]
183
+ num_frames = video_info['num frames'].iloc[0]
184
+
185
+ scene_cam_dir = video_path.replace("videos","annotations")[:-4]
186
+ scene_cam_path = os.path.join(scene_cam_dir,'poses.npy')
187
+ scene_caption_path = os.path.join(scene_cam_dir,'caption.json')
188
+
189
+ with open(scene_caption_path, 'r', encoding='utf-8') as f:
190
+ caption_data = json.load(f)
191
+ caption = caption_data["SceneSummary"]
192
+
193
+ if not os.path.exists(scene_cam_path):
194
+ print(f"Pose not found: {scene_cam_path}")
195
+ continue
196
+
197
+ camera_poses = np.load(scene_cam_path)
198
+ cam_data_len = camera_poses.shape[0]
199
+
200
+ if not os.path.exists(video_path):
201
+ print(f"Video not found: {video_path}")
202
+ continue
203
+
204
+ video_name = video_name[:-4].split('_')[0]
205
+ start_frame = 0
206
+ end_frame = num_frames
207
+
208
+ cam_interval = end_frame // (cam_data_len - 1)
209
+
210
+ cam_frames = np.linspace(start_frame, end_frame, cam_data_len, endpoint=True)
211
+ cam_frames = np.round(cam_frames).astype(int)
212
+ cam_frames = cam_frames.tolist()
213
+
214
+ sampled_range = range(start_frame, end_frame, chunk_size)
215
+ sampled_frames = list(sampled_range)
216
+
217
+ print(f"Encoding scene {video_name}...")
218
+ chunk_count_in_one_video = 0
219
+
220
+ for sampled_chunk_start in sampled_frames:
221
+ if num_frames - sampled_chunk_start < 100:
222
+ continue
223
+
224
+ sampled_chunk_end = sampled_chunk_start + chunk_size
225
+ start_str = f"{sampled_chunk_start:07d}"
226
+ end_str = f"{sampled_chunk_end:07d}"
227
+
228
+ chunk_name = f"{video_name}_{start_str}_{end_str}"
229
+ save_chunk_dir = os.path.join(output_dir, chunk_name)
230
+ os.makedirs(save_chunk_dir, exist_ok=True)
231
+
232
+ print(f"Encoding chunk {chunk_name}...")
233
+
234
+ first_latent_path = os.path.join(save_chunk_dir, "first_latent.pth")
235
+
236
+ if os.path.exists(first_latent_path):
237
+ print(f"First latent for chunk {chunk_name} already exists, skipping...")
238
+ continue
239
+
240
+ # 只加载需要的那一帧
241
+ first_frame_idx = sampled_chunk_start
242
+ print(f"first_frame:{first_frame_idx}")
243
+ first_frame = encoder.load_single_frame(video_path, first_frame_idx)
244
+
245
+ if first_frame is None:
246
+ print(f"Failed to load frame {first_frame_idx} from: {video_path}")
247
+ continue
248
+
249
+ first_frame = first_frame.to("cuda", dtype=torch.bfloat16)
250
+
251
+ # 重复4次
252
+ repeated_first_frame = first_frame.repeat(1, 1, 4, 1, 1)
253
+ print(f"Repeated first frame shape: {repeated_first_frame.shape}")
254
+
255
+ with torch.no_grad():
256
+ first_latents = encoder.pipe.encode_video(repeated_first_frame, **encoder.tiler_kwargs)[0]
257
+
258
+ first_latent_data = {
259
+ "latents": first_latents.cpu(),
260
+ }
261
+ torch.save(first_latent_data, first_latent_path)
262
+ print(f"Saved first latent: {first_latent_path}")
263
+
264
+ processed_chunk_count += 1
265
+ chunk_count_in_one_video += 1
266
+
267
+ processed_count += 1
268
+ print("Encoded scene number:", processed_count)
269
+ print("Encoded chunk number:", processed_chunk_count)
270
+
271
+ print(f"Encoding completed! Processed {processed_count} scenes.")
272
+
273
+ if __name__ == "__main__":
274
+ parser = argparse.ArgumentParser()
275
+ parser.add_argument("--scenes_path", type=str, default="/share_zhuyixuan05/public_datasets/SpatialVID-HQ/SpatialVid/HQ/videos/")
276
+ parser.add_argument("--text_encoder_path", type=str,
277
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth")
278
+ parser.add_argument("--vae_path", type=str,
279
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth")
280
+
281
+ parser.add_argument("--output_dir",type=str,
282
+ default="/share_zhuyixuan05/zhuyixuan05/spatialvid")
283
+
284
+ args = parser.parse_args()
285
+ encode_scenes(args.scenes_path, args.text_encoder_path, args.vae_path,args.output_dir)
scripts/hud_logo.py CHANGED
@@ -7,7 +7,7 @@ os.makedirs("wasd_ui", exist_ok=True)
7
  key_size = (48, 48)
8
  corner = 10
9
  bg_padding = 6
10
- font = ImageFont.truetype("arial.ttf", 28) # Replace with locally supported font
11
 
12
  def rounded_rect(im, bbox, radius, fill):
13
  draw = ImageDraw.Draw(im, "RGBA")
 
7
  key_size = (48, 48)
8
  corner = 10
9
  bg_padding = 6
10
+ font = ImageFont.truetype("arial.ttf", 28) # 替换成本地支持的字体
11
 
12
  def rounded_rect(im, bbox, radius, fill):
13
  draw = ImageDraw.Draw(im, "RGBA")
scripts/infer_demo.py CHANGED
@@ -1,7 +1,5 @@
1
  import os
2
  import sys
3
- from pathlib import Path
4
- from typing import Optional
5
 
6
  ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
7
  sys.path.append(ROOT_DIR)
@@ -12,158 +10,48 @@ import numpy as np
12
  from PIL import Image
13
  import imageio
14
  import json
15
- from diffsynth import WanVideoAstraPipeline, ModelManager
16
  import argparse
17
  from torchvision.transforms import v2
18
  from einops import rearrange
19
- from scipy.spatial.transform import Rotation as R
20
  import random
21
  import copy
22
  from datetime import datetime
23
 
24
- VALID_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg"}
25
- class InlineVideoEncoder:
26
-
27
- def __init__(self, pipe: WanVideoAstraPipeline, device="cuda"):
28
- self.device = getattr(pipe, "device", device)
29
- self.tiler_kwargs = {"tiled": True, "tile_size": (34, 34), "tile_stride": (18, 16)}
30
- self.frame_process = v2.Compose([
31
- v2.ToTensor(),
32
- v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
33
- ])
34
-
35
- self.pipe = pipe
36
-
37
- @staticmethod
38
- def _crop_and_resize(image: Image.Image) -> Image.Image:
39
- target_w, target_h = 832, 480
40
- return v2.functional.resize(
41
- image,
42
- (round(target_h), round(target_w)),
43
- interpolation=v2.InterpolationMode.BILINEAR,
44
- )
45
-
46
- def preprocess_frame(self, image: Image.Image) -> torch.Tensor:
47
- image = image.convert("RGB")
48
- image = self._crop_and_resize(image)
49
- return self.frame_process(image)
50
-
51
- def load_video_frames(self, video_path: Path) -> Optional[torch.Tensor]:
52
- reader = imageio.get_reader(str(video_path))
53
- frames = []
54
- for frame_data in reader:
55
- frame = Image.fromarray(frame_data)
56
- frames.append(self.preprocess_frame(frame))
57
- reader.close()
58
-
59
- if not frames:
60
- return None
61
-
62
- frames = torch.stack(frames, dim=0)
63
- return rearrange(frames, "T C H W -> C T H W")
64
-
65
- def encode_frames_to_latents(self, frames: torch.Tensor) -> torch.Tensor:
66
- frames = frames.unsqueeze(0).to(self.device, dtype=torch.bfloat16)
67
- with torch.no_grad():
68
- latents = self.pipe.encode_video(frames, **self.tiler_kwargs)[0]
69
-
70
- if latents.dim() == 5 and latents.shape[0] == 1:
71
- latents = latents.squeeze(0)
72
- return latents.cpu()
73
-
74
- def image_to_frame_stack(
75
- image_path: Path,
76
- encoder: InlineVideoEncoder,
77
- repeat_count: int = 10
78
- ) -> torch.Tensor:
79
- """Repeat a single image into a tensor with specified number of frames, shape [C, T, H, W]"""
80
- if image_path.suffix.lower() not in VALID_IMAGE_EXTENSIONS:
81
- raise ValueError(f"Unsupported image format: {image_path.suffix}")
82
-
83
- image = Image.open(str(image_path))
84
- frame = encoder.preprocess_frame(image)
85
- frames = torch.stack([frame for _ in range(repeat_count)], dim=0)
86
- return rearrange(frames, "T C H W -> C T H W")
87
-
88
-
89
- def load_or_encode_condition(
90
- condition_pth_path: Optional[str],
91
- condition_video: Optional[str],
92
- condition_image: Optional[str],
93
- start_frame: int,
94
- num_frames: int,
95
- device: str,
96
- pipe: WanVideoAstraPipeline,
97
- ) -> tuple[torch.Tensor, dict]:
98
- if condition_pth_path:
99
- return load_encoded_video_from_pth(condition_pth_path, start_frame, num_frames)
100
-
101
- encoder = InlineVideoEncoder(pipe=pipe, device=device)
102
-
103
- if condition_video:
104
- video_path = Path(condition_video).expanduser().resolve()
105
- if not video_path.exists():
106
- raise FileNotFoundError(f"File not Found: {video_path}")
107
- frames = encoder.load_video_frames(video_path)
108
- if frames is None:
109
- raise ValueError(f"no valid frames in {video_path}")
110
- elif condition_image:
111
- image_path = Path(condition_image).expanduser().resolve()
112
- if not image_path.exists():
113
- raise FileNotFoundError(f"File not Found: {image_path}")
114
- frames = image_to_frame_stack(image_path, encoder, repeat_count=10)
115
- else:
116
- raise ValueError("condition video or image is needed for video generation.")
117
-
118
- latents = encoder.encode_frames_to_latents(frames)
119
- encoded_data = {"latents": latents}
120
-
121
- if start_frame + num_frames > latents.shape[1]:
122
- raise ValueError(
123
- f"Not enough frames after encoding: requested {start_frame + num_frames}, available {latents.shape[1]}"
124
- )
125
-
126
- condition_latents = latents[:, start_frame:start_frame + num_frames, :, :]
127
- return condition_latents, encoded_data
128
-
129
-
130
-
131
  def compute_relative_pose_matrix(pose1, pose2):
132
  """
133
- Compute relative pose between two consecutive frames, return 3x4 camera matrix [R_rel | t_rel]
134
 
135
- Args:
136
- pose1: Camera pose of frame i, shape (7,) array [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
137
- pose2: Camera pose of frame i+1, shape (7,) array [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
138
 
139
- Returns:
140
- relative_matrix: 3x4 relative pose matrix,
141
- first 3 columns are rotation matrix R_rel,
142
- last column is translation vector t_rel
143
  """
144
- # Separate translation vector and quaternion
145
- t1 = pose1[:3] # Translation of frame i [tx1, ty1, tz1]
146
- q1 = pose1[3:] # Quaternion of frame i [qx1, qy1, qz1, qw1]
147
- t2 = pose2[:3] # Translation of frame i+1
148
- q2 = pose2[3:] # Quaternion of frame i+1
149
-
150
- # 1. Compute relative rotation matrix R_rel
151
- rot1 = R.from_quat(q1) # Rotation of frame i
152
- rot2 = R.from_quat(q2) # Rotation of frame i+1
153
- rot_rel = rot2 * rot1.inv() # Relative rotation = next frame rotation × inverse of current frame rotation
154
- R_rel = rot_rel.as_matrix() # Convert to 3x3 matrix
155
-
156
- # 2. Compute relative translation vector t_rel
157
- R1_T = rot1.as_matrix().T # Transpose of current frame rotation matrix (equivalent to inverse)
158
- t_rel = R1_T @ (t2 - t1) # Relative translation = R1^T × (t2 - t1)
159
-
160
- # 3. Combine into 3x4 matrix [R_rel | t_rel]
161
  relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
162
 
163
  return relative_matrix
164
 
165
  def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
166
- """Load pre-encoded video data from pth file"""
167
  print(f"Loading encoded video from {pth_path}")
168
 
169
  encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
@@ -180,10 +68,11 @@ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
180
 
181
  return condition_latents, encoded_data
182
 
 
183
  def compute_relative_pose(pose_a, pose_b, use_torch=False):
184
- """Compute relative pose matrix of camera B with respect to camera A"""
185
- assert pose_a.shape == (4, 4), f"Camera A extrinsic matrix should be (4,4), got {pose_a.shape}"
186
- assert pose_b.shape == (4, 4), f"Camera B extrinsic matrix should be (4,4), got {pose_b.shape}"
187
 
188
  if use_torch:
189
  if not isinstance(pose_a, torch.Tensor):
@@ -206,7 +95,7 @@ def compute_relative_pose(pose_a, pose_b, use_torch=False):
206
 
207
 
208
  def replace_dit_model_in_manager():
209
- """Replace DiT model class with MoE version"""
210
  from diffsynth.models.wan_video_dit_moe import WanModelMoe
211
  from diffsynth.configs.model_config import model_loader_configs
212
 
@@ -221,7 +110,7 @@ def replace_dit_model_in_manager():
221
  if name == 'wan_video_dit':
222
  new_model_names.append(name)
223
  new_model_classes.append(WanModelMoe)
224
- print(f"Replaced model class: {name} -> WanModelMoe")
225
  else:
226
  new_model_names.append(name)
227
  new_model_classes.append(cls)
@@ -230,7 +119,7 @@ def replace_dit_model_in_manager():
230
 
231
 
232
  def add_framepack_components(dit_model):
233
- """Add FramePack related components"""
234
  if not hasattr(dit_model, 'clean_x_embedder'):
235
  inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
236
 
@@ -257,37 +146,37 @@ def add_framepack_components(dit_model):
257
  dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
258
  model_dtype = next(dit_model.parameters()).dtype
259
  dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
260
- print("Added FramePack clean_x_embedder component")
261
 
262
 
263
  def add_moe_components(dit_model, moe_config):
264
- """Add MoE related components - corrected version"""
265
  if not hasattr(dit_model, 'moe_config'):
266
  dit_model.moe_config = moe_config
267
- print("Added MoE config to model")
268
  dit_model.top_k = moe_config.get("top_k", 1)
269
 
270
- # Dynamically add MoE components for each block
271
  dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
272
  unified_dim = moe_config.get("unified_dim", 25)
273
  num_experts = moe_config.get("num_experts", 4)
274
  from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
275
  dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
276
  dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
277
- dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX uses 13-dim input, similar to sekai but handled independently
278
  dit_model.global_router = nn.Linear(unified_dim, num_experts)
279
 
280
 
281
  for i, block in enumerate(dit_model.blocks):
282
- # MoE network - input unified_dim, output dim
283
  block.moe = MultiModalMoE(
284
  unified_dim=unified_dim,
285
- output_dim=dim, # Output dimension matches transformer block dim
286
  num_experts=moe_config.get("num_experts", 4),
287
  top_k=moe_config.get("top_k", 2)
288
  )
289
 
290
- print(f"Block {i} added MoE component (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
291
 
292
 
293
  def generate_sekai_camera_embeddings_sliding(
@@ -299,45 +188,45 @@ def generate_sekai_camera_embeddings_sliding(
299
  use_real_poses=True,
300
  direction="left"):
301
  """
302
- Generate camera embeddings for Sekai dataset - sliding window version
303
 
304
  Args:
305
- cam_data: Dictionary containing Sekai camera extrinsic parameters, key 'extrinsic' corresponds to an N*4*4 numpy array
306
- start_frame: Current generation start frame index
307
- initial_condition_frames: Initial condition frame count
308
- new_frames: Number of new frames to generate this time
309
- total_generated: Total frames already generated
310
- use_real_poses: Whether to use real Sekai camera poses
311
- direction: Camera movement direction, default "left"
312
 
313
  Returns:
314
- camera_embedding: Torch tensor of shape (M, 3*4 + 1), where M is the total number of generated frames
315
  """
316
  time_compression_ratio = 4
317
 
318
- # Calculate the actual number of camera frames needed for FramePack
319
- # 1 initial frame + 16 frames 4x + 2 frames 2x + 1 frame 1x + new_frames
320
  framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
321
 
322
  if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
323
- print("🔧 Using real Sekai camera data")
324
  cam_extrinsic = cam_data['extrinsic']
325
 
326
- # Ensure generating a sufficiently long camera sequence
327
  max_needed_frames = max(
328
  start_frame + initial_condition_frames + new_frames,
329
  framepack_needed_frames,
330
  30
331
  )
332
 
333
- print(f"🔧 Calculating Sekai camera sequence length:")
334
- print(f" - Basic requirement: {start_frame + initial_condition_frames + new_frames}")
335
- print(f" - FramePack requirement: {framepack_needed_frames}")
336
- print(f" - Final generation: {max_needed_frames}")
337
 
338
  relative_poses = []
339
  for i in range(max_needed_frames):
340
- # Calculate the position of the current frame in the original sequence
341
  frame_idx = i * time_compression_ratio
342
  next_frame_idx = frame_idx + time_compression_ratio
343
 
@@ -347,72 +236,52 @@ def generate_sekai_camera_embeddings_sliding(
347
  relative_pose = compute_relative_pose(cam_prev, cam_next)
348
  relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
349
  else:
350
- # Out of range, use zero motion
351
- print(f"⚠️ Frame {frame_idx} exceeds camera data range, using zero motion")
352
  relative_poses.append(torch.zeros(3, 4))
353
 
354
  pose_embedding = torch.stack(relative_poses, dim=0)
355
  pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
356
 
357
- # Create mask sequence of corresponding length
358
  mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
359
- # Mark from start_frame to start_frame+initial_condition_frames as condition
360
  condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
361
  mask[start_frame:condition_end] = 1.0
362
 
363
  camera_embedding = torch.cat([pose_embedding, mask], dim=1)
364
- print(f"🔧 Sekai real camera embedding shape: {camera_embedding.shape}")
365
  return camera_embedding.to(torch.bfloat16)
366
 
367
  else:
368
- # Ensure generating a sufficiently long camera sequence
369
  max_needed_frames = max(
370
  start_frame + initial_condition_frames + new_frames,
371
  framepack_needed_frames,
372
  30)
373
 
374
- print(f"🔧 Generating Sekai synthetic camera frames: {max_needed_frames}")
375
 
376
  CONDITION_FRAMES = initial_condition_frames
377
  STAGE_1 = new_frames//2
378
  STAGE_2 = new_frames - STAGE_1
379
 
380
- if direction=="forward":
381
- print("--------------- FORWARD MODE ---------------")
382
- relative_poses = []
383
- for i in range(max_needed_frames):
384
- if i < CONDITION_FRAMES:
385
- # Input condition frames default to zero motion camera pose
386
- pose = np.eye(4, dtype=np.float32)
387
- elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
388
- # Forward
389
- forward_speed = 0.03
390
-
391
- pose = np.eye(4, dtype=np.float32)
392
- pose[2, 3] = -forward_speed
393
- else:
394
- # The part beyond condition frames and target frames remains stationary
395
- pose = np.eye(4, dtype=np.float32)
396
-
397
- relative_pose = pose[:3, :]
398
- relative_poses.append(torch.as_tensor(relative_pose))
399
-
400
- elif direction=="left":
401
  print("--------------- LEFT TURNING MODE ---------------")
402
  relative_poses = []
403
  for i in range(max_needed_frames):
404
  if i < CONDITION_FRAMES:
405
- # Input condition frames default to zero motion camera pose
406
  pose = np.eye(4, dtype=np.float32)
407
  elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
408
- # Left turn
409
  yaw_per_frame = 0.03
410
 
411
- # Rotation matrix
412
  cos_yaw = np.cos(yaw_per_frame)
413
  sin_yaw = np.sin(yaw_per_frame)
414
 
415
- # Forward
416
  forward_speed = 0.00
417
 
418
  pose = np.eye(4, dtype=np.float32)
@@ -423,7 +292,7 @@ def generate_sekai_camera_embeddings_sliding(
423
  pose[2, 2] = cos_yaw
424
  pose[2, 3] = -forward_speed
425
  else:
426
- # The part beyond condition frames and target frames remains stationary
427
  pose = np.eye(4, dtype=np.float32)
428
 
429
  relative_pose = pose[:3, :]
@@ -434,17 +303,17 @@ def generate_sekai_camera_embeddings_sliding(
434
  relative_poses = []
435
  for i in range(max_needed_frames):
436
  if i < CONDITION_FRAMES:
437
- # Input condition frames default to zero motion camera pose
438
  pose = np.eye(4, dtype=np.float32)
439
  elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
440
- # Right turn
441
  yaw_per_frame = -0.03
442
 
443
- # Rotation matrix
444
  cos_yaw = np.cos(yaw_per_frame)
445
  sin_yaw = np.sin(yaw_per_frame)
446
 
447
- # Forward
448
  forward_speed = 0.00
449
 
450
  pose = np.eye(4, dtype=np.float32)
@@ -455,7 +324,7 @@ def generate_sekai_camera_embeddings_sliding(
455
  pose[2, 2] = cos_yaw
456
  pose[2, 3] = -forward_speed
457
  else:
458
- # The part beyond condition frames and target frames remains stationary
459
  pose = np.eye(4, dtype=np.float32)
460
 
461
  relative_pose = pose[:3, :]
@@ -466,17 +335,17 @@ def generate_sekai_camera_embeddings_sliding(
466
  relative_poses = []
467
  for i in range(max_needed_frames):
468
  if i < CONDITION_FRAMES:
469
- # Input condition frames default to zero motion camera pose
470
  pose = np.eye(4, dtype=np.float32)
471
  elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
472
- # Left turn
473
  yaw_per_frame = 0.03
474
 
475
- # Rotation matrix
476
  cos_yaw = np.cos(yaw_per_frame)
477
  sin_yaw = np.sin(yaw_per_frame)
478
 
479
- # Forward
480
  forward_speed = 0.03
481
 
482
  pose = np.eye(4, dtype=np.float32)
@@ -488,7 +357,7 @@ def generate_sekai_camera_embeddings_sliding(
488
  pose[2, 3] = -forward_speed
489
 
490
  else:
491
- # The part beyond condition frames and target frames remains stationary
492
  pose = np.eye(4, dtype=np.float32)
493
 
494
  relative_pose = pose[:3, :]
@@ -499,17 +368,17 @@ def generate_sekai_camera_embeddings_sliding(
499
  relative_poses = []
500
  for i in range(max_needed_frames):
501
  if i < CONDITION_FRAMES:
502
- # Input condition frames default to zero motion camera pose
503
  pose = np.eye(4, dtype=np.float32)
504
  elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
505
- # Right turn
506
  yaw_per_frame = -0.03
507
 
508
- # Rotation matrix
509
  cos_yaw = np.cos(yaw_per_frame)
510
  sin_yaw = np.sin(yaw_per_frame)
511
 
512
- # Forward
513
  forward_speed = 0.03
514
 
515
  pose = np.eye(4, dtype=np.float32)
@@ -521,7 +390,7 @@ def generate_sekai_camera_embeddings_sliding(
521
  pose[2, 3] = -forward_speed
522
 
523
  else:
524
- # The part beyond condition frames and target frames remains stationary
525
  pose = np.eye(4, dtype=np.float32)
526
 
527
  relative_pose = pose[:3, :]
@@ -532,17 +401,17 @@ def generate_sekai_camera_embeddings_sliding(
532
  relative_poses = []
533
  for i in range(max_needed_frames):
534
  if i < CONDITION_FRAMES:
535
- # Input condition frames default to zero motion camera pose
536
  pose = np.eye(4, dtype=np.float32)
537
  elif i < CONDITION_FRAMES+STAGE_1:
538
- # Left turn
539
  yaw_per_frame = 0.03
540
 
541
- # Rotation matrix
542
  cos_yaw = np.cos(yaw_per_frame)
543
  sin_yaw = np.sin(yaw_per_frame)
544
 
545
- # Forward
546
  forward_speed = 0.03
547
 
548
  pose = np.eye(4, dtype=np.float32)
@@ -554,16 +423,16 @@ def generate_sekai_camera_embeddings_sliding(
554
  pose[2, 3] = -forward_speed
555
 
556
  elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
557
- # Right turn
558
  yaw_per_frame = -0.03
559
 
560
- # Rotation matrix
561
  cos_yaw = np.cos(yaw_per_frame)
562
  sin_yaw = np.sin(yaw_per_frame)
563
 
564
- # Forward
565
  forward_speed = 0.03
566
- # Slight left drift to maintain inertia
567
  if i < CONDITION_FRAMES+STAGE_1+STAGE_2//3:
568
  radius_shift = -0.01
569
  else:
@@ -579,7 +448,7 @@ def generate_sekai_camera_embeddings_sliding(
579
  pose[0, 3] = radius_shift
580
 
581
  else:
582
- # The part beyond condition frames and target frames remains stationary
583
  pose = np.eye(4, dtype=np.float32)
584
 
585
  relative_pose = pose[:3, :]
@@ -590,17 +459,17 @@ def generate_sekai_camera_embeddings_sliding(
590
  relative_poses = []
591
  for i in range(max_needed_frames):
592
  if i < CONDITION_FRAMES:
593
- # Input condition frames default to zero motion camera pose
594
  pose = np.eye(4, dtype=np.float32)
595
  elif i < CONDITION_FRAMES+STAGE_1:
596
- # Left turn
597
  yaw_per_frame = 0.03
598
 
599
- # Rotation matrix
600
  cos_yaw = np.cos(yaw_per_frame)
601
  sin_yaw = np.sin(yaw_per_frame)
602
 
603
- # Forward
604
  forward_speed = 0.00
605
 
606
  pose = np.eye(4, dtype=np.float32)
@@ -612,14 +481,14 @@ def generate_sekai_camera_embeddings_sliding(
612
  pose[2, 3] = -forward_speed
613
 
614
  elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
615
- # Right turn
616
  yaw_per_frame = -0.03
617
 
618
- # Rotation matrix
619
  cos_yaw = np.cos(yaw_per_frame)
620
  sin_yaw = np.sin(yaw_per_frame)
621
 
622
- # Forward
623
  forward_speed = 0.00
624
 
625
  pose = np.eye(4, dtype=np.float32)
@@ -631,55 +500,55 @@ def generate_sekai_camera_embeddings_sliding(
631
  pose[2, 3] = -forward_speed
632
 
633
  else:
634
- # The part beyond condition frames and target frames remains stationary
635
  pose = np.eye(4, dtype=np.float32)
636
 
637
  relative_pose = pose[:3, :]
638
  relative_poses.append(torch.as_tensor(relative_pose))
639
 
640
  else:
641
- raise ValueError(f"Not Defined Direction: {direction}")
642
 
643
  pose_embedding = torch.stack(relative_poses, dim=0)
644
  pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
645
 
646
- # Create mask sequence of corresponding length
647
  mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
648
  condition_end = min(start_frame + initial_condition_frames + 1, max_needed_frames)
649
  mask[start_frame:condition_end] = 1.0
650
 
651
  camera_embedding = torch.cat([pose_embedding, mask], dim=1)
652
- print(f"🔧 Sekai synthetic camera embedding shape: {camera_embedding.shape}")
653
  return camera_embedding.to(torch.bfloat16)
654
 
655
 
656
  def generate_openx_camera_embeddings_sliding(
657
  encoded_data, start_frame, initial_condition_frames, new_frames, use_real_poses):
658
- """Generate camera embeddings for OpenX dataset - sliding window version"""
659
  time_compression_ratio = 4
660
 
661
- # Calculate the actual number of camera frames needed for FramePack
662
  framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
663
 
664
  if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
665
- print("🔧 Using OpenX real camera data")
666
  cam_extrinsic = encoded_data['cam_emb']['extrinsic']
667
 
668
- # Ensure generating a sufficiently long camera sequence
669
  max_needed_frames = max(
670
  start_frame + initial_condition_frames + new_frames,
671
  framepack_needed_frames,
672
  30
673
  )
674
 
675
- print(f"🔧 Calculating OpenX camera sequence length:")
676
- print(f" - Basic requirement: {start_frame + initial_condition_frames + new_frames}")
677
- print(f" - FramePack requirement: {framepack_needed_frames}")
678
- print(f" - Final generation: {max_needed_frames}")
679
 
680
  relative_poses = []
681
  for i in range(max_needed_frames):
682
- # OpenX uses 4x interval, similar to sekai but handles shorter sequences
683
  frame_idx = i * time_compression_ratio
684
  next_frame_idx = frame_idx + time_compression_ratio
685
 
@@ -689,25 +558,25 @@ def generate_openx_camera_embeddings_sliding(
689
  relative_pose = compute_relative_pose(cam_prev, cam_next)
690
  relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
691
  else:
692
- # Out of range, use zero motion
693
- print(f"⚠️ Frame {frame_idx} exceeds OpenX camera data range, using zero motion")
694
  relative_poses.append(torch.zeros(3, 4))
695
 
696
  pose_embedding = torch.stack(relative_poses, dim=0)
697
  pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
698
 
699
- # Create mask sequence of corresponding length
700
  mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
701
- # Mark from start_frame to start_frame + initial_condition_frames as condition
702
  condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
703
  mask[start_frame:condition_end] = 1.0
704
 
705
  camera_embedding = torch.cat([pose_embedding, mask], dim=1)
706
- print(f"🔧 OpenX real camera embedding shape: {camera_embedding.shape}")
707
  return camera_embedding.to(torch.bfloat16)
708
 
709
  else:
710
- print("🔧 Using OpenX synthetic camera data")
711
 
712
  max_needed_frames = max(
713
  start_frame + initial_condition_frames + new_frames,
@@ -715,30 +584,30 @@ def generate_openx_camera_embeddings_sliding(
715
  30
716
  )
717
 
718
- print(f"🔧 Generating OpenX synthetic camera frames: {max_needed_frames}")
719
  relative_poses = []
720
  for i in range(max_needed_frames):
721
- # OpenX robot operation motion mode - smaller motion amplitude
722
- # Simulate fine operation motion of robot arm
723
- roll_per_frame = 0.02 # Slight roll
724
- pitch_per_frame = 0.01 # Slight pitch
725
- yaw_per_frame = 0.015 # Slight yaw
726
- forward_speed = 0.003 # Slower forward speed
727
 
728
  pose = np.eye(4, dtype=np.float32)
729
 
730
- # Compound rotation - simulate complex motion of robot arm
731
- # Rotate around X-axis (roll)
732
  cos_roll = np.cos(roll_per_frame)
733
  sin_roll = np.sin(roll_per_frame)
734
- # Rotate around Y-axis (pitch)
735
  cos_pitch = np.cos(pitch_per_frame)
736
  sin_pitch = np.sin(pitch_per_frame)
737
- # Rotate around Z-axis (yaw)
738
  cos_yaw = np.cos(yaw_per_frame)
739
  sin_yaw = np.sin(yaw_per_frame)
740
 
741
- # Simplified compound rotation matrix (ZYX order)
742
  pose[0, 0] = cos_yaw * cos_pitch
743
  pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
744
  pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
@@ -749,10 +618,10 @@ def generate_openx_camera_embeddings_sliding(
749
  pose[2, 1] = cos_pitch * sin_roll
750
  pose[2, 2] = cos_pitch * cos_roll
751
 
752
- # Translation - simulate fine movement of robot operation
753
- pose[0, 3] = forward_speed * 0.5 # Slight movement in X direction
754
- pose[1, 3] = forward_speed * 0.3 # Slight movement in Y direction
755
- pose[2, 3] = -forward_speed # Main movement in Z direction (depth)
756
 
757
  relative_pose = pose[:3, :]
758
  relative_poses.append(torch.as_tensor(relative_pose))
@@ -760,34 +629,30 @@ def generate_openx_camera_embeddings_sliding(
760
  pose_embedding = torch.stack(relative_poses, dim=0)
761
  pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
762
 
763
- # Create mask sequence of corresponding length
764
  mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
765
  condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
766
  mask[start_frame:condition_end] = 1.0
767
 
768
  camera_embedding = torch.cat([pose_embedding, mask], dim=1)
769
- print(f"🔧 OpenX synthetic camera embedding shape: {camera_embedding.shape}")
770
  return camera_embedding.to(torch.bfloat16)
771
 
772
 
773
  def generate_nuscenes_camera_embeddings_sliding(
774
  scene_info, start_frame, initial_condition_frames, new_frames):
775
- """
776
- Generate camera embeddings for NuScenes dataset - sliding window version
777
-
778
- corrected version, consistent with train_moe.py
779
- """
780
  time_compression_ratio = 4
781
 
782
- # Calculate the actual number of camera frames needed for FramePack
783
  framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
784
 
785
  if scene_info is not None and 'keyframe_poses' in scene_info:
786
- print("🔧 Using NuScenes real pose data")
787
  keyframe_poses = scene_info['keyframe_poses']
788
 
789
  if len(keyframe_poses) == 0:
790
- print("⚠️ NuScenes keyframe_poses is empty, using zero pose")
791
  max_needed_frames = max(framepack_needed_frames, 30)
792
 
793
  pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
@@ -797,10 +662,10 @@ def generate_nuscenes_camera_embeddings_sliding(
797
  mask[start_frame:condition_end] = 1.0
798
 
799
  camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
800
- print(f"🔧 NuScenes zero pose embedding shape: {camera_embedding.shape}")
801
  return camera_embedding.to(torch.bfloat16)
802
 
803
- # Use first pose as reference
804
  reference_pose = keyframe_poses[0]
805
 
806
  max_needed_frames = max(framepack_needed_frames, 30)
@@ -810,18 +675,18 @@ def generate_nuscenes_camera_embeddings_sliding(
810
  if i < len(keyframe_poses):
811
  current_pose = keyframe_poses[i]
812
 
813
- # Calculate relative displacement
814
  translation = torch.tensor(
815
  np.array(current_pose['translation']) - np.array(reference_pose['translation']),
816
  dtype=torch.float32
817
  )
818
 
819
- # Calculate relative rotation (simplified version)
820
  rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
821
 
822
  pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
823
  else:
824
- # Out of range, use zero pose
825
  pose_vec = torch.cat([
826
  torch.zeros(3, dtype=torch.float32),
827
  torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
@@ -831,41 +696,41 @@ def generate_nuscenes_camera_embeddings_sliding(
831
 
832
  pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
833
 
834
- # Create mask
835
  mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
836
  condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
837
  mask[start_frame:condition_end] = 1.0
838
 
839
  camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
840
- print(f"🔧 NuScenes real pose embedding shape: {camera_embedding.shape}")
841
  return camera_embedding.to(torch.bfloat16)
842
 
843
  else:
844
- print("🔧 Using NuScenes synthetic pose data")
845
  max_needed_frames = max(framepack_needed_frames, 30)
846
 
847
- # Create synthetic motion sequence
848
  pose_vecs = []
849
  for i in range(max_needed_frames):
850
- # Left turn motion mode - similar to left turns in city driving
851
- angle = i * 0.04 # Rotate 0.08 radians per frame (slightly slower turn)
852
- radius = 15.0 # Larger turning radius, more suitable for car turns
853
 
854
- # Calculate position on circular arc trajectory
855
  x = radius * np.sin(angle)
856
- y = 0.0 # Keep horizontal plane motion
857
  z = radius * (1 - np.cos(angle))
858
 
859
  translation = torch.tensor([x, y, z], dtype=torch.float32)
860
 
861
- # Vehicle orientation - always along trajectory tangent direction
862
- yaw = angle + np.pi/2 # Yaw angle relative to initial forward direction
863
- # Quaternion representation of rotation around Y-axis
864
  rotation = torch.tensor([
865
- np.cos(yaw/2), # w (real part)
866
  0.0, # x
867
  0.0, # y
868
- np.sin(yaw/2) # z (imaginary part, around Y-axis)
869
  ], dtype=torch.float32)
870
 
871
  pose_vec = torch.cat([translation, rotation], dim=0) # [7D: tx,ty,tz,qw,qx,qy,qz]
@@ -873,15 +738,15 @@ def generate_nuscenes_camera_embeddings_sliding(
873
 
874
  pose_sequence = torch.stack(pose_vecs, dim=0)
875
 
876
- # Create mask
877
  mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
878
  condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
879
  mask[start_frame:condition_end] = 1.0
880
 
881
  camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
882
- print(f"🔧 NuScenes synthetic left turn pose embedding shape: {camera_embedding.shape}")
883
  return camera_embedding.to(torch.bfloat16)
884
-
885
  def prepare_framepack_sliding_window_with_camera_moe(
886
  history_latents,
887
  target_frames_to_generate,
@@ -889,12 +754,12 @@ def prepare_framepack_sliding_window_with_camera_moe(
889
  start_frame,
890
  modality_type,
891
  max_history_frames=49):
892
- """FramePack sliding window mechanism - MoE version"""
893
- # history_latents: [C, T, H, W] current history latents
894
  C, T, H, W = history_latents.shape
895
 
896
- # Fixed index structure (this determines the number of camera frames needed)
897
- # 1 start frame + 16 frames 4x + 2 frames 2x + 1 frame 1x + target_frames_to_generate
898
  total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
899
  indices = torch.arange(0, total_indices_length)
900
  split_sizes = [1, 16, 2, 1, target_frames_to_generate]
@@ -902,44 +767,44 @@ def prepare_framepack_sliding_window_with_camera_moe(
902
  indices.split(split_sizes, dim=0)
903
  clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
904
 
905
- # Check if camera length is sufficient
906
  if camera_embedding_full.shape[0] < total_indices_length:
907
- print(f"⚠️ camera_embedding length insufficient, performing zero padding: current length {camera_embedding_full.shape[0]}, required length {total_indices_length}")
908
  shortage = total_indices_length - camera_embedding_full.shape[0]
909
  padding = torch.zeros(shortage, camera_embedding_full.shape[1],
910
  dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
911
  camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
912
 
913
- # Select corresponding part from complete camera sequence
914
  combined_camera = torch.zeros(
915
  total_indices_length,
916
  camera_embedding_full.shape[1],
917
  dtype=camera_embedding_full.dtype,
918
  device=camera_embedding_full.device)
919
 
920
- # Camera poses for historical condition frames
921
  history_slice = camera_embedding_full[max(T - 19, 0):T, :].clone()
922
  combined_camera[19 - history_slice.shape[0]:19, :] = history_slice
923
 
924
- # Camera poses for target frames
925
  target_slice = camera_embedding_full[T:T + target_frames_to_generate, :].clone()
926
  combined_camera[19:19 + target_slice.shape[0], :] = target_slice
927
 
928
- # Reset mask according to current history length
929
- combined_camera[:, -1] = 0.0 # First set all to target (0)
930
 
931
- # Set condition mask: first 19 frames determined by actual history length
932
  if T > 0:
933
  available_frames = min(T, 19)
934
  start_pos = 19 - available_frames
935
- combined_camera[start_pos:19, -1] = 1.0 # Mark cameras corresponding to valid clean latents as condition
936
 
937
- print(f"🔧 MoE Camera mask update:")
938
- print(f" - History frames: {T}")
939
- print(f" - Valid condition frames: {available_frames if T > 0 else 0}")
940
- print(f" - Modality type: {modality_type}")
941
 
942
- # Process latents
943
  clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
944
 
945
  if T > 0:
@@ -967,54 +832,54 @@ def prepare_framepack_sliding_window_with_camera_moe(
967
  'clean_latent_2x_indices': clean_latent_2x_indices,
968
  'clean_latent_4x_indices': clean_latent_4x_indices,
969
  'camera_embedding': combined_camera,
970
- 'modality_type': modality_type, # Added modality type information
971
  'current_length': T,
972
  'next_length': T + target_frames_to_generate
973
  }
974
 
975
  def overlay_controls(frame_img, pose_vec, icons):
976
  """
977
- Overlay control icons (WASD and arrows) on frame based on camera pose
978
- pose_vec: 12 elements (flattened 3x4 matrix) + mask
979
  """
980
  if pose_vec is None or np.all(pose_vec[:12] == 0):
981
  return frame_img
982
 
983
- # Extract translation vector (based on flattened 3x4 matrix indices)
984
  # [r00, r01, r02, tx, r10, r11, r12, ty, r20, r21, r22, tz]
985
  tx = pose_vec[3]
986
  # ty = pose_vec[7]
987
  tz = pose_vec[11]
988
 
989
- # Extract rotation (yaw and pitch)
990
- # Yaw: around Y axis. sin(yaw) = r02, cos(yaw) = r00
991
  r00 = pose_vec[0]
992
  r02 = pose_vec[2]
993
  yaw = np.arctan2(r02, r00)
994
 
995
- # Pitch: around X axis. sin(pitch) = -r12, cos(pitch) = r22
996
  r12 = pose_vec[6]
997
  r22 = pose_vec[10]
998
  pitch = np.arctan2(-r12, r22)
999
 
1000
- # Threshold for key activation
1001
  TRANS_THRESH = 0.01
1002
  ROT_THRESH = 0.005
1003
 
1004
- # Determine key states
1005
- # Translation (WASD)
1006
- # Assume -Z is forward, +X is right
1007
  is_forward = tz < -TRANS_THRESH
1008
  is_backward = tz > TRANS_THRESH
1009
  is_left = tx < -TRANS_THRESH
1010
  is_right = tx > TRANS_THRESH
1011
 
1012
- # Rotation (arrows)
1013
- # Yaw: + is left, - is right
1014
  is_turn_left = yaw > ROT_THRESH
1015
  is_turn_right = yaw < -ROT_THRESH
1016
 
1017
- # Pitch: + is down, - is up
1018
  is_turn_up = pitch < -ROT_THRESH
1019
  is_turn_down = pitch > ROT_THRESH
1020
 
@@ -1025,10 +890,10 @@ def overlay_controls(frame_img, pose_vec, icons):
1025
  name = name_active if is_active else name_inactive
1026
  if name in icons:
1027
  icon = icons[name]
1028
- # Paste using alpha channel
1029
  frame_img.paste(icon, (int(x), int(y)), icon)
1030
 
1031
- # Overlay WASD (bottom left)
1032
  base_x_right = 100
1033
  base_y = H - 100
1034
 
@@ -1041,7 +906,7 @@ def overlay_controls(frame_img, pose_vec, icons):
1041
  # D
1042
  paste_icon('move_right.png', 'not_move_right.png', is_right, base_x_right + spacing, base_y)
1043
 
1044
- # Overlay arrows (bottom right)
1045
  base_x_left = W - 150
1046
 
1047
  # ↑
@@ -1057,11 +922,8 @@ def overlay_controls(frame_img, pose_vec, icons):
1057
 
1058
 
1059
  def inference_moe_framepack_sliding_window(
1060
- condition_pth_path=None,
1061
- condition_video=None,
1062
- condition_image=None,
1063
- dit_path=None,
1064
- wan_model_path=None,
1065
  output_path="../examples/output_videos/output_moe_framepack_sliding.mp4",
1066
  start_frame=0,
1067
  initial_condition_frames=8,
@@ -1070,14 +932,14 @@ def inference_moe_framepack_sliding_window(
1070
  max_history_frames=49,
1071
  device="cuda",
1072
  prompt="A video of a scene shot using a pedestrian's front camera while walking",
1073
- modality_type="sekai", # "sekai" or "nuscenes"
1074
  use_real_poses=True,
1075
- scene_info_path=None, # For NuScenes dataset
1076
- # CFG parameters
1077
  use_camera_cfg=True,
1078
  camera_guidance_scale=2.0,
1079
  text_guidance_scale=1.0,
1080
- # MoE parameters
1081
  moe_num_experts=4,
1082
  moe_top_k=2,
1083
  moe_hidden_dim=None,
@@ -1086,30 +948,30 @@ def inference_moe_framepack_sliding_window(
1086
  add_icons=False
1087
  ):
1088
  """
1089
- MoE FramePack sliding window video generation - multi-modal support
1090
  """
1091
- # Create output directory
1092
  dir_path = os.path.dirname(output_path)
1093
  os.makedirs(dir_path, exist_ok=True)
1094
 
1095
- print(f"🔧 Starting MoE FramePack sliding window generation...")
1096
- print(f" Modality type: {modality_type}")
1097
- print(f" Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
1098
- print(f" Text guidance scale: {text_guidance_scale}")
1099
- print(f" MoE config: experts={moe_num_experts}, top_k={moe_top_k}")
1100
 
1101
- # 1. Model initialization
1102
  replace_dit_model_in_manager()
1103
 
1104
  model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
1105
  model_manager.load_models([
1106
- os.path.join(wan_model_path, "diffusion_pytorch_model.safetensors"),
1107
- os.path.join(wan_model_path, "models_t5_umt5-xxl-enc-bf16.pth"),
1108
- os.path.join(wan_model_path, "Wan2.1_VAE.pth"),
1109
  ])
1110
- pipe = WanVideoAstraPipeline.from_model_manager(model_manager, device="cuda")
1111
 
1112
- # 2. Add traditional camera encoder (compatibility)
1113
  dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
1114
  for block in pipe.dit.blocks:
1115
  block.cam_encoder = nn.Linear(13, dim)
@@ -1119,45 +981,41 @@ def inference_moe_framepack_sliding_window(
1119
  block.projector.weight = nn.Parameter(torch.eye(dim))
1120
  block.projector.bias = nn.Parameter(torch.zeros(dim))
1121
 
1122
- # 3. Add FramePack components
1123
  add_framepack_components(pipe.dit)
1124
 
1125
- # 4. Add MoE components
1126
  moe_config = {
1127
  "num_experts": moe_num_experts,
1128
  "top_k": moe_top_k,
1129
  "hidden_dim": moe_hidden_dim or dim * 2,
1130
- "sekai_input_dim": 13, # Sekai: 12-dim pose + 1-dim mask
1131
- "nuscenes_input_dim": 8, # NuScenes: 7-dim pose + 1-dim mask
1132
- "openx_input_dim": 13 # OpenX: 12-dim pose + 1-dim mask (similar to sekai)
1133
  }
1134
  add_moe_components(pipe.dit, moe_config)
1135
 
1136
- # 5. Load trained weights
1137
  dit_state_dict = torch.load(dit_path, map_location="cpu")
1138
- pipe.dit.load_state_dict(dit_state_dict, strict=False) # Use strict=False to be compatible with newly added MoE components
1139
  pipe = pipe.to(device)
1140
  model_dtype = next(pipe.dit.parameters()).dtype
1141
 
1142
  if hasattr(pipe.dit, 'clean_x_embedder'):
1143
  pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
1144
 
1145
- # Set denoising steps
1146
  pipe.scheduler.set_timesteps(50)
1147
 
1148
- # 6. Load initial conditions
1149
  print("Loading initial condition frames...")
1150
- initial_latents, encoded_data = load_or_encode_condition(
1151
- condition_pth_path,
1152
- condition_video,
1153
- condition_image,
1154
- start_frame,
1155
- initial_condition_frames,
1156
- device,
1157
- pipe,
1158
  )
1159
 
1160
- # Spatial cropping
1161
  target_height, target_width = 60, 104
1162
  C, T, H, W = initial_latents.shape
1163
 
@@ -1169,50 +1027,50 @@ def inference_moe_framepack_sliding_window(
1169
 
1170
  history_latents = initial_latents.to(device, dtype=model_dtype)
1171
 
1172
- print(f"Initial history_latents shape: {history_latents.shape}")
1173
 
1174
- # 7. Encode prompt - support CFG
1175
  if use_gt_prompt and 'prompt_emb' in encoded_data:
1176
- print("✅ Using pre-encoded GT prompt embedding")
1177
  prompt_emb_pos = encoded_data['prompt_emb']
1178
- # Move prompt_emb to correct device and dtype
1179
  if 'context' in prompt_emb_pos:
1180
  prompt_emb_pos['context'] = prompt_emb_pos['context'].to(device, dtype=model_dtype)
1181
  if 'context_mask' in prompt_emb_pos:
1182
  prompt_emb_pos['context_mask'] = prompt_emb_pos['context_mask'].to(device, dtype=model_dtype)
1183
 
1184
- # Generate negative prompt if using Text CFG
1185
  if text_guidance_scale > 1.0:
1186
  prompt_emb_neg = pipe.encode_prompt("")
1187
- print(f"Using Text CFG with GT prompt, guidance scale: {text_guidance_scale}")
1188
  else:
1189
  prompt_emb_neg = None
1190
- print("Not using Text CFG")
1191
 
1192
- # Print GT prompt text if available
1193
  if 'prompt' in encoded_data['prompt_emb']:
1194
  gt_prompt_text = encoded_data['prompt_emb']['prompt']
1195
- print(f"📝 GT Prompt text: {gt_prompt_text}")
1196
  else:
1197
- # Re-encode using provided prompt parameter
1198
- print(f"🔄 Re-encoding prompt: {prompt}")
1199
  if text_guidance_scale > 1.0:
1200
  prompt_emb_pos = pipe.encode_prompt(prompt)
1201
  prompt_emb_neg = pipe.encode_prompt("")
1202
- print(f"Using Text CFG, guidance scale: {text_guidance_scale}")
1203
  else:
1204
  prompt_emb_pos = pipe.encode_prompt(prompt)
1205
  prompt_emb_neg = None
1206
- print("Not using Text CFG")
1207
 
1208
- # 8. Load scene information (for NuScenes)
1209
  scene_info = None
1210
  if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
1211
  with open(scene_info_path, 'r') as f:
1212
  scene_info = json.load(f)
1213
- print(f"Loading NuScenes scene information: {scene_info_path}")
1214
 
1215
- # 9. Pre-generate complete camera embedding sequence
1216
  if modality_type == "sekai":
1217
  camera_embedding_full = generate_sekai_camera_embeddings_sliding(
1218
  encoded_data.get('cam_emb', None),
@@ -1239,25 +1097,25 @@ def inference_moe_framepack_sliding_window(
1239
  use_real_poses=use_real_poses
1240
  ).to(device, dtype=model_dtype)
1241
  else:
1242
- raise ValueError(f"Unsupported modality type: {modality_type}")
1243
 
1244
- print(f"Complete camera sequence shape: {camera_embedding_full.shape}")
1245
 
1246
- # 10. Create unconditional camera embedding for Camera CFG
1247
  if use_camera_cfg:
1248
  camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
1249
- print(f"Creating unconditional camera embedding for CFG")
1250
 
1251
- # 11. Sliding window generation loop
1252
  total_generated = 0
1253
  all_generated_frames = []
1254
 
1255
  while total_generated < total_frames_to_generate:
1256
  current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
1257
- print(f"\nGeneration step {total_generated // frames_per_generation + 1}")
1258
- print(f"Current history length: {history_latents.shape[1]}, generating: {current_generation}")
1259
 
1260
- # FramePack data preparation - MoE version
1261
  framepack_data = prepare_framepack_sliding_window_with_camera_moe(
1262
  history_latents,
1263
  current_generation,
@@ -1267,27 +1125,27 @@ def inference_moe_framepack_sliding_window(
1267
  max_history_frames
1268
  )
1269
 
1270
- # Prepare input
1271
  clean_latents = framepack_data['clean_latents'].unsqueeze(0)
1272
  clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
1273
  clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
1274
  camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
1275
 
1276
- # Prepare modality_inputs
1277
  modality_inputs = {modality_type: camera_embedding}
1278
 
1279
- # Prepare unconditional camera embedding for CFG
1280
  if use_camera_cfg:
1281
  camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
1282
  modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
1283
 
1284
- # Index processing
1285
  latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
1286
  clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
1287
  clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
1288
  clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
1289
 
1290
- # Initialize latents to generate
1291
  new_latents = torch.randn(
1292
  1, C, current_generation, H, W,
1293
  device=device, dtype=model_dtype
@@ -1296,26 +1154,26 @@ def inference_moe_framepack_sliding_window(
1296
  extra_input = pipe.prepare_extra_input(new_latents)
1297
 
1298
  print(f"Camera embedding shape: {camera_embedding.shape}")
1299
- print(f"Camera mask distribution - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
1300
 
1301
- # Denoising loop - supports CFG
1302
  timesteps = pipe.scheduler.timesteps
1303
 
1304
  for i, timestep in enumerate(timesteps):
1305
  if i % 10 == 0:
1306
- print(f" Denoising step {i+1}/{len(timesteps)}")
1307
 
1308
  timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
1309
 
1310
  with torch.no_grad():
1311
- # CFG inference
1312
  if use_camera_cfg and camera_guidance_scale > 1.0:
1313
- # Conditional prediction (with camera)
1314
  noise_pred_cond, moe_loess = pipe.dit(
1315
  new_latents,
1316
  timestep=timestep_tensor,
1317
  cam_emb=camera_embedding,
1318
- modality_inputs=modality_inputs, # MoE modality input
1319
  latent_indices=latent_indices,
1320
  clean_latents=clean_latents,
1321
  clean_latent_indices=clean_latent_indices,
@@ -1327,12 +1185,12 @@ def inference_moe_framepack_sliding_window(
1327
  **extra_input
1328
  )
1329
 
1330
- # Unconditional prediction (no camera)
1331
  noise_pred_uncond, moe_loess = pipe.dit(
1332
  new_latents,
1333
  timestep=timestep_tensor,
1334
  cam_emb=camera_embedding_uncond_batch,
1335
- modality_inputs=modality_inputs_uncond, # MoE unconditional modality input
1336
  latent_indices=latent_indices,
1337
  clean_latents=clean_latents,
1338
  clean_latent_indices=clean_latent_indices,
@@ -1347,7 +1205,7 @@ def inference_moe_framepack_sliding_window(
1347
  # Camera CFG
1348
  noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
1349
 
1350
- # If using Text CFG at the same time
1351
  if text_guidance_scale > 1.0 and prompt_emb_neg:
1352
  noise_pred_text_uncond, moe_loess = pipe.dit(
1353
  new_latents,
@@ -1365,11 +1223,11 @@ def inference_moe_framepack_sliding_window(
1365
  **extra_input
1366
  )
1367
 
1368
- # Apply Text CFG to results that have already applied Camera CFG
1369
  noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
1370
 
1371
  elif text_guidance_scale > 1.0 and prompt_emb_neg:
1372
- # Use Text CFG only
1373
  noise_pred_cond, moe_loess = pipe.dit(
1374
  new_latents,
1375
  timestep=timestep_tensor,
@@ -1405,12 +1263,12 @@ def inference_moe_framepack_sliding_window(
1405
  noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
1406
 
1407
  else:
1408
- # Standard inference (no CFG)
1409
  noise_pred, moe_loess = pipe.dit(
1410
  new_latents,
1411
  timestep=timestep_tensor,
1412
  cam_emb=camera_embedding,
1413
- modality_inputs=modality_inputs, # MoE modality input
1414
  latent_indices=latent_indices,
1415
  clean_latents=clean_latents,
1416
  clean_latent_indices=clean_latent_indices,
@@ -1424,31 +1282,31 @@ def inference_moe_framepack_sliding_window(
1424
 
1425
  new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
1426
 
1427
- # Update history
1428
  new_latents_squeezed = new_latents.squeeze(0)
1429
  history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
1430
 
1431
- # Maintain sliding window
1432
  if history_latents.shape[1] > max_history_frames:
1433
  first_frame = history_latents[:, 0:1, :, :]
1434
  recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
1435
  history_latents = torch.cat([first_frame, recent_frames], dim=1)
1436
- print(f"⚠️ History window full, keeping first frame + latest {max_history_frames-1} frames")
1437
 
1438
- print(f"History_latents shape after update: {history_latents.shape}")
1439
 
1440
  all_generated_frames.append(new_latents_squeezed)
1441
  total_generated += current_generation
1442
 
1443
- print(f"✅ Generated {total_generated}/{total_frames_to_generate} frames")
1444
 
1445
- # 12. Decode and save
1446
- print("\nDecoding generated video...")
1447
 
1448
  all_generated = torch.cat(all_generated_frames, dim=1)
1449
  final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
1450
 
1451
- print(f"Final video shape: {final_video.shape}")
1452
 
1453
  decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
1454
 
@@ -1461,7 +1319,7 @@ def inference_moe_framepack_sliding_window(
1461
  icons = {}
1462
  video_camera_poses = None
1463
  if add_icons:
1464
- # Load icon resources for overlay
1465
  icons_dir = os.path.join(ROOT_DIR, 'icons')
1466
  icon_names = ['move_forward.png', 'not_move_forward.png',
1467
  'move_backward.png', 'not_move_backward.png',
@@ -1476,15 +1334,15 @@ def inference_moe_framepack_sliding_window(
1476
  if os.path.exists(path):
1477
  try:
1478
  icon = Image.open(path).convert("RGBA")
1479
- # Adjust icon size
1480
  icon = icon.resize((50, 50), Image.Resampling.LANCZOS)
1481
  icons[name] = icon
1482
  except Exception as e:
1483
  print(f"Error loading icon {name}: {e}")
1484
  else:
1485
- print(f"⚠️ Warning: Icon {name} not found at {path}")
1486
 
1487
- # Get camera poses corresponding to video frames
1488
  time_compression_ratio = 4
1489
  camera_poses = camera_embedding_full.detach().float().cpu().numpy()
1490
  video_camera_poses = [x for x in camera_poses for _ in range(time_compression_ratio)]
@@ -1503,108 +1361,74 @@ def inference_moe_framepack_sliding_window(
1503
 
1504
  writer.append_data(np.array(img))
1505
 
1506
- print(f" MoE FramePack sliding window generation completed! Saved to: {output_path}")
1507
- print(f" Total generated {total_generated} frames (compressed), corresponding to original {total_generated * 4} frames")
1508
- print(f" Using modality: {modality_type}")
1509
 
1510
 
1511
  def main():
1512
- parser = argparse.ArgumentParser(description="MoE FramePack sliding window video generation - supports multi-modal")
1513
-
1514
- # Basic parameters
1515
- parser.add_argument("--condition_pth",
1516
- type=str,
1517
- default=None,
1518
- help="Path to pre-encoded condition pth file")
1519
- parser.add_argument("--condition_video",
1520
- type=str,
1521
- default=None,
1522
- help="Input video for novel view synthesis.")
1523
- parser.add_argument("--condition_image",
1524
- type=str,
1525
- default=None,
1526
- required=True,
1527
- help="Input image for novel view synthesis.")
1528
  parser.add_argument("--start_frame", type=int, default=0)
1529
  parser.add_argument("--initial_condition_frames", type=int, default=1)
1530
  parser.add_argument("--frames_per_generation", type=int, default=8)
1531
  parser.add_argument("--total_frames_to_generate", type=int, default=24)
1532
  parser.add_argument("--max_history_frames", type=int, default=100)
1533
  parser.add_argument("--use_real_poses", default=False)
1534
- parser.add_argument("--dit_path", type=str,
1535
- default="../models/Astra/checkpoints/diffusion_pytorch_model.ckpt",
1536
  help="path to the pretrained DiT MoE model checkpoint")
1537
- parser.add_argument("--wan_model_path",
1538
- type=str,
1539
- default="../models/Wan-AI/Wan2.1-T2V-1.3B",
1540
- help="path to Wan2.1-T2V-1.3B")
1541
  parser.add_argument("--output_path", type=str,
1542
- default='../examples/output_videos/output_moe_framepack_sliding.mp4')
1543
- parser.add_argument("--prompt",
1544
- type=str,
1545
- default="",
1546
  help="text prompt for video generation")
1547
  parser.add_argument("--device", type=str, default="cuda")
1548
  parser.add_argument("--add_icons", action="store_true", default=False,
1549
- help="Overlay control icons on generated video")
1550
 
1551
- # Modality type parameters
1552
  parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"],
1553
- default="sekai", help="Modality type: sekai, nuscenes, or openx")
1554
  parser.add_argument("--scene_info_path", type=str, default=None,
1555
- help="NuScenes scene info file path (for nuscenes modality only)")
1556
 
1557
- # CFG parameters
1558
  parser.add_argument("--use_camera_cfg", default=False,
1559
- help="Use Camera CFG")
1560
  parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
1561
  help="Camera guidance scale for CFG")
1562
  parser.add_argument("--text_guidance_scale", type=float, default=1.0,
1563
  help="Text guidance scale for CFG")
1564
 
1565
- # MoE parameters
1566
- parser.add_argument("--moe_num_experts", type=int, default=3, help="Number of experts")
1567
- parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K experts")
1568
- parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE hidden dimension")
1569
- parser.add_argument("--direction", type=str, default="left", help="Direction of video trajectory")
1570
  parser.add_argument("--use_gt_prompt", action="store_true", default=False,
1571
- help="Use ground truth prompt embedding from dataset")
1572
 
1573
  args = parser.parse_args()
1574
 
1575
- print(f"MoE FramePack CFG generation settings:")
1576
- print(f"Modality type: {args.modality_type}")
1577
  print(f"Camera CFG: {args.use_camera_cfg}")
1578
  if args.use_camera_cfg:
1579
  print(f"Camera guidance scale: {args.camera_guidance_scale}")
1580
- print(f"Using GT Prompt: {args.use_gt_prompt}")
1581
  print(f"Text guidance scale: {args.text_guidance_scale}")
1582
- print(f"MoE config: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
1583
  print(f"DiT{args.dit_path}")
1584
 
1585
- # Validate NuScenes parameters
1586
  if args.modality_type == "nuscenes" and not args.scene_info_path:
1587
- print("⚠️ Warning: Using NuScenes modality but scene_info_path not provided, will use synthetic pose data")
1588
-
1589
- if not args.use_gt_prompt and (args.prompt is None or args.prompt.strip() == ""):
1590
- print("⚠️ Warning: No prompt provided, will use empty string as prompt")
1591
-
1592
- if not any([args.condition_pth, args.condition_video, args.condition_image]):
1593
- raise ValueError("Need to provide condition_pth, condition_video, or condition_image as condition input")
1594
-
1595
- if args.condition_pth:
1596
- print(f"Using pre-encoded pth: {args.condition_pth}")
1597
- elif args.condition_video:
1598
- print(f"Using condition video for online encoding: {args.condition_video}")
1599
- elif args.condition_image:
1600
- print(f"Using condition image for online encoding: {args.condition_image} (repeat 10 frames)")
1601
 
1602
  inference_moe_framepack_sliding_window(
1603
  condition_pth_path=args.condition_pth,
1604
- condition_video=args.condition_video,
1605
- condition_image=args.condition_image,
1606
  dit_path=args.dit_path,
1607
- wan_model_path=args.wan_model_path,
1608
  output_path=args.output_path,
1609
  start_frame=args.start_frame,
1610
  initial_condition_frames=args.initial_condition_frames,
@@ -1616,11 +1440,11 @@ def main():
1616
  modality_type=args.modality_type,
1617
  use_real_poses=args.use_real_poses,
1618
  scene_info_path=args.scene_info_path,
1619
- # CFG parameters
1620
  use_camera_cfg=args.use_camera_cfg,
1621
  camera_guidance_scale=args.camera_guidance_scale,
1622
  text_guidance_scale=args.text_guidance_scale,
1623
- # MoE parameters
1624
  moe_num_experts=args.moe_num_experts,
1625
  moe_top_k=args.moe_top_k,
1626
  moe_hidden_dim=args.moe_hidden_dim,
 
1
  import os
2
  import sys
 
 
3
 
4
  ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
5
  sys.path.append(ROOT_DIR)
 
10
  from PIL import Image
11
  import imageio
12
  import json
13
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
14
  import argparse
15
  from torchvision.transforms import v2
16
  from einops import rearrange
 
17
  import random
18
  import copy
19
  from datetime import datetime
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def compute_relative_pose_matrix(pose1, pose2):
22
  """
23
+ 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel]
24
 
25
+ 参数:
26
+ pose1: i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
27
+ pose2: i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
28
 
29
+ 返回:
30
+ relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel
 
 
31
  """
32
+ # 分离平移向量和四元数
33
+ t1 = pose1[:3] # i帧平移 [tx1, ty1, tz1]
34
+ q1 = pose1[3:] # i帧四元数 [qx1, qy1, qz1, qw1]
35
+ t2 = pose2[:3] # i+1帧平移
36
+ q2 = pose2[3:] # i+1帧四元数
37
+
38
+ # 1. 计算相对旋转矩阵 R_rel
39
+ rot1 = R.from_quat(q1) # i帧旋转
40
+ rot2 = R.from_quat(q2) # i+1帧旋转
41
+ rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆
42
+ R_rel = rot_rel.as_matrix() # 转换为3×3矩阵
43
+
44
+ # 2. 计算相对平移向量 t_rel
45
+ R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆)
46
+ t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1)
47
+
48
+ # 3. 组合为3×4矩阵 [R_rel | t_rel]
49
  relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
50
 
51
  return relative_matrix
52
 
53
  def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
54
+ """pth文件加载预编码的视频数据"""
55
  print(f"Loading encoded video from {pth_path}")
56
 
57
  encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
 
68
 
69
  return condition_latents, encoded_data
70
 
71
+
72
  def compute_relative_pose(pose_a, pose_b, use_torch=False):
73
+ """计算相机B相对于相机A的相对位姿矩阵"""
74
+ assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
75
+ assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
76
 
77
  if use_torch:
78
  if not isinstance(pose_a, torch.Tensor):
 
95
 
96
 
97
  def replace_dit_model_in_manager():
98
+ """替换DiT模型类为MoE版本"""
99
  from diffsynth.models.wan_video_dit_moe import WanModelMoe
100
  from diffsynth.configs.model_config import model_loader_configs
101
 
 
110
  if name == 'wan_video_dit':
111
  new_model_names.append(name)
112
  new_model_classes.append(WanModelMoe)
113
+ print(f" 替换了模型类: {name} -> WanModelMoe")
114
  else:
115
  new_model_names.append(name)
116
  new_model_classes.append(cls)
 
119
 
120
 
121
  def add_framepack_components(dit_model):
122
+ """添加FramePack相关组件"""
123
  if not hasattr(dit_model, 'clean_x_embedder'):
124
  inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
125
 
 
146
  dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
147
  model_dtype = next(dit_model.parameters()).dtype
148
  dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
149
+ print(" 添加了FramePackclean_x_embedder组件")
150
 
151
 
152
  def add_moe_components(dit_model, moe_config):
153
+ """🔧 添加MoE相关组件 - 修正版本"""
154
  if not hasattr(dit_model, 'moe_config'):
155
  dit_model.moe_config = moe_config
156
+ print(" 添加了MoE配置到模型")
157
  dit_model.top_k = moe_config.get("top_k", 1)
158
 
159
+ # 为每个block动态添加MoE组件
160
  dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
161
  unified_dim = moe_config.get("unified_dim", 25)
162
  num_experts = moe_config.get("num_experts", 4)
163
  from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
164
  dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
165
  dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
166
+ dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理
167
  dit_model.global_router = nn.Linear(unified_dim, num_experts)
168
 
169
 
170
  for i, block in enumerate(dit_model.blocks):
171
+ # MoE网络 - 输入unified_dim,输出dim
172
  block.moe = MultiModalMoE(
173
  unified_dim=unified_dim,
174
+ output_dim=dim, # 输出维度匹配transformer blockdim
175
  num_experts=moe_config.get("num_experts", 4),
176
  top_k=moe_config.get("top_k", 2)
177
  )
178
 
179
+ print(f"Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
180
 
181
 
182
  def generate_sekai_camera_embeddings_sliding(
 
188
  use_real_poses=True,
189
  direction="left"):
190
  """
191
+ 为Sekai数据集生成camera embeddings - 滑动窗口版本
192
 
193
  Args:
194
+ cam_data: 包含Sekai相机外参的字典, 'extrinsic'对应一个N*4*4numpy数组
195
+ start_frame: 当前生成起始帧索引
196
+ initial_condition_frames: 初始条件帧数
197
+ new_frames: 本次生成的新帧数
198
+ total_generated: 已生成的总帧数
199
+ use_real_poses: 是否使用真实的Sekai相机位姿
200
+ direction: 相机运动方向,默认为"left"
201
 
202
  Returns:
203
+ camera_embedding: 形状为(M, 3*4 + 1)的torch张量, M为生成的总帧数
204
  """
205
  time_compression_ratio = 4
206
 
207
+ # 计算FramePack实际需要的camera帧数
208
+ # 1帧初始 + 164x + 22x + 11x + new_frames
209
  framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
210
 
211
  if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
212
+ print("🔧 使用真实Sekai camera数据")
213
  cam_extrinsic = cam_data['extrinsic']
214
 
215
+ # 确保生成足够长的camera序列
216
  max_needed_frames = max(
217
  start_frame + initial_condition_frames + new_frames,
218
  framepack_needed_frames,
219
  30
220
  )
221
 
222
+ print(f"🔧 计算Sekai camera序列长度:")
223
+ print(f" - 基础需求: {start_frame + initial_condition_frames + new_frames}")
224
+ print(f" - FramePack需求: {framepack_needed_frames}")
225
+ print(f" - 最终生成: {max_needed_frames}")
226
 
227
  relative_poses = []
228
  for i in range(max_needed_frames):
229
+ # 计算当前帧在原始序列中的位置
230
  frame_idx = i * time_compression_ratio
231
  next_frame_idx = frame_idx + time_compression_ratio
232
 
 
236
  relative_pose = compute_relative_pose(cam_prev, cam_next)
237
  relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
238
  else:
239
+ # 超出范围,使用零运动
240
+ print(f"⚠️ {frame_idx}超出camera数据范围,使用零运动")
241
  relative_poses.append(torch.zeros(3, 4))
242
 
243
  pose_embedding = torch.stack(relative_poses, dim=0)
244
  pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
245
 
246
+ # 创建对应长度的mask序列
247
  mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
248
+ # start_framestart_frame+initial_condition_frames标记为condition
249
  condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
250
  mask[start_frame:condition_end] = 1.0
251
 
252
  camera_embedding = torch.cat([pose_embedding, mask], dim=1)
253
+ print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
254
  return camera_embedding.to(torch.bfloat16)
255
 
256
  else:
257
+ # 确保生成足够长的camera序列
258
  max_needed_frames = max(
259
  start_frame + initial_condition_frames + new_frames,
260
  framepack_needed_frames,
261
  30)
262
 
263
+ print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
264
 
265
  CONDITION_FRAMES = initial_condition_frames
266
  STAGE_1 = new_frames//2
267
  STAGE_2 = new_frames - STAGE_1
268
 
269
+ if direction=="left":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  print("--------------- LEFT TURNING MODE ---------------")
271
  relative_poses = []
272
  for i in range(max_needed_frames):
273
  if i < CONDITION_FRAMES:
274
+ # 输入的条件帧默认的相机位姿为零运动
275
  pose = np.eye(4, dtype=np.float32)
276
  elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
277
+ # 左转
278
  yaw_per_frame = 0.03
279
 
280
+ # 旋转矩阵
281
  cos_yaw = np.cos(yaw_per_frame)
282
  sin_yaw = np.sin(yaw_per_frame)
283
 
284
+ # 前进
285
  forward_speed = 0.00
286
 
287
  pose = np.eye(4, dtype=np.float32)
 
292
  pose[2, 2] = cos_yaw
293
  pose[2, 3] = -forward_speed
294
  else:
295
+ # 超出条件帧与目标帧的部分,保持静止
296
  pose = np.eye(4, dtype=np.float32)
297
 
298
  relative_pose = pose[:3, :]
 
303
  relative_poses = []
304
  for i in range(max_needed_frames):
305
  if i < CONDITION_FRAMES:
306
+ # 输入的条件帧默认的相机位姿为零运动
307
  pose = np.eye(4, dtype=np.float32)
308
  elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
309
+ # 右转
310
  yaw_per_frame = -0.03
311
 
312
+ # 旋转矩阵
313
  cos_yaw = np.cos(yaw_per_frame)
314
  sin_yaw = np.sin(yaw_per_frame)
315
 
316
+ # 前进
317
  forward_speed = 0.00
318
 
319
  pose = np.eye(4, dtype=np.float32)
 
324
  pose[2, 2] = cos_yaw
325
  pose[2, 3] = -forward_speed
326
  else:
327
+ # 超出条件帧与目标帧的部分,保持静止
328
  pose = np.eye(4, dtype=np.float32)
329
 
330
  relative_pose = pose[:3, :]
 
335
  relative_poses = []
336
  for i in range(max_needed_frames):
337
  if i < CONDITION_FRAMES:
338
+ # 输入的条件帧默认的相机位姿为零运动
339
  pose = np.eye(4, dtype=np.float32)
340
  elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
341
+ # 左转
342
  yaw_per_frame = 0.03
343
 
344
+ # 旋转矩阵
345
  cos_yaw = np.cos(yaw_per_frame)
346
  sin_yaw = np.sin(yaw_per_frame)
347
 
348
+ # 前进
349
  forward_speed = 0.03
350
 
351
  pose = np.eye(4, dtype=np.float32)
 
357
  pose[2, 3] = -forward_speed
358
 
359
  else:
360
+ # 超出条件帧与目标帧的部分,保持静止
361
  pose = np.eye(4, dtype=np.float32)
362
 
363
  relative_pose = pose[:3, :]
 
368
  relative_poses = []
369
  for i in range(max_needed_frames):
370
  if i < CONDITION_FRAMES:
371
+ # 输入的条件帧默认的相机位姿为零运动
372
  pose = np.eye(4, dtype=np.float32)
373
  elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
374
+ # 右转
375
  yaw_per_frame = -0.03
376
 
377
+ # 旋转矩阵
378
  cos_yaw = np.cos(yaw_per_frame)
379
  sin_yaw = np.sin(yaw_per_frame)
380
 
381
+ # 前进
382
  forward_speed = 0.03
383
 
384
  pose = np.eye(4, dtype=np.float32)
 
390
  pose[2, 3] = -forward_speed
391
 
392
  else:
393
+ # 超出条件帧与目标帧的部分,保持静止
394
  pose = np.eye(4, dtype=np.float32)
395
 
396
  relative_pose = pose[:3, :]
 
401
  relative_poses = []
402
  for i in range(max_needed_frames):
403
  if i < CONDITION_FRAMES:
404
+ # 输入的条件帧默认的相机位姿为零运动
405
  pose = np.eye(4, dtype=np.float32)
406
  elif i < CONDITION_FRAMES+STAGE_1:
407
+ # 左转
408
  yaw_per_frame = 0.03
409
 
410
+ # 旋转矩阵
411
  cos_yaw = np.cos(yaw_per_frame)
412
  sin_yaw = np.sin(yaw_per_frame)
413
 
414
+ # 前进
415
  forward_speed = 0.03
416
 
417
  pose = np.eye(4, dtype=np.float32)
 
423
  pose[2, 3] = -forward_speed
424
 
425
  elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
426
+ # 右转
427
  yaw_per_frame = -0.03
428
 
429
+ # 旋转矩阵
430
  cos_yaw = np.cos(yaw_per_frame)
431
  sin_yaw = np.sin(yaw_per_frame)
432
 
433
+ # 前进
434
  forward_speed = 0.03
435
+ # 轻微向左漂移,保持惯性
436
  if i < CONDITION_FRAMES+STAGE_1+STAGE_2//3:
437
  radius_shift = -0.01
438
  else:
 
448
  pose[0, 3] = radius_shift
449
 
450
  else:
451
+ # 超出条件帧与目标帧的部分,保持静止
452
  pose = np.eye(4, dtype=np.float32)
453
 
454
  relative_pose = pose[:3, :]
 
459
  relative_poses = []
460
  for i in range(max_needed_frames):
461
  if i < CONDITION_FRAMES:
462
+ # 输入的条件帧默认的相机位姿为零运动
463
  pose = np.eye(4, dtype=np.float32)
464
  elif i < CONDITION_FRAMES+STAGE_1:
465
+ # 左转
466
  yaw_per_frame = 0.03
467
 
468
+ # 旋转矩阵
469
  cos_yaw = np.cos(yaw_per_frame)
470
  sin_yaw = np.sin(yaw_per_frame)
471
 
472
+ # 前进
473
  forward_speed = 0.00
474
 
475
  pose = np.eye(4, dtype=np.float32)
 
481
  pose[2, 3] = -forward_speed
482
 
483
  elif i < CONDITION_FRAMES+STAGE_1+STAGE_2:
484
+ # 右转
485
  yaw_per_frame = -0.03
486
 
487
+ # 旋转矩阵
488
  cos_yaw = np.cos(yaw_per_frame)
489
  sin_yaw = np.sin(yaw_per_frame)
490
 
491
+ # 前进
492
  forward_speed = 0.00
493
 
494
  pose = np.eye(4, dtype=np.float32)
 
500
  pose[2, 3] = -forward_speed
501
 
502
  else:
503
+ # 超出条件帧与目标帧的部分,保持静止
504
  pose = np.eye(4, dtype=np.float32)
505
 
506
  relative_pose = pose[:3, :]
507
  relative_poses.append(torch.as_tensor(relative_pose))
508
 
509
  else:
510
+ raise ValueError(f"未定义的相机运动方向: {direction}")
511
 
512
  pose_embedding = torch.stack(relative_poses, dim=0)
513
  pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
514
 
515
+ # 创建对应长度的mask序列
516
  mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
517
  condition_end = min(start_frame + initial_condition_frames + 1, max_needed_frames)
518
  mask[start_frame:condition_end] = 1.0
519
 
520
  camera_embedding = torch.cat([pose_embedding, mask], dim=1)
521
+ print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
522
  return camera_embedding.to(torch.bfloat16)
523
 
524
 
525
  def generate_openx_camera_embeddings_sliding(
526
  encoded_data, start_frame, initial_condition_frames, new_frames, use_real_poses):
527
+ """为OpenX数据集生成camera embeddings - 滑动窗口版本"""
528
  time_compression_ratio = 4
529
 
530
+ # 计算FramePack实际需要的camera帧数
531
  framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
532
 
533
  if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
534
+ print("🔧 使用OpenX真实camera数据")
535
  cam_extrinsic = encoded_data['cam_emb']['extrinsic']
536
 
537
+ # 确保生成足够长的camera序列
538
  max_needed_frames = max(
539
  start_frame + initial_condition_frames + new_frames,
540
  framepack_needed_frames,
541
  30
542
  )
543
 
544
+ print(f"🔧 计算OpenX camera序列长度:")
545
+ print(f" - 基础需求: {start_frame + initial_condition_frames + new_frames}")
546
+ print(f" - FramePack需求: {framepack_needed_frames}")
547
+ print(f" - 最终生成: {max_needed_frames}")
548
 
549
  relative_poses = []
550
  for i in range(max_needed_frames):
551
+ # OpenX使用4倍间隔,类似sekai但处理更短的序列
552
  frame_idx = i * time_compression_ratio
553
  next_frame_idx = frame_idx + time_compression_ratio
554
 
 
558
  relative_pose = compute_relative_pose(cam_prev, cam_next)
559
  relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
560
  else:
561
+ # 超出范围,使用零运动
562
+ print(f"⚠️ {frame_idx}超出OpenX camera数据范围,使用零运动")
563
  relative_poses.append(torch.zeros(3, 4))
564
 
565
  pose_embedding = torch.stack(relative_poses, dim=0)
566
  pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
567
 
568
+ # 创建对应长度的mask序列
569
  mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
570
+ # start_framestart_frame + initial_condition_frames标记为condition
571
  condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
572
  mask[start_frame:condition_end] = 1.0
573
 
574
  camera_embedding = torch.cat([pose_embedding, mask], dim=1)
575
+ print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
576
  return camera_embedding.to(torch.bfloat16)
577
 
578
  else:
579
+ print("🔧 使用OpenX合成camera数据")
580
 
581
  max_needed_frames = max(
582
  start_frame + initial_condition_frames + new_frames,
 
584
  30
585
  )
586
 
587
+ print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
588
  relative_poses = []
589
  for i in range(max_needed_frames):
590
+ # OpenX机器人操作运动模式 - 较小的运动幅度
591
+ # 模拟机器人手臂的精细操作运动
592
+ roll_per_frame = 0.02 # 轻微翻滚
593
+ pitch_per_frame = 0.01 # 轻微俯仰
594
+ yaw_per_frame = 0.015 # 轻微偏航
595
+ forward_speed = 0.003 # 较慢的前进速度
596
 
597
  pose = np.eye(4, dtype=np.float32)
598
 
599
+ # 复合旋转 - 模拟机器人手臂的复杂运动
600
+ # X轴旋转(roll
601
  cos_roll = np.cos(roll_per_frame)
602
  sin_roll = np.sin(roll_per_frame)
603
+ # Y轴旋转(pitch
604
  cos_pitch = np.cos(pitch_per_frame)
605
  sin_pitch = np.sin(pitch_per_frame)
606
+ # Z轴旋转(yaw
607
  cos_yaw = np.cos(yaw_per_frame)
608
  sin_yaw = np.sin(yaw_per_frame)
609
 
610
+ # 简化的复合旋转矩阵(ZYX顺序)
611
  pose[0, 0] = cos_yaw * cos_pitch
612
  pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
613
  pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
 
618
  pose[2, 1] = cos_pitch * sin_roll
619
  pose[2, 2] = cos_pitch * cos_roll
620
 
621
+ # 平移 - ���拟机器人操作的精细移动
622
+ pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
623
+ pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动
624
+ pose[2, 3] = -forward_speed # Z方向(深度)主要移动
625
 
626
  relative_pose = pose[:3, :]
627
  relative_poses.append(torch.as_tensor(relative_pose))
 
629
  pose_embedding = torch.stack(relative_poses, dim=0)
630
  pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
631
 
632
+ # 创建对应长度的mask序列
633
  mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
634
  condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
635
  mask[start_frame:condition_end] = 1.0
636
 
637
  camera_embedding = torch.cat([pose_embedding, mask], dim=1)
638
+ print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
639
  return camera_embedding.to(torch.bfloat16)
640
 
641
 
642
  def generate_nuscenes_camera_embeddings_sliding(
643
  scene_info, start_frame, initial_condition_frames, new_frames):
644
+ """为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
 
 
 
 
645
  time_compression_ratio = 4
646
 
647
+ # 计算FramePack实际需要的camera帧数
648
  framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
649
 
650
  if scene_info is not None and 'keyframe_poses' in scene_info:
651
+ print("🔧 使用NuScenes真实pose数据")
652
  keyframe_poses = scene_info['keyframe_poses']
653
 
654
  if len(keyframe_poses) == 0:
655
+ print("⚠️ NuScenes keyframe_poses为空,使用零pose")
656
  max_needed_frames = max(framepack_needed_frames, 30)
657
 
658
  pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
 
662
  mask[start_frame:condition_end] = 1.0
663
 
664
  camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
665
+ print(f"🔧 NuScenespose embedding shape: {camera_embedding.shape}")
666
  return camera_embedding.to(torch.bfloat16)
667
 
668
+ # 使用第一个pose作为参考
669
  reference_pose = keyframe_poses[0]
670
 
671
  max_needed_frames = max(framepack_needed_frames, 30)
 
675
  if i < len(keyframe_poses):
676
  current_pose = keyframe_poses[i]
677
 
678
+ # 计算相对位移
679
  translation = torch.tensor(
680
  np.array(current_pose['translation']) - np.array(reference_pose['translation']),
681
  dtype=torch.float32
682
  )
683
 
684
+ # 计算相对旋转(简化版本)
685
  rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
686
 
687
  pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
688
  else:
689
+ # 超出范围,使用零pose
690
  pose_vec = torch.cat([
691
  torch.zeros(3, dtype=torch.float32),
692
  torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
 
696
 
697
  pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
698
 
699
+ # 创建mask
700
  mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
701
  condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
702
  mask[start_frame:condition_end] = 1.0
703
 
704
  camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
705
+ print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
706
  return camera_embedding.to(torch.bfloat16)
707
 
708
  else:
709
+ print("🔧 使用NuScenes合成pose数据")
710
  max_needed_frames = max(framepack_needed_frames, 30)
711
 
712
+ # 创建合成运动序列
713
  pose_vecs = []
714
  for i in range(max_needed_frames):
715
+ # 左转运动模式 - 类似城市驾驶中的左转弯
716
+ angle = i * 0.04 # 每帧转动0.08弧度(稍微慢一点的转弯)
717
+ radius = 15.0 # 较大的转弯半径,更符合汽车转弯
718
 
719
+ # 计算圆弧轨迹上的位置
720
  x = radius * np.sin(angle)
721
+ y = 0.0 # 保持水平面运动
722
  z = radius * (1 - np.cos(angle))
723
 
724
  translation = torch.tensor([x, y, z], dtype=torch.float32)
725
 
726
+ # 车辆朝向 - 始终沿着轨迹切线方向
727
+ yaw = angle + np.pi/2 # 相对于初始前进方向的偏航角
728
+ # 四元数表示绕Y轴的旋转
729
  rotation = torch.tensor([
730
+ np.cos(yaw/2), # w (实部)
731
  0.0, # x
732
  0.0, # y
733
+ np.sin(yaw/2) # z (虚部,绕Y)
734
  ], dtype=torch.float32)
735
 
736
  pose_vec = torch.cat([translation, rotation], dim=0) # [7D: tx,ty,tz,qw,qx,qy,qz]
 
738
 
739
  pose_sequence = torch.stack(pose_vecs, dim=0)
740
 
741
+ # 创建mask
742
  mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
743
  condition_end = min(start_frame + initial_condition_frames, max_needed_frames)
744
  mask[start_frame:condition_end] = 1.0
745
 
746
  camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
747
+ print(f"🔧 NuScenes合成左转pose embedding shape: {camera_embedding.shape}")
748
  return camera_embedding.to(torch.bfloat16)
749
+
750
  def prepare_framepack_sliding_window_with_camera_moe(
751
  history_latents,
752
  target_frames_to_generate,
 
754
  start_frame,
755
  modality_type,
756
  max_history_frames=49):
757
+ """FramePack滑动窗口机制 - MoE版本"""
758
+ # history_latents: [C, T, H, W] 当前的历史latents
759
  C, T, H, W = history_latents.shape
760
 
761
+ # 固定索引结构(这决定了需要的camera帧数)
762
+ # 1帧起始 + 164x + 22x + 11x + target_frames_to_generate
763
  total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
764
  indices = torch.arange(0, total_indices_length)
765
  split_sizes = [1, 16, 2, 1, target_frames_to_generate]
 
767
  indices.split(split_sizes, dim=0)
768
  clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
769
 
770
+ # 检查camera长度是否足够
771
  if camera_embedding_full.shape[0] < total_indices_length:
772
+ print(f"⚠️ camera_embedding长度不足,进行零补齐: 当前长度 {camera_embedding_full.shape[0]}, 需要长度 {total_indices_length}")
773
  shortage = total_indices_length - camera_embedding_full.shape[0]
774
  padding = torch.zeros(shortage, camera_embedding_full.shape[1],
775
  dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
776
  camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
777
 
778
+ # 从完整camera序列中选取对应部分
779
  combined_camera = torch.zeros(
780
  total_indices_length,
781
  camera_embedding_full.shape[1],
782
  dtype=camera_embedding_full.dtype,
783
  device=camera_embedding_full.device)
784
 
785
+ # 历史条件帧的相机位姿
786
  history_slice = camera_embedding_full[max(T - 19, 0):T, :].clone()
787
  combined_camera[19 - history_slice.shape[0]:19, :] = history_slice
788
 
789
+ # 目标帧的相机位姿
790
  target_slice = camera_embedding_full[T:T + target_frames_to_generate, :].clone()
791
  combined_camera[19:19 + target_slice.shape[0], :] = target_slice
792
 
793
+ # 根据当前history length重新设置mask
794
+ combined_camera[:, -1] = 0.0 # 先全部设为target (0)
795
 
796
+ # 设置condition mask:前19帧根据实际历史长度决定
797
  if T > 0:
798
  available_frames = min(T, 19)
799
  start_pos = 19 - available_frames
800
+ combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
801
 
802
+ print(f"🔧 MoE Camera mask更新:")
803
+ print(f" - 历史帧数: {T}")
804
+ print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
805
+ print(f" - 模态类型: {modality_type}")
806
 
807
+ # 处理latents
808
  clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
809
 
810
  if T > 0:
 
832
  'clean_latent_2x_indices': clean_latent_2x_indices,
833
  'clean_latent_4x_indices': clean_latent_4x_indices,
834
  'camera_embedding': combined_camera,
835
+ 'modality_type': modality_type, # 新增模态类型信息
836
  'current_length': T,
837
  'next_length': T + target_frames_to_generate
838
  }
839
 
840
  def overlay_controls(frame_img, pose_vec, icons):
841
  """
842
+ 根据相机位姿在帧上叠加控制图标(WASD 和箭头)
843
+ pose_vec: 12 个元素(展平的 3x4 矩阵)+ mask
844
  """
845
  if pose_vec is None or np.all(pose_vec[:12] == 0):
846
  return frame_img
847
 
848
+ # 提取平移向量(基于展平的 3x4 矩阵的索引)
849
  # [r00, r01, r02, tx, r10, r11, r12, ty, r20, r21, r22, tz]
850
  tx = pose_vec[3]
851
  # ty = pose_vec[7]
852
  tz = pose_vec[11]
853
 
854
+ # 提取旋转(偏航和俯仰)
855
+ # 偏航:绕 Y 轴。sin(偏航) = r02, cos(偏航) = r00
856
  r00 = pose_vec[0]
857
  r02 = pose_vec[2]
858
  yaw = np.arctan2(r02, r00)
859
 
860
+ # 俯仰:绕 X 轴。sin(俯仰) = -r12, cos(俯仰) = r22
861
  r12 = pose_vec[6]
862
  r22 = pose_vec[10]
863
  pitch = np.arctan2(-r12, r22)
864
 
865
+ # 按键激活的阈值
866
  TRANS_THRESH = 0.01
867
  ROT_THRESH = 0.005
868
 
869
+ # 确定按键状态
870
+ # 平移(WASD
871
+ # 假设 -Z 为前进,+X 为右
872
  is_forward = tz < -TRANS_THRESH
873
  is_backward = tz > TRANS_THRESH
874
  is_left = tx < -TRANS_THRESH
875
  is_right = tx > TRANS_THRESH
876
 
877
+ # 旋转(箭头)
878
+ # 偏航:+ 为左,- 为右
879
  is_turn_left = yaw > ROT_THRESH
880
  is_turn_right = yaw < -ROT_THRESH
881
 
882
+ # 俯仰:+ 为下,- 为上
883
  is_turn_up = pitch < -ROT_THRESH
884
  is_turn_down = pitch > ROT_THRESH
885
 
 
890
  name = name_active if is_active else name_inactive
891
  if name in icons:
892
  icon = icons[name]
893
+ # 使用 alpha 通道粘贴
894
  frame_img.paste(icon, (int(x), int(y)), icon)
895
 
896
+ # 叠加 WASD(左下角)
897
  base_x_right = 100
898
  base_y = H - 100
899
 
 
906
  # D
907
  paste_icon('move_right.png', 'not_move_right.png', is_right, base_x_right + spacing, base_y)
908
 
909
+ # 叠加 ↑↓←→(右下角)
910
  base_x_left = W - 150
911
 
912
  # ↑
 
922
 
923
 
924
  def inference_moe_framepack_sliding_window(
925
+ condition_pth_path,
926
+ dit_path,
 
 
 
927
  output_path="../examples/output_videos/output_moe_framepack_sliding.mp4",
928
  start_frame=0,
929
  initial_condition_frames=8,
 
932
  max_history_frames=49,
933
  device="cuda",
934
  prompt="A video of a scene shot using a pedestrian's front camera while walking",
935
+ modality_type="sekai", # "sekai" "nuscenes"
936
  use_real_poses=True,
937
+ scene_info_path=None, # 对于NuScenes数据集
938
+ # CFG参数
939
  use_camera_cfg=True,
940
  camera_guidance_scale=2.0,
941
  text_guidance_scale=1.0,
942
+ # MoE参数
943
  moe_num_experts=4,
944
  moe_top_k=2,
945
  moe_hidden_dim=None,
 
948
  add_icons=False
949
  ):
950
  """
951
+ MoE FramePack滑动窗口视频生成 - 支持多模态
952
  """
953
+ # 创建输出目录
954
  dir_path = os.path.dirname(output_path)
955
  os.makedirs(dir_path, exist_ok=True)
956
 
957
+ print(f"🔧 MoE FramePack滑动窗口生成开始...")
958
+ print(f"模态类型: {modality_type}")
959
+ print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
960
+ print(f"Text guidance scale: {text_guidance_scale}")
961
+ print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
962
 
963
+ # 1. 模型初始化
964
  replace_dit_model_in_manager()
965
 
966
  model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
967
  model_manager.load_models([
968
+ "/mnt/data/louis_crq/models/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
969
+ "/mnt/data/louis_crq/models/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
970
+ "/mnt/data/louis_crq/models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
971
  ])
972
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
973
 
974
+ # 2. 添加传统camera编码器(兼容性)
975
  dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
976
  for block in pipe.dit.blocks:
977
  block.cam_encoder = nn.Linear(13, dim)
 
981
  block.projector.weight = nn.Parameter(torch.eye(dim))
982
  block.projector.bias = nn.Parameter(torch.zeros(dim))
983
 
984
+ # 3. 添加FramePack组件
985
  add_framepack_components(pipe.dit)
986
 
987
+ # 4. 添加MoE组件
988
  moe_config = {
989
  "num_experts": moe_num_experts,
990
  "top_k": moe_top_k,
991
  "hidden_dim": moe_hidden_dim or dim * 2,
992
+ "sekai_input_dim": 13, # Sekai: 12pose + 1mask
993
+ "nuscenes_input_dim": 8, # NuScenes: 7pose + 1mask
994
+ "openx_input_dim": 13 # OpenX: 12pose + 1mask (类似sekai)
995
  }
996
  add_moe_components(pipe.dit, moe_config)
997
 
998
+ # 5. 加载训练好的权重
999
  dit_state_dict = torch.load(dit_path, map_location="cpu")
1000
+ pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
1001
  pipe = pipe.to(device)
1002
  model_dtype = next(pipe.dit.parameters()).dtype
1003
 
1004
  if hasattr(pipe.dit, 'clean_x_embedder'):
1005
  pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
1006
 
1007
+ # 设置去噪步数
1008
  pipe.scheduler.set_timesteps(50)
1009
 
1010
+ # 6. 加载初始条件
1011
  print("Loading initial condition frames...")
1012
+ initial_latents, encoded_data = load_encoded_video_from_pth(
1013
+ condition_pth_path,
1014
+ start_frame=start_frame,
1015
+ num_frames=initial_condition_frames
 
 
 
 
1016
  )
1017
 
1018
+ # 空间裁剪
1019
  target_height, target_width = 60, 104
1020
  C, T, H, W = initial_latents.shape
1021
 
 
1027
 
1028
  history_latents = initial_latents.to(device, dtype=model_dtype)
1029
 
1030
+ print(f"初始history_latents shape: {history_latents.shape}")
1031
 
1032
+ # 7. 编码prompt - 支持CFG
1033
  if use_gt_prompt and 'prompt_emb' in encoded_data:
1034
+ print("✅ 使用预编码的GT prompt embedding")
1035
  prompt_emb_pos = encoded_data['prompt_emb']
1036
+ # prompt_emb移到正确的设备和数据类型
1037
  if 'context' in prompt_emb_pos:
1038
  prompt_emb_pos['context'] = prompt_emb_pos['context'].to(device, dtype=model_dtype)
1039
  if 'context_mask' in prompt_emb_pos:
1040
  prompt_emb_pos['context_mask'] = prompt_emb_pos['context_mask'].to(device, dtype=model_dtype)
1041
 
1042
+ # 如果使用Text CFG,生成负向prompt
1043
  if text_guidance_scale > 1.0:
1044
  prompt_emb_neg = pipe.encode_prompt("")
1045
+ print(f"使用Text CFG with GT promptguidance scale: {text_guidance_scale}")
1046
  else:
1047
  prompt_emb_neg = None
1048
+ print("不使用Text CFG")
1049
 
1050
+ # 🔧 打印GT prompt文本(如果有)
1051
  if 'prompt' in encoded_data['prompt_emb']:
1052
  gt_prompt_text = encoded_data['prompt_emb']['prompt']
1053
+ print(f"📝 GT Prompt文本: {gt_prompt_text}")
1054
  else:
1055
+ # 使用传入的prompt参数重新编码
1056
+ print(f"🔄 重新编码prompt: {prompt}")
1057
  if text_guidance_scale > 1.0:
1058
  prompt_emb_pos = pipe.encode_prompt(prompt)
1059
  prompt_emb_neg = pipe.encode_prompt("")
1060
+ print(f"使用Text CFGguidance scale: {text_guidance_scale}")
1061
  else:
1062
  prompt_emb_pos = pipe.encode_prompt(prompt)
1063
  prompt_emb_neg = None
1064
+ print("不使用Text CFG")
1065
 
1066
+ # 8. 加载场景信息(对于NuScenes
1067
  scene_info = None
1068
  if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
1069
  with open(scene_info_path, 'r') as f:
1070
  scene_info = json.load(f)
1071
+ print(f"加载NuScenes场景信息: {scene_info_path}")
1072
 
1073
+ # 9. 预生成完整的camera embedding序列
1074
  if modality_type == "sekai":
1075
  camera_embedding_full = generate_sekai_camera_embeddings_sliding(
1076
  encoded_data.get('cam_emb', None),
 
1097
  use_real_poses=use_real_poses
1098
  ).to(device, dtype=model_dtype)
1099
  else:
1100
+ raise ValueError(f"不支持的模态类型: {modality_type}")
1101
 
1102
+ print(f"完整camera序列shape: {camera_embedding_full.shape}")
1103
 
1104
+ # 10. 为Camera CFG创建无条件的camera embedding
1105
  if use_camera_cfg:
1106
  camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
1107
+ print(f"创建无条件camera embedding用于CFG")
1108
 
1109
+ # 11. 滑动窗口生成循环
1110
  total_generated = 0
1111
  all_generated_frames = []
1112
 
1113
  while total_generated < total_frames_to_generate:
1114
  current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
1115
+ print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
1116
+ print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
1117
 
1118
+ # FramePack数据准备 - MoE版本
1119
  framepack_data = prepare_framepack_sliding_window_with_camera_moe(
1120
  history_latents,
1121
  current_generation,
 
1125
  max_history_frames
1126
  )
1127
 
1128
+ # 准备输入
1129
  clean_latents = framepack_data['clean_latents'].unsqueeze(0)
1130
  clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
1131
  clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
1132
  camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
1133
 
1134
+ # 准备modality_inputs
1135
  modality_inputs = {modality_type: camera_embedding}
1136
 
1137
+ # 为CFG准备无条件camera embedding
1138
  if use_camera_cfg:
1139
  camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
1140
  modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
1141
 
1142
+ # 索引处理
1143
  latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
1144
  clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
1145
  clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
1146
  clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
1147
 
1148
+ # 初始化要生成的latents
1149
  new_latents = torch.randn(
1150
  1, C, current_generation, H, W,
1151
  device=device, dtype=model_dtype
 
1154
  extra_input = pipe.prepare_extra_input(new_latents)
1155
 
1156
  print(f"Camera embedding shape: {camera_embedding.shape}")
1157
+ print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
1158
 
1159
+ # 去噪循环 - 支持CFG
1160
  timesteps = pipe.scheduler.timesteps
1161
 
1162
  for i, timestep in enumerate(timesteps):
1163
  if i % 10 == 0:
1164
+ print(f" 去噪步骤 {i+1}/{len(timesteps)}")
1165
 
1166
  timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
1167
 
1168
  with torch.no_grad():
1169
+ # CFG推理
1170
  if use_camera_cfg and camera_guidance_scale > 1.0:
1171
+ # 条件预测(有camera
1172
  noise_pred_cond, moe_loess = pipe.dit(
1173
  new_latents,
1174
  timestep=timestep_tensor,
1175
  cam_emb=camera_embedding,
1176
+ modality_inputs=modality_inputs, # MoE模态输入
1177
  latent_indices=latent_indices,
1178
  clean_latents=clean_latents,
1179
  clean_latent_indices=clean_latent_indices,
 
1185
  **extra_input
1186
  )
1187
 
1188
+ # 无条件预测(无camera
1189
  noise_pred_uncond, moe_loess = pipe.dit(
1190
  new_latents,
1191
  timestep=timestep_tensor,
1192
  cam_emb=camera_embedding_uncond_batch,
1193
+ modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
1194
  latent_indices=latent_indices,
1195
  clean_latents=clean_latents,
1196
  clean_latent_indices=clean_latent_indices,
 
1205
  # Camera CFG
1206
  noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
1207
 
1208
+ # 如果同时使用Text CFG
1209
  if text_guidance_scale > 1.0 and prompt_emb_neg:
1210
  noise_pred_text_uncond, moe_loess = pipe.dit(
1211
  new_latents,
 
1223
  **extra_input
1224
  )
1225
 
1226
+ # 应用Text CFG到已经应用Camera CFG的结果
1227
  noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
1228
 
1229
  elif text_guidance_scale > 1.0 and prompt_emb_neg:
1230
+ # 只使用Text CFG
1231
  noise_pred_cond, moe_loess = pipe.dit(
1232
  new_latents,
1233
  timestep=timestep_tensor,
 
1263
  noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
1264
 
1265
  else:
1266
+ # 标准推理(无CFG
1267
  noise_pred, moe_loess = pipe.dit(
1268
  new_latents,
1269
  timestep=timestep_tensor,
1270
  cam_emb=camera_embedding,
1271
+ modality_inputs=modality_inputs, # MoE模态输入
1272
  latent_indices=latent_indices,
1273
  clean_latents=clean_latents,
1274
  clean_latent_indices=clean_latent_indices,
 
1282
 
1283
  new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
1284
 
1285
+ # 更新历史
1286
  new_latents_squeezed = new_latents.squeeze(0)
1287
  history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
1288
 
1289
+ # 维护滑动窗口
1290
  if history_latents.shape[1] > max_history_frames:
1291
  first_frame = history_latents[:, 0:1, :, :]
1292
  recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
1293
  history_latents = torch.cat([first_frame, recent_frames], dim=1)
1294
+ print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}")
1295
 
1296
+ print(f"更新后history_latents shape: {history_latents.shape}")
1297
 
1298
  all_generated_frames.append(new_latents_squeezed)
1299
  total_generated += current_generation
1300
 
1301
+ print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} ")
1302
 
1303
+ # 12. 解码和保存
1304
+ print("\n🔧 解码生成的视频...")
1305
 
1306
  all_generated = torch.cat(all_generated_frames, dim=1)
1307
  final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
1308
 
1309
+ print(f"最终视频shape: {final_video.shape}")
1310
 
1311
  decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
1312
 
 
1319
  icons = {}
1320
  video_camera_poses = None
1321
  if add_icons:
1322
+ # 加载用于叠加的图标资源
1323
  icons_dir = os.path.join(ROOT_DIR, 'icons')
1324
  icon_names = ['move_forward.png', 'not_move_forward.png',
1325
  'move_backward.png', 'not_move_backward.png',
 
1334
  if os.path.exists(path):
1335
  try:
1336
  icon = Image.open(path).convert("RGBA")
1337
+ # 调整图标尺寸
1338
  icon = icon.resize((50, 50), Image.Resampling.LANCZOS)
1339
  icons[name] = icon
1340
  except Exception as e:
1341
  print(f"Error loading icon {name}: {e}")
1342
  else:
1343
+ print(f"Warning: Icon {name} not found at {path}")
1344
 
1345
+ # 获取与视频帧对应的相机姿态
1346
  time_compression_ratio = 4
1347
  camera_poses = camera_embedding_full.detach().float().cpu().numpy()
1348
  video_camera_poses = [x for x in camera_poses for _ in range(time_compression_ratio)]
 
1361
 
1362
  writer.append_data(np.array(img))
1363
 
1364
+ print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
1365
+ print(f"总共生成了 {total_generated} (压缩后), 对应原始 {total_generated * 4} ")
1366
+ print(f"使用模态: {modality_type}")
1367
 
1368
 
1369
  def main():
1370
+ parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
1371
+
1372
+ # 基础参数
1373
+ parser.add_argument("--condition_pth", type=str,
1374
+ default="../examples/condition_pth/garden_1.pth")
 
 
 
 
 
 
 
 
 
 
 
1375
  parser.add_argument("--start_frame", type=int, default=0)
1376
  parser.add_argument("--initial_condition_frames", type=int, default=1)
1377
  parser.add_argument("--frames_per_generation", type=int, default=8)
1378
  parser.add_argument("--total_frames_to_generate", type=int, default=24)
1379
  parser.add_argument("--max_history_frames", type=int, default=100)
1380
  parser.add_argument("--use_real_poses", default=False)
1381
+ parser.add_argument("--dit_path", type=str, default=None, required=True,
 
1382
  help="path to the pretrained DiT MoE model checkpoint")
 
 
 
 
1383
  parser.add_argument("--output_path", type=str,
1384
+ default='./examples/output_videos/output_moe_framepack_sliding.mp4')
1385
+ parser.add_argument("--prompt", type=str, default=None,
 
 
1386
  help="text prompt for video generation")
1387
  parser.add_argument("--device", type=str, default="cuda")
1388
  parser.add_argument("--add_icons", action="store_true", default=False,
1389
+ help="在生成的视频上叠加控制图标")
1390
 
1391
+ # 模态类型参数
1392
  parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"],
1393
+ default="sekai", help="模态类型:sekai nuscenes openx")
1394
  parser.add_argument("--scene_info_path", type=str, default=None,
1395
+ help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
1396
 
1397
+ # CFG参数
1398
  parser.add_argument("--use_camera_cfg", default=False,
1399
+ help="使用Camera CFG")
1400
  parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
1401
  help="Camera guidance scale for CFG")
1402
  parser.add_argument("--text_guidance_scale", type=float, default=1.0,
1403
  help="Text guidance scale for CFG")
1404
 
1405
+ # MoE参数
1406
+ parser.add_argument("--moe_num_experts", type=int, default=3, help="专家数量")
1407
+ parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
1408
+ parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
1409
+ parser.add_argument("--direction", type=str, default="left", help="生成视频的行进轨迹方向")
1410
  parser.add_argument("--use_gt_prompt", action="store_true", default=False,
1411
+ help="使用数据集中的ground truth prompt embedding")
1412
 
1413
  args = parser.parse_args()
1414
 
1415
+ print(f"🔧 MoE FramePack CFG生成设置:")
1416
+ print(f"模态类型: {args.modality_type}")
1417
  print(f"Camera CFG: {args.use_camera_cfg}")
1418
  if args.use_camera_cfg:
1419
  print(f"Camera guidance scale: {args.camera_guidance_scale}")
1420
+ print(f"使用GT Prompt: {args.use_gt_prompt}")
1421
  print(f"Text guidance scale: {args.text_guidance_scale}")
1422
+ print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
1423
  print(f"DiT{args.dit_path}")
1424
 
1425
+ # 验证NuScenes参数
1426
  if args.modality_type == "nuscenes" and not args.scene_info_path:
1427
+ print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
 
 
 
 
 
 
 
 
 
 
 
 
 
1428
 
1429
  inference_moe_framepack_sliding_window(
1430
  condition_pth_path=args.condition_pth,
 
 
1431
  dit_path=args.dit_path,
 
1432
  output_path=args.output_path,
1433
  start_frame=args.start_frame,
1434
  initial_condition_frames=args.initial_condition_frames,
 
1440
  modality_type=args.modality_type,
1441
  use_real_poses=args.use_real_poses,
1442
  scene_info_path=args.scene_info_path,
1443
+ # CFG参数
1444
  use_camera_cfg=args.use_camera_cfg,
1445
  camera_guidance_scale=args.camera_guidance_scale,
1446
  text_guidance_scale=args.text_guidance_scale,
1447
+ # MoE参数
1448
  moe_num_experts=args.moe_num_experts,
1449
  moe_top_k=args.moe_top_k,
1450
  moe_hidden_dim=args.moe_hidden_dim,
scripts/infer_moe.py ADDED
@@ -0,0 +1,1023 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ import imageio
7
+ import json
8
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
9
+ import argparse
10
+ from torchvision.transforms import v2
11
+ from einops import rearrange
12
+ import copy
13
+ from scipy.spatial.transform import Rotation as R
14
+
15
+
16
+ def compute_relative_pose_matrix(pose1, pose2):
17
+ """
18
+ 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel]
19
+
20
+ 参数:
21
+ pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
22
+ pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
23
+
24
+ 返回:
25
+ relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel
26
+ """
27
+ # 分离平移向量和四元数
28
+ t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1]
29
+ q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1]
30
+ t2 = pose2[:3] # 第i+1帧平移
31
+ q2 = pose2[3:] # 第i+1帧四元数
32
+
33
+ # 1. 计算相对旋转矩阵 R_rel
34
+ rot1 = R.from_quat(q1) # 第i帧旋转
35
+ rot2 = R.from_quat(q2) # 第i+1帧旋转
36
+ rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆
37
+ R_rel = rot_rel.as_matrix() # 转换为3×3矩阵
38
+
39
+ # 2. 计算相对平移向量 t_rel
40
+ R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆)
41
+ t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1)
42
+
43
+ # 3. 组合为3×4矩阵 [R_rel | t_rel]
44
+ relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
45
+
46
+ return relative_matrix
47
+
48
+
49
+ def calculate_relative_rotation(current_rotation, reference_rotation):
50
+ """计算相对旋转四元数 - NuScenes专用"""
51
+ q_current = torch.tensor(current_rotation, dtype=torch.float32)
52
+ q_ref = torch.tensor(reference_rotation, dtype=torch.float32)
53
+ q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]])
54
+ w1, x1, y1, z1 = q_ref_inv
55
+ w2, x2, y2, z2 = q_current
56
+ relative_rotation = torch.tensor([
57
+ w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
58
+ w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
59
+ w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
60
+ w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
61
+ ])
62
+ return relative_rotation
63
+
64
+
65
+ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
66
+ """从pth文件加载预编码的视频数据"""
67
+ print(f"Loading encoded video from {pth_path}")
68
+
69
+ encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
70
+ full_latents = encoded_data['latents'] # [C, T, H, W]
71
+
72
+ print(f"Full latents shape: {full_latents.shape}")
73
+ print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
74
+
75
+ if start_frame + num_frames > full_latents.shape[1]:
76
+ raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
77
+
78
+ condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
79
+ print(f"Extracted condition latents shape: {condition_latents.shape}")
80
+
81
+ return condition_latents, encoded_data
82
+
83
+
84
+ def compute_relative_pose(pose_a, pose_b, use_torch=False):
85
+ """计算相机B相对于相机A的相对位姿矩阵"""
86
+ assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
87
+ assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
88
+
89
+ if use_torch:
90
+ if not isinstance(pose_a, torch.Tensor):
91
+ pose_a = torch.from_numpy(pose_a).float()
92
+ if not isinstance(pose_b, torch.Tensor):
93
+ pose_b = torch.from_numpy(pose_b).float()
94
+
95
+ pose_a_inv = torch.inverse(pose_a)
96
+ relative_pose = torch.matmul(pose_b, pose_a_inv)
97
+ else:
98
+ if not isinstance(pose_a, np.ndarray):
99
+ pose_a = np.array(pose_a, dtype=np.float32)
100
+ if not isinstance(pose_b, np.ndarray):
101
+ pose_b = np.array(pose_b, dtype=np.float32)
102
+
103
+ pose_a_inv = np.linalg.inv(pose_a)
104
+ relative_pose = np.matmul(pose_b, pose_a_inv)
105
+
106
+ return relative_pose
107
+
108
+
109
+ def replace_dit_model_in_manager():
110
+ """替换DiT模型类为MoE版本"""
111
+ from diffsynth.models.wan_video_dit_moe import WanModelMoe
112
+ from diffsynth.configs.model_config import model_loader_configs
113
+
114
+ for i, config in enumerate(model_loader_configs):
115
+ keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
116
+
117
+ if 'wan_video_dit' in model_names:
118
+ new_model_names = []
119
+ new_model_classes = []
120
+
121
+ for name, cls in zip(model_names, model_classes):
122
+ if name == 'wan_video_dit':
123
+ new_model_names.append(name)
124
+ new_model_classes.append(WanModelMoe)
125
+ print(f"✅ 替换了模型类: {name} -> WanModelMoe")
126
+ else:
127
+ new_model_names.append(name)
128
+ new_model_classes.append(cls)
129
+
130
+ model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
131
+
132
+
133
+ def add_framepack_components(dit_model):
134
+ """添加FramePack相关组件"""
135
+ if not hasattr(dit_model, 'clean_x_embedder'):
136
+ inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
137
+
138
+ class CleanXEmbedder(nn.Module):
139
+ def __init__(self, inner_dim):
140
+ super().__init__()
141
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
142
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
143
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
144
+
145
+ def forward(self, x, scale="1x"):
146
+ if scale == "1x":
147
+ x = x.to(self.proj.weight.dtype)
148
+ return self.proj(x)
149
+ elif scale == "2x":
150
+ x = x.to(self.proj_2x.weight.dtype)
151
+ return self.proj_2x(x)
152
+ elif scale == "4x":
153
+ x = x.to(self.proj_4x.weight.dtype)
154
+ return self.proj_4x(x)
155
+ else:
156
+ raise ValueError(f"Unsupported scale: {scale}")
157
+
158
+ dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
159
+ model_dtype = next(dit_model.parameters()).dtype
160
+ dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
161
+ print("✅ 添加了FramePack的clean_x_embedder组件")
162
+
163
+
164
+ def add_moe_components(dit_model, moe_config):
165
+ """🔧 添加MoE相关组件 - 修正版本"""
166
+ if not hasattr(dit_model, 'moe_config'):
167
+ dit_model.moe_config = moe_config
168
+ print("✅ 添加了MoE配置到模型")
169
+ dit_model.top_k = moe_config.get("top_k", 1)
170
+
171
+ # 为每个block动态添加MoE组件
172
+ dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
173
+ unified_dim = moe_config.get("unified_dim", 25)
174
+ num_experts = moe_config.get("num_experts", 4)
175
+ from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
176
+ dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
177
+ dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
178
+ dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理
179
+ dit_model.global_router = nn.Linear(unified_dim, num_experts)
180
+
181
+
182
+ for i, block in enumerate(dit_model.blocks):
183
+ # MoE网络 - 输入unified_dim,输出dim
184
+ block.moe = MultiModalMoE(
185
+ unified_dim=unified_dim,
186
+ output_dim=dim, # 输出维度匹配transformer block的dim
187
+ num_experts=moe_config.get("num_experts", 4),
188
+ top_k=moe_config.get("top_k", 2)
189
+ )
190
+
191
+ print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
192
+
193
+
194
+ def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
195
+ """为Sekai数据集生成camera embeddings - 滑动窗口版本"""
196
+ time_compression_ratio = 4
197
+
198
+ # 计算FramePack实际需要的camera帧数
199
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
200
+
201
+ if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
202
+ print("🔧 使用真实Sekai camera数据")
203
+ cam_extrinsic = cam_data['extrinsic']
204
+
205
+ # 确保生成足够长的camera序列
206
+ max_needed_frames = max(
207
+ start_frame + current_history_length + new_frames,
208
+ framepack_needed_frames,
209
+ 30
210
+ )
211
+
212
+ print(f"🔧 计算Sekai camera序列长度:")
213
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
214
+ print(f" - FramePack需求: {framepack_needed_frames}")
215
+ print(f" - 最终生成: {max_needed_frames}")
216
+
217
+ relative_poses = []
218
+ for i in range(max_needed_frames):
219
+ # 计算当前帧在原始序列中的位置
220
+ frame_idx = i * time_compression_ratio
221
+ next_frame_idx = frame_idx + time_compression_ratio
222
+
223
+ if next_frame_idx < len(cam_extrinsic):
224
+ cam_prev = cam_extrinsic[frame_idx]
225
+ cam_next = cam_extrinsic[next_frame_idx]
226
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
227
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
228
+ else:
229
+ # 超出范围,使用零运动
230
+ print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
231
+ relative_poses.append(torch.zeros(3, 4))
232
+
233
+ pose_embedding = torch.stack(relative_poses, dim=0)
234
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
235
+
236
+ # 创建对应长度的mask序列
237
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
238
+ # 从start_frame到current_history_length标记为condition
239
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
240
+ mask[start_frame:condition_end] = 1.0
241
+
242
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
243
+ print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
244
+ return camera_embedding.to(torch.bfloat16)
245
+
246
+ else:
247
+ print("🔧 使用Sekai合成camera数据")
248
+
249
+ max_needed_frames = max(
250
+ start_frame + current_history_length + new_frames,
251
+ framepack_needed_frames,
252
+ 30
253
+ )
254
+
255
+ print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
256
+ relative_poses = []
257
+ for i in range(max_needed_frames):
258
+ # 持续左转运动模式
259
+ yaw_per_frame = -0.1 # 每帧左转(正角度表示左转)
260
+ forward_speed = 0.005 # 每帧前进距离
261
+
262
+ pose = np.eye(4, dtype=np.float32)
263
+
264
+ # 旋转矩阵(绕Y轴左转)
265
+ cos_yaw = np.cos(yaw_per_frame)
266
+ sin_yaw = np.sin(yaw_per_frame)
267
+
268
+ pose[0, 0] = cos_yaw
269
+ pose[0, 2] = sin_yaw
270
+ pose[2, 0] = -sin_yaw
271
+ pose[2, 2] = cos_yaw
272
+
273
+ # 平移(在旋转后的局部坐标系中前进)
274
+ pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
275
+
276
+ # 添加轻微的向心运动,模拟圆形轨迹
277
+ radius_drift = 0.002 # 向圆心的轻微漂移
278
+ pose[0, 3] = radius_drift # 局部X轴负方向(向左)
279
+
280
+ relative_pose = pose[:3, :]
281
+ relative_poses.append(torch.as_tensor(relative_pose))
282
+
283
+ pose_embedding = torch.stack(relative_poses, dim=0)
284
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
285
+
286
+ # 创建对应长度的mask序列
287
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
288
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
289
+ mask[start_frame:condition_end] = 1.0
290
+
291
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
292
+ print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
293
+ return camera_embedding.to(torch.bfloat16)
294
+
295
+ def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses):
296
+ """为OpenX数据集生成camera embeddings - 滑动窗口版本"""
297
+ time_compression_ratio = 4
298
+
299
+ # 计算FramePack实际需要的camera帧数
300
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
301
+
302
+ if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
303
+ print("🔧 使用OpenX真实camera数据")
304
+ cam_extrinsic = encoded_data['cam_emb']['extrinsic']
305
+
306
+ # 确保生成足够长的camera序列
307
+ max_needed_frames = max(
308
+ start_frame + current_history_length + new_frames,
309
+ framepack_needed_frames,
310
+ 30
311
+ )
312
+
313
+ print(f"🔧 计算OpenX camera序列长度:")
314
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
315
+ print(f" - FramePack需求: {framepack_needed_frames}")
316
+ print(f" - 最终生成: {max_needed_frames}")
317
+
318
+ relative_poses = []
319
+ for i in range(max_needed_frames):
320
+ # OpenX使用4倍间隔,类似sekai但处理更短的序列
321
+ frame_idx = i * time_compression_ratio
322
+ next_frame_idx = frame_idx + time_compression_ratio
323
+
324
+ if next_frame_idx < len(cam_extrinsic):
325
+ cam_prev = cam_extrinsic[frame_idx]
326
+ cam_next = cam_extrinsic[next_frame_idx]
327
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
328
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
329
+ else:
330
+ # 超出范围,使用零运动
331
+ print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
332
+ relative_poses.append(torch.zeros(3, 4))
333
+
334
+ pose_embedding = torch.stack(relative_poses, dim=0)
335
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
336
+
337
+ # 创建对应长度的mask序列
338
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
339
+ # 从start_frame到current_history_length标记为condition
340
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
341
+ mask[start_frame:condition_end] = 1.0
342
+
343
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
344
+ print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
345
+ return camera_embedding.to(torch.bfloat16)
346
+
347
+ else:
348
+ print("🔧 使用OpenX合成camera数据")
349
+
350
+ max_needed_frames = max(
351
+ start_frame + current_history_length + new_frames,
352
+ framepack_needed_frames,
353
+ 30
354
+ )
355
+
356
+ print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
357
+ relative_poses = []
358
+ for i in range(max_needed_frames):
359
+ # OpenX机器人操作运动模式 - 较小的运动幅度
360
+ # 模拟机器人手臂的精细操作运动
361
+ roll_per_frame = 0.02 # 轻微翻滚
362
+ pitch_per_frame = 0.01 # 轻微俯仰
363
+ yaw_per_frame = 0.015 # 轻微偏航
364
+ forward_speed = 0.003 # 较慢的前进速度
365
+
366
+ pose = np.eye(4, dtype=np.float32)
367
+
368
+ # 复合旋转 - 模拟机器人手臂的复杂运动
369
+ # 绕X轴旋转(roll)
370
+ cos_roll = np.cos(roll_per_frame)
371
+ sin_roll = np.sin(roll_per_frame)
372
+ # 绕Y轴旋转(pitch)
373
+ cos_pitch = np.cos(pitch_per_frame)
374
+ sin_pitch = np.sin(pitch_per_frame)
375
+ # 绕Z轴旋转(yaw)
376
+ cos_yaw = np.cos(yaw_per_frame)
377
+ sin_yaw = np.sin(yaw_per_frame)
378
+
379
+ # 简化的复合旋转矩阵(ZYX顺序)
380
+ pose[0, 0] = cos_yaw * cos_pitch
381
+ pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
382
+ pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
383
+ pose[1, 0] = sin_yaw * cos_pitch
384
+ pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
385
+ pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
386
+ pose[2, 0] = -sin_pitch
387
+ pose[2, 1] = cos_pitch * sin_roll
388
+ pose[2, 2] = cos_pitch * cos_roll
389
+
390
+ # 平移 - 模拟机器人操作的精细移动
391
+ pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
392
+ pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动
393
+ pose[2, 3] = -forward_speed # Z方向(深度)主要移动
394
+
395
+ relative_pose = pose[:3, :]
396
+ relative_poses.append(torch.as_tensor(relative_pose))
397
+
398
+ pose_embedding = torch.stack(relative_poses, dim=0)
399
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
400
+
401
+ # 创建对应长度的mask序列
402
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
403
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
404
+ mask[start_frame:condition_end] = 1.0
405
+
406
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
407
+ print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
408
+ return camera_embedding.to(torch.bfloat16)
409
+
410
+
411
+ def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames):
412
+ """为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
413
+ time_compression_ratio = 4
414
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
415
+ max_needed_frames = max(framepack_needed_frames, 30)
416
+
417
+ if scene_info is not None and 'keyframe_poses' in scene_info:
418
+ print("🔧 使用NuScenes真实pose数据")
419
+ keyframe_poses = scene_info['keyframe_poses']
420
+ # 生成所有需要的关键帧索引
421
+ keyframe_indices = []
422
+ for i in range(max_needed_frames + 1): # +1是因为需要前后两帧
423
+ idx = (start_frame + i) * time_compression_ratio
424
+ keyframe_indices.append(idx)
425
+ keyframe_indices = [min(idx, len(keyframe_poses)-1) for idx in keyframe_indices]
426
+
427
+ pose_vecs = []
428
+ for i in range(max_needed_frames):
429
+ pose_prev = keyframe_poses[keyframe_indices[i]]
430
+ pose_next = keyframe_poses[keyframe_indices[i+1]]
431
+ # 计算相对位移
432
+ translation = torch.tensor(
433
+ np.array(pose_next['translation']) - np.array(pose_prev['translation']),
434
+ dtype=torch.float32
435
+ )
436
+ # 计算相对旋转
437
+ relative_rotation = calculate_relative_rotation(
438
+ pose_next['rotation'],
439
+ pose_prev['rotation']
440
+ )
441
+ pose_vec = torch.cat([translation, relative_rotation], dim=0) # [7D]
442
+ pose_vecs.append(pose_vec)
443
+
444
+ pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
445
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
446
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
447
+ mask[start_frame:condition_end] = 1.0
448
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1)
449
+ print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
450
+ return camera_embedding.to(torch.bfloat16)
451
+
452
+ else:
453
+ print("🔧 使用NuScenes合成pose数据")
454
+ # 先生成绝对轨迹
455
+ abs_translations = []
456
+ abs_rotations = []
457
+ for i in range(max_needed_frames + 1): # +1是为了后续做相对
458
+ angle = -i * 0.12
459
+ radius = 8.0
460
+ x = radius * np.sin(angle)
461
+ y = 0.0
462
+ z = radius * (1 - np.cos(angle))
463
+ abs_translations.append(np.array([x, y, z], dtype=np.float32))
464
+ yaw = angle + np.pi/2
465
+ abs_rotations.append(np.array([
466
+ np.cos(yaw/2), 0.0, 0.0, np.sin(yaw/2)
467
+ ], dtype=np.float32))
468
+
469
+ # 计算每帧相对上一帧的运动
470
+ pose_vecs = []
471
+ for i in range(max_needed_frames):
472
+ translation = torch.tensor(abs_translations[i+1] - abs_translations[i], dtype=torch.float32)
473
+ # 计算相对旋转
474
+ q_next = abs_rotations[i+1]
475
+ q_prev = abs_rotations[i]
476
+ # 四元数相对旋转
477
+ q_prev_inv = np.array([q_prev[0], -q_prev[1], -q_prev[2], -q_prev[3]], dtype=np.float32)
478
+ w1, x1, y1, z1 = q_prev_inv
479
+ w2, x2, y2, z2 = q_next
480
+ relative_rotation = torch.tensor([
481
+ w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
482
+ w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
483
+ w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
484
+ w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
485
+ ], dtype=torch.float32)
486
+ pose_vec = torch.cat([translation, relative_rotation], dim=0) # [7D]
487
+ pose_vecs.append(pose_vec)
488
+
489
+ pose_sequence = torch.stack(pose_vecs, dim=0)
490
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
491
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
492
+ mask[start_frame:condition_end] = 1.0
493
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1)
494
+ print(f"🔧 NuScenes合成相对pose embedding shape: {camera_embedding.shape}")
495
+ return camera_embedding.to(torch.bfloat16)
496
+
497
+ def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49):
498
+ """FramePack滑动窗口机制 - MoE版本"""
499
+ # history_latents: [C, T, H, W] 当前的历史latents
500
+ C, T, H, W = history_latents.shape
501
+
502
+ # 固定索引结构(这决定了需要的camera帧数)
503
+ total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
504
+ indices = torch.arange(0, total_indices_length)
505
+ split_sizes = [1, 16, 2, 1, target_frames_to_generate]
506
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
507
+ indices.split(split_sizes, dim=0)
508
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
509
+
510
+ # 检查camera长度是否足够
511
+ if camera_embedding_full.shape[0] < total_indices_length:
512
+ shortage = total_indices_length - camera_embedding_full.shape[0]
513
+ padding = torch.zeros(shortage, camera_embedding_full.shape[1],
514
+ dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
515
+ camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
516
+
517
+ # 从完整camera序列中选取对应部分
518
+ combined_camera = camera_embedding_full[:total_indices_length, :].clone()
519
+
520
+ # 根据当前history length重新设置mask
521
+ combined_camera[:, -1] = 0.0 # 先全部设为target (0)
522
+
523
+ # 设置condition mask:前19帧根据实际历史长度决定
524
+ if T > 0:
525
+ available_frames = min(T, 19)
526
+ start_pos = 19 - available_frames
527
+ combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
528
+
529
+ print(f"🔧 MoE Camera mask更新:")
530
+ print(f" - 历史帧数: {T}")
531
+ print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
532
+ print(f" - 模态类型: {modality_type}")
533
+
534
+ # 处理latents
535
+ clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
536
+
537
+ if T > 0:
538
+ available_frames = min(T, 19)
539
+ start_pos = 19 - available_frames
540
+ clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
541
+
542
+ clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
543
+ clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
544
+ clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
545
+
546
+ if T > 0:
547
+ start_latent = history_latents[:, 0:1, :, :]
548
+ else:
549
+ start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
550
+
551
+ clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
552
+
553
+ return {
554
+ 'latent_indices': latent_indices,
555
+ 'clean_latents': clean_latents,
556
+ 'clean_latents_2x': clean_latents_2x,
557
+ 'clean_latents_4x': clean_latents_4x,
558
+ 'clean_latent_indices': clean_latent_indices,
559
+ 'clean_latent_2x_indices': clean_latent_2x_indices,
560
+ 'clean_latent_4x_indices': clean_latent_4x_indices,
561
+ 'camera_embedding': combined_camera,
562
+ 'modality_type': modality_type, # 新增模态类型信息
563
+ 'current_length': T,
564
+ 'next_length': T + target_frames_to_generate
565
+ }
566
+
567
+
568
+ def inference_moe_framepack_sliding_window(
569
+ condition_pth_path,
570
+ dit_path,
571
+ output_path="moe/infer_results/output_moe_framepack_sliding.mp4",
572
+ start_frame=0,
573
+ initial_condition_frames=8,
574
+ frames_per_generation=4,
575
+ total_frames_to_generate=32,
576
+ max_history_frames=49,
577
+ device="cuda",
578
+ prompt="A video of a scene shot using a pedestrian's front camera while walking",
579
+ modality_type="sekai", # "sekai" 或 "nuscenes"
580
+ use_real_poses=True,
581
+ scene_info_path=None, # 对于NuScenes数据集
582
+ # CFG参数
583
+ use_camera_cfg=True,
584
+ camera_guidance_scale=2.0,
585
+ text_guidance_scale=1.0,
586
+ # MoE参数
587
+ moe_num_experts=4,
588
+ moe_top_k=2,
589
+ moe_hidden_dim=None
590
+ ):
591
+ """
592
+ MoE FramePack滑动窗口视频生成 - 支持多模态
593
+ """
594
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
595
+ print(f"🔧 MoE FramePack滑动窗口生成开始...")
596
+ print(f"模态类型: {modality_type}")
597
+ print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
598
+ print(f"Text guidance scale: {text_guidance_scale}")
599
+ print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
600
+
601
+ # 1. 模型初始化
602
+ replace_dit_model_in_manager()
603
+
604
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
605
+ model_manager.load_models([
606
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
607
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
608
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
609
+ ])
610
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
611
+
612
+ # 2. 添加传统camera编码器(兼容性)
613
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
614
+ for block in pipe.dit.blocks:
615
+ block.cam_encoder = nn.Linear(13, dim)
616
+ block.projector = nn.Linear(dim, dim)
617
+ block.cam_encoder.weight.data.zero_()
618
+ block.cam_encoder.bias.data.zero_()
619
+ block.projector.weight = nn.Parameter(torch.eye(dim))
620
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
621
+
622
+ # 3. 添加FramePack组件
623
+ add_framepack_components(pipe.dit)
624
+
625
+ # 4. 添加MoE组件
626
+ moe_config = {
627
+ "num_experts": moe_num_experts,
628
+ "top_k": moe_top_k,
629
+ "hidden_dim": moe_hidden_dim or dim * 2,
630
+ "sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
631
+ "nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
632
+ "openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
633
+ }
634
+ add_moe_components(pipe.dit, moe_config)
635
+
636
+ # 5. 加载训练好的权重
637
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
638
+ pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
639
+ pipe = pipe.to(device)
640
+ model_dtype = next(pipe.dit.parameters()).dtype
641
+
642
+ if hasattr(pipe.dit, 'clean_x_embedder'):
643
+ pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
644
+
645
+ pipe.scheduler.set_timesteps(50)
646
+
647
+ # 6. 加载初始条件
648
+ print("Loading initial condition frames...")
649
+ initial_latents, encoded_data = load_encoded_video_from_pth(
650
+ condition_pth_path,
651
+ start_frame=start_frame,
652
+ num_frames=initial_condition_frames
653
+ )
654
+
655
+ # 空间裁剪
656
+ target_height, target_width = 60, 104
657
+ C, T, H, W = initial_latents.shape
658
+
659
+ if H > target_height or W > target_width:
660
+ h_start = (H - target_height) // 2
661
+ w_start = (W - target_width) // 2
662
+ initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
663
+ H, W = target_height, target_width
664
+
665
+ history_latents = initial_latents.to(device, dtype=model_dtype)
666
+
667
+ print(f"初始history_latents shape: {history_latents.shape}")
668
+
669
+ # 7. 编码prompt - 支持CFG
670
+ if text_guidance_scale > 1.0:
671
+ prompt_emb_pos = pipe.encode_prompt(prompt)
672
+ prompt_emb_neg = pipe.encode_prompt("")
673
+ print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
674
+ else:
675
+ prompt_emb_pos = pipe.encode_prompt(prompt)
676
+ prompt_emb_neg = None
677
+ print("不使用Text CFG")
678
+
679
+ # 8. 加载场景信息(对于NuScenes)
680
+ scene_info = None
681
+ if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
682
+ with open(scene_info_path, 'r') as f:
683
+ scene_info = json.load(f)
684
+ print(f"加载NuScenes场景信息: {scene_info_path}")
685
+
686
+ # 9. 预生成完整的camera embedding序列
687
+ if modality_type == "sekai":
688
+ camera_embedding_full = generate_sekai_camera_embeddings_sliding(
689
+ encoded_data.get('cam_emb', None),
690
+ 0,
691
+ max_history_frames,
692
+ 0,
693
+ 0,
694
+ use_real_poses=use_real_poses
695
+ ).to(device, dtype=model_dtype)
696
+ elif modality_type == "nuscenes":
697
+ camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
698
+ scene_info,
699
+ 0,
700
+ max_history_frames,
701
+ 0
702
+ ).to(device, dtype=model_dtype)
703
+ elif modality_type == "openx":
704
+ camera_embedding_full = generate_openx_camera_embeddings_sliding(
705
+ encoded_data,
706
+ 0,
707
+ max_history_frames,
708
+ 0,
709
+ use_real_poses=use_real_poses
710
+ ).to(device, dtype=model_dtype)
711
+ else:
712
+ raise ValueError(f"不支持的模态类型: {modality_type}")
713
+
714
+ print(f"完整camera序列shape: {camera_embedding_full.shape}")
715
+
716
+ # 10. 为Camera CFG创建无条件的camera embedding
717
+ if use_camera_cfg:
718
+ camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
719
+ print(f"创建无条件camera embedding用于CFG")
720
+
721
+ # 11. 滑动窗口生成循环
722
+ total_generated = 0
723
+ all_generated_frames = []
724
+
725
+ while total_generated < total_frames_to_generate:
726
+ current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
727
+ print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
728
+ print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
729
+
730
+ # FramePack数据准备 - MoE版本
731
+ framepack_data = prepare_framepack_sliding_window_with_camera_moe(
732
+ history_latents,
733
+ current_generation,
734
+ camera_embedding_full,
735
+ start_frame,
736
+ modality_type,
737
+ max_history_frames
738
+ )
739
+
740
+ # 准备输入
741
+ clean_latents = framepack_data['clean_latents'].unsqueeze(0)
742
+ clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
743
+ clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
744
+ camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
745
+
746
+ # 准备modality_inputs
747
+ modality_inputs = {modality_type: camera_embedding}
748
+
749
+ # 为CFG准备无条件camera embedding
750
+ if use_camera_cfg:
751
+ camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
752
+ modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
753
+
754
+ # 索引处理
755
+ latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
756
+ clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
757
+ clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
758
+ clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
759
+
760
+ # 初始化要生成的latents
761
+ new_latents = torch.randn(
762
+ 1, C, current_generation, H, W,
763
+ device=device, dtype=model_dtype
764
+ )
765
+
766
+ extra_input = pipe.prepare_extra_input(new_latents)
767
+
768
+ print(f"Camera embedding shape: {camera_embedding.shape}")
769
+ print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
770
+
771
+ # 去噪循环 - 支持CFG
772
+ timesteps = pipe.scheduler.timesteps
773
+
774
+ for i, timestep in enumerate(timesteps):
775
+ if i % 10 == 0:
776
+ print(f" 去噪步骤 {i+1}/{len(timesteps)}")
777
+
778
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
779
+
780
+ with torch.no_grad():
781
+ # CFG推理
782
+ if use_camera_cfg and camera_guidance_scale > 1.0:
783
+ # 条件预测(有camera)
784
+ noise_pred_cond, moe_loess = pipe.dit(
785
+ new_latents,
786
+ timestep=timestep_tensor,
787
+ cam_emb=camera_embedding,
788
+ modality_inputs=modality_inputs, # MoE模态输入
789
+ latent_indices=latent_indices,
790
+ clean_latents=clean_latents,
791
+ clean_latent_indices=clean_latent_indices,
792
+ clean_latents_2x=clean_latents_2x,
793
+ clean_latent_2x_indices=clean_latent_2x_indices,
794
+ clean_latents_4x=clean_latents_4x,
795
+ clean_latent_4x_indices=clean_latent_4x_indices,
796
+ **prompt_emb_pos,
797
+ **extra_input
798
+ )
799
+
800
+ # 无条件预测(无camera)
801
+ noise_pred_uncond, moe_loess = pipe.dit(
802
+ new_latents,
803
+ timestep=timestep_tensor,
804
+ cam_emb=camera_embedding_uncond_batch,
805
+ modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
806
+ latent_indices=latent_indices,
807
+ clean_latents=clean_latents,
808
+ clean_latent_indices=clean_latent_indices,
809
+ clean_latents_2x=clean_latents_2x,
810
+ clean_latent_2x_indices=clean_latent_2x_indices,
811
+ clean_latents_4x=clean_latents_4x,
812
+ clean_latent_4x_indices=clean_latent_4x_indices,
813
+ **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
814
+ **extra_input
815
+ )
816
+
817
+ # Camera CFG
818
+ noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
819
+
820
+ # 如果同时使用Text CFG
821
+ if text_guidance_scale > 1.0 and prompt_emb_neg:
822
+ noise_pred_text_uncond, moe_loess = pipe.dit(
823
+ new_latents,
824
+ timestep=timestep_tensor,
825
+ cam_emb=camera_embedding,
826
+ modality_inputs=modality_inputs,
827
+ latent_indices=latent_indices,
828
+ clean_latents=clean_latents,
829
+ clean_latent_indices=clean_latent_indices,
830
+ clean_latents_2x=clean_latents_2x,
831
+ clean_latent_2x_indices=clean_latent_2x_indices,
832
+ clean_latents_4x=clean_latents_4x,
833
+ clean_latent_4x_indices=clean_latent_4x_indices,
834
+ **prompt_emb_neg,
835
+ **extra_input
836
+ )
837
+
838
+ # 应用Text CFG到已经应用Camera CFG的结果
839
+ noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
840
+
841
+ elif text_guidance_scale > 1.0 and prompt_emb_neg:
842
+ # 只使用Text CFG
843
+ noise_pred_cond, moe_loess = pipe.dit(
844
+ new_latents,
845
+ timestep=timestep_tensor,
846
+ cam_emb=camera_embedding,
847
+ modality_inputs=modality_inputs,
848
+ latent_indices=latent_indices,
849
+ clean_latents=clean_latents,
850
+ clean_latent_indices=clean_latent_indices,
851
+ clean_latents_2x=clean_latents_2x,
852
+ clean_latent_2x_indices=clean_latent_2x_indices,
853
+ clean_latents_4x=clean_latents_4x,
854
+ clean_latent_4x_indices=clean_latent_4x_indices,
855
+ **prompt_emb_pos,
856
+ **extra_input
857
+ )
858
+
859
+ noise_pred_uncond, moe_loess= pipe.dit(
860
+ new_latents,
861
+ timestep=timestep_tensor,
862
+ cam_emb=camera_embedding,
863
+ modality_inputs=modality_inputs,
864
+ latent_indices=latent_indices,
865
+ clean_latents=clean_latents,
866
+ clean_latent_indices=clean_latent_indices,
867
+ clean_latents_2x=clean_latents_2x,
868
+ clean_latent_2x_indices=clean_latent_2x_indices,
869
+ clean_latents_4x=clean_latents_4x,
870
+ clean_latent_4x_indices=clean_latent_4x_indices,
871
+ **prompt_emb_neg,
872
+ **extra_input
873
+ )
874
+
875
+ noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
876
+
877
+ else:
878
+ # 标准推理(无CFG)
879
+ noise_pred, moe_loess = pipe.dit(
880
+ new_latents,
881
+ timestep=timestep_tensor,
882
+ cam_emb=camera_embedding,
883
+ modality_inputs=modality_inputs, # MoE模态输入
884
+ latent_indices=latent_indices,
885
+ clean_latents=clean_latents,
886
+ clean_latent_indices=clean_latent_indices,
887
+ clean_latents_2x=clean_latents_2x,
888
+ clean_latent_2x_indices=clean_latent_2x_indices,
889
+ clean_latents_4x=clean_latents_4x,
890
+ clean_latent_4x_indices=clean_latent_4x_indices,
891
+ **prompt_emb_pos,
892
+ **extra_input
893
+ )
894
+
895
+ new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
896
+
897
+ # 更新历史
898
+ new_latents_squeezed = new_latents.squeeze(0)
899
+ history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
900
+
901
+ # 维护滑动窗口
902
+ if history_latents.shape[1] > max_history_frames:
903
+ first_frame = history_latents[:, 0:1, :, :]
904
+ recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
905
+ history_latents = torch.cat([first_frame, recent_frames], dim=1)
906
+ print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
907
+
908
+ print(f"更新后history_latents shape: {history_latents.shape}")
909
+
910
+ all_generated_frames.append(new_latents_squeezed)
911
+ total_generated += current_generation
912
+
913
+ print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
914
+
915
+ # 12. 解码和保存
916
+ print("\n🔧 解码生成的视频...")
917
+
918
+ all_generated = torch.cat(all_generated_frames, dim=1)
919
+ final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
920
+
921
+ print(f"最终视频shape: {final_video.shape}")
922
+
923
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
924
+
925
+ print(f"Saving video to {output_path}")
926
+
927
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
928
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1)
929
+ video_np = (video_np * 255).astype(np.uint8)
930
+
931
+ with imageio.get_writer(output_path, fps=20) as writer:
932
+ for frame in video_np:
933
+ writer.append_data(frame)
934
+
935
+ print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
936
+ print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
937
+ print(f"使用模态: {modality_type}")
938
+
939
+
940
+ def main():
941
+ parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
942
+
943
+ # 基础参数
944
+ parser.add_argument("--condition_pth", type=str,
945
+ default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
946
+ #default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth")
947
+ #default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth")
948
+ #default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth")
949
+ parser.add_argument("--start_frame", type=int, default=0)
950
+ parser.add_argument("--initial_condition_frames", type=int, default=16)
951
+ parser.add_argument("--frames_per_generation", type=int, default=8)
952
+ parser.add_argument("--total_frames_to_generate", type=int, default=24)
953
+ parser.add_argument("--max_history_frames", type=int, default=100)
954
+ parser.add_argument("--use_real_poses", default=True)
955
+ parser.add_argument("--dit_path", type=str,
956
+ default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step25000_first.ckpt")
957
+ parser.add_argument("--output_path", type=str,
958
+ default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4')
959
+ parser.add_argument("--prompt", type=str,
960
+ default="A drone flying scene in a game world ")
961
+ parser.add_argument("--device", type=str, default="cuda")
962
+
963
+ # 模态类型参数
964
+ parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="sekai",
965
+ help="模态类型:sekai 或 nuscenes 或 openx")
966
+ parser.add_argument("--scene_info_path", type=str, default=None,
967
+ help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
968
+
969
+ # CFG参数
970
+ parser.add_argument("--use_camera_cfg", default=False,
971
+ help="使用Camera CFG")
972
+ parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
973
+ help="Camera guidance scale for CFG")
974
+ parser.add_argument("--text_guidance_scale", type=float, default=1.0,
975
+ help="Text guidance scale for CFG")
976
+
977
+ # MoE参数
978
+ parser.add_argument("--moe_num_experts", type=int, default=3, help="专家数量")
979
+ parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
980
+ parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
981
+
982
+ args = parser.parse_args()
983
+
984
+ print(f"🔧 MoE FramePack CFG生成设置:")
985
+ print(f"模态类型: {args.modality_type}")
986
+ print(f"Camera CFG: {args.use_camera_cfg}")
987
+ if args.use_camera_cfg:
988
+ print(f"Camera guidance scale: {args.camera_guidance_scale}")
989
+ print(f"Text guidance scale: {args.text_guidance_scale}")
990
+ print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
991
+ print(f"DiT{args.dit_path}")
992
+
993
+ # 验证NuScenes参数
994
+ if args.modality_type == "nuscenes" and not args.scene_info_path:
995
+ print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
996
+
997
+ inference_moe_framepack_sliding_window(
998
+ condition_pth_path=args.condition_pth,
999
+ dit_path=args.dit_path,
1000
+ output_path=args.output_path,
1001
+ start_frame=args.start_frame,
1002
+ initial_condition_frames=args.initial_condition_frames,
1003
+ frames_per_generation=args.frames_per_generation,
1004
+ total_frames_to_generate=args.total_frames_to_generate,
1005
+ max_history_frames=args.max_history_frames,
1006
+ device=args.device,
1007
+ prompt=args.prompt,
1008
+ modality_type=args.modality_type,
1009
+ use_real_poses=args.use_real_poses,
1010
+ scene_info_path=args.scene_info_path,
1011
+ # CFG参数
1012
+ use_camera_cfg=args.use_camera_cfg,
1013
+ camera_guidance_scale=args.camera_guidance_scale,
1014
+ text_guidance_scale=args.text_guidance_scale,
1015
+ # MoE参数
1016
+ moe_num_experts=args.moe_num_experts,
1017
+ moe_top_k=args.moe_top_k,
1018
+ moe_hidden_dim=args.moe_hidden_dim
1019
+ )
1020
+
1021
+
1022
+ if __name__ == "__main__":
1023
+ main()
scripts/infer_moe_spatialvid.py ADDED
@@ -0,0 +1,1008 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ import imageio
7
+ import json
8
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
9
+ import argparse
10
+ from torchvision.transforms import v2
11
+ from einops import rearrange
12
+ import copy
13
+ from scipy.spatial.transform import Rotation as R
14
+
15
+ def compute_relative_pose_matrix(pose1, pose2):
16
+ """
17
+ 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel]
18
+
19
+ 参数:
20
+ pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
21
+ pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
22
+
23
+ 返回:
24
+ relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel
25
+ """
26
+ # 分离平移向量和四元数
27
+ t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1]
28
+ q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1]
29
+ t2 = pose2[:3] # 第i+1帧平移
30
+ q2 = pose2[3:] # 第i+1帧四元数
31
+
32
+ # 1. 计算相对旋转矩阵 R_rel
33
+ rot1 = R.from_quat(q1) # 第i帧旋转
34
+ rot2 = R.from_quat(q2) # 第i+1帧旋转
35
+ rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆
36
+ R_rel = rot_rel.as_matrix() # 转换为3×3矩阵
37
+
38
+ # 2. 计算相对平移向量 t_rel
39
+ R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆)
40
+ t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1)
41
+
42
+ # 3. 组合为3×4矩阵 [R_rel | t_rel]
43
+ relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
44
+
45
+ return relative_matrix
46
+
47
+ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
48
+ """从pth文件加载预编码的视频数据"""
49
+ print(f"Loading encoded video from {pth_path}")
50
+
51
+ encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
52
+ full_latents = encoded_data['latents'] # [C, T, H, W]
53
+
54
+ print(f"Full latents shape: {full_latents.shape}")
55
+ print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
56
+
57
+ if start_frame + num_frames > full_latents.shape[1]:
58
+ raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
59
+
60
+ condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
61
+ print(f"Extracted condition latents shape: {condition_latents.shape}")
62
+
63
+ return condition_latents, encoded_data
64
+
65
+
66
+ def compute_relative_pose(pose_a, pose_b, use_torch=False):
67
+ """计算相机B相对于相机A的相对位姿矩阵"""
68
+ assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
69
+ assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
70
+
71
+ if use_torch:
72
+ if not isinstance(pose_a, torch.Tensor):
73
+ pose_a = torch.from_numpy(pose_a).float()
74
+ if not isinstance(pose_b, torch.Tensor):
75
+ pose_b = torch.from_numpy(pose_b).float()
76
+
77
+ pose_a_inv = torch.inverse(pose_a)
78
+ relative_pose = torch.matmul(pose_b, pose_a_inv)
79
+ else:
80
+ if not isinstance(pose_a, np.ndarray):
81
+ pose_a = np.array(pose_a, dtype=np.float32)
82
+ if not isinstance(pose_b, np.ndarray):
83
+ pose_b = np.array(pose_b, dtype=np.float32)
84
+
85
+ pose_a_inv = np.linalg.inv(pose_a)
86
+ relative_pose = np.matmul(pose_b, pose_a_inv)
87
+
88
+ return relative_pose
89
+
90
+
91
+ def replace_dit_model_in_manager():
92
+ """替换DiT模型类为MoE版本"""
93
+ from diffsynth.models.wan_video_dit_moe import WanModelMoe
94
+ from diffsynth.configs.model_config import model_loader_configs
95
+
96
+ for i, config in enumerate(model_loader_configs):
97
+ keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
98
+
99
+ if 'wan_video_dit' in model_names:
100
+ new_model_names = []
101
+ new_model_classes = []
102
+
103
+ for name, cls in zip(model_names, model_classes):
104
+ if name == 'wan_video_dit':
105
+ new_model_names.append(name)
106
+ new_model_classes.append(WanModelMoe)
107
+ print(f"✅ 替换了模型类: {name} -> WanModelMoe")
108
+ else:
109
+ new_model_names.append(name)
110
+ new_model_classes.append(cls)
111
+
112
+ model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
113
+
114
+
115
+ def add_framepack_components(dit_model):
116
+ """添加FramePack相关组件"""
117
+ if not hasattr(dit_model, 'clean_x_embedder'):
118
+ inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
119
+
120
+ class CleanXEmbedder(nn.Module):
121
+ def __init__(self, inner_dim):
122
+ super().__init__()
123
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
124
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
125
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
126
+
127
+ def forward(self, x, scale="1x"):
128
+ if scale == "1x":
129
+ x = x.to(self.proj.weight.dtype)
130
+ return self.proj(x)
131
+ elif scale == "2x":
132
+ x = x.to(self.proj_2x.weight.dtype)
133
+ return self.proj_2x(x)
134
+ elif scale == "4x":
135
+ x = x.to(self.proj_4x.weight.dtype)
136
+ return self.proj_4x(x)
137
+ else:
138
+ raise ValueError(f"Unsupported scale: {scale}")
139
+
140
+ dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
141
+ model_dtype = next(dit_model.parameters()).dtype
142
+ dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
143
+ print("✅ 添加了FramePack的clean_x_embedder组件")
144
+
145
+
146
+ def add_moe_components(dit_model, moe_config):
147
+ """🔧 添加MoE相关组件 - 修正版本"""
148
+ if not hasattr(dit_model, 'moe_config'):
149
+ dit_model.moe_config = moe_config
150
+ print("✅ 添加了MoE配置到模型")
151
+
152
+ # 为每个block动态添加MoE组件
153
+ dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
154
+ unified_dim = moe_config.get("unified_dim", 25)
155
+
156
+ for i, block in enumerate(dit_model.blocks):
157
+ from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
158
+
159
+ # Sekai模态处理器 - 输出unified_dim
160
+ block.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
161
+
162
+ # # NuScenes模态处理器 - 输出unified_dim
163
+ # block.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
164
+
165
+ # MoE网络 - 输入unified_dim,输出dim
166
+ block.moe = MultiModalMoE(
167
+ unified_dim=unified_dim,
168
+ output_dim=dim, # 输出维度匹配transformer block的dim
169
+ num_experts=moe_config.get("num_experts", 4),
170
+ top_k=moe_config.get("top_k", 2)
171
+ )
172
+
173
+ print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
174
+
175
+
176
+ def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
177
+ """为Sekai数据集生成camera embeddings - 滑动窗口版本"""
178
+ time_compression_ratio = 4
179
+
180
+ # 计算FramePack实际需要的camera帧数
181
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
182
+
183
+ if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
184
+ print("🔧 使用真实Sekai camera数据")
185
+ cam_extrinsic = cam_data['extrinsic']
186
+
187
+ # 确保生成足够长的camera序列
188
+ max_needed_frames = max(
189
+ start_frame + current_history_length + new_frames,
190
+ framepack_needed_frames,
191
+ 30
192
+ )
193
+
194
+ print(f"🔧 计算Sekai camera序列长度:")
195
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
196
+ print(f" - FramePack需求: {framepack_needed_frames}")
197
+ print(f" - 最终生成: {max_needed_frames}")
198
+
199
+ relative_poses = []
200
+ for i in range(max_needed_frames):
201
+ # 计算当前帧在原始序列中的位置
202
+ frame_idx = i * time_compression_ratio
203
+ next_frame_idx = frame_idx + time_compression_ratio
204
+
205
+ if next_frame_idx < len(cam_extrinsic):
206
+ cam_prev = cam_extrinsic[frame_idx]
207
+ cam_next = cam_extrinsic[next_frame_idx]
208
+ relative_pose = compute_relative_pose_matrix(cam_prev, cam_next)
209
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
210
+ else:
211
+ # 超出范围,使用零运动
212
+ print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
213
+ relative_poses.append(torch.zeros(3, 4))
214
+
215
+ pose_embedding = torch.stack(relative_poses, dim=0)
216
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
217
+
218
+ # 创建对应长度的mask序列
219
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
220
+ # 从start_frame到current_history_length标记为condition
221
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
222
+ mask[start_frame:condition_end] = 1.0
223
+
224
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
225
+ print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
226
+ return camera_embedding.to(torch.bfloat16)
227
+
228
+ else:
229
+ print("🔧 使用Sekai合成camera数据")
230
+
231
+ max_needed_frames = max(
232
+ start_frame + current_history_length + new_frames,
233
+ framepack_needed_frames,
234
+ 30
235
+ )
236
+
237
+ print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
238
+ relative_poses = []
239
+ for i in range(max_needed_frames):
240
+ # 持续左转运动模式
241
+ yaw_per_frame = 0.05 # 每帧左转(正角度表示左转)
242
+ forward_speed = 0.005 # 每帧前进距离
243
+
244
+ pose = np.eye(4, dtype=np.float32)
245
+
246
+ # 旋转矩阵(绕Y轴左转)
247
+ cos_yaw = np.cos(yaw_per_frame)
248
+ sin_yaw = np.sin(yaw_per_frame)
249
+
250
+ pose[0, 0] = cos_yaw
251
+ pose[0, 2] = sin_yaw
252
+ pose[2, 0] = -sin_yaw
253
+ pose[2, 2] = cos_yaw
254
+
255
+ # 平移(在旋转后的局部坐标系中前进)
256
+ pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
257
+
258
+ # 添加轻微的向心运动,模拟圆形轨迹
259
+ radius_drift = 0.002 # 向圆心的轻微漂移
260
+ pose[0, 3] = -radius_drift # 局部X轴负方向(向左)
261
+
262
+ relative_pose = pose[:3, :]
263
+ relative_poses.append(torch.as_tensor(relative_pose))
264
+
265
+ pose_embedding = torch.stack(relative_poses, dim=0)
266
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
267
+
268
+ # 创建对应长度的mask序列
269
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
270
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
271
+ mask[start_frame:condition_end] = 1.0
272
+
273
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
274
+ print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
275
+ return camera_embedding.to(torch.bfloat16)
276
+
277
+ def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses):
278
+ """为OpenX数据集生成camera embeddings - 滑动窗口版本"""
279
+ time_compression_ratio = 4
280
+
281
+ # 计算FramePack实际需要的camera帧数
282
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
283
+
284
+ if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
285
+ print("🔧 使用OpenX真实camera数据")
286
+ cam_extrinsic = encoded_data['cam_emb']['extrinsic']
287
+
288
+ # 确保生成足够长的camera序列
289
+ max_needed_frames = max(
290
+ start_frame + current_history_length + new_frames,
291
+ framepack_needed_frames,
292
+ 30
293
+ )
294
+
295
+ print(f"🔧 计算OpenX camera序列长度:")
296
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
297
+ print(f" - FramePack需求: {framepack_needed_frames}")
298
+ print(f" - 最终生成: {max_needed_frames}")
299
+
300
+ relative_poses = []
301
+ for i in range(max_needed_frames):
302
+ # OpenX使用4倍间隔,类似sekai但处理更短的序列
303
+ frame_idx = i * time_compression_ratio
304
+ next_frame_idx = frame_idx + time_compression_ratio
305
+
306
+ if next_frame_idx < len(cam_extrinsic):
307
+ cam_prev = cam_extrinsic[frame_idx]
308
+ cam_next = cam_extrinsic[next_frame_idx]
309
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
310
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
311
+ else:
312
+ # 超出范围,使用零运动
313
+ print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
314
+ relative_poses.append(torch.zeros(3, 4))
315
+
316
+ pose_embedding = torch.stack(relative_poses, dim=0)
317
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
318
+
319
+ # 创建对应长度的mask序列
320
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
321
+ # 从start_frame到current_history_length标记为condition
322
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
323
+ mask[start_frame:condition_end] = 1.0
324
+
325
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
326
+ print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
327
+ return camera_embedding.to(torch.bfloat16)
328
+
329
+ else:
330
+ print("🔧 使用OpenX合成camera数据")
331
+
332
+ max_needed_frames = max(
333
+ start_frame + current_history_length + new_frames,
334
+ framepack_needed_frames,
335
+ 30
336
+ )
337
+
338
+ print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
339
+ relative_poses = []
340
+ for i in range(max_needed_frames):
341
+ # OpenX机器人操作运动模式 - 较小的运动幅度
342
+ # 模拟机器人手臂的精细操作运动
343
+ roll_per_frame = 0.02 # 轻微翻滚
344
+ pitch_per_frame = 0.01 # 轻微俯仰
345
+ yaw_per_frame = 0.015 # 轻微偏航
346
+ forward_speed = 0.003 # 较慢的前进速度
347
+
348
+ pose = np.eye(4, dtype=np.float32)
349
+
350
+ # 复合旋转 - 模拟机器人手臂的复杂运动
351
+ # 绕X轴旋转(roll)
352
+ cos_roll = np.cos(roll_per_frame)
353
+ sin_roll = np.sin(roll_per_frame)
354
+ # 绕Y轴旋转(pitch)
355
+ cos_pitch = np.cos(pitch_per_frame)
356
+ sin_pitch = np.sin(pitch_per_frame)
357
+ # 绕Z轴旋转(yaw)
358
+ cos_yaw = np.cos(yaw_per_frame)
359
+ sin_yaw = np.sin(yaw_per_frame)
360
+
361
+ # 简化的复合旋转矩阵(ZYX顺序)
362
+ pose[0, 0] = cos_yaw * cos_pitch
363
+ pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
364
+ pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
365
+ pose[1, 0] = sin_yaw * cos_pitch
366
+ pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
367
+ pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
368
+ pose[2, 0] = -sin_pitch
369
+ pose[2, 1] = cos_pitch * sin_roll
370
+ pose[2, 2] = cos_pitch * cos_roll
371
+
372
+ # 平移 - 模拟机器人操作的精细移动
373
+ pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
374
+ pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动
375
+ pose[2, 3] = -forward_speed # Z方向(深度)主要移动
376
+
377
+ relative_pose = pose[:3, :]
378
+ relative_poses.append(torch.as_tensor(relative_pose))
379
+
380
+ pose_embedding = torch.stack(relative_poses, dim=0)
381
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
382
+
383
+ # 创建对应长度的mask序列
384
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
385
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
386
+ mask[start_frame:condition_end] = 1.0
387
+
388
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
389
+ print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
390
+ return camera_embedding.to(torch.bfloat16)
391
+
392
+
393
+ def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames):
394
+ """为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
395
+ time_compression_ratio = 4
396
+
397
+ # 计算FramePack实际需要的camera帧数
398
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
399
+
400
+ if scene_info is not None and 'keyframe_poses' in scene_info:
401
+ print("🔧 使用NuScenes真实pose数据")
402
+ keyframe_poses = scene_info['keyframe_poses']
403
+
404
+ if len(keyframe_poses) == 0:
405
+ print("⚠️ NuScenes keyframe_poses为空,使用零pose")
406
+ max_needed_frames = max(framepack_needed_frames, 30)
407
+
408
+ pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
409
+
410
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
411
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
412
+ mask[start_frame:condition_end] = 1.0
413
+
414
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
415
+ print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}")
416
+ return camera_embedding.to(torch.bfloat16)
417
+
418
+ # 使用第一个pose作为参考
419
+ reference_pose = keyframe_poses[0]
420
+
421
+ max_needed_frames = max(framepack_needed_frames, 30)
422
+
423
+ pose_vecs = []
424
+ for i in range(max_needed_frames):
425
+ if i < len(keyframe_poses):
426
+ current_pose = keyframe_poses[i]
427
+
428
+ # 计算相对位移
429
+ translation = torch.tensor(
430
+ np.array(current_pose['translation']) - np.array(reference_pose['translation']),
431
+ dtype=torch.float32
432
+ )
433
+
434
+ # 计算相对旋转(简化版本)
435
+ rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
436
+
437
+ pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
438
+ else:
439
+ # 超出范围,使用零pose
440
+ pose_vec = torch.cat([
441
+ torch.zeros(3, dtype=torch.float32),
442
+ torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
443
+ ], dim=0) # [7D]
444
+
445
+ pose_vecs.append(pose_vec)
446
+
447
+ pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
448
+
449
+ # 创建mask
450
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
451
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
452
+ mask[start_frame:condition_end] = 1.0
453
+
454
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
455
+ print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
456
+ return camera_embedding.to(torch.bfloat16)
457
+
458
+ else:
459
+ print("🔧 使用NuScenes合成pose数据")
460
+ max_needed_frames = max(framepack_needed_frames, 30)
461
+
462
+ # 创建合成运动序列
463
+ pose_vecs = []
464
+ for i in range(max_needed_frames):
465
+ # 简单的前进运动
466
+ translation = torch.tensor([0.0, 0.0, i * 0.1], dtype=torch.float32) # 沿Z轴前进
467
+ rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) # 无旋转
468
+
469
+ pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
470
+ pose_vecs.append(pose_vec)
471
+
472
+ pose_sequence = torch.stack(pose_vecs, dim=0)
473
+
474
+ # 创建mask
475
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
476
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
477
+ mask[start_frame:condition_end] = 1.0
478
+
479
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
480
+ print(f"🔧 NuScenes合成pose embedding shape: {camera_embedding.shape}")
481
+ return camera_embedding.to(torch.bfloat16)
482
+
483
+ def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49):
484
+ """FramePack滑动窗口机制 - MoE版本"""
485
+ # history_latents: [C, T, H, W] 当前的历史latents
486
+ C, T, H, W = history_latents.shape
487
+
488
+ # 固定索引结构(这决定了需要的camera帧数)
489
+ total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
490
+ indices = torch.arange(0, total_indices_length)
491
+ split_sizes = [1, 16, 2, 1, target_frames_to_generate]
492
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
493
+ indices.split(split_sizes, dim=0)
494
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
495
+
496
+ # 检查camera长度是否足够
497
+ if camera_embedding_full.shape[0] < total_indices_length:
498
+ shortage = total_indices_length - camera_embedding_full.shape[0]
499
+ padding = torch.zeros(shortage, camera_embedding_full.shape[1],
500
+ dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
501
+ camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
502
+
503
+ # 从完整camera序列中选取对应部分
504
+ combined_camera = camera_embedding_full[:total_indices_length, :].clone()
505
+
506
+ # 根据当前history length重新设置mask
507
+ combined_camera[:, -1] = 0.0 # 先全部设为target (0)
508
+
509
+ # 设置condition mask:前19帧根据实际历史长度决定
510
+ if T > 0:
511
+ available_frames = min(T, 19)
512
+ start_pos = 19 - available_frames
513
+ combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
514
+
515
+ print(f"🔧 MoE Camera mask更新:")
516
+ print(f" - 历史帧数: {T}")
517
+ print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
518
+ print(f" - 模态类型: {modality_type}")
519
+
520
+ # 处理latents
521
+ clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
522
+
523
+ if T > 0:
524
+ available_frames = min(T, 19)
525
+ start_pos = 19 - available_frames
526
+ clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
527
+
528
+ clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
529
+ clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
530
+ clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
531
+
532
+ if T > 0:
533
+ start_latent = history_latents[:, 0:1, :, :]
534
+ else:
535
+ start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
536
+
537
+ clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
538
+
539
+ return {
540
+ 'latent_indices': latent_indices,
541
+ 'clean_latents': clean_latents,
542
+ 'clean_latents_2x': clean_latents_2x,
543
+ 'clean_latents_4x': clean_latents_4x,
544
+ 'clean_latent_indices': clean_latent_indices,
545
+ 'clean_latent_2x_indices': clean_latent_2x_indices,
546
+ 'clean_latent_4x_indices': clean_latent_4x_indices,
547
+ 'camera_embedding': combined_camera,
548
+ 'modality_type': modality_type, # 新增模态类型信息
549
+ 'current_length': T,
550
+ 'next_length': T + target_frames_to_generate
551
+ }
552
+
553
+
554
+ def inference_moe_framepack_sliding_window(
555
+ condition_pth_path,
556
+ dit_path,
557
+ output_path="moe/infer_results/output_moe_framepack_sliding.mp4",
558
+ start_frame=0,
559
+ initial_condition_frames=8,
560
+ frames_per_generation=4,
561
+ total_frames_to_generate=32,
562
+ max_history_frames=49,
563
+ device="cuda",
564
+ prompt="A video of a scene shot using a pedestrian's front camera while walking",
565
+ modality_type="sekai", # "sekai" 或 "nuscenes"
566
+ use_real_poses=True,
567
+ scene_info_path=None, # 对于NuScenes数据集
568
+ # CFG参数
569
+ use_camera_cfg=True,
570
+ camera_guidance_scale=2.0,
571
+ text_guidance_scale=1.0,
572
+ # MoE参数
573
+ moe_num_experts=4,
574
+ moe_top_k=2,
575
+ moe_hidden_dim=None
576
+ ):
577
+ """
578
+ MoE FramePack滑动窗口视频生成 - 支持多模态
579
+ """
580
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
581
+ print(f"🔧 MoE FramePack滑动窗口生成开始...")
582
+ print(f"模态类型: {modality_type}")
583
+ print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
584
+ print(f"Text guidance scale: {text_guidance_scale}")
585
+ print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
586
+
587
+ # 1. 模型初始化
588
+ replace_dit_model_in_manager()
589
+
590
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
591
+ model_manager.load_models([
592
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
593
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
594
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
595
+ ])
596
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
597
+
598
+ # 2. 添加传统camera编码器(兼容性)
599
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
600
+ for block in pipe.dit.blocks:
601
+ block.cam_encoder = nn.Linear(13, dim)
602
+ block.projector = nn.Linear(dim, dim)
603
+ block.cam_encoder.weight.data.zero_()
604
+ block.cam_encoder.bias.data.zero_()
605
+ block.projector.weight = nn.Parameter(torch.eye(dim))
606
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
607
+
608
+ # 3. 添加FramePack组件
609
+ add_framepack_components(pipe.dit)
610
+
611
+ # 4. 添加MoE组件
612
+ moe_config = {
613
+ "num_experts": moe_num_experts,
614
+ "top_k": moe_top_k,
615
+ "hidden_dim": moe_hidden_dim or dim * 2,
616
+ "sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
617
+ "nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
618
+ "openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
619
+ }
620
+ add_moe_components(pipe.dit, moe_config)
621
+
622
+ # 5. 加载训练好的权重
623
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
624
+ pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
625
+ pipe = pipe.to(device)
626
+ model_dtype = next(pipe.dit.parameters()).dtype
627
+
628
+ if hasattr(pipe.dit, 'clean_x_embedder'):
629
+ pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
630
+
631
+ pipe.scheduler.set_timesteps(50)
632
+
633
+ # 6. 加载初始条件
634
+ print("Loading initial condition frames...")
635
+ initial_latents, encoded_data = load_encoded_video_from_pth(
636
+ condition_pth_path,
637
+ start_frame=start_frame,
638
+ num_frames=initial_condition_frames
639
+ )
640
+
641
+ # 空间裁剪
642
+ target_height, target_width = 60, 104
643
+ C, T, H, W = initial_latents.shape
644
+
645
+ if H > target_height or W > target_width:
646
+ h_start = (H - target_height) // 2
647
+ w_start = (W - target_width) // 2
648
+ initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
649
+ H, W = target_height, target_width
650
+
651
+ history_latents = initial_latents.to(device, dtype=model_dtype)
652
+
653
+ print(f"初始history_latents shape: {history_latents.shape}")
654
+
655
+ # 7. 编码prompt - 支持CFG
656
+ if text_guidance_scale > 1.0:
657
+ prompt_emb_pos = pipe.encode_prompt(prompt)
658
+ prompt_emb_neg = pipe.encode_prompt("")
659
+ print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
660
+ else:
661
+ prompt_emb_pos = pipe.encode_prompt(prompt)
662
+ prompt_emb_neg = None
663
+ print("不使用Text CFG")
664
+
665
+ # 8. 加载场景信息(对于NuScenes)
666
+ scene_info = None
667
+ if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
668
+ with open(scene_info_path, 'r') as f:
669
+ scene_info = json.load(f)
670
+ print(f"加载NuScenes场景信息: {scene_info_path}")
671
+
672
+ # 9. 预生成完整的camera embedding序列
673
+ if modality_type == "sekai":
674
+ camera_embedding_full = generate_sekai_camera_embeddings_sliding(
675
+ encoded_data.get('cam_emb', None),
676
+ 0,
677
+ max_history_frames,
678
+ 0,
679
+ 0,
680
+ use_real_poses=use_real_poses
681
+ ).to(device, dtype=model_dtype)
682
+ elif modality_type == "nuscenes":
683
+ camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
684
+ scene_info,
685
+ 0,
686
+ max_history_frames,
687
+ 0
688
+ ).to(device, dtype=model_dtype)
689
+ elif modality_type == "openx":
690
+ camera_embedding_full = generate_openx_camera_embeddings_sliding(
691
+ encoded_data,
692
+ 0,
693
+ max_history_frames,
694
+ 0,
695
+ use_real_poses=use_real_poses
696
+ ).to(device, dtype=model_dtype)
697
+ else:
698
+ raise ValueError(f"不支持的模态类型: {modality_type}")
699
+
700
+ print(f"完整camera序列shape: {camera_embedding_full.shape}")
701
+
702
+ # 10. 为Camera CFG创建无条件的camera embedding
703
+ if use_camera_cfg:
704
+ camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
705
+ print(f"创建无条件camera embedding用于CFG")
706
+
707
+ # 11. 滑动窗口生成循环
708
+ total_generated = 0
709
+ all_generated_frames = []
710
+
711
+ while total_generated < total_frames_to_generate:
712
+ current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
713
+ print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
714
+ print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
715
+
716
+ # FramePack数据准备 - MoE版本
717
+ framepack_data = prepare_framepack_sliding_window_with_camera_moe(
718
+ history_latents,
719
+ current_generation,
720
+ camera_embedding_full,
721
+ start_frame,
722
+ modality_type,
723
+ max_history_frames
724
+ )
725
+
726
+ # 准备输入
727
+ clean_latents = framepack_data['clean_latents'].unsqueeze(0)
728
+ clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
729
+ clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
730
+ camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
731
+
732
+ # 准备modality_inputs
733
+ modality_inputs = {modality_type: camera_embedding}
734
+
735
+ # 为CFG准备无条件camera embedding
736
+ if use_camera_cfg:
737
+ camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
738
+ modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
739
+
740
+ # 索引处理
741
+ latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
742
+ clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
743
+ clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
744
+ clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
745
+
746
+ # 初始化要生成的latents
747
+ new_latents = torch.randn(
748
+ 1, C, current_generation, H, W,
749
+ device=device, dtype=model_dtype
750
+ )
751
+
752
+ extra_input = pipe.prepare_extra_input(new_latents)
753
+
754
+ print(f"Camera embedding shape: {camera_embedding.shape}")
755
+ print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
756
+
757
+ # 去噪循环 - 支持CFG
758
+ timesteps = pipe.scheduler.timesteps
759
+
760
+ for i, timestep in enumerate(timesteps):
761
+ if i % 10 == 0:
762
+ print(f" 去噪步骤 {i+1}/{len(timesteps)}")
763
+
764
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
765
+
766
+ with torch.no_grad():
767
+ # CFG推理
768
+ if use_camera_cfg and camera_guidance_scale > 1.0:
769
+ # 条件预测(有camera)
770
+ noise_pred_cond, moe_loss = pipe.dit(
771
+ new_latents,
772
+ timestep=timestep_tensor,
773
+ cam_emb=camera_embedding,
774
+ modality_inputs=modality_inputs, # MoE模态输入
775
+ latent_indices=latent_indices,
776
+ clean_latents=clean_latents,
777
+ clean_latent_indices=clean_latent_indices,
778
+ clean_latents_2x=clean_latents_2x,
779
+ clean_latent_2x_indices=clean_latent_2x_indices,
780
+ clean_latents_4x=clean_latents_4x,
781
+ clean_latent_4x_indices=clean_latent_4x_indices,
782
+ **prompt_emb_pos,
783
+ **extra_input
784
+ )
785
+
786
+ # 无条件预测(无camera)
787
+ noise_pred_uncond, moe_loss = pipe.dit(
788
+ new_latents,
789
+ timestep=timestep_tensor,
790
+ cam_emb=camera_embedding_uncond_batch,
791
+ modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
792
+ latent_indices=latent_indices,
793
+ clean_latents=clean_latents,
794
+ clean_latent_indices=clean_latent_indices,
795
+ clean_latents_2x=clean_latents_2x,
796
+ clean_latent_2x_indices=clean_latent_2x_indices,
797
+ clean_latents_4x=clean_latents_4x,
798
+ clean_latent_4x_indices=clean_latent_4x_indices,
799
+ **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
800
+ **extra_input
801
+ )
802
+
803
+ # Camera CFG
804
+ noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
805
+
806
+ # 如果同时使用Text CFG
807
+ if text_guidance_scale > 1.0 and prompt_emb_neg:
808
+ noise_pred_text_uncond, moe_loss = pipe.dit(
809
+ new_latents,
810
+ timestep=timestep_tensor,
811
+ cam_emb=camera_embedding,
812
+ modality_inputs=modality_inputs,
813
+ latent_indices=latent_indices,
814
+ clean_latents=clean_latents,
815
+ clean_latent_indices=clean_latent_indices,
816
+ clean_latents_2x=clean_latents_2x,
817
+ clean_latent_2x_indices=clean_latent_2x_indices,
818
+ clean_latents_4x=clean_latents_4x,
819
+ clean_latent_4x_indices=clean_latent_4x_indices,
820
+ **prompt_emb_neg,
821
+ **extra_input
822
+ )
823
+
824
+ # 应用Text CFG到已经应用Camera CFG的结果
825
+ noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
826
+
827
+ elif text_guidance_scale > 1.0 and prompt_emb_neg:
828
+ # 只使用Text CFG
829
+ noise_pred_cond, moe_loss = pipe.dit(
830
+ new_latents,
831
+ timestep=timestep_tensor,
832
+ cam_emb=camera_embedding,
833
+ modality_inputs=modality_inputs,
834
+ latent_indices=latent_indices,
835
+ clean_latents=clean_latents,
836
+ clean_latent_indices=clean_latent_indices,
837
+ clean_latents_2x=clean_latents_2x,
838
+ clean_latent_2x_indices=clean_latent_2x_indices,
839
+ clean_latents_4x=clean_latents_4x,
840
+ clean_latent_4x_indices=clean_latent_4x_indices,
841
+ **prompt_emb_pos,
842
+ **extra_input
843
+ )
844
+
845
+ noise_pred_uncond, moe_loss = pipe.dit(
846
+ new_latents,
847
+ timestep=timestep_tensor,
848
+ cam_emb=camera_embedding,
849
+ modality_inputs=modality_inputs,
850
+ latent_indices=latent_indices,
851
+ clean_latents=clean_latents,
852
+ clean_latent_indices=clean_latent_indices,
853
+ clean_latents_2x=clean_latents_2x,
854
+ clean_latent_2x_indices=clean_latent_2x_indices,
855
+ clean_latents_4x=clean_latents_4x,
856
+ clean_latent_4x_indices=clean_latent_4x_indices,
857
+ **prompt_emb_neg,
858
+ **extra_input
859
+ )
860
+
861
+ noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
862
+
863
+ else:
864
+ # 标准推理(无CFG)
865
+ noise_pred, moe_loss = pipe.dit(
866
+ new_latents,
867
+ timestep=timestep_tensor,
868
+ cam_emb=camera_embedding,
869
+ modality_inputs=modality_inputs, # MoE模态输入
870
+ latent_indices=latent_indices,
871
+ clean_latents=clean_latents,
872
+ clean_latent_indices=clean_latent_indices,
873
+ clean_latents_2x=clean_latents_2x,
874
+ clean_latent_2x_indices=clean_latent_2x_indices,
875
+ clean_latents_4x=clean_latents_4x,
876
+ clean_latent_4x_indices=clean_latent_4x_indices,
877
+ **prompt_emb_pos,
878
+ **extra_input
879
+ )
880
+
881
+ new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
882
+
883
+ # 更新历史
884
+ new_latents_squeezed = new_latents.squeeze(0)
885
+ history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
886
+
887
+ # 维护滑动窗口
888
+ if history_latents.shape[1] > max_history_frames:
889
+ first_frame = history_latents[:, 0:1, :, :]
890
+ recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
891
+ history_latents = torch.cat([first_frame, recent_frames], dim=1)
892
+ print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
893
+
894
+ print(f"更新后history_latents shape: {history_latents.shape}")
895
+
896
+ all_generated_frames.append(new_latents_squeezed)
897
+ total_generated += current_generation
898
+
899
+ print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
900
+
901
+ # 12. 解码和保存
902
+ print("\n🔧 解码生成的视频...")
903
+
904
+ all_generated = torch.cat(all_generated_frames, dim=1)
905
+ final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
906
+
907
+ print(f"最终视频shape: {final_video.shape}")
908
+
909
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
910
+
911
+ print(f"Saving video to {output_path}")
912
+
913
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
914
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1)
915
+ video_np = (video_np * 255).astype(np.uint8)
916
+
917
+ with imageio.get_writer(output_path, fps=20) as writer:
918
+ for frame in video_np:
919
+ writer.append_data(frame)
920
+
921
+ print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
922
+ print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
923
+ print(f"使用模态: {modality_type}")
924
+
925
+
926
+ def main():
927
+ parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
928
+
929
+ # 基础参数
930
+ parser.add_argument("--condition_pth", type=str,
931
+ #default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
932
+ #default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth")
933
+ default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth")
934
+ #default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth")
935
+ parser.add_argument("--start_frame", type=int, default=0)
936
+ parser.add_argument("--initial_condition_frames", type=int, default=16)
937
+ parser.add_argument("--frames_per_generation", type=int, default=8)
938
+ parser.add_argument("--total_frames_to_generate", type=int, default=8)
939
+ parser.add_argument("--max_history_frames", type=int, default=100)
940
+ parser.add_argument("--use_real_poses", action="store_true", default=False)
941
+ parser.add_argument("--dit_path", type=str,
942
+ default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe_spatialvid/step250_moe.ckpt")
943
+ parser.add_argument("--output_path", type=str,
944
+ default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4')
945
+ parser.add_argument("--prompt", type=str,
946
+ default="A man enter the room")
947
+ parser.add_argument("--device", type=str, default="cuda")
948
+
949
+ # 模态类型参数
950
+ parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="sekai",
951
+ help="模态类型:sekai 或 nuscenes 或 openx")
952
+ parser.add_argument("--scene_info_path", type=str, default=None,
953
+ help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
954
+
955
+ # CFG参数
956
+ parser.add_argument("--use_camera_cfg", default=True,
957
+ help="使用Camera CFG")
958
+ parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
959
+ help="Camera guidance scale for CFG")
960
+ parser.add_argument("--text_guidance_scale", type=float, default=1.0,
961
+ help="Text guidance scale for CFG")
962
+
963
+ # MoE参数
964
+ parser.add_argument("--moe_num_experts", type=int, default=1, help="专家数量")
965
+ parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
966
+ parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
967
+
968
+ args = parser.parse_args()
969
+
970
+ print(f"🔧 MoE FramePack CFG生成设置:")
971
+ print(f"模态类型: {args.modality_type}")
972
+ print(f"Camera CFG: {args.use_camera_cfg}")
973
+ if args.use_camera_cfg:
974
+ print(f"Camera guidance scale: {args.camera_guidance_scale}")
975
+ print(f"Text guidance scale: {args.text_guidance_scale}")
976
+ print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
977
+
978
+ # 验证NuScenes参数
979
+ if args.modality_type == "nuscenes" and not args.scene_info_path:
980
+ print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
981
+
982
+ inference_moe_framepack_sliding_window(
983
+ condition_pth_path=args.condition_pth,
984
+ dit_path=args.dit_path,
985
+ output_path=args.output_path,
986
+ start_frame=args.start_frame,
987
+ initial_condition_frames=args.initial_condition_frames,
988
+ frames_per_generation=args.frames_per_generation,
989
+ total_frames_to_generate=args.total_frames_to_generate,
990
+ max_history_frames=args.max_history_frames,
991
+ device=args.device,
992
+ prompt=args.prompt,
993
+ modality_type=args.modality_type,
994
+ use_real_poses=args.use_real_poses,
995
+ scene_info_path=args.scene_info_path,
996
+ # CFG参数
997
+ use_camera_cfg=args.use_camera_cfg,
998
+ camera_guidance_scale=args.camera_guidance_scale,
999
+ text_guidance_scale=args.text_guidance_scale,
1000
+ # MoE参数
1001
+ moe_num_experts=args.moe_num_experts,
1002
+ moe_top_k=args.moe_top_k,
1003
+ moe_hidden_dim=args.moe_hidden_dim
1004
+ )
1005
+
1006
+
1007
+ if __name__ == "__main__":
1008
+ main()
scripts/infer_moe_test.py ADDED
@@ -0,0 +1,976 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ import imageio
7
+ import json
8
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
9
+ import argparse
10
+ from torchvision.transforms import v2
11
+ from einops import rearrange
12
+ import copy
13
+
14
+
15
+ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
16
+ """从pth文件加载预编码的视频数据"""
17
+ print(f"Loading encoded video from {pth_path}")
18
+
19
+ encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
20
+ full_latents = encoded_data['latents'] # [C, T, H, W]
21
+
22
+ print(f"Full latents shape: {full_latents.shape}")
23
+ print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
24
+
25
+ if start_frame + num_frames > full_latents.shape[1]:
26
+ raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
27
+
28
+ condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
29
+ print(f"Extracted condition latents shape: {condition_latents.shape}")
30
+
31
+ return condition_latents, encoded_data
32
+
33
+
34
+ def compute_relative_pose(pose_a, pose_b, use_torch=False):
35
+ """计算相机B相对于相机A的相对位姿矩阵"""
36
+ assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
37
+ assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
38
+
39
+ if use_torch:
40
+ if not isinstance(pose_a, torch.Tensor):
41
+ pose_a = torch.from_numpy(pose_a).float()
42
+ if not isinstance(pose_b, torch.Tensor):
43
+ pose_b = torch.from_numpy(pose_b).float()
44
+
45
+ pose_a_inv = torch.inverse(pose_a)
46
+ relative_pose = torch.matmul(pose_b, pose_a_inv)
47
+ else:
48
+ if not isinstance(pose_a, np.ndarray):
49
+ pose_a = np.array(pose_a, dtype=np.float32)
50
+ if not isinstance(pose_b, np.ndarray):
51
+ pose_b = np.array(pose_b, dtype=np.float32)
52
+
53
+ pose_a_inv = np.linalg.inv(pose_a)
54
+ relative_pose = np.matmul(pose_b, pose_a_inv)
55
+
56
+ return relative_pose
57
+
58
+
59
+ def replace_dit_model_in_manager():
60
+ """替换DiT模型类为MoE版本"""
61
+ from diffsynth.models.wan_video_dit_moe import WanModelMoe
62
+ from diffsynth.configs.model_config import model_loader_configs
63
+
64
+ for i, config in enumerate(model_loader_configs):
65
+ keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
66
+
67
+ if 'wan_video_dit' in model_names:
68
+ new_model_names = []
69
+ new_model_classes = []
70
+
71
+ for name, cls in zip(model_names, model_classes):
72
+ if name == 'wan_video_dit':
73
+ new_model_names.append(name)
74
+ new_model_classes.append(WanModelMoe)
75
+ print(f"✅ 替换了模型类: {name} -> WanModelMoe")
76
+ else:
77
+ new_model_names.append(name)
78
+ new_model_classes.append(cls)
79
+
80
+ model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
81
+
82
+
83
+ def add_framepack_components(dit_model):
84
+ """添加FramePack相关组件"""
85
+ if not hasattr(dit_model, 'clean_x_embedder'):
86
+ inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
87
+
88
+ class CleanXEmbedder(nn.Module):
89
+ def __init__(self, inner_dim):
90
+ super().__init__()
91
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
92
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
93
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
94
+
95
+ def forward(self, x, scale="1x"):
96
+ if scale == "1x":
97
+ x = x.to(self.proj.weight.dtype)
98
+ return self.proj(x)
99
+ elif scale == "2x":
100
+ x = x.to(self.proj_2x.weight.dtype)
101
+ return self.proj_2x(x)
102
+ elif scale == "4x":
103
+ x = x.to(self.proj_4x.weight.dtype)
104
+ return self.proj_4x(x)
105
+ else:
106
+ raise ValueError(f"Unsupported scale: {scale}")
107
+
108
+ dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
109
+ model_dtype = next(dit_model.parameters()).dtype
110
+ dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
111
+ print("✅ 添加了FramePack的clean_x_embedder组件")
112
+
113
+
114
+ def add_moe_components(dit_model, moe_config):
115
+ """🔧 添加MoE相关组件 - 修正版本"""
116
+ if not hasattr(dit_model, 'moe_config'):
117
+ dit_model.moe_config = moe_config
118
+ print("✅ 添加了MoE配置到模型")
119
+
120
+ # 为每个block动态添加MoE组件
121
+ dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
122
+ unified_dim = moe_config.get("unified_dim", 25)
123
+
124
+ for i, block in enumerate(dit_model.blocks):
125
+ from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
126
+
127
+ # Sekai模态处理器 - 输出unified_dim
128
+ block.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
129
+
130
+ # # NuScenes模态处理器 - 输出unified_dim
131
+ # block.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
132
+
133
+ # MoE网络 - 输入unified_dim,输出dim
134
+ block.moe = MultiModalMoE(
135
+ unified_dim=unified_dim,
136
+ output_dim=dim, # 输出维度匹配transformer block的dim
137
+ num_experts=moe_config.get("num_experts", 4),
138
+ top_k=moe_config.get("top_k", 2)
139
+ )
140
+
141
+ print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
142
+
143
+
144
+ def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
145
+ """为Sekai数据集生成camera embeddings - 滑动窗口版本"""
146
+ time_compression_ratio = 4
147
+
148
+ # 计算FramePack实际需要的camera帧数
149
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
150
+
151
+ if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
152
+ print("🔧 使用真实Sekai camera数据")
153
+ cam_extrinsic = cam_data['extrinsic']
154
+
155
+ # 确保生成足够长的camera序列
156
+ max_needed_frames = max(
157
+ start_frame + current_history_length + new_frames,
158
+ framepack_needed_frames,
159
+ 30
160
+ )
161
+
162
+ print(f"🔧 计算Sekai camera序列长度:")
163
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
164
+ print(f" - FramePack需求: {framepack_needed_frames}")
165
+ print(f" - 最终生成: {max_needed_frames}")
166
+
167
+ relative_poses = []
168
+ for i in range(max_needed_frames):
169
+ # 计算当前帧在原始序列中的位置
170
+ frame_idx = i * time_compression_ratio
171
+ next_frame_idx = frame_idx + time_compression_ratio
172
+
173
+ if next_frame_idx < len(cam_extrinsic):
174
+ cam_prev = cam_extrinsic[frame_idx]
175
+ cam_next = cam_extrinsic[next_frame_idx]
176
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
177
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
178
+ else:
179
+ # 超出范围,使用零运动
180
+ print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
181
+ relative_poses.append(torch.zeros(3, 4))
182
+
183
+ pose_embedding = torch.stack(relative_poses, dim=0)
184
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
185
+
186
+ # 创建对应长度的mask序列
187
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
188
+ # 从start_frame到current_history_length标记为condition
189
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
190
+ mask[start_frame:condition_end] = 1.0
191
+
192
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
193
+ print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
194
+ return camera_embedding.to(torch.bfloat16)
195
+
196
+ else:
197
+ print("🔧 使用Sekai合成camera数据")
198
+
199
+ max_needed_frames = max(
200
+ start_frame + current_history_length + new_frames,
201
+ framepack_needed_frames,
202
+ 30
203
+ )
204
+
205
+ print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
206
+ relative_poses = []
207
+ for i in range(max_needed_frames):
208
+ # 持续左转运动模式
209
+ yaw_per_frame = 0.05 # 每帧左转(正角度表示左转)
210
+ forward_speed = 0.005 # 每帧前进距离
211
+
212
+ pose = np.eye(4, dtype=np.float32)
213
+
214
+ # 旋转矩阵(绕Y轴左转)
215
+ cos_yaw = np.cos(yaw_per_frame)
216
+ sin_yaw = np.sin(yaw_per_frame)
217
+
218
+ pose[0, 0] = cos_yaw
219
+ pose[0, 2] = sin_yaw
220
+ pose[2, 0] = -sin_yaw
221
+ pose[2, 2] = cos_yaw
222
+
223
+ # 平移(在旋转后的局部坐标系中前进)
224
+ pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
225
+
226
+ # 添加轻微的向心运动,模拟圆形轨迹
227
+ radius_drift = 0.002 # 向圆心的轻微漂移
228
+ pose[0, 3] = -radius_drift # 局部X轴负方向(向左)
229
+
230
+ relative_pose = pose[:3, :]
231
+ relative_poses.append(torch.as_tensor(relative_pose))
232
+
233
+ pose_embedding = torch.stack(relative_poses, dim=0)
234
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
235
+
236
+ # 创建对应长度的mask序列
237
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
238
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
239
+ mask[start_frame:condition_end] = 1.0
240
+
241
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
242
+ print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
243
+ return camera_embedding.to(torch.bfloat16)
244
+
245
+ def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses):
246
+ """为OpenX数据集生成camera embeddings - 滑动窗口版本"""
247
+ time_compression_ratio = 4
248
+
249
+ # 计算FramePack实际需要的camera帧数
250
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
251
+
252
+ if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
253
+ print("🔧 使用OpenX真实camera数据")
254
+ cam_extrinsic = encoded_data['cam_emb']['extrinsic']
255
+
256
+ # 确保生成足够长的camera序列
257
+ max_needed_frames = max(
258
+ start_frame + current_history_length + new_frames,
259
+ framepack_needed_frames,
260
+ 30
261
+ )
262
+
263
+ print(f"🔧 计算OpenX camera序列长度:")
264
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
265
+ print(f" - FramePack需求: {framepack_needed_frames}")
266
+ print(f" - 最终生成: {max_needed_frames}")
267
+
268
+ relative_poses = []
269
+ for i in range(max_needed_frames):
270
+ # OpenX使用4倍间隔,类似sekai但处理更短的序列
271
+ frame_idx = i * time_compression_ratio
272
+ next_frame_idx = frame_idx + time_compression_ratio
273
+
274
+ if next_frame_idx < len(cam_extrinsic):
275
+ cam_prev = cam_extrinsic[frame_idx]
276
+ cam_next = cam_extrinsic[next_frame_idx]
277
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
278
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
279
+ else:
280
+ # 超出范围,使用零运动
281
+ print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
282
+ relative_poses.append(torch.zeros(3, 4))
283
+
284
+ pose_embedding = torch.stack(relative_poses, dim=0)
285
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
286
+
287
+ # 创建对应长度的mask序列
288
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
289
+ # 从start_frame到current_history_length标记为condition
290
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
291
+ mask[start_frame:condition_end] = 1.0
292
+
293
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
294
+ print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
295
+ return camera_embedding.to(torch.bfloat16)
296
+
297
+ else:
298
+ print("🔧 使用OpenX合成camera数据")
299
+
300
+ max_needed_frames = max(
301
+ start_frame + current_history_length + new_frames,
302
+ framepack_needed_frames,
303
+ 30
304
+ )
305
+
306
+ print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
307
+ relative_poses = []
308
+ for i in range(max_needed_frames):
309
+ # OpenX机器人操作运动模式 - 较小的运动幅度
310
+ # 模拟机器人手臂的精细操作运动
311
+ roll_per_frame = 0.02 # 轻微翻滚
312
+ pitch_per_frame = 0.01 # 轻微俯仰
313
+ yaw_per_frame = 0.015 # 轻微偏航
314
+ forward_speed = 0.003 # 较慢的前进速度
315
+
316
+ pose = np.eye(4, dtype=np.float32)
317
+
318
+ # 复合旋转 - 模拟机器人手臂的复杂运动
319
+ # 绕X轴旋转(roll)
320
+ cos_roll = np.cos(roll_per_frame)
321
+ sin_roll = np.sin(roll_per_frame)
322
+ # 绕Y轴旋转(pitch)
323
+ cos_pitch = np.cos(pitch_per_frame)
324
+ sin_pitch = np.sin(pitch_per_frame)
325
+ # 绕Z轴旋转(yaw)
326
+ cos_yaw = np.cos(yaw_per_frame)
327
+ sin_yaw = np.sin(yaw_per_frame)
328
+
329
+ # 简化的复合旋转矩阵(ZYX顺序)
330
+ pose[0, 0] = cos_yaw * cos_pitch
331
+ pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
332
+ pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
333
+ pose[1, 0] = sin_yaw * cos_pitch
334
+ pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
335
+ pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
336
+ pose[2, 0] = -sin_pitch
337
+ pose[2, 1] = cos_pitch * sin_roll
338
+ pose[2, 2] = cos_pitch * cos_roll
339
+
340
+ # 平移 - 模拟机器人操作的精细移动
341
+ pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
342
+ pose[1, 3] = forward_speed * 0.3 # Y��向轻微移动
343
+ pose[2, 3] = -forward_speed # Z方向(深度)主要移动
344
+
345
+ relative_pose = pose[:3, :]
346
+ relative_poses.append(torch.as_tensor(relative_pose))
347
+
348
+ pose_embedding = torch.stack(relative_poses, dim=0)
349
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
350
+
351
+ # 创建对应长度的mask序列
352
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
353
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
354
+ mask[start_frame:condition_end] = 1.0
355
+
356
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
357
+ print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
358
+ return camera_embedding.to(torch.bfloat16)
359
+
360
+
361
+ def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames):
362
+ """为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
363
+ time_compression_ratio = 4
364
+
365
+ # 计算FramePack实际需要的camera帧数
366
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
367
+
368
+ if scene_info is not None and 'keyframe_poses' in scene_info:
369
+ print("🔧 使用NuScenes真实pose数据")
370
+ keyframe_poses = scene_info['keyframe_poses']
371
+
372
+ if len(keyframe_poses) == 0:
373
+ print("⚠️ NuScenes keyframe_poses为空,使用零pose")
374
+ max_needed_frames = max(framepack_needed_frames, 30)
375
+
376
+ pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
377
+
378
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
379
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
380
+ mask[start_frame:condition_end] = 1.0
381
+
382
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
383
+ print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}")
384
+ return camera_embedding.to(torch.bfloat16)
385
+
386
+ # 使用第一个pose作为参考
387
+ reference_pose = keyframe_poses[0]
388
+
389
+ max_needed_frames = max(framepack_needed_frames, 30)
390
+
391
+ pose_vecs = []
392
+ for i in range(max_needed_frames):
393
+ if i < len(keyframe_poses):
394
+ current_pose = keyframe_poses[i]
395
+
396
+ # 计算相对位移
397
+ translation = torch.tensor(
398
+ np.array(current_pose['translation']) - np.array(reference_pose['translation']),
399
+ dtype=torch.float32
400
+ )
401
+
402
+ # 计算相对旋转(简化版本)
403
+ rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
404
+
405
+ pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
406
+ else:
407
+ # 超出范围,使用零pose
408
+ pose_vec = torch.cat([
409
+ torch.zeros(3, dtype=torch.float32),
410
+ torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
411
+ ], dim=0) # [7D]
412
+
413
+ pose_vecs.append(pose_vec)
414
+
415
+ pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
416
+
417
+ # 创建mask
418
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
419
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
420
+ mask[start_frame:condition_end] = 1.0
421
+
422
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
423
+ print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
424
+ return camera_embedding.to(torch.bfloat16)
425
+
426
+ else:
427
+ print("🔧 使用NuScenes合成pose数据")
428
+ max_needed_frames = max(framepack_needed_frames, 30)
429
+
430
+ # 创建合成运动序列
431
+ pose_vecs = []
432
+ for i in range(max_needed_frames):
433
+ # 简单的前进运动
434
+ translation = torch.tensor([0.0, 0.0, i * 0.1], dtype=torch.float32) # 沿Z轴前进
435
+ rotation = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32) # 无旋转
436
+
437
+ pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
438
+ pose_vecs.append(pose_vec)
439
+
440
+ pose_sequence = torch.stack(pose_vecs, dim=0)
441
+
442
+ # 创建mask
443
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
444
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
445
+ mask[start_frame:condition_end] = 1.0
446
+
447
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
448
+ print(f"🔧 NuScenes合成pose embedding shape: {camera_embedding.shape}")
449
+ return camera_embedding.to(torch.bfloat16)
450
+
451
+ def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49):
452
+ """FramePack滑动窗口机制 - MoE版本"""
453
+ # history_latents: [C, T, H, W] 当前的历史latents
454
+ C, T, H, W = history_latents.shape
455
+
456
+ # 固定索引结构(这决定了需要的camera帧数)
457
+ total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
458
+ indices = torch.arange(0, total_indices_length)
459
+ split_sizes = [1, 16, 2, 1, target_frames_to_generate]
460
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
461
+ indices.split(split_sizes, dim=0)
462
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
463
+
464
+ # 检查camera长度是否足够
465
+ if camera_embedding_full.shape[0] < total_indices_length:
466
+ shortage = total_indices_length - camera_embedding_full.shape[0]
467
+ padding = torch.zeros(shortage, camera_embedding_full.shape[1],
468
+ dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
469
+ camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
470
+
471
+ # 从完整camera序列中选取对应部分
472
+ combined_camera = camera_embedding_full[:total_indices_length, :].clone()
473
+
474
+ # 根据当前history length重新设置mask
475
+ combined_camera[:, -1] = 0.0 # 先全部设为target (0)
476
+
477
+ # 设置condition mask:前19帧根据实际历史长度决定
478
+ if T > 0:
479
+ available_frames = min(T, 19)
480
+ start_pos = 19 - available_frames
481
+ combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
482
+
483
+ print(f"🔧 MoE Camera mask更新:")
484
+ print(f" - 历史帧数: {T}")
485
+ print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
486
+ print(f" - 模态类型: {modality_type}")
487
+
488
+ # 处理latents
489
+ clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
490
+
491
+ if T > 0:
492
+ available_frames = min(T, 19)
493
+ start_pos = 19 - available_frames
494
+ clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
495
+
496
+ clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
497
+ clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
498
+ clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
499
+
500
+ if T > 0:
501
+ start_latent = history_latents[:, 0:1, :, :]
502
+ else:
503
+ start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
504
+
505
+ clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
506
+
507
+ return {
508
+ 'latent_indices': latent_indices,
509
+ 'clean_latents': clean_latents,
510
+ 'clean_latents_2x': clean_latents_2x,
511
+ 'clean_latents_4x': clean_latents_4x,
512
+ 'clean_latent_indices': clean_latent_indices,
513
+ 'clean_latent_2x_indices': clean_latent_2x_indices,
514
+ 'clean_latent_4x_indices': clean_latent_4x_indices,
515
+ 'camera_embedding': combined_camera,
516
+ 'modality_type': modality_type, # 新增模态类型信息
517
+ 'current_length': T,
518
+ 'next_length': T + target_frames_to_generate
519
+ }
520
+
521
+
522
+ def inference_moe_framepack_sliding_window(
523
+ condition_pth_path,
524
+ dit_path,
525
+ output_path="moe/infer_results/output_moe_framepack_sliding.mp4",
526
+ start_frame=0,
527
+ initial_condition_frames=8,
528
+ frames_per_generation=4,
529
+ total_frames_to_generate=32,
530
+ max_history_frames=49,
531
+ device="cuda",
532
+ prompt="A video of a scene shot using a pedestrian's front camera while walking",
533
+ modality_type="sekai", # "sekai" 或 "nuscenes"
534
+ use_real_poses=True,
535
+ scene_info_path=None, # 对于NuScenes数据集
536
+ # CFG参数
537
+ use_camera_cfg=True,
538
+ camera_guidance_scale=2.0,
539
+ text_guidance_scale=1.0,
540
+ # MoE参数
541
+ moe_num_experts=4,
542
+ moe_top_k=2,
543
+ moe_hidden_dim=None
544
+ ):
545
+ """
546
+ MoE FramePack滑动窗口视频生成 - 支持多模态
547
+ """
548
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
549
+ print(f"🔧 MoE FramePack滑动窗口生成开始...")
550
+ print(f"模态类型: {modality_type}")
551
+ print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
552
+ print(f"Text guidance scale: {text_guidance_scale}")
553
+ print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
554
+
555
+ # 1. 模型初始化
556
+ replace_dit_model_in_manager()
557
+
558
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
559
+ model_manager.load_models([
560
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
561
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
562
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
563
+ ])
564
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
565
+
566
+ # 2. 添加传统camera编码器(兼容性)
567
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
568
+ for block in pipe.dit.blocks:
569
+ block.cam_encoder = nn.Linear(13, dim)
570
+ block.projector = nn.Linear(dim, dim)
571
+ block.cam_encoder.weight.data.zero_()
572
+ block.cam_encoder.bias.data.zero_()
573
+ block.projector.weight = nn.Parameter(torch.eye(dim))
574
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
575
+
576
+ # 3. 添加FramePack组件
577
+ add_framepack_components(pipe.dit)
578
+
579
+ # 4. 添加MoE组件
580
+ moe_config = {
581
+ "num_experts": moe_num_experts,
582
+ "top_k": moe_top_k,
583
+ "hidden_dim": moe_hidden_dim or dim * 2,
584
+ "sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
585
+ "nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
586
+ "openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
587
+ }
588
+ add_moe_components(pipe.dit, moe_config)
589
+
590
+ # 5. 加载训练好的权重
591
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
592
+ pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
593
+ pipe = pipe.to(device)
594
+ model_dtype = next(pipe.dit.parameters()).dtype
595
+
596
+ if hasattr(pipe.dit, 'clean_x_embedder'):
597
+ pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
598
+
599
+ pipe.scheduler.set_timesteps(50)
600
+
601
+ # 6. 加载初始条件
602
+ print("Loading initial condition frames...")
603
+ initial_latents, encoded_data = load_encoded_video_from_pth(
604
+ condition_pth_path,
605
+ start_frame=start_frame,
606
+ num_frames=initial_condition_frames
607
+ )
608
+
609
+ # 空间裁剪
610
+ target_height, target_width = 60, 104
611
+ C, T, H, W = initial_latents.shape
612
+
613
+ if H > target_height or W > target_width:
614
+ h_start = (H - target_height) // 2
615
+ w_start = (W - target_width) // 2
616
+ initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
617
+ H, W = target_height, target_width
618
+
619
+ history_latents = initial_latents.to(device, dtype=model_dtype)
620
+
621
+ print(f"初始history_latents shape: {history_latents.shape}")
622
+
623
+ # 7. 编码prompt - 支持CFG
624
+ if text_guidance_scale > 1.0:
625
+ prompt_emb_pos = pipe.encode_prompt(prompt)
626
+ prompt_emb_neg = pipe.encode_prompt("")
627
+ print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
628
+ else:
629
+ prompt_emb_pos = pipe.encode_prompt(prompt)
630
+ prompt_emb_neg = None
631
+ print("不使用Text CFG")
632
+
633
+ # 8. 加载场景信息(对于NuScenes)
634
+ scene_info = None
635
+ if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
636
+ with open(scene_info_path, 'r') as f:
637
+ scene_info = json.load(f)
638
+ print(f"加载NuScenes场景信息: {scene_info_path}")
639
+
640
+ # 9. 预生成完整的camera embedding序列
641
+ if modality_type == "sekai":
642
+ camera_embedding_full = generate_sekai_camera_embeddings_sliding(
643
+ encoded_data.get('cam_emb', None),
644
+ 0,
645
+ max_history_frames,
646
+ 0,
647
+ 0,
648
+ use_real_poses=use_real_poses
649
+ ).to(device, dtype=model_dtype)
650
+ elif modality_type == "nuscenes":
651
+ camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
652
+ scene_info,
653
+ 0,
654
+ max_history_frames,
655
+ 0
656
+ ).to(device, dtype=model_dtype)
657
+ elif modality_type == "openx":
658
+ camera_embedding_full = generate_openx_camera_embeddings_sliding(
659
+ encoded_data,
660
+ 0,
661
+ max_history_frames,
662
+ 0,
663
+ use_real_poses=use_real_poses
664
+ ).to(device, dtype=model_dtype)
665
+ else:
666
+ raise ValueError(f"不支持的模态类型: {modality_type}")
667
+
668
+ print(f"完整camera序列shape: {camera_embedding_full.shape}")
669
+
670
+ # 10. 为Camera CFG创建无条件的camera embedding
671
+ if use_camera_cfg:
672
+ camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
673
+ print(f"创建无条件camera embedding用于CFG")
674
+
675
+ # 11. 滑动窗口生成循环
676
+ total_generated = 0
677
+ all_generated_frames = []
678
+
679
+ while total_generated < total_frames_to_generate:
680
+ current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
681
+ print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
682
+ print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
683
+
684
+ # FramePack数据准备 - MoE版本
685
+ framepack_data = prepare_framepack_sliding_window_with_camera_moe(
686
+ history_latents,
687
+ current_generation,
688
+ camera_embedding_full,
689
+ start_frame,
690
+ modality_type,
691
+ max_history_frames
692
+ )
693
+
694
+ # 准备输入
695
+ clean_latents = framepack_data['clean_latents'].unsqueeze(0)
696
+ clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
697
+ clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
698
+ camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
699
+
700
+ # 准备modality_inputs
701
+ modality_inputs = {modality_type: camera_embedding}
702
+
703
+ # 为CFG准备无条件camera embedding
704
+ if use_camera_cfg:
705
+ camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
706
+ modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
707
+
708
+ # 索引处理
709
+ latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
710
+ clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
711
+ clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
712
+ clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
713
+
714
+ # 初始化要生成的latents
715
+ new_latents = torch.randn(
716
+ 1, C, current_generation, H, W,
717
+ device=device, dtype=model_dtype
718
+ )
719
+
720
+ extra_input = pipe.prepare_extra_input(new_latents)
721
+
722
+ print(f"Camera embedding shape: {camera_embedding.shape}")
723
+ print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
724
+
725
+ # 去噪循环 - 支持CFG
726
+ timesteps = pipe.scheduler.timesteps
727
+
728
+ for i, timestep in enumerate(timesteps):
729
+ if i % 10 == 0:
730
+ print(f" 去噪步骤 {i+1}/{len(timesteps)}")
731
+
732
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
733
+
734
+ with torch.no_grad():
735
+ # CFG推理
736
+ if use_camera_cfg and camera_guidance_scale > 1.0:
737
+ # 条件预测(有camera)
738
+ noise_pred_cond, moe_loss = pipe.dit(
739
+ new_latents,
740
+ timestep=timestep_tensor,
741
+ cam_emb=camera_embedding,
742
+ modality_inputs=modality_inputs, # MoE模态输入
743
+ latent_indices=latent_indices,
744
+ clean_latents=clean_latents,
745
+ clean_latent_indices=clean_latent_indices,
746
+ clean_latents_2x=clean_latents_2x,
747
+ clean_latent_2x_indices=clean_latent_2x_indices,
748
+ clean_latents_4x=clean_latents_4x,
749
+ clean_latent_4x_indices=clean_latent_4x_indices,
750
+ **prompt_emb_pos,
751
+ **extra_input
752
+ )
753
+
754
+ # 无条件预测(无camera)
755
+ noise_pred_uncond, moe_loss = pipe.dit(
756
+ new_latents,
757
+ timestep=timestep_tensor,
758
+ cam_emb=camera_embedding_uncond_batch,
759
+ modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
760
+ latent_indices=latent_indices,
761
+ clean_latents=clean_latents,
762
+ clean_latent_indices=clean_latent_indices,
763
+ clean_latents_2x=clean_latents_2x,
764
+ clean_latent_2x_indices=clean_latent_2x_indices,
765
+ clean_latents_4x=clean_latents_4x,
766
+ clean_latent_4x_indices=clean_latent_4x_indices,
767
+ **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
768
+ **extra_input
769
+ )
770
+
771
+ # Camera CFG
772
+ noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
773
+
774
+ # 如果同时使用Text CFG
775
+ if text_guidance_scale > 1.0 and prompt_emb_neg:
776
+ noise_pred_text_uncond, moe_loss = pipe.dit(
777
+ new_latents,
778
+ timestep=timestep_tensor,
779
+ cam_emb=camera_embedding,
780
+ modality_inputs=modality_inputs,
781
+ latent_indices=latent_indices,
782
+ clean_latents=clean_latents,
783
+ clean_latent_indices=clean_latent_indices,
784
+ clean_latents_2x=clean_latents_2x,
785
+ clean_latent_2x_indices=clean_latent_2x_indices,
786
+ clean_latents_4x=clean_latents_4x,
787
+ clean_latent_4x_indices=clean_latent_4x_indices,
788
+ **prompt_emb_neg,
789
+ **extra_input
790
+ )
791
+
792
+ # 应用Text CFG到已经应用Camera CFG的结果
793
+ noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
794
+
795
+ elif text_guidance_scale > 1.0 and prompt_emb_neg:
796
+ # 只使用Text CFG
797
+ noise_pred_cond, moe_loss = pipe.dit(
798
+ new_latents,
799
+ timestep=timestep_tensor,
800
+ cam_emb=camera_embedding,
801
+ modality_inputs=modality_inputs,
802
+ latent_indices=latent_indices,
803
+ clean_latents=clean_latents,
804
+ clean_latent_indices=clean_latent_indices,
805
+ clean_latents_2x=clean_latents_2x,
806
+ clean_latent_2x_indices=clean_latent_2x_indices,
807
+ clean_latents_4x=clean_latents_4x,
808
+ clean_latent_4x_indices=clean_latent_4x_indices,
809
+ **prompt_emb_pos,
810
+ **extra_input
811
+ )
812
+
813
+ noise_pred_uncond, moe_loss = pipe.dit(
814
+ new_latents,
815
+ timestep=timestep_tensor,
816
+ cam_emb=camera_embedding,
817
+ modality_inputs=modality_inputs,
818
+ latent_indices=latent_indices,
819
+ clean_latents=clean_latents,
820
+ clean_latent_indices=clean_latent_indices,
821
+ clean_latents_2x=clean_latents_2x,
822
+ clean_latent_2x_indices=clean_latent_2x_indices,
823
+ clean_latents_4x=clean_latents_4x,
824
+ clean_latent_4x_indices=clean_latent_4x_indices,
825
+ **prompt_emb_neg,
826
+ **extra_input
827
+ )
828
+
829
+ noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
830
+
831
+ else:
832
+ # 标准推理(无CFG)
833
+ noise_pred, moe_loss = pipe.dit(
834
+ new_latents,
835
+ timestep=timestep_tensor,
836
+ cam_emb=camera_embedding,
837
+ modality_inputs=modality_inputs, # MoE模态输入
838
+ latent_indices=latent_indices,
839
+ clean_latents=clean_latents,
840
+ clean_latent_indices=clean_latent_indices,
841
+ clean_latents_2x=clean_latents_2x,
842
+ clean_latent_2x_indices=clean_latent_2x_indices,
843
+ clean_latents_4x=clean_latents_4x,
844
+ clean_latent_4x_indices=clean_latent_4x_indices,
845
+ **prompt_emb_pos,
846
+ **extra_input
847
+ )
848
+
849
+ new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
850
+
851
+ # 更新历史
852
+ new_latents_squeezed = new_latents.squeeze(0)
853
+ history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
854
+
855
+ # 维护滑动窗口
856
+ if history_latents.shape[1] > max_history_frames:
857
+ first_frame = history_latents[:, 0:1, :, :]
858
+ recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
859
+ history_latents = torch.cat([first_frame, recent_frames], dim=1)
860
+ print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
861
+
862
+ print(f"更新后history_latents shape: {history_latents.shape}")
863
+
864
+ all_generated_frames.append(new_latents_squeezed)
865
+ total_generated += current_generation
866
+
867
+ print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
868
+
869
+ # 12. 解码和保存
870
+ print("\n🔧 解码生成的视频...")
871
+
872
+ all_generated = torch.cat(all_generated_frames, dim=1)
873
+ final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
874
+
875
+ print(f"最终视频shape: {final_video.shape}")
876
+
877
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
878
+
879
+ print(f"Saving video to {output_path}")
880
+
881
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
882
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1)
883
+ video_np = (video_np * 255).astype(np.uint8)
884
+
885
+ with imageio.get_writer(output_path, fps=20) as writer:
886
+ for frame in video_np:
887
+ writer.append_data(frame)
888
+
889
+ print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
890
+ print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
891
+ print(f"使用模态: {modality_type}")
892
+
893
+
894
+ def main():
895
+ parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
896
+
897
+ # 基��参数
898
+ parser.add_argument("--condition_pth", type=str,
899
+ default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
900
+ #default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth")
901
+ #default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth")
902
+ #default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth")
903
+ parser.add_argument("--start_frame", type=int, default=0)
904
+ parser.add_argument("--initial_condition_frames", type=int, default=16)
905
+ parser.add_argument("--frames_per_generation", type=int, default=8)
906
+ parser.add_argument("--total_frames_to_generate", type=int, default=40)
907
+ parser.add_argument("--max_history_frames", type=int, default=100)
908
+ parser.add_argument("--use_real_poses", action="store_true", default=False)
909
+ parser.add_argument("--dit_path", type=str,
910
+ default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe_test/step1000_moe.ckpt")
911
+ parser.add_argument("--output_path", type=str,
912
+ default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4')
913
+ parser.add_argument("--prompt", type=str,
914
+ default="A drone flying scene in a game world")
915
+ parser.add_argument("--device", type=str, default="cuda")
916
+
917
+ # 模态类型参数
918
+ parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="sekai",
919
+ help="模态类型:sekai 或 nuscenes 或 openx")
920
+ parser.add_argument("--scene_info_path", type=str, default=None,
921
+ help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
922
+
923
+ # CFG参数
924
+ parser.add_argument("--use_camera_cfg", default=True,
925
+ help="使用Camera CFG")
926
+ parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
927
+ help="Camera guidance scale for CFG")
928
+ parser.add_argument("--text_guidance_scale", type=float, default=1.0,
929
+ help="Text guidance scale for CFG")
930
+
931
+ # MoE参数
932
+ parser.add_argument("--moe_num_experts", type=int, default=1, help="专家数量")
933
+ parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
934
+ parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
935
+
936
+ args = parser.parse_args()
937
+
938
+ print(f"🔧 MoE FramePack CFG生成设置:")
939
+ print(f"模态类型: {args.modality_type}")
940
+ print(f"Camera CFG: {args.use_camera_cfg}")
941
+ if args.use_camera_cfg:
942
+ print(f"Camera guidance scale: {args.camera_guidance_scale}")
943
+ print(f"Text guidance scale: {args.text_guidance_scale}")
944
+ print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
945
+
946
+ # 验证NuScenes参数
947
+ if args.modality_type == "nuscenes" and not args.scene_info_path:
948
+ print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
949
+
950
+ inference_moe_framepack_sliding_window(
951
+ condition_pth_path=args.condition_pth,
952
+ dit_path=args.dit_path,
953
+ output_path=args.output_path,
954
+ start_frame=args.start_frame,
955
+ initial_condition_frames=args.initial_condition_frames,
956
+ frames_per_generation=args.frames_per_generation,
957
+ total_frames_to_generate=args.total_frames_to_generate,
958
+ max_history_frames=args.max_history_frames,
959
+ device=args.device,
960
+ prompt=args.prompt,
961
+ modality_type=args.modality_type,
962
+ use_real_poses=args.use_real_poses,
963
+ scene_info_path=args.scene_info_path,
964
+ # CFG参数
965
+ use_camera_cfg=args.use_camera_cfg,
966
+ camera_guidance_scale=args.camera_guidance_scale,
967
+ text_guidance_scale=args.text_guidance_scale,
968
+ # MoE参数
969
+ moe_num_experts=args.moe_num_experts,
970
+ moe_top_k=args.moe_top_k,
971
+ moe_hidden_dim=args.moe_hidden_dim
972
+ )
973
+
974
+
975
+ if __name__ == "__main__":
976
+ main()
scripts/infer_nus.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import imageio
6
+ import json
7
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
8
+ import argparse
9
+ from torchvision.transforms import v2
10
+ from einops import rearrange
11
+ import torch.nn as nn
12
+ from pose_classifier import PoseClassifier
13
+
14
+
15
+ def load_video_frames(video_path, num_frames=20, height=900, width=1600):
16
+ """Load video frames and preprocess them"""
17
+ frame_process = v2.Compose([
18
+ # v2.CenterCrop(size=(height, width)),
19
+ # v2.Resize(size=(height, width), antialias=True),
20
+ v2.ToTensor(),
21
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
22
+ ])
23
+
24
+ def crop_and_resize(image):
25
+ w, h = image.size
26
+ # scale = max(width / w, height / h)
27
+ image = v2.functional.resize(
28
+ image,
29
+ (round(480), round(832)),
30
+ interpolation=v2.InterpolationMode.BILINEAR
31
+ )
32
+ return image
33
+
34
+ reader = imageio.get_reader(video_path)
35
+ frames = []
36
+
37
+ for i, frame_data in enumerate(reader):
38
+ if i >= num_frames:
39
+ break
40
+ frame = Image.fromarray(frame_data)
41
+ frame = crop_and_resize(frame)
42
+ frame = frame_process(frame)
43
+ frames.append(frame)
44
+
45
+ reader.close()
46
+
47
+ if len(frames) == 0:
48
+ return None
49
+
50
+ frames = torch.stack(frames, dim=0)
51
+ frames = rearrange(frames, "T C H W -> C T H W")
52
+ return frames
53
+
54
+ def calculate_relative_rotation(current_rotation, reference_rotation):
55
+ """计算相对旋转四元数"""
56
+ q_current = torch.tensor(current_rotation, dtype=torch.float32)
57
+ q_ref = torch.tensor(reference_rotation, dtype=torch.float32)
58
+
59
+ # 计算参考旋转的逆 (q_ref^-1)
60
+ q_ref_inv = torch.tensor([q_ref[0], -q_ref[1], -q_ref[2], -q_ref[3]])
61
+
62
+ # 四元数乘法计算相对旋转: q_relative = q_ref^-1 * q_current
63
+ w1, x1, y1, z1 = q_ref_inv
64
+ w2, x2, y2, z2 = q_current
65
+
66
+ relative_rotation = torch.tensor([
67
+ w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
68
+ w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
69
+ w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
70
+ w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
71
+ ])
72
+
73
+ return relative_rotation
74
+
75
+ def generate_direction_poses(direction="left", target_frames=10, condition_frames=20):
76
+ """
77
+ 根据指定方向生成pose类别embedding,包含condition和target帧
78
+ Args:
79
+ direction: 'forward', 'backward', 'left_turn', 'right_turn'
80
+ target_frames: 目标帧数
81
+ condition_frames: 条件帧数
82
+ """
83
+ classifier = PoseClassifier()
84
+
85
+ total_frames = condition_frames + target_frames
86
+ print(f"conditon{condition_frames}")
87
+ print(f"target{target_frames}")
88
+ poses = []
89
+
90
+ # 🔧 生成condition帧的pose(相对稳定的前向运动)
91
+ for i in range(condition_frames):
92
+ t = i / max(1, condition_frames - 1) # 0 to 1
93
+
94
+ # condition帧保持相对稳定的前向运动
95
+ translation = [-t * 0.5, 0.0, 0.0] # 缓慢前进
96
+ rotation = [1.0, 0.0, 0.0, 0.0] # 无旋转
97
+ frame_type = 0.0 # condition
98
+
99
+ pose_vec = translation + rotation + [frame_type] # 8D vector
100
+ poses.append(pose_vec)
101
+
102
+ # 🔧 生成target帧的pose(根据指定方向)
103
+ for i in range(target_frames):
104
+ t = i / max(1, target_frames - 1) # 0 to 1
105
+
106
+ if direction == "forward":
107
+ # 前进:x负方向移动,无旋转
108
+ translation = [-(condition_frames * 0.5 + t * 2.0), 0.0, 0.0]
109
+ rotation = [1.0, 0.0, 0.0, 0.0] # 单位四元数
110
+
111
+ elif direction == "backward":
112
+ # 后退:x正方向移动,无旋转
113
+ translation = [-(condition_frames * 0.5) + t * 2.0, 0.0, 0.0]
114
+ rotation = [1.0, 0.0, 0.0, 0.0]
115
+
116
+ elif direction == "left_turn":
117
+ # 左转:前进 + 绕z轴正向旋转
118
+ translation = [-(condition_frames * 0.5 + t * 1.5), t * 0.5, 0.0] # 前进并稍微左移
119
+ yaw = t * 0.3 # 左转角度(弧度)
120
+ rotation = [
121
+ np.cos(yaw/2), # w
122
+ 0.0, # x
123
+ 0.0, # y
124
+ np.sin(yaw/2) # z (左转为正)
125
+ ]
126
+
127
+ elif direction == "right_turn":
128
+ # 右转:前进 + 绕z轴负向旋转
129
+ translation = [-(condition_frames * 0.5 + t * 1.5), -t * 0.5, 0.0] # 前进并稍微右移
130
+ yaw = -t * 0.3 # 右转角度(弧度)
131
+ rotation = [
132
+ np.cos(abs(yaw)/2), # w
133
+ 0.0, # x
134
+ 0.0, # y
135
+ np.sin(yaw/2) # z (右转为负)
136
+ ]
137
+ else:
138
+ raise ValueError(f"Unknown direction: {direction}")
139
+
140
+ frame_type = 1.0 # target
141
+ pose_vec = translation + rotation + [frame_type] # 8D vector
142
+ poses.append(pose_vec)
143
+
144
+ pose_sequence = torch.tensor(poses, dtype=torch.float32)
145
+
146
+ # 🔧 只对target部分进行分类(前7维,去掉frame type)
147
+ target_pose_sequence = pose_sequence[condition_frames:, :7]
148
+
149
+ # 🔧 使用增强的embedding生成方法
150
+ condition_classes = torch.full((condition_frames,), 0, dtype=torch.long) # condition都是forward
151
+ target_classes = classifier.classify_pose_sequence(target_pose_sequence)
152
+ full_classes = torch.cat([condition_classes, target_classes], dim=0)
153
+
154
+ # 创建增强的embedding
155
+ class_embeddings = create_enhanced_class_embedding_for_inference(
156
+ full_classes, pose_sequence, embed_dim=512
157
+ )
158
+
159
+ print(f"Generated {direction} poses:")
160
+ print(f" Total frames: {total_frames} (condition: {condition_frames}, target: {target_frames})")
161
+ analysis = classifier.analyze_pose_sequence(target_pose_sequence)
162
+ print(f" Target class distribution: {analysis['class_distribution']}")
163
+ print(f" Target motion segments: {len(analysis['motion_segments'])}")
164
+
165
+ return class_embeddings
166
+
167
+ def create_enhanced_class_embedding_for_inference(class_labels: torch.Tensor, pose_sequence: torch.Tensor, embed_dim: int = 512) -> torch.Tensor:
168
+ """推理时创建增强的类别embedding"""
169
+ num_classes = 4
170
+ num_frames = len(class_labels)
171
+
172
+ # 基础的方向embedding
173
+ direction_vectors = torch.tensor([
174
+ [1.0, 0.0, 0.0, 0.0], # forward
175
+ [-1.0, 0.0, 0.0, 0.0], # backward
176
+ [0.0, 1.0, 0.0, 0.0], # left_turn
177
+ [0.0, -1.0, 0.0, 0.0], # right_turn
178
+ ], dtype=torch.float32)
179
+
180
+ # One-hot编码
181
+ one_hot = torch.zeros(num_frames, num_classes)
182
+ one_hot.scatter_(1, class_labels.unsqueeze(1), 1)
183
+
184
+ # 基于方向向量的基础embedding
185
+ base_embeddings = one_hot @ direction_vectors # [num_frames, 4]
186
+
187
+ # 添加frame type信息
188
+ frame_types = pose_sequence[:, -1] # 最后一维是frame type
189
+ frame_type_embeddings = torch.zeros(num_frames, 2)
190
+ frame_type_embeddings[:, 0] = (frame_types == 0).float() # condition
191
+ frame_type_embeddings[:, 1] = (frame_types == 1).float() # target
192
+
193
+ # 添加pose的几何信息
194
+ translations = pose_sequence[:, :3] # [num_frames, 3]
195
+ rotations = pose_sequence[:, 3:7] # [num_frames, 4]
196
+
197
+ # 组合所有特征
198
+ combined_features = torch.cat([
199
+ base_embeddings, # [num_frames, 4]
200
+ frame_type_embeddings, # [num_frames, 2]
201
+ translations, # [num_frames, 3]
202
+ rotations, # [num_frames, 4]
203
+ ], dim=1) # [num_frames, 13]
204
+
205
+ # 扩展到目标维度
206
+ if embed_dim > 13:
207
+ expand_matrix = torch.randn(13, embed_dim) * 0.1
208
+ expand_matrix[:13, :13] = torch.eye(13)
209
+ embeddings = combined_features @ expand_matrix
210
+ else:
211
+ embeddings = combined_features[:, :embed_dim]
212
+
213
+ return embeddings
214
+
215
+ def generate_poses_from_file(poses_path, target_frames=10):
216
+ """从poses.json文件生成类别embedding"""
217
+ classifier = PoseClassifier()
218
+
219
+ with open(poses_path, 'r') as f:
220
+ poses_data = json.load(f)
221
+
222
+ target_relative_poses = poses_data['target_relative_poses']
223
+
224
+ if not target_relative_poses:
225
+ print("No poses found in file, using forward direction")
226
+ return generate_direction_poses("forward", target_frames)
227
+
228
+ # 创建pose序列
229
+ pose_vecs = []
230
+ for i in range(target_frames):
231
+ if len(target_relative_poses) == 1:
232
+ pose_data = target_relative_poses[0]
233
+ else:
234
+ pose_idx = min(i * len(target_relative_poses) // target_frames,
235
+ len(target_relative_poses) - 1)
236
+ pose_data = target_relative_poses[pose_idx]
237
+
238
+ # 提取相对位移和旋转
239
+ translation = torch.tensor(pose_data['relative_translation'], dtype=torch.float32)
240
+ current_rotation = torch.tensor(pose_data['current_rotation'], dtype=torch.float32)
241
+ reference_rotation = torch.tensor(pose_data['reference_rotation'], dtype=torch.float32)
242
+
243
+ # 计算相对旋转
244
+ relative_rotation = calculate_relative_rotation(current_rotation, reference_rotation)
245
+
246
+ # 组合为7D向量
247
+ pose_vec = torch.cat([translation, relative_rotation], dim=0)
248
+ pose_vecs.append(pose_vec)
249
+
250
+ pose_sequence = torch.stack(pose_vecs, dim=0)
251
+
252
+ # 使用分类器生成class embedding
253
+ class_embeddings = classifier.create_class_embedding(
254
+ classifier.classify_pose_sequence(pose_sequence),
255
+ embed_dim=512
256
+ )
257
+
258
+ print(f"Generated poses from file:")
259
+ analysis = classifier.analyze_pose_sequence(pose_sequence)
260
+ print(f" Class distribution: {analysis['class_distribution']}")
261
+ print(f" Motion segments: {len(analysis['motion_segments'])}")
262
+
263
+ return class_embeddings
264
+
265
+ def inference_nuscenes_video(
266
+ condition_video_path,
267
+ dit_path,
268
+ text_encoder_path,
269
+ vae_path,
270
+ output_path="nus/infer_results/output_nuscenes.mp4",
271
+ condition_frames=20,
272
+ target_frames=3,
273
+ height=900,
274
+ width=1600,
275
+ device="cuda",
276
+ prompt="A car driving scene captured by front camera",
277
+ poses_path=None,
278
+ direction="forward"
279
+ ):
280
+ """
281
+ 使用方向类别控制的推理函数 - 支持condition和target pose区分
282
+ """
283
+ os.makedirs(os.path.dirname(output_path),exist_ok=True)
284
+
285
+ print(f"Setting up models for {direction} movement...")
286
+
287
+ # 1. Load models (same as before)
288
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
289
+ model_manager.load_models([
290
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
291
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
292
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
293
+ ])
294
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
295
+
296
+ # Add camera components to DiT
297
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
298
+ for block in pipe.dit.blocks:
299
+ block.cam_encoder = nn.Linear(512, dim) # 保持512维embedding
300
+ block.projector = nn.Linear(dim, dim)
301
+ block.cam_encoder.weight.data.zero_()
302
+ block.cam_encoder.bias.data.zero_()
303
+ block.projector.weight = nn.Parameter(torch.eye(dim))
304
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
305
+
306
+ # Load trained DiT weights
307
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
308
+ pipe.dit.load_state_dict(dit_state_dict, strict=True)
309
+ pipe = pipe.to(device)
310
+ pipe.scheduler.set_timesteps(50)
311
+
312
+ print("Loading condition video...")
313
+
314
+ # Load condition video
315
+ condition_video = load_video_frames(
316
+ condition_video_path,
317
+ num_frames=condition_frames,
318
+ height=height,
319
+ width=width
320
+ )
321
+
322
+ if condition_video is None:
323
+ raise ValueError(f"Failed to load condition video from {condition_video_path}")
324
+
325
+ condition_video = condition_video.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
326
+
327
+ print("Processing poses...")
328
+
329
+ # 🔧 修改:生成包含condition和target的pose embedding
330
+ print(f"Generating {direction} movement poses...")
331
+ camera_embedding = generate_direction_poses(
332
+ direction=direction,
333
+ target_frames=target_frames,
334
+ condition_frames=int(condition_frames/4) # 压缩后的condition帧数
335
+ )
336
+
337
+ camera_embedding = camera_embedding.unsqueeze(0).to(device, dtype=torch.bfloat16)
338
+
339
+ print(f"Camera embedding shape: {camera_embedding.shape}")
340
+ print(f"Generated poses for direction: {direction}")
341
+
342
+ print("Encoding inputs...")
343
+
344
+ # Encode text prompt
345
+ prompt_emb = pipe.encode_prompt(prompt)
346
+
347
+ # Encode condition video
348
+ condition_latents = pipe.encode_video(condition_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))[0]
349
+
350
+ print("Generating video...")
351
+
352
+ # Generate target latents
353
+ batch_size = 1
354
+ channels = condition_latents.shape[0]
355
+ latent_height = condition_latents.shape[2]
356
+ latent_width = condition_latents.shape[3]
357
+ target_height, target_width = 60, 104 # 根据你的需求调整
358
+
359
+ if latent_height > target_height or latent_width > target_width:
360
+ # 中心裁剪
361
+ h_start = (latent_height - target_height) // 2
362
+ w_start = (latent_width - target_width) // 2
363
+ condition_latents = condition_latents[:, :,
364
+ h_start:h_start+target_height,
365
+ w_start:w_start+target_width]
366
+ latent_height = target_height
367
+ latent_width = target_width
368
+ condition_latents = condition_latents.to(device, dtype=pipe.torch_dtype)
369
+ condition_latents = condition_latents.unsqueeze(0)
370
+ condition_latents = condition_latents + 0.05 * torch.randn_like(condition_latents) # 添加少量噪声以增加多样性
371
+
372
+ # Initialize target latents with noise
373
+ target_latents = torch.randn(
374
+ batch_size, channels, target_frames, latent_height, latent_width,
375
+ device=device, dtype=pipe.torch_dtype
376
+ )
377
+ print(target_latents.shape)
378
+ print(camera_embedding.shape)
379
+ # Combine condition and target latents
380
+ combined_latents = torch.cat([condition_latents, target_latents], dim=2)
381
+ print(combined_latents.shape)
382
+
383
+ # Prepare extra inputs
384
+ extra_input = pipe.prepare_extra_input(combined_latents)
385
+
386
+ # Denoising loop
387
+ timesteps = pipe.scheduler.timesteps
388
+
389
+ for i, timestep in enumerate(timesteps):
390
+ print(f"Denoising step {i+1}/{len(timesteps)}")
391
+
392
+ # Prepare timestep
393
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
394
+
395
+ # Predict noise
396
+ with torch.no_grad():
397
+ noise_pred = pipe.dit(
398
+ combined_latents,
399
+ timestep=timestep_tensor,
400
+ cam_emb=camera_embedding,
401
+ **prompt_emb,
402
+ **extra_input
403
+ )
404
+
405
+ # Update only target part
406
+ target_noise_pred = noise_pred[:, :, int(condition_frames/4):, :, :]
407
+ target_latents = pipe.scheduler.step(target_noise_pred, timestep, target_latents)
408
+
409
+ # Update combined latents
410
+ combined_latents[:, :, int(condition_frames/4):, :, :] = target_latents
411
+
412
+ print("Decoding video...")
413
+
414
+ # Decode final video
415
+ final_video = torch.cat([condition_latents, target_latents], dim=2)
416
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
417
+
418
+ # Save video
419
+ print(f"Saving video to {output_path}")
420
+
421
+ # Convert to numpy and save
422
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy() # 转换为 Float32
423
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1) # Denormalize
424
+ video_np = (video_np * 255).astype(np.uint8)
425
+
426
+ with imageio.get_writer(output_path, fps=20) as writer:
427
+ for frame in video_np:
428
+ writer.append_data(frame)
429
+
430
+ print(f"Video generation completed! Saved to {output_path}")
431
+
432
+ def main():
433
+ parser = argparse.ArgumentParser(description="NuScenes Video Generation Inference with Direction Control")
434
+ parser.add_argument("--condition_video", type=str, default="/home/zhuyixuan05/ReCamMaster/nus/videos/4032/right.mp4",
435
+ help="Path to condition video")
436
+ parser.add_argument("--direction", type=str, default="left_turn",
437
+ choices=["forward", "backward", "left_turn", "right_turn"],
438
+ help="Direction of camera movement")
439
+ parser.add_argument("--dit_path", type=str, default="/home/zhuyixuan05/ReCamMaster/nus_dynamic/step15000_dynamic.ckpt",
440
+ help="Path to trained DiT checkpoint")
441
+ parser.add_argument("--text_encoder_path", type=str,
442
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
443
+ help="Path to text encoder")
444
+ parser.add_argument("--vae_path", type=str,
445
+ default="models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
446
+ help="Path to VAE")
447
+ parser.add_argument("--output_path", type=str, default="nus/infer_results-15000/right_left.mp4",
448
+ help="Output video path")
449
+ parser.add_argument("--poses_path", type=str, default=None,
450
+ help="Path to poses.json file (optional, will use direction if not provided)")
451
+ parser.add_argument("--prompt", type=str,
452
+ default="A car driving scene captured by front camera",
453
+ help="Text prompt for generation")
454
+ parser.add_argument("--condition_frames", type=int, default=40,
455
+ help="Number of condition frames")
456
+ # 这个是原始帧数
457
+ parser.add_argument("--target_frames", type=int, default=8,
458
+ help="Number of target frames to generate")
459
+ # 这个要除以4
460
+ parser.add_argument("--height", type=int, default=900,
461
+ help="Video height")
462
+ parser.add_argument("--width", type=int, default=1600,
463
+ help="Video width")
464
+ parser.add_argument("--device", type=str, default="cuda",
465
+ help="Device to run inference on")
466
+
467
+ args = parser.parse_args()
468
+
469
+ condition_video_path = args.condition_video
470
+ input_filename = os.path.basename(condition_video_path)
471
+ output_dir = "nus/infer_results"
472
+ os.makedirs(output_dir, exist_ok=True)
473
+
474
+ # 🔧 修改:在输出文件名中包含方向信息
475
+ if args.output_path is None:
476
+ name_parts = os.path.splitext(input_filename)
477
+ output_filename = f"{name_parts[0]}_{args.direction}{name_parts[1]}"
478
+ output_path = os.path.join(output_dir, output_filename)
479
+ else:
480
+ output_path = args.output_path
481
+
482
+ print(f"Output video will be saved to: {output_path}")
483
+ inference_nuscenes_video(
484
+ condition_video_path=args.condition_video,
485
+ dit_path=args.dit_path,
486
+ text_encoder_path=args.text_encoder_path,
487
+ vae_path=args.vae_path,
488
+ output_path=output_path,
489
+ condition_frames=args.condition_frames,
490
+ target_frames=args.target_frames,
491
+ height=args.height,
492
+ width=args.width,
493
+ device=args.device,
494
+ prompt=args.prompt,
495
+ poses_path=args.poses_path,
496
+ direction=args.direction # 🔧 新增
497
+ )
498
+
499
+ if __name__ == "__main__":
500
+ main()
scripts/infer_openx.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
3
+ from torchvision.transforms import v2
4
+ from einops import rearrange
5
+ import os
6
+ import torch
7
+ import torch.nn as nn
8
+ import argparse
9
+ import numpy as np
10
+ import imageio
11
+ import copy
12
+ import random
13
+
14
+ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
15
+ """从pth文件加载预编码的视频数据"""
16
+ print(f"Loading encoded video from {pth_path}")
17
+
18
+ encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
19
+ full_latents = encoded_data['latents'] # [C, T, H, W]
20
+
21
+ print(f"Full latents shape: {full_latents.shape}")
22
+ print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
23
+
24
+ if start_frame + num_frames > full_latents.shape[1]:
25
+ raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
26
+
27
+ condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
28
+ print(f"Extracted condition latents shape: {condition_latents.shape}")
29
+
30
+ return condition_latents, encoded_data
31
+
32
+ def compute_relative_pose(pose_a, pose_b, use_torch=False):
33
+ """计算相机B相对于相机A的相对位姿矩阵"""
34
+ assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
35
+ assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
36
+
37
+ if use_torch:
38
+ if not isinstance(pose_a, torch.Tensor):
39
+ pose_a = torch.from_numpy(pose_a).float()
40
+ if not isinstance(pose_b, torch.Tensor):
41
+ pose_b = torch.from_numpy(pose_b).float()
42
+
43
+ pose_a_inv = torch.inverse(pose_a)
44
+ relative_pose = torch.matmul(pose_b, pose_a_inv)
45
+ else:
46
+ if not isinstance(pose_a, np.ndarray):
47
+ pose_a = np.array(pose_a, dtype=np.float32)
48
+ if not isinstance(pose_b, np.ndarray):
49
+ pose_b = np.array(pose_b, dtype=np.float32)
50
+
51
+ pose_a_inv = np.linalg.inv(pose_a)
52
+ relative_pose = np.matmul(pose_b, pose_a_inv)
53
+
54
+ return relative_pose
55
+
56
+ def replace_dit_model_in_manager():
57
+ """在模型加载前替换DiT模型类"""
58
+ from diffsynth.models.wan_video_dit_recam_future import WanModelFuture
59
+ from diffsynth.configs.model_config import model_loader_configs
60
+
61
+ # 修改model_loader_configs中的配置
62
+ for i, config in enumerate(model_loader_configs):
63
+ keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
64
+
65
+ # 检查是否包含wan_video_dit模型
66
+ if 'wan_video_dit' in model_names:
67
+ # 找到wan_video_dit的索引并替换为WanModelFuture
68
+ new_model_names = []
69
+ new_model_classes = []
70
+
71
+ for name, cls in zip(model_names, model_classes):
72
+ if name == 'wan_video_dit':
73
+ new_model_names.append(name) # 保持名称不变
74
+ new_model_classes.append(WanModelFuture) # 替换为新的类
75
+ print(f"✅ 替换了模型类: {name} -> WanModelFuture")
76
+ else:
77
+ new_model_names.append(name)
78
+ new_model_classes.append(cls)
79
+
80
+ # 更新配置
81
+ model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
82
+
83
+ def add_framepack_components(dit_model):
84
+ """添加FramePack相关组件"""
85
+ if not hasattr(dit_model, 'clean_x_embedder'):
86
+ inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
87
+
88
+ class CleanXEmbedder(nn.Module):
89
+ def __init__(self, inner_dim):
90
+ super().__init__()
91
+ # 参考hunyuan_video_packed.py的设计
92
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
93
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
94
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
95
+
96
+ def forward(self, x, scale="1x"):
97
+ if scale == "1x":
98
+ return self.proj(x)
99
+ elif scale == "2x":
100
+ return self.proj_2x(x)
101
+ elif scale == "4x":
102
+ return self.proj_4x(x)
103
+ else:
104
+ raise ValueError(f"Unsupported scale: {scale}")
105
+
106
+ dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
107
+ model_dtype = next(dit_model.parameters()).dtype
108
+ dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
109
+ print("✅ 添加了FramePack的clean_x_embedder组件")
110
+
111
+ def generate_openx_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
112
+ """为OpenX数据集生成camera embeddings - 滑动窗口版本"""
113
+ time_compression_ratio = 4
114
+
115
+ # 计算FramePack实际需要的camera帧数
116
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
117
+
118
+ if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
119
+ print("🔧 使用真实OpenX camera数据")
120
+ cam_extrinsic = cam_data['extrinsic']
121
+
122
+ # 确保生成足够长的camera序列
123
+ max_needed_frames = max(
124
+ start_frame + current_history_length + new_frames,
125
+ framepack_needed_frames,
126
+ 30
127
+ )
128
+
129
+ print(f"🔧 计算OpenX camera序列长度:")
130
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
131
+ print(f" - FramePack需求: {framepack_needed_frames}")
132
+ print(f" - 最终生成: {max_needed_frames}")
133
+
134
+ relative_poses = []
135
+ for i in range(max_needed_frames):
136
+ # OpenX特有:每隔4帧
137
+ frame_idx = i * time_compression_ratio
138
+ next_frame_idx = frame_idx + time_compression_ratio
139
+
140
+ if next_frame_idx < len(cam_extrinsic):
141
+ cam_prev = cam_extrinsic[frame_idx]
142
+ cam_next = cam_extrinsic[next_frame_idx]
143
+ relative_cam = compute_relative_pose(cam_prev, cam_next)
144
+ relative_poses.append(torch.as_tensor(relative_cam[:3, :]))
145
+ else:
146
+ # 超出范围,使用零运动
147
+ print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
148
+ relative_poses.append(torch.zeros(3, 4))
149
+
150
+ pose_embedding = torch.stack(relative_poses, dim=0)
151
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
152
+
153
+ # 创建对应长度的mask序列
154
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
155
+ # 从start_frame到current_history_length标记为condition
156
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
157
+ mask[start_frame:condition_end] = 1.0
158
+
159
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
160
+ print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
161
+ return camera_embedding.to(torch.bfloat16)
162
+
163
+ else:
164
+ print("🔧 使用OpenX合成camera数据")
165
+
166
+ max_needed_frames = max(
167
+ start_frame + current_history_length + new_frames,
168
+ framepack_needed_frames,
169
+ 30
170
+ )
171
+
172
+ print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
173
+ relative_poses = []
174
+ for i in range(max_needed_frames):
175
+ # OpenX机器人操作模式 - 稳定的小幅度运动
176
+ # 模拟机器人手臂的精细操作
177
+ forward_speed = 0.001 # 每帧前进距离(很小,因为是精细操作)
178
+ lateral_motion = 0.0005 * np.sin(i * 0.05) # 轻微的左右移动
179
+ vertical_motion = 0.0003 * np.cos(i * 0.1) # 轻微的上下移动
180
+
181
+ # 旋转变化(模拟视角微调)
182
+ yaw_change = 0.01 * np.sin(i * 0.03) # 轻微的偏航
183
+ pitch_change = 0.008 * np.cos(i * 0.04) # 轻微的俯仰
184
+
185
+ pose = np.eye(4, dtype=np.float32)
186
+
187
+ # 旋转矩阵(绕Y轴和X轴的小角度旋转)
188
+ cos_yaw = np.cos(yaw_change)
189
+ sin_yaw = np.sin(yaw_change)
190
+ cos_pitch = np.cos(pitch_change)
191
+ sin_pitch = np.sin(pitch_change)
192
+
193
+ # 组合旋转(先pitch后yaw)
194
+ pose[0, 0] = cos_yaw
195
+ pose[0, 2] = sin_yaw
196
+ pose[1, 1] = cos_pitch
197
+ pose[1, 2] = -sin_pitch
198
+ pose[2, 0] = -sin_yaw
199
+ pose[2, 1] = sin_pitch
200
+ pose[2, 2] = cos_yaw * cos_pitch
201
+
202
+ # 平移(精细操作的小幅度移动)
203
+ pose[0, 3] = lateral_motion # X轴(左右)
204
+ pose[1, 3] = vertical_motion # Y轴(上下)
205
+ pose[2, 3] = -forward_speed # Z轴(前后,负值表示前进)
206
+
207
+ relative_pose = pose[:3, :]
208
+ relative_poses.append(torch.as_tensor(relative_pose))
209
+
210
+ pose_embedding = torch.stack(relative_poses, dim=0)
211
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
212
+
213
+ # 创建对应长度的mask序列
214
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
215
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
216
+ mask[start_frame:condition_end] = 1.0
217
+
218
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
219
+ print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
220
+ return camera_embedding.to(torch.bfloat16)
221
+
222
+ def prepare_framepack_sliding_window_with_camera(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, max_history_frames=49):
223
+ """FramePack滑动��口机制 - OpenX版本"""
224
+ # history_latents: [C, T, H, W] 当前的历史latents
225
+ C, T, H, W = history_latents.shape
226
+
227
+ # 固定索引结构(这决定了需要的camera帧数)
228
+ total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
229
+ indices = torch.arange(0, total_indices_length)
230
+ split_sizes = [1, 16, 2, 1, target_frames_to_generate]
231
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
232
+ indices.split(split_sizes, dim=0)
233
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
234
+
235
+ # 检查camera长度是否足够
236
+ if camera_embedding_full.shape[0] < total_indices_length:
237
+ shortage = total_indices_length - camera_embedding_full.shape[0]
238
+ padding = torch.zeros(shortage, camera_embedding_full.shape[1],
239
+ dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
240
+ camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
241
+
242
+ # 从完整camera序列中选取对应部分
243
+ combined_camera = camera_embedding_full[:total_indices_length, :].clone()
244
+
245
+ # 根据当前history length重新设置mask
246
+ combined_camera[:, -1] = 0.0 # 先全部设为target (0)
247
+
248
+ # 设置condition mask:前19帧根据实际历史长度决定
249
+ if T > 0:
250
+ available_frames = min(T, 19)
251
+ start_pos = 19 - available_frames
252
+ combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
253
+
254
+ print(f"🔧 OpenX Camera mask更新:")
255
+ print(f" - 历史帧数: {T}")
256
+ print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
257
+
258
+ # 处理latents
259
+ clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
260
+
261
+ if T > 0:
262
+ available_frames = min(T, 19)
263
+ start_pos = 19 - available_frames
264
+ clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
265
+
266
+ clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
267
+ clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
268
+ clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
269
+
270
+ if T > 0:
271
+ start_latent = history_latents[:, 0:1, :, :]
272
+ else:
273
+ start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
274
+
275
+ clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
276
+
277
+ return {
278
+ 'latent_indices': latent_indices,
279
+ 'clean_latents': clean_latents,
280
+ 'clean_latents_2x': clean_latents_2x,
281
+ 'clean_latents_4x': clean_latents_4x,
282
+ 'clean_latent_indices': clean_latent_indices,
283
+ 'clean_latent_2x_indices': clean_latent_2x_indices,
284
+ 'clean_latent_4x_indices': clean_latent_4x_indices,
285
+ 'camera_embedding': combined_camera,
286
+ 'current_length': T,
287
+ 'next_length': T + target_frames_to_generate
288
+ }
289
+
290
+ def inference_openx_framepack_sliding_window(
291
+ condition_pth_path,
292
+ dit_path,
293
+ output_path="openx_results/output_openx_framepack_sliding.mp4",
294
+ start_frame=0,
295
+ initial_condition_frames=8,
296
+ frames_per_generation=4,
297
+ total_frames_to_generate=32,
298
+ max_history_frames=49,
299
+ device="cuda",
300
+ prompt="A video of robotic manipulation task with camera movement",
301
+ use_real_poses=True,
302
+ # CFG参数
303
+ use_camera_cfg=True,
304
+ camera_guidance_scale=2.0,
305
+ text_guidance_scale=1.0
306
+ ):
307
+ """
308
+ OpenX FramePack滑动窗口视频生成
309
+ """
310
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
311
+ print(f"🔧 OpenX FramePack滑动窗口生成开始...")
312
+ print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
313
+ print(f"Text guidance scale: {text_guidance_scale}")
314
+
315
+ # 1. 模型初始化
316
+ replace_dit_model_in_manager()
317
+
318
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
319
+ model_manager.load_models([
320
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
321
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
322
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
323
+ ])
324
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
325
+
326
+ # 2. 添加camera编码器
327
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
328
+ for block in pipe.dit.blocks:
329
+ block.cam_encoder = nn.Linear(13, dim)
330
+ block.projector = nn.Linear(dim, dim)
331
+ block.cam_encoder.weight.data.zero_()
332
+ block.cam_encoder.bias.data.zero_()
333
+ block.projector.weight = nn.Parameter(torch.eye(dim))
334
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
335
+
336
+ # 3. 添加FramePack组件
337
+ add_framepack_components(pipe.dit)
338
+
339
+ # 4. 加载训练好的权重
340
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
341
+ pipe.dit.load_state_dict(dit_state_dict, strict=True)
342
+ pipe = pipe.to(device)
343
+ model_dtype = next(pipe.dit.parameters()).dtype
344
+
345
+ if hasattr(pipe.dit, 'clean_x_embedder'):
346
+ pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
347
+
348
+ pipe.scheduler.set_timesteps(50)
349
+
350
+ # 5. 加载初始条件
351
+ print("Loading initial condition frames...")
352
+ initial_latents, encoded_data = load_encoded_video_from_pth(
353
+ condition_pth_path,
354
+ start_frame=start_frame,
355
+ num_frames=initial_condition_frames
356
+ )
357
+
358
+ # 空间裁剪(适配OpenX数据尺寸)
359
+ target_height, target_width = 60, 104
360
+ C, T, H, W = initial_latents.shape
361
+
362
+ if H > target_height or W > target_width:
363
+ h_start = (H - target_height) // 2
364
+ w_start = (W - target_width) // 2
365
+ initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
366
+ H, W = target_height, target_width
367
+
368
+ history_latents = initial_latents.to(device, dtype=model_dtype)
369
+
370
+ print(f"初始history_latents shape: {history_latents.shape}")
371
+
372
+ # 6. 编码prompt - 支持CFG
373
+ if text_guidance_scale > 1.0:
374
+ prompt_emb_pos = pipe.encode_prompt(prompt)
375
+ prompt_emb_neg = pipe.encode_prompt("")
376
+ print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
377
+ else:
378
+ prompt_emb_pos = pipe.encode_prompt(prompt)
379
+ prompt_emb_neg = None
380
+ print("不使用Text CFG")
381
+
382
+ # 7. 预生成完整的camera embedding序列
383
+ camera_embedding_full = generate_openx_camera_embeddings_sliding(
384
+ encoded_data.get('cam_emb', None),
385
+ 0,
386
+ max_history_frames,
387
+ 0,
388
+ 0,
389
+ use_real_poses=use_real_poses
390
+ ).to(device, dtype=model_dtype)
391
+
392
+ print(f"完整camera序列shape: {camera_embedding_full.shape}")
393
+
394
+ # 8. 为Camera CFG创建无条件的camera embedding
395
+ if use_camera_cfg:
396
+ camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
397
+ print(f"创建无条件camera embedding用于CFG")
398
+
399
+ # 9. 滑动窗口生成循环
400
+ total_generated = 0
401
+ all_generated_frames = []
402
+
403
+ while total_generated < total_frames_to_generate:
404
+ current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
405
+ print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
406
+ print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
407
+
408
+ # FramePack数据准备 - OpenX版本
409
+ framepack_data = prepare_framepack_sliding_window_with_camera(
410
+ history_latents,
411
+ current_generation,
412
+ camera_embedding_full,
413
+ start_frame,
414
+ max_history_frames
415
+ )
416
+
417
+ # 准备输入
418
+ clean_latents = framepack_data['clean_latents'].unsqueeze(0)
419
+ clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
420
+ clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
421
+ camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
422
+
423
+ # 为CFG准备无条件camera embedding
424
+ if use_camera_cfg:
425
+ camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
426
+
427
+ # 索引处理
428
+ latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
429
+ clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
430
+ clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
431
+ clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
432
+
433
+ # 初始化要生成的latents
434
+ new_latents = torch.randn(
435
+ 1, C, current_generation, H, W,
436
+ device=device, dtype=model_dtype
437
+ )
438
+
439
+ extra_input = pipe.prepare_extra_input(new_latents)
440
+
441
+ print(f"Camera embedding shape: {camera_embedding.shape}")
442
+ print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
443
+
444
+ # 去噪循环 - 支持CFG
445
+ timesteps = pipe.scheduler.timesteps
446
+
447
+ for i, timestep in enumerate(timesteps):
448
+ if i % 10 == 0:
449
+ print(f" 去噪步骤 {i}/{len(timesteps)}")
450
+
451
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
452
+
453
+ with torch.no_grad():
454
+ # 正向预测(带条件)
455
+ noise_pred_pos = pipe.dit(
456
+ new_latents,
457
+ timestep=timestep_tensor,
458
+ cam_emb=camera_embedding,
459
+ latent_indices=latent_indices,
460
+ clean_latents=clean_latents,
461
+ clean_latent_indices=clean_latent_indices,
462
+ clean_latents_2x=clean_latents_2x,
463
+ clean_latent_2x_indices=clean_latent_2x_indices,
464
+ clean_latents_4x=clean_latents_4x,
465
+ clean_latent_4x_indices=clean_latent_4x_indices,
466
+ **prompt_emb_pos,
467
+ **extra_input
468
+ )
469
+
470
+ # CFG处理
471
+ if use_camera_cfg and camera_guidance_scale > 1.0:
472
+ # 无条件预测(无camera条件)
473
+ noise_pred_uncond = pipe.dit(
474
+ new_latents,
475
+ timestep=timestep_tensor,
476
+ cam_emb=camera_embedding_uncond_batch,
477
+ latent_indices=latent_indices,
478
+ clean_latents=clean_latents,
479
+ clean_latent_indices=clean_latent_indices,
480
+ clean_latents_2x=clean_latents_2x,
481
+ clean_latent_2x_indices=clean_latent_2x_indices,
482
+ clean_latents_4x=clean_latents_4x,
483
+ clean_latent_4x_indices=clean_latent_4x_indices,
484
+ **prompt_emb_pos,
485
+ **extra_input
486
+ )
487
+
488
+ # Camera CFG
489
+ noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_pos - noise_pred_uncond)
490
+ else:
491
+ noise_pred = noise_pred_pos
492
+
493
+ # Text CFG
494
+ if prompt_emb_neg is not None and text_guidance_scale > 1.0:
495
+ noise_pred_text_uncond = pipe.dit(
496
+ new_latents,
497
+ timestep=timestep_tensor,
498
+ cam_emb=camera_embedding,
499
+ latent_indices=latent_indices,
500
+ clean_latents=clean_latents,
501
+ clean_latent_indices=clean_latent_indices,
502
+ clean_latents_2x=clean_latents_2x,
503
+ clean_latent_2x_indices=clean_latent_2x_indices,
504
+ clean_latents_4x=clean_latents_4x,
505
+ clean_latent_4x_indices=clean_latent_4x_indices,
506
+ **prompt_emb_neg,
507
+ **extra_input
508
+ )
509
+
510
+ # Text CFG
511
+ noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
512
+
513
+ new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
514
+
515
+ # 更新历史
516
+ new_latents_squeezed = new_latents.squeeze(0)
517
+ history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
518
+
519
+ # 维护滑动窗口
520
+ if history_latents.shape[1] > max_history_frames:
521
+ first_frame = history_latents[:, 0:1, :, :]
522
+ recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
523
+ history_latents = torch.cat([first_frame, recent_frames], dim=1)
524
+ print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
525
+
526
+ print(f"更新后history_latents shape: {history_latents.shape}")
527
+
528
+ all_generated_frames.append(new_latents_squeezed)
529
+ total_generated += current_generation
530
+
531
+ print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
532
+
533
+ # 10. 解码和保存
534
+ print("\n🔧 解码生成的视频...")
535
+
536
+ all_generated = torch.cat(all_generated_frames, dim=1)
537
+ final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
538
+
539
+ print(f"最终视频shape: {final_video.shape}")
540
+
541
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
542
+
543
+ print(f"Saving video to {output_path}")
544
+
545
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
546
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1)
547
+ video_np = (video_np * 255).astype(np.uint8)
548
+
549
+ with imageio.get_writer(output_path, fps=20) as writer:
550
+ for frame in video_np:
551
+ writer.append_data(frame)
552
+
553
+ print(f"🔧 OpenX FramePack滑动窗口生成完成! 保存到: {output_path}")
554
+ print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
555
+
556
+ def main():
557
+ parser = argparse.ArgumentParser(description="OpenX FramePack滑动窗口视频生成")
558
+
559
+ # 基础参数
560
+ parser.add_argument("--condition_pth", type=str,
561
+ default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth",
562
+ help="输入编码视频路径")
563
+ parser.add_argument("--start_frame", type=int, default=0)
564
+ parser.add_argument("--initial_condition_frames", type=int, default=16)
565
+ parser.add_argument("--frames_per_generation", type=int, default=8)
566
+ parser.add_argument("--total_frames_to_generate", type=int, default=24)
567
+ parser.add_argument("--max_history_frames", type=int, default=100)
568
+ parser.add_argument("--use_real_poses", action="store_true", default=False)
569
+ parser.add_argument("--dit_path", type=str,
570
+ default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/openx/openx_framepack/step2000.ckpt",
571
+ help="训练好的模型权重路径")
572
+ parser.add_argument("--output_path", type=str,
573
+ default='openx_results/output_openx_framepack_sliding.mp4')
574
+ parser.add_argument("--prompt", type=str,
575
+ default="A video of robotic manipulation task with camera movement")
576
+ parser.add_argument("--device", type=str, default="cuda")
577
+
578
+ # CFG参数
579
+ parser.add_argument("--use_camera_cfg", action="store_true", default=True,
580
+ help="使用Camera CFG")
581
+ parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
582
+ help="Camera guidance scale for CFG")
583
+ parser.add_argument("--text_guidance_scale", type=float, default=1.0,
584
+ help="Text guidance scale for CFG")
585
+
586
+ args = parser.parse_args()
587
+
588
+ print(f"🔧 OpenX FramePack CFG生成设置:")
589
+ print(f"Camera CFG: {args.use_camera_cfg}")
590
+ if args.use_camera_cfg:
591
+ print(f"Camera guidance scale: {args.camera_guidance_scale}")
592
+ print(f"Text guidance scale: {args.text_guidance_scale}")
593
+ print(f"OpenX特有特性: camera间隔为4帧,适用于机器人操作任务")
594
+
595
+ inference_openx_framepack_sliding_window(
596
+ condition_pth_path=args.condition_pth,
597
+ dit_path=args.dit_path,
598
+ output_path=args.output_path,
599
+ start_frame=args.start_frame,
600
+ initial_condition_frames=args.initial_condition_frames,
601
+ frames_per_generation=args.frames_per_generation,
602
+ total_frames_to_generate=args.total_frames_to_generate,
603
+ max_history_frames=args.max_history_frames,
604
+ device=args.device,
605
+ prompt=args.prompt,
606
+ use_real_poses=args.use_real_poses,
607
+ # CFG参数
608
+ use_camera_cfg=args.use_camera_cfg,
609
+ camera_guidance_scale=args.camera_guidance_scale,
610
+ text_guidance_scale=args.text_guidance_scale
611
+ )
612
+
613
+ if __name__ == "__main__":
614
+ main()
scripts/infer_origin.py ADDED
@@ -0,0 +1,1108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ import imageio
7
+ import json
8
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
9
+ import argparse
10
+ from torchvision.transforms import v2
11
+ from einops import rearrange
12
+ import copy
13
+
14
+ def compute_relative_pose_matrix(pose1, pose2):
15
+ """
16
+ 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel]
17
+
18
+ 参数:
19
+ pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
20
+ pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
21
+
22
+ 返回:
23
+ relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel
24
+ """
25
+ # 分离平移向量和四元数
26
+ t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1]
27
+ q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1]
28
+ t2 = pose2[:3] # 第i+1帧平移
29
+ q2 = pose2[3:] # 第i+1帧四元数
30
+
31
+ # 1. 计算相对旋转矩阵 R_rel
32
+ rot1 = R.from_quat(q1) # 第i帧旋转
33
+ rot2 = R.from_quat(q2) # 第i+1帧旋转
34
+ rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆
35
+ R_rel = rot_rel.as_matrix() # 转换为3×3矩阵
36
+
37
+ # 2. 计算相对平移向量 t_rel
38
+ R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆)
39
+ t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1)
40
+
41
+ # 3. 组合为3×4矩阵 [R_rel | t_rel]
42
+ relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
43
+
44
+ return relative_matrix
45
+
46
+ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
47
+ """从pth文件加载预编码的视频数据"""
48
+ print(f"Loading encoded video from {pth_path}")
49
+
50
+ encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
51
+ full_latents = encoded_data['latents'] # [C, T, H, W]
52
+
53
+ print(f"Full latents shape: {full_latents.shape}")
54
+ print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
55
+
56
+ if start_frame + num_frames > full_latents.shape[1]:
57
+ raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
58
+
59
+ condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
60
+ print(f"Extracted condition latents shape: {condition_latents.shape}")
61
+
62
+ return condition_latents, encoded_data
63
+
64
+
65
+ def compute_relative_pose(pose_a, pose_b, use_torch=False):
66
+ """计算相机B相对于相机A的相对位姿矩阵"""
67
+ assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
68
+ assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
69
+
70
+ if use_torch:
71
+ if not isinstance(pose_a, torch.Tensor):
72
+ pose_a = torch.from_numpy(pose_a).float()
73
+ if not isinstance(pose_b, torch.Tensor):
74
+ pose_b = torch.from_numpy(pose_b).float()
75
+
76
+ pose_a_inv = torch.inverse(pose_a)
77
+ relative_pose = torch.matmul(pose_b, pose_a_inv)
78
+ else:
79
+ if not isinstance(pose_a, np.ndarray):
80
+ pose_a = np.array(pose_a, dtype=np.float32)
81
+ if not isinstance(pose_b, np.ndarray):
82
+ pose_b = np.array(pose_b, dtype=np.float32)
83
+
84
+ pose_a_inv = np.linalg.inv(pose_a)
85
+ relative_pose = np.matmul(pose_b, pose_a_inv)
86
+
87
+ return relative_pose
88
+
89
+
90
+ def replace_dit_model_in_manager():
91
+ """替换DiT模型类为MoE版本"""
92
+ from diffsynth.models.wan_video_dit_moe import WanModelMoe
93
+ from diffsynth.configs.model_config import model_loader_configs
94
+
95
+ for i, config in enumerate(model_loader_configs):
96
+ keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
97
+
98
+ if 'wan_video_dit' in model_names:
99
+ new_model_names = []
100
+ new_model_classes = []
101
+
102
+ for name, cls in zip(model_names, model_classes):
103
+ if name == 'wan_video_dit':
104
+ new_model_names.append(name)
105
+ new_model_classes.append(WanModelMoe)
106
+ print(f"✅ 替换了模型类: {name} -> WanModelMoe")
107
+ else:
108
+ new_model_names.append(name)
109
+ new_model_classes.append(cls)
110
+
111
+ model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
112
+
113
+
114
+ def add_framepack_components(dit_model):
115
+ """添加FramePack相关组件"""
116
+ if not hasattr(dit_model, 'clean_x_embedder'):
117
+ inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
118
+
119
+ class CleanXEmbedder(nn.Module):
120
+ def __init__(self, inner_dim):
121
+ super().__init__()
122
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
123
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
124
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
125
+
126
+ def forward(self, x, scale="1x"):
127
+ if scale == "1x":
128
+ x = x.to(self.proj.weight.dtype)
129
+ return self.proj(x)
130
+ elif scale == "2x":
131
+ x = x.to(self.proj_2x.weight.dtype)
132
+ return self.proj_2x(x)
133
+ elif scale == "4x":
134
+ x = x.to(self.proj_4x.weight.dtype)
135
+ return self.proj_4x(x)
136
+ else:
137
+ raise ValueError(f"Unsupported scale: {scale}")
138
+
139
+ dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
140
+ model_dtype = next(dit_model.parameters()).dtype
141
+ dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
142
+ print("✅ 添加了FramePack的clean_x_embedder组件")
143
+
144
+
145
+ def add_moe_components(dit_model, moe_config):
146
+ """🔧 添加MoE相关组件 - 修正版本"""
147
+ if not hasattr(dit_model, 'moe_config'):
148
+ dit_model.moe_config = moe_config
149
+ print("✅ 添加了MoE配置到模型")
150
+ dit_model.top_k = moe_config.get("top_k", 1)
151
+
152
+ # 为每个block动态添加MoE组件
153
+ dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
154
+ unified_dim = moe_config.get("unified_dim", 25)
155
+ num_experts = moe_config.get("num_experts", 4)
156
+ from diffsynth.models.wan_video_dit_moe import ModalityProcessor, MultiModalMoE
157
+ dit_model.sekai_processor = ModalityProcessor("sekai", 13, unified_dim)
158
+ dit_model.nuscenes_processor = ModalityProcessor("nuscenes", 8, unified_dim)
159
+ dit_model.openx_processor = ModalityProcessor("openx", 13, unified_dim) # OpenX使用13维输入,类似sekai但独立处理
160
+ dit_model.global_router = nn.Linear(unified_dim, num_experts)
161
+
162
+
163
+ for i, block in enumerate(dit_model.blocks):
164
+ # MoE网络 - 输入unified_dim,输出dim
165
+ block.moe = MultiModalMoE(
166
+ unified_dim=unified_dim,
167
+ output_dim=dim, # 输出维度匹配transformer block的dim
168
+ num_experts=moe_config.get("num_experts", 4),
169
+ top_k=moe_config.get("top_k", 2)
170
+ )
171
+
172
+ print(f"✅ Block {i} 添加了MoE组件 (unified_dim: {unified_dim}, experts: {moe_config.get('num_experts', 4)})")
173
+
174
+
175
+ def generate_sekai_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True,direction="left"):
176
+ """为Sekai数据集生成camera embeddings - 滑动窗口版本"""
177
+ time_compression_ratio = 4
178
+
179
+ # 计算FramePack实际需要的camera帧数
180
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
181
+
182
+ if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
183
+ print("🔧 使用真实Sekai camera数据")
184
+ cam_extrinsic = cam_data['extrinsic']
185
+
186
+ # 确保生成足够长的camera序列
187
+ max_needed_frames = max(
188
+ start_frame + current_history_length + new_frames,
189
+ framepack_needed_frames,
190
+ 30
191
+ )
192
+
193
+ print(f"🔧 计算Sekai camera序列长度:")
194
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
195
+ print(f" - FramePack需求: {framepack_needed_frames}")
196
+ print(f" - 最终生成: {max_needed_frames}")
197
+
198
+ relative_poses = []
199
+ for i in range(max_needed_frames):
200
+ # 计算当前帧在原始序列中的位置
201
+ frame_idx = i * time_compression_ratio
202
+ next_frame_idx = frame_idx + time_compression_ratio
203
+
204
+ if next_frame_idx < len(cam_extrinsic):
205
+ cam_prev = cam_extrinsic[frame_idx]
206
+ cam_next = cam_extrinsic[next_frame_idx]
207
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
208
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
209
+ else:
210
+ # 超出范围,使用零运动
211
+ print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
212
+ relative_poses.append(torch.zeros(3, 4))
213
+
214
+ pose_embedding = torch.stack(relative_poses, dim=0)
215
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
216
+
217
+ # 创建对应长度的mask序列
218
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
219
+ # 从start_frame到current_history_length标记为condition
220
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
221
+ mask[start_frame:condition_end] = 1.0
222
+
223
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
224
+ print(f"🔧 Sekai真实camera embedding shape: {camera_embedding.shape}")
225
+ return camera_embedding.to(torch.bfloat16)
226
+
227
+ else:
228
+ if direction=="left":
229
+ print("-----Left-------")
230
+
231
+ max_needed_frames = max(
232
+ start_frame + current_history_length + new_frames,
233
+ framepack_needed_frames,
234
+ 30
235
+ )
236
+
237
+ print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
238
+ relative_poses = []
239
+ for i in range(max_needed_frames):
240
+ # 持续左转运动模式
241
+ yaw_per_frame = 0.05 # 每帧左转(正角度表示左转)
242
+ forward_speed = 0.05 # 每帧前进距离
243
+
244
+ pose = np.eye(4, dtype=np.float32)
245
+
246
+ # 旋转矩阵(绕Y轴左转)
247
+ cos_yaw = np.cos(yaw_per_frame)
248
+ sin_yaw = np.sin(yaw_per_frame)
249
+
250
+ pose[0, 0] = cos_yaw
251
+ pose[0, 2] = sin_yaw
252
+ pose[2, 0] = -sin_yaw
253
+ pose[2, 2] = cos_yaw
254
+
255
+ # 平移(在旋转后的局部坐标系中前进)
256
+ pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
257
+
258
+ # 添加轻微的向心运动,模拟圆形轨迹
259
+ radius_drift = 0.002 # 向圆心的轻微漂移
260
+ pose[0, 3] = -radius_drift # 局部X轴负方向(向左)
261
+
262
+ relative_pose = pose[:3, :]
263
+ relative_poses.append(torch.as_tensor(relative_pose))
264
+
265
+ pose_embedding = torch.stack(relative_poses, dim=0)
266
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
267
+
268
+ # 创建对应长度的mask序列
269
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
270
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
271
+ mask[start_frame:condition_end] = 1.0
272
+
273
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
274
+ print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
275
+ return camera_embedding.to(torch.bfloat16)
276
+ elif direction=="right":
277
+ print("------------Right----------")
278
+
279
+ max_needed_frames = max(
280
+ start_frame + current_history_length + new_frames,
281
+ framepack_needed_frames,
282
+ 30
283
+ )
284
+
285
+ print(f"🔧 生成Sekai合成camera帧数: {max_needed_frames}")
286
+ relative_poses = []
287
+ for i in range(max_needed_frames):
288
+ # 持续左转运动模式
289
+ yaw_per_frame = -0.00 # 每帧左转(正角度表示左转)
290
+ forward_speed = 0.1 # 每帧前进距离
291
+
292
+ pose = np.eye(4, dtype=np.float32)
293
+
294
+ # 旋转矩阵(绕Y轴左转)
295
+ cos_yaw = np.cos(yaw_per_frame)
296
+ sin_yaw = np.sin(yaw_per_frame)
297
+
298
+ pose[0, 0] = cos_yaw
299
+ pose[0, 2] = sin_yaw
300
+ pose[2, 0] = -sin_yaw
301
+ pose[2, 2] = cos_yaw
302
+
303
+ # 平移(在旋转后的局部坐标系中前进)
304
+ pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
305
+
306
+ # 添加轻微的向心运动,模拟圆形轨迹
307
+ radius_drift = 0.000 # 向圆心的轻微漂移
308
+ pose[0, 3] = radius_drift # 局部X轴负方向(向左)
309
+
310
+ relative_pose = pose[:3, :]
311
+ relative_poses.append(torch.as_tensor(relative_pose))
312
+
313
+ pose_embedding = torch.stack(relative_poses, dim=0)
314
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
315
+
316
+ # 创建对应长度的mask序列
317
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
318
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
319
+ mask[start_frame:condition_end] = 1.0
320
+
321
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
322
+ print(f"🔧 Sekai合成camera embedding shape: {camera_embedding.shape}")
323
+ return camera_embedding.to(torch.bfloat16)
324
+
325
+
326
+ def generate_openx_camera_embeddings_sliding(encoded_data, start_frame, current_history_length, new_frames, use_real_poses):
327
+ """为OpenX数据集生成camera embeddings - 滑动窗口版本"""
328
+ time_compression_ratio = 4
329
+
330
+ # 计算FramePack实际需要的camera帧数
331
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
332
+
333
+ if use_real_poses and encoded_data is not None and 'cam_emb' in encoded_data and 'extrinsic' in encoded_data['cam_emb']:
334
+ print("🔧 使用OpenX真实camera数据")
335
+ cam_extrinsic = encoded_data['cam_emb']['extrinsic']
336
+
337
+ # 确保生成足够长的camera序列
338
+ max_needed_frames = max(
339
+ start_frame + current_history_length + new_frames,
340
+ framepack_needed_frames,
341
+ 30
342
+ )
343
+
344
+ print(f"🔧 计算OpenX camera序列长度:")
345
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
346
+ print(f" - FramePack需求: {framepack_needed_frames}")
347
+ print(f" - 最终生成: {max_needed_frames}")
348
+
349
+ relative_poses = []
350
+ for i in range(max_needed_frames):
351
+ # OpenX使用4倍间隔,类似sekai但处理更短的序列
352
+ frame_idx = i * time_compression_ratio
353
+ next_frame_idx = frame_idx + time_compression_ratio
354
+
355
+ if next_frame_idx < len(cam_extrinsic):
356
+ cam_prev = cam_extrinsic[frame_idx]
357
+ cam_next = cam_extrinsic[next_frame_idx]
358
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
359
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
360
+ else:
361
+ # 超出范围,使用零运动
362
+ print(f"⚠️ 帧{frame_idx}超出OpenX camera数据范围,使用零运动")
363
+ relative_poses.append(torch.zeros(3, 4))
364
+
365
+ pose_embedding = torch.stack(relative_poses, dim=0)
366
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
367
+
368
+ # 创建对应长度的mask序列
369
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
370
+ # 从start_frame到current_history_length标记为condition
371
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
372
+ mask[start_frame:condition_end] = 1.0
373
+
374
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
375
+ print(f"🔧 OpenX真实camera embedding shape: {camera_embedding.shape}")
376
+ return camera_embedding.to(torch.bfloat16)
377
+
378
+ else:
379
+ print("🔧 使用OpenX合成camera数据")
380
+
381
+ max_needed_frames = max(
382
+ start_frame + current_history_length + new_frames,
383
+ framepack_needed_frames,
384
+ 30
385
+ )
386
+
387
+ print(f"🔧 生成OpenX合成camera帧数: {max_needed_frames}")
388
+ relative_poses = []
389
+ for i in range(max_needed_frames):
390
+ # OpenX机器人操作运动模式 - 较小的运动幅度
391
+ # 模拟机器人手臂的精细操作运动
392
+ roll_per_frame = 0.02 # 轻微翻滚
393
+ pitch_per_frame = 0.01 # 轻微俯仰
394
+ yaw_per_frame = 0.015 # 轻微偏航
395
+ forward_speed = 0.003 # 较慢的前进速度
396
+
397
+ pose = np.eye(4, dtype=np.float32)
398
+
399
+ # 复合旋转 - 模拟机器人手臂的复杂运动
400
+ # 绕X轴旋转(roll)
401
+ cos_roll = np.cos(roll_per_frame)
402
+ sin_roll = np.sin(roll_per_frame)
403
+ # 绕Y轴旋转(pitch)
404
+ cos_pitch = np.cos(pitch_per_frame)
405
+ sin_pitch = np.sin(pitch_per_frame)
406
+ # 绕Z轴旋转(yaw)
407
+ cos_yaw = np.cos(yaw_per_frame)
408
+ sin_yaw = np.sin(yaw_per_frame)
409
+
410
+ # 简化的复合旋转矩阵(ZYX顺序)
411
+ pose[0, 0] = cos_yaw * cos_pitch
412
+ pose[0, 1] = cos_yaw * sin_pitch * sin_roll - sin_yaw * cos_roll
413
+ pose[0, 2] = cos_yaw * sin_pitch * cos_roll + sin_yaw * sin_roll
414
+ pose[1, 0] = sin_yaw * cos_pitch
415
+ pose[1, 1] = sin_yaw * sin_pitch * sin_roll + cos_yaw * cos_roll
416
+ pose[1, 2] = sin_yaw * sin_pitch * cos_roll - cos_yaw * sin_roll
417
+ pose[2, 0] = -sin_pitch
418
+ pose[2, 1] = cos_pitch * sin_roll
419
+ pose[2, 2] = cos_pitch * cos_roll
420
+
421
+ # 平移 - 模拟机器人操作的精细移动
422
+ pose[0, 3] = forward_speed * 0.5 # X方向轻微移动
423
+ pose[1, 3] = forward_speed * 0.3 # Y方向轻微移动
424
+ pose[2, 3] = -forward_speed # Z方向(深度)主要移动
425
+
426
+ relative_pose = pose[:3, :]
427
+ relative_poses.append(torch.as_tensor(relative_pose))
428
+
429
+ pose_embedding = torch.stack(relative_poses, dim=0)
430
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
431
+
432
+ # 创建对应长度的mask序列
433
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
434
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
435
+ mask[start_frame:condition_end] = 1.0
436
+
437
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
438
+ print(f"🔧 OpenX合成camera embedding shape: {camera_embedding.shape}")
439
+ return camera_embedding.to(torch.bfloat16)
440
+
441
+
442
+ def generate_nuscenes_camera_embeddings_sliding(scene_info, start_frame, current_history_length, new_frames):
443
+ """为NuScenes数据集生成camera embeddings - 滑动窗口版本 - 修正版,与train_moe.py保持一致"""
444
+ time_compression_ratio = 4
445
+
446
+ # 计算FramePack实际需要的camera��数
447
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
448
+
449
+ if scene_info is not None and 'keyframe_poses' in scene_info:
450
+ print("🔧 使用NuScenes真实pose数据")
451
+ keyframe_poses = scene_info['keyframe_poses']
452
+
453
+ if len(keyframe_poses) == 0:
454
+ print("⚠️ NuScenes keyframe_poses为空,使用零pose")
455
+ max_needed_frames = max(framepack_needed_frames, 30)
456
+
457
+ pose_sequence = torch.zeros(max_needed_frames, 7, dtype=torch.float32)
458
+
459
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
460
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
461
+ mask[start_frame:condition_end] = 1.0
462
+
463
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
464
+ print(f"🔧 NuScenes零pose embedding shape: {camera_embedding.shape}")
465
+ return camera_embedding.to(torch.bfloat16)
466
+
467
+ # 使用第一个pose作为参考
468
+ reference_pose = keyframe_poses[0]
469
+
470
+ max_needed_frames = max(framepack_needed_frames, 30)
471
+
472
+ pose_vecs = []
473
+ for i in range(max_needed_frames):
474
+ if i < len(keyframe_poses):
475
+ current_pose = keyframe_poses[i]
476
+
477
+ # 计算相对位移
478
+ translation = torch.tensor(
479
+ np.array(current_pose['translation']) - np.array(reference_pose['translation']),
480
+ dtype=torch.float32
481
+ )
482
+
483
+ # 计算相对旋转(简化版本)
484
+ rotation = torch.tensor(current_pose['rotation'], dtype=torch.float32)
485
+
486
+ pose_vec = torch.cat([translation, rotation], dim=0) # [7D]
487
+ else:
488
+ # 超出范围,使用零pose
489
+ pose_vec = torch.cat([
490
+ torch.zeros(3, dtype=torch.float32),
491
+ torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32)
492
+ ], dim=0) # [7D]
493
+
494
+ pose_vecs.append(pose_vec)
495
+
496
+ pose_sequence = torch.stack(pose_vecs, dim=0) # [max_needed_frames, 7]
497
+
498
+ # 创建mask
499
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
500
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
501
+ mask[start_frame:condition_end] = 1.0
502
+
503
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
504
+ print(f"🔧 NuScenes真实pose embedding shape: {camera_embedding.shape}")
505
+ return camera_embedding.to(torch.bfloat16)
506
+
507
+ else:
508
+ print("🔧 使用NuScenes合成pose数据")
509
+ max_needed_frames = max(framepack_needed_frames, 30)
510
+
511
+ # 创建合成运动序列
512
+ pose_vecs = []
513
+ for i in range(max_needed_frames):
514
+ # 左转运动模式 - 类似城市驾驶中的左转弯
515
+ angle = i * 0.04 # 每帧转动0.08弧度(稍微慢一点的转弯)
516
+ radius = 15.0 # 较大的转弯半径,更符合汽车转弯
517
+
518
+ # 计算圆弧轨迹上的位置
519
+ x = radius * np.sin(angle)
520
+ y = 0.0 # 保持水平面运动
521
+ z = radius * (1 - np.cos(angle))
522
+
523
+ translation = torch.tensor([x, y, z], dtype=torch.float32)
524
+
525
+ # 车辆朝向 - 始终沿着轨迹切线方向
526
+ yaw = angle + np.pi/2 # 相对于初始前进方向的偏航角
527
+ # 四元数表示绕Y轴的旋转
528
+ rotation = torch.tensor([
529
+ np.cos(yaw/2), # w (实部)
530
+ 0.0, # x
531
+ 0.0, # y
532
+ np.sin(yaw/2) # z (虚部,绕Y轴)
533
+ ], dtype=torch.float32)
534
+
535
+ pose_vec = torch.cat([translation, rotation], dim=0) # [7D: tx,ty,tz,qw,qx,qy,qz]
536
+ pose_vecs.append(pose_vec)
537
+
538
+ pose_sequence = torch.stack(pose_vecs, dim=0)
539
+
540
+ # 创建mask
541
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
542
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
543
+ mask[start_frame:condition_end] = 1.0
544
+
545
+ camera_embedding = torch.cat([pose_sequence, mask], dim=1) # [max_needed_frames, 8]
546
+ print(f"🔧 NuScenes合成左转pose embedding shape: {camera_embedding.shape}")
547
+ return camera_embedding.to(torch.bfloat16)
548
+
549
+ def prepare_framepack_sliding_window_with_camera_moe(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, modality_type, max_history_frames=49):
550
+ """FramePack滑动窗口机制 - MoE版本"""
551
+ # history_latents: [C, T, H, W] 当前的历史latents
552
+ C, T, H, W = history_latents.shape
553
+
554
+ # 固定索引结构(这决定了需要的camera帧数)
555
+ total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
556
+ indices = torch.arange(0, total_indices_length)
557
+ split_sizes = [1, 16, 2, 1, target_frames_to_generate]
558
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
559
+ indices.split(split_sizes, dim=0)
560
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
561
+
562
+ # 检查camera长度是否足够
563
+ if camera_embedding_full.shape[0] < total_indices_length:
564
+ shortage = total_indices_length - camera_embedding_full.shape[0]
565
+ padding = torch.zeros(shortage, camera_embedding_full.shape[1],
566
+ dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
567
+ camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
568
+
569
+ # 从完整camera序列中选取对应部分
570
+ combined_camera = camera_embedding_full[:total_indices_length, :].clone()
571
+
572
+ # 根据当前history length重新设置mask
573
+ combined_camera[:, -1] = 0.0 # 先全部设为target (0)
574
+
575
+ # 设置condition mask:前19帧根据实际历史长度决定
576
+ if T > 0:
577
+ available_frames = min(T, 19)
578
+ start_pos = 19 - available_frames
579
+ combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
580
+
581
+ print(f"🔧 MoE Camera mask更新:")
582
+ print(f" - 历史帧数: {T}")
583
+ print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
584
+ print(f" - 模态类型: {modality_type}")
585
+
586
+ # 处理latents
587
+ clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
588
+
589
+ if T > 0:
590
+ available_frames = min(T, 19)
591
+ start_pos = 19 - available_frames
592
+ clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
593
+
594
+ clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
595
+ clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
596
+ clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
597
+
598
+ if T > 0:
599
+ start_latent = history_latents[:, 0:1, :, :]
600
+ else:
601
+ start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
602
+
603
+ clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
604
+
605
+ return {
606
+ 'latent_indices': latent_indices,
607
+ 'clean_latents': clean_latents,
608
+ 'clean_latents_2x': clean_latents_2x,
609
+ 'clean_latents_4x': clean_latents_4x,
610
+ 'clean_latent_indices': clean_latent_indices,
611
+ 'clean_latent_2x_indices': clean_latent_2x_indices,
612
+ 'clean_latent_4x_indices': clean_latent_4x_indices,
613
+ 'camera_embedding': combined_camera,
614
+ 'modality_type': modality_type, # 新增模态类型信息
615
+ 'current_length': T,
616
+ 'next_length': T + target_frames_to_generate
617
+ }
618
+
619
+
620
+ def inference_moe_framepack_sliding_window(
621
+ condition_pth_path,
622
+ dit_path,
623
+ output_path="moe/infer_results/output_moe_framepack_sliding.mp4",
624
+ start_frame=0,
625
+ initial_condition_frames=8,
626
+ frames_per_generation=4,
627
+ total_frames_to_generate=32,
628
+ max_history_frames=49,
629
+ device="cuda",
630
+ prompt="A video of a scene shot using a pedestrian's front camera while walking",
631
+ modality_type="sekai", # "sekai" 或 "nuscenes"
632
+ use_real_poses=True,
633
+ scene_info_path=None, # 对于NuScenes数据集
634
+ # CFG参数
635
+ use_camera_cfg=True,
636
+ camera_guidance_scale=2.0,
637
+ text_guidance_scale=1.0,
638
+ # MoE参数
639
+ moe_num_experts=4,
640
+ moe_top_k=2,
641
+ moe_hidden_dim=None,
642
+ direction="left",
643
+ use_gt_prompt=True
644
+ ):
645
+ """
646
+ MoE FramePack滑动窗口视频生成 - 支持多模态
647
+ """
648
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
649
+ print(f"🔧 MoE FramePack滑动窗口生成开始...")
650
+ print(f"模态类型: {modality_type}")
651
+ print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
652
+ print(f"Text guidance scale: {text_guidance_scale}")
653
+ print(f"MoE配置: experts={moe_num_experts}, top_k={moe_top_k}")
654
+
655
+ # 1. 模型初始化
656
+ replace_dit_model_in_manager()
657
+
658
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
659
+ model_manager.load_models([
660
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
661
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
662
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
663
+ ])
664
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
665
+
666
+ # 2. 添加传统camera编码器(兼容性)
667
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
668
+ for block in pipe.dit.blocks:
669
+ block.cam_encoder = nn.Linear(13, dim)
670
+ block.projector = nn.Linear(dim, dim)
671
+ block.cam_encoder.weight.data.zero_()
672
+ block.cam_encoder.bias.data.zero_()
673
+ block.projector.weight = nn.Parameter(torch.eye(dim))
674
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
675
+
676
+ # 3. 添加FramePack组件
677
+ add_framepack_components(pipe.dit)
678
+
679
+ # 4. 添加MoE组件
680
+ moe_config = {
681
+ "num_experts": moe_num_experts,
682
+ "top_k": moe_top_k,
683
+ "hidden_dim": moe_hidden_dim or dim * 2,
684
+ "sekai_input_dim": 13, # Sekai: 12维pose + 1维mask
685
+ "nuscenes_input_dim": 8, # NuScenes: 7维pose + 1维mask
686
+ "openx_input_dim": 13 # OpenX: 12维pose + 1维mask (类似sekai)
687
+ }
688
+ add_moe_components(pipe.dit, moe_config)
689
+
690
+ # 5. 加载训练好的权重
691
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
692
+ pipe.dit.load_state_dict(dit_state_dict, strict=False) # 使用strict=False以兼容新增的MoE组件
693
+ pipe = pipe.to(device)
694
+ model_dtype = next(pipe.dit.parameters()).dtype
695
+
696
+ if hasattr(pipe.dit, 'clean_x_embedder'):
697
+ pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
698
+
699
+ pipe.scheduler.set_timesteps(50)
700
+
701
+ # 6. 加载初始条件
702
+ print("Loading initial condition frames...")
703
+ initial_latents, encoded_data = load_encoded_video_from_pth(
704
+ condition_pth_path,
705
+ start_frame=start_frame,
706
+ num_frames=initial_condition_frames
707
+ )
708
+
709
+ # 空间裁剪
710
+ target_height, target_width = 60, 104
711
+ C, T, H, W = initial_latents.shape
712
+
713
+ if H > target_height or W > target_width:
714
+ h_start = (H - target_height) // 2
715
+ w_start = (W - target_width) // 2
716
+ initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
717
+ H, W = target_height, target_width
718
+
719
+ history_latents = initial_latents.to(device, dtype=model_dtype)
720
+
721
+ print(f"初始history_latents shape: {history_latents.shape}")
722
+
723
+ # 7. 编码prompt - 支持CFG
724
+ if use_gt_prompt and 'prompt_emb' in encoded_data:
725
+ print("✅ 使用预编码的GT prompt embedding")
726
+ prompt_emb_pos = encoded_data['prompt_emb']
727
+ # 将prompt_emb移到正确的设备和数据类型
728
+ if 'context' in prompt_emb_pos:
729
+ prompt_emb_pos['context'] = prompt_emb_pos['context'].to(device, dtype=model_dtype)
730
+ if 'context_mask' in prompt_emb_pos:
731
+ prompt_emb_pos['context_mask'] = prompt_emb_pos['context_mask'].to(device, dtype=model_dtype)
732
+
733
+ # 如果使用Text CFG,生成负向prompt
734
+ if text_guidance_scale > 1.0:
735
+ prompt_emb_neg = pipe.encode_prompt("")
736
+ print(f"使用Text CFG with GT prompt,guidance scale: {text_guidance_scale}")
737
+ else:
738
+ prompt_emb_neg = None
739
+ print("不使用Text CFG")
740
+
741
+ # 🔧 打印GT prompt文本(如果有)
742
+ if 'prompt' in encoded_data['prompt_emb']:
743
+ gt_prompt_text = encoded_data['prompt_emb']['prompt']
744
+ print(f"📝 GT Prompt文本: {gt_prompt_text}")
745
+ else:
746
+ # 使用传入的prompt参数重新编码
747
+ print(f"🔄 重新编码prompt: {prompt}")
748
+ if text_guidance_scale > 1.0:
749
+ prompt_emb_pos = pipe.encode_prompt(prompt)
750
+ prompt_emb_neg = pipe.encode_prompt("")
751
+ print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
752
+ else:
753
+ prompt_emb_pos = pipe.encode_prompt(prompt)
754
+ prompt_emb_neg = None
755
+ print("不使用Text CFG")
756
+
757
+ # 8. 加载场景信息(对于NuScenes)
758
+ scene_info = None
759
+ if modality_type == "nuscenes" and scene_info_path and os.path.exists(scene_info_path):
760
+ with open(scene_info_path, 'r') as f:
761
+ scene_info = json.load(f)
762
+ print(f"加载NuScenes场景信息: {scene_info_path}")
763
+
764
+ # 9. 预生成完整的camera embedding序列
765
+ if modality_type == "sekai":
766
+ camera_embedding_full = generate_sekai_camera_embeddings_sliding(
767
+ encoded_data.get('cam_emb', None),
768
+ 0,
769
+ max_history_frames,
770
+ 0,
771
+ 0,
772
+ use_real_poses=use_real_poses,
773
+ direction=direction
774
+ ).to(device, dtype=model_dtype)
775
+ elif modality_type == "nuscenes":
776
+ camera_embedding_full = generate_nuscenes_camera_embeddings_sliding(
777
+ scene_info,
778
+ 0,
779
+ max_history_frames,
780
+ 0
781
+ ).to(device, dtype=model_dtype)
782
+ elif modality_type == "openx":
783
+ camera_embedding_full = generate_openx_camera_embeddings_sliding(
784
+ encoded_data,
785
+ 0,
786
+ max_history_frames,
787
+ 0,
788
+ use_real_poses=use_real_poses
789
+ ).to(device, dtype=model_dtype)
790
+ else:
791
+ raise ValueError(f"不支持的模态类型: {modality_type}")
792
+
793
+ print(f"完整camera序列shape: {camera_embedding_full.shape}")
794
+
795
+ # 10. 为Camera CFG创建无条件的camera embedding
796
+ if use_camera_cfg:
797
+ camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
798
+ print(f"创建无条件camera embedding用于CFG")
799
+
800
+ # 11. 滑动窗口生成循环
801
+ total_generated = 0
802
+ all_generated_frames = []
803
+
804
+ while total_generated < total_frames_to_generate:
805
+ current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
806
+ print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
807
+ print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
808
+
809
+ # FramePack数据准备 - MoE版本
810
+ framepack_data = prepare_framepack_sliding_window_with_camera_moe(
811
+ history_latents,
812
+ current_generation,
813
+ camera_embedding_full,
814
+ start_frame,
815
+ modality_type,
816
+ max_history_frames
817
+ )
818
+
819
+ # 准备输入
820
+ clean_latents = framepack_data['clean_latents'].unsqueeze(0)
821
+ clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
822
+ clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
823
+ camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
824
+
825
+ # 准备modality_inputs
826
+ modality_inputs = {modality_type: camera_embedding}
827
+
828
+ # 为CFG准备无条件camera embedding
829
+ if use_camera_cfg:
830
+ camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
831
+ modality_inputs_uncond = {modality_type: camera_embedding_uncond_batch}
832
+
833
+ # 索引处理
834
+ latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
835
+ clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
836
+ clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
837
+ clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
838
+
839
+ # 初始化要生成的latents
840
+ new_latents = torch.randn(
841
+ 1, C, current_generation, H, W,
842
+ device=device, dtype=model_dtype
843
+ )
844
+
845
+ extra_input = pipe.prepare_extra_input(new_latents)
846
+
847
+ print(f"Camera embedding shape: {camera_embedding.shape}")
848
+ print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
849
+
850
+ # 去噪循环 - 支持CFG
851
+ timesteps = pipe.scheduler.timesteps
852
+
853
+ for i, timestep in enumerate(timesteps):
854
+ if i % 10 == 0:
855
+ print(f" 去噪步骤 {i+1}/{len(timesteps)}")
856
+
857
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
858
+
859
+ with torch.no_grad():
860
+ # CFG推理
861
+ if use_camera_cfg and camera_guidance_scale > 1.0:
862
+ # 条件预测(有camera)
863
+ noise_pred_cond, moe_loess = pipe.dit(
864
+ new_latents,
865
+ timestep=timestep_tensor,
866
+ cam_emb=camera_embedding,
867
+ modality_inputs=modality_inputs, # MoE模态输入
868
+ latent_indices=latent_indices,
869
+ clean_latents=clean_latents,
870
+ clean_latent_indices=clean_latent_indices,
871
+ clean_latents_2x=clean_latents_2x,
872
+ clean_latent_2x_indices=clean_latent_2x_indices,
873
+ clean_latents_4x=clean_latents_4x,
874
+ clean_latent_4x_indices=clean_latent_4x_indices,
875
+ **prompt_emb_pos,
876
+ **extra_input
877
+ )
878
+
879
+ # 无条件预测(无camera)
880
+ noise_pred_uncond, moe_loess = pipe.dit(
881
+ new_latents,
882
+ timestep=timestep_tensor,
883
+ cam_emb=camera_embedding_uncond_batch,
884
+ modality_inputs=modality_inputs_uncond, # MoE无条件模态输入
885
+ latent_indices=latent_indices,
886
+ clean_latents=clean_latents,
887
+ clean_latent_indices=clean_latent_indices,
888
+ clean_latents_2x=clean_latents_2x,
889
+ clean_latent_2x_indices=clean_latent_2x_indices,
890
+ clean_latents_4x=clean_latents_4x,
891
+ clean_latent_4x_indices=clean_latent_4x_indices,
892
+ **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
893
+ **extra_input
894
+ )
895
+
896
+ # Camera CFG
897
+ noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
898
+
899
+ # 如果同时使用Text CFG
900
+ if text_guidance_scale > 1.0 and prompt_emb_neg:
901
+ noise_pred_text_uncond, moe_loess = pipe.dit(
902
+ new_latents,
903
+ timestep=timestep_tensor,
904
+ cam_emb=camera_embedding,
905
+ modality_inputs=modality_inputs,
906
+ latent_indices=latent_indices,
907
+ clean_latents=clean_latents,
908
+ clean_latent_indices=clean_latent_indices,
909
+ clean_latents_2x=clean_latents_2x,
910
+ clean_latent_2x_indices=clean_latent_2x_indices,
911
+ clean_latents_4x=clean_latents_4x,
912
+ clean_latent_4x_indices=clean_latent_4x_indices,
913
+ **prompt_emb_neg,
914
+ **extra_input
915
+ )
916
+
917
+ # 应用Text CFG到已经应用Camera CFG的结果
918
+ noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
919
+
920
+ elif text_guidance_scale > 1.0 and prompt_emb_neg:
921
+ # 只使用Text CFG
922
+ noise_pred_cond, moe_loess = pipe.dit(
923
+ new_latents,
924
+ timestep=timestep_tensor,
925
+ cam_emb=camera_embedding,
926
+ modality_inputs=modality_inputs,
927
+ latent_indices=latent_indices,
928
+ clean_latents=clean_latents,
929
+ clean_latent_indices=clean_latent_indices,
930
+ clean_latents_2x=clean_latents_2x,
931
+ clean_latent_2x_indices=clean_latent_2x_indices,
932
+ clean_latents_4x=clean_latents_4x,
933
+ clean_latent_4x_indices=clean_latent_4x_indices,
934
+ **prompt_emb_pos,
935
+ **extra_input
936
+ )
937
+
938
+ noise_pred_uncond, moe_loess= pipe.dit(
939
+ new_latents,
940
+ timestep=timestep_tensor,
941
+ cam_emb=camera_embedding,
942
+ modality_inputs=modality_inputs,
943
+ latent_indices=latent_indices,
944
+ clean_latents=clean_latents,
945
+ clean_latent_indices=clean_latent_indices,
946
+ clean_latents_2x=clean_latents_2x,
947
+ clean_latent_2x_indices=clean_latent_2x_indices,
948
+ clean_latents_4x=clean_latents_4x,
949
+ clean_latent_4x_indices=clean_latent_4x_indices,
950
+ **prompt_emb_neg,
951
+ **extra_input
952
+ )
953
+
954
+ noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
955
+
956
+ else:
957
+ # 标准推理(无CFG)
958
+ noise_pred, moe_loess = pipe.dit(
959
+ new_latents,
960
+ timestep=timestep_tensor,
961
+ cam_emb=camera_embedding,
962
+ modality_inputs=modality_inputs, # MoE模态输入
963
+ latent_indices=latent_indices,
964
+ clean_latents=clean_latents,
965
+ clean_latent_indices=clean_latent_indices,
966
+ clean_latents_2x=clean_latents_2x,
967
+ clean_latent_2x_indices=clean_latent_2x_indices,
968
+ clean_latents_4x=clean_latents_4x,
969
+ clean_latent_4x_indices=clean_latent_4x_indices,
970
+ **prompt_emb_pos,
971
+ **extra_input
972
+ )
973
+
974
+ new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
975
+
976
+ # 更新历史
977
+ new_latents_squeezed = new_latents.squeeze(0)
978
+ history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
979
+
980
+ # 维护滑动窗口
981
+ if history_latents.shape[1] > max_history_frames:
982
+ first_frame = history_latents[:, 0:1, :, :]
983
+ recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
984
+ history_latents = torch.cat([first_frame, recent_frames], dim=1)
985
+ print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
986
+
987
+ print(f"更新后history_latents shape: {history_latents.shape}")
988
+
989
+ all_generated_frames.append(new_latents_squeezed)
990
+ total_generated += current_generation
991
+
992
+ print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
993
+
994
+ # 12. 解码和保存
995
+ print("\n🔧 解码生成的视频...")
996
+
997
+ all_generated = torch.cat(all_generated_frames, dim=1)
998
+ final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
999
+
1000
+ print(f"最终视频shape: {final_video.shape}")
1001
+
1002
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
1003
+
1004
+ print(f"Saving video to {output_path}")
1005
+
1006
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
1007
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1)
1008
+ video_np = (video_np * 255).astype(np.uint8)
1009
+
1010
+ with imageio.get_writer(output_path, fps=20) as writer:
1011
+ for frame in video_np:
1012
+ writer.append_data(frame)
1013
+
1014
+ print(f"🔧 MoE FramePack滑动窗口生成完成! 保存到: {output_path}")
1015
+ print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
1016
+ print(f"使用模态: {modality_type}")
1017
+
1018
+
1019
+ def main():
1020
+ parser = argparse.ArgumentParser(description="MoE FramePack滑动窗口视频生成 - 支持多模态")
1021
+
1022
+ # 基础参数
1023
+ parser.add_argument("--condition_pth", type=str,
1024
+ #default="/share_zhuyixuan05/zhuyixuan05/sekai-game-drone/00500210001_0012150_0012450/encoded_video.pth")
1025
+ default="/share_zhuyixuan05/zhuyixuan05/nuscenes_video_generation_dynamic/scenes/scene-0001_CAM_FRONT/encoded_video-480p.pth")
1026
+ #default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth")
1027
+ #default="/share_zhuyixuan05/zhuyixuan05/openx-fractal-encoded/episode_000001/encoded_video.pth")
1028
+ parser.add_argument("--start_frame", type=int, default=0)
1029
+ parser.add_argument("--initial_condition_frames", type=int, default=16)
1030
+ parser.add_argument("--frames_per_generation", type=int, default=8)
1031
+ parser.add_argument("--total_frames_to_generate", type=int, default=24)
1032
+ parser.add_argument("--max_history_frames", type=int, default=100)
1033
+ parser.add_argument("--use_real_poses", default=False)
1034
+ parser.add_argument("--dit_path", type=str,
1035
+ default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/framepack_moe/step175000_origin_other_continue3.ckpt")
1036
+ parser.add_argument("--output_path", type=str,
1037
+ default='/home/zhuyixuan05/ReCamMaster/moe/infer_results/output_moe_framepack_sliding.mp4')
1038
+ parser.add_argument("--prompt", type=str,
1039
+ default="A car is driving")
1040
+ parser.add_argument("--device", type=str, default="cuda")
1041
+
1042
+ # 模态类型参数
1043
+ parser.add_argument("--modality_type", type=str, choices=["sekai", "nuscenes", "openx"], default="nuscenes",
1044
+ help="模态类型:sekai 或 nuscenes 或 openx")
1045
+ parser.add_argument("--scene_info_path", type=str, default=None,
1046
+ help="NuScenes场景信息文件路径(仅用于nuscenes模态)")
1047
+
1048
+ # CFG参数
1049
+ parser.add_argument("--use_camera_cfg", default=False,
1050
+ help="使用Camera CFG")
1051
+ parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
1052
+ help="Camera guidance scale for CFG")
1053
+ parser.add_argument("--text_guidance_scale", type=float, default=1.0,
1054
+ help="Text guidance scale for CFG")
1055
+
1056
+ # MoE参数
1057
+ parser.add_argument("--moe_num_experts", type=int, default=3, help="专家数量")
1058
+ parser.add_argument("--moe_top_k", type=int, default=1, help="Top-K专家")
1059
+ parser.add_argument("--moe_hidden_dim", type=int, default=None, help="MoE隐藏层维度")
1060
+ parser.add_argument("--direction", type=str, default="left")
1061
+ parser.add_argument("--use_gt_prompt", action="store_true", default=False,
1062
+ help="使用数据集中的ground truth prompt embedding")
1063
+
1064
+ args = parser.parse_args()
1065
+
1066
+ print(f"🔧 MoE FramePack CFG生成设置:")
1067
+ print(f"模态类型: {args.modality_type}")
1068
+ print(f"Camera CFG: {args.use_camera_cfg}")
1069
+ if args.use_camera_cfg:
1070
+ print(f"Camera guidance scale: {args.camera_guidance_scale}")
1071
+ print(f"使用GT Prompt: {args.use_gt_prompt}")
1072
+ print(f"Text guidance scale: {args.text_guidance_scale}")
1073
+ print(f"MoE配置: experts={args.moe_num_experts}, top_k={args.moe_top_k}")
1074
+ print(f"DiT{args.dit_path}")
1075
+
1076
+ # 验证NuScenes参数
1077
+ if args.modality_type == "nuscenes" and not args.scene_info_path:
1078
+ print("⚠️ 使用NuScenes模态但未提供scene_info_path,将使用合成pose数据")
1079
+
1080
+ inference_moe_framepack_sliding_window(
1081
+ condition_pth_path=args.condition_pth,
1082
+ dit_path=args.dit_path,
1083
+ output_path=args.output_path,
1084
+ start_frame=args.start_frame,
1085
+ initial_condition_frames=args.initial_condition_frames,
1086
+ frames_per_generation=args.frames_per_generation,
1087
+ total_frames_to_generate=args.total_frames_to_generate,
1088
+ max_history_frames=args.max_history_frames,
1089
+ device=args.device,
1090
+ prompt=args.prompt,
1091
+ modality_type=args.modality_type,
1092
+ use_real_poses=args.use_real_poses,
1093
+ scene_info_path=args.scene_info_path,
1094
+ # CFG参数
1095
+ use_camera_cfg=args.use_camera_cfg,
1096
+ camera_guidance_scale=args.camera_guidance_scale,
1097
+ text_guidance_scale=args.text_guidance_scale,
1098
+ # MoE参数
1099
+ moe_num_experts=args.moe_num_experts,
1100
+ moe_top_k=args.moe_top_k,
1101
+ moe_hidden_dim=args.moe_hidden_dim,
1102
+ direction=args.direction,
1103
+ use_gt_prompt=args.use_gt_prompt
1104
+ )
1105
+
1106
+
1107
+ if __name__ == "__main__":
1108
+ main()
scripts/infer_recam.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import torch.nn as nn
4
+ from diffsynth import ModelManager, WanVideoReCamMasterPipeline, save_video, VideoData
5
+ import torch, os, imageio, argparse
6
+ from torchvision.transforms import v2
7
+ from einops import rearrange
8
+ import pandas as pd
9
+ import torchvision
10
+ from PIL import Image
11
+ import numpy as np
12
+ import json
13
+
14
+ class Camera(object):
15
+ def __init__(self, c2w):
16
+ c2w_mat = np.array(c2w).reshape(4, 4)
17
+ self.c2w_mat = c2w_mat
18
+ self.w2c_mat = np.linalg.inv(c2w_mat)
19
+
20
+ class TextVideoCameraDataset(torch.utils.data.Dataset):
21
+ def __init__(self, base_path, metadata_path, args, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, condition_frames=40, target_frames=20):
22
+ metadata = pd.read_csv(metadata_path)
23
+ self.path = [os.path.join(base_path, "videos", file_name) for file_name in metadata["file_name"]]
24
+ self.text = metadata["text"].to_list()
25
+
26
+ self.max_num_frames = max_num_frames
27
+ self.frame_interval = frame_interval
28
+ self.num_frames = num_frames
29
+ self.height = height
30
+ self.width = width
31
+ self.is_i2v = is_i2v
32
+ self.args = args
33
+ self.cam_type = self.args.cam_type
34
+
35
+ # 🔧 新增:保存帧数配置
36
+ self.condition_frames = condition_frames
37
+ self.target_frames = target_frames
38
+
39
+ self.frame_process = v2.Compose([
40
+ v2.CenterCrop(size=(height, width)),
41
+ v2.Resize(size=(height, width), antialias=True),
42
+ v2.ToTensor(),
43
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
44
+ ])
45
+
46
+ def crop_and_resize(self, image):
47
+ width, height = image.size
48
+ scale = max(self.width / width, self.height / height)
49
+ image = torchvision.transforms.functional.resize(
50
+ image,
51
+ (round(height*scale), round(width*scale)),
52
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR
53
+ )
54
+ return image
55
+
56
+ def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
57
+ reader = imageio.get_reader(file_path)
58
+ if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
59
+ reader.close()
60
+ return None
61
+
62
+ frames = []
63
+ first_frame = None
64
+ for frame_id in range(num_frames):
65
+ frame = reader.get_data(start_frame_id + frame_id * interval)
66
+ frame = Image.fromarray(frame)
67
+ frame = self.crop_and_resize(frame)
68
+ if first_frame is None:
69
+ first_frame = np.array(frame)
70
+ frame = frame_process(frame)
71
+ frames.append(frame)
72
+ reader.close()
73
+
74
+ frames = torch.stack(frames, dim=0)
75
+ frames = rearrange(frames, "T C H W -> C T H W")
76
+
77
+ if self.is_i2v:
78
+ return frames, first_frame
79
+ else:
80
+ return frames
81
+
82
+ def is_image(self, file_path):
83
+ file_ext_name = file_path.split(".")[-1]
84
+ if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]:
85
+ return True
86
+ return False
87
+
88
+ def load_video(self, file_path):
89
+ start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0]
90
+ frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process)
91
+ return frames
92
+
93
+ def parse_matrix(self, matrix_str):
94
+ rows = matrix_str.strip().split('] [')
95
+ matrix = []
96
+ for row in rows:
97
+ row = row.replace('[', '').replace(']', '')
98
+ matrix.append(list(map(float, row.split())))
99
+ return np.array(matrix)
100
+
101
+ def get_relative_pose(self, cam_params):
102
+ abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
103
+ abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
104
+
105
+ cam_to_origin = 0
106
+ target_cam_c2w = np.array([
107
+ [1, 0, 0, 0],
108
+ [0, 1, 0, -cam_to_origin],
109
+ [0, 0, 1, 0],
110
+ [0, 0, 0, 1]
111
+ ])
112
+ abs2rel = target_cam_c2w @ abs_w2cs[0]
113
+ ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
114
+ ret_poses = np.array(ret_poses, dtype=np.float32)
115
+ return ret_poses
116
+
117
+ def __getitem__(self, data_id):
118
+ text = self.text[data_id]
119
+ path = self.path[data_id]
120
+ video = self.load_video(path)
121
+ if video is None:
122
+ raise ValueError(f"{path} is not a valid video.")
123
+ num_frames = video.shape[1]
124
+ assert num_frames == 81
125
+ data = {"text": text, "video": video, "path": path}
126
+
127
+ # load camera
128
+ tgt_camera_path = "./example_test_data/cameras/camera_extrinsics.json"
129
+ with open(tgt_camera_path, 'r') as file:
130
+ cam_data = json.load(file)
131
+
132
+ # 🔧 修改:生成target_frames长度的相机轨迹
133
+ cam_idx = np.linspace(0, 80, self.target_frames, dtype=int).tolist() # 改为target_frames长度
134
+ traj = [self.parse_matrix(cam_data[f"frame{idx}"][f"cam{int(self.cam_type):02d}"]) for idx in cam_idx]
135
+ traj = np.stack(traj).transpose(0, 2, 1)
136
+ c2ws = []
137
+ for c2w in traj:
138
+ c2w = c2w[:, [1, 2, 0, 3]]
139
+ c2w[:3, 1] *= -1.
140
+ c2w[:3, 3] /= 100
141
+ c2ws.append(c2w)
142
+ tgt_cam_params = [Camera(cam_param) for cam_param in c2ws]
143
+ relative_poses = []
144
+ for i in range(len(tgt_cam_params)):
145
+ relative_pose = self.get_relative_pose([tgt_cam_params[0], tgt_cam_params[i]])
146
+ relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1])
147
+ pose_embedding = torch.stack(relative_poses, dim=0) # [target_frames, 3, 4]
148
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [target_frames, 12]
149
+ data['camera'] = pose_embedding.to(torch.bfloat16)
150
+ return data
151
+
152
+ def __len__(self):
153
+ return len(self.path)
154
+
155
+ def parse_args():
156
+ parser = argparse.ArgumentParser(description="ReCamMaster Inference")
157
+ parser.add_argument(
158
+ "--dataset_path",
159
+ type=str,
160
+ default="./example_test_data",
161
+ help="The path of the Dataset.",
162
+ )
163
+ parser.add_argument(
164
+ "--ckpt_path",
165
+ type=str,
166
+ default="/share_zhuyixuan05/zhuyixuan05/recam_future_checkpoint/step1000.ckpt",
167
+ help="Path to save the model.",
168
+ )
169
+ parser.add_argument(
170
+ "--output_dir",
171
+ type=str,
172
+ default="./results",
173
+ help="Path to save the results.",
174
+ )
175
+ parser.add_argument(
176
+ "--dataloader_num_workers",
177
+ type=int,
178
+ default=1,
179
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
180
+ )
181
+ parser.add_argument(
182
+ "--cam_type",
183
+ type=str,
184
+ default=1,
185
+ )
186
+ parser.add_argument(
187
+ "--cfg_scale",
188
+ type=float,
189
+ default=5.0,
190
+ )
191
+ # 🔧 新增:condition和target帧数参数
192
+ parser.add_argument(
193
+ "--condition_frames",
194
+ type=int,
195
+ default=15,
196
+ help="Number of condition frames",
197
+ )
198
+ parser.add_argument(
199
+ "--target_frames",
200
+ type=int,
201
+ default=15,
202
+ help="Number of target frames to generate",
203
+ )
204
+ args = parser.parse_args()
205
+ return args
206
+
207
+ if __name__ == '__main__':
208
+ args = parse_args()
209
+
210
+ # 1. Load Wan2.1 pre-trained models
211
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
212
+ model_manager.load_models([
213
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
214
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
215
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
216
+ ])
217
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
218
+
219
+ # 2. Initialize additional modules introduced in ReCamMaster
220
+ dim=pipe.dit.blocks[0].self_attn.q.weight.shape[0]
221
+ for block in pipe.dit.blocks:
222
+ block.cam_encoder = nn.Linear(12, dim)
223
+ block.projector = nn.Linear(dim, dim)
224
+ block.cam_encoder.weight.data.zero_()
225
+ block.cam_encoder.bias.data.zero_()
226
+ block.projector.weight = nn.Parameter(torch.eye(dim))
227
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
228
+
229
+ # 3. Load ReCamMaster checkpoint
230
+ state_dict = torch.load(args.ckpt_path, map_location="cpu")
231
+ pipe.dit.load_state_dict(state_dict, strict=True)
232
+ pipe.to("cuda")
233
+ pipe.to(dtype=torch.bfloat16)
234
+
235
+ output_dir = os.path.join(args.output_dir, f"cam_type{args.cam_type}")
236
+ if not os.path.exists(output_dir):
237
+ os.makedirs(output_dir)
238
+
239
+ # 4. Prepare test data (source video, target camera, target trajectory)
240
+ dataset = TextVideoCameraDataset(
241
+ args.dataset_path,
242
+ os.path.join(args.dataset_path, "metadata.csv"),
243
+ args,
244
+ condition_frames=args.condition_frames, # 🔧 传递参数
245
+ target_frames=args.target_frames, # 🔧 传递参数
246
+ )
247
+ dataloader = torch.utils.data.DataLoader(
248
+ dataset,
249
+ shuffle=False,
250
+ batch_size=1,
251
+ num_workers=args.dataloader_num_workers
252
+ )
253
+
254
+ # 5. Inference
255
+ for batch_idx, batch in enumerate(dataloader):
256
+ target_text = batch["text"]
257
+ source_video = batch["video"]
258
+ target_camera = batch["camera"]
259
+
260
+ video = pipe(
261
+ prompt=target_text,
262
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的��景,三条腿,背景人很多,倒着走",
263
+ source_video=source_video,
264
+ target_camera=target_camera,
265
+ cfg_scale=args.cfg_scale,
266
+ num_inference_steps=50,
267
+ seed=0,
268
+ tiled=True,
269
+ condition_frames=args.condition_frames,
270
+ target_frames=args.target_frames,
271
+ )
272
+ save_video(video, os.path.join(output_dir, f"video{batch_idx}.mp4"), fps=30, quality=5)
scripts/infer_rlbench.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import imageio
6
+ import json
7
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
8
+ import argparse
9
+ from torchvision.transforms import v2
10
+ from einops import rearrange
11
+ import torch.nn as nn
12
+
13
+
14
+ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
15
+ """
16
+ 从pth文件加载预编码的视频数据
17
+ Args:
18
+ pth_path: pth文件路径
19
+ start_frame: 起始帧索引(基于压缩后的latent帧数)
20
+ num_frames: 需要的帧数(基于压缩后的latent帧数)
21
+ Returns:
22
+ condition_latents: [C, T, H, W] 格式的latent tensor
23
+ """
24
+ print(f"Loading encoded video from {pth_path}")
25
+
26
+ # 加载编码数据
27
+ encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
28
+
29
+ # 获取latent数据
30
+ full_latents = encoded_data['latents'] # [C, T, H, W]
31
+
32
+ print(f"Full latents shape: {full_latents.shape}")
33
+ print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
34
+
35
+ # 检查帧数是否足够
36
+ if start_frame + num_frames > full_latents.shape[1]:
37
+ raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
38
+
39
+ # 提取指定帧数
40
+ condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
41
+
42
+ print(f"Extracted condition latents shape: {condition_latents.shape}")
43
+
44
+ return condition_latents, encoded_data
45
+
46
+
47
+ def compute_relative_pose(pose_a, pose_b, use_torch=False):
48
+ """
49
+ 计算相机B相对于相机A的相对位姿矩阵
50
+ """
51
+ assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
52
+ assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
53
+
54
+ if use_torch:
55
+ if not isinstance(pose_a, torch.Tensor):
56
+ pose_a = torch.from_numpy(pose_a).float()
57
+ if not isinstance(pose_b, torch.Tensor):
58
+ pose_b = torch.from_numpy(pose_b).float()
59
+
60
+ pose_a_inv = torch.inverse(pose_a)
61
+ relative_pose = torch.matmul(pose_b, pose_a_inv)
62
+ else:
63
+ if not isinstance(pose_a, np.ndarray):
64
+ pose_a = np.array(pose_a, dtype=np.float32)
65
+ if not isinstance(pose_b, np.ndarray):
66
+ pose_b = np.array(pose_b, dtype=np.float32)
67
+
68
+ pose_a_inv = np.linalg.inv(pose_a)
69
+ relative_pose = np.matmul(pose_b, pose_a_inv)
70
+
71
+ return relative_pose
72
+
73
+
74
+ def generate_camera_poses_from_data(cam_data, start_frame, condition_frames, target_frames):
75
+ """
76
+ 从实际相机数据生成pose embeddings
77
+ Args:
78
+ cam_data: 相机外参数据
79
+ start_frame: 起始帧(原始帧索引)
80
+ condition_frames: 条件帧数(压缩后)
81
+ target_frames: 目标帧数(压缩后)
82
+ """
83
+ time_compression_ratio = 4
84
+ total_frames = condition_frames + target_frames
85
+
86
+ # 获取相机外参序列
87
+ cam_extrinsic = cam_data # [N, 4, 4]
88
+
89
+ # 计算原始帧索引
90
+ start_frame_original = start_frame * time_compression_ratio
91
+ end_frame_original = (start_frame + total_frames) * time_compression_ratio
92
+
93
+ print(f"Using camera data from frame {start_frame_original} to {end_frame_original}")
94
+
95
+ # 计算相对pose
96
+ relative_poses = []
97
+ for i in range(total_frames):
98
+ frame_idx = start_frame_original + i * time_compression_ratio
99
+ next_frame_idx = frame_idx + time_compression_ratio
100
+
101
+
102
+ cam_prev = cam_extrinsic[frame_idx]
103
+
104
+
105
+
106
+ relative_poses.append(torch.as_tensor(cam_prev)) # 取前3行
107
+
108
+ print(cam_prev)
109
+ # 组装pose embedding
110
+ pose_embedding = torch.stack(relative_poses, dim=0)
111
+ # print('pose_embedding init:',pose_embedding[0])
112
+ print('pose_embedding:',pose_embedding)
113
+ # assert False
114
+
115
+ # pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12]
116
+
117
+ # 添加mask信息
118
+ mask = torch.zeros(total_frames, dtype=torch.float32)
119
+ mask[:condition_frames] = 1.0 # condition frames
120
+ mask = mask.view(-1, 1)
121
+
122
+ # 组合pose和mask
123
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13]
124
+
125
+ print(f"Generated camera embedding shape: {camera_embedding.shape}")
126
+
127
+ return camera_embedding.to(torch.bfloat16)
128
+
129
+
130
+ def generate_camera_poses(direction="forward", target_frames=10, condition_frames=20):
131
+ """
132
+ 根据指定方向生成相机pose序列(合成数据)
133
+ """
134
+ time_compression_ratio = 4
135
+ total_frames = condition_frames + target_frames
136
+
137
+ poses = []
138
+
139
+ for i in range(total_frames):
140
+ t = i / max(1, total_frames - 1) # 0 to 1
141
+
142
+ # 创建变换矩阵
143
+ pose = np.eye(4, dtype=np.float32)
144
+
145
+ if direction == "forward":
146
+ # 前进:沿z轴负方向移动
147
+ pose[2, 3] = -t * 0.04
148
+ print('forward!')
149
+
150
+ elif direction == "backward":
151
+ # 后退:沿z轴正方向移动
152
+ pose[2, 3] = t * 2.0
153
+
154
+ elif direction == "left_turn":
155
+ # 左转:前进 + 绕y轴旋转
156
+ pose[2, 3] = -t * 0.03 # 前进
157
+ pose[0, 3] = t * 0.02 # 左移
158
+ # 添加旋转
159
+ yaw = t * 1
160
+ pose[0, 0] = np.cos(yaw)
161
+ pose[0, 2] = np.sin(yaw)
162
+ pose[2, 0] = -np.sin(yaw)
163
+ pose[2, 2] = np.cos(yaw)
164
+
165
+ elif direction == "right_turn":
166
+ # 右转:前进 + 绕y轴反向旋转
167
+ pose[2, 3] = -t * 0.03 # 前进
168
+ pose[0, 3] = -t * 0.02 # 右移
169
+ # 添加旋转
170
+ yaw = - t * 1
171
+ pose[0, 0] = np.cos(yaw)
172
+ pose[0, 2] = np.sin(yaw)
173
+ pose[2, 0] = -np.sin(yaw)
174
+ pose[2, 2] = np.cos(yaw)
175
+
176
+ poses.append(pose)
177
+
178
+ # 计算相对pose
179
+ relative_poses = []
180
+ for i in range(len(poses) - 1):
181
+ relative_pose = compute_relative_pose(poses[i], poses[i + 1])
182
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :])) # 取前3行
183
+
184
+ # 为了匹配模型输入,需要确保帧数正确
185
+ if len(relative_poses) < total_frames:
186
+ # 补充最后一帧
187
+ relative_poses.append(relative_poses[-1])
188
+
189
+ pose_embedding = torch.stack(relative_poses[:total_frames], dim=0)
190
+
191
+ print('pose_embedding init:',pose_embedding[0])
192
+
193
+ print('pose_embedding:',pose_embedding[-5:])
194
+
195
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12]
196
+
197
+ # 添加mask信息
198
+ mask = torch.zeros(total_frames, dtype=torch.float32)
199
+ mask[:condition_frames] = 1.0 # condition frames
200
+ mask = mask.view(-1, 1)
201
+
202
+ # 组合pose和mask
203
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13]
204
+
205
+ print(f"Generated {direction} movement poses:")
206
+ print(f" Total frames: {total_frames}")
207
+ print(f" Camera embedding shape: {camera_embedding.shape}")
208
+
209
+ return camera_embedding.to(torch.bfloat16)
210
+
211
+
212
+ def inference_sekai_video_from_pth(
213
+ condition_pth_path,
214
+ dit_path,
215
+ output_path="sekai/infer_results/output_sekai.mp4",
216
+ start_frame=0,
217
+ condition_frames=10, # 压缩后的帧数
218
+ target_frames=2, # 压缩后的帧数
219
+ device="cuda",
220
+ prompt="a robotic arm executing precise manipulation tasks on a clean, organized desk",
221
+ direction="forward",
222
+ use_real_poses=True
223
+ ):
224
+ """
225
+ 从pth文件进行Sekai视频推理
226
+ """
227
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
228
+
229
+ print(f"Setting up models for {direction} movement...")
230
+
231
+ # 1. Load models
232
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
233
+ model_manager.load_models([
234
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
235
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
236
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
237
+ ])
238
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
239
+
240
+ # Add camera components to DiT
241
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
242
+ for block in pipe.dit.blocks:
243
+ block.cam_encoder = nn.Linear(30, dim) # 13维embedding (12D pose + 1D mask)
244
+ block.projector = nn.Linear(dim, dim)
245
+ block.cam_encoder.weight.data.zero_()
246
+ block.cam_encoder.bias.data.zero_()
247
+ block.projector.weight = nn.Parameter(torch.eye(dim))
248
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
249
+
250
+ # Load trained DiT weights
251
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
252
+ pipe.dit.load_state_dict(dit_state_dict, strict=True)
253
+ pipe = pipe.to(device)
254
+ pipe.scheduler.set_timesteps(50)
255
+
256
+ print("Loading condition video from pth...")
257
+
258
+ # Load condition video from pth
259
+ condition_latents, encoded_data = load_encoded_video_from_pth(
260
+ condition_pth_path,
261
+ start_frame=start_frame,
262
+ num_frames=condition_frames
263
+ )
264
+
265
+ condition_latents = condition_latents.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
266
+
267
+ print("Processing poses...")
268
+
269
+ # 生成相机pose embedding
270
+ if use_real_poses and 'cam_emb' in encoded_data:
271
+ print("Using real camera poses from data")
272
+ camera_embedding = generate_camera_poses_from_data(
273
+ encoded_data['cam_emb'],
274
+ start_frame=start_frame,
275
+ condition_frames=condition_frames,
276
+ target_frames=target_frames
277
+ )
278
+ else:
279
+ print(f"Using synthetic {direction} poses")
280
+ camera_embedding = generate_camera_poses(
281
+ direction=direction,
282
+ target_frames=target_frames,
283
+ condition_frames=condition_frames
284
+ )
285
+
286
+
287
+
288
+ camera_embedding = camera_embedding.unsqueeze(0).to(device, dtype=torch.bfloat16)
289
+
290
+ print(f"Camera embedding shape: {camera_embedding.shape}")
291
+
292
+ print("Encoding prompt...")
293
+
294
+ # Encode text prompt
295
+ prompt_emb = pipe.encode_prompt(prompt)
296
+
297
+ print("Generating video...")
298
+
299
+ # Generate target latents
300
+ batch_size = 1
301
+ channels = condition_latents.shape[1]
302
+ latent_height = condition_latents.shape[3]
303
+ latent_width = condition_latents.shape[4]
304
+
305
+ # 空间裁剪以节省内存(如果需要)
306
+ target_height, target_width = 64, 64
307
+
308
+ if latent_height > target_height or latent_width > target_width:
309
+ # 中心裁剪
310
+ h_start = (latent_height - target_height) // 2
311
+ w_start = (latent_width - target_width) // 2
312
+ condition_latents = condition_latents[:, :, :,
313
+ h_start:h_start+target_height,
314
+ w_start:w_start+target_width]
315
+ latent_height = target_height
316
+ latent_width = target_width
317
+
318
+ # Initialize target latents with noise
319
+ target_latents = torch.randn(
320
+ batch_size, channels, target_frames, latent_height, latent_width,
321
+ device=device, dtype=pipe.torch_dtype
322
+ )
323
+
324
+ print(f"Condition latents shape: {condition_latents.shape}")
325
+ print(f"Target latents shape: {target_latents.shape}")
326
+ print(f"Camera embedding shape: {camera_embedding.shape}")
327
+
328
+ # Combine condition and target latents
329
+ combined_latents = torch.cat([condition_latents, target_latents], dim=2)
330
+ print(f"Combined latents shape: {combined_latents.shape}")
331
+
332
+ # Prepare extra inputs
333
+ extra_input = pipe.prepare_extra_input(combined_latents)
334
+
335
+ # Denoising loop
336
+ timesteps = pipe.scheduler.timesteps
337
+
338
+ for i, timestep in enumerate(timesteps):
339
+ print(f"Denoising step {i+1}/{len(timesteps)}")
340
+
341
+ # Prepare timestep
342
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
343
+
344
+ # Predict noise
345
+ with torch.no_grad():
346
+ noise_pred = pipe.dit(
347
+ combined_latents,
348
+ timestep=timestep_tensor,
349
+ cam_emb=camera_embedding,
350
+ **prompt_emb,
351
+ **extra_input
352
+ )
353
+
354
+ # Update only target part
355
+ target_noise_pred = noise_pred[:, :, condition_frames:, :, :]
356
+ target_latents = pipe.scheduler.step(target_noise_pred, timestep, target_latents)
357
+
358
+ # Update combined latents
359
+ combined_latents[:, :, condition_frames:, :, :] = target_latents
360
+
361
+ print("Decoding video...")
362
+
363
+ # Decode final video
364
+ final_video = torch.cat([condition_latents, target_latents], dim=2)
365
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
366
+
367
+ # Save video
368
+ print(f"Saving video to {output_path}")
369
+
370
+ # Convert to numpy and save
371
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
372
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1) # Denormalize
373
+ video_np = (video_np * 255).astype(np.uint8)
374
+
375
+ with imageio.get_writer(output_path, fps=20) as writer:
376
+ for frame in video_np:
377
+ writer.append_data(frame)
378
+
379
+ print(f"Video generation completed! Saved to {output_path}")
380
+
381
+
382
+ def main():
383
+ parser = argparse.ArgumentParser(description="Sekai Video Generation Inference from PTH")
384
+ parser.add_argument("--condition_pth", type=str,
385
+ default="/share_zhuyixuan05/zhuyixuan05/rlbench/OpenBox_demo_49/encoded_video.pth")
386
+ parser.add_argument("--start_frame", type=int, default=0,
387
+ help="Starting frame index (compressed latent frames)")
388
+ parser.add_argument("--condition_frames", type=int, default=8,
389
+ help="Number of condition frames (compressed latent frames)")
390
+ parser.add_argument("--target_frames", type=int, default=8,
391
+ help="Number of target frames to generate (compressed latent frames)")
392
+ parser.add_argument("--direction", type=str, default="left_turn",
393
+ choices=["forward", "backward", "left_turn", "right_turn"],
394
+ help="Direction of camera movement (if not using real poses)")
395
+ parser.add_argument("--use_real_poses", default=False,
396
+ help="Use real camera poses from data")
397
+ parser.add_argument("--dit_path", type=str, default="/home/zhuyixuan05/ReCamMaster/RLBench-train/step2000_dynamic.ckpt",
398
+ help="Path to trained DiT checkpoint")
399
+ parser.add_argument("--output_path", type=str, default='/home/zhuyixuan05/ReCamMaster/rlbench/infer_results/output_rl_2.mp4',
400
+ help="Output video path")
401
+ parser.add_argument("--prompt", type=str,
402
+ default="a robotic arm executing precise manipulation tasks on a clean, organized desk",
403
+ help="Text prompt for generation")
404
+ parser.add_argument("--device", type=str, default="cuda",
405
+ help="Device to run inference on")
406
+
407
+ args = parser.parse_args()
408
+
409
+ # 生成输出路径
410
+ if args.output_path is None:
411
+ pth_filename = os.path.basename(args.condition_pth)
412
+ name_parts = os.path.splitext(pth_filename)
413
+ output_dir = "rlbench/infer_results"
414
+ os.makedirs(output_dir, exist_ok=True)
415
+
416
+ if args.use_real_poses:
417
+ output_filename = f"{name_parts[0]}_real_poses_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4"
418
+ else:
419
+ output_filename = f"{name_parts[0]}_{args.direction}_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4"
420
+
421
+ output_path = os.path.join(output_dir, output_filename)
422
+ else:
423
+ output_path = args.output_path
424
+
425
+ print(f"Input pth: {args.condition_pth}")
426
+ print(f"Start frame: {args.start_frame} (compressed)")
427
+ print(f"Condition frames: {args.condition_frames} (compressed, original: {args.condition_frames * 4})")
428
+ print(f"Target frames: {args.target_frames} (compressed, original: {args.target_frames * 4})")
429
+ print(f"Use real poses: {args.use_real_poses}")
430
+ print(f"Output video will be saved to: {output_path}")
431
+
432
+ inference_sekai_video_from_pth(
433
+ condition_pth_path=args.condition_pth,
434
+ dit_path=args.dit_path,
435
+ output_path=output_path,
436
+ start_frame=args.start_frame,
437
+ condition_frames=args.condition_frames,
438
+ target_frames=args.target_frames,
439
+ device=args.device,
440
+ prompt=args.prompt,
441
+ direction=args.direction,
442
+ use_real_poses=args.use_real_poses
443
+ )
444
+
445
+
446
+ if __name__ == "__main__":
447
+ main()
scripts/infer_sekai.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import imageio
6
+ import json
7
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
8
+ import argparse
9
+ from torchvision.transforms import v2
10
+ from einops import rearrange
11
+ import torch.nn as nn
12
+
13
+
14
+ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
15
+ """
16
+ 从pth文件加载预编码的视频数据
17
+ Args:
18
+ pth_path: pth文件路径
19
+ start_frame: 起始帧索引(基于压缩后的latent帧数)
20
+ num_frames: 需要的帧数(基于压缩后的latent帧数)
21
+ Returns:
22
+ condition_latents: [C, T, H, W] 格式的latent tensor
23
+ """
24
+ print(f"Loading encoded video from {pth_path}")
25
+
26
+ # 加载编码数据
27
+ encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
28
+
29
+ # 获取latent数据
30
+ full_latents = encoded_data['latents'] # [C, T, H, W]
31
+
32
+ print(f"Full latents shape: {full_latents.shape}")
33
+ print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
34
+
35
+ # 检查帧数是否足够
36
+ if start_frame + num_frames > full_latents.shape[1]:
37
+ raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
38
+
39
+ # 提取指定帧数
40
+ condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
41
+
42
+ print(f"Extracted condition latents shape: {condition_latents.shape}")
43
+
44
+ return condition_latents, encoded_data
45
+
46
+
47
+ def compute_relative_pose(pose_a, pose_b, use_torch=False):
48
+ """
49
+ 计算相机B相对于相机A的相对位姿矩阵
50
+ """
51
+ assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
52
+ assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
53
+
54
+ if use_torch:
55
+ if not isinstance(pose_a, torch.Tensor):
56
+ pose_a = torch.from_numpy(pose_a).float()
57
+ if not isinstance(pose_b, torch.Tensor):
58
+ pose_b = torch.from_numpy(pose_b).float()
59
+
60
+ pose_a_inv = torch.inverse(pose_a)
61
+ relative_pose = torch.matmul(pose_b, pose_a_inv)
62
+ else:
63
+ if not isinstance(pose_a, np.ndarray):
64
+ pose_a = np.array(pose_a, dtype=np.float32)
65
+ if not isinstance(pose_b, np.ndarray):
66
+ pose_b = np.array(pose_b, dtype=np.float32)
67
+
68
+ pose_a_inv = np.linalg.inv(pose_a)
69
+ relative_pose = np.matmul(pose_b, pose_a_inv)
70
+
71
+ return relative_pose
72
+
73
+
74
+ def generate_camera_poses_from_data(cam_data, start_frame, condition_frames, target_frames):
75
+ """
76
+ 从实际相机数据生成pose embeddings
77
+ Args:
78
+ cam_data: 相机外参数据
79
+ start_frame: 起始帧(原始帧索引)
80
+ condition_frames: 条件帧数(压缩后)
81
+ target_frames: 目标帧数(压缩后)
82
+ """
83
+ time_compression_ratio = 4
84
+ total_frames = condition_frames + target_frames
85
+
86
+ # 获取相机外参序列
87
+ cam_extrinsic = cam_data['extrinsic'] # [N, 4, 4]
88
+
89
+ # 计算原始帧索引
90
+ start_frame_original = start_frame * time_compression_ratio
91
+ end_frame_original = (start_frame + total_frames) * time_compression_ratio
92
+
93
+ print(f"Using camera data from frame {start_frame_original} to {end_frame_original}")
94
+
95
+ # 计算相对pose
96
+ relative_poses = []
97
+ for i in range(total_frames):
98
+ frame_idx = start_frame_original + i * time_compression_ratio
99
+ next_frame_idx = frame_idx + time_compression_ratio
100
+
101
+ if next_frame_idx >= len(cam_extrinsic):
102
+ print('out of temporal range!!!')
103
+ # 如果超出范围,使用最后一个可用的pose
104
+ relative_poses.append(relative_poses[-1] if relative_poses else torch.zeros(3, 4))
105
+ else:
106
+ cam_prev = cam_extrinsic[frame_idx]
107
+ cam_next = cam_extrinsic[next_frame_idx]
108
+
109
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
110
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :])) # 取前3行
111
+
112
+ print(cam_prev)
113
+ # 组装pose embedding
114
+ pose_embedding = torch.stack(relative_poses, dim=0)
115
+ # print('pose_embedding init:',pose_embedding[0])
116
+ print('pose_embedding:',pose_embedding)
117
+ assert False
118
+
119
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12]
120
+
121
+ # 添加mask信息
122
+ mask = torch.zeros(total_frames, dtype=torch.float32)
123
+ mask[:condition_frames] = 1.0 # condition frames
124
+ mask = mask.view(-1, 1)
125
+
126
+ # 组合pose和mask
127
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13]
128
+
129
+ print(f"Generated camera embedding shape: {camera_embedding.shape}")
130
+
131
+ return camera_embedding.to(torch.bfloat16)
132
+
133
+
134
+ def generate_camera_poses(direction="forward", target_frames=10, condition_frames=20):
135
+ """
136
+ 根据指定方向生成相机pose序列(合成数据)
137
+ """
138
+ time_compression_ratio = 4
139
+ total_frames = condition_frames + target_frames
140
+
141
+ poses = []
142
+
143
+ for i in range(total_frames):
144
+ t = i / max(1, total_frames - 1) # 0 to 1
145
+
146
+ # 创建变换矩阵
147
+ pose = np.eye(4, dtype=np.float32)
148
+
149
+ if direction == "forward":
150
+ # 前进:沿z轴负方向移动
151
+ pose[2, 3] = -t * 0.04
152
+ print('forward!')
153
+
154
+ elif direction == "backward":
155
+ # 后退:沿z轴正方向移动
156
+ pose[2, 3] = t * 2.0
157
+
158
+ elif direction == "left_turn":
159
+ # 左转:前进 + 绕y轴旋转
160
+ pose[2, 3] = -t * 0.03 # 前进
161
+ pose[0, 3] = t * 0.02 # 左移
162
+ # 添加旋转
163
+ yaw = t * 1
164
+ pose[0, 0] = np.cos(yaw)
165
+ pose[0, 2] = np.sin(yaw)
166
+ pose[2, 0] = -np.sin(yaw)
167
+ pose[2, 2] = np.cos(yaw)
168
+
169
+ elif direction == "right_turn":
170
+ # 右转:前进 + 绕y轴反向旋转
171
+ pose[2, 3] = -t * 0.03 # 前进
172
+ pose[0, 3] = -t * 0.02 # 右移
173
+ # 添加旋转
174
+ yaw = - t * 1
175
+ pose[0, 0] = np.cos(yaw)
176
+ pose[0, 2] = np.sin(yaw)
177
+ pose[2, 0] = -np.sin(yaw)
178
+ pose[2, 2] = np.cos(yaw)
179
+
180
+ poses.append(pose)
181
+
182
+ # 计算相对pose
183
+ relative_poses = []
184
+ for i in range(len(poses) - 1):
185
+ relative_pose = compute_relative_pose(poses[i], poses[i + 1])
186
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :])) # 取前3行
187
+
188
+ # 为了匹配模型输入,需要确保帧数正确
189
+ if len(relative_poses) < total_frames:
190
+ # 补充最后一帧
191
+ relative_poses.append(relative_poses[-1])
192
+
193
+ pose_embedding = torch.stack(relative_poses[:total_frames], dim=0)
194
+
195
+ print('pose_embedding init:',pose_embedding[0])
196
+
197
+ print('pose_embedding:',pose_embedding[-5:])
198
+
199
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12]
200
+
201
+ # 添加mask信息
202
+ mask = torch.zeros(total_frames, dtype=torch.float32)
203
+ mask[:condition_frames] = 1.0 # condition frames
204
+ mask = mask.view(-1, 1)
205
+
206
+ # 组合pose和mask
207
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13]
208
+
209
+ print(f"Generated {direction} movement poses:")
210
+ print(f" Total frames: {total_frames}")
211
+ print(f" Camera embedding shape: {camera_embedding.shape}")
212
+
213
+ return camera_embedding.to(torch.bfloat16)
214
+
215
+
216
+ def inference_sekai_video_from_pth(
217
+ condition_pth_path,
218
+ dit_path,
219
+ output_path="sekai/infer_results/output_sekai.mp4",
220
+ start_frame=0,
221
+ condition_frames=10, # 压缩后的帧数
222
+ target_frames=2, # 压缩后的帧数
223
+ device="cuda",
224
+ prompt="A video of a scene shot using a pedestrian's front camera while walking",
225
+ direction="forward",
226
+ use_real_poses=True
227
+ ):
228
+ """
229
+ 从pth文件进行Sekai视频推理
230
+ """
231
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
232
+
233
+ print(f"Setting up models for {direction} movement...")
234
+
235
+ # 1. Load models
236
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
237
+ model_manager.load_models([
238
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
239
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
240
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
241
+ ])
242
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
243
+
244
+ # Add camera components to DiT
245
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
246
+ for block in pipe.dit.blocks:
247
+ block.cam_encoder = nn.Linear(13, dim) # 13维embedding (12D pose + 1D mask)
248
+ block.projector = nn.Linear(dim, dim)
249
+ block.cam_encoder.weight.data.zero_()
250
+ block.cam_encoder.bias.data.zero_()
251
+ block.projector.weight = nn.Parameter(torch.eye(dim))
252
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
253
+
254
+ # Load trained DiT weights
255
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
256
+ pipe.dit.load_state_dict(dit_state_dict, strict=True)
257
+ pipe = pipe.to(device)
258
+ pipe.scheduler.set_timesteps(50)
259
+
260
+ print("Loading condition video from pth...")
261
+
262
+ # Load condition video from pth
263
+ condition_latents, encoded_data = load_encoded_video_from_pth(
264
+ condition_pth_path,
265
+ start_frame=start_frame,
266
+ num_frames=condition_frames
267
+ )
268
+
269
+ condition_latents = condition_latents.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
270
+
271
+ print("Processing poses...")
272
+
273
+ # 生成相机pose embedding
274
+ if use_real_poses and 'cam_emb' in encoded_data:
275
+ print("Using real camera poses from data")
276
+ camera_embedding = generate_camera_poses_from_data(
277
+ encoded_data['cam_emb'],
278
+ start_frame=start_frame,
279
+ condition_frames=condition_frames,
280
+ target_frames=target_frames
281
+ )
282
+ else:
283
+ print(f"Using synthetic {direction} poses")
284
+ camera_embedding = generate_camera_poses(
285
+ direction=direction,
286
+ target_frames=target_frames,
287
+ condition_frames=condition_frames
288
+ )
289
+
290
+ # camera_embedding = torch.tensor([
291
+ # [ 9.9992e-01, 5.7823e-04, -1.2807e-02, -6.4978e-03, -6.1466e-04,
292
+ # 1.0000e+00, -2.8406e-03, -7.1422e-04, 1.2806e-02, 2.8482e-03,
293
+ # 9.9991e-01, -1.4152e-02, 1.0000e+00],
294
+ # [ 9.9993e-01, 5.0678e-04, -1.1601e-02, -5.7938e-03, -5.3597e-04,
295
+ # 1.0000e+00, -2.5129e-03, -5.6941e-04, 1.1600e-02, 2.5189e-03,
296
+ # 9.9993e-01, -1.4078e-02, 1.0000e+00],
297
+ # [ 9.9992e-01, 4.4420e-04, -1.2374e-02, -6.2723e-03, -4.8356e-04,
298
+ # 9.9999e-01, -3.1780e-03, -1.0313e-03, 1.2372e-02, 3.1837e-03,
299
+ # 9.9992e-01, -1.4170e-02, 1.0000e+00],
300
+ # [ 9.9997e-01, 2.6684e-04, -7.1423e-03, -2.9546e-03, -2.7965e-04,
301
+ # 1.0000e+00, -1.7922e-03, -2.0437e-04, 7.1418e-03, 1.7942e-03,
302
+ # 9.9997e-01, -1.3811e-02, 1.0000e+00],
303
+ # [ 9.9999e-01, 1.5524e-04, -4.1128e-03, -9.7896e-04, -1.5948e-04,
304
+ # 1.0000e+00, -1.0322e-03, 2.5742e-04, 4.1126e-03, 1.0328e-03,
305
+ # 9.9999e-01, -1.3608e-02, 1.0000e+00],
306
+ # [ 1.0000e+00, 8.9919e-05, -2.3684e-03, 1.8947e-04, -9.1325e-05,
307
+ # 1.0000e+00, -5.9445e-04, 5.2862e-04, 2.3683e-03, 5.9466e-04,
308
+ # 1.0000e+00, -1.3490e-02, 1.0000e+00],
309
+ # [ 1.0000e+00, 5.1932e-05, -1.3635e-03, 8.8221e-04, -5.2401e-05,
310
+ # 1.0000e+00, -3.4229e-04, 6.8774e-04, 1.3635e-03, 3.4236e-04,
311
+ # 1.0000e+00, -1.3419e-02, 1.0000e+00],
312
+ # [ 1.0000e+00, 2.9971e-05, -7.8533e-04, 1.2923e-03, -3.0129e-05,
313
+ # 1.0000e+00, -1.9714e-04, 7.8124e-04, 7.8534e-04, 1.9716e-04,
314
+ # 1.0000e+00, -1.3378e-02, 1.0000e+00],
315
+ # [ 1.0000e+00, 1.7271e-05, -4.5211e-04, 1.5351e-03, -1.7318e-05,
316
+ # 1.0000e+00, -1.1352e-04, 8.3586e-04, 4.5211e-04, 1.1353e-04,
317
+ # 1.0000e+00, -1.3353e-02, 1.0000e+00],
318
+ # [ 1.0000e+00, 9.9305e-06, -2.5968e-04, 1.6798e-03, -9.9495e-06,
319
+ # 1.0000e+00, -6.5163e-05, 8.6798e-04, 2.5970e-04, 6.5163e-05,
320
+ # 1.0000e+00, -1.3338e-02, 1.0000e+00],
321
+ # [ 1.0000e+00, 1.4484e-05, -3.7806e-04, 1.5971e-03, -1.4521e-05,
322
+ # 1.0000e+00, -9.4604e-05, 8.4546e-04, 3.7804e-04, 9.4615e-05,
323
+ # 1.0000e+00, -1.3347e-02, 0.0000e+00],
324
+ # [ 1.0000e+00, 6.5319e-05, -9.4321e-04, 1.1732e-03, -6.5316e-05,
325
+ # 1.0000e+00, 5.4177e-06, 9.2146e-04, 9.4322e-04, -5.3641e-06,
326
+ # 1.0000e+00, -1.3372e-02, 0.0000e+00],
327
+ # [ 9.9999e-01, 2.5994e-04, -3.9389e-03, -1.0991e-03, -2.6020e-04,
328
+ # 1.0000e+00, -6.6082e-05, 8.7861e-04, 3.9388e-03, 6.7103e-05,
329
+ # 9.9999e-01, -1.3561e-02, 0.0000e+00],
330
+ # [ 9.9998e-01, 2.7008e-04, -6.8774e-03, -3.3641e-03, -2.7882e-04,
331
+ # 1.0000e+00, -1.2689e-03, -5.0134e-05, 6.8771e-03, 1.2708e-03,
332
+ # 9.9998e-01, -1.3853e-02, 0.0000e+00],
333
+ # [ 9.9996e-01, 4.6250e-04, -8.4143e-03, -4.5899e-03, -4.6835e-04,
334
+ # 1.0000e+00, -6.9268e-04, 3.9740e-04, 8.4139e-03, 6.9660e-04,
335
+ # 9.9996e-01, -1.3917e-02, 0.0000e+00]
336
+ #], dtype=torch.bfloat16, device=device)
337
+
338
+ camera_embedding = camera_embedding.unsqueeze(0).to(device, dtype=torch.bfloat16)
339
+
340
+ print(f"Camera embedding shape: {camera_embedding.shape}")
341
+
342
+ print("Encoding prompt...")
343
+
344
+ # Encode text prompt
345
+ prompt_emb = pipe.encode_prompt(prompt)
346
+
347
+ print("Generating video...")
348
+
349
+ # Generate target latents
350
+ batch_size = 1
351
+ channels = condition_latents.shape[1]
352
+ latent_height = condition_latents.shape[3]
353
+ latent_width = condition_latents.shape[4]
354
+
355
+ # 空间裁剪以节省内存(如果需要)
356
+ target_height, target_width = 60, 104
357
+
358
+ if latent_height > target_height or latent_width > target_width:
359
+ # 中心裁剪
360
+ h_start = (latent_height - target_height) // 2
361
+ w_start = (latent_width - target_width) // 2
362
+ condition_latents = condition_latents[:, :, :,
363
+ h_start:h_start+target_height,
364
+ w_start:w_start+target_width]
365
+ latent_height = target_height
366
+ latent_width = target_width
367
+
368
+ # Initialize target latents with noise
369
+ target_latents = torch.randn(
370
+ batch_size, channels, target_frames, latent_height, latent_width,
371
+ device=device, dtype=pipe.torch_dtype
372
+ )
373
+
374
+ print(f"Condition latents shape: {condition_latents.shape}")
375
+ print(f"Target latents shape: {target_latents.shape}")
376
+ print(f"Camera embedding shape: {camera_embedding.shape}")
377
+
378
+ # Combine condition and target latents
379
+ combined_latents = torch.cat([condition_latents, target_latents], dim=2)
380
+ print(f"Combined latents shape: {combined_latents.shape}")
381
+
382
+ # Prepare extra inputs
383
+ extra_input = pipe.prepare_extra_input(combined_latents)
384
+
385
+ # Denoising loop
386
+ timesteps = pipe.scheduler.timesteps
387
+
388
+ for i, timestep in enumerate(timesteps):
389
+ print(f"Denoising step {i+1}/{len(timesteps)}")
390
+
391
+ # Prepare timestep
392
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=pipe.torch_dtype)
393
+
394
+ # Predict noise
395
+ with torch.no_grad():
396
+ noise_pred = pipe.dit(
397
+ combined_latents,
398
+ timestep=timestep_tensor,
399
+ cam_emb=camera_embedding,
400
+ **prompt_emb,
401
+ **extra_input
402
+ )
403
+
404
+ # Update only target part
405
+ target_noise_pred = noise_pred[:, :, condition_frames:, :, :]
406
+ target_latents = pipe.scheduler.step(target_noise_pred, timestep, target_latents)
407
+
408
+ # Update combined latents
409
+ combined_latents[:, :, condition_frames:, :, :] = target_latents
410
+
411
+ print("Decoding video...")
412
+
413
+ # Decode final video
414
+ final_video = torch.cat([condition_latents, target_latents], dim=2)
415
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
416
+
417
+ # Save video
418
+ print(f"Saving video to {output_path}")
419
+
420
+ # Convert to numpy and save
421
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
422
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1) # Denormalize
423
+ video_np = (video_np * 255).astype(np.uint8)
424
+
425
+ with imageio.get_writer(output_path, fps=20) as writer:
426
+ for frame in video_np:
427
+ writer.append_data(frame)
428
+
429
+ print(f"Video generation completed! Saved to {output_path}")
430
+
431
+
432
+ def main():
433
+ parser = argparse.ArgumentParser(description="Sekai Video Generation Inference from PTH")
434
+ parser.add_argument("--condition_pth", type=str,
435
+ default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
436
+ parser.add_argument("--start_frame", type=int, default=0,
437
+ help="Starting frame index (compressed latent frames)")
438
+ parser.add_argument("--condition_frames", type=int, default=8,
439
+ help="Number of condition frames (compressed latent frames)")
440
+ parser.add_argument("--target_frames", type=int, default=8,
441
+ help="Number of target frames to generate (compressed latent frames)")
442
+ parser.add_argument("--direction", type=str, default="left_turn",
443
+ choices=["forward", "backward", "left_turn", "right_turn"],
444
+ help="Direction of camera movement (if not using real poses)")
445
+ parser.add_argument("--use_real_poses", default=False,
446
+ help="Use real camera poses from data")
447
+ parser.add_argument("--dit_path", type=str, default="/home/zhuyixuan05/ReCamMaster/sekai_walking_noise/step14000_dynamic.ckpt",
448
+ help="Path to trained DiT checkpoint")
449
+ parser.add_argument("--output_path", type=str, default='/home/zhuyixuan05/ReCamMaster/sekai/infer_noise_results/output_sekai_right_turn.mp4',
450
+ help="Output video path")
451
+ parser.add_argument("--prompt", type=str,
452
+ default="A drone flying scene in a game world",
453
+ help="Text prompt for generation")
454
+ parser.add_argument("--device", type=str, default="cuda",
455
+ help="Device to run inference on")
456
+
457
+ args = parser.parse_args()
458
+
459
+ # 生成输出路径
460
+ if args.output_path is None:
461
+ pth_filename = os.path.basename(args.condition_pth)
462
+ name_parts = os.path.splitext(pth_filename)
463
+ output_dir = "sekai/infer_results"
464
+ os.makedirs(output_dir, exist_ok=True)
465
+
466
+ if args.use_real_poses:
467
+ output_filename = f"{name_parts[0]}_real_poses_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4"
468
+ else:
469
+ output_filename = f"{name_parts[0]}_{args.direction}_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4"
470
+
471
+ output_path = os.path.join(output_dir, output_filename)
472
+ else:
473
+ output_path = args.output_path
474
+
475
+ print(f"Input pth: {args.condition_pth}")
476
+ print(f"Start frame: {args.start_frame} (compressed)")
477
+ print(f"Condition frames: {args.condition_frames} (compressed, original: {args.condition_frames * 4})")
478
+ print(f"Target frames: {args.target_frames} (compressed, original: {args.target_frames * 4})")
479
+ print(f"Use real poses: {args.use_real_poses}")
480
+ print(f"Output video will be saved to: {output_path}")
481
+
482
+ inference_sekai_video_from_pth(
483
+ condition_pth_path=args.condition_pth,
484
+ dit_path=args.dit_path,
485
+ output_path=output_path,
486
+ start_frame=args.start_frame,
487
+ condition_frames=args.condition_frames,
488
+ target_frames=args.target_frames,
489
+ device=args.device,
490
+ prompt=args.prompt,
491
+ direction=args.direction,
492
+ use_real_poses=args.use_real_poses
493
+ )
494
+
495
+
496
+ if __name__ == "__main__":
497
+ main()
scripts/infer_sekai_framepack.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ import imageio
7
+ import json
8
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
9
+ import argparse
10
+ from torchvision.transforms import v2
11
+ from einops import rearrange
12
+ import copy
13
+
14
+
15
+ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
16
+ """从pth文件加载预编码的视频数据"""
17
+ print(f"Loading encoded video from {pth_path}")
18
+
19
+ encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
20
+ full_latents = encoded_data['latents'] # [C, T, H, W]
21
+
22
+ print(f"Full latents shape: {full_latents.shape}")
23
+ print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
24
+
25
+ if start_frame + num_frames > full_latents.shape[1]:
26
+ raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
27
+
28
+ condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
29
+ print(f"Extracted condition latents shape: {condition_latents.shape}")
30
+
31
+ return condition_latents, encoded_data
32
+
33
+
34
+ def compute_relative_pose(pose_a, pose_b, use_torch=False):
35
+ """计算相机B相对于相机A的相对位姿矩阵"""
36
+ assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
37
+ assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
38
+
39
+ if use_torch:
40
+ if not isinstance(pose_a, torch.Tensor):
41
+ pose_a = torch.from_numpy(pose_a).float()
42
+ if not isinstance(pose_b, torch.Tensor):
43
+ pose_b = torch.from_numpy(pose_b).float()
44
+
45
+ pose_a_inv = torch.inverse(pose_a)
46
+ relative_pose = torch.matmul(pose_b, pose_a_inv)
47
+ else:
48
+ if not isinstance(pose_a, np.ndarray):
49
+ pose_a = np.array(pose_a, dtype=np.float32)
50
+ if not isinstance(pose_b, np.ndarray):
51
+ pose_b = np.array(pose_b, dtype=np.float32)
52
+
53
+ pose_a_inv = np.linalg.inv(pose_a)
54
+ relative_pose = np.matmul(pose_b, pose_a_inv)
55
+
56
+ return relative_pose
57
+
58
+ def replace_dit_model_in_manager():
59
+ """替换DiT模型类为FramePack版本"""
60
+ from diffsynth.models.wan_video_dit_recam_future import WanModelFuture
61
+ from diffsynth.configs.model_config import model_loader_configs
62
+
63
+ for i, config in enumerate(model_loader_configs):
64
+ keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
65
+
66
+ if 'wan_video_dit' in model_names:
67
+ new_model_names = []
68
+ new_model_classes = []
69
+
70
+ for name, cls in zip(model_names, model_classes):
71
+ if name == 'wan_video_dit':
72
+ new_model_names.append(name)
73
+ new_model_classes.append(WanModelFuture)
74
+ print(f"✅ 替换了模型类: {name} -> WanModelFuture")
75
+ else:
76
+ new_model_names.append(name)
77
+ new_model_classes.append(cls)
78
+
79
+ model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
80
+
81
+
82
+ def add_framepack_components(dit_model):
83
+ """添加FramePack相关组件"""
84
+ if not hasattr(dit_model, 'clean_x_embedder'):
85
+ inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
86
+
87
+ class CleanXEmbedder(nn.Module):
88
+ def __init__(self, inner_dim):
89
+ super().__init__()
90
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
91
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
92
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
93
+
94
+ def forward(self, x, scale="1x"):
95
+ if scale == "1x":
96
+ x = x.to(self.proj.weight.dtype)
97
+ return self.proj(x)
98
+ elif scale == "2x":
99
+ x = x.to(self.proj_2x.weight.dtype)
100
+ return self.proj_2x(x)
101
+ elif scale == "4x":
102
+ x = x.to(self.proj_4x.weight.dtype)
103
+ return self.proj_4x(x)
104
+ else:
105
+ raise ValueError(f"Unsupported scale: {scale}")
106
+
107
+ dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
108
+ model_dtype = next(dit_model.parameters()).dtype
109
+ dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
110
+ print("✅ 添加了FramePack的clean_x_embedder组件")
111
+
112
+ def generate_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
113
+ """🔧 为滑动窗口生成camera embeddings - 修正长度计算,确保包含start_latent帧"""
114
+ time_compression_ratio = 4
115
+
116
+ # 🔧 计算FramePack实际需要的camera帧���
117
+ # FramePack结构: 1(start) + 16(4x) + 2(2x) + 1(1x) + target_frames
118
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
119
+
120
+ if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
121
+ print("🔧 使用真实camera数据")
122
+ cam_extrinsic = cam_data['extrinsic']
123
+
124
+ # 🔧 确保生成足够长的camera序列
125
+ # 需要考虑:当前历史位置 + FramePack所需的完整结构
126
+ max_needed_frames = max(
127
+ start_frame + current_history_length + new_frames, # 基础需求
128
+ framepack_needed_frames, # FramePack结构需求
129
+ 30 # 最小保证长度
130
+ )
131
+
132
+ print(f"🔧 计算camera序列长度:")
133
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
134
+ print(f" - FramePack需求: {framepack_needed_frames}")
135
+ print(f" - 最终生成: {max_needed_frames}")
136
+
137
+ relative_poses = []
138
+ for i in range(max_needed_frames):
139
+ # 计算当前帧在原始序列中的位置
140
+ frame_idx = i * time_compression_ratio
141
+ next_frame_idx = frame_idx + time_compression_ratio
142
+
143
+ if next_frame_idx < len(cam_extrinsic):
144
+ cam_prev = cam_extrinsic[frame_idx]
145
+ cam_next = cam_extrinsic[next_frame_idx]
146
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
147
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
148
+ else:
149
+ # 超出范围,使用零运动
150
+ print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
151
+ relative_poses.append(torch.zeros(3, 4))
152
+
153
+ pose_embedding = torch.stack(relative_poses, dim=0)
154
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
155
+
156
+ # 🔧 创建对应长度的mask序列
157
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
158
+ # 从start_frame到current_history_length标记为condition
159
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
160
+ mask[start_frame:condition_end] = 1.0
161
+
162
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
163
+ print(f"🔧 真实camera embedding shape: {camera_embedding.shape} (总长度:{max_needed_frames})")
164
+ return camera_embedding.to(torch.bfloat16)
165
+
166
+ else:
167
+ print("🔧 使用合成camera数据")
168
+
169
+ # 🔧 确保合成数据也有足够长度
170
+ max_needed_frames = max(
171
+ start_frame + current_history_length + new_frames,
172
+ framepack_needed_frames,
173
+ 30
174
+ )
175
+
176
+ print(f"🔧 生成合成camera帧数: {max_needed_frames}")
177
+ print(f" - FramePack需求: {framepack_needed_frames}")
178
+
179
+ relative_poses = []
180
+ for i in range(max_needed_frames):
181
+ # 🔧 持续左转运动模式
182
+ # 每帧旋转一个固定角度,同时前进
183
+ yaw_per_frame = -0.05 # 每帧左转(正角度表示左转)
184
+ forward_speed = 0.005 # 每帧前进距离
185
+
186
+ # 计算当前累积角度
187
+ current_yaw = i * yaw_per_frame
188
+
189
+ # 创建相对变换矩阵(从第i帧到第i+1帧的变换)
190
+ pose = np.eye(4, dtype=np.float32)
191
+
192
+ # 旋转矩阵(绕Y轴左转)
193
+ cos_yaw = np.cos(yaw_per_frame)
194
+ sin_yaw = np.sin(yaw_per_frame)
195
+
196
+ pose[0, 0] = cos_yaw
197
+ pose[0, 2] = sin_yaw
198
+ pose[2, 0] = -sin_yaw
199
+ pose[2, 2] = cos_yaw
200
+
201
+ # 平移(在旋转后的局部坐标系中前进)
202
+ pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
203
+
204
+ # 可选:添加轻微的向心运动,模拟圆形轨迹
205
+ radius_drift = 0.002 # 向圆心的轻微漂移
206
+ pose[0, 3] = radius_drift # 局部X轴负方向(向左)
207
+
208
+ relative_pose = pose[:3, :]
209
+ relative_poses.append(torch.as_tensor(relative_pose))
210
+
211
+ pose_embedding = torch.stack(relative_poses, dim=0)
212
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
213
+
214
+ # 创建对应长度的mask序列
215
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
216
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
217
+ mask[start_frame:condition_end] = 1.0
218
+
219
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
220
+ print(f"🔧 合成camera embedding shape: {camera_embedding.shape} (总长度:{max_needed_frames})")
221
+ return camera_embedding.to(torch.bfloat16)
222
+
223
+ def prepare_framepack_sliding_window_with_camera(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, max_history_frames=49):
224
+ """🔧 FramePack滑动窗口机�� - 修正camera mask更新逻辑"""
225
+ # history_latents: [C, T, H, W] 当前的历史latents
226
+ C, T, H, W = history_latents.shape
227
+
228
+ # 🔧 固定索引结构(这决定了需要的camera帧数)
229
+ total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
230
+ indices = torch.arange(0, total_indices_length)
231
+ split_sizes = [1, 16, 2, 1, target_frames_to_generate]
232
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
233
+ indices.split(split_sizes, dim=0)
234
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
235
+
236
+ # 🔧 检查camera长度是否足够
237
+ if camera_embedding_full.shape[0] < total_indices_length:
238
+ shortage = total_indices_length - camera_embedding_full.shape[0]
239
+ padding = torch.zeros(shortage, camera_embedding_full.shape[1],
240
+ dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
241
+ camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
242
+
243
+ # 🔧 从完整camera序列中选取对应部分
244
+ combined_camera = camera_embedding_full[:total_indices_length, :].clone() # clone to avoid modifying original
245
+
246
+ # 🔧 关键修正:根据当前history length重新设置mask
247
+ # combined_camera的结构对应: [1(start) + 16(4x) + 2(2x) + 1(1x) + target_frames]
248
+ # 前19帧对应clean latents,后面对应target
249
+
250
+ # 清空所有mask,重新设置
251
+ combined_camera[:, -1] = 0.0 # 先全部设为target (0)
252
+
253
+ # 设置condition mask:前19帧根据实际历史长度决定
254
+ if T > 0:
255
+ # 根据clean_latents的填充逻辑,确定哪些位置应该是condition
256
+ available_frames = min(T, 19)
257
+ start_pos = 19 - available_frames
258
+
259
+ # 对应的camera位置也应该标记为condition
260
+ combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
261
+
262
+ # target部分保持为0(已经在上面设置)
263
+
264
+ print(f"🔧 Camera mask更新:")
265
+ print(f" - 历史帧数: {T}")
266
+ print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
267
+ print(f" - Condition mask (前19帧): {combined_camera[:19, -1].cpu().tolist()}")
268
+ print(f" - Target mask (后{target_frames_to_generate}帧): {combined_camera[19:, -1].cpu().tolist()}")
269
+ # 其余处理逻辑保持不变...
270
+ clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
271
+
272
+ if T > 0:
273
+ available_frames = min(T, 19)
274
+ start_pos = 19 - available_frames
275
+ clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
276
+
277
+ clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
278
+ clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
279
+ clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
280
+
281
+ if T > 0:
282
+ start_latent = history_latents[:, 0:1, :, :]
283
+ else:
284
+ start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
285
+
286
+ clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
287
+
288
+ return {
289
+ 'latent_indices': latent_indices,
290
+ 'clean_latents': clean_latents,
291
+ 'clean_latents_2x': clean_latents_2x,
292
+ 'clean_latents_4x': clean_latents_4x,
293
+ 'clean_latent_indices': clean_latent_indices,
294
+ 'clean_latent_2x_indices': clean_latent_2x_indices,
295
+ 'clean_latent_4x_indices': clean_latent_4x_indices,
296
+ 'camera_embedding': combined_camera, # 🔧 现在包含正确更新的mask
297
+ 'current_length': T,
298
+ 'next_length': T + target_frames_to_generate
299
+ }
300
+
301
+ def inference_sekai_framepack_sliding_window(
302
+ condition_pth_path,
303
+ dit_path,
304
+ output_path="sekai/infer_results/output_sekai_framepack_sliding.mp4",
305
+ start_frame=0,
306
+ initial_condition_frames=8,
307
+ frames_per_generation=4,
308
+ total_frames_to_generate=32,
309
+ max_history_frames=49,
310
+ device="cuda",
311
+ prompt="A video of a scene shot using a pedestrian's front camera while walking",
312
+ use_real_poses=True,
313
+ synthetic_direction="forward",
314
+ # 🔧 新增CFG参数
315
+ use_camera_cfg=True,
316
+ camera_guidance_scale=2.0,
317
+ text_guidance_scale=7.5
318
+ ):
319
+ """
320
+ 🔧 FramePack滑动窗口视频生成 - 支持Camera CFG
321
+ """
322
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
323
+ print(f"🔧 FramePack滑动窗口生成开始...")
324
+ print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
325
+ print(f"Text guidance scale: {text_guidance_scale}")
326
+ print(f"初始条件帧: {initial_condition_frames}, 每次生成: {frames_per_generation}, 总生成: {total_frames_to_generate}")
327
+ print(f"使用真实姿态: {use_real_poses}")
328
+ if not use_real_poses:
329
+ print(f"合成camera方向: {synthetic_direction}")
330
+
331
+ # 1-3. 模型初始化和组件添加(保持不变)
332
+ replace_dit_model_in_manager()
333
+
334
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
335
+ model_manager.load_models([
336
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
337
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
338
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
339
+ ])
340
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
341
+
342
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
343
+ for block in pipe.dit.blocks:
344
+ block.cam_encoder = nn.Linear(13, dim)
345
+ block.projector = nn.Linear(dim, dim)
346
+ block.cam_encoder.weight.data.zero_()
347
+ block.cam_encoder.bias.data.zero_()
348
+ block.projector.weight = nn.Parameter(torch.eye(dim))
349
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
350
+
351
+ add_framepack_components(pipe.dit)
352
+
353
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
354
+ pipe.dit.load_state_dict(dit_state_dict, strict=True)
355
+ pipe = pipe.to(device)
356
+ model_dtype = next(pipe.dit.parameters()).dtype
357
+
358
+ if hasattr(pipe.dit, 'clean_x_embedder'):
359
+ pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
360
+
361
+ pipe.scheduler.set_timesteps(50)
362
+
363
+ # 4. 加载初始条件
364
+ print("Loading initial condition frames...")
365
+ initial_latents, encoded_data = load_encoded_video_from_pth(
366
+ condition_pth_path,
367
+ start_frame=start_frame,
368
+ num_frames=initial_condition_frames
369
+ )
370
+
371
+ # 空间裁剪
372
+ target_height, target_width = 60, 104
373
+ C, T, H, W = initial_latents.shape
374
+
375
+ if H > target_height or W > target_width:
376
+ h_start = (H - target_height) // 2
377
+ w_start = (W - target_width) // 2
378
+ initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
379
+ H, W = target_height, target_width
380
+
381
+ history_latents = initial_latents.to(device, dtype=model_dtype)
382
+
383
+ print(f"初始history_latents shape: {history_latents.shape}")
384
+
385
+ # 编码prompt - 支持CFG
386
+ if text_guidance_scale > 1.0:
387
+ # 编码positive prompt
388
+ prompt_emb_pos = pipe.encode_prompt(prompt)
389
+ # 编码negative prompt (空字符串)
390
+ prompt_emb_neg = pipe.encode_prompt("")
391
+ print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
392
+ else:
393
+ prompt_emb_pos = pipe.encode_prompt(prompt)
394
+ prompt_emb_neg = None
395
+ print("不使用Text CFG")
396
+
397
+ # 预生成完整的camera embedding序列
398
+ camera_embedding_full = generate_camera_embeddings_sliding(
399
+ encoded_data.get('cam_emb', None),
400
+ 0,
401
+ max_history_frames,
402
+ 0,
403
+ 0,
404
+ use_real_poses=use_real_poses
405
+ ).to(device, dtype=model_dtype)
406
+
407
+ print(f"完整camera序列shape: {camera_embedding_full.shape}")
408
+
409
+ # 🔧 为Camera CFG创建无条件的camera embedding
410
+ if use_camera_cfg:
411
+ # 创建零camera embedding(无条件)
412
+ camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
413
+ print(f"创建无条件camera embedding用于CFG")
414
+
415
+ # 滑动窗口生成循环
416
+ total_generated = 0
417
+ all_generated_frames = []
418
+
419
+ while total_generated < total_frames_to_generate:
420
+ current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
421
+ print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
422
+ print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
423
+
424
+ # FramePack数据准备
425
+ framepack_data = prepare_framepack_sliding_window_with_camera(
426
+ history_latents,
427
+ current_generation,
428
+ camera_embedding_full,
429
+ start_frame,
430
+ max_history_frames
431
+ )
432
+
433
+ # 准备输入
434
+ clean_latents = framepack_data['clean_latents'].unsqueeze(0)
435
+ clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
436
+ clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
437
+ camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
438
+
439
+ # 🔧 为CFG准备无条件camera embedding
440
+ if use_camera_cfg:
441
+ camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
442
+
443
+ # 索引处理
444
+ latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
445
+ clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
446
+ clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
447
+ clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
448
+
449
+ # 初始化要生成的latents
450
+ new_latents = torch.randn(
451
+ 1, C, current_generation, H, W,
452
+ device=device, dtype=model_dtype
453
+ )
454
+
455
+ extra_input = pipe.prepare_extra_input(new_latents)
456
+
457
+ print(f"Camera embedding shape: {camera_embedding.shape}")
458
+ print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
459
+
460
+ # 去噪循环 - 支持CFG
461
+ timesteps = pipe.scheduler.timesteps
462
+
463
+ for i, timestep in enumerate(timesteps):
464
+ if i % 10 == 0:
465
+ print(f" 去噪步骤 {i+1}/{len(timesteps)}")
466
+
467
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
468
+
469
+ with torch.no_grad():
470
+ # 🔧 CFG推理
471
+ if use_camera_cfg and camera_guidance_scale > 1.0:
472
+ # 条件预测(有camera)
473
+ noise_pred_cond = pipe.dit(
474
+ new_latents,
475
+ timestep=timestep_tensor,
476
+ cam_emb=camera_embedding,
477
+ latent_indices=latent_indices,
478
+ clean_latents=clean_latents,
479
+ clean_latent_indices=clean_latent_indices,
480
+ clean_latents_2x=clean_latents_2x,
481
+ clean_latent_2x_indices=clean_latent_2x_indices,
482
+ clean_latents_4x=clean_latents_4x,
483
+ clean_latent_4x_indices=clean_latent_4x_indices,
484
+ **prompt_emb_pos,
485
+ **extra_input
486
+ )
487
+
488
+ # 无条件预测(无camera)
489
+ noise_pred_uncond = pipe.dit(
490
+ new_latents,
491
+ timestep=timestep_tensor,
492
+ cam_emb=camera_embedding_uncond_batch,
493
+ latent_indices=latent_indices,
494
+ clean_latents=clean_latents,
495
+ clean_latent_indices=clean_latent_indices,
496
+ clean_latents_2x=clean_latents_2x,
497
+ clean_latent_2x_indices=clean_latent_2x_indices,
498
+ clean_latents_4x=clean_latents_4x,
499
+ clean_latent_4x_indices=clean_latent_4x_indices,
500
+ **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
501
+ **extra_input
502
+ )
503
+
504
+ # Camera CFG
505
+ noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
506
+
507
+ # 如果同时使用Text CFG
508
+ if text_guidance_scale > 1.0 and prompt_emb_neg:
509
+ # 还需要计算text无条件预测
510
+ noise_pred_text_uncond = pipe.dit(
511
+ new_latents,
512
+ timestep=timestep_tensor,
513
+ cam_emb=camera_embedding,
514
+ latent_indices=latent_indices,
515
+ clean_latents=clean_latents,
516
+ clean_latent_indices=clean_latent_indices,
517
+ clean_latents_2x=clean_latents_2x,
518
+ clean_latent_2x_indices=clean_latent_2x_indices,
519
+ clean_latents_4x=clean_latents_4x,
520
+ clean_latent_4x_indices=clean_latent_4x_indices,
521
+ **prompt_emb_neg,
522
+ **extra_input
523
+ )
524
+
525
+ # 应用Text CFG到已经应用Camera CFG的结果
526
+ noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
527
+
528
+ elif text_guidance_scale > 1.0 and prompt_emb_neg:
529
+ # 只使用Text CFG
530
+ noise_pred_cond = pipe.dit(
531
+ new_latents,
532
+ timestep=timestep_tensor,
533
+ cam_emb=camera_embedding,
534
+ latent_indices=latent_indices,
535
+ clean_latents=clean_latents,
536
+ clean_latent_indices=clean_latent_indices,
537
+ clean_latents_2x=clean_latents_2x,
538
+ clean_latent_2x_indices=clean_latent_2x_indices,
539
+ clean_latents_4x=clean_latents_4x,
540
+ clean_latent_4x_indices=clean_latent_4x_indices,
541
+ **prompt_emb_pos,
542
+ **extra_input
543
+ )
544
+
545
+ noise_pred_uncond = pipe.dit(
546
+ new_latents,
547
+ timestep=timestep_tensor,
548
+ cam_emb=camera_embedding,
549
+ latent_indices=latent_indices,
550
+ clean_latents=clean_latents,
551
+ clean_latent_indices=clean_latent_indices,
552
+ clean_latents_2x=clean_latents_2x,
553
+ clean_latent_2x_indices=clean_latent_2x_indices,
554
+ clean_latents_4x=clean_latents_4x,
555
+ clean_latent_4x_indices=clean_latent_4x_indices,
556
+ **prompt_emb_neg,
557
+ **extra_input
558
+ )
559
+
560
+ noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
561
+
562
+ else:
563
+ # 标准推理(无CFG)
564
+ noise_pred = pipe.dit(
565
+ new_latents,
566
+ timestep=timestep_tensor,
567
+ cam_emb=camera_embedding,
568
+ latent_indices=latent_indices,
569
+ clean_latents=clean_latents,
570
+ clean_latent_indices=clean_latent_indices,
571
+ clean_latents_2x=clean_latents_2x,
572
+ clean_latent_2x_indices=clean_latent_2x_indices,
573
+ clean_latents_4x=clean_latents_4x,
574
+ clean_latent_4x_indices=clean_latent_4x_indices,
575
+ **prompt_emb_pos,
576
+ **extra_input
577
+ )
578
+
579
+ new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
580
+
581
+ # 更新历史
582
+ new_latents_squeezed = new_latents.squeeze(0)
583
+ history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
584
+
585
+ # 维护滑动窗口
586
+ if history_latents.shape[1] > max_history_frames:
587
+ first_frame = history_latents[:, 0:1, :, :]
588
+ recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
589
+ history_latents = torch.cat([first_frame, recent_frames], dim=1)
590
+ print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
591
+
592
+ print(f"更新后history_latents shape: {history_latents.shape}")
593
+
594
+ all_generated_frames.append(new_latents_squeezed)
595
+ total_generated += current_generation
596
+
597
+ print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
598
+
599
+ # 7. 解码和保存
600
+ print("\n🔧 解码生成的视频...")
601
+
602
+ all_generated = torch.cat(all_generated_frames, dim=1)
603
+ final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
604
+
605
+ print(f"最终视频shape: {final_video.shape}")
606
+
607
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
608
+
609
+ print(f"Saving video to {output_path}")
610
+
611
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
612
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1)
613
+ video_np = (video_np * 255).astype(np.uint8)
614
+
615
+ with imageio.get_writer(output_path, fps=20) as writer:
616
+ for frame in video_np:
617
+ writer.append_data(frame)
618
+
619
+ print(f"🔧 FramePack滑动窗口生成完成! 保存到: {output_path}")
620
+ print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
621
+
622
+ def main():
623
+ parser = argparse.ArgumentParser(description="Sekai FramePack滑动窗口视频生成 - 支持CFG")
624
+ parser.add_argument("--condition_pth", type=str,
625
+ default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
626
+ parser.add_argument("--start_frame", type=int, default=0)
627
+ parser.add_argument("--initial_condition_frames", type=int, default=16)
628
+ parser.add_argument("--frames_per_generation", type=int, default=8)
629
+ parser.add_argument("--total_frames_to_generate", type=int, default=40)
630
+ parser.add_argument("--max_history_frames", type=int, default=100)
631
+ parser.add_argument("--use_real_poses", action="store_true", default=False)
632
+ parser.add_argument("--dit_path", type=str,
633
+ default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/sekai/sekai_walking_framepack/step1000_framepack.ckpt")
634
+ parser.add_argument("--output_path", type=str,
635
+ default='/home/zhuyixuan05/ReCamMaster/sekai/infer_framepack_results/output_sekai_framepack_sliding.mp4')
636
+ parser.add_argument("--prompt", type=str,
637
+ default="A drone flying scene in a game world")
638
+ parser.add_argument("--device", type=str, default="cuda")
639
+
640
+ # 🔧 新增CFG参数
641
+ parser.add_argument("--use_camera_cfg", default=True,
642
+ help="使用Camera CFG")
643
+ parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
644
+ help="Camera guidance scale for CFG")
645
+ parser.add_argument("--text_guidance_scale", type=float, default=1.0,
646
+ help="Text guidance scale for CFG")
647
+
648
+ args = parser.parse_args()
649
+
650
+ print(f"🔧 FramePack CFG生成设置:")
651
+ print(f"Camera CFG: {args.use_camera_cfg}")
652
+ if args.use_camera_cfg:
653
+ print(f"Camera guidance scale: {args.camera_guidance_scale}")
654
+ print(f"Text guidance scale: {args.text_guidance_scale}")
655
+
656
+ inference_sekai_framepack_sliding_window(
657
+ condition_pth_path=args.condition_pth,
658
+ dit_path=args.dit_path,
659
+ output_path=args.output_path,
660
+ start_frame=args.start_frame,
661
+ initial_condition_frames=args.initial_condition_frames,
662
+ frames_per_generation=args.frames_per_generation,
663
+ total_frames_to_generate=args.total_frames_to_generate,
664
+ max_history_frames=args.max_history_frames,
665
+ device=args.device,
666
+ prompt=args.prompt,
667
+ use_real_poses=args.use_real_poses,
668
+ # 🔧 CFG参数
669
+ use_camera_cfg=args.use_camera_cfg,
670
+ camera_guidance_scale=args.camera_guidance_scale,
671
+ text_guidance_scale=args.text_guidance_scale
672
+ )
673
+
674
+ if __name__ == "__main__":
675
+ main()
scripts/infer_sekai_framepack_4.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ import imageio
7
+ import json
8
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
9
+ import argparse
10
+ from torchvision.transforms import v2
11
+ from einops import rearrange
12
+ import copy
13
+
14
+
15
+ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
16
+ """从pth文件加载预编码的视频数据"""
17
+ print(f"Loading encoded video from {pth_path}")
18
+
19
+ encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
20
+ full_latents = encoded_data['latents'] # [C, T, H, W]
21
+
22
+ print(f"Full latents shape: {full_latents.shape}")
23
+ print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
24
+
25
+ if start_frame + num_frames > full_latents.shape[1]:
26
+ raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
27
+
28
+ condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
29
+ print(f"Extracted condition latents shape: {condition_latents.shape}")
30
+
31
+ return condition_latents, encoded_data
32
+
33
+
34
+ def compute_relative_pose(pose_a, pose_b, use_torch=False):
35
+ """计算相机B相对于相机A的相对位姿矩阵"""
36
+ assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
37
+ assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
38
+
39
+ if use_torch:
40
+ if not isinstance(pose_a, torch.Tensor):
41
+ pose_a = torch.from_numpy(pose_a).float()
42
+ if not isinstance(pose_b, torch.Tensor):
43
+ pose_b = torch.from_numpy(pose_b).float()
44
+
45
+ pose_a_inv = torch.inverse(pose_a)
46
+ relative_pose = torch.matmul(pose_b, pose_a_inv)
47
+ else:
48
+ if not isinstance(pose_a, np.ndarray):
49
+ pose_a = np.array(pose_a, dtype=np.float32)
50
+ if not isinstance(pose_b, np.ndarray):
51
+ pose_b = np.array(pose_b, dtype=np.float32)
52
+
53
+ pose_a_inv = np.linalg.inv(pose_a)
54
+ relative_pose = np.matmul(pose_b, pose_a_inv)
55
+
56
+ return relative_pose
57
+
58
+ def replace_dit_model_in_manager():
59
+ """替换DiT模型类为FramePack版本"""
60
+ from diffsynth.models.wan_video_dit_4 import WanModelFuture4
61
+ from diffsynth.configs.model_config import model_loader_configs
62
+
63
+ for i, config in enumerate(model_loader_configs):
64
+ keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
65
+
66
+ if 'wan_video_dit' in model_names:
67
+ new_model_names = []
68
+ new_model_classes = []
69
+
70
+ for name, cls in zip(model_names, model_classes):
71
+ if name == 'wan_video_dit':
72
+ new_model_names.append(name)
73
+ new_model_classes.append(WanModelFuture4)
74
+ print(f"✅ 替换了模型类: {name} -> WanModelFuture4")
75
+ else:
76
+ new_model_names.append(name)
77
+ new_model_classes.append(cls)
78
+
79
+ model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
80
+
81
+
82
+ def add_framepack_components(dit_model):
83
+ """添加FramePack相关组件"""
84
+ if not hasattr(dit_model, 'clean_x_embedder'):
85
+ inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
86
+
87
+ class CleanXEmbedder(nn.Module):
88
+ def __init__(self, inner_dim):
89
+ super().__init__()
90
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
91
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
92
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
93
+
94
+ def forward(self, x, scale="1x"):
95
+ if scale == "1x":
96
+ x = x.to(self.proj.weight.dtype)
97
+ return self.proj(x)
98
+ elif scale == "2x":
99
+ x = x.to(self.proj_2x.weight.dtype)
100
+ return self.proj_2x(x)
101
+ elif scale == "4x":
102
+ x = x.to(self.proj_4x.weight.dtype)
103
+ return self.proj_4x(x)
104
+ else:
105
+ raise ValueError(f"Unsupported scale: {scale}")
106
+
107
+ dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
108
+ model_dtype = next(dit_model.parameters()).dtype
109
+ dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
110
+ print("✅ 添加了FramePack的clean_x_embedder组件")
111
+
112
+ def generate_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
113
+ """🔧 为滑动窗口生成camera embeddings - 修正长度计算,确保包含start_latent帧"""
114
+ time_compression_ratio = 4
115
+
116
+ # 🔧 计算FramePack实际需要的camera帧数
117
+ # FramePack结构: 1(start) + 16(4x) + 2(2x) + 1(1x) + target_frames
118
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
119
+
120
+ if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
121
+ print("🔧 使用真实camera数据")
122
+ cam_extrinsic = cam_data['extrinsic']
123
+
124
+ # 🔧 确保生成足够长的camera序列
125
+ # 需要考虑:当前历史位置 + FramePack所需的完整结构
126
+ max_needed_frames = max(
127
+ start_frame + current_history_length + new_frames, # 基础需求
128
+ framepack_needed_frames, # FramePack结构需求
129
+ 30 # 最小保证长度
130
+ )
131
+
132
+ print(f"🔧 计算camera序列长度:")
133
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
134
+ print(f" - FramePack需求: {framepack_needed_frames}")
135
+ print(f" - 最终生成: {max_needed_frames}")
136
+
137
+ relative_poses = []
138
+ for i in range(max_needed_frames):
139
+ # 计算当前帧在原始序列中的位置
140
+ frame_idx = i * time_compression_ratio
141
+ next_frame_idx = frame_idx + time_compression_ratio
142
+
143
+ if next_frame_idx < len(cam_extrinsic):
144
+ cam_prev = cam_extrinsic[frame_idx]
145
+ cam_next = cam_extrinsic[next_frame_idx]
146
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
147
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
148
+ else:
149
+ # 超出范围,使用零运动
150
+ print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
151
+ relative_poses.append(torch.zeros(3, 4))
152
+
153
+ pose_embedding = torch.stack(relative_poses, dim=0)
154
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
155
+
156
+ # 🔧 创建对应长度的mask序列
157
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
158
+ # 从start_frame到current_history_length标记为condition
159
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
160
+ mask[start_frame:condition_end] = 1.0
161
+
162
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
163
+ print(f"🔧 真实camera embedding shape: {camera_embedding.shape} (总长度:{max_needed_frames})")
164
+ return camera_embedding.to(torch.bfloat16)
165
+
166
+ else:
167
+ print("🔧 使用合成camera数据")
168
+
169
+ # 🔧 确保合成数据也有足够长度
170
+ max_needed_frames = max(
171
+ start_frame + current_history_length + new_frames,
172
+ framepack_needed_frames,
173
+ 30
174
+ )
175
+
176
+ print(f"🔧 生成合成camera帧数: {max_needed_frames}")
177
+ print(f" - FramePack需求: {framepack_needed_frames}")
178
+
179
+ relative_poses = []
180
+ for i in range(max_needed_frames):
181
+ # 🔧 持续左转运动模式
182
+ # 每帧旋转一个固定角度,同时前进
183
+ yaw_per_frame = -0.05 # 每帧左转(正角度表示左转)
184
+ forward_speed = 0.005 # 每帧前进距离
185
+
186
+ # 计算当前累积角度
187
+ current_yaw = i * yaw_per_frame
188
+
189
+ # 创建相对变换矩阵(从第i帧到第i+1帧的变换)
190
+ pose = np.eye(4, dtype=np.float32)
191
+
192
+ # 旋转矩阵(绕Y轴左转)
193
+ cos_yaw = np.cos(yaw_per_frame)
194
+ sin_yaw = np.sin(yaw_per_frame)
195
+
196
+ pose[0, 0] = cos_yaw
197
+ pose[0, 2] = sin_yaw
198
+ pose[2, 0] = -sin_yaw
199
+ pose[2, 2] = cos_yaw
200
+
201
+ # 平移(在旋转后的局部坐标系中前进)
202
+ pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
203
+
204
+ # 可选:添加轻微的向心运动,模拟圆形轨迹
205
+ radius_drift = 0.002 # 向圆心的轻微漂移
206
+ pose[0, 3] = radius_drift # 局部X轴负方向(向左)
207
+
208
+ relative_pose = pose[:3, :]
209
+ relative_poses.append(torch.as_tensor(relative_pose))
210
+
211
+ pose_embedding = torch.stack(relative_poses, dim=0)
212
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
213
+
214
+ # 创建对应长度的mask序列
215
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
216
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
217
+ mask[start_frame:condition_end] = 1.0
218
+
219
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
220
+ print(f"🔧 合成camera embedding shape: {camera_embedding.shape} (总长度:{max_needed_frames})")
221
+ return camera_embedding.to(torch.bfloat16)
222
+
223
+ def prepare_framepack_sliding_window_with_camera(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, max_history_frames=49):
224
+ """🔧 FramePack滑动窗口机制 - 支���起始4帧+最后1帧的clean_latents"""
225
+ # history_latents: [C, T, H, W] 当前的历史latents
226
+ C, T, H, W = history_latents.shape
227
+
228
+ # 🔧 固定索引结构:起始4帧 + 最后1帧 = 5帧clean_latents
229
+ total_indices_length = 1 + 16 + 2 + 5 + target_frames_to_generate # 修改:clean_latents现在是5帧
230
+ indices = torch.arange(0, total_indices_length)
231
+ split_sizes = [1, 16, 2, 5, target_frames_to_generate] # 修改:clean_latents部分改为5
232
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
233
+ indices.split(split_sizes, dim=0)
234
+
235
+ # clean_latents结构:起始4帧 + 最后1帧
236
+ clean_latent_indices = clean_latent_1x_indices # 现在是5帧,包含起始4帧+最后1帧
237
+
238
+ # 🔧 检查camera长度是否足够
239
+ if camera_embedding_full.shape[0] < total_indices_length:
240
+ shortage = total_indices_length - camera_embedding_full.shape[0]
241
+ padding = torch.zeros(shortage, camera_embedding_full.shape[1],
242
+ dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
243
+ camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
244
+
245
+ # 🔧 从完整camera序列中选取对应部分
246
+ combined_camera = camera_embedding_full[:total_indices_length, :].clone()
247
+
248
+ # 🔧 根据当前history length重新设置mask
249
+ combined_camera[:, -1] = 0.0 # 先全部设为target (0)
250
+
251
+ # 设置condition mask:前24帧根据实际历史长度决定(1+16+2+5)
252
+ if T > 0:
253
+ # 根据clean_latents的填充逻辑,确定哪些位置应该是condition
254
+ available_frames = min(T, 24) # 修改:现在前24帧对应clean latents
255
+ start_pos = 24 - available_frames
256
+
257
+ # 对应的camera位置也应该标记为condition
258
+ combined_camera[start_pos:24, -1] = 1.0 # 修改:前24帧对应condition
259
+
260
+ # target部分保持为0(已经在上面设置)
261
+
262
+ print(f"🔧 Camera mask更新:")
263
+ print(f" - 历史帧数: {T}")
264
+ print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
265
+ print(f" - Condition mask (前24帧): {combined_camera[:24, -1].cpu().tolist()}") # 修改:24帧
266
+ print(f" - Target mask (后{target_frames_to_generate}帧): {combined_camera[24:, -1].cpu().tolist()}")
267
+
268
+ # 处理clean latents - 现在clean_latents是5帧:起始4帧+最后1帧
269
+ clean_latents_combined = torch.zeros(C, 24, H, W, dtype=history_latents.dtype, device=history_latents.device) # 修改:24帧
270
+
271
+ if T > 0:
272
+ available_frames = min(T, 24) # 修改:24帧
273
+ start_pos = 24 - available_frames
274
+ clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
275
+
276
+ clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
277
+ clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
278
+ clean_latents_1x = clean_latents_combined[:, 18:23, :, :] # 修改:5帧clean latents
279
+
280
+ # 构建clean_latents:起始4帧 + 最后1帧
281
+ if T >= 5:
282
+ # 如果历史足够,取起始4帧+最后1帧
283
+ start_latent = history_latents[:, 0:4, :, :] # 起始4帧
284
+ last_latent = history_latents[:, -1:, :, :] # 最后1帧
285
+ clean_latents = torch.cat([start_latent, last_latent], dim=1) # 5帧
286
+ elif T > 0:
287
+ # 如果历史不足5帧,用0填充+最后1帧
288
+ clean_latents = torch.zeros(C, 5, H, W, dtype=history_latents.dtype, device=history_latents.device)
289
+ # 从后往前填充历史帧
290
+ clean_latents[:, -T:, :, :] = history_latents
291
+ else:
292
+ # 没有历史,全部用0
293
+ clean_latents = torch.zeros(C, 5, H, W, dtype=history_latents.dtype, device=history_latents.device)
294
+
295
+ return {
296
+ 'latent_indices': latent_indices,
297
+ 'clean_latents': clean_latents, # 现在是5帧:起始4帧+最后1帧
298
+ 'clean_latents_2x': clean_latents_2x,
299
+ 'clean_latents_4x': clean_latents_4x,
300
+ 'clean_latent_indices': clean_latent_indices,
301
+ 'clean_latent_2x_indices': clean_latent_2x_indices,
302
+ 'clean_latent_4x_indices': clean_latent_4x_indices,
303
+ 'camera_embedding': combined_camera,
304
+ 'current_length': T,
305
+ 'next_length': T + target_frames_to_generate
306
+ }
307
+
308
+ def inference_sekai_framepack_sliding_window(
309
+ condition_pth_path,
310
+ dit_path,
311
+ output_path="sekai/infer_results/output_sekai_framepack_sliding.mp4",
312
+ start_frame=0,
313
+ initial_condition_frames=8,
314
+ frames_per_generation=4,
315
+ total_frames_to_generate=32,
316
+ max_history_frames=49,
317
+ device="cuda",
318
+ prompt="A video of a scene shot using a pedestrian's front camera while walking",
319
+ use_real_poses=True,
320
+ synthetic_direction="forward",
321
+ # 🔧 新增CFG参数
322
+ use_camera_cfg=True,
323
+ camera_guidance_scale=2.0,
324
+ text_guidance_scale=7.5
325
+ ):
326
+ """
327
+ 🔧 FramePack滑动窗口视频生成 - 支持Camera CFG
328
+ """
329
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
330
+ print(f"🔧 FramePack滑动窗口生成开始...")
331
+ print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
332
+ print(f"Text guidance scale: {text_guidance_scale}")
333
+ print(f"初始条件帧: {initial_condition_frames}, 每次生成: {frames_per_generation}, 总生成: {total_frames_to_generate}")
334
+ print(f"使用真实姿态: {use_real_poses}")
335
+ if not use_real_poses:
336
+ print(f"合成camera方向: {synthetic_direction}")
337
+
338
+ # 1-3. 模型初始化和组件添加(保持不变)
339
+ replace_dit_model_in_manager()
340
+
341
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
342
+ model_manager.load_models([
343
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
344
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
345
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
346
+ ])
347
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
348
+
349
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
350
+ for block in pipe.dit.blocks:
351
+ block.cam_encoder = nn.Linear(13, dim)
352
+ block.projector = nn.Linear(dim, dim)
353
+ block.cam_encoder.weight.data.zero_()
354
+ block.cam_encoder.bias.data.zero_()
355
+ block.projector.weight = nn.Parameter(torch.eye(dim))
356
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
357
+
358
+ add_framepack_components(pipe.dit)
359
+
360
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
361
+ pipe.dit.load_state_dict(dit_state_dict, strict=True)
362
+ pipe = pipe.to(device)
363
+ model_dtype = next(pipe.dit.parameters()).dtype
364
+
365
+ if hasattr(pipe.dit, 'clean_x_embedder'):
366
+ pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
367
+
368
+ pipe.scheduler.set_timesteps(50)
369
+
370
+ # 4. 加载初始条件
371
+ print("Loading initial condition frames...")
372
+ initial_latents, encoded_data = load_encoded_video_from_pth(
373
+ condition_pth_path,
374
+ start_frame=start_frame,
375
+ num_frames=initial_condition_frames
376
+ )
377
+
378
+ # 空间裁剪
379
+ target_height, target_width = 60, 104
380
+ C, T, H, W = initial_latents.shape
381
+
382
+ if H > target_height or W > target_width:
383
+ h_start = (H - target_height) // 2
384
+ w_start = (W - target_width) // 2
385
+ initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
386
+ H, W = target_height, target_width
387
+
388
+ history_latents = initial_latents.to(device, dtype=model_dtype)
389
+
390
+ print(f"初始history_latents shape: {history_latents.shape}")
391
+
392
+ # 编码prompt - 支持CFG
393
+ if text_guidance_scale > 1.0:
394
+ # 编码positive prompt
395
+ prompt_emb_pos = pipe.encode_prompt(prompt)
396
+ # 编码negative prompt (空字符串)
397
+ prompt_emb_neg = pipe.encode_prompt("")
398
+ print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
399
+ else:
400
+ prompt_emb_pos = pipe.encode_prompt(prompt)
401
+ prompt_emb_neg = None
402
+ print("不使用Text CFG")
403
+
404
+ # 预生成完整的camera embedding序列
405
+ camera_embedding_full = generate_camera_embeddings_sliding(
406
+ encoded_data.get('cam_emb', None),
407
+ 0,
408
+ max_history_frames,
409
+ 0,
410
+ 0,
411
+ use_real_poses=use_real_poses
412
+ ).to(device, dtype=model_dtype)
413
+
414
+ print(f"完整camera序列shape: {camera_embedding_full.shape}")
415
+
416
+ # 🔧 为Camera CFG创建无条件的camera embedding
417
+ if use_camera_cfg:
418
+ # 创建零camera embedding(无条件)
419
+ camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
420
+ print(f"创建无条件camera embedding用于CFG")
421
+
422
+ # 滑动窗口生成循环
423
+ total_generated = 0
424
+ all_generated_frames = []
425
+
426
+ while total_generated < total_frames_to_generate:
427
+ current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
428
+ print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
429
+ print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
430
+
431
+ # FramePack数据准备
432
+ framepack_data = prepare_framepack_sliding_window_with_camera(
433
+ history_latents,
434
+ current_generation,
435
+ camera_embedding_full,
436
+ start_frame,
437
+ max_history_frames
438
+ )
439
+
440
+ # 准备输入
441
+ clean_latents = framepack_data['clean_latents'].unsqueeze(0)
442
+ clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
443
+ clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
444
+ camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
445
+
446
+ # 🔧 为CFG准备无条件camera embedding
447
+ if use_camera_cfg:
448
+ camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
449
+
450
+ # 索引处理
451
+ latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
452
+ clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
453
+ clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
454
+ clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
455
+
456
+ # 初始化要生成的latents
457
+ new_latents = torch.randn(
458
+ 1, C, current_generation, H, W,
459
+ device=device, dtype=model_dtype
460
+ )
461
+
462
+ extra_input = pipe.prepare_extra_input(new_latents)
463
+
464
+ print(f"Camera embedding shape: {camera_embedding.shape}")
465
+ print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
466
+
467
+ # 去噪循环 - 支持CFG
468
+ timesteps = pipe.scheduler.timesteps
469
+
470
+ for i, timestep in enumerate(timesteps):
471
+ if i % 10 == 0:
472
+ print(f" 去噪步骤 {i+1}/{len(timesteps)}")
473
+
474
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
475
+
476
+ with torch.no_grad():
477
+ # 🔧 CFG推理
478
+ if use_camera_cfg and camera_guidance_scale > 1.0:
479
+ # 条件预测(有camera)
480
+ noise_pred_cond = pipe.dit(
481
+ new_latents,
482
+ timestep=timestep_tensor,
483
+ cam_emb=camera_embedding,
484
+ latent_indices=latent_indices,
485
+ clean_latents=clean_latents,
486
+ clean_latent_indices=clean_latent_indices,
487
+ clean_latents_2x=clean_latents_2x,
488
+ clean_latent_2x_indices=clean_latent_2x_indices,
489
+ clean_latents_4x=clean_latents_4x,
490
+ clean_latent_4x_indices=clean_latent_4x_indices,
491
+ **prompt_emb_pos,
492
+ **extra_input
493
+ )
494
+
495
+ # 无条件预测(无camera)
496
+ noise_pred_uncond = pipe.dit(
497
+ new_latents,
498
+ timestep=timestep_tensor,
499
+ cam_emb=camera_embedding_uncond_batch,
500
+ latent_indices=latent_indices,
501
+ clean_latents=clean_latents,
502
+ clean_latent_indices=clean_latent_indices,
503
+ clean_latents_2x=clean_latents_2x,
504
+ clean_latent_2x_indices=clean_latent_2x_indices,
505
+ clean_latents_4x=clean_latents_4x,
506
+ clean_latent_4x_indices=clean_latent_4x_indices,
507
+ **(prompt_emb_neg if prompt_emb_neg else prompt_emb_pos),
508
+ **extra_input
509
+ )
510
+
511
+ # Camera CFG
512
+ noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_cond - noise_pred_uncond)
513
+
514
+ # 如果同时使用Text CFG
515
+ if text_guidance_scale > 1.0 and prompt_emb_neg:
516
+ # 还需要计算text无条件预测
517
+ noise_pred_text_uncond = pipe.dit(
518
+ new_latents,
519
+ timestep=timestep_tensor,
520
+ cam_emb=camera_embedding,
521
+ latent_indices=latent_indices,
522
+ clean_latents=clean_latents,
523
+ clean_latent_indices=clean_latent_indices,
524
+ clean_latents_2x=clean_latents_2x,
525
+ clean_latent_2x_indices=clean_latent_2x_indices,
526
+ clean_latents_4x=clean_latents_4x,
527
+ clean_latent_4x_indices=clean_latent_4x_indices,
528
+ **prompt_emb_neg,
529
+ **extra_input
530
+ )
531
+
532
+ # 应用Text CFG到已经应用Camera CFG的结果
533
+ noise_pred = noise_pred_text_uncond + text_guidance_scale * (noise_pred - noise_pred_text_uncond)
534
+
535
+ elif text_guidance_scale > 1.0 and prompt_emb_neg:
536
+ # 只使用Text CFG
537
+ noise_pred_cond = pipe.dit(
538
+ new_latents,
539
+ timestep=timestep_tensor,
540
+ cam_emb=camera_embedding,
541
+ latent_indices=latent_indices,
542
+ clean_latents=clean_latents,
543
+ clean_latent_indices=clean_latent_indices,
544
+ clean_latents_2x=clean_latents_2x,
545
+ clean_latent_2x_indices=clean_latent_2x_indices,
546
+ clean_latents_4x=clean_latents_4x,
547
+ clean_latent_4x_indices=clean_latent_4x_indices,
548
+ **prompt_emb_pos,
549
+ **extra_input
550
+ )
551
+
552
+ noise_pred_uncond = pipe.dit(
553
+ new_latents,
554
+ timestep=timestep_tensor,
555
+ cam_emb=camera_embedding,
556
+ latent_indices=latent_indices,
557
+ clean_latents=clean_latents,
558
+ clean_latent_indices=clean_latent_indices,
559
+ clean_latents_2x=clean_latents_2x,
560
+ clean_latent_2x_indices=clean_latent_2x_indices,
561
+ clean_latents_4x=clean_latents_4x,
562
+ clean_latent_4x_indices=clean_latent_4x_indices,
563
+ **prompt_emb_neg,
564
+ **extra_input
565
+ )
566
+
567
+ noise_pred = noise_pred_uncond + text_guidance_scale * (noise_pred_cond - noise_pred_uncond)
568
+
569
+ else:
570
+ # 标准推理(无CFG)
571
+ noise_pred = pipe.dit(
572
+ new_latents,
573
+ timestep=timestep_tensor,
574
+ cam_emb=camera_embedding,
575
+ latent_indices=latent_indices,
576
+ clean_latents=clean_latents,
577
+ clean_latent_indices=clean_latent_indices,
578
+ clean_latents_2x=clean_latents_2x,
579
+ clean_latent_2x_indices=clean_latent_2x_indices,
580
+ clean_latents_4x=clean_latents_4x,
581
+ clean_latent_4x_indices=clean_latent_4x_indices,
582
+ **prompt_emb_pos,
583
+ **extra_input
584
+ )
585
+
586
+ new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
587
+
588
+ # 更新历史
589
+ new_latents_squeezed = new_latents.squeeze(0)
590
+ history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
591
+
592
+ # 维护滑动窗口
593
+ if history_latents.shape[1] > max_history_frames:
594
+ first_frame = history_latents[:, 0:1, :, :]
595
+ recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
596
+ history_latents = torch.cat([first_frame, recent_frames], dim=1)
597
+ print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
598
+
599
+ print(f"更新后history_latents shape: {history_latents.shape}")
600
+
601
+ all_generated_frames.append(new_latents_squeezed)
602
+ total_generated += current_generation
603
+
604
+ print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
605
+
606
+ # 7. 解码和保存
607
+ print("\n🔧 解码生成的视频...")
608
+
609
+ all_generated = torch.cat(all_generated_frames, dim=1)
610
+ final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
611
+
612
+ print(f"最终视频shape: {final_video.shape}")
613
+
614
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
615
+
616
+ print(f"Saving video to {output_path}")
617
+
618
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
619
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1)
620
+ video_np = (video_np * 255).astype(np.uint8)
621
+
622
+ with imageio.get_writer(output_path, fps=20) as writer:
623
+ for frame in video_np:
624
+ writer.append_data(frame)
625
+
626
+ print(f"🔧 FramePack滑动窗口生成完成! 保存到: {output_path}")
627
+ print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
628
+
629
+ def main():
630
+ parser = argparse.ArgumentParser(description="Sekai FramePack滑动窗口视频生成 - 支持CFG")
631
+ parser.add_argument("--condition_pth", type=str,
632
+ default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
633
+ parser.add_argument("--start_frame", type=int, default=0)
634
+ parser.add_argument("--initial_condition_frames", type=int, default=16)
635
+ parser.add_argument("--frames_per_generation", type=int, default=8)
636
+ parser.add_argument("--total_frames_to_generate", type=int, default=60)
637
+ parser.add_argument("--max_history_frames", type=int, default=100)
638
+ parser.add_argument("--use_real_poses", action="store_true", default=True)
639
+ parser.add_argument("--dit_path", type=str,
640
+ default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/sekai/sekai_walking_framepack_4/step34290_framepack.ckpt")
641
+ parser.add_argument("--output_path", type=str,
642
+ default='/home/zhuyixuan05/ReCamMaster/sekai/infer_framepack_results/output_sekai_framepack_sliding.mp4')
643
+ parser.add_argument("--prompt", type=str,
644
+ default="A drone flying scene in a game world")
645
+ parser.add_argument("--device", type=str, default="cuda")
646
+
647
+ # 🔧 新增CFG参数
648
+ parser.add_argument("--use_camera_cfg", default=False,
649
+ help="使用Camera CFG")
650
+ parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
651
+ help="Camera guidance scale for CFG")
652
+ parser.add_argument("--text_guidance_scale", type=float, default=1.0,
653
+ help="Text guidance scale for CFG")
654
+
655
+ args = parser.parse_args()
656
+
657
+ print(f"🔧 FramePack CFG生成设置:")
658
+ print(f"Camera CFG: {args.use_camera_cfg}")
659
+ if args.use_camera_cfg:
660
+ print(f"Camera guidance scale: {args.camera_guidance_scale}")
661
+ print(f"Text guidance scale: {args.text_guidance_scale}")
662
+
663
+ inference_sekai_framepack_sliding_window(
664
+ condition_pth_path=args.condition_pth,
665
+ dit_path=args.dit_path,
666
+ output_path=args.output_path,
667
+ start_frame=args.start_frame,
668
+ initial_condition_frames=args.initial_condition_frames,
669
+ frames_per_generation=args.frames_per_generation,
670
+ total_frames_to_generate=args.total_frames_to_generate,
671
+ max_history_frames=args.max_history_frames,
672
+ device=args.device,
673
+ prompt=args.prompt,
674
+ use_real_poses=args.use_real_poses,
675
+ # 🔧 CFG参数
676
+ use_camera_cfg=args.use_camera_cfg,
677
+ camera_guidance_scale=args.camera_guidance_scale,
678
+ text_guidance_scale=args.text_guidance_scale
679
+ )
680
+
681
+ if __name__ == "__main__":
682
+ main()
scripts/infer_sekai_framepack_test.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ import imageio
7
+ import json
8
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
9
+ import argparse
10
+ from torchvision.transforms import v2
11
+ from einops import rearrange
12
+ import copy
13
+
14
+
15
+ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
16
+ """
17
+ 从pth文件加载预编码的视频数据
18
+ Args:
19
+ pth_path: pth文件路径
20
+ start_frame: 起始帧索引(基于压缩后的latent帧数)
21
+ num_frames: 需要的帧数(基于压缩后的latent帧数)
22
+ Returns:
23
+ condition_latents: [C, T, H, W] 格式的latent tensor
24
+ """
25
+ print(f"Loading encoded video from {pth_path}")
26
+
27
+ encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
28
+ full_latents = encoded_data['latents'] # [C, T, H, W]
29
+
30
+ print(f"Full latents shape: {full_latents.shape}")
31
+ print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
32
+
33
+ if start_frame + num_frames > full_latents.shape[1]:
34
+ raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
35
+
36
+ condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
37
+ print(f"Extracted condition latents shape: {condition_latents.shape}")
38
+
39
+ return condition_latents, encoded_data
40
+
41
+
42
+ def compute_relative_pose(pose_a, pose_b, use_torch=False):
43
+ """计算相机B相对于相机A的相对位姿矩阵"""
44
+ assert pose_a.shape == (4, 4), f"相机A外参矩阵形状应为(4,4),实际为{pose_a.shape}"
45
+ assert pose_b.shape == (4, 4), f"相机B外参矩阵形状应为(4,4),实际为{pose_b.shape}"
46
+
47
+ if use_torch:
48
+ if not isinstance(pose_a, torch.Tensor):
49
+ pose_a = torch.from_numpy(pose_a).float()
50
+ if not isinstance(pose_b, torch.Tensor):
51
+ pose_b = torch.from_numpy(pose_b).float()
52
+
53
+ pose_a_inv = torch.inverse(pose_a)
54
+ relative_pose = torch.matmul(pose_b, pose_a_inv)
55
+ else:
56
+ if not isinstance(pose_a, np.ndarray):
57
+ pose_a = np.array(pose_a, dtype=np.float32)
58
+ if not isinstance(pose_b, np.ndarray):
59
+ pose_b = np.array(pose_b, dtype=np.float32)
60
+
61
+ pose_a_inv = np.linalg.inv(pose_a)
62
+ relative_pose = np.matmul(pose_b, pose_a_inv)
63
+
64
+ return relative_pose
65
+
66
+
67
+ def prepare_framepack_inputs(full_latents, condition_frames, target_frames, start_frame=0):
68
+ """🔧 准备FramePack风格的多尺度输入"""
69
+ # 确保有batch维度
70
+ if len(full_latents.shape) == 4: # [C, T, H, W]
71
+ full_latents = full_latents.unsqueeze(0) # -> [1, C, T, H, W]
72
+ squeeze_batch = True
73
+ else:
74
+ squeeze_batch = False
75
+
76
+ B, C, T, H, W = full_latents.shape
77
+
78
+ # 主要latents(用于去噪预测)
79
+ target_start = start_frame + condition_frames
80
+ target_end = target_start + target_frames
81
+ latent_indices = torch.arange(target_start, target_end)
82
+ main_latents = full_latents[:, :, latent_indices, :, :]
83
+
84
+ # 🔧 1x条件帧(起始帧 + 最后1帧)
85
+ clean_latent_indices = torch.tensor([start_frame, start_frame + condition_frames - 1])
86
+ clean_latents = full_latents[:, :, clean_latent_indices, :, :]
87
+
88
+ # 🔧 2x条件帧(最后2帧)
89
+ clean_latents_2x = torch.zeros(B, C, 2, H, W, dtype=full_latents.dtype)
90
+ clean_latent_2x_indices = torch.full((2,), -1, dtype=torch.long)
91
+
92
+ if condition_frames >= 2:
93
+ actual_indices = torch.arange(max(start_frame, start_frame + condition_frames - 2),
94
+ start_frame + condition_frames)
95
+ start_pos = 2 - len(actual_indices)
96
+ clean_latents_2x[:, :, start_pos:, :, :] = full_latents[:, :, actual_indices, :, :]
97
+ clean_latent_2x_indices[start_pos:] = actual_indices
98
+
99
+ # 🔧 4x条件帧(最多16帧)
100
+ clean_latents_4x = torch.zeros(B, C, 16, H, W, dtype=full_latents.dtype)
101
+ clean_latent_4x_indices = torch.full((16,), -1, dtype=torch.long)
102
+
103
+ if condition_frames >= 1:
104
+ actual_indices = torch.arange(max(start_frame, start_frame + condition_frames - 16),
105
+ start_frame + condition_frames)
106
+ start_pos = 16 - len(actual_indices)
107
+ clean_latents_4x[:, :, start_pos:, :, :] = full_latents[:, :, actual_indices, :, :]
108
+ clean_latent_4x_indices[start_pos:] = actual_indices
109
+
110
+ # 移除batch维度(如果原来没有)
111
+ if squeeze_batch:
112
+ main_latents = main_latents.squeeze(0)
113
+ clean_latents = clean_latents.squeeze(0)
114
+ clean_latents_2x = clean_latents_2x.squeeze(0)
115
+ clean_latents_4x = clean_latents_4x.squeeze(0)
116
+
117
+ return {
118
+ 'latents': main_latents,
119
+ 'clean_latents': clean_latents,
120
+ 'clean_latents_2x': clean_latents_2x,
121
+ 'clean_latents_4x': clean_latents_4x,
122
+ 'latent_indices': latent_indices,
123
+ 'clean_latent_indices': clean_latent_indices,
124
+ 'clean_latent_2x_indices': clean_latent_2x_indices,
125
+ 'clean_latent_4x_indices': clean_latent_4x_indices,
126
+ }
127
+
128
+
129
+ def generate_camera_poses_from_data(cam_data, start_frame, condition_frames, target_frames):
130
+ """从实际相机数据生成pose embeddings"""
131
+ time_compression_ratio = 4
132
+ total_frames = condition_frames + target_frames
133
+
134
+ cam_extrinsic = cam_data['extrinsic'] # [N, 4, 4]
135
+ start_frame_original = start_frame * time_compression_ratio
136
+
137
+ print(f"Using camera data from frame {start_frame_original}")
138
+
139
+ # 计算相对pose
140
+ relative_poses = []
141
+ for i in range(total_frames):
142
+ frame_idx = start_frame_original + i * time_compression_ratio
143
+ next_frame_idx = frame_idx + time_compression_ratio
144
+
145
+ if next_frame_idx >= len(cam_extrinsic):
146
+ print('Out of temporal range, using last available pose')
147
+ relative_poses.append(relative_poses[-1] if relative_poses else torch.zeros(3, 4))
148
+ else:
149
+ cam_prev = cam_extrinsic[frame_idx]
150
+ cam_next = cam_extrinsic[next_frame_idx]
151
+
152
+ relative_pose = compute_relative_pose(cam_prev, cam_next)
153
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
154
+
155
+ pose_embedding = torch.stack(relative_poses, dim=0)
156
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12]
157
+
158
+ # 添加mask信息
159
+ mask = torch.zeros(total_frames, dtype=torch.float32)
160
+ mask[:condition_frames] = 1.0 # condition frames
161
+ mask = mask.view(-1, 1)
162
+
163
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13]
164
+ print(f"Generated camera embedding shape: {camera_embedding.shape}")
165
+
166
+ return camera_embedding.to(torch.bfloat16)
167
+
168
+
169
+ def generate_synthetic_camera_poses(direction="forward", target_frames=10, condition_frames=20):
170
+ """根据指定方向生成相机pose序列(合成数据)"""
171
+ total_frames = condition_frames + target_frames
172
+ poses = []
173
+
174
+ for i in range(total_frames):
175
+ t = i / max(1, total_frames - 1)
176
+ pose = np.eye(4, dtype=np.float32)
177
+
178
+ if direction == "forward":
179
+ pose[2, 3] = -t * 0.04
180
+ elif direction == "backward":
181
+ pose[2, 3] = t * 2.0
182
+ elif direction == "left_turn":
183
+ pose[2, 3] = -t * 0.03
184
+ pose[0, 3] = t * 0.02
185
+ yaw = t * 1
186
+ pose[0, 0] = np.cos(yaw)
187
+ pose[0, 2] = np.sin(yaw)
188
+ pose[2, 0] = -np.sin(yaw)
189
+ pose[2, 2] = np.cos(yaw)
190
+ elif direction == "right_turn":
191
+ pose[2, 3] = -t * 0.03
192
+ pose[0, 3] = -t * 0.02
193
+ yaw = -t * 1
194
+ pose[0, 0] = np.cos(yaw)
195
+ pose[0, 2] = np.sin(yaw)
196
+ pose[2, 0] = -np.sin(yaw)
197
+ pose[2, 2] = np.cos(yaw)
198
+
199
+ poses.append(pose)
200
+
201
+ # 计算相对pose
202
+ relative_poses = []
203
+ for i in range(len(poses) - 1):
204
+ relative_pose = compute_relative_pose(poses[i], poses[i + 1])
205
+ relative_poses.append(torch.as_tensor(relative_pose[:3, :]))
206
+
207
+ if len(relative_poses) < total_frames:
208
+ relative_poses.append(relative_poses[-1])
209
+
210
+ pose_embedding = torch.stack(relative_poses[:total_frames], dim=0)
211
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') # [frames, 12]
212
+
213
+ # 添加mask信息
214
+ mask = torch.zeros(total_frames, dtype=torch.float32)
215
+ mask[:condition_frames] = 1.0
216
+ mask = mask.view(-1, 1)
217
+
218
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1) # [frames, 13]
219
+ print(f"Generated {direction} movement poses: {camera_embedding.shape}")
220
+
221
+ return camera_embedding.to(torch.bfloat16)
222
+
223
+
224
+ def replace_dit_model_in_manager():
225
+ """替换DiT模型类为FramePack版本"""
226
+ from diffsynth.models.wan_video_dit_recam_future import WanModelFuture
227
+ from diffsynth.configs.model_config import model_loader_configs
228
+
229
+ for i, config in enumerate(model_loader_configs):
230
+ keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
231
+
232
+ if 'wan_video_dit' in model_names:
233
+ new_model_names = []
234
+ new_model_classes = []
235
+
236
+ for name, cls in zip(model_names, model_classes):
237
+ if name == 'wan_video_dit':
238
+ new_model_names.append(name)
239
+ new_model_classes.append(WanModelFuture)
240
+ print(f"✅ 替换了模型类: {name} -> WanModelFuture")
241
+ else:
242
+ new_model_names.append(name)
243
+ new_model_classes.append(cls)
244
+
245
+ model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
246
+
247
+ def add_framepack_components(dit_model):
248
+ """添加FramePack相关组件"""
249
+ if not hasattr(dit_model, 'clean_x_embedder'):
250
+ inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
251
+
252
+ class CleanXEmbedder(nn.Module):
253
+ def __init__(self, inner_dim):
254
+ super().__init__()
255
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
256
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
257
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
258
+
259
+ def forward(self, x, scale="1x"):
260
+ # 🔧 确保输入和权重的数据类型匹配
261
+ if scale == "1x":
262
+ x = x.to(self.proj.weight.dtype)
263
+ return self.proj(x)
264
+ elif scale == "2x":
265
+ x = x.to(self.proj_2x.weight.dtype)
266
+ return self.proj_2x(x)
267
+ elif scale == "4x":
268
+ x = x.to(self.proj_4x.weight.dtype)
269
+ return self.proj_4x(x)
270
+ else:
271
+ raise ValueError(f"Unsupported scale: {scale}")
272
+
273
+ dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
274
+ # 🔧 修复:使用模型参数的dtype而不是模型的dtype属性
275
+ model_dtype = next(dit_model.parameters()).dtype
276
+ dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
277
+ print("✅ 添加了FramePack的clean_x_embedder组件")
278
+
279
+ def inference_sekai_framepack_from_pth(
280
+ condition_pth_path,
281
+ dit_path,
282
+ output_path="sekai/infer_results/output_sekai_framepack.mp4",
283
+ start_frame=0,
284
+ condition_frames=10,
285
+ target_frames=2,
286
+ device="cuda",
287
+ prompt="A video of a scene shot using a pedestrian's front camera while walking",
288
+ direction="forward",
289
+ use_real_poses=True
290
+ ):
291
+ """
292
+ FramePack风格的Sekai视频推理
293
+ """
294
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
295
+ print(f"Setting up FramePack models for {direction} movement...")
296
+
297
+ # 1. 替换模型类并加载模型
298
+ replace_dit_model_in_manager()
299
+
300
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
301
+ model_manager.load_models([
302
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
303
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
304
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
305
+ ])
306
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
307
+
308
+ # 2. 添加camera components和FramePack components
309
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
310
+ for block in pipe.dit.blocks:
311
+ block.cam_encoder = nn.Linear(13, dim)
312
+ block.projector = nn.Linear(dim, dim)
313
+ block.cam_encoder.weight.data.zero_()
314
+ block.cam_encoder.bias.data.zero_()
315
+ block.projector.weight = nn.Parameter(torch.eye(dim))
316
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
317
+
318
+ # 添加FramePack组件
319
+ add_framepack_components(pipe.dit)
320
+
321
+ # 3. 加载训练的权重
322
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
323
+ pipe.dit.load_state_dict(dit_state_dict, strict=True)
324
+
325
+ pipe = pipe.to(device)
326
+ model_dtype = next(pipe.dit.parameters()).dtype
327
+ pipe.dit = pipe.dit.to(dtype=model_dtype)
328
+ if hasattr(pipe.dit, 'clean_x_embedder'):
329
+ pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
330
+
331
+ pipe.scheduler.set_timesteps(50)
332
+ print("Loading condition video from pth...")
333
+
334
+ # 4. 加载条件视频数据
335
+ condition_latents, encoded_data = load_encoded_video_from_pth(
336
+ condition_pth_path,
337
+ start_frame=start_frame,
338
+ num_frames=condition_frames
339
+ )
340
+
341
+ print("Preparing FramePack inputs...")
342
+
343
+ # 5. 🔧 准备FramePack风格的多尺度输入
344
+ full_latents = encoded_data['latents']
345
+ framepack_inputs = prepare_framepack_inputs(
346
+ full_latents, condition_frames, target_frames, start_frame
347
+ )
348
+
349
+ # 🔧 转换为正确的设备和数据类型,确保与DiT模型一致
350
+ for key in framepack_inputs:
351
+ if torch.is_tensor(framepack_inputs[key]):
352
+ framepack_inputs[key] = framepack_inputs[key].to(device, dtype=model_dtype)
353
+
354
+ print("Processing poses...")
355
+
356
+ # 6. 生成相机pose embedding
357
+ if use_real_poses and 'cam_emb' in encoded_data:
358
+ print("Using real camera poses from data")
359
+ camera_embedding = generate_camera_poses_from_data(
360
+ encoded_data['cam_emb'],
361
+ start_frame=start_frame,
362
+ condition_frames=condition_frames,
363
+ target_frames=target_frames
364
+ )
365
+ else:
366
+ print(f"Using synthetic {direction} poses")
367
+ camera_embedding = generate_synthetic_camera_poses(
368
+ direction=direction,
369
+ target_frames=target_frames,
370
+ condition_frames=condition_frames
371
+ )
372
+
373
+ camera_embedding = camera_embedding.unsqueeze(0).to(device, dtype=model_dtype)
374
+ print("Encoding prompt...")
375
+
376
+ # 7. 编码文本提示
377
+ prompt_emb = pipe.encode_prompt(prompt)
378
+ print("Generating video...")
379
+
380
+ # 8. 生成目标latents
381
+ batch_size = 1
382
+ channels = framepack_inputs['latents'].shape[0] # 现在latents没有batch维度
383
+ latent_height = framepack_inputs['latents'].shape[2]
384
+ latent_width = framepack_inputs['latents'].shape[3]
385
+
386
+ # 空间裁剪以节省内存
387
+ target_height, target_width = 60, 104
388
+
389
+ if latent_height > target_height or latent_width > target_width:
390
+ h_start = (latent_height - target_height) // 2
391
+ w_start = (latent_width - target_width) // 2
392
+
393
+ # 裁剪所有inputs
394
+ for key in ['latents', 'clean_latents', 'clean_latents_2x', 'clean_latents_4x']:
395
+ if key in framepack_inputs and torch.is_tensor(framepack_inputs[key]):
396
+ framepack_inputs[key] = framepack_inputs[key][:, :,
397
+ h_start:h_start+target_height,
398
+ w_start:w_start+target_width]
399
+
400
+ latent_height = target_height
401
+ latent_width = target_width
402
+
403
+ # 为推理添加batch维度
404
+ for key in ['latents', 'clean_latents', 'clean_latents_2x', 'clean_latents_4x']:
405
+ if key in framepack_inputs and torch.is_tensor(framepack_inputs[key]):
406
+ framepack_inputs[key] = framepack_inputs[key].unsqueeze(0)
407
+
408
+ # 🔧 修复:为索引张量添加batch维度并确保正确的数据类型
409
+ for key in ['latent_indices', 'clean_latent_indices', 'clean_latent_2x_indices', 'clean_latent_4x_indices']:
410
+ if key in framepack_inputs and torch.is_tensor(framepack_inputs[key]):
411
+ # 确保索引是long类型,并且在CPU上
412
+ framepack_inputs[key] = framepack_inputs[key].long().cpu().unsqueeze(0)
413
+
414
+ # 初始化target latents with noise
415
+ target_latents = torch.randn(
416
+ batch_size, channels, target_frames, latent_height, latent_width,
417
+ device=device, dtype=model_dtype # 🔧 使用模型的dtype
418
+ )
419
+
420
+ print(f"FramePack inputs:")
421
+ for key, value in framepack_inputs.items():
422
+ if torch.is_tensor(value):
423
+ print(f" {key}: {value.shape} {value.dtype}")
424
+ else:
425
+ print(f" {key}: {value}")
426
+ print(f"Camera embedding shape: {camera_embedding.shape}")
427
+ print(f"Target latents shape: {target_latents.shape}")
428
+
429
+ # 9. 准备额外输入
430
+ extra_input = pipe.prepare_extra_input(target_latents)
431
+
432
+ # 10. 🔧 FramePack风格的去噪循环
433
+ timesteps = pipe.scheduler.timesteps
434
+
435
+ for i, timestep in enumerate(timesteps):
436
+ print(f"Denoising step {i+1}/{len(timesteps)}")
437
+
438
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
439
+
440
+ # 🔧 使用FramePack风格的forward调用
441
+ with torch.no_grad():
442
+ noise_pred = pipe.dit(
443
+ target_latents,
444
+ timestep=timestep_tensor,
445
+ cam_emb=camera_embedding,
446
+ # FramePack参数
447
+ latent_indices=framepack_inputs['latent_indices'],
448
+ clean_latents=framepack_inputs['clean_latents'],
449
+ clean_latent_indices=framepack_inputs['clean_latent_indices'],
450
+ clean_latents_2x=framepack_inputs['clean_latents_2x'],
451
+ clean_latent_2x_indices=framepack_inputs['clean_latent_2x_indices'],
452
+ clean_latents_4x=framepack_inputs['clean_latents_4x'],
453
+ clean_latent_4x_indices=framepack_inputs['clean_latent_4x_indices'],
454
+ **prompt_emb,
455
+ **extra_input
456
+ )
457
+
458
+ # 更新target latents
459
+ target_latents = pipe.scheduler.step(noise_pred, timestep, target_latents)
460
+
461
+ print("Decoding video...")
462
+
463
+ # 11. 解码最终视频
464
+ # 拼接condition和target用于解码
465
+ condition_for_decode = framepack_inputs['clean_latents'][:, :, -1:, :, :] # 取最后一帧作为条件
466
+ final_video = torch.cat([condition_for_decode, target_latents], dim=2)
467
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
468
+
469
+ # 12. 保存视频
470
+ print(f"Saving video to {output_path}")
471
+
472
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
473
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1)
474
+ video_np = (video_np * 255).astype(np.uint8)
475
+
476
+ with imageio.get_writer(output_path, fps=20) as writer:
477
+ for frame in video_np:
478
+ writer.append_data(frame)
479
+
480
+ print(f"FramePack video generation completed! Saved to {output_path}")
481
+
482
+ def main():
483
+ parser = argparse.ArgumentParser(description="Sekai FramePack Video Generation Inference from PTH")
484
+ parser.add_argument("--condition_pth", type=str,
485
+ default="/share_zhuyixuan05/zhuyixuan05/sekai-game-walking/00100100001_0004650_0004950/encoded_video.pth")
486
+ parser.add_argument("--start_frame", type=int, default=0,
487
+ help="Starting frame index (compressed latent frames)")
488
+ parser.add_argument("--condition_frames", type=int, default=8,
489
+ help="Number of condition frames (compressed latent frames)")
490
+ parser.add_argument("--target_frames", type=int, default=8,
491
+ help="Number of target frames to generate (compressed latent frames)")
492
+ parser.add_argument("--direction", type=str, default="left_turn",
493
+ choices=["forward", "backward", "left_turn", "right_turn"],
494
+ help="Direction of camera movement (if not using real poses)")
495
+ parser.add_argument("--use_real_poses", action="store_true", default=False,
496
+ help="Use real camera poses from data")
497
+ parser.add_argument("--dit_path", type=str,
498
+ default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/sekai/sekai_walking_framepack/step24000_framepack.ckpt",
499
+ help="Path to trained FramePack DiT checkpoint")
500
+ parser.add_argument("--output_path", type=str,
501
+ default='/home/zhuyixuan05/ReCamMaster/sekai/infer_framepack_results/output_sekai_framepack.mp4',
502
+ help="Output video path")
503
+ parser.add_argument("--prompt", type=str,
504
+ default="A drone flying scene in a game world",
505
+ help="Text prompt for generation")
506
+ parser.add_argument("--device", type=str, default="cuda",
507
+ help="Device to run inference on")
508
+
509
+ args = parser.parse_args()
510
+
511
+ # 生成输出路径
512
+ if args.output_path is None:
513
+ pth_filename = os.path.basename(args.condition_pth)
514
+ name_parts = os.path.splitext(pth_filename)
515
+ output_dir = "sekai/infer_framepack_results"
516
+ os.makedirs(output_dir, exist_ok=True)
517
+
518
+ if args.use_real_poses:
519
+ output_filename = f"{name_parts[0]}_framepack_real_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4"
520
+ else:
521
+ output_filename = f"{name_parts[0]}_framepack_{args.direction}_{args.start_frame}_{args.condition_frames}_{args.target_frames}.mp4"
522
+
523
+ output_path = os.path.join(output_dir, output_filename)
524
+ else:
525
+ output_path = args.output_path
526
+
527
+ print(f"🔧 FramePack Inference Settings:")
528
+ print(f"Input pth: {args.condition_pth}")
529
+ print(f"Start frame: {args.start_frame} (compressed)")
530
+ print(f"Condition frames: {args.condition_frames} (compressed, original: {args.condition_frames * 4})")
531
+ print(f"Target frames: {args.target_frames} (compressed, original: {args.target_frames * 4})")
532
+ print(f"Use real poses: {args.use_real_poses}")
533
+ print(f"Direction: {args.direction}")
534
+ print(f"Output video will be saved to: {output_path}")
535
+
536
+ inference_sekai_framepack_from_pth(
537
+ condition_pth_path=args.condition_pth,
538
+ dit_path=args.dit_path,
539
+ output_path=output_path,
540
+ start_frame=args.start_frame,
541
+ condition_frames=args.condition_frames,
542
+ target_frames=args.target_frames,
543
+ device=args.device,
544
+ prompt=args.prompt,
545
+ direction=args.direction,
546
+ use_real_poses=args.use_real_poses
547
+ )
548
+
549
+
550
+ if __name__ == "__main__":
551
+ main()
scripts/infer_spatialvid.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import os
5
+ import json
6
+ import imageio
7
+ import argparse
8
+ from PIL import Image
9
+ from diffsynth import WanVideoReCamMasterPipeline, ModelManager
10
+ from torchvision.transforms import v2
11
+ from einops import rearrange
12
+ from scipy.spatial.transform import Rotation as R
13
+
14
+ def compute_relative_pose_matrix(pose1, pose2):
15
+ """
16
+ 计算相邻两帧的相对位姿,返回3×4的相机矩阵 [R_rel | t_rel]
17
+
18
+ 参数:
19
+ pose1: 第i帧的相机位姿,形状为(7,)的数组 [tx1, ty1, tz1, qx1, qy1, qz1, qw1]
20
+ pose2: 第i+1帧的相机位姿,形状为(7,)的数组 [tx2, ty2, tz2, qx2, qy2, qz2, qw2]
21
+
22
+ 返回:
23
+ relative_matrix: 3×4的相对位姿矩阵,前3列是旋转矩阵R_rel,第4列是平移向量t_rel
24
+ """
25
+ # 分离平移向量和四元数
26
+ t1 = pose1[:3] # 第i帧平移 [tx1, ty1, tz1]
27
+ q1 = pose1[3:] # 第i帧四元数 [qx1, qy1, qz1, qw1]
28
+ t2 = pose2[:3] # 第i+1帧平移
29
+ q2 = pose2[3:] # 第i+1帧四元数
30
+
31
+ # 1. 计算相对旋转矩阵 R_rel
32
+ rot1 = R.from_quat(q1) # 第i帧旋转
33
+ rot2 = R.from_quat(q2) # 第i+1帧旋转
34
+ rot_rel = rot2 * rot1.inv() # 相对旋转 = 后一帧旋转 × 前一帧旋转的逆
35
+ R_rel = rot_rel.as_matrix() # 转换为3×3矩阵
36
+
37
+ # 2. 计算相对平移向量 t_rel
38
+ R1_T = rot1.as_matrix().T # 前一帧旋转矩阵的转置(等价于逆)
39
+ t_rel = R1_T @ (t2 - t1) # 相对平移 = R1^T × (t2 - t1)
40
+
41
+ # 3. 组合为3×4矩阵 [R_rel | t_rel]
42
+ relative_matrix = np.hstack([R_rel, t_rel.reshape(3, 1)])
43
+
44
+ return relative_matrix
45
+
46
+ def load_encoded_video_from_pth(pth_path, start_frame=0, num_frames=10):
47
+ """从pth文件加载预编码的视频数据"""
48
+ print(f"Loading encoded video from {pth_path}")
49
+
50
+ encoded_data = torch.load(pth_path, weights_only=False, map_location="cpu")
51
+ full_latents = encoded_data['latents'] # [C, T, H, W]
52
+
53
+ print(f"Full latents shape: {full_latents.shape}")
54
+ print(f"Extracting frames {start_frame} to {start_frame + num_frames}")
55
+
56
+ if start_frame + num_frames > full_latents.shape[1]:
57
+ raise ValueError(f"Not enough frames: requested {start_frame + num_frames}, available {full_latents.shape[1]}")
58
+
59
+ condition_latents = full_latents[:, start_frame:start_frame + num_frames, :, :]
60
+ print(f"Extracted condition latents shape: {condition_latents.shape}")
61
+
62
+ return condition_latents, encoded_data
63
+
64
+ def replace_dit_model_in_manager():
65
+ """在模型加载前替换DiT模型类"""
66
+ from diffsynth.models.wan_video_dit_recam_future import WanModelFuture
67
+ from diffsynth.configs.model_config import model_loader_configs
68
+
69
+ # 修改model_loader_configs中的配置
70
+ for i, config in enumerate(model_loader_configs):
71
+ keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource = config
72
+
73
+ # 检查是否包含wan_video_dit模型
74
+ if 'wan_video_dit' in model_names:
75
+ # 找到wan_video_dit的索引并替换为WanModelFuture
76
+ new_model_names = []
77
+ new_model_classes = []
78
+
79
+ for name, cls in zip(model_names, model_classes):
80
+ if name == 'wan_video_dit':
81
+ new_model_names.append(name) # 保持名称不变
82
+ new_model_classes.append(WanModelFuture) # 替换为新的类
83
+ print(f"✅ 替换了模型类: {name} -> WanModelFuture")
84
+ else:
85
+ new_model_names.append(name)
86
+ new_model_classes.append(cls)
87
+
88
+ # 更新配置
89
+ model_loader_configs[i] = (keys_hash, keys_hash_with_shape, new_model_names, new_model_classes, model_resource)
90
+
91
+ def add_framepack_components(dit_model):
92
+ """添加FramePack相关组件"""
93
+ if not hasattr(dit_model, 'clean_x_embedder'):
94
+ inner_dim = dit_model.blocks[0].self_attn.q.weight.shape[0]
95
+
96
+ class CleanXEmbedder(nn.Module):
97
+ def __init__(self, inner_dim):
98
+ super().__init__()
99
+ # 参考hunyuan_video_packed.py的设计
100
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
101
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
102
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
103
+
104
+ def forward(self, x, scale="1x"):
105
+ if scale == "1x":
106
+ return self.proj(x)
107
+ elif scale == "2x":
108
+ return self.proj_2x(x)
109
+ elif scale == "4x":
110
+ return self.proj_4x(x)
111
+ else:
112
+ raise ValueError(f"Unsupported scale: {scale}")
113
+
114
+ dit_model.clean_x_embedder = CleanXEmbedder(inner_dim)
115
+ model_dtype = next(dit_model.parameters()).dtype
116
+ dit_model.clean_x_embedder = dit_model.clean_x_embedder.to(dtype=model_dtype)
117
+ print("✅ 添加了FramePack的clean_x_embedder组件")
118
+
119
+ def generate_spatialvid_camera_embeddings_sliding(cam_data, start_frame, current_history_length, new_frames, total_generated, use_real_poses=True):
120
+ """为SpatialVid数据集生成camera embeddings - 滑动窗口版本"""
121
+ time_compression_ratio = 4
122
+
123
+ # 计算FramePack实际需要的camera帧数
124
+ framepack_needed_frames = 1 + 16 + 2 + 1 + new_frames
125
+
126
+ if use_real_poses and cam_data is not None and 'extrinsic' in cam_data:
127
+ print("🔧 使用真实SpatialVid camera数据")
128
+ cam_extrinsic = cam_data['extrinsic']
129
+
130
+ # 确保生成足够长的camera序列
131
+ max_needed_frames = max(
132
+ start_frame + current_history_length + new_frames,
133
+ framepack_needed_frames,
134
+ 30
135
+ )
136
+
137
+ print(f"🔧 计算SpatialVid camera序列长度:")
138
+ print(f" - 基础需求: {start_frame + current_history_length + new_frames}")
139
+ print(f" - FramePack需求: {framepack_needed_frames}")
140
+ print(f" - 最终生成: {max_needed_frames}")
141
+
142
+ relative_poses = []
143
+ for i in range(max_needed_frames):
144
+ # SpatialVid特有:每隔1帧而不是4帧
145
+ frame_idx = i
146
+ next_frame_idx = frame_idx + 1
147
+
148
+ if next_frame_idx < len(cam_extrinsic):
149
+ cam_prev = cam_extrinsic[frame_idx]
150
+ cam_next = cam_extrinsic[next_frame_idx]
151
+ relative_cam = compute_relative_pose_matrix(cam_prev, cam_next)
152
+ relative_poses.append(torch.as_tensor(relative_cam[:3, :]))
153
+ else:
154
+ # 超出范围,使用零运动
155
+ print(f"⚠️ 帧{frame_idx}超出camera数据范围,使用零运动")
156
+ relative_poses.append(torch.zeros(3, 4))
157
+
158
+ pose_embedding = torch.stack(relative_poses, dim=0)
159
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
160
+
161
+ # 创建对应长度的mask序列
162
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
163
+ # 从start_frame到current_history_length标记为condition
164
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
165
+ mask[start_frame:condition_end] = 1.0
166
+
167
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
168
+ print(f"🔧 SpatialVid真实camera embedding shape: {camera_embedding.shape}")
169
+ return camera_embedding.to(torch.bfloat16)
170
+
171
+ else:
172
+ print("🔧 使用SpatialVid合成camera数据")
173
+
174
+ max_needed_frames = max(
175
+ start_frame + current_history_length + new_frames,
176
+ framepack_needed_frames,
177
+ 30
178
+ )
179
+
180
+ print(f"🔧 生成SpatialVid合成camera帧数: {max_needed_frames}")
181
+ relative_poses = []
182
+ for i in range(max_needed_frames):
183
+ # SpatialVid室内行走模式 - 轻微的左右摆动 + 前进
184
+ yaw_per_frame = 0.03 * np.sin(i * 0.1) # 左右摆动
185
+ forward_speed = 0.008 # 每帧前进距离
186
+
187
+ pose = np.eye(4, dtype=np.float32)
188
+
189
+ # 旋转矩阵(绕Y轴摆动)
190
+ cos_yaw = np.cos(yaw_per_frame)
191
+ sin_yaw = np.sin(yaw_per_frame)
192
+
193
+ pose[0, 0] = cos_yaw
194
+ pose[0, 2] = sin_yaw
195
+ pose[2, 0] = -sin_yaw
196
+ pose[2, 2] = cos_yaw
197
+
198
+ # 平移(前进 + 轻微的上下晃动)
199
+ pose[2, 3] = -forward_speed # 局部Z轴负方向(前进)
200
+ pose[1, 3] = 0.002 * np.sin(i * 0.15) # 轻微的上下晃动
201
+
202
+ relative_pose = pose[:3, :]
203
+ relative_poses.append(torch.as_tensor(relative_pose))
204
+
205
+ pose_embedding = torch.stack(relative_poses, dim=0)
206
+ pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
207
+
208
+ # 创建对应长度的mask序列
209
+ mask = torch.zeros(max_needed_frames, 1, dtype=torch.float32)
210
+ condition_end = min(start_frame + current_history_length, max_needed_frames)
211
+ mask[start_frame:condition_end] = 1.0
212
+
213
+ camera_embedding = torch.cat([pose_embedding, mask], dim=1)
214
+ print(f"🔧 SpatialVid合成camera embedding shape: {camera_embedding.shape}")
215
+ return camera_embedding.to(torch.bfloat16)
216
+
217
+ def prepare_framepack_sliding_window_with_camera(history_latents, target_frames_to_generate, camera_embedding_full, start_frame, max_history_frames=49):
218
+ """FramePack滑动窗口机制 - SpatialVid版本"""
219
+ # history_latents: [C, T, H, W] 当前的历史latents
220
+ C, T, H, W = history_latents.shape
221
+
222
+ # 固定索引结构(这决定了需要的camera帧数)
223
+ total_indices_length = 1 + 16 + 2 + 1 + target_frames_to_generate
224
+ indices = torch.arange(0, total_indices_length)
225
+ split_sizes = [1, 16, 2, 1, target_frames_to_generate]
226
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = \
227
+ indices.split(split_sizes, dim=0)
228
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=0)
229
+
230
+ # 检查camera长度是否足够
231
+ if camera_embedding_full.shape[0] < total_indices_length:
232
+ shortage = total_indices_length - camera_embedding_full.shape[0]
233
+ padding = torch.zeros(shortage, camera_embedding_full.shape[1],
234
+ dtype=camera_embedding_full.dtype, device=camera_embedding_full.device)
235
+ camera_embedding_full = torch.cat([camera_embedding_full, padding], dim=0)
236
+
237
+ # 从完整camera序列中选取对应部分
238
+ combined_camera = camera_embedding_full[:total_indices_length, :].clone()
239
+
240
+ # 根据当前history length重新设置mask
241
+ combined_camera[:, -1] = 0.0 # 先全部设为target (0)
242
+
243
+ # 设置condition mask:前19帧根据实际历史长度决定
244
+ if T > 0:
245
+ available_frames = min(T, 19)
246
+ start_pos = 19 - available_frames
247
+ combined_camera[start_pos:19, -1] = 1.0 # 将有效的clean latents对应的camera标记为condition
248
+
249
+ print(f"🔧 SpatialVid Camera mask更新:")
250
+ print(f" - 历史帧数: {T}")
251
+ print(f" - 有效condition帧数: {available_frames if T > 0 else 0}")
252
+
253
+ # 处理latents
254
+ clean_latents_combined = torch.zeros(C, 19, H, W, dtype=history_latents.dtype, device=history_latents.device)
255
+
256
+ if T > 0:
257
+ available_frames = min(T, 19)
258
+ start_pos = 19 - available_frames
259
+ clean_latents_combined[:, start_pos:, :, :] = history_latents[:, -available_frames:, :, :]
260
+
261
+ clean_latents_4x = clean_latents_combined[:, 0:16, :, :]
262
+ clean_latents_2x = clean_latents_combined[:, 16:18, :, :]
263
+ clean_latents_1x = clean_latents_combined[:, 18:19, :, :]
264
+
265
+ if T > 0:
266
+ start_latent = history_latents[:, 0:1, :, :]
267
+ else:
268
+ start_latent = torch.zeros(C, 1, H, W, dtype=history_latents.dtype, device=history_latents.device)
269
+
270
+ clean_latents = torch.cat([start_latent, clean_latents_1x], dim=1)
271
+
272
+ return {
273
+ 'latent_indices': latent_indices,
274
+ 'clean_latents': clean_latents,
275
+ 'clean_latents_2x': clean_latents_2x,
276
+ 'clean_latents_4x': clean_latents_4x,
277
+ 'clean_latent_indices': clean_latent_indices,
278
+ 'clean_latent_2x_indices': clean_latent_2x_indices,
279
+ 'clean_latent_4x_indices': clean_latent_4x_indices,
280
+ 'camera_embedding': combined_camera,
281
+ 'current_length': T,
282
+ 'next_length': T + target_frames_to_generate
283
+ }
284
+
285
+ def inference_spatialvid_framepack_sliding_window(
286
+ condition_pth_path,
287
+ dit_path,
288
+ output_path="spatialvid_results/output_spatialvid_framepack_sliding.mp4",
289
+ start_frame=0,
290
+ initial_condition_frames=8,
291
+ frames_per_generation=4,
292
+ total_frames_to_generate=32,
293
+ max_history_frames=49,
294
+ device="cuda",
295
+ prompt="A man walking through indoor spaces with a first-person view",
296
+ use_real_poses=True,
297
+ # CFG参数
298
+ use_camera_cfg=True,
299
+ camera_guidance_scale=2.0,
300
+ text_guidance_scale=1.0
301
+ ):
302
+ """
303
+ SpatialVid FramePack滑动窗口视频生成
304
+ """
305
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
306
+ print(f"🔧 SpatialVid FramePack滑动窗口生成开始...")
307
+ print(f"Camera CFG: {use_camera_cfg}, Camera guidance scale: {camera_guidance_scale}")
308
+ print(f"Text guidance scale: {text_guidance_scale}")
309
+
310
+ # 1. 模型初始化
311
+ replace_dit_model_in_manager()
312
+
313
+ model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
314
+ model_manager.load_models([
315
+ "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
316
+ "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
317
+ "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
318
+ ])
319
+ pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
320
+
321
+ # 2. 添加camera编码器
322
+ dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
323
+ for block in pipe.dit.blocks:
324
+ block.cam_encoder = nn.Linear(13, dim)
325
+ block.projector = nn.Linear(dim, dim)
326
+ block.cam_encoder.weight.data.zero_()
327
+ block.cam_encoder.bias.data.zero_()
328
+ block.projector.weight = nn.Parameter(torch.eye(dim))
329
+ block.projector.bias = nn.Parameter(torch.zeros(dim))
330
+
331
+ # 3. 添加FramePack组件
332
+ add_framepack_components(pipe.dit)
333
+
334
+ # 4. 加载训练好的权重
335
+ dit_state_dict = torch.load(dit_path, map_location="cpu")
336
+ pipe.dit.load_state_dict(dit_state_dict, strict=True)
337
+ pipe = pipe.to(device)
338
+ model_dtype = next(pipe.dit.parameters()).dtype
339
+
340
+ if hasattr(pipe.dit, 'clean_x_embedder'):
341
+ pipe.dit.clean_x_embedder = pipe.dit.clean_x_embedder.to(dtype=model_dtype)
342
+
343
+ pipe.scheduler.set_timesteps(50)
344
+
345
+ # 5. 加载初始条件
346
+ print("Loading initial condition frames...")
347
+ initial_latents, encoded_data = load_encoded_video_from_pth(
348
+ condition_pth_path,
349
+ start_frame=start_frame,
350
+ num_frames=initial_condition_frames
351
+ )
352
+
353
+ # 空间裁剪
354
+ target_height, target_width = 60, 104
355
+ C, T, H, W = initial_latents.shape
356
+
357
+ if H > target_height or W > target_width:
358
+ h_start = (H - target_height) // 2
359
+ w_start = (W - target_width) // 2
360
+ initial_latents = initial_latents[:, :, h_start:h_start+target_height, w_start:w_start+target_width]
361
+ H, W = target_height, target_width
362
+
363
+ history_latents = initial_latents.to(device, dtype=model_dtype)
364
+
365
+ print(f"初始history_latents shape: {history_latents.shape}")
366
+
367
+ # 6. 编码prompt - 支持CFG
368
+ if text_guidance_scale > 1.0:
369
+ prompt_emb_pos = pipe.encode_prompt(prompt)
370
+ prompt_emb_neg = pipe.encode_prompt("")
371
+ print(f"使用Text CFG,guidance scale: {text_guidance_scale}")
372
+ else:
373
+ prompt_emb_pos = pipe.encode_prompt(prompt)
374
+ prompt_emb_neg = None
375
+ print("不使用Text CFG")
376
+
377
+ # 7. 预生成完整的camera embedding序列
378
+ camera_embedding_full = generate_spatialvid_camera_embeddings_sliding(
379
+ encoded_data.get('cam_emb', None),
380
+ 0,
381
+ max_history_frames,
382
+ 0,
383
+ 0,
384
+ use_real_poses=use_real_poses
385
+ ).to(device, dtype=model_dtype)
386
+
387
+ print(f"完整camera序列shape: {camera_embedding_full.shape}")
388
+
389
+ # 8. 为Camera CFG创建无条件的camera embedding
390
+ if use_camera_cfg:
391
+ camera_embedding_uncond = torch.zeros_like(camera_embedding_full)
392
+ print(f"创建无条件camera embedding用于CFG")
393
+
394
+ # 9. 滑动窗口生成循环
395
+ total_generated = 0
396
+ all_generated_frames = []
397
+
398
+ while total_generated < total_frames_to_generate:
399
+ current_generation = min(frames_per_generation, total_frames_to_generate - total_generated)
400
+ print(f"\n🔧 生成步骤 {total_generated // frames_per_generation + 1}")
401
+ print(f"当前历史长度: {history_latents.shape[1]}, 本次生成: {current_generation}")
402
+
403
+ # FramePack数据准备 - SpatialVid版本
404
+ framepack_data = prepare_framepack_sliding_window_with_camera(
405
+ history_latents,
406
+ current_generation,
407
+ camera_embedding_full,
408
+ start_frame,
409
+ max_history_frames
410
+ )
411
+
412
+ # 准备输入
413
+ clean_latents = framepack_data['clean_latents'].unsqueeze(0)
414
+ clean_latents_2x = framepack_data['clean_latents_2x'].unsqueeze(0)
415
+ clean_latents_4x = framepack_data['clean_latents_4x'].unsqueeze(0)
416
+ camera_embedding = framepack_data['camera_embedding'].unsqueeze(0)
417
+
418
+ # 为CFG准备无条件camera embedding
419
+ if use_camera_cfg:
420
+ camera_embedding_uncond_batch = camera_embedding_uncond[:camera_embedding.shape[1], :].unsqueeze(0)
421
+
422
+ # 索引处理
423
+ latent_indices = framepack_data['latent_indices'].unsqueeze(0).cpu()
424
+ clean_latent_indices = framepack_data['clean_latent_indices'].unsqueeze(0).cpu()
425
+ clean_latent_2x_indices = framepack_data['clean_latent_2x_indices'].unsqueeze(0).cpu()
426
+ clean_latent_4x_indices = framepack_data['clean_latent_4x_indices'].unsqueeze(0).cpu()
427
+
428
+ # 初始化要生成的latents
429
+ new_latents = torch.randn(
430
+ 1, C, current_generation, H, W,
431
+ device=device, dtype=model_dtype
432
+ )
433
+
434
+ extra_input = pipe.prepare_extra_input(new_latents)
435
+
436
+ print(f"Camera embedding shape: {camera_embedding.shape}")
437
+ print(f"Camera mask分布 - condition: {torch.sum(camera_embedding[0, :, -1] == 1.0).item()}, target: {torch.sum(camera_embedding[0, :, -1] == 0.0).item()}")
438
+
439
+ # 去噪循环 - 支持CFG
440
+ timesteps = pipe.scheduler.timesteps
441
+
442
+ for i, timestep in enumerate(timesteps):
443
+ if i % 10 == 0:
444
+ print(f" 去噪步骤 {i}/{len(timesteps)}")
445
+
446
+ timestep_tensor = timestep.unsqueeze(0).to(device, dtype=model_dtype)
447
+
448
+ with torch.no_grad():
449
+ # 正向预测(带条件)
450
+ noise_pred_pos = pipe.dit(
451
+ new_latents,
452
+ timestep=timestep_tensor,
453
+ cam_emb=camera_embedding,
454
+ latent_indices=latent_indices,
455
+ clean_latents=clean_latents,
456
+ clean_latent_indices=clean_latent_indices,
457
+ clean_latents_2x=clean_latents_2x,
458
+ clean_latent_2x_indices=clean_latent_2x_indices,
459
+ clean_latents_4x=clean_latents_4x,
460
+ clean_latent_4x_indices=clean_latent_4x_indices,
461
+ **prompt_emb_pos,
462
+ **extra_input
463
+ )
464
+
465
+ # CFG处理
466
+ if use_camera_cfg and camera_guidance_scale > 1.0:
467
+ # 无条件预测(无camera条件)
468
+ noise_pred_uncond = pipe.dit(
469
+ new_latents,
470
+ timestep=timestep_tensor,
471
+ cam_emb=camera_embedding_uncond_batch,
472
+ latent_indices=latent_indices,
473
+ clean_latents=clean_latents,
474
+ clean_latent_indices=clean_latent_indices,
475
+ clean_latents_2x=clean_latents_2x,
476
+ clean_latent_2x_indices=clean_latent_2x_indices,
477
+ clean_latents_4x=clean_latents_4x,
478
+ clean_latent_4x_indices=clean_latent_4x_indices,
479
+ **prompt_emb_pos,
480
+ **extra_input
481
+ )
482
+
483
+ # Camera CFG
484
+ noise_pred = noise_pred_uncond + camera_guidance_scale * (noise_pred_pos - noise_pred_uncond)
485
+ else:
486
+ noise_pred = noise_pred_pos
487
+
488
+ # Text CFG
489
+ if prompt_emb_neg is not None and text_guidance_scale > 1.0:
490
+ noise_pred_neg = pipe.dit(
491
+ new_latents,
492
+ timestep=timestep_tensor,
493
+ cam_emb=camera_embedding,
494
+ latent_indices=latent_indices,
495
+ clean_latents=clean_latents,
496
+ clean_latent_indices=clean_latent_indices,
497
+ clean_latents_2x=clean_latents_2x,
498
+ clean_latent_2x_indices=clean_latent_2x_indices,
499
+ clean_latents_4x=clean_latents_4x,
500
+ clean_latent_4x_indices=clean_latent_4x_indices,
501
+ **prompt_emb_neg,
502
+ **extra_input
503
+ )
504
+
505
+ noise_pred = noise_pred_neg + text_guidance_scale * (noise_pred - noise_pred_neg)
506
+
507
+ new_latents = pipe.scheduler.step(noise_pred, timestep, new_latents)
508
+
509
+ # 更新历史
510
+ new_latents_squeezed = new_latents.squeeze(0)
511
+ history_latents = torch.cat([history_latents, new_latents_squeezed], dim=1)
512
+
513
+ # 维护滑动窗口
514
+ if history_latents.shape[1] > max_history_frames:
515
+ first_frame = history_latents[:, 0:1, :, :]
516
+ recent_frames = history_latents[:, -(max_history_frames-1):, :, :]
517
+ history_latents = torch.cat([first_frame, recent_frames], dim=1)
518
+ print(f"历史窗口已满,保留第一帧+最新{max_history_frames-1}帧")
519
+
520
+ print(f"更新后history_latents shape: {history_latents.shape}")
521
+
522
+ all_generated_frames.append(new_latents_squeezed)
523
+ total_generated += current_generation
524
+
525
+ print(f"✅ 已生成 {total_generated}/{total_frames_to_generate} 帧")
526
+
527
+ # 10. 解码和保存
528
+ print("\n🔧 解码生成的视频...")
529
+
530
+ all_generated = torch.cat(all_generated_frames, dim=1)
531
+ final_video = torch.cat([initial_latents.to(all_generated.device), all_generated], dim=1).unsqueeze(0)
532
+
533
+ print(f"最终视频shape: {final_video.shape}")
534
+
535
+ decoded_video = pipe.decode_video(final_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16))
536
+
537
+ print(f"Saving video to {output_path}")
538
+
539
+ video_np = decoded_video[0].to(torch.float32).permute(1, 2, 3, 0).cpu().numpy()
540
+ video_np = (video_np * 0.5 + 0.5).clip(0, 1)
541
+ video_np = (video_np * 255).astype(np.uint8)
542
+
543
+ with imageio.get_writer(output_path, fps=20) as writer:
544
+ for frame in video_np:
545
+ writer.append_data(frame)
546
+
547
+ print(f"🔧 SpatialVid FramePack滑动窗口生成完成! 保存到: {output_path}")
548
+ print(f"总共生成了 {total_generated} 帧 (压缩后), 对应原始 {total_generated * 4} 帧")
549
+
550
+ def main():
551
+ parser = argparse.ArgumentParser(description="SpatialVid FramePack滑动窗口视频生成")
552
+
553
+ # 基础参数
554
+ parser.add_argument("--condition_pth", type=str,
555
+ default="/share_zhuyixuan05/zhuyixuan05/spatialvid/a9a6d37f-0a6c-548a-a494-7d902469f3f2_0000000_0000300/encoded_video.pth",
556
+ help="输入编码视频路径")
557
+ parser.add_argument("--start_frame", type=int, default=0)
558
+ parser.add_argument("--initial_condition_frames", type=int, default=16)
559
+ parser.add_argument("--frames_per_generation", type=int, default=8)
560
+ parser.add_argument("--total_frames_to_generate", type=int, default=16)
561
+ parser.add_argument("--max_history_frames", type=int, default=100)
562
+ parser.add_argument("--use_real_poses", action="store_true", default=True)
563
+ parser.add_argument("--dit_path", type=str,
564
+ default="/share_zhuyixuan05/zhuyixuan05/ICLR2026/spatialvid/spatialvid_framepack_random/step50.ckpt",
565
+ help="训练好的模型权重路径")
566
+ parser.add_argument("--output_path", type=str,
567
+ default='spatialvid_results/output_spatialvid_framepack_sliding.mp4')
568
+ parser.add_argument("--prompt", type=str,
569
+ default="A man walking through indoor spaces with a first-person view")
570
+ parser.add_argument("--device", type=str, default="cuda")
571
+
572
+ # CFG参数
573
+ parser.add_argument("--use_camera_cfg", action="store_true", default=True,
574
+ help="使用Camera CFG")
575
+ parser.add_argument("--camera_guidance_scale", type=float, default=2.0,
576
+ help="Camera guidance scale for CFG")
577
+ parser.add_argument("--text_guidance_scale", type=float, default=1.0,
578
+ help="Text guidance scale for CFG")
579
+
580
+ args = parser.parse_args()
581
+
582
+ print(f"🔧 SpatialVid FramePack CFG生成设置:")
583
+ print(f"Camera CFG: {args.use_camera_cfg}")
584
+ if args.use_camera_cfg:
585
+ print(f"Camera guidance scale: {args.camera_guidance_scale}")
586
+ print(f"Text guidance scale: {args.text_guidance_scale}")
587
+ print(f"SpatialVid特有特性: camera间隔为1帧")
588
+
589
+ inference_spatialvid_framepack_sliding_window(
590
+ condition_pth_path=args.condition_pth,
591
+ dit_path=args.dit_path,
592
+ output_path=args.output_path,
593
+ start_frame=args.start_frame,
594
+ initial_condition_frames=args.initial_condition_frames,
595
+ frames_per_generation=args.frames_per_generation,
596
+ total_frames_to_generate=args.total_frames_to_generate,
597
+ max_history_frames=args.max_history_frames,
598
+ device=args.device,
599
+ prompt=args.prompt,
600
+ use_real_poses=args.use_real_poses,
601
+ # CFG参数
602
+ use_camera_cfg=args.use_camera_cfg,
603
+ camera_guidance_scale=args.camera_guidance_scale,
604
+ text_guidance_scale=args.text_guidance_scale
605
+ )
606
+
607
+ if __name__ == "__main__":
608
+ main()