lsnu commited on
Commit
d72206d
·
verified ·
1 Parent(s): 408c79e

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. MANIFEST.txt +0 -0
  2. README.md +51 -1
  3. external/README.md +26 -0
  4. external/peract_bimanual/LICENSE +402 -0
  5. external/peract_bimanual/README.md +300 -0
  6. external/peract_bimanual/helpers/clip/__init__.py +0 -0
  7. external/peract_bimanual/model-card.md +47 -0
  8. external/peract_bimanual/peract_config.py +32 -0
  9. external/peract_bimanual/pyproject.toml +35 -0
  10. external/peract_bimanual/run_seed_fn.py +218 -0
  11. external/peract_bimanual/train.py +116 -0
  12. external/peract_bimanual/voxel/__init__.py +0 -0
  13. external/peract_bimanual/voxel/voxel_grid.py +252 -0
  14. external/yarr/.gitignore +13 -0
  15. external/yarr/LICENSE +201 -0
  16. external/yarr/README.md +28 -0
  17. external/yarr/logo.png +0 -0
  18. external/yarr/requirements.txt +11 -0
  19. external/yarr/setup.py +37 -0
  20. external/yarr/yarr/__init__.py +1 -0
  21. external/yarr/yarr/agents/__init__.py +0 -0
  22. external/yarr/yarr/agents/agent.py +345 -0
  23. external/yarr/yarr/envs/__init__.py +0 -0
  24. external/yarr/yarr/envs/env.py +64 -0
  25. external/yarr/yarr/envs/rlbench_env.py +332 -0
  26. external/yarr/yarr/replay_buffer/__init__.py +0 -0
  27. external/yarr/yarr/replay_buffer/prioritized_replay_buffer.py +217 -0
  28. external/yarr/yarr/replay_buffer/replay_buffer.py +71 -0
  29. external/yarr/yarr/replay_buffer/sum_tree.py +201 -0
  30. external/yarr/yarr/replay_buffer/task_uniform_replay_buffer.py +182 -0
  31. external/yarr/yarr/replay_buffer/uniform_replay_buffer.py +804 -0
  32. external/yarr/yarr/replay_buffer/wrappers/__init__.py +24 -0
  33. external/yarr/yarr/replay_buffer/wrappers/pytorch_replay_buffer.py +82 -0
  34. external/yarr/yarr/runners/__init__.py +0 -0
  35. external/yarr/yarr/runners/_env_runner.py +228 -0
  36. external/yarr/yarr/runners/_independent_env_runner.py +297 -0
  37. external/yarr/yarr/runners/env_runner.py +224 -0
  38. external/yarr/yarr/runners/independent_env_runner.py +130 -0
  39. external/yarr/yarr/runners/offline_train_runner.py +163 -0
  40. external/yarr/yarr/runners/pytorch_train_runner.py +308 -0
  41. external/yarr/yarr/runners/train_runner.py +37 -0
  42. external/yarr/yarr/utils/__init__.py +0 -0
  43. external/yarr/yarr/utils/log_writer.py +128 -0
  44. external/yarr/yarr/utils/multi_task_rollout_generator.py +65 -0
  45. external/yarr/yarr/utils/observation_type.py +10 -0
  46. external/yarr/yarr/utils/process_str.py +5 -0
  47. external/yarr/yarr/utils/rollout_generator.py +89 -0
  48. external/yarr/yarr/utils/stat_accumulator.py +192 -0
  49. external/yarr/yarr/utils/transition.py +33 -0
  50. external/yarr/yarr/utils/video_utils.py +80 -0
MANIFEST.txt CHANGED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -14,6 +14,12 @@ This pass is a label study, not a policy study. No `pi0.5` integration is includ
14
 
15
  ## What is in this upload
16
 
 
 
 
 
 
 
17
  - `code/rr_label_study/`
18
  - Core study code, including dense replay, visibility metrics, pregrasp/extraction oracles, keyframe extraction, intervention checks, and summary metric computation.
19
  - `code/scripts/`
@@ -30,6 +36,44 @@ This pass is a label study, not a policy study. No `pi0.5` integration is includ
30
  - `MANIFEST.txt`
31
  - Flat file listing of the uploaded bundle contents.
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  ## Final validated artifact
34
 
35
  The clean single-episode artifact is:
@@ -122,11 +166,13 @@ The local run used:
122
  - `markusgrotz/peract_bimanual` at `bb0232a6ba3fe116566e9568f0c7af980ed6703d`
123
  - `markusgrotz/YARR` at `6822ff78602c77878b27d4cfe759ce029c67bffb`
124
 
 
 
125
  ## Reproducing on the same hardware class
126
 
127
  1. Read `environment/dataset_notes.txt`.
128
  2. Run `environment/setup_same_hardware.sh /workspace`.
129
- 3. Source `environment/activate_rlbench_runtime.sh /workspace`.
130
  4. Run the dense study:
131
 
132
  ```bash
@@ -160,3 +206,7 @@ On a single deterministic episode, normalized time can become a degenerate perfe
160
  ## Dataset note
161
 
162
  The upstream RLBench demonstration dataset itself is not re-uploaded in this bundle. This repo contains the study code and all artifacts generated from the local run. The expected dataset path is documented in `environment/dataset_notes.txt`.
 
 
 
 
 
14
 
15
  ## What is in this upload
16
 
17
+ - `external/`
18
+ - Full local benchmark snapshots copied from the RunPod workspace.
19
+ - `external/rlbench/`: local RLBench tree used for this run.
20
+ - `external/pyrep/`: local PyRep tree used for this run.
21
+ - `external/peract_bimanual/`: local PerAct bimanual tree used for context.
22
+ - `external/yarr/`: local YARR tree used for context.
23
  - `code/rr_label_study/`
24
  - Core study code, including dense replay, visibility metrics, pregrasp/extraction oracles, keyframe extraction, intervention checks, and summary metric computation.
25
  - `code/scripts/`
 
36
  - `MANIFEST.txt`
37
  - Flat file listing of the uploaded bundle contents.
38
 
39
+ ## Repository map
40
+
41
+ Relevant entry points and where to look:
42
+
43
+ - Benchmark snapshots
44
+ - `external/README.md`
45
+ - `external/rlbench/README.md`
46
+ - `external/rlbench/rlbench/bimanual_tasks/`
47
+ - `external/rlbench/rlbench/action_modes/`
48
+ - `external/pyrep/README.md`
49
+ - `external/pyrep/pyrep/`
50
+ - `external/peract_bimanual/`
51
+ - `external/yarr/`
52
+ - Study code
53
+ - `code/rr_label_study/oven_study.py`
54
+ - `code/scripts/run_oven_label_study.py`
55
+ - `code/scripts/launch_parallel_oven_label_study.py`
56
+ - `code/scripts/run_oven_single_frame.py`
57
+ - `code/scripts/repair_oven_episode_dense.py`
58
+ - Final clean artifact
59
+ - `artifacts/results/oven_episode0_repaired_v1/episode0.dense.csv`
60
+ - `artifacts/results/oven_episode0_repaired_v1/episode0.keyframes.csv`
61
+ - `artifacts/results/oven_episode0_repaired_v1/episode0.metrics.json`
62
+ - `artifacts/results/oven_episode0_repaired_v1/summary.json`
63
+ - Intermediate/debug artifacts
64
+ - `artifacts/results/oven_episode0_full*/`
65
+ - `artifacts/results/oven_to240_*/`
66
+ - `artifacts/results/oven_episode0_independent_v1/`
67
+ - `artifacts/results/parallel_smoke_2x10/`
68
+ - Environment/repro
69
+ - `environment/system_info.txt`
70
+ - `environment/repo_revisions.txt`
71
+ - `environment/conda_env_rlbench.yml`
72
+ - `environment/pip_freeze_rlbench.txt`
73
+ - `environment/setup_same_hardware.sh`
74
+ - `environment/activate_rlbench_runtime.sh`
75
+ - `environment/dataset_notes.txt`
76
+
77
  ## Final validated artifact
78
 
79
  The clean single-episode artifact is:
 
166
  - `markusgrotz/peract_bimanual` at `bb0232a6ba3fe116566e9568f0c7af980ed6703d`
167
  - `markusgrotz/YARR` at `6822ff78602c77878b27d4cfe759ce029c67bffb`
168
 
169
+ Those exact local source snapshots are also included under `external/`.
170
+
171
  ## Reproducing on the same hardware class
172
 
173
  1. Read `environment/dataset_notes.txt`.
174
  2. Run `environment/setup_same_hardware.sh /workspace`.
175
+ 3. Source `environment/activate_rlbench_runtime.sh /workspace`.
176
  4. Run the dense study:
177
 
178
  ```bash
 
206
  ## Dataset note
207
 
208
  The upstream RLBench demonstration dataset itself is not re-uploaded in this bundle. This repo contains the study code and all artifacts generated from the local run. The expected dataset path is documented in `environment/dataset_notes.txt`.
