bazaar-research commited on
Commit
0a7036f
·
verified ·
1 Parent(s): 90764a1

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. .gitignore +2 -0
  3. .vscode/launch.json +58 -0
  4. INSTALL.md +55 -0
  5. LICENSE.txt +201 -0
  6. LingBot_VA_paper.pdf +3 -0
  7. Makefile +5 -0
  8. README.md +371 -0
  9. assets/teaser.mp4 +3 -0
  10. assets/teaser_v3.png +3 -0
  11. debug/place_fan/call1_reset.msgpack +3 -0
  12. debug/place_fan/call2.msgpack +3 -0
  13. debug/place_fan/call3.msgpack +3 -0
  14. evaluation/robotwin/calc_stat.py +132 -0
  15. evaluation/robotwin/eval_polict_client_openpi.py +696 -0
  16. evaluation/robotwin/geometry.py +463 -0
  17. evaluation/robotwin/launch_client.sh +40 -0
  18. evaluation/robotwin/launch_client_multigpus.sh +81 -0
  19. evaluation/robotwin/launch_server.sh +15 -0
  20. evaluation/robotwin/launch_server_multigpus.sh +31 -0
  21. evaluation/robotwin/msgpack_numpy.py +57 -0
  22. evaluation/robotwin/test_render.py +81 -0
  23. evaluation/robotwin/websocket_client_policy.py +108 -0
  24. example/franka/observation.images.cam_high.png +0 -0
  25. example/franka/observation.images.cam_left_wrist.png +0 -0
  26. example/franka/observation.images.cam_right_wrist.png +0 -0
  27. example/robotwin/observation.images.cam_high.png +0 -0
  28. example/robotwin/observation.images.cam_left_wrist.png +0 -0
  29. example/robotwin/observation.images.cam_right_wrist.png +0 -0
  30. lingbot_robotwin_policy.py +506 -0
  31. pyproject.toml +61 -0
  32. requirements.txt +11 -0
  33. script/run_launch_va_server_sync.sh +34 -0
  34. wan_va/__init__.py +2 -0
  35. wan_va/configs/__init__.py +12 -0
  36. wan_va/configs/shared_config.py +13 -0
  37. wan_va/configs/va_franka_cfg.py +59 -0
  38. wan_va/configs/va_franka_i2va.py +11 -0
  39. wan_va/configs/va_robotwin_cfg.py +54 -0
  40. wan_va/configs/va_robotwin_i2va.py +11 -0
  41. wan_va/distributed/__init__.py +1 -0
  42. wan_va/distributed/fsdp.py +42 -0
  43. wan_va/distributed/util.py +29 -0
  44. wan_va/modules/__init__.py +7 -0
  45. wan_va/modules/model.py +580 -0
  46. wan_va/modules/utils.py +95 -0
  47. wan_va/utils/Simple_Remote_Infer/LEGAL.md +7 -0
  48. wan_va/utils/Simple_Remote_Infer/README.md +16 -0
  49. wan_va/utils/Simple_Remote_Infer/deploy/__init__.py +0 -0
  50. wan_va/utils/Simple_Remote_Infer/deploy/image_tools.py +66 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ LingBot_VA_paper.pdf filter=lfs diff=lfs merge=lfs -text
