Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +3 -0
- .gitignore +2 -0
- .vscode/launch.json +58 -0
- INSTALL.md +55 -0
- LICENSE.txt +201 -0
- LingBot_VA_paper.pdf +3 -0
- Makefile +5 -0
- README.md +371 -0
- assets/teaser.mp4 +3 -0
- assets/teaser_v3.png +3 -0
- debug/place_fan/call1_reset.msgpack +3 -0
- debug/place_fan/call2.msgpack +3 -0
- debug/place_fan/call3.msgpack +3 -0
- evaluation/robotwin/calc_stat.py +132 -0
- evaluation/robotwin/eval_polict_client_openpi.py +696 -0
- evaluation/robotwin/geometry.py +463 -0
- evaluation/robotwin/launch_client.sh +40 -0
- evaluation/robotwin/launch_client_multigpus.sh +81 -0
- evaluation/robotwin/launch_server.sh +15 -0
- evaluation/robotwin/launch_server_multigpus.sh +31 -0
- evaluation/robotwin/msgpack_numpy.py +57 -0
- evaluation/robotwin/test_render.py +81 -0
- evaluation/robotwin/websocket_client_policy.py +108 -0
- example/franka/observation.images.cam_high.png +0 -0
- example/franka/observation.images.cam_left_wrist.png +0 -0
- example/franka/observation.images.cam_right_wrist.png +0 -0
- example/robotwin/observation.images.cam_high.png +0 -0
- example/robotwin/observation.images.cam_left_wrist.png +0 -0
- example/robotwin/observation.images.cam_right_wrist.png +0 -0
- lingbot_robotwin_policy.py +506 -0
- pyproject.toml +61 -0
- requirements.txt +11 -0
- script/run_launch_va_server_sync.sh +34 -0
- wan_va/__init__.py +2 -0
- wan_va/configs/__init__.py +12 -0
- wan_va/configs/shared_config.py +13 -0
- wan_va/configs/va_franka_cfg.py +59 -0
- wan_va/configs/va_franka_i2va.py +11 -0
- wan_va/configs/va_robotwin_cfg.py +54 -0
- wan_va/configs/va_robotwin_i2va.py +11 -0
- wan_va/distributed/__init__.py +1 -0
- wan_va/distributed/fsdp.py +42 -0
- wan_va/distributed/util.py +29 -0
- wan_va/modules/__init__.py +7 -0
- wan_va/modules/model.py +580 -0
- wan_va/modules/utils.py +95 -0
- wan_va/utils/Simple_Remote_Infer/LEGAL.md +7 -0
- wan_va/utils/Simple_Remote_Infer/README.md +16 -0
- wan_va/utils/Simple_Remote_Infer/deploy/__init__.py +0 -0
- wan_va/utils/Simple_Remote_Infer/deploy/image_tools.py +66 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
LingBot_VA_paper.pdf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/teaser.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/teaser_v3.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pyc
|
| 2 |
+
visualization/
|
.vscode/launch.json
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
// Use IntelliSense to learn about possible attributes.
|
| 3 |
+
// Hover to view descriptions of existing attributes.
|
| 4 |
+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
| 5 |
+
"version": "0.2.0",
|
| 6 |
+
"configurations": [
|
| 7 |
+
|
| 8 |
+
{
|
| 9 |
+
"name": "Wan Server",
|
| 10 |
+
"type": "debugpy",
|
| 11 |
+
"request": "launch",
|
| 12 |
+
"program": "${file}",
|
| 13 |
+
"console": "integratedTerminal",
|
| 14 |
+
"justMyCode": false,
|
| 15 |
+
"args": [
|
| 16 |
+
"--config-name",
|
| 17 |
+
"robotwin",
|
| 18 |
+
"--port",
|
| 19 |
+
"29056",
|
| 20 |
+
"--save_root",
|
| 21 |
+
"visualization/",
|
| 22 |
+
"--debug_infer_once"
|
| 23 |
+
],
|
| 24 |
+
"env": {
|
| 25 |
+
"CUDA_VISIBLE_DEVICES": "6"
|
| 26 |
+
}
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"name": "Robotwin Client",
|
| 30 |
+
"type": "debugpy",
|
| 31 |
+
"request": "launch",
|
| 32 |
+
"module": "evaluation.robotwin.eval_polict_client_openpi",
|
| 33 |
+
"console": "integratedTerminal",
|
| 34 |
+
"justMyCode": false,
|
| 35 |
+
"cwd": "${workspaceFolder}",
|
| 36 |
+
"args": [
|
| 37 |
+
"--config", "policy/ACT/deploy_policy.yml",
|
| 38 |
+
"--overrides",
|
| 39 |
+
"--task_name", "adjust_bottle",
|
| 40 |
+
"--task_config", "demo_clean",
|
| 41 |
+
"--train_config_name", "0",
|
| 42 |
+
"--model_name", "0",
|
| 43 |
+
"--ckpt_setting", "0",
|
| 44 |
+
"--seed", "0",
|
| 45 |
+
"--policy_name", "ACT",
|
| 46 |
+
"--save_root", "./results",
|
| 47 |
+
"--video_guidance_scale", "5",
|
| 48 |
+
"--action_guidance_scale", "1",
|
| 49 |
+
"--test_num", "100",
|
| 50 |
+
"--port", "29056"
|
| 51 |
+
],
|
| 52 |
+
"env": {
|
| 53 |
+
"LD_LIBRARY_PATH": "/usr/lib64:/usr/lib:${env:LD_LIBRARY_PATH}",
|
| 54 |
+
"XLA_PYTHON_CLIENT_MEM_FRACTION": "0.9"
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
]
|
| 58 |
+
}
|
INSTALL.md
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Installation Guide
|
| 2 |
+
|
| 3 |
+
## Install with pip
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
pip install .
|
| 7 |
+
pip install .[dev] # Installe aussi les outils de dev
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
## Install with Poetry
|
| 11 |
+
|
| 12 |
+
Ensure you have [Poetry](https://python-poetry.org/docs/#installation) installed on your system.
|
| 13 |
+
|
| 14 |
+
To install all dependencies:
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
poetry install
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
### Handling `flash-attn` Installation Issues
|
| 21 |
+
|
| 22 |
+
If `flash-attn` fails due to **PEP 517 build issues**, you can try one of the following fixes.
|
| 23 |
+
|
| 24 |
+
#### No-Build-Isolation Installation (Recommended)
|
| 25 |
+
```bash
|
| 26 |
+
poetry run pip install --upgrade pip setuptools wheel
|
| 27 |
+
poetry run pip install flash-attn --no-build-isolation
|
| 28 |
+
poetry install
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
#### Install from Git (Alternative)
|
| 32 |
+
```bash
|
| 33 |
+
poetry run pip install git+https://github.com/Dao-AILab/flash-attention.git
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
---
|
| 37 |
+
|
| 38 |
+
### Running the Model
|
| 39 |
+
|
| 40 |
+
Once the installation is complete, you can run **Wan2.2** using:
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
poetry run python generate.py --task t2v-A14B --size '1280*720' --ckpt_dir ./Wan2.2-T2V-A14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
#### Test
|
| 47 |
+
```bash
|
| 48 |
+
bash tests/test.sh
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
#### Format
|
| 52 |
+
```bash
|
| 53 |
+
black .
|
| 54 |
+
isort .
|
| 55 |
+
```
|
LICENSE.txt
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
LingBot_VA_paper.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e791faff04ff10eccb62eef7952eabbcb2c654abc4d73f4b4a8d1f683a6e48ba
|
| 3 |
+
size 7834795
|
Makefile
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: format
|
| 2 |
+
|
| 3 |
+
format:
|
| 4 |
+
isort wan_va
|
| 5 |
+
yapf -i -r *.py wan_va
|
README.md
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<h1 align="center">LingBot-VA: Causal World Modeling for Robot Control</h1>
|
| 2 |
+
|
| 3 |
+
<p align="center">
|
| 4 |
+
<a href="https://arxiv.org/abs/2601.21998"><img src="https://img.shields.io/static/v1?label=Paper&message=PDF&color=red&logo=arxiv"></a>
|
| 5 |
+
<a href="https://technology.robbyant.com/lingbot-va"><img src="https://img.shields.io/badge/Project-Website-blue"></a>
|
| 6 |
+
<a href="https://huggingface.co/collections/robbyant/lingbot-va"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Model&message=HuggingFace&color=orange"></a>
|
| 7 |
+
<a href="https://modelscope.cn/collections/Robbyant/LingBot-VA"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%96%20Model&message=ModelScope&color=purple"></a>
|
| 8 |
+
<a href="LICENSE.txt"><img src="https://img.shields.io/badge/License-Apache--2.0-green"></a>
|
| 9 |
+
</p>
|
| 10 |
+
|
| 11 |
+
<p align="center">
|
| 12 |
+
<img src="assets/teaser_v3.png" width="100%">
|
| 13 |
+
</p>
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
https://github.com/user-attachments/assets/cec7b7a6-953b-4fa4-8f1a-47efc1fce547
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
## 💫 Meet **LingBot-VA**! We've built an AR diffusion framework for simultaneous world modeling and action! 🤖✨
|
| 23 |
+
|
| 24 |
+
**LingBot-VA** has focused on:
|
| 25 |
+
- **Autoregressive Video-Action World Modeling**: Architecturally unifies visual dynamics prediction and action inference within a single interleaved sequence while maintaining their conceptual distinction.
|
| 26 |
+
- **High-efficiency Execution**: A dual-stream mixture-of-transformers(MoT) architecture with Asynchronous Execution and KV Cache.
|
| 27 |
+
- **Long-Horizon Performance and Generalization**: High improvements in sample efficiency, long-horizon success rates, and generalization to novel scenes.
|
| 28 |
+
|
| 29 |
+
# 🚀 News
|
| 30 |
+
- **[2026-01-29]** Weights and code for shared backbone released! Please stay tuned for our separated version!
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# 📦 Model Download
|
| 40 |
+
- **Pretrained Checkpoints for Post-Training**
|
| 41 |
+
|
| 42 |
+
| Model Name | Huggingface Repository | ModelScope Repository | Description |
|
| 43 |
+
| :--- | :--- | :--- | :--- |
|
| 44 |
+
| lingbot-va-base | [🤗 robbyant/lingbot-va-base ](https://huggingface.co/robbyant/lingbot-va-base) | [🤖 Robbyant/lingbot-va-base ](https://modelscope.cn/models/Robbyant/lingbot-va-base) | LingBot-VA w/ shared backbone|
|
| 45 |
+
| lingbot-va-posttrain-robotwin | [🤗 robbyant/lingbot-va-posttrain-robotwin ](https://huggingface.co/robbyant/lingbot-va-posttrain-robotwin) | [🤖 Robbyant/lingbot-va-posttrain-robotwin ](https://modelscope.cn/models/Robbyant/lingbot-va-posttrain-robotwin) | LingBot-VA-Posttrain-Robotwin w/ shared backbone|
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
# 🛠️ Quick Start
|
| 49 |
+
|
| 50 |
+
## Installation
|
| 51 |
+
**Requirements**
|
| 52 |
+
• Python == 3.10.16
|
| 53 |
+
• Pytorch == 2.9.0
|
| 54 |
+
• CUDA 12.6
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
pip install torch==2.9.0 torchvision==0.24.0 torchaudio==2.9.0 --index-url https://download.pytorch.org/whl/cu126
|
| 58 |
+
pip install websockets einops diffusers==0.36.0 transformers==4.55.2 accelerate msgpack opencv-python matplotlib ftfy easydict
|
| 59 |
+
pip install flash-attn --no-build-isolation
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
## Deploying LingBot-VA for Inference
|
| 64 |
+
LingBot-VA supports both standalone execution and Server-Client architecture which separates the model environment from simulation. By isolating dependencies, the design avoids package clashes and supports distributed inference on GPUs, clusters, and other devices.
|
| 65 |
+
|
| 66 |
+
<!-- ### Standalone Inference
|
| 67 |
+
```python
|
| 68 |
+
python inference.py
|
| 69 |
+
```
|
| 70 |
+
This processes the example data from `examples/0/` and saves visualizations to `result/`. -->
|
| 71 |
+
|
| 72 |
+
### Evaluation on RoboTwin-2.0
|
| 73 |
+
|
| 74 |
+
**Preparing the Environment**
|
| 75 |
+
|
| 76 |
+
You can follow the official instructions from the original RoboTwin-2.0 repository:
|
| 77 |
+
[https://robotwin-platform.github.io/doc/usage/robotwin-install.html](https://robotwin-platform.github.io/doc/usage/robotwin-install.html)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
In summary:
|
| 81 |
+
|
| 82 |
+
1.
|
| 83 |
+
```bash
|
| 84 |
+
sudo apt install libvulkan1 mesa-vulkan-drivers vulkan-tools
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
2.
|
| 88 |
+
```bash
|
| 89 |
+
git clone https://github.com/RoboTwin-Platform/RoboTwin.git && cd RoboTwin
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
3. modify script/requirements.txt
|
| 93 |
+
```bash
|
| 94 |
+
transforms3d==0.4.2
|
| 95 |
+
sapien==3.0.0b1
|
| 96 |
+
scipy==1.10.1
|
| 97 |
+
mplib==0.2.1
|
| 98 |
+
gymnasium==0.29.1
|
| 99 |
+
trimesh==4.4.3
|
| 100 |
+
open3d==0.18.0
|
| 101 |
+
imageio==2.34.2
|
| 102 |
+
pydantic
|
| 103 |
+
zarr
|
| 104 |
+
openai
|
| 105 |
+
huggingface_hub==0.36.2
|
| 106 |
+
h5py
|
| 107 |
+
# For Description Generation
|
| 108 |
+
azure==4.0.0
|
| 109 |
+
azure-ai-inference
|
| 110 |
+
pyglet<2
|
| 111 |
+
wandb
|
| 112 |
+
moviepy
|
| 113 |
+
imageio
|
| 114 |
+
termcolor
|
| 115 |
+
av
|
| 116 |
+
matplotlib
|
| 117 |
+
ffmpeg
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
4. modify line 8 of script/_install.sh:
|
| 121 |
+
```bash
|
| 122 |
+
pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" --no-build-isolation
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
5. run:
|
| 126 |
+
```bash
|
| 127 |
+
bash script/_install.sh
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
6. run:
|
| 131 |
+
```bash
|
| 132 |
+
bash script/_download_assets.sh
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
**Deploying the Inference Server**
|
| 136 |
+
```bash
|
| 137 |
+
# single GPU
|
| 138 |
+
bash evaluation/robotwin/launch_server.sh
|
| 139 |
+
|
| 140 |
+
# multi-GPU
|
| 141 |
+
bash evaluation/robotwin/launch_server_multigpus.sh
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
**Executing the Inference Client**
|
| 145 |
+
```bash
|
| 146 |
+
# single GPU
|
| 147 |
+
task_name="adjust_bottle";
|
| 148 |
+
save_root="results/";
|
| 149 |
+
bash evaluation/robotwin/launch_client.sh ${save_root} ${task_name}
|
| 150 |
+
|
| 151 |
+
# multi-GPU
|
| 152 |
+
save_root="results/"
|
| 153 |
+
task_group_id=0;
|
| 154 |
+
bash evaluation/robotwin/launch_client_multigpus.sh ${save_root} ${task_group_id}
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
Related experiments results will be save in `/path/to/your/RoboTwin/${save_root}`. Please note that an `eval_result` folder is also generated. This is a native output from RoboTwin and is identical to the contents in the results folder; it can be safely ignored.
|
| 158 |
+
It is important to note that the inference server and client must be deployed on the same machine. For launching multi-GPU client, we padded the original 50 tasks to 56 via duplication and partitioned them into 7 groups to align with the 8-GPU configuration of our inference node. You can specify the `task_group_id` (0-6) to select a particular group for inference. For detailed grouping configurations, please refer to `evaluation/robotwin/launch_client_multigpus.sh`.
|
| 159 |
+
|
| 160 |
+
### Run Image to Video-Action Generation
|
| 161 |
+
|
| 162 |
+
We also provide a script for image to video-action generation:
|
| 163 |
+
|
| 164 |
+
```bash
|
| 165 |
+
NGPU=1 CONFIG_NAME='robotwin_i2av' bash script/run_launch_va_server_sync.sh
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
---
|
| 171 |
+
|
| 172 |
+
# 📊 Performance
|
| 173 |
+
|
| 174 |
+
We evaluate our model on both simulation benchmarks and real-world scenarios, and achieve state-of-the-art performance.
|
| 175 |
+
|
| 176 |
+
## Simulation Evaluation
|
| 177 |
+
|
| 178 |
+
- **RoboTwin 2.0**
|
| 179 |
+
|
| 180 |
+
We are the first to propel RoboTwin 2.0 metrics performance past the 90+ threshold!
|
| 181 |
+
<table style="border-collapse: collapse; width: auto; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Arial, sans-serif; font-size: 13px; line-height: 1.2;">
|
| 182 |
+
<!-- 指标说明 -->
|
| 183 |
+
<p style="font-size: 12px; color: #666; margin-bottom: 5px;">* All metrics are reported in percentage (%). Higher values are <b>bolded</b>.</p>
|
| 184 |
+
<thead>
|
| 185 |
+
<tr style="border-top: 2px solid black; border-bottom: 1px solid black;">
|
| 186 |
+
<th align="left" style="padding: 6px 12px; white-space: nowrap;">Method (Average 50 Tasks)</th>
|
| 187 |
+
<th align="center" style="padding: 6px 12px;">Easy SR (%)</th>
|
| 188 |
+
<th align="center" style="padding: 6px 12px;">Hard SR (%)</th>
|
| 189 |
+
</tr>
|
| 190 |
+
</thead>
|
| 191 |
+
<tbody>
|
| 192 |
+
<tr>
|
| 193 |
+
<td style="padding: 4px 12px; white-space: nowrap;">X-VLA</td>
|
| 194 |
+
<td align="center">72.9</td>
|
| 195 |
+
<td align="center">72.8</td>
|
| 196 |
+
</tr>
|
| 197 |
+
<tr>
|
| 198 |
+
<td style="padding: 4px 12px; white-space: nowrap;">π<sub>0</sub></td>
|
| 199 |
+
<td align="center">65.9</td>
|
| 200 |
+
<td align="center">58.4</td>
|
| 201 |
+
</tr>
|
| 202 |
+
<tr>
|
| 203 |
+
<td style="padding: 4px 12px; white-space: nowrap;">π<sub>0.5</sub></td>
|
| 204 |
+
<td align="center">82.7</td>
|
| 205 |
+
<td align="center">76.8</td>
|
| 206 |
+
</tr>
|
| 207 |
+
<tr>
|
| 208 |
+
<td style="padding: 4px 12px; white-space: nowrap;">Motus</td>
|
| 209 |
+
<td align="center"><u>88.7</u></td>
|
| 210 |
+
<td align="center"><u>87.0</u></td>
|
| 211 |
+
</tr>
|
| 212 |
+
<tr style="border-top: 1px solid black; border-bottom: 2px solid black;">
|
| 213 |
+
<td style="padding: 6px 12px; white-space: nowrap;"><b>LingBot-VA (Ours)</b></td>
|
| 214 |
+
<td align="center"><b>92.9</b> <small>(+4.2)</small></td>
|
| 215 |
+
<td align="center"><b>91.6</b> <small>(+4.6)</small></td>
|
| 216 |
+
</tr>
|
| 217 |
+
</tbody>
|
| 218 |
+
</table>
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
- **LIBERO**
|
| 222 |
+
|
| 223 |
+
<table style="border-collapse: collapse; width: auto; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Arial, sans-serif; font-size: 13px; line-height: 1.2;">
|
| 224 |
+
<!-- 指标说明 -->
|
| 225 |
+
<p style="font-size: 12px; color: #666; margin-bottom: 5px;">* All metrics are reported in percentage (%). Higher values are <b>bolded</b>.</p>
|
| 226 |
+
<thead>
|
| 227 |
+
<tr style="border-top: 2px solid black; border-bottom: 1px solid black;">
|
| 228 |
+
<th align="left" style="padding: 6px 10px; border-right: 1px solid black; white-space: nowrap;">Methods</th>
|
| 229 |
+
<th align="center" style="padding: 6px 8px;">Spatial</th>
|
| 230 |
+
<th align="center" style="padding: 6px 8px;">Object</th>
|
| 231 |
+
<th align="center" style="padding: 6px 8px;">Goal</th>
|
| 232 |
+
<th align="center" style="padding: 6px 8px;">Long</th>
|
| 233 |
+
<th align="center" style="padding: 6px 8px;">Avg</th>
|
| 234 |
+
</tr>
|
| 235 |
+
</thead>
|
| 236 |
+
<tbody>
|
| 237 |
+
<tr>
|
| 238 |
+
<td style="padding: 4px 10px; border-right: 1px solid black; white-space: nowrap;">π<sub>0</sub></td>
|
| 239 |
+
<td align="center">96.8</td><td align="center">98.8</td><td align="center">95.8</td><td align="center">85.2</td><td align="center">94.1</td>
|
| 240 |
+
</tr>
|
| 241 |
+
<tr>
|
| 242 |
+
<td style="padding: 4px 10px; border-right: 1px solid black; white-space: nowrap;">π<sub>0.5</sub></td>
|
| 243 |
+
<td align="center">98.8</td><td align="center">98.2</td><td align="center">98.0</td><td align="center">92.4</td><td align="center">96.9</td>
|
| 244 |
+
</tr>
|
| 245 |
+
<tr>
|
| 246 |
+
<td style="padding: 4px 10px; border-right: 1px solid black; white-space: nowrap;">OpenVLA</td>
|
| 247 |
+
<td align="center">84.7</td><td align="center">88.4</td><td align="center">79.2</td><td align="center">53.7</td><td align="center">76.5</td>
|
| 248 |
+
</tr>
|
| 249 |
+
<tr>
|
| 250 |
+
<td style="padding: 4px 10px; border-right: 1px solid black; white-space: nowrap;">X-VLA</td>
|
| 251 |
+
<td align="center">98.2</td><td align="center">98.6</td><td align="center">97.8</td><td align="center">97.6</td><td align="center">98.1</td>
|
| 252 |
+
</tr>
|
| 253 |
+
<tr style="border-top: 1.5px solid black; border-bottom: 2px solid black;">
|
| 254 |
+
<td style="padding: 5px 10px; border-right: 1px solid black; white-space: nowrap;"><b>LingBot-VA (Ours)</b></td>
|
| 255 |
+
<td align="center"><b>98.5 ± 0.3</b></td>
|
| 256 |
+
<td align="center"><b>99.6 ± 0.3</b></td>
|
| 257 |
+
<td align="center"><b>97.2 ± 0.2</b></td>
|
| 258 |
+
<td align="center"><b>98.5 ± 0.5</b></td>
|
| 259 |
+
<td align="center"><b>98.5</b></td>
|
| 260 |
+
</tr>
|
| 261 |
+
</tbody>
|
| 262 |
+
</table>
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
## Real-world Deployment
|
| 269 |
+
|
| 270 |
+
Six manipulation tasks across three categories: longhorizon tasks (Make Breakfast, Pick Screws), precision tasks (Insert Tube, Unpack Delivery), and deformable & articulated object
|
| 271 |
+
manipulation (Fold Clothes, Fold Pants). Our method achieves state-of-the-art performance on both metrics (Progress Rate and Success Rate) with <b>only 50 trials</b> per task, substantially outperforming strong baseline π<sub>0.5</sub>.
|
| 272 |
+
|
| 273 |
+
<div style="text-align: left; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Arial, sans-serif; line-height: 1.6;">
|
| 274 |
+
|
| 275 |
+
<!-- 第一部分:PS 说明 -->
|
| 276 |
+
<div style="margin-bottom: 5px;"><strong>Progress Score (PS):</strong> The average score across all trials divided by the maximum possible score, expressed as a percentage:</div>
|
| 277 |
+
|
| 278 |
+
PS = Average_Progress / Max_Steps × 100%
|
| 279 |
+
|
| 280 |
+
<!-- 第二部分:SR 说明 -->
|
| 281 |
+
<div style="margin-bottom: 5px;"><strong>Success Rate (SR):</strong> The number of successful trials divided by the total number of trials, expressed as a percentage:</div>
|
| 282 |
+
|
| 283 |
+
SR = Successful_Trials / N × 100%
|
| 284 |
+
|
| 285 |
+
</div>
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
<div style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Arial, sans-serif;">
|
| 290 |
+
<!-- 指标说明 -->
|
| 291 |
+
<p style="font-size: 12px; color: #666; margin-bottom: 5px;">* All metrics are reported in percentage (%). Higher values are <b>bolded</b>.</p>
|
| 292 |
+
|
| 293 |
+
<table style="border-collapse: collapse; width: auto; font-size: 13px; line-height: 1.2;">
|
| 294 |
+
<thead>
|
| 295 |
+
<tr style="border-top: 2px solid black;">
|
| 296 |
+
<th rowspan="2" align="left" style="padding: 4px 10px; border-bottom: 1px solid black; white-space: nowrap;"><b>Task</b></th>
|
| 297 |
+
<th colspan="2" style="padding: 4px 10px; border-bottom: 1px solid black;">Make Breakfast</th>
|
| 298 |
+
<th colspan="2" style="padding: 4px 10px; border-bottom: 1px solid black;">Pick Screws</th>
|
| 299 |
+
<th colspan="2" style="padding: 4px 10px; border-bottom: 1px solid black;">Insert Tube</th>
|
| 300 |
+
<th colspan="2" style="padding: 4px 10px; border-bottom: 1px solid black;">Unpack Delivery</th>
|
| 301 |
+
<th colspan="2" style="padding: 4px 10px; border-bottom: 1px solid black;">Fold Clothes</th>
|
| 302 |
+
<th colspan="2" style="padding: 4px 10px; border-bottom: 1px solid black;">Fold Pants</th>
|
| 303 |
+
</tr>
|
| 304 |
+
<tr style="border-bottom: 1px solid black;">
|
| 305 |
+
<th style="padding: 4px 8px;">PS</th>
|
| 306 |
+
<th style="padding: 4px 8px;">SR</th>
|
| 307 |
+
<th style="padding: 4px 8px;">PS</th>
|
| 308 |
+
<th style="padding: 4px 8px;">SR</th>
|
| 309 |
+
<th style="padding: 4px 8px;">PS</th>
|
| 310 |
+
<th style="padding: 4px 8px;">SR</th>
|
| 311 |
+
<th style="padding: 4px 8px;">PS</th>
|
| 312 |
+
<th style="padding: 4px 8px;">SR</th>
|
| 313 |
+
<th style="padding: 4px 8px;">PS</th>
|
| 314 |
+
<th style="padding: 4px 8px;">SR</th>
|
| 315 |
+
<th style="padding: 4px 8px;">PS</th>
|
| 316 |
+
<th style="padding: 4px 8px;">SR</th>
|
| 317 |
+
</tr>
|
| 318 |
+
</thead>
|
| 319 |
+
<tbody>
|
| 320 |
+
<tr>
|
| 321 |
+
<td style="padding: 6px 10px; white-space: nowrap;">π<sub>0.5</sub></td>
|
| 322 |
+
<td align="center">73.0</td><td align="center">70.0</td>
|
| 323 |
+
<td align="center">74.0</td><td align="center">50.0</td>
|
| 324 |
+
<td align="center">79.2</td><td align="center">30.0</td>
|
| 325 |
+
<td align="center">73.0</td><td align="center">25.0</td>
|
| 326 |
+
<td align="center"><b>62.9</b></td><td align="center">30.0</td>
|
| 327 |
+
<td align="center">30.0</td><td align="center">30.0</td>
|
| 328 |
+
</tr>
|
| 329 |
+
<tr style="border-bottom: 2px solid black;">
|
| 330 |
+
<td style="padding: 6px 10px; white-space: nowrap;"><b>LingBot-VA (Ours)</b></td>
|
| 331 |
+
<td align="center"><b>97.0</b></td><td align="center"><b>75.0</b></td>
|
| 332 |
+
<td align="center"><b>82.5</b></td><td align="center"><b>70.0</b></td>
|
| 333 |
+
<td align="center"><b>85.8</b></td><td align="center"><b>40.0</b></td>
|
| 334 |
+
<td align="center"><b>84.5</b></td><td align="center"><b>65.0</b></td>
|
| 335 |
+
<td align="center">48.8</td><td align="center"><b>35.0</b></td>
|
| 336 |
+
<td align="center"><b>76.7</b></td><td align="center"><b>70.0</b></td>
|
| 337 |
+
</tr>
|
| 338 |
+
</tbody>
|
| 339 |
+
</table>
|
| 340 |
+
</div>
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
# 🪪 License
|
| 344 |
+
|
| 345 |
+
This project is released under the Apache License 2.0. See [LICENSE](LICENSE.txt) file for details.
|
| 346 |
+
|
| 347 |
+
# 📚Citation
|
| 348 |
+
|
| 349 |
+
```bibtex
|
| 350 |
+
@article{lingbot-va2026,
|
| 351 |
+
title={Causal World Modeling for Robot Control},
|
| 352 |
+
author={Li, Lin and Zhang, Qihang and Luo, Yiming and Yang, Shuai and Wang, Ruilin and Han, Fei and Yu, Mingrui and Gao, Zelin and Xue, Nan and Zhu, Xing and Shen, Yujun and Xu, Yinghao},
|
| 353 |
+
journal={arXiv preprint arXiv:2601.21998},
|
| 354 |
+
year={2026}
|
| 355 |
+
}
|
| 356 |
+
```
|
| 357 |
+
|
| 358 |
+
# 🧩 Acknowledgments
|
| 359 |
+
|
| 360 |
+
This work builds upon several excellent open-source projects:
|
| 361 |
+
|
| 362 |
+
- [Wan-Video](https://github.com/Wan-Video) - Vision transformer backbone
|
| 363 |
+
- [MoT](https://github.com/facebookresearch/Mixture-of-Transformers) - Mixture-of-Transformers architecture
|
| 364 |
+
- The broader open-source computer vision and robotics communities
|
| 365 |
+
|
| 366 |
+
---
|
| 367 |
+
|
| 368 |
+
For questions, discussions, or collaborations:
|
| 369 |
+
|
| 370 |
+
- **Issues**: Open an [issue](https://github.com/robbyant/lingbot-va/issues) on GitHub
|
| 371 |
+
- **Email**: Contact Dr. [Qihang Zhang](https://zqh0253.github.io/) (liuhuan.zqh@antgroup.com) or Dr. [Lin Li](https://lilin-hitcrt.github.io/) (fengchang.ll@antgroup.com)
|
assets/teaser.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2b23b4170e7784b82d8a6287451e886ac88fd0d7d841c55c4bc290b068c9f394
|
| 3 |
+
size 12144486
|
assets/teaser_v3.png
ADDED
|
Git LFS Details
|
debug/place_fan/call1_reset.msgpack
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c0d8e7cb0edd8bb74e138e0cc15f50470a7db8316c4ac39e8a2178fde2e1a3da
|
| 3 |
+
size 140
|
debug/place_fan/call2.msgpack
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:23fc9654673dfad3b4c70a545782bab385b9b18f439fbfb6005b582de0e7005e
|
| 3 |
+
size 691918
|
debug/place_fan/call3.msgpack
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b0a92955b91b046dc90ee8e1c671be2938f53d90712245b126b20e2a569651c6
|
| 3 |
+
size 2769097
|
evaluation/robotwin/calc_stat.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
def compute_success_rates(root_dir: str, true_suffix="True.mp4", false_suffix="False.mp4"):
|
| 4 |
+
root = Path(root_dir)
|
| 5 |
+
if not root.exists():
|
| 6 |
+
raise FileNotFoundError(f"Root dir not found: {root}")
|
| 7 |
+
|
| 8 |
+
results = []
|
| 9 |
+
for sub in sorted([p for p in root.iterdir() if p.is_dir()]):
|
| 10 |
+
true_cnt = 0
|
| 11 |
+
false_cnt = 0
|
| 12 |
+
|
| 13 |
+
for mp4 in sub.rglob("*.mp4"):
|
| 14 |
+
name = mp4.name
|
| 15 |
+
if name.endswith(true_suffix):
|
| 16 |
+
true_cnt += 1
|
| 17 |
+
elif name.endswith(false_suffix):
|
| 18 |
+
false_cnt += 1
|
| 19 |
+
|
| 20 |
+
total = true_cnt + false_cnt
|
| 21 |
+
rate = (true_cnt / total) if total > 0 else None
|
| 22 |
+
results.append((sub.name, true_cnt, false_cnt, total, rate))
|
| 23 |
+
|
| 24 |
+
return results
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# 你的三类:task -> 1/2/3
|
| 28 |
+
TASK_CLASS = {
|
| 29 |
+
"adjust_bottle": 1,
|
| 30 |
+
"beat_block_hammer": 1,
|
| 31 |
+
"blocks_ranking_rgb": 3,
|
| 32 |
+
"blocks_ranking_size": 3,
|
| 33 |
+
"click_alarmclock": 1,
|
| 34 |
+
"click_bell": 1,
|
| 35 |
+
"dump_bin_bigbin": 1,
|
| 36 |
+
"grab_roller": 1,
|
| 37 |
+
"handover_block": 2,
|
| 38 |
+
"handover_mic": 2,
|
| 39 |
+
"hanging_mug": 2,
|
| 40 |
+
"lift_pot": 1,
|
| 41 |
+
"move_can_pot": 1,
|
| 42 |
+
"move_pillbottle_pad": 1,
|
| 43 |
+
"move_playingcard_away": 1,
|
| 44 |
+
"move_stapler_pad": 1,
|
| 45 |
+
"open_laptop": 1,
|
| 46 |
+
"open_microwave": 1,
|
| 47 |
+
"pick_diverse_bottles": 2,
|
| 48 |
+
"pick_dual_bottles": 2,
|
| 49 |
+
"place_a2b_left": 1,
|
| 50 |
+
"place_a2b_right": 1,
|
| 51 |
+
"place_bread_basket": 1,
|
| 52 |
+
"place_bread_skillet": 2,
|
| 53 |
+
"place_burger_fries": 2,
|
| 54 |
+
"place_can_basket": 2,
|
| 55 |
+
"place_cans_plasticbox": 2,
|
| 56 |
+
"place_container_plate": 1,
|
| 57 |
+
"place_dual_shoes": 2,
|
| 58 |
+
"place_empty_cup": 1,
|
| 59 |
+
"place_fan": 1,
|
| 60 |
+
"place_mouse_pad": 1,
|
| 61 |
+
"place_object_basket": 2,
|
| 62 |
+
"place_object_scale": 1,
|
| 63 |
+
"place_object_stand": 1,
|
| 64 |
+
"place_phone_stand": 1,
|
| 65 |
+
"place_shoe": 1,
|
| 66 |
+
"press_stapler": 1,
|
| 67 |
+
"put_bottles_dustbin": 3,
|
| 68 |
+
"put_object_cabinet": 2,
|
| 69 |
+
"rotate_qrcode": 1,
|
| 70 |
+
"scan_object": 2,
|
| 71 |
+
"shake_bottle_horizontally": 1,
|
| 72 |
+
"shake_bottle": 1,
|
| 73 |
+
"stack_blocks_three": 3,
|
| 74 |
+
"stack_blocks_two": 2,
|
| 75 |
+
"stack_bowls_three": 3,
|
| 76 |
+
"stack_bowls_two": 2,
|
| 77 |
+
"stamp_seal": 1,
|
| 78 |
+
"turn_switch": 1,
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
def mean_rate_of(results_subset):
|
| 82 |
+
rates = [r[4] for r in results_subset if r[4] is not None]
|
| 83 |
+
return (sum(rates) / len(rates)) if rates else None
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def print_table(results):
|
| 87 |
+
# 按成功率排序:None(=N/A) 放最后,其余从高到低
|
| 88 |
+
results = sorted(results, key=lambda r: (r[4] is None, -(r[4] or 0.0)))
|
| 89 |
+
|
| 90 |
+
print(f"{'folder':30s} {'True':>6s} {'False':>6s} {'Total':>6s} {'SuccessRate':>12s} {'Class':>6s}")
|
| 91 |
+
print("-" * 90)
|
| 92 |
+
|
| 93 |
+
for folder, t, f, total, rate in results:
|
| 94 |
+
rate_str = "N/A" if rate is None else f"{rate*100:9.2f}%"
|
| 95 |
+
cls = TASK_CLASS.get(folder, None)
|
| 96 |
+
cls_str = "N/A" if cls is None else str(cls)
|
| 97 |
+
print(f"{folder:30s} {t:6d} {f:6d} {total:6d} {rate_str:>12s} {cls_str:>6s}")
|
| 98 |
+
|
| 99 |
+
print("-" * 90)
|
| 100 |
+
|
| 101 |
+
# overall mean
|
| 102 |
+
overall_mean = mean_rate_of(results)
|
| 103 |
+
overall_str = "N/A" if overall_mean is None else f"{overall_mean*100:9.2f}%"
|
| 104 |
+
print(f"{'MEAN (ALL)':30s} {'':6s} {'':6s} {'':6s} {overall_str:>12s}")
|
| 105 |
+
|
| 106 |
+
# per-class mean (1/2/3)
|
| 107 |
+
for c in (1, 2, 3):
|
| 108 |
+
subset = [r for r in results if TASK_CLASS.get(r[0]) == c]
|
| 109 |
+
m = mean_rate_of(subset)
|
| 110 |
+
m_str = "N/A" if m is None else f"{m*100:9.2f}%"
|
| 111 |
+
print(f"{('MEAN (CLASS '+str(c)+')'):30s} {'':6s} {'':6s} {'':6s} {m_str:>12s}")
|
| 112 |
+
|
| 113 |
+
# optional: tasks not in mapping
|
| 114 |
+
unknown_subset = [r for r in results if r[0] not in TASK_CLASS]
|
| 115 |
+
if unknown_subset:
|
| 116 |
+
m = mean_rate_of(unknown_subset)
|
| 117 |
+
m_str = "N/A" if m is None else f"{m*100:9.2f}%"
|
| 118 |
+
print(f"{'MEAN (UNKNOWN)':30s} {'':6s} {'':6s} {'':6s} {m_str:>12s}")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
import sys
|
| 123 |
+
|
| 124 |
+
roots = sys.argv[1:]
|
| 125 |
+
if not roots:
|
| 126 |
+
raise SystemExit("Usage: python a.py <root_folder1> [<root_folder2> ...]")
|
| 127 |
+
|
| 128 |
+
all_results = []
|
| 129 |
+
for root_dir in roots:
|
| 130 |
+
all_results.extend(compute_success_rates(root_dir))
|
| 131 |
+
|
| 132 |
+
print_table(all_results)
|
evaluation/robotwin/eval_polict_client_openpi.py
ADDED
|
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import subprocess
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
| 6 |
+
import cv2
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
robowin_root = Path("/group/ossdphi_algo_scratch_11/weicxu/pythonproject/RoboTwin")
|
| 10 |
+
if str(robowin_root) not in sys.path:
|
| 11 |
+
sys.path.insert(0, str(robowin_root))
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
os.chdir(robowin_root)
|
| 16 |
+
|
| 17 |
+
from envs import CONFIGS_PATH
|
| 18 |
+
from envs.utils.create_actor import UnStableError
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from collections import deque
|
| 23 |
+
import traceback
|
| 24 |
+
|
| 25 |
+
import yaml
|
| 26 |
+
from datetime import datetime
|
| 27 |
+
import importlib
|
| 28 |
+
import argparse
|
| 29 |
+
import pdb
|
| 30 |
+
from evaluation.robotwin.geometry import euler2quat
|
| 31 |
+
import numpy as np
|
| 32 |
+
|
| 33 |
+
from description.utils.generate_episode_instructions import *
|
| 34 |
+
import traceback
|
| 35 |
+
|
| 36 |
+
import imageio
|
| 37 |
+
import numpy as np
|
| 38 |
+
from pathlib import Path
|
| 39 |
+
from scipy.spatial.transform import Rotation as R
|
| 40 |
+
import json
|
| 41 |
+
from pathlib import Path
|
| 42 |
+
|
| 43 |
+
from evaluation.robotwin.websocket_client_policy import WebsocketClientPolicy
|
| 44 |
+
from evaluation.robotwin.test_render import Sapien_TEST
|
| 45 |
+
|
| 46 |
+
def write_json(data: dict, fpath: Path) -> None:
|
| 47 |
+
"""Write data to a JSON file.
|
| 48 |
+
|
| 49 |
+
Creates parent directories if they don't exist.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
data (dict): The dictionary to write.
|
| 53 |
+
fpath (Path): The path to the output JSON file.
|
| 54 |
+
"""
|
| 55 |
+
fpath.parent.mkdir(exist_ok=True, parents=True)
|
| 56 |
+
with open(fpath, "w") as f:
|
| 57 |
+
json.dump(data, f, indent=4, ensure_ascii=False)
|
| 58 |
+
|
| 59 |
+
def add_title_bar(img, text, font_scale=0.8, thickness=2):
|
| 60 |
+
"""Add a black title bar with text above the image"""
|
| 61 |
+
h, w, _ = img.shape
|
| 62 |
+
bar_height = 40
|
| 63 |
+
|
| 64 |
+
# Create black background bar
|
| 65 |
+
title_bar = np.zeros((bar_height, w, 3), dtype=np.uint8)
|
| 66 |
+
|
| 67 |
+
# Calculate text position to center it
|
| 68 |
+
(text_w, text_h), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
|
| 69 |
+
text_x = (w - text_w) // 2
|
| 70 |
+
text_y = (bar_height + text_h) // 2 - 5
|
| 71 |
+
|
| 72 |
+
cv2.putText(title_bar, text, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX,
|
| 73 |
+
font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
|
| 74 |
+
|
| 75 |
+
return np.vstack([title_bar, img])
|
| 76 |
+
|
| 77 |
+
def quaternion_to_euler(quat):
|
| 78 |
+
"""
|
| 79 |
+
Convert quaternion to Euler angles (roll, pitch, yaw)
|
| 80 |
+
quat: [rx, ry, rz, rw] format
|
| 81 |
+
Return: [roll, pitch, yaw] (radians)
|
| 82 |
+
"""
|
| 83 |
+
# scipy uses [x, y, z, w] format
|
| 84 |
+
rotation = R.from_quat(quat)
|
| 85 |
+
euler = rotation.as_euler('xyz', degrees=False) # returns [roll, pitch, yaw]
|
| 86 |
+
return euler
|
| 87 |
+
|
| 88 |
+
def visualize_action_step(action_history, step_idx, window=50):
|
| 89 |
+
"""
|
| 90 |
+
Plot dual-arm action curves:
|
| 91 |
+
Subplot 1: Left arm XYZ Position + Gripper
|
| 92 |
+
Subplot 2: Left arm Euler angles (Roll, Pitch, Yaw) - converted from quaternion
|
| 93 |
+
Subplot 3: Right arm XYZ Position + Gripper
|
| 94 |
+
Subplot 4: Right arm Euler angles (Roll, Pitch, Yaw) - converted from quaternion
|
| 95 |
+
|
| 96 |
+
Input data format: [left_x, left_y, left_z, left_rx, left_ry, left_rz, left_rw, left_gripper,
|
| 97 |
+
right_x, right_y, right_z, right_rx, right_ry, right_rz, right_rw, right_gripper]
|
| 98 |
+
Total 16 dimensions
|
| 99 |
+
"""
|
| 100 |
+
# Create four subplots, sharing the X-axis
|
| 101 |
+
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 8), dpi=100, sharex=True)
|
| 102 |
+
|
| 103 |
+
# 1. Determine slice range
|
| 104 |
+
start = max(0, step_idx - window)
|
| 105 |
+
end = step_idx + 1
|
| 106 |
+
|
| 107 |
+
# 2. Get data subset
|
| 108 |
+
history_subset = np.array(action_history)[start:end]
|
| 109 |
+
|
| 110 |
+
# 3. Generate X-axis based on actual data length
|
| 111 |
+
actual_len = len(history_subset)
|
| 112 |
+
x_axis = range(start, start + actual_len)
|
| 113 |
+
|
| 114 |
+
if actual_len > 0 and history_subset.shape[1] >= 16:
|
| 115 |
+
# Convert quaternions to Euler angles
|
| 116 |
+
left_euler = []
|
| 117 |
+
right_euler = []
|
| 118 |
+
|
| 119 |
+
for action in history_subset:
|
| 120 |
+
# Left arm quaternion to Euler angles
|
| 121 |
+
left_quat = action[3:7] # [rx, ry, rz, rw]
|
| 122 |
+
left_rpy = quaternion_to_euler(left_quat)
|
| 123 |
+
left_euler.append(left_rpy)
|
| 124 |
+
|
| 125 |
+
# Right arm quaternion to Euler angles
|
| 126 |
+
right_quat = action[11:15] # [rx, ry, rz, rw]
|
| 127 |
+
right_rpy = quaternion_to_euler(right_quat)
|
| 128 |
+
right_euler.append(right_rpy)
|
| 129 |
+
|
| 130 |
+
left_euler = np.array(left_euler)
|
| 131 |
+
right_euler = np.array(right_euler)
|
| 132 |
+
|
| 133 |
+
# --- Left Arm ---
|
| 134 |
+
# Subplot 1: Left Arm Translation (XYZ) + Gripper
|
| 135 |
+
ax1.plot(x_axis, history_subset[:, 0], label='left_x', color='r', linewidth=1.5)
|
| 136 |
+
ax1.plot(x_axis, history_subset[:, 1], label='left_y', color='g', linewidth=1.5)
|
| 137 |
+
ax1.plot(x_axis, history_subset[:, 2], label='left_z', color='b', linewidth=1.5)
|
| 138 |
+
ax1.plot(x_axis, history_subset[:, 7], label='left_grip', color='orange',
|
| 139 |
+
linestyle=':', linewidth=2, alpha=0.8)
|
| 140 |
+
ax1.set_ylabel('Position (m)')
|
| 141 |
+
ax1.legend(loc='upper right', fontsize='x-small', ncol=4)
|
| 142 |
+
ax1.grid(True, alpha=0.3)
|
| 143 |
+
ax1.set_title(f"Step {step_idx}: Left Arm Position & Gripper")
|
| 144 |
+
|
| 145 |
+
# Subplot 2: Left Arm Euler Angles (Roll, Pitch, Yaw)
|
| 146 |
+
ax2.plot(x_axis, left_euler[:, 0], label='left_roll', color='c', linewidth=1.5)
|
| 147 |
+
ax2.plot(x_axis, left_euler[:, 1], label='left_pitch', color='m', linewidth=1.5)
|
| 148 |
+
ax2.plot(x_axis, left_euler[:, 2], label='left_yaw', color='y', linewidth=1.5)
|
| 149 |
+
ax2.set_ylabel('Rotation (rad)')
|
| 150 |
+
ax2.legend(loc='upper right', fontsize='x-small', ncol=3)
|
| 151 |
+
ax2.grid(True, alpha=0.3)
|
| 152 |
+
ax2.set_title("Left Arm Rotation (RPY from Quaternion)")
|
| 153 |
+
|
| 154 |
+
# --- Right Arm ---
|
| 155 |
+
# Subplot 3: Right Arm Translation (XYZ) + Gripper
|
| 156 |
+
ax3.plot(x_axis, history_subset[:, 8], label='right_x', color='r', linewidth=1.5, linestyle='--')
|
| 157 |
+
ax3.plot(x_axis, history_subset[:, 9], label='right_y', color='g', linewidth=1.5, linestyle='--')
|
| 158 |
+
ax3.plot(x_axis, history_subset[:, 10], label='right_z', color='b', linewidth=1.5, linestyle='--')
|
| 159 |
+
ax3.plot(x_axis, history_subset[:, 15], label='right_grip', color='orange',
|
| 160 |
+
linestyle=':', linewidth=2, alpha=0.8)
|
| 161 |
+
ax3.set_ylabel('Position (m)')
|
| 162 |
+
ax3.legend(loc='upper right', fontsize='x-small', ncol=4)
|
| 163 |
+
ax3.grid(True, alpha=0.3)
|
| 164 |
+
ax3.set_title("Right Arm Position & Gripper")
|
| 165 |
+
|
| 166 |
+
# Subplot 4: Right Arm Euler Angles (Roll, Pitch, Yaw)
|
| 167 |
+
ax4.plot(x_axis, right_euler[:, 0], label='right_roll', color='c', linewidth=1.5, linestyle='--')
|
| 168 |
+
ax4.plot(x_axis, right_euler[:, 1], label='right_pitch', color='m', linewidth=1.5, linestyle='--')
|
| 169 |
+
ax4.plot(x_axis, right_euler[:, 2], label='right_yaw', color='y', linewidth=1.5, linestyle='--')
|
| 170 |
+
ax4.set_ylabel('Rotation (rad)')
|
| 171 |
+
ax4.legend(loc='upper right', fontsize='x-small', ncol=3)
|
| 172 |
+
ax4.grid(True, alpha=0.3)
|
| 173 |
+
ax4.set_title("Right Arm Rotation (RPY from Quaternion)")
|
| 174 |
+
|
| 175 |
+
# Set X-axis display range to maintain sliding window effect
|
| 176 |
+
ax1.set_xlim(max(0, step_idx - window), max(window, step_idx))
|
| 177 |
+
ax3.set_xlabel('Step')
|
| 178 |
+
ax4.set_xlabel('Step')
|
| 179 |
+
|
| 180 |
+
plt.tight_layout()
|
| 181 |
+
canvas = FigureCanvas(fig)
|
| 182 |
+
canvas.draw()
|
| 183 |
+
img = np.asarray(canvas.buffer_rgba())
|
| 184 |
+
img = img[:, :, :3]
|
| 185 |
+
|
| 186 |
+
# Convert to uint8
|
| 187 |
+
if img.dtype != np.uint8:
|
| 188 |
+
img = (img * 255).astype(np.uint8)
|
| 189 |
+
|
| 190 |
+
plt.close(fig)
|
| 191 |
+
return img
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def save_comparison_video(real_obs_list, imagined_video, action_history, save_path, fps=15):
|
| 195 |
+
if not real_obs_list:
|
| 196 |
+
return
|
| 197 |
+
|
| 198 |
+
n_real = len(real_obs_list)
|
| 199 |
+
if imagined_video is not None:
|
| 200 |
+
imagined_video = np.concatenate(imagined_video, 0)
|
| 201 |
+
n_imagined = len(imagined_video)
|
| 202 |
+
else:
|
| 203 |
+
n_imagined = 0
|
| 204 |
+
n_frames = n_real # Based on real observation frames
|
| 205 |
+
|
| 206 |
+
print(f"Saving video: Real {n_real} frames, Imagined {n_imagined} frames...")
|
| 207 |
+
|
| 208 |
+
final_frames = []
|
| 209 |
+
|
| 210 |
+
for i in range(n_frames):
|
| 211 |
+
obs = real_obs_list[i]
|
| 212 |
+
cam_high = obs["observation.images.cam_high"]
|
| 213 |
+
cam_left = obs["observation.images.cam_left_wrist"]
|
| 214 |
+
cam_right = obs["observation.images.cam_right_wrist"]
|
| 215 |
+
|
| 216 |
+
base_h = cam_high.shape[0]
|
| 217 |
+
|
| 218 |
+
def resize_h(img, h):
|
| 219 |
+
if img.shape[0] != h:
|
| 220 |
+
w = int(img.shape[1] * h / img.shape[0])
|
| 221 |
+
return cv2.resize(img, (w, h))
|
| 222 |
+
return img
|
| 223 |
+
|
| 224 |
+
row_real = np.hstack([
|
| 225 |
+
resize_h(cam_high, base_h),
|
| 226 |
+
resize_h(cam_left, base_h),
|
| 227 |
+
resize_h(cam_right, base_h)
|
| 228 |
+
])
|
| 229 |
+
|
| 230 |
+
if row_real.dtype != np.uint8:
|
| 231 |
+
row_real = (row_real * 255).astype(np.uint8)
|
| 232 |
+
|
| 233 |
+
row_real = add_title_bar(row_real, "Real Observation (High / Left / Right)")
|
| 234 |
+
|
| 235 |
+
target_width = row_real.shape[1]
|
| 236 |
+
|
| 237 |
+
if imagined_video is not None and i < n_imagined:
|
| 238 |
+
img_frame = imagined_video[i]
|
| 239 |
+
if img_frame.dtype != np.uint8 and img_frame.max() <= 1.0001:
|
| 240 |
+
img_frame = (img_frame * 255).astype(np.uint8)
|
| 241 |
+
elif img_frame.dtype != np.uint8:
|
| 242 |
+
img_frame = img_frame.astype(np.uint8)
|
| 243 |
+
|
| 244 |
+
h = int(img_frame.shape[0] * target_width / img_frame.shape[1])
|
| 245 |
+
row_imagined = cv2.resize(img_frame, (target_width, h))
|
| 246 |
+
else:
|
| 247 |
+
row_imagined = np.zeros((300, target_width, 3), dtype=np.uint8)
|
| 248 |
+
cv2.putText(row_imagined, "Coming soon", (target_width//2 - 100, 150),
|
| 249 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1, (100, 100, 100), 2)
|
| 250 |
+
|
| 251 |
+
row_imagined = add_title_bar(row_imagined, "Imagined Video Stream")
|
| 252 |
+
full_frame = np.vstack([row_real, row_imagined])
|
| 253 |
+
final_frames.append(full_frame)
|
| 254 |
+
|
| 255 |
+
imageio.mimsave(save_path, final_frames, fps=fps)
|
| 256 |
+
print(f"Combined video saved to: {save_path}")
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def class_decorator(task_name):
|
| 260 |
+
envs_module = importlib.import_module(f"envs.{task_name}")
|
| 261 |
+
try:
|
| 262 |
+
env_class = getattr(envs_module, task_name)
|
| 263 |
+
env_instance = env_class()
|
| 264 |
+
except:
|
| 265 |
+
raise SystemExit("No Task")
|
| 266 |
+
return env_instance
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def eval_function_decorator(policy_name, model_name):
|
| 270 |
+
try:
|
| 271 |
+
policy_model = importlib.import_module(policy_name)
|
| 272 |
+
return getattr(policy_model, model_name)
|
| 273 |
+
except ImportError as e:
|
| 274 |
+
raise e
|
| 275 |
+
|
| 276 |
+
def get_camera_config(camera_type):
|
| 277 |
+
camera_config_path = os.path.join(robowin_root, "task_config/_camera_config.yml")
|
| 278 |
+
|
| 279 |
+
assert os.path.isfile(camera_config_path), "task config file is missing"
|
| 280 |
+
|
| 281 |
+
with open(camera_config_path, "r", encoding="utf-8") as f:
|
| 282 |
+
args = yaml.load(f.read(), Loader=yaml.FullLoader)
|
| 283 |
+
|
| 284 |
+
assert camera_type in args, f"camera {camera_type} is not defined"
|
| 285 |
+
return args[camera_type]
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def get_embodiment_config(robot_file):
|
| 289 |
+
robot_config_file = os.path.join(robot_file, "config.yml")
|
| 290 |
+
with open(robot_config_file, "r", encoding="utf-8") as f:
|
| 291 |
+
embodiment_args = yaml.load(f.read(), Loader=yaml.FullLoader)
|
| 292 |
+
return embodiment_args
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def main(usr_args):
|
| 296 |
+
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 297 |
+
task_name = usr_args["task_name"]
|
| 298 |
+
task_config = usr_args["task_config"]
|
| 299 |
+
ckpt_setting = usr_args["ckpt_setting"]
|
| 300 |
+
save_root = usr_args["save_root"]
|
| 301 |
+
policy_name = usr_args["policy_name"]
|
| 302 |
+
video_guidance_scale = usr_args["video_guidance_scale"]
|
| 303 |
+
action_guidance_scale = usr_args["action_guidance_scale"]
|
| 304 |
+
instruction_type = 'seen'
|
| 305 |
+
save_dir = None
|
| 306 |
+
video_save_dir = None
|
| 307 |
+
video_size = None
|
| 308 |
+
|
| 309 |
+
with open(f"./task_config/{task_config}.yml", "r", encoding="utf-8") as f:
|
| 310 |
+
args = yaml.load(f.read(), Loader=yaml.FullLoader)
|
| 311 |
+
|
| 312 |
+
args['task_name'] = task_name
|
| 313 |
+
args["task_config"] = task_config
|
| 314 |
+
args["ckpt_setting"] = ckpt_setting
|
| 315 |
+
args["save_root"] = save_root
|
| 316 |
+
|
| 317 |
+
embodiment_type = args.get("embodiment")
|
| 318 |
+
embodiment_config_path = os.path.join(CONFIGS_PATH, "_embodiment_config.yml")
|
| 319 |
+
|
| 320 |
+
with open(embodiment_config_path, "r", encoding="utf-8") as f:
|
| 321 |
+
_embodiment_types = yaml.load(f.read(), Loader=yaml.FullLoader)
|
| 322 |
+
|
| 323 |
+
def get_embodiment_file(embodiment_type):
|
| 324 |
+
robot_file = _embodiment_types[embodiment_type]["file_path"]
|
| 325 |
+
if robot_file is None:
|
| 326 |
+
raise "No embodiment files"
|
| 327 |
+
return robot_file
|
| 328 |
+
|
| 329 |
+
with open(CONFIGS_PATH + "_camera_config.yml", "r", encoding="utf-8") as f:
|
| 330 |
+
_camera_config = yaml.load(f.read(), Loader=yaml.FullLoader)
|
| 331 |
+
|
| 332 |
+
head_camera_type = args["camera"]["head_camera_type"]
|
| 333 |
+
args["head_camera_h"] = _camera_config[head_camera_type]["h"]
|
| 334 |
+
args["head_camera_w"] = _camera_config[head_camera_type]["w"]
|
| 335 |
+
|
| 336 |
+
if len(embodiment_type) == 1:
|
| 337 |
+
args["left_robot_file"] = get_embodiment_file(embodiment_type[0])
|
| 338 |
+
args["right_robot_file"] = get_embodiment_file(embodiment_type[0])
|
| 339 |
+
args["dual_arm_embodied"] = True
|
| 340 |
+
elif len(embodiment_type) == 3:
|
| 341 |
+
args["left_robot_file"] = get_embodiment_file(embodiment_type[0])
|
| 342 |
+
args["right_robot_file"] = get_embodiment_file(embodiment_type[1])
|
| 343 |
+
args["embodiment_dis"] = embodiment_type[2]
|
| 344 |
+
args["dual_arm_embodied"] = False
|
| 345 |
+
else:
|
| 346 |
+
raise "embodiment items should be 1 or 3"
|
| 347 |
+
|
| 348 |
+
args["left_embodiment_config"] = get_embodiment_config(args["left_robot_file"])
|
| 349 |
+
args["right_embodiment_config"] = get_embodiment_config(args["right_robot_file"])
|
| 350 |
+
|
| 351 |
+
if len(embodiment_type) == 1:
|
| 352 |
+
embodiment_name = str(embodiment_type[0])
|
| 353 |
+
else:
|
| 354 |
+
embodiment_name = str(embodiment_type[0]) + "+" + str(embodiment_type[1])
|
| 355 |
+
|
| 356 |
+
save_dir = Path(f"eval_result/{task_name}/{policy_name}/{task_config}/{ckpt_setting}/{current_time}")
|
| 357 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 358 |
+
|
| 359 |
+
if args["eval_video_log"]:
|
| 360 |
+
video_save_dir = save_dir
|
| 361 |
+
camera_config = get_camera_config(args["camera"]["head_camera_type"])
|
| 362 |
+
video_size = str(camera_config["w"]) + "x" + str(camera_config["h"])
|
| 363 |
+
video_save_dir.mkdir(parents=True, exist_ok=True)
|
| 364 |
+
args["eval_video_save_dir"] = video_save_dir
|
| 365 |
+
|
| 366 |
+
print("============= Config =============\n")
|
| 367 |
+
print("\033[95mMessy Table:\033[0m " + str(args["domain_randomization"]["cluttered_table"]))
|
| 368 |
+
print("\033[95mRandom Background:\033[0m " + str(args["domain_randomization"]["random_background"]))
|
| 369 |
+
if args["domain_randomization"]["random_background"]:
|
| 370 |
+
print(" - Clean Background Rate: " + str(args["domain_randomization"]["clean_background_rate"]))
|
| 371 |
+
print("\033[95mRandom Light:\033[0m " + str(args["domain_randomization"]["random_light"]))
|
| 372 |
+
if args["domain_randomization"]["random_light"]:
|
| 373 |
+
print(" - Crazy Random Light Rate: " + str(args["domain_randomization"]["crazy_random_light_rate"]))
|
| 374 |
+
print("\033[95mRandom Table Height:\033[0m " + str(args["domain_randomization"]["random_table_height"]))
|
| 375 |
+
print("\033[95mRandom Head Camera Distance:\033[0m " + str(args["domain_randomization"]["random_head_camera_dis"]))
|
| 376 |
+
|
| 377 |
+
print("\033[94mHead Camera Config:\033[0m " + str(args["camera"]["head_camera_type"]) + f", " +
|
| 378 |
+
str(args["camera"]["collect_head_camera"]))
|
| 379 |
+
print("\033[94mWrist Camera Config:\033[0m " + str(args["camera"]["wrist_camera_type"]) + f", " +
|
| 380 |
+
str(args["camera"]["collect_wrist_camera"]))
|
| 381 |
+
print("\033[94mEmbodiment Config:\033[0m " + embodiment_name)
|
| 382 |
+
print("\n==================================")
|
| 383 |
+
|
| 384 |
+
TASK_ENV = class_decorator(args["task_name"])
|
| 385 |
+
args["policy_name"] = policy_name
|
| 386 |
+
usr_args["left_arm_dim"] = len(args["left_embodiment_config"]["arm_joints_name"][0])
|
| 387 |
+
usr_args["right_arm_dim"] = len(args["right_embodiment_config"]["arm_joints_name"][1])
|
| 388 |
+
|
| 389 |
+
seed = usr_args["seed"]
|
| 390 |
+
|
| 391 |
+
st_seed = 10000 * (1 + seed)
|
| 392 |
+
suc_nums = []
|
| 393 |
+
test_num = usr_args["test_num"]
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
model = WebsocketClientPolicy(port=usr_args['port'])
|
| 397 |
+
|
| 398 |
+
st_seed, suc_num = eval_policy(task_name,
|
| 399 |
+
TASK_ENV,
|
| 400 |
+
args,
|
| 401 |
+
model,
|
| 402 |
+
st_seed,
|
| 403 |
+
test_num=test_num,
|
| 404 |
+
video_size=video_size,
|
| 405 |
+
instruction_type=instruction_type,
|
| 406 |
+
save_visualization=True,
|
| 407 |
+
video_guidance_scale=video_guidance_scale,
|
| 408 |
+
action_guidance_scale=action_guidance_scale)
|
| 409 |
+
suc_nums.append(suc_num)
|
| 410 |
+
|
| 411 |
+
file_path = os.path.join(save_dir, f"_result.txt")
|
| 412 |
+
with open(file_path, "w") as file:
|
| 413 |
+
file.write(f"Timestamp: {current_time}\n\n")
|
| 414 |
+
file.write(f"Instruction Type: {instruction_type}\n\n")
|
| 415 |
+
file.write("\n".join(map(str, np.array(suc_nums) / test_num)))
|
| 416 |
+
|
| 417 |
+
print(f"Data has been saved to {file_path}")
|
| 418 |
+
|
| 419 |
+
def format_obs(observation, prompt):
|
| 420 |
+
return {
|
| 421 |
+
"observation.images.cam_high": observation["observation"]["head_camera"]["rgb"], # H,W,3
|
| 422 |
+
"observation.images.cam_left_wrist": observation["observation"]["left_camera"]["rgb"],
|
| 423 |
+
"observation.images.cam_right_wrist": observation["observation"]["right_camera"]["rgb"],
|
| 424 |
+
"observation.state": observation["joint_action"]["vector"],
|
| 425 |
+
"task": prompt,
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
def add_eef_pose(new_pose, init_pose):
|
| 429 |
+
new_pose_R = R.from_quat(new_pose[3:7][None])
|
| 430 |
+
init_pose_R = R.from_quat(init_pose[3:7][None])
|
| 431 |
+
out_rot = (init_pose_R * new_pose_R).as_quat().reshape(-1)
|
| 432 |
+
out_trans = new_pose[:3] + init_pose[:3]
|
| 433 |
+
return np.concatenate([out_trans, out_rot, new_pose[7:8]])
|
| 434 |
+
|
| 435 |
+
def add_init_pose(new_pose, init_pose):
|
| 436 |
+
left_pose = add_eef_pose(new_pose[:8], init_pose[:8])
|
| 437 |
+
right_pose = add_eef_pose(new_pose[8:], init_pose[8:])
|
| 438 |
+
return np.concatenate([left_pose, right_pose])
|
| 439 |
+
|
| 440 |
+
def eval_policy(task_name,
|
| 441 |
+
TASK_ENV,
|
| 442 |
+
args,
|
| 443 |
+
model,
|
| 444 |
+
st_seed,
|
| 445 |
+
test_num=100,
|
| 446 |
+
video_size=None,
|
| 447 |
+
instruction_type=None,
|
| 448 |
+
save_visualization=False,
|
| 449 |
+
video_guidance_scale=5.0,
|
| 450 |
+
action_guidance_scale=5.0):
|
| 451 |
+
print(f"\033[34mTask Name: {args['task_name']}\033[0m")
|
| 452 |
+
print(f"\033[34mPolicy Name: {args['policy_name']}\033[0m")
|
| 453 |
+
|
| 454 |
+
expert_check = True
|
| 455 |
+
TASK_ENV.suc = 0
|
| 456 |
+
TASK_ENV.test_num = 0
|
| 457 |
+
|
| 458 |
+
now_id = 0
|
| 459 |
+
succ_seed = 0
|
| 460 |
+
suc_test_seed_list = []
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
now_seed = st_seed
|
| 464 |
+
clear_cache_freq = args["clear_cache_freq"]
|
| 465 |
+
|
| 466 |
+
args["eval_mode"] = True
|
| 467 |
+
|
| 468 |
+
while succ_seed < test_num:
|
| 469 |
+
render_freq = args["render_freq"]
|
| 470 |
+
args["render_freq"] = 0
|
| 471 |
+
|
| 472 |
+
if expert_check:
|
| 473 |
+
try:
|
| 474 |
+
TASK_ENV.setup_demo(now_ep_num=now_id, seed=now_seed, is_test=True, **args)
|
| 475 |
+
episode_info = TASK_ENV.play_once()
|
| 476 |
+
TASK_ENV.close_env()
|
| 477 |
+
except UnStableError as e:
|
| 478 |
+
TASK_ENV.close_env()
|
| 479 |
+
now_seed += 1
|
| 480 |
+
args["render_freq"] = render_freq
|
| 481 |
+
continue
|
| 482 |
+
except Exception as e:
|
| 483 |
+
TASK_ENV.close_env()
|
| 484 |
+
now_seed += 1
|
| 485 |
+
args["render_freq"] = render_freq
|
| 486 |
+
print(f"error occurs ! {e}")
|
| 487 |
+
traceback.print_exc()
|
| 488 |
+
continue
|
| 489 |
+
|
| 490 |
+
if (not expert_check) or (TASK_ENV.plan_success and TASK_ENV.check_success()):
|
| 491 |
+
succ_seed += 1
|
| 492 |
+
suc_test_seed_list.append(now_seed)
|
| 493 |
+
else:
|
| 494 |
+
now_seed += 1
|
| 495 |
+
args["render_freq"] = render_freq
|
| 496 |
+
continue
|
| 497 |
+
|
| 498 |
+
args["render_freq"] = render_freq
|
| 499 |
+
|
| 500 |
+
TASK_ENV.setup_demo(now_ep_num=now_id, seed=now_seed, is_test=True, **args)
|
| 501 |
+
episode_info_list = [episode_info["info"]]
|
| 502 |
+
results = generate_episode_descriptions(args["task_name"], episode_info_list, test_num)
|
| 503 |
+
instruction = np.random.choice(results[0][instruction_type])
|
| 504 |
+
TASK_ENV.set_instruction(instruction=instruction) # set language instruction
|
| 505 |
+
|
| 506 |
+
if TASK_ENV.eval_video_path is not None:
|
| 507 |
+
ffmpeg = subprocess.Popen(
|
| 508 |
+
[
|
| 509 |
+
"ffmpeg",
|
| 510 |
+
"-y",
|
| 511 |
+
"-loglevel",
|
| 512 |
+
"error",
|
| 513 |
+
"-f",
|
| 514 |
+
"rawvideo",
|
| 515 |
+
"-pixel_format",
|
| 516 |
+
"rgb24",
|
| 517 |
+
"-video_size",
|
| 518 |
+
video_size,
|
| 519 |
+
"-framerate",
|
| 520 |
+
"10",
|
| 521 |
+
"-i",
|
| 522 |
+
"-",
|
| 523 |
+
"-pix_fmt",
|
| 524 |
+
"yuv420p",
|
| 525 |
+
"-vcodec",
|
| 526 |
+
"libx264",
|
| 527 |
+
"-crf",
|
| 528 |
+
"23",
|
| 529 |
+
f"{TASK_ENV.eval_video_path}/episode{TASK_ENV.test_num}.mp4",
|
| 530 |
+
],
|
| 531 |
+
stdin=subprocess.PIPE,
|
| 532 |
+
)
|
| 533 |
+
TASK_ENV._set_eval_video_ffmpeg(ffmpeg)
|
| 534 |
+
|
| 535 |
+
succ = False
|
| 536 |
+
|
| 537 |
+
prompt = TASK_ENV.get_instruction()
|
| 538 |
+
ret = model.infer(dict(reset = True, prompt=prompt, save_visualization=save_visualization))
|
| 539 |
+
|
| 540 |
+
first = True
|
| 541 |
+
full_obs_list = []
|
| 542 |
+
gen_video_list = []
|
| 543 |
+
full_action_history = []
|
| 544 |
+
|
| 545 |
+
initial_obs = TASK_ENV.get_obs()
|
| 546 |
+
inint_eef_pose = initial_obs['endpose']['left_endpose'] + \
|
| 547 |
+
[initial_obs['endpose']['left_gripper']] + \
|
| 548 |
+
initial_obs['endpose']['right_endpose'] + \
|
| 549 |
+
[initial_obs['endpose']['right_gripper']]
|
| 550 |
+
inint_eef_pose = np.array(inint_eef_pose, dtype=np.float64)
|
| 551 |
+
initial_formatted_obs = format_obs(initial_obs, prompt)
|
| 552 |
+
full_obs_list.append(initial_formatted_obs)
|
| 553 |
+
first_obs = None
|
| 554 |
+
while TASK_ENV.take_action_cnt<TASK_ENV.step_lim:
|
| 555 |
+
if first:
|
| 556 |
+
observation = TASK_ENV.get_obs()
|
| 557 |
+
first_obs = format_obs(observation, prompt)
|
| 558 |
+
|
| 559 |
+
ret = model.infer(dict(obs=first_obs, prompt=prompt, save_visualization=save_visualization, video_guidance_scale=video_guidance_scale, action_guidance_scale=action_guidance_scale)) #(TASK_ENV, model, observation)
|
| 560 |
+
action = ret['action']
|
| 561 |
+
if 'video' in ret:
|
| 562 |
+
imagined_video = ret['video']
|
| 563 |
+
gen_video_list.append(imagined_video)
|
| 564 |
+
key_frame_list = []
|
| 565 |
+
|
| 566 |
+
assert action.shape[2] % 4 == 0
|
| 567 |
+
action_per_frame = action.shape[2] // 4
|
| 568 |
+
|
| 569 |
+
start_idx = 1 if first else 0
|
| 570 |
+
for i in range(start_idx, action.shape[1]):
|
| 571 |
+
for j in range(action.shape[2]):
|
| 572 |
+
raw_action_step = action[:, i, j].flatten()
|
| 573 |
+
full_action_history.append(raw_action_step)
|
| 574 |
+
|
| 575 |
+
ee_action = action[:, i, j]
|
| 576 |
+
if action.shape[0] == 14:
|
| 577 |
+
ee_action = np.concatenate([
|
| 578 |
+
ee_action[:3],
|
| 579 |
+
euler2quat(ee_action[3], ee_action[4], ee_action[5]),
|
| 580 |
+
ee_action[6:10],
|
| 581 |
+
euler2quat(ee_action[10], ee_action[11], ee_action[12]),
|
| 582 |
+
ee_action[13:14]
|
| 583 |
+
])
|
| 584 |
+
elif action.shape[0] == 16:
|
| 585 |
+
ee_action = add_init_pose(ee_action, inint_eef_pose)
|
| 586 |
+
ee_action = np.concatenate([
|
| 587 |
+
ee_action[:3],
|
| 588 |
+
ee_action[3:7] / np.linalg.norm(ee_action[3:7]),
|
| 589 |
+
ee_action[7:11],
|
| 590 |
+
ee_action[11:15] / np.linalg.norm(ee_action[11:15]),
|
| 591 |
+
ee_action[15:16]
|
| 592 |
+
])
|
| 593 |
+
else:
|
| 594 |
+
raise NotImplementedError
|
| 595 |
+
TASK_ENV.take_action(ee_action, action_type='ee')
|
| 596 |
+
|
| 597 |
+
if (j+1) % action_per_frame == 0:
|
| 598 |
+
obs = format_obs(TASK_ENV.get_obs(), prompt)
|
| 599 |
+
full_obs_list.append(obs)
|
| 600 |
+
key_frame_list.append(obs)
|
| 601 |
+
|
| 602 |
+
first = False
|
| 603 |
+
|
| 604 |
+
model.infer(dict(obs = key_frame_list, compute_kv_cache=True, imagine=False, save_visualization=save_visualization, state=action))
|
| 605 |
+
|
| 606 |
+
if TASK_ENV.eval_success:
|
| 607 |
+
succ = True
|
| 608 |
+
break
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
vis_dir = Path(args['save_root']) / f'stseed-{st_seed}' / 'visualization' / task_name
|
| 612 |
+
vis_dir.mkdir(parents=True, exist_ok=True)
|
| 613 |
+
video_name = f"{TASK_ENV.test_num}_{prompt.replace(' ', '_')}_{succ}.mp4"
|
| 614 |
+
out_img_file = vis_dir / video_name
|
| 615 |
+
save_comparison_video(
|
| 616 |
+
real_obs_list=full_obs_list,
|
| 617 |
+
imagined_video=None, #gen_video_list,
|
| 618 |
+
action_history=full_action_history,
|
| 619 |
+
save_path=str(out_img_file),
|
| 620 |
+
fps=15 # Suggest adjusting fps based on simulation step
|
| 621 |
+
)
|
| 622 |
+
if TASK_ENV.eval_video_path is not None:
|
| 623 |
+
TASK_ENV._del_eval_video_ffmpeg()
|
| 624 |
+
|
| 625 |
+
if succ:
|
| 626 |
+
TASK_ENV.suc += 1
|
| 627 |
+
print("\033[92mSuccess!\033[0m")
|
| 628 |
+
else:
|
| 629 |
+
print("\033[91mFail!\033[0m")
|
| 630 |
+
|
| 631 |
+
now_id += 1
|
| 632 |
+
TASK_ENV.close_env(clear_cache=((succ_seed + 1) % clear_cache_freq == 0))
|
| 633 |
+
|
| 634 |
+
if TASK_ENV.render_freq:
|
| 635 |
+
TASK_ENV.viewer.close()
|
| 636 |
+
|
| 637 |
+
TASK_ENV.test_num += 1
|
| 638 |
+
|
| 639 |
+
save_dir = Path(args['save_root']) / f'stseed-{st_seed}' / 'metrics' / task_name
|
| 640 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 641 |
+
out_json_file = save_dir / 'res.json'
|
| 642 |
+
write_json({
|
| 643 |
+
"succ_num": float(TASK_ENV.suc),
|
| 644 |
+
"total_num": float(TASK_ENV.test_num),
|
| 645 |
+
"succ_rate": float(TASK_ENV.suc / TASK_ENV.test_num),
|
| 646 |
+
}, out_json_file)
|
| 647 |
+
|
| 648 |
+
print(
|
| 649 |
+
f"\033[93m{task_name}\033[0m | \033[94m{args['policy_name']}\033[0m | \033[92m{args['task_config']}\033[0m | \033[91m{args['ckpt_setting']}\033[0m\n"
|
| 650 |
+
f"Success rate: \033[96m{TASK_ENV.suc}/{TASK_ENV.test_num}\033[0m => \033[95m{round(TASK_ENV.suc/TASK_ENV.test_num*100, 1)}%\033[0m, current seed: \033[90m{now_seed}\033[0m\n"
|
| 651 |
+
)
|
| 652 |
+
now_seed += 1
|
| 653 |
+
|
| 654 |
+
return now_seed, TASK_ENV.suc
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
def parse_args_and_config():
|
| 658 |
+
parser = argparse.ArgumentParser()
|
| 659 |
+
parser.add_argument("--config", type=str, required=True)
|
| 660 |
+
parser.add_argument("--overrides", nargs=argparse.REMAINDER)
|
| 661 |
+
parser.add_argument("--port", type=int, default=8000, help='remote policy socket port.')
|
| 662 |
+
parser.add_argument("--save_root", type=str, default="results/default_vis_path")
|
| 663 |
+
parser.add_argument("--video_guidance_scale", type=float, default=5.0)
|
| 664 |
+
parser.add_argument("--action_guidance_scale", type=float, default=5.0)
|
| 665 |
+
parser.add_argument("--test_num", type=int, default=100)
|
| 666 |
+
args = parser.parse_args()
|
| 667 |
+
|
| 668 |
+
with open(args.config, "r", encoding="utf-8") as f:
|
| 669 |
+
config = yaml.safe_load(f)
|
| 670 |
+
|
| 671 |
+
# Parse overrides
|
| 672 |
+
def parse_override_pairs(pairs):
|
| 673 |
+
override_dict = {}
|
| 674 |
+
for i in range(0, len(pairs), 2):
|
| 675 |
+
key = pairs[i].lstrip("--")
|
| 676 |
+
value = pairs[i + 1]
|
| 677 |
+
try:
|
| 678 |
+
value = eval(value)
|
| 679 |
+
except:
|
| 680 |
+
pass
|
| 681 |
+
override_dict[key] = value
|
| 682 |
+
return override_dict
|
| 683 |
+
|
| 684 |
+
if args.overrides:
|
| 685 |
+
overrides = parse_override_pairs(args.overrides)
|
| 686 |
+
config.update(overrides)
|
| 687 |
+
|
| 688 |
+
return config
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
if __name__ == "__main__":
|
| 692 |
+
|
| 693 |
+
Sapien_TEST()
|
| 694 |
+
usr_args = parse_args_and_config()
|
| 695 |
+
main(usr_args)
|
| 696 |
+
|
evaluation/robotwin/geometry.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mostly copied from transforms3d library
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
_FLOAT_EPS = np.finfo(np.float64).eps
|
| 11 |
+
|
| 12 |
+
# axis sequences for Euler angles
|
| 13 |
+
_NEXT_AXIS = [1, 2, 0, 1]
|
| 14 |
+
|
| 15 |
+
# map axes strings to/from tuples of inner axis, parity, repetition, frame
|
| 16 |
+
_AXES2TUPLE = {
|
| 17 |
+
"sxyz": (0, 0, 0, 0),
|
| 18 |
+
"sxyx": (0, 0, 1, 0),
|
| 19 |
+
"sxzy": (0, 1, 0, 0),
|
| 20 |
+
"sxzx": (0, 1, 1, 0),
|
| 21 |
+
"syzx": (1, 0, 0, 0),
|
| 22 |
+
"syzy": (1, 0, 1, 0),
|
| 23 |
+
"syxz": (1, 1, 0, 0),
|
| 24 |
+
"syxy": (1, 1, 1, 0),
|
| 25 |
+
"szxy": (2, 0, 0, 0),
|
| 26 |
+
"szxz": (2, 0, 1, 0),
|
| 27 |
+
"szyx": (2, 1, 0, 0),
|
| 28 |
+
"szyz": (2, 1, 1, 0),
|
| 29 |
+
"rzyx": (0, 0, 0, 1),
|
| 30 |
+
"rxyx": (0, 0, 1, 1),
|
| 31 |
+
"ryzx": (0, 1, 0, 1),
|
| 32 |
+
"rxzx": (0, 1, 1, 1),
|
| 33 |
+
"rxzy": (1, 0, 0, 1),
|
| 34 |
+
"ryzy": (1, 0, 1, 1),
|
| 35 |
+
"rzxy": (1, 1, 0, 1),
|
| 36 |
+
"ryxy": (1, 1, 1, 1),
|
| 37 |
+
"ryxz": (2, 0, 0, 1),
|
| 38 |
+
"rzxz": (2, 0, 1, 1),
|
| 39 |
+
"rxyz": (2, 1, 0, 1),
|
| 40 |
+
"rzyz": (2, 1, 1, 1),
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
_TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items())
|
| 44 |
+
|
| 45 |
+
# For testing whether a number is close to zero
|
| 46 |
+
_EPS4 = np.finfo(float).eps * 4.0
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def mat2euler(mat, axes="sxyz"):
|
| 50 |
+
"""Return Euler angles from rotation matrix for specified axis sequence.
|
| 51 |
+
|
| 52 |
+
Note that many Euler angle triplets can describe one matrix.
|
| 53 |
+
|
| 54 |
+
Parameters
|
| 55 |
+
----------
|
| 56 |
+
mat : array-like shape (3, 3) or (4, 4)
|
| 57 |
+
Rotation matrix or affine.
|
| 58 |
+
axes : str, optional
|
| 59 |
+
Axis specification; one of 24 axis sequences as string or encoded
|
| 60 |
+
tuple - e.g. ``sxyz`` (the default).
|
| 61 |
+
|
| 62 |
+
Returns
|
| 63 |
+
-------
|
| 64 |
+
ai : float
|
| 65 |
+
First rotation angle (according to `axes`).
|
| 66 |
+
aj : float
|
| 67 |
+
Second rotation angle (according to `axes`).
|
| 68 |
+
ak : float
|
| 69 |
+
Third rotation angle (according to `axes`).
|
| 70 |
+
|
| 71 |
+
Examples
|
| 72 |
+
--------
|
| 73 |
+
>>> R0 = euler2mat(1, 2, 3, 'syxz')
|
| 74 |
+
>>> al, be, ga = mat2euler(R0, 'syxz')
|
| 75 |
+
>>> R1 = euler2mat(al, be, ga, 'syxz')
|
| 76 |
+
>>> np.allclose(R0, R1)
|
| 77 |
+
True
|
| 78 |
+
"""
|
| 79 |
+
try:
|
| 80 |
+
firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
|
| 81 |
+
except (AttributeError, KeyError):
|
| 82 |
+
_TUPLE2AXES[axes] # validation
|
| 83 |
+
firstaxis, parity, repetition, frame = axes
|
| 84 |
+
|
| 85 |
+
i = firstaxis
|
| 86 |
+
j = _NEXT_AXIS[i + parity]
|
| 87 |
+
k = _NEXT_AXIS[i - parity + 1]
|
| 88 |
+
|
| 89 |
+
M = np.array(mat, dtype=np.float64, copy=False)[:3, :3]
|
| 90 |
+
if repetition:
|
| 91 |
+
sy = math.sqrt(M[i, j] * M[i, j] + M[i, k] * M[i, k])
|
| 92 |
+
if sy > _EPS4:
|
| 93 |
+
ax = math.atan2(M[i, j], M[i, k])
|
| 94 |
+
ay = math.atan2(sy, M[i, i])
|
| 95 |
+
az = math.atan2(M[j, i], -M[k, i])
|
| 96 |
+
else:
|
| 97 |
+
ax = math.atan2(-M[j, k], M[j, j])
|
| 98 |
+
ay = math.atan2(sy, M[i, i])
|
| 99 |
+
az = 0.0
|
| 100 |
+
else:
|
| 101 |
+
cy = math.sqrt(M[i, i] * M[i, i] + M[j, i] * M[j, i])
|
| 102 |
+
if cy > _EPS4:
|
| 103 |
+
ax = math.atan2(M[k, j], M[k, k])
|
| 104 |
+
ay = math.atan2(-M[k, i], cy)
|
| 105 |
+
az = math.atan2(M[j, i], M[i, i])
|
| 106 |
+
else:
|
| 107 |
+
ax = math.atan2(-M[j, k], M[j, j])
|
| 108 |
+
ay = math.atan2(-M[k, i], cy)
|
| 109 |
+
az = 0.0
|
| 110 |
+
|
| 111 |
+
if parity:
|
| 112 |
+
ax, ay, az = -ax, -ay, -az
|
| 113 |
+
if frame:
|
| 114 |
+
ax, az = az, ax
|
| 115 |
+
return ax, ay, az
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def quat2mat(q):
|
| 119 |
+
"""Calculate rotation matrix corresponding to quaternion
|
| 120 |
+
|
| 121 |
+
Parameters
|
| 122 |
+
----------
|
| 123 |
+
q : 4 element array-like
|
| 124 |
+
|
| 125 |
+
Returns
|
| 126 |
+
-------
|
| 127 |
+
M : (3,3) array
|
| 128 |
+
Rotation matrix corresponding to input quaternion *q*
|
| 129 |
+
|
| 130 |
+
Notes
|
| 131 |
+
-----
|
| 132 |
+
Rotation matrix applies to column vectors, and is applied to the
|
| 133 |
+
left of coordinate vectors. The algorithm here allows quaternions that
|
| 134 |
+
have not been normalized.
|
| 135 |
+
|
| 136 |
+
References
|
| 137 |
+
----------
|
| 138 |
+
Algorithm from http://en.wikipedia.org/wiki/Rotation_matrix#Quaternion
|
| 139 |
+
|
| 140 |
+
Examples
|
| 141 |
+
--------
|
| 142 |
+
>>> import numpy as np
|
| 143 |
+
>>> M = quat2mat([1, 0, 0, 0]) # Identity quaternion
|
| 144 |
+
>>> np.allclose(M, np.eye(3))
|
| 145 |
+
True
|
| 146 |
+
>>> M = quat2mat([0, 1, 0, 0]) # 180 degree rotn around axis 0
|
| 147 |
+
>>> np.allclose(M, np.diag([1, -1, -1]))
|
| 148 |
+
True
|
| 149 |
+
"""
|
| 150 |
+
w, x, y, z = q
|
| 151 |
+
Nq = w * w + x * x + y * y + z * z
|
| 152 |
+
if Nq < _FLOAT_EPS:
|
| 153 |
+
return np.eye(3)
|
| 154 |
+
s = 2.0 / Nq
|
| 155 |
+
X = x * s
|
| 156 |
+
Y = y * s
|
| 157 |
+
Z = z * s
|
| 158 |
+
wX = w * X
|
| 159 |
+
wY = w * Y
|
| 160 |
+
wZ = w * Z
|
| 161 |
+
xX = x * X
|
| 162 |
+
xY = x * Y
|
| 163 |
+
xZ = x * Z
|
| 164 |
+
yY = y * Y
|
| 165 |
+
yZ = y * Z
|
| 166 |
+
zZ = z * Z
|
| 167 |
+
return np.array(
|
| 168 |
+
[
|
| 169 |
+
[1.0 - (yY + zZ), xY - wZ, xZ + wY],
|
| 170 |
+
[xY + wZ, 1.0 - (xX + zZ), yZ - wX],
|
| 171 |
+
[xZ - wY, yZ + wX, 1.0 - (xX + yY)],
|
| 172 |
+
]
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# Checks if a matrix is a valid rotation matrix.
|
| 177 |
+
def isrotation(
|
| 178 |
+
R: np.ndarray,
|
| 179 |
+
thresh=1e-6,
|
| 180 |
+
) -> bool:
|
| 181 |
+
Rt = np.transpose(R)
|
| 182 |
+
shouldBeIdentity = np.dot(Rt, R)
|
| 183 |
+
iden = np.identity(3, dtype=R.dtype)
|
| 184 |
+
n = np.linalg.norm(iden - shouldBeIdentity)
|
| 185 |
+
return n < thresh
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def euler2mat(ai, aj, ak, axes="sxyz"):
|
| 189 |
+
"""Return rotation matrix from Euler angles and axis sequence.
|
| 190 |
+
|
| 191 |
+
Parameters
|
| 192 |
+
----------
|
| 193 |
+
ai : float
|
| 194 |
+
First rotation angle (according to `axes`).
|
| 195 |
+
aj : float
|
| 196 |
+
Second rotation angle (according to `axes`).
|
| 197 |
+
ak : float
|
| 198 |
+
Third rotation angle (according to `axes`).
|
| 199 |
+
axes : str, optional
|
| 200 |
+
Axis specification; one of 24 axis sequences as string or encoded
|
| 201 |
+
tuple - e.g. ``sxyz`` (the default).
|
| 202 |
+
|
| 203 |
+
Returns
|
| 204 |
+
-------
|
| 205 |
+
mat : array (3, 3)
|
| 206 |
+
Rotation matrix or affine.
|
| 207 |
+
|
| 208 |
+
Examples
|
| 209 |
+
--------
|
| 210 |
+
>>> R = euler2mat(1, 2, 3, 'syxz')
|
| 211 |
+
>>> np.allclose(np.sum(R[0]), -1.34786452)
|
| 212 |
+
True
|
| 213 |
+
>>> R = euler2mat(1, 2, 3, (0, 1, 0, 1))
|
| 214 |
+
>>> np.allclose(np.sum(R[0]), -0.383436184)
|
| 215 |
+
True
|
| 216 |
+
"""
|
| 217 |
+
try:
|
| 218 |
+
firstaxis, parity, repetition, frame = _AXES2TUPLE[axes]
|
| 219 |
+
except (AttributeError, KeyError):
|
| 220 |
+
_TUPLE2AXES[axes] # validation
|
| 221 |
+
firstaxis, parity, repetition, frame = axes
|
| 222 |
+
|
| 223 |
+
i = firstaxis
|
| 224 |
+
j = _NEXT_AXIS[i + parity]
|
| 225 |
+
k = _NEXT_AXIS[i - parity + 1]
|
| 226 |
+
|
| 227 |
+
if frame:
|
| 228 |
+
ai, ak = ak, ai
|
| 229 |
+
if parity:
|
| 230 |
+
ai, aj, ak = -ai, -aj, -ak
|
| 231 |
+
|
| 232 |
+
si, sj, sk = math.sin(ai), math.sin(aj), math.sin(ak)
|
| 233 |
+
ci, cj, ck = math.cos(ai), math.cos(aj), math.cos(ak)
|
| 234 |
+
cc, cs = ci * ck, ci * sk
|
| 235 |
+
sc, ss = si * ck, si * sk
|
| 236 |
+
|
| 237 |
+
M = np.eye(3)
|
| 238 |
+
if repetition:
|
| 239 |
+
M[i, i] = cj
|
| 240 |
+
M[i, j] = sj * si
|
| 241 |
+
M[i, k] = sj * ci
|
| 242 |
+
M[j, i] = sj * sk
|
| 243 |
+
M[j, j] = -cj * ss + cc
|
| 244 |
+
M[j, k] = -cj * cs - sc
|
| 245 |
+
M[k, i] = -sj * ck
|
| 246 |
+
M[k, j] = cj * sc + cs
|
| 247 |
+
M[k, k] = cj * cc - ss
|
| 248 |
+
else:
|
| 249 |
+
M[i, i] = cj * ck
|
| 250 |
+
M[i, j] = sj * sc - cs
|
| 251 |
+
M[i, k] = sj * cc + ss
|
| 252 |
+
M[j, i] = cj * sk
|
| 253 |
+
M[j, j] = sj * ss + cc
|
| 254 |
+
M[j, k] = sj * cs - sc
|
| 255 |
+
M[k, i] = -sj
|
| 256 |
+
M[k, j] = cj * si
|
| 257 |
+
M[k, k] = cj * ci
|
| 258 |
+
return M
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def euler2axangle(ai, aj, ak, axes="sxyz"):
|
| 262 |
+
"""Return angle, axis corresponding to Euler angles, axis specification
|
| 263 |
+
|
| 264 |
+
Parameters
|
| 265 |
+
----------
|
| 266 |
+
ai : float
|
| 267 |
+
First rotation angle (according to `axes`).
|
| 268 |
+
aj : float
|
| 269 |
+
Second rotation angle (according to `axes`).
|
| 270 |
+
ak : float
|
| 271 |
+
Third rotation angle (according to `axes`).
|
| 272 |
+
axes : str, optional
|
| 273 |
+
Axis specification; one of 24 axis sequences as string or encoded
|
| 274 |
+
tuple - e.g. ``sxyz`` (the default).
|
| 275 |
+
|
| 276 |
+
Returns
|
| 277 |
+
-------
|
| 278 |
+
vector : array shape (3,)
|
| 279 |
+
axis around which rotation occurs
|
| 280 |
+
theta : scalar
|
| 281 |
+
angle of rotation
|
| 282 |
+
|
| 283 |
+
Examples
|
| 284 |
+
--------
|
| 285 |
+
>>> vec, theta = euler2axangle(0, 1.5, 0, 'szyx')
|
| 286 |
+
>>> np.allclose(vec, [0, 1, 0])
|
| 287 |
+
True
|
| 288 |
+
>>> theta
|
| 289 |
+
1.5
|
| 290 |
+
"""
|
| 291 |
+
return quat2axangle(euler2quat(ai, aj, ak, axes))
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def euler2quat(ai, aj, ak, axes="sxyz"):
|
| 295 |
+
"""Return `quaternion` from Euler angles and axis sequence `axes`
|
| 296 |
+
|
| 297 |
+
Parameters
|
| 298 |
+
----------
|
| 299 |
+
ai : float
|
| 300 |
+
First rotation angle (according to `axes`).
|
| 301 |
+
aj : float
|
| 302 |
+
Second rotation angle (according to `axes`).
|
| 303 |
+
ak : float
|
| 304 |
+
Third rotation angle (according to `axes`).
|
| 305 |
+
axes : str, optional
|
| 306 |
+
Axis specification; one of 24 axis sequences as string or encoded
|
| 307 |
+
tuple - e.g. ``sxyz`` (the default).
|
| 308 |
+
|
| 309 |
+
Returns
|
| 310 |
+
-------
|
| 311 |
+
quat : array shape (4,)
|
| 312 |
+
Quaternion in w, x, y z (real, then vector) format
|
| 313 |
+
|
| 314 |
+
Examples
|
| 315 |
+
--------
|
| 316 |
+
>>> q = euler2quat(1, 2, 3, 'ryxz')
|
| 317 |
+
>>> np.allclose(q, [0.435953, 0.310622, -0.718287, 0.444435])
|
| 318 |
+
True
|
| 319 |
+
"""
|
| 320 |
+
try:
|
| 321 |
+
firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
|
| 322 |
+
except (AttributeError, KeyError):
|
| 323 |
+
_TUPLE2AXES[axes] # validation
|
| 324 |
+
firstaxis, parity, repetition, frame = axes
|
| 325 |
+
|
| 326 |
+
i = firstaxis + 1
|
| 327 |
+
j = _NEXT_AXIS[i + parity - 1] + 1
|
| 328 |
+
k = _NEXT_AXIS[i - parity] + 1
|
| 329 |
+
|
| 330 |
+
if frame:
|
| 331 |
+
ai, ak = ak, ai
|
| 332 |
+
if parity:
|
| 333 |
+
aj = -aj
|
| 334 |
+
|
| 335 |
+
ai = ai / 2.0
|
| 336 |
+
aj = aj / 2.0
|
| 337 |
+
ak = ak / 2.0
|
| 338 |
+
ci = math.cos(ai)
|
| 339 |
+
si = math.sin(ai)
|
| 340 |
+
cj = math.cos(aj)
|
| 341 |
+
sj = math.sin(aj)
|
| 342 |
+
ck = math.cos(ak)
|
| 343 |
+
sk = math.sin(ak)
|
| 344 |
+
cc = ci * ck
|
| 345 |
+
cs = ci * sk
|
| 346 |
+
sc = si * ck
|
| 347 |
+
ss = si * sk
|
| 348 |
+
|
| 349 |
+
q = np.empty((4,))
|
| 350 |
+
if repetition:
|
| 351 |
+
q[0] = cj * (cc - ss)
|
| 352 |
+
q[i] = cj * (cs + sc)
|
| 353 |
+
q[j] = sj * (cc + ss)
|
| 354 |
+
q[k] = sj * (cs - sc)
|
| 355 |
+
else:
|
| 356 |
+
q[0] = cj * cc + sj * ss
|
| 357 |
+
q[i] = cj * sc - sj * cs
|
| 358 |
+
q[j] = cj * ss + sj * cc
|
| 359 |
+
q[k] = cj * cs - sj * sc
|
| 360 |
+
if parity:
|
| 361 |
+
q[j] *= -1.0
|
| 362 |
+
|
| 363 |
+
return q
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def quat2axangle(quat, identity_thresh=None):
|
| 367 |
+
"""Convert quaternion to rotation of angle around axis
|
| 368 |
+
|
| 369 |
+
Parameters
|
| 370 |
+
----------
|
| 371 |
+
quat : 4 element sequence
|
| 372 |
+
w, x, y, z forming quaternion.
|
| 373 |
+
identity_thresh : None or scalar, optional
|
| 374 |
+
Threshold below which the norm of the vector part of the quaternion (x,
|
| 375 |
+
y, z) is deemed to be 0, leading to the identity rotation. None (the
|
| 376 |
+
default) leads to a threshold estimated based on the precision of the
|
| 377 |
+
input.
|
| 378 |
+
|
| 379 |
+
Returns
|
| 380 |
+
-------
|
| 381 |
+
theta : scalar
|
| 382 |
+
angle of rotation.
|
| 383 |
+
vector : array shape (3,)
|
| 384 |
+
axis around which rotation occurs.
|
| 385 |
+
|
| 386 |
+
Examples
|
| 387 |
+
--------
|
| 388 |
+
>>> vec, theta = quat2axangle([0, 1, 0, 0])
|
| 389 |
+
>>> vec
|
| 390 |
+
array([1., 0., 0.])
|
| 391 |
+
>>> np.allclose(theta, np.pi)
|
| 392 |
+
True
|
| 393 |
+
|
| 394 |
+
If this is an identity rotation, we return a zero angle and an arbitrary
|
| 395 |
+
vector:
|
| 396 |
+
|
| 397 |
+
>>> quat2axangle([1, 0, 0, 0])
|
| 398 |
+
(array([1., 0., 0.]), 0.0)
|
| 399 |
+
|
| 400 |
+
If any of the quaternion values are not finite, we return a NaN in the
|
| 401 |
+
angle, and an arbitrary vector:
|
| 402 |
+
|
| 403 |
+
>>> quat2axangle([1, np.inf, 0, 0])
|
| 404 |
+
(array([1., 0., 0.]), nan)
|
| 405 |
+
|
| 406 |
+
Notes
|
| 407 |
+
-----
|
| 408 |
+
A quaternion for which x, y, z are all equal to 0, is an identity rotation.
|
| 409 |
+
In this case we return a 0 angle and an arbitrary vector, here [1, 0, 0].
|
| 410 |
+
|
| 411 |
+
The algorithm allows for quaternions that have not been normalized.
|
| 412 |
+
"""
|
| 413 |
+
quat = np.asarray(quat)
|
| 414 |
+
Nq = np.sum(quat**2)
|
| 415 |
+
if not np.isfinite(Nq):
|
| 416 |
+
return np.array([1.0, 0, 0]), float("nan")
|
| 417 |
+
if identity_thresh is None:
|
| 418 |
+
try:
|
| 419 |
+
identity_thresh = np.finfo(Nq.type).eps * 3
|
| 420 |
+
except (AttributeError, ValueError): # Not a numpy type or not float
|
| 421 |
+
identity_thresh = _FLOAT_EPS * 3
|
| 422 |
+
if Nq < _FLOAT_EPS**2: # Results unreliable after normalization
|
| 423 |
+
return np.array([1.0, 0, 0]), 0.0
|
| 424 |
+
if Nq != 1: # Normalize if not normalized
|
| 425 |
+
s = math.sqrt(Nq)
|
| 426 |
+
quat = quat / s
|
| 427 |
+
xyz = quat[1:]
|
| 428 |
+
len2 = np.sum(xyz**2)
|
| 429 |
+
if len2 < identity_thresh**2:
|
| 430 |
+
# if vec is nearly 0,0,0, this is an identity rotation
|
| 431 |
+
return np.array([1.0, 0, 0]), 0.0
|
| 432 |
+
# Make sure w is not slightly above 1 or below -1
|
| 433 |
+
theta = 2 * math.acos(max(min(quat[0], 1), -1))
|
| 434 |
+
return xyz / math.sqrt(len2), theta
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def quat2euler(quaternion, axes="sxyz"):
|
| 438 |
+
"""Euler angles from `quaternion` for specified axis sequence `axes`
|
| 439 |
+
|
| 440 |
+
Parameters
|
| 441 |
+
----------
|
| 442 |
+
q : 4 element sequence
|
| 443 |
+
w, x, y, z of quaternion
|
| 444 |
+
axes : str, optional
|
| 445 |
+
Axis specification; one of 24 axis sequences as string or encoded
|
| 446 |
+
tuple - e.g. ``sxyz`` (the default).
|
| 447 |
+
|
| 448 |
+
Returns
|
| 449 |
+
-------
|
| 450 |
+
ai : float
|
| 451 |
+
First rotation angle (according to `axes`).
|
| 452 |
+
aj : float
|
| 453 |
+
Second rotation angle (according to `axes`).
|
| 454 |
+
ak : float
|
| 455 |
+
Third rotation angle (according to `axes`).
|
| 456 |
+
|
| 457 |
+
Examples
|
| 458 |
+
--------
|
| 459 |
+
>>> angles = quat2euler([0.99810947, 0.06146124, 0, 0])
|
| 460 |
+
>>> np.allclose(angles, [0.123, 0, 0])
|
| 461 |
+
True
|
| 462 |
+
"""
|
| 463 |
+
return mat2euler(quat2mat(quaternion), axes)
|
evaluation/robotwin/launch_client.sh
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
export LD_LIBRARY_PATH=/usr/lib64:/usr/lib:$LD_LIBRARY_PATH
|
| 3 |
+
|
| 4 |
+
task_groups=(
|
| 5 |
+
"stack_bowls_three handover_block hanging_mug scan_object lift_pot put_object_cabinet stack_blocks_three place_shoe"
|
| 6 |
+
"adjust_bottle place_mouse_pad dump_bin_bigbin move_pillbottle_pad pick_dual_bottles shake_bottle place_fan turn_switch"
|
| 7 |
+
"shake_bottle_horizontally place_container_plate rotate_qrcode place_object_stand put_bottles_dustbin move_stapler_pad place_burger_fries place_bread_basket"
|
| 8 |
+
"pick_diverse_bottles open_microwave beat_block_hammer press_stapler click_bell move_playingcard_away open_laptop move_can_pot"
|
| 9 |
+
"stack_bowls_two place_a2b_right stamp_seal place_object_basket handover_mic place_bread_skillet stack_blocks_two place_cans_plasticbox"
|
| 10 |
+
"click_alarmclock blocks_ranking_size place_phone_stand place_can_basket place_object_scale place_a2b_left grab_roller place_dual_shoes"
|
| 11 |
+
"place_empty_cup blocks_ranking_rgb place_empty_cup blocks_ranking_rgb place_empty_cup blocks_ranking_rgb place_empty_cup blocks_ranking_rgb"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
save_root='./results'
|
| 15 |
+
task_name="adjust_bottle"
|
| 16 |
+
|
| 17 |
+
policy_name=ACT
|
| 18 |
+
task_config=demo_clean
|
| 19 |
+
train_config_name=0
|
| 20 |
+
model_name=0
|
| 21 |
+
seed=0
|
| 22 |
+
PORT=29056
|
| 23 |
+
|
| 24 |
+
PYTHONWARNINGS=ignore::UserWarning \
|
| 25 |
+
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 python -m evaluation.robotwin.eval_polict_client_openpi --config policy/$policy_name/deploy_policy.yml \
|
| 26 |
+
--overrides \
|
| 27 |
+
--task_name ${task_name} \
|
| 28 |
+
--task_config ${task_config} \
|
| 29 |
+
--train_config_name ${train_config_name} \
|
| 30 |
+
--model_name ${model_name} \
|
| 31 |
+
--ckpt_setting ${model_name} \
|
| 32 |
+
--seed ${seed} \
|
| 33 |
+
--policy_name ${policy_name} \
|
| 34 |
+
--save_root ${save_root} \
|
| 35 |
+
--video_guidance_scale 5 \
|
| 36 |
+
--action_guidance_scale 1 \
|
| 37 |
+
--test_num 100 \
|
| 38 |
+
--port ${PORT}
|
| 39 |
+
|
| 40 |
+
|
evaluation/robotwin/launch_client_multigpus.sh
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
export LD_LIBRARY_PATH=/usr/lib64:/usr/lib:$LD_LIBRARY_PATH
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
save_root=${1:-'./results'}
|
| 6 |
+
|
| 7 |
+
# General parameters
|
| 8 |
+
policy_name=ACT
|
| 9 |
+
task_config=demo_clean
|
| 10 |
+
train_config_name=0
|
| 11 |
+
model_name=0
|
| 12 |
+
seed=${3:-0}
|
| 13 |
+
test_num=${4:-100}
|
| 14 |
+
start_port=29556
|
| 15 |
+
num_gpus=8
|
| 16 |
+
|
| 17 |
+
task_list_id=${2:-0}
|
| 18 |
+
|
| 19 |
+
task_groups=(
|
| 20 |
+
"stack_bowls_three handover_block hanging_mug scan_object lift_pot put_object_cabinet stack_blocks_three place_shoe"
|
| 21 |
+
"adjust_bottle place_mouse_pad dump_bin_bigbin move_pillbottle_pad pick_dual_bottles shake_bottle place_fan turn_switch"
|
| 22 |
+
"shake_bottle_horizontally place_container_plate rotate_qrcode place_object_stand put_bottles_dustbin move_stapler_pad place_burger_fries place_bread_basket"
|
| 23 |
+
"pick_diverse_bottles open_microwave beat_block_hammer press_stapler click_bell move_playingcard_away open_laptop move_can_pot"
|
| 24 |
+
"stack_bowls_two place_a2b_right stamp_seal place_object_basket handover_mic place_bread_skillet stack_blocks_two place_cans_plasticbox"
|
| 25 |
+
"click_alarmclock blocks_ranking_size place_phone_stand place_can_basket place_object_scale place_a2b_left grab_roller place_dual_shoes"
|
| 26 |
+
"place_empty_cup blocks_ranking_rgb place_empty_cup blocks_ranking_rgb place_empty_cup blocks_ranking_rgb place_empty_cup blocks_ranking_rgb"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
if (( task_list_id < 0 || task_list_id >= ${#task_groups[@]} )); then
|
| 30 |
+
echo "task_list_id out of range: $task_list_id (0..$(( ${#task_groups[@]} - 1 )))" >&2
|
| 31 |
+
exit 1
|
| 32 |
+
fi
|
| 33 |
+
|
| 34 |
+
read -r -a task_names <<< "${task_groups[$task_list_id]}"
|
| 35 |
+
|
| 36 |
+
echo "task_list_id=$task_list_id"
|
| 37 |
+
printf 'task_names (%d): %s\n' "${#task_names[@]}" "${task_names[*]}"
|
| 38 |
+
|
| 39 |
+
log_dir="./logs"
|
| 40 |
+
mkdir -p "$log_dir"
|
| 41 |
+
|
| 42 |
+
echo -e "\033[32mLaunching ${#task_names[@]} tasks. GPUs assigned by mod ${num_gpus}, ports starting from ${start_port} incrementing.\033[0m"
|
| 43 |
+
|
| 44 |
+
pid_file="pids.txt"
|
| 45 |
+
> "$pid_file"
|
| 46 |
+
|
| 47 |
+
batch_time=$(date +%Y%m%d_%H%M%S)
|
| 48 |
+
|
| 49 |
+
for i in "${!task_names[@]}"; do
|
| 50 |
+
task_name="${task_names[$i]}"
|
| 51 |
+
gpu_id=$(( i % num_gpus ))
|
| 52 |
+
port=$(( start_port + i ))
|
| 53 |
+
|
| 54 |
+
export CUDA_VISIBLE_DEVICES=${gpu_id}
|
| 55 |
+
|
| 56 |
+
log_file="${log_dir}/${task_name}_${batch_time}.log"
|
| 57 |
+
|
| 58 |
+
echo -e "\033[33m[Task $i] Task: ${task_name}, GPU: ${gpu_id}, PORT: ${port}, Log: ${log_file}\033[0m"
|
| 59 |
+
|
| 60 |
+
PYTHONWARNINGS=ignore::UserWarning \
|
| 61 |
+
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 python -m evaluation.robotwin.eval_polict_client_openpi --config policy/$policy_name/deploy_policy.yml \
|
| 62 |
+
--overrides \
|
| 63 |
+
--task_name ${task_name} \
|
| 64 |
+
--task_config ${task_config} \
|
| 65 |
+
--train_config_name ${train_config_name} \
|
| 66 |
+
--model_name ${model_name} \
|
| 67 |
+
--ckpt_setting ${model_name} \
|
| 68 |
+
--seed ${seed} \
|
| 69 |
+
--policy_name ${policy_name} \
|
| 70 |
+
--save_root ${save_root} \
|
| 71 |
+
--video_guidance_scale 5 \
|
| 72 |
+
--action_guidance_scale 1 \
|
| 73 |
+
--test_num ${test_num} \
|
| 74 |
+
--port ${port} > "$log_file" 2>&1 &
|
| 75 |
+
|
| 76 |
+
pid=$!
|
| 77 |
+
echo "${pid}" | tee -a "$pid_file"
|
| 78 |
+
done
|
| 79 |
+
|
| 80 |
+
echo -e "\033[32mAll tasks launched. PIDs saved to ${pid_file}\033[0m"
|
| 81 |
+
echo -e "\033[36mTo terminate all processes, run: kill \$(cat ${pid_file})\033[0m"
|
evaluation/robotwin/launch_server.sh
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
START_PORT=${START_PORT:-29056}
|
| 2 |
+
MASTER_PORT=${MASTER_PORT:-29061}
|
| 3 |
+
|
| 4 |
+
save_root='visualization/'
|
| 5 |
+
mkdir -p $save_root
|
| 6 |
+
|
| 7 |
+
python -m torch.distributed.run \
|
| 8 |
+
--nproc_per_node 1 \
|
| 9 |
+
--master_port $MASTER_PORT \
|
| 10 |
+
wan_va/wan_va_server.py \
|
| 11 |
+
--config-name robotwin \
|
| 12 |
+
--port $START_PORT \
|
| 13 |
+
--save_root $save_root
|
| 14 |
+
|
| 15 |
+
|
evaluation/robotwin/launch_server_multigpus.sh
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
START_PORT=${START_PORT:-29556}
|
| 2 |
+
MASTER_PORT=${MASTER_PORT:-29661}
|
| 3 |
+
LOG_DIR='./logs'
|
| 4 |
+
mkdir -p $LOG_DIR
|
| 5 |
+
|
| 6 |
+
save_root='./visualization/'
|
| 7 |
+
mkdir -p $save_root
|
| 8 |
+
|
| 9 |
+
batch_time=$(date +%Y%m%d_%H%M%S)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
for i in {0..7}; do
|
| 13 |
+
CURRENT_PORT=$((START_PORT + i))
|
| 14 |
+
CURRENT_MASTER_PORT=$((MASTER_PORT + i))
|
| 15 |
+
|
| 16 |
+
LOG_FILE="${LOG_DIR}/server_${i}_${batch_time}.log"
|
| 17 |
+
echo "[Task ${j}] GPU: ${i} | PORT: ${CURRENT_PORT} | MASTER_PORT: ${CURRENT_MASTER_PORT} | Log: ${LOG_FILE}"
|
| 18 |
+
|
| 19 |
+
CUDA_VISIBLE_DEVICES=$i \
|
| 20 |
+
nohup python -m torch.distributed.run \
|
| 21 |
+
--nproc_per_node 1 \
|
| 22 |
+
--master_port $CURRENT_MASTER_PORT \
|
| 23 |
+
wan_va/wan_va_server.py \
|
| 24 |
+
--config-name robotwin \
|
| 25 |
+
--save_root $save_root \
|
| 26 |
+
--port $CURRENT_PORT > $LOG_FILE 2>&1 &
|
| 27 |
+
sleep 2;
|
| 28 |
+
done
|
| 29 |
+
|
| 30 |
+
echo "All 8 instances have been launched in the background."
|
| 31 |
+
wait
|
evaluation/robotwin/msgpack_numpy.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Adds NumPy array support to msgpack.
|
| 2 |
+
|
| 3 |
+
msgpack is good for (de)serializing data over a network for multiple reasons:
|
| 4 |
+
- msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)
|
| 5 |
+
- msgpack is widely used and has good cross-language support
|
| 6 |
+
- msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed
|
| 7 |
+
languages like Python and JavaScript
|
| 8 |
+
- msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster
|
| 9 |
+
than pickle for serializing large arrays using the below strategy
|
| 10 |
+
|
| 11 |
+
The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is
|
| 12 |
+
that it falls back to pickle for object arrays.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import functools
|
| 16 |
+
|
| 17 |
+
import msgpack
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def pack_array(obj):
|
| 22 |
+
if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
|
| 23 |
+
raise ValueError(f"Unsupported dtype: {obj.dtype}")
|
| 24 |
+
|
| 25 |
+
if isinstance(obj, np.ndarray):
|
| 26 |
+
return {
|
| 27 |
+
b"__ndarray__": True,
|
| 28 |
+
b"data": obj.tobytes(),
|
| 29 |
+
b"dtype": obj.dtype.str,
|
| 30 |
+
b"shape": obj.shape,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
if isinstance(obj, np.generic):
|
| 34 |
+
return {
|
| 35 |
+
b"__npgeneric__": True,
|
| 36 |
+
b"data": obj.item(),
|
| 37 |
+
b"dtype": obj.dtype.str,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
return obj
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def unpack_array(obj):
|
| 44 |
+
if b"__ndarray__" in obj:
|
| 45 |
+
return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
|
| 46 |
+
|
| 47 |
+
if b"__npgeneric__" in obj:
|
| 48 |
+
return np.dtype(obj[b"dtype"]).type(obj[b"data"])
|
| 49 |
+
|
| 50 |
+
return obj
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
Packer = functools.partial(msgpack.Packer, default=pack_array)
|
| 54 |
+
packb = functools.partial(msgpack.packb, default=pack_array)
|
| 55 |
+
|
| 56 |
+
Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
|
| 57 |
+
unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)
|
evaluation/robotwin/test_render.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import warnings
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
warnings.simplefilter(action="ignore", category=FutureWarning)
|
| 6 |
+
warnings.simplefilter(action="ignore", category=UserWarning)
|
| 7 |
+
current_file_path = os.path.abspath(__file__)
|
| 8 |
+
parent_dir = os.path.dirname(current_file_path)
|
| 9 |
+
|
| 10 |
+
sys.path.append(os.path.join(parent_dir, "../../tools"))
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pdb
|
| 13 |
+
import json
|
| 14 |
+
import torch
|
| 15 |
+
import sapien.core as sapien
|
| 16 |
+
from sapien.utils.viewer import Viewer
|
| 17 |
+
import gymnasium as gym
|
| 18 |
+
import toppra as ta
|
| 19 |
+
import transforms3d as t3d
|
| 20 |
+
from collections import OrderedDict
|
| 21 |
+
|
| 22 |
+
import sys
|
| 23 |
+
import warnings
|
| 24 |
+
import os
|
| 25 |
+
|
| 26 |
+
warnings.simplefilter(action="ignore", category=FutureWarning)
|
| 27 |
+
warnings.simplefilter(action="ignore", category=UserWarning)
|
| 28 |
+
current_file_path = os.path.abspath(__file__)
|
| 29 |
+
parent_dir = os.path.dirname(current_file_path)
|
| 30 |
+
|
| 31 |
+
sys.path.append(os.path.join(parent_dir, "../../tools"))
|
| 32 |
+
import numpy as np
|
| 33 |
+
import pdb
|
| 34 |
+
import json
|
| 35 |
+
import torch
|
| 36 |
+
import sapien.core as sapien
|
| 37 |
+
from sapien.utils.viewer import Viewer
|
| 38 |
+
import gymnasium as gym
|
| 39 |
+
import toppra as ta
|
| 40 |
+
import transforms3d as t3d
|
| 41 |
+
from collections import OrderedDict
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Sapien_TEST(gym.Env):
|
| 45 |
+
|
| 46 |
+
def __init__(self):
|
| 47 |
+
super().__init__()
|
| 48 |
+
ta.setup_logging("CRITICAL") # hide logging
|
| 49 |
+
try:
|
| 50 |
+
self.setup_scene()
|
| 51 |
+
print("\033[32m" + "Render Well" + "\033[0m")
|
| 52 |
+
except:
|
| 53 |
+
print("\033[31m" + "Render Error" + "\033[0m")
|
| 54 |
+
exit()
|
| 55 |
+
|
| 56 |
+
def setup_scene(self, **kwargs):
|
| 57 |
+
"""
|
| 58 |
+
Set the scene
|
| 59 |
+
- Set up the basic scene: light source, viewer.
|
| 60 |
+
"""
|
| 61 |
+
self.engine = sapien.Engine()
|
| 62 |
+
# declare sapien renderer
|
| 63 |
+
from sapien.render import set_global_config
|
| 64 |
+
|
| 65 |
+
set_global_config(max_num_materials=50000, max_num_textures=50000)
|
| 66 |
+
self.renderer = sapien.SapienRenderer()
|
| 67 |
+
# give renderer to sapien sim
|
| 68 |
+
self.engine.set_renderer(self.renderer)
|
| 69 |
+
|
| 70 |
+
sapien.render.set_camera_shader_dir("rt")
|
| 71 |
+
sapien.render.set_ray_tracing_samples_per_pixel(32)
|
| 72 |
+
sapien.render.set_ray_tracing_path_depth(8)
|
| 73 |
+
sapien.render.set_ray_tracing_denoiser("oidn")
|
| 74 |
+
|
| 75 |
+
# declare sapien scene
|
| 76 |
+
scene_config = sapien.SceneConfig()
|
| 77 |
+
self.scene = self.engine.create_scene(scene_config)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
if __name__ == "__main__":
|
| 81 |
+
a = Sapien_TEST()
|
evaluation/robotwin/websocket_client_policy.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import time
|
| 3 |
+
from typing import Dict, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
from typing_extensions import override
|
| 6 |
+
import websockets.sync.client
|
| 7 |
+
from .msgpack_numpy import Packer, unpackb
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class WebsocketClientPolicy:
|
| 11 |
+
"""Implements the Policy interface by communicating with a server over websocket.
|
| 12 |
+
|
| 13 |
+
See WebsocketPolicyServer for a corresponding server implementation.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, api_key: Optional[str] = None) -> None:
|
| 17 |
+
self._uri = f"ws://{host}"
|
| 18 |
+
if port is not None:
|
| 19 |
+
self._uri += f":{port}"
|
| 20 |
+
self._packer = Packer()
|
| 21 |
+
self._api_key = api_key
|
| 22 |
+
self._ws, self._server_metadata = self._wait_for_server()
|
| 23 |
+
|
| 24 |
+
def get_server_metadata(self) -> Dict:
|
| 25 |
+
return self._server_metadata
|
| 26 |
+
|
| 27 |
+
# def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:
|
| 28 |
+
# logging.info(f"Waiting for server at {self._uri}...")
|
| 29 |
+
# while True:
|
| 30 |
+
# try:
|
| 31 |
+
# headers = {"Authorization": f"Api-Key {self._api_key}"} if self._api_key else None
|
| 32 |
+
# conn = websockets.sync.client.connect(
|
| 33 |
+
# self._uri, compression=None, max_size=None, additional_headers=headers
|
| 34 |
+
# )
|
| 35 |
+
# metadata = unpackb(conn.recv())
|
| 36 |
+
# return conn, metadata
|
| 37 |
+
# except ConnectionRefusedError:
|
| 38 |
+
# logging.info("Still waiting for server...")
|
| 39 |
+
# time.sleep(5)
|
| 40 |
+
|
| 41 |
+
def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:
|
| 42 |
+
logging.info(f"Waiting for server at {self._uri}...")
|
| 43 |
+
while True:
|
| 44 |
+
try:
|
| 45 |
+
headers = {"Authorization": f"Api-Key {self._api_key}"} if self._api_key else None
|
| 46 |
+
# 禁用 ping 机制,防止推理时间过长导致超时
|
| 47 |
+
conn = websockets.sync.client.connect(
|
| 48 |
+
self._uri,
|
| 49 |
+
compression=None,
|
| 50 |
+
max_size=None,
|
| 51 |
+
additional_headers=headers,
|
| 52 |
+
ping_interval=None,
|
| 53 |
+
close_timeout=10
|
| 54 |
+
)
|
| 55 |
+
metadata = unpackb(conn.recv())
|
| 56 |
+
return conn, metadata
|
| 57 |
+
except (ConnectionRefusedError, Exception) as e:
|
| 58 |
+
logging.info(f"Still waiting for server... (Error: {e})")
|
| 59 |
+
time.sleep(5)
|
| 60 |
+
|
| 61 |
+
@override
|
| 62 |
+
def infer(self, obs: Dict) -> Dict: # noqa: UP006
|
| 63 |
+
data = self._packer.pack(obs)
|
| 64 |
+
self._ws.send(data)
|
| 65 |
+
response = self._ws.recv()
|
| 66 |
+
if isinstance(response, str):
|
| 67 |
+
# we're expecting bytes; if the server sends a string, it's an error.
|
| 68 |
+
raise RuntimeError(f"Error in inference server:\n{response}")
|
| 69 |
+
return unpackb(response)
|
| 70 |
+
|
| 71 |
+
@override
|
| 72 |
+
def reset(self) -> None:
|
| 73 |
+
pass
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
policy_on_device = WebsocketClientPolicy(port=8000)
|
| 77 |
+
import torch
|
| 78 |
+
import numpy as np
|
| 79 |
+
from PIL import Image
|
| 80 |
+
from .image_tools import convert_to_uint8
|
| 81 |
+
device = torch.device("cuda")
|
| 82 |
+
|
| 83 |
+
base_0_rgb = np.random.randint(0, 256, size=(1, 3, 224, 224), dtype=np.uint8)
|
| 84 |
+
left_wrist_0_rgb = np.random.randint(0, 256, size=(1, 3, 224, 224), dtype=np.uint8)
|
| 85 |
+
state = np.random.rand(1,8).astype(np.float32)
|
| 86 |
+
prompt = ["do something"]
|
| 87 |
+
|
| 88 |
+
# observation = {
|
| 89 |
+
# "image": {
|
| 90 |
+
# "base_0_rgb": torch.from_numpy(base_0_rgb).to(device)[None],
|
| 91 |
+
# "left_wrist_0_rgb": torch.from_numpy(left_wrist_0_rgb).to(device)[None],
|
| 92 |
+
# },
|
| 93 |
+
# "state": torch.from_numpy(state).to(device)[None],
|
| 94 |
+
# "prompt": prompt,
|
| 95 |
+
# }
|
| 96 |
+
|
| 97 |
+
observation = {
|
| 98 |
+
"image": {
|
| 99 |
+
"base_0_rgb": convert_to_uint8(base_0_rgb),
|
| 100 |
+
"left_wrist_0_rgb": convert_to_uint8(left_wrist_0_rgb),
|
| 101 |
+
"right_wrist_0_rgb": convert_to_uint8(left_wrist_0_rgb),
|
| 102 |
+
},
|
| 103 |
+
"state": state,
|
| 104 |
+
"prompt": prompt,
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
policy_on_device.infer(observation)
|
| 108 |
+
from IPython import embed;embed()
|
example/franka/observation.images.cam_high.png
ADDED
|
example/franka/observation.images.cam_left_wrist.png
ADDED
|
example/franka/observation.images.cam_right_wrist.png
ADDED
|
example/robotwin/observation.images.cam_high.png
ADDED
|
example/robotwin/observation.images.cam_left_wrist.png
ADDED
|
example/robotwin/observation.images.cam_right_wrist.png
ADDED
|
lingbot_robotwin_policy.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import random
|
| 5 |
+
import numpy as np
|
| 6 |
+
from collections import deque
|
| 7 |
+
import torchvision
|
| 8 |
+
import yaml
|
| 9 |
+
from types import SimpleNamespace
|
| 10 |
+
from packaging.version import Version
|
| 11 |
+
from typing import Callable, Dict, List, Optional, Type, Union, Tuple, Any, Sequence
|
| 12 |
+
from glob import glob
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from safetensors import safe_open
|
| 15 |
+
from safetensors.torch import load_file
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from PIL import Image
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from torch import Tensor, nn
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
import transformers
|
| 24 |
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
| 25 |
+
from transformers import (
|
| 26 |
+
AutoConfig,
|
| 27 |
+
PretrainedConfig,
|
| 28 |
+
PreTrainedModel,
|
| 29 |
+
AutoProcessor,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
from lerobot.configs.policies import PreTrainedConfig
|
| 33 |
+
from lingbotvla.models.vla.pi0.modeling_pi0 import PI0Policy
|
| 34 |
+
from lingbotvla.models.vla.pi0.modeling_lingbot_vla import LingbotVlaPolicy
|
| 35 |
+
from lingbotvla.data.vla_data.transform import Normalizer, prepare_images, prepare_language, prepare_state
|
| 36 |
+
from lingbotvla.models import build_processor
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def set_seed_everywhere(seed: int):
|
| 40 |
+
"""Sets the random seed for Python, NumPy, and PyTorch functions."""
|
| 41 |
+
torch.manual_seed(seed)
|
| 42 |
+
torch.cuda.manual_seed_all(seed)
|
| 43 |
+
np.random.seed(seed)
|
| 44 |
+
random.seed(seed)
|
| 45 |
+
torch.backends.cudnn.deterministic = True
|
| 46 |
+
torch.backends.cudnn.benchmark = False
|
| 47 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 48 |
+
|
| 49 |
+
set_seed_everywhere(42)
|
| 50 |
+
|
| 51 |
+
BASE_MODEL_PATH = {
|
| 52 |
+
'pi0': os.environ.get('PALIGEMMA_PATH', './paligemma-3b-pt-224/'),
|
| 53 |
+
'lingbotvla': os.environ.get('QWEN25_PATH', './Qwen2.5-VL-3B-Instruct/'),
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
def load_model_weights(policy, path_to_pi_model, strict=True):
|
| 57 |
+
all_safetensors = glob(os.path.join(path_to_pi_model, "*.safetensors"))
|
| 58 |
+
merged_weights = {}
|
| 59 |
+
|
| 60 |
+
for file_path in tqdm(all_safetensors):
|
| 61 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 62 |
+
for key in f.keys():
|
| 63 |
+
merged_weights[key] = f.get_tensor(key)
|
| 64 |
+
policy.load_state_dict(merged_weights, strict=strict)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def center_crop_image(image: Union[np.ndarray, Image.Image]) -> Image.Image:
|
| 68 |
+
crop_scale = 0.9
|
| 69 |
+
side_scale = float(np.sqrt(np.clip(crop_scale, 0.0, 1.0))) # side length scale
|
| 70 |
+
out_size = (224, 224)
|
| 71 |
+
|
| 72 |
+
# Convert input to PIL Image
|
| 73 |
+
if isinstance(image, np.ndarray):
|
| 74 |
+
arr = image
|
| 75 |
+
if arr.dtype.kind == "f":
|
| 76 |
+
# If floats likely in [0,1], map to [0,255]
|
| 77 |
+
if arr.max() <= 1.0 and arr.min() >= 0.0:
|
| 78 |
+
arr = (np.clip(arr, 0.0, 1.0) * 255.0).astype(np.uint8)
|
| 79 |
+
else:
|
| 80 |
+
arr = np.clip(arr, 0.0, 255.0).astype(np.uint8)
|
| 81 |
+
elif arr.dtype == np.uint16:
|
| 82 |
+
# Map 16-bit to 8-bit
|
| 83 |
+
arr = (arr / 257).astype(np.uint8)
|
| 84 |
+
elif arr.dtype != np.uint8:
|
| 85 |
+
arr = arr.astype(np.uint8)
|
| 86 |
+
pil = Image.fromarray(arr)
|
| 87 |
+
elif isinstance(image, Image.Image):
|
| 88 |
+
pil = image
|
| 89 |
+
else:
|
| 90 |
+
raise TypeError("image must be a numpy array or PIL.Image.Image")
|
| 91 |
+
|
| 92 |
+
# Force RGB for consistent output
|
| 93 |
+
pil = pil.convert("RGB")
|
| 94 |
+
W, H = pil.size
|
| 95 |
+
|
| 96 |
+
# Compute centered crop box (integer pixels)
|
| 97 |
+
crop_w = max(1, int(round(W * side_scale)))
|
| 98 |
+
crop_h = max(1, int(round(H * side_scale)))
|
| 99 |
+
left = (W - crop_w) // 2
|
| 100 |
+
top = (H - crop_h) // 2
|
| 101 |
+
right = left + crop_w
|
| 102 |
+
bottom = top + crop_h
|
| 103 |
+
|
| 104 |
+
cropped = pil.crop((left, top, right, bottom))
|
| 105 |
+
resized = cropped.resize(out_size, resample=Image.BILINEAR)
|
| 106 |
+
return resized
|
| 107 |
+
|
| 108 |
+
def resize_with_pad(img, width, height, pad_value=-1):
|
| 109 |
+
# assume no-op when width height fits already
|
| 110 |
+
if img.ndim != 4:
|
| 111 |
+
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
| 112 |
+
|
| 113 |
+
# channel last to channel first if necessary
|
| 114 |
+
if img.shape[1] not in (1, 3) and img.shape[-1] in (1, 3):
|
| 115 |
+
img = img.permute(0, 3, 1, 2)
|
| 116 |
+
|
| 117 |
+
cur_height, cur_width = img.shape[2:]
|
| 118 |
+
|
| 119 |
+
ratio = max(cur_width / width, cur_height / height)
|
| 120 |
+
resized_height = int(cur_height / ratio)
|
| 121 |
+
resized_width = int(cur_width / ratio)
|
| 122 |
+
resized_img = F.interpolate(
|
| 123 |
+
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
pad_height = max(0, int(height - resized_height))
|
| 127 |
+
pad_width = max(0, int(width - resized_width))
|
| 128 |
+
|
| 129 |
+
# pad on left and top of image
|
| 130 |
+
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
| 131 |
+
return padded_img
|
| 132 |
+
|
| 133 |
+
class PolicyPreprocessMixin:
|
| 134 |
+
|
| 135 |
+
@torch.no_grad
|
| 136 |
+
def select_action(
|
| 137 |
+
self, observation: dict[str, Tensor], use_bf16: bool = False, vlm_causal: bool = False, noise: Tensor | None = None
|
| 138 |
+
):
|
| 139 |
+
self.eval()
|
| 140 |
+
device = 'cuda'
|
| 141 |
+
if use_bf16:
|
| 142 |
+
dtype = torch.bfloat16
|
| 143 |
+
else:
|
| 144 |
+
dtype = torch.float32
|
| 145 |
+
s1 = time.time()
|
| 146 |
+
|
| 147 |
+
if len(observation['images'].shape) == 4:
|
| 148 |
+
observation['images'] = observation['images'].unsqueeze(0)
|
| 149 |
+
observation['img_masks'] = observation['img_masks'].unsqueeze(0)
|
| 150 |
+
|
| 151 |
+
if 'expert_imgs' in observation:
|
| 152 |
+
actions = self.model.sample_actions(
|
| 153 |
+
observation['images'].to(dtype=dtype, device=device),
|
| 154 |
+
observation['img_masks'].to(device=device),
|
| 155 |
+
observation['lang_tokens'].unsqueeze(0).to(device=device),
|
| 156 |
+
observation['lang_masks'].unsqueeze(0).to(device=device),
|
| 157 |
+
observation['state'].unsqueeze(0).to(dtype=dtype, device=device),
|
| 158 |
+
observation['expert_imgs'].to(dtype=dtype, device=device),
|
| 159 |
+
vlm_causal = vlm_causal
|
| 160 |
+
)
|
| 161 |
+
else:
|
| 162 |
+
actions = self.model.sample_actions(
|
| 163 |
+
observation['images'].to(dtype=dtype, device=device),
|
| 164 |
+
observation['img_masks'].to(device=device),
|
| 165 |
+
observation['lang_tokens'].unsqueeze(0).to(device=device),
|
| 166 |
+
observation['lang_masks'].unsqueeze(0).to(device=device),
|
| 167 |
+
observation['state'].unsqueeze(0).to(dtype=dtype, device=device),
|
| 168 |
+
vlm_causal = vlm_causal
|
| 169 |
+
)
|
| 170 |
+
delta_time = time.time() - s1
|
| 171 |
+
print(f'sample_actions cost {delta_time} s')
|
| 172 |
+
observation['action'] = actions.squeeze(0)[:, :14].to(dtype=torch.float32, device='cpu')
|
| 173 |
+
if use_bf16:
|
| 174 |
+
observation['state'] = observation['state'].to(dtype=torch.float32)
|
| 175 |
+
data = self.normalizer.unnormalize(observation)
|
| 176 |
+
return data
|
| 177 |
+
|
| 178 |
+
class LingBotVlaInferencePolicy(PolicyPreprocessMixin, LingbotVlaPolicy):
|
| 179 |
+
pass # Only combine necessary functions
|
| 180 |
+
|
| 181 |
+
class PI0InfernecePolicy(PolicyPreprocessMixin, PI0Policy):
|
| 182 |
+
pass # Only combine necessary functions
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def merge_qwen_config(policy_config, qwen_config):
|
| 186 |
+
if hasattr(qwen_config, 'to_dict'):
|
| 187 |
+
config_dict = qwen_config.to_dict()
|
| 188 |
+
else:
|
| 189 |
+
config_dict = qwen_config
|
| 190 |
+
|
| 191 |
+
text_keys = {
|
| 192 |
+
"hidden_size",
|
| 193 |
+
"intermediate_size",
|
| 194 |
+
"num_hidden_layers",
|
| 195 |
+
"num_attention_heads",
|
| 196 |
+
"num_key_value_heads",
|
| 197 |
+
"rms_norm_eps",
|
| 198 |
+
"rope_theta",
|
| 199 |
+
"vocab_size",
|
| 200 |
+
"max_position_embeddings",
|
| 201 |
+
"hidden_act",
|
| 202 |
+
"tie_word_embeddings",
|
| 203 |
+
"tokenizer_path",
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
for key in text_keys:
|
| 207 |
+
if key in config_dict:
|
| 208 |
+
setattr(policy_config, key, config_dict[key])
|
| 209 |
+
print(f"✅ Merged LLM: {key} = {config_dict[key]}")
|
| 210 |
+
|
| 211 |
+
if "vision_config" in config_dict:
|
| 212 |
+
policy_config.vision_config = qwen_config.vision_config
|
| 213 |
+
else:
|
| 214 |
+
print("⚠️ Warning: 'vision_config' not found in qwen_config!")
|
| 215 |
+
|
| 216 |
+
return policy_config
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class QwenPiServer:
|
| 220 |
+
'''
|
| 221 |
+
policy wrapper to support action ensemble or chunk execution
|
| 222 |
+
'''
|
| 223 |
+
def __init__(
|
| 224 |
+
self,
|
| 225 |
+
path_to_pi_model="",
|
| 226 |
+
adaptive_ensemble_alpha=0.1,
|
| 227 |
+
action_ensemble_horizon=8,
|
| 228 |
+
use_length=1, # to control the execution length of the action chunk, -1 denotes using action ensemble
|
| 229 |
+
chunk_ret=False,
|
| 230 |
+
use_bf16=True,
|
| 231 |
+
use_fp32=False,
|
| 232 |
+
) -> None:
|
| 233 |
+
assert not (use_bf16 and use_fp32), 'Bfloat16 or Float32!!!'
|
| 234 |
+
self.adaptive_ensemble_alpha = adaptive_ensemble_alpha
|
| 235 |
+
self.use_length = use_length
|
| 236 |
+
self.chunk_ret = chunk_ret
|
| 237 |
+
|
| 238 |
+
self.task_description = None
|
| 239 |
+
|
| 240 |
+
self.vla = self.load_vla(path_to_pi_model)
|
| 241 |
+
self.vla = self.vla.cuda().eval()
|
| 242 |
+
if use_bf16:
|
| 243 |
+
self.vla = self.vla.to(torch.bfloat16)
|
| 244 |
+
elif use_fp32:
|
| 245 |
+
self.vla.model.float()
|
| 246 |
+
self.global_step = 0
|
| 247 |
+
self.last_action_chunk = None
|
| 248 |
+
self.use_bf16 = use_bf16
|
| 249 |
+
self.use_fp32 = use_fp32
|
| 250 |
+
|
| 251 |
+
def load_vla(self, path_to_pi_model) -> LingbotVlaPolicy:
|
| 252 |
+
# load model
|
| 253 |
+
|
| 254 |
+
print(f"loading model from: {path_to_pi_model}")
|
| 255 |
+
config = PreTrainedConfig.from_pretrained(path_to_pi_model)
|
| 256 |
+
|
| 257 |
+
# load training config
|
| 258 |
+
training_config_path = Path(path_to_pi_model).parent.parent.parent/'lingbotvla_cli.yaml'
|
| 259 |
+
with open(training_config_path, 'r') as f:
|
| 260 |
+
training_config = yaml.safe_load(f)
|
| 261 |
+
f.close()
|
| 262 |
+
|
| 263 |
+
# update model config according to training config
|
| 264 |
+
training_model_config = training_config['model']
|
| 265 |
+
training_model_config.update(training_config['train'])
|
| 266 |
+
for k, v in training_model_config.items():
|
| 267 |
+
v = getattr(config, k, training_model_config[k])
|
| 268 |
+
setattr(config, k, v)
|
| 269 |
+
|
| 270 |
+
# Set attention_implementation to 'eager' to speed up evaluation.
|
| 271 |
+
config.attention_implementation = 'eager'
|
| 272 |
+
|
| 273 |
+
# set base model according to training config
|
| 274 |
+
training_base_model = training_config['model']['tokenizer_path']
|
| 275 |
+
if 'paligemma' in training_base_model:
|
| 276 |
+
model_name = 'pi0'
|
| 277 |
+
config.vocab_size = 257152 # set vocab size for paligamma
|
| 278 |
+
elif 'qwen2' in training_base_model.lower():
|
| 279 |
+
model_name = 'lingbotvla'
|
| 280 |
+
else:
|
| 281 |
+
raise ValueError(f"Unsupported base model of {path_to_pi_model}")
|
| 282 |
+
base_model_path = BASE_MODEL_PATH[model_name]
|
| 283 |
+
config.tokenizer_path = base_model_path
|
| 284 |
+
self.model_name = model_name
|
| 285 |
+
|
| 286 |
+
qwen_config = AutoConfig.from_pretrained(base_model_path)
|
| 287 |
+
config = merge_qwen_config(config, qwen_config)
|
| 288 |
+
|
| 289 |
+
if 'vocab_size' in training_config['model'] and training_config['model']['vocab_size'] != 0:
|
| 290 |
+
config.vocab_size = training_config['model']['vocab_size']
|
| 291 |
+
# load processors
|
| 292 |
+
self.processor = build_processor(base_model_path)
|
| 293 |
+
self.language_tokenizer = self.processor.tokenizer
|
| 294 |
+
self.image_processor = self.processor.image_processor
|
| 295 |
+
data_config = SimpleNamespace(**training_config['data'])
|
| 296 |
+
|
| 297 |
+
print('Initializing model ... ')
|
| 298 |
+
|
| 299 |
+
if 'paligemma' in training_base_model:
|
| 300 |
+
policy = PI0InfernecePolicy(config, tokenizer_path=base_model_path)
|
| 301 |
+
else:
|
| 302 |
+
policy = LingBotVlaInferencePolicy(config, tokenizer_path=base_model_path)
|
| 303 |
+
|
| 304 |
+
load_model_weights(policy, path_to_pi_model, strict=True)
|
| 305 |
+
|
| 306 |
+
policy.feature_transform = None
|
| 307 |
+
self.data_config = data_config
|
| 308 |
+
self.config = config
|
| 309 |
+
self.joint_max_dim = training_config['train']['max_action_dim']
|
| 310 |
+
self.action_dim = training_config['train']['action_dim']
|
| 311 |
+
self.chunk_size = training_config['train']['chunk_size']
|
| 312 |
+
policy.action_dim = self.action_dim
|
| 313 |
+
policy.chunk_size = self.chunk_size
|
| 314 |
+
self.norm_stats_file = data_config.norm_stats_file
|
| 315 |
+
if 'align_params' in training_config['train']:
|
| 316 |
+
self.use_depth_align = True
|
| 317 |
+
else: self.use_depth_align = False
|
| 318 |
+
with open(self.norm_stats_file) as f:
|
| 319 |
+
self.norm_stats = json.load(f)
|
| 320 |
+
policy.normalizer = Normalizer(
|
| 321 |
+
norm_stats=self.norm_stats['norm_stats'],
|
| 322 |
+
from_file=True,
|
| 323 |
+
data_type='robotwin',
|
| 324 |
+
norm_type={
|
| 325 |
+
"observation.images.cam_high": "identity",
|
| 326 |
+
"observation.images.cam_left_wrist": "identity",
|
| 327 |
+
"observation.images.cam_right_wrist": "identity",
|
| 328 |
+
"observation.state": self.data_config.norm_type,
|
| 329 |
+
"action": self.data_config.norm_type,
|
| 330 |
+
},
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
print('Model initialized ... ')
|
| 334 |
+
|
| 335 |
+
return policy
|
| 336 |
+
|
| 337 |
+
def reset(self, robo_name, path_to_pi_model = None) -> None:
|
| 338 |
+
|
| 339 |
+
if path_to_pi_model is not None:
|
| 340 |
+
self.vla = self.load_vla(path_to_pi_model)
|
| 341 |
+
self.vla = self.vla.cuda().eval()
|
| 342 |
+
if self.use_bf16:
|
| 343 |
+
self.vla = self.vla.to(torch.bfloat16)
|
| 344 |
+
elif self.use_fp32:
|
| 345 |
+
self.vla.model.float()
|
| 346 |
+
|
| 347 |
+
self.global_step = 0
|
| 348 |
+
self.last_action_chunk = None
|
| 349 |
+
|
| 350 |
+
if getattr(self.data_config, 'norm_type', None) is None:
|
| 351 |
+
self.data_config.norm_type = 'meanstd'
|
| 352 |
+
if getattr(self.config, 'vlm_causal', None) is None:
|
| 353 |
+
self.config.vlm_causal = False
|
| 354 |
+
if getattr(self.config, 'qwenvl_bos', None) is None:
|
| 355 |
+
self.config.qwenvl_bos = False
|
| 356 |
+
|
| 357 |
+
# if update ckpt path
|
| 358 |
+
if path_to_pi_model is not None:
|
| 359 |
+
all_safetensors = glob(os.path.join(path_to_pi_model, "*.safetensors"))
|
| 360 |
+
merged_weights = {}
|
| 361 |
+
|
| 362 |
+
for file_path in tqdm(all_safetensors):
|
| 363 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 364 |
+
for key in f.keys():
|
| 365 |
+
merged_weights[key] = f.get_tensor(key)
|
| 366 |
+
|
| 367 |
+
self.vla.load_state_dict(merged_weights, strict=True)
|
| 368 |
+
|
| 369 |
+
def resize_image(self, observation):
|
| 370 |
+
for image_feature in ['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist']:
|
| 371 |
+
assert image_feature in observation
|
| 372 |
+
assert len(observation[image_feature].shape)==3 and observation[image_feature].shape[-1] == 3
|
| 373 |
+
image = observation[image_feature]
|
| 374 |
+
img_pil = Image.fromarray(image)
|
| 375 |
+
image_size = getattr(self.data_config, 'img_size', 224)
|
| 376 |
+
img_pil = img_pil.resize((image_size, image_size), Image.BILINEAR)
|
| 377 |
+
|
| 378 |
+
# img_resized shape: C*H*W
|
| 379 |
+
img_resized = np.transpose(np.array(img_pil), (2,0,1)) # (3,224,224)
|
| 380 |
+
observation[image_feature] = img_resized / 255.
|
| 381 |
+
|
| 382 |
+
def infer(self, observation, center_crop=True):
|
| 383 |
+
"""Generates an action with the VLA policy."""
|
| 384 |
+
|
| 385 |
+
# (If trained with image augmentations) Center crop image and then resize back up to original size.
|
| 386 |
+
# IMPORTANT: Let's say crop scale == 0.9. To get the new height and width (post-crop), multiply
|
| 387 |
+
# the original height and width by sqrt(0.9) -- not 0.9!
|
| 388 |
+
if 'reset' in observation and observation['reset']:
|
| 389 |
+
self.reset(robo_name=observation['robo_name'], path_to_pi_model=observation['path_to_pi_model'] if 'path_to_pi_model' in observation else None)
|
| 390 |
+
return dict(action = None)
|
| 391 |
+
|
| 392 |
+
self.resize_image(observation)
|
| 393 |
+
for k, v in observation.items():
|
| 394 |
+
if isinstance(v, np.ndarray):
|
| 395 |
+
observation[k] = torch.from_numpy(v)
|
| 396 |
+
|
| 397 |
+
if self.use_length == -1 or self.global_step % self.use_length == 0:
|
| 398 |
+
joint_max_dim = getattr(self, 'joint_max_dim')
|
| 399 |
+
action_dim = getattr(self, 'action_dim')
|
| 400 |
+
chunk_size = getattr(self, 'chunk_size')
|
| 401 |
+
normalized_observation = self.vla.normalizer.normalize(observation)
|
| 402 |
+
base_image = (normalized_observation["observation.images.cam_high"] * 255).to(torch.uint8)
|
| 403 |
+
left_wrist_image = (normalized_observation["observation.images.cam_left_wrist"] * 255).to(
|
| 404 |
+
torch.uint8
|
| 405 |
+
)
|
| 406 |
+
right_wrist_image = (normalized_observation["observation.images.cam_right_wrist"] * 255).to(
|
| 407 |
+
torch.uint8
|
| 408 |
+
)
|
| 409 |
+
obs_dict = {
|
| 410 |
+
"image": {"base_0_rgb": base_image, "left_wrist_0_rgb": left_wrist_image, "right_wrist_0_rgb": right_wrist_image},
|
| 411 |
+
"state": normalized_observation["observation.state"].to(torch.float32),
|
| 412 |
+
"prompt": [observation["task"]],
|
| 413 |
+
}
|
| 414 |
+
state = prepare_state(self.config, obs_dict)
|
| 415 |
+
lang_tokens, lang_masks = prepare_language(self.config, self.language_tokenizer, obs_dict)
|
| 416 |
+
images, img_masks, _ = prepare_images(self.config, self.image_processor, obs_dict)
|
| 417 |
+
observation = {
|
| 418 |
+
'images': images,
|
| 419 |
+
'img_masks': img_masks,
|
| 420 |
+
'state': state,
|
| 421 |
+
'lang_tokens': lang_tokens,
|
| 422 |
+
'lang_masks': lang_masks,
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
if self.use_bf16:
|
| 426 |
+
observation['state'] = observation['state'].to(torch.bfloat16)
|
| 427 |
+
|
| 428 |
+
org_actions = ['action']
|
| 429 |
+
assert len(org_actions)==1, "Only support single action feature"
|
| 430 |
+
if self.chunk_ret:
|
| 431 |
+
action = self.vla.select_action(observation, self.use_bf16, self.config.vlm_causal)[org_actions[0]].float().cpu().numpy()
|
| 432 |
+
action = action[:self.use_length, :self.action_dim]
|
| 433 |
+
else:
|
| 434 |
+
if self.use_length == -1 or self.global_step % self.use_length == 0:
|
| 435 |
+
action = self.vla.select_action(observation, self.use_bf16, self.config.vlm_causal)[org_actions[0]]
|
| 436 |
+
self.last_action_chunk = action.float().cpu().numpy()
|
| 437 |
+
|
| 438 |
+
if self.use_length > 0:
|
| 439 |
+
action = self.last_action_chunk[self.global_step % self.use_length]
|
| 440 |
+
action = action[:, :self.action_dim]
|
| 441 |
+
print(f"on server step: {self.global_step}")
|
| 442 |
+
self.global_step+=1
|
| 443 |
+
|
| 444 |
+
return dict(action = action)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
import argparse
|
| 448 |
+
from .websocket_policy_server import WebsocketPolicyServer
|
| 449 |
+
|
| 450 |
+
def main():
|
| 451 |
+
parser = argparse.ArgumentParser(description="启动 QwenPi WebSocket 策略服务器")
|
| 452 |
+
|
| 453 |
+
parser.add_argument(
|
| 454 |
+
"--model_path",
|
| 455 |
+
type=str,
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
parser.add_argument(
|
| 459 |
+
"--use_length",
|
| 460 |
+
type=int,
|
| 461 |
+
default=50,
|
| 462 |
+
help="used length of action chunk"
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
parser.add_argument(
|
| 466 |
+
"--chunk_ret",
|
| 467 |
+
type=bool,
|
| 468 |
+
default=True,
|
| 469 |
+
help=" True: The returned action tensor includes the horizon dimension. This allows the model to output a sequence of actions for each horizon step. False: The horizon dimension is omitted. The model selects and returns the next step autonomously based on its policy."
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
parser.add_argument(
|
| 473 |
+
"--port",
|
| 474 |
+
type=int,
|
| 475 |
+
default=8006,
|
| 476 |
+
help="port of WebSocket"
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
parser.add_argument(
|
| 480 |
+
"--debug_infer_once",
|
| 481 |
+
action="store_true",
|
| 482 |
+
help="Run one infer with dummy observation then exit (for debugging infer() without WebSocket client)",
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
args = parser.parse_args()
|
| 486 |
+
|
| 487 |
+
model = QwenPiServer(args.model_path, use_length=args.use_length, chunk_ret=args.chunk_ret)
|
| 488 |
+
if args.debug_infer_once:
|
| 489 |
+
# 调试用:不启动 WebSocket,只跑一次 infer,可在 infer / select_action 里下断点
|
| 490 |
+
dummy_obs = {
|
| 491 |
+
"observation.images.cam_high": np.zeros((224, 224, 3), dtype=np.uint8),
|
| 492 |
+
"observation.images.cam_left_wrist": np.zeros((224, 224, 3), dtype=np.uint8),
|
| 493 |
+
"observation.images.cam_right_wrist": np.zeros((224, 224, 3), dtype=np.uint8),
|
| 494 |
+
"observation.state": np.zeros(model.action_dim, dtype=np.float32),
|
| 495 |
+
"task": "dummy task for debug",
|
| 496 |
+
"reset": False,
|
| 497 |
+
}
|
| 498 |
+
out = model.infer(dummy_obs)
|
| 499 |
+
print("debug_infer_once result keys:", out.keys())
|
| 500 |
+
return
|
| 501 |
+
model_server = WebsocketPolicyServer(model, port=args.port)
|
| 502 |
+
model_server.serve_forever()
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
if __name__ == "__main__":
|
| 506 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "LingBot_VA"
|
| 7 |
+
version = "0.0.0"
|
| 8 |
+
description = "LingBot-VA: A Pragmatic VA Foundation Model"
|
| 9 |
+
authors = [
|
| 10 |
+
{ name = "Robbyant Team", email = "fengchang.ll@antgroup.com" }
|
| 11 |
+
]
|
| 12 |
+
license = { file = "LICENSE.txt" }
|
| 13 |
+
readme = "README.md"
|
| 14 |
+
requires-python = ">=3.10,<4.0"
|
| 15 |
+
dependencies = [
|
| 16 |
+
"torch>=2.9.0",
|
| 17 |
+
"torchvision>=0.24.0",
|
| 18 |
+
"diffusers>=0.36.0",
|
| 19 |
+
"transformers>=4.55.4",
|
| 20 |
+
"tokenizers>=0.21.4",
|
| 21 |
+
"tqdm",
|
| 22 |
+
"imageio",
|
| 23 |
+
"easydict",
|
| 24 |
+
"flash_attn",
|
| 25 |
+
"numpy>=1.26.4,<2"
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
[project.optional-dependencies]
|
| 29 |
+
dev = [
|
| 30 |
+
"pytest",
|
| 31 |
+
"black",
|
| 32 |
+
"flake8",
|
| 33 |
+
"isort",
|
| 34 |
+
"mypy",
|
| 35 |
+
"huggingface-hub[cli]"
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
[project.urls]
|
| 39 |
+
homepage = "https://github.com/Robbyant"
|
| 40 |
+
documentation = "https://github.com/Robbyant"
|
| 41 |
+
repository = "https://github.com/Robbyant"
|
| 42 |
+
huggingface = "https://github.com/Robbyant"
|
| 43 |
+
modelscope = "https://github.com/Robbyant"
|
| 44 |
+
discord = "https://github.com/Robbyant"
|
| 45 |
+
|
| 46 |
+
[tool.setuptools]
|
| 47 |
+
packages = ["lingbot_va"]
|
| 48 |
+
|
| 49 |
+
[tool.setuptools.package-data]
|
| 50 |
+
"lingbot_va" = ["**/*.py"]
|
| 51 |
+
|
| 52 |
+
[tool.black]
|
| 53 |
+
line-length = 88
|
| 54 |
+
|
| 55 |
+
[tool.isort]
|
| 56 |
+
profile = "black"
|
| 57 |
+
|
| 58 |
+
[tool.mypy]
|
| 59 |
+
strict = true
|
| 60 |
+
|
| 61 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.9.0
|
| 2 |
+
torchvision>=0.24.0
|
| 3 |
+
torchaudio
|
| 4 |
+
diffusers>=0.36.0
|
| 5 |
+
transformers>=4.55.4
|
| 6 |
+
tokenizers>=0.21.4
|
| 7 |
+
tqdm
|
| 8 |
+
imageio[ffmpeg]
|
| 9 |
+
easydict
|
| 10 |
+
flash_attn
|
| 11 |
+
numpy>=1.26.4,<2
|
script/run_launch_va_server_sync.sh
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/bash
|
| 2 |
+
|
| 3 |
+
set -x
|
| 4 |
+
|
| 5 |
+
umask 007
|
| 6 |
+
|
| 7 |
+
NGPU=${NGPU:-"8"}
|
| 8 |
+
MASTER_PORT=${MASTER_PORT:-"29501"}
|
| 9 |
+
PORT=${PORT:-"1106"}
|
| 10 |
+
LOG_RANK=${LOG_RANK:-"0"}
|
| 11 |
+
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}
|
| 12 |
+
CONFIG_NAME=${CONFIG_NAME:-"robotwin"}
|
| 13 |
+
|
| 14 |
+
overrides=""
|
| 15 |
+
if [ $# -ne 0 ]; then
|
| 16 |
+
overrides="$*"
|
| 17 |
+
fi
|
| 18 |
+
|
| 19 |
+
## node setting
|
| 20 |
+
num_gpu=${NGPU}
|
| 21 |
+
master_port=${MASTER_PORT}
|
| 22 |
+
log_rank=${LOG_RANK}
|
| 23 |
+
torchft_lighthouse=${TORCHFT_LIGHTHOUSE}
|
| 24 |
+
config_name=${CONFIG_NAME}
|
| 25 |
+
|
| 26 |
+
## cmd setting
|
| 27 |
+
export TOKENIZERS_PARALLELISM=false
|
| 28 |
+
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" TORCHFT_LIGHTHOUSE=${torchft_lighthouse} \
|
| 29 |
+
python -m torch.distributed.run \
|
| 30 |
+
--nproc_per_node=${num_gpu} \
|
| 31 |
+
--local-ranks-filter=${log_rank} \
|
| 32 |
+
--master_port ${master_port} \
|
| 33 |
+
--tee 3 \
|
| 34 |
+
-m wan_va.wan_va_server --config-name ${config_name} $overrides
|
wan_va/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from . import configs, distributed, modules
|
wan_va/configs/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
| 2 |
+
from .va_franka_cfg import va_franka_cfg
|
| 3 |
+
from .va_robotwin_cfg import va_robotwin_cfg
|
| 4 |
+
from .va_franka_i2va import va_franka_i2va_cfg
|
| 5 |
+
from .va_robotwin_i2va import va_robotwin_i2va_cfg
|
| 6 |
+
|
| 7 |
+
VA_CONFIGS = {
|
| 8 |
+
'robotwin': va_robotwin_cfg,
|
| 9 |
+
'franka': va_franka_cfg,
|
| 10 |
+
'robotwin_i2av': va_robotwin_i2va_cfg,
|
| 11 |
+
'franka_i2av': va_franka_i2va_cfg,
|
| 12 |
+
}
|
wan_va/configs/shared_config.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
from easydict import EasyDict
|
| 4 |
+
|
| 5 |
+
va_shared_cfg = EasyDict()
|
| 6 |
+
|
| 7 |
+
va_shared_cfg.host = '0.0.0.0'
|
| 8 |
+
va_shared_cfg.port = 29536
|
| 9 |
+
|
| 10 |
+
va_shared_cfg.param_dtype = torch.bfloat16
|
| 11 |
+
va_shared_cfg.save_root = './visualization'
|
| 12 |
+
|
| 13 |
+
va_shared_cfg.patch_size = (1, 2, 2)
|
wan_va/configs/va_franka_cfg.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
from easydict import EasyDict
|
| 4 |
+
|
| 5 |
+
from .shared_config import va_shared_cfg
|
| 6 |
+
|
| 7 |
+
va_franka_cfg = EasyDict(__name__='Config: VA franka')
|
| 8 |
+
va_franka_cfg.update(va_shared_cfg)
|
| 9 |
+
va_shared_cfg.infer_mode = 'server'
|
| 10 |
+
|
| 11 |
+
va_franka_cfg.wan22_pretrained_model_name_or_path = "/path/to/pretrained/model"
|
| 12 |
+
|
| 13 |
+
va_franka_cfg.attn_window = 30
|
| 14 |
+
va_franka_cfg.frame_chunk_size = 4
|
| 15 |
+
va_franka_cfg.env_type = 'none'
|
| 16 |
+
|
| 17 |
+
va_franka_cfg.height = 224
|
| 18 |
+
va_franka_cfg.width = 320
|
| 19 |
+
va_franka_cfg.action_dim = 30
|
| 20 |
+
va_franka_cfg.action_per_frame = 20
|
| 21 |
+
va_franka_cfg.obs_cam_keys = [
|
| 22 |
+
'observation.images.cam_high', 'observation.images.cam_left_wrist',
|
| 23 |
+
'observation.images.cam_right_wrist'
|
| 24 |
+
]
|
| 25 |
+
va_franka_cfg.guidance_scale = 5
|
| 26 |
+
va_franka_cfg.action_guidance_scale = 1
|
| 27 |
+
|
| 28 |
+
va_franka_cfg.num_inference_steps = 5
|
| 29 |
+
va_franka_cfg.video_exec_step = -1
|
| 30 |
+
va_franka_cfg.action_num_inference_steps = 10
|
| 31 |
+
|
| 32 |
+
va_franka_cfg.snr_shift = 5.0
|
| 33 |
+
va_franka_cfg.action_snr_shift = 1.0
|
| 34 |
+
|
| 35 |
+
va_franka_cfg.used_action_channel_ids = list(range(0, 7)) + list(range(
|
| 36 |
+
28, 29)) + list(range(7, 14)) + list(range(29, 30))
|
| 37 |
+
inverse_used_action_channel_ids = [len(va_franka_cfg.used_action_channel_ids)
|
| 38 |
+
] * va_franka_cfg.action_dim
|
| 39 |
+
for i, j in enumerate(va_franka_cfg.used_action_channel_ids):
|
| 40 |
+
inverse_used_action_channel_ids[j] = i
|
| 41 |
+
va_franka_cfg.inverse_used_action_channel_ids = inverse_used_action_channel_ids
|
| 42 |
+
|
| 43 |
+
va_franka_cfg.action_norm_method = 'quantiles'
|
| 44 |
+
va_franka_cfg.norm_stat = {
|
| 45 |
+
"q01": [
|
| 46 |
+
0.3051295876502991, -0.22647984325885773, 0.19957000017166138,
|
| 47 |
+
-0.022680532187223434, -0.05553057789802551, -0.2693849802017212,
|
| 48 |
+
-0.29341773986816405, 0.2935442328453064, -0.4431332051753998,
|
| 49 |
+
0.21256473660469055, -0.7962440848350525, -0.40816226601600647,
|
| 50 |
+
-0.28359392285346985, -0.44507765769958496
|
| 51 |
+
] + [0.] * 16,
|
| 52 |
+
"q99": [
|
| 53 |
+
0.7572150230407715, 0.47736290097236633, 0.6428080797195435,
|
| 54 |
+
0.9835678935050964, 0.9927203059196472, 0.28041139245033264,
|
| 55 |
+
0.47529348731040877, 0.7564866304397571, 0.04082797020673729,
|
| 56 |
+
0.5355993628501885, 0.9976375699043274, 0.8973174452781656,
|
| 57 |
+
0.6016915678977965, 0.5027598619461056
|
| 58 |
+
] + [0.] * 14 + [1.0, 1.0],
|
| 59 |
+
}
|
wan_va/configs/va_franka_i2va.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
| 2 |
+
from easydict import EasyDict
|
| 3 |
+
from .va_franka_cfg import va_franka_cfg
|
| 4 |
+
|
| 5 |
+
va_franka_i2va_cfg = EasyDict(__name__='Config: VA franka i2va')
|
| 6 |
+
va_franka_i2va_cfg.update(va_franka_cfg)
|
| 7 |
+
|
| 8 |
+
va_franka_i2va_cfg.input_img_path = 'example/franka'
|
| 9 |
+
va_franka_i2va_cfg.num_chunks_to_infer = 10
|
| 10 |
+
va_franka_i2va_cfg.prompt = 'pick bunk'
|
| 11 |
+
va_franka_i2va_cfg.infer_mode = 'i2va'
|
wan_va/configs/va_robotwin_cfg.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
| 2 |
+
from easydict import EasyDict
|
| 3 |
+
|
| 4 |
+
from .shared_config import va_shared_cfg
|
| 5 |
+
|
| 6 |
+
va_robotwin_cfg = EasyDict(__name__='Config: VA robotwin')
|
| 7 |
+
va_robotwin_cfg.update(va_shared_cfg)
|
| 8 |
+
|
| 9 |
+
va_robotwin_cfg.wan22_pretrained_model_name_or_path = "/group/ossdphi_algo_scratch_11/weicxu/huggingface_cache/hub/models--robbyant--lingbot-va-posttrain-robotwin/snapshots/ef7242af28caff0af2ff8e947c78806f94719a39"
|
| 10 |
+
|
| 11 |
+
va_robotwin_cfg.attn_window = 72
|
| 12 |
+
va_robotwin_cfg.frame_chunk_size = 2
|
| 13 |
+
va_robotwin_cfg.env_type = 'robotwin_tshape'
|
| 14 |
+
|
| 15 |
+
va_robotwin_cfg.height = 256
|
| 16 |
+
va_robotwin_cfg.width = 320
|
| 17 |
+
va_robotwin_cfg.action_dim = 30
|
| 18 |
+
va_robotwin_cfg.action_per_frame = 16
|
| 19 |
+
va_robotwin_cfg.obs_cam_keys = [
|
| 20 |
+
'observation.images.cam_high', 'observation.images.cam_left_wrist',
|
| 21 |
+
'observation.images.cam_right_wrist'
|
| 22 |
+
]
|
| 23 |
+
va_robotwin_cfg.guidance_scale = 5
|
| 24 |
+
va_robotwin_cfg.action_guidance_scale = 1
|
| 25 |
+
|
| 26 |
+
va_robotwin_cfg.num_inference_steps = 25
|
| 27 |
+
va_robotwin_cfg.video_exec_step = -1
|
| 28 |
+
va_robotwin_cfg.action_num_inference_steps = 50
|
| 29 |
+
|
| 30 |
+
va_robotwin_cfg.snr_shift = 5.0
|
| 31 |
+
va_robotwin_cfg.action_snr_shift = 1.0
|
| 32 |
+
|
| 33 |
+
va_robotwin_cfg.used_action_channel_ids = list(range(0, 7)) + list(
|
| 34 |
+
range(28, 29)) + list(range(7, 14)) + list(range(29, 30))
|
| 35 |
+
inverse_used_action_channel_ids = [
|
| 36 |
+
len(va_robotwin_cfg.used_action_channel_ids)
|
| 37 |
+
] * va_robotwin_cfg.action_dim
|
| 38 |
+
for i, j in enumerate(va_robotwin_cfg.used_action_channel_ids):
|
| 39 |
+
inverse_used_action_channel_ids[j] = i
|
| 40 |
+
va_robotwin_cfg.inverse_used_action_channel_ids = inverse_used_action_channel_ids
|
| 41 |
+
|
| 42 |
+
va_robotwin_cfg.action_norm_method = 'quantiles'
|
| 43 |
+
va_robotwin_cfg.norm_stat = {
|
| 44 |
+
"q01": [
|
| 45 |
+
-0.06172713458538055, -3.6716461181640625e-05, -0.08783501386642456,
|
| 46 |
+
-1, -1, -1, -1, -0.3547105032205582, -1.3113021850585938e-06,
|
| 47 |
+
-0.11975435614585876, -1, -1, -1, -1
|
| 48 |
+
] + [0.] * 16,
|
| 49 |
+
"q99": [
|
| 50 |
+
0.3462600058317184, 0.39966784834861746, 0.14745532035827624, 1, 1, 1,
|
| 51 |
+
1, 0.034201726913452024, 0.39142737388610793, 0.1792279863357542, 1, 1,
|
| 52 |
+
1, 1
|
| 53 |
+
] + [0.] * 14 + [1.0, 1.0],
|
| 54 |
+
}
|
wan_va/configs/va_robotwin_i2va.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
| 2 |
+
from easydict import EasyDict
|
| 3 |
+
from .va_robotwin_cfg import va_robotwin_cfg
|
| 4 |
+
|
| 5 |
+
va_robotwin_i2va_cfg = EasyDict(__name__='Config: VA robotwin i2va')
|
| 6 |
+
va_robotwin_i2va_cfg.update(va_robotwin_cfg)
|
| 7 |
+
|
| 8 |
+
va_robotwin_i2va_cfg.input_img_path = 'example/robotwin'
|
| 9 |
+
va_robotwin_i2va_cfg.num_chunks_to_infer = 10
|
| 10 |
+
va_robotwin_i2va_cfg.prompt = 'Grab the medium-sized white mug, rotate it, place it on the table, and hook it onto the smooth dark gray rack.'
|
| 11 |
+
va_robotwin_i2va_cfg.infer_mode = 'i2va'
|
wan_va/distributed/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
wan_va/distributed/fsdp.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import gc
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 7 |
+
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
| 8 |
+
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
| 9 |
+
from torch.distributed.utils import _free_storage
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def shard_model(model,
|
| 13 |
+
device_id,
|
| 14 |
+
param_dtype=torch.bfloat16,
|
| 15 |
+
reduce_dtype=torch.float32,
|
| 16 |
+
buffer_dtype=torch.float32,
|
| 17 |
+
process_group=None,
|
| 18 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
| 19 |
+
sync_module_states=True,
|
| 20 |
+
use_lora=False):
|
| 21 |
+
model = FSDP(module=model,
|
| 22 |
+
process_group=process_group,
|
| 23 |
+
sharding_strategy=sharding_strategy,
|
| 24 |
+
auto_wrap_policy=partial(
|
| 25 |
+
lambda_auto_wrap_policy,
|
| 26 |
+
lambda_fn=lambda m: m in model.blocks),
|
| 27 |
+
mixed_precision=MixedPrecision(param_dtype=param_dtype,
|
| 28 |
+
reduce_dtype=reduce_dtype,
|
| 29 |
+
buffer_dtype=buffer_dtype),
|
| 30 |
+
device_id=device_id,
|
| 31 |
+
sync_module_states=sync_module_states,
|
| 32 |
+
use_orig_params=True if use_lora else False)
|
| 33 |
+
return model
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def free_model(model):
|
| 37 |
+
for m in model.modules():
|
| 38 |
+
if isinstance(m, FSDP):
|
| 39 |
+
_free_storage(m._handle.flat_param.data)
|
| 40 |
+
del model
|
| 41 |
+
gc.collect()
|
| 42 |
+
torch.cuda.empty_cache()
|
wan_va/distributed/util.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _configure_model(model, shard_fn, param_dtype, device):
|
| 7 |
+
"""
|
| 8 |
+
TODO
|
| 9 |
+
"""
|
| 10 |
+
model.eval().requires_grad_(False)
|
| 11 |
+
if dist.is_initialized():
|
| 12 |
+
dist.barrier()
|
| 13 |
+
|
| 14 |
+
if dist.is_initialized():
|
| 15 |
+
model = shard_fn(model)
|
| 16 |
+
else:
|
| 17 |
+
model.to(param_dtype)
|
| 18 |
+
model.to(device)
|
| 19 |
+
|
| 20 |
+
return model
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def init_distributed(world_size, local_rank, rank):
|
| 24 |
+
torch.cuda.set_device(local_rank)
|
| 25 |
+
if world_size > 1:
|
| 26 |
+
dist.init_process_group(backend="nccl",
|
| 27 |
+
init_method="env://",
|
| 28 |
+
rank=rank,
|
| 29 |
+
world_size=world_size)
|
wan_va/modules/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
| 2 |
+
from .utils import load_text_encoder, load_tokenizer, load_transformer, load_vae
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
'load_transformer', 'load_text_encoder', 'load_tokenizer', 'load_vae',
|
| 6 |
+
'WanVAEStreamingWrapper'
|
| 7 |
+
]
|
wan_va/modules/model.py
ADDED
|
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
| 2 |
+
import math
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 9 |
+
from diffusers.models.attention import FeedForward
|
| 10 |
+
from diffusers.models.embeddings import (
|
| 11 |
+
PixArtAlphaTextProjection,
|
| 12 |
+
TimestepEmbedding,
|
| 13 |
+
Timesteps,
|
| 14 |
+
)
|
| 15 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 16 |
+
from diffusers.models.normalization import FP32LayerNorm
|
| 17 |
+
from einops import rearrange
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from flash_attn_interface import flash_attn_func
|
| 21 |
+
except:
|
| 22 |
+
from flash_attn import flash_attn_func
|
| 23 |
+
|
| 24 |
+
__all__ = ['WanTransformer3DModel']
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def custom_sdpa(q, k, v):
|
| 28 |
+
out = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2),
|
| 29 |
+
v.transpose(1, 2))
|
| 30 |
+
return out.transpose(1, 2)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class WanTimeTextImageEmbedding(nn.Module):
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
dim,
|
| 38 |
+
time_freq_dim,
|
| 39 |
+
time_proj_dim,
|
| 40 |
+
text_embed_dim,
|
| 41 |
+
pos_embed_seq_len,
|
| 42 |
+
):
|
| 43 |
+
super().__init__()
|
| 44 |
+
|
| 45 |
+
self.timesteps_proj = Timesteps(num_channels=time_freq_dim,
|
| 46 |
+
flip_sin_to_cos=True,
|
| 47 |
+
downscale_freq_shift=0)
|
| 48 |
+
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim,
|
| 49 |
+
time_embed_dim=dim)
|
| 50 |
+
self.act_fn = nn.SiLU()
|
| 51 |
+
self.time_proj = nn.Linear(dim, time_proj_dim)
|
| 52 |
+
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim,
|
| 53 |
+
dim,
|
| 54 |
+
act_fn="gelu_tanh")
|
| 55 |
+
|
| 56 |
+
def forward(
|
| 57 |
+
self,
|
| 58 |
+
timestep: torch.Tensor,
|
| 59 |
+
dtype=None,
|
| 60 |
+
):
|
| 61 |
+
B, L = timestep.shape
|
| 62 |
+
timestep = timestep.reshape(-1)
|
| 63 |
+
timestep = self.timesteps_proj(timestep)
|
| 64 |
+
# time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
| 65 |
+
time_embedder_dtype = self.time_embedder.linear_1.weight.dtype
|
| 66 |
+
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
| 67 |
+
timestep = timestep.to(time_embedder_dtype)
|
| 68 |
+
temb = self.time_embedder(timestep).to(dtype=dtype)
|
| 69 |
+
timestep_proj = self.time_proj(self.act_fn(temb))
|
| 70 |
+
return temb.reshape(B, L, -1), timestep_proj.reshape(B, L, -1)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class WanRotaryPosEmbed(nn.Module):
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
attention_head_dim,
|
| 78 |
+
patch_size,
|
| 79 |
+
max_seq_len,
|
| 80 |
+
theta=10000.0,
|
| 81 |
+
):
|
| 82 |
+
super().__init__()
|
| 83 |
+
|
| 84 |
+
self.attention_head_dim = attention_head_dim
|
| 85 |
+
self.patch_size = patch_size
|
| 86 |
+
self.max_seq_len = max_seq_len
|
| 87 |
+
self.theta = theta
|
| 88 |
+
|
| 89 |
+
self.f_dim = self.attention_head_dim - 2 * (self.attention_head_dim //
|
| 90 |
+
3)
|
| 91 |
+
self.h_dim = self.attention_head_dim // 3
|
| 92 |
+
self.w_dim = self.attention_head_dim // 3
|
| 93 |
+
|
| 94 |
+
# Precompute and register buffers
|
| 95 |
+
f_freqs_base, h_freqs_base, w_freqs_base = self._precompute_freqs_base(
|
| 96 |
+
)
|
| 97 |
+
self.register_buffer("f_freqs_base", f_freqs_base, persistent=False)
|
| 98 |
+
self.register_buffer("h_freqs_base", h_freqs_base, persistent=False)
|
| 99 |
+
self.register_buffer("w_freqs_base", w_freqs_base, persistent=False)
|
| 100 |
+
|
| 101 |
+
def _precompute_freqs_base(self):
|
| 102 |
+
# freqs_base = 1.0 / (theta ** (2k / dim))
|
| 103 |
+
f_freqs_base = 1.0 / (self.theta**(torch.arange(
|
| 104 |
+
0, self.f_dim, 2)[:(self.f_dim // 2)].double() / self.f_dim))
|
| 105 |
+
h_freqs_base = 1.0 / (self.theta**(torch.arange(
|
| 106 |
+
0, self.h_dim, 2)[:(self.h_dim // 2)].double() / self.h_dim))
|
| 107 |
+
w_freqs_base = 1.0 / (self.theta**(torch.arange(
|
| 108 |
+
0, self.w_dim, 2)[:(self.w_dim // 2)].double() / self.w_dim))
|
| 109 |
+
return f_freqs_base, h_freqs_base, w_freqs_base
|
| 110 |
+
|
| 111 |
+
def forward(self, grid_ids):
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
f_freqs = grid_ids[:, 0, :].unsqueeze(-1) * self.f_freqs_base
|
| 114 |
+
h_freqs = grid_ids[:, 1, :].unsqueeze(-1) * self.h_freqs_base
|
| 115 |
+
w_freqs = grid_ids[:, 2, :].unsqueeze(-1) * self.w_freqs_base
|
| 116 |
+
freqs = torch.cat([f_freqs, h_freqs, w_freqs], dim=-1).float()
|
| 117 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| 118 |
+
|
| 119 |
+
return freqs_cis
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class WanAttention(torch.nn.Module):
|
| 123 |
+
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
dim,
|
| 127 |
+
heads=8,
|
| 128 |
+
dim_head=64,
|
| 129 |
+
eps=1e-5,
|
| 130 |
+
dropout=0.0,
|
| 131 |
+
cross_attention_dim_head=None,
|
| 132 |
+
attn_mode='torch',
|
| 133 |
+
):
|
| 134 |
+
super().__init__()
|
| 135 |
+
if attn_mode == 'torch':
|
| 136 |
+
self.attn_op = custom_sdpa
|
| 137 |
+
elif attn_mode == 'flashattn':
|
| 138 |
+
self.attn_op = flash_attn_func
|
| 139 |
+
else:
|
| 140 |
+
raise ValueError(
|
| 141 |
+
f"Unsupported attention mode: {attn_mode}, only support torch and flashattn"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
self.inner_dim = dim_head * heads
|
| 145 |
+
self.heads = heads
|
| 146 |
+
self.cross_attention_dim_head = cross_attention_dim_head
|
| 147 |
+
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
|
| 148 |
+
|
| 149 |
+
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
|
| 150 |
+
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
| 151 |
+
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
| 152 |
+
self.to_out = torch.nn.ModuleList([
|
| 153 |
+
torch.nn.Linear(self.inner_dim, dim, bias=True),
|
| 154 |
+
torch.nn.Dropout(dropout),
|
| 155 |
+
])
|
| 156 |
+
self.norm_q = torch.nn.RMSNorm(dim_head * heads,
|
| 157 |
+
eps=eps,
|
| 158 |
+
elementwise_affine=True)
|
| 159 |
+
self.norm_k = torch.nn.RMSNorm(dim_head * heads,
|
| 160 |
+
eps=eps,
|
| 161 |
+
elementwise_affine=True)
|
| 162 |
+
self.attn_caches = {} if cross_attention_dim_head is None else None
|
| 163 |
+
|
| 164 |
+
def clear_pred_cache(self, cache_name):
|
| 165 |
+
if self.attn_caches is None:
|
| 166 |
+
return
|
| 167 |
+
cache = self.attn_caches[cache_name]
|
| 168 |
+
is_pred = cache['is_pred']
|
| 169 |
+
cache['mask'][is_pred] = False
|
| 170 |
+
|
| 171 |
+
def clear_cache(self, cache_name):
|
| 172 |
+
if self.attn_caches is None:
|
| 173 |
+
return
|
| 174 |
+
self.attn_caches[cache_name] = None
|
| 175 |
+
|
| 176 |
+
def init_kv_cache(self, cache_name, total_tolen, num_head, head_dim,
|
| 177 |
+
device, dtype, batch_size):
|
| 178 |
+
if self.attn_caches is None:
|
| 179 |
+
return
|
| 180 |
+
self.attn_caches[cache_name] = {
|
| 181 |
+
'k':
|
| 182 |
+
torch.empty([batch_size, total_tolen, num_head, head_dim],
|
| 183 |
+
device=device,
|
| 184 |
+
dtype=dtype),
|
| 185 |
+
'v':
|
| 186 |
+
torch.empty([batch_size, total_tolen, num_head, head_dim],
|
| 187 |
+
device=device,
|
| 188 |
+
dtype=dtype),
|
| 189 |
+
'id':
|
| 190 |
+
torch.full((total_tolen, ), -1, device=device),
|
| 191 |
+
"mask":
|
| 192 |
+
torch.zeros((total_tolen, ), dtype=torch.bool, device=device),
|
| 193 |
+
"is_pred":
|
| 194 |
+
torch.zeros((total_tolen, ), dtype=torch.bool, device=device),
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
def allocate_slots(self, cache_name, key_size):
|
| 198 |
+
cache = self.attn_caches[cache_name]
|
| 199 |
+
mask = cache["mask"]
|
| 200 |
+
ids = cache["id"]
|
| 201 |
+
free = (~mask).nonzero(as_tuple=False).squeeze(-1)
|
| 202 |
+
|
| 203 |
+
if free.numel() < key_size:
|
| 204 |
+
used = mask.nonzero(as_tuple=False).squeeze(-1)
|
| 205 |
+
|
| 206 |
+
used_ids = ids[used]
|
| 207 |
+
order = torch.argsort(used_ids)
|
| 208 |
+
need = key_size - free.numel()
|
| 209 |
+
to_free = used[order[:need]]
|
| 210 |
+
|
| 211 |
+
mask[to_free] = False
|
| 212 |
+
ids[to_free] = -1
|
| 213 |
+
free = (~mask).nonzero(as_tuple=False).squeeze(-1)
|
| 214 |
+
|
| 215 |
+
assert free.numel() >= key_size
|
| 216 |
+
return free[:key_size]
|
| 217 |
+
|
| 218 |
+
def _next_cache_id(self, cache_name):
|
| 219 |
+
ids = self.attn_caches[cache_name]['id']
|
| 220 |
+
mask = self.attn_caches[cache_name]['mask']
|
| 221 |
+
|
| 222 |
+
if mask.any():
|
| 223 |
+
return ids[mask].max() + 1
|
| 224 |
+
else:
|
| 225 |
+
return torch.tensor(0, device=ids.device, dtype=ids.dtype)
|
| 226 |
+
|
| 227 |
+
def update_cache(self, cache_name, key, value, is_pred):
|
| 228 |
+
cache = self.attn_caches[cache_name]
|
| 229 |
+
|
| 230 |
+
key_size = key.shape[1]
|
| 231 |
+
slots = self.allocate_slots(cache_name, key_size)
|
| 232 |
+
|
| 233 |
+
new_id = self._next_cache_id(cache_name)
|
| 234 |
+
|
| 235 |
+
cache['k'][:, slots] = key
|
| 236 |
+
cache['v'][:, slots] = value
|
| 237 |
+
cache['mask'][slots] = True
|
| 238 |
+
cache['id'][slots] = new_id
|
| 239 |
+
cache['is_pred'][slots] = is_pred
|
| 240 |
+
return slots
|
| 241 |
+
|
| 242 |
+
def restore_cache(self, cache_name, slots):
|
| 243 |
+
self.attn_caches[cache_name]['mask'][slots] = False
|
| 244 |
+
|
| 245 |
+
def forward(
|
| 246 |
+
self,
|
| 247 |
+
q,
|
| 248 |
+
k,
|
| 249 |
+
v,
|
| 250 |
+
rotary_emb,
|
| 251 |
+
update_cache=0,
|
| 252 |
+
cache_name='pos',
|
| 253 |
+
):
|
| 254 |
+
kv_cache = self.attn_caches[
|
| 255 |
+
cache_name] if self.attn_caches is not None else None
|
| 256 |
+
|
| 257 |
+
query, key, value = self.to_q(q), self.to_k(k), self.to_v(v)
|
| 258 |
+
query = self.norm_q(query)
|
| 259 |
+
query = query.unflatten(2, (self.heads, -1))
|
| 260 |
+
key = self.norm_k(key)
|
| 261 |
+
key = key.unflatten(2, (self.heads, -1))
|
| 262 |
+
value = value.unflatten(2, (self.heads, -1))
|
| 263 |
+
if rotary_emb is not None:
|
| 264 |
+
|
| 265 |
+
def apply_rotary_emb(x, freqs):
|
| 266 |
+
x_out = torch.view_as_complex(
|
| 267 |
+
x.to(torch.float64).reshape(x.shape[0], x.shape[1],
|
| 268 |
+
x.shape[2], -1, 2))
|
| 269 |
+
x_out = torch.view_as_real(x_out * freqs).flatten(3)
|
| 270 |
+
return x_out.to(x.dtype)
|
| 271 |
+
query = apply_rotary_emb(query, rotary_emb)
|
| 272 |
+
key = apply_rotary_emb(key, rotary_emb)
|
| 273 |
+
slots = None
|
| 274 |
+
if kv_cache is not None and kv_cache['k'] is not None:
|
| 275 |
+
slots = self.update_cache(cache_name,
|
| 276 |
+
key,
|
| 277 |
+
value,
|
| 278 |
+
is_pred=(update_cache == 1))
|
| 279 |
+
key_pool = self.attn_caches[cache_name]['k']
|
| 280 |
+
value_pool = self.attn_caches[cache_name]['v']
|
| 281 |
+
mask = self.attn_caches[cache_name]['mask']
|
| 282 |
+
valid = mask.nonzero(as_tuple=False).squeeze(-1)
|
| 283 |
+
key = key_pool[:, valid]
|
| 284 |
+
value = value_pool[:, valid]
|
| 285 |
+
|
| 286 |
+
hidden_states = self.attn_op(query, key, value)
|
| 287 |
+
|
| 288 |
+
if update_cache == 0:
|
| 289 |
+
if kv_cache is not None and kv_cache['k'] is not None:
|
| 290 |
+
self.restore_cache(cache_name, slots)
|
| 291 |
+
|
| 292 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 293 |
+
hidden_states = hidden_states.type_as(query)
|
| 294 |
+
hidden_states = self.to_out[0](hidden_states)
|
| 295 |
+
hidden_states = self.to_out[1](hidden_states)
|
| 296 |
+
return hidden_states
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class WanTransformerBlock(nn.Module):
|
| 300 |
+
|
| 301 |
+
def __init__(
|
| 302 |
+
self,
|
| 303 |
+
dim,
|
| 304 |
+
ffn_dim,
|
| 305 |
+
num_heads,
|
| 306 |
+
cross_attn_norm=False,
|
| 307 |
+
eps=1e-6,
|
| 308 |
+
attn_mode: str = "flashattn",
|
| 309 |
+
):
|
| 310 |
+
super().__init__()
|
| 311 |
+
self.attn_mode = attn_mode
|
| 312 |
+
|
| 313 |
+
# 1. Self-attention
|
| 314 |
+
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
| 315 |
+
self.attn1 = WanAttention(
|
| 316 |
+
dim=dim,
|
| 317 |
+
heads=num_heads,
|
| 318 |
+
dim_head=dim // num_heads,
|
| 319 |
+
eps=eps,
|
| 320 |
+
cross_attention_dim_head=None,
|
| 321 |
+
attn_mode=attn_mode,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# 2. Cross-attention
|
| 325 |
+
self.attn2 = WanAttention(
|
| 326 |
+
dim=dim,
|
| 327 |
+
heads=num_heads,
|
| 328 |
+
dim_head=dim // num_heads,
|
| 329 |
+
eps=eps,
|
| 330 |
+
cross_attention_dim_head=dim // num_heads,
|
| 331 |
+
attn_mode=attn_mode,
|
| 332 |
+
)
|
| 333 |
+
self.norm2 = FP32LayerNorm(
|
| 334 |
+
dim, eps,
|
| 335 |
+
elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
| 336 |
+
|
| 337 |
+
# 3. Feed-forward
|
| 338 |
+
self.ffn = FeedForward(dim,
|
| 339 |
+
inner_dim=ffn_dim,
|
| 340 |
+
activation_fn="gelu-approximate")
|
| 341 |
+
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
| 342 |
+
|
| 343 |
+
self.scale_shift_table = nn.Parameter(
|
| 344 |
+
torch.randn(1, 6, dim) / dim**0.5)
|
| 345 |
+
|
| 346 |
+
def forward(
|
| 347 |
+
self,
|
| 348 |
+
hidden_states,
|
| 349 |
+
encoder_hidden_states,
|
| 350 |
+
temb,
|
| 351 |
+
rotary_emb,
|
| 352 |
+
update_cache=0,
|
| 353 |
+
cache_name='pos',
|
| 354 |
+
) -> torch.Tensor:
|
| 355 |
+
temb_scale_shift_table = self.scale_shift_table[None] + temb.float()
|
| 356 |
+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = \
|
| 357 |
+
rearrange(temb_scale_shift_table, 'b l n c -> b n l c').chunk(6, dim=1)
|
| 358 |
+
shift_msa = shift_msa.squeeze(1)
|
| 359 |
+
scale_msa = scale_msa.squeeze(1)
|
| 360 |
+
gate_msa = gate_msa.squeeze(1)
|
| 361 |
+
c_shift_msa = c_shift_msa.squeeze(1)
|
| 362 |
+
c_scale_msa = c_scale_msa.squeeze(1)
|
| 363 |
+
c_gate_msa = c_gate_msa.squeeze(1)
|
| 364 |
+
|
| 365 |
+
# 1. Self-attention
|
| 366 |
+
norm_hidden_states = (self.norm1(hidden_states.float()) *
|
| 367 |
+
(1. + scale_msa) +
|
| 368 |
+
shift_msa).type_as(hidden_states)
|
| 369 |
+
attn_output = self.attn1(norm_hidden_states,
|
| 370 |
+
norm_hidden_states,
|
| 371 |
+
norm_hidden_states,
|
| 372 |
+
rotary_emb,
|
| 373 |
+
update_cache=update_cache,
|
| 374 |
+
cache_name=cache_name)
|
| 375 |
+
hidden_states = (hidden_states.float() +
|
| 376 |
+
attn_output * gate_msa).type_as(hidden_states)
|
| 377 |
+
|
| 378 |
+
# 2. Cross-attention
|
| 379 |
+
norm_hidden_states = self.norm2(
|
| 380 |
+
hidden_states.float()).type_as(hidden_states)
|
| 381 |
+
attn_output = self.attn2(norm_hidden_states,
|
| 382 |
+
encoder_hidden_states,
|
| 383 |
+
encoder_hidden_states,
|
| 384 |
+
None,
|
| 385 |
+
update_cache=0,
|
| 386 |
+
cache_name=cache_name)
|
| 387 |
+
hidden_states = hidden_states + attn_output
|
| 388 |
+
|
| 389 |
+
# 3. Feed-forward
|
| 390 |
+
norm_hidden_states = (self.norm3(hidden_states.float()) *
|
| 391 |
+
(1. + c_scale_msa) +
|
| 392 |
+
c_shift_msa).type_as(hidden_states)
|
| 393 |
+
|
| 394 |
+
ff_output = self.ffn(norm_hidden_states)
|
| 395 |
+
|
| 396 |
+
hidden_states = (hidden_states.float() +
|
| 397 |
+
ff_output.float() * c_gate_msa).type_as(hidden_states)
|
| 398 |
+
return hidden_states
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class WanTransformer3DModel(ModelMixin, ConfigMixin):
|
| 402 |
+
r"""
|
| 403 |
+
TODO
|
| 404 |
+
"""
|
| 405 |
+
|
| 406 |
+
@register_to_config
|
| 407 |
+
def __init__(self,
|
| 408 |
+
patch_size=[1, 2, 2],
|
| 409 |
+
num_attention_heads=24,
|
| 410 |
+
attention_head_dim=128,
|
| 411 |
+
in_channels=48,
|
| 412 |
+
out_channels=48,
|
| 413 |
+
action_dim=30,
|
| 414 |
+
text_dim=4096,
|
| 415 |
+
freq_dim=256,
|
| 416 |
+
ffn_dim=14336,
|
| 417 |
+
num_layers=30,
|
| 418 |
+
cross_attn_norm=True,
|
| 419 |
+
eps=1e-06,
|
| 420 |
+
rope_max_seq_len=1024,
|
| 421 |
+
pos_embed_seq_len=None,
|
| 422 |
+
attn_mode="torch"):
|
| 423 |
+
r"""
|
| 424 |
+
TODO
|
| 425 |
+
"""
|
| 426 |
+
super().__init__()
|
| 427 |
+
self.patch_size = patch_size
|
| 428 |
+
self.num_attention_heads = num_attention_heads
|
| 429 |
+
self.attention_head_dim = attention_head_dim
|
| 430 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 431 |
+
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size,
|
| 432 |
+
rope_max_seq_len)
|
| 433 |
+
self.patch_embedding_mlp = nn.Linear(
|
| 434 |
+
in_channels * patch_size[0] * patch_size[1] * patch_size[2],
|
| 435 |
+
inner_dim)
|
| 436 |
+
self.action_embedder = nn.Linear(action_dim, inner_dim)
|
| 437 |
+
self.condition_embedder = WanTimeTextImageEmbedding(
|
| 438 |
+
dim=inner_dim,
|
| 439 |
+
time_freq_dim=freq_dim,
|
| 440 |
+
time_proj_dim=inner_dim * 6,
|
| 441 |
+
text_embed_dim=text_dim,
|
| 442 |
+
pos_embed_seq_len=pos_embed_seq_len,
|
| 443 |
+
)
|
| 444 |
+
self.condition_embedder_action = deepcopy(self.condition_embedder)
|
| 445 |
+
|
| 446 |
+
self.blocks = nn.ModuleList([
|
| 447 |
+
WanTransformerBlock(inner_dim,
|
| 448 |
+
ffn_dim,
|
| 449 |
+
num_attention_heads,
|
| 450 |
+
cross_attn_norm,
|
| 451 |
+
eps,
|
| 452 |
+
attn_mode=attn_mode) for _ in range(num_layers)
|
| 453 |
+
])
|
| 454 |
+
|
| 455 |
+
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
|
| 456 |
+
self.proj_out = nn.Linear(inner_dim,
|
| 457 |
+
out_channels * math.prod(patch_size))
|
| 458 |
+
self.action_proj_out = nn.Linear(inner_dim, action_dim)
|
| 459 |
+
self.scale_shift_table = nn.Parameter(
|
| 460 |
+
torch.randn(1, 2, inner_dim) / inner_dim**0.5)
|
| 461 |
+
|
| 462 |
+
def clear_cache(self, cache_name):
|
| 463 |
+
for block in self.blocks:
|
| 464 |
+
block.attn1.clear_cache(cache_name)
|
| 465 |
+
|
| 466 |
+
def clear_pred_cache(self, cache_name):
|
| 467 |
+
for block in self.blocks:
|
| 468 |
+
block.attn1.clear_pred_cache(cache_name)
|
| 469 |
+
|
| 470 |
+
def create_empty_cache(self, cache_name, attn_window,
|
| 471 |
+
latent_token_per_chunk, action_token_per_chunk,
|
| 472 |
+
device, dtype, batch_size):
|
| 473 |
+
total_tolen = (attn_window // 2) * latent_token_per_chunk + (
|
| 474 |
+
attn_window // 2) * action_token_per_chunk
|
| 475 |
+
for block in self.blocks:
|
| 476 |
+
block.attn1.init_kv_cache(cache_name, total_tolen,
|
| 477 |
+
self.num_attention_heads,
|
| 478 |
+
self.attention_head_dim, device, dtype, batch_size)
|
| 479 |
+
|
| 480 |
+
def forward(
|
| 481 |
+
self,
|
| 482 |
+
input_dict,
|
| 483 |
+
update_cache=0,
|
| 484 |
+
cache_name="pos",
|
| 485 |
+
action_mode=False,
|
| 486 |
+
):
|
| 487 |
+
r"""
|
| 488 |
+
Forward pass through the diffusion model
|
| 489 |
+
|
| 490 |
+
Args:
|
| 491 |
+
x (List[Tensor]):
|
| 492 |
+
List of input video tensors, each with shape [C_in, F, H, W]
|
| 493 |
+
t (Tensor):
|
| 494 |
+
Diffusion timesteps tensor of shape [B]
|
| 495 |
+
context (List[Tensor]):
|
| 496 |
+
List of text embeddings each with shape [L, C]
|
| 497 |
+
seq_len (`int`):
|
| 498 |
+
Maximum sequence length for positional encoding
|
| 499 |
+
y (List[Tensor], *optional*):
|
| 500 |
+
Conditional video inputs for image-to-video mode, same shape as x
|
| 501 |
+
|
| 502 |
+
Returns:
|
| 503 |
+
List[Tensor]:
|
| 504 |
+
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
| 505 |
+
"""
|
| 506 |
+
if action_mode: # action input emb
|
| 507 |
+
latent_hidden_states = rearrange(input_dict['noisy_latents'],
|
| 508 |
+
'b c f h w -> b (f h w) c')
|
| 509 |
+
latent_hidden_states = self.action_embedder(
|
| 510 |
+
latent_hidden_states) # B L1 C
|
| 511 |
+
else: # latent input emb
|
| 512 |
+
latent_hidden_states = rearrange(
|
| 513 |
+
input_dict['noisy_latents'],
|
| 514 |
+
'b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)',
|
| 515 |
+
p1=self.patch_size[0],
|
| 516 |
+
p2=self.patch_size[1],
|
| 517 |
+
p3=self.patch_size[2])
|
| 518 |
+
latent_hidden_states = self.patch_embedding_mlp(
|
| 519 |
+
latent_hidden_states)
|
| 520 |
+
text_hidden_states = self.condition_embedder.text_embedder(
|
| 521 |
+
input_dict["text_emb"]) # B L2 C
|
| 522 |
+
|
| 523 |
+
latent_grid_id = input_dict['grid_id']
|
| 524 |
+
rotary_emb = self.rope(latent_grid_id)[:, :, None] # 1 L 1 C
|
| 525 |
+
pach_scale_h, pach_scale_w = (1, 1) if action_mode else (
|
| 526 |
+
self.patch_size[1], self.patch_size[2])
|
| 527 |
+
|
| 528 |
+
latent_time_steps = torch.repeat_interleave(
|
| 529 |
+
input_dict['timesteps'],
|
| 530 |
+
(input_dict['noisy_latents'].shape[-2] // pach_scale_h) *
|
| 531 |
+
(input_dict['noisy_latents'].shape[-1] // pach_scale_w), dim=1) # L
|
| 532 |
+
current_condition_embedder = self.condition_embedder_action if action_mode else self.condition_embedder
|
| 533 |
+
temb, timestep_proj = current_condition_embedder(
|
| 534 |
+
latent_time_steps, dtype=latent_hidden_states.dtype)
|
| 535 |
+
timestep_proj = timestep_proj.unflatten(2, (6, -1)) # B L 6 C
|
| 536 |
+
|
| 537 |
+
for block in self.blocks:
|
| 538 |
+
latent_hidden_states = block(latent_hidden_states,
|
| 539 |
+
text_hidden_states,
|
| 540 |
+
timestep_proj,
|
| 541 |
+
rotary_emb,
|
| 542 |
+
update_cache=update_cache,
|
| 543 |
+
cache_name=cache_name)
|
| 544 |
+
temb_scale_shift_table = self.scale_shift_table[None] + temb[:, :, None, ...]
|
| 545 |
+
shift, scale = rearrange(temb_scale_shift_table,
|
| 546 |
+
'b l n c -> b n l c').chunk(2, dim=1)
|
| 547 |
+
shift = shift.to(latent_hidden_states.device).squeeze(1)
|
| 548 |
+
scale = scale.to(latent_hidden_states.device).squeeze(1)
|
| 549 |
+
latent_hidden_states = (self.norm_out(latent_hidden_states.float()) *
|
| 550 |
+
(1. + scale) +
|
| 551 |
+
shift).type_as(latent_hidden_states)
|
| 552 |
+
|
| 553 |
+
if action_mode:
|
| 554 |
+
latent_hidden_states = self.action_proj_out(latent_hidden_states)
|
| 555 |
+
else:
|
| 556 |
+
latent_hidden_states = self.proj_out(latent_hidden_states)
|
| 557 |
+
latent_hidden_states = rearrange(latent_hidden_states,
|
| 558 |
+
'b l (n c) -> b (l n) c',
|
| 559 |
+
n=math.prod(self.patch_size)) #
|
| 560 |
+
|
| 561 |
+
return latent_hidden_states
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
if __name__ == '__main__':
|
| 565 |
+
model = WanTransformer3DModel(patch_size=[1, 2, 2],
|
| 566 |
+
num_attention_heads=24,
|
| 567 |
+
attention_head_dim=128,
|
| 568 |
+
in_channels=48,
|
| 569 |
+
out_channels=48,
|
| 570 |
+
action_dim=30,
|
| 571 |
+
text_dim=4096,
|
| 572 |
+
freq_dim=256,
|
| 573 |
+
ffn_dim=14336,
|
| 574 |
+
num_layers=30,
|
| 575 |
+
cross_attn_norm=True,
|
| 576 |
+
eps=1e-6,
|
| 577 |
+
rope_max_seq_len=1024,
|
| 578 |
+
pos_embed_seq_len=None,
|
| 579 |
+
attn_mode="torch")
|
| 580 |
+
print(model)
|
wan_va/modules/utils.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Robbyant Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
from diffusers import AutoencoderKLWan
|
| 4 |
+
from transformers import (
|
| 5 |
+
T5TokenizerFast,
|
| 6 |
+
UMT5EncoderModel,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
from .model import WanTransformer3DModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_vae(
|
| 13 |
+
vae_path,
|
| 14 |
+
torch_dtype,
|
| 15 |
+
torch_device,
|
| 16 |
+
):
|
| 17 |
+
vae = AutoencoderKLWan.from_pretrained(
|
| 18 |
+
vae_path,
|
| 19 |
+
torch_dtype=torch_dtype,
|
| 20 |
+
)
|
| 21 |
+
return vae.to(torch_device)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_text_encoder(
|
| 25 |
+
text_encoder_path,
|
| 26 |
+
torch_dtype,
|
| 27 |
+
torch_device,
|
| 28 |
+
):
|
| 29 |
+
text_encoder = UMT5EncoderModel.from_pretrained(
|
| 30 |
+
text_encoder_path,
|
| 31 |
+
torch_dtype=torch_dtype,
|
| 32 |
+
)
|
| 33 |
+
return text_encoder.to(torch_device)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_tokenizer(tokenizer_path, ):
|
| 37 |
+
tokenizer = T5TokenizerFast.from_pretrained(tokenizer_path, )
|
| 38 |
+
return tokenizer
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load_transformer(
|
| 42 |
+
transformer_path,
|
| 43 |
+
torch_dtype,
|
| 44 |
+
torch_device,
|
| 45 |
+
):
|
| 46 |
+
model = WanTransformer3DModel.from_pretrained(
|
| 47 |
+
transformer_path,
|
| 48 |
+
torch_dtype=torch_dtype,
|
| 49 |
+
)
|
| 50 |
+
return model.to(torch_device)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def patchify(x, patch_size):
|
| 54 |
+
if patch_size is None or patch_size == 1:
|
| 55 |
+
return x
|
| 56 |
+
batch_size, channels, frames, height, width = x.shape
|
| 57 |
+
x = x.view(batch_size, channels, frames, height // patch_size, patch_size,
|
| 58 |
+
width // patch_size, patch_size)
|
| 59 |
+
x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
|
| 60 |
+
x = x.view(batch_size, channels * patch_size * patch_size, frames,
|
| 61 |
+
height // patch_size, width // patch_size)
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class WanVAEStreamingWrapper:
|
| 66 |
+
|
| 67 |
+
def __init__(self, vae_model):
|
| 68 |
+
self.vae = vae_model
|
| 69 |
+
self.encoder = vae_model.encoder
|
| 70 |
+
self.quant_conv = vae_model.quant_conv
|
| 71 |
+
|
| 72 |
+
if hasattr(self.vae, "_cached_conv_counts"):
|
| 73 |
+
self.enc_conv_num = self.vae._cached_conv_counts["encoder"]
|
| 74 |
+
else:
|
| 75 |
+
count = 0
|
| 76 |
+
for m in self.encoder.modules():
|
| 77 |
+
if m.__class__.__name__ == "WanCausalConv3d":
|
| 78 |
+
count += 1
|
| 79 |
+
self.enc_conv_num = count
|
| 80 |
+
|
| 81 |
+
self.clear_cache()
|
| 82 |
+
|
| 83 |
+
def clear_cache(self):
|
| 84 |
+
self.feat_cache = [None] * self.enc_conv_num
|
| 85 |
+
|
| 86 |
+
def encode_chunk(self, x_chunk):
|
| 87 |
+
if hasattr(self.vae.config,
|
| 88 |
+
"patch_size") and self.vae.config.patch_size is not None:
|
| 89 |
+
x_chunk = patchify(x_chunk, self.vae.config.patch_size)
|
| 90 |
+
feat_idx = [0]
|
| 91 |
+
out = self.encoder(x_chunk,
|
| 92 |
+
feat_cache=self.feat_cache,
|
| 93 |
+
feat_idx=feat_idx)
|
| 94 |
+
enc = self.quant_conv(out)
|
| 95 |
+
return enc
|
wan_va/utils/Simple_Remote_Infer/LEGAL.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Legal Disclaimer
|
| 2 |
+
|
| 3 |
+
Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail.
|
| 4 |
+
|
| 5 |
+
法律免责声明
|
| 6 |
+
|
| 7 |
+
关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。
|
wan_va/utils/Simple_Remote_Infer/README.md
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 通用的server-client
|
| 2 |
+
|
| 3 |
+
## /Simple_Remote_Infer/deploy/qwenpi_policy.py
|
| 4 |
+
|
| 5 |
+
- QwenPiServer: 一个用于示范的模型,拥有init和infer方法
|
| 6 |
+
|
| 7 |
+
将加载好的模型用`WebsocketPolicyServer`包裹,并指定端口即可
|
| 8 |
+
|
| 9 |
+
```python
|
| 10 |
+
model_server = WebsocketPolicyServer(model, port=8002)
|
| 11 |
+
model_server.serve_forever() # 开启监听
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
## ./websocket_client_policy.py
|
| 15 |
+
|
| 16 |
+
在`__main__()`中展示了如何创造一个假模型向真模型发送环境信息,只需要用`WebsocketClientPolicy`代替原有的模型即可
|
wan_va/utils/Simple_Remote_Infer/deploy/__init__.py
ADDED
|
File without changes
|
wan_va/utils/Simple_Remote_Infer/deploy/image_tools.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def convert_to_uint8(img: np.ndarray) -> np.ndarray:
|
| 6 |
+
"""Converts an image to uint8 if it is a float image.
|
| 7 |
+
|
| 8 |
+
This is important for reducing the size of the image when sending it over the network.
|
| 9 |
+
"""
|
| 10 |
+
if np.issubdtype(img.dtype, np.floating):
|
| 11 |
+
img = (255 * img).astype(np.uint8)
|
| 12 |
+
return img
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def resize_with_pad(images: np.ndarray,
|
| 16 |
+
height: int,
|
| 17 |
+
width: int,
|
| 18 |
+
method=Image.BILINEAR) -> np.ndarray:
|
| 19 |
+
"""Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
images: A batch of images in [..., height, width, channel] format.
|
| 23 |
+
height: The target height of the image.
|
| 24 |
+
width: The target width of the image.
|
| 25 |
+
method: The interpolation method to use. Default is bilinear.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
The resized images in [..., height, width, channel].
|
| 29 |
+
"""
|
| 30 |
+
# If the images are already the correct size, return them as is.
|
| 31 |
+
if images.shape[-3:-1] == (height, width):
|
| 32 |
+
return images
|
| 33 |
+
|
| 34 |
+
original_shape = images.shape
|
| 35 |
+
|
| 36 |
+
images = images.reshape(-1, *original_shape[-3:])
|
| 37 |
+
resized = np.stack([
|
| 38 |
+
_resize_with_pad_pil(Image.fromarray(im), height, width, method=method)
|
| 39 |
+
for im in images
|
| 40 |
+
])
|
| 41 |
+
return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _resize_with_pad_pil(image: Image.Image, height: int, width: int,
|
| 45 |
+
method: int) -> Image.Image:
|
| 46 |
+
"""Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
|
| 47 |
+
width without distortion by padding with zeros.
|
| 48 |
+
|
| 49 |
+
Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
|
| 50 |
+
"""
|
| 51 |
+
cur_width, cur_height = image.size
|
| 52 |
+
if cur_width == width and cur_height == height:
|
| 53 |
+
return image # No need to resize if the image is already the correct size.
|
| 54 |
+
|
| 55 |
+
ratio = max(cur_width / width, cur_height / height)
|
| 56 |
+
resized_height = int(cur_height / ratio)
|
| 57 |
+
resized_width = int(cur_width / ratio)
|
| 58 |
+
resized_image = image.resize((resized_width, resized_height),
|
| 59 |
+
resample=method)
|
| 60 |
+
|
| 61 |
+
zero_image = Image.new(resized_image.mode, (width, height), 0)
|
| 62 |
+
pad_height = max(0, int((height - resized_height) / 2))
|
| 63 |
+
pad_width = max(0, int((width - resized_width) / 2))
|
| 64 |
+
zero_image.paste(resized_image, (pad_width, pad_height))
|
| 65 |
+
assert zero_image.size == (width, height)
|
| 66 |
+
return zero_image
|