209
+
210
+ The cloned benchmark code is included directly in this upload under `external/`.
211
+
212
+ CoppeliaSim binaries are not included in this repo. The setup helpers expect a local extraction at `/workspace/coppelia_sim`.
external/README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # External Benchmark Snapshots
2
+
3
+ This directory contains the local benchmark/source trees copied from the RunPod workspace used for the study.
4
+
5
+ Included trees:
6
+
7
+ - `rlbench/`
8
+ - Source snapshot of `/workspace/rlbench`
9
+ - Upstream: `https://github.com/markusgrotz/RLBench.git`
10
+ - Commit: `8af748c51287989294e00c9c670e3330a0e35ed5`
11
+ - `pyrep/`
12
+ - Source snapshot of `/workspace/pyrep`
13
+ - Upstream: `https://github.com/markusgrotz/PyRep.git`
14
+ - Commit: `b8bd1d7a3182adcd570d001649c0849047ebf197`
15
+ - `peract_bimanual/`
16
+ - Source snapshot of `/workspace/peract_bimanual`
17
+ - Upstream: `https://github.com/markusgrotz/peract_bimanual.git`
18
+ - Commit: `bb0232a6ba3fe116566e9568f0c7af980ed6703d`
19
+ - `yarr/`
20
+ - Source snapshot of `/workspace/yarr`
21
+ - Upstream: `https://github.com/markusgrotz/YARR.git`
22
+ - Commit: `6822ff78602c77878b27d4cfe759ce029c67bffb`
23
+
24
+ These are source snapshots, not git clones with `.git/` metadata.
25
+
26
+ See `../environment/repo_revisions.txt` for the recorded origin URLs and revisions.
external/peract_bimanual/LICENSE ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
202
+ Apache License
203
+ Version 2.0, January 2004
204
+ http://www.apache.org/licenses/
205
+
206
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
207
+
208
+ 1. Definitions.
209
+
210
+ "License" shall mean the terms and conditions for use, reproduction,
211
+ and distribution as defined by Sections 1 through 9 of this document.
212
+
213
+ "Licensor" shall mean the copyright owner or entity authorized by
214
+ the copyright owner that is granting the License.
215
+
216
+ "Legal Entity" shall mean the union of the acting entity and all
217
+ other entities that control, are controlled by, or are under common
218
+ control with that entity. For the purposes of this definition,
219
+ "control" means (i) the power, direct or indirect, to cause the
220
+ direction or management of such entity, whether by contract or
221
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
222
+ outstanding shares, or (iii) beneficial ownership of such entity.
223
+
224
+ "You" (or "Your") shall mean an individual or Legal Entity
225
+ exercising permissions granted by this License.
226
+
227
+ "Source" form shall mean the preferred form for making modifications,
228
+ including but not limited to software source code, documentation
229
+ source, and configuration files.
230
+
231
+ "Object" form shall mean any form resulting from mechanical
232
+ transformation or translation of a Source form, including but
233
+ not limited to compiled object code, generated documentation,
234
+ and conversions to other media types.
235
+
236
+ "Work" shall mean the work of authorship, whether in Source or
237
+ Object form, made available under the License, as indicated by a
238
+ copyright notice that is included in or attached to the work
239
+ (an example is provided in the Appendix below).
240
+
241
+ "Derivative Works" shall mean any work, whether in Source or Object
242
+ form, that is based on (or derived from) the Work and for which the
243
+ editorial revisions, annotations, elaborations, or other modifications
244
+ represent, as a whole, an original work of authorship. For the purposes
245
+ of this License, Derivative Works shall not include works that remain
246
+ separable from, or merely link (or bind by name) to the interfaces of,
247
+ the Work and Derivative Works thereof.
248
+
249
+ "Contribution" shall mean any work of authorship, including
250
+ the original version of the Work and any modifications or additions
251
+ to that Work or Derivative Works thereof, that is intentionally
252
+ submitted to Licensor for inclusion in the Work by the copyright owner
253
+ or by an individual or Legal Entity authorized to submit on behalf of
254
+ the copyright owner. For the purposes of this definition, "submitted"
255
+ means any form of electronic, verbal, or written communication sent
256
+ to the Licensor or its representatives, including but not limited to
257
+ communication on electronic mailing lists, source code control systems,
258
+ and issue tracking systems that are managed by, or on behalf of, the
259
+ Licensor for the purpose of discussing and improving the Work, but
260
+ excluding communication that is conspicuously marked or otherwise
261
+ designated in writing by the copyright owner as "Not a Contribution."
262
+
263
+ "Contributor" shall mean Licensor and any individual or Legal Entity
264
+ on behalf of whom a Contribution has been received by Licensor and
265
+ subsequently incorporated within the Work.
266
+
267
+ 2. Grant of Copyright License. Subject to the terms and conditions of
268
+ this License, each Contributor hereby grants to You a perpetual,
269
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
270
+ copyright license to reproduce, prepare Derivative Works of,
271
+ publicly display, publicly perform, sublicense, and distribute the
272
+ Work and such Derivative Works in Source or Object form.
273
+
274
+ 3. Grant of Patent License. Subject to the terms and conditions of
275
+ this License, each Contributor hereby grants to You a perpetual,
276
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
277
+ (except as stated in this section) patent license to make, have made,
278
+ use, offer to sell, sell, import, and otherwise transfer the Work,
279
+ where such license applies only to those patent claims licensable
280
+ by such Contributor that are necessarily infringed by their
281
+ Contribution(s) alone or by combination of their Contribution(s)
282
+ with the Work to which such Contribution(s) was submitted. If You
283
+ institute patent litigation against any entity (including a
284
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
285
+ or a Contribution incorporated within the Work constitutes direct
286
+ or contributory patent infringement, then any patent licenses
287
+ granted to You under this License for that Work shall terminate
288
+ as of the date such litigation is filed.
289
+
290
+ 4. Redistribution. You may reproduce and distribute copies of the
291
+ Work or Derivative Works thereof in any medium, with or without
292
+ modifications, and in Source or Object form, provided that You
293
+ meet the following conditions:
294
+
295
+ (a) You must give any other recipients of the Work or
296
+ Derivative Works a copy of this License; and
297
+
298
+ (b) You must cause any modified files to carry prominent notices
299
+ stating that You changed the files; and
300
+
301
+ (c) You must retain, in the Source form of any Derivative Works
302
+ that You distribute, all copyright, patent, trademark, and
303
+ attribution notices from the Source form of the Work,
304
+ excluding those notices that do not pertain to any part of
305
+ the Derivative Works; and
306
+
307
+ (d) If the Work includes a "NOTICE" text file as part of its
308
+ distribution, then any Derivative Works that You distribute must
309
+ include a readable copy of the attribution notices contained
310
+ within such NOTICE file, excluding those notices that do not
311
+ pertain to any part of the Derivative Works, in at least one
312
+ of the following places: within a NOTICE text file distributed
313
+ as part of the Derivative Works; within the Source form or
314
+ documentation, if provided along with the Derivative Works; or,
315
+ within a display generated by the Derivative Works, if and
316
+ wherever such third-party notices normally appear. The contents
317
+ of the NOTICE file are for informational purposes only and
318
+ do not modify the License. You may add Your own attribution
319
+ notices within Derivative Works that You distribute, alongside
320
+ or as an addendum to the NOTICE text from the Work, provided
321
+ that such additional attribution notices cannot be construed
322
+ as modifying the License.
323
+
324
+ You may add Your own copyright statement to Your modifications and
325
+ may provide additional or different license terms and conditions
326
+ for use, reproduction, or distribution of Your modifications, or
327
+ for any such Derivative Works as a whole, provided Your use,
328
+ reproduction, and distribution of the Work otherwise complies with
329
+ the conditions stated in this License.
330
+
331
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
332
+ any Contribution intentionally submitted for inclusion in the Work
333
+ by You to the Licensor shall be under the terms and conditions of
334
+ this License, without any additional terms or conditions.
335
+ Notwithstanding the above, nothing herein shall supersede or modify
336
+ the terms of any separate license agreement you may have executed
337
+ with Licensor regarding such Contributions.
338
+
339
+ 6. Trademarks. This License does not grant permission to use the trade
340
+ names, trademarks, service marks, or product names of the Licensor,
341
+ except as required for reasonable and customary use in describing the
342
+ origin of the Work and reproducing the content of the NOTICE file.
343
+
344
+ 7. Disclaimer of Warranty. Unless required by applicable law or
345
+ agreed to in writing, Licensor provides the Work (and each
346
+ Contributor provides its Contributions) on an "AS IS" BASIS,
347
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
348
+ implied, including, without limitation, any warranties or conditions
349
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
350
+ PARTICULAR PURPOSE. You are solely responsible for determining the
351
+ appropriateness of using or redistributing the Work and assume any
352
+ risks associated with Your exercise of permissions under this License.
353
+
354
+ 8. Limitation of Liability. In no event and under no legal theory,
355
+ whether in tort (including negligence), contract, or otherwise,
356
+ unless required by applicable law (such as deliberate and grossly
357
+ negligent acts) or agreed to in writing, shall any Contributor be
358
+ liable to You for damages, including any direct, indirect, special,
359
+ incidental, or consequential damages of any character arising as a
360
+ result of this License or out of the use or inability to use the
361
+ Work (including but not limited to damages for loss of goodwill,
362
+ work stoppage, computer failure or malfunction, or any and all
363
+ other commercial damages or losses), even if such Contributor
364
+ has been advised of the possibility of such damages.
365
+
366
+ 9. Accepting Warranty or Additional Liability. While redistributing
367
+ the Work or Derivative Works thereof, You may choose to offer,
368
+ and charge a fee for, acceptance of support, warranty, indemnity,
369
+ or other liability obligations and/or rights consistent with this
370
+ License. However, in accepting such obligations, You may act only
371
+ on Your own behalf and on Your sole responsibility, not on behalf
372
+ of any other Contributor, and only if You agree to indemnify,
373
+ defend, and hold each Contributor harmless for any liability
374
+ incurred by, or claims asserted against, such Contributor by reason
375
+ of your accepting any such warranty or additional liability.
376
+
377
+ END OF TERMS AND CONDITIONS
378
+
379
+ APPENDIX: How to apply the Apache License to your work.
380
+
381
+ To apply the Apache License to your work, attach the following
382
+ boilerplate notice, with the fields enclosed by brackets "[]"
383
+ replaced with your own identifying information. (Don't include
384
+ the brackets!) The text should be enclosed in the appropriate
385
+ comment syntax for the file format. We also recommend that a
386
+ file or class name and description of purpose be included on the
387
+ same "printed page" as the copyright notice for easier
388
+ identification within third-party archives.
389
+
390
+ Copyright [yyyy] [name of copyright owner]
391
+
392
+ Licensed under the Apache License, Version 2.0 (the "License");
393
+ you may not use this file except in compliance with the License.
394
+ You may obtain a copy of the License at
395
+
396
+ http://www.apache.org/licenses/LICENSE-2.0
397
+
398
+ Unless required by applicable law or agreed to in writing, software
399
+ distributed under the License is distributed on an "AS IS" BASIS,
400
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
401
+ See the License for the specific language governing permissions and
402
+ limitations under the License.
external/peract_bimanual/README.md ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Perceiver-Actor^2: A Multi-Task Transformer for Bimanual Robotic Manipulation Tasks
2
+
3
+ [![Code style](https://img.shields.io/badge/code%20style-black-black)](https://black.readthedocs.io/en/stable/)
4
+
5
+ This work extends previous work [PerAct](https://peract.github.io) as well as
6
+ [RLBench](https://sites.google.com/view/rlbench) for bimanual manipulation
7
+ tasks.
8
+
9
+ The repository and documentation are still work in progress.
10
+
11
+ For the latest updates, see: [bimanual.github.io](https://bimanual.github.io)
12
+
13
+
14
+ ## Installation
15
+
16
+
17
+ Please see [Installation](INSTALLATION.md) for further details.
18
+
19
+ ### Prerequisites
20
+
21
+ The code PerAct^2 is built-off the [PerAct](https://peract.github.io) which itself is
22
+ built on the [ARM repository](https://github.com/stepjam/ARM) by James et al.
23
+ The prerequisites are the same as PerAct or ARM.
24
+
25
+ #### 1. Environment
26
+
27
+
28
+ Install miniconda if not already present on the current system.
29
+ You can use `scripts/install_conda.sh` for this step:
30
+
31
+ ```bash
32
+
33
+ sudo apt install curl
34
+
35
+ curl -L -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
36
+ chmod +x Miniconda3-latest-Linux-x86_64.sh
37
+ ./Miniconda3-latest-Linux-x86_64.sh
38
+
39
+ SHELL_NAME=`basename $SHELL`
40
+ eval "$($HOME/miniconda3/bin/conda shell.${SHELL_NAME} hook)"
41
+ conda init ${SHELL_NAME}
42
+ conda install mamba -c conda-forge
43
+ conda config --set auto_activate_base false
44
+ ```
45
+
46
+ Next, create the rlbench environment and install the dependencies
47
+
48
+ ```bash
49
+ conda create -n rlbench python=3.8
50
+ conda activate rlbench
51
+ conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
52
+ ```
53
+
54
+
55
+ #### 2. Dependencies
56
+
57
+ You need to setup [RLBench](https://github.com/markusgrotz/rlbench/), [Pyrep](https://github.com/markusgrotz/Pyrep/), and [YARR](https://github.com/markusgrotz/YARR/).
58
+ Please note that due to the bimanual functionallity the main repository does not work.
59
+ You can use `scripts/install_dependencies.sh` to do so.
60
+ See [Installation](INSTALLATION.md) for details.
61
+
62
+ ```bash
63
+ ./scripts/install_dependencies.sh
64
+ ```
65
+
66
+
67
+
68
+ ### Pre-Generated Datasets
69
+
70
+
71
+ Please checkout the website for [pre-generated RLBench
72
+ demonstrations](https://bimanual.github.io). If you directly use these
73
+ datasets, you don't need to run `tools/bimanual_data_generator.py` from
74
+ RLBench. Using these datasets will also help reproducibility since each scene
75
+ is randomly sampled in `data_generator_bimanual.py`.
76
+
77
+ ### Training
78
+
79
+
80
+ #### Single-GPU Training
81
+
82
+ To configure and train the model, follow these guidelines:
83
+
84
+ - **General Parameters**: You can find and modify general parameters in the `conf/config.yaml` file. This file contains overall settings for the training environment, such as the number of cameras or the the tasks to use.
85
+
86
+ - **Method-Specific Parameters**: For parameters specific to each method, refer to the corresponding files located in the `conf/method` directory. These files define configurations tailored to each method's requirements.
87
+
88
+
89
+
90
+ When training adjust the `replay.batch_size` parameter to maximize the utilization of your GPU resources. Increasing this value can improve training efficiency based on the capacity of your available hardware.
91
+ You can either modify the config files directly or you can pass parameters directly through the command line when running the training script. This allows for quick adjustments without editing configuration files:
92
+
93
+ ```bash
94
+ python train.py replay.batch_size=3 method=BIMANUAL_PERACT
95
+ ```
96
+
97
+ In this example, the command sets replay.batch_size to 3 and specifies the use of the BIMANUAL_PERACT method for training.
98
+ Another important parameter to specify the tasks is `rlbench.task_name`, which sets the overall task, and `rlbench.tasks`, which is a list of tasks used for training. Note that these can be different for evaluation.
99
+ A complete set of tasks is shown below:
100
+
101
+ ```yaml
102
+
103
+ rlbench:
104
+ task_name: multi
105
+ tasks:
106
+ - bimanual_push_box
107
+ - bimanual_lift_ball
108
+ - bimanual_dual_push_buttons
109
+ - bimanual_pick_plate
110
+ - bimanual_put_item_in_drawer
111
+ - bimanual_put_bottle_in_fridge
112
+ - bimanual_handover_item
113
+ - bimanual_pick_laptop
114
+ - bimanual_straighten_rope
115
+ - bimanual_sweep_to_dustpan
116
+ - bimanual_lift_tray
117
+ - bimanual_handover_item_easy
118
+ - bimanual_take_tray_out_of_oven
119
+ ```
120
+
121
+
122
+ #### Multi-GPU and Multi-Node Training
123
+
124
+ This repository supports multi-GPU training and distributed training across multiple nodes using [PyTorch Distributed Data Parallel (DDP)](https://pytorch.org/docs/stable/notes/ddp.html).
125
+ Follow the instructions below to configure and run training across multiple GPUs and nodes.
126
+
127
+ #### Multi-GPU Training on a Single Node
128
+
129
+ To train using multiple GPUs on a single node, set the parameter `ddp.num_devices` to the number of GPUs available. For example, if you have 4 GPUs, you can start the training process as follows:
130
+
131
+ ```bash
132
+ python train.py replay.batch_size=3 method=BIMANUAL_PERACT ddp.num_devices=4
133
+ ```
134
+
135
+ This command will utilize 4 GPUs on the current node for training. Remember to set the `replay.batch_size`, which is per GPU.
136
+
137
+ #### Multi-Node Training Across Different Nodes
138
+
139
+ If you want to perform distributed training across multiple nodes, you need to set additional parameters: ddp.master_addr and ddp.master_port. These parameters should be configured as follows:
140
+
141
+ `ddp.master_addr`: The IP address of the master node (usually the node where the training is initiated).
142
+ `ddp.master_port`: A port number to be used for communication across nodes.
143
+
144
+ Example Command:
145
+
146
+ ```bash
147
+ python train.py replay.batch_size=3 method=BIMANUAL_PERACT ddp.num_devices=4 ddp.master_addr=192.168.1.1 ddp.master_port=29500
148
+ ```
149
+
150
+ Note: Ensure that all nodes can communicate with each other through the specified IP and port, and that they have the same codebase, data access, and configurations for a successful distributed training run.
151
+
152
+
153
+
154
+ ### Evaluation
155
+
156
+
157
+ Similar to training you can find general parameters in `conf/eval.yaml` and method specific parameters in the `conf/method` directory.
158
+ For each method, you have to set the execution mode in RLBench. For bimanual agents such as `BIMANUAL_PERACT` or `PERACT_BC` this is:
159
+
160
+ ```yaml
161
+ rlbench:
162
+ gripper_mode: 'BimanualDiscrete'
163
+ arm_action_mode: 'BimanualEndEffectorPoseViaPlanning'
164
+ action_mode: 'BimanualMoveArmThenGripper'
165
+ ```
166
+
167
+
168
+ To generate videos of the current evaluation you can set `cinematic_recorder.enabled` to `True`.
169
+ It is recommended during evalution to disable the recorder, i.e. `cinematic_recorder.enabled=False`, as rendering the video increases the total evaluation time.
170
+
171
+
172
+ ## Acknowledgements
173
+
174
+ This repository uses code from the following open-source projects:
175
+
176
+ #### ARM
177
+ Original: [https://github.com/stepjam/ARM](https://github.com/stepjam/ARM)
178
+ License: [ARM License](https://github.com/stepjam/ARM/LICENSE)
179
+ Changes: Data loading was modified for PerAct. Voxelization code was modified for DDP training.
180
+
181
+ #### PerceiverIO
182
+ Original: [https://github.com/lucidrains/perceiver-pytorch](https://github.com/lucidrains/perceiver-pytorch)
183
+ License: [MIT](https://github.com/lucidrains/perceiver-pytorch/blob/main/LICENSE)
184
+ Changes: PerceiverIO adapted for 6-DoF manipulation.
185
+
186
+ #### ViT
187
+ Original: [https://github.com/lucidrains/vit-pytorch](https://github.com/lucidrains/vit-pytorch)
188
+ License: [MIT](https://github.com/lucidrains/vit-pytorch/blob/main/LICENSE)
189
+ Changes: ViT adapted for baseline.
190
+
191
+ #### LAMB Optimizer
192
+ Original: [https://github.com/cybertronai/pytorch-lamb](https://github.com/cybertronai/pytorch-lamb)
193
+ License: [MIT](https://github.com/cybertronai/pytorch-lamb/blob/master/LICENSE)
194
+ Changes: None.
195
+
196
+ #### OpenAI CLIP
197
+ Original: [https://github.com/openai/CLIP](https://github.com/openai/CLIP)
198
+ License: [MIT](https://github.com/openai/CLIP/blob/main/LICENSE)
199
+ Changes: Minor modifications to extract token and sentence features.
200
+
201
+ Thanks for open-sourcing!
202
+
203
+ ## Licenses
204
+ - [PerAct License (Apache 2.0)](LICENSE) - Perceiver-Actor Transformer
205
+ - [ARM License](ARM_LICENSE) - Voxelization and Data Preprocessing
206
+ - [YARR Licence (Apache 2.0)](https://github.com/stepjam/YARR/blob/main/LICENSE)
207
+ - [RLBench Licence](https://github.com/stepjam/RLBench/blob/master/LICENSE)
208
+ - [PyRep License (MIT)](https://github.com/stepjam/PyRep/blob/master/LICENSE)
209
+ - [Perceiver PyTorch License (MIT)](https://github.com/lucidrains/perceiver-pytorch/blob/main/LICENSE)
210
+ - [LAMB License (MIT)](https://github.com/cybertronai/pytorch-lamb/blob/master/LICENSE)
211
+ - [CLIP License (MIT)](https://github.com/openai/CLIP/blob/main/LICENSE)
212
+
213
+ ## Release Notes
214
+
215
+ **Update 2025-02-20**
216
+
217
+ - Update instructions
218
+ - Add missing dependency for install script
219
+ - Add docker build file
220
+
221
+ **Update 2024-11-06**
222
+
223
+ - Regenerat and repack dataset. Closes #13. Task names are now more consistent. Dataset now includes waypoint information.
224
+
225
+
226
+ **Update 2024-10-17**
227
+
228
+ - Update Readme
229
+
230
+
231
+
232
+ **Update 2024-07-10**
233
+
234
+ - Initial release
235
+
236
+
237
+ ## Citations
238
+
239
+
240
+ **PerAct^2**
241
+ ```
242
+ @misc{grotz2024peract2,
243
+ title={PerAct2: Benchmarking and Learning for Robotic Bimanual Manipulation Tasks},
244
+ author={Markus Grotz and Mohit Shridhar and Tamim Asfour and Dieter Fox},
245
+ year={2024},
246
+ eprint={2407.00278},
247
+ archivePrefix={arXiv},
248
+ primaryClass={cs.RO},
249
+ url={https://arxiv.org/abs/2407.00278},
250
+ }
251
+ ```
252
+
253
+ **PerAct**
254
+ ```
255
+ @inproceedings{shridhar2022peract,
256
+ title = {Perceiver-Actor: A Multi-Task Transformer for Robotic Manipulation},
257
+ author = {Shridhar, Mohit and Manuelli, Lucas and Fox, Dieter},
258
+ booktitle = {Proceedings of the 6th Conference on Robot Learning (CoRL)},
259
+ year = {2022},
260
+ }
261
+ ```
262
+
263
+ **C2FARM**
264
+ ```
265
+ @inproceedings{james2022coarse,
266
+ title={Coarse-to-fine q-attention: Efficient learning for visual robotic manipulation via discretisation},
267
+ author={James, Stephen and Wada, Kentaro and Laidlow, Tristan and Davison, Andrew J},
268
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
269
+ pages={13739--13748},
270
+ year={2022}
271
+ }
272
+ ```
273
+
274
+ **PerceiverIO**
275
+ ```
276
+ @article{jaegle2021perceiver,
277
+ title={Perceiver io: A general architecture for structured inputs \& outputs},
278
+ author={Jaegle, Andrew and Borgeaud, Sebastian and Alayrac, Jean-Baptiste and Doersch, Carl and Ionescu, Catalin and Ding, David and Koppula, Skanda and Zoran, Daniel and Brock, Andrew and Shelhamer, Evan and others},
279
+ journal={arXiv preprint arXiv:2107.14795},
280
+ year={2021}
281
+ }
282
+ ```
283
+
284
+ **RLBench**
285
+ ```
286
+ @article{james2020rlbench,
287
+ title={Rlbench: The robot learning benchmark \& learning environment},
288
+ author={James, Stephen and Ma, Zicong and Arrojo, David Rovick and Davison, Andrew J},
289
+ journal={IEEE Robotics and Automation Letters},
290
+ volume={5},
291
+ number={2},
292
+ pages={3019--3026},
293
+ year={2020},
294
+ publisher={IEEE}
295
+ }
296
+ ```
297
+
298
+ ## Questions or Issues?
299
+
300
+ Please file an issue with the issue tracker.
external/peract_bimanual/helpers/clip/__init__.py ADDED
File without changes
external/peract_bimanual/model-card.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Card: Perceiver-Actor
2
+
3
+ Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993) and [Lessons from Archives (Jo & Gebru)](https://arxiv.org/pdf/1912.10389.pdf) we provide additional information on PerAct.
4
+
5
+ ## Model Details
6
+
7
+
8
+ ### Overview
9
+ - Developed by Shridhar et al. at University of Washington and NVIDIA. PerAct is an end-to-end behavior cloning agent that learns to perform a wide variety of language-conditioned manipulation tasks. PerAct uses a Transformer that exploits the 3D structure of _voxel patches_ to learn policies with just a few demonstrations per task.
10
+ - Architecture: Transformer trained from scratch with end-to-end supervised learning.
11
+ - Trained for 6-DoF manipulation tasks with objects that appear in tabletop scenes.
12
+
13
+ ### Model Date
14
+
15
+ Nov 2022
16
+
17
+ ### Documents
18
+
19
+ - [PerAct Paper](https://peract.github.io/paper/peract_corl2022.pdf)
20
+ - [PerceiverIO Paper](https://arxiv.org/abs/2107.14795)
21
+ - [C2FARM Paper](https://arxiv.org/abs/2106.12534)
22
+
23
+
24
+ ## Model Use
25
+
26
+ - **Primary intended use case**: PerAct is intended for robotic manipulation research. We hope the benchmark and pre-trained models will enable researchers to study the capabilities of Transformers for end-to-end 6-DoF Manipulation. Specifically, we hope the setup serves a reproducible framework for evaluating robustness and scaling capabilities of manipulation agents.
27
+ - **Primary intended users**: Robotics researchers.
28
+ - **Out-of-scope use cases**: Deployed use cases in real-world autonomous systems without human supervision during test-time is currently out-of-scope. Use cases that involve manipulating novel objects and observations with people, are not recommended for safety-critical systems. The agent is also intended to be trained and evaluated with English language instructions.
29
+
30
+ ## Data
31
+
32
+ - Pre-training Data for CLIP's language encoder: See [OpenAI's Model Card](https://github.com/openai/CLIP/blob/main/model-card.md#data) for full details. **Note:** We do not use CLIP's vision encoders for any agents in the repo.
33
+ - Manipulation Data for PerAct: The agent was trained with expert demonstrations. In simulation, we use oracle agents and in real-world we use human demonstrations. Since the agent is used in few-shot settings with very limited data, the agent might exploit intended and un-intented biases in the training demonstrations. Currently, these biases are limited to just objects that appear on tabletops.
34
+
35
+
36
+ ## Limitations
37
+
38
+ - Depends on a sampling-based motion planner.
39
+ - Hard to extend to dexterous and continuous manipulation tasks.
40
+ - Lacks memory to solve tasks with ordering and history-based sequencing.
41
+ - Exploits biases in training demonstrations.
42
+ - Needs good hand-eye calibration.
43
+ - Doesn't generalize to novel objects.
44
+ - Struggles with grounding complex spatial relationships.
45
+ - Does not predict task completion.
46
+
47
+ See Appendix L in the [paper](https://peract.github.io/paper/peract_corl2022.pdf) for an extended discussion.
external/peract_bimanual/peract_config.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ System configuration for peract
3
+ """
4
+ import os
5
+ import logging
6
+
7
+ import torch.multiprocessing as mp
8
+
9
+
10
+ def config_logging(logging_level=logging.INFO, reset=False):
11
+ if reset:
12
+ root = logging.getLogger()
13
+ list(map(root.removeHandler, root.handlers))
14
+ list(map(root.removeFilter, root.filters))
15
+
16
+ from rich.logging import RichHandler
17
+
18
+ logging.basicConfig(level=logging_level, handlers=[RichHandler()])
19
+
20
+
21
+ def on_init():
22
+ config_logging(logging.INFO)
23
+
24
+ logging.debug("Configuring environment.")
25
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
26
+ mp.set_start_method("spawn", force=True)
27
+ mp.set_sharing_strategy("file_system")
28
+
29
+
30
+ def on_config(cfg):
31
+ os.environ["MASTER_ADDR"] = str(cfg.ddp.master_addr)
32
+ os.environ["MASTER_PORT"] = str(cfg.ddp.master_port)
external/peract_bimanual/pyproject.toml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "peract_bimanual"
3
+ version = "0.0.1"
4
+ description = "A perceiver actor framework for bimanual manipulation tasks"
5
+ authors = [ "Markus Grotz <grotz@uw.edu>",
6
+ "Mohit Shridhar <mshr@cs.washington.edu>"]
7
+ packages = [{include = "agents"}, {include = "helpers"}, {include = "voxel"}]
8
+
9
+
10
+ readme = "README.md"
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "Framework :: Robot Framework "
14
+ ]
15
+
16
+ [tool.poetry.dependencies]
17
+ python = ">=3.8,<4.0"
18
+ einops = "0.3.2"
19
+ ftfy = "^6.1.1"
20
+ hydra-core = ">=1.0.5"
21
+ matplotlib = "^3.7.1"
22
+ pandas = "1.4.1"
23
+ regex = "^2023.6.3"
24
+ tensorboard = "^2.13.0"
25
+ perceiver-pytorch = "^0.8.7"
26
+ transformers = "^4.21"
27
+
28
+
29
+
30
+ [tool.poetry.extras]
31
+ docs = ["sphinx"]
32
+
33
+ [build-system]
34
+ requires = ["setuptools", "wheel", "poetry-core>=1.0.0"]
35
+ build-backend = "poetry.core.masonry.api"
external/peract_bimanual/run_seed_fn.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import gc
4
+ from typing import List
5
+
6
+ import hydra
7
+ import numpy as np
8
+ import torch
9
+ from omegaconf import DictConfig
10
+
11
+ from rlbench import CameraConfig, ObservationConfig
12
+ from yarr.replay_buffer.wrappers.pytorch_replay_buffer import PyTorchReplayBuffer
13
+ from yarr.runners.offline_train_runner import OfflineTrainRunner
14
+ from yarr.utils.stat_accumulator import SimpleAccumulator
15
+
16
+ from helpers.custom_rlbench_env import CustomRLBenchEnv, CustomMultiTaskRLBenchEnv
17
+ import torch.distributed as dist
18
+
19
+ from agents import agent_factory
20
+ from agents import replay_utils
21
+
22
+ import peract_config
23
+ from functools import partial
24
+
25
+
26
+ def run_seed(
27
+ rank,
28
+ cfg: DictConfig,
29
+ obs_config: ObservationConfig,
30
+ seed,
31
+ world_size,
32
+ ) -> None:
33
+ peract_config.config_logging()
34
+
35
+ dist.init_process_group("gloo", rank=rank, world_size=world_size)
36
+
37
+ tasks = cfg.rlbench.tasks
38
+ cams = cfg.rlbench.cameras
39
+
40
+ task_folder = "multi" if len(tasks) > 1 else tasks[0]
41
+ replay_path = os.path.join(
42
+ cfg.replay.path, task_folder, cfg.method.name, "seed%d" % seed
43
+ )
44
+
45
+ agent = agent_factory.create_agent(cfg)
46
+
47
+ if not agent:
48
+ print("Unable to create agent")
49
+ return
50
+
51
+ if cfg.method.name == "ARM":
52
+ raise NotImplementedError("ARM is not supported yet")
53
+ elif cfg.method.name == "BC_LANG":
54
+ from agents.baselines import bc_lang
55
+
56
+ assert cfg.ddp.num_devices == 1, "BC_LANG only supports single GPU training"
57
+ replay_buffer = bc_lang.launch_utils.create_replay(
58
+ cfg.replay.batch_size,
59
+ cfg.replay.timesteps,
60
+ cfg.replay.prioritisation,
61
+ cfg.replay.task_uniform,
62
+ replay_path if cfg.replay.use_disk else None,
63
+ cams,
64
+ cfg.rlbench.camera_resolution,
65
+ )
66
+
67
+ bc_lang.launch_utils.fill_multi_task_replay(
68
+ cfg,
69
+ obs_config,
70
+ rank,
71
+ replay_buffer,
72
+ tasks,
73
+ cfg.rlbench.demos,
74
+ cfg.method.demo_augmentation,
75
+ cfg.method.demo_augmentation_every_n,
76
+ cams,
77
+ )
78
+
79
+ elif cfg.method.name == "VIT_BC_LANG":
80
+ from agents.baselines import vit_bc_lang
81
+
82
+ assert cfg.ddp.num_devices == 1, "VIT_BC_LANG only supports single GPU training"
83
+ replay_buffer = vit_bc_lang.launch_utils.create_replay(
84
+ cfg.replay.batch_size,
85
+ cfg.replay.timesteps,
86
+ cfg.replay.prioritisation,
87
+ cfg.replay.task_uniform,
88
+ replay_path if cfg.replay.use_disk else None,
89
+ cams,
90
+ cfg.rlbench.camera_resolution,
91
+ )
92
+
93
+ vit_bc_lang.launch_utils.fill_multi_task_replay(
94
+ cfg,
95
+ obs_config,
96
+ rank,
97
+ replay_buffer,
98
+ tasks,
99
+ cfg.rlbench.demos,
100
+ cfg.method.demo_augmentation,
101
+ cfg.method.demo_augmentation_every_n,
102
+ cams,
103
+ )
104
+
105
+ elif cfg.method.name.startswith("ACT_BC_LANG"):
106
+ from agents import act_bc_lang
107
+
108
+ assert cfg.ddp.num_devices == 1, "ACT_BC_LANG only supports single GPU training"
109
+ replay_buffer = act_bc_lang.launch_utils.create_replay(
110
+ cfg.replay.batch_size,
111
+ cfg.replay.timesteps,
112
+ cfg.replay.prioritisation,
113
+ cfg.replay.task_uniform,
114
+ replay_path if cfg.replay.use_disk else None,
115
+ cams,
116
+ cfg.rlbench.camera_resolution,
117
+ replay_size=3e5,
118
+ prev_action_horizon=cfg.method.prev_action_horizon,
119
+ next_action_horizon=cfg.method.next_action_horizon,
120
+ )
121
+
122
+ act_bc_lang.launch_utils.fill_multi_task_replay(
123
+ cfg,
124
+ obs_config,
125
+ rank,
126
+ replay_buffer,
127
+ tasks,
128
+ cfg.rlbench.demos,
129
+ cfg.method.demo_augmentation,
130
+ cfg.method.demo_augmentation_every_n,
131
+ cams,
132
+ )
133
+
134
+ elif cfg.method.name == "C2FARM_LINGUNET_BC":
135
+ from agents import c2farm_lingunet_bc
136
+
137
+ replay_buffer = c2farm_lingunet_bc.launch_utils.create_replay(
138
+ cfg.replay.batch_size,
139
+ cfg.replay.timesteps,
140
+ cfg.replay.prioritisation,
141
+ cfg.replay.task_uniform,
142
+ replay_path if cfg.replay.use_disk else None,
143
+ cams,
144
+ cfg.method.voxel_sizes,
145
+ cfg.rlbench.camera_resolution,
146
+ )
147
+
148
+ c2farm_lingunet_bc.launch_utils.fill_multi_task_replay(
149
+ cfg,
150
+ obs_config,
151
+ rank,
152
+ replay_buffer,
153
+ tasks,
154
+ cfg.rlbench.demos,
155
+ cfg.method.demo_augmentation,
156
+ cfg.method.demo_augmentation_every_n,
157
+ cams,
158
+ cfg.rlbench.scene_bounds,
159
+ cfg.method.voxel_sizes,
160
+ cfg.method.bounds_offset,
161
+ cfg.method.rotation_resolution,
162
+ cfg.method.crop_augmentation,
163
+ keypoint_method=cfg.method.keypoint_method,
164
+ )
165
+
166
+ elif (
167
+ cfg.method.name.startswith("BIMANUAL_PERACT")
168
+ or cfg.method.name.startswith("RVT")
169
+ or cfg.method.name.startswith("PERACT_BC")
170
+ ):
171
+ replay_buffer = replay_utils.create_replay(cfg, replay_path)
172
+
173
+ replay_utils.fill_multi_task_replay(cfg, obs_config, rank, replay_buffer, tasks)
174
+
175
+ elif cfg.method.name == "PERACT_RL":
176
+ raise NotImplementedError("PERACT_RL is not supported yet")
177
+
178
+ else:
179
+ raise ValueError("Method %s does not exists." % cfg.method.name)
180
+
181
+ wrapped_replay = PyTorchReplayBuffer(
182
+ replay_buffer, num_workers=cfg.framework.num_workers
183
+ )
184
+ stat_accum = SimpleAccumulator(eval_video_fps=30)
185
+
186
+ cwd = os.getcwd()
187
+ weightsdir = os.path.join(cwd, "seed%d" % seed, "weights")
188
+ logdir = os.path.join(cwd, "seed%d" % seed)
189
+
190
+ train_runner = OfflineTrainRunner(
191
+ agent=agent,
192
+ wrapped_replay_buffer=wrapped_replay,
193
+ train_device=rank,
194
+ stat_accumulator=stat_accum,
195
+ iterations=cfg.framework.training_iterations,
196
+ logdir=logdir,
197
+ logging_level=cfg.framework.logging_level,
198
+ log_freq=cfg.framework.log_freq,
199
+ weightsdir=weightsdir,
200
+ num_weights_to_keep=cfg.framework.num_weights_to_keep,
201
+ save_freq=cfg.framework.save_freq,
202
+ tensorboard_logging=cfg.framework.tensorboard_logging,
203
+ csv_logging=cfg.framework.csv_logging,
204
+ load_existing_weights=cfg.framework.load_existing_weights,
205
+ rank=rank,
206
+ world_size=world_size,
207
+ )
208
+
209
+ train_runner._on_thread_start = partial(
210
+ peract_config.config_logging, cfg.framework.logging_level
211
+ )
212
+
213
+ train_runner.start()
214
+
215
+ del train_runner
216
+ del agent
217
+ gc.collect()
218
+ torch.cuda.empty_cache()
external/peract_bimanual/train.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import logging
3
+ import os
4
+ import sys
5
+ from datetime import datetime
6
+
7
+ import peract_config
8
+
9
+ import hydra
10
+ from omegaconf import DictConfig, OmegaConf, ListConfig
11
+
12
+ import run_seed_fn
13
+ from helpers.observation_utils import create_obs_config
14
+
15
+ import torch.multiprocessing as mp
16
+
17
+
18
+ @hydra.main(config_name="config", config_path="conf")
19
+ def main(cfg: DictConfig) -> None:
20
+ cfg_yaml = OmegaConf.to_yaml(cfg)
21
+ logging.info("\n" + cfg_yaml)
22
+
23
+ peract_config.on_config(cfg)
24
+
25
+ cfg.rlbench.cameras = (
26
+ cfg.rlbench.cameras
27
+ if isinstance(cfg.rlbench.cameras, ListConfig)
28
+ else [cfg.rlbench.cameras]
29
+ )
30
+
31
+ # sanity check if rgb is not used as camera name
32
+ for camera_name in cfg.rlbench.cameras:
33
+ assert "rgb" not in camera_name
34
+
35
+ obs_config = create_obs_config(
36
+ cfg.rlbench.cameras, cfg.rlbench.camera_resolution, cfg.method.name
37
+ )
38
+
39
+ cwd = os.getcwd()
40
+ logging.info("CWD:" + os.getcwd())
41
+
42
+ if cfg.framework.start_seed >= 0:
43
+ # seed specified
44
+ start_seed = cfg.framework.start_seed
45
+ elif (
46
+ cfg.framework.start_seed == -1
47
+ and len(list(filter(lambda x: "seed" in x, os.listdir(cwd)))) > 0
48
+ ):
49
+ # unspecified seed; use largest existing seed plus one
50
+ largest_seed = max(
51
+ [
52
+ int(n.replace("seed", ""))
53
+ for n in list(filter(lambda x: "seed" in x, os.listdir(cwd)))
54
+ ]
55
+ )
56
+ start_seed = largest_seed + 1
57
+ else:
58
+ # start with seed 0
59
+ start_seed = 0
60
+
61
+ seed_folder = os.path.join(os.getcwd(), "seed%d" % start_seed)
62
+ os.makedirs(seed_folder, exist_ok=True)
63
+
64
+ start_time = datetime.now()
65
+ with open(os.path.join(seed_folder, "config.yaml"), "w") as f:
66
+ f.write(cfg_yaml)
67
+
68
+ # check if previous checkpoints already exceed the number of desired training iterations
69
+ # if so, exit the script
70
+ latest_weight = 0
71
+ weights_folder = os.path.join(seed_folder, "weights")
72
+ if os.path.isdir(weights_folder) and len(os.listdir(weights_folder)) > 0:
73
+ weights = os.listdir(weights_folder)
74
+ latest_weight = sorted(map(int, weights))[-1]
75
+ if latest_weight >= cfg.framework.training_iterations:
76
+ logging.info(
77
+ "Agent was already trained for %d iterations. Exiting." % latest_weight
78
+ )
79
+ sys.exit(0)
80
+
81
+ with open(os.path.join(seed_folder, "training.log"), "a") as f:
82
+ f.write(
83
+ f"# Starting training from weights: {latest_weight} to {cfg.framework.training_iterations}"
84
+ )
85
+ f.write(f"# Training started on: {start_time.isoformat()}")
86
+ f.write(os.linesep)
87
+
88
+ # run train jobs with multiple seeds (sequentially)
89
+ for seed in range(start_seed, start_seed + cfg.framework.seeds):
90
+ logging.info("Starting seed %d." % seed)
91
+
92
+ world_size = cfg.ddp.num_devices
93
+ mp.spawn(
94
+ run_seed_fn.run_seed,
95
+ args=(
96
+ cfg,
97
+ obs_config,
98
+ seed,
99
+ world_size,
100
+ ),
101
+ nprocs=world_size,
102
+ join=True,
103
+ )
104
+
105
+ end_time = datetime.now()
106
+ duration = end_time - start_time
107
+ with open(os.path.join(seed_folder, "training.log"), "a") as f:
108
+ f.write(f"# Training finished on: {end_time.isoformat()}")
109
+ f.write(f"# Took {duration.total_seconds()}")
110
+ f.write(os.linesep)
111
+ f.write(os.linesep)
112
+
113
+
114
+ if __name__ == "__main__":
115
+ peract_config.on_init()
116
+ main()
external/peract_bimanual/voxel/__init__.py ADDED
File without changes
external/peract_bimanual/voxel/voxel_grid.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Voxelizer modified from ARM for DDP training
2
+ # Source: https://github.com/stepjam/ARM
3
+ # License: https://github.com/stepjam/ARM/LICENSE
4
+
5
+ from functools import reduce
6
+ from operator import mul
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+ MIN_DENOMINATOR = 1e-12
12
+ INCLUDE_PER_VOXEL_COORD = False
13
+
14
+
15
+ class VoxelGrid(nn.Module):
16
+ def __init__(
17
+ self,
18
+ coord_bounds,
19
+ voxel_size: int,
20
+ device,
21
+ batch_size,
22
+ feature_size, # e.g. rgb or image features
23
+ max_num_coords: int,
24
+ ):
25
+ super(VoxelGrid, self).__init__()
26
+ self._device = device
27
+ self._voxel_size = voxel_size
28
+ self._voxel_shape = [voxel_size] * 3
29
+ self._voxel_d = float(self._voxel_shape[-1])
30
+ self._voxel_feature_size = 4 + feature_size
31
+ self._voxel_shape_spec = (
32
+ torch.tensor(
33
+ self._voxel_shape,
34
+ ).unsqueeze(0)
35
+ + 2
36
+ ) # +2 because we crop the edges.
37
+ self._coord_bounds = torch.tensor(
38
+ coord_bounds,
39
+ dtype=torch.float,
40
+ ).unsqueeze(0)
41
+ max_dims = self._voxel_shape_spec[0]
42
+ self._total_dims_list = torch.cat(
43
+ [
44
+ torch.tensor(
45
+ [batch_size],
46
+ ),
47
+ max_dims,
48
+ torch.tensor(
49
+ [4 + feature_size],
50
+ ),
51
+ ],
52
+ -1,
53
+ ).tolist()
54
+
55
+ self.register_buffer(
56
+ "_ones_max_coords", torch.ones((batch_size, max_num_coords, 1))
57
+ )
58
+ self._num_coords = max_num_coords
59
+
60
+ shape = self._total_dims_list
61
+ result_dim_sizes = torch.tensor(
62
+ [reduce(mul, shape[i + 1 :], 1) for i in range(len(shape) - 1)] + [1],
63
+ )
64
+ self.register_buffer("_result_dim_sizes", result_dim_sizes)
65
+ flat_result_size = reduce(mul, shape, 1)
66
+
67
+ self._initial_val = torch.tensor(0, dtype=torch.float)
68
+ flat_output = (
69
+ torch.ones(flat_result_size, dtype=torch.float) * self._initial_val
70
+ )
71
+ self.register_buffer("_flat_output", flat_output)
72
+
73
+ self.register_buffer("_arange_to_max_coords", torch.arange(4 + feature_size))
74
+ self._flat_zeros = torch.zeros(flat_result_size, dtype=torch.float)
75
+
76
+ self._const_1 = torch.tensor(
77
+ 1.0,
78
+ )
79
+ self._batch_size = batch_size
80
+
81
+ # Coordinate Bounds:
82
+ bb_mins = self._coord_bounds[..., 0:3]
83
+ self.register_buffer("_bb_mins", bb_mins)
84
+ bb_maxs = self._coord_bounds[..., 3:6]
85
+ bb_ranges = bb_maxs - bb_mins
86
+ # get voxel dimensions. 'DIMS' mode
87
+ self._dims = dims = self._voxel_shape_spec.int()
88
+ dims_orig = self._voxel_shape_spec.int() - 2
89
+ self.register_buffer("_dims_orig", dims_orig)
90
+
91
+ # self._dims_m_one = (dims - 1).int()
92
+ dims_m_one = (dims - 1).int()
93
+ self.register_buffer("_dims_m_one", dims_m_one)
94
+
95
+ # BS x 1 x 3
96
+ res = bb_ranges / (dims_orig.float() + MIN_DENOMINATOR)
97
+ self._res_minis_2 = bb_ranges / (dims.float() - 2 + MIN_DENOMINATOR)
98
+ self.register_buffer("_res", res)
99
+
100
+ voxel_indicy_denmominator = res + MIN_DENOMINATOR
101
+ self.register_buffer("_voxel_indicy_denmominator", voxel_indicy_denmominator)
102
+
103
+ self.register_buffer("_dims_m_one_zeros", torch.zeros_like(dims_m_one))
104
+
105
+ batch_indices = torch.arange(self._batch_size, dtype=torch.int).view(
106
+ self._batch_size, 1, 1
107
+ )
108
+ self.register_buffer(
109
+ "_tiled_batch_indices", batch_indices.repeat([1, self._num_coords, 1])
110
+ )
111
+
112
+ w = self._voxel_shape[0] + 2
113
+ arange = torch.arange(
114
+ 0,
115
+ w,
116
+ dtype=torch.float,
117
+ )
118
+ index_grid = (
119
+ torch.cat(
120
+ [
121
+ arange.view(w, 1, 1, 1).repeat([1, w, w, 1]),
122
+ arange.view(1, w, 1, 1).repeat([w, 1, w, 1]),
123
+ arange.view(1, 1, w, 1).repeat([w, w, 1, 1]),
124
+ ],
125
+ dim=-1,
126
+ )
127
+ .unsqueeze(0)
128
+ .repeat([self._batch_size, 1, 1, 1, 1])
129
+ )
130
+ self.register_buffer("_index_grid", index_grid)
131
+
132
+ def _broadcast(self, src: torch.Tensor, other: torch.Tensor, dim: int):
133
+ if dim < 0:
134
+ dim = other.dim() + dim
135
+ if src.dim() == 1:
136
+ for _ in range(0, dim):
137
+ src = src.unsqueeze(0)
138
+ for _ in range(src.dim(), other.dim()):
139
+ src = src.unsqueeze(-1)
140
+ src = src.expand_as(other)
141
+ return src
142
+
143
+ def _scatter_mean(
144
+ self, src: torch.Tensor, index: torch.Tensor, out: torch.Tensor, dim: int = -1
145
+ ):
146
+ out = out.scatter_add_(dim, index, src)
147
+
148
+ index_dim = dim
149
+ if index_dim < 0:
150
+ index_dim = index_dim + src.dim()
151
+ if index.dim() <= index_dim:
152
+ index_dim = index.dim() - 1
153
+
154
+ ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
155
+ out_count = torch.zeros(out.size(), dtype=out.dtype, device=out.device)
156
+ out_count = out_count.scatter_add_(index_dim, index, ones)
157
+ out_count.clamp_(1)
158
+ count = self._broadcast(out_count, out, dim)
159
+ if torch.is_floating_point(out):
160
+ out.true_divide_(count)
161
+ else:
162
+ out.floor_divide_(count)
163
+ return out
164
+
165
+ def _scatter_nd(self, indices, updates):
166
+ indices_shape = indices.shape
167
+ num_index_dims = indices_shape[-1]
168
+ flat_updates = updates.view((-1,))
169
+ indices_scales = self._result_dim_sizes[0:num_index_dims].view(
170
+ [1] * (len(indices_shape) - 1) + [num_index_dims]
171
+ )
172
+ indices_for_flat_tiled = (
173
+ ((indices * indices_scales).sum(dim=-1, keepdims=True))
174
+ .view(-1, 1)
175
+ .repeat(*[1, self._voxel_feature_size])
176
+ )
177
+
178
+ implicit_indices = (
179
+ self._arange_to_max_coords[: self._voxel_feature_size]
180
+ .unsqueeze(0)
181
+ .repeat(*[indices_for_flat_tiled.shape[0], 1])
182
+ )
183
+ indices_for_flat = indices_for_flat_tiled + implicit_indices
184
+ flat_indices_for_flat = indices_for_flat.view((-1,)).long()
185
+
186
+ flat_scatter = self._scatter_mean(
187
+ flat_updates, flat_indices_for_flat, out=torch.zeros_like(self._flat_output)
188
+ )
189
+ return flat_scatter.view(self._total_dims_list)
190
+
191
+ def coords_to_bounding_voxel_grid(
192
+ self, coords, coord_features=None, coord_bounds=None
193
+ ):
194
+ voxel_indicy_denmominator = self._voxel_indicy_denmominator
195
+ res, bb_mins = self._res, self._bb_mins
196
+ if coord_bounds is not None:
197
+ bb_mins = coord_bounds[..., 0:3]
198
+ bb_maxs = coord_bounds[..., 3:6]
199
+ bb_ranges = bb_maxs - bb_mins
200
+ res = bb_ranges / (self._dims_orig.float() + MIN_DENOMINATOR)
201
+ voxel_indicy_denmominator = res + MIN_DENOMINATOR
202
+
203
+ bb_mins_shifted = bb_mins - res # shift back by one
204
+ floor = torch.floor(
205
+ (coords - bb_mins_shifted.unsqueeze(1))
206
+ / voxel_indicy_denmominator.unsqueeze(1)
207
+ ).int()
208
+ voxel_indices = torch.min(floor, self._dims_m_one)
209
+ voxel_indices = torch.max(voxel_indices, self._dims_m_one_zeros)
210
+
211
+ # BS x NC x 3
212
+ voxel_values = coords
213
+ if coord_features is not None:
214
+ voxel_values = torch.cat([voxel_values, coord_features], -1)
215
+
216
+ _, num_coords, _ = voxel_indices.shape
217
+ # BS x N x (num_batch_dims + 2)
218
+ all_indices = torch.cat(
219
+ [self._tiled_batch_indices[:, :num_coords], voxel_indices], -1
220
+ )
221
+
222
+ # BS x N x 4
223
+ voxel_values_pruned_flat = torch.cat(
224
+ [voxel_values, self._ones_max_coords[:, :num_coords]], -1
225
+ )
226
+
227
+ # BS x x_max x y_max x z_max x 4
228
+ scattered = self._scatter_nd(
229
+ all_indices.view([-1, 1 + 3]),
230
+ voxel_values_pruned_flat.view(-1, self._voxel_feature_size),
231
+ )
232
+
233
+ vox = scattered[:, 1:-1, 1:-1, 1:-1]
234
+ if INCLUDE_PER_VOXEL_COORD:
235
+ res_expanded = res.unsqueeze(1).unsqueeze(1).unsqueeze(1)
236
+ res_centre = (res_expanded * self._index_grid) + res_expanded / 2.0
237
+ coord_positions = (
238
+ res_centre + bb_mins_shifted.unsqueeze(1).unsqueeze(1).unsqueeze(1)
239
+ )[:, 1:-1, 1:-1, 1:-1]
240
+ vox = torch.cat([vox[..., :-1], coord_positions, vox[..., -1:]], -1)
241
+
242
+ occupied = (vox[..., -1:] > 0).float()
243
+ vox = torch.cat([vox[..., :-1], occupied], -1)
244
+
245
+ return torch.cat(
246
+ [
247
+ vox[..., :-1],
248
+ self._index_grid[:, :-2, :-2, :-2] / self._voxel_d,
249
+ vox[..., -1:],
250
+ ],
251
+ -1,
252
+ )
external/yarr/.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ venv
3
+ .idea
4
+ .bash_history
5
+ .cache/
6
+ .local/
7
+ .python_history
8
+ nvidia-persistenced/
9
+ results/
10
+ rlight.egg-info
11
+ dist/
12
+ build/
13
+ yarr.egg-info
external/yarr/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
external/yarr/README.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![Logo Missing](logo.png)
2
+
3
+ **Note**: Pirate qualification not needed to use this library.
4
+
5
+ YARR is **Y**et **A**nother **R**obotics and **R**einforcement learning framework for PyTorch.
6
+
7
+ The framework allows for asynchronous training (i.e. agent and learner running in separate processes), which makes it suitable for robot learning.
8
+ For an example of how to use this framework, see my [Attention-driven Robot Manipulation (ARM) repo](https://github.com/stepjam/ARM).
9
+
10
+ This project is mostly intended for my personal use (Stephen James) and facilitate my research.
11
+
12
+ ## Modifcations
13
+
14
+ This is my (Mohit Shridhar) fork of YARR. Honestly, I don't understand what exactly is happening in a lot of places, so there a lot of hacks to make it work for my purposes. If you are doing simple behavior cloning, you can probably write simpler training and evaluation routines, but YARR might be useful if you also want to do RL. Here is a quick summary of my modifcations:
15
+
16
+ - Switched from randomly sampling evaluation episodes to deterministic reloading of val/test dataset episodes for one-to-one comparisons across models.
17
+ - Separated training and evaluation routines.
18
+ - Task-uniform replay buffer for multi-task training. Each batch has a uniform distribution of tasks.
19
+ - Added cinematic recorder for rollouts.
20
+ - Some other weird hacks to prevent memory leaks.
21
+
22
+ ## Install
23
+
24
+ Ensure you have [PyTorch installed](https://pytorch.org/get-started/locally/).
25
+ Then simply run:
26
+ ```bash
27
+ python setup.py develop
28
+ ```
external/yarr/logo.png ADDED
external/yarr/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tensorboard
2
+ moviepy
3
+ natsort
4
+ psutil
5
+ timeout-decorator
6
+ pyrender==0.1.45
7
+ omegaconf
8
+ hydra-core
9
+ pandas==1.4.1
10
+ opencv-python
11
+
external/yarr/setup.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import codecs
2
+ import os
3
+
4
+ import setuptools
5
+
6
+
7
+ def read(rel_path):
8
+ here = os.path.abspath(os.path.dirname(__file__))
9
+ with codecs.open(os.path.join(here, rel_path), 'r') as fp:
10
+ return fp.read()
11
+
12
+
13
+ def get_version(rel_path):
14
+ for line in read(rel_path).splitlines():
15
+ if line.startswith('__version__'):
16
+ delim = '"' if '"' in line else "'"
17
+ return line.split(delim)[1]
18
+ else:
19
+ raise RuntimeError("Unable to find version string.")
20
+
21
+
22
+ def get_install_requires():
23
+ install_requires = []
24
+ with open('requirements.txt') as f:
25
+ for req in f:
26
+ install_requires.append(req.strip())
27
+ return install_requires
28
+
29
+
30
+ setuptools.setup(
31
+ version=get_version("yarr/__init__.py"),
32
+ name='yarr',
33
+ author='Stephen James',
34
+ author_email='slj12@ic.ac.uk',
35
+ packages=setuptools.find_packages(),
36
+ install_requires=get_install_requires()
37
+ )
external/yarr/yarr/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = '0.1'
external/yarr/yarr/agents/__init__.py ADDED
File without changes
external/yarr/yarr/agents/agent.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, List
3
+
4
+
5
+ class Summary(object):
6
+ def __init__(self, name: str, value: Any):
7
+ self.name = name
8
+ self.value = value
9
+
10
+
11
+ class ScalarSummary(Summary):
12
+ pass
13
+
14
+
15
+ class HistogramSummary(Summary):
16
+ pass
17
+
18
+
19
+ class ImageSummary(Summary):
20
+ pass
21
+
22
+
23
+ class TextSummary(Summary):
24
+ pass
25
+
26
+
27
+ class VideoSummary(Summary):
28
+ def __init__(self, name: str, value: Any, fps: int = 30):
29
+ super(VideoSummary, self).__init__(name, value)
30
+ self.fps = fps
31
+
32
+
33
+ class ActResult(object):
34
+
35
+ def __init__(self, action: Any,
36
+ observation_elements: dict = None,
37
+ replay_elements: dict = None,
38
+ info: dict = None):
39
+ self.action = action
40
+ self.observation_elements = observation_elements or {}
41
+ self.replay_elements = replay_elements or {}
42
+ self.info = info or {}
43
+
44
+
45
+ class Agent(ABC):
46
+
47
+ @abstractmethod
48
+ def build(self, training: bool, device=None) -> None:
49
+ pass
50
+
51
+ @abstractmethod
52
+ def update(self, step: int, replay_sample: dict) -> dict:
53
+ pass
54
+
55
+ @abstractmethod
56
+ def act(self, step: int, observation: dict, deterministic: bool) -> ActResult:
57
+ # returns dict of values that get put in the replay.
58
+ # One of these must be 'action'.
59
+ pass
60
+
61
+ def reset(self) -> None:
62
+ pass
63
+
64
+ @abstractmethod
65
+ def update_summaries(self) -> List[Summary]:
66
+ pass
67
+
68
+ @abstractmethod
69
+ def act_summaries(self) -> List[Summary]:
70
+ pass
71
+
72
+ @abstractmethod
73
+ def load_weights(self, savedir: str) -> None:
74
+ pass
75
+
76
+ @abstractmethod
77
+ def save_weights(self, savedir: str) -> None:
78
+ pass
79
+
80
+
81
+ class BimanualAgent(Agent):
82
+ """
83
+
84
+ """
85
+
86
+ def __init__(self, right_agent: Agent, left_agent: Agent):
87
+ self.right_agent = right_agent
88
+ self.left_agent = left_agent
89
+ self._summaries = {}
90
+
91
+ def build(self, training: bool, device=None) -> None:
92
+ self.right_agent.build(training, device)
93
+ self.left_agent.build(training, device)
94
+
95
+ def update(self, step: int, replay_sample: dict) -> dict:
96
+ right_observation = {}
97
+ left_observation = {}
98
+
99
+ for k, v in replay_sample.items():
100
+ if "rgb" in k or "point_cloud" in k or "camera" in k:
101
+ right_observation[k] = v
102
+ left_observation[k] = v
103
+ elif "right_" in k :
104
+ right_observation[k[6:]] = v
105
+ elif "left_" in k:
106
+ left_observation[k[5:]] = v
107
+ else:
108
+ right_observation[k] = v
109
+ left_observation[k] = v
110
+
111
+ action = replay_sample["action"]
112
+ right_action, left_action = action.chunk(2, dim=2)
113
+ right_observation["action"] = right_action
114
+ left_observation["action"] = left_action
115
+
116
+ right_update_dict = self.right_agent.update(step, right_observation)
117
+ left_update_dict = self.left_agent.update(step, left_observation)
118
+
119
+ total_losses = right_update_dict["total_losses"] + left_update_dict["total_losses"]
120
+ self._summaries.update({"total_losses": total_losses})
121
+ return self._summaries
122
+
123
+
124
+ def act(self, step: int, observation: dict, deterministic: bool) -> ActResult:
125
+
126
+ observation_elements = {}
127
+ info = {}
128
+
129
+ right_observation = {}
130
+ left_observation = {}
131
+
132
+ for k, v in observation.items():
133
+ if "rgb" in k or "point_cloud" in k or "camera" in k:
134
+ right_observation[k] = v
135
+ left_observation[k] = v
136
+ elif "right_" in k :
137
+ right_observation[k[6:]] = v
138
+ elif "left_" in k:
139
+ left_observation[k[5:]] = v
140
+ else:
141
+ right_observation[k] = v
142
+ left_observation[k] = v
143
+
144
+ right_act_result = self.right_agent.act(step, right_observation, deterministic)
145
+ left_act_result = self.left_agent.act(step, left_observation, deterministic)
146
+
147
+ action = (*right_act_result.action, *left_act_result.action)
148
+
149
+ observation_elements.update(right_act_result.observation_elements)
150
+ observation_elements.update(left_act_result.observation_elements)
151
+
152
+ info.update(right_act_result.info)
153
+ info.update(left_act_result.info)
154
+
155
+ return ActResult(action, observation_elements=observation_elements, info=info)
156
+
157
+ def reset(self) -> None:
158
+ self.right_agent.reset()
159
+ self.left_agent.reset()
160
+
161
+ def update_summaries(self) -> List[Summary]:
162
+ summaries = []
163
+ for k, v in self._summaries.items():
164
+ summaries.append(ScalarSummary(f"{k}", v))
165
+
166
+ right_summaries = self.right_agent.update_summaries()
167
+ left_summaries = self.left_agent.update_summaries()
168
+
169
+ for summary in right_summaries:
170
+ if not isinstance(summary, ImageSummary):
171
+ summary.name = f"agent_right/{summary.name}"
172
+
173
+ for summary in left_summaries:
174
+ if not isinstance(summary, ImageSummary):
175
+ summary.name = f"agent_left/{summary.name}"
176
+
177
+ return right_summaries + left_summaries + summaries
178
+
179
+
180
+ def act_summaries(self) -> List[Summary]:
181
+ right_summaries = self.right_agent.act_summaries()
182
+ left_summaries = self.left_agent.act_summaries()
183
+
184
+ for summary in right_summaries:
185
+ if not isinstance(summary, ImageSummary):
186
+ summary.name = f"agent_right/{summary.name}"
187
+
188
+ for summary in left_summaries:
189
+ if not isinstance(summary, ImageSummary):
190
+ summary.name = f"agent_left/{summary.name}"
191
+
192
+ return right_summaries + left_summaries
193
+
194
+
195
+ def load_weights(self, savedir: str) -> None:
196
+ self.right_agent.load_weights(savedir)
197
+ self.left_agent.load_weights(savedir)
198
+
199
+ def save_weights(self, savedir: str) -> None:
200
+ self.right_agent.save_weights(savedir)
201
+ self.left_agent.save_weights(savedir)
202
+
203
+
204
+ class LeaderFollowerAgent(Agent):
205
+
206
+ def __init__(self, leader_agent: Agent, follower_agent: Agent):
207
+ self.leader_agent = leader_agent
208
+ self.follower_agent = follower_agent
209
+ self._summaries = {}
210
+
211
+ def build(self, training: bool, device=None) -> None:
212
+ self.leader_agent.build(training, device)
213
+ self.follower_agent.build(training, device)
214
+
215
+ def update(self, step: int, replay_sample: dict) -> dict:
216
+
217
+ leader_observation = {}
218
+ follower_observation = {}
219
+
220
+
221
+ for k, v in replay_sample.items():
222
+ if "rgb" in k or "point_cloud" in k or "camera" in k:
223
+ leader_observation[k] = v
224
+ follower_observation[k] = v
225
+ elif "right_" in k :
226
+ leader_observation[k[6:]] = v
227
+ elif "left_" in k:
228
+ follower_observation[k[5:]] = v
229
+ else:
230
+ leader_observation[k] = v
231
+ follower_observation[k] = v
232
+
233
+ action = replay_sample["action"]
234
+ right_action, left_action = action.chunk(2, dim=2)
235
+ leader_observation["action"] = right_action
236
+ follower_observation["action"] = left_action
237
+
238
+ leader_update_dict = self.leader_agent.update(step, leader_observation)
239
+ import torch
240
+ follower_observation['low_dim_state'] = torch.cat([follower_observation['low_dim_state'],
241
+ replay_sample["right_trans_action_indicies"],
242
+ replay_sample["right_rot_grip_action_indicies"],
243
+ replay_sample["right_ignore_collisions"]], dim=-1)
244
+
245
+ follower_update_dict = self.follower_agent.update(step, follower_observation)
246
+
247
+ total_losses = leader_update_dict["total_losses"] + follower_update_dict["total_losses"]
248
+ self._summaries.update({"total_losses": total_losses})
249
+ return self._summaries
250
+
251
+ def act(self, step: int, observation: dict, deterministic: bool) -> ActResult:
252
+
253
+ observation_elements = {}
254
+ info = {}
255
+
256
+ leader_observation = {}
257
+ follower_observation = {}
258
+
259
+ for k,v in observation.items():
260
+ if "right_" in k and not "rgb" in k and not "point_cloud" in k and not "camera" in k:
261
+ leader_observation[k[6:]] = v
262
+ elif "left_" in k and not "rgb" in k and not "point_cloud" in k and not "camera" in k:
263
+ follower_observation[k[5:]] = v
264
+ else:
265
+ leader_observation[k] = v
266
+ follower_observation[k] = v
267
+
268
+ right_act_result = self.leader_agent.act(step, leader_observation, deterministic)
269
+
270
+ right_observation_elements = right_act_result.observation_elements
271
+
272
+ import torch
273
+
274
+ device = follower_observation['low_dim_state'].device
275
+ if "trans_action_indicies" in right_observation_elements:
276
+ right_trans_action_indicies = torch.from_numpy(right_observation_elements["trans_action_indicies"]).unsqueeze(0).unsqueeze(0).to(device)
277
+ right_rot_grip_action_indicies = torch.from_numpy(right_observation_elements["rot_grip_action_indicies"]).unsqueeze(0).unsqueeze(0).to(device)
278
+ right_ignore_collisions = torch.from_numpy(right_act_result.action[-1:]).unsqueeze(0).unsqueeze(0).to(device)
279
+ else:
280
+ right_trans_action_indicies = torch.empty((1, 1, 3)).to(device)
281
+ right_rot_grip_action_indicies = torch.empty((1, 1, 4)).to(device)
282
+ right_ignore_collisions = torch.empty((1, 1, 1)).to(device)
283
+
284
+
285
+ follower_observation['low_dim_state'] = torch.cat([follower_observation['low_dim_state'],
286
+ right_trans_action_indicies,
287
+ right_rot_grip_action_indicies,
288
+ right_ignore_collisions], dim=-1)
289
+
290
+ left_act_result = self.follower_agent.act(step, follower_observation, deterministic)
291
+
292
+ action = (*right_act_result.action, *left_act_result.action)
293
+
294
+ observation_elements.update(right_act_result.observation_elements)
295
+ observation_elements.update(left_act_result.observation_elements)
296
+
297
+ info.update(right_act_result.info)
298
+ info.update(left_act_result.info)
299
+
300
+ return ActResult(action, observation_elements=observation_elements, info=info)
301
+
302
+
303
+ def reset(self) -> None:
304
+ self.leader_agent.reset()
305
+ self.follower_agent.reset()
306
+
307
+ def update_summaries(self) -> List[Summary]:
308
+
309
+ summaries = []
310
+ for k, v in self._summaries.items():
311
+ summaries.append(ScalarSummary(f"{k}", v))
312
+
313
+ leader_summaries = self.leader_agent.update_summaries()
314
+ follower_summaries = self.follower_agent.update_summaries()
315
+
316
+ for summary in leader_summaries:
317
+ if not isinstance(summary, ImageSummary):
318
+ summary.name = f"agent_leader/{summary.name}"
319
+ for summary in follower_summaries:
320
+ if not isinstance(summary, ImageSummary):
321
+ summary.name = f"agent_follower/{summary.name}"
322
+
323
+ return leader_summaries + follower_summaries + summaries
324
+
325
+
326
+ def act_summaries(self) -> List[Summary]:
327
+ leader_summaries = self.leader_agent.act_summaries()
328
+ follower_summaries = self.follower_agent.act_summaries()
329
+
330
+ for summary in leader_summaries:
331
+ if not isinstance(summary, ImageSummary):
332
+ summary.name = f"agent_leader/{summary.name}"
333
+ for summary in follower_summaries:
334
+ if not isinstance(summary, ImageSummary):
335
+ summary.name = f"agent_follower/{summary.name}"
336
+
337
+ return leader_summaries + follower_summaries
338
+
339
+ def load_weights(self, savedir: str) -> None:
340
+ self.leader_agent.load_weights(savedir)
341
+ self.follower_agent.load_weights(savedir)
342
+
343
+ def save_weights(self, savedir: str) -> None:
344
+ self.leader_agent.save_weights(savedir)
345
+ self.follower_agent.save_weights(savedir)
external/yarr/yarr/envs/__init__.py ADDED
File without changes
external/yarr/yarr/envs/env.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, List
3
+
4
+ import numpy as np
5
+
6
+ from yarr.utils.observation_type import ObservationElement
7
+ from yarr.utils.transition import Transition
8
+
9
+
10
+ class Env(ABC):
11
+
12
+ def __init__(self):
13
+ self._active_task_id = 0
14
+ self._eval_env = False
15
+
16
+ @property
17
+ def eval(self):
18
+ return self._eval_env
19
+
20
+ @eval.setter
21
+ def eval(self, eval):
22
+ self._eval_env = eval
23
+
24
+ @property
25
+ def active_task_id(self) -> int:
26
+ return self._active_task_id
27
+
28
+ @abstractmethod
29
+ def launch(self) -> None:
30
+ pass
31
+
32
+ def shutdown(self) -> None:
33
+ pass
34
+
35
+ @abstractmethod
36
+ def reset(self) -> dict:
37
+ pass
38
+
39
+ @abstractmethod
40
+ def step(self, action: np.ndarray) -> Transition:
41
+ pass
42
+
43
+ @property
44
+ @abstractmethod
45
+ def observation_elements(self) -> List[ObservationElement]:
46
+ pass
47
+
48
+ @property
49
+ @abstractmethod
50
+ def action_shape(self) -> tuple:
51
+ pass
52
+
53
+ @property
54
+ @abstractmethod
55
+ def env(self) -> Any:
56
+ pass
57
+
58
+
59
+ class MultiTaskEnv(Env):
60
+
61
+ @property
62
+ @abstractmethod
63
+ def num_tasks(self) -> int:
64
+ pass
external/yarr/yarr/envs/rlbench_env.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Type, List
3
+
4
+ import numpy as np
5
+ try:
6
+ from rlbench import ObservationConfig, Environment, CameraConfig
7
+ except (ModuleNotFoundError, ImportError) as e:
8
+ print("You need to install RLBench: 'https://github.com/stepjam/RLBench'")
9
+ raise e
10
+ from rlbench.action_modes.action_mode import ActionMode
11
+ from rlbench.backend.observation import BimanualObservation, Observation
12
+ from rlbench.backend.task import Task
13
+ from rlbench.backend.task import BimanualTask
14
+
15
+ from helpers.clip.core.clip import tokenize
16
+
17
+ from yarr.envs.env import Env, MultiTaskEnv
18
+ from yarr.utils.observation_type import ObservationElement
19
+ from yarr.utils.transition import Transition
20
+ from yarr.utils.process_str import change_case
21
+
22
+ import logging
23
+
24
+
25
+ ROBOT_STATE_KEYS = ['joint_velocities', 'joint_positions', 'joint_forces',
26
+ 'gripper_open', 'gripper_pose',
27
+ 'gripper_joint_positions', 'gripper_touch_forces',
28
+ 'task_low_dim_state', 'misc', 'left', 'right']
29
+
30
+
31
+ # ..todo:: possibly duplicated code.
32
+ def _extract_obs_bimanual(obs: BimanualObservation, channels_last: bool, observation_config: ObservationConfig):
33
+ obs_dict = vars(obs)
34
+ obs_dict = {k: v for k, v in obs_dict.items() if v is not None}
35
+
36
+ right_robot_state = obs.get_low_dim_data(obs.right)
37
+ left_robot_state = obs.get_low_dim_data(obs.left)
38
+
39
+ obs_dict = {k: v for k, v in obs_dict.items()
40
+ if k not in ROBOT_STATE_KEYS}
41
+
42
+ if not channels_last:
43
+ # Swap channels from last dim to 1st dim
44
+ obs_dict = {k: np.transpose(v, [2, 0, 1]) if v.ndim == 3 else np.expand_dims(v, 0)
45
+ for k, v in obs.perception_data.items() if v is not None}
46
+ else:
47
+ # Add extra dim to depth data
48
+ obs_dict = {k: v if v.ndim == 3 else np.expand_dims(v, -1)
49
+ for k, v in obs.perception_data.items() if v is not None}
50
+
51
+ if observation_config.robot_name == 'right':
52
+ obs_dict['low_dim_state'] = right_robot_state.astype(np.float32)
53
+ obs_dict['ignore_collisions'] = np.array([obs.right.ignore_collisions], dtype=np.float32)
54
+ elif observation_config.robot_name == 'left':
55
+ obs_dict['low_dim_state'] = left_robot_state.astype(np.float32)
56
+ obs_dict['ignore_collisions'] = np.array([obs.left.ignore_collisions], dtype=np.float32)
57
+ else:
58
+ obs_dict['right_low_dim_state'] = right_robot_state.astype(np.float32)
59
+ obs_dict['left_low_dim_state'] = left_robot_state.astype(np.float32)
60
+ obs_dict['right_ignore_collisions'] = np.array([obs.right.ignore_collisions], dtype=np.float32)
61
+ obs_dict['left_ignore_collisions'] = np.array([obs.left.ignore_collisions], dtype=np.float32)
62
+
63
+ for (k, v) in [(k, v) for k, v in obs_dict.items() if 'point_cloud' in k]:
64
+ # ..TODO::
65
+ obs_dict[k] = v.astype(np.float16)
66
+
67
+ for camera_name, config in observation_config.camera_configs.items():
68
+ if config.point_cloud:
69
+ obs_dict[f'{camera_name}_camera_extrinsics'] = obs.misc[f'{camera_name}_camera_extrinsics']
70
+ obs_dict[f'{camera_name}_camera_intrinsics'] = obs.misc[f'{camera_name}_camera_intrinsics']
71
+ return obs_dict
72
+
73
+
74
+ def _extract_obs_unimanual(obs: Observation, channels_last: bool, observation_config):
75
+ obs_dict = vars(obs)
76
+ obs_dict = {k: v for k, v in obs_dict.items() if v is not None}
77
+ robot_state = obs.get_low_dim_data()
78
+ # Remove all of the individual state elements
79
+ obs_dict = {k: v for k, v in obs_dict.items()
80
+ if k not in ROBOT_STATE_KEYS}
81
+ if not channels_last:
82
+ # Swap channels from last dim to 1st dim
83
+ obs_dict = {k: np.transpose(
84
+ v, [2, 0, 1]) if v.ndim == 3 else np.expand_dims(v, 0)
85
+ for k, v in obs_dict.items()}
86
+ else:
87
+ # Add extra dim to depth data
88
+ obs_dict = {k: v if v.ndim == 3 else np.expand_dims(v, -1)
89
+ for k, v in obs_dict.items()}
90
+ obs_dict['low_dim_state'] = np.array(robot_state, dtype=np.float32)
91
+ obs_dict['ignore_collisions'] = np.array([obs.ignore_collisions], dtype=np.float32)
92
+ for (k, v) in [(k, v) for k, v in obs_dict.items() if 'point_cloud' in k]:
93
+ obs_dict[k] = v.astype(np.float32)
94
+
95
+ for config, name in [
96
+ (observation_config.left_shoulder_camera, 'left_shoulder'),
97
+ (observation_config.right_shoulder_camera, 'right_shoulder'),
98
+ (observation_config.front_camera, 'front'),
99
+ (observation_config.wrist_camera, 'wrist'),
100
+ (observation_config.overhead_camera, 'overhead')]:
101
+ if config.point_cloud:
102
+ obs_dict['%s_camera_extrinsics' % name] = obs.misc['%s_camera_extrinsics' % name]
103
+ obs_dict['%s_camera_intrinsics' % name] = obs.misc['%s_camera_intrinsics' % name]
104
+ return obs_dict
105
+
106
+
107
+ def _get_cam_observation_elements(camera: CameraConfig, prefix: str, channels_last):
108
+ elements = []
109
+ img_s = list(camera.image_size)
110
+ shape = img_s + [3] if channels_last else [3] + img_s
111
+ if camera.rgb:
112
+ elements.append(
113
+ ObservationElement('%s_rgb' % prefix, shape, np.uint8))
114
+ if camera.point_cloud:
115
+ elements.append(
116
+ ObservationElement('%s_point_cloud' % prefix, shape, np.float32))
117
+ elements.append(
118
+ ObservationElement('%s_camera_extrinsics' % prefix, (4, 4),
119
+ np.float32))
120
+ elements.append(
121
+ ObservationElement('%s_camera_intrinsics' % prefix, (3, 3),
122
+ np.float32))
123
+ if camera.depth:
124
+ shape = img_s + [1] if channels_last else [1] + img_s
125
+ elements.append(
126
+ ObservationElement('%s_depth' % prefix, shape, np.float32))
127
+ if camera.mask:
128
+ raise NotImplementedError()
129
+
130
+ return elements
131
+
132
+
133
+ def _observation_elements(observation_config, channels_last) -> List[ObservationElement]:
134
+ elements = []
135
+ robot_state_len = 0
136
+ if observation_config.joint_velocities:
137
+ robot_state_len += 7
138
+ if observation_config.joint_positions:
139
+ robot_state_len += 7
140
+ if observation_config.joint_forces:
141
+ robot_state_len += 7
142
+ if observation_config.gripper_open:
143
+ robot_state_len += 1
144
+ if observation_config.gripper_pose:
145
+ robot_state_len += 7
146
+ if observation_config.gripper_joint_positions:
147
+ robot_state_len += 2
148
+ if observation_config.gripper_touch_forces:
149
+ robot_state_len += 2
150
+ if observation_config.task_low_dim_state:
151
+ raise NotImplementedError()
152
+ if robot_state_len > 0:
153
+ if observation_config.robot_name == 'bimanual':
154
+ elements.append(ObservationElement(
155
+ 'right_low_dim_state', (robot_state_len,), np.float32))
156
+ elements.append(ObservationElement(
157
+ 'left_low_dim_state', (robot_state_len,), np.float32))
158
+ elif observation_config.robot_name in ['unimanual', 'left', 'right']:
159
+ elements.append(ObservationElement('low_dim_state', (robot_state_len,), np.float32))
160
+ elements.extend(_get_cam_observation_elements(
161
+ observation_config.left_shoulder_camera, 'left_shoulder', channels_last))
162
+ elements.extend(_get_cam_observation_elements(
163
+ observation_config.right_shoulder_camera, 'right_shoulder', channels_last))
164
+ elements.extend(_get_cam_observation_elements(
165
+ observation_config.front_camera, 'front', channels_last))
166
+ elements.extend(_get_cam_observation_elements(
167
+ observation_config.wrist_camera, 'wrist', channels_last))
168
+ return elements
169
+
170
+
171
+ class RLBenchEnv(Env):
172
+
173
+ def __init__(self, task_class: Type[Task],
174
+ observation_config: ObservationConfig,
175
+ action_mode: ActionMode,
176
+ dataset_root: str = '',
177
+ channels_last=False,
178
+ headless=True,
179
+ include_lang_goal_in_obs=False):
180
+ super(RLBenchEnv, self).__init__()
181
+ self._task_class = task_class
182
+ self._observation_config = observation_config
183
+ self._channels_last = channels_last
184
+ self._include_lang_goal_in_obs = include_lang_goal_in_obs
185
+ if issubclass(task_class, BimanualTask):
186
+ robot_setup = "dual_panda"
187
+ else:
188
+ robot_setup = "panda"
189
+ self._rlbench_env = Environment(
190
+ action_mode=action_mode, obs_config=observation_config,
191
+ dataset_root=dataset_root, headless=headless, robot_setup=robot_setup)
192
+ self._task = None
193
+ self._lang_goal = 'unknown goal'
194
+
195
+
196
+ def extract_obs(self, obs: Observation):
197
+ if isinstance(obs, BimanualObservation):
198
+ extracted_obs = _extract_obs_bimanual(obs, self._channels_last, self._observation_config)
199
+ else:
200
+ extracted_obs = _extract_obs_unimanual(obs, self._channels_last, self._observation_config)
201
+ if self._include_lang_goal_in_obs:
202
+ extracted_obs['lang_goal_tokens'] = tokenize([self._lang_goal])[0].numpy()
203
+ return extracted_obs
204
+
205
+ def launch(self):
206
+ self._rlbench_env.launch()
207
+ self._task = self._rlbench_env.get_task(self._task_class)
208
+
209
+ def shutdown(self):
210
+ self._rlbench_env.shutdown()
211
+
212
+ def reset(self) -> dict:
213
+ descriptions, obs = self._task.reset()
214
+ self._lang_goal = descriptions[0] # first description variant
215
+ extracted_obs = self.extract_obs(obs)
216
+ return extracted_obs
217
+
218
+ def step(self, action: np.ndarray) -> Transition:
219
+ obs, reward, terminal = self._task.step(action)
220
+ obs = self.extract_obs(obs)
221
+ return Transition(obs, reward, terminal)
222
+
223
+ @property
224
+ def observation_elements(self) -> List[ObservationElement]:
225
+ return _observation_elements(self._observation_config, self._channels_last)
226
+
227
+ @property
228
+ def action_shape(self):
229
+ return (self._rlbench_env.action_size, )
230
+
231
+ @property
232
+ def env(self) -> Environment:
233
+ return self._rlbench_env
234
+
235
+
236
+ class MultiTaskRLBenchEnv(MultiTaskEnv):
237
+
238
+ def __init__(self,
239
+ task_classes: List[Type[Task]],
240
+ observation_config: ObservationConfig,
241
+ action_mode: ActionMode,
242
+ dataset_root: str = '',
243
+ channels_last=False,
244
+ headless=True,
245
+ swap_task_every: int = 1,
246
+ include_lang_goal_in_obs=False):
247
+ super(MultiTaskRLBenchEnv, self).__init__()
248
+ self._task_classes = task_classes
249
+ self._observation_config = observation_config
250
+ self._channels_last = channels_last
251
+ self._include_lang_goal_in_obs = include_lang_goal_in_obs
252
+ if issubclass(task_classes[0], BimanualTask):
253
+ robot_setup = "dual_panda"
254
+ else:
255
+ robot_setup = "panda"
256
+ self._rlbench_env = Environment(
257
+ action_mode=action_mode, obs_config=observation_config,
258
+ dataset_root=dataset_root, headless=headless, robot_setup=robot_setup)
259
+ self._task = None
260
+ self._task_name = ''
261
+ self._lang_goal = 'unknown goal'
262
+ self._swap_task_every = swap_task_every
263
+ self._rlbench_env
264
+ self._episodes_this_task = 0
265
+ self._active_task_id = -1
266
+
267
+ self._task_name_to_idx = {change_case(tc.__name__):i for i, tc in enumerate(self._task_classes)}
268
+
269
+ def _set_new_task(self, shuffle=False):
270
+ if shuffle:
271
+ self._active_task_id = np.random.randint(0, len(self._task_classes))
272
+ else:
273
+ self._active_task_id = (self._active_task_id + 1) % len(self._task_classes)
274
+ task = self._task_classes[self._active_task_id]
275
+ self._task = self._rlbench_env.get_task(task)
276
+
277
+ def set_task(self, task_name: str):
278
+ self._active_task_id = self._task_name_to_idx[task_name]
279
+ task = self._task_classes[self._active_task_id]
280
+ self._task = self._rlbench_env.get_task(task)
281
+
282
+ descriptions, _ = self._task.reset()
283
+ self._lang_goal = descriptions[0] # first description variant
284
+
285
+ def extract_obs(self, obs: Observation):
286
+ if obs.is_bimanual:
287
+ extracted_obs = _extract_obs_bimanual(obs, self._channels_last, self._observation_config)
288
+ else:
289
+ extracted_obs = _extract_obs_unimanual(obs, self._channels_last, self._observation_config)
290
+ if self._include_lang_goal_in_obs:
291
+ extracted_obs['lang_goal_tokens'] = tokenize([self._lang_goal])[0].numpy()
292
+ return extracted_obs
293
+
294
+ def launch(self):
295
+ self._rlbench_env.launch()
296
+ self._set_new_task()
297
+
298
+ def shutdown(self):
299
+ self._rlbench_env.shutdown()
300
+
301
+ def reset(self) -> dict:
302
+ if self._episodes_this_task == self._swap_task_every:
303
+ self._set_new_task()
304
+ self._episodes_this_task = 0
305
+ self._episodes_this_task += 1
306
+
307
+ descriptions, obs = self._task.reset()
308
+ self._lang_goal = descriptions[0] # first description variant
309
+ extracted_obs = self.extract_obs(obs)
310
+
311
+ return extracted_obs
312
+
313
+ def step(self, action: np.ndarray) -> Transition:
314
+ obs, reward, terminal = self._task.step(action)
315
+ obs = self.extract_obs(obs)
316
+ return Transition(obs, reward, terminal)
317
+
318
+ @property
319
+ def observation_elements(self) -> List[ObservationElement]:
320
+ return _observation_elements(self._observation_config, self._channels_last)
321
+
322
+ @property
323
+ def action_shape(self):
324
+ return (self._rlbench_env.action_size, )
325
+
326
+ @property
327
+ def env(self) -> Environment:
328
+ return self._rlbench_env
329
+
330
+ @property
331
+ def num_tasks(self) -> int:
332
+ return len(self._task_classes)
external/yarr/yarr/replay_buffer/__init__.py ADDED
File without changes
external/yarr/yarr/replay_buffer/prioritized_replay_buffer.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """An implementation of Prioritized Experience Replay (PER).
2
+
3
+ This implementation is based on the paper "Prioritized Experience Replay"
4
+ by Tom Schaul et al. (2015).
5
+ """
6
+ from __future__ import absolute_import
7
+ from __future__ import division
8
+ from __future__ import print_function
9
+
10
+ from .uniform_replay_buffer import *
11
+ from .sum_tree import *
12
+ import numpy as np
13
+
14
+
15
+ PRIORITY = 'priority'
16
+
17
+
18
+ class PrioritizedReplayBuffer(UniformReplayBuffer):
19
+ """An out-of-graph Replay Buffer for Prioritized Experience Replay.
20
+
21
+ See uniform_replay_buffer.py for details.
22
+ """
23
+
24
+ def __init__(self, *args, **kwargs):
25
+ """Initializes OutOfGraphPrioritizedReplayBuffer."""
26
+ super(PrioritizedReplayBuffer, self).__init__(*args, **kwargs)
27
+ self._sum_tree = SumTree(self._replay_capacity)
28
+
29
+ def get_storage_signature(self) -> Tuple[List[ReplayElement],
30
+ List[ReplayElement]]:
31
+ """Returns a default list of elements to be stored in this replay memory.
32
+
33
+ Note - Derived classes may return a different signature.
34
+
35
+ Returns:
36
+ dict of ReplayElements defining the type of the contents stored.
37
+ """
38
+ storage_elements, obs_elements = super(
39
+ PrioritizedReplayBuffer, self).get_storage_signature()
40
+ storage_elements.append(ReplayElement(PRIORITY, (), np.float32),)
41
+
42
+ return storage_elements, obs_elements
43
+
44
+ def add(self, action, reward, terminal, timeout, priority=None, **kwargs):
45
+ kwargs['priority'] = priority
46
+ super(PrioritizedReplayBuffer, self).add(
47
+ action, reward, terminal, timeout, **kwargs)
48
+
49
+ def _add(self, kwargs: dict):
50
+ """Internal add method to add to the storage arrays.
51
+
52
+ Args:
53
+ kwargs: All the elements in a transition.
54
+ """
55
+ with self._lock:
56
+ cursor = self.cursor()
57
+ priority = kwargs[PRIORITY]
58
+ if priority is None:
59
+ priority = self._sum_tree.max_recorded_priority
60
+
61
+ if self._disk_saving:
62
+ term = self._store[TERMINAL]
63
+ term[cursor] = kwargs[TERMINAL]
64
+ self._store[TERMINAL] = term
65
+
66
+ with open(join(self._save_dir, '%d.replay' % cursor), 'wb') as f:
67
+ pickle.dump(kwargs, f)
68
+ # If first add, then pad for correct wrapping
69
+ if self._add_count.value == 0:
70
+ self._add_initial_to_disk(kwargs)
71
+ else:
72
+ for name, data in kwargs.items():
73
+ item = self._store[name]
74
+ item[cursor] = data
75
+ self._store[name] = item
76
+
77
+
78
+ self._sum_tree.set(self.cursor(), priority)
79
+ self._add_count.value += 1
80
+ self.invalid_range = invalid_range(
81
+ self.cursor(), self._replay_capacity, self._timesteps,
82
+ self._update_horizon)
83
+
84
+ def add_final(self, **kwargs):
85
+ """Adds a transition to the replay memory.
86
+ Args:
87
+ **kwargs: The remaining args
88
+ """
89
+ # if self.is_empty() or self._store['terminal'][self.cursor() - 1] != 1:
90
+ # raise ValueError('The previous transition was not terminal.')
91
+ self._check_add_types(kwargs, self._obs_signature)
92
+ transition = self._final_transition(kwargs)
93
+ for element_type in self._storage_signature:
94
+ # 0 priority for final observation.
95
+ if element_type.name == PRIORITY:
96
+ transition[element_type.name] = 0.0
97
+ self._add(transition)
98
+
99
+ def sample_index_batch(self, batch_size):
100
+ """Returns a batch of valid indices sampled as in Schaul et al. (2015).
101
+
102
+ Args:
103
+ batch_size: int, number of indices returned.
104
+
105
+ Returns:
106
+ list of ints, a batch of valid indices sampled uniformly.
107
+
108
+ Raises:
109
+ Exception: If the batch was not constructed after maximum number of tries.
110
+ """
111
+ # Sample stratified indices. Some of them might be invalid.
112
+ indices = self._sum_tree.stratified_sample(batch_size)
113
+ allowed_attempts = self._max_sample_attempts
114
+ for i in range(len(indices)):
115
+ if not self.is_valid_transition(indices[i]):
116
+ if allowed_attempts == 0:
117
+ raise RuntimeError(
118
+ 'Max sample attempts: Tried {} times but only sampled {}'
119
+ ' valid indices. Batch size is {}'.
120
+ format(self._max_sample_attempts, i, batch_size))
121
+ index = indices[i]
122
+ while not self.is_valid_transition(
123
+ index) and allowed_attempts > 0:
124
+ # If index i is not valid keep sampling others. Note that this
125
+ # is not stratified.
126
+ index = self._sum_tree.sample()
127
+ allowed_attempts -= 1
128
+ indices[i] = index
129
+ return indices
130
+
131
+ def sample_transition_batch(self, batch_size=None, indices=None,
132
+ pack_in_dict=True):
133
+ """Returns a batch of transitions with extra storage and the priorities.
134
+
135
+ The extra storage are defined through the extra_storage_types constructor
136
+ argument.
137
+
138
+ When the transition is terminal next_state_batch has undefined contents.
139
+
140
+ Args:
141
+ batch_size: int, number of transitions returned. If None, the default
142
+ batch_size will be used.
143
+ indices: None or list of ints, the indices of every transition in the
144
+ batch. If None, sample the indices uniformly.
145
+
146
+ Returns:
147
+ transition_batch: tuple of np.arrays with the shape and type as in
148
+ get_transition_elements().
149
+ """
150
+ transition = super(
151
+ PrioritizedReplayBuffer, self).sample_transition_batch(
152
+ batch_size, indices, pack_in_dict=False)
153
+
154
+ transition_elements = self.get_transition_elements(batch_size)
155
+ transition_names = [e.name for e in transition_elements]
156
+ probabilities_index = transition_names.index('sampling_probabilities')
157
+ indices_index = transition_names.index('indices')
158
+ indices = transition[indices_index]
159
+ # The parent returned an empty array for the probabilities. Fill it with the
160
+ # contents of the sum tree.
161
+ transition[probabilities_index][:] = self.get_priority(indices)
162
+ batch_arrays = transition
163
+ if pack_in_dict:
164
+ batch_arrays = self.unpack_transition(transition,
165
+ transition_elements)
166
+ return batch_arrays
167
+
168
+ def set_priority(self, indices, priorities):
169
+ """Sets the priority of the given elements according to Schaul et al.
170
+
171
+ Args:
172
+ indices: np.array with dtype int32, of indices in range
173
+ [0, replay_capacity).
174
+ priorities: float, the corresponding priorities.
175
+ """
176
+ assert indices.dtype == np.int32, ('Indices must be integers, '
177
+ 'given: {}'.format(indices.dtype))
178
+ for index, priority in zip(indices, priorities):
179
+ self._sum_tree.set(index, priority)
180
+
181
+ def get_priority(self, indices):
182
+ """Fetches the priorities correspond to a batch of memory indices.
183
+
184
+ For any memory location not yet used, the corresponding priority is 0.
185
+
186
+ Args:
187
+ indices: np.array with dtype int32, of indices in range
188
+ [0, replay_capacity).
189
+
190
+ Returns:
191
+ priorities: float, the corresponding priorities.
192
+ """
193
+ assert indices.shape, 'Indices must be an array.'
194
+ assert indices.dtype == np.int32, ('Indices must be int32s, '
195
+ 'given: {}'.format(indices.dtype))
196
+ batch_size = len(indices)
197
+ priority_batch = np.empty((batch_size), dtype=np.float32)
198
+ for i, memory_index in enumerate(indices):
199
+ priority_batch[i] = self._sum_tree.get(memory_index)
200
+ return priority_batch
201
+
202
+ def get_transition_elements(self, batch_size=None):
203
+ """Returns a 'type signature' for sample_transition_batch.
204
+
205
+ Args:
206
+ batch_size: int, number of transitions returned. If None, the default
207
+ batch_size will be used.
208
+ Returns:
209
+ signature: A namedtuple describing the method's return type signature.
210
+ """
211
+ parent_transition_type = (
212
+ super(PrioritizedReplayBuffer,
213
+ self).get_transition_elements(batch_size))
214
+ probablilities_type = [
215
+ ReplayElement('sampling_probabilities', (batch_size,), np.float32)
216
+ ]
217
+ return parent_transition_type + probablilities_type
external/yarr/yarr/replay_buffer/replay_buffer.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+ from typing import Tuple, List
3
+
4
+ class ReplayElement(object):
5
+ def __init__(self, name, shape, type, is_observation=False):
6
+ self.name = name
7
+ self.shape = shape
8
+ self.type = type
9
+ self.is_observation = is_observation
10
+
11
+
12
+ class ReplayBuffer(ABC):
13
+
14
+ def replay_capacity(self):
15
+ pass
16
+
17
+ def batch_size(self):
18
+ pass
19
+
20
+ def get_storage_signature(self) -> Tuple[List[ReplayElement],
21
+ List[ReplayElement]]:
22
+ pass
23
+
24
+ def add(self, action, reward, terminal, timeout, **kwargs):
25
+ pass
26
+
27
+ def add_final(self, **kwargs):
28
+ pass
29
+
30
+ def is_empty(self):
31
+ pass
32
+
33
+ def is_full(self):
34
+ pass
35
+
36
+ def cursor(self):
37
+ pass
38
+
39
+ def set_cursor(self):
40
+ pass
41
+
42
+ def get_range(self, array, start_index, end_index):
43
+ pass
44
+
45
+ def get_range_stack(self, array, start_index, end_index, terminals=None):
46
+ pass
47
+
48
+ def get_terminal_stack(self, index):
49
+ pass
50
+
51
+ def is_valid_transition(self, index):
52
+ pass
53
+
54
+ def sample_index_batch(self, batch_size):
55
+ pass
56
+
57
+ def unpack_transition(self, transition_tensors, transition_type):
58
+ pass
59
+
60
+ def sample_transition_batch(self, batch_size=None, indices=None,
61
+ pack_in_dict=True):
62
+ pass
63
+
64
+ def get_transition_elements(self, batch_size=None):
65
+ pass
66
+
67
+ def shutdown(self):
68
+ pass
69
+
70
+ def using_disk(self):
71
+ pass
external/yarr/yarr/replay_buffer/sum_tree.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A sum tree data structure.
2
+
3
+ Used for prioritized experience replay. See prioritized_replay_buffer.py
4
+ and Schaul et al. (2015).
5
+ """
6
+ from __future__ import absolute_import
7
+ from __future__ import division
8
+ from __future__ import print_function
9
+
10
+ import math
11
+ import random
12
+
13
+ import numpy as np
14
+
15
+
16
+ class SumTree(object):
17
+ """A sum tree data structure for storing replay priorities.
18
+
19
+ A sum tree is a complete binary tree whose leaves contain values called
20
+ priorities. Internal nodes maintain the sum of the priorities of all leaf
21
+ nodes in their subtree.
22
+
23
+ For capacity = 4, the tree may look like this:
24
+
25
+ +---+
26
+ |2.5|
27
+ +-+-+
28
+ |
29
+ +-------+--------+
30
+ | |
31
+ +-+-+ +-+-+
32
+ |1.5| |1.0|
33
+ +-+-+ +-+-+
34
+ | |
35
+ +----+----+ +----+----+
36
+ | | | |
37
+ +-+-+ +-+-+ +-+-+ +-+-+
38
+ |0.5| |1.0| |0.5| |0.5|
39
+ +---+ +---+ +---+ +---+
40
+
41
+ This is stored in a list of numpy arrays:
42
+ self.nodes = [ [2.5], [1.5, 1], [0.5, 1, 0.5, 0.5] ]
43
+
44
+ For conciseness, we allocate arrays as powers of two, and pad the excess
45
+ elements with zero values.
46
+
47
+ This is similar to the usual array-based representation of a complete binary
48
+ tree, but is a little more user-friendly.
49
+ """
50
+
51
+ def __init__(self, capacity, nodes=None):
52
+ """Creates the sum tree data structure for the given replay capacity.
53
+
54
+ Args:
55
+ capacity: int, the maximum number of elements that can be stored in this
56
+ data structure.
57
+ nodes: storage list for storing nodes
58
+
59
+ Raises:
60
+ ValueError: If requested capacity is not positive.
61
+ """
62
+ assert isinstance(capacity, int)
63
+ if capacity <= 0:
64
+ raise ValueError('Sum tree capacity should be positive. Got: {}'.
65
+ format(capacity))
66
+
67
+ self.nodes = [] if nodes is None else nodes
68
+ tree_depth = int(math.ceil(np.log2(capacity)))
69
+ level_size = 1
70
+ for _ in range(tree_depth + 1):
71
+ nodes_at_this_depth = np.zeros(level_size)
72
+ self.nodes.append(nodes_at_this_depth)
73
+
74
+ level_size *= 2
75
+
76
+ self.max_recorded_priority = 1.0
77
+
78
+ def _total_priority(self):
79
+ """Returns the sum of all priorities stored in this sum tree.
80
+
81
+ Returns:
82
+ float, sum of priorities stored in this sum tree.
83
+ """
84
+ return self.nodes[0][0]
85
+
86
+ def sample(self, query_value=None):
87
+ """Samples an element from the sum tree.
88
+
89
+ Each element has probability p_i / sum_j p_j of being picked, where p_i is
90
+ the (positive) value associated with node i (possibly unnormalized).
91
+
92
+ Args:
93
+ query_value: float in [0, 1], used as the random value to select a
94
+ sample. If None, will select one randomly in [0, 1).
95
+
96
+ Returns:
97
+ int, a random element from the sum tree.
98
+
99
+ Raises:
100
+ Exception: If the sum tree is empty (i.e. its node values sum to 0), or if
101
+ the supplied query_value is larger than the total sum.
102
+ """
103
+ if self._total_priority() == 0.0:
104
+ raise Exception('Cannot sample from an empty sum tree.')
105
+
106
+ if query_value and (query_value < 0. or query_value > 1.):
107
+ raise ValueError('query_value must be in [0, 1].')
108
+
109
+ # Sample a value in range [0, R), where R is the value stored at the root.
110
+ query_value = random.random() if query_value is None else query_value
111
+ query_value *= self._total_priority()
112
+
113
+ # Now traverse the sum tree.
114
+ node_index = 0
115
+ for nodes_at_this_depth in self.nodes[1:]:
116
+ # Compute children of previous depth's node.
117
+ left_child = node_index * 2
118
+
119
+ left_sum = nodes_at_this_depth[left_child]
120
+ # Each subtree describes a range [0, a), where a is its value.
121
+ if query_value < left_sum: # Recurse into left subtree.
122
+ node_index = left_child
123
+ else: # Recurse into right subtree.
124
+ node_index = left_child + 1
125
+ # Adjust query to be relative to right subtree.
126
+ query_value -= left_sum
127
+
128
+ return node_index
129
+
130
+ def stratified_sample(self, batch_size):
131
+ """Performs stratified sampling using the sum tree.
132
+
133
+ Let R be the value at the root (total value of sum tree). This method will
134
+ divide [0, R) into batch_size segments, pick a random number from each of
135
+ those segments, and use that random number to sample from the sum_tree. This
136
+ is as specified in Schaul et al. (2015).
137
+
138
+ Args:
139
+ batch_size: int, the number of strata to use.
140
+ Returns:
141
+ list of batch_size elements sampled from the sum tree.
142
+
143
+ Raises:
144
+ Exception: If the sum tree is empty (i.e. its node values sum to 0).
145
+ """
146
+ if self._total_priority() == 0.0:
147
+ raise Exception('Cannot sample from an empty sum tree.')
148
+
149
+ bounds = np.linspace(0., 1., batch_size + 1)
150
+ assert len(bounds) == batch_size + 1
151
+ segments = [(bounds[i], bounds[i + 1]) for i in range(batch_size)]
152
+ # TODO removed for now
153
+ # query_values = [random.uniform(x[0], x[1]) for x in segments]
154
+ query_values = [random.uniform(0, 1) for x in segments]
155
+ return [self.sample(query_value=x) for x in query_values]
156
+
157
+ def get(self, node_index):
158
+ """Returns the value of the leaf node corresponding to the index.
159
+
160
+ Args:
161
+ node_index: The index of the leaf node.
162
+ Returns:
163
+ The value of the leaf node.
164
+ """
165
+ return self.nodes[-1][node_index]
166
+
167
+ def set(self, node_index, value):
168
+ """Sets the value of a leaf node and updates internal nodes accordingly.
169
+
170
+ This operation takes O(log(capacity)).
171
+ Args:
172
+ node_index: int, the index of the leaf node to be updated.
173
+ value: float, the value which we assign to the node. This value must be
174
+ nonnegative. Setting value = 0 will cause the element to never be
175
+ sampled.
176
+
177
+ Raises:
178
+ ValueError: If the given value is negative.
179
+ """
180
+ if value < 0.0:
181
+ raise ValueError('Sum tree values should be nonnegative. Got {}'.
182
+ format(value))
183
+ self.max_recorded_priority = max(value, self.max_recorded_priority)
184
+
185
+ delta_value = value - self.nodes[-1][node_index]
186
+
187
+ # # Now traverse back the tree, adjusting all sums along the way.
188
+ # for nodes_at_this_depth in reversed(self.nodes):
189
+ # # Note: Adding a delta leads to some tolerable numerical inaccuracies.
190
+ # nodes_at_this_depth[node_index] += delta_value
191
+ # self.nodes[]
192
+ # node_index //= 2
193
+
194
+ for nodes_at_this_depth_idx in reversed(range(len(self.nodes))):
195
+ nodes_at_this_depth = self.nodes[nodes_at_this_depth_idx]
196
+ nodes_at_this_depth[node_index] += delta_value
197
+ self.nodes[nodes_at_this_depth_idx] = nodes_at_this_depth
198
+ node_index //= 2
199
+
200
+ assert node_index == 0, ('Sum tree traversal failed, final node index '
201
+ 'is not 0.')
external/yarr/yarr/replay_buffer/task_uniform_replay_buffer.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ from os.path import join
4
+ import pickle
5
+ import math
6
+ from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer
7
+ from yarr.replay_buffer.uniform_replay_buffer import invalid_range
8
+
9
+ from yarr.replay_buffer.replay_buffer import ReplayBuffer, ReplayElement
10
+ from yarr.utils.observation_type import ObservationElement
11
+
12
+ ACTION = 'action'
13
+ REWARD = 'reward'
14
+ TERMINAL = 'terminal'
15
+ TIMEOUT = 'timeout'
16
+ INDICES = 'indices'
17
+ TASK = 'task'
18
+
19
+
20
+ class TaskUniformReplayBuffer(UniformReplayBuffer):
21
+ """
22
+ A uniform with uniform task sampling for each batch
23
+ """
24
+
25
+ def __init__(self, *args, **kwargs):
26
+ """Initializes OutOfGraphPrioritizedReplayBuffer."""
27
+ super(TaskUniformReplayBuffer, self).__init__(*args, **kwargs)
28
+ self._task_idxs = dict()
29
+
30
+ def _add(self, kwargs: dict):
31
+ """Internal add method to add to the storage arrays.
32
+
33
+ Args:
34
+ kwargs: All the elements in a transition.
35
+ """
36
+ with self._lock:
37
+ cursor = self.cursor()
38
+
39
+ if self._disk_saving:
40
+ term = self._store[TERMINAL]
41
+ term[cursor] = kwargs[TERMINAL]
42
+ self._store[TERMINAL] = term
43
+
44
+
45
+ ## reduce size
46
+ for k, v in kwargs.items():
47
+ try:
48
+ if 'float' in v.dtype.name and v.size > 100:
49
+ v = v.astype(np.float16)
50
+ kwargs[k] = v
51
+ except:
52
+ pass
53
+
54
+
55
+ with open(join(self._save_dir, '%d.replay' % cursor), 'wb') as f:
56
+ pickle.dump(kwargs, f)
57
+ # If first add, then pad for correct wrapping
58
+ if self._add_count.value == 0:
59
+ self._add_initial_to_disk(kwargs)
60
+ else:
61
+ for name, data in kwargs.items():
62
+ item = self._store[name]
63
+ item[cursor] = data
64
+ self._store[name] = item
65
+ with self._add_count.get_lock():
66
+ task = kwargs[TASK]
67
+ if task not in self._task_idxs:
68
+ self._task_idxs[task] = [cursor]
69
+ else:
70
+ self._task_idxs[task] = self._task_idxs[task] + [cursor]
71
+ self._add_count.value += 1
72
+
73
+ self.invalid_range = invalid_range(
74
+ self.cursor(), self._replay_capacity, self._timesteps,
75
+ self._update_horizon)
76
+
77
+ def sample_index_batch(self,
78
+ batch_size):
79
+ """Returns a batch of valid indices sampled uniformly.
80
+
81
+ Args:
82
+ batch_size: int, number of indices returned.
83
+
84
+ Returns:
85
+ list of ints, a batch of valid indices sampled uniformly across tasks.
86
+
87
+ Raises:
88
+ RuntimeError: If the batch was not constructed after maximum number of
89
+ tries.
90
+ """
91
+ if self.is_full():
92
+ min_id = (self.cursor() - self._replay_capacity +
93
+ self._timesteps - 1)
94
+ max_id = self.cursor() - self._update_horizon
95
+ else:
96
+ min_id = 0
97
+ max_id = self.cursor() - self._update_horizon
98
+ if max_id <= min_id:
99
+ raise RuntimeError(
100
+ 'Cannot sample a batch with fewer than stack size '
101
+ '({}) + update_horizon ({}) transitions.'.
102
+ format(self._timesteps, self._update_horizon))
103
+
104
+ tasks = list(self._task_idxs.keys())
105
+ attempt_count = 0
106
+ found_indicies = False
107
+
108
+ # uniform distribution of tasks
109
+ while not found_indicies and attempt_count < 1000:
110
+ # sample random tasks of batch_size length
111
+ sampled_tasks = list(np.random.choice(tasks, batch_size, replace=(batch_size > len(tasks))))
112
+ potential_indices = []
113
+ for task in sampled_tasks:
114
+ # DDP setting where each GPU only sees a fraction of the data
115
+ # reference: https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py
116
+ task_data_size = len(self._task_idxs[task])
117
+ num_samples = math.ceil(task_data_size / self._num_replicas)
118
+ total_size = num_samples * self._num_replicas
119
+ task_indices = self._task_idxs[task][self._rank:total_size:self._num_replicas]
120
+
121
+ sampled_task_idx = np.random.choice(task_indices, 1)[0]
122
+ per_task_attempt_count = 0
123
+
124
+ # Argh.. this is slow
125
+ while not self.is_valid_transition(sampled_task_idx) and \
126
+ per_task_attempt_count < self._max_sample_attempts:
127
+ sampled_task_idx = np.random.choice(task_indices, 1)[0]
128
+ per_task_attempt_count += 1
129
+
130
+ if not self.is_valid_transition(sampled_task_idx):
131
+ attempt_count += 1
132
+ continue
133
+ else:
134
+ potential_indices.append(sampled_task_idx)
135
+ found_indicies = len(potential_indices) == batch_size
136
+ indices = potential_indices
137
+
138
+ if len(indices) != batch_size:
139
+ raise RuntimeError(
140
+ 'Max sample attempts: Tried {} times but only sampled {}'
141
+ ' valid indices. Batch size is {}'.
142
+ format(self._max_sample_attempts, len(indices), batch_size))
143
+
144
+ return indices
145
+
146
+ def get_transition_elements(self, batch_size=None):
147
+ """Returns a 'type signature' for sample_transition_batch.
148
+
149
+ Args:
150
+ batch_size: int, number of transitions returned. If None, the default
151
+ batch_size will be used.
152
+ Returns:
153
+ signature: A namedtuple describing the method's return type signature.
154
+ """
155
+ batch_size = self._batch_size if batch_size is None else batch_size
156
+
157
+ transition_elements = [
158
+ ReplayElement(ACTION, (batch_size, self._timesteps) + self._action_shape,
159
+ self._action_dtype),
160
+ ReplayElement(REWARD, (batch_size, self._timesteps) + self._reward_shape,
161
+ self._reward_dtype),
162
+ ReplayElement(TERMINAL, (batch_size, self._timesteps), np.int8),
163
+ ReplayElement(TIMEOUT, (batch_size, self._timesteps), bool),
164
+ ReplayElement(INDICES, (batch_size, self._timesteps), np.int32),
165
+ ]
166
+
167
+ for element in self._observation_elements:
168
+ transition_elements.append(ReplayElement(
169
+ element.name,
170
+ (batch_size, self._timesteps) + tuple(element.shape),
171
+ element.type, True))
172
+ transition_elements.append(ReplayElement(
173
+ element.name + '_tp1',
174
+ (batch_size, self._timesteps) + tuple(element.shape),
175
+ element.type, True))
176
+
177
+ for element in self._extra_replay_elements:
178
+ transition_elements.append(ReplayElement(
179
+ element.name,
180
+ (batch_size,) + tuple(element.shape),
181
+ element.type))
182
+ return transition_elements
external/yarr/yarr/replay_buffer/uniform_replay_buffer.py ADDED
@@ -0,0 +1,804 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The standard DQN replay memory.
2
+
3
+ This implementation is an out-of-graph replay memory + in-graph wrapper. It
4
+ supports vanilla n-step updates of the form typically found in the literature,
5
+ i.e. where rewards are accumulated for n steps and the intermediate trajectory
6
+ is not exposed to the agent. This does not allow, for example, performing
7
+ off-policy corrections.
8
+ """
9
+ import ctypes
10
+ import collections
11
+ import concurrent.futures
12
+ import os
13
+ from os.path import join
14
+ import pickle
15
+ from typing import List, Tuple, Type
16
+ import time
17
+ import math
18
+ # from threading import Lock
19
+ import multiprocessing as mp
20
+ from multiprocessing import Lock
21
+ import numpy as np
22
+ import logging
23
+
24
+ from natsort import natsort
25
+
26
+ from yarr.replay_buffer.replay_buffer import ReplayBuffer, ReplayElement
27
+ from yarr.utils.observation_type import ObservationElement
28
+
29
+ import torch.distributed as dist
30
+
31
+ # Defines a type describing part of the tuple returned by the replay
32
+ # memory. Each element of the tuple is a tensor of shape [batch, ...] where
33
+ # ... is defined the 'shape' field of ReplayElement. The tensor type is
34
+ # given by the 'type' field. The 'name' field is for convenience and ease of
35
+ # debugging.
36
+
37
+
38
+ # String constants for storage
39
+ ACTION = 'action'
40
+ REWARD = 'reward'
41
+ TERMINAL = 'terminal'
42
+ TIMEOUT = 'timeout'
43
+ INDICES = 'indices'
44
+
45
+
46
+ def invalid_range(cursor, replay_capacity, stack_size, update_horizon):
47
+ """Returns a array with the indices of cursor-related invalid transitions.
48
+
49
+ There are update_horizon + stack_size invalid indices:
50
+ - The update_horizon indices before the cursor, because we do not have a
51
+ valid N-step transition (including the next state).
52
+ - The stack_size indices on or immediately after the cursor.
53
+ If N = update_horizon, K = stack_size, and the cursor is at c, invalid
54
+ indices are:
55
+ c - N, c - N + 1, ..., c, c + 1, ..., c + K - 1.
56
+
57
+ It handles special cases in a circular buffer in the beginning and the end.
58
+
59
+ Args:
60
+ cursor: int, the position of the cursor.
61
+ replay_capacity: int, the size of the replay memory.
62
+ stack_size: int, the size of the stacks returned by the replay memory.
63
+ update_horizon: int, the agent's update horizon.
64
+ Returns:
65
+ np.array of size stack_size with the invalid indices.
66
+ """
67
+ assert cursor < replay_capacity
68
+ return np.array(
69
+ [(cursor - update_horizon + i) % replay_capacity
70
+ for i in range(stack_size + update_horizon)])
71
+
72
+
73
+ class UniformReplayBuffer(ReplayBuffer):
74
+ """A simple out-of-graph Replay Buffer.
75
+
76
+ Stores transitions, state, action, reward, next_state, terminal (and any
77
+ extra contents specified) in a circular buffer and provides a uniform
78
+ transition sampling function.
79
+
80
+ When the states consist of stacks of observations storing the states is
81
+ inefficient. This class writes observations and constructs the stacked states
82
+ at sample time.
83
+
84
+ Attributes:
85
+ _add_count: int, counter of how many transitions have been added (including
86
+ the blank ones at the beginning of an episode).
87
+ invalid_range: np.array, an array with the indices of cursor-related invalid
88
+ transitions
89
+ """
90
+
91
+ def __init__(self,
92
+ batch_size: int = 32,
93
+ timesteps: int = 1,
94
+ replay_capacity: int = int(1e6),
95
+ update_horizon: int = 1,
96
+ gamma: float = 0.99,
97
+ max_sample_attempts: int = 10000,
98
+ action_shape: tuple = (),
99
+ action_dtype: Type[np.dtype] = np.float32,
100
+ reward_shape: tuple = (),
101
+ reward_dtype: Type[np.dtype] = np.float32,
102
+ observation_elements: List[ObservationElement] = None,
103
+ extra_replay_elements: List[ReplayElement] = None,
104
+ save_dir: str = None,
105
+ purge_replay_on_shutdown: bool = True,
106
+ num_replicas: int = None,
107
+ rank: int = None,
108
+ ):
109
+ """Initializes OutOfGraphReplayBuffer.
110
+
111
+ Args:
112
+ batch_size: int.
113
+ timesteps: int, number of frames to use in state stack.
114
+ replay_capacity: int, number of transitions to keep in memory.
115
+ update_horizon: int, length of update ('n' in n-step update).
116
+ gamma: int, the discount factor.
117
+ max_sample_attempts: int, the maximum number of attempts allowed to
118
+ get a sample.
119
+ action_shape: tuple of ints, the shape for the action vector.
120
+ Empty tuple means the action is a scalar.
121
+ action_dtype: np.dtype, type of elements in the action.
122
+ reward_shape: tuple of ints, the shape of the reward vector.
123
+ Empty tuple means the reward is a scalar.
124
+ reward_dtype: np.dtype, type of elements in the reward.
125
+ observation_elements: list of ObservationElement defining the type of
126
+ the extra contents that will be stored and returned.
127
+ extra_storage_elements: list of ReplayElement defining the type of
128
+ the extra contents that will be stored and returned.
129
+
130
+ Raises:
131
+ ValueError: If replay_capacity is too small to hold at least one
132
+ transition.
133
+ """
134
+ if num_replicas is None:
135
+ if not dist.is_available():
136
+ raise RuntimeError("Requires distributed package to be available")
137
+ self._num_replicas = dist.get_world_size()
138
+ if rank is None:
139
+ if not dist.is_available():
140
+ raise RuntimeError("Requires distributed package to be available")
141
+ self._rank = dist.get_rank()
142
+ if self._rank >= self._num_replicas or self._rank < 0:
143
+ raise ValueError(
144
+ "Invalid rank {}, rank should be in the interval"
145
+ " [0, {}]".format(self._rank, self._num_replicas - 1))
146
+
147
+ if observation_elements is None:
148
+ observation_elements = []
149
+ if extra_replay_elements is None:
150
+ extra_replay_elements = []
151
+
152
+ if replay_capacity < update_horizon + timesteps:
153
+ raise ValueError('There is not enough capacity to cover '
154
+ 'update_horizon and stack_size.')
155
+
156
+ logging.info(
157
+ 'Creating a %s replay memory with the following parameters:',
158
+ self.__class__.__name__)
159
+ logging.info('\t timesteps: %d', timesteps)
160
+ logging.info('\t replay_capacity: %d', replay_capacity)
161
+ logging.info('\t batch_size: %d', batch_size)
162
+ logging.info('\t update_horizon: %d', update_horizon)
163
+ logging.info('\t gamma: %f', gamma)
164
+
165
+ self._disk_saving = save_dir is not None
166
+ self._save_dir = save_dir
167
+ self._purge_replay_on_shutdown = purge_replay_on_shutdown
168
+ if self._disk_saving:
169
+ logging.info('\t saving to disk: %s', self._save_dir)
170
+ os.makedirs(save_dir, exist_ok=True)
171
+ else:
172
+ logging.info('\t saving to RAM')
173
+
174
+
175
+ self._action_shape = action_shape
176
+ self._action_dtype = action_dtype
177
+ self._reward_shape = reward_shape
178
+ self._reward_dtype = reward_dtype
179
+ self._timesteps = timesteps
180
+ self._replay_capacity = replay_capacity
181
+ self._batch_size = batch_size
182
+ self._update_horizon = update_horizon
183
+ self._gamma = gamma
184
+ self._max_sample_attempts = max_sample_attempts
185
+
186
+ self._observation_elements = observation_elements
187
+ self._extra_replay_elements = extra_replay_elements
188
+
189
+ self._storage_signature, self._obs_signature = self.get_storage_signature()
190
+ self._create_storage()
191
+
192
+ self._lock = Lock()
193
+ self._add_count = mp.Value('i', 0)
194
+
195
+ self._replay_capacity = replay_capacity
196
+
197
+ self.invalid_range = np.zeros((self._timesteps))
198
+
199
+ # When the horizon is > 1, we compute the sum of discounted rewards as a dot
200
+ # product using the precomputed vector <gamma^0, gamma^1, ..., gamma^{n-1}>.
201
+ self._cumulative_discount_vector = np.array(
202
+ [math.pow(self._gamma, n) for n in range(update_horizon)],
203
+ dtype=np.float32)
204
+
205
+ @property
206
+ def timesteps(self):
207
+ return self._timesteps
208
+
209
+ @property
210
+ def replay_capacity(self):
211
+ return self._replay_capacity
212
+
213
+ @property
214
+ def batch_size(self):
215
+ return self._batch_size
216
+
217
+ def _create_storage(self, store=None):
218
+ """Creates the numpy arrays used to store transitions.
219
+ """
220
+ self._store = {} if store is None else store
221
+ for storage_element in self._storage_signature:
222
+ array_shape = [self._replay_capacity] + list(storage_element.shape)
223
+ if storage_element.name == TERMINAL:
224
+ self._store[storage_element.name] = np.full(
225
+ array_shape, -1, dtype=storage_element.type)
226
+ elif not self._disk_saving:
227
+ # If saving to disk, we don't need to store anything else.
228
+ self._store[storage_element.name] = np.empty(
229
+ array_shape, dtype=storage_element.type)
230
+
231
+ def get_storage_signature(self) -> Tuple[List[ReplayElement],
232
+ List[ReplayElement]]:
233
+ """Returns a default list of elements to be stored in this replay memory.
234
+
235
+ Note - Derived classes may return a different signature.
236
+
237
+ Returns:
238
+ dict of ReplayElements defining the type of the contents stored.
239
+ """
240
+ storage_elements = [
241
+ ReplayElement(ACTION, self._action_shape, self._action_dtype),
242
+ ReplayElement(REWARD, self._reward_shape, self._reward_dtype),
243
+ ReplayElement(TERMINAL, (), np.int8),
244
+ ReplayElement(TIMEOUT, (), bool),
245
+ ]
246
+
247
+ obs_elements = []
248
+ for obs_element in self._observation_elements:
249
+ obs_elements.append(
250
+ ReplayElement(
251
+ obs_element.name, obs_element.shape, obs_element.type))
252
+ storage_elements.extend(obs_elements)
253
+
254
+ for extra_replay_element in self._extra_replay_elements:
255
+ storage_elements.append(extra_replay_element)
256
+
257
+ return storage_elements, obs_elements
258
+
259
+ def add(self, action, reward, terminal, timeout, **kwargs):
260
+ """Adds a transition to the replay memory.
261
+
262
+ WE ONLY STORE THE TPS1s on the final frame
263
+
264
+ This function checks the types and handles the padding at the beginning of
265
+ an episode. Then it calls the _add function.
266
+
267
+ Since the next_observation in the transition will be the observation added
268
+ next there is no need to pass it.
269
+
270
+ If the replay memory is at capacity the oldest transition will be discarded.
271
+
272
+ Args:
273
+ action: int, the action in the transition.
274
+ reward: float, the reward received in the transition.
275
+ terminal: A uint8 acting as a boolean indicating whether the transition
276
+ was terminal (1) or not (0).
277
+ **kwargs: The remaining args
278
+ """
279
+
280
+ # If previous transition was a terminal, then add_final wasn't called
281
+ # if not self.is_empty() and self._store['terminal'][self.cursor() - 1] == 1:
282
+ # raise ValueError('The previous transition was a terminal, '
283
+ # 'but add_final was not called.')
284
+
285
+ kwargs[ACTION] = action
286
+ kwargs[REWARD] = reward
287
+ kwargs[TERMINAL] = terminal
288
+ kwargs[TIMEOUT] = timeout
289
+ self._check_add_types(kwargs, self._storage_signature)
290
+ self._add(kwargs)
291
+
292
+ def add_final(self, **kwargs):
293
+ """Adds a transition to the replay memory.
294
+ Args:
295
+ **kwargs: The remaining args
296
+ """
297
+ # if self.is_empty() or self._store['terminal'][self.cursor() - 1] != 1:
298
+ # raise ValueError('The previous transition was not terminal.')
299
+ self._check_add_types(kwargs, self._obs_signature)
300
+ transition = self._final_transition(kwargs)
301
+ self._add(transition)
302
+
303
+ def _final_transition(self, kwargs):
304
+ transition = {}
305
+ for element_type in self._storage_signature:
306
+ if element_type.name in kwargs:
307
+ transition[element_type.name] = kwargs[element_type.name]
308
+ elif element_type.name == TERMINAL:
309
+ # Used to check that user is correctly adding transitions
310
+ transition[element_type.name] = -1
311
+ else:
312
+ transition[element_type.name] = np.empty(
313
+ element_type.shape, dtype=element_type.type)
314
+ return transition
315
+
316
+ def _add_initial_to_disk(self ,kwargs: dict):
317
+ for i in range(self._timesteps - 1):
318
+ with open(join(self._save_dir, '%d.replay' % (
319
+ self._replay_capacity - 1 - i)), 'wb') as f:
320
+ pickle.dump(kwargs, f)
321
+
322
+ def _add(self, kwargs: dict):
323
+ """Internal add method to add to the storage arrays.
324
+
325
+ Args:
326
+ kwargs: All the elements in a transition.
327
+ """
328
+ with self._lock:
329
+ cursor = self.cursor()
330
+
331
+ if self._disk_saving:
332
+ term = self._store[TERMINAL]
333
+ term[cursor] = kwargs[TERMINAL]
334
+ self._store[TERMINAL] = term
335
+ with open(join(self._save_dir, '%d.replay' % cursor), 'wb') as f:
336
+ pickle.dump(kwargs, f)
337
+ # If first add, then pad for correct wrapping
338
+ if self._add_count.value == 0:
339
+ self._add_initial_to_disk(kwargs)
340
+ else:
341
+ for name, data in kwargs.items():
342
+ item = self._store[name]
343
+ item[cursor] = data
344
+ self._store[name] = item
345
+ with self._add_count.get_lock():
346
+ self._add_count.value += 1
347
+ self.invalid_range = invalid_range(
348
+ self.cursor(), self._replay_capacity, self._timesteps,
349
+ self._update_horizon)
350
+
351
+ def _get_from_disk(self, start_index, end_index):
352
+ """Returns the range of array at the index handling wraparound if necessary.
353
+
354
+ Args:
355
+ start_index: int, index to the start of the range to be returned. Range
356
+ will wraparound if start_index is smaller than 0.
357
+ end_index: int, exclusive end index. Range will wraparound if end_index
358
+ exceeds replay_capacity.
359
+
360
+ Returns:
361
+ np.array, with shape [end_index - start_index, array.shape[1:]].
362
+ """
363
+ assert end_index > start_index, 'end_index must be larger than start_index'
364
+ assert end_index >= 0
365
+ assert start_index < self._replay_capacity
366
+ if not self.is_full():
367
+ assert end_index <= self.cursor(), (
368
+ 'Index {} has not been added.'.format(start_index))
369
+
370
+ # Here we fake a mini store (buffer)
371
+ store = {store_element.name: {}
372
+ for store_element in self._storage_signature}
373
+ if start_index % self._replay_capacity < end_index % self._replay_capacity:
374
+ for i in range(start_index, end_index):
375
+ with open(join(self._save_dir, '%d.replay' % i), 'rb') as f:
376
+ d = pickle.load(f)
377
+ for k, v in d.items():
378
+ store[k][i] = v
379
+ else:
380
+ for i in range(end_index - start_index):
381
+ idx = (start_index + i) % self._replay_capacity
382
+ with open(join(self._save_dir, '%d.replay' % idx), 'rb') as f:
383
+ d = pickle.load(f)
384
+ for k, v in d.items():
385
+ store[k][idx] = v
386
+ return store
387
+
388
+ def _check_add_types(self, kwargs, signature):
389
+ """Checks if args passed to the add method match those of the storage.
390
+
391
+ Args:
392
+ *args: Args whose types need to be validated.
393
+
394
+ Raises:
395
+ ValueError: If args have wrong shape or dtype.
396
+ """
397
+
398
+ if (len(kwargs)) != len(signature):
399
+ expected = str(natsort.natsorted([e.name for e in signature]))
400
+ actual = str(natsort.natsorted(list(kwargs.keys())))
401
+ error_list = '\nList of expected:\n{}\nList of actual:\n{}'.format(
402
+ expected, actual)
403
+ raise ValueError('Add expects {} elements, received {}.'.format(
404
+ len(signature), len(kwargs)) + error_list)
405
+
406
+ for store_element in signature:
407
+ arg_element = kwargs[store_element.name]
408
+ if isinstance(arg_element, np.ndarray):
409
+ arg_shape = arg_element.shape
410
+ elif isinstance(arg_element, tuple) or isinstance(arg_element, list):
411
+ # TODO: This is not efficient when arg_element is a list.
412
+ arg_shape = np.array(arg_element).shape
413
+ else:
414
+ # Assume it is scalar.
415
+ arg_shape = tuple()
416
+ store_element_shape = tuple(store_element.shape)
417
+ if arg_shape != store_element_shape:
418
+ raise ValueError('arg {} has shape {}, expected {}'.format(store_element.name,
419
+ arg_shape, store_element_shape))
420
+
421
+ def is_empty(self):
422
+ """Is the Replay Buffer empty?"""
423
+ return self._add_count.value == 0
424
+
425
+ def is_full(self):
426
+ """Is the Replay Buffer full?"""
427
+ return self._add_count.value >= self._replay_capacity
428
+
429
+ def cursor(self):
430
+ """Index to the location where the next transition will be written."""
431
+ return self._add_count.value % self._replay_capacity
432
+
433
+ @property
434
+ def add_count(self):
435
+ return np.array(self._add_count.value)
436
+
437
+ @add_count.setter
438
+ def add_count(self, count):
439
+ if isinstance(count, int):
440
+ self._add_count = mp.Value('i', count)
441
+ else:
442
+ self._add_count = count
443
+
444
+
445
+ def get_range(self, array, start_index, end_index):
446
+ """Returns the range of array at the index handling wraparound if necessary.
447
+
448
+ Args:
449
+ array: np.array, the array to get the stack from.
450
+ start_index: int, index to the start of the range to be returned. Range
451
+ will wraparound if start_index is smaller than 0.
452
+ end_index: int, exclusive end index. Range will wraparound if end_index
453
+ exceeds replay_capacity.
454
+
455
+ Returns:
456
+ np.array, with shape [end_index - start_index, array.shape[1:]].
457
+ """
458
+ assert end_index > start_index, 'end_index must be larger than start_index'
459
+ assert end_index >= 0
460
+ assert start_index < self._replay_capacity
461
+ if not self.is_full():
462
+ assert end_index <= self.cursor(), (
463
+ 'Index {} has not been added.'.format(start_index))
464
+
465
+ # Fast slice read when there is no wraparound.
466
+ if start_index % self._replay_capacity < end_index % self._replay_capacity:
467
+ return_array = np.array(
468
+ [array[i] for i in range(start_index, end_index)])
469
+ # Slow list read.
470
+ else:
471
+ indices = [(start_index + i) % self._replay_capacity
472
+ for i in range(end_index - start_index)]
473
+ return_array = np.array([array[i] for i in indices])
474
+
475
+ return return_array
476
+
477
+ def get_range_stack(self, array, start_index, end_index, terminals=None):
478
+ """Returns the range of array at the index handling wraparound if necessary.
479
+
480
+ Args:
481
+ array: np.array, the array to get the stack from.
482
+ start_index: int, index to the start of the range to be returned. Range
483
+ will wraparound if start_index is smaller than 0.
484
+ end_index: int, exclusive end index. Range will wraparound if end_index
485
+ exceeds replay_capacity.
486
+
487
+ Returns:
488
+ np.array, with shape [end_index - start_index, array.shape[1:]].
489
+ """
490
+ return_array = np.array(self.get_range(array, start_index, end_index))
491
+ if terminals is None:
492
+ terminals = self.get_range(
493
+ self._store[TERMINAL], start_index, end_index)
494
+
495
+ terminals = terminals[:-1]
496
+
497
+ # Here we now check if we need to pad the front episodes
498
+ # If any have a terminal of -1, then we have spilled over
499
+ # into the the previous transition
500
+ if np.any(terminals == -1):
501
+ padding_item = return_array[-1]
502
+ _array = list(return_array)[:-1]
503
+ arr_len = len(_array)
504
+ pad_from_now = False
505
+ for i, (ar, term) in enumerate(
506
+ zip(reversed(_array), reversed(terminals))):
507
+ if term == -1 or pad_from_now:
508
+ # The first time we see a -1 term, means we have hit the
509
+ # beginning of this episode, so pad from now.
510
+ # pad_from_now needed because the next transition (reverse)
511
+ # will not be a -1 terminal.
512
+ pad_from_now = True
513
+ return_array[arr_len - 1 - i] = padding_item
514
+ else:
515
+ # After we hit out first -1 terminal, we never reassign.
516
+ padding_item = ar
517
+
518
+ return return_array
519
+
520
+ def _get_element_stack(self, array, index, terminals=None):
521
+ state = self.get_range_stack(array,
522
+ index - self._timesteps + 1, index + 1,
523
+ terminals=terminals)
524
+ return state
525
+
526
+ def get_terminal_stack(self, index):
527
+ terminal_stack = self.get_range(self._store[TERMINAL],
528
+ index - self._timesteps + 1,
529
+ index + 1)
530
+ return terminal_stack
531
+
532
+ def is_valid_transition(self, index):
533
+ """Checks if the index contains a valid transition.
534
+
535
+ Checks for collisions with the end of episodes and the current position
536
+ of the cursor.
537
+
538
+ Args:
539
+ index: int, the index to the state in the transition.
540
+
541
+ Returns:
542
+ Is the index valid: Boolean.
543
+
544
+ """
545
+ # Check the index is in the valid range
546
+ if index < 0 or index >= self._replay_capacity:
547
+ return False
548
+ if not self.is_full():
549
+ # The indices and next_indices must be smaller than the cursor.
550
+ if index >= self.cursor() - self._update_horizon:
551
+ return False
552
+
553
+ # Skip transitions that straddle the cursor.
554
+ if index in set(self.invalid_range):
555
+ return False
556
+
557
+ term_stack = self.get_terminal_stack(index)
558
+ if term_stack[-1] == -1:
559
+ return False
560
+
561
+ return True
562
+
563
+ def _create_batch_arrays(self, batch_size):
564
+ """Create a tuple of arrays with the type of get_transition_elements.
565
+
566
+ When using the WrappedReplayBuffer with staging enabled it is important
567
+ to create new arrays every sample because StaginArea keeps a pointer to
568
+ the returned arrays.
569
+
570
+ Args:
571
+ batch_size: (int) number of transitions returned. If None the default
572
+ batch_size will be used.
573
+
574
+ Returns:
575
+ Tuple of np.arrays with the shape and type of get_transition_elements.
576
+ """
577
+ transition_elements = self.get_transition_elements(batch_size)
578
+ batch_arrays = []
579
+ for element in transition_elements:
580
+ batch_arrays.append(np.empty(element.shape, dtype=element.type))
581
+ return tuple(batch_arrays)
582
+
583
+ def sample_index_batch(self, batch_size):
584
+ """Returns a batch of valid indices sampled uniformly.
585
+
586
+ Args:
587
+ batch_size: int, number of indices returned.
588
+
589
+ Returns:
590
+ list of ints, a batch of valid indices sampled uniformly.
591
+
592
+ Raises:
593
+ RuntimeError: If the batch was not constructed after maximum number of
594
+ tries.
595
+ """
596
+ if self.is_full():
597
+ # add_count >= self._replay_capacity > self._stack_size
598
+ min_id = (self.cursor() - self._replay_capacity +
599
+ self._timesteps - 1)
600
+ max_id = self.cursor() - self._update_horizon
601
+ else:
602
+ min_id = 0
603
+ max_id = self.cursor() - self._update_horizon
604
+ if max_id <= min_id:
605
+ raise RuntimeError(
606
+ 'Cannot sample a batch with fewer than stack size '
607
+ '({}) + update_horizon ({}) transitions.'.
608
+ format(self._timesteps, self._update_horizon))
609
+
610
+ indices = []
611
+ attempt_count = 0
612
+ while (len(indices) < batch_size and
613
+ attempt_count < self._max_sample_attempts):
614
+ index = np.random.randint(min_id, max_id) % self._replay_capacity
615
+ if self.is_valid_transition(index):
616
+ indices.append(index)
617
+ else:
618
+ attempt_count += 1
619
+ if len(indices) != batch_size:
620
+ raise RuntimeError(
621
+ 'Max sample attempts: Tried {} times but only sampled {}'
622
+ ' valid indices. Batch size is {}'.
623
+ format(self._max_sample_attempts, len(indices), batch_size))
624
+
625
+ return indices
626
+
627
+ def unpack_transition(self, transition_tensors, transition_type):
628
+ """Unpacks the given transition into member variables.
629
+
630
+ Args:
631
+ transition_tensors: tuple of tf.Tensors.
632
+ transition_type: tuple of ReplayElements matching transition_tensors.
633
+ """
634
+ self.transition = collections.OrderedDict()
635
+ for element, element_type in zip(transition_tensors, transition_type):
636
+ self.transition[element_type.name] = element
637
+ return self.transition
638
+
639
+ def sample_transition_batch(self, batch_size=None, indices=None,
640
+ pack_in_dict=True):
641
+ """Returns a batch of transitions (including any extra contents).
642
+
643
+ If get_transition_elements has been overridden and defines elements not
644
+ stored in self._store, an empty array will be returned and it will be
645
+ left to the child class to fill it. For example, for the child class
646
+ OutOfGraphPrioritizedReplayBuffer, the contents of the
647
+ sampling_probabilities are stored separately in a sum tree.
648
+
649
+ When the transition is terminal next_state_batch has undefined contents.
650
+
651
+ NOTE: This transition contains the indices of the sampled elements.
652
+ These are only valid during the call to sample_transition_batch,
653
+ i.e. they may be used by subclasses of this replay buffer but may
654
+ point to different data as soon as sampling is done.
655
+
656
+ Args:
657
+ batch_size: int, number of transitions returned. If None, the default
658
+ batch_size will be used.
659
+ indices: None or list of ints, the indices of every transition in the
660
+ batch. If None, sample the indices uniformly.
661
+
662
+ Returns:
663
+ transition_batch: tuple of np.arrays with the shape and type as in
664
+ get_transition_elements().
665
+
666
+ Raises:
667
+ ValueError: If an element to be sampled is missing from the
668
+ replay buffer.
669
+ """
670
+
671
+ if batch_size is None:
672
+ batch_size = self._batch_size
673
+ with self._lock:
674
+ if indices is None:
675
+ indices = self.sample_index_batch(batch_size)
676
+ assert len(indices) == batch_size
677
+
678
+ transition_elements = self.get_transition_elements(batch_size)
679
+ batch_arrays = self._create_batch_arrays(batch_size)
680
+
681
+ for batch_element, state_index in enumerate(indices):
682
+
683
+ if not self.is_valid_transition(state_index):
684
+ raise ValueError('Invalid index %d.' % state_index)
685
+
686
+ trajectory_indices = [(state_index + j) % self._replay_capacity
687
+ for j in range(self._update_horizon)]
688
+ trajectory_terminals = self._store['terminal'][
689
+ trajectory_indices]
690
+ is_terminal_transition = trajectory_terminals.any()
691
+ if not is_terminal_transition:
692
+ trajectory_length = self._update_horizon
693
+ else:
694
+ # np.argmax of a bool array returns index of the first True.
695
+ trajectory_length = np.argmax(
696
+ trajectory_terminals.astype(bool),
697
+ 0) + 1
698
+
699
+ next_state_index = state_index + trajectory_length
700
+
701
+ store = self._store
702
+ if self._disk_saving:
703
+ store = self._get_from_disk(
704
+ state_index - (self._timesteps - 1),
705
+ next_state_index + 1)
706
+
707
+ trajectory_discount_vector = (
708
+ self._cumulative_discount_vector[:trajectory_length])
709
+ trajectory_rewards = self.get_range(store['reward'],
710
+ state_index,
711
+ next_state_index)
712
+
713
+ terminal_stack = self.get_terminal_stack(state_index)
714
+ terminal_stack_tp1 = self.get_terminal_stack(
715
+ next_state_index % self._replay_capacity)
716
+
717
+ # Fill the contents of each array in the sampled batch.
718
+ assert len(transition_elements) == len(batch_arrays)
719
+ for element_array, element in zip(batch_arrays,
720
+ transition_elements):
721
+ if element.is_observation:
722
+ if element.name.endswith('tp1'):
723
+ element_array[
724
+ batch_element] = self._get_element_stack(
725
+ store[element.name[:-4]],
726
+ next_state_index % self._replay_capacity,
727
+ terminal_stack_tp1)
728
+ else:
729
+ element_array[
730
+ batch_element] = self._get_element_stack(
731
+ store[element.name],
732
+ state_index, terminal_stack)
733
+ elif element.name == REWARD:
734
+ # compute discounted sum of rewards in the trajectory.
735
+ element_array[batch_element] = np.sum(
736
+ trajectory_discount_vector * trajectory_rewards,
737
+ axis=0)
738
+ elif element.name == TERMINAL:
739
+ element_array[batch_element] = is_terminal_transition
740
+ elif element.name == INDICES:
741
+ element_array[batch_element] = state_index
742
+ elif element.name in store.keys():
743
+ element_array[batch_element] = (
744
+ store[element.name][state_index])
745
+
746
+ if pack_in_dict:
747
+ batch_arrays = self.unpack_transition(
748
+ batch_arrays, transition_elements)
749
+
750
+ # TODO(Mohit): proper fix to discard task names
751
+ if 'task' in batch_arrays:
752
+ del batch_arrays['task']
753
+ if 'task_tp1' in batch_arrays:
754
+ del batch_arrays['task_tp1']
755
+
756
+ return batch_arrays
757
+
758
+ def get_transition_elements(self, batch_size=None):
759
+ """Returns a 'type signature' for sample_transition_batch.
760
+
761
+ Args:
762
+ batch_size: int, number of transitions returned. If None, the default
763
+ batch_size will be used.
764
+ Returns:
765
+ signature: A namedtuple describing the method's return type signature.
766
+ """
767
+ batch_size = self._batch_size if batch_size is None else batch_size
768
+
769
+ transition_elements = [
770
+ ReplayElement(ACTION, (batch_size,) + self._action_shape,
771
+ self._action_dtype),
772
+ ReplayElement(REWARD, (batch_size,) + self._reward_shape,
773
+ self._reward_dtype),
774
+ ReplayElement(TERMINAL, (batch_size,), np.int8),
775
+ ReplayElement(TIMEOUT, (batch_size,), bool),
776
+ ReplayElement(INDICES, (batch_size,), np.int32),
777
+ ]
778
+
779
+ for element in self._observation_elements:
780
+ transition_elements.append(ReplayElement(
781
+ element.name,
782
+ (batch_size, self._timesteps) + tuple(element.shape),
783
+ element.type, True))
784
+ transition_elements.append(ReplayElement(
785
+ element.name + '_tp1',
786
+ (batch_size, self._timesteps) + tuple(element.shape),
787
+ element.type, True))
788
+
789
+ for element in self._extra_replay_elements:
790
+ transition_elements.append(ReplayElement(
791
+ element.name,
792
+ (batch_size,) + tuple(element.shape),
793
+ element.type))
794
+ return transition_elements
795
+
796
+ def shutdown(self):
797
+ if self._purge_replay_on_shutdown:
798
+ # Safely delete replay
799
+ logging.info('Clearing disk replay buffer.')
800
+ for f in [f for f in os.listdir(self._save_dir) if '.replay' in f]:
801
+ os.remove(join(self._save_dir, f))
802
+
803
+ def using_disk(self):
804
+ return self._disk_saving
external/yarr/yarr/replay_buffer/wrappers/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any
3
+
4
+ from yarr.replay_buffer.replay_buffer import ReplayBuffer
5
+
6
+
7
+ class WrappedReplayBuffer(ABC):
8
+
9
+ def __init__(self, replay_buffer: ReplayBuffer):
10
+ """Initializes WrappedReplayBuffer.
11
+
12
+ Raises:
13
+ ValueError: If update_horizon is not positive.
14
+ ValueError: If discount factor is not in [0, 1].
15
+ """
16
+ self._replay_buffer = replay_buffer
17
+
18
+ @property
19
+ def replay_buffer(self):
20
+ return self._replay_buffer
21
+
22
+ @abstractmethod
23
+ def dataset(self) -> Any:
24
+ pass
external/yarr/yarr/replay_buffer/wrappers/pytorch_replay_buffer.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from threading import Thread
3
+
4
+ from torch.utils.data import IterableDataset, DataLoader
5
+
6
+ from yarr.replay_buffer.replay_buffer import ReplayBuffer
7
+ from yarr.replay_buffer.wrappers import WrappedReplayBuffer
8
+
9
+
10
+ class PyTorchIterableReplayDataset(IterableDataset):
11
+
12
+ def __init__(self, replay_buffer: ReplayBuffer):
13
+ self._replay_buffer = replay_buffer
14
+
15
+ def _generator(self):
16
+ while True:
17
+ yield self._replay_buffer.sample_transition_batch(pack_in_dict=True)
18
+
19
+ def __iter__(self):
20
+ return iter(self._generator())
21
+
22
+ # class PyTorchIterableReplayDataset(IterableDataset):
23
+ #
24
+ # BUFFER = 4
25
+ #
26
+ # def __init__(self, replay_buffer: ReplayBuffer, num_workers: int):
27
+ # self._replay_buffer = replay_buffer
28
+ # self._num_wokers = num_workers
29
+ # self._samples = []
30
+ # self._lock = Lock()
31
+ #
32
+ # def _run(self):
33
+ # while True:
34
+ # # Check if replay buffer is ig enough to be sampled
35
+ # while self._replay_buffer.add_count < self._replay_buffer.batch_size:
36
+ # time.sleep(1.)
37
+ # s = self._replay_buffer.sample_transition_batch(pack_in_dict=True)
38
+ # while len(self._samples) >= PyTorchIterableReplayDataset.BUFFER:
39
+ # time.sleep(0.25)
40
+ # with self._lock:
41
+ # self._samples.append(s)
42
+ #
43
+ # def _generator(self):
44
+ # ts = [Thread(
45
+ # target=self._run, args=()) for _ in range(self._num_wokers)]
46
+ # [t.start() for t in ts]
47
+ # while True:
48
+ # while len(self._samples) == 0:
49
+ # time.sleep(0.1)
50
+ # with self._lock:
51
+ # s = self._samples.pop(0)
52
+ # yield s
53
+ #
54
+ # def __iter__(self):
55
+ # i = iter(self._generator())
56
+ # return i
57
+
58
+
59
+ class PyTorchReplayBuffer(WrappedReplayBuffer):
60
+ """Wrapper of OutOfGraphReplayBuffer with an in graph sampling mechanism.
61
+
62
+ Usage:
63
+ To add a transition: call the add function.
64
+
65
+ To sample a batch: Construct operations that depend on any of the
66
+ tensors is the transition dictionary. Every sess.run
67
+ that requires any of these tensors will sample a new
68
+ transition.
69
+ """
70
+
71
+ def __init__(self, replay_buffer: ReplayBuffer, num_workers: int = 2):
72
+ super(PyTorchReplayBuffer, self).__init__(replay_buffer)
73
+ self._num_workers = num_workers
74
+
75
+ def dataset(self, batch_size=None, drop_last=False) -> DataLoader:
76
+ # d = PyTorchIterableReplayDataset(self._replay_buffer)
77
+ d = PyTorchIterableReplayDataset(self._replay_buffer)
78
+
79
+ # Batch size None disables automatic batching
80
+ return DataLoader(d, batch_size=batch_size,
81
+ drop_last=drop_last,
82
+ num_workers=self._num_workers, pin_memory=True)
external/yarr/yarr/runners/__init__.py ADDED
File without changes
external/yarr/yarr/runners/_env_runner.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ import time
5
+ import pandas as pd
6
+
7
+ from multiprocessing import Process, Manager
8
+ from multiprocessing import get_start_method, set_start_method
9
+ from typing import Any
10
+
11
+ import numpy as np
12
+ import torch
13
+ from yarr.agents.agent import Agent
14
+ from yarr.agents.agent import ScalarSummary
15
+ from yarr.agents.agent import Summary
16
+ from yarr.envs.env import Env
17
+ from yarr.utils.rollout_generator import RolloutGenerator
18
+ from yarr.utils.log_writer import LogWriter
19
+ from yarr.utils.process_str import change_case
20
+ from yarr.utils.video_utils import CircleCameraMotion, TaskRecorder
21
+
22
+ from pyrep.objects.dummy import Dummy
23
+ from pyrep.objects.vision_sensor import VisionSensor
24
+
25
+ try:
26
+ if get_start_method() != 'spawn':
27
+ set_start_method('spawn', force=True)
28
+ except RuntimeError:
29
+ pass
30
+
31
+
32
+ class _EnvRunner(object):
33
+
34
+ def __init__(self,
35
+ train_env: Env,
36
+ eval_env: Env,
37
+ agent: Agent,
38
+ timesteps: int,
39
+ train_envs: int,
40
+ eval_envs: int,
41
+ rollout_episodes: int,
42
+ eval_episodes: int,
43
+ training_iterations: int,
44
+ eval_from_eps_number: int,
45
+ episode_length: int,
46
+ kill_signal: Any,
47
+ step_signal: Any,
48
+ num_eval_episodes_signal: Any,
49
+ eval_epochs_signal: Any,
50
+ eval_report_signal: Any,
51
+ log_freq: int,
52
+ rollout_generator: RolloutGenerator,
53
+ save_load_lock,
54
+ current_replay_ratio,
55
+ target_replay_ratio,
56
+ weightsdir: str = None,
57
+ logdir: str = None,
58
+ env_device: torch.device = None,
59
+ previous_loaded_weight_folder: str = '',
60
+ num_eval_runs: int = 1,
61
+ ):
62
+ self._train_env = train_env
63
+ self._eval_env = eval_env
64
+ self._agent = agent
65
+ self._train_envs = train_envs
66
+ self._eval_envs = eval_envs
67
+ self._rollout_episodes = rollout_episodes
68
+ self._eval_episodes = eval_episodes
69
+ self._training_iterations = training_iterations
70
+ self._num_eval_runs = num_eval_runs
71
+ self._eval_from_eps_number = eval_from_eps_number
72
+ self._episode_length = episode_length
73
+ self._rollout_generator = rollout_generator
74
+ self._weightsdir = weightsdir
75
+ self._logdir = logdir
76
+ self._env_device = env_device
77
+ self._previous_loaded_weight_folder = previous_loaded_weight_folder
78
+
79
+ self._timesteps = timesteps
80
+
81
+ self._p_args = {}
82
+ self.p_failures = {}
83
+ manager = Manager()
84
+ self.write_lock = manager.Lock()
85
+ self.stored_transitions = manager.list()
86
+ self.agent_summaries = manager.list()
87
+ self._kill_signal = kill_signal
88
+ self._step_signal = step_signal
89
+ self._num_eval_episodes_signal = num_eval_episodes_signal
90
+ self._eval_epochs_signal = eval_epochs_signal
91
+ self._eval_report_signal = eval_report_signal
92
+ self._save_load_lock = save_load_lock
93
+ self._current_replay_ratio = current_replay_ratio
94
+ self._target_replay_ratio = target_replay_ratio
95
+ self._log_freq = log_freq
96
+
97
+ self._new_weights = False
98
+
99
+ def restart_process(self, name: str):
100
+ p = Process(target=self._run_env, args=self._p_args[name], name=name)
101
+ p.start()
102
+ return p
103
+
104
+ def spin_up_envs(self, name: str, num_envs: int, eval: bool):
105
+
106
+ ps = []
107
+ for i in range(num_envs):
108
+ n = name + str(i)
109
+ self._p_args[n] = (n, eval)
110
+ self.p_failures[n] = 0
111
+ p = Process(target=self._run_env, args=self._p_args[n], name=n)
112
+ p.start()
113
+ ps.append(p)
114
+ return ps
115
+
116
+ def _load_save(self):
117
+ if self._weightsdir is None:
118
+ logging.info("'weightsdir' was None, so not loading weights.")
119
+ return
120
+ while True:
121
+ weight_folders = []
122
+ with self._save_load_lock:
123
+ if os.path.exists(self._weightsdir):
124
+ weight_folders = os.listdir(self._weightsdir)
125
+ if len(weight_folders) > 0:
126
+ weight_folders = sorted(map(int, weight_folders))
127
+ # Only load if there has been a new weight saving
128
+ if self._previous_loaded_weight_folder != weight_folders[-1]:
129
+ self._previous_loaded_weight_folder = weight_folders[-1]
130
+ d = os.path.join(self._weightsdir, str(weight_folders[-1]))
131
+ try:
132
+ self._agent.load_weights(d)
133
+ except FileNotFoundError:
134
+ # Rare case when agent hasn't finished writing.
135
+ time.sleep(1)
136
+ self._agent.load_weights(d)
137
+ print('Agent %s: Loaded weights: %s' % (self._name, d))
138
+ self._new_weights = True
139
+ else:
140
+ self._new_weights = False
141
+ break
142
+ print('Waiting for weights to become available.')
143
+ time.sleep(1)
144
+
145
+ def _get_type(self, x):
146
+ if x.dtype == np.float64:
147
+ return np.float32
148
+ return x.dtype
149
+
150
+ def _get_task_name(self):
151
+ if hasattr(self._eval_env, '_task_class'):
152
+ eval_task_name = change_case(self._eval_env._task_class.__name__)
153
+ multi_task = False
154
+ elif hasattr(self._eval_env, '_task_classes'):
155
+ if self._eval_env.active_task_id != -1:
156
+ task_id = (self._eval_env.active_task_id) % len(self._eval_env._task_classes)
157
+ eval_task_name = change_case(self._eval_env._task_classes[task_id].__name__)
158
+ else:
159
+ eval_task_name = ''
160
+ multi_task = True
161
+ else:
162
+ raise Exception('Neither task_class nor task_classes found in eval env')
163
+ return eval_task_name, multi_task
164
+
165
+ def _run_env(self, name: str, eval: bool):
166
+
167
+ self._name = name
168
+
169
+ self._agent = copy.deepcopy(self._agent)
170
+
171
+ self._agent.build(training=False, device=self._env_device)
172
+
173
+ logging.info('%s: Launching env.' % name)
174
+ np.random.seed()
175
+
176
+ logging.info('Agent information:')
177
+ logging.info(self._agent)
178
+
179
+ env = self._train_env
180
+ if eval:
181
+ env = self._eval_env
182
+ env.eval = eval
183
+ env.launch()
184
+ for ep in range(self._rollout_episodes):
185
+ self._load_save()
186
+ logging.debug('%s: Starting episode %d.' % (name, ep))
187
+ episode_rollout = []
188
+ generator = self._rollout_generator.generator(
189
+ self._step_signal, env, self._agent,
190
+ self._episode_length, self._timesteps,
191
+ eval, eval_demo_seed=eval_demo_seed,
192
+ record_enabled=rec_cfg.enabled)
193
+ try:
194
+ for replay_transition in generator:
195
+ while True:
196
+ if self._kill_signal.value:
197
+ env.shutdown()
198
+ return
199
+ if (eval or self._target_replay_ratio is None or
200
+ self._step_signal.value <= 0 or (
201
+ self._current_replay_ratio.value >
202
+ self._target_replay_ratio)):
203
+ break
204
+ time.sleep(1)
205
+ logging.debug(
206
+ 'Agent. Waiting for replay_ratio %f to be more than %f' %
207
+ (self._current_replay_ratio.value, self._target_replay_ratio))
208
+
209
+ with self.write_lock:
210
+ if len(self.agent_summaries) == 0:
211
+ # Only store new summaries if the previous ones
212
+ # have been popped by the main env runner.
213
+ for s in self._agent.act_summaries():
214
+ self.agent_summaries.append(s)
215
+ episode_rollout.append(replay_transition)
216
+ except StopIteration as e:
217
+ continue
218
+ except Exception as e:
219
+ env.shutdown()
220
+ raise e
221
+
222
+ with self.write_lock:
223
+ for transition in episode_rollout:
224
+ self.stored_transitions.append((name, transition, eval))
225
+ env.shutdown()
226
+
227
+ def kill(self):
228
+ self._kill_signal.value = True
external/yarr/yarr/runners/_independent_env_runner.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ import time
5
+ import pandas as pd
6
+
7
+ from multiprocessing import Process, Manager
8
+ from multiprocessing import get_start_method, set_start_method
9
+ from typing import Any
10
+
11
+ import numpy as np
12
+ import torch
13
+ from yarr.agents.agent import Agent
14
+ from yarr.agents.agent import ScalarSummary
15
+ from yarr.agents.agent import Summary
16
+ from yarr.envs.env import Env
17
+ from yarr.utils.rollout_generator import RolloutGenerator
18
+ from yarr.utils.log_writer import LogWriter
19
+ from yarr.utils.process_str import change_case
20
+ from yarr.utils.video_utils import CircleCameraMotion, TaskRecorder
21
+
22
+ from pyrep.objects.dummy import Dummy
23
+ from pyrep.objects.vision_sensor import VisionSensor
24
+
25
+ from yarr.runners._env_runner import _EnvRunner
26
+
27
+
28
+ class _IndependentEnvRunner(_EnvRunner):
29
+
30
+ def __init__(self,
31
+ train_env: Env,
32
+ eval_env: Env,
33
+ agent: Agent,
34
+ timesteps: int,
35
+ train_envs: int,
36
+ eval_envs: int,
37
+ rollout_episodes: int,
38
+ eval_episodes: int,
39
+ training_iterations: int,
40
+ eval_from_eps_number: int,
41
+ episode_length: int,
42
+ kill_signal: Any,
43
+ step_signal: Any,
44
+ num_eval_episodes_signal: Any,
45
+ eval_epochs_signal: Any,
46
+ eval_report_signal: Any,
47
+ log_freq: int,
48
+ rollout_generator: RolloutGenerator,
49
+ save_load_lock,
50
+ current_replay_ratio,
51
+ target_replay_ratio,
52
+ weightsdir: str = None,
53
+ logdir: str = None,
54
+ env_device: torch.device = None,
55
+ previous_loaded_weight_folder: str = '',
56
+ num_eval_runs: int = 1,
57
+ ):
58
+
59
+ super().__init__(train_env, eval_env, agent, timesteps,
60
+ train_envs, eval_envs, rollout_episodes, eval_episodes,
61
+ training_iterations, eval_from_eps_number, episode_length,
62
+ kill_signal, step_signal, num_eval_episodes_signal,
63
+ eval_epochs_signal, eval_report_signal, log_freq,
64
+ rollout_generator, save_load_lock, current_replay_ratio,
65
+ target_replay_ratio, weightsdir, logdir, env_device,
66
+ previous_loaded_weight_folder, num_eval_runs)
67
+
68
+ def _load_save(self):
69
+ if self._weightsdir is None:
70
+ logging.info("'weightsdir' was None, so not loading weights.")
71
+ return
72
+ while True:
73
+ weight_folders = []
74
+ with self._save_load_lock:
75
+ if os.path.exists(self._weightsdir):
76
+ weight_folders = os.listdir(self._weightsdir)
77
+ if len(weight_folders) > 0:
78
+ weight_folders = sorted(map(int, weight_folders))
79
+ # only load if there has been a new weight saving
80
+ if self._previous_loaded_weight_folder != weight_folders[-1]:
81
+ self._previous_loaded_weight_folder = weight_folders[-1]
82
+ d = os.path.join(self._weightsdir, str(weight_folders[-1]))
83
+ try:
84
+ self._agent.load_weights(d)
85
+ except FileNotFoundError:
86
+ # rare case when agent hasn't finished writing.
87
+ time.sleep(1)
88
+ self._agent.load_weights(d)
89
+ logging.info('Agent %s: Loaded weights: %s' % (self._name, d))
90
+ self._new_weights = True
91
+ else:
92
+ self._new_weights = False
93
+ break
94
+ logging.info('Waiting for weights to become available.')
95
+ time.sleep(1)
96
+
97
+ def _get_task_name(self):
98
+ if hasattr(self._eval_env, '_task_class'):
99
+ eval_task_name = change_case(self._eval_env._task_class.__name__)
100
+ multi_task = False
101
+ elif hasattr(self._eval_env, '_task_classes'):
102
+ if self._eval_env.active_task_id != -1:
103
+ task_id = (self._eval_env.active_task_id) % len(self._eval_env._task_classes)
104
+ eval_task_name = change_case(self._eval_env._task_classes[task_id].__name__)
105
+ else:
106
+ eval_task_name = ''
107
+ multi_task = True
108
+ else:
109
+ raise Exception('Neither task_class nor task_classes found in eval env')
110
+ return eval_task_name, multi_task
111
+
112
+ def _run_eval_independent(self, name: str,
113
+ stats_accumulator,
114
+ weight,
115
+ writer_lock,
116
+ eval=True,
117
+ device_idx=0,
118
+ save_metrics=True,
119
+ cinematic_recorder_cfg=None):
120
+
121
+ self._name = name
122
+ self._save_metrics = save_metrics
123
+ self._is_test_set = type(weight) == dict
124
+
125
+ self._agent = copy.deepcopy(self._agent)
126
+
127
+ device = torch.device('cuda:%d' % device_idx) if torch.cuda.device_count() > 1 else torch.device('cuda:0')
128
+ with writer_lock: # hack to prevent multiple CLIP downloads ... argh should use a separate lock
129
+ self._agent.build(training=False, device=device)
130
+
131
+ logging.info('%s: Launching env.' % name)
132
+ np.random.seed()
133
+
134
+ logging.info('Agent information:')
135
+ logging.info(self._agent)
136
+
137
+ env = self._eval_env
138
+ env.eval = eval
139
+ env.launch()
140
+
141
+ # initialize cinematic recorder if specified
142
+ rec_cfg = cinematic_recorder_cfg
143
+ if rec_cfg.enabled:
144
+ cam_placeholder = Dummy('cam_cinematic_placeholder')
145
+ cam = VisionSensor.create(rec_cfg.camera_resolution)
146
+ cam.set_pose(cam_placeholder.get_pose())
147
+ cam.set_parent(cam_placeholder)
148
+
149
+ cam_motion = CircleCameraMotion(cam, Dummy('cam_cinematic_base'), rec_cfg.rotate_speed)
150
+ tr = TaskRecorder(env, cam_motion, fps=rec_cfg.fps)
151
+
152
+ env.env._action_mode.arm_action_mode.set_callable_each_step(tr.take_snap)
153
+
154
+ if not os.path.exists(self._weightsdir):
155
+ raise Exception('No weights directory found.')
156
+
157
+ # to save or not to save evaluation metrics (set as False for recording videos)
158
+ if self._save_metrics:
159
+ csv_file = 'eval_data.csv' if not self._is_test_set else 'test_data.csv'
160
+ writer = LogWriter(self._logdir, True, True,
161
+ env_csv=csv_file)
162
+
163
+ # one weight for all tasks (used for validation)
164
+ if type(weight) == int:
165
+ logging.info('Evaluating weight %s' % weight)
166
+ weight_path = os.path.join(self._weightsdir, str(weight))
167
+ seed_path = self._weightsdir.replace('/weights', '')
168
+ self._agent.load_weights(weight_path)
169
+ weight_name = str(weight)
170
+
171
+ new_transitions = {'train_envs': 0, 'eval_envs': 0}
172
+ total_transitions = {'train_envs': 0, 'eval_envs': 0}
173
+ current_task_id = -1
174
+
175
+ for n_eval in range(self._num_eval_runs):
176
+ if rec_cfg.enabled:
177
+ tr._cam_motion.save_pose()
178
+
179
+ # best weight for each task (used for test evaluation)
180
+ if type(weight) == dict:
181
+ task_name = list(weight.keys())[n_eval]
182
+ task_weight = weight[task_name]
183
+ weight_path = os.path.join(self._weightsdir, str(task_weight))
184
+ seed_path = self._weightsdir.replace('/weights', '')
185
+ self._agent.load_weights(weight_path)
186
+ weight_name = str(task_weight)
187
+ print('Evaluating weight %s for %s' % (weight_name, task_name))
188
+
189
+ # evaluate on N tasks * M episodes per task = total eval episodes
190
+ for ep in range(self._eval_episodes):
191
+ eval_demo_seed = ep + self._eval_from_eps_number
192
+ logging.info('%s: Starting episode %d, seed %d.' % (name, ep, eval_demo_seed))
193
+
194
+ # the current task gets reset after every M episodes
195
+ episode_rollout = []
196
+ generator = self._rollout_generator.generator(
197
+ self._step_signal, env, self._agent,
198
+ self._episode_length, self._timesteps,
199
+ eval, eval_demo_seed=eval_demo_seed,
200
+ record_enabled=rec_cfg.enabled)
201
+ try:
202
+ for replay_transition in generator:
203
+ while True:
204
+ if self._kill_signal.value:
205
+ env.shutdown()
206
+ return
207
+ if (eval or self._target_replay_ratio is None or
208
+ self._step_signal.value <= 0 or (
209
+ self._current_replay_ratio.value >
210
+ self._target_replay_ratio)):
211
+ break
212
+ time.sleep(1)
213
+ logging.debug(
214
+ 'Agent. Waiting for replay_ratio %f to be more than %f' %
215
+ (self._current_replay_ratio.value, self._target_replay_ratio))
216
+
217
+ with self.write_lock:
218
+ if len(self.agent_summaries) == 0:
219
+ # Only store new summaries if the previous ones
220
+ # have been popped by the main env runner.
221
+ for s in self._agent.act_summaries():
222
+ self.agent_summaries.append(s)
223
+ episode_rollout.append(replay_transition)
224
+ except StopIteration as e:
225
+ continue
226
+ except Exception as e:
227
+ env.shutdown()
228
+ raise e
229
+
230
+ with self.write_lock:
231
+ for transition in episode_rollout:
232
+ self.stored_transitions.append((name, transition, eval))
233
+
234
+ new_transitions['eval_envs'] += 1
235
+ total_transitions['eval_envs'] += 1
236
+ stats_accumulator.step(transition, eval)
237
+ current_task_id = transition.info['active_task_id']
238
+
239
+ self._num_eval_episodes_signal.value += 1
240
+
241
+ task_name, _ = self._get_task_name()
242
+ reward = episode_rollout[-1].reward
243
+ lang_goal = env._lang_goal
244
+ print(f"Evaluating {task_name} | Episode {ep} | Score: {reward} | Lang Goal: {lang_goal}")
245
+
246
+ # save recording
247
+ if rec_cfg.enabled:
248
+ success = reward > 0.99
249
+ record_file = os.path.join(seed_path, 'videos',
250
+ '%s_w%s_s%s_%s.mp4' % (task_name,
251
+ weight_name,
252
+ eval_demo_seed,
253
+ 'succ' if success else 'fail'))
254
+
255
+ lang_goal = self._eval_env._lang_goal
256
+
257
+ tr.save(record_file, lang_goal, reward)
258
+ tr._cam_motion.restore_pose()
259
+
260
+ # report summaries
261
+ summaries = []
262
+ summaries.extend(stats_accumulator.pop())
263
+
264
+ eval_task_name, multi_task = self._get_task_name()
265
+
266
+ if eval_task_name and multi_task:
267
+ for s in summaries:
268
+ if 'eval' in s.name:
269
+ s.name = '%s/%s' % (s.name, eval_task_name)
270
+
271
+ if len(summaries) > 0:
272
+ if multi_task:
273
+ task_score = [s.value for s in summaries if f'eval_envs/return/{eval_task_name}' in s.name][0]
274
+ else:
275
+ task_score = [s.value for s in summaries if f'eval_envs/return' in s.name][0]
276
+ else:
277
+ task_score = "unknown"
278
+
279
+ print(f"Finished {eval_task_name} | Final Score: {task_score}\n")
280
+
281
+ if self._save_metrics:
282
+ with writer_lock:
283
+ writer.add_summaries(weight_name, summaries)
284
+
285
+ self._new_transitions = {'train_envs': 0, 'eval_envs': 0}
286
+ self.agent_summaries[:] = []
287
+ self.stored_transitions[:] = []
288
+
289
+ if self._save_metrics:
290
+ with writer_lock:
291
+ writer.end_iteration()
292
+
293
+ logging.info('Finished evaluation.')
294
+ env.shutdown()
295
+
296
+ def kill(self):
297
+ self._kill_signal.value = True
external/yarr/yarr/runners/env_runner.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+ import os
4
+ import signal
5
+ import time
6
+ from multiprocessing import Value
7
+ from threading import Thread
8
+ from typing import List
9
+ from typing import Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ from yarr.agents.agent import Agent
14
+ from yarr.agents.agent import ScalarSummary
15
+ from yarr.agents.agent import Summary
16
+ from yarr.envs.env import Env
17
+ from yarr.replay_buffer.replay_buffer import ReplayBuffer
18
+ from yarr.runners._env_runner import _EnvRunner
19
+ from yarr.utils.rollout_generator import RolloutGenerator
20
+ from yarr.utils.stat_accumulator import StatAccumulator, SimpleAccumulator
21
+ from yarr.utils.process_str import change_case
22
+ from helpers.custom_rlbench_env import CustomRLBenchEnv, CustomMultiTaskRLBenchEnv
23
+
24
+ class EnvRunner(object):
25
+
26
+ def __init__(self,
27
+ train_env: Env,
28
+ agent: Agent,
29
+ train_replay_buffer: Union[ReplayBuffer, List[ReplayBuffer]],
30
+ num_train_envs: int,
31
+ num_eval_envs: int,
32
+ rollout_episodes: int,
33
+ eval_episodes: int,
34
+ training_iterations: int,
35
+ eval_from_eps_number: int,
36
+ episode_length: int,
37
+ eval_env: Union[Env, None] = None,
38
+ eval_replay_buffer: Union[ReplayBuffer, List[ReplayBuffer], None] = None,
39
+ stat_accumulator: Union[StatAccumulator, None] = None,
40
+ rollout_generator: RolloutGenerator = None,
41
+ weightsdir: str = None,
42
+ logdir: str = None,
43
+ max_fails: int = 10,
44
+ num_eval_runs: int = 1,
45
+ env_device: torch.device = None,
46
+ multi_task: bool = False):
47
+ self._train_env = train_env
48
+ self._eval_env = eval_env if eval_env else train_env
49
+ self._agent = agent
50
+ self._train_envs = num_train_envs
51
+ self._eval_envs = num_eval_envs
52
+ self._train_replay_buffer = train_replay_buffer if isinstance(train_replay_buffer, list) else [train_replay_buffer]
53
+ self._timesteps = self._train_replay_buffer[0].timesteps if self._train_replay_buffer[0] is not None else 1
54
+
55
+ if eval_replay_buffer is not None:
56
+ eval_replay_buffer = eval_replay_buffer if isinstance(eval_replay_buffer, list) else [eval_replay_buffer]
57
+ self._eval_replay_buffer = eval_replay_buffer
58
+ self._rollout_episodes = rollout_episodes
59
+ self._eval_episodes = eval_episodes
60
+ self._num_eval_runs = num_eval_runs
61
+ self._training_iterations = training_iterations
62
+ self._eval_from_eps_number = eval_from_eps_number
63
+ self._episode_length = episode_length
64
+ self._stat_accumulator = stat_accumulator
65
+ self._rollout_generator = (
66
+ RolloutGenerator() if rollout_generator is None
67
+ else rollout_generator)
68
+ self._rollout_generator._env_device = env_device
69
+ self._weightsdir = weightsdir
70
+ self._logdir = logdir
71
+ self._max_fails = max_fails
72
+ self._env_device = env_device
73
+ self._previous_loaded_weight_folder = ''
74
+ self._p = None
75
+ self._kill_signal = Value('b', 0)
76
+ self._step_signal = Value('i', -1)
77
+ self._num_eval_episodes_signal = Value('i', 0)
78
+ self._eval_epochs_signal = Value('i', 0)
79
+ self._eval_report_signal = Value('b', 0)
80
+ self._new_transitions = {'train_envs': 0, 'eval_envs': 0}
81
+ self._total_transitions = {'train_envs': 0, 'eval_envs': 0}
82
+ self.log_freq = 1000 # Will get overridden later
83
+ self.target_replay_ratio = None # Will get overridden later
84
+ self.current_replay_ratio = Value('f', -1)
85
+ self._current_task_id = -1
86
+ self._multi_task = multi_task
87
+
88
+ def summaries(self) -> List[Summary]:
89
+ summaries = []
90
+ if self._stat_accumulator is not None:
91
+ summaries.extend(self._stat_accumulator.pop())
92
+ for key, value in self._new_transitions.items():
93
+ summaries.append(ScalarSummary('%s/new_transitions' % key, value))
94
+ for key, value in self._total_transitions.items():
95
+ summaries.append(ScalarSummary('%s/total_transitions' % key, value))
96
+ self._new_transitions = {'train_envs': 0, 'eval_envs': 0}
97
+ summaries.extend(self._agent_summaries)
98
+
99
+ # add current task_name to eval summaries .... argh this should be inside a helper function
100
+ if hasattr(self._eval_env, '_task_class'):
101
+ eval_task_name = change_case(self._eval_env._task_class.__name__)
102
+ elif hasattr(self._eval_env, '_task_classes'):
103
+ if self._current_task_id != -1:
104
+ task_id = (self._current_task_id) % len(self._eval_env._task_classes)
105
+ eval_task_name = change_case(self._eval_env._task_classes[task_id].__name__)
106
+ else:
107
+ eval_task_name = ''
108
+ else:
109
+ raise Exception('Neither task_class nor task_classes found in eval env')
110
+
111
+ # multi-task summaries
112
+ if eval_task_name and self._multi_task:
113
+ for s in summaries:
114
+ if 'eval' in s.name:
115
+ s.name = '%s/%s' % (s.name, eval_task_name)
116
+
117
+ return summaries
118
+
119
+ def _update(self):
120
+ # Move the stored transitions to the replay and accumulate statistics.
121
+ new_transitions = collections.defaultdict(int)
122
+ with self._internal_env_runner.write_lock:
123
+ self._agent_summaries = list(
124
+ self._internal_env_runner.agent_summaries)
125
+ if self._num_eval_episodes_signal.value % self._eval_episodes == 0 and self._num_eval_episodes_signal.value > 0:
126
+ self._internal_env_runner.agent_summaries[:] = []
127
+ for name, transition, eval in self._internal_env_runner.stored_transitions:
128
+ add_to_buffer = (not eval) or self._eval_replay_buffer is not None
129
+ if add_to_buffer:
130
+ kwargs = dict(transition.observation)
131
+ replay_index = transition.info["active_task_id"]
132
+ rb = self._eval_replay_buffer[replay_index] if eval else self._train_replay_buffer[replay_index]
133
+ rb.add(
134
+ np.array(transition.action), transition.reward,
135
+ transition.terminal,
136
+ transition.timeout, **kwargs)
137
+ if transition.terminal:
138
+ rb.add_final(
139
+ **transition.final_observation)
140
+ new_transitions[name] += 1
141
+ self._new_transitions[
142
+ 'eval_envs' if eval else 'train_envs'] += 1
143
+ self._total_transitions[
144
+ 'eval_envs' if eval else 'train_envs'] += 1
145
+ if self._stat_accumulator is not None:
146
+ self._stat_accumulator.step(transition, eval)
147
+ self._current_task_id = transition.info["active_task_id"] if eval else -1
148
+ self._internal_env_runner.stored_transitions[:] = [] # Clear list
149
+ return new_transitions
150
+
151
+ def _run(self, save_load_lock):
152
+ self._internal_env_runner = _EnvRunner(
153
+ self._train_env, self._eval_env, self._agent, self._timesteps, self._train_envs,
154
+ self._eval_envs, self._rollout_episodes, self._eval_episodes,
155
+ self._training_iterations, self._eval_from_eps_number, self._episode_length, self._kill_signal,
156
+ self._step_signal, self._num_eval_episodes_signal,
157
+ self._eval_epochs_signal, self._eval_report_signal,
158
+ self.log_freq, self._rollout_generator, save_load_lock,
159
+ self.current_replay_ratio, self.target_replay_ratio,
160
+ self._weightsdir, self._logdir,
161
+ self._env_device, self._previous_loaded_weight_folder,
162
+ num_eval_runs=self._num_eval_runs)
163
+ training_envs = self._internal_env_runner.spin_up_envs('train_env', self._train_envs, False)
164
+ eval_envs = self._internal_env_runner.spin_up_envs('eval_env', self._eval_envs, True)
165
+ envs = training_envs + eval_envs
166
+ no_transitions = {env.name: 0 for env in envs}
167
+ while True:
168
+ for p in envs:
169
+ if p.exitcode is not None:
170
+ envs.remove(p)
171
+ if p.exitcode != 0:
172
+ self._internal_env_runner.p_failures[p.name] += 1
173
+ n_failures = self._internal_env_runner.p_failures[p.name]
174
+ if n_failures > self._max_fails:
175
+ logging.error('Env %s failed too many times (%d times > %d)' %
176
+ (p.name, n_failures, self._max_fails))
177
+ raise RuntimeError('Too many process failures.')
178
+ logging.warning('Env %s failed (%d times <= %d). restarting' %
179
+ (p.name, n_failures, self._max_fails))
180
+ p = self._internal_env_runner.restart_process(p.name)
181
+ envs.append(p)
182
+
183
+ if not self._kill_signal.value:
184
+ new_transitions = self._update()
185
+ for p in envs:
186
+ if new_transitions[p.name] == 0:
187
+ no_transitions[p.name] += 1
188
+ else:
189
+ no_transitions[p.name] = 0
190
+ if no_transitions[p.name] > 1200: #600: # 10min
191
+ logging.warning("Env %s hangs, so restarting" % p.name)
192
+ envs.remove(p)
193
+ os.kill(p.pid, signal.SIGTERM)
194
+ p = self._internal_env_runner.restart_process(p.name)
195
+ envs.append(p)
196
+ no_transitions[p.name] = 0
197
+
198
+ if len(envs) == 0:
199
+ break
200
+ time.sleep(1)
201
+
202
+ def start(self, save_load_lock):
203
+ self._p = Thread(target=self._run, args=(save_load_lock,), daemon=True)
204
+ self._p.name = 'EnvRunnerThread'
205
+ self._p.start()
206
+
207
+ def wait(self):
208
+ if self._p.is_alive():
209
+ self._p.join()
210
+
211
+ def stop(self):
212
+ if self._p.is_alive():
213
+ self._kill_signal.value = True
214
+ self._p.join()
215
+
216
+ def set_step(self, step):
217
+ self._step_signal.value = step
218
+
219
+ def set_eval_report(self, report):
220
+ self._eval_report_signal.value = report
221
+
222
+ def set_eval_epochs(self, epochs):
223
+ self._eval_epochs_signal.value = epochs
224
+
external/yarr/yarr/runners/independent_env_runner.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from typing import List
4
+ from typing import Union
5
+
6
+ from yarr.agents.agent import Agent
7
+ from yarr.envs.env import Env
8
+ from yarr.replay_buffer.replay_buffer import ReplayBuffer
9
+ from yarr.runners._independent_env_runner import _IndependentEnvRunner
10
+ from yarr.utils.rollout_generator import RolloutGenerator
11
+ from yarr.utils.stat_accumulator import StatAccumulator, SimpleAccumulator
12
+ from yarr.agents.agent import Summary
13
+ from helpers.custom_rlbench_env import CustomRLBenchEnv, CustomMultiTaskRLBenchEnv
14
+
15
+ from yarr.runners.env_runner import EnvRunner
16
+
17
+
18
+ class IndependentEnvRunner(EnvRunner):
19
+
20
+ def __init__(self,
21
+ train_env: Env,
22
+ agent: Agent,
23
+ train_replay_buffer: Union[ReplayBuffer, List[ReplayBuffer]],
24
+ num_train_envs: int,
25
+ num_eval_envs: int,
26
+ rollout_episodes: int,
27
+ eval_episodes: int,
28
+ training_iterations: int,
29
+ eval_from_eps_number: int,
30
+ episode_length: int,
31
+ eval_env: Union[Env, None] = None,
32
+ eval_replay_buffer: Union[ReplayBuffer, List[ReplayBuffer], None] = None,
33
+ stat_accumulator: Union[StatAccumulator, None] = None,
34
+ rollout_generator: RolloutGenerator = None,
35
+ weightsdir: str = None,
36
+ logdir: str = None,
37
+ max_fails: int = 10,
38
+ num_eval_runs: int = 1,
39
+ env_device: torch.device = None,
40
+ multi_task: bool = False):
41
+ super().__init__(train_env, agent, train_replay_buffer, num_train_envs, num_eval_envs,
42
+ rollout_episodes, eval_episodes, training_iterations, eval_from_eps_number,
43
+ episode_length, eval_env, eval_replay_buffer, stat_accumulator,
44
+ rollout_generator, weightsdir, logdir, max_fails, num_eval_runs,
45
+ env_device, multi_task)
46
+
47
+ def summaries(self) -> List[Summary]:
48
+ summaries = []
49
+ if self._stat_accumulator is not None:
50
+ summaries.extend(self._stat_accumulator.pop())
51
+ self._new_transitions = {'train_envs': 0, 'eval_envs': 0}
52
+ summaries.extend(self._agent_summaries)
53
+
54
+ # add current task_name to eval summaries .... argh this should be inside a helper function
55
+ if hasattr(self._eval_env, '_task_class'):
56
+ eval_task_name = change_case(self._eval_env._task_class.__name__)
57
+ elif hasattr(self._eval_env, '_task_classes'):
58
+ if self._current_task_id != -1:
59
+ task_id = (self._current_task_id) % len(self._eval_env._task_classes)
60
+ eval_task_name = change_case(self._eval_env._task_classes[task_id].__name__)
61
+ else:
62
+ eval_task_name = ''
63
+ else:
64
+ raise Exception('Neither task_class nor task_classes found in eval env')
65
+
66
+ # multi-task summaries
67
+ if eval_task_name and self._multi_task:
68
+ for s in summaries:
69
+ if 'eval' in s.name:
70
+ s.name = '%s/%s' % (s.name, eval_task_name)
71
+
72
+ return summaries
73
+
74
+ # serialized evaluator for individual tasks
75
+ def start(self, weight,
76
+ save_load_lock, writer_lock,
77
+ env_config,
78
+ device_idx,
79
+ save_metrics,
80
+ cinematic_recorder_cfg):
81
+
82
+ if hasattr(self, "_on_thread_start"):
83
+ self._on_thread_start()
84
+
85
+ multi_task = isinstance(env_config[0], list)
86
+ if multi_task:
87
+ eval_env = CustomMultiTaskRLBenchEnv(
88
+ task_classes=env_config[0],
89
+ observation_config=env_config[1],
90
+ action_mode=env_config[2],
91
+ dataset_root=env_config[3],
92
+ episode_length=env_config[4],
93
+ headless=env_config[5],
94
+ swap_task_every=env_config[6],
95
+ include_lang_goal_in_obs=env_config[7],
96
+ time_in_state=env_config[8],
97
+ record_every_n=env_config[9])
98
+ else:
99
+ eval_env = CustomRLBenchEnv(
100
+ task_class=env_config[0],
101
+ observation_config=env_config[1],
102
+ action_mode=env_config[2],
103
+ dataset_root=env_config[3],
104
+ episode_length=env_config[4],
105
+ headless=env_config[5],
106
+ include_lang_goal_in_obs=env_config[6],
107
+ time_in_state=env_config[7],
108
+ record_every_n=env_config[8])
109
+
110
+ self._internal_env_runner = _IndependentEnvRunner(
111
+ self._train_env, eval_env, self._agent, self._timesteps, self._train_envs,
112
+ self._eval_envs, self._rollout_episodes, self._eval_episodes,
113
+ self._training_iterations, self._eval_from_eps_number, self._episode_length, self._kill_signal,
114
+ self._step_signal, self._num_eval_episodes_signal,
115
+ self._eval_epochs_signal, self._eval_report_signal,
116
+ self.log_freq, self._rollout_generator, None,
117
+ self.current_replay_ratio, self.target_replay_ratio,
118
+ self._weightsdir, self._logdir,
119
+ self._env_device, self._previous_loaded_weight_folder,
120
+ num_eval_runs=self._num_eval_runs)
121
+
122
+ stat_accumulator = SimpleAccumulator(eval_video_fps=30)
123
+ self._internal_env_runner._run_eval_independent('eval_env',
124
+ stat_accumulator,
125
+ weight,
126
+ writer_lock,
127
+ True,
128
+ device_idx,
129
+ save_metrics,
130
+ cinematic_recorder_cfg)
external/yarr/yarr/runners/offline_train_runner.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ import shutil
5
+ import time
6
+ from typing import List
7
+ from typing import Union
8
+
9
+ import psutil
10
+ import torch
11
+ import pandas as pd
12
+ from yarr.agents.agent import Agent
13
+ from yarr.replay_buffer.wrappers.pytorch_replay_buffer import \
14
+ PyTorchReplayBuffer
15
+ from yarr.utils.log_writer import LogWriter
16
+ from yarr.utils.stat_accumulator import StatAccumulator
17
+
18
+
19
+ class OfflineTrainRunner():
20
+
21
+ def __init__(self,
22
+ agent: Agent,
23
+ wrapped_replay_buffer: PyTorchReplayBuffer,
24
+ train_device: torch.device,
25
+ stat_accumulator: Union[StatAccumulator, None] = None,
26
+ iterations: int = int(6e6),
27
+ logdir: str = '/tmp/yarr/logs',
28
+ logging_level: int = logging.INFO,
29
+ log_freq: int = 10,
30
+ weightsdir: str = '/tmp/yarr/weights',
31
+ num_weights_to_keep: int = 60,
32
+ save_freq: int = 100,
33
+ tensorboard_logging: bool = True,
34
+ csv_logging: bool = False,
35
+ load_existing_weights: bool = True,
36
+ rank: int = None,
37
+ world_size: int = None):
38
+ self._agent = agent
39
+ self._wrapped_buffer = wrapped_replay_buffer
40
+ self._stat_accumulator = stat_accumulator
41
+ self._iterations = iterations
42
+ self._logdir = logdir
43
+ self._logging_level = logging_level
44
+ self._log_freq = log_freq
45
+ self._weightsdir = weightsdir
46
+ self._num_weights_to_keep = num_weights_to_keep
47
+ self._save_freq = save_freq
48
+
49
+ self._wrapped_buffer = wrapped_replay_buffer
50
+ self._train_device = train_device
51
+ self._tensorboard_logging = tensorboard_logging
52
+ self._csv_logging = csv_logging
53
+ self._load_existing_weights = load_existing_weights
54
+ self._rank = rank
55
+ self._world_size = world_size
56
+
57
+ self._writer = None
58
+ if logdir is None:
59
+ logging.info("'logdir' was None. No logging will take place.")
60
+ else:
61
+ self._writer = LogWriter(
62
+ self._logdir, tensorboard_logging, csv_logging)
63
+
64
+ if weightsdir is None:
65
+ logging.info(
66
+ "'weightsdir' was None. No weight saving will take place.")
67
+ else:
68
+ os.makedirs(self._weightsdir, exist_ok=True)
69
+
70
+ def _save_model(self, i):
71
+ d = os.path.join(self._weightsdir, str(i))
72
+ os.makedirs(d, exist_ok=True)
73
+ self._agent.save_weights(d)
74
+
75
+ # remove oldest save
76
+ prev_dir = os.path.join(self._weightsdir, str(
77
+ i - self._save_freq * self._num_weights_to_keep))
78
+ if os.path.exists(prev_dir):
79
+ shutil.rmtree(prev_dir)
80
+
81
+ def _step(self, i, sampled_batch):
82
+ update_dict = self._agent.update(i, sampled_batch)
83
+ total_losses = update_dict['total_losses']
84
+ return total_losses
85
+
86
+ def _get_resume_eval_epoch(self):
87
+ starting_epoch = 0
88
+ eval_csv_file = self._weightsdir.replace('weights', 'eval_data.csv') # TODO(mohit): check if it's supposed be 'env_data.csv'
89
+ if os.path.exists(eval_csv_file):
90
+ eval_dict = pd.read_csv(eval_csv_file).to_dict()
91
+ epochs = list(eval_dict['step'].values())
92
+ return epochs[-1] if len(epochs) > 0 else starting_epoch
93
+ else:
94
+ return starting_epoch
95
+
96
+ def start(self):
97
+
98
+ if hasattr(self, "_on_thread_start"):
99
+ self._on_thread_start()
100
+ else:
101
+ logging.getLogger().setLevel(self._logging_level)
102
+
103
+
104
+ self._agent = copy.deepcopy(self._agent)
105
+ self._agent.build(training=True, device=self._train_device)
106
+
107
+ if self._weightsdir is not None:
108
+ existing_weights = sorted([int(f) for f in os.listdir(self._weightsdir)])
109
+ if (not self._load_existing_weights) or len(existing_weights) == 0:
110
+ # self._save_model(0)
111
+ start_iter = 0
112
+ else:
113
+ resume_iteration = existing_weights[-1]
114
+ self._agent.load_weights(os.path.join(self._weightsdir, str(resume_iteration)))
115
+ start_iter = resume_iteration + 1
116
+ if self._rank == 0:
117
+ logging.info(f"Resuming training from iteration {resume_iteration} ...")
118
+
119
+ dataset = self._wrapped_buffer.dataset()
120
+ data_iter = iter(dataset)
121
+
122
+ process = psutil.Process(os.getpid())
123
+ num_cpu = psutil.cpu_count()
124
+
125
+ for i in range(start_iter, self._iterations):
126
+ log_iteration = i % self._log_freq == 0 and i > 0
127
+
128
+ if log_iteration:
129
+ process.cpu_percent(interval=None)
130
+
131
+ t = time.time()
132
+ sampled_batch = next(data_iter)
133
+ sample_time = time.time() - t
134
+
135
+ batch = {k: v.to(self._train_device) for k, v in sampled_batch.items() if type(v) == torch.Tensor}
136
+ t = time.time()
137
+ loss = self._step(i, batch)
138
+ step_time = time.time() - t
139
+
140
+ if self._rank == 0:
141
+ if log_iteration and self._writer is not None:
142
+ agent_summaries = self._agent.update_summaries()
143
+ self._writer.add_summaries(i, agent_summaries)
144
+
145
+ self._writer.add_scalar(
146
+ i, 'monitoring/memory_gb',
147
+ process.memory_info().rss * 1e-9)
148
+ self._writer.add_scalar(
149
+ i, 'monitoring/cpu_percent',
150
+ process.cpu_percent(interval=None) / num_cpu)
151
+
152
+ logging.info(f"Train Step {i:06d} | Loss: {loss:0.5f} | Sample time: {sample_time:0.6f} | Step time: {step_time:0.4f}.")
153
+
154
+ self._writer.end_iteration()
155
+
156
+ if i % self._save_freq == 0 and self._weightsdir is not None:
157
+ self._save_model(i)
158
+
159
+ if self._rank == 0 and self._writer is not None:
160
+ self._writer.close()
161
+ logging.info('Stopping envs ...')
162
+
163
+ self._wrapped_buffer.replay_buffer.shutdown()
external/yarr/yarr/runners/pytorch_train_runner.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ import shutil
5
+ import signal
6
+ import sys
7
+ import threading
8
+ import time
9
+ from multiprocessing import Lock
10
+ from typing import Optional, List
11
+ from typing import Union
12
+
13
+ import gc
14
+ import numpy as np
15
+ import psutil
16
+ import torch
17
+ import pandas as pd
18
+ from yarr.agents.agent import Agent
19
+ from yarr.replay_buffer.wrappers.pytorch_replay_buffer import \
20
+ PyTorchReplayBuffer
21
+ from yarr.runners.env_runner import EnvRunner
22
+ from yarr.runners.train_runner import TrainRunner
23
+ from yarr.utils.log_writer import LogWriter
24
+ from yarr.utils.stat_accumulator import StatAccumulator
25
+ from yarr.replay_buffer.prioritized_replay_buffer import PrioritizedReplayBuffer
26
+
27
+ NUM_WEIGHTS_TO_KEEP = 60
28
+
29
+
30
+ class PyTorchTrainRunner(TrainRunner):
31
+
32
+ def __init__(self,
33
+ agent: Agent,
34
+ env_runner: EnvRunner,
35
+ wrapped_replay_buffer: Union[
36
+ PyTorchReplayBuffer, List[PyTorchReplayBuffer]],
37
+ train_device: torch.device,
38
+ replay_buffer_sample_rates: List[float] = None,
39
+ stat_accumulator: Union[StatAccumulator, None] = None,
40
+ iterations: int = int(1e6),
41
+ num_train_envs: int = 1,
42
+ num_eval_envs: int = 1,
43
+ eval_episodes: int = 10,
44
+ logdir: str = '/tmp/yarr/logs',
45
+ log_freq: int = 10,
46
+ transitions_before_train: int = 1000,
47
+ weightsdir: str = '/tmp/yarr/weights',
48
+ save_freq: int = 100,
49
+ replay_ratio: Optional[float] = None,
50
+ tensorboard_logging: bool = True,
51
+ csv_logging: bool = False,
52
+ buffers_per_batch: int = -1, # -1 = all
53
+ load_existing_weights: bool = True):
54
+ super(PyTorchTrainRunner, self).__init__(
55
+ agent, env_runner, wrapped_replay_buffer,
56
+ stat_accumulator,
57
+ iterations, logdir, log_freq, transitions_before_train, weightsdir,
58
+ save_freq)
59
+
60
+ env_runner.log_freq = log_freq
61
+ env_runner.target_replay_ratio = replay_ratio
62
+ self._wrapped_buffer = wrapped_replay_buffer if isinstance(
63
+ wrapped_replay_buffer, list) else [wrapped_replay_buffer]
64
+ self._replay_buffer_sample_rates = (
65
+ [1.0] if replay_buffer_sample_rates is None else
66
+ replay_buffer_sample_rates)
67
+ if len(self._replay_buffer_sample_rates) != len(wrapped_replay_buffer):
68
+ logging.warning(
69
+ 'Numbers of replay buffers differs from sampling rates. Setting as uniform sampling.')
70
+ self._replay_buffer_sample_rates = [1.0 / len(self._wrapped_buffer)] * len(self._wrapped_buffer)
71
+ if sum(self._replay_buffer_sample_rates) != 1:
72
+ raise ValueError('Sum of sampling rates should be 1.')
73
+
74
+ self._train_device = train_device
75
+ self._tensorboard_logging = tensorboard_logging
76
+ self._csv_logging = csv_logging
77
+ self._num_train_envs = num_train_envs
78
+ self._num_eval_envs = num_eval_envs
79
+ self._eval_episodes = eval_episodes
80
+ self._load_existing_weights = load_existing_weights
81
+
82
+ if replay_ratio is not None and replay_ratio < 0:
83
+ raise ValueError("max_replay_ratio must be positive.")
84
+ self._target_replay_ratio = replay_ratio
85
+
86
+ self._writer = None
87
+ if logdir is None:
88
+ logging.info("'logdir' was None. No logging will take place.")
89
+ else:
90
+ self._writer = LogWriter(
91
+ self._logdir, tensorboard_logging, csv_logging)
92
+ if weightsdir is None:
93
+ logging.info(
94
+ "'weightsdir' was None. No weight saving will take place.")
95
+ else:
96
+ os.makedirs(self._weightsdir, exist_ok=True)
97
+ self._buffers_per_batch = buffers_per_batch if buffers_per_batch > 0 else len(wrapped_replay_buffer)
98
+
99
+ def _save_model(self, i):
100
+ with self._save_load_lock:
101
+ d = os.path.join(self._weightsdir, str(i))
102
+ os.makedirs(d, exist_ok=True)
103
+ self._agent.save_weights(d)
104
+ # Remove oldest save
105
+ prev_dir = os.path.join(self._weightsdir, str(
106
+ i - self._save_freq * NUM_WEIGHTS_TO_KEEP))
107
+ if os.path.exists(prev_dir):
108
+ shutil.rmtree(prev_dir)
109
+
110
+ def _step(self, i, sampled_batch):
111
+ update_dict = self._agent.update(i, sampled_batch)
112
+ if "priority" in update_dict:
113
+ priority = update_dict['priority'].cpu().detach().numpy() if isinstance(update_dict['priority'], torch.Tensor) else np.numpy(update_dict['priority'])
114
+ else:
115
+ priority = None
116
+ indices = sampled_batch['indices'].cpu().detach().numpy()
117
+ acc_bs = 0
118
+ for wb_idx, wb in enumerate(self._wrapped_buffer):
119
+ bs = wb.replay_buffer.batch_size
120
+ if 'priority' in update_dict:
121
+ indices_ = indices[:, wb_idx]
122
+ if hasattr(wb, "replay_buffer"):
123
+ if len(priority.shape) > 1:
124
+ priority_ = priority[:, wb_idx]
125
+ else:
126
+ # legacy version
127
+ priority_ = priority[acc_bs: acc_bs + bs]
128
+ wb.replay_buffer.set_priority(indices_, priority_)
129
+ acc_bs += bs
130
+
131
+ def _signal_handler(self, sig, frame):
132
+ if threading.current_thread().name != 'MainThread':
133
+ return
134
+ logging.info('SIGINT captured. Shutting down.'
135
+ 'This may take a few seconds.')
136
+ self._env_runner.stop()
137
+ [r.replay_buffer.shutdown() for r in self._wrapped_buffer]
138
+ sys.exit(0)
139
+
140
+ def _get_add_counts(self):
141
+ return np.array([
142
+ r.replay_buffer.add_count for r in self._wrapped_buffer])
143
+
144
+ def _get_sum_add_counts(self):
145
+ return sum([
146
+ r.replay_buffer.add_count for r in self._wrapped_buffer])
147
+
148
+ def _get_resume_eval_epoch(self):
149
+ starting_epoch = 0
150
+ eval_csv_file = self._weightsdir.replace('weights', 'eval_data.csv') # TODO(mohit): check if it's supposed be 'env_data.csv'
151
+ if os.path.exists(eval_csv_file):
152
+ eval_dict = pd.read_csv(eval_csv_file).to_dict()
153
+ epochs = list(eval_dict['step'].values())
154
+ return epochs[-1] if len(epochs) > 0 else starting_epoch
155
+ else:
156
+ return starting_epoch
157
+
158
+ def start(self):
159
+
160
+ signal.signal(signal.SIGINT, self._signal_handler)
161
+
162
+ self._save_load_lock = Lock()
163
+
164
+ # Kick off the environments
165
+ self._env_runner.start(self._save_load_lock)
166
+
167
+ self._agent = copy.deepcopy(self._agent)
168
+ self._agent.build(training=True, device=self._train_device)
169
+
170
+ if self._weightsdir is not None:
171
+ existing_weights = sorted([int(f) for f in os.listdir(self._weightsdir)])
172
+ if (not self._load_existing_weights) or len(existing_weights) == 0:
173
+ self._save_model(0)
174
+ start_iter = 0
175
+ else:
176
+ resume_iteration = existing_weights[-1]
177
+ self._agent.load_weights(os.path.join(self._weightsdir, str(resume_iteration)))
178
+ start_iter = resume_iteration + 1
179
+ print(f"Resuming training from iteration {resume_iteration} ...")
180
+
181
+ if self._num_eval_envs > 0:
182
+ eval_epoch = self._get_resume_eval_epoch()
183
+ self._env_runner.set_eval_epochs(eval_epoch)
184
+ self._writer.set_resumed_from_prev_run(True)
185
+ print(f"Resuming evaluation from epoch {eval_epoch} ...")
186
+
187
+ while (np.any(self._get_add_counts() < self._transitions_before_train)):
188
+ time.sleep(1)
189
+ logging.info(
190
+ 'Waiting for %d samples before training. Currently have %s.' %
191
+ (self._transitions_before_train, str(self._get_add_counts())))
192
+
193
+ datasets = [r.dataset() for r in self._wrapped_buffer]
194
+ data_iter = [iter(d) for d in datasets]
195
+
196
+ init_replay_size = self._get_sum_add_counts().astype(float)
197
+ batch_times_buffers_per_sample = sum([
198
+ r.replay_buffer.batch_size for r in self._wrapped_buffer[:self._buffers_per_batch]])
199
+ process = psutil.Process(os.getpid())
200
+ num_cpu = psutil.cpu_count()
201
+
202
+ for i in range(start_iter, self._iterations):
203
+ self._env_runner.set_step(i)
204
+
205
+ if self._num_train_envs > 0 or self._num_eval_envs == 0:
206
+ log_iteration = i % self._log_freq == 0 and i > 0
207
+ else:
208
+ num_eval_episodes = self._env_runner._num_eval_episodes_signal.value
209
+ log_iteration = self._env_runner._eval_report_signal.value and num_eval_episodes > 0
210
+
211
+ if log_iteration:
212
+ process.cpu_percent(interval=None)
213
+
214
+ def get_replay_ratio():
215
+ size_used = batch_times_buffers_per_sample * i
216
+ size_added = (
217
+ self._get_sum_add_counts()
218
+ - init_replay_size
219
+ )
220
+ replay_ratio = size_used / (size_added + 1e-6)
221
+ return replay_ratio
222
+
223
+ if self._target_replay_ratio is not None:
224
+ # wait for env_runner collecting enough samples
225
+ while True:
226
+ replay_ratio = get_replay_ratio()
227
+ self._env_runner.current_replay_ratio.value = replay_ratio
228
+ if replay_ratio < self._target_replay_ratio:
229
+ break
230
+ time.sleep(1)
231
+ logging.debug(
232
+ 'Waiting for replay_ratio %f to be less than %f.' %
233
+ (replay_ratio, self._target_replay_ratio))
234
+ del replay_ratio
235
+
236
+ t = time.time()
237
+
238
+ sampled_task_ids = np.random.choice(
239
+ range(len(datasets)), self._buffers_per_batch, replace=False)
240
+ sampled_batch = [next(data_iter[j]) for j in sampled_task_ids]
241
+ result = {}
242
+ for key in sampled_batch[0]:
243
+ result[key] = torch.stack([d[key] for d in sampled_batch], 1)
244
+ sampled_batch = result
245
+ sample_time = time.time() - t
246
+
247
+ batch = {k: v.to(self._train_device) for k, v in sampled_batch.items()}
248
+ t = time.time()
249
+ self._step(i, batch)
250
+ step_time = time.time() - t
251
+
252
+ if log_iteration and self._writer is not None:
253
+ replay_ratio = get_replay_ratio()
254
+ logging.info('Train Step %d. Eval Epoch %d. Sample time: %s. Step time: %s. Replay ratio: %s.' % (
255
+ i, self._env_runner._eval_epochs_signal.value, sample_time, step_time, replay_ratio))
256
+ agent_summaries = self._agent.update_summaries()
257
+ env_summaries = self._env_runner.summaries()
258
+
259
+ # agent summaries
260
+ self._writer.add_summaries(i, agent_summaries)
261
+
262
+ # env summaries
263
+ self._writer.add_summaries(self._env_runner._eval_epochs_signal.value, env_summaries)
264
+
265
+ for r_i, wrapped_buffer in enumerate(self._wrapped_buffer):
266
+ self._writer.add_scalar(
267
+ i, 'replay%d/add_count' % r_i,
268
+ wrapped_buffer.replay_buffer.add_count)
269
+ self._writer.add_scalar(
270
+ i, 'replay%d/size' % r_i,
271
+ wrapped_buffer.replay_buffer.replay_capacity
272
+ if wrapped_buffer.replay_buffer.is_full()
273
+ else wrapped_buffer.replay_buffer.add_count)
274
+
275
+ self._writer.add_scalar(
276
+ i, 'replay/replay_ratio', replay_ratio)
277
+ self._writer.add_scalar(
278
+ i, 'replay/update_to_insert_ratio',
279
+ float(i) / float(
280
+ self._get_sum_add_counts() -
281
+ init_replay_size + 1e-6))
282
+
283
+ self._writer.add_scalar(
284
+ i, 'monitoring/sample_time_per_item',
285
+ sample_time / batch_times_buffers_per_sample)
286
+ self._writer.add_scalar(
287
+ i, 'monitoring/train_time_per_item',
288
+ step_time / batch_times_buffers_per_sample)
289
+ self._writer.add_scalar(
290
+ i, 'monitoring/memory_gb',
291
+ process.memory_info().rss * 1e-9)
292
+ self._writer.add_scalar(
293
+ i, 'monitoring/cpu_percent',
294
+ process.cpu_percent(interval=None) / num_cpu)
295
+
296
+ self._env_runner.set_eval_report(False)
297
+
298
+ self._writer.end_iteration()
299
+
300
+ if i % self._save_freq == 0 and self._weightsdir is not None:
301
+ self._save_model(i)
302
+
303
+ if self._writer is not None:
304
+ self._writer.close()
305
+
306
+ logging.info('Stopping envs ...')
307
+ self._env_runner.stop()
308
+ [r.replay_buffer.shutdown() for r in self._wrapped_buffer]
external/yarr/yarr/runners/train_runner.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod, ABC
2
+ from typing import Union, List
3
+
4
+ from yarr.agents.agent import Agent
5
+ from yarr.replay_buffer.wrappers import WrappedReplayBuffer
6
+ from yarr.runners.env_runner import EnvRunner
7
+ from yarr.utils.stat_accumulator import StatAccumulator
8
+
9
+
10
+ class TrainRunner(ABC):
11
+
12
+ def __init__(self,
13
+ agent: Agent,
14
+ env_runner: EnvRunner,
15
+ wrapped_replay_buffer: WrappedReplayBuffer,
16
+ stat_accumulator: Union[StatAccumulator, None] = None,
17
+ iterations: int = int(1e6),
18
+ logdir: str = '/tmp/yarr/logs',
19
+ log_freq: int = 500,
20
+ transitions_before_train: int = 1000,
21
+ weightsdir: str = '/tmp/yarr/weights',
22
+ save_freq: int = 100,
23
+ ):
24
+ self._agent = agent
25
+ self._env_runner = env_runner
26
+ self._wrapped_buffer = wrapped_replay_buffer
27
+ self._stat_accumulator = stat_accumulator
28
+ self._iterations = iterations
29
+ self._logdir = logdir
30
+ self._log_freq = log_freq
31
+ self._transitions_before_train = transitions_before_train
32
+ self._weightsdir = weightsdir
33
+ self._save_freq = save_freq
34
+
35
+ @abstractmethod
36
+ def start(self):
37
+ pass
external/yarr/yarr/utils/__init__.py ADDED
File without changes
external/yarr/yarr/utils/log_writer.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import logging
3
+ import os
4
+ from collections import OrderedDict
5
+
6
+ import numpy as np
7
+ import torch
8
+ from yarr.agents.agent import ScalarSummary, HistogramSummary, ImageSummary, \
9
+ VideoSummary, TextSummary
10
+ from torch.utils.tensorboard import SummaryWriter
11
+
12
+
13
+ class LogWriter(object):
14
+
15
+ def __init__(self,
16
+ logdir: str,
17
+ tensorboard_logging: bool,
18
+ csv_logging: bool,
19
+ train_csv: str = 'train_data.csv',
20
+ env_csv: str = 'env_data.csv'):
21
+ self._tensorboard_logging = tensorboard_logging
22
+ self._csv_logging = csv_logging
23
+ os.makedirs(logdir, exist_ok=True)
24
+ if tensorboard_logging:
25
+ self._tf_writer = SummaryWriter(logdir)
26
+ if csv_logging:
27
+ self._train_prev_row_data = self._train_row_data = OrderedDict()
28
+ self._train_csv_file = os.path.join(logdir, train_csv)
29
+ self._env_prev_row_data = self._env_row_data = OrderedDict()
30
+ self._env_csv_file = os.path.join(logdir, env_csv)
31
+ self._train_field_names = None
32
+ self._env_field_names = None
33
+
34
+ def add_scalar(self, i, name, value):
35
+ if self._tensorboard_logging:
36
+ self._tf_writer.add_scalar(name, value, i)
37
+ if self._csv_logging:
38
+ if 'env' in name or 'eval' in name or 'test' in name:
39
+ if len(self._env_row_data) == 0:
40
+ self._env_row_data['step'] = i
41
+ self._env_row_data[name] = value.item() if isinstance(
42
+ value, torch.Tensor) else value
43
+ else:
44
+ if len(self._train_row_data) == 0:
45
+ self._train_row_data['step'] = i
46
+ self._train_row_data[name] = value.item() if isinstance(
47
+ value, torch.Tensor) else value
48
+
49
+ def add_summaries(self, i, summaries):
50
+ for summary in summaries:
51
+ try:
52
+ if isinstance(summary, ScalarSummary):
53
+ self.add_scalar(i, summary.name, summary.value)
54
+ elif self._tensorboard_logging:
55
+ if isinstance(summary, HistogramSummary):
56
+ self._tf_writer.add_histogram(
57
+ summary.name, summary.value, i)
58
+ elif isinstance(summary, ImageSummary):
59
+ # Only grab first item in batch
60
+ v = (summary.value if summary.value.ndim == 3 else
61
+ summary.value[0])
62
+ self._tf_writer.add_image(summary.name, v, i)
63
+ elif isinstance(summary, VideoSummary):
64
+ # Only grab first item in batch
65
+ v = (summary.value if summary.value.ndim == 5 else
66
+ np.array([summary.value]))
67
+ self._tf_writer.add_video(
68
+ summary.name, v, i, fps=summary.fps)
69
+ elif isinstance(summary, TextSummary):
70
+ self._tf_writer.add_text(summary.name, summary.value, i)
71
+ except Exception as e:
72
+ logging.error('Error on summary: %s' % summary.name)
73
+ raise e
74
+
75
+ def end_iteration(self):
76
+ # write train data
77
+ if self._csv_logging and len(self._train_row_data) > 0:
78
+ should_write_train_header = not os.path.exists(self._train_csv_file)
79
+ with open(self._train_csv_file, mode='a+') as csv_f:
80
+ names = self._train_row_data.keys()
81
+ writer = csv.DictWriter(csv_f, fieldnames=names)
82
+ if should_write_train_header:
83
+ if self._train_field_names is None:
84
+ writer.writeheader()
85
+ else:
86
+ if not np.array_equal(self._train_field_names, self._train_row_data.keys()):
87
+ # Special case when we are logging faster than new
88
+ # summaries are coming in.
89
+ missing_keys = list(set(self._train_field_names) - set(
90
+ self._train_row_data.keys()))
91
+ for mk in missing_keys:
92
+ self._train_row_data[mk] = self._train_prev_row_data[mk]
93
+ self._train_field_names = names
94
+ try:
95
+ writer.writerow(self._train_row_data)
96
+ except Exception as e:
97
+ print(e)
98
+ self._train_prev_row_data = self._train_row_data
99
+ self._train_row_data = OrderedDict()
100
+
101
+ # write env data (also eval or test during evaluation)
102
+ if self._csv_logging and len(self._env_row_data) > 0:
103
+ should_write_env_header = not os.path.exists(self._env_csv_file)
104
+ with open(self._env_csv_file, mode='a+') as csv_f:
105
+ names = self._env_row_data.keys()
106
+ writer = csv.DictWriter(csv_f, fieldnames=names)
107
+ if should_write_env_header:
108
+ if self._env_field_names is None:
109
+ writer.writeheader()
110
+ else:
111
+ if not np.array_equal(self._env_field_names, self._env_row_data.keys()):
112
+ # Special case when we are logging faster than new
113
+ # summaries are coming in.
114
+ missing_keys = list(set(self._env_field_names) - set(
115
+ self._env_row_data.keys()))
116
+ for mk in missing_keys:
117
+ self._env_row_data[mk] = self._env_prev_row_data[mk]
118
+ self._env_field_names = names
119
+ try:
120
+ writer.writerow(self._env_row_data)
121
+ except Exception as e:
122
+ print(e)
123
+ self._env_prev_row_data = self._env_row_data
124
+ self._env_row_data = OrderedDict()
125
+
126
+ def close(self):
127
+ if self._tensorboard_logging:
128
+ self._tf_writer.close()
external/yarr/yarr/utils/multi_task_rollout_generator.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing import Value
2
+
3
+ import numpy as np
4
+
5
+ from yarr.agents.agent import Agent
6
+ from yarr.envs.env import Env
7
+ from yarr.envs.multi_task_env import MultiTaskEnv
8
+ from yarr.utils.transition import ReplayTransition
9
+
10
+
11
+ class RolloutGenerator(object):
12
+
13
+ def _get_type(self, x):
14
+ if x.dtype == np.float64:
15
+ return np.float32
16
+ return x.dtype
17
+
18
+ def generator(self, step_signal: Value, env: MultiTaskEnv, agent: Agent,
19
+ episode_length: int, timesteps: int, eval: bool):
20
+ obs = env.reset()
21
+ agent.reset()
22
+ obs_history = {k: [np.array(v, dtype=self._get_type(v))] * timesteps for k, v in obs.items()}
23
+ for step in range(episode_length):
24
+
25
+ prepped_data = {k: np.array([v]) for k, v in obs_history.items()}
26
+
27
+ act_result = agent.act(step_signal.value, prepped_data,
28
+ deterministic=eval)
29
+
30
+ # Convert to np if not already
31
+ agent_obs_elems = {k: np.array(v) for k, v in
32
+ act_result.observation_elements.items()}
33
+ agent_extra_elems = {k: np.array(v) for k, v in
34
+ act_result.replay_elements.items()}
35
+
36
+ transition = env.step(act_result)
37
+ timeout = False
38
+ if step == episode_length - 1:
39
+ # If last transition, and not terminal, then we timed out
40
+ timeout = not transition.terminal
41
+ if timeout:
42
+ transition.terminal = True
43
+ if "needs_reset" in transition.info:
44
+ transition.info["needs_reset"] = True
45
+
46
+ obs.update(agent_obs_elems)
47
+ obs_tp1 = dict(transition.observation)
48
+
49
+ for k in obs_history.keys():
50
+ obs_history[k].append(transition.observation[k])
51
+ obs_history[k].pop(0)
52
+
53
+ transition.info["active_task_id"] = env.active_task_id
54
+
55
+ replay_transition = ReplayTransition(
56
+ obs, act_result.action, transition.reward,
57
+ transition.terminal,
58
+ timeout, obs_tp1, agent_extra_elems,
59
+ transition.info)
60
+
61
+ obs = transition.observation
62
+ yield replay_transition
63
+
64
+ if transition.info.get("needs_reset", transition.terminal):
65
+ return
external/yarr/yarr/utils/observation_type.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type
2
+ import numpy as np
3
+
4
+
5
+ class ObservationElement(object):
6
+
7
+ def __init__(self, name: str, shape: tuple, type: Type[np.dtype]):
8
+ self.name = name
9
+ self.shape = shape
10
+ self.type = type
external/yarr/yarr/utils/process_str.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from functools import reduce
2
+
3
+
4
+ def change_case(str):
5
+ return reduce(lambda x, y: x + ('_' if y.isupper() else '') + y, str).lower()
external/yarr/yarr/utils/rollout_generator.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing import Value
2
+
3
+ import numpy as np
4
+ import torch
5
+ from yarr.agents.agent import Agent
6
+ from yarr.envs.env import Env
7
+ from yarr.utils.transition import ReplayTransition
8
+
9
+
10
+ class RolloutGenerator(object):
11
+
12
+ def _get_type(self, x):
13
+ if x.dtype == np.float64:
14
+ return np.float32
15
+ return x.dtype
16
+
17
+ def generator(self, step_signal: Value, env: Env, agent: Agent,
18
+ episode_length: int, timesteps: int,
19
+ eval: bool, eval_demo_seed: int = 0,
20
+ record_enabled: bool = False):
21
+
22
+ if eval:
23
+ obs = env.reset_to_demo(eval_demo_seed)
24
+ else:
25
+ obs = env.reset()
26
+
27
+ agent.reset()
28
+ obs_history = {k: [np.array(v, dtype=self._get_type(v))] * timesteps for k, v in obs.items()}
29
+ for step in range(episode_length):
30
+
31
+ prepped_data = {k:torch.tensor(np.array(v)[None], device=self._env_device) for k, v in obs_history.items()}
32
+
33
+ act_result = agent.act(step_signal.value, prepped_data,
34
+ deterministic=eval)
35
+
36
+ # Convert to np if not already
37
+ agent_obs_elems = {k: np.array(v) for k, v in
38
+ act_result.observation_elements.items()}
39
+ extra_replay_elements = {k: np.array(v) for k, v in
40
+ act_result.replay_elements.items()}
41
+
42
+ transition = env.step(act_result)
43
+ obs_tp1 = dict(transition.observation)
44
+ timeout = False
45
+ if step == episode_length - 1:
46
+ # If last transition, and not terminal, then we timed out
47
+ timeout = not transition.terminal
48
+ if timeout:
49
+ transition.terminal = True
50
+ if "needs_reset" in transition.info:
51
+ transition.info["needs_reset"] = True
52
+
53
+ obs_and_replay_elems = {}
54
+ obs_and_replay_elems.update(obs)
55
+ obs_and_replay_elems.update(agent_obs_elems)
56
+ obs_and_replay_elems.update(extra_replay_elements)
57
+
58
+ for k in obs_history.keys():
59
+ obs_history[k].append(transition.observation[k])
60
+ obs_history[k].pop(0)
61
+
62
+ transition.info["active_task_id"] = env.active_task_id
63
+
64
+ replay_transition = ReplayTransition(
65
+ obs_and_replay_elems, act_result.action, transition.reward,
66
+ transition.terminal, timeout, summaries=transition.summaries,
67
+ info=transition.info)
68
+
69
+ if transition.terminal or timeout:
70
+ # If the agent gives us observations then we need to call act
71
+ # one last time (i.e. acting in the terminal state).
72
+ if len(act_result.observation_elements) > 0:
73
+ prepped_data = {k: torch.tensor([v], device=self._env_device) for k, v in obs_history.items()}
74
+ act_result = agent.act(step_signal.value, prepped_data,
75
+ deterministic=eval)
76
+ agent_obs_elems_tp1 = {k: np.array(v) for k, v in
77
+ act_result.observation_elements.items()}
78
+ obs_tp1.update(agent_obs_elems_tp1)
79
+ replay_transition.final_observation = obs_tp1
80
+
81
+ if record_enabled and transition.terminal or timeout or step == episode_length - 1:
82
+ env.env._action_mode.arm_action_mode.record_end(env.env._scene,
83
+ steps=60, step_scene=True)
84
+
85
+ obs = dict(transition.observation)
86
+ yield replay_transition
87
+
88
+ if transition.info.get("needs_reset", transition.terminal):
89
+ return
external/yarr/yarr/utils/stat_accumulator.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing import Lock
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ from yarr.agents.agent import Summary, ScalarSummary
6
+ from yarr.utils.transition import ReplayTransition
7
+
8
+
9
+ class StatAccumulator(object):
10
+
11
+ def step(self, transition: ReplayTransition, eval: bool):
12
+ pass
13
+
14
+ def pop(self) -> List[Summary]:
15
+ pass
16
+
17
+ def peak(self) -> List[Summary]:
18
+ pass
19
+
20
+ def reset(self) -> None:
21
+ pass
22
+
23
+
24
+ class Metric(object):
25
+
26
+ def __init__(self):
27
+ self._previous = []
28
+ self._current = 0
29
+
30
+ def update(self, value):
31
+ self._current += value
32
+
33
+ def next(self):
34
+ self._previous.append(self._current)
35
+ self._current = 0
36
+
37
+ def reset(self):
38
+ self._previous.clear()
39
+
40
+ def min(self):
41
+ return np.min(self._previous)
42
+
43
+ def max(self):
44
+ return np.max(self._previous)
45
+
46
+ def mean(self):
47
+ return np.mean(self._previous)
48
+
49
+ def median(self):
50
+ return np.median(self._previous)
51
+
52
+ def std(self):
53
+ return np.std(self._previous)
54
+
55
+ def __len__(self):
56
+ return len(self._previous)
57
+
58
+ def __getitem__(self, i):
59
+ return self._previous[i]
60
+
61
+
62
+ class _SimpleAccumulator(StatAccumulator):
63
+
64
+ def __init__(self, prefix, eval_video_fps: int = 30,
65
+ mean_only: bool = True):
66
+ self._prefix = prefix
67
+ self._eval_video_fps = eval_video_fps
68
+ self._mean_only = mean_only
69
+ self._lock = Lock()
70
+ self._episode_returns = Metric()
71
+ self._episode_lengths = Metric()
72
+ self._summaries = []
73
+ self._transitions = 0
74
+
75
+ def _reset_data(self):
76
+ with self._lock:
77
+ self._episode_returns.reset()
78
+ self._episode_lengths.reset()
79
+ self._summaries.clear()
80
+
81
+ def step(self, transition: ReplayTransition, eval: bool):
82
+ with self._lock:
83
+ self._transitions += 1
84
+ self._episode_returns.update(transition.reward)
85
+ self._episode_lengths.update(1)
86
+ if transition.terminal:
87
+ self._episode_returns.next()
88
+ self._episode_lengths.next()
89
+ self._summaries.extend(list(transition.summaries))
90
+
91
+ def _get(self) -> List[Summary]:
92
+ sums = []
93
+
94
+ if self._mean_only:
95
+ stat_keys = ["mean"]
96
+ else:
97
+ stat_keys = ["min", "max", "mean", "median", "std"]
98
+ names = ["return", "length"]
99
+ metrics = [self._episode_returns, self._episode_lengths]
100
+ for name, metric in zip(names, metrics):
101
+ for stat_key in stat_keys:
102
+ if self._mean_only:
103
+ assert stat_key == "mean"
104
+ sum_name = '%s/%s' % (self._prefix, name)
105
+ else:
106
+ sum_name = '%s/%s/%s' % (self._prefix, name, stat_key)
107
+ sums.append(
108
+ ScalarSummary(sum_name, getattr(metric, stat_key)()))
109
+ sums.append(ScalarSummary(
110
+ '%s/total_transitions' % self._prefix, self._transitions))
111
+ sums.extend(self._summaries)
112
+ return sums
113
+
114
+ def pop(self) -> List[Summary]:
115
+ data = []
116
+ if len(self._episode_returns) > 1:
117
+ data = self._get()
118
+ self._reset_data()
119
+ return data
120
+
121
+ def peak(self) -> List[Summary]:
122
+ return self._get()
123
+
124
+ def reset(self):
125
+ self._transitions = 0
126
+ self._reset_data()
127
+
128
+
129
+ class SimpleAccumulator(StatAccumulator):
130
+
131
+ def __init__(self, eval_video_fps: int = 30, mean_only: bool = True):
132
+ self._train_acc = _SimpleAccumulator(
133
+ 'train_envs', eval_video_fps, mean_only=mean_only)
134
+ self._eval_acc = _SimpleAccumulator(
135
+ 'eval_envs', eval_video_fps, mean_only=mean_only)
136
+
137
+ def step(self, transition: ReplayTransition, eval: bool):
138
+ if eval:
139
+ self._eval_acc.step(transition, eval)
140
+ else:
141
+ self._train_acc.step(transition, eval)
142
+
143
+ def pop(self) -> List[Summary]:
144
+ return self._train_acc.pop() + self._eval_acc.pop()
145
+
146
+ def peak(self) -> List[Summary]:
147
+ return self._train_acc.peak() + self._eval_acc.peak()
148
+
149
+ def reset(self) -> None:
150
+ self._train_acc.reset()
151
+ self._eval_acc.reset()
152
+
153
+
154
+ class MultiTaskAccumulator(StatAccumulator):
155
+
156
+ def __init__(self, num_tasks,
157
+ eval_video_fps: int = 30, mean_only: bool = True,
158
+ train_prefix: str = 'train_task',
159
+ eval_prefix: str = 'eval_task'):
160
+ self._train_accs = [_SimpleAccumulator(
161
+ '%s%d/envs' % (train_prefix, i), eval_video_fps, mean_only=mean_only)
162
+ for i in range(num_tasks)]
163
+ self._eval_accs = [_SimpleAccumulator(
164
+ '%s%d/envs' % (eval_prefix, i), eval_video_fps, mean_only=mean_only)
165
+ for i in range(num_tasks)]
166
+ self._train_accs_mean = _SimpleAccumulator(
167
+ '%s_summary/envs' % train_prefix, eval_video_fps,
168
+ mean_only=mean_only)
169
+
170
+ def step(self, transition: ReplayTransition, eval: bool):
171
+ replay_index = transition.info["active_task_id"]
172
+ if eval:
173
+ self._eval_accs[replay_index].step(transition, eval)
174
+ else:
175
+ self._train_accs[replay_index].step(transition, eval)
176
+ self._train_accs_mean.step(transition, eval)
177
+
178
+ def pop(self) -> List[Summary]:
179
+ combined = self._train_accs_mean.pop()
180
+ for acc in self._train_accs + self._eval_accs:
181
+ combined.extend(acc.pop())
182
+ return combined
183
+
184
+ def peak(self) -> List[Summary]:
185
+ combined = self._train_accs_mean.peak()
186
+ for acc in self._train_accs + self._eval_accs:
187
+ combined.extend(acc.peak())
188
+ return combined
189
+
190
+ def reset(self) -> None:
191
+ self._train_accs_mean.reset()
192
+ [acc.reset() for acc in self._train_accs + self._eval_accs]
external/yarr/yarr/utils/transition.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+ from yarr.agents.agent import Summary
5
+
6
+
7
+ class Transition(object):
8
+
9
+ def __init__(self, observation: dict, reward: float, terminal: bool,
10
+ info: dict = None, summaries: List[Summary] = None):
11
+ self.observation = observation
12
+ self.reward = reward
13
+ self.terminal = terminal
14
+ self.info = info or {}
15
+ self.summaries = summaries or []
16
+
17
+
18
+ class ReplayTransition(object):
19
+
20
+ def __init__(self, observation: dict, action: np.ndarray,
21
+ reward: float, terminal: bool, timeout: bool,
22
+ final_observation: dict = None,
23
+ summaries: List[Summary] = None,
24
+ info: dict = None):
25
+ self.observation = observation
26
+ self.action = action
27
+ self.reward = reward
28
+ self.terminal = terminal
29
+ self.timeout = timeout
30
+ # final only populated on last timestep
31
+ self.final_observation = final_observation
32
+ self.summaries = summaries or []
33
+ self.info = info
external/yarr/yarr/utils/video_utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from pyrep.objects.dummy import Dummy
4
+ from pyrep.objects.vision_sensor import VisionSensor
5
+ from rlbench import Environment
6
+ from rlbench.backend.observation import Observation
7
+
8
+
9
+ class CameraMotion(object):
10
+ def __init__(self, cam: VisionSensor):
11
+ self.cam = cam
12
+
13
+ def step(self):
14
+ raise NotImplementedError()
15
+
16
+ def save_pose(self):
17
+ self._prev_pose = self.cam.get_pose()
18
+
19
+ def restore_pose(self):
20
+ self.cam.set_pose(self._prev_pose)
21
+
22
+
23
+ class CircleCameraMotion(CameraMotion):
24
+
25
+ def __init__(self, cam: VisionSensor, origin: Dummy,
26
+ speed: float, init_rotation: float = np.deg2rad(180)):
27
+ super().__init__(cam)
28
+ self.origin = origin
29
+ self.speed = speed # in radians
30
+ self.origin.rotate([0, 0, init_rotation])
31
+
32
+ def step(self):
33
+ self.origin.rotate([0, 0, self.speed])
34
+
35
+
36
+ class TaskRecorder(object):
37
+
38
+ def __init__(self, env: Environment, cam_motion: CameraMotion, fps=30):
39
+ self._env = env
40
+ self._cam_motion = cam_motion
41
+ self._fps = fps
42
+ self._snaps = []
43
+ self._current_snaps = []
44
+
45
+ def take_snap(self, obs: Observation):
46
+ self._cam_motion.step()
47
+ self._current_snaps.append(
48
+ (self._cam_motion.cam.capture_rgb() * 255.).astype(np.uint8))
49
+
50
+ def save(self, path, lang_goal, reward):
51
+ print(f"Converting to video ... {path}")
52
+ os.makedirs(os.path.dirname(path), exist_ok=True)
53
+ # OpenCV QT version can conflict with PyRep, so import here
54
+ import cv2
55
+ image_size = self._cam_motion.cam.get_resolution()
56
+ video = cv2.VideoWriter(
57
+ path, cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), self._fps,
58
+ tuple(image_size))
59
+
60
+ for image in self._current_snaps:
61
+ frame = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
62
+
63
+ font = cv2.FONT_HERSHEY_DUPLEX
64
+ font_scale = (0.45 * image_size[0]) / 640
65
+ font_thickness = 2
66
+
67
+
68
+ if lang_goal:
69
+
70
+ lang_textsize = cv2.getTextSize(lang_goal, font, font_scale, font_thickness)[0]
71
+ lang_textX = (image_size[0] - lang_textsize[0]) // 2
72
+
73
+ frame = cv2.putText(frame, lang_goal, org=(lang_textX, image_size[1] - 35),
74
+ fontScale=font_scale, fontFace=font, color=(0, 0, 0),
75
+ thickness=font_thickness, lineType=cv2.LINE_AA)
76
+
77
+
78
+ video.write(frame)
79
+ video.release()
80
+ self._current_snaps = []