37
+ assets/teaser.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ assets/teaser_v3.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.pyc
2
+ visualization/
.vscode/launch.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+
8
+ {
9
+ "name": "Wan Server",
10
+ "type": "debugpy",
11
+ "request": "launch",
12
+ "program": "${file}",
13
+ "console": "integratedTerminal",
14
+ "justMyCode": false,
15
+ "args": [
16
+ "--config-name",
17
+ "robotwin",
18
+ "--port",
19
+ "29056",
20
+ "--save_root",
21
+ "visualization/",
22
+ "--debug_infer_once"
23
+ ],
24
+ "env": {
25
+ "CUDA_VISIBLE_DEVICES": "6"
26
+ }
27
+ },
28
+ {
29
+ "name": "Robotwin Client",
30
+ "type": "debugpy",
31
+ "request": "launch",
32
+ "module": "evaluation.robotwin.eval_polict_client_openpi",
33
+ "console": "integratedTerminal",
34
+ "justMyCode": false,
35
+ "cwd": "${workspaceFolder}",
36
+ "args": [
37
+ "--config", "policy/ACT/deploy_policy.yml",
38
+ "--overrides",
39
+ "--task_name", "adjust_bottle",
40
+ "--task_config", "demo_clean",
41
+ "--train_config_name", "0",
42
+ "--model_name", "0",
43
+ "--ckpt_setting", "0",
44
+ "--seed", "0",
45
+ "--policy_name", "ACT",
46
+ "--save_root", "./results",
47
+ "--video_guidance_scale", "5",
48
+ "--action_guidance_scale", "1",
49
+ "--test_num", "100",
50
+ "--port", "29056"
51
+ ],
52
+ "env": {
53
+ "LD_LIBRARY_PATH": "/usr/lib64:/usr/lib:${env:LD_LIBRARY_PATH}",
54
+ "XLA_PYTHON_CLIENT_MEM_FRACTION": "0.9"
55
+ }
56
+ }
57
+ ]
58
+ }
INSTALL.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Installation Guide
2
+
3
+ ## Install with pip
4
+
5
+ ```bash
6
+ pip install .
7
+ pip install .[dev] # Installe aussi les outils de dev
8
+ ```
9
+
10
+ ## Install with Poetry
11
+
12
+ Ensure you have [Poetry](https://python-poetry.org/docs/#installation) installed on your system.
13
+
14
+ To install all dependencies:
15
+
16
+ ```bash
17
+ poetry install
18
+ ```
19
+
20
+ ### Handling `flash-attn` Installation Issues
21
+
22
+ If `flash-attn` fails due to **PEP 517 build issues**, you can try one of the following fixes.
23
+
24
+ #### No-Build-Isolation Installation (Recommended)
25
+ ```bash
26
+ poetry run pip install --upgrade pip setuptools wheel
27
+ poetry run pip install flash-attn --no-build-isolation
28
+ poetry install
29
+ ```
30
+
31
+ #### Install from Git (Alternative)
32
+ ```bash
33
+ poetry run pip install git+https://github.com/Dao-AILab/flash-attention.git
34
+ ```
35
+
36
+ ---
37
+
38
+ ### Running the Model
39
+
40
+ Once the installation is complete, you can run **Wan2.2** using:
41
+
42
+ ```bash
43
+ poetry run python generate.py --task t2v-A14B --size '1280*720' --ckpt_dir ./Wan2.2-T2V-A14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
44
+ ```
45
+
46
+ #### Test
47
+ ```bash
48
+ bash tests/test.sh
49
+ ```
50
+
51
+ #### Format
52
+ ```bash
53
+ black .
54
+ isort .
55
+ ```
LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
LingBot_VA_paper.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e791faff04ff10eccb62eef7952eabbcb2c654abc4d73f4b4a8d1f683a6e48ba
3
+ size 7834795
Makefile ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .PHONY: format
2
+
3
+ format:
4
+ isort wan_va
5
+ yapf -i -r *.py wan_va
README.md ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">LingBot-VA: Causal World Modeling for Robot Control</h1>
2
+
3
+ <p align="center">
4
+ <a href="https://arxiv.org/abs/2601.21998"><img src="https://img.shields.io/static/v1?label=Paper&message=PDF&color=red&logo=arxiv"></a>
5
+ <a href="https://technology.robbyant.com/lingbot-va"><img src="https://img.shields.io/badge/Project-Website-blue"></a>
6
+ <a href="https://huggingface.co/collections/robbyant/lingbot-va"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Model&message=HuggingFace&color=orange"></a>
7
+ <a href="https://modelscope.cn/collections/Robbyant/LingBot-VA"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%96%20Model&message=ModelScope&color=purple"></a>
8
+ <a href="LICENSE.txt"><img src="https://img.shields.io/badge/License-Apache--2.0-green"></a>
9
+ </p>
10
+
11
+ <p align="center">
12
+ <img src="assets/teaser_v3.png" width="100%">
13
+ </p>
14
+
15
+
16
+
17
+ https://github.com/user-attachments/assets/cec7b7a6-953b-4fa4-8f1a-47efc1fce547
18
+
19
+
20
+
21
+
22
+ ## 💫 Meet **LingBot-VA**! We've built an AR diffusion framework for simultaneous world modeling and action! 🤖✨
23
+
24
+ **LingBot-VA** has focused on:
25
+ - **Autoregressive Video-Action World Modeling**: Architecturally unifies visual dynamics prediction and action inference within a single interleaved sequence while maintaining their conceptual distinction.
26
+ - **High-efficiency Execution**: A dual-stream mixture-of-transformers(MoT) architecture with Asynchronous Execution and KV Cache.
27
+ - **Long-Horizon Performance and Generalization**: High improvements in sample efficiency, long-horizon success rates, and generalization to novel scenes.
28
+
29
+ # 🚀 News
30
+ - **[2026-01-29]** Weights and code for shared backbone released! Please stay tuned for our separated version!
31
+
32
+
33
+
34
+
35
+ ---
36
+
37
+
38
+
39
+ # 📦 Model Download
40
+ - **Pretrained Checkpoints for Post-Training**
41
+
42
+ | Model Name | Huggingface Repository | ModelScope Repository | Description |
43
+ | :--- | :--- | :--- | :--- |
44
+ | lingbot-va-base &nbsp; | [🤗 robbyant/lingbot-va-base &nbsp;](https://huggingface.co/robbyant/lingbot-va-base) | [🤖 Robbyant/lingbot-va-base &nbsp;](https://modelscope.cn/models/Robbyant/lingbot-va-base) | LingBot-VA w/ shared backbone|
45
+ | lingbot-va-posttrain-robotwin &nbsp; | [🤗 robbyant/lingbot-va-posttrain-robotwin &nbsp;](https://huggingface.co/robbyant/lingbot-va-posttrain-robotwin) | [🤖 Robbyant/lingbot-va-posttrain-robotwin &nbsp;](https://modelscope.cn/models/Robbyant/lingbot-va-posttrain-robotwin) | LingBot-VA-Posttrain-Robotwin w/ shared backbone|
46
+ ---
47
+
48
+ # 🛠️ Quick Start
49
+
50
+ ## Installation
51
+ **Requirements**
52
+ • Python == 3.10.16
53
+ • Pytorch == 2.9.0
54
+ • CUDA 12.6
55
+
56
+ ```bash
57
+ pip install torch==2.9.0 torchvision==0.24.0 torchaudio==2.9.0 --index-url https://download.pytorch.org/whl/cu126
58
+ pip install websockets einops diffusers==0.36.0 transformers==4.55.2 accelerate msgpack opencv-python matplotlib ftfy easydict
59
+ pip install flash-attn --no-build-isolation
60
+ ```
61
+
62
+
63
+ ## Deploying LingBot-VA for Inference
64
+ LingBot-VA supports both standalone execution and Server-Client architecture which separates the model environment from simulation. By isolating dependencies, the design avoids package clashes and supports distributed inference on GPUs, clusters, and other devices.
65
+
66
+ <!-- ### Standalone Inference
67
+ ```python
68
+ python inference.py
69
+ ```
70
+ This processes the example data from `examples/0/` and saves visualizations to `result/`. -->
71
+
72
+ ### Evaluation on RoboTwin-2.0
73
+
74
+ **Preparing the Environment**
75
+
76
+ You can follow the official instructions from the original RoboTwin-2.0 repository:
77
+ [https://robotwin-platform.github.io/doc/usage/robotwin-install.html](https://robotwin-platform.github.io/doc/usage/robotwin-install.html)
78
+
79
+
80
+ In summary:
81
+
82
+ 1.
83
+ ```bash
84
+ sudo apt install libvulkan1 mesa-vulkan-drivers vulkan-tools
85
+ ```
86
+
87
+ 2.
88
+ ```bash
89
+ git clone https://github.com/RoboTwin-Platform/RoboTwin.git && cd RoboTwin
90
+ ```
91
+
92
+ 3. modify script/requirements.txt
93
+ ```bash
94
+ transforms3d==0.4.2
95
+ sapien==3.0.0b1
96
+ scipy==1.10.1
97
+ mplib==0.2.1
98
+ gymnasium==0.29.1
99
+ trimesh==4.4.3
100
+ open3d==0.18.0
101
+ imageio==2.34.2
102
+ pydantic
103
+ zarr
104
+ openai
105
+ huggingface_hub==0.36.2
106
+ h5py
107
+ # For Description Generation
108
+ azure==4.0.0
109
+ azure-ai-inference
110
+ pyglet<2
111
+ wandb
112
+ moviepy
113
+ imageio
114
+ termcolor
115
+ av
116
+ matplotlib
117
+ ffmpeg
118
+ ```
119
+
120
+ 4. modify line 8 of script/_install.sh:
121
+ ```bash
122
+ pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" --no-build-isolation
123
+ ```
124
+
125
+ 5. run:
126
+ ```bash
127
+ bash script/_install.sh
128
+ ```
129
+
130
+ 6. run:
131
+ ```bash
132
+ bash script/_download_assets.sh
133
+ ```
134
+
135
+ **Deploying the Inference Server**
136
+ ```bash
137
+ # single GPU
138
+ bash evaluation/robotwin/launch_server.sh
139
+
140
+ # multi-GPU
141
+ bash evaluation/robotwin/launch_server_multigpus.sh
142
+ ```
143
+
144
+ **Executing the Inference Client**
145
+ ```bash
146
+ # single GPU
147
+ task_name="adjust_bottle";
148
+ save_root="results/";
149
+ bash evaluation/robotwin/launch_client.sh ${save_root} ${task_name}
150
+
151
+ # multi-GPU
152
+ save_root="results/"
153
+ task_group_id=0;
154
+ bash evaluation/robotwin/launch_client_multigpus.sh ${save_root} ${task_group_id}
155
+ ```
156
+
157
+ Related experiments results will be save in `/path/to/your/RoboTwin/${save_root}`. Please note that an `eval_result` folder is also generated. This is a native output from RoboTwin and is identical to the contents in the results folder; it can be safely ignored.
158
+ It is important to note that the inference server and client must be deployed on the same machine. For launching multi-GPU client, we padded the original 50 tasks to 56 via duplication and partitioned them into 7 groups to align with the 8-GPU configuration of our inference node. You can specify the `task_group_id` (0-6) to select a particular group for inference. For detailed grouping configurations, please refer to `evaluation/robotwin/launch_client_multigpus.sh`.
159
+
160
+ ### Run Image to Video-Action Generation
161
+
162
+ We also provide a script for image to video-action generation:
163
+
164
+ ```bash
165
+ NGPU=1 CONFIG_NAME='robotwin_i2av' bash script/run_launch_va_server_sync.sh
166
+ ```
167
+
168
+
169
+
170
+ ---
171
+
172
+ # 📊 Performance
173
+
174
+ We evaluate our model on both simulation benchmarks and real-world scenarios, and achieve state-of-the-art performance.
175
+
176
+ ## Simulation Evaluation
177
+
178
+ - **RoboTwin 2.0**
179
+
180
+ We are the first to propel RoboTwin 2.0 metrics performance past the 90+ threshold!
181
+ <table style="border-collapse: collapse; width: auto; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Arial, sans-serif; font-size: 13px; line-height: 1.2;">
182
+ <!-- 指标说明 -->
183
+ <p style="font-size: 12px; color: #666; margin-bottom: 5px;">* All metrics are reported in percentage (%). Higher values are <b>bolded</b>.</p>
184
+ <thead>
185
+ <tr style="border-top: 2px solid black; border-bottom: 1px solid black;">
186
+ <th align="left" style="padding: 6px 12px; white-space: nowrap;">Method (Average 50 Tasks)</th>
187
+ <th align="center" style="padding: 6px 12px;">Easy SR (%)</th>
188
+ <th align="center" style="padding: 6px 12px;">Hard SR (%)</th>
189
+ </tr>
190
+ </thead>
191
+ <tbody>
192
+ <tr>
193
+ <td style="padding: 4px 12px; white-space: nowrap;">X-VLA</td>
194
+ <td align="center">72.9</td>
195
+ <td align="center">72.8</td>
196
+ </tr>
197
+ <tr>
198
+ <td style="padding: 4px 12px; white-space: nowrap;">&pi;<sub>0</sub></td>
199
+ <td align="center">65.9</td>
200
+ <td align="center">58.4</td>
201
+ </tr>
202
+ <tr>
203
+ <td style="padding: 4px 12px; white-space: nowrap;">&pi;<sub>0.5</sub></td>
204
+ <td align="center">82.7</td>
205
+ <td align="center">76.8</td>
206
+ </tr>
207
+ <tr>
208
+ <td style="padding: 4px 12px; white-space: nowrap;">Motus</td>
209
+ <td align="center"><u>88.7</u></td>
210
+ <td align="center"><u>87.0</u></td>
211
+ </tr>
212
+ <tr style="border-top: 1px solid black; border-bottom: 2px solid black;">
213
+ <td style="padding: 6px 12px; white-space: nowrap;"><b>LingBot-VA (Ours)</b></td>
214
+ <td align="center"><b>92.9</b> <small>(+4.2)</small></td>
215
+ <td align="center"><b>91.6</b> <small>(+4.6)</small></td>
216
+ </tr>
217
+ </tbody>
218
+ </table>
219
+
220
+
221
+ - **LIBERO**
222
+
223
+ <table style="border-collapse: collapse; width: auto; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Arial, sans-serif; font-size: 13px; line-height: 1.2;">
224
+ <!-- 指标说明 -->
225
+ <p style="font-size: 12px; color: #666; margin-bottom: 5px;">* All metrics are reported in percentage (%). Higher values are <b>bolded</b>.</p>
226
+ <thead>
227
+ <tr style="border-top: 2px solid black; border-bottom: 1px solid black;">
228
+ <th align="left" style="padding: 6px 10px; border-right: 1px solid black; white-space: nowrap;">Methods</th>
229
+ <th align="center" style="padding: 6px 8px;">Spatial</th>
230
+ <th align="center" style="padding: 6px 8px;">Object</th>
231
+ <th align="center" style="padding: 6px 8px;">Goal</th>
232
+ <th align="center" style="padding: 6px 8px;">Long</th>
233
+ <th align="center" style="padding: 6px 8px;">Avg</th>
234
+ </tr>
235
+ </thead>
236
+ <tbody>
237
+ <tr>
238
+ <td style="padding: 4px 10px; border-right: 1px solid black; white-space: nowrap;">&pi;<sub>0</sub></td>
239
+ <td align="center">96.8</td><td align="center">98.8</td><td align="center">95.8</td><td align="center">85.2</td><td align="center">94.1</td>
240
+ </tr>
241
+ <tr>
242
+ <td style="padding: 4px 10px; border-right: 1px solid black; white-space: nowrap;">&pi;<sub>0.5</sub></td>
243
+ <td align="center">98.8</td><td align="center">98.2</td><td align="center">98.0</td><td align="center">92.4</td><td align="center">96.9</td>
244
+ </tr>
245
+ <tr>
246
+ <td style="padding: 4px 10px; border-right: 1px solid black; white-space: nowrap;">OpenVLA</td>
247
+ <td align="center">84.7</td><td align="center">88.4</td><td align="center">79.2</td><td align="center">53.7</td><td align="center">76.5</td>
248
+ </tr>
249
+ <tr>
250
+ <td style="padding: 4px 10px; border-right: 1px solid black; white-space: nowrap;">X-VLA</td>
251
+ <td align="center">98.2</td><td align="center">98.6</td><td align="center">97.8</td><td align="center">97.6</td><td align="center">98.1</td>
252
+ </tr>
253
+ <tr style="border-top: 1.5px solid black; border-bottom: 2px solid black;">
254
+ <td style="padding: 5px 10px; border-right: 1px solid black; white-space: nowrap;"><b>LingBot-VA (Ours)</b></td>
255
+ <td align="center"><b>98.5 &plusmn; 0.3</b></td>
256
+ <td align="center"><b>99.6 &plusmn; 0.3</b></td>
257
+ <td align="center"><b>97.2 &plusmn; 0.2</b></td>
258
+ <td align="center"><b>98.5 &plusmn; 0.5</b></td>
259
+ <td align="center"><b>98.5</b></td>
260
+ </tr>
261
+ </tbody>
262
+ </table>
263
+
264
+
265
+
266
+ &nbsp;
267
+
268
+ ## Real-world Deployment
269
+
270
+ Six manipulation tasks across three categories: longhorizon tasks (Make Breakfast, Pick Screws), precision tasks (Insert Tube, Unpack Delivery), and deformable & articulated object
271
+ manipulation (Fold Clothes, Fold Pants). Our method achieves state-of-the-art performance on both metrics (Progress Rate and Success Rate) with <b>only 50 trials</b> per task, substantially outperforming strong baseline &pi;<sub>0.5</sub>.
272
+
273
+ <div style="text-align: left; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Arial, sans-serif; line-height: 1.6;">
274
+
275
+ <!-- 第一部分:PS 说明 -->
276
+ <div style="margin-bottom: 5px;"><strong>Progress Score (PS):</strong> The average score across all trials divided by the maximum possible score, expressed as a percentage:</div>
277
+
278
+ PS = Average_Progress / Max_Steps &times; 100%
279
+
280
+ <!-- 第二部分:SR 说明 -->
281
+ <div style="margin-bottom: 5px;"><strong>Success Rate (SR):</strong> The number of successful trials divided by the total number of trials, expressed as a percentage:</div>
282
+
283
+ SR = Successful_Trials / N &times; 100%
284
+
285
+ </div>
286
+
287
+
288
+
289
+ <div style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Arial, sans-serif;">
290
+ <!-- 指标说明 -->
291
+ <p style="font-size: 12px; color: #666; margin-bottom: 5px;">* All metrics are reported in percentage (%). Higher values are <b>bolded</b>.</p>
292
+
293
+ <table style="border-collapse: collapse; width: auto; font-size: 13px; line-height: 1.2;">
294
+ <thead>
295
+ <tr style="border-top: 2px solid black;">
296
+ <th rowspan="2" align="left" style="padding: 4px 10px; border-bottom: 1px solid black; white-space: nowrap;"><b>Task</b></th>
297
+ <th colspan="2" style="padding: 4px 10px; border-bottom: 1px solid black;">Make Breakfast</th>
298
+ <th colspan="2" style="padding: 4px 10px; border-bottom: 1px solid black;">Pick Screws</th>
299
+ <th colspan="2" style="padding: 4px 10px; border-bottom: 1px solid black;">Insert Tube</th>
300
+ <th colspan="2" style="padding: 4px 10px; border-bottom: 1px solid black;">Unpack Delivery</th>
301
+ <th colspan="2" style="padding: 4px 10px; border-bottom: 1px solid black;">Fold Clothes</th>
302
+ <th colspan="2" style="padding: 4px 10px; border-bottom: 1px solid black;">Fold Pants</th>
303
+ </tr>
304
+ <tr style="border-bottom: 1px solid black;">
305
+ <th style="padding: 4px 8px;">PS</th>
306
+ <th style="padding: 4px 8px;">SR</th>
307
+ <th style="padding: 4px 8px;">PS</th>
308
+ <th style="padding: 4px 8px;">SR</th>
309
+ <th style="padding: 4px 8px;">PS</th>
310
+ <th style="padding: 4px 8px;">SR</th>
311
+ <th style="padding: 4px 8px;">PS</th>
312
+ <th style="padding: 4px 8px;">SR</th>
313
+ <th style="padding: 4px 8px;">PS</th>
314
+ <th style="padding: 4px 8px;">SR</th>
315
+ <th style="padding: 4px 8px;">PS</th>
316
+ <th style="padding: 4px 8px;">SR</th>
317
+ </tr>
318
+ </thead>
319
+ <tbody>
320
+ <tr>
321
+ <td style="padding: 6px 10px; white-space: nowrap;">&pi;<sub>0.5</sub></td>
322
+ <td align="center">73.0</td><td align="center">70.0</td>
323
+ <td align="center">74.0</td><td align="center">50.0</td>
324
+ <td align="center">79.2</td><td align="center">30.0</td>
325
+ <td align="center">73.0</td><td align="center">25.0</td>
326
+ <td align="center"><b>62.9</b></td><td align="center">30.0</td>
327
+ <td align="center">30.0</td><td align="center">30.0</td>
328
+ </tr>
329
+ <tr style="border-bottom: 2px solid black;">
330
+ <td style="padding: 6px 10px; white-space: nowrap;"><b>LingBot-VA (Ours)</b></td>
331
+ <td align="center"><b>97.0</b></td><td align="center"><b>75.0</b></td>
332
+ <td align="center"><b>82.5</b></td><td align="center"><b>70.0</b></td>
333
+ <td align="center"><b>85.8</b></td><td align="center"><b>40.0</b></td>
334
+ <td align="center"><b>84.5</b></td><td align="center"><b>65.0</b></td>
335
+ <td align="center">48.8</td><td align="center"><b>35.0</b></td>
336
+ <td align="center"><b>76.7</b></td><td align="center"><b>70.0</b></td>
337
+ </tr>
338
+ </tbody>
339
+ </table>
340
+ </div>
341
+
342
+
343
+ # 🪪 License
344
+
345
+ This project is released under the Apache License 2.0. See [LICENSE](LICENSE.txt) file for details.
346
+
347
+ # 📚Citation
348
+
349
+ ```bibtex
350
+ @article{lingbot-va2026,
351
+ title={Causal World Modeling for Robot Control},
352
+ author={Li, Lin and Zhang, Qihang and Luo, Yiming and Yang, Shuai and Wang, Ruilin and Han, Fei and Yu, Mingrui and Gao, Zelin and Xue, Nan and Zhu, Xing and Shen, Yujun and Xu, Yinghao},
353
+ journal={arXiv preprint arXiv:2601.21998},
354
+ year={2026}
355
+ }
356
+ ```
357
+
358
+ # 🧩 Acknowledgments
359
+
360
+ This work builds upon several excellent open-source projects:
361
+
362
+ - [Wan-Video](https://github.com/Wan-Video) - Vision transformer backbone
363
+ - [MoT](https://github.com/facebookresearch/Mixture-of-Transformers) - Mixture-of-Transformers architecture
364
+ - The broader open-source computer vision and robotics communities
365
+
366
+ ---
367
+
368
+ For questions, discussions, or collaborations:
369
+
370
+ - **Issues**: Open an [issue](https://github.com/robbyant/lingbot-va/issues) on GitHub
371
+ - **Email**: Contact Dr. [Qihang Zhang](https://zqh0253.github.io/) (liuhuan.zqh@antgroup.com) or Dr. [Lin Li](https://lilin-hitcrt.github.io/) (fengchang.ll@antgroup.com)
assets/teaser.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b23b4170e7784b82d8a6287451e886ac88fd0d7d841c55c4bc290b068c9f394
3
+ size 12144486
assets/teaser_v3.png ADDED

Git LFS Details

  • SHA256: f27d456c7af839b5929aeffaffa9ac2cb9a3b69e5e0df2440b37b4f241aa933d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
debug/place_fan/call1_reset.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0d8e7cb0edd8bb74e138e0cc15f50470a7db8316c4ac39e8a2178fde2e1a3da
3
+ size 140
debug/place_fan/call2.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23fc9654673dfad3b4c70a545782bab385b9b18f439fbfb6005b582de0e7005e
3
+ size 691918
debug/place_fan/call3.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0a92955b91b046dc90ee8e1c671be2938f53d90712245b126b20e2a569651c6
3
+ size 2769097
evaluation/robotwin/calc_stat.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ def compute_success_rates(root_dir: str, true_suffix="True.mp4", false_suffix="False.mp4"):
4
+ root = Path(root_dir)
5
+ if not root.exists():
6
+ raise FileNotFoundError(f"Root dir not found: {root}")
7
+
8
+ results = []
9
+ for sub in sorted([p for p in root.iterdir() if p.is_dir()]):
10
+ true_cnt = 0
11
+ false_cnt = 0
12
+
13
+ for mp4 in sub.rglob("*.mp4"):
14
+ name = mp4.name
15
+ if name.endswith(true_suffix):
16
+ true_cnt += 1
17
+ elif name.endswith(false_suffix):
18
+ false_cnt += 1
19
+
20
+ total = true_cnt + false_cnt
21
+ rate = (true_cnt / total) if total > 0 else None
22
+ results.append((sub.name, true_cnt, false_cnt, total, rate))
23
+
24
+ return results
25
+
26
+
27
+ # 你的三类:task -> 1/2/3
28
+ TASK_CLASS = {
29
+ "adjust_bottle": 1,
30
+ "beat_block_hammer": 1,
31
+ "blocks_ranking_rgb": 3,
32
+ "blocks_ranking_size": 3,
33
+ "click_alarmclock": 1,
34
+ "click_bell": 1,
35
+ "dump_bin_bigbin": 1,
36
+ "grab_roller": 1,
37
+ "handover_block": 2,
38
+ "handover_mic": 2,
39
+ "hanging_mug": 2,
40
+ "lift_pot": 1,
41
+ "move_can_pot": 1,
42
+ "move_pillbottle_pad": 1,
43
+ "move_playingcard_away": 1,
44
+ "move_stapler_pad": 1,
45
+ "open_laptop": 1,
46
+ "open_microwave": 1,
47
+ "pick_diverse_bottles": 2,
48
+ "pick_dual_bottles": 2,
49
+ "place_a2b_left": 1,
50
+ "place_a2b_right": 1,
51
+ "place_bread_basket": 1,
52
+ "place_bread_skillet": 2,
53
+ "place_burger_fries": 2,
54
+ "place_can_basket": 2,
55
+ "place_cans_plasticbox": 2,
56
+ "place_container_plate": 1,
57
+ "place_dual_shoes": 2,
58
+ "place_empty_cup": 1,
59
+ "place_fan": 1,
60
+ "place_mouse_pad": 1,
61
+ "place_object_basket": 2,
62
+ "place_object_scale": 1,
63
+ "place_object_stand": 1,
64
+ "place_phone_stand": 1,
65
+ "place_shoe": 1,
66
+ "press_stapler": 1,
67
+ "put_bottles_dustbin": 3,
68
+ "put_object_cabinet": 2,
69
+ "rotate_qrcode": 1,
70
+ "scan_object": 2,
71
+ "shake_bottle_horizontally": 1,
72
+ "shake_bottle": 1,
73
+ "stack_blocks_three": 3,
74
+ "stack_blocks_two": 2,
75
+ "stack_bowls_three": 3,
76
+ "stack_bowls_two": 2,
77
+ "stamp_seal": 1,
78
+ "turn_switch": 1,
79
+ }
80
+
81
+ def mean_rate_of(results_subset):
82
+ rates = [r[4] for r in results_subset if r[4] is not None]
83
+ return (sum(rates) / len(rates)) if rates else None
84
+
85
+
86
+ def print_table(results):
87
+ # 按成功率排序:None(=N/A) 放最后,其余从高到低
88
+ results = sorted(results, key=lambda r: (r[4] is None, -(r[4] or 0.0)))
89
+
90
+ print(f"{'folder':30s} {'True':>6s} {'False':>6s} {'Total':>6s} {'SuccessRate':>12s} {'Class':>6s}")
91
+ print("-" * 90)
92
+
93
+ for folder, t, f, total, rate in results:
94
+ rate_str = "N/A" if rate is None else f"{rate*100:9.2f}%"
95
+ cls = TASK_CLASS.get(folder, None)
96
+ cls_str = "N/A" if cls is None else str(cls)
97
+ print(f"{folder:30s} {t:6d} {f:6d} {total:6d} {rate_str:>12s} {cls_str:>6s}")
98
+
99
+ print("-" * 90)
100
+
101
+ # overall mean
102
+ overall_mean = mean_rate_of(results)
103
+ overall_str = "N/A" if overall_mean is None else f"{overall_mean*100:9.2f}%"
104
+ print(f"{'MEAN (ALL)':30s} {'':6s} {'':6s} {'':6s} {overall_str:>12s}")
105
+
106
+ # per-class mean (1/2/3)
107
+ for c in (1, 2, 3):
108
+ subset = [r for r in results if TASK_CLASS.get(r[0]) == c]
109
+ m = mean_rate_of(subset)
110
+ m_str = "N/A" if m is None else f"{m*100:9.2f}%"
111
+ print(f"{('MEAN (CLASS '+str(c)+')'):30s} {'':6s} {'':6s} {'':6s} {m_str:>12s}")
112
+
113
+ # optional: tasks not in mapping
114
+ unknown_subset = [r for r in results if r[0] not in TASK_CLASS]
115
+ if unknown_subset:
116
+ m = mean_rate_of(unknown_subset)
117
+ m_str = "N/A" if m is None else f"{m*100:9.2f}%"
118
+ print(f"{'MEAN (UNKNOWN)':30s} {'':6s} {'':6s} {'':6s} {m_str:>12s}")
119
+
120
+
121
+ if __name__ == "__main__":
122
+ import sys
123
+
124
+ roots = sys.argv[1:]
125
+ if not roots:
126
+ raise SystemExit("Usage: python a.py <root_folder1> [<root_folder2> ...]")
127
+
128
+ all_results = []
129
+ for root_dir in roots:
130
+ all_results.extend(compute_success_rates(root_dir))
131
+
132
+ print_table(all_results)
evaluation/robotwin/eval_polict_client_openpi.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import subprocess
4
+ import matplotlib.pyplot as plt
5
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
6
+ import cv2
7
+ from pathlib import Path
8
+
9
+ robowin_root = Path("/group/ossdphi_algo_scratch_11/weicxu/pythonproject/RoboTwin")
10
+ if str(robowin_root) not in sys.path:
11
+ sys.path.insert(0, str(robowin_root))
12
+
13
+
14
+ import os
15
+ os.chdir(robowin_root)
16
+
17
+ from envs import CONFIGS_PATH
18
+ from envs.utils.create_actor import UnStableError
19
+
20
+ import numpy as np
21
+ from pathlib import Path
22
+ from collections import deque
23
+ import traceback
24
+
25
+ import yaml
26
+ from datetime import datetime
27
+ import importlib
28
+ import argparse
29
+ import pdb
30
+ from evaluation.robotwin.geometry import euler2quat
31
+ import numpy as np
32
+
33
+ from description.utils.generate_episode_instructions import *
34
+ import traceback
35
+
36
+ import imageio
37
+ import numpy as np
38
+ from pathlib import Path
39
+ from scipy.spatial.transform import Rotation as R
40
+ import json
41
+ from pathlib import Path
42
+
43
+ from evaluation.robotwin.websocket_client_policy import WebsocketClientPolicy
44
+ from evaluation.robotwin.test_render import Sapien_TEST
45
+
46
+ def write_json(data: dict, fpath: Path) -> None:
47
+ """Write data to a JSON file.
48
+
49
+ Creates parent directories if they don't exist.
50
+
51
+ Args:
52
+ data (dict): The dictionary to write.
53
+ fpath (Path): The path to the output JSON file.
54
+ """
55
+ fpath.parent.mkdir(exist_ok=True, parents=True)
56
+ with open(fpath, "w") as f:
57
+ json.dump(data, f, indent=4, ensure_ascii=False)
58
+
59
+ def add_title_bar(img, text, font_scale=0.8, thickness=2):
60
+ """Add a black title bar with text above the image"""
61
+ h, w, _ = img.shape
62
+ bar_height = 40
63
+
64
+ # Create black background bar
65
+ title_bar = np.zeros((bar_height, w, 3), dtype=np.uint8)
66
+
67
+ # Calculate text position to center it
68
+ (text_w, text_h), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
69
+ text_x = (w - text_w) // 2
70
+ text_y = (bar_height + text_h) // 2 - 5
71
+
72
+ cv2.putText(title_bar, text, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX,
73
+ font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
74
+
75
+ return np.vstack([title_bar, img])
76
+
77
+ def quaternion_to_euler(quat):
78
+ """
79
+ Convert quaternion to Euler angles (roll, pitch, yaw)
80
+ quat: [rx, ry, rz, rw] format
81
+ Return: [roll, pitch, yaw] (radians)
82
+ """
83
+ # scipy uses [x, y, z, w] format
84
+ rotation = R.from_quat(quat)
85
+ euler = rotation.as_euler('xyz', degrees=False) # returns [roll, pitch, yaw]
86
+ return euler
87
+
88
+ def visualize_action_step(action_history, step_idx, window=50):
89
+ """
90
+ Plot dual-arm action curves:
91
+ Subplot 1: Left arm XYZ Position + Gripper
92
+ Subplot 2: Left arm Euler angles (Roll, Pitch, Yaw) - converted from quaternion
93
+ Subplot 3: Right arm XYZ Position + Gripper
94
+ Subplot 4: Right arm Euler angles (Roll, Pitch, Yaw) - converted from quaternion
95
+
96
+ Input data format: [left_x, left_y, left_z, left_rx, left_ry, left_rz, left_rw, left_gripper,
97
+ right_x, right_y, right_z, right_rx, right_ry, right_rz, right_rw, right_gripper]
98
+ Total 16 dimensions
99
+ """
100
+ # Create four subplots, sharing the X-axis
101
+ fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 8), dpi=100, sharex=True)
102
+
103
+ # 1. Determine slice range
104
+ start = max(0, step_idx - window)
105
+ end = step_idx + 1
106
+
107
+ # 2. Get data subset
108
+ history_subset = np.array(action_history)[start:end]
109
+
110
+ # 3. Generate X-axis based on actual data length
111
+ actual_len = len(history_subset)
112
+ x_axis = range(start, start + actual_len)
113
+
114
+ if actual_len > 0 and history_subset.shape[1] >= 16:
115
+ # Convert quaternions to Euler angles
116
+ left_euler = []
117
+ right_euler = []
118
+
119
+ for action in history_subset:
120
+ # Left arm quaternion to Euler angles
121
+ left_quat = action[3:7] # [rx, ry, rz, rw]
122
+ left_rpy = quaternion_to_euler(left_quat)
123
+ left_euler.append(left_rpy)
124
+
125
+ # Right arm quaternion to Euler angles
126
+ right_quat = action[11:15] # [rx, ry, rz, rw]
127
+ right_rpy = quaternion_to_euler(right_quat)
128
+ right_euler.append(right_rpy)
129
+
130
+ left_euler = np.array(left_euler)
131
+ right_euler = np.array(right_euler)
132
+
133
+ # --- Left Arm ---
134
+ # Subplot 1: Left Arm Translation (XYZ) + Gripper
135
+ ax1.plot(x_axis, history_subset[:, 0], label='left_x', color='r', linewidth=1.5)
136
+ ax1.plot(x_axis, history_subset[:, 1], label='left_y', color='g', linewidth=1.5)
137
+ ax1.plot(x_axis, history_subset[:, 2], label='left_z', color='b', linewidth=1.5)
138
+ ax1.plot(x_axis, history_subset[:, 7], label='left_grip', color='orange',
139
+ linestyle=':', linewidth=2, alpha=0.8)
140
+ ax1.set_ylabel('Position (m)')
141
+ ax1.legend(loc='upper right', fontsize='x-small', ncol=4)
142
+ ax1.grid(True, alpha=0.3)
143
+ ax1.set_title(f"Step {step_idx}: Left Arm Position & Gripper")
144
+
145
+ # Subplot 2: Left Arm Euler Angles (Roll, Pitch, Yaw)
146
+ ax2.plot(x_axis, left_euler[:, 0], label='left_roll', color='c', linewidth=1.5)
147
+ ax2.plot(x_axis, left_euler[:, 1], label='left_pitch', color='m', linewidth=1.5)
148
+ ax2.plot(x_axis, left_euler[:, 2], label='left_yaw', color='y', linewidth=1.5)
149
+ ax2.set_ylabel('Rotation (rad)')
150
+ ax2.legend(loc='upper right', fontsize='x-small', ncol=3)
151
+ ax2.grid(True, alpha=0.3)
152
+ ax2.set_title("Left Arm Rotation (RPY from Quaternion)")
153
+
154
+ # --- Right Arm ---
155
+ # Subplot 3: Right Arm Translation (XYZ) + Gripper
156
+ ax3.plot(x_axis, history_subset[:, 8], label='right_x', color='r', linewidth=1.5, linestyle='--')
157
+ ax3.plot(x_axis, history_subset[:, 9], label='right_y', color='g', linewidth=1.5, linestyle='--')
158
+ ax3.plot(x_axis, history_subset[:, 10], label='right_z', color='b', linewidth=1.5, linestyle='--')
159
+ ax3.plot(x_axis, history_subset[:, 15], label='right_grip', color='orange',
160
+ linestyle=':', linewidth=2, alpha=0.8)
161
+ ax3.set_ylabel('Position (m)')
162
+ ax3.legend(loc='upper right', fontsize='x-small', ncol=4)
163
+ ax3.grid(True, alpha=0.3)
164
+ ax3.set_title("Right Arm Position & Gripper")
165
+
166
+ # Subplot 4: Right Arm Euler Angles (Roll, Pitch, Yaw)
167
+ ax4.plot(x_axis, right_euler[:, 0], label='right_roll', color='c', linewidth=1.5, linestyle='--')
168
+ ax4.plot(x_axis, right_euler[:, 1], label='right_pitch', color='m', linewidth=1.5, linestyle='--')
169
+ ax4.plot(x_axis, right_euler[:, 2], label='right_yaw', color='y', linewidth=1.5, linestyle='--')
170
+ ax4.set_ylabel('Rotation (rad)')
171
+ ax4.legend(loc='upper right', fontsize='x-small', ncol=3)
172
+ ax4.grid(True, alpha=0.3)
173
+ ax4.set_title("Right Arm Rotation (RPY from Quaternion)")
174
+
175
+ # Set X-axis display range to maintain sliding window effect
176
+ ax1.set_xlim(max(0, step_idx - window), max(window, step_idx))
177
+ ax3.set_xlabel('Step')
178
+ ax4.set_xlabel('Step')
179
+
180
+ plt.tight_layout()
181
+ canvas = FigureCanvas(fig)
182
+ canvas.draw()
183
+ img = np.asarray(canvas.buffer_rgba())
184
+ img = img[:, :, :3]
185
+
186
+ # Convert to uint8
187
+ if img.dtype != np.uint8:
188
+ img = (img * 255).astype(np.uint8)
189
+
190
+ plt.close(fig)
191
+ return img
192
+
193
+
194
+ def save_comparison_video(real_obs_list, imagined_video, action_history, save_path, fps=15):
195
+ if not real_obs_list:
196
+ return
197
+
198
+ n_real = len(real_obs_list)
199
+ if imagined_video is not None:
200
+ imagined_video = np.concatenate(imagined_video, 0)
201
+ n_imagined = len(imagined_video)
202
+ else:
203
+ n_imagined = 0
204
+ n_frames = n_real # Based on real observation frames
205
+
206
+ print(f"Saving video: Real {n_real} frames, Imagined {n_imagined} frames...")
207
+
208
+ final_frames = []
209
+
210
+ for i in range(n_frames):
211
+ obs = real_obs_list[i]
212
+ cam_high = obs["observation.images.cam_high"]
213
+ cam_left = obs["observation.images.cam_left_wrist"]
214
+ cam_right = obs["observation.images.cam_right_wrist"]
215
+
216
+ base_h = cam_high.shape[0]
217
+
218
+ def resize_h(img, h):
219
+ if img.shape[0] != h:
220
+ w = int(img.shape[1] * h / img.shape[0])
221
+ return cv2.resize(img, (w, h))
222
+ return img
223
+
224
+ row_real = np.hstack([
225
+ resize_h(cam_high, base_h),
226
+ resize_h(cam_left, base_h),
227
+ resize_h(cam_right, base_h)
228
+ ])
229
+
230
+ if row_real.dtype != np.uint8:
231
+ row_real = (row_real * 255).astype(np.uint8)
232
+
233
+ row_real = add_title_bar(row_real, "Real Observation (High / Left / Right)")
234
+
235
+ target_width = row_real.shape[1]
236
+
237
+ if imagined_video is not None and i < n_imagined:
238
+ img_frame = imagined_video[i]
239
+ if img_frame.dtype != np.uint8 and img_frame.max() <= 1.0001:
240
+ img_frame = (img_frame * 255).astype(np.uint8)
241
+ elif img_frame.dtype != np.uint8:
242
+ img_frame = img_frame.astype(np.uint8)
243
+
244
+ h = int(img_frame.shape[0] * target_width / img_frame.shape[1])
245
+ row_imagined = cv2.resize(img_frame, (target_width, h))
246
+ else:
247
+ row_imagined = np.zeros((300, target_width, 3), dtype=np.uint8)
248
+ cv2.putText(row_imagined, "Coming soon", (target_width//2 - 100, 150),
249
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (100, 100, 100), 2)
250
+
251
+ row_imagined = add_title_bar(row_imagined, "Imagined Video Stream")
252
+ full_frame = np.vstack([row_real, row_imagined])
253
+ final_frames.append(full_frame)
254
+
255
+ imageio.mimsave(save_path, final_frames, fps=fps)
256
+ print(f"Combined video saved to: {save_path}")
257
+
258
+
259
+ def class_decorator(task_name):
260
+ envs_module = importlib.import_module(f"envs.{task_name}")
261
+ try:
262
+ env_class = getattr(envs_module, task_name)
263
+ env_instance = env_class()
264
+ except:
265
+ raise SystemExit("No Task")
266
+ return env_instance
267
+
268
+
269
+ def eval_function_decorator(policy_name, model_name):
270
+ try:
271
+ policy_model = importlib.import_module(policy_name)
272
+ return getattr(policy_model, model_name)
273
+ except ImportError as e:
274
+ raise e
275
+
276
+ def get_camera_config(camera_type):
277
+ camera_config_path = os.path.join(robowin_root, "task_config/_camera_config.yml")
278
+
279
+ assert os.path.isfile(camera_config_path), "task config file is missing"
280
+
281
+ with open(camera_config_path, "r", encoding="utf-8") as f:
282
+ args = yaml.load(f.read(), Loader=yaml.FullLoader)
283
+
284
+ assert camera_type in args, f"camera {camera_type} is not defined"
285
+ return args[camera_type]
286
+
287
+
288
+ def get_embodiment_config(robot_file):
289
+ robot_config_file = os.path.join(robot_file, "config.yml")
290
+ with open(robot_config_file, "r", encoding="utf-8") as f:
291
+ embodiment_args = yaml.load(f.read(), Loader=yaml.FullLoader)
292
+ return embodiment_args
293
+
294
+
295
+ def main(usr_args):
296
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
297
+ task_name = usr_args["task_name"]
298
+ task_config = usr_args["task_config"]
299
+ ckpt_setting = usr_args["ckpt_setting"]
300
+ save_root = usr_args["save_root"]
301
+ policy_name = usr_args["policy_name"]
302
+ video_guidance_scale = usr_args["video_guidance_scale"]
303
+ action_guidance_scale = usr_args["action_guidance_scale"]
304
+ instruction_type = 'seen'
305
+ save_dir = None
306
+ video_save_dir = None
307
+ video_size = None
308
+
309
+ with open(f"./task_config/{task_config}.yml", "r", encoding="utf-8") as f:
310
+ args = yaml.load(f.read(), Loader=yaml.FullLoader)
311
+
312
+ args['task_name'] = task_name
313
+ args["task_config"] = task_config
314
+ args["ckpt_setting"] = ckpt_setting
315
+ args["save_root"] = save_root
316
+
317
+ embodiment_type = args.get("embodiment")
318
+ embodiment_config_path = os.path.join(CONFIGS_PATH, "_embodiment_config.yml")
319
+
320
+ with open(embodiment_config_path, "r", encoding="utf-8") as f:
321
+ _embodiment_types = yaml.load(f.read(), Loader=yaml.FullLoader)
322
+
323
+ def get_embodiment_file(embodiment_type):
324
+ robot_file = _embodiment_types[embodiment_type]["file_path"]
325
+ if robot_file is None:
326
+ raise "No embodiment files"
327
+ return robot_file
328
+
329
+ with open(CONFIGS_PATH + "_camera_config.yml", "r", encoding="utf-8") as f:
330
+ _camera_config = yaml.load(f.read(), Loader=yaml.FullLoader)
331
+
332
+ head_camera_type = args["camera"]["head_camera_type"]
333
+ args["head_camera_h"] = _camera_config[head_camera_type]["h"]
334
+ args["head_camera_w"] = _camera_config[head_camera_type]["w"]
335
+
336
+ if len(embodiment_type) == 1:
337
+ args["left_robot_file"] = get_embodiment_file(embodiment_type[0])
338
+ args["right_robot_file"] = get_embodiment_file(embodiment_type[0])
339
+ args["dual_arm_embodied"] = True
340
+ elif len(embodiment_type) == 3:
341
+ args["left_robot_file"] = get_embodiment_file(embodiment_type[0])
342
+ args["right_robot_file"] = get_embodiment_file(embodiment_type[1])
343
+ args["embodiment_dis"] = embodiment_type[2]
344
+ args["dual_arm_embodied"] = False
345
+ else:
346
+ raise "embodiment items should be 1 or 3"
347
+
348
+ args["left_embodiment_config"] = get_embodiment_config(args["left_robot_file"])
349
+ args["right_embodiment_config"] = get_embodiment_config(args["right_robot_file"])
350
+
351
+ if len(embodiment_type) == 1:
352
+ embodiment_name = str(embodiment_type[0])
353
+ else:
354
+ embodiment_name = str(embodiment_type[0]) + "+" + str(embodiment_type[1])
355
+
356
+ save_dir = Path(f"eval_result/{task_name}/{policy_name}/{task_config}/{ckpt_setting}/{current_time}")
357
+ save_dir.mkdir(parents=True, exist_ok=True)
358
+
359
+ if args["eval_video_log"]:
360
+ video_save_dir = save_dir
361
+ camera_config = get_camera_config(args["camera"]["head_camera_type"])
362
+ video_size = str(camera_config["w"]) + "x" + str(camera_config["h"])
363
+ video_save_dir.mkdir(parents=True, exist_ok=True)
364
+ args["eval_video_save_dir"] = video_save_dir
365
+
366
+ print("============= Config =============\n")
367
+ print("\033[95mMessy Table:\033[0m " + str(args["domain_randomization"]["cluttered_table"]))
368
+ print("\033[95mRandom Background:\033[0m " + str(args["domain_randomization"]["random_background"]))
369
+ if args["domain_randomization"]["random_background"]:
370
+ print(" - Clean Background Rate: " + str(args["domain_randomization"]["clean_background_rate"]))
371
+ print("\033[95mRandom Light:\033[0m " + str(args["domain_randomization"]["random_light"]))
372
+ if args["domain_randomization"]["random_light"]:
373
+ print(" - Crazy Random Light Rate: " + str(args["domain_randomization"]["crazy_random_light_rate"]))
374
+ print("\033[95mRandom Table Height:\033[0m " + str(args["domain_randomization"]["random_table_height"]))
375
+ print("\033[95mRandom Head Camera Distance:\033[0m " + str(args["domain_randomization"]["random_head_camera_dis"]))
376
+
377
+ print("\033[94mHead Camera Config:\033[0m " + str(args["camera"]["head_camera_type"]) + f", " +
378
+ str(args["camera"]["collect_head_camera"]))
379
+ print("\033[94mWrist Camera Config:\033[0m " + str(args["camera"]["wrist_camera_type"]) + f", " +
380
+ str(args["camera"]["collect_wrist_camera"]))
381
+ print("\033[94mEmbodiment Config:\033[0m " + embodiment_name)
382
+ print("\n==================================")
383
+
384
+ TASK_ENV = class_decorator(args["task_name"])
385
+ args["policy_name"] = policy_name
386
+ usr_args["left_arm_dim"] = len(args["left_embodiment_config"]["arm_joints_name"][0])
387
+ usr_args["right_arm_dim"] = len(args["right_embodiment_config"]["arm_joints_name"][1])
388
+
389
+ seed = usr_args["seed"]
390
+
391
+ st_seed = 10000 * (1 + seed)
392
+ suc_nums = []
393
+ test_num = usr_args["test_num"]
394
+
395
+
396
+ model = WebsocketClientPolicy(port=usr_args['port'])
397
+
398
+ st_seed, suc_num = eval_policy(task_name,
399
+ TASK_ENV,
400
+ args,
401
+ model,
402
+ st_seed,
403
+ test_num=test_num,
404
+ video_size=video_size,
405
+ instruction_type=instruction_type,
406
+ save_visualization=True,
407
+ video_guidance_scale=video_guidance_scale,
408
+ action_guidance_scale=action_guidance_scale)
409
+ suc_nums.append(suc_num)
410
+
411
+ file_path = os.path.join(save_dir, f"_result.txt")
412
+ with open(file_path, "w") as file:
413
+ file.write(f"Timestamp: {current_time}\n\n")
414
+ file.write(f"Instruction Type: {instruction_type}\n\n")
415
+ file.write("\n".join(map(str, np.array(suc_nums) / test_num)))
416
+
417
+ print(f"Data has been saved to {file_path}")
418
+
419
+ def format_obs(observation, prompt):
420
+ return {
421
+ "observation.images.cam_high": observation["observation"]["head_camera"]["rgb"], # H,W,3
422
+ "observation.images.cam_left_wrist": observation["observation"]["left_camera"]["rgb"],
423
+ "observation.images.cam_right_wrist": observation["observation"]["right_camera"]["rgb"],
424
+ "observation.state": observation["joint_action"]["vector"],
425
+ "task": prompt,
426
+ }
427
+
428
+ def add_eef_pose(new_pose, init_pose):
429
+ new_pose_R = R.from_quat(new_pose[3:7][None])
430
+ init_pose_R = R.from_quat(init_pose[3:7][None])
431
+ out_rot = (init_pose_R * new_pose_R).as_quat().reshape(-1)
432
+ out_trans = new_pose[:3] + init_pose[:3]
433
+ return np.concatenate([out_trans, out_rot, new_pose[7:8]])
434
+
435
+ def add_init_pose(new_pose, init_pose):
436
+ left_pose = add_eef_pose(new_pose[:8], init_pose[:8])
437
+ right_pose = add_eef_pose(new_pose[8:], init_pose[8:])
438
+ return np.concatenate([left_pose, right_pose])
439
+
440
+ def eval_policy(task_name,
441
+ TASK_ENV,
442
+ args,
443
+ model,
444
+ st_seed,
445
+ test_num=100,
446
+ video_size=None,
447
+ instruction_type=None,
448
+ save_visualization=False,
449
+ video_guidance_scale=5.0,
450
+ action_guidance_scale=5.0):
451
+ print(f"\033[34mTask Name: {args['task_name']}\033[0m")
452
+ print(f"\033[34mPolicy Name: {args['policy_name']}\033[0m")
453
+
454
+ expert_check = True
455
+ TASK_ENV.suc = 0
456
+ TASK_ENV.test_num = 0
457
+
458
+ now_id = 0
459
+ succ_seed = 0
460
+ suc_test_seed_list = []
461
+
462
+
463
+ now_seed = st_seed
464
+ clear_cache_freq = args["clear_cache_freq"]
465
+
466
+ args["eval_mode"] = True
467
+
468
+ while succ_seed < test_num:
469
+ render_freq = args["render_freq"]
470
+ args["render_freq"] = 0
471
+
472
+ if expert_check:
473
+ try:
474
+ TASK_ENV.setup_demo(now_ep_num=now_id, seed=now_seed, is_test=True, **args)
475
+ episode_info = TASK_ENV.play_once()
476
+ TASK_ENV.close_env()
477
+ except UnStableError as e:
478
+ TASK_ENV.close_env()
479
+ now_seed += 1
480
+ args["render_freq"] = render_freq
481
+ continue
482
+ except Exception as e:
483
+ TASK_ENV.close_env()
484
+ now_seed += 1
485
+ args["render_freq"] = render_freq
486
+ print(f"error occurs ! {e}")
487
+ traceback.print_exc()
488
+ continue
489
+
490
+ if (not expert_check) or (TASK_ENV.plan_success and TASK_ENV.check_success()):
491
+ succ_seed += 1
492
+ suc_test_seed_list.append(now_seed)
493
+ else:
494
+ now_seed += 1
495
+ args["render_freq"] = render_freq
496
+ continue
497
+
498
+ args["render_freq"] = render_freq
499
+
500
+ TASK_ENV.setup_demo(now_ep_num=now_id, seed=now_seed, is_test=True, **args)
501
+ episode_info_list = [episode_info["info"]]
502
+ results = generate_episode_descriptions(args["task_name"], episode_info_list, test_num)
503
+ instruction = np.random.choice(results[0][instruction_type])
504
+ TASK_ENV.set_instruction(instruction=instruction) # set language instruction
505
+
506
+ if TASK_ENV.eval_video_path is not None:
507
+ ffmpeg = subprocess.Popen(
508
+ [
509
+ "ffmpeg",
510
+ "-y",
511
+ "-loglevel",
512
+ "error",
513
+ "-f",
514
+ "rawvideo",
515
+ "-pixel_format",
516
+ "rgb24",
517
+ "-video_size",
518
+ video_size,
519
+ "-framerate",
520
+ "10",
521
+ "-i",
522
+ "-",
523
+ "-pix_fmt",
524
+ "yuv420p",
525
+ "-vcodec",
526
+ "libx264",
527
+ "-crf",
528
+ "23",
529
+ f"{TASK_ENV.eval_video_path}/episode{TASK_ENV.test_num}.mp4",
530
+ ],
531
+ stdin=subprocess.PIPE,
532
+ )
533
+ TASK_ENV._set_eval_video_ffmpeg(ffmpeg)
534
+
535
+ succ = False
536
+
537
+ prompt = TASK_ENV.get_instruction()
538
+ ret = model.infer(dict(reset = True, prompt=prompt, save_visualization=save_visualization))
539
+
540
+ first = True
541
+ full_obs_list = []
542
+ gen_video_list = []
543
+ full_action_history = []
544
+
545
+ initial_obs = TASK_ENV.get_obs()
546
+ inint_eef_pose = initial_obs['endpose']['left_endpose'] + \
547
+ [initial_obs['endpose']['left_gripper']] + \
548
+ initial_obs['endpose']['right_endpose'] + \
549
+ [initial_obs['endpose']['right_gripper']]
550
+ inint_eef_pose = np.array(inint_eef_pose, dtype=np.float64)
551
+ initial_formatted_obs = format_obs(initial_obs, prompt)
552
+ full_obs_list.append(initial_formatted_obs)
553
+ first_obs = None
554
+ while TASK_ENV.take_action_cnt<TASK_ENV.step_lim:
555
+ if first:
556
+ observation = TASK_ENV.get_obs()
557
+ first_obs = format_obs(observation, prompt)
558
+
559
+ ret = model.infer(dict(obs=first_obs, prompt=prompt, save_visualization=save_visualization, video_guidance_scale=video_guidance_scale, action_guidance_scale=action_guidance_scale)) #(TASK_ENV, model, observation)
560
+ action = ret['action']
561
+ if 'video' in ret:
562
+ imagined_video = ret['video']
563
+ gen_video_list.append(imagined_video)
564
+ key_frame_list = []
565
+
566
+ assert action.shape[2] % 4 == 0
567
+ action_per_frame = action.shape[2] // 4
568
+
569
+ start_idx = 1 if first else 0
570
+ for i in range(start_idx, action.shape[1]):
571
+ for j in range(action.shape[2]):
572
+ raw_action_step = action[:, i, j].flatten()
573
+ full_action_history.append(raw_action_step)
574
+
575
+ ee_action = action[:, i, j]
576
+ if action.shape[0] == 14:
577
+ ee_action = np.concatenate([
578
+ ee_action[:3],
579
+ euler2quat(ee_action[3], ee_action[4], ee_action[5]),
580
+ ee_action[6:10],
581
+ euler2quat(ee_action[10], ee_action[11], ee_action[12]),
582
+ ee_action[13:14]
583
+ ])
584
+ elif action.shape[0] == 16:
585
+ ee_action = add_init_pose(ee_action, inint_eef_pose)
586
+ ee_action = np.concatenate([
587
+ ee_action[:3],
588
+ ee_action[3:7] / np.linalg.norm(ee_action[3:7]),
589
+ ee_action[7:11],
590
+ ee_action[11:15] / np.linalg.norm(ee_action[11:15]),
591
+ ee_action[15:16]
592
+ ])
593
+ else:
594
+ raise NotImplementedError
595
+ TASK_ENV.take_action(ee_action, action_type='ee')
596
+
597
+ if (j+1) % action_per_frame == 0:
598
+ obs = format_obs(TASK_ENV.get_obs(), prompt)
599
+ full_obs_list.append(obs)
600
+ key_frame_list.append(obs)
601
+
602
+ first = False
603
+
604
+ model.infer(dict(obs = key_frame_list, compute_kv_cache=True, imagine=False, save_visualization=save_visualization, state=action))
605
+
606
+ if TASK_ENV.eval_success:
607
+ succ = True
608
+ break
609
+
610
+
611
+ vis_dir = Path(args['save_root']) / f'stseed-{st_seed}' / 'visualization' / task_name
612
+ vis_dir.mkdir(parents=True, exist_ok=True)
613
+ video_name = f"{TASK_ENV.test_num}_{prompt.replace(' ', '_')}_{succ}.mp4"
614
+ out_img_file = vis_dir / video_name
615
+ save_comparison_video(
616
+ real_obs_list=full_obs_list,
617
+ imagined_video=None, #gen_video_list,
618
+ action_history=full_action_history,
619
+ save_path=str(out_img_file),
620
+ fps=15 # Suggest adjusting fps based on simulation step
621
+ )
622
+ if TASK_ENV.eval_video_path is not None:
623
+ TASK_ENV._del_eval_video_ffmpeg()
624
+
625
+ if succ:
626
+ TASK_ENV.suc += 1
627
+ print("\033[92mSuccess!\033[0m")
628
+ else:
629
+ print("\033[91mFail!\033[0m")
630
+
631
+ now_id += 1
632
+ TASK_ENV.close_env(clear_cache=((succ_seed + 1) % clear_cache_freq == 0))
633
+
634
+ if TASK_ENV.render_freq:
635
+ TASK_ENV.viewer.close()
636
+
637
+ TASK_ENV.test_num += 1
638
+
639
+ save_dir = Path(args['save_root']) / f'stseed-{st_seed}' / 'metrics' / task_name
640
+ save_dir.mkdir(parents=True, exist_ok=True)
641
+ out_json_file = save_dir / 'res.json'
642
+ write_json({
643
+ "succ_num": float(TASK_ENV.suc),
644
+ "total_num": float(TASK_ENV.test_num),
645
+ "succ_rate": float(TASK_ENV.suc / TASK_ENV.test_num),
646
+ }, out_json_file)
647
+
648
+ print(
649
+ f"\033[93m{task_name}\033[0m | \033[94m{args['policy_name']}\033[0m | \033[92m{args['task_config']}\033[0m | \033[91m{args['ckpt_setting']}\033[0m\n"
650
+ f"Success rate: \033[96m{TASK_ENV.suc}/{TASK_ENV.test_num}\033[0m => \033[95m{round(TASK_ENV.suc/TASK_ENV.test_num*100, 1)}%\033[0m, current seed: \033[90m{now_seed}\033[0m\n"
651
+ )
652
+ now_seed += 1
653
+
654
+ return now_seed, TASK_ENV.suc
655
+
656
+
657
+ def parse_args_and_config():
658
+ parser = argparse.ArgumentParser()
659
+ parser.add_argument("--config", type=str, required=True)
660
+ parser.add_argument("--overrides", nargs=argparse.REMAINDER)
661
+ parser.add_argument("--port", type=int, default=8000, help='remote policy socket port.')
662
+ parser.add_argument("--save_root", type=str, default="results/default_vis_path")
663
+ parser.add_argument("--video_guidance_scale", type=float, default=5.0)
664
+ parser.add_argument("--action_guidance_scale", type=float, default=5.0)
665
+ parser.add_argument("--test_num", type=int, default=100)
666
+ args = parser.parse_args()
667
+
668
+ with open(args.config, "r", encoding="utf-8") as f:
669
+ config = yaml.safe_load(f)
670
+
671
+ # Parse overrides
672
+ def parse_override_pairs(pairs):
673
+ override_dict = {}
674
+ for i in range(0, len(pairs), 2):
675
+ key = pairs[i].lstrip("--")
676
+ value = pairs[i + 1]
677
+ try:
678
+ value = eval(value)
679
+ except:
680
+ pass
681
+ override_dict[key] = value
682
+ return override_dict
683
+
684
+ if args.overrides:
685
+ overrides = parse_override_pairs(args.overrides)
686
+ config.update(overrides)
687
+
688
+ return config
689
+
690
+
691
+ if __name__ == "__main__":
692
+
693
+ Sapien_TEST()
694
+ usr_args = parse_args_and_config()
695
+ main(usr_args)
696
+
evaluation/robotwin/geometry.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mostly copied from transforms3d library
3
+
4
+ """
5
+
6
+ import math
7
+
8
+ import numpy as np
9
+
10
+ _FLOAT_EPS = np.finfo(np.float64).eps
11
+
12
+ # axis sequences for Euler angles
13
+ _NEXT_AXIS = [1, 2, 0, 1]
14
+
15
+ # map axes strings to/from tuples of inner axis, parity, repetition, frame
16
+ _AXES2TUPLE = {
17
+ "sxyz": (0, 0, 0, 0),
18
+ "sxyx": (0, 0, 1, 0),
19
+ "sxzy": (0, 1, 0, 0),
20
+ "sxzx": (0, 1, 1, 0),
21
+ "syzx": (1, 0, 0, 0),
22
+ "syzy": (1, 0, 1, 0),
23
+ "syxz": (1, 1, 0, 0),
24
+ "syxy": (1, 1, 1, 0),
25
+ "szxy": (2, 0, 0, 0),
26
+ "szxz": (2, 0, 1, 0),
27
+ "szyx": (2, 1, 0, 0),
28
+ "szyz": (2, 1, 1, 0),
29
+ "rzyx": (0, 0, 0, 1),
30
+ "rxyx": (0, 0, 1, 1),
31
+ "ryzx": (0, 1, 0, 1),
32
+ "rxzx": (0, 1, 1, 1),
33
+ "rxzy": (1, 0, 0, 1),
34
+ "ryzy": (1, 0, 1, 1),
35
+ "rzxy": (1, 1, 0, 1),
36
+ "ryxy": (1, 1, 1, 1),
37
+ "ryxz": (2, 0, 0, 1),
38
+ "rzxz": (2, 0, 1, 1),
39
+ "rxyz": (2, 1, 0, 1),
40
+ "rzyz": (2, 1, 1, 1),
41
+ }
42
+
43
+ _TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items())
44
+
45
+ # For testing whether a number is close to zero
46
+ _EPS4 = np.finfo(float).eps * 4.0
47
+
48
+
49
+ def mat2euler(mat, axes="sxyz"):
50
+ """Return Euler angles from rotation matrix for specified axis sequence.
51
+
52
+ Note that many Euler angle triplets can describe one matrix.
53
+
54
+ Parameters
55
+ ----------
56
+ mat : array-like shape (3, 3) or (4, 4)
57
+ Rotation matrix or affine.
58
+ axes : str, optional
59
+ Axis specification; one of 24 axis sequences as string or encoded
60
+ tuple - e.g. ``sxyz`` (the default).
61
+
62
+ Returns
63
+ -------
64
+ ai : float
65
+ First rotation angle (according to `axes`).
66
+ aj : float
67
+ Second rotation angle (according to `axes`).
68
+ ak : float
69
+ Third rotation angle (according to `axes`).
70
+
71
+ Examples
72
+ --------
73
+ >>> R0 = euler2mat(1, 2, 3, 'syxz')
74
+ >>> al, be, ga = mat2euler(R0, 'syxz')
75
+ >>> R1 = euler2mat(al, be, ga, 'syxz')
76
+ >>> np.allclose(R0, R1)
77
+ True
78
+ """
79
+ try:
80
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
81
+ except (AttributeError, KeyError):
82
+ _TUPLE2AXES[axes] # validation
83
+ firstaxis, parity, repetition, frame = axes
84
+
85
+ i = firstaxis
86
+ j = _NEXT_AXIS[i + parity]
87
+ k = _NEXT_AXIS[i - parity + 1]
88
+
89
+ M = np.array(mat, dtype=np.float64, copy=False)[:3, :3]
90
+ if repetition:
91
+ sy = math.sqrt(M[i, j] * M[i, j] + M[i, k] * M[i, k])
92
+ if sy > _EPS4:
93
+ ax = math.atan2(M[i, j], M[i, k])
94
+ ay = math.atan2(sy, M[i, i])
95
+ az = math.atan2(M[j, i], -M[k, i])
96
+ else:
97
+ ax = math.atan2(-M[j, k], M[j, j])
98
+ ay = math.atan2(sy, M[i, i])
99
+ az = 0.0
100
+ else:
101
+ cy = math.sqrt(M[i, i] * M[i, i] + M[j, i] * M[j, i])
102
+ if cy > _EPS4:
103
+ ax = math.atan2(M[k, j], M[k, k])
104
+ ay = math.atan2(-M[k, i], cy)
105
+ az = math.atan2(M[j, i], M[i, i])
106
+ else:
107
+ ax = math.atan2(-M[j, k], M[j, j])
108
+ ay = math.atan2(-M[k, i], cy)
109
+ az = 0.0
110
+
111
+ if parity:
112
+ ax, ay, az = -ax, -ay, -az
113
+ if frame:
114
+ ax, az = az, ax
115
+ return ax, ay, az
116
+
117
+
118
+ def quat2mat(q):
119
+ """Calculate rotation matrix corresponding to quaternion
120
+
121
+ Parameters
122
+ ----------
123
+ q : 4 element array-like
124
+
125
+ Returns
126
+ -------
127
+ M : (3,3) array
128
+ Rotation matrix corresponding to input quaternion *q*
129
+
130
+ Notes
131
+ -----
132
+ Rotation matrix applies to column vectors, and is applied to the
133
+ left of coordinate vectors. The algorithm here allows quaternions that
134
+ have not been normalized.
135
+
136
+ References
137
+ ----------
138
+ Algorithm from http://en.wikipedia.org/wiki/Rotation_matrix#Quaternion
139
+
140
+ Examples
141
+ --------
142
+ >>> import numpy as np
143
+ >>> M = quat2mat([1, 0, 0, 0]) # Identity quaternion
144
+ >>> np.allclose(M, np.eye(3))
145
+ True
146
+ >>> M = quat2mat([0, 1, 0, 0]) # 180 degree rotn around axis 0
147
+ >>> np.allclose(M, np.diag([1, -1, -1]))
148
+ True
149
+ """
150
+ w, x, y, z = q
151
+ Nq = w * w + x * x + y * y + z * z
152
+ if Nq < _FLOAT_EPS:
153
+ return np.eye(3)
154
+ s = 2.0 / Nq
155
+ X = x * s
156
+ Y = y * s
157
+ Z = z * s
158
+ wX = w * X
159
+ wY = w * Y
160
+ wZ = w * Z
161
+ xX = x * X
162
+ xY = x * Y
163
+ xZ = x * Z
164
+ yY = y * Y
165
+ yZ = y * Z
166
+ zZ = z * Z
167
+ return np.array(
168
+ [
169
+ [1.0 - (yY + zZ), xY - wZ, xZ + wY],
170
+ [xY + wZ, 1.0 - (xX + zZ), yZ - wX],
171
+ [xZ - wY, yZ + wX, 1.0 - (xX + yY)],
172
+ ]
173
+ )
174
+
175
+
176
+ # Checks if a matrix is a valid rotation matrix.
177
+ def isrotation(
178
+ R: np.ndarray,
179
+ thresh=1e-6,
180
+ ) -> bool:
181
+ Rt = np.transpose(R)
182
+ shouldBeIdentity = np.dot(Rt, R)
183
+ iden = np.identity(3, dtype=R.dtype)
184
+ n = np.linalg.norm(iden - shouldBeIdentity)
185
+ return n < thresh
186
+
187
+
188
+ def euler2mat(ai, aj, ak, axes="sxyz"):
189
+ """Return rotation matrix from Euler angles and axis sequence.
190
+
191
+ Parameters
192
+ ----------
193
+ ai : float
194
+ First rotation angle (according to `axes`).
195
+ aj : float
196
+ Second rotation angle (according to `axes`).
197
+ ak : float
198
+ Third rotation angle (according to `axes`).
199
+ axes : str, optional
200
+ Axis specification; one of 24 axis sequences as string or encoded
201
+ tuple - e.g. ``sxyz`` (the default).
202
+
203
+ Returns
204
+ -------
205
+ mat : array (3, 3)
206
+ Rotation matrix or affine.
207
+
208
+ Examples
209
+ --------
210
+ >>> R = euler2mat(1, 2, 3, 'syxz')
211
+ >>> np.allclose(np.sum(R[0]), -1.34786452)
212
+ True
213
+ >>> R = euler2mat(1, 2, 3, (0, 1, 0, 1))
214
+ >>> np.allclose(np.sum(R[0]), -0.383436184)
215
+ True
216
+ """
217
+ try:
218
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes]
219
+ except (AttributeError, KeyError):
220
+ _TUPLE2AXES[axes] # validation
221
+ firstaxis, parity, repetition, frame = axes
222
+
223
+ i = firstaxis
224
+ j = _NEXT_AXIS[i + parity]
225
+ k = _NEXT_AXIS[i - parity + 1]
226
+
227
+ if frame:
228
+ ai, ak = ak, ai
229
+ if parity:
230
+ ai, aj, ak = -ai, -aj, -ak
231
+
232
+ si, sj, sk = math.sin(ai), math.sin(aj), math.sin(ak)
233
+ ci, cj, ck = math.cos(ai), math.cos(aj), math.cos(ak)
234
+ cc, cs = ci * ck, ci * sk
235
+ sc, ss = si * ck, si * sk
236
+
237
+ M = np.eye(3)
238
+ if repetition:
239
+ M[i, i] = cj
240
+ M[i, j] = sj * si
241
+ M[i, k] = sj * ci
242
+ M[j, i] = sj * sk
243
+ M[j, j] = -cj * ss + cc
244
+ M[j, k] = -cj * cs - sc
245
+ M[k, i] = -sj * ck
246
+ M[k, j] = cj * sc + cs
247
+ M[k, k] = cj * cc - ss
248
+ else:
249
+ M[i, i] = cj * ck
250
+ M[i, j] = sj * sc - cs
251
+ M[i, k] = sj * cc + ss
252
+ M[j, i] = cj * sk
253
+ M[j, j] = sj * ss + cc
254
+ M[j, k] = sj * cs - sc
255
+ M[k, i] = -sj
256
+ M[k, j] = cj * si
257
+ M[k, k] = cj * ci
258
+ return M
259
+
260
+
261
+ def euler2axangle(ai, aj, ak, axes="sxyz"):
262
+ """Return angle, axis corresponding to Euler angles, axis specification
263
+
264
+ Parameters
265
+ ----------
266
+ ai : float
267
+ First rotation angle (according to `axes`).
268
+ aj : float
269
+ Second rotation angle (according to `axes`).
270
+ ak : float
271
+ Third rotation angle (according to `axes`).
272
+ axes : str, optional
273
+ Axis specification; one of 24 axis sequences as string or encoded
274
+ tuple - e.g. ``sxyz`` (the default).
275
+
276
+ Returns
277
+ -------
278
+ vector : array shape (3,)
279
+ axis around which rotation occurs
280
+ theta : scalar
281
+ angle of rotation
282
+
283
+ Examples
284
+ --------
285
+ >>> vec, theta = euler2axangle(0, 1.5, 0, 'szyx')
286
+ >>> np.allclose(vec, [0, 1, 0])
287
+ True
288
+ >>> theta
289
+ 1.5
290
+ """
291
+ return quat2axangle(euler2quat(ai, aj, ak, axes))
292
+
293
+
294
+ def euler2quat(ai, aj, ak, axes="sxyz"):
295
+ """Return `quaternion` from Euler angles and axis sequence `axes`
296
+
297
+ Parameters
298
+ ----------
299
+ ai : float
300
+ First rotation angle (according to `axes`).
301
+ aj : float
302
+ Second rotation angle (according to `axes`).
303
+ ak : float
304
+ Third rotation angle (according to `axes`).
305
+ axes : str, optional
306
+ Axis specification; one of 24 axis sequences as string or encoded
307
+ tuple - e.g. ``sxyz`` (the default).
308
+
309
+ Returns
310
+ -------
311
+ quat : array shape (4,)
312
+ Quaternion in w, x, y z (real, then vector) format
313
+
314
+ Examples
315
+ --------
316
+ >>> q = euler2quat(1, 2, 3, 'ryxz')
317
+ >>> np.allclose(q, [0.435953, 0.310622, -0.718287, 0.444435])
318
+ True
319
+ """
320
+ try:
321
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
322
+ except (AttributeError, KeyError):
323
+ _TUPLE2AXES[axes] # validation
324
+ firstaxis, parity, repetition, frame = axes
325
+
326
+ i = firstaxis + 1
327
+ j = _NEXT_AXIS[i + parity - 1] + 1
328
+ k = _NEXT_AXIS[i - parity] + 1
329
+
330
+ if frame:
331
+ ai, ak = ak, ai
332
+ if parity:
333
+ aj = -aj
334
+
335
+ ai = ai / 2.0
336
+ aj = aj / 2.0
337
+ ak = ak / 2.0
338
+ ci = math.cos(ai)
339
+ si = math.sin(ai)
340
+ cj = math.cos(aj)
341
+ sj = math.sin(aj)
342
+ ck = math.cos(ak)
343
+ sk = math.sin(ak)
344
+ cc = ci * ck
345
+ cs = ci * sk
346
+ sc = si * ck
347
+ ss = si * sk
348
+
349
+ q = np.empty((4,))
350
+ if repetition:
351
+ q[0] = cj * (cc - ss)
352
+ q[i] = cj * (cs + sc)
353
+ q[j] = sj * (cc + ss)
354
+ q[k] = sj * (cs - sc)
355
+ else:
356
+ q[0] = cj * cc + sj * ss
357
+ q[i] = cj * sc - sj * cs
358
+ q[j] = cj * ss + sj * cc
359
+ q[k] = cj * cs - sj * sc
360
+ if parity:
361
+ q[j] *= -1.0
362
+
363
+ return q
364
+
365
+
366
+ def quat2axangle(quat, identity_thresh=None):
367
+ """Convert quaternion to rotation of angle around axis
368
+
369
+ Parameters
370
+ ----------
371
+ quat : 4 element sequence
372
+ w, x, y, z forming quaternion.
373
+ identity_thresh : None or scalar, optional
374
+ Threshold below which the norm of the vector part of the quaternion (x,
375
+ y, z) is deemed to be 0, leading to the identity rotation. None (the
376
+ default) leads to a threshold estimated based on the precision of the
377
+ input.
378
+
379
+ Returns
380
+ -------
381
+ theta : scalar
382
+ angle of rotation.
383
+ vector : array shape (3,)
384
+ axis around which rotation occurs.
385
+
386
+ Examples
387
+ --------
388
+ >>> vec, theta = quat2axangle([0, 1, 0, 0])
389
+ >>> vec
390
+ array([1., 0., 0.])
391
+ >>> np.allclose(theta, np.pi)
392
+ True
393
+
394
+ If this is an identity rotation, we return a zero angle and an arbitrary
395
+ vector:
396
+
397
+ >>> quat2axangle([1, 0, 0, 0])
398
+ (array([1., 0., 0.]), 0.0)
399
+
400
+ If any of the quaternion values are not finite, we return a NaN in the
401
+ angle, and an arbitrary vector:
402
+
403
+ >>> quat2axangle([1, np.inf, 0, 0])
404
+ (array([1., 0., 0.]), nan)
405
+
406
+ Notes
407
+ -----
408
+ A quaternion for which x, y, z are all equal to 0, is an identity rotation.
409
+ In this case we return a 0 angle and an arbitrary vector, here [1, 0, 0].
410
+
411
+ The algorithm allows for quaternions that have not been normalized.
412
+ """
413
+ quat = np.asarray(quat)
414
+ Nq = np.sum(quat**2)
415
+ if not np.isfinite(Nq):
416
+ return np.array([1.0, 0, 0]), float("nan")
417
+ if identity_thresh is None:
418
+ try:
419
+ identity_thresh = np.finfo(Nq.type).eps * 3
420
+ except (AttributeError, ValueError): # Not a numpy type or not float
421
+ identity_thresh = _FLOAT_EPS * 3
422
+ if Nq < _FLOAT_EPS**2: # Results unreliable after normalization
423
+ return np.array([1.0, 0, 0]), 0.0
424
+ if Nq != 1: # Normalize if not normalized
425
+ s = math.sqrt(Nq)
426
+ quat = quat / s
427
+ xyz = quat[1:]
428
+ len2 = np.sum(xyz**2)
429
+ if len2 < identity_thresh**2:
430
+ # if vec is nearly 0,0,0, this is an identity rotation
431
+ return np.array([1.0, 0, 0]), 0.0
432
+ # Make sure w is not slightly above 1 or below -1
433
+ theta = 2 * math.acos(max(min(quat[0], 1), -1))
434
+ return xyz / math.sqrt(len2), theta
435
+
436
+
437
+ def quat2euler(quaternion, axes="sxyz"):
438
+ """Euler angles from `quaternion` for specified axis sequence `axes`
439
+
440
+ Parameters
441
+ ----------
442
+ q : 4 element sequence
443
+ w, x, y, z of quaternion
444
+ axes : str, optional
445
+ Axis specification; one of 24 axis sequences as string or encoded
446
+ tuple - e.g. ``sxyz`` (the default).
447
+
448
+ Returns
449
+ -------
450
+ ai : float
451
+ First rotation angle (according to `axes`).
452
+ aj : float
453
+ Second rotation angle (according to `axes`).
454
+ ak : float
455
+ Third rotation angle (according to `axes`).
456
+
457
+ Examples
458
+ --------
459
+ >>> angles = quat2euler([0.99810947, 0.06146124, 0, 0])
460
+ >>> np.allclose(angles, [0.123, 0, 0])
461
+ True
462
+ """
463
+ return mat2euler(quat2mat(quaternion), axes)
evaluation/robotwin/launch_client.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ export LD_LIBRARY_PATH=/usr/lib64:/usr/lib:$LD_LIBRARY_PATH
3
+
4
+ task_groups=(
5
+ "stack_bowls_three handover_block hanging_mug scan_object lift_pot put_object_cabinet stack_blocks_three place_shoe"
6
+ "adjust_bottle place_mouse_pad dump_bin_bigbin move_pillbottle_pad pick_dual_bottles shake_bottle place_fan turn_switch"
7
+ "shake_bottle_horizontally place_container_plate rotate_qrcode place_object_stand put_bottles_dustbin move_stapler_pad place_burger_fries place_bread_basket"
8
+ "pick_diverse_bottles open_microwave beat_block_hammer press_stapler click_bell move_playingcard_away open_laptop move_can_pot"
9
+ "stack_bowls_two place_a2b_right stamp_seal place_object_basket handover_mic place_bread_skillet stack_blocks_two place_cans_plasticbox"
10
+ "click_alarmclock blocks_ranking_size place_phone_stand place_can_basket place_object_scale place_a2b_left grab_roller place_dual_shoes"
11
+ "place_empty_cup blocks_ranking_rgb place_empty_cup blocks_ranking_rgb place_empty_cup blocks_ranking_rgb place_empty_cup blocks_ranking_rgb"
12
+ )
13
+
14
+ save_root='./results'
15
+ task_name="adjust_bottle"
16
+
17
+ policy_name=ACT
18
+ task_config=demo_clean
19
+ train_config_name=0
20
+ model_name=0
21
+ seed=0
22
+ PORT=29056
23
+
24
+ PYTHONWARNINGS=ignore::UserWarning \
25
+ XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 python -m evaluation.robotwin.eval_polict_client_openpi --config policy/$policy_name/deploy_policy.yml \
26
+ --overrides \
27
+ --task_name ${task_name} \
28
+ --task_config ${task_config} \
29
+ --train_config_name ${train_config_name} \
30
+ --model_name ${model_name} \
31
+ --ckpt_setting ${model_name} \
32
+ --seed ${seed} \
33
+ --policy_name ${policy_name} \
34
+ --save_root ${save_root} \
35
+ --video_guidance_scale 5 \
36
+ --action_guidance_scale 1 \
37
+ --test_num 100 \
38
+ --port ${PORT}
39
+
40
+
evaluation/robotwin/launch_client_multigpus.sh ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ export LD_LIBRARY_PATH=/usr/lib64:/usr/lib:$LD_LIBRARY_PATH
3
+
4
+
5
+ save_root=${1:-'./results'}
6
+
7
+ # General parameters
8
+ policy_name=ACT
9
+ task_config=demo_clean
10
+ train_config_name=0
11
+ model_name=0
12
+ seed=${3:-0}
13
+ test_num=${4:-100}
14
+ start_port=29556
15
+ num_gpus=8
16
+
17
+ task_list_id=${2:-0}
18
+
19
+ task_groups=(
20
+ "stack_bowls_three handover_block hanging_mug scan_object lift_pot put_object_cabinet stack_blocks_three place_shoe"
21
+ "adjust_bottle place_mouse_pad dump_bin_bigbin move_pillbottle_pad pick_dual_bottles shake_bottle place_fan turn_switch"
22
+ "shake_bottle_horizontally place_container_plate rotate_qrcode place_object_stand put_bottles_dustbin move_stapler_pad place_burger_fries place_bread_basket"
23
+ "pick_diverse_bottles open_microwave beat_block_hammer press_stapler click_bell move_playingcard_away open_laptop move_can_pot"
24
+ "stack_bowls_two place_a2b_right stamp_seal place_object_basket handover_mic place_bread_skillet stack_blocks_two place_cans_plasticbox"
25
+ "click_alarmclock blocks_ranking_size place_phone_stand place_can_basket place_object_scale place_a2b_left grab_roller place_dual_shoes"
26
+ "place_empty_cup blocks_ranking_rgb place_empty_cup blocks_ranking_rgb place_empty_cup blocks_ranking_rgb place_empty_cup blocks_ranking_rgb"
27
+ )
28
+
29
+ if (( task_list_id < 0 || task_list_id >= ${#task_groups[@]} )); then
30
+ echo "task_list_id out of range: $task_list_id (0..$(( ${#task_groups[@]} - 1 )))" >&2
31
+ exit 1
32
+ fi
33
+
34
+ read -r -a task_names <<< "${task_groups[$task_list_id]}"
35
+
36
+ echo "task_list_id=$task_list_id"
37
+ printf 'task_names (%d): %s\n' "${#task_names[@]}" "${task_names[*]}"
38
+
39
+ log_dir="./logs"
40
+ mkdir -p "$log_dir"
41
+
42
+ echo -e "\033[32mLaunching ${#task_names[@]} tasks. GPUs assigned by mod ${num_gpus}, ports starting from ${start_port} incrementing.\033[0m"
43
+
44
+ pid_file="pids.txt"
45
+ > "$pid_file"
46
+
47
+ batch_time=$(date +%Y%m%d_%H%M%S)
48
+
49
+ for i in "${!task_names[@]}"; do
50
+ task_name="${task_names[$i]}"
51
+ gpu_id=$(( i % num_gpus ))
52
+ port=$(( start_port + i ))
53
+
54
+ export CUDA_VISIBLE_DEVICES=${gpu_id}
55
+
56
+ log_file="${log_dir}/${task_name}_${batch_time}.log"
57
+
58
+ echo -e "\033[33m[Task $i] Task: ${task_name}, GPU: ${gpu_id}, PORT: ${port}, Log: ${log_file}\033[0m"
59
+
60
+ PYTHONWARNINGS=ignore::UserWarning \
61
+ XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 python -m evaluation.robotwin.eval_polict_client_openpi --config policy/$policy_name/deploy_policy.yml \
62
+ --overrides \
63
+ --task_name ${task_name} \
64
+ --task_config ${task_config} \
65
+ --train_config_name ${train_config_name} \
66
+ --model_name ${model_name} \
67
+ --ckpt_setting ${model_name} \
68
+ --seed ${seed} \
69
+ --policy_name ${policy_name} \
70
+ --save_root ${save_root} \
71
+ --video_guidance_scale 5 \
72
+ --action_guidance_scale 1 \
73
+ --test_num ${test_num} \
74
+ --port ${port} > "$log_file" 2>&1 &
75
+
76
+ pid=$!
77
+ echo "${pid}" | tee -a "$pid_file"
78
+ done
79
+
80
+ echo -e "\033[32mAll tasks launched. PIDs saved to ${pid_file}\033[0m"
81
+ echo -e "\033[36mTo terminate all processes, run: kill \$(cat ${pid_file})\033[0m"
evaluation/robotwin/launch_server.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ START_PORT=${START_PORT:-29056}
2
+ MASTER_PORT=${MASTER_PORT:-29061}
3
+
4
+ save_root='visualization/'
5
+ mkdir -p $save_root
6
+
7
+ python -m torch.distributed.run \
8
+ --nproc_per_node 1 \
9
+ --master_port $MASTER_PORT \
10
+ wan_va/wan_va_server.py \
11
+ --config-name robotwin \
12
+ --port $START_PORT \
13
+ --save_root $save_root
14
+
15
+
evaluation/robotwin/launch_server_multigpus.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ START_PORT=${START_PORT:-29556}
2
+ MASTER_PORT=${MASTER_PORT:-29661}
3
+ LOG_DIR='./logs'
4
+ mkdir -p $LOG_DIR
5
+
6
+ save_root='./visualization/'
7
+ mkdir -p $save_root
8
+
9
+ batch_time=$(date +%Y%m%d_%H%M%S)
10
+
11
+
12
+ for i in {0..7}; do
13
+ CURRENT_PORT=$((START_PORT + i))
14
+ CURRENT_MASTER_PORT=$((MASTER_PORT + i))
15
+
16
+ LOG_FILE="${LOG_DIR}/server_${i}_${batch_time}.log"
17
+ echo "[Task ${j}] GPU: ${i} | PORT: ${CURRENT_PORT} | MASTER_PORT: ${CURRENT_MASTER_PORT} | Log: ${LOG_FILE}"
18
+
19
+ CUDA_VISIBLE_DEVICES=$i \
20
+ nohup python -m torch.distributed.run \
21
+ --nproc_per_node 1 \
22
+ --master_port $CURRENT_MASTER_PORT \
23
+ wan_va/wan_va_server.py \
24
+ --config-name robotwin \
25
+ --save_root $save_root \
26
+ --port $CURRENT_PORT > $LOG_FILE 2>&1 &
27
+ sleep 2;
28
+ done
29
+
30
+ echo "All 8 instances have been launched in the background."
31
+ wait
evaluation/robotwin/msgpack_numpy.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adds NumPy array support to msgpack.
2
+
3
+ msgpack is good for (de)serializing data over a network for multiple reasons:
4
+ - msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)
5
+ - msgpack is widely used and has good cross-language support
6
+ - msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed
7
+ languages like Python and JavaScript
8
+ - msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster
9
+ than pickle for serializing large arrays using the below strategy
10
+
11
+ The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is
12
+ that it falls back to pickle for object arrays.
13
+ """
14
+
15
+ import functools
16
+
17
+ import msgpack
18
+ import numpy as np
19
+
20
+
21
+ def pack_array(obj):
22
+ if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
23
+ raise ValueError(f"Unsupported dtype: {obj.dtype}")
24
+
25
+ if isinstance(obj, np.ndarray):
26
+ return {
27
+ b"__ndarray__": True,
28
+ b"data": obj.tobytes(),
29
+ b"dtype": obj.dtype.str,
30
+ b"shape": obj.shape,
31
+ }
32
+
33
+ if isinstance(obj, np.generic):
34
+ return {
35
+ b"__npgeneric__": True,
36
+ b"data": obj.item(),
37
+ b"dtype": obj.dtype.str,
38
+ }
39
+
40
+ return obj
41
+
42
+
43
+ def unpack_array(obj):
44
+ if b"__ndarray__" in obj:
45
+ return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
46
+
47
+ if b"__npgeneric__" in obj:
48
+ return np.dtype(obj[b"dtype"]).type(obj[b"data"])
49
+
50
+ return obj
51
+
52
+
53
+ Packer = functools.partial(msgpack.Packer, default=pack_array)
54
+ packb = functools.partial(msgpack.packb, default=pack_array)
55
+
56
+ Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
57
+ unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)
evaluation/robotwin/test_render.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import warnings
3
+ import os
4
+
5
+ warnings.simplefilter(action="ignore", category=FutureWarning)
6
+ warnings.simplefilter(action="ignore", category=UserWarning)
7
+ current_file_path = os.path.abspath(__file__)
8
+ parent_dir = os.path.dirname(current_file_path)
9
+
10
+ sys.path.append(os.path.join(parent_dir, "../../tools"))
11
+ import numpy as np
12
+ import pdb
13
+ import json
14
+ import torch
15
+ import sapien.core as sapien
16
+ from sapien.utils.viewer import Viewer
17
+ import gymnasium as gym
18
+ import toppra as ta
19
+ import transforms3d as t3d
20
+ from collections import OrderedDict
21
+
22
+ import sys
23
+ import warnings
24
+ import os
25
+
26
+ warnings.simplefilter(action="ignore", category=FutureWarning)
27
+ warnings.simplefilter(action="ignore", category=UserWarning)
28
+ current_file_path = os.path.abspath(__file__)
29
+ parent_dir = os.path.dirname(current_file_path)
30
+
31
+ sys.path.append(os.path.join(parent_dir, "../../tools"))
32
+ import numpy as np
33
+ import pdb
34
+ import json
35
+ import torch
36
+ import sapien.core as sapien
37
+ from sapien.utils.viewer import Viewer
38
+ import gymnasium as gym
39
+ import toppra as ta
40
+ import transforms3d as t3d
41
+ from collections import OrderedDict
42
+
43
+
44
+ class Sapien_TEST(gym.Env):
45
+
46
+ def __init__(self):
47
+ super().__init__()
48
+ ta.setup_logging("CRITICAL") # hide logging
49
+ try:
50
+ self.setup_scene()
51
+ print("\033[32m" + "Render Well" + "\033[0m")
52
+ except:
53
+ print("\033[31m" + "Render Error" + "\033[0m")
54
+ exit()
55
+
56
+ def setup_scene(self, **kwargs):
57
+ """
58
+ Set the scene
59
+ - Set up the basic scene: light source, viewer.
60
+ """
61
+ self.engine = sapien.Engine()
62
+ # declare sapien renderer
63
+ from sapien.render import set_global_config
64
+
65
+ set_global_config(max_num_materials=50000, max_num_textures=50000)
66
+ self.renderer = sapien.SapienRenderer()
67
+ # give renderer to sapien sim
68
+ self.engine.set_renderer(self.renderer)
69
+
70
+ sapien.render.set_camera_shader_dir("rt")
71
+ sapien.render.set_ray_tracing_samples_per_pixel(32)
72
+ sapien.render.set_ray_tracing_path_depth(8)
73
+ sapien.render.set_ray_tracing_denoiser("oidn")
74
+
75
+ # declare sapien scene
76
+ scene_config = sapien.SceneConfig()
77
+ self.scene = self.engine.create_scene(scene_config)
78
+
79
+
80
+ if __name__ == "__main__":
81
+ a = Sapien_TEST()
evaluation/robotwin/websocket_client_policy.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ from typing import Dict, Optional, Tuple
4
+
5
+ from typing_extensions import override
6
+ import websockets.sync.client
7
+ from .msgpack_numpy import Packer, unpackb
8
+
9
+
10
+ class WebsocketClientPolicy:
11
+ """Implements the Policy interface by communicating with a server over websocket.
12
+
13
+ See WebsocketPolicyServer for a corresponding server implementation.
14
+ """
15
+
16
+ def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, api_key: Optional[str] = None) -> None:
17
+ self._uri = f"ws://{host}"
18
+ if port is not None:
19
+ self._uri += f":{port}"
20
+ self._packer = Packer()
21
+ self._api_key = api_key
22
+ self._ws, self._server_metadata = self._wait_for_server()
23
+
24
+ def get_server_metadata(self) -> Dict:
25
+ return self._server_metadata
26
+
27
+ # def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:
28
+ # logging.info(f"Waiting for server at {self._uri}...")
29
+ # while True:
30
+ # try:
31
+ # headers = {"Authorization": f"Api-Key {self._api_key}"} if self._api_key else None
32
+ # conn = websockets.sync.client.connect(
33
+ # self._uri, compression=None, max_size=None, additional_headers=headers
34
+ # )
35
+ # metadata = unpackb(conn.recv())
36
+ # return conn, metadata
37
+ # except ConnectionRefusedError:
38
+ # logging.info("Still waiting for server...")
39
+ # time.sleep(5)
40
+
41
+ def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:
42
+ logging.info(f"Waiting for server at {self._uri}...")
43
+ while True:
44
+ try:
45
+ headers = {"Authorization": f"Api-Key {self._api_key}"} if self._api_key else None
46
+ # 禁用 ping 机制,防止推理时间过长导致超时
47
+ conn = websockets.sync.client.connect(
48
+ self._uri,
49
+ compression=None,
50
+ max_size=None,
51
+ additional_headers=headers,
52
+ ping_interval=None,
53
+ close_timeout=10
54
+ )
55
+ metadata = unpackb(conn.recv())
56
+ return conn, metadata
57
+ except (ConnectionRefusedError, Exception) as e:
58
+ logging.info(f"Still waiting for server... (Error: {e})")
59
+ time.sleep(5)
60
+
61
+ @override
62
+ def infer(self, obs: Dict) -> Dict: # noqa: UP006
63
+ data = self._packer.pack(obs)
64
+ self._ws.send(data)
65
+ response = self._ws.recv()
66
+ if isinstance(response, str):
67
+ # we're expecting bytes; if the server sends a string, it's an error.
68
+ raise RuntimeError(f"Error in inference server:\n{response}")
69
+ return unpackb(response)
70
+
71
+ @override
72
+ def reset(self) -> None:
73
+ pass
74
+
75
+ if __name__ == "__main__":
76
+ policy_on_device = WebsocketClientPolicy(port=8000)
77
+ import torch
78
+ import numpy as np
79
+ from PIL import Image
80
+ from .image_tools import convert_to_uint8
81
+ device = torch.device("cuda")
82
+
83
+ base_0_rgb = np.random.randint(0, 256, size=(1, 3, 224, 224), dtype=np.uint8)
84
+ left_wrist_0_rgb = np.random.randint(0, 256, size=(1, 3, 224, 224), dtype=np.uint8)
85
+ state = np.random.rand(1,8).astype(np.float32)
86
+ prompt = ["do something"]
87
+
88
+ # observation = {
89
+ # "image": {
90
+ # "base_0_rgb": torch.from_numpy(base_0_rgb).to(device)[None],
91
+ # "left_wrist_0_rgb": torch.from_numpy(left_wrist_0_rgb).to(device)[None],
92
+ # },
93
+ # "state": torch.from_numpy(state).to(device)[None],
94
+ # "prompt": prompt,
95
+ # }
96
+
97
+ observation = {
98
+ "image": {
99
+ "base_0_rgb": convert_to_uint8(base_0_rgb),
100
+ "left_wrist_0_rgb": convert_to_uint8(left_wrist_0_rgb),
101
+ "right_wrist_0_rgb": convert_to_uint8(left_wrist_0_rgb),
102
+ },
103
+ "state": state,
104
+ "prompt": prompt,
105
+ }
106
+
107
+ policy_on_device.infer(observation)
108
+ from IPython import embed;embed()
example/franka/observation.images.cam_high.png ADDED
example/franka/observation.images.cam_left_wrist.png ADDED
example/franka/observation.images.cam_right_wrist.png ADDED
example/robotwin/observation.images.cam_high.png ADDED
example/robotwin/observation.images.cam_left_wrist.png ADDED
example/robotwin/observation.images.cam_right_wrist.png ADDED
lingbot_robotwin_policy.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import time
4
+ import random
5
+ import numpy as np
6
+ from collections import deque
7
+ import torchvision
8
+ import yaml
9
+ from types import SimpleNamespace
10
+ from packaging.version import Version
11
+ from typing import Callable, Dict, List, Optional, Type, Union, Tuple, Any, Sequence
12
+ from glob import glob
13
+ from tqdm import tqdm
14
+ from safetensors import safe_open
15
+ from safetensors.torch import load_file
16
+ from pathlib import Path
17
+ from PIL import Image
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import Tensor, nn
21
+
22
+
23
+ import transformers
24
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
25
+ from transformers import (
26
+ AutoConfig,
27
+ PretrainedConfig,
28
+ PreTrainedModel,
29
+ AutoProcessor,
30
+ )
31
+
32
+ from lerobot.configs.policies import PreTrainedConfig
33
+ from lingbotvla.models.vla.pi0.modeling_pi0 import PI0Policy
34
+ from lingbotvla.models.vla.pi0.modeling_lingbot_vla import LingbotVlaPolicy
35
+ from lingbotvla.data.vla_data.transform import Normalizer, prepare_images, prepare_language, prepare_state
36
+ from lingbotvla.models import build_processor
37
+
38
+
39
+ def set_seed_everywhere(seed: int):
40
+ """Sets the random seed for Python, NumPy, and PyTorch functions."""
41
+ torch.manual_seed(seed)
42
+ torch.cuda.manual_seed_all(seed)
43
+ np.random.seed(seed)
44
+ random.seed(seed)
45
+ torch.backends.cudnn.deterministic = True
46
+ torch.backends.cudnn.benchmark = False
47
+ os.environ["PYTHONHASHSEED"] = str(seed)
48
+
49
+ set_seed_everywhere(42)
50
+
51
+ BASE_MODEL_PATH = {
52
+ 'pi0': os.environ.get('PALIGEMMA_PATH', './paligemma-3b-pt-224/'),
53
+ 'lingbotvla': os.environ.get('QWEN25_PATH', './Qwen2.5-VL-3B-Instruct/'),
54
+ }
55
+
56
+ def load_model_weights(policy, path_to_pi_model, strict=True):
57
+ all_safetensors = glob(os.path.join(path_to_pi_model, "*.safetensors"))
58
+ merged_weights = {}
59
+
60
+ for file_path in tqdm(all_safetensors):
61
+ with safe_open(file_path, framework="pt", device="cpu") as f:
62
+ for key in f.keys():
63
+ merged_weights[key] = f.get_tensor(key)
64
+ policy.load_state_dict(merged_weights, strict=strict)
65
+
66
+
67
+ def center_crop_image(image: Union[np.ndarray, Image.Image]) -> Image.Image:
68
+ crop_scale = 0.9
69
+ side_scale = float(np.sqrt(np.clip(crop_scale, 0.0, 1.0))) # side length scale
70
+ out_size = (224, 224)
71
+
72
+ # Convert input to PIL Image
73
+ if isinstance(image, np.ndarray):
74
+ arr = image
75
+ if arr.dtype.kind == "f":
76
+ # If floats likely in [0,1], map to [0,255]
77
+ if arr.max() <= 1.0 and arr.min() >= 0.0:
78
+ arr = (np.clip(arr, 0.0, 1.0) * 255.0).astype(np.uint8)
79
+ else:
80
+ arr = np.clip(arr, 0.0, 255.0).astype(np.uint8)
81
+ elif arr.dtype == np.uint16:
82
+ # Map 16-bit to 8-bit
83
+ arr = (arr / 257).astype(np.uint8)
84
+ elif arr.dtype != np.uint8:
85
+ arr = arr.astype(np.uint8)
86
+ pil = Image.fromarray(arr)
87
+ elif isinstance(image, Image.Image):
88
+ pil = image
89
+ else:
90
+ raise TypeError("image must be a numpy array or PIL.Image.Image")
91
+
92
+ # Force RGB for consistent output
93
+ pil = pil.convert("RGB")
94
+ W, H = pil.size
95
+
96
+ # Compute centered crop box (integer pixels)
97
+ crop_w = max(1, int(round(W * side_scale)))
98
+ crop_h = max(1, int(round(H * side_scale)))
99
+ left = (W - crop_w) // 2
100
+ top = (H - crop_h) // 2
101
+ right = left + crop_w
102
+ bottom = top + crop_h
103
+
104
+ cropped = pil.crop((left, top, right, bottom))
105
+ resized = cropped.resize(out_size, resample=Image.BILINEAR)
106
+ return resized
107
+
108
+ def resize_with_pad(img, width, height, pad_value=-1):
109
+ # assume no-op when width height fits already
110
+ if img.ndim != 4:
111
+ raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
112
+
113
+ # channel last to channel first if necessary
114
+ if img.shape[1] not in (1, 3) and img.shape[-1] in (1, 3):
115
+ img = img.permute(0, 3, 1, 2)
116
+
117
+ cur_height, cur_width = img.shape[2:]
118
+
119
+ ratio = max(cur_width / width, cur_height / height)
120
+ resized_height = int(cur_height / ratio)
121
+ resized_width = int(cur_width / ratio)
122
+ resized_img = F.interpolate(
123
+ img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
124
+ )
125
+
126
+ pad_height = max(0, int(height - resized_height))
127
+ pad_width = max(0, int(width - resized_width))
128
+
129
+ # pad on left and top of image
130
+ padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
131
+ return padded_img
132
+
133
+ class PolicyPreprocessMixin:
134
+
135
+ @torch.no_grad
136
+ def select_action(
137
+ self, observation: dict[str, Tensor], use_bf16: bool = False, vlm_causal: bool = False, noise: Tensor | None = None
138
+ ):
139
+ self.eval()
140
+ device = 'cuda'
141
+ if use_bf16:
142
+ dtype = torch.bfloat16
143
+ else:
144
+ dtype = torch.float32
145
+ s1 = time.time()
146
+
147
+ if len(observation['images'].shape) == 4:
148
+ observation['images'] = observation['images'].unsqueeze(0)
149
+ observation['img_masks'] = observation['img_masks'].unsqueeze(0)
150
+
151
+ if 'expert_imgs' in observation:
152
+ actions = self.model.sample_actions(
153
+ observation['images'].to(dtype=dtype, device=device),
154
+ observation['img_masks'].to(device=device),
155
+ observation['lang_tokens'].unsqueeze(0).to(device=device),
156
+ observation['lang_masks'].unsqueeze(0).to(device=device),
157
+ observation['state'].unsqueeze(0).to(dtype=dtype, device=device),
158
+ observation['expert_imgs'].to(dtype=dtype, device=device),
159
+ vlm_causal = vlm_causal
160
+ )
161
+ else:
162
+ actions = self.model.sample_actions(
163
+ observation['images'].to(dtype=dtype, device=device),
164
+ observation['img_masks'].to(device=device),
165
+ observation['lang_tokens'].unsqueeze(0).to(device=device),
166
+ observation['lang_masks'].unsqueeze(0).to(device=device),
167
+ observation['state'].unsqueeze(0).to(dtype=dtype, device=device),
168
+ vlm_causal = vlm_causal
169
+ )
170
+ delta_time = time.time() - s1
171
+ print(f'sample_actions cost {delta_time} s')
172
+ observation['action'] = actions.squeeze(0)[:, :14].to(dtype=torch.float32, device='cpu')
173
+ if use_bf16:
174
+ observation['state'] = observation['state'].to(dtype=torch.float32)
175
+ data = self.normalizer.unnormalize(observation)
176
+ return data
177
+
178
+ class LingBotVlaInferencePolicy(PolicyPreprocessMixin, LingbotVlaPolicy):
179
+ pass # Only combine necessary functions
180
+
181
+ class PI0InfernecePolicy(PolicyPreprocessMixin, PI0Policy):
182
+ pass # Only combine necessary functions
183
+
184
+
185
+ def merge_qwen_config(policy_config, qwen_config):
186
+ if hasattr(qwen_config, 'to_dict'):
187
+ config_dict = qwen_config.to_dict()
188
+ else:
189
+ config_dict = qwen_config
190
+
191
+ text_keys = {
192
+ "hidden_size",
193
+ "intermediate_size",
194
+ "num_hidden_layers",
195
+ "num_attention_heads",
196
+ "num_key_value_heads",
197
+ "rms_norm_eps",
198
+ "rope_theta",
199
+ "vocab_size",
200
+ "max_position_embeddings",
201
+ "hidden_act",
202
+ "tie_word_embeddings",
203
+ "tokenizer_path",
204
+ }
205
+
206
+ for key in text_keys:
207
+ if key in config_dict:
208
+ setattr(policy_config, key, config_dict[key])
209
+ print(f"✅ Merged LLM: {key} = {config_dict[key]}")
210
+
211
+ if "vision_config" in config_dict:
212
+ policy_config.vision_config = qwen_config.vision_config
213
+ else:
214
+ print("⚠️ Warning: 'vision_config' not found in qwen_config!")
215
+
216
+ return policy_config
217
+
218
+
219
+ class QwenPiServer:
220
+ '''
221
+ policy wrapper to support action ensemble or chunk execution
222
+ '''
223
+ def __init__(
224
+ self,
225
+ path_to_pi_model="",
226
+ adaptive_ensemble_alpha=0.1,
227
+ action_ensemble_horizon=8,
228
+ use_length=1, # to control the execution length of the action chunk, -1 denotes using action ensemble
229
+ chunk_ret=False,
230
+ use_bf16=True,
231
+ use_fp32=False,
232
+ ) -> None:
233
+ assert not (use_bf16 and use_fp32), 'Bfloat16 or Float32!!!'
234
+ self.adaptive_ensemble_alpha = adaptive_ensemble_alpha
235
+ self.use_length = use_length
236
+ self.chunk_ret = chunk_ret
237
+
238
+ self.task_description = None
239
+
240
+ self.vla = self.load_vla(path_to_pi_model)
241
+ self.vla = self.vla.cuda().eval()
242
+ if use_bf16:
243
+ self.vla = self.vla.to(torch.bfloat16)
244
+ elif use_fp32:
245
+ self.vla.model.float()
246
+ self.global_step = 0
247
+ self.last_action_chunk = None
248
+ self.use_bf16 = use_bf16
249
+ self.use_fp32 = use_fp32
250
+
251
+ def load_vla(self, path_to_pi_model) -> LingbotVlaPolicy:
252
+ # load model
253
+
254
+ print(f"loading model from: {path_to_pi_model}")
255
+ config = PreTrainedConfig.from_pretrained(path_to_pi_model)
256
+
257
+ # load training config
258
+ training_config_path = Path(path_to_pi_model).parent.parent.parent/'lingbotvla_cli.yaml'
259
+ with open(training_config_path, 'r') as f:
260
+ training_config = yaml.safe_load(f)
261
+ f.close()
262
+
263
+ # update model config according to training config
264
+ training_model_config = training_config['model']
265
+ training_model_config.update(training_config['train'])
266
+ for k, v in training_model_config.items():
267
+ v = getattr(config, k, training_model_config[k])
268
+ setattr(config, k, v)
269
+
270
+ # Set attention_implementation to 'eager' to speed up evaluation.
271
+ config.attention_implementation = 'eager'
272
+
273
+ # set base model according to training config
274
+ training_base_model = training_config['model']['tokenizer_path']
275
+ if 'paligemma' in training_base_model:
276
+ model_name = 'pi0'
277
+ config.vocab_size = 257152 # set vocab size for paligamma
278
+ elif 'qwen2' in training_base_model.lower():
279
+ model_name = 'lingbotvla'
280
+ else:
281
+ raise ValueError(f"Unsupported base model of {path_to_pi_model}")
282
+ base_model_path = BASE_MODEL_PATH[model_name]
283
+ config.tokenizer_path = base_model_path
284
+ self.model_name = model_name
285
+
286
+ qwen_config = AutoConfig.from_pretrained(base_model_path)
287
+ config = merge_qwen_config(config, qwen_config)
288
+
289
+ if 'vocab_size' in training_config['model'] and training_config['model']['vocab_size'] != 0:
290
+ config.vocab_size = training_config['model']['vocab_size']
291
+ # load processors
292
+ self.processor = build_processor(base_model_path)
293
+ self.language_tokenizer = self.processor.tokenizer
294
+ self.image_processor = self.processor.image_processor
295
+ data_config = SimpleNamespace(**training_config['data'])
296
+
297
+ print('Initializing model ... ')
298
+
299
+ if 'paligemma' in training_base_model:
300
+ policy = PI0InfernecePolicy(config, tokenizer_path=base_model_path)
301
+ else:
302
+ policy = LingBotVlaInferencePolicy(config, tokenizer_path=base_model_path)
303
+
304
+ load_model_weights(policy, path_to_pi_model, strict=True)
305
+
306
+ policy.feature_transform = None
307
+ self.data_config = data_config
308
+ self.config = config
309
+ self.joint_max_dim = training_config['train']['max_action_dim']
310
+ self.action_dim = training_config['train']['action_dim']
311
+ self.chunk_size = training_config['train']['chunk_size']
312
+ policy.action_dim = self.action_dim
313
+ policy.chunk_size = self.chunk_size
314
+ self.norm_stats_file = data_config.norm_stats_file
315
+ if 'align_params' in training_config['train']:
316
+ self.use_depth_align = True
317
+ else: self.use_depth_align = False
318
+ with open(self.norm_stats_file) as f:
319
+ self.norm_stats = json.load(f)
320
+ policy.normalizer = Normalizer(
321
+ norm_stats=self.norm_stats['norm_stats'],
322
+ from_file=True,
323
+ data_type='robotwin',
324
+ norm_type={
325
+ "observation.images.cam_high": "identity",
326
+ "observation.images.cam_left_wrist": "identity",
327
+ "observation.images.cam_right_wrist": "identity",
328
+ "observation.state": self.data_config.norm_type,
329
+ "action": self.data_config.norm_type,
330
+ },
331
+ )
332
+
333
+ print('Model initialized ... ')
334
+
335
+ return policy
336
+
337
+ def reset(self, robo_name, path_to_pi_model = None) -> None:
338
+
339
+ if path_to_pi_model is not None:
340
+ self.vla = self.load_vla(path_to_pi_model)
341
+ self.vla = self.vla.cuda().eval()
342
+ if self.use_bf16:
343
+ self.vla = self.vla.to(torch.bfloat16)
344
+ elif self.use_fp32:
345
+ self.vla.model.float()
346
+
347
+ self.global_step = 0
348
+ self.last_action_chunk = None
349
+
350
+ if getattr(self.data_config, 'norm_type', None) is None:
351
+ self.data_config.norm_type = 'meanstd'
352
+ if getattr(self.config, 'vlm_causal', None) is None:
353
+ self.config.vlm_causal = False
354
+ if getattr(self.config, 'qwenvl_bos', None) is None:
355
+ self.config.qwenvl_bos = False
356
+
357
+ # if update ckpt path
358
+ if path_to_pi_model is not None:
359
+ all_safetensors = glob(os.path.join(path_to_pi_model, "*.safetensors"))
360
+ merged_weights = {}
361
+
362
+ for file_path in tqdm(all_safetensors):
363
+ with safe_open(file_path, framework="pt", device="cpu") as f:
364
+ for key in f.keys():
365
+ merged_weights[key] = f.get_tensor(key)
366
+
367
+ self.vla.load_state_dict(merged_weights, strict=True)
368
+
369
+ def resize_image(self, observation):
370
+ for image_feature in ['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist']:
371
+ assert image_feature in observation
372
+ assert len(observation[image_feature].shape)==3 and observation[image_feature].shape[-1] == 3
373
+ image = observation[image_feature]
374
+ img_pil = Image.fromarray(image)
375
+ image_size = getattr(self.data_config, 'img_size', 224)
376
+ img_pil = img_pil.resize((image_size, image_size), Image.BILINEAR)
377
+
378
+ # img_resized shape: C*H*W
379
+ img_resized = np.transpose(np.array(img_pil), (2,0,1)) # (3,224,224)
380
+ observation[image_feature] = img_resized / 255.
381
+
382
+ def infer(self, observation, center_crop=True):
383
+ """Generates an action with the VLA policy."""
384
+
385
+ # (If trained with image augmentations) Center crop image and then resize back up to original size.
386
+ # IMPORTANT: Let's say crop scale == 0.9. To get the new height and width (post-crop), multiply
387
+ # the original height and width by sqrt(0.9) -- not 0.9!
388
+ if 'reset' in observation and observation['reset']:
389
+ self.reset(robo_name=observation['robo_name'], path_to_pi_model=observation['path_to_pi_model'] if 'path_to_pi_model' in observation else None)
390
+ return dict(action = None)
391
+
392
+ self.resize_image(observation)
393
+ for k, v in observation.items():
394
+ if isinstance(v, np.ndarray):
395
+ observation[k] = torch.from_numpy(v)
396
+
397
+ if self.use_length == -1 or self.global_step % self.use_length == 0:
398
+ joint_max_dim = getattr(self, 'joint_max_dim')
399
+ action_dim = getattr(self, 'action_dim')
400
+ chunk_size = getattr(self, 'chunk_size')
401
+ normalized_observation = self.vla.normalizer.normalize(observation)
402
+ base_image = (normalized_observation["observation.images.cam_high"] * 255).to(torch.uint8)
403
+ left_wrist_image = (normalized_observation["observation.images.cam_left_wrist"] * 255).to(
404
+ torch.uint8
405
+ )
406
+ right_wrist_image = (normalized_observation["observation.images.cam_right_wrist"] * 255).to(
407
+ torch.uint8
408
+ )
409
+ obs_dict = {
410
+ "image": {"base_0_rgb": base_image, "left_wrist_0_rgb": left_wrist_image, "right_wrist_0_rgb": right_wrist_image},
411
+ "state": normalized_observation["observation.state"].to(torch.float32),
412
+ "prompt": [observation["task"]],
413
+ }
414
+ state = prepare_state(self.config, obs_dict)
415
+ lang_tokens, lang_masks = prepare_language(self.config, self.language_tokenizer, obs_dict)
416
+ images, img_masks, _ = prepare_images(self.config, self.image_processor, obs_dict)
417
+ observation = {
418
+ 'images': images,
419
+ 'img_masks': img_masks,
420
+ 'state': state,
421
+ 'lang_tokens': lang_tokens,
422
+ 'lang_masks': lang_masks,
423
+ }
424
+
425
+ if self.use_bf16:
426
+ observation['state'] = observation['state'].to(torch.bfloat16)
427
+
428
+ org_actions = ['action']
429
+ assert len(org_actions)==1, "Only support single action feature"
430
+ if self.chunk_ret:
431
+ action = self.vla.select_action(observation, self.use_bf16, self.config.vlm_causal)[org_actions[0]].float().cpu().numpy()
432
+ action = action[:self.use_length, :self.action_dim]
433
+ else:
434
+ if self.use_length == -1 or self.global_step % self.use_length == 0:
435
+ action = self.vla.select_action(observation, self.use_bf16, self.config.vlm_causal)[org_actions[0]]
436
+ self.last_action_chunk = action.float().cpu().numpy()
437
+
438
+ if self.use_length > 0:
439
+ action = self.last_action_chunk[self.global_step % self.use_length]
440
+ action = action[:, :self.action_dim]
441
+ print(f"on server step: {self.global_step}")
442
+ self.global_step+=1
443
+
444
+ return dict(action = action)
445
+
446
+
447
+ import argparse
448
+ from .websocket_policy_server import WebsocketPolicyServer
449
+
450
+ def main():
451
+ parser = argparse.ArgumentParser(description="启动 QwenPi WebSocket 策略服务器")
452
+
453
+ parser.add_argument(
454
+ "--model_path",
455
+ type=str,
456
+ )
457
+
458
+ parser.add_argument(
459
+ "--use_length",
460
+ type=int,
461
+ default=50,
462
+ help="used length of action chunk"
463
+ )
464
+
465
+ parser.add_argument(
466
+ "--chunk_ret",
467
+ type=bool,
468
+ default=True,
469
+ help=" True: The returned action tensor includes the horizon dimension. This allows the model to output a sequence of actions for each horizon step. False: The horizon dimension is omitted. The model selects and returns the next step autonomously based on its policy."
470
+ )
471
+
472
+ parser.add_argument(
473
+ "--port",
474
+ type=int,
475
+ default=8006,
476
+ help="port of WebSocket"
477
+ )
478
+
479
+ parser.add_argument(
480
+ "--debug_infer_once",
481
+ action="store_true",
482
+ help="Run one infer with dummy observation then exit (for debugging infer() without WebSocket client)",
483
+ )
484
+
485
+ args = parser.parse_args()
486
+
487
+ model = QwenPiServer(args.model_path, use_length=args.use_length, chunk_ret=args.chunk_ret)
488
+ if args.debug_infer_once:
489
+ # 调试用:不启动 WebSocket,只跑一次 infer,可在 infer / select_action 里下断点
490
+ dummy_obs = {
491
+ "observation.images.cam_high": np.zeros((224, 224, 3), dtype=np.uint8),
492
+ "observation.images.cam_left_wrist": np.zeros((224, 224, 3), dtype=np.uint8),
493
+ "observation.images.cam_right_wrist": np.zeros((224, 224, 3), dtype=np.uint8),
494
+ "observation.state": np.zeros(model.action_dim, dtype=np.float32),
495
+ "task": "dummy task for debug",
496
+ "reset": False,
497
+ }
498
+ out = model.infer(dummy_obs)
499
+ print("debug_infer_once result keys:", out.keys())
500
+ return
501
+ model_server = WebsocketPolicyServer(model, port=args.port)
502
+ model_server.serve_forever()
503
+
504
+
505
+ if __name__ == "__main__":
506
+ main()
pyproject.toml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "LingBot_VA"
7
+ version = "0.0.0"
8
+ description = "LingBot-VA: A Pragmatic VA Foundation Model"
9
+ authors = [
10
+ { name = "Robbyant Team", email = "fengchang.ll@antgroup.com" }
11
+ ]
12
+ license = { file = "LICENSE.txt" }
13
+ readme = "README.md"
14
+ requires-python = ">=3.10,<4.0"
15
+ dependencies = [
16
+ "torch>=2.9.0",
17
+ "torchvision>=0.24.0",
18
+ "diffusers>=0.36.0",
19
+ "transformers>=4.55.4",
20
+ "tokenizers>=0.21.4",
21
+ "tqdm",
22
+ "imageio",
23
+ "easydict",
24
+ "flash_attn",
25
+ "numpy>=1.26.4,<2"
26
+ ]
27
+
28
+ [project.optional-dependencies]
29
+ dev = [
30
+ "pytest",
31
+ "black",
32
+ "flake8",
33
+ "isort",
34
+ "mypy",
35
+ "huggingface-hub[cli]"
36
+ ]
37
+
38
+ [project.urls]
39
+ homepage = "https://github.com/Robbyant"
40
+ documentation = "https://github.com/Robbyant"
41
+ repository = "https://github.com/Robbyant"
42
+ huggingface = "https://github.com/Robbyant"
43
+ modelscope = "https://github.com/Robbyant"
44
+ discord = "https://github.com/Robbyant"
45
+
46
+ [tool.setuptools]
47
+ packages = ["lingbot_va"]
48
+
49
+ [tool.setuptools.package-data]
50
+ "lingbot_va" = ["**/*.py"]
51
+
52
+ [tool.black]
53
+ line-length = 88
54
+
55
+ [tool.isort]
56
+ profile = "black"
57
+
58
+ [tool.mypy]
59
+ strict = true
60
+
61
+
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.9.0
2
+ torchvision>=0.24.0
3
+ torchaudio
4
+ diffusers>=0.36.0
5
+ transformers>=4.55.4
6
+ tokenizers>=0.21.4
7
+ tqdm
8
+ imageio[ffmpeg]
9
+ easydict
10
+ flash_attn
11
+ numpy>=1.26.4,<2
script/run_launch_va_server_sync.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/bash
2
+
3
+ set -x
4
+
5
+ umask 007
6
+
7
+ NGPU=${NGPU:-"8"}
8
+ MASTER_PORT=${MASTER_PORT:-"29501"}
9
+ PORT=${PORT:-"1106"}
10
+ LOG_RANK=${LOG_RANK:-"0"}
11
+ TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}
12
+ CONFIG_NAME=${CONFIG_NAME:-"robotwin"}
13
+
14
+ overrides=""
15
+ if [ $# -ne 0 ]; then
16
+ overrides="$*"
17
+ fi
18
+
19
+ ## node setting
20
+ num_gpu=${NGPU}
21
+ master_port=${MASTER_PORT}
22
+ log_rank=${LOG_RANK}
23
+ torchft_lighthouse=${TORCHFT_LIGHTHOUSE}
24
+ config_name=${CONFIG_NAME}
25
+
26
+ ## cmd setting
27
+ export TOKENIZERS_PARALLELISM=false
28
+ PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" TORCHFT_LIGHTHOUSE=${torchft_lighthouse} \
29
+ python -m torch.distributed.run \
30
+ --nproc_per_node=${num_gpu} \
31
+ --local-ranks-filter=${log_rank} \
32
+ --master_port ${master_port} \
33
+ --tee 3 \
34
+ -m wan_va.wan_va_server --config-name ${config_name} $overrides
wan_va/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from . import configs, distributed, modules
wan_va/configs/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
2
+ from .va_franka_cfg import va_franka_cfg
3
+ from .va_robotwin_cfg import va_robotwin_cfg
4
+ from .va_franka_i2va import va_franka_i2va_cfg
5
+ from .va_robotwin_i2va import va_robotwin_i2va_cfg
6
+
7
+ VA_CONFIGS = {
8
+ 'robotwin': va_robotwin_cfg,
9
+ 'franka': va_franka_cfg,
10
+ 'robotwin_i2av': va_robotwin_i2va_cfg,
11
+ 'franka_i2av': va_franka_i2va_cfg,
12
+ }
wan_va/configs/shared_config.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
2
+ import torch
3
+ from easydict import EasyDict
4
+
5
+ va_shared_cfg = EasyDict()
6
+
7
+ va_shared_cfg.host = '0.0.0.0'
8
+ va_shared_cfg.port = 29536
9
+
10
+ va_shared_cfg.param_dtype = torch.bfloat16
11
+ va_shared_cfg.save_root = './visualization'
12
+
13
+ va_shared_cfg.patch_size = (1, 2, 2)
wan_va/configs/va_franka_cfg.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
2
+ import torch
3
+ from easydict import EasyDict
4
+
5
+ from .shared_config import va_shared_cfg
6
+
7
+ va_franka_cfg = EasyDict(__name__='Config: VA franka')
8
+ va_franka_cfg.update(va_shared_cfg)
9
+ va_shared_cfg.infer_mode = 'server'
10
+
11
+ va_franka_cfg.wan22_pretrained_model_name_or_path = "/path/to/pretrained/model"
12
+
13
+ va_franka_cfg.attn_window = 30
14
+ va_franka_cfg.frame_chunk_size = 4
15
+ va_franka_cfg.env_type = 'none'
16
+
17
+ va_franka_cfg.height = 224
18
+ va_franka_cfg.width = 320
19
+ va_franka_cfg.action_dim = 30
20
+ va_franka_cfg.action_per_frame = 20
21
+ va_franka_cfg.obs_cam_keys = [
22
+ 'observation.images.cam_high', 'observation.images.cam_left_wrist',
23
+ 'observation.images.cam_right_wrist'
24
+ ]
25
+ va_franka_cfg.guidance_scale = 5
26
+ va_franka_cfg.action_guidance_scale = 1
27
+
28
+ va_franka_cfg.num_inference_steps = 5
29
+ va_franka_cfg.video_exec_step = -1
30
+ va_franka_cfg.action_num_inference_steps = 10
31
+
32
+ va_franka_cfg.snr_shift = 5.0
33
+ va_franka_cfg.action_snr_shift = 1.0
34
+
35
+ va_franka_cfg.used_action_channel_ids = list(range(0, 7)) + list(range(
36
+ 28, 29)) + list(range(7, 14)) + list(range(29, 30))
37
+ inverse_used_action_channel_ids = [len(va_franka_cfg.used_action_channel_ids)
38
+ ] * va_franka_cfg.action_dim
39
+ for i, j in enumerate(va_franka_cfg.used_action_channel_ids):
40
+ inverse_used_action_channel_ids[j] = i
41
+ va_franka_cfg.inverse_used_action_channel_ids = inverse_used_action_channel_ids
42
+
43
+ va_franka_cfg.action_norm_method = 'quantiles'
44
+ va_franka_cfg.norm_stat = {
45
+ "q01": [
46
+ 0.3051295876502991, -0.22647984325885773, 0.19957000017166138,
47
+ -0.022680532187223434, -0.05553057789802551, -0.2693849802017212,
48
+ -0.29341773986816405, 0.2935442328453064, -0.4431332051753998,
49
+ 0.21256473660469055, -0.7962440848350525, -0.40816226601600647,
50
+ -0.28359392285346985, -0.44507765769958496
51
+ ] + [0.] * 16,
52
+ "q99": [
53
+ 0.7572150230407715, 0.47736290097236633, 0.6428080797195435,
54
+ 0.9835678935050964, 0.9927203059196472, 0.28041139245033264,
55
+ 0.47529348731040877, 0.7564866304397571, 0.04082797020673729,
56
+ 0.5355993628501885, 0.9976375699043274, 0.8973174452781656,
57
+ 0.6016915678977965, 0.5027598619461056
58
+ ] + [0.] * 14 + [1.0, 1.0],
59
+ }
wan_va/configs/va_franka_i2va.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+ from .va_franka_cfg import va_franka_cfg
4
+
5
+ va_franka_i2va_cfg = EasyDict(__name__='Config: VA franka i2va')
6
+ va_franka_i2va_cfg.update(va_franka_cfg)
7
+
8
+ va_franka_i2va_cfg.input_img_path = 'example/franka'
9
+ va_franka_i2va_cfg.num_chunks_to_infer = 10
10
+ va_franka_i2va_cfg.prompt = 'pick bunk'
11
+ va_franka_i2va_cfg.infer_mode = 'i2va'
wan_va/configs/va_robotwin_cfg.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import va_shared_cfg
5
+
6
+ va_robotwin_cfg = EasyDict(__name__='Config: VA robotwin')
7
+ va_robotwin_cfg.update(va_shared_cfg)
8
+
9
+ va_robotwin_cfg.wan22_pretrained_model_name_or_path = "/group/ossdphi_algo_scratch_11/weicxu/huggingface_cache/hub/models--robbyant--lingbot-va-posttrain-robotwin/snapshots/ef7242af28caff0af2ff8e947c78806f94719a39"
10
+
11
+ va_robotwin_cfg.attn_window = 72
12
+ va_robotwin_cfg.frame_chunk_size = 2
13
+ va_robotwin_cfg.env_type = 'robotwin_tshape'
14
+
15
+ va_robotwin_cfg.height = 256
16
+ va_robotwin_cfg.width = 320
17
+ va_robotwin_cfg.action_dim = 30
18
+ va_robotwin_cfg.action_per_frame = 16
19
+ va_robotwin_cfg.obs_cam_keys = [
20
+ 'observation.images.cam_high', 'observation.images.cam_left_wrist',
21
+ 'observation.images.cam_right_wrist'
22
+ ]
23
+ va_robotwin_cfg.guidance_scale = 5
24
+ va_robotwin_cfg.action_guidance_scale = 1
25
+
26
+ va_robotwin_cfg.num_inference_steps = 25
27
+ va_robotwin_cfg.video_exec_step = -1
28
+ va_robotwin_cfg.action_num_inference_steps = 50
29
+
30
+ va_robotwin_cfg.snr_shift = 5.0
31
+ va_robotwin_cfg.action_snr_shift = 1.0
32
+
33
+ va_robotwin_cfg.used_action_channel_ids = list(range(0, 7)) + list(
34
+ range(28, 29)) + list(range(7, 14)) + list(range(29, 30))
35
+ inverse_used_action_channel_ids = [
36
+ len(va_robotwin_cfg.used_action_channel_ids)
37
+ ] * va_robotwin_cfg.action_dim
38
+ for i, j in enumerate(va_robotwin_cfg.used_action_channel_ids):
39
+ inverse_used_action_channel_ids[j] = i
40
+ va_robotwin_cfg.inverse_used_action_channel_ids = inverse_used_action_channel_ids
41
+
42
+ va_robotwin_cfg.action_norm_method = 'quantiles'
43
+ va_robotwin_cfg.norm_stat = {
44
+ "q01": [
45
+ -0.06172713458538055, -3.6716461181640625e-05, -0.08783501386642456,
46
+ -1, -1, -1, -1, -0.3547105032205582, -1.3113021850585938e-06,
47
+ -0.11975435614585876, -1, -1, -1, -1
48
+ ] + [0.] * 16,
49
+ "q99": [
50
+ 0.3462600058317184, 0.39966784834861746, 0.14745532035827624, 1, 1, 1,
51
+ 1, 0.034201726913452024, 0.39142737388610793, 0.1792279863357542, 1, 1,
52
+ 1, 1
53
+ ] + [0.] * 14 + [1.0, 1.0],
54
+ }
wan_va/configs/va_robotwin_i2va.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+ from .va_robotwin_cfg import va_robotwin_cfg
4
+
5
+ va_robotwin_i2va_cfg = EasyDict(__name__='Config: VA robotwin i2va')
6
+ va_robotwin_i2va_cfg.update(va_robotwin_cfg)
7
+
8
+ va_robotwin_i2va_cfg.input_img_path = 'example/robotwin'
9
+ va_robotwin_i2va_cfg.num_chunks_to_infer = 10
10
+ va_robotwin_i2va_cfg.prompt = 'Grab the medium-sized white mug, rotate it, place it on the table, and hook it onto the smooth dark gray rack.'
11
+ va_robotwin_i2va_cfg.infer_mode = 'i2va'
wan_va/distributed/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
wan_va/distributed/fsdp.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ from functools import partial
4
+
5
+ import torch
6
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
7
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
8
+ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
9
+ from torch.distributed.utils import _free_storage
10
+
11
+
12
+ def shard_model(model,
13
+ device_id,
14
+ param_dtype=torch.bfloat16,
15
+ reduce_dtype=torch.float32,
16
+ buffer_dtype=torch.float32,
17
+ process_group=None,
18
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
19
+ sync_module_states=True,
20
+ use_lora=False):
21
+ model = FSDP(module=model,
22
+ process_group=process_group,
23
+ sharding_strategy=sharding_strategy,
24
+ auto_wrap_policy=partial(
25
+ lambda_auto_wrap_policy,
26
+ lambda_fn=lambda m: m in model.blocks),
27
+ mixed_precision=MixedPrecision(param_dtype=param_dtype,
28
+ reduce_dtype=reduce_dtype,
29
+ buffer_dtype=buffer_dtype),
30
+ device_id=device_id,
31
+ sync_module_states=sync_module_states,
32
+ use_orig_params=True if use_lora else False)
33
+ return model
34
+
35
+
36
+ def free_model(model):
37
+ for m in model.modules():
38
+ if isinstance(m, FSDP):
39
+ _free_storage(m._handle.flat_param.data)
40
+ del model
41
+ gc.collect()
42
+ torch.cuda.empty_cache()
wan_va/distributed/util.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.distributed as dist
4
+
5
+
6
+ def _configure_model(model, shard_fn, param_dtype, device):
7
+ """
8
+ TODO
9
+ """
10
+ model.eval().requires_grad_(False)
11
+ if dist.is_initialized():
12
+ dist.barrier()
13
+
14
+ if dist.is_initialized():
15
+ model = shard_fn(model)
16
+ else:
17
+ model.to(param_dtype)
18
+ model.to(device)
19
+
20
+ return model
21
+
22
+
23
+ def init_distributed(world_size, local_rank, rank):
24
+ torch.cuda.set_device(local_rank)
25
+ if world_size > 1:
26
+ dist.init_process_group(backend="nccl",
27
+ init_method="env://",
28
+ rank=rank,
29
+ world_size=world_size)
wan_va/modules/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
2
+ from .utils import load_text_encoder, load_tokenizer, load_transformer, load_vae
3
+
4
+ __all__ = [
5
+ 'load_transformer', 'load_text_encoder', 'load_tokenizer', 'load_vae',
6
+ 'WanVAEStreamingWrapper'
7
+ ]
wan_va/modules/model.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
2
+ import math
3
+ from copy import deepcopy
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.models.attention import FeedForward
10
+ from diffusers.models.embeddings import (
11
+ PixArtAlphaTextProjection,
12
+ TimestepEmbedding,
13
+ Timesteps,
14
+ )
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.models.normalization import FP32LayerNorm
17
+ from einops import rearrange
18
+
19
+ try:
20
+ from flash_attn_interface import flash_attn_func
21
+ except:
22
+ from flash_attn import flash_attn_func
23
+
24
+ __all__ = ['WanTransformer3DModel']
25
+
26
+
27
+ def custom_sdpa(q, k, v):
28
+ out = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2),
29
+ v.transpose(1, 2))
30
+ return out.transpose(1, 2)
31
+
32
+
33
+ class WanTimeTextImageEmbedding(nn.Module):
34
+
35
+ def __init__(
36
+ self,
37
+ dim,
38
+ time_freq_dim,
39
+ time_proj_dim,
40
+ text_embed_dim,
41
+ pos_embed_seq_len,
42
+ ):
43
+ super().__init__()
44
+
45
+ self.timesteps_proj = Timesteps(num_channels=time_freq_dim,
46
+ flip_sin_to_cos=True,
47
+ downscale_freq_shift=0)
48
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim,
49
+ time_embed_dim=dim)
50
+ self.act_fn = nn.SiLU()
51
+ self.time_proj = nn.Linear(dim, time_proj_dim)
52
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim,
53
+ dim,
54
+ act_fn="gelu_tanh")
55
+
56
+ def forward(
57
+ self,
58
+ timestep: torch.Tensor,
59
+ dtype=None,
60
+ ):
61
+ B, L = timestep.shape
62
+ timestep = timestep.reshape(-1)
63
+ timestep = self.timesteps_proj(timestep)
64
+ # time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
65
+ time_embedder_dtype = self.time_embedder.linear_1.weight.dtype
66
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
67
+ timestep = timestep.to(time_embedder_dtype)
68
+ temb = self.time_embedder(timestep).to(dtype=dtype)
69
+ timestep_proj = self.time_proj(self.act_fn(temb))
70
+ return temb.reshape(B, L, -1), timestep_proj.reshape(B, L, -1)
71
+
72
+
73
+ class WanRotaryPosEmbed(nn.Module):
74
+
75
+ def __init__(
76
+ self,
77
+ attention_head_dim,
78
+ patch_size,
79
+ max_seq_len,
80
+ theta=10000.0,
81
+ ):
82
+ super().__init__()
83
+
84
+ self.attention_head_dim = attention_head_dim
85
+ self.patch_size = patch_size
86
+ self.max_seq_len = max_seq_len
87
+ self.theta = theta
88
+
89
+ self.f_dim = self.attention_head_dim - 2 * (self.attention_head_dim //
90
+ 3)
91
+ self.h_dim = self.attention_head_dim // 3
92
+ self.w_dim = self.attention_head_dim // 3
93
+
94
+ # Precompute and register buffers
95
+ f_freqs_base, h_freqs_base, w_freqs_base = self._precompute_freqs_base(
96
+ )
97
+ self.register_buffer("f_freqs_base", f_freqs_base, persistent=False)
98
+ self.register_buffer("h_freqs_base", h_freqs_base, persistent=False)
99
+ self.register_buffer("w_freqs_base", w_freqs_base, persistent=False)
100
+
101
+ def _precompute_freqs_base(self):
102
+ # freqs_base = 1.0 / (theta ** (2k / dim))
103
+ f_freqs_base = 1.0 / (self.theta**(torch.arange(
104
+ 0, self.f_dim, 2)[:(self.f_dim // 2)].double() / self.f_dim))
105
+ h_freqs_base = 1.0 / (self.theta**(torch.arange(
106
+ 0, self.h_dim, 2)[:(self.h_dim // 2)].double() / self.h_dim))
107
+ w_freqs_base = 1.0 / (self.theta**(torch.arange(
108
+ 0, self.w_dim, 2)[:(self.w_dim // 2)].double() / self.w_dim))
109
+ return f_freqs_base, h_freqs_base, w_freqs_base
110
+
111
+ def forward(self, grid_ids):
112
+ with torch.no_grad():
113
+ f_freqs = grid_ids[:, 0, :].unsqueeze(-1) * self.f_freqs_base
114
+ h_freqs = grid_ids[:, 1, :].unsqueeze(-1) * self.h_freqs_base
115
+ w_freqs = grid_ids[:, 2, :].unsqueeze(-1) * self.w_freqs_base
116
+ freqs = torch.cat([f_freqs, h_freqs, w_freqs], dim=-1).float()
117
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
118
+
119
+ return freqs_cis
120
+
121
+
122
+ class WanAttention(torch.nn.Module):
123
+
124
+ def __init__(
125
+ self,
126
+ dim,
127
+ heads=8,
128
+ dim_head=64,
129
+ eps=1e-5,
130
+ dropout=0.0,
131
+ cross_attention_dim_head=None,
132
+ attn_mode='torch',
133
+ ):
134
+ super().__init__()
135
+ if attn_mode == 'torch':
136
+ self.attn_op = custom_sdpa
137
+ elif attn_mode == 'flashattn':
138
+ self.attn_op = flash_attn_func
139
+ else:
140
+ raise ValueError(
141
+ f"Unsupported attention mode: {attn_mode}, only support torch and flashattn"
142
+ )
143
+
144
+ self.inner_dim = dim_head * heads
145
+ self.heads = heads
146
+ self.cross_attention_dim_head = cross_attention_dim_head
147
+ self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
148
+
149
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
150
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
151
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
152
+ self.to_out = torch.nn.ModuleList([
153
+ torch.nn.Linear(self.inner_dim, dim, bias=True),
154
+ torch.nn.Dropout(dropout),
155
+ ])
156
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads,
157
+ eps=eps,
158
+ elementwise_affine=True)
159
+ self.norm_k = torch.nn.RMSNorm(dim_head * heads,
160
+ eps=eps,
161
+ elementwise_affine=True)
162
+ self.attn_caches = {} if cross_attention_dim_head is None else None
163
+
164
+ def clear_pred_cache(self, cache_name):
165
+ if self.attn_caches is None:
166
+ return
167
+ cache = self.attn_caches[cache_name]
168
+ is_pred = cache['is_pred']
169
+ cache['mask'][is_pred] = False
170
+
171
+ def clear_cache(self, cache_name):
172
+ if self.attn_caches is None:
173
+ return
174
+ self.attn_caches[cache_name] = None
175
+
176
+ def init_kv_cache(self, cache_name, total_tolen, num_head, head_dim,
177
+ device, dtype, batch_size):
178
+ if self.attn_caches is None:
179
+ return
180
+ self.attn_caches[cache_name] = {
181
+ 'k':
182
+ torch.empty([batch_size, total_tolen, num_head, head_dim],
183
+ device=device,
184
+ dtype=dtype),
185
+ 'v':
186
+ torch.empty([batch_size, total_tolen, num_head, head_dim],
187
+ device=device,
188
+ dtype=dtype),
189
+ 'id':
190
+ torch.full((total_tolen, ), -1, device=device),
191
+ "mask":
192
+ torch.zeros((total_tolen, ), dtype=torch.bool, device=device),
193
+ "is_pred":
194
+ torch.zeros((total_tolen, ), dtype=torch.bool, device=device),
195
+ }
196
+
197
+ def allocate_slots(self, cache_name, key_size):
198
+ cache = self.attn_caches[cache_name]
199
+ mask = cache["mask"]
200
+ ids = cache["id"]
201
+ free = (~mask).nonzero(as_tuple=False).squeeze(-1)
202
+
203
+ if free.numel() < key_size:
204
+ used = mask.nonzero(as_tuple=False).squeeze(-1)
205
+
206
+ used_ids = ids[used]
207
+ order = torch.argsort(used_ids)
208
+ need = key_size - free.numel()
209
+ to_free = used[order[:need]]
210
+
211
+ mask[to_free] = False
212
+ ids[to_free] = -1
213
+ free = (~mask).nonzero(as_tuple=False).squeeze(-1)
214
+
215
+ assert free.numel() >= key_size
216
+ return free[:key_size]
217
+
218
+ def _next_cache_id(self, cache_name):
219
+ ids = self.attn_caches[cache_name]['id']
220
+ mask = self.attn_caches[cache_name]['mask']
221
+
222
+ if mask.any():
223
+ return ids[mask].max() + 1
224
+ else:
225
+ return torch.tensor(0, device=ids.device, dtype=ids.dtype)
226
+
227
+ def update_cache(self, cache_name, key, value, is_pred):
228
+ cache = self.attn_caches[cache_name]
229
+
230
+ key_size = key.shape[1]
231
+ slots = self.allocate_slots(cache_name, key_size)
232
+
233
+ new_id = self._next_cache_id(cache_name)
234
+
235
+ cache['k'][:, slots] = key
236
+ cache['v'][:, slots] = value
237
+ cache['mask'][slots] = True
238
+ cache['id'][slots] = new_id
239
+ cache['is_pred'][slots] = is_pred
240
+ return slots
241
+
242
+ def restore_cache(self, cache_name, slots):
243
+ self.attn_caches[cache_name]['mask'][slots] = False
244
+
245
+ def forward(
246
+ self,
247
+ q,
248
+ k,
249
+ v,
250
+ rotary_emb,
251
+ update_cache=0,
252
+ cache_name='pos',
253
+ ):
254
+ kv_cache = self.attn_caches[
255
+ cache_name] if self.attn_caches is not None else None
256
+
257
+ query, key, value = self.to_q(q), self.to_k(k), self.to_v(v)
258
+ query = self.norm_q(query)
259
+ query = query.unflatten(2, (self.heads, -1))
260
+ key = self.norm_k(key)
261
+ key = key.unflatten(2, (self.heads, -1))
262
+ value = value.unflatten(2, (self.heads, -1))
263
+ if rotary_emb is not None:
264
+
265
+ def apply_rotary_emb(x, freqs):
266
+ x_out = torch.view_as_complex(
267
+ x.to(torch.float64).reshape(x.shape[0], x.shape[1],
268
+ x.shape[2], -1, 2))
269
+ x_out = torch.view_as_real(x_out * freqs).flatten(3)
270
+ return x_out.to(x.dtype)
271
+ query = apply_rotary_emb(query, rotary_emb)
272
+ key = apply_rotary_emb(key, rotary_emb)
273
+ slots = None
274
+ if kv_cache is not None and kv_cache['k'] is not None:
275
+ slots = self.update_cache(cache_name,
276
+ key,
277
+ value,
278
+ is_pred=(update_cache == 1))
279
+ key_pool = self.attn_caches[cache_name]['k']
280
+ value_pool = self.attn_caches[cache_name]['v']
281
+ mask = self.attn_caches[cache_name]['mask']
282
+ valid = mask.nonzero(as_tuple=False).squeeze(-1)
283
+ key = key_pool[:, valid]
284
+ value = value_pool[:, valid]
285
+
286
+ hidden_states = self.attn_op(query, key, value)
287
+
288
+ if update_cache == 0:
289
+ if kv_cache is not None and kv_cache['k'] is not None:
290
+ self.restore_cache(cache_name, slots)
291
+
292
+ hidden_states = hidden_states.flatten(2, 3)
293
+ hidden_states = hidden_states.type_as(query)
294
+ hidden_states = self.to_out[0](hidden_states)
295
+ hidden_states = self.to_out[1](hidden_states)
296
+ return hidden_states
297
+
298
+
299
+ class WanTransformerBlock(nn.Module):
300
+
301
+ def __init__(
302
+ self,
303
+ dim,
304
+ ffn_dim,
305
+ num_heads,
306
+ cross_attn_norm=False,
307
+ eps=1e-6,
308
+ attn_mode: str = "flashattn",
309
+ ):
310
+ super().__init__()
311
+ self.attn_mode = attn_mode
312
+
313
+ # 1. Self-attention
314
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
315
+ self.attn1 = WanAttention(
316
+ dim=dim,
317
+ heads=num_heads,
318
+ dim_head=dim // num_heads,
319
+ eps=eps,
320
+ cross_attention_dim_head=None,
321
+ attn_mode=attn_mode,
322
+ )
323
+
324
+ # 2. Cross-attention
325
+ self.attn2 = WanAttention(
326
+ dim=dim,
327
+ heads=num_heads,
328
+ dim_head=dim // num_heads,
329
+ eps=eps,
330
+ cross_attention_dim_head=dim // num_heads,
331
+ attn_mode=attn_mode,
332
+ )
333
+ self.norm2 = FP32LayerNorm(
334
+ dim, eps,
335
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
336
+
337
+ # 3. Feed-forward
338
+ self.ffn = FeedForward(dim,
339
+ inner_dim=ffn_dim,
340
+ activation_fn="gelu-approximate")
341
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
342
+
343
+ self.scale_shift_table = nn.Parameter(
344
+ torch.randn(1, 6, dim) / dim**0.5)
345
+
346
+ def forward(
347
+ self,
348
+ hidden_states,
349
+ encoder_hidden_states,
350
+ temb,
351
+ rotary_emb,
352
+ update_cache=0,
353
+ cache_name='pos',
354
+ ) -> torch.Tensor:
355
+ temb_scale_shift_table = self.scale_shift_table[None] + temb.float()
356
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = \
357
+ rearrange(temb_scale_shift_table, 'b l n c -> b n l c').chunk(6, dim=1)
358
+ shift_msa = shift_msa.squeeze(1)
359
+ scale_msa = scale_msa.squeeze(1)
360
+ gate_msa = gate_msa.squeeze(1)
361
+ c_shift_msa = c_shift_msa.squeeze(1)
362
+ c_scale_msa = c_scale_msa.squeeze(1)
363
+ c_gate_msa = c_gate_msa.squeeze(1)
364
+
365
+ # 1. Self-attention
366
+ norm_hidden_states = (self.norm1(hidden_states.float()) *
367
+ (1. + scale_msa) +
368
+ shift_msa).type_as(hidden_states)
369
+ attn_output = self.attn1(norm_hidden_states,
370
+ norm_hidden_states,
371
+ norm_hidden_states,
372
+ rotary_emb,
373
+ update_cache=update_cache,
374
+ cache_name=cache_name)
375
+ hidden_states = (hidden_states.float() +
376
+ attn_output * gate_msa).type_as(hidden_states)
377
+
378
+ # 2. Cross-attention
379
+ norm_hidden_states = self.norm2(
380
+ hidden_states.float()).type_as(hidden_states)
381
+ attn_output = self.attn2(norm_hidden_states,
382
+ encoder_hidden_states,
383
+ encoder_hidden_states,
384
+ None,
385
+ update_cache=0,
386
+ cache_name=cache_name)
387
+ hidden_states = hidden_states + attn_output
388
+
389
+ # 3. Feed-forward
390
+ norm_hidden_states = (self.norm3(hidden_states.float()) *
391
+ (1. + c_scale_msa) +
392
+ c_shift_msa).type_as(hidden_states)
393
+
394
+ ff_output = self.ffn(norm_hidden_states)
395
+
396
+ hidden_states = (hidden_states.float() +
397
+ ff_output.float() * c_gate_msa).type_as(hidden_states)
398
+ return hidden_states
399
+
400
+
401
+ class WanTransformer3DModel(ModelMixin, ConfigMixin):
402
+ r"""
403
+ TODO
404
+ """
405
+
406
+ @register_to_config
407
+ def __init__(self,
408
+ patch_size=[1, 2, 2],
409
+ num_attention_heads=24,
410
+ attention_head_dim=128,
411
+ in_channels=48,
412
+ out_channels=48,
413
+ action_dim=30,
414
+ text_dim=4096,
415
+ freq_dim=256,
416
+ ffn_dim=14336,
417
+ num_layers=30,
418
+ cross_attn_norm=True,
419
+ eps=1e-06,
420
+ rope_max_seq_len=1024,
421
+ pos_embed_seq_len=None,
422
+ attn_mode="torch"):
423
+ r"""
424
+ TODO
425
+ """
426
+ super().__init__()
427
+ self.patch_size = patch_size
428
+ self.num_attention_heads = num_attention_heads
429
+ self.attention_head_dim = attention_head_dim
430
+ inner_dim = num_attention_heads * attention_head_dim
431
+ self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size,
432
+ rope_max_seq_len)
433
+ self.patch_embedding_mlp = nn.Linear(
434
+ in_channels * patch_size[0] * patch_size[1] * patch_size[2],
435
+ inner_dim)
436
+ self.action_embedder = nn.Linear(action_dim, inner_dim)
437
+ self.condition_embedder = WanTimeTextImageEmbedding(
438
+ dim=inner_dim,
439
+ time_freq_dim=freq_dim,
440
+ time_proj_dim=inner_dim * 6,
441
+ text_embed_dim=text_dim,
442
+ pos_embed_seq_len=pos_embed_seq_len,
443
+ )
444
+ self.condition_embedder_action = deepcopy(self.condition_embedder)
445
+
446
+ self.blocks = nn.ModuleList([
447
+ WanTransformerBlock(inner_dim,
448
+ ffn_dim,
449
+ num_attention_heads,
450
+ cross_attn_norm,
451
+ eps,
452
+ attn_mode=attn_mode) for _ in range(num_layers)
453
+ ])
454
+
455
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
456
+ self.proj_out = nn.Linear(inner_dim,
457
+ out_channels * math.prod(patch_size))
458
+ self.action_proj_out = nn.Linear(inner_dim, action_dim)
459
+ self.scale_shift_table = nn.Parameter(
460
+ torch.randn(1, 2, inner_dim) / inner_dim**0.5)
461
+
462
+ def clear_cache(self, cache_name):
463
+ for block in self.blocks:
464
+ block.attn1.clear_cache(cache_name)
465
+
466
+ def clear_pred_cache(self, cache_name):
467
+ for block in self.blocks:
468
+ block.attn1.clear_pred_cache(cache_name)
469
+
470
+ def create_empty_cache(self, cache_name, attn_window,
471
+ latent_token_per_chunk, action_token_per_chunk,
472
+ device, dtype, batch_size):
473
+ total_tolen = (attn_window // 2) * latent_token_per_chunk + (
474
+ attn_window // 2) * action_token_per_chunk
475
+ for block in self.blocks:
476
+ block.attn1.init_kv_cache(cache_name, total_tolen,
477
+ self.num_attention_heads,
478
+ self.attention_head_dim, device, dtype, batch_size)
479
+
480
+ def forward(
481
+ self,
482
+ input_dict,
483
+ update_cache=0,
484
+ cache_name="pos",
485
+ action_mode=False,
486
+ ):
487
+ r"""
488
+ Forward pass through the diffusion model
489
+
490
+ Args:
491
+ x (List[Tensor]):
492
+ List of input video tensors, each with shape [C_in, F, H, W]
493
+ t (Tensor):
494
+ Diffusion timesteps tensor of shape [B]
495
+ context (List[Tensor]):
496
+ List of text embeddings each with shape [L, C]
497
+ seq_len (`int`):
498
+ Maximum sequence length for positional encoding
499
+ y (List[Tensor], *optional*):
500
+ Conditional video inputs for image-to-video mode, same shape as x
501
+
502
+ Returns:
503
+ List[Tensor]:
504
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
505
+ """
506
+ if action_mode: # action input emb
507
+ latent_hidden_states = rearrange(input_dict['noisy_latents'],
508
+ 'b c f h w -> b (f h w) c')
509
+ latent_hidden_states = self.action_embedder(
510
+ latent_hidden_states) # B L1 C
511
+ else: # latent input emb
512
+ latent_hidden_states = rearrange(
513
+ input_dict['noisy_latents'],
514
+ 'b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)',
515
+ p1=self.patch_size[0],
516
+ p2=self.patch_size[1],
517
+ p3=self.patch_size[2])
518
+ latent_hidden_states = self.patch_embedding_mlp(
519
+ latent_hidden_states)
520
+ text_hidden_states = self.condition_embedder.text_embedder(
521
+ input_dict["text_emb"]) # B L2 C
522
+
523
+ latent_grid_id = input_dict['grid_id']
524
+ rotary_emb = self.rope(latent_grid_id)[:, :, None] # 1 L 1 C
525
+ pach_scale_h, pach_scale_w = (1, 1) if action_mode else (
526
+ self.patch_size[1], self.patch_size[2])
527
+
528
+ latent_time_steps = torch.repeat_interleave(
529
+ input_dict['timesteps'],
530
+ (input_dict['noisy_latents'].shape[-2] // pach_scale_h) *
531
+ (input_dict['noisy_latents'].shape[-1] // pach_scale_w), dim=1) # L
532
+ current_condition_embedder = self.condition_embedder_action if action_mode else self.condition_embedder
533
+ temb, timestep_proj = current_condition_embedder(
534
+ latent_time_steps, dtype=latent_hidden_states.dtype)
535
+ timestep_proj = timestep_proj.unflatten(2, (6, -1)) # B L 6 C
536
+
537
+ for block in self.blocks:
538
+ latent_hidden_states = block(latent_hidden_states,
539
+ text_hidden_states,
540
+ timestep_proj,
541
+ rotary_emb,
542
+ update_cache=update_cache,
543
+ cache_name=cache_name)
544
+ temb_scale_shift_table = self.scale_shift_table[None] + temb[:, :, None, ...]
545
+ shift, scale = rearrange(temb_scale_shift_table,
546
+ 'b l n c -> b n l c').chunk(2, dim=1)
547
+ shift = shift.to(latent_hidden_states.device).squeeze(1)
548
+ scale = scale.to(latent_hidden_states.device).squeeze(1)
549
+ latent_hidden_states = (self.norm_out(latent_hidden_states.float()) *
550
+ (1. + scale) +
551
+ shift).type_as(latent_hidden_states)
552
+
553
+ if action_mode:
554
+ latent_hidden_states = self.action_proj_out(latent_hidden_states)
555
+ else:
556
+ latent_hidden_states = self.proj_out(latent_hidden_states)
557
+ latent_hidden_states = rearrange(latent_hidden_states,
558
+ 'b l (n c) -> b (l n) c',
559
+ n=math.prod(self.patch_size)) #
560
+
561
+ return latent_hidden_states
562
+
563
+
564
+ if __name__ == '__main__':
565
+ model = WanTransformer3DModel(patch_size=[1, 2, 2],
566
+ num_attention_heads=24,
567
+ attention_head_dim=128,
568
+ in_channels=48,
569
+ out_channels=48,
570
+ action_dim=30,
571
+ text_dim=4096,
572
+ freq_dim=256,
573
+ ffn_dim=14336,
574
+ num_layers=30,
575
+ cross_attn_norm=True,
576
+ eps=1e-6,
577
+ rope_max_seq_len=1024,
578
+ pos_embed_seq_len=None,
579
+ attn_mode="torch")
580
+ print(model)
wan_va/modules/utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
2
+ import torch
3
+ from diffusers import AutoencoderKLWan
4
+ from transformers import (
5
+ T5TokenizerFast,
6
+ UMT5EncoderModel,
7
+ )
8
+
9
+ from .model import WanTransformer3DModel
10
+
11
+
12
+ def load_vae(
13
+ vae_path,
14
+ torch_dtype,
15
+ torch_device,
16
+ ):
17
+ vae = AutoencoderKLWan.from_pretrained(
18
+ vae_path,
19
+ torch_dtype=torch_dtype,
20
+ )
21
+ return vae.to(torch_device)
22
+
23
+
24
+ def load_text_encoder(
25
+ text_encoder_path,
26
+ torch_dtype,
27
+ torch_device,
28
+ ):
29
+ text_encoder = UMT5EncoderModel.from_pretrained(
30
+ text_encoder_path,
31
+ torch_dtype=torch_dtype,
32
+ )
33
+ return text_encoder.to(torch_device)
34
+
35
+
36
+ def load_tokenizer(tokenizer_path, ):
37
+ tokenizer = T5TokenizerFast.from_pretrained(tokenizer_path, )
38
+ return tokenizer
39
+
40
+
41
+ def load_transformer(
42
+ transformer_path,
43
+ torch_dtype,
44
+ torch_device,
45
+ ):
46
+ model = WanTransformer3DModel.from_pretrained(
47
+ transformer_path,
48
+ torch_dtype=torch_dtype,
49
+ )
50
+ return model.to(torch_device)
51
+
52
+
53
+ def patchify(x, patch_size):
54
+ if patch_size is None or patch_size == 1:
55
+ return x
56
+ batch_size, channels, frames, height, width = x.shape
57
+ x = x.view(batch_size, channels, frames, height // patch_size, patch_size,
58
+ width // patch_size, patch_size)
59
+ x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
60
+ x = x.view(batch_size, channels * patch_size * patch_size, frames,
61
+ height // patch_size, width // patch_size)
62
+ return x
63
+
64
+
65
+ class WanVAEStreamingWrapper:
66
+
67
+ def __init__(self, vae_model):
68
+ self.vae = vae_model
69
+ self.encoder = vae_model.encoder
70
+ self.quant_conv = vae_model.quant_conv
71
+
72
+ if hasattr(self.vae, "_cached_conv_counts"):
73
+ self.enc_conv_num = self.vae._cached_conv_counts["encoder"]
74
+ else:
75
+ count = 0
76
+ for m in self.encoder.modules():
77
+ if m.__class__.__name__ == "WanCausalConv3d":
78
+ count += 1
79
+ self.enc_conv_num = count
80
+
81
+ self.clear_cache()
82
+
83
+ def clear_cache(self):
84
+ self.feat_cache = [None] * self.enc_conv_num
85
+
86
+ def encode_chunk(self, x_chunk):
87
+ if hasattr(self.vae.config,
88
+ "patch_size") and self.vae.config.patch_size is not None:
89
+ x_chunk = patchify(x_chunk, self.vae.config.patch_size)
90
+ feat_idx = [0]
91
+ out = self.encoder(x_chunk,
92
+ feat_cache=self.feat_cache,
93
+ feat_idx=feat_idx)
94
+ enc = self.quant_conv(out)
95
+ return enc
wan_va/utils/Simple_Remote_Infer/LEGAL.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Legal Disclaimer
2
+
3
+ Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail.
4
+
5
+ 法律免责声明
6
+
7
+ 关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。
wan_va/utils/Simple_Remote_Infer/README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 通用的server-client
2
+
3
+ ## /Simple_Remote_Infer/deploy/qwenpi_policy.py
4
+
5
+ - QwenPiServer: 一个用于示范的模型,拥有init和infer方法
6
+
7
+ 将加载好的模型用`WebsocketPolicyServer`包裹,并指定端口即可
8
+
9
+ ```python
10
+ model_server = WebsocketPolicyServer(model, port=8002)
11
+ model_server.serve_forever() # 开启监听
12
+ ```
13
+
14
+ ## ./websocket_client_policy.py
15
+
16
+ 在`__main__()`中展示了如何创造一个假模型向真模型发送环境信息,只需要用`WebsocketClientPolicy`代替原有的模型即可
wan_va/utils/Simple_Remote_Infer/deploy/__init__.py ADDED
File without changes
wan_va/utils/Simple_Remote_Infer/deploy/image_tools.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+
4
+
5
+ def convert_to_uint8(img: np.ndarray) -> np.ndarray:
6
+ """Converts an image to uint8 if it is a float image.
7
+
8
+ This is important for reducing the size of the image when sending it over the network.
9
+ """
10
+ if np.issubdtype(img.dtype, np.floating):
11
+ img = (255 * img).astype(np.uint8)
12
+ return img
13
+
14
+
15
+ def resize_with_pad(images: np.ndarray,
16
+ height: int,
17
+ width: int,
18
+ method=Image.BILINEAR) -> np.ndarray:
19
+ """Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
20
+
21
+ Args:
22
+ images: A batch of images in [..., height, width, channel] format.
23
+ height: The target height of the image.
24
+ width: The target width of the image.
25
+ method: The interpolation method to use. Default is bilinear.
26
+
27
+ Returns:
28
+ The resized images in [..., height, width, channel].
29
+ """
30
+ # If the images are already the correct size, return them as is.
31
+ if images.shape[-3:-1] == (height, width):
32
+ return images
33
+
34
+ original_shape = images.shape
35
+
36
+ images = images.reshape(-1, *original_shape[-3:])
37
+ resized = np.stack([
38
+ _resize_with_pad_pil(Image.fromarray(im), height, width, method=method)
39
+ for im in images
40
+ ])
41
+ return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
42
+
43
+
44
+ def _resize_with_pad_pil(image: Image.Image, height: int, width: int,
45
+ method: int) -> Image.Image:
46
+ """Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
47
+ width without distortion by padding with zeros.
48
+
49
+ Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
50
+ """
51
+ cur_width, cur_height = image.size
52
+ if cur_width == width and cur_height == height:
53
+ return image # No need to resize if the image is already the correct size.
54
+
55
+ ratio = max(cur_width / width, cur_height / height)
56
+ resized_height = int(cur_height / ratio)
57
+ resized_width = int(cur_width / ratio)
58
+ resized_image = image.resize((resized_width, resized_height),
59
+ resample=method)
60
+
61
+ zero_image = Image.new(resized_image.mode, (width, height), 0)
62
+ pad_height = max(0, int((height - resized_height) / 2))
63
+ pad_width = max(0, int((width - resized_width) / 2))
64
+ zero_image.paste(resized_image, (pad_width, pad_height))
65
+ assert zero_image.size == (width, height)
66
+ return zero_image