Add files using upload-large-folder tool
Browse files- MANIFEST.txt +0 -0
- README.md +51 -1
- external/README.md +26 -0
- external/peract_bimanual/LICENSE +402 -0
- external/peract_bimanual/README.md +300 -0
- external/peract_bimanual/helpers/clip/__init__.py +0 -0
- external/peract_bimanual/model-card.md +47 -0
- external/peract_bimanual/peract_config.py +32 -0
- external/peract_bimanual/pyproject.toml +35 -0
- external/peract_bimanual/run_seed_fn.py +218 -0
- external/peract_bimanual/train.py +116 -0
- external/peract_bimanual/voxel/__init__.py +0 -0
- external/peract_bimanual/voxel/voxel_grid.py +252 -0
- external/yarr/.gitignore +13 -0
- external/yarr/LICENSE +201 -0
- external/yarr/README.md +28 -0
- external/yarr/logo.png +0 -0
- external/yarr/requirements.txt +11 -0
- external/yarr/setup.py +37 -0
- external/yarr/yarr/__init__.py +1 -0
- external/yarr/yarr/agents/__init__.py +0 -0
- external/yarr/yarr/agents/agent.py +345 -0
- external/yarr/yarr/envs/__init__.py +0 -0
- external/yarr/yarr/envs/env.py +64 -0
- external/yarr/yarr/envs/rlbench_env.py +332 -0
- external/yarr/yarr/replay_buffer/__init__.py +0 -0
- external/yarr/yarr/replay_buffer/prioritized_replay_buffer.py +217 -0
- external/yarr/yarr/replay_buffer/replay_buffer.py +71 -0
- external/yarr/yarr/replay_buffer/sum_tree.py +201 -0
- external/yarr/yarr/replay_buffer/task_uniform_replay_buffer.py +182 -0
- external/yarr/yarr/replay_buffer/uniform_replay_buffer.py +804 -0
- external/yarr/yarr/replay_buffer/wrappers/__init__.py +24 -0
- external/yarr/yarr/replay_buffer/wrappers/pytorch_replay_buffer.py +82 -0
- external/yarr/yarr/runners/__init__.py +0 -0
- external/yarr/yarr/runners/_env_runner.py +228 -0
- external/yarr/yarr/runners/_independent_env_runner.py +297 -0
- external/yarr/yarr/runners/env_runner.py +224 -0
- external/yarr/yarr/runners/independent_env_runner.py +130 -0
- external/yarr/yarr/runners/offline_train_runner.py +163 -0
- external/yarr/yarr/runners/pytorch_train_runner.py +308 -0
- external/yarr/yarr/runners/train_runner.py +37 -0
- external/yarr/yarr/utils/__init__.py +0 -0
- external/yarr/yarr/utils/log_writer.py +128 -0
- external/yarr/yarr/utils/multi_task_rollout_generator.py +65 -0
- external/yarr/yarr/utils/observation_type.py +10 -0
- external/yarr/yarr/utils/process_str.py +5 -0
- external/yarr/yarr/utils/rollout_generator.py +89 -0
- external/yarr/yarr/utils/stat_accumulator.py +192 -0
- external/yarr/yarr/utils/transition.py +33 -0
- external/yarr/yarr/utils/video_utils.py +80 -0
MANIFEST.txt
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
README.md
CHANGED
|
@@ -14,6 +14,12 @@ This pass is a label study, not a policy study. No `pi0.5` integration is includ
|
|
| 14 |
|
| 15 |
## What is in this upload
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
- `code/rr_label_study/`
|
| 18 |
- Core study code, including dense replay, visibility metrics, pregrasp/extraction oracles, keyframe extraction, intervention checks, and summary metric computation.
|
| 19 |
- `code/scripts/`
|
|
@@ -30,6 +36,44 @@ This pass is a label study, not a policy study. No `pi0.5` integration is includ
|
|
| 30 |
- `MANIFEST.txt`
|
| 31 |
- Flat file listing of the uploaded bundle contents.
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
## Final validated artifact
|
| 34 |
|
| 35 |
The clean single-episode artifact is:
|
|
@@ -122,11 +166,13 @@ The local run used:
|
|
| 122 |
- `markusgrotz/peract_bimanual` at `bb0232a6ba3fe116566e9568f0c7af980ed6703d`
|
| 123 |
- `markusgrotz/YARR` at `6822ff78602c77878b27d4cfe759ce029c67bffb`
|
| 124 |
|
|
|
|
|
|
|
| 125 |
## Reproducing on the same hardware class
|
| 126 |
|
| 127 |
1. Read `environment/dataset_notes.txt`.
|
| 128 |
2. Run `environment/setup_same_hardware.sh /workspace`.
|
| 129 |
-
3. Source `environment/activate_rlbench_runtime.sh /workspace`.
|
| 130 |
4. Run the dense study:
|
| 131 |
|
| 132 |
```bash
|
|
@@ -160,3 +206,7 @@ On a single deterministic episode, normalized time can become a degenerate perfe
|
|
| 160 |
## Dataset note
|
| 161 |
|
| 162 |
The upstream RLBench demonstration dataset itself is not re-uploaded in this bundle. This repo contains the study code and all artifacts generated from the local run. The expected dataset path is documented in `environment/dataset_notes.txt`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
## What is in this upload
|
| 16 |
|
| 17 |
+
- `external/`
|
| 18 |
+
- Full local benchmark snapshots copied from the RunPod workspace.
|
| 19 |
+
- `external/rlbench/`: local RLBench tree used for this run.
|
| 20 |
+
- `external/pyrep/`: local PyRep tree used for this run.
|
| 21 |
+
- `external/peract_bimanual/`: local PerAct bimanual tree used for context.
|
| 22 |
+
- `external/yarr/`: local YARR tree used for context.
|
| 23 |
- `code/rr_label_study/`
|
| 24 |
- Core study code, including dense replay, visibility metrics, pregrasp/extraction oracles, keyframe extraction, intervention checks, and summary metric computation.
|
| 25 |
- `code/scripts/`
|
|
|
|
| 36 |
- `MANIFEST.txt`
|
| 37 |
- Flat file listing of the uploaded bundle contents.
|
| 38 |
|
| 39 |
+
## Repository map
|
| 40 |
+
|
| 41 |
+
Relevant entry points and where to look:
|
| 42 |
+
|
| 43 |
+
- Benchmark snapshots
|
| 44 |
+
- `external/README.md`
|
| 45 |
+
- `external/rlbench/README.md`
|
| 46 |
+
- `external/rlbench/rlbench/bimanual_tasks/`
|
| 47 |
+
- `external/rlbench/rlbench/action_modes/`
|
| 48 |
+
- `external/pyrep/README.md`
|
| 49 |
+
- `external/pyrep/pyrep/`
|
| 50 |
+
- `external/peract_bimanual/`
|
| 51 |
+
- `external/yarr/`
|
| 52 |
+
- Study code
|
| 53 |
+
- `code/rr_label_study/oven_study.py`
|
| 54 |
+
- `code/scripts/run_oven_label_study.py`
|
| 55 |
+
- `code/scripts/launch_parallel_oven_label_study.py`
|
| 56 |
+
- `code/scripts/run_oven_single_frame.py`
|
| 57 |
+
- `code/scripts/repair_oven_episode_dense.py`
|
| 58 |
+
- Final clean artifact
|
| 59 |
+
- `artifacts/results/oven_episode0_repaired_v1/episode0.dense.csv`
|
| 60 |
+
- `artifacts/results/oven_episode0_repaired_v1/episode0.keyframes.csv`
|
| 61 |
+
- `artifacts/results/oven_episode0_repaired_v1/episode0.metrics.json`
|
| 62 |
+
- `artifacts/results/oven_episode0_repaired_v1/summary.json`
|
| 63 |
+
- Intermediate/debug artifacts
|
| 64 |
+
- `artifacts/results/oven_episode0_full*/`
|
| 65 |
+
- `artifacts/results/oven_to240_*/`
|
| 66 |
+
- `artifacts/results/oven_episode0_independent_v1/`
|
| 67 |
+
- `artifacts/results/parallel_smoke_2x10/`
|
| 68 |
+
- Environment/repro
|
| 69 |
+
- `environment/system_info.txt`
|
| 70 |
+
- `environment/repo_revisions.txt`
|
| 71 |
+
- `environment/conda_env_rlbench.yml`
|
| 72 |
+
- `environment/pip_freeze_rlbench.txt`
|
| 73 |
+
- `environment/setup_same_hardware.sh`
|
| 74 |
+
- `environment/activate_rlbench_runtime.sh`
|
| 75 |
+
- `environment/dataset_notes.txt`
|
| 76 |
+
|
| 77 |
## Final validated artifact
|
| 78 |
|
| 79 |
The clean single-episode artifact is:
|
|
|
|
| 166 |
- `markusgrotz/peract_bimanual` at `bb0232a6ba3fe116566e9568f0c7af980ed6703d`
|
| 167 |
- `markusgrotz/YARR` at `6822ff78602c77878b27d4cfe759ce029c67bffb`
|
| 168 |
|
| 169 |
+
Those exact local source snapshots are also included under `external/`.
|
| 170 |
+
|
| 171 |
## Reproducing on the same hardware class
|
| 172 |
|
| 173 |
1. Read `environment/dataset_notes.txt`.
|
| 174 |
2. Run `environment/setup_same_hardware.sh /workspace`.
|
| 175 |
+
3. Source `environment/activate_rlbench_runtime.sh /workspace`.
|
| 176 |
4. Run the dense study:
|
| 177 |
|
| 178 |
```bash
|
|
|
|
| 206 |
## Dataset note
|
| 207 |
|
| 208 |
The upstream RLBench demonstration dataset itself is not re-uploaded in this bundle. This repo contains the study code and all artifacts generated from the local run. The expected dataset path is documented in `environment/dataset_notes.txt`.
|
| 209 |
+
|
| 210 |
+
The cloned benchmark code is included directly in this upload under `external/`.
|
| 211 |
+
|
| 212 |
+
CoppeliaSim binaries are not included in this repo. The setup helpers expect a local extraction at `/workspace/coppelia_sim`.
|
external/README.md
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# External Benchmark Snapshots
|
| 2 |
+
|
| 3 |
+
This directory contains the local benchmark/source trees copied from the RunPod workspace used for the study.
|
| 4 |
+
|
| 5 |
+
Included trees:
|
| 6 |
+
|
| 7 |
+
- `rlbench/`
|
| 8 |
+
- Source snapshot of `/workspace/rlbench`
|
| 9 |
+
- Upstream: `https://github.com/markusgrotz/RLBench.git`
|
| 10 |
+
- Commit: `8af748c51287989294e00c9c670e3330a0e35ed5`
|
| 11 |
+
- `pyrep/`
|
| 12 |
+
- Source snapshot of `/workspace/pyrep`
|
| 13 |
+
- Upstream: `https://github.com/markusgrotz/PyRep.git`
|
| 14 |
+
- Commit: `b8bd1d7a3182adcd570d001649c0849047ebf197`
|
| 15 |
+
- `peract_bimanual/`
|
| 16 |
+
- Source snapshot of `/workspace/peract_bimanual`
|
| 17 |
+
- Upstream: `https://github.com/markusgrotz/peract_bimanual.git`
|
| 18 |
+
- Commit: `bb0232a6ba3fe116566e9568f0c7af980ed6703d`
|
| 19 |
+
- `yarr/`
|
| 20 |
+
- Source snapshot of `/workspace/yarr`
|
| 21 |
+
- Upstream: `https://github.com/markusgrotz/YARR.git`
|
| 22 |
+
- Commit: `6822ff78602c77878b27d4cfe759ce029c67bffb`
|
| 23 |
+
|
| 24 |
+
These are source snapshots, not git clones with `.git/` metadata.
|
| 25 |
+
|
| 26 |
+
See `../environment/repo_revisions.txt` for the recorded origin URLs and revisions.
|
external/peract_bimanual/LICENSE
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
| 202 |
+
Apache License
|
| 203 |
+
Version 2.0, January 2004
|
| 204 |
+
http://www.apache.org/licenses/
|
| 205 |
+
|
| 206 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 207 |
+
|
| 208 |
+
1. Definitions.
|
| 209 |
+
|
| 210 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 211 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 212 |
+
|
| 213 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 214 |
+
the copyright owner that is granting the License.
|
| 215 |
+
|
| 216 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 217 |
+
other entities that control, are controlled by, or are under common
|
| 218 |
+
control with that entity. For the purposes of this definition,
|
| 219 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 220 |
+
direction or management of such entity, whether by contract or
|
| 221 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 222 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 223 |
+
|
| 224 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 225 |
+
exercising permissions granted by this License.
|
| 226 |
+
|
| 227 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 228 |
+
including but not limited to software source code, documentation
|
| 229 |
+
source, and configuration files.
|
| 230 |
+
|
| 231 |
+
"Object" form shall mean any form resulting from mechanical
|
| 232 |
+
transformation or translation of a Source form, including but
|
| 233 |
+
not limited to compiled object code, generated documentation,
|
| 234 |
+
and conversions to other media types.
|
| 235 |
+
|
| 236 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 237 |
+
Object form, made available under the License, as indicated by a
|
| 238 |
+
copyright notice that is included in or attached to the work
|
| 239 |
+
(an example is provided in the Appendix below).
|
| 240 |
+
|
| 241 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 242 |
+
form, that is based on (or derived from) the Work and for which the
|
| 243 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 244 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 245 |
+
of this License, Derivative Works shall not include works that remain
|
| 246 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 247 |
+
the Work and Derivative Works thereof.
|
| 248 |
+
|
| 249 |
+
"Contribution" shall mean any work of authorship, including
|
| 250 |
+
the original version of the Work and any modifications or additions
|
| 251 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 252 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 253 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 254 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 255 |
+
means any form of electronic, verbal, or written communication sent
|
| 256 |
+
to the Licensor or its representatives, including but not limited to
|
| 257 |
+
communication on electronic mailing lists, source code control systems,
|
| 258 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 259 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 260 |
+
excluding communication that is conspicuously marked or otherwise
|
| 261 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 262 |
+
|
| 263 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 264 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 265 |
+
subsequently incorporated within the Work.
|
| 266 |
+
|
| 267 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 268 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 269 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 270 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 271 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 272 |
+
Work and such Derivative Works in Source or Object form.
|
| 273 |
+
|
| 274 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 275 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 276 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 277 |
+
(except as stated in this section) patent license to make, have made,
|
| 278 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 279 |
+
where such license applies only to those patent claims licensable
|
| 280 |
+
by such Contributor that are necessarily infringed by their
|
| 281 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 282 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 283 |
+
institute patent litigation against any entity (including a
|
| 284 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 285 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 286 |
+
or contributory patent infringement, then any patent licenses
|
| 287 |
+
granted to You under this License for that Work shall terminate
|
| 288 |
+
as of the date such litigation is filed.
|
| 289 |
+
|
| 290 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 291 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 292 |
+
modifications, and in Source or Object form, provided that You
|
| 293 |
+
meet the following conditions:
|
| 294 |
+
|
| 295 |
+
(a) You must give any other recipients of the Work or
|
| 296 |
+
Derivative Works a copy of this License; and
|
| 297 |
+
|
| 298 |
+
(b) You must cause any modified files to carry prominent notices
|
| 299 |
+
stating that You changed the files; and
|
| 300 |
+
|
| 301 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 302 |
+
that You distribute, all copyright, patent, trademark, and
|
| 303 |
+
attribution notices from the Source form of the Work,
|
| 304 |
+
excluding those notices that do not pertain to any part of
|
| 305 |
+
the Derivative Works; and
|
| 306 |
+
|
| 307 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 308 |
+
distribution, then any Derivative Works that You distribute must
|
| 309 |
+
include a readable copy of the attribution notices contained
|
| 310 |
+
within such NOTICE file, excluding those notices that do not
|
| 311 |
+
pertain to any part of the Derivative Works, in at least one
|
| 312 |
+
of the following places: within a NOTICE text file distributed
|
| 313 |
+
as part of the Derivative Works; within the Source form or
|
| 314 |
+
documentation, if provided along with the Derivative Works; or,
|
| 315 |
+
within a display generated by the Derivative Works, if and
|
| 316 |
+
wherever such third-party notices normally appear. The contents
|
| 317 |
+
of the NOTICE file are for informational purposes only and
|
| 318 |
+
do not modify the License. You may add Your own attribution
|
| 319 |
+
notices within Derivative Works that You distribute, alongside
|
| 320 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 321 |
+
that such additional attribution notices cannot be construed
|
| 322 |
+
as modifying the License.
|
| 323 |
+
|
| 324 |
+
You may add Your own copyright statement to Your modifications and
|
| 325 |
+
may provide additional or different license terms and conditions
|
| 326 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 327 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 328 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 329 |
+
the conditions stated in this License.
|
| 330 |
+
|
| 331 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 332 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 333 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 334 |
+
this License, without any additional terms or conditions.
|
| 335 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 336 |
+
the terms of any separate license agreement you may have executed
|
| 337 |
+
with Licensor regarding such Contributions.
|
| 338 |
+
|
| 339 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 340 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 341 |
+
except as required for reasonable and customary use in describing the
|
| 342 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 343 |
+
|
| 344 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 345 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 346 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 347 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 348 |
+
implied, including, without limitation, any warranties or conditions
|
| 349 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 350 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 351 |
+
appropriateness of using or redistributing the Work and assume any
|
| 352 |
+
risks associated with Your exercise of permissions under this License.
|
| 353 |
+
|
| 354 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 355 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 356 |
+
unless required by applicable law (such as deliberate and grossly
|
| 357 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 358 |
+
liable to You for damages, including any direct, indirect, special,
|
| 359 |
+
incidental, or consequential damages of any character arising as a
|
| 360 |
+
result of this License or out of the use or inability to use the
|
| 361 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 362 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 363 |
+
other commercial damages or losses), even if such Contributor
|
| 364 |
+
has been advised of the possibility of such damages.
|
| 365 |
+
|
| 366 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 367 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 368 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 369 |
+
or other liability obligations and/or rights consistent with this
|
| 370 |
+
License. However, in accepting such obligations, You may act only
|
| 371 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 372 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 373 |
+
defend, and hold each Contributor harmless for any liability
|
| 374 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 375 |
+
of your accepting any such warranty or additional liability.
|
| 376 |
+
|
| 377 |
+
END OF TERMS AND CONDITIONS
|
| 378 |
+
|
| 379 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 380 |
+
|
| 381 |
+
To apply the Apache License to your work, attach the following
|
| 382 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 383 |
+
replaced with your own identifying information. (Don't include
|
| 384 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 385 |
+
comment syntax for the file format. We also recommend that a
|
| 386 |
+
file or class name and description of purpose be included on the
|
| 387 |
+
same "printed page" as the copyright notice for easier
|
| 388 |
+
identification within third-party archives.
|
| 389 |
+
|
| 390 |
+
Copyright [yyyy] [name of copyright owner]
|
| 391 |
+
|
| 392 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 393 |
+
you may not use this file except in compliance with the License.
|
| 394 |
+
You may obtain a copy of the License at
|
| 395 |
+
|
| 396 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 397 |
+
|
| 398 |
+
Unless required by applicable law or agreed to in writing, software
|
| 399 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 400 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 401 |
+
See the License for the specific language governing permissions and
|
| 402 |
+
limitations under the License.
|
external/peract_bimanual/README.md
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Perceiver-Actor^2: A Multi-Task Transformer for Bimanual Robotic Manipulation Tasks
|
| 2 |
+
|
| 3 |
+
[](https://black.readthedocs.io/en/stable/)
|
| 4 |
+
|
| 5 |
+
This work extends previous work [PerAct](https://peract.github.io) as well as
|
| 6 |
+
[RLBench](https://sites.google.com/view/rlbench) for bimanual manipulation
|
| 7 |
+
tasks.
|
| 8 |
+
|
| 9 |
+
The repository and documentation are still work in progress.
|
| 10 |
+
|
| 11 |
+
For the latest updates, see: [bimanual.github.io](https://bimanual.github.io)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
## Installation
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
Please see [Installation](INSTALLATION.md) for further details.
|
| 18 |
+
|
| 19 |
+
### Prerequisites
|
| 20 |
+
|
| 21 |
+
The code PerAct^2 is built-off the [PerAct](https://peract.github.io) which itself is
|
| 22 |
+
built on the [ARM repository](https://github.com/stepjam/ARM) by James et al.
|
| 23 |
+
The prerequisites are the same as PerAct or ARM.
|
| 24 |
+
|
| 25 |
+
#### 1. Environment
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
Install miniconda if not already present on the current system.
|
| 29 |
+
You can use `scripts/install_conda.sh` for this step:
|
| 30 |
+
|
| 31 |
+
```bash
|
| 32 |
+
|
| 33 |
+
sudo apt install curl
|
| 34 |
+
|
| 35 |
+
curl -L -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
|
| 36 |
+
chmod +x Miniconda3-latest-Linux-x86_64.sh
|
| 37 |
+
./Miniconda3-latest-Linux-x86_64.sh
|
| 38 |
+
|
| 39 |
+
SHELL_NAME=`basename $SHELL`
|
| 40 |
+
eval "$($HOME/miniconda3/bin/conda shell.${SHELL_NAME} hook)"
|
| 41 |
+
conda init ${SHELL_NAME}
|
| 42 |
+
conda install mamba -c conda-forge
|
| 43 |
+
conda config --set auto_activate_base false
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
Next, create the rlbench environment and install the dependencies
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
conda create -n rlbench python=3.8
|
| 50 |
+
conda activate rlbench
|
| 51 |
+
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
#### 2. Dependencies
|
| 56 |
+
|
| 57 |
+
You need to setup [RLBench](https://github.com/markusgrotz/rlbench/), [Pyrep](https://github.com/markusgrotz/Pyrep/), and [YARR](https://github.com/markusgrotz/YARR/).
|
| 58 |
+
Please note that due to the bimanual functionallity the main repository does not work.
|
| 59 |
+
You can use `scripts/install_dependencies.sh` to do so.
|
| 60 |
+
See [Installation](INSTALLATION.md) for details.
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
./scripts/install_dependencies.sh
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
### Pre-Generated Datasets
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
Please checkout the website for [pre-generated RLBench
|
| 72 |
+
demonstrations](https://bimanual.github.io). If you directly use these
|
| 73 |
+
datasets, you don't need to run `tools/bimanual_data_generator.py` from
|
| 74 |
+
RLBench. Using these datasets will also help reproducibility since each scene
|
| 75 |
+
is randomly sampled in `data_generator_bimanual.py`.
|
| 76 |
+
|
| 77 |
+
### Training
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
#### Single-GPU Training
|
| 81 |
+
|
| 82 |
+
To configure and train the model, follow these guidelines:
|
| 83 |
+
|
| 84 |
+
- **General Parameters**: You can find and modify general parameters in the `conf/config.yaml` file. This file contains overall settings for the training environment, such as the number of cameras or the the tasks to use.
|
| 85 |
+
|
| 86 |
+
- **Method-Specific Parameters**: For parameters specific to each method, refer to the corresponding files located in the `conf/method` directory. These files define configurations tailored to each method's requirements.
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
When training adjust the `replay.batch_size` parameter to maximize the utilization of your GPU resources. Increasing this value can improve training efficiency based on the capacity of your available hardware.
|
| 91 |
+
You can either modify the config files directly or you can pass parameters directly through the command line when running the training script. This allows for quick adjustments without editing configuration files:
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
python train.py replay.batch_size=3 method=BIMANUAL_PERACT
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
In this example, the command sets replay.batch_size to 3 and specifies the use of the BIMANUAL_PERACT method for training.
|
| 98 |
+
Another important parameter to specify the tasks is `rlbench.task_name`, which sets the overall task, and `rlbench.tasks`, which is a list of tasks used for training. Note that these can be different for evaluation.
|
| 99 |
+
A complete set of tasks is shown below:
|
| 100 |
+
|
| 101 |
+
```yaml
|
| 102 |
+
|
| 103 |
+
rlbench:
|
| 104 |
+
task_name: multi
|
| 105 |
+
tasks:
|
| 106 |
+
- bimanual_push_box
|
| 107 |
+
- bimanual_lift_ball
|
| 108 |
+
- bimanual_dual_push_buttons
|
| 109 |
+
- bimanual_pick_plate
|
| 110 |
+
- bimanual_put_item_in_drawer
|
| 111 |
+
- bimanual_put_bottle_in_fridge
|
| 112 |
+
- bimanual_handover_item
|
| 113 |
+
- bimanual_pick_laptop
|
| 114 |
+
- bimanual_straighten_rope
|
| 115 |
+
- bimanual_sweep_to_dustpan
|
| 116 |
+
- bimanual_lift_tray
|
| 117 |
+
- bimanual_handover_item_easy
|
| 118 |
+
- bimanual_take_tray_out_of_oven
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
#### Multi-GPU and Multi-Node Training
|
| 123 |
+
|
| 124 |
+
This repository supports multi-GPU training and distributed training across multiple nodes using [PyTorch Distributed Data Parallel (DDP)](https://pytorch.org/docs/stable/notes/ddp.html).
|
| 125 |
+
Follow the instructions below to configure and run training across multiple GPUs and nodes.
|
| 126 |
+
|
| 127 |
+
#### Multi-GPU Training on a Single Node
|
| 128 |
+
|
| 129 |
+
To train using multiple GPUs on a single node, set the parameter `ddp.num_devices` to the number of GPUs available. For example, if you have 4 GPUs, you can start the training process as follows:
|
| 130 |
+
|
| 131 |
+
```bash
|
| 132 |
+
python train.py replay.batch_size=3 method=BIMANUAL_PERACT ddp.num_devices=4
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
This command will utilize 4 GPUs on the current node for training. Remember to set the `replay.batch_size`, which is per GPU.
|
| 136 |
+
|
| 137 |
+
#### Multi-Node Training Across Different Nodes
|
| 138 |
+
|
| 139 |
+
If you want to perform distributed training across multiple nodes, you need to set additional parameters: ddp.master_addr and ddp.master_port. These parameters should be configured as follows:
|
| 140 |
+
|
| 141 |
+
`ddp.master_addr`: The IP address of the master node (usually the node where the training is initiated).
|
| 142 |
+
`ddp.master_port`: A port number to be used for communication across nodes.
|
| 143 |
+
|
| 144 |
+
Example Command:
|
| 145 |
+
|
| 146 |
+
```bash
|
| 147 |
+
python train.py replay.batch_size=3 method=BIMANUAL_PERACT ddp.num_devices=4 ddp.master_addr=192.168.1.1 ddp.master_port=29500
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
Note: Ensure that all nodes can communicate with each other through the specified IP and port, and that they have the same codebase, data access, and configurations for a successful distributed training run.
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
### Evaluation
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
Similar to training you can find general parameters in `conf/eval.yaml` and method specific parameters in the `conf/method` directory.
|
| 158 |
+
For each method, you have to set the execution mode in RLBench. For bimanual agents such as `BIMANUAL_PERACT` or `PERACT_BC` this is:
|
| 159 |
+
|
| 160 |
+
```yaml
|
| 161 |
+
rlbench:
|
| 162 |
+
gripper_mode: 'BimanualDiscrete'
|
| 163 |
+
arm_action_mode: 'BimanualEndEffectorPoseViaPlanning'
|
| 164 |
+
action_mode: 'BimanualMoveArmThenGripper'
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
To generate videos of the current evaluation you can set `cinematic_recorder.enabled` to `True`.
|
| 169 |
+
It is recommended during evalution to disable the recorder, i.e. `cinematic_recorder.enabled=False`, as rendering the video increases the total evaluation time.
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
## Acknowledgements
|
| 173 |
+
|
| 174 |
+
This repository uses code from the following open-source projects:
|
| 175 |
+
|
| 176 |
+
#### ARM
|
| 177 |
+
Original: [https://github.com/stepjam/ARM](https://github.com/stepjam/ARM)
|
| 178 |
+
License: [ARM License](https://github.com/stepjam/ARM/LICENSE)
|
| 179 |
+
Changes: Data loading was modified for PerAct. Voxelization code was modified for DDP training.
|
| 180 |
+
|
| 181 |
+
#### PerceiverIO
|
| 182 |
+
Original: [https://github.com/lucidrains/perceiver-pytorch](https://github.com/lucidrains/perceiver-pytorch)
|
| 183 |
+
License: [MIT](https://github.com/lucidrains/perceiver-pytorch/blob/main/LICENSE)
|
| 184 |
+
Changes: PerceiverIO adapted for 6-DoF manipulation.
|
| 185 |
+
|
| 186 |
+
#### ViT
|
| 187 |
+
Original: [https://github.com/lucidrains/vit-pytorch](https://github.com/lucidrains/vit-pytorch)
|
| 188 |
+
License: [MIT](https://github.com/lucidrains/vit-pytorch/blob/main/LICENSE)
|
| 189 |
+
Changes: ViT adapted for baseline.
|
| 190 |
+
|
| 191 |
+
#### LAMB Optimizer
|
| 192 |
+
Original: [https://github.com/cybertronai/pytorch-lamb](https://github.com/cybertronai/pytorch-lamb)
|
| 193 |
+
License: [MIT](https://github.com/cybertronai/pytorch-lamb/blob/master/LICENSE)
|
| 194 |
+
Changes: None.
|
| 195 |
+
|
| 196 |
+
#### OpenAI CLIP
|
| 197 |
+
Original: [https://github.com/openai/CLIP](https://github.com/openai/CLIP)
|
| 198 |
+
License: [MIT](https://github.com/openai/CLIP/blob/main/LICENSE)
|
| 199 |
+
Changes: Minor modifications to extract token and sentence features.
|
| 200 |
+
|
| 201 |
+
Thanks for open-sourcing!
|
| 202 |
+
|
| 203 |
+
## Licenses
|
| 204 |
+
- [PerAct License (Apache 2.0)](LICENSE) - Perceiver-Actor Transformer
|
| 205 |
+
- [ARM License](ARM_LICENSE) - Voxelization and Data Preprocessing
|
| 206 |
+
- [YARR Licence (Apache 2.0)](https://github.com/stepjam/YARR/blob/main/LICENSE)
|
| 207 |
+
- [RLBench Licence](https://github.com/stepjam/RLBench/blob/master/LICENSE)
|
| 208 |
+
- [PyRep License (MIT)](https://github.com/stepjam/PyRep/blob/master/LICENSE)
|
| 209 |
+
- [Perceiver PyTorch License (MIT)](https://github.com/lucidrains/perceiver-pytorch/blob/main/LICENSE)
|
| 210 |
+
- [LAMB License (MIT)](https://github.com/cybertronai/pytorch-lamb/blob/master/LICENSE)
|
| 211 |
+
- [CLIP License (MIT)](https://github.com/openai/CLIP/blob/main/LICENSE)
|
| 212 |
+
|
| 213 |
+
## Release Notes
|
| 214 |
+
|
| 215 |
+
**Update 2025-02-20**
|
| 216 |
+
|
| 217 |
+
- Update instructions
|
| 218 |
+
- Add missing dependency for install script
|
| 219 |
+
- Add docker build file
|
| 220 |
+
|
| 221 |
+
**Update 2024-11-06**
|
| 222 |
+
|
| 223 |
+
- Regenerat and repack dataset. Closes #13. Task names are now more consistent. Dataset now includes waypoint information.
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
**Update 2024-10-17**
|
| 227 |
+
|
| 228 |
+
- Update Readme
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
**Update 2024-07-10**
|
| 233 |
+
|
| 234 |
+
- Initial release
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
## Citations
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
**PerAct^2**
|
| 241 |
+
```
|
| 242 |
+
@misc{grotz2024peract2,
|
| 243 |
+
title={PerAct2: Benchmarking and Learning for Robotic Bimanual Manipulation Tasks},
|
| 244 |
+
author={Markus Grotz and Mohit Shridhar and Tamim Asfour and Dieter Fox},
|
| 245 |
+
year={2024},
|
| 246 |
+
eprint={2407.00278},
|
| 247 |
+
archivePrefix={arXiv},
|
| 248 |
+
primaryClass={cs.RO},
|
| 249 |
+
url={https://arxiv.org/abs/2407.00278},
|
| 250 |
+
}
|
| 251 |
+
```
|
| 252 |
+
|
| 253 |
+
**PerAct**
|
| 254 |
+
```
|
| 255 |
+
@inproceedings{shridhar2022peract,
|
| 256 |
+
title = {Perceiver-Actor: A Multi-Task Transformer for Robotic Manipulation},
|
| 257 |
+
author = {Shridhar, Mohit and Manuelli, Lucas and Fox, Dieter},
|
| 258 |
+
booktitle = {Proceedings of the 6th Conference on Robot Learning (CoRL)},
|
| 259 |
+
year = {2022},
|
| 260 |
+
}
|
| 261 |
+
```
|
| 262 |
+
|
| 263 |
+
**C2FARM**
|
| 264 |
+
```
|
| 265 |
+
@inproceedings{james2022coarse,
|
| 266 |
+
title={Coarse-to-fine q-attention: Efficient learning for visual robotic manipulation via discretisation},
|
| 267 |
+
author={James, Stephen and Wada, Kentaro and Laidlow, Tristan and Davison, Andrew J},
|
| 268 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
| 269 |
+
pages={13739--13748},
|
| 270 |
+
year={2022}
|
| 271 |
+
}
|
| 272 |
+
```
|
| 273 |
+
|
| 274 |
+
**PerceiverIO**
|
| 275 |
+
```
|
| 276 |
+
@article{jaegle2021perceiver,
|
| 277 |
+
title={Perceiver io: A general architecture for structured inputs \& outputs},
|
| 278 |
+
author={Jaegle, Andrew and Borgeaud, Sebastian and Alayrac, Jean-Baptiste and Doersch, Carl and Ionescu, Catalin and Ding, David and Koppula, Skanda and Zoran, Daniel and Brock, Andrew and Shelhamer, Evan and others},
|
| 279 |
+
journal={arXiv preprint arXiv:2107.14795},
|
| 280 |
+
year={2021}
|
| 281 |
+
}
|
| 282 |
+
```
|
| 283 |
+
|
| 284 |
+
**RLBench**
|
| 285 |
+
```
|
| 286 |
+
@article{james2020rlbench,
|
| 287 |
+
title={Rlbench: The robot learning benchmark \& learning environment},
|
| 288 |
+
author={James, Stephen and Ma, Zicong and Arrojo, David Rovick and Davison, Andrew J},
|
| 289 |
+
journal={IEEE Robotics and Automation Letters},
|
| 290 |
+
volume={5},
|
| 291 |
+
number={2},
|
| 292 |
+
pages={3019--3026},
|
| 293 |
+
year={2020},
|
| 294 |
+
publisher={IEEE}
|
| 295 |
+
}
|
| 296 |
+
```
|
| 297 |
+
|
| 298 |
+
## Questions or Issues?
|
| 299 |
+
|
| 300 |
+
Please file an issue with the issue tracker.
|
external/peract_bimanual/helpers/clip/__init__.py
ADDED
|
File without changes
|
external/peract_bimanual/model-card.md
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Card: Perceiver-Actor
|
| 2 |
+
|
| 3 |
+
Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993) and [Lessons from Archives (Jo & Gebru)](https://arxiv.org/pdf/1912.10389.pdf) we provide additional information on PerAct.
|
| 4 |
+
|
| 5 |
+
## Model Details
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
### Overview
|
| 9 |
+
- Developed by Shridhar et al. at University of Washington and NVIDIA. PerAct is an end-to-end behavior cloning agent that learns to perform a wide variety of language-conditioned manipulation tasks. PerAct uses a Transformer that exploits the 3D structure of _voxel patches_ to learn policies with just a few demonstrations per task.
|
| 10 |
+
- Architecture: Transformer trained from scratch with end-to-end supervised learning.
|
| 11 |
+
- Trained for 6-DoF manipulation tasks with objects that appear in tabletop scenes.
|
| 12 |
+
|
| 13 |
+
### Model Date
|
| 14 |
+
|
| 15 |
+
Nov 2022
|
| 16 |
+
|
| 17 |
+
### Documents
|
| 18 |
+
|
| 19 |
+
- [PerAct Paper](https://peract.github.io/paper/peract_corl2022.pdf)
|
| 20 |
+
- [PerceiverIO Paper](https://arxiv.org/abs/2107.14795)
|
| 21 |
+
- [C2FARM Paper](https://arxiv.org/abs/2106.12534)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
## Model Use
|
| 25 |
+
|
| 26 |
+
- **Primary intended use case**: PerAct is intended for robotic manipulation research. We hope the benchmark and pre-trained models will enable researchers to study the capabilities of Transformers for end-to-end 6-DoF Manipulation. Specifically, we hope the setup serves a reproducible framework for evaluating robustness and scaling capabilities of manipulation agents.
|
| 27 |
+
- **Primary intended users**: Robotics researchers.
|
| 28 |
+
- **Out-of-scope use cases**: Deployed use cases in real-world autonomous systems without human supervision during test-time is currently out-of-scope. Use cases that involve manipulating novel objects and observations with people, are not recommended for safety-critical systems. The agent is also intended to be trained and evaluated with English language instructions.
|
| 29 |
+
|
| 30 |
+
## Data
|
| 31 |
+
|
| 32 |
+
- Pre-training Data for CLIP's language encoder: See [OpenAI's Model Card](https://github.com/openai/CLIP/blob/main/model-card.md#data) for full details. **Note:** We do not use CLIP's vision encoders for any agents in the repo.
|
| 33 |
+
- Manipulation Data for PerAct: The agent was trained with expert demonstrations. In simulation, we use oracle agents and in real-world we use human demonstrations. Since the agent is used in few-shot settings with very limited data, the agent might exploit intended and un-intented biases in the training demonstrations. Currently, these biases are limited to just objects that appear on tabletops.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
## Limitations
|
| 37 |
+
|
| 38 |
+
- Depends on a sampling-based motion planner.
|
| 39 |
+
- Hard to extend to dexterous and continuous manipulation tasks.
|
| 40 |
+
- Lacks memory to solve tasks with ordering and history-based sequencing.
|
| 41 |
+
- Exploits biases in training demonstrations.
|
| 42 |
+
- Needs good hand-eye calibration.
|
| 43 |
+
- Doesn't generalize to novel objects.
|
| 44 |
+
- Struggles with grounding complex spatial relationships.
|
| 45 |
+
- Does not predict task completion.
|
| 46 |
+
|
| 47 |
+
See Appendix L in the [paper](https://peract.github.io/paper/peract_corl2022.pdf) for an extended discussion.
|
external/peract_bimanual/peract_config.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
System configuration for peract
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
import torch.multiprocessing as mp
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def config_logging(logging_level=logging.INFO, reset=False):
|
| 11 |
+
if reset:
|
| 12 |
+
root = logging.getLogger()
|
| 13 |
+
list(map(root.removeHandler, root.handlers))
|
| 14 |
+
list(map(root.removeFilter, root.filters))
|
| 15 |
+
|
| 16 |
+
from rich.logging import RichHandler
|
| 17 |
+
|
| 18 |
+
logging.basicConfig(level=logging_level, handlers=[RichHandler()])
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def on_init():
|
| 22 |
+
config_logging(logging.INFO)
|
| 23 |
+
|
| 24 |
+
logging.debug("Configuring environment.")
|
| 25 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 26 |
+
mp.set_start_method("spawn", force=True)
|
| 27 |
+
mp.set_sharing_strategy("file_system")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def on_config(cfg):
|
| 31 |
+
os.environ["MASTER_ADDR"] = str(cfg.ddp.master_addr)
|
| 32 |
+
os.environ["MASTER_PORT"] = str(cfg.ddp.master_port)
|
external/peract_bimanual/pyproject.toml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.poetry]
|
| 2 |
+
name = "peract_bimanual"
|
| 3 |
+
version = "0.0.1"
|
| 4 |
+
description = "A perceiver actor framework for bimanual manipulation tasks"
|
| 5 |
+
authors = [ "Markus Grotz <grotz@uw.edu>",
|
| 6 |
+
"Mohit Shridhar <mshr@cs.washington.edu>"]
|
| 7 |
+
packages = [{include = "agents"}, {include = "helpers"}, {include = "voxel"}]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
readme = "README.md"
|
| 11 |
+
classifiers = [
|
| 12 |
+
"Programming Language :: Python :: 3",
|
| 13 |
+
"Framework :: Robot Framework "
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
[tool.poetry.dependencies]
|
| 17 |
+
python = ">=3.8,<4.0"
|
| 18 |
+
einops = "0.3.2"
|
| 19 |
+
ftfy = "^6.1.1"
|
| 20 |
+
hydra-core = ">=1.0.5"
|
| 21 |
+
matplotlib = "^3.7.1"
|
| 22 |
+
pandas = "1.4.1"
|
| 23 |
+
regex = "^2023.6.3"
|
| 24 |
+
tensorboard = "^2.13.0"
|
| 25 |
+
perceiver-pytorch = "^0.8.7"
|
| 26 |
+
transformers = "^4.21"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
[tool.poetry.extras]
|
| 31 |
+
docs = ["sphinx"]
|
| 32 |
+
|
| 33 |
+
[build-system]
|
| 34 |
+
requires = ["setuptools", "wheel", "poetry-core>=1.0.0"]
|
| 35 |
+
build-backend = "poetry.core.masonry.api"
|
external/peract_bimanual/run_seed_fn.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import gc
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import hydra
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from omegaconf import DictConfig
|
| 10 |
+
|
| 11 |
+
from rlbench import CameraConfig, ObservationConfig
|
| 12 |
+
from yarr.replay_buffer.wrappers.pytorch_replay_buffer import PyTorchReplayBuffer
|
| 13 |
+
from yarr.runners.offline_train_runner import OfflineTrainRunner
|
| 14 |
+
from yarr.utils.stat_accumulator import SimpleAccumulator
|
| 15 |
+
|
| 16 |
+
from helpers.custom_rlbench_env import CustomRLBenchEnv, CustomMultiTaskRLBenchEnv
|
| 17 |
+
import torch.distributed as dist
|
| 18 |
+
|
| 19 |
+
from agents import agent_factory
|
| 20 |
+
from agents import replay_utils
|
| 21 |
+
|
| 22 |
+
import peract_config
|
| 23 |
+
from functools import partial
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def run_seed(
|
| 27 |
+
rank,
|
| 28 |
+
cfg: DictConfig,
|
| 29 |
+
obs_config: ObservationConfig,
|
| 30 |
+
seed,
|
| 31 |
+
world_size,
|
| 32 |
+
) -> None:
|
| 33 |
+
peract_config.config_logging()
|
| 34 |
+
|
| 35 |
+
dist.init_process_group("gloo", rank=rank, world_size=world_size)
|
| 36 |
+
|
| 37 |
+
tasks = cfg.rlbench.tasks
|
| 38 |
+
cams = cfg.rlbench.cameras
|
| 39 |
+
|
| 40 |
+
task_folder = "multi" if len(tasks) > 1 else tasks[0]
|
| 41 |
+
replay_path = os.path.join(
|
| 42 |
+
cfg.replay.path, task_folder, cfg.method.name, "seed%d" % seed
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
agent = agent_factory.create_agent(cfg)
|
| 46 |
+
|
| 47 |
+
if not agent:
|
| 48 |
+
print("Unable to create agent")
|
| 49 |
+
return
|
| 50 |
+
|
| 51 |
+
if cfg.method.name == "ARM":
|
| 52 |
+
raise NotImplementedError("ARM is not supported yet")
|
| 53 |
+
elif cfg.method.name == "BC_LANG":
|
| 54 |
+
from agents.baselines import bc_lang
|
| 55 |
+
|
| 56 |
+
assert cfg.ddp.num_devices == 1, "BC_LANG only supports single GPU training"
|
| 57 |
+
replay_buffer = bc_lang.launch_utils.create_replay(
|
| 58 |
+
cfg.replay.batch_size,
|
| 59 |
+
cfg.replay.timesteps,
|
| 60 |
+
cfg.replay.prioritisation,
|
| 61 |
+
cfg.replay.task_uniform,
|
| 62 |
+
replay_path if cfg.replay.use_disk else None,
|
| 63 |
+
cams,
|
| 64 |
+
cfg.rlbench.camera_resolution,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
bc_lang.launch_utils.fill_multi_task_replay(
|
| 68 |
+
cfg,
|
| 69 |
+
obs_config,
|
| 70 |
+
rank,
|
| 71 |
+
replay_buffer,
|
| 72 |
+
tasks,
|
| 73 |
+
cfg.rlbench.demos,
|
| 74 |
+
cfg.method.demo_augmentation,
|
| 75 |
+
cfg.method.demo_augmentation_every_n,
|
| 76 |
+
cams,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
elif cfg.method.name == "VIT_BC_LANG":
|
| 80 |
+
from agents.baselines import vit_bc_lang
|
| 81 |
+
|
| 82 |
+
assert cfg.ddp.num_devices == 1, "VIT_BC_LANG only supports single GPU training"
|
| 83 |
+
replay_buffer = vit_bc_lang.launch_utils.create_replay(
|
| 84 |
+
cfg.replay.batch_size,
|
| 85 |
+
cfg.replay.timesteps,
|
| 86 |
+
cfg.replay.prioritisation,
|
| 87 |
+
cfg.replay.task_uniform,
|
| 88 |
+
replay_path if cfg.replay.use_disk else None,
|
| 89 |
+
cams,
|
| 90 |
+
cfg.rlbench.camera_resolution,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
vit_bc_lang.launch_utils.fill_multi_task_replay(
|
| 94 |
+
cfg,
|
| 95 |
+
obs_config,
|
| 96 |
+
rank,
|
| 97 |
+
replay_buffer,
|
| 98 |
+
tasks,
|
| 99 |
+
cfg.rlbench.demos,
|
| 100 |
+
cfg.method.demo_augmentation,
|
| 101 |
+
cfg.method.demo_augmentation_every_n,
|
| 102 |
+
cams,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
elif cfg.method.name.startswith("ACT_BC_LANG"):
|
| 106 |
+
from agents import act_bc_lang
|
| 107 |
+
|
| 108 |
+
assert cfg.ddp.num_devices == 1, "ACT_BC_LANG only supports single GPU training"
|
| 109 |
+
replay_buffer = act_bc_lang.launch_utils.create_replay(
|
| 110 |
+
cfg.replay.batch_size,
|
| 111 |
+
cfg.replay.timesteps,
|
| 112 |
+
cfg.replay.prioritisation,
|
| 113 |
+
cfg.replay.task_uniform,
|
| 114 |
+
replay_path if cfg.replay.use_disk else None,
|
| 115 |
+
cams,
|
| 116 |
+
cfg.rlbench.camera_resolution,
|
| 117 |
+
replay_size=3e5,
|
| 118 |
+
prev_action_horizon=cfg.method.prev_action_horizon,
|
| 119 |
+
next_action_horizon=cfg.method.next_action_horizon,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
act_bc_lang.launch_utils.fill_multi_task_replay(
|
| 123 |
+
cfg,
|
| 124 |
+
obs_config,
|
| 125 |
+
rank,
|
| 126 |
+
replay_buffer,
|
| 127 |
+
tasks,
|
| 128 |
+
cfg.rlbench.demos,
|
| 129 |
+
cfg.method.demo_augmentation,
|
| 130 |
+
cfg.method.demo_augmentation_every_n,
|
| 131 |
+
cams,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
elif cfg.method.name == "C2FARM_LINGUNET_BC":
|
| 135 |
+
from agents import c2farm_lingunet_bc
|
| 136 |
+
|
| 137 |
+
replay_buffer = c2farm_lingunet_bc.launch_utils.create_replay(
|
| 138 |
+
cfg.replay.batch_size,
|
| 139 |
+
cfg.replay.timesteps,
|
| 140 |
+
cfg.replay.prioritisation,
|
| 141 |
+
cfg.replay.task_uniform,
|
| 142 |
+
replay_path if cfg.replay.use_disk else None,
|
| 143 |
+
cams,
|
| 144 |
+
cfg.method.voxel_sizes,
|
| 145 |
+
cfg.rlbench.camera_resolution,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
c2farm_lingunet_bc.launch_utils.fill_multi_task_replay(
|
| 149 |
+
cfg,
|
| 150 |
+
obs_config,
|
| 151 |
+
rank,
|
| 152 |
+
replay_buffer,
|
| 153 |
+
tasks,
|
| 154 |
+
cfg.rlbench.demos,
|
| 155 |
+
cfg.method.demo_augmentation,
|
| 156 |
+
cfg.method.demo_augmentation_every_n,
|
| 157 |
+
cams,
|
| 158 |
+
cfg.rlbench.scene_bounds,
|
| 159 |
+
cfg.method.voxel_sizes,
|
| 160 |
+
cfg.method.bounds_offset,
|
| 161 |
+
cfg.method.rotation_resolution,
|
| 162 |
+
cfg.method.crop_augmentation,
|
| 163 |
+
keypoint_method=cfg.method.keypoint_method,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
elif (
|
| 167 |
+
cfg.method.name.startswith("BIMANUAL_PERACT")
|
| 168 |
+
or cfg.method.name.startswith("RVT")
|
| 169 |
+
or cfg.method.name.startswith("PERACT_BC")
|
| 170 |
+
):
|
| 171 |
+
replay_buffer = replay_utils.create_replay(cfg, replay_path)
|
| 172 |
+
|
| 173 |
+
replay_utils.fill_multi_task_replay(cfg, obs_config, rank, replay_buffer, tasks)
|
| 174 |
+
|
| 175 |
+
elif cfg.method.name == "PERACT_RL":
|
| 176 |
+
raise NotImplementedError("PERACT_RL is not supported yet")
|
| 177 |
+
|
| 178 |
+
else:
|
| 179 |
+
raise ValueError("Method %s does not exists." % cfg.method.name)
|
| 180 |
+
|
| 181 |
+
wrapped_replay = PyTorchReplayBuffer(
|
| 182 |
+
replay_buffer, num_workers=cfg.framework.num_workers
|
| 183 |
+
)
|
| 184 |
+
stat_accum = SimpleAccumulator(eval_video_fps=30)
|
| 185 |
+
|
| 186 |
+
cwd = os.getcwd()
|
| 187 |
+
weightsdir = os.path.join(cwd, "seed%d" % seed, "weights")
|
| 188 |
+
logdir = os.path.join(cwd, "seed%d" % seed)
|
| 189 |
+
|
| 190 |
+
train_runner = OfflineTrainRunner(
|
| 191 |
+
agent=agent,
|
| 192 |
+
wrapped_replay_buffer=wrapped_replay,
|
| 193 |
+
train_device=rank,
|
| 194 |
+
stat_accumulator=stat_accum,
|
| 195 |
+
iterations=cfg.framework.training_iterations,
|
| 196 |
+
logdir=logdir,
|
| 197 |
+
logging_level=cfg.framework.logging_level,
|
| 198 |
+
log_freq=cfg.framework.log_freq,
|
| 199 |
+
weightsdir=weightsdir,
|
| 200 |
+
num_weights_to_keep=cfg.framework.num_weights_to_keep,
|
| 201 |
+
save_freq=cfg.framework.save_freq,
|
| 202 |
+
tensorboard_logging=cfg.framework.tensorboard_logging,
|
| 203 |
+
csv_logging=cfg.framework.csv_logging,
|
| 204 |
+
load_existing_weights=cfg.framework.load_existing_weights,
|
| 205 |
+
rank=rank,
|
| 206 |
+
world_size=world_size,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
train_runner._on_thread_start = partial(
|
| 210 |
+
peract_config.config_logging, cfg.framework.logging_level
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
train_runner.start()
|
| 214 |
+
|
| 215 |
+
del train_runner
|
| 216 |
+
del agent
|
| 217 |
+
gc.collect()
|
| 218 |
+
torch.cuda.empty_cache()
|
external/peract_bimanual/train.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
|
| 7 |
+
import peract_config
|
| 8 |
+
|
| 9 |
+
import hydra
|
| 10 |
+
from omegaconf import DictConfig, OmegaConf, ListConfig
|
| 11 |
+
|
| 12 |
+
import run_seed_fn
|
| 13 |
+
from helpers.observation_utils import create_obs_config
|
| 14 |
+
|
| 15 |
+
import torch.multiprocessing as mp
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@hydra.main(config_name="config", config_path="conf")
|
| 19 |
+
def main(cfg: DictConfig) -> None:
|
| 20 |
+
cfg_yaml = OmegaConf.to_yaml(cfg)
|
| 21 |
+
logging.info("\n" + cfg_yaml)
|
| 22 |
+
|
| 23 |
+
peract_config.on_config(cfg)
|
| 24 |
+
|
| 25 |
+
cfg.rlbench.cameras = (
|
| 26 |
+
cfg.rlbench.cameras
|
| 27 |
+
if isinstance(cfg.rlbench.cameras, ListConfig)
|
| 28 |
+
else [cfg.rlbench.cameras]
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# sanity check if rgb is not used as camera name
|
| 32 |
+
for camera_name in cfg.rlbench.cameras:
|
| 33 |
+
assert "rgb" not in camera_name
|
| 34 |
+
|
| 35 |
+
obs_config = create_obs_config(
|
| 36 |
+
cfg.rlbench.cameras, cfg.rlbench.camera_resolution, cfg.method.name
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
cwd = os.getcwd()
|
| 40 |
+
logging.info("CWD:" + os.getcwd())
|
| 41 |
+
|
| 42 |
+
if cfg.framework.start_seed >= 0:
|
| 43 |
+
# seed specified
|
| 44 |
+
start_seed = cfg.framework.start_seed
|
| 45 |
+
elif (
|
| 46 |
+
cfg.framework.start_seed == -1
|
| 47 |
+
and len(list(filter(lambda x: "seed" in x, os.listdir(cwd)))) > 0
|
| 48 |
+
):
|
| 49 |
+
# unspecified seed; use largest existing seed plus one
|
| 50 |
+
largest_seed = max(
|
| 51 |
+
[
|
| 52 |
+
int(n.replace("seed", ""))
|
| 53 |
+
for n in list(filter(lambda x: "seed" in x, os.listdir(cwd)))
|
| 54 |
+
]
|
| 55 |
+
)
|
| 56 |
+
start_seed = largest_seed + 1
|
| 57 |
+
else:
|
| 58 |
+
# start with seed 0
|
| 59 |
+
start_seed = 0
|
| 60 |
+
|
| 61 |
+
seed_folder = os.path.join(os.getcwd(), "seed%d" % start_seed)
|
| 62 |
+
os.makedirs(seed_folder, exist_ok=True)
|
| 63 |
+
|
| 64 |
+
start_time = datetime.now()
|
| 65 |
+
with open(os.path.join(seed_folder, "config.yaml"), "w") as f:
|
| 66 |
+
f.write(cfg_yaml)
|
| 67 |
+
|
| 68 |
+
# check if previous checkpoints already exceed the number of desired training iterations
|
| 69 |
+
# if so, exit the script
|
| 70 |
+
latest_weight = 0
|
| 71 |
+
weights_folder = os.path.join(seed_folder, "weights")
|
| 72 |
+
if os.path.isdir(weights_folder) and len(os.listdir(weights_folder)) > 0:
|
| 73 |
+
weights = os.listdir(weights_folder)
|
| 74 |
+
latest_weight = sorted(map(int, weights))[-1]
|
| 75 |
+
if latest_weight >= cfg.framework.training_iterations:
|
| 76 |
+
logging.info(
|
| 77 |
+
"Agent was already trained for %d iterations. Exiting." % latest_weight
|
| 78 |
+
)
|
| 79 |
+
sys.exit(0)
|
| 80 |
+
|
| 81 |
+
with open(os.path.join(seed_folder, "training.log"), "a") as f:
|
| 82 |
+
f.write(
|
| 83 |
+
f"# Starting training from weights: {latest_weight} to {cfg.framework.training_iterations}"
|
| 84 |
+
)
|
| 85 |
+
f.write(f"# Training started on: {start_time.isoformat()}")
|
| 86 |
+
f.write(os.linesep)
|
| 87 |
+
|
| 88 |
+
# run train jobs with multiple seeds (sequentially)
|
| 89 |
+
for seed in range(start_seed, start_seed + cfg.framework.seeds):
|
| 90 |
+
logging.info("Starting seed %d." % seed)
|
| 91 |
+
|
| 92 |
+
world_size = cfg.ddp.num_devices
|
| 93 |
+
mp.spawn(
|
| 94 |
+
run_seed_fn.run_seed,
|
| 95 |
+
args=(
|
| 96 |
+
cfg,
|
| 97 |
+
obs_config,
|
| 98 |
+
seed,
|
| 99 |
+
world_size,
|
| 100 |
+
),
|
| 101 |
+
nprocs=world_size,
|
| 102 |
+
join=True,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
end_time = datetime.now()
|
| 106 |
+
duration = end_time - start_time
|
| 107 |
+
with open(os.path.join(seed_folder, "training.log"), "a") as f:
|
| 108 |
+
f.write(f"# Training finished on: {end_time.isoformat()}")
|
| 109 |
+
f.write(f"# Took {duration.total_seconds()}")
|
| 110 |
+
f.write(os.linesep)
|
| 111 |
+
f.write(os.linesep)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
peract_config.on_init()
|
| 116 |
+
main()
|
external/peract_bimanual/voxel/__init__.py
ADDED
|
File without changes
|
external/peract_bimanual/voxel/voxel_grid.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Voxelizer modified from ARM for DDP training
|
| 2 |
+
# Source: https://github.com/stepjam/ARM
|
| 3 |
+
# License: https://github.com/stepjam/ARM/LICENSE
|
| 4 |
+
|
| 5 |
+
from functools import reduce
|
| 6 |
+
from operator import mul
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
MIN_DENOMINATOR = 1e-12
|
| 12 |
+
INCLUDE_PER_VOXEL_COORD = False
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class VoxelGrid(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
coord_bounds,
|
| 19 |
+
voxel_size: int,
|
| 20 |
+
device,
|
| 21 |
+
batch_size,
|
| 22 |
+
feature_size, # e.g. rgb or image features
|
| 23 |
+
max_num_coords: int,
|
| 24 |
+
):
|
| 25 |
+
super(VoxelGrid, self).__init__()
|
| 26 |
+
self._device = device
|
| 27 |
+
self._voxel_size = voxel_size
|
| 28 |
+
self._voxel_shape = [voxel_size] * 3
|
| 29 |
+
self._voxel_d = float(self._voxel_shape[-1])
|
| 30 |
+
self._voxel_feature_size = 4 + feature_size
|
| 31 |
+
self._voxel_shape_spec = (
|
| 32 |
+
torch.tensor(
|
| 33 |
+
self._voxel_shape,
|
| 34 |
+
).unsqueeze(0)
|
| 35 |
+
+ 2
|
| 36 |
+
) # +2 because we crop the edges.
|
| 37 |
+
self._coord_bounds = torch.tensor(
|
| 38 |
+
coord_bounds,
|
| 39 |
+
dtype=torch.float,
|
| 40 |
+
).unsqueeze(0)
|
| 41 |
+
max_dims = self._voxel_shape_spec[0]
|
| 42 |
+
self._total_dims_list = torch.cat(
|
| 43 |
+
[
|
| 44 |
+
torch.tensor(
|
| 45 |
+
[batch_size],
|
| 46 |
+
),
|
| 47 |
+
max_dims,
|
| 48 |
+
torch.tensor(
|
| 49 |
+
[4 + feature_size],
|
| 50 |
+
),
|
| 51 |
+
],
|
| 52 |
+
-1,
|
| 53 |
+
).tolist()
|
| 54 |
+
|
| 55 |
+
self.register_buffer(
|
| 56 |
+
"_ones_max_coords", torch.ones((batch_size, max_num_coords, 1))
|
| 57 |
+
)
|
| 58 |
+
self._num_coords = max_num_coords
|
| 59 |
+
|
| 60 |
+
shape = self._total_dims_list
|
| 61 |
+
result_dim_sizes = torch.tensor(
|
| 62 |
+
[reduce(mul, shape[i + 1 :], 1) for i in range(len(shape) - 1)] + [1],
|
| 63 |
+
)
|
| 64 |
+
self.register_buffer("_result_dim_sizes", result_dim_sizes)
|
| 65 |
+
flat_result_size = reduce(mul, shape, 1)
|
| 66 |
+
|
| 67 |
+
self._initial_val = torch.tensor(0, dtype=torch.float)
|
| 68 |
+
flat_output = (
|
| 69 |
+
torch.ones(flat_result_size, dtype=torch.float) * self._initial_val
|
| 70 |
+
)
|
| 71 |
+
self.register_buffer("_flat_output", flat_output)
|
| 72 |
+
|
| 73 |
+
self.register_buffer("_arange_to_max_coords", torch.arange(4 + feature_size))
|
| 74 |
+
self._flat_zeros = torch.zeros(flat_result_size, dtype=torch.float)
|
| 75 |
+
|
| 76 |
+
self._const_1 = torch.tensor(
|
| 77 |
+
1.0,
|
| 78 |
+
)
|
| 79 |
+
self._batch_size = batch_size
|
| 80 |
+
|
| 81 |
+
# Coordinate Bounds:
|
| 82 |
+
bb_mins = self._coord_bounds[..., 0:3]
|
| 83 |
+
self.register_buffer("_bb_mins", bb_mins)
|
| 84 |
+
bb_maxs = self._coord_bounds[..., 3:6]
|
| 85 |
+
bb_ranges = bb_maxs - bb_mins
|
| 86 |
+
# get voxel dimensions. 'DIMS' mode
|
| 87 |
+
self._dims = dims = self._voxel_shape_spec.int()
|
| 88 |
+
dims_orig = self._voxel_shape_spec.int() - 2
|
| 89 |
+
self.register_buffer("_dims_orig", dims_orig)
|
| 90 |
+
|
| 91 |
+
# self._dims_m_one = (dims - 1).int()
|
| 92 |
+
dims_m_one = (dims - 1).int()
|
| 93 |
+
self.register_buffer("_dims_m_one", dims_m_one)
|
| 94 |
+
|
| 95 |
+
# BS x 1 x 3
|
| 96 |
+
res = bb_ranges / (dims_orig.float() + MIN_DENOMINATOR)
|
| 97 |
+
self._res_minis_2 = bb_ranges / (dims.float() - 2 + MIN_DENOMINATOR)
|
| 98 |
+
self.register_buffer("_res", res)
|
| 99 |
+
|
| 100 |
+
voxel_indicy_denmominator = res + MIN_DENOMINATOR
|
| 101 |
+
self.register_buffer("_voxel_indicy_denmominator", voxel_indicy_denmominator)
|
| 102 |
+
|
| 103 |
+
self.register_buffer("_dims_m_one_zeros", torch.zeros_like(dims_m_one))
|
| 104 |
+
|
| 105 |
+
batch_indices = torch.arange(self._batch_size, dtype=torch.int).view(
|
| 106 |
+
self._batch_size, 1, 1
|
| 107 |
+
)
|
| 108 |
+
self.register_buffer(
|
| 109 |
+
"_tiled_batch_indices", batch_indices.repeat([1, self._num_coords, 1])
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
w = self._voxel_shape[0] + 2
|
| 113 |
+
arange = torch.arange(
|
| 114 |
+
0,
|
| 115 |
+
w,
|
| 116 |
+
dtype=torch.float,
|
| 117 |
+
)
|
| 118 |
+
index_grid = (
|
| 119 |
+
torch.cat(
|
| 120 |
+
[
|
| 121 |
+
arange.view(w, 1, 1, 1).repeat([1, w, w, 1]),
|
| 122 |
+
arange.view(1, w, 1, 1).repeat([w, 1, w, 1]),
|
| 123 |
+
arange.view(1, 1, w, 1).repeat([w, w, 1, 1]),
|
| 124 |
+
],
|
| 125 |
+
dim=-1,
|
| 126 |
+
)
|
| 127 |
+
.unsqueeze(0)
|
| 128 |
+
.repeat([self._batch_size, 1, 1, 1, 1])
|
| 129 |
+
)
|
| 130 |
+
self.register_buffer("_index_grid", index_grid)
|
| 131 |
+
|
| 132 |
+
def _broadcast(self, src: torch.Tensor, other: torch.Tensor, dim: int):
|
| 133 |
+
if dim < 0:
|
| 134 |
+
dim = other.dim() + dim
|
| 135 |
+
if src.dim() == 1:
|
| 136 |
+
for _ in range(0, dim):
|
| 137 |
+
src = src.unsqueeze(0)
|
| 138 |
+
for _ in range(src.dim(), other.dim()):
|
| 139 |
+
src = src.unsqueeze(-1)
|
| 140 |
+
src = src.expand_as(other)
|
| 141 |
+
return src
|
| 142 |
+
|
| 143 |
+
def _scatter_mean(
|
| 144 |
+
self, src: torch.Tensor, index: torch.Tensor, out: torch.Tensor, dim: int = -1
|
| 145 |
+
):
|
| 146 |
+
out = out.scatter_add_(dim, index, src)
|
| 147 |
+
|
| 148 |
+
index_dim = dim
|
| 149 |
+
if index_dim < 0:
|
| 150 |
+
index_dim = index_dim + src.dim()
|
| 151 |
+
if index.dim() <= index_dim:
|
| 152 |
+
index_dim = index.dim() - 1
|
| 153 |
+
|
| 154 |
+
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
|
| 155 |
+
out_count = torch.zeros(out.size(), dtype=out.dtype, device=out.device)
|
| 156 |
+
out_count = out_count.scatter_add_(index_dim, index, ones)
|
| 157 |
+
out_count.clamp_(1)
|
| 158 |
+
count = self._broadcast(out_count, out, dim)
|
| 159 |
+
if torch.is_floating_point(out):
|
| 160 |
+
out.true_divide_(count)
|
| 161 |
+
else:
|
| 162 |
+
out.floor_divide_(count)
|
| 163 |
+
return out
|
| 164 |
+
|
| 165 |
+
def _scatter_nd(self, indices, updates):
|
| 166 |
+
indices_shape = indices.shape
|
| 167 |
+
num_index_dims = indices_shape[-1]
|
| 168 |
+
flat_updates = updates.view((-1,))
|
| 169 |
+
indices_scales = self._result_dim_sizes[0:num_index_dims].view(
|
| 170 |
+
[1] * (len(indices_shape) - 1) + [num_index_dims]
|
| 171 |
+
)
|
| 172 |
+
indices_for_flat_tiled = (
|
| 173 |
+
((indices * indices_scales).sum(dim=-1, keepdims=True))
|
| 174 |
+
.view(-1, 1)
|
| 175 |
+
.repeat(*[1, self._voxel_feature_size])
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
implicit_indices = (
|
| 179 |
+
self._arange_to_max_coords[: self._voxel_feature_size]
|
| 180 |
+
.unsqueeze(0)
|
| 181 |
+
.repeat(*[indices_for_flat_tiled.shape[0], 1])
|
| 182 |
+
)
|
| 183 |
+
indices_for_flat = indices_for_flat_tiled + implicit_indices
|
| 184 |
+
flat_indices_for_flat = indices_for_flat.view((-1,)).long()
|
| 185 |
+
|
| 186 |
+
flat_scatter = self._scatter_mean(
|
| 187 |
+
flat_updates, flat_indices_for_flat, out=torch.zeros_like(self._flat_output)
|
| 188 |
+
)
|
| 189 |
+
return flat_scatter.view(self._total_dims_list)
|
| 190 |
+
|
| 191 |
+
def coords_to_bounding_voxel_grid(
|
| 192 |
+
self, coords, coord_features=None, coord_bounds=None
|
| 193 |
+
):
|
| 194 |
+
voxel_indicy_denmominator = self._voxel_indicy_denmominator
|
| 195 |
+
res, bb_mins = self._res, self._bb_mins
|
| 196 |
+
if coord_bounds is not None:
|
| 197 |
+
bb_mins = coord_bounds[..., 0:3]
|
| 198 |
+
bb_maxs = coord_bounds[..., 3:6]
|
| 199 |
+
bb_ranges = bb_maxs - bb_mins
|
| 200 |
+
res = bb_ranges / (self._dims_orig.float() + MIN_DENOMINATOR)
|
| 201 |
+
voxel_indicy_denmominator = res + MIN_DENOMINATOR
|
| 202 |
+
|
| 203 |
+
bb_mins_shifted = bb_mins - res # shift back by one
|
| 204 |
+
floor = torch.floor(
|
| 205 |
+
(coords - bb_mins_shifted.unsqueeze(1))
|
| 206 |
+
/ voxel_indicy_denmominator.unsqueeze(1)
|
| 207 |
+
).int()
|
| 208 |
+
voxel_indices = torch.min(floor, self._dims_m_one)
|
| 209 |
+
voxel_indices = torch.max(voxel_indices, self._dims_m_one_zeros)
|
| 210 |
+
|
| 211 |
+
# BS x NC x 3
|
| 212 |
+
voxel_values = coords
|
| 213 |
+
if coord_features is not None:
|
| 214 |
+
voxel_values = torch.cat([voxel_values, coord_features], -1)
|
| 215 |
+
|
| 216 |
+
_, num_coords, _ = voxel_indices.shape
|
| 217 |
+
# BS x N x (num_batch_dims + 2)
|
| 218 |
+
all_indices = torch.cat(
|
| 219 |
+
[self._tiled_batch_indices[:, :num_coords], voxel_indices], -1
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# BS x N x 4
|
| 223 |
+
voxel_values_pruned_flat = torch.cat(
|
| 224 |
+
[voxel_values, self._ones_max_coords[:, :num_coords]], -1
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# BS x x_max x y_max x z_max x 4
|
| 228 |
+
scattered = self._scatter_nd(
|
| 229 |
+
all_indices.view([-1, 1 + 3]),
|
| 230 |
+
voxel_values_pruned_flat.view(-1, self._voxel_feature_size),
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
vox = scattered[:, 1:-1, 1:-1, 1:-1]
|
| 234 |
+
if INCLUDE_PER_VOXEL_COORD:
|
| 235 |
+
res_expanded = res.unsqueeze(1).unsqueeze(1).unsqueeze(1)
|
| 236 |
+
res_centre = (res_expanded * self._index_grid) + res_expanded / 2.0
|
| 237 |
+
coord_positions = (
|
| 238 |
+
res_centre + bb_mins_shifted.unsqueeze(1).unsqueeze(1).unsqueeze(1)
|
| 239 |
+
)[:, 1:-1, 1:-1, 1:-1]
|
| 240 |
+
vox = torch.cat([vox[..., :-1], coord_positions, vox[..., -1:]], -1)
|
| 241 |
+
|
| 242 |
+
occupied = (vox[..., -1:] > 0).float()
|
| 243 |
+
vox = torch.cat([vox[..., :-1], occupied], -1)
|
| 244 |
+
|
| 245 |
+
return torch.cat(
|
| 246 |
+
[
|
| 247 |
+
vox[..., :-1],
|
| 248 |
+
self._index_grid[:, :-2, :-2, :-2] / self._voxel_d,
|
| 249 |
+
vox[..., -1:],
|
| 250 |
+
],
|
| 251 |
+
-1,
|
| 252 |
+
)
|
external/yarr/.gitignore
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
venv
|
| 3 |
+
.idea
|
| 4 |
+
.bash_history
|
| 5 |
+
.cache/
|
| 6 |
+
.local/
|
| 7 |
+
.python_history
|
| 8 |
+
nvidia-persistenced/
|
| 9 |
+
results/
|
| 10 |
+
rlight.egg-info
|
| 11 |
+
dist/
|
| 12 |
+
build/
|
| 13 |
+
yarr.egg-info
|
external/yarr/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
external/yarr/README.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+

|
| 2 |
+
|
| 3 |
+
**Note**: Pirate qualification not needed to use this library.
|
| 4 |
+
|
| 5 |
+
YARR is **Y**et **A**nother **R**obotics and **R**einforcement learning framework for PyTorch.
|
| 6 |
+
|
| 7 |
+
The framework allows for asynchronous training (i.e. agent and learner running in separate processes), which makes it suitable for robot learning.
|
| 8 |
+
For an example of how to use this framework, see my [Attention-driven Robot Manipulation (ARM) repo](https://github.com/stepjam/ARM).
|
| 9 |
+
|
| 10 |
+
This project is mostly intended for my personal use (Stephen James) and facilitate my research.
|
| 11 |
+
|
| 12 |
+
## Modifcations
|
| 13 |
+
|
| 14 |
+
This is my (Mohit Shridhar) fork of YARR. Honestly, I don't understand what exactly is happening in a lot of places, so there a lot of hacks to make it work for my purposes. If you are doing simple behavior cloning, you can probably write simpler training and evaluation routines, but YARR might be useful if you also want to do RL. Here is a quick summary of my modifcations:
|
| 15 |
+
|
| 16 |
+
- Switched from randomly sampling evaluation episodes to deterministic reloading of val/test dataset episodes for one-to-one comparisons across models.
|
| 17 |
+
- Separated training and evaluation routines.
|
| 18 |
+
- Task-uniform replay buffer for multi-task training. Each batch has a uniform distribution of tasks.
|
| 19 |
+
- Added cinematic recorder for rollouts.
|
| 20 |
+
- Some other weird hacks to prevent memory leaks.
|
| 21 |
+
|
| 22 |
+
## Install
|
| 23 |
+
|
| 24 |
+
Ensure you have [PyTorch installed](https://pytorch.org/get-started/locally/).
|
| 25 |
+
Then simply run:
|
| 26 |
+
```bash
|
| 27 |
+
python setup.py develop
|
| 28 |
+
```
|
external/yarr/logo.png
ADDED
|
external/yarr/requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tensorboard
|
| 2 |
+
moviepy
|
| 3 |
+
natsort
|
| 4 |
+
psutil
|
| 5 |
+
timeout-decorator
|
| 6 |
+
pyrender==0.1.45
|
| 7 |
+
omegaconf
|
| 8 |
+
hydra-core
|
| 9 |
+
pandas==1.4.1
|
| 10 |
+
opencv-python
|
| 11 |
+
|
external/yarr/setup.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import codecs
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import setuptools
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def read(rel_path):
|
| 8 |
+
here = os.path.abspath(os.path.dirname(__file__))
|
| 9 |
+
with codecs.open(os.path.join(here, rel_path), 'r') as fp:
|
| 10 |
+
return fp.read()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_version(rel_path):
|
| 14 |
+
for line in read(rel_path).splitlines():
|
| 15 |
+
if line.startswith('__version__'):
|
| 16 |
+
delim = '"' if '"' in line else "'"
|
| 17 |
+
return line.split(delim)[1]
|
| 18 |
+
else:
|
| 19 |
+
raise RuntimeError("Unable to find version string.")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_install_requires():
|
| 23 |
+
install_requires = []
|
| 24 |
+
with open('requirements.txt') as f:
|
| 25 |
+
for req in f:
|
| 26 |
+
install_requires.append(req.strip())
|
| 27 |
+
return install_requires
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
setuptools.setup(
|
| 31 |
+
version=get_version("yarr/__init__.py"),
|
| 32 |
+
name='yarr',
|
| 33 |
+
author='Stephen James',
|
| 34 |
+
author_email='slj12@ic.ac.uk',
|
| 35 |
+
packages=setuptools.find_packages(),
|
| 36 |
+
install_requires=get_install_requires()
|
| 37 |
+
)
|
external/yarr/yarr/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__version__ = '0.1'
|
external/yarr/yarr/agents/__init__.py
ADDED
|
File without changes
|
external/yarr/yarr/agents/agent.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Any, List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Summary(object):
|
| 6 |
+
def __init__(self, name: str, value: Any):
|
| 7 |
+
self.name = name
|
| 8 |
+
self.value = value
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ScalarSummary(Summary):
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class HistogramSummary(Summary):
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ImageSummary(Summary):
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TextSummary(Summary):
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class VideoSummary(Summary):
|
| 28 |
+
def __init__(self, name: str, value: Any, fps: int = 30):
|
| 29 |
+
super(VideoSummary, self).__init__(name, value)
|
| 30 |
+
self.fps = fps
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ActResult(object):
|
| 34 |
+
|
| 35 |
+
def __init__(self, action: Any,
|
| 36 |
+
observation_elements: dict = None,
|
| 37 |
+
replay_elements: dict = None,
|
| 38 |
+
info: dict = None):
|
| 39 |
+
self.action = action
|
| 40 |
+
self.observation_elements = observation_elements or {}
|
| 41 |
+
self.replay_elements = replay_elements or {}
|
| 42 |
+
self.info = info or {}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Agent(ABC):
|
| 46 |
+
|
| 47 |
+
@abstractmethod
|
| 48 |
+
def build(self, training: bool, device=None) -> None:
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
@abstractmethod
|
| 52 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
@abstractmethod
|
| 56 |
+
def act(self, step: int, observation: dict, deterministic: bool) -> ActResult:
|
| 57 |
+
# returns dict of values that get put in the replay.
|
| 58 |
+
# One of these must be 'action'.
|
| 59 |
+
pass
|
| 60 |
+
|
| 61 |
+
def reset(self) -> None:
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
@abstractmethod
|
| 65 |
+
def update_summaries(self) -> List[Summary]:
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
@abstractmethod
|
| 69 |
+
def act_summaries(self) -> List[Summary]:
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
@abstractmethod
|
| 73 |
+
def load_weights(self, savedir: str) -> None:
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
@abstractmethod
|
| 77 |
+
def save_weights(self, savedir: str) -> None:
|
| 78 |
+
pass
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class BimanualAgent(Agent):
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(self, right_agent: Agent, left_agent: Agent):
|
| 87 |
+
self.right_agent = right_agent
|
| 88 |
+
self.left_agent = left_agent
|
| 89 |
+
self._summaries = {}
|
| 90 |
+
|
| 91 |
+
def build(self, training: bool, device=None) -> None:
|
| 92 |
+
self.right_agent.build(training, device)
|
| 93 |
+
self.left_agent.build(training, device)
|
| 94 |
+
|
| 95 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 96 |
+
right_observation = {}
|
| 97 |
+
left_observation = {}
|
| 98 |
+
|
| 99 |
+
for k, v in replay_sample.items():
|
| 100 |
+
if "rgb" in k or "point_cloud" in k or "camera" in k:
|
| 101 |
+
right_observation[k] = v
|
| 102 |
+
left_observation[k] = v
|
| 103 |
+
elif "right_" in k :
|
| 104 |
+
right_observation[k[6:]] = v
|
| 105 |
+
elif "left_" in k:
|
| 106 |
+
left_observation[k[5:]] = v
|
| 107 |
+
else:
|
| 108 |
+
right_observation[k] = v
|
| 109 |
+
left_observation[k] = v
|
| 110 |
+
|
| 111 |
+
action = replay_sample["action"]
|
| 112 |
+
right_action, left_action = action.chunk(2, dim=2)
|
| 113 |
+
right_observation["action"] = right_action
|
| 114 |
+
left_observation["action"] = left_action
|
| 115 |
+
|
| 116 |
+
right_update_dict = self.right_agent.update(step, right_observation)
|
| 117 |
+
left_update_dict = self.left_agent.update(step, left_observation)
|
| 118 |
+
|
| 119 |
+
total_losses = right_update_dict["total_losses"] + left_update_dict["total_losses"]
|
| 120 |
+
self._summaries.update({"total_losses": total_losses})
|
| 121 |
+
return self._summaries
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def act(self, step: int, observation: dict, deterministic: bool) -> ActResult:
|
| 125 |
+
|
| 126 |
+
observation_elements = {}
|
| 127 |
+
info = {}
|
| 128 |
+
|
| 129 |
+
right_observation = {}
|
| 130 |
+
left_observation = {}
|
| 131 |
+
|
| 132 |
+
for k, v in observation.items():
|
| 133 |
+
if "rgb" in k or "point_cloud" in k or "camera" in k:
|
| 134 |
+
right_observation[k] = v
|
| 135 |
+
left_observation[k] = v
|
| 136 |
+
elif "right_" in k :
|
| 137 |
+
right_observation[k[6:]] = v
|
| 138 |
+
elif "left_" in k:
|
| 139 |
+
left_observation[k[5:]] = v
|
| 140 |
+
else:
|
| 141 |
+
right_observation[k] = v
|
| 142 |
+
left_observation[k] = v
|
| 143 |
+
|
| 144 |
+
right_act_result = self.right_agent.act(step, right_observation, deterministic)
|
| 145 |
+
left_act_result = self.left_agent.act(step, left_observation, deterministic)
|
| 146 |
+
|
| 147 |
+
action = (*right_act_result.action, *left_act_result.action)
|
| 148 |
+
|
| 149 |
+
observation_elements.update(right_act_result.observation_elements)
|
| 150 |
+
observation_elements.update(left_act_result.observation_elements)
|
| 151 |
+
|
| 152 |
+
info.update(right_act_result.info)
|
| 153 |
+
info.update(left_act_result.info)
|
| 154 |
+
|
| 155 |
+
return ActResult(action, observation_elements=observation_elements, info=info)
|
| 156 |
+
|
| 157 |
+
def reset(self) -> None:
|
| 158 |
+
self.right_agent.reset()
|
| 159 |
+
self.left_agent.reset()
|
| 160 |
+
|
| 161 |
+
def update_summaries(self) -> List[Summary]:
|
| 162 |
+
summaries = []
|
| 163 |
+
for k, v in self._summaries.items():
|
| 164 |
+
summaries.append(ScalarSummary(f"{k}", v))
|
| 165 |
+
|
| 166 |
+
right_summaries = self.right_agent.update_summaries()
|
| 167 |
+
left_summaries = self.left_agent.update_summaries()
|
| 168 |
+
|
| 169 |
+
for summary in right_summaries:
|
| 170 |
+
if not isinstance(summary, ImageSummary):
|
| 171 |
+
summary.name = f"agent_right/{summary.name}"
|
| 172 |
+
|
| 173 |
+
for summary in left_summaries:
|
| 174 |
+
if not isinstance(summary, ImageSummary):
|
| 175 |
+
summary.name = f"agent_left/{summary.name}"
|
| 176 |
+
|
| 177 |
+
return right_summaries + left_summaries + summaries
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def act_summaries(self) -> List[Summary]:
|
| 181 |
+
right_summaries = self.right_agent.act_summaries()
|
| 182 |
+
left_summaries = self.left_agent.act_summaries()
|
| 183 |
+
|
| 184 |
+
for summary in right_summaries:
|
| 185 |
+
if not isinstance(summary, ImageSummary):
|
| 186 |
+
summary.name = f"agent_right/{summary.name}"
|
| 187 |
+
|
| 188 |
+
for summary in left_summaries:
|
| 189 |
+
if not isinstance(summary, ImageSummary):
|
| 190 |
+
summary.name = f"agent_left/{summary.name}"
|
| 191 |
+
|
| 192 |
+
return right_summaries + left_summaries
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def load_weights(self, savedir: str) -> None:
|
| 196 |
+
self.right_agent.load_weights(savedir)
|
| 197 |
+
self.left_agent.load_weights(savedir)
|
| 198 |
+
|
| 199 |
+
def save_weights(self, savedir: str) -> None:
|
| 200 |
+
self.right_agent.save_weights(savedir)
|
| 201 |
+
self.left_agent.save_weights(savedir)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class LeaderFollowerAgent(Agent):
|
| 205 |
+
|
| 206 |
+
def __init__(self, leader_agent: Agent, follower_agent: Agent):
|
| 207 |
+
self.leader_agent = leader_agent
|
| 208 |
+
self.follower_agent = follower_agent
|
| 209 |
+
self._summaries = {}
|
| 210 |
+
|
| 211 |
+
def build(self, training: bool, device=None) -> None:
|
| 212 |
+
self.leader_agent.build(training, device)
|
| 213 |
+
self.follower_agent.build(training, device)
|
| 214 |
+
|
| 215 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 216 |
+
|
| 217 |
+
leader_observation = {}
|
| 218 |
+
follower_observation = {}
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
for k, v in replay_sample.items():
|
| 222 |
+
if "rgb" in k or "point_cloud" in k or "camera" in k:
|
| 223 |
+
leader_observation[k] = v
|
| 224 |
+
follower_observation[k] = v
|
| 225 |
+
elif "right_" in k :
|
| 226 |
+
leader_observation[k[6:]] = v
|
| 227 |
+
elif "left_" in k:
|
| 228 |
+
follower_observation[k[5:]] = v
|
| 229 |
+
else:
|
| 230 |
+
leader_observation[k] = v
|
| 231 |
+
follower_observation[k] = v
|
| 232 |
+
|
| 233 |
+
action = replay_sample["action"]
|
| 234 |
+
right_action, left_action = action.chunk(2, dim=2)
|
| 235 |
+
leader_observation["action"] = right_action
|
| 236 |
+
follower_observation["action"] = left_action
|
| 237 |
+
|
| 238 |
+
leader_update_dict = self.leader_agent.update(step, leader_observation)
|
| 239 |
+
import torch
|
| 240 |
+
follower_observation['low_dim_state'] = torch.cat([follower_observation['low_dim_state'],
|
| 241 |
+
replay_sample["right_trans_action_indicies"],
|
| 242 |
+
replay_sample["right_rot_grip_action_indicies"],
|
| 243 |
+
replay_sample["right_ignore_collisions"]], dim=-1)
|
| 244 |
+
|
| 245 |
+
follower_update_dict = self.follower_agent.update(step, follower_observation)
|
| 246 |
+
|
| 247 |
+
total_losses = leader_update_dict["total_losses"] + follower_update_dict["total_losses"]
|
| 248 |
+
self._summaries.update({"total_losses": total_losses})
|
| 249 |
+
return self._summaries
|
| 250 |
+
|
| 251 |
+
def act(self, step: int, observation: dict, deterministic: bool) -> ActResult:
|
| 252 |
+
|
| 253 |
+
observation_elements = {}
|
| 254 |
+
info = {}
|
| 255 |
+
|
| 256 |
+
leader_observation = {}
|
| 257 |
+
follower_observation = {}
|
| 258 |
+
|
| 259 |
+
for k,v in observation.items():
|
| 260 |
+
if "right_" in k and not "rgb" in k and not "point_cloud" in k and not "camera" in k:
|
| 261 |
+
leader_observation[k[6:]] = v
|
| 262 |
+
elif "left_" in k and not "rgb" in k and not "point_cloud" in k and not "camera" in k:
|
| 263 |
+
follower_observation[k[5:]] = v
|
| 264 |
+
else:
|
| 265 |
+
leader_observation[k] = v
|
| 266 |
+
follower_observation[k] = v
|
| 267 |
+
|
| 268 |
+
right_act_result = self.leader_agent.act(step, leader_observation, deterministic)
|
| 269 |
+
|
| 270 |
+
right_observation_elements = right_act_result.observation_elements
|
| 271 |
+
|
| 272 |
+
import torch
|
| 273 |
+
|
| 274 |
+
device = follower_observation['low_dim_state'].device
|
| 275 |
+
if "trans_action_indicies" in right_observation_elements:
|
| 276 |
+
right_trans_action_indicies = torch.from_numpy(right_observation_elements["trans_action_indicies"]).unsqueeze(0).unsqueeze(0).to(device)
|
| 277 |
+
right_rot_grip_action_indicies = torch.from_numpy(right_observation_elements["rot_grip_action_indicies"]).unsqueeze(0).unsqueeze(0).to(device)
|
| 278 |
+
right_ignore_collisions = torch.from_numpy(right_act_result.action[-1:]).unsqueeze(0).unsqueeze(0).to(device)
|
| 279 |
+
else:
|
| 280 |
+
right_trans_action_indicies = torch.empty((1, 1, 3)).to(device)
|
| 281 |
+
right_rot_grip_action_indicies = torch.empty((1, 1, 4)).to(device)
|
| 282 |
+
right_ignore_collisions = torch.empty((1, 1, 1)).to(device)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
follower_observation['low_dim_state'] = torch.cat([follower_observation['low_dim_state'],
|
| 286 |
+
right_trans_action_indicies,
|
| 287 |
+
right_rot_grip_action_indicies,
|
| 288 |
+
right_ignore_collisions], dim=-1)
|
| 289 |
+
|
| 290 |
+
left_act_result = self.follower_agent.act(step, follower_observation, deterministic)
|
| 291 |
+
|
| 292 |
+
action = (*right_act_result.action, *left_act_result.action)
|
| 293 |
+
|
| 294 |
+
observation_elements.update(right_act_result.observation_elements)
|
| 295 |
+
observation_elements.update(left_act_result.observation_elements)
|
| 296 |
+
|
| 297 |
+
info.update(right_act_result.info)
|
| 298 |
+
info.update(left_act_result.info)
|
| 299 |
+
|
| 300 |
+
return ActResult(action, observation_elements=observation_elements, info=info)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def reset(self) -> None:
|
| 304 |
+
self.leader_agent.reset()
|
| 305 |
+
self.follower_agent.reset()
|
| 306 |
+
|
| 307 |
+
def update_summaries(self) -> List[Summary]:
|
| 308 |
+
|
| 309 |
+
summaries = []
|
| 310 |
+
for k, v in self._summaries.items():
|
| 311 |
+
summaries.append(ScalarSummary(f"{k}", v))
|
| 312 |
+
|
| 313 |
+
leader_summaries = self.leader_agent.update_summaries()
|
| 314 |
+
follower_summaries = self.follower_agent.update_summaries()
|
| 315 |
+
|
| 316 |
+
for summary in leader_summaries:
|
| 317 |
+
if not isinstance(summary, ImageSummary):
|
| 318 |
+
summary.name = f"agent_leader/{summary.name}"
|
| 319 |
+
for summary in follower_summaries:
|
| 320 |
+
if not isinstance(summary, ImageSummary):
|
| 321 |
+
summary.name = f"agent_follower/{summary.name}"
|
| 322 |
+
|
| 323 |
+
return leader_summaries + follower_summaries + summaries
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def act_summaries(self) -> List[Summary]:
|
| 327 |
+
leader_summaries = self.leader_agent.act_summaries()
|
| 328 |
+
follower_summaries = self.follower_agent.act_summaries()
|
| 329 |
+
|
| 330 |
+
for summary in leader_summaries:
|
| 331 |
+
if not isinstance(summary, ImageSummary):
|
| 332 |
+
summary.name = f"agent_leader/{summary.name}"
|
| 333 |
+
for summary in follower_summaries:
|
| 334 |
+
if not isinstance(summary, ImageSummary):
|
| 335 |
+
summary.name = f"agent_follower/{summary.name}"
|
| 336 |
+
|
| 337 |
+
return leader_summaries + follower_summaries
|
| 338 |
+
|
| 339 |
+
def load_weights(self, savedir: str) -> None:
|
| 340 |
+
self.leader_agent.load_weights(savedir)
|
| 341 |
+
self.follower_agent.load_weights(savedir)
|
| 342 |
+
|
| 343 |
+
def save_weights(self, savedir: str) -> None:
|
| 344 |
+
self.leader_agent.save_weights(savedir)
|
| 345 |
+
self.follower_agent.save_weights(savedir)
|
external/yarr/yarr/envs/__init__.py
ADDED
|
File without changes
|
external/yarr/yarr/envs/env.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Any, List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from yarr.utils.observation_type import ObservationElement
|
| 7 |
+
from yarr.utils.transition import Transition
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Env(ABC):
|
| 11 |
+
|
| 12 |
+
def __init__(self):
|
| 13 |
+
self._active_task_id = 0
|
| 14 |
+
self._eval_env = False
|
| 15 |
+
|
| 16 |
+
@property
|
| 17 |
+
def eval(self):
|
| 18 |
+
return self._eval_env
|
| 19 |
+
|
| 20 |
+
@eval.setter
|
| 21 |
+
def eval(self, eval):
|
| 22 |
+
self._eval_env = eval
|
| 23 |
+
|
| 24 |
+
@property
|
| 25 |
+
def active_task_id(self) -> int:
|
| 26 |
+
return self._active_task_id
|
| 27 |
+
|
| 28 |
+
@abstractmethod
|
| 29 |
+
def launch(self) -> None:
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
def shutdown(self) -> None:
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
@abstractmethod
|
| 36 |
+
def reset(self) -> dict:
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
@abstractmethod
|
| 40 |
+
def step(self, action: np.ndarray) -> Transition:
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
@abstractmethod
|
| 45 |
+
def observation_elements(self) -> List[ObservationElement]:
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
@abstractmethod
|
| 50 |
+
def action_shape(self) -> tuple:
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
@abstractmethod
|
| 55 |
+
def env(self) -> Any:
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class MultiTaskEnv(Env):
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
@abstractmethod
|
| 63 |
+
def num_tasks(self) -> int:
|
| 64 |
+
pass
|
external/yarr/yarr/envs/rlbench_env.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Type, List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
try:
|
| 6 |
+
from rlbench import ObservationConfig, Environment, CameraConfig
|
| 7 |
+
except (ModuleNotFoundError, ImportError) as e:
|
| 8 |
+
print("You need to install RLBench: 'https://github.com/stepjam/RLBench'")
|
| 9 |
+
raise e
|
| 10 |
+
from rlbench.action_modes.action_mode import ActionMode
|
| 11 |
+
from rlbench.backend.observation import BimanualObservation, Observation
|
| 12 |
+
from rlbench.backend.task import Task
|
| 13 |
+
from rlbench.backend.task import BimanualTask
|
| 14 |
+
|
| 15 |
+
from helpers.clip.core.clip import tokenize
|
| 16 |
+
|
| 17 |
+
from yarr.envs.env import Env, MultiTaskEnv
|
| 18 |
+
from yarr.utils.observation_type import ObservationElement
|
| 19 |
+
from yarr.utils.transition import Transition
|
| 20 |
+
from yarr.utils.process_str import change_case
|
| 21 |
+
|
| 22 |
+
import logging
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
ROBOT_STATE_KEYS = ['joint_velocities', 'joint_positions', 'joint_forces',
|
| 26 |
+
'gripper_open', 'gripper_pose',
|
| 27 |
+
'gripper_joint_positions', 'gripper_touch_forces',
|
| 28 |
+
'task_low_dim_state', 'misc', 'left', 'right']
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ..todo:: possibly duplicated code.
|
| 32 |
+
def _extract_obs_bimanual(obs: BimanualObservation, channels_last: bool, observation_config: ObservationConfig):
|
| 33 |
+
obs_dict = vars(obs)
|
| 34 |
+
obs_dict = {k: v for k, v in obs_dict.items() if v is not None}
|
| 35 |
+
|
| 36 |
+
right_robot_state = obs.get_low_dim_data(obs.right)
|
| 37 |
+
left_robot_state = obs.get_low_dim_data(obs.left)
|
| 38 |
+
|
| 39 |
+
obs_dict = {k: v for k, v in obs_dict.items()
|
| 40 |
+
if k not in ROBOT_STATE_KEYS}
|
| 41 |
+
|
| 42 |
+
if not channels_last:
|
| 43 |
+
# Swap channels from last dim to 1st dim
|
| 44 |
+
obs_dict = {k: np.transpose(v, [2, 0, 1]) if v.ndim == 3 else np.expand_dims(v, 0)
|
| 45 |
+
for k, v in obs.perception_data.items() if v is not None}
|
| 46 |
+
else:
|
| 47 |
+
# Add extra dim to depth data
|
| 48 |
+
obs_dict = {k: v if v.ndim == 3 else np.expand_dims(v, -1)
|
| 49 |
+
for k, v in obs.perception_data.items() if v is not None}
|
| 50 |
+
|
| 51 |
+
if observation_config.robot_name == 'right':
|
| 52 |
+
obs_dict['low_dim_state'] = right_robot_state.astype(np.float32)
|
| 53 |
+
obs_dict['ignore_collisions'] = np.array([obs.right.ignore_collisions], dtype=np.float32)
|
| 54 |
+
elif observation_config.robot_name == 'left':
|
| 55 |
+
obs_dict['low_dim_state'] = left_robot_state.astype(np.float32)
|
| 56 |
+
obs_dict['ignore_collisions'] = np.array([obs.left.ignore_collisions], dtype=np.float32)
|
| 57 |
+
else:
|
| 58 |
+
obs_dict['right_low_dim_state'] = right_robot_state.astype(np.float32)
|
| 59 |
+
obs_dict['left_low_dim_state'] = left_robot_state.astype(np.float32)
|
| 60 |
+
obs_dict['right_ignore_collisions'] = np.array([obs.right.ignore_collisions], dtype=np.float32)
|
| 61 |
+
obs_dict['left_ignore_collisions'] = np.array([obs.left.ignore_collisions], dtype=np.float32)
|
| 62 |
+
|
| 63 |
+
for (k, v) in [(k, v) for k, v in obs_dict.items() if 'point_cloud' in k]:
|
| 64 |
+
# ..TODO::
|
| 65 |
+
obs_dict[k] = v.astype(np.float16)
|
| 66 |
+
|
| 67 |
+
for camera_name, config in observation_config.camera_configs.items():
|
| 68 |
+
if config.point_cloud:
|
| 69 |
+
obs_dict[f'{camera_name}_camera_extrinsics'] = obs.misc[f'{camera_name}_camera_extrinsics']
|
| 70 |
+
obs_dict[f'{camera_name}_camera_intrinsics'] = obs.misc[f'{camera_name}_camera_intrinsics']
|
| 71 |
+
return obs_dict
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _extract_obs_unimanual(obs: Observation, channels_last: bool, observation_config):
|
| 75 |
+
obs_dict = vars(obs)
|
| 76 |
+
obs_dict = {k: v for k, v in obs_dict.items() if v is not None}
|
| 77 |
+
robot_state = obs.get_low_dim_data()
|
| 78 |
+
# Remove all of the individual state elements
|
| 79 |
+
obs_dict = {k: v for k, v in obs_dict.items()
|
| 80 |
+
if k not in ROBOT_STATE_KEYS}
|
| 81 |
+
if not channels_last:
|
| 82 |
+
# Swap channels from last dim to 1st dim
|
| 83 |
+
obs_dict = {k: np.transpose(
|
| 84 |
+
v, [2, 0, 1]) if v.ndim == 3 else np.expand_dims(v, 0)
|
| 85 |
+
for k, v in obs_dict.items()}
|
| 86 |
+
else:
|
| 87 |
+
# Add extra dim to depth data
|
| 88 |
+
obs_dict = {k: v if v.ndim == 3 else np.expand_dims(v, -1)
|
| 89 |
+
for k, v in obs_dict.items()}
|
| 90 |
+
obs_dict['low_dim_state'] = np.array(robot_state, dtype=np.float32)
|
| 91 |
+
obs_dict['ignore_collisions'] = np.array([obs.ignore_collisions], dtype=np.float32)
|
| 92 |
+
for (k, v) in [(k, v) for k, v in obs_dict.items() if 'point_cloud' in k]:
|
| 93 |
+
obs_dict[k] = v.astype(np.float32)
|
| 94 |
+
|
| 95 |
+
for config, name in [
|
| 96 |
+
(observation_config.left_shoulder_camera, 'left_shoulder'),
|
| 97 |
+
(observation_config.right_shoulder_camera, 'right_shoulder'),
|
| 98 |
+
(observation_config.front_camera, 'front'),
|
| 99 |
+
(observation_config.wrist_camera, 'wrist'),
|
| 100 |
+
(observation_config.overhead_camera, 'overhead')]:
|
| 101 |
+
if config.point_cloud:
|
| 102 |
+
obs_dict['%s_camera_extrinsics' % name] = obs.misc['%s_camera_extrinsics' % name]
|
| 103 |
+
obs_dict['%s_camera_intrinsics' % name] = obs.misc['%s_camera_intrinsics' % name]
|
| 104 |
+
return obs_dict
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _get_cam_observation_elements(camera: CameraConfig, prefix: str, channels_last):
|
| 108 |
+
elements = []
|
| 109 |
+
img_s = list(camera.image_size)
|
| 110 |
+
shape = img_s + [3] if channels_last else [3] + img_s
|
| 111 |
+
if camera.rgb:
|
| 112 |
+
elements.append(
|
| 113 |
+
ObservationElement('%s_rgb' % prefix, shape, np.uint8))
|
| 114 |
+
if camera.point_cloud:
|
| 115 |
+
elements.append(
|
| 116 |
+
ObservationElement('%s_point_cloud' % prefix, shape, np.float32))
|
| 117 |
+
elements.append(
|
| 118 |
+
ObservationElement('%s_camera_extrinsics' % prefix, (4, 4),
|
| 119 |
+
np.float32))
|
| 120 |
+
elements.append(
|
| 121 |
+
ObservationElement('%s_camera_intrinsics' % prefix, (3, 3),
|
| 122 |
+
np.float32))
|
| 123 |
+
if camera.depth:
|
| 124 |
+
shape = img_s + [1] if channels_last else [1] + img_s
|
| 125 |
+
elements.append(
|
| 126 |
+
ObservationElement('%s_depth' % prefix, shape, np.float32))
|
| 127 |
+
if camera.mask:
|
| 128 |
+
raise NotImplementedError()
|
| 129 |
+
|
| 130 |
+
return elements
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _observation_elements(observation_config, channels_last) -> List[ObservationElement]:
|
| 134 |
+
elements = []
|
| 135 |
+
robot_state_len = 0
|
| 136 |
+
if observation_config.joint_velocities:
|
| 137 |
+
robot_state_len += 7
|
| 138 |
+
if observation_config.joint_positions:
|
| 139 |
+
robot_state_len += 7
|
| 140 |
+
if observation_config.joint_forces:
|
| 141 |
+
robot_state_len += 7
|
| 142 |
+
if observation_config.gripper_open:
|
| 143 |
+
robot_state_len += 1
|
| 144 |
+
if observation_config.gripper_pose:
|
| 145 |
+
robot_state_len += 7
|
| 146 |
+
if observation_config.gripper_joint_positions:
|
| 147 |
+
robot_state_len += 2
|
| 148 |
+
if observation_config.gripper_touch_forces:
|
| 149 |
+
robot_state_len += 2
|
| 150 |
+
if observation_config.task_low_dim_state:
|
| 151 |
+
raise NotImplementedError()
|
| 152 |
+
if robot_state_len > 0:
|
| 153 |
+
if observation_config.robot_name == 'bimanual':
|
| 154 |
+
elements.append(ObservationElement(
|
| 155 |
+
'right_low_dim_state', (robot_state_len,), np.float32))
|
| 156 |
+
elements.append(ObservationElement(
|
| 157 |
+
'left_low_dim_state', (robot_state_len,), np.float32))
|
| 158 |
+
elif observation_config.robot_name in ['unimanual', 'left', 'right']:
|
| 159 |
+
elements.append(ObservationElement('low_dim_state', (robot_state_len,), np.float32))
|
| 160 |
+
elements.extend(_get_cam_observation_elements(
|
| 161 |
+
observation_config.left_shoulder_camera, 'left_shoulder', channels_last))
|
| 162 |
+
elements.extend(_get_cam_observation_elements(
|
| 163 |
+
observation_config.right_shoulder_camera, 'right_shoulder', channels_last))
|
| 164 |
+
elements.extend(_get_cam_observation_elements(
|
| 165 |
+
observation_config.front_camera, 'front', channels_last))
|
| 166 |
+
elements.extend(_get_cam_observation_elements(
|
| 167 |
+
observation_config.wrist_camera, 'wrist', channels_last))
|
| 168 |
+
return elements
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class RLBenchEnv(Env):
|
| 172 |
+
|
| 173 |
+
def __init__(self, task_class: Type[Task],
|
| 174 |
+
observation_config: ObservationConfig,
|
| 175 |
+
action_mode: ActionMode,
|
| 176 |
+
dataset_root: str = '',
|
| 177 |
+
channels_last=False,
|
| 178 |
+
headless=True,
|
| 179 |
+
include_lang_goal_in_obs=False):
|
| 180 |
+
super(RLBenchEnv, self).__init__()
|
| 181 |
+
self._task_class = task_class
|
| 182 |
+
self._observation_config = observation_config
|
| 183 |
+
self._channels_last = channels_last
|
| 184 |
+
self._include_lang_goal_in_obs = include_lang_goal_in_obs
|
| 185 |
+
if issubclass(task_class, BimanualTask):
|
| 186 |
+
robot_setup = "dual_panda"
|
| 187 |
+
else:
|
| 188 |
+
robot_setup = "panda"
|
| 189 |
+
self._rlbench_env = Environment(
|
| 190 |
+
action_mode=action_mode, obs_config=observation_config,
|
| 191 |
+
dataset_root=dataset_root, headless=headless, robot_setup=robot_setup)
|
| 192 |
+
self._task = None
|
| 193 |
+
self._lang_goal = 'unknown goal'
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def extract_obs(self, obs: Observation):
|
| 197 |
+
if isinstance(obs, BimanualObservation):
|
| 198 |
+
extracted_obs = _extract_obs_bimanual(obs, self._channels_last, self._observation_config)
|
| 199 |
+
else:
|
| 200 |
+
extracted_obs = _extract_obs_unimanual(obs, self._channels_last, self._observation_config)
|
| 201 |
+
if self._include_lang_goal_in_obs:
|
| 202 |
+
extracted_obs['lang_goal_tokens'] = tokenize([self._lang_goal])[0].numpy()
|
| 203 |
+
return extracted_obs
|
| 204 |
+
|
| 205 |
+
def launch(self):
|
| 206 |
+
self._rlbench_env.launch()
|
| 207 |
+
self._task = self._rlbench_env.get_task(self._task_class)
|
| 208 |
+
|
| 209 |
+
def shutdown(self):
|
| 210 |
+
self._rlbench_env.shutdown()
|
| 211 |
+
|
| 212 |
+
def reset(self) -> dict:
|
| 213 |
+
descriptions, obs = self._task.reset()
|
| 214 |
+
self._lang_goal = descriptions[0] # first description variant
|
| 215 |
+
extracted_obs = self.extract_obs(obs)
|
| 216 |
+
return extracted_obs
|
| 217 |
+
|
| 218 |
+
def step(self, action: np.ndarray) -> Transition:
|
| 219 |
+
obs, reward, terminal = self._task.step(action)
|
| 220 |
+
obs = self.extract_obs(obs)
|
| 221 |
+
return Transition(obs, reward, terminal)
|
| 222 |
+
|
| 223 |
+
@property
|
| 224 |
+
def observation_elements(self) -> List[ObservationElement]:
|
| 225 |
+
return _observation_elements(self._observation_config, self._channels_last)
|
| 226 |
+
|
| 227 |
+
@property
|
| 228 |
+
def action_shape(self):
|
| 229 |
+
return (self._rlbench_env.action_size, )
|
| 230 |
+
|
| 231 |
+
@property
|
| 232 |
+
def env(self) -> Environment:
|
| 233 |
+
return self._rlbench_env
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class MultiTaskRLBenchEnv(MultiTaskEnv):
|
| 237 |
+
|
| 238 |
+
def __init__(self,
|
| 239 |
+
task_classes: List[Type[Task]],
|
| 240 |
+
observation_config: ObservationConfig,
|
| 241 |
+
action_mode: ActionMode,
|
| 242 |
+
dataset_root: str = '',
|
| 243 |
+
channels_last=False,
|
| 244 |
+
headless=True,
|
| 245 |
+
swap_task_every: int = 1,
|
| 246 |
+
include_lang_goal_in_obs=False):
|
| 247 |
+
super(MultiTaskRLBenchEnv, self).__init__()
|
| 248 |
+
self._task_classes = task_classes
|
| 249 |
+
self._observation_config = observation_config
|
| 250 |
+
self._channels_last = channels_last
|
| 251 |
+
self._include_lang_goal_in_obs = include_lang_goal_in_obs
|
| 252 |
+
if issubclass(task_classes[0], BimanualTask):
|
| 253 |
+
robot_setup = "dual_panda"
|
| 254 |
+
else:
|
| 255 |
+
robot_setup = "panda"
|
| 256 |
+
self._rlbench_env = Environment(
|
| 257 |
+
action_mode=action_mode, obs_config=observation_config,
|
| 258 |
+
dataset_root=dataset_root, headless=headless, robot_setup=robot_setup)
|
| 259 |
+
self._task = None
|
| 260 |
+
self._task_name = ''
|
| 261 |
+
self._lang_goal = 'unknown goal'
|
| 262 |
+
self._swap_task_every = swap_task_every
|
| 263 |
+
self._rlbench_env
|
| 264 |
+
self._episodes_this_task = 0
|
| 265 |
+
self._active_task_id = -1
|
| 266 |
+
|
| 267 |
+
self._task_name_to_idx = {change_case(tc.__name__):i for i, tc in enumerate(self._task_classes)}
|
| 268 |
+
|
| 269 |
+
def _set_new_task(self, shuffle=False):
|
| 270 |
+
if shuffle:
|
| 271 |
+
self._active_task_id = np.random.randint(0, len(self._task_classes))
|
| 272 |
+
else:
|
| 273 |
+
self._active_task_id = (self._active_task_id + 1) % len(self._task_classes)
|
| 274 |
+
task = self._task_classes[self._active_task_id]
|
| 275 |
+
self._task = self._rlbench_env.get_task(task)
|
| 276 |
+
|
| 277 |
+
def set_task(self, task_name: str):
|
| 278 |
+
self._active_task_id = self._task_name_to_idx[task_name]
|
| 279 |
+
task = self._task_classes[self._active_task_id]
|
| 280 |
+
self._task = self._rlbench_env.get_task(task)
|
| 281 |
+
|
| 282 |
+
descriptions, _ = self._task.reset()
|
| 283 |
+
self._lang_goal = descriptions[0] # first description variant
|
| 284 |
+
|
| 285 |
+
def extract_obs(self, obs: Observation):
|
| 286 |
+
if obs.is_bimanual:
|
| 287 |
+
extracted_obs = _extract_obs_bimanual(obs, self._channels_last, self._observation_config)
|
| 288 |
+
else:
|
| 289 |
+
extracted_obs = _extract_obs_unimanual(obs, self._channels_last, self._observation_config)
|
| 290 |
+
if self._include_lang_goal_in_obs:
|
| 291 |
+
extracted_obs['lang_goal_tokens'] = tokenize([self._lang_goal])[0].numpy()
|
| 292 |
+
return extracted_obs
|
| 293 |
+
|
| 294 |
+
def launch(self):
|
| 295 |
+
self._rlbench_env.launch()
|
| 296 |
+
self._set_new_task()
|
| 297 |
+
|
| 298 |
+
def shutdown(self):
|
| 299 |
+
self._rlbench_env.shutdown()
|
| 300 |
+
|
| 301 |
+
def reset(self) -> dict:
|
| 302 |
+
if self._episodes_this_task == self._swap_task_every:
|
| 303 |
+
self._set_new_task()
|
| 304 |
+
self._episodes_this_task = 0
|
| 305 |
+
self._episodes_this_task += 1
|
| 306 |
+
|
| 307 |
+
descriptions, obs = self._task.reset()
|
| 308 |
+
self._lang_goal = descriptions[0] # first description variant
|
| 309 |
+
extracted_obs = self.extract_obs(obs)
|
| 310 |
+
|
| 311 |
+
return extracted_obs
|
| 312 |
+
|
| 313 |
+
def step(self, action: np.ndarray) -> Transition:
|
| 314 |
+
obs, reward, terminal = self._task.step(action)
|
| 315 |
+
obs = self.extract_obs(obs)
|
| 316 |
+
return Transition(obs, reward, terminal)
|
| 317 |
+
|
| 318 |
+
@property
|
| 319 |
+
def observation_elements(self) -> List[ObservationElement]:
|
| 320 |
+
return _observation_elements(self._observation_config, self._channels_last)
|
| 321 |
+
|
| 322 |
+
@property
|
| 323 |
+
def action_shape(self):
|
| 324 |
+
return (self._rlbench_env.action_size, )
|
| 325 |
+
|
| 326 |
+
@property
|
| 327 |
+
def env(self) -> Environment:
|
| 328 |
+
return self._rlbench_env
|
| 329 |
+
|
| 330 |
+
@property
|
| 331 |
+
def num_tasks(self) -> int:
|
| 332 |
+
return len(self._task_classes)
|
external/yarr/yarr/replay_buffer/__init__.py
ADDED
|
File without changes
|
external/yarr/yarr/replay_buffer/prioritized_replay_buffer.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""An implementation of Prioritized Experience Replay (PER).
|
| 2 |
+
|
| 3 |
+
This implementation is based on the paper "Prioritized Experience Replay"
|
| 4 |
+
by Tom Schaul et al. (2015).
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import absolute_import
|
| 7 |
+
from __future__ import division
|
| 8 |
+
from __future__ import print_function
|
| 9 |
+
|
| 10 |
+
from .uniform_replay_buffer import *
|
| 11 |
+
from .sum_tree import *
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
PRIORITY = 'priority'
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class PrioritizedReplayBuffer(UniformReplayBuffer):
|
| 19 |
+
"""An out-of-graph Replay Buffer for Prioritized Experience Replay.
|
| 20 |
+
|
| 21 |
+
See uniform_replay_buffer.py for details.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, *args, **kwargs):
|
| 25 |
+
"""Initializes OutOfGraphPrioritizedReplayBuffer."""
|
| 26 |
+
super(PrioritizedReplayBuffer, self).__init__(*args, **kwargs)
|
| 27 |
+
self._sum_tree = SumTree(self._replay_capacity)
|
| 28 |
+
|
| 29 |
+
def get_storage_signature(self) -> Tuple[List[ReplayElement],
|
| 30 |
+
List[ReplayElement]]:
|
| 31 |
+
"""Returns a default list of elements to be stored in this replay memory.
|
| 32 |
+
|
| 33 |
+
Note - Derived classes may return a different signature.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
dict of ReplayElements defining the type of the contents stored.
|
| 37 |
+
"""
|
| 38 |
+
storage_elements, obs_elements = super(
|
| 39 |
+
PrioritizedReplayBuffer, self).get_storage_signature()
|
| 40 |
+
storage_elements.append(ReplayElement(PRIORITY, (), np.float32),)
|
| 41 |
+
|
| 42 |
+
return storage_elements, obs_elements
|
| 43 |
+
|
| 44 |
+
def add(self, action, reward, terminal, timeout, priority=None, **kwargs):
|
| 45 |
+
kwargs['priority'] = priority
|
| 46 |
+
super(PrioritizedReplayBuffer, self).add(
|
| 47 |
+
action, reward, terminal, timeout, **kwargs)
|
| 48 |
+
|
| 49 |
+
def _add(self, kwargs: dict):
|
| 50 |
+
"""Internal add method to add to the storage arrays.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
kwargs: All the elements in a transition.
|
| 54 |
+
"""
|
| 55 |
+
with self._lock:
|
| 56 |
+
cursor = self.cursor()
|
| 57 |
+
priority = kwargs[PRIORITY]
|
| 58 |
+
if priority is None:
|
| 59 |
+
priority = self._sum_tree.max_recorded_priority
|
| 60 |
+
|
| 61 |
+
if self._disk_saving:
|
| 62 |
+
term = self._store[TERMINAL]
|
| 63 |
+
term[cursor] = kwargs[TERMINAL]
|
| 64 |
+
self._store[TERMINAL] = term
|
| 65 |
+
|
| 66 |
+
with open(join(self._save_dir, '%d.replay' % cursor), 'wb') as f:
|
| 67 |
+
pickle.dump(kwargs, f)
|
| 68 |
+
# If first add, then pad for correct wrapping
|
| 69 |
+
if self._add_count.value == 0:
|
| 70 |
+
self._add_initial_to_disk(kwargs)
|
| 71 |
+
else:
|
| 72 |
+
for name, data in kwargs.items():
|
| 73 |
+
item = self._store[name]
|
| 74 |
+
item[cursor] = data
|
| 75 |
+
self._store[name] = item
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
self._sum_tree.set(self.cursor(), priority)
|
| 79 |
+
self._add_count.value += 1
|
| 80 |
+
self.invalid_range = invalid_range(
|
| 81 |
+
self.cursor(), self._replay_capacity, self._timesteps,
|
| 82 |
+
self._update_horizon)
|
| 83 |
+
|
| 84 |
+
def add_final(self, **kwargs):
|
| 85 |
+
"""Adds a transition to the replay memory.
|
| 86 |
+
Args:
|
| 87 |
+
**kwargs: The remaining args
|
| 88 |
+
"""
|
| 89 |
+
# if self.is_empty() or self._store['terminal'][self.cursor() - 1] != 1:
|
| 90 |
+
# raise ValueError('The previous transition was not terminal.')
|
| 91 |
+
self._check_add_types(kwargs, self._obs_signature)
|
| 92 |
+
transition = self._final_transition(kwargs)
|
| 93 |
+
for element_type in self._storage_signature:
|
| 94 |
+
# 0 priority for final observation.
|
| 95 |
+
if element_type.name == PRIORITY:
|
| 96 |
+
transition[element_type.name] = 0.0
|
| 97 |
+
self._add(transition)
|
| 98 |
+
|
| 99 |
+
def sample_index_batch(self, batch_size):
|
| 100 |
+
"""Returns a batch of valid indices sampled as in Schaul et al. (2015).
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
batch_size: int, number of indices returned.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
list of ints, a batch of valid indices sampled uniformly.
|
| 107 |
+
|
| 108 |
+
Raises:
|
| 109 |
+
Exception: If the batch was not constructed after maximum number of tries.
|
| 110 |
+
"""
|
| 111 |
+
# Sample stratified indices. Some of them might be invalid.
|
| 112 |
+
indices = self._sum_tree.stratified_sample(batch_size)
|
| 113 |
+
allowed_attempts = self._max_sample_attempts
|
| 114 |
+
for i in range(len(indices)):
|
| 115 |
+
if not self.is_valid_transition(indices[i]):
|
| 116 |
+
if allowed_attempts == 0:
|
| 117 |
+
raise RuntimeError(
|
| 118 |
+
'Max sample attempts: Tried {} times but only sampled {}'
|
| 119 |
+
' valid indices. Batch size is {}'.
|
| 120 |
+
format(self._max_sample_attempts, i, batch_size))
|
| 121 |
+
index = indices[i]
|
| 122 |
+
while not self.is_valid_transition(
|
| 123 |
+
index) and allowed_attempts > 0:
|
| 124 |
+
# If index i is not valid keep sampling others. Note that this
|
| 125 |
+
# is not stratified.
|
| 126 |
+
index = self._sum_tree.sample()
|
| 127 |
+
allowed_attempts -= 1
|
| 128 |
+
indices[i] = index
|
| 129 |
+
return indices
|
| 130 |
+
|
| 131 |
+
def sample_transition_batch(self, batch_size=None, indices=None,
|
| 132 |
+
pack_in_dict=True):
|
| 133 |
+
"""Returns a batch of transitions with extra storage and the priorities.
|
| 134 |
+
|
| 135 |
+
The extra storage are defined through the extra_storage_types constructor
|
| 136 |
+
argument.
|
| 137 |
+
|
| 138 |
+
When the transition is terminal next_state_batch has undefined contents.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
batch_size: int, number of transitions returned. If None, the default
|
| 142 |
+
batch_size will be used.
|
| 143 |
+
indices: None or list of ints, the indices of every transition in the
|
| 144 |
+
batch. If None, sample the indices uniformly.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
transition_batch: tuple of np.arrays with the shape and type as in
|
| 148 |
+
get_transition_elements().
|
| 149 |
+
"""
|
| 150 |
+
transition = super(
|
| 151 |
+
PrioritizedReplayBuffer, self).sample_transition_batch(
|
| 152 |
+
batch_size, indices, pack_in_dict=False)
|
| 153 |
+
|
| 154 |
+
transition_elements = self.get_transition_elements(batch_size)
|
| 155 |
+
transition_names = [e.name for e in transition_elements]
|
| 156 |
+
probabilities_index = transition_names.index('sampling_probabilities')
|
| 157 |
+
indices_index = transition_names.index('indices')
|
| 158 |
+
indices = transition[indices_index]
|
| 159 |
+
# The parent returned an empty array for the probabilities. Fill it with the
|
| 160 |
+
# contents of the sum tree.
|
| 161 |
+
transition[probabilities_index][:] = self.get_priority(indices)
|
| 162 |
+
batch_arrays = transition
|
| 163 |
+
if pack_in_dict:
|
| 164 |
+
batch_arrays = self.unpack_transition(transition,
|
| 165 |
+
transition_elements)
|
| 166 |
+
return batch_arrays
|
| 167 |
+
|
| 168 |
+
def set_priority(self, indices, priorities):
|
| 169 |
+
"""Sets the priority of the given elements according to Schaul et al.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
indices: np.array with dtype int32, of indices in range
|
| 173 |
+
[0, replay_capacity).
|
| 174 |
+
priorities: float, the corresponding priorities.
|
| 175 |
+
"""
|
| 176 |
+
assert indices.dtype == np.int32, ('Indices must be integers, '
|
| 177 |
+
'given: {}'.format(indices.dtype))
|
| 178 |
+
for index, priority in zip(indices, priorities):
|
| 179 |
+
self._sum_tree.set(index, priority)
|
| 180 |
+
|
| 181 |
+
def get_priority(self, indices):
|
| 182 |
+
"""Fetches the priorities correspond to a batch of memory indices.
|
| 183 |
+
|
| 184 |
+
For any memory location not yet used, the corresponding priority is 0.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
indices: np.array with dtype int32, of indices in range
|
| 188 |
+
[0, replay_capacity).
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
priorities: float, the corresponding priorities.
|
| 192 |
+
"""
|
| 193 |
+
assert indices.shape, 'Indices must be an array.'
|
| 194 |
+
assert indices.dtype == np.int32, ('Indices must be int32s, '
|
| 195 |
+
'given: {}'.format(indices.dtype))
|
| 196 |
+
batch_size = len(indices)
|
| 197 |
+
priority_batch = np.empty((batch_size), dtype=np.float32)
|
| 198 |
+
for i, memory_index in enumerate(indices):
|
| 199 |
+
priority_batch[i] = self._sum_tree.get(memory_index)
|
| 200 |
+
return priority_batch
|
| 201 |
+
|
| 202 |
+
def get_transition_elements(self, batch_size=None):
|
| 203 |
+
"""Returns a 'type signature' for sample_transition_batch.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
batch_size: int, number of transitions returned. If None, the default
|
| 207 |
+
batch_size will be used.
|
| 208 |
+
Returns:
|
| 209 |
+
signature: A namedtuple describing the method's return type signature.
|
| 210 |
+
"""
|
| 211 |
+
parent_transition_type = (
|
| 212 |
+
super(PrioritizedReplayBuffer,
|
| 213 |
+
self).get_transition_elements(batch_size))
|
| 214 |
+
probablilities_type = [
|
| 215 |
+
ReplayElement('sampling_probabilities', (batch_size,), np.float32)
|
| 216 |
+
]
|
| 217 |
+
return parent_transition_type + probablilities_type
|
external/yarr/yarr/replay_buffer/replay_buffer.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC
|
| 2 |
+
from typing import Tuple, List
|
| 3 |
+
|
| 4 |
+
class ReplayElement(object):
|
| 5 |
+
def __init__(self, name, shape, type, is_observation=False):
|
| 6 |
+
self.name = name
|
| 7 |
+
self.shape = shape
|
| 8 |
+
self.type = type
|
| 9 |
+
self.is_observation = is_observation
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ReplayBuffer(ABC):
|
| 13 |
+
|
| 14 |
+
def replay_capacity(self):
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
def batch_size(self):
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
def get_storage_signature(self) -> Tuple[List[ReplayElement],
|
| 21 |
+
List[ReplayElement]]:
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
def add(self, action, reward, terminal, timeout, **kwargs):
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
def add_final(self, **kwargs):
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
def is_empty(self):
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
def is_full(self):
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
def cursor(self):
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
def set_cursor(self):
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def get_range(self, array, start_index, end_index):
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
def get_range_stack(self, array, start_index, end_index, terminals=None):
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
def get_terminal_stack(self, index):
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
def is_valid_transition(self, index):
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
def sample_index_batch(self, batch_size):
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
def unpack_transition(self, transition_tensors, transition_type):
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
def sample_transition_batch(self, batch_size=None, indices=None,
|
| 61 |
+
pack_in_dict=True):
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
def get_transition_elements(self, batch_size=None):
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
def shutdown(self):
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
def using_disk(self):
|
| 71 |
+
pass
|
external/yarr/yarr/replay_buffer/sum_tree.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""A sum tree data structure.
|
| 2 |
+
|
| 3 |
+
Used for prioritized experience replay. See prioritized_replay_buffer.py
|
| 4 |
+
and Schaul et al. (2015).
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import absolute_import
|
| 7 |
+
from __future__ import division
|
| 8 |
+
from __future__ import print_function
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import random
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SumTree(object):
|
| 17 |
+
"""A sum tree data structure for storing replay priorities.
|
| 18 |
+
|
| 19 |
+
A sum tree is a complete binary tree whose leaves contain values called
|
| 20 |
+
priorities. Internal nodes maintain the sum of the priorities of all leaf
|
| 21 |
+
nodes in their subtree.
|
| 22 |
+
|
| 23 |
+
For capacity = 4, the tree may look like this:
|
| 24 |
+
|
| 25 |
+
+---+
|
| 26 |
+
|2.5|
|
| 27 |
+
+-+-+
|
| 28 |
+
|
|
| 29 |
+
+-------+--------+
|
| 30 |
+
| |
|
| 31 |
+
+-+-+ +-+-+
|
| 32 |
+
|1.5| |1.0|
|
| 33 |
+
+-+-+ +-+-+
|
| 34 |
+
| |
|
| 35 |
+
+----+----+ +----+----+
|
| 36 |
+
| | | |
|
| 37 |
+
+-+-+ +-+-+ +-+-+ +-+-+
|
| 38 |
+
|0.5| |1.0| |0.5| |0.5|
|
| 39 |
+
+---+ +---+ +---+ +---+
|
| 40 |
+
|
| 41 |
+
This is stored in a list of numpy arrays:
|
| 42 |
+
self.nodes = [ [2.5], [1.5, 1], [0.5, 1, 0.5, 0.5] ]
|
| 43 |
+
|
| 44 |
+
For conciseness, we allocate arrays as powers of two, and pad the excess
|
| 45 |
+
elements with zero values.
|
| 46 |
+
|
| 47 |
+
This is similar to the usual array-based representation of a complete binary
|
| 48 |
+
tree, but is a little more user-friendly.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, capacity, nodes=None):
|
| 52 |
+
"""Creates the sum tree data structure for the given replay capacity.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
capacity: int, the maximum number of elements that can be stored in this
|
| 56 |
+
data structure.
|
| 57 |
+
nodes: storage list for storing nodes
|
| 58 |
+
|
| 59 |
+
Raises:
|
| 60 |
+
ValueError: If requested capacity is not positive.
|
| 61 |
+
"""
|
| 62 |
+
assert isinstance(capacity, int)
|
| 63 |
+
if capacity <= 0:
|
| 64 |
+
raise ValueError('Sum tree capacity should be positive. Got: {}'.
|
| 65 |
+
format(capacity))
|
| 66 |
+
|
| 67 |
+
self.nodes = [] if nodes is None else nodes
|
| 68 |
+
tree_depth = int(math.ceil(np.log2(capacity)))
|
| 69 |
+
level_size = 1
|
| 70 |
+
for _ in range(tree_depth + 1):
|
| 71 |
+
nodes_at_this_depth = np.zeros(level_size)
|
| 72 |
+
self.nodes.append(nodes_at_this_depth)
|
| 73 |
+
|
| 74 |
+
level_size *= 2
|
| 75 |
+
|
| 76 |
+
self.max_recorded_priority = 1.0
|
| 77 |
+
|
| 78 |
+
def _total_priority(self):
|
| 79 |
+
"""Returns the sum of all priorities stored in this sum tree.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
float, sum of priorities stored in this sum tree.
|
| 83 |
+
"""
|
| 84 |
+
return self.nodes[0][0]
|
| 85 |
+
|
| 86 |
+
def sample(self, query_value=None):
|
| 87 |
+
"""Samples an element from the sum tree.
|
| 88 |
+
|
| 89 |
+
Each element has probability p_i / sum_j p_j of being picked, where p_i is
|
| 90 |
+
the (positive) value associated with node i (possibly unnormalized).
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
query_value: float in [0, 1], used as the random value to select a
|
| 94 |
+
sample. If None, will select one randomly in [0, 1).
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
int, a random element from the sum tree.
|
| 98 |
+
|
| 99 |
+
Raises:
|
| 100 |
+
Exception: If the sum tree is empty (i.e. its node values sum to 0), or if
|
| 101 |
+
the supplied query_value is larger than the total sum.
|
| 102 |
+
"""
|
| 103 |
+
if self._total_priority() == 0.0:
|
| 104 |
+
raise Exception('Cannot sample from an empty sum tree.')
|
| 105 |
+
|
| 106 |
+
if query_value and (query_value < 0. or query_value > 1.):
|
| 107 |
+
raise ValueError('query_value must be in [0, 1].')
|
| 108 |
+
|
| 109 |
+
# Sample a value in range [0, R), where R is the value stored at the root.
|
| 110 |
+
query_value = random.random() if query_value is None else query_value
|
| 111 |
+
query_value *= self._total_priority()
|
| 112 |
+
|
| 113 |
+
# Now traverse the sum tree.
|
| 114 |
+
node_index = 0
|
| 115 |
+
for nodes_at_this_depth in self.nodes[1:]:
|
| 116 |
+
# Compute children of previous depth's node.
|
| 117 |
+
left_child = node_index * 2
|
| 118 |
+
|
| 119 |
+
left_sum = nodes_at_this_depth[left_child]
|
| 120 |
+
# Each subtree describes a range [0, a), where a is its value.
|
| 121 |
+
if query_value < left_sum: # Recurse into left subtree.
|
| 122 |
+
node_index = left_child
|
| 123 |
+
else: # Recurse into right subtree.
|
| 124 |
+
node_index = left_child + 1
|
| 125 |
+
# Adjust query to be relative to right subtree.
|
| 126 |
+
query_value -= left_sum
|
| 127 |
+
|
| 128 |
+
return node_index
|
| 129 |
+
|
| 130 |
+
def stratified_sample(self, batch_size):
|
| 131 |
+
"""Performs stratified sampling using the sum tree.
|
| 132 |
+
|
| 133 |
+
Let R be the value at the root (total value of sum tree). This method will
|
| 134 |
+
divide [0, R) into batch_size segments, pick a random number from each of
|
| 135 |
+
those segments, and use that random number to sample from the sum_tree. This
|
| 136 |
+
is as specified in Schaul et al. (2015).
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
batch_size: int, the number of strata to use.
|
| 140 |
+
Returns:
|
| 141 |
+
list of batch_size elements sampled from the sum tree.
|
| 142 |
+
|
| 143 |
+
Raises:
|
| 144 |
+
Exception: If the sum tree is empty (i.e. its node values sum to 0).
|
| 145 |
+
"""
|
| 146 |
+
if self._total_priority() == 0.0:
|
| 147 |
+
raise Exception('Cannot sample from an empty sum tree.')
|
| 148 |
+
|
| 149 |
+
bounds = np.linspace(0., 1., batch_size + 1)
|
| 150 |
+
assert len(bounds) == batch_size + 1
|
| 151 |
+
segments = [(bounds[i], bounds[i + 1]) for i in range(batch_size)]
|
| 152 |
+
# TODO removed for now
|
| 153 |
+
# query_values = [random.uniform(x[0], x[1]) for x in segments]
|
| 154 |
+
query_values = [random.uniform(0, 1) for x in segments]
|
| 155 |
+
return [self.sample(query_value=x) for x in query_values]
|
| 156 |
+
|
| 157 |
+
def get(self, node_index):
|
| 158 |
+
"""Returns the value of the leaf node corresponding to the index.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
node_index: The index of the leaf node.
|
| 162 |
+
Returns:
|
| 163 |
+
The value of the leaf node.
|
| 164 |
+
"""
|
| 165 |
+
return self.nodes[-1][node_index]
|
| 166 |
+
|
| 167 |
+
def set(self, node_index, value):
|
| 168 |
+
"""Sets the value of a leaf node and updates internal nodes accordingly.
|
| 169 |
+
|
| 170 |
+
This operation takes O(log(capacity)).
|
| 171 |
+
Args:
|
| 172 |
+
node_index: int, the index of the leaf node to be updated.
|
| 173 |
+
value: float, the value which we assign to the node. This value must be
|
| 174 |
+
nonnegative. Setting value = 0 will cause the element to never be
|
| 175 |
+
sampled.
|
| 176 |
+
|
| 177 |
+
Raises:
|
| 178 |
+
ValueError: If the given value is negative.
|
| 179 |
+
"""
|
| 180 |
+
if value < 0.0:
|
| 181 |
+
raise ValueError('Sum tree values should be nonnegative. Got {}'.
|
| 182 |
+
format(value))
|
| 183 |
+
self.max_recorded_priority = max(value, self.max_recorded_priority)
|
| 184 |
+
|
| 185 |
+
delta_value = value - self.nodes[-1][node_index]
|
| 186 |
+
|
| 187 |
+
# # Now traverse back the tree, adjusting all sums along the way.
|
| 188 |
+
# for nodes_at_this_depth in reversed(self.nodes):
|
| 189 |
+
# # Note: Adding a delta leads to some tolerable numerical inaccuracies.
|
| 190 |
+
# nodes_at_this_depth[node_index] += delta_value
|
| 191 |
+
# self.nodes[]
|
| 192 |
+
# node_index //= 2
|
| 193 |
+
|
| 194 |
+
for nodes_at_this_depth_idx in reversed(range(len(self.nodes))):
|
| 195 |
+
nodes_at_this_depth = self.nodes[nodes_at_this_depth_idx]
|
| 196 |
+
nodes_at_this_depth[node_index] += delta_value
|
| 197 |
+
self.nodes[nodes_at_this_depth_idx] = nodes_at_this_depth
|
| 198 |
+
node_index //= 2
|
| 199 |
+
|
| 200 |
+
assert node_index == 0, ('Sum tree traversal failed, final node index '
|
| 201 |
+
'is not 0.')
|
external/yarr/yarr/replay_buffer/task_uniform_replay_buffer.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
from os.path import join
|
| 4 |
+
import pickle
|
| 5 |
+
import math
|
| 6 |
+
from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer
|
| 7 |
+
from yarr.replay_buffer.uniform_replay_buffer import invalid_range
|
| 8 |
+
|
| 9 |
+
from yarr.replay_buffer.replay_buffer import ReplayBuffer, ReplayElement
|
| 10 |
+
from yarr.utils.observation_type import ObservationElement
|
| 11 |
+
|
| 12 |
+
ACTION = 'action'
|
| 13 |
+
REWARD = 'reward'
|
| 14 |
+
TERMINAL = 'terminal'
|
| 15 |
+
TIMEOUT = 'timeout'
|
| 16 |
+
INDICES = 'indices'
|
| 17 |
+
TASK = 'task'
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TaskUniformReplayBuffer(UniformReplayBuffer):
|
| 21 |
+
"""
|
| 22 |
+
A uniform with uniform task sampling for each batch
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, *args, **kwargs):
|
| 26 |
+
"""Initializes OutOfGraphPrioritizedReplayBuffer."""
|
| 27 |
+
super(TaskUniformReplayBuffer, self).__init__(*args, **kwargs)
|
| 28 |
+
self._task_idxs = dict()
|
| 29 |
+
|
| 30 |
+
def _add(self, kwargs: dict):
|
| 31 |
+
"""Internal add method to add to the storage arrays.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
kwargs: All the elements in a transition.
|
| 35 |
+
"""
|
| 36 |
+
with self._lock:
|
| 37 |
+
cursor = self.cursor()
|
| 38 |
+
|
| 39 |
+
if self._disk_saving:
|
| 40 |
+
term = self._store[TERMINAL]
|
| 41 |
+
term[cursor] = kwargs[TERMINAL]
|
| 42 |
+
self._store[TERMINAL] = term
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
## reduce size
|
| 46 |
+
for k, v in kwargs.items():
|
| 47 |
+
try:
|
| 48 |
+
if 'float' in v.dtype.name and v.size > 100:
|
| 49 |
+
v = v.astype(np.float16)
|
| 50 |
+
kwargs[k] = v
|
| 51 |
+
except:
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
with open(join(self._save_dir, '%d.replay' % cursor), 'wb') as f:
|
| 56 |
+
pickle.dump(kwargs, f)
|
| 57 |
+
# If first add, then pad for correct wrapping
|
| 58 |
+
if self._add_count.value == 0:
|
| 59 |
+
self._add_initial_to_disk(kwargs)
|
| 60 |
+
else:
|
| 61 |
+
for name, data in kwargs.items():
|
| 62 |
+
item = self._store[name]
|
| 63 |
+
item[cursor] = data
|
| 64 |
+
self._store[name] = item
|
| 65 |
+
with self._add_count.get_lock():
|
| 66 |
+
task = kwargs[TASK]
|
| 67 |
+
if task not in self._task_idxs:
|
| 68 |
+
self._task_idxs[task] = [cursor]
|
| 69 |
+
else:
|
| 70 |
+
self._task_idxs[task] = self._task_idxs[task] + [cursor]
|
| 71 |
+
self._add_count.value += 1
|
| 72 |
+
|
| 73 |
+
self.invalid_range = invalid_range(
|
| 74 |
+
self.cursor(), self._replay_capacity, self._timesteps,
|
| 75 |
+
self._update_horizon)
|
| 76 |
+
|
| 77 |
+
def sample_index_batch(self,
|
| 78 |
+
batch_size):
|
| 79 |
+
"""Returns a batch of valid indices sampled uniformly.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
batch_size: int, number of indices returned.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
list of ints, a batch of valid indices sampled uniformly across tasks.
|
| 86 |
+
|
| 87 |
+
Raises:
|
| 88 |
+
RuntimeError: If the batch was not constructed after maximum number of
|
| 89 |
+
tries.
|
| 90 |
+
"""
|
| 91 |
+
if self.is_full():
|
| 92 |
+
min_id = (self.cursor() - self._replay_capacity +
|
| 93 |
+
self._timesteps - 1)
|
| 94 |
+
max_id = self.cursor() - self._update_horizon
|
| 95 |
+
else:
|
| 96 |
+
min_id = 0
|
| 97 |
+
max_id = self.cursor() - self._update_horizon
|
| 98 |
+
if max_id <= min_id:
|
| 99 |
+
raise RuntimeError(
|
| 100 |
+
'Cannot sample a batch with fewer than stack size '
|
| 101 |
+
'({}) + update_horizon ({}) transitions.'.
|
| 102 |
+
format(self._timesteps, self._update_horizon))
|
| 103 |
+
|
| 104 |
+
tasks = list(self._task_idxs.keys())
|
| 105 |
+
attempt_count = 0
|
| 106 |
+
found_indicies = False
|
| 107 |
+
|
| 108 |
+
# uniform distribution of tasks
|
| 109 |
+
while not found_indicies and attempt_count < 1000:
|
| 110 |
+
# sample random tasks of batch_size length
|
| 111 |
+
sampled_tasks = list(np.random.choice(tasks, batch_size, replace=(batch_size > len(tasks))))
|
| 112 |
+
potential_indices = []
|
| 113 |
+
for task in sampled_tasks:
|
| 114 |
+
# DDP setting where each GPU only sees a fraction of the data
|
| 115 |
+
# reference: https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py
|
| 116 |
+
task_data_size = len(self._task_idxs[task])
|
| 117 |
+
num_samples = math.ceil(task_data_size / self._num_replicas)
|
| 118 |
+
total_size = num_samples * self._num_replicas
|
| 119 |
+
task_indices = self._task_idxs[task][self._rank:total_size:self._num_replicas]
|
| 120 |
+
|
| 121 |
+
sampled_task_idx = np.random.choice(task_indices, 1)[0]
|
| 122 |
+
per_task_attempt_count = 0
|
| 123 |
+
|
| 124 |
+
# Argh.. this is slow
|
| 125 |
+
while not self.is_valid_transition(sampled_task_idx) and \
|
| 126 |
+
per_task_attempt_count < self._max_sample_attempts:
|
| 127 |
+
sampled_task_idx = np.random.choice(task_indices, 1)[0]
|
| 128 |
+
per_task_attempt_count += 1
|
| 129 |
+
|
| 130 |
+
if not self.is_valid_transition(sampled_task_idx):
|
| 131 |
+
attempt_count += 1
|
| 132 |
+
continue
|
| 133 |
+
else:
|
| 134 |
+
potential_indices.append(sampled_task_idx)
|
| 135 |
+
found_indicies = len(potential_indices) == batch_size
|
| 136 |
+
indices = potential_indices
|
| 137 |
+
|
| 138 |
+
if len(indices) != batch_size:
|
| 139 |
+
raise RuntimeError(
|
| 140 |
+
'Max sample attempts: Tried {} times but only sampled {}'
|
| 141 |
+
' valid indices. Batch size is {}'.
|
| 142 |
+
format(self._max_sample_attempts, len(indices), batch_size))
|
| 143 |
+
|
| 144 |
+
return indices
|
| 145 |
+
|
| 146 |
+
def get_transition_elements(self, batch_size=None):
|
| 147 |
+
"""Returns a 'type signature' for sample_transition_batch.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
batch_size: int, number of transitions returned. If None, the default
|
| 151 |
+
batch_size will be used.
|
| 152 |
+
Returns:
|
| 153 |
+
signature: A namedtuple describing the method's return type signature.
|
| 154 |
+
"""
|
| 155 |
+
batch_size = self._batch_size if batch_size is None else batch_size
|
| 156 |
+
|
| 157 |
+
transition_elements = [
|
| 158 |
+
ReplayElement(ACTION, (batch_size, self._timesteps) + self._action_shape,
|
| 159 |
+
self._action_dtype),
|
| 160 |
+
ReplayElement(REWARD, (batch_size, self._timesteps) + self._reward_shape,
|
| 161 |
+
self._reward_dtype),
|
| 162 |
+
ReplayElement(TERMINAL, (batch_size, self._timesteps), np.int8),
|
| 163 |
+
ReplayElement(TIMEOUT, (batch_size, self._timesteps), bool),
|
| 164 |
+
ReplayElement(INDICES, (batch_size, self._timesteps), np.int32),
|
| 165 |
+
]
|
| 166 |
+
|
| 167 |
+
for element in self._observation_elements:
|
| 168 |
+
transition_elements.append(ReplayElement(
|
| 169 |
+
element.name,
|
| 170 |
+
(batch_size, self._timesteps) + tuple(element.shape),
|
| 171 |
+
element.type, True))
|
| 172 |
+
transition_elements.append(ReplayElement(
|
| 173 |
+
element.name + '_tp1',
|
| 174 |
+
(batch_size, self._timesteps) + tuple(element.shape),
|
| 175 |
+
element.type, True))
|
| 176 |
+
|
| 177 |
+
for element in self._extra_replay_elements:
|
| 178 |
+
transition_elements.append(ReplayElement(
|
| 179 |
+
element.name,
|
| 180 |
+
(batch_size,) + tuple(element.shape),
|
| 181 |
+
element.type))
|
| 182 |
+
return transition_elements
|
external/yarr/yarr/replay_buffer/uniform_replay_buffer.py
ADDED
|
@@ -0,0 +1,804 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""The standard DQN replay memory.
|
| 2 |
+
|
| 3 |
+
This implementation is an out-of-graph replay memory + in-graph wrapper. It
|
| 4 |
+
supports vanilla n-step updates of the form typically found in the literature,
|
| 5 |
+
i.e. where rewards are accumulated for n steps and the intermediate trajectory
|
| 6 |
+
is not exposed to the agent. This does not allow, for example, performing
|
| 7 |
+
off-policy corrections.
|
| 8 |
+
"""
|
| 9 |
+
import ctypes
|
| 10 |
+
import collections
|
| 11 |
+
import concurrent.futures
|
| 12 |
+
import os
|
| 13 |
+
from os.path import join
|
| 14 |
+
import pickle
|
| 15 |
+
from typing import List, Tuple, Type
|
| 16 |
+
import time
|
| 17 |
+
import math
|
| 18 |
+
# from threading import Lock
|
| 19 |
+
import multiprocessing as mp
|
| 20 |
+
from multiprocessing import Lock
|
| 21 |
+
import numpy as np
|
| 22 |
+
import logging
|
| 23 |
+
|
| 24 |
+
from natsort import natsort
|
| 25 |
+
|
| 26 |
+
from yarr.replay_buffer.replay_buffer import ReplayBuffer, ReplayElement
|
| 27 |
+
from yarr.utils.observation_type import ObservationElement
|
| 28 |
+
|
| 29 |
+
import torch.distributed as dist
|
| 30 |
+
|
| 31 |
+
# Defines a type describing part of the tuple returned by the replay
|
| 32 |
+
# memory. Each element of the tuple is a tensor of shape [batch, ...] where
|
| 33 |
+
# ... is defined the 'shape' field of ReplayElement. The tensor type is
|
| 34 |
+
# given by the 'type' field. The 'name' field is for convenience and ease of
|
| 35 |
+
# debugging.
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# String constants for storage
|
| 39 |
+
ACTION = 'action'
|
| 40 |
+
REWARD = 'reward'
|
| 41 |
+
TERMINAL = 'terminal'
|
| 42 |
+
TIMEOUT = 'timeout'
|
| 43 |
+
INDICES = 'indices'
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def invalid_range(cursor, replay_capacity, stack_size, update_horizon):
|
| 47 |
+
"""Returns a array with the indices of cursor-related invalid transitions.
|
| 48 |
+
|
| 49 |
+
There are update_horizon + stack_size invalid indices:
|
| 50 |
+
- The update_horizon indices before the cursor, because we do not have a
|
| 51 |
+
valid N-step transition (including the next state).
|
| 52 |
+
- The stack_size indices on or immediately after the cursor.
|
| 53 |
+
If N = update_horizon, K = stack_size, and the cursor is at c, invalid
|
| 54 |
+
indices are:
|
| 55 |
+
c - N, c - N + 1, ..., c, c + 1, ..., c + K - 1.
|
| 56 |
+
|
| 57 |
+
It handles special cases in a circular buffer in the beginning and the end.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
cursor: int, the position of the cursor.
|
| 61 |
+
replay_capacity: int, the size of the replay memory.
|
| 62 |
+
stack_size: int, the size of the stacks returned by the replay memory.
|
| 63 |
+
update_horizon: int, the agent's update horizon.
|
| 64 |
+
Returns:
|
| 65 |
+
np.array of size stack_size with the invalid indices.
|
| 66 |
+
"""
|
| 67 |
+
assert cursor < replay_capacity
|
| 68 |
+
return np.array(
|
| 69 |
+
[(cursor - update_horizon + i) % replay_capacity
|
| 70 |
+
for i in range(stack_size + update_horizon)])
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class UniformReplayBuffer(ReplayBuffer):
|
| 74 |
+
"""A simple out-of-graph Replay Buffer.
|
| 75 |
+
|
| 76 |
+
Stores transitions, state, action, reward, next_state, terminal (and any
|
| 77 |
+
extra contents specified) in a circular buffer and provides a uniform
|
| 78 |
+
transition sampling function.
|
| 79 |
+
|
| 80 |
+
When the states consist of stacks of observations storing the states is
|
| 81 |
+
inefficient. This class writes observations and constructs the stacked states
|
| 82 |
+
at sample time.
|
| 83 |
+
|
| 84 |
+
Attributes:
|
| 85 |
+
_add_count: int, counter of how many transitions have been added (including
|
| 86 |
+
the blank ones at the beginning of an episode).
|
| 87 |
+
invalid_range: np.array, an array with the indices of cursor-related invalid
|
| 88 |
+
transitions
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self,
|
| 92 |
+
batch_size: int = 32,
|
| 93 |
+
timesteps: int = 1,
|
| 94 |
+
replay_capacity: int = int(1e6),
|
| 95 |
+
update_horizon: int = 1,
|
| 96 |
+
gamma: float = 0.99,
|
| 97 |
+
max_sample_attempts: int = 10000,
|
| 98 |
+
action_shape: tuple = (),
|
| 99 |
+
action_dtype: Type[np.dtype] = np.float32,
|
| 100 |
+
reward_shape: tuple = (),
|
| 101 |
+
reward_dtype: Type[np.dtype] = np.float32,
|
| 102 |
+
observation_elements: List[ObservationElement] = None,
|
| 103 |
+
extra_replay_elements: List[ReplayElement] = None,
|
| 104 |
+
save_dir: str = None,
|
| 105 |
+
purge_replay_on_shutdown: bool = True,
|
| 106 |
+
num_replicas: int = None,
|
| 107 |
+
rank: int = None,
|
| 108 |
+
):
|
| 109 |
+
"""Initializes OutOfGraphReplayBuffer.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
batch_size: int.
|
| 113 |
+
timesteps: int, number of frames to use in state stack.
|
| 114 |
+
replay_capacity: int, number of transitions to keep in memory.
|
| 115 |
+
update_horizon: int, length of update ('n' in n-step update).
|
| 116 |
+
gamma: int, the discount factor.
|
| 117 |
+
max_sample_attempts: int, the maximum number of attempts allowed to
|
| 118 |
+
get a sample.
|
| 119 |
+
action_shape: tuple of ints, the shape for the action vector.
|
| 120 |
+
Empty tuple means the action is a scalar.
|
| 121 |
+
action_dtype: np.dtype, type of elements in the action.
|
| 122 |
+
reward_shape: tuple of ints, the shape of the reward vector.
|
| 123 |
+
Empty tuple means the reward is a scalar.
|
| 124 |
+
reward_dtype: np.dtype, type of elements in the reward.
|
| 125 |
+
observation_elements: list of ObservationElement defining the type of
|
| 126 |
+
the extra contents that will be stored and returned.
|
| 127 |
+
extra_storage_elements: list of ReplayElement defining the type of
|
| 128 |
+
the extra contents that will be stored and returned.
|
| 129 |
+
|
| 130 |
+
Raises:
|
| 131 |
+
ValueError: If replay_capacity is too small to hold at least one
|
| 132 |
+
transition.
|
| 133 |
+
"""
|
| 134 |
+
if num_replicas is None:
|
| 135 |
+
if not dist.is_available():
|
| 136 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 137 |
+
self._num_replicas = dist.get_world_size()
|
| 138 |
+
if rank is None:
|
| 139 |
+
if not dist.is_available():
|
| 140 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 141 |
+
self._rank = dist.get_rank()
|
| 142 |
+
if self._rank >= self._num_replicas or self._rank < 0:
|
| 143 |
+
raise ValueError(
|
| 144 |
+
"Invalid rank {}, rank should be in the interval"
|
| 145 |
+
" [0, {}]".format(self._rank, self._num_replicas - 1))
|
| 146 |
+
|
| 147 |
+
if observation_elements is None:
|
| 148 |
+
observation_elements = []
|
| 149 |
+
if extra_replay_elements is None:
|
| 150 |
+
extra_replay_elements = []
|
| 151 |
+
|
| 152 |
+
if replay_capacity < update_horizon + timesteps:
|
| 153 |
+
raise ValueError('There is not enough capacity to cover '
|
| 154 |
+
'update_horizon and stack_size.')
|
| 155 |
+
|
| 156 |
+
logging.info(
|
| 157 |
+
'Creating a %s replay memory with the following parameters:',
|
| 158 |
+
self.__class__.__name__)
|
| 159 |
+
logging.info('\t timesteps: %d', timesteps)
|
| 160 |
+
logging.info('\t replay_capacity: %d', replay_capacity)
|
| 161 |
+
logging.info('\t batch_size: %d', batch_size)
|
| 162 |
+
logging.info('\t update_horizon: %d', update_horizon)
|
| 163 |
+
logging.info('\t gamma: %f', gamma)
|
| 164 |
+
|
| 165 |
+
self._disk_saving = save_dir is not None
|
| 166 |
+
self._save_dir = save_dir
|
| 167 |
+
self._purge_replay_on_shutdown = purge_replay_on_shutdown
|
| 168 |
+
if self._disk_saving:
|
| 169 |
+
logging.info('\t saving to disk: %s', self._save_dir)
|
| 170 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 171 |
+
else:
|
| 172 |
+
logging.info('\t saving to RAM')
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
self._action_shape = action_shape
|
| 176 |
+
self._action_dtype = action_dtype
|
| 177 |
+
self._reward_shape = reward_shape
|
| 178 |
+
self._reward_dtype = reward_dtype
|
| 179 |
+
self._timesteps = timesteps
|
| 180 |
+
self._replay_capacity = replay_capacity
|
| 181 |
+
self._batch_size = batch_size
|
| 182 |
+
self._update_horizon = update_horizon
|
| 183 |
+
self._gamma = gamma
|
| 184 |
+
self._max_sample_attempts = max_sample_attempts
|
| 185 |
+
|
| 186 |
+
self._observation_elements = observation_elements
|
| 187 |
+
self._extra_replay_elements = extra_replay_elements
|
| 188 |
+
|
| 189 |
+
self._storage_signature, self._obs_signature = self.get_storage_signature()
|
| 190 |
+
self._create_storage()
|
| 191 |
+
|
| 192 |
+
self._lock = Lock()
|
| 193 |
+
self._add_count = mp.Value('i', 0)
|
| 194 |
+
|
| 195 |
+
self._replay_capacity = replay_capacity
|
| 196 |
+
|
| 197 |
+
self.invalid_range = np.zeros((self._timesteps))
|
| 198 |
+
|
| 199 |
+
# When the horizon is > 1, we compute the sum of discounted rewards as a dot
|
| 200 |
+
# product using the precomputed vector <gamma^0, gamma^1, ..., gamma^{n-1}>.
|
| 201 |
+
self._cumulative_discount_vector = np.array(
|
| 202 |
+
[math.pow(self._gamma, n) for n in range(update_horizon)],
|
| 203 |
+
dtype=np.float32)
|
| 204 |
+
|
| 205 |
+
@property
|
| 206 |
+
def timesteps(self):
|
| 207 |
+
return self._timesteps
|
| 208 |
+
|
| 209 |
+
@property
|
| 210 |
+
def replay_capacity(self):
|
| 211 |
+
return self._replay_capacity
|
| 212 |
+
|
| 213 |
+
@property
|
| 214 |
+
def batch_size(self):
|
| 215 |
+
return self._batch_size
|
| 216 |
+
|
| 217 |
+
def _create_storage(self, store=None):
|
| 218 |
+
"""Creates the numpy arrays used to store transitions.
|
| 219 |
+
"""
|
| 220 |
+
self._store = {} if store is None else store
|
| 221 |
+
for storage_element in self._storage_signature:
|
| 222 |
+
array_shape = [self._replay_capacity] + list(storage_element.shape)
|
| 223 |
+
if storage_element.name == TERMINAL:
|
| 224 |
+
self._store[storage_element.name] = np.full(
|
| 225 |
+
array_shape, -1, dtype=storage_element.type)
|
| 226 |
+
elif not self._disk_saving:
|
| 227 |
+
# If saving to disk, we don't need to store anything else.
|
| 228 |
+
self._store[storage_element.name] = np.empty(
|
| 229 |
+
array_shape, dtype=storage_element.type)
|
| 230 |
+
|
| 231 |
+
def get_storage_signature(self) -> Tuple[List[ReplayElement],
|
| 232 |
+
List[ReplayElement]]:
|
| 233 |
+
"""Returns a default list of elements to be stored in this replay memory.
|
| 234 |
+
|
| 235 |
+
Note - Derived classes may return a different signature.
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
dict of ReplayElements defining the type of the contents stored.
|
| 239 |
+
"""
|
| 240 |
+
storage_elements = [
|
| 241 |
+
ReplayElement(ACTION, self._action_shape, self._action_dtype),
|
| 242 |
+
ReplayElement(REWARD, self._reward_shape, self._reward_dtype),
|
| 243 |
+
ReplayElement(TERMINAL, (), np.int8),
|
| 244 |
+
ReplayElement(TIMEOUT, (), bool),
|
| 245 |
+
]
|
| 246 |
+
|
| 247 |
+
obs_elements = []
|
| 248 |
+
for obs_element in self._observation_elements:
|
| 249 |
+
obs_elements.append(
|
| 250 |
+
ReplayElement(
|
| 251 |
+
obs_element.name, obs_element.shape, obs_element.type))
|
| 252 |
+
storage_elements.extend(obs_elements)
|
| 253 |
+
|
| 254 |
+
for extra_replay_element in self._extra_replay_elements:
|
| 255 |
+
storage_elements.append(extra_replay_element)
|
| 256 |
+
|
| 257 |
+
return storage_elements, obs_elements
|
| 258 |
+
|
| 259 |
+
def add(self, action, reward, terminal, timeout, **kwargs):
|
| 260 |
+
"""Adds a transition to the replay memory.
|
| 261 |
+
|
| 262 |
+
WE ONLY STORE THE TPS1s on the final frame
|
| 263 |
+
|
| 264 |
+
This function checks the types and handles the padding at the beginning of
|
| 265 |
+
an episode. Then it calls the _add function.
|
| 266 |
+
|
| 267 |
+
Since the next_observation in the transition will be the observation added
|
| 268 |
+
next there is no need to pass it.
|
| 269 |
+
|
| 270 |
+
If the replay memory is at capacity the oldest transition will be discarded.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
action: int, the action in the transition.
|
| 274 |
+
reward: float, the reward received in the transition.
|
| 275 |
+
terminal: A uint8 acting as a boolean indicating whether the transition
|
| 276 |
+
was terminal (1) or not (0).
|
| 277 |
+
**kwargs: The remaining args
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
# If previous transition was a terminal, then add_final wasn't called
|
| 281 |
+
# if not self.is_empty() and self._store['terminal'][self.cursor() - 1] == 1:
|
| 282 |
+
# raise ValueError('The previous transition was a terminal, '
|
| 283 |
+
# 'but add_final was not called.')
|
| 284 |
+
|
| 285 |
+
kwargs[ACTION] = action
|
| 286 |
+
kwargs[REWARD] = reward
|
| 287 |
+
kwargs[TERMINAL] = terminal
|
| 288 |
+
kwargs[TIMEOUT] = timeout
|
| 289 |
+
self._check_add_types(kwargs, self._storage_signature)
|
| 290 |
+
self._add(kwargs)
|
| 291 |
+
|
| 292 |
+
def add_final(self, **kwargs):
|
| 293 |
+
"""Adds a transition to the replay memory.
|
| 294 |
+
Args:
|
| 295 |
+
**kwargs: The remaining args
|
| 296 |
+
"""
|
| 297 |
+
# if self.is_empty() or self._store['terminal'][self.cursor() - 1] != 1:
|
| 298 |
+
# raise ValueError('The previous transition was not terminal.')
|
| 299 |
+
self._check_add_types(kwargs, self._obs_signature)
|
| 300 |
+
transition = self._final_transition(kwargs)
|
| 301 |
+
self._add(transition)
|
| 302 |
+
|
| 303 |
+
def _final_transition(self, kwargs):
|
| 304 |
+
transition = {}
|
| 305 |
+
for element_type in self._storage_signature:
|
| 306 |
+
if element_type.name in kwargs:
|
| 307 |
+
transition[element_type.name] = kwargs[element_type.name]
|
| 308 |
+
elif element_type.name == TERMINAL:
|
| 309 |
+
# Used to check that user is correctly adding transitions
|
| 310 |
+
transition[element_type.name] = -1
|
| 311 |
+
else:
|
| 312 |
+
transition[element_type.name] = np.empty(
|
| 313 |
+
element_type.shape, dtype=element_type.type)
|
| 314 |
+
return transition
|
| 315 |
+
|
| 316 |
+
def _add_initial_to_disk(self ,kwargs: dict):
|
| 317 |
+
for i in range(self._timesteps - 1):
|
| 318 |
+
with open(join(self._save_dir, '%d.replay' % (
|
| 319 |
+
self._replay_capacity - 1 - i)), 'wb') as f:
|
| 320 |
+
pickle.dump(kwargs, f)
|
| 321 |
+
|
| 322 |
+
def _add(self, kwargs: dict):
|
| 323 |
+
"""Internal add method to add to the storage arrays.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
kwargs: All the elements in a transition.
|
| 327 |
+
"""
|
| 328 |
+
with self._lock:
|
| 329 |
+
cursor = self.cursor()
|
| 330 |
+
|
| 331 |
+
if self._disk_saving:
|
| 332 |
+
term = self._store[TERMINAL]
|
| 333 |
+
term[cursor] = kwargs[TERMINAL]
|
| 334 |
+
self._store[TERMINAL] = term
|
| 335 |
+
with open(join(self._save_dir, '%d.replay' % cursor), 'wb') as f:
|
| 336 |
+
pickle.dump(kwargs, f)
|
| 337 |
+
# If first add, then pad for correct wrapping
|
| 338 |
+
if self._add_count.value == 0:
|
| 339 |
+
self._add_initial_to_disk(kwargs)
|
| 340 |
+
else:
|
| 341 |
+
for name, data in kwargs.items():
|
| 342 |
+
item = self._store[name]
|
| 343 |
+
item[cursor] = data
|
| 344 |
+
self._store[name] = item
|
| 345 |
+
with self._add_count.get_lock():
|
| 346 |
+
self._add_count.value += 1
|
| 347 |
+
self.invalid_range = invalid_range(
|
| 348 |
+
self.cursor(), self._replay_capacity, self._timesteps,
|
| 349 |
+
self._update_horizon)
|
| 350 |
+
|
| 351 |
+
def _get_from_disk(self, start_index, end_index):
|
| 352 |
+
"""Returns the range of array at the index handling wraparound if necessary.
|
| 353 |
+
|
| 354 |
+
Args:
|
| 355 |
+
start_index: int, index to the start of the range to be returned. Range
|
| 356 |
+
will wraparound if start_index is smaller than 0.
|
| 357 |
+
end_index: int, exclusive end index. Range will wraparound if end_index
|
| 358 |
+
exceeds replay_capacity.
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
np.array, with shape [end_index - start_index, array.shape[1:]].
|
| 362 |
+
"""
|
| 363 |
+
assert end_index > start_index, 'end_index must be larger than start_index'
|
| 364 |
+
assert end_index >= 0
|
| 365 |
+
assert start_index < self._replay_capacity
|
| 366 |
+
if not self.is_full():
|
| 367 |
+
assert end_index <= self.cursor(), (
|
| 368 |
+
'Index {} has not been added.'.format(start_index))
|
| 369 |
+
|
| 370 |
+
# Here we fake a mini store (buffer)
|
| 371 |
+
store = {store_element.name: {}
|
| 372 |
+
for store_element in self._storage_signature}
|
| 373 |
+
if start_index % self._replay_capacity < end_index % self._replay_capacity:
|
| 374 |
+
for i in range(start_index, end_index):
|
| 375 |
+
with open(join(self._save_dir, '%d.replay' % i), 'rb') as f:
|
| 376 |
+
d = pickle.load(f)
|
| 377 |
+
for k, v in d.items():
|
| 378 |
+
store[k][i] = v
|
| 379 |
+
else:
|
| 380 |
+
for i in range(end_index - start_index):
|
| 381 |
+
idx = (start_index + i) % self._replay_capacity
|
| 382 |
+
with open(join(self._save_dir, '%d.replay' % idx), 'rb') as f:
|
| 383 |
+
d = pickle.load(f)
|
| 384 |
+
for k, v in d.items():
|
| 385 |
+
store[k][idx] = v
|
| 386 |
+
return store
|
| 387 |
+
|
| 388 |
+
def _check_add_types(self, kwargs, signature):
|
| 389 |
+
"""Checks if args passed to the add method match those of the storage.
|
| 390 |
+
|
| 391 |
+
Args:
|
| 392 |
+
*args: Args whose types need to be validated.
|
| 393 |
+
|
| 394 |
+
Raises:
|
| 395 |
+
ValueError: If args have wrong shape or dtype.
|
| 396 |
+
"""
|
| 397 |
+
|
| 398 |
+
if (len(kwargs)) != len(signature):
|
| 399 |
+
expected = str(natsort.natsorted([e.name for e in signature]))
|
| 400 |
+
actual = str(natsort.natsorted(list(kwargs.keys())))
|
| 401 |
+
error_list = '\nList of expected:\n{}\nList of actual:\n{}'.format(
|
| 402 |
+
expected, actual)
|
| 403 |
+
raise ValueError('Add expects {} elements, received {}.'.format(
|
| 404 |
+
len(signature), len(kwargs)) + error_list)
|
| 405 |
+
|
| 406 |
+
for store_element in signature:
|
| 407 |
+
arg_element = kwargs[store_element.name]
|
| 408 |
+
if isinstance(arg_element, np.ndarray):
|
| 409 |
+
arg_shape = arg_element.shape
|
| 410 |
+
elif isinstance(arg_element, tuple) or isinstance(arg_element, list):
|
| 411 |
+
# TODO: This is not efficient when arg_element is a list.
|
| 412 |
+
arg_shape = np.array(arg_element).shape
|
| 413 |
+
else:
|
| 414 |
+
# Assume it is scalar.
|
| 415 |
+
arg_shape = tuple()
|
| 416 |
+
store_element_shape = tuple(store_element.shape)
|
| 417 |
+
if arg_shape != store_element_shape:
|
| 418 |
+
raise ValueError('arg {} has shape {}, expected {}'.format(store_element.name,
|
| 419 |
+
arg_shape, store_element_shape))
|
| 420 |
+
|
| 421 |
+
def is_empty(self):
|
| 422 |
+
"""Is the Replay Buffer empty?"""
|
| 423 |
+
return self._add_count.value == 0
|
| 424 |
+
|
| 425 |
+
def is_full(self):
|
| 426 |
+
"""Is the Replay Buffer full?"""
|
| 427 |
+
return self._add_count.value >= self._replay_capacity
|
| 428 |
+
|
| 429 |
+
def cursor(self):
|
| 430 |
+
"""Index to the location where the next transition will be written."""
|
| 431 |
+
return self._add_count.value % self._replay_capacity
|
| 432 |
+
|
| 433 |
+
@property
|
| 434 |
+
def add_count(self):
|
| 435 |
+
return np.array(self._add_count.value)
|
| 436 |
+
|
| 437 |
+
@add_count.setter
|
| 438 |
+
def add_count(self, count):
|
| 439 |
+
if isinstance(count, int):
|
| 440 |
+
self._add_count = mp.Value('i', count)
|
| 441 |
+
else:
|
| 442 |
+
self._add_count = count
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def get_range(self, array, start_index, end_index):
|
| 446 |
+
"""Returns the range of array at the index handling wraparound if necessary.
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
array: np.array, the array to get the stack from.
|
| 450 |
+
start_index: int, index to the start of the range to be returned. Range
|
| 451 |
+
will wraparound if start_index is smaller than 0.
|
| 452 |
+
end_index: int, exclusive end index. Range will wraparound if end_index
|
| 453 |
+
exceeds replay_capacity.
|
| 454 |
+
|
| 455 |
+
Returns:
|
| 456 |
+
np.array, with shape [end_index - start_index, array.shape[1:]].
|
| 457 |
+
"""
|
| 458 |
+
assert end_index > start_index, 'end_index must be larger than start_index'
|
| 459 |
+
assert end_index >= 0
|
| 460 |
+
assert start_index < self._replay_capacity
|
| 461 |
+
if not self.is_full():
|
| 462 |
+
assert end_index <= self.cursor(), (
|
| 463 |
+
'Index {} has not been added.'.format(start_index))
|
| 464 |
+
|
| 465 |
+
# Fast slice read when there is no wraparound.
|
| 466 |
+
if start_index % self._replay_capacity < end_index % self._replay_capacity:
|
| 467 |
+
return_array = np.array(
|
| 468 |
+
[array[i] for i in range(start_index, end_index)])
|
| 469 |
+
# Slow list read.
|
| 470 |
+
else:
|
| 471 |
+
indices = [(start_index + i) % self._replay_capacity
|
| 472 |
+
for i in range(end_index - start_index)]
|
| 473 |
+
return_array = np.array([array[i] for i in indices])
|
| 474 |
+
|
| 475 |
+
return return_array
|
| 476 |
+
|
| 477 |
+
def get_range_stack(self, array, start_index, end_index, terminals=None):
|
| 478 |
+
"""Returns the range of array at the index handling wraparound if necessary.
|
| 479 |
+
|
| 480 |
+
Args:
|
| 481 |
+
array: np.array, the array to get the stack from.
|
| 482 |
+
start_index: int, index to the start of the range to be returned. Range
|
| 483 |
+
will wraparound if start_index is smaller than 0.
|
| 484 |
+
end_index: int, exclusive end index. Range will wraparound if end_index
|
| 485 |
+
exceeds replay_capacity.
|
| 486 |
+
|
| 487 |
+
Returns:
|
| 488 |
+
np.array, with shape [end_index - start_index, array.shape[1:]].
|
| 489 |
+
"""
|
| 490 |
+
return_array = np.array(self.get_range(array, start_index, end_index))
|
| 491 |
+
if terminals is None:
|
| 492 |
+
terminals = self.get_range(
|
| 493 |
+
self._store[TERMINAL], start_index, end_index)
|
| 494 |
+
|
| 495 |
+
terminals = terminals[:-1]
|
| 496 |
+
|
| 497 |
+
# Here we now check if we need to pad the front episodes
|
| 498 |
+
# If any have a terminal of -1, then we have spilled over
|
| 499 |
+
# into the the previous transition
|
| 500 |
+
if np.any(terminals == -1):
|
| 501 |
+
padding_item = return_array[-1]
|
| 502 |
+
_array = list(return_array)[:-1]
|
| 503 |
+
arr_len = len(_array)
|
| 504 |
+
pad_from_now = False
|
| 505 |
+
for i, (ar, term) in enumerate(
|
| 506 |
+
zip(reversed(_array), reversed(terminals))):
|
| 507 |
+
if term == -1 or pad_from_now:
|
| 508 |
+
# The first time we see a -1 term, means we have hit the
|
| 509 |
+
# beginning of this episode, so pad from now.
|
| 510 |
+
# pad_from_now needed because the next transition (reverse)
|
| 511 |
+
# will not be a -1 terminal.
|
| 512 |
+
pad_from_now = True
|
| 513 |
+
return_array[arr_len - 1 - i] = padding_item
|
| 514 |
+
else:
|
| 515 |
+
# After we hit out first -1 terminal, we never reassign.
|
| 516 |
+
padding_item = ar
|
| 517 |
+
|
| 518 |
+
return return_array
|
| 519 |
+
|
| 520 |
+
def _get_element_stack(self, array, index, terminals=None):
|
| 521 |
+
state = self.get_range_stack(array,
|
| 522 |
+
index - self._timesteps + 1, index + 1,
|
| 523 |
+
terminals=terminals)
|
| 524 |
+
return state
|
| 525 |
+
|
| 526 |
+
def get_terminal_stack(self, index):
|
| 527 |
+
terminal_stack = self.get_range(self._store[TERMINAL],
|
| 528 |
+
index - self._timesteps + 1,
|
| 529 |
+
index + 1)
|
| 530 |
+
return terminal_stack
|
| 531 |
+
|
| 532 |
+
def is_valid_transition(self, index):
|
| 533 |
+
"""Checks if the index contains a valid transition.
|
| 534 |
+
|
| 535 |
+
Checks for collisions with the end of episodes and the current position
|
| 536 |
+
of the cursor.
|
| 537 |
+
|
| 538 |
+
Args:
|
| 539 |
+
index: int, the index to the state in the transition.
|
| 540 |
+
|
| 541 |
+
Returns:
|
| 542 |
+
Is the index valid: Boolean.
|
| 543 |
+
|
| 544 |
+
"""
|
| 545 |
+
# Check the index is in the valid range
|
| 546 |
+
if index < 0 or index >= self._replay_capacity:
|
| 547 |
+
return False
|
| 548 |
+
if not self.is_full():
|
| 549 |
+
# The indices and next_indices must be smaller than the cursor.
|
| 550 |
+
if index >= self.cursor() - self._update_horizon:
|
| 551 |
+
return False
|
| 552 |
+
|
| 553 |
+
# Skip transitions that straddle the cursor.
|
| 554 |
+
if index in set(self.invalid_range):
|
| 555 |
+
return False
|
| 556 |
+
|
| 557 |
+
term_stack = self.get_terminal_stack(index)
|
| 558 |
+
if term_stack[-1] == -1:
|
| 559 |
+
return False
|
| 560 |
+
|
| 561 |
+
return True
|
| 562 |
+
|
| 563 |
+
def _create_batch_arrays(self, batch_size):
|
| 564 |
+
"""Create a tuple of arrays with the type of get_transition_elements.
|
| 565 |
+
|
| 566 |
+
When using the WrappedReplayBuffer with staging enabled it is important
|
| 567 |
+
to create new arrays every sample because StaginArea keeps a pointer to
|
| 568 |
+
the returned arrays.
|
| 569 |
+
|
| 570 |
+
Args:
|
| 571 |
+
batch_size: (int) number of transitions returned. If None the default
|
| 572 |
+
batch_size will be used.
|
| 573 |
+
|
| 574 |
+
Returns:
|
| 575 |
+
Tuple of np.arrays with the shape and type of get_transition_elements.
|
| 576 |
+
"""
|
| 577 |
+
transition_elements = self.get_transition_elements(batch_size)
|
| 578 |
+
batch_arrays = []
|
| 579 |
+
for element in transition_elements:
|
| 580 |
+
batch_arrays.append(np.empty(element.shape, dtype=element.type))
|
| 581 |
+
return tuple(batch_arrays)
|
| 582 |
+
|
| 583 |
+
def sample_index_batch(self, batch_size):
|
| 584 |
+
"""Returns a batch of valid indices sampled uniformly.
|
| 585 |
+
|
| 586 |
+
Args:
|
| 587 |
+
batch_size: int, number of indices returned.
|
| 588 |
+
|
| 589 |
+
Returns:
|
| 590 |
+
list of ints, a batch of valid indices sampled uniformly.
|
| 591 |
+
|
| 592 |
+
Raises:
|
| 593 |
+
RuntimeError: If the batch was not constructed after maximum number of
|
| 594 |
+
tries.
|
| 595 |
+
"""
|
| 596 |
+
if self.is_full():
|
| 597 |
+
# add_count >= self._replay_capacity > self._stack_size
|
| 598 |
+
min_id = (self.cursor() - self._replay_capacity +
|
| 599 |
+
self._timesteps - 1)
|
| 600 |
+
max_id = self.cursor() - self._update_horizon
|
| 601 |
+
else:
|
| 602 |
+
min_id = 0
|
| 603 |
+
max_id = self.cursor() - self._update_horizon
|
| 604 |
+
if max_id <= min_id:
|
| 605 |
+
raise RuntimeError(
|
| 606 |
+
'Cannot sample a batch with fewer than stack size '
|
| 607 |
+
'({}) + update_horizon ({}) transitions.'.
|
| 608 |
+
format(self._timesteps, self._update_horizon))
|
| 609 |
+
|
| 610 |
+
indices = []
|
| 611 |
+
attempt_count = 0
|
| 612 |
+
while (len(indices) < batch_size and
|
| 613 |
+
attempt_count < self._max_sample_attempts):
|
| 614 |
+
index = np.random.randint(min_id, max_id) % self._replay_capacity
|
| 615 |
+
if self.is_valid_transition(index):
|
| 616 |
+
indices.append(index)
|
| 617 |
+
else:
|
| 618 |
+
attempt_count += 1
|
| 619 |
+
if len(indices) != batch_size:
|
| 620 |
+
raise RuntimeError(
|
| 621 |
+
'Max sample attempts: Tried {} times but only sampled {}'
|
| 622 |
+
' valid indices. Batch size is {}'.
|
| 623 |
+
format(self._max_sample_attempts, len(indices), batch_size))
|
| 624 |
+
|
| 625 |
+
return indices
|
| 626 |
+
|
| 627 |
+
def unpack_transition(self, transition_tensors, transition_type):
|
| 628 |
+
"""Unpacks the given transition into member variables.
|
| 629 |
+
|
| 630 |
+
Args:
|
| 631 |
+
transition_tensors: tuple of tf.Tensors.
|
| 632 |
+
transition_type: tuple of ReplayElements matching transition_tensors.
|
| 633 |
+
"""
|
| 634 |
+
self.transition = collections.OrderedDict()
|
| 635 |
+
for element, element_type in zip(transition_tensors, transition_type):
|
| 636 |
+
self.transition[element_type.name] = element
|
| 637 |
+
return self.transition
|
| 638 |
+
|
| 639 |
+
def sample_transition_batch(self, batch_size=None, indices=None,
|
| 640 |
+
pack_in_dict=True):
|
| 641 |
+
"""Returns a batch of transitions (including any extra contents).
|
| 642 |
+
|
| 643 |
+
If get_transition_elements has been overridden and defines elements not
|
| 644 |
+
stored in self._store, an empty array will be returned and it will be
|
| 645 |
+
left to the child class to fill it. For example, for the child class
|
| 646 |
+
OutOfGraphPrioritizedReplayBuffer, the contents of the
|
| 647 |
+
sampling_probabilities are stored separately in a sum tree.
|
| 648 |
+
|
| 649 |
+
When the transition is terminal next_state_batch has undefined contents.
|
| 650 |
+
|
| 651 |
+
NOTE: This transition contains the indices of the sampled elements.
|
| 652 |
+
These are only valid during the call to sample_transition_batch,
|
| 653 |
+
i.e. they may be used by subclasses of this replay buffer but may
|
| 654 |
+
point to different data as soon as sampling is done.
|
| 655 |
+
|
| 656 |
+
Args:
|
| 657 |
+
batch_size: int, number of transitions returned. If None, the default
|
| 658 |
+
batch_size will be used.
|
| 659 |
+
indices: None or list of ints, the indices of every transition in the
|
| 660 |
+
batch. If None, sample the indices uniformly.
|
| 661 |
+
|
| 662 |
+
Returns:
|
| 663 |
+
transition_batch: tuple of np.arrays with the shape and type as in
|
| 664 |
+
get_transition_elements().
|
| 665 |
+
|
| 666 |
+
Raises:
|
| 667 |
+
ValueError: If an element to be sampled is missing from the
|
| 668 |
+
replay buffer.
|
| 669 |
+
"""
|
| 670 |
+
|
| 671 |
+
if batch_size is None:
|
| 672 |
+
batch_size = self._batch_size
|
| 673 |
+
with self._lock:
|
| 674 |
+
if indices is None:
|
| 675 |
+
indices = self.sample_index_batch(batch_size)
|
| 676 |
+
assert len(indices) == batch_size
|
| 677 |
+
|
| 678 |
+
transition_elements = self.get_transition_elements(batch_size)
|
| 679 |
+
batch_arrays = self._create_batch_arrays(batch_size)
|
| 680 |
+
|
| 681 |
+
for batch_element, state_index in enumerate(indices):
|
| 682 |
+
|
| 683 |
+
if not self.is_valid_transition(state_index):
|
| 684 |
+
raise ValueError('Invalid index %d.' % state_index)
|
| 685 |
+
|
| 686 |
+
trajectory_indices = [(state_index + j) % self._replay_capacity
|
| 687 |
+
for j in range(self._update_horizon)]
|
| 688 |
+
trajectory_terminals = self._store['terminal'][
|
| 689 |
+
trajectory_indices]
|
| 690 |
+
is_terminal_transition = trajectory_terminals.any()
|
| 691 |
+
if not is_terminal_transition:
|
| 692 |
+
trajectory_length = self._update_horizon
|
| 693 |
+
else:
|
| 694 |
+
# np.argmax of a bool array returns index of the first True.
|
| 695 |
+
trajectory_length = np.argmax(
|
| 696 |
+
trajectory_terminals.astype(bool),
|
| 697 |
+
0) + 1
|
| 698 |
+
|
| 699 |
+
next_state_index = state_index + trajectory_length
|
| 700 |
+
|
| 701 |
+
store = self._store
|
| 702 |
+
if self._disk_saving:
|
| 703 |
+
store = self._get_from_disk(
|
| 704 |
+
state_index - (self._timesteps - 1),
|
| 705 |
+
next_state_index + 1)
|
| 706 |
+
|
| 707 |
+
trajectory_discount_vector = (
|
| 708 |
+
self._cumulative_discount_vector[:trajectory_length])
|
| 709 |
+
trajectory_rewards = self.get_range(store['reward'],
|
| 710 |
+
state_index,
|
| 711 |
+
next_state_index)
|
| 712 |
+
|
| 713 |
+
terminal_stack = self.get_terminal_stack(state_index)
|
| 714 |
+
terminal_stack_tp1 = self.get_terminal_stack(
|
| 715 |
+
next_state_index % self._replay_capacity)
|
| 716 |
+
|
| 717 |
+
# Fill the contents of each array in the sampled batch.
|
| 718 |
+
assert len(transition_elements) == len(batch_arrays)
|
| 719 |
+
for element_array, element in zip(batch_arrays,
|
| 720 |
+
transition_elements):
|
| 721 |
+
if element.is_observation:
|
| 722 |
+
if element.name.endswith('tp1'):
|
| 723 |
+
element_array[
|
| 724 |
+
batch_element] = self._get_element_stack(
|
| 725 |
+
store[element.name[:-4]],
|
| 726 |
+
next_state_index % self._replay_capacity,
|
| 727 |
+
terminal_stack_tp1)
|
| 728 |
+
else:
|
| 729 |
+
element_array[
|
| 730 |
+
batch_element] = self._get_element_stack(
|
| 731 |
+
store[element.name],
|
| 732 |
+
state_index, terminal_stack)
|
| 733 |
+
elif element.name == REWARD:
|
| 734 |
+
# compute discounted sum of rewards in the trajectory.
|
| 735 |
+
element_array[batch_element] = np.sum(
|
| 736 |
+
trajectory_discount_vector * trajectory_rewards,
|
| 737 |
+
axis=0)
|
| 738 |
+
elif element.name == TERMINAL:
|
| 739 |
+
element_array[batch_element] = is_terminal_transition
|
| 740 |
+
elif element.name == INDICES:
|
| 741 |
+
element_array[batch_element] = state_index
|
| 742 |
+
elif element.name in store.keys():
|
| 743 |
+
element_array[batch_element] = (
|
| 744 |
+
store[element.name][state_index])
|
| 745 |
+
|
| 746 |
+
if pack_in_dict:
|
| 747 |
+
batch_arrays = self.unpack_transition(
|
| 748 |
+
batch_arrays, transition_elements)
|
| 749 |
+
|
| 750 |
+
# TODO(Mohit): proper fix to discard task names
|
| 751 |
+
if 'task' in batch_arrays:
|
| 752 |
+
del batch_arrays['task']
|
| 753 |
+
if 'task_tp1' in batch_arrays:
|
| 754 |
+
del batch_arrays['task_tp1']
|
| 755 |
+
|
| 756 |
+
return batch_arrays
|
| 757 |
+
|
| 758 |
+
def get_transition_elements(self, batch_size=None):
|
| 759 |
+
"""Returns a 'type signature' for sample_transition_batch.
|
| 760 |
+
|
| 761 |
+
Args:
|
| 762 |
+
batch_size: int, number of transitions returned. If None, the default
|
| 763 |
+
batch_size will be used.
|
| 764 |
+
Returns:
|
| 765 |
+
signature: A namedtuple describing the method's return type signature.
|
| 766 |
+
"""
|
| 767 |
+
batch_size = self._batch_size if batch_size is None else batch_size
|
| 768 |
+
|
| 769 |
+
transition_elements = [
|
| 770 |
+
ReplayElement(ACTION, (batch_size,) + self._action_shape,
|
| 771 |
+
self._action_dtype),
|
| 772 |
+
ReplayElement(REWARD, (batch_size,) + self._reward_shape,
|
| 773 |
+
self._reward_dtype),
|
| 774 |
+
ReplayElement(TERMINAL, (batch_size,), np.int8),
|
| 775 |
+
ReplayElement(TIMEOUT, (batch_size,), bool),
|
| 776 |
+
ReplayElement(INDICES, (batch_size,), np.int32),
|
| 777 |
+
]
|
| 778 |
+
|
| 779 |
+
for element in self._observation_elements:
|
| 780 |
+
transition_elements.append(ReplayElement(
|
| 781 |
+
element.name,
|
| 782 |
+
(batch_size, self._timesteps) + tuple(element.shape),
|
| 783 |
+
element.type, True))
|
| 784 |
+
transition_elements.append(ReplayElement(
|
| 785 |
+
element.name + '_tp1',
|
| 786 |
+
(batch_size, self._timesteps) + tuple(element.shape),
|
| 787 |
+
element.type, True))
|
| 788 |
+
|
| 789 |
+
for element in self._extra_replay_elements:
|
| 790 |
+
transition_elements.append(ReplayElement(
|
| 791 |
+
element.name,
|
| 792 |
+
(batch_size,) + tuple(element.shape),
|
| 793 |
+
element.type))
|
| 794 |
+
return transition_elements
|
| 795 |
+
|
| 796 |
+
def shutdown(self):
|
| 797 |
+
if self._purge_replay_on_shutdown:
|
| 798 |
+
# Safely delete replay
|
| 799 |
+
logging.info('Clearing disk replay buffer.')
|
| 800 |
+
for f in [f for f in os.listdir(self._save_dir) if '.replay' in f]:
|
| 801 |
+
os.remove(join(self._save_dir, f))
|
| 802 |
+
|
| 803 |
+
def using_disk(self):
|
| 804 |
+
return self._disk_saving
|
external/yarr/yarr/replay_buffer/wrappers/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
from yarr.replay_buffer.replay_buffer import ReplayBuffer
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class WrappedReplayBuffer(ABC):
|
| 8 |
+
|
| 9 |
+
def __init__(self, replay_buffer: ReplayBuffer):
|
| 10 |
+
"""Initializes WrappedReplayBuffer.
|
| 11 |
+
|
| 12 |
+
Raises:
|
| 13 |
+
ValueError: If update_horizon is not positive.
|
| 14 |
+
ValueError: If discount factor is not in [0, 1].
|
| 15 |
+
"""
|
| 16 |
+
self._replay_buffer = replay_buffer
|
| 17 |
+
|
| 18 |
+
@property
|
| 19 |
+
def replay_buffer(self):
|
| 20 |
+
return self._replay_buffer
|
| 21 |
+
|
| 22 |
+
@abstractmethod
|
| 23 |
+
def dataset(self) -> Any:
|
| 24 |
+
pass
|
external/yarr/yarr/replay_buffer/wrappers/pytorch_replay_buffer.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from threading import Thread
|
| 3 |
+
|
| 4 |
+
from torch.utils.data import IterableDataset, DataLoader
|
| 5 |
+
|
| 6 |
+
from yarr.replay_buffer.replay_buffer import ReplayBuffer
|
| 7 |
+
from yarr.replay_buffer.wrappers import WrappedReplayBuffer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class PyTorchIterableReplayDataset(IterableDataset):
|
| 11 |
+
|
| 12 |
+
def __init__(self, replay_buffer: ReplayBuffer):
|
| 13 |
+
self._replay_buffer = replay_buffer
|
| 14 |
+
|
| 15 |
+
def _generator(self):
|
| 16 |
+
while True:
|
| 17 |
+
yield self._replay_buffer.sample_transition_batch(pack_in_dict=True)
|
| 18 |
+
|
| 19 |
+
def __iter__(self):
|
| 20 |
+
return iter(self._generator())
|
| 21 |
+
|
| 22 |
+
# class PyTorchIterableReplayDataset(IterableDataset):
|
| 23 |
+
#
|
| 24 |
+
# BUFFER = 4
|
| 25 |
+
#
|
| 26 |
+
# def __init__(self, replay_buffer: ReplayBuffer, num_workers: int):
|
| 27 |
+
# self._replay_buffer = replay_buffer
|
| 28 |
+
# self._num_wokers = num_workers
|
| 29 |
+
# self._samples = []
|
| 30 |
+
# self._lock = Lock()
|
| 31 |
+
#
|
| 32 |
+
# def _run(self):
|
| 33 |
+
# while True:
|
| 34 |
+
# # Check if replay buffer is ig enough to be sampled
|
| 35 |
+
# while self._replay_buffer.add_count < self._replay_buffer.batch_size:
|
| 36 |
+
# time.sleep(1.)
|
| 37 |
+
# s = self._replay_buffer.sample_transition_batch(pack_in_dict=True)
|
| 38 |
+
# while len(self._samples) >= PyTorchIterableReplayDataset.BUFFER:
|
| 39 |
+
# time.sleep(0.25)
|
| 40 |
+
# with self._lock:
|
| 41 |
+
# self._samples.append(s)
|
| 42 |
+
#
|
| 43 |
+
# def _generator(self):
|
| 44 |
+
# ts = [Thread(
|
| 45 |
+
# target=self._run, args=()) for _ in range(self._num_wokers)]
|
| 46 |
+
# [t.start() for t in ts]
|
| 47 |
+
# while True:
|
| 48 |
+
# while len(self._samples) == 0:
|
| 49 |
+
# time.sleep(0.1)
|
| 50 |
+
# with self._lock:
|
| 51 |
+
# s = self._samples.pop(0)
|
| 52 |
+
# yield s
|
| 53 |
+
#
|
| 54 |
+
# def __iter__(self):
|
| 55 |
+
# i = iter(self._generator())
|
| 56 |
+
# return i
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class PyTorchReplayBuffer(WrappedReplayBuffer):
|
| 60 |
+
"""Wrapper of OutOfGraphReplayBuffer with an in graph sampling mechanism.
|
| 61 |
+
|
| 62 |
+
Usage:
|
| 63 |
+
To add a transition: call the add function.
|
| 64 |
+
|
| 65 |
+
To sample a batch: Construct operations that depend on any of the
|
| 66 |
+
tensors is the transition dictionary. Every sess.run
|
| 67 |
+
that requires any of these tensors will sample a new
|
| 68 |
+
transition.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(self, replay_buffer: ReplayBuffer, num_workers: int = 2):
|
| 72 |
+
super(PyTorchReplayBuffer, self).__init__(replay_buffer)
|
| 73 |
+
self._num_workers = num_workers
|
| 74 |
+
|
| 75 |
+
def dataset(self, batch_size=None, drop_last=False) -> DataLoader:
|
| 76 |
+
# d = PyTorchIterableReplayDataset(self._replay_buffer)
|
| 77 |
+
d = PyTorchIterableReplayDataset(self._replay_buffer)
|
| 78 |
+
|
| 79 |
+
# Batch size None disables automatic batching
|
| 80 |
+
return DataLoader(d, batch_size=batch_size,
|
| 81 |
+
drop_last=drop_last,
|
| 82 |
+
num_workers=self._num_workers, pin_memory=True)
|
external/yarr/yarr/runners/__init__.py
ADDED
|
File without changes
|
external/yarr/yarr/runners/_env_runner.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
from multiprocessing import Process, Manager
|
| 8 |
+
from multiprocessing import get_start_method, set_start_method
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from yarr.agents.agent import Agent
|
| 14 |
+
from yarr.agents.agent import ScalarSummary
|
| 15 |
+
from yarr.agents.agent import Summary
|
| 16 |
+
from yarr.envs.env import Env
|
| 17 |
+
from yarr.utils.rollout_generator import RolloutGenerator
|
| 18 |
+
from yarr.utils.log_writer import LogWriter
|
| 19 |
+
from yarr.utils.process_str import change_case
|
| 20 |
+
from yarr.utils.video_utils import CircleCameraMotion, TaskRecorder
|
| 21 |
+
|
| 22 |
+
from pyrep.objects.dummy import Dummy
|
| 23 |
+
from pyrep.objects.vision_sensor import VisionSensor
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
if get_start_method() != 'spawn':
|
| 27 |
+
set_start_method('spawn', force=True)
|
| 28 |
+
except RuntimeError:
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class _EnvRunner(object):
|
| 33 |
+
|
| 34 |
+
def __init__(self,
|
| 35 |
+
train_env: Env,
|
| 36 |
+
eval_env: Env,
|
| 37 |
+
agent: Agent,
|
| 38 |
+
timesteps: int,
|
| 39 |
+
train_envs: int,
|
| 40 |
+
eval_envs: int,
|
| 41 |
+
rollout_episodes: int,
|
| 42 |
+
eval_episodes: int,
|
| 43 |
+
training_iterations: int,
|
| 44 |
+
eval_from_eps_number: int,
|
| 45 |
+
episode_length: int,
|
| 46 |
+
kill_signal: Any,
|
| 47 |
+
step_signal: Any,
|
| 48 |
+
num_eval_episodes_signal: Any,
|
| 49 |
+
eval_epochs_signal: Any,
|
| 50 |
+
eval_report_signal: Any,
|
| 51 |
+
log_freq: int,
|
| 52 |
+
rollout_generator: RolloutGenerator,
|
| 53 |
+
save_load_lock,
|
| 54 |
+
current_replay_ratio,
|
| 55 |
+
target_replay_ratio,
|
| 56 |
+
weightsdir: str = None,
|
| 57 |
+
logdir: str = None,
|
| 58 |
+
env_device: torch.device = None,
|
| 59 |
+
previous_loaded_weight_folder: str = '',
|
| 60 |
+
num_eval_runs: int = 1,
|
| 61 |
+
):
|
| 62 |
+
self._train_env = train_env
|
| 63 |
+
self._eval_env = eval_env
|
| 64 |
+
self._agent = agent
|
| 65 |
+
self._train_envs = train_envs
|
| 66 |
+
self._eval_envs = eval_envs
|
| 67 |
+
self._rollout_episodes = rollout_episodes
|
| 68 |
+
self._eval_episodes = eval_episodes
|
| 69 |
+
self._training_iterations = training_iterations
|
| 70 |
+
self._num_eval_runs = num_eval_runs
|
| 71 |
+
self._eval_from_eps_number = eval_from_eps_number
|
| 72 |
+
self._episode_length = episode_length
|
| 73 |
+
self._rollout_generator = rollout_generator
|
| 74 |
+
self._weightsdir = weightsdir
|
| 75 |
+
self._logdir = logdir
|
| 76 |
+
self._env_device = env_device
|
| 77 |
+
self._previous_loaded_weight_folder = previous_loaded_weight_folder
|
| 78 |
+
|
| 79 |
+
self._timesteps = timesteps
|
| 80 |
+
|
| 81 |
+
self._p_args = {}
|
| 82 |
+
self.p_failures = {}
|
| 83 |
+
manager = Manager()
|
| 84 |
+
self.write_lock = manager.Lock()
|
| 85 |
+
self.stored_transitions = manager.list()
|
| 86 |
+
self.agent_summaries = manager.list()
|
| 87 |
+
self._kill_signal = kill_signal
|
| 88 |
+
self._step_signal = step_signal
|
| 89 |
+
self._num_eval_episodes_signal = num_eval_episodes_signal
|
| 90 |
+
self._eval_epochs_signal = eval_epochs_signal
|
| 91 |
+
self._eval_report_signal = eval_report_signal
|
| 92 |
+
self._save_load_lock = save_load_lock
|
| 93 |
+
self._current_replay_ratio = current_replay_ratio
|
| 94 |
+
self._target_replay_ratio = target_replay_ratio
|
| 95 |
+
self._log_freq = log_freq
|
| 96 |
+
|
| 97 |
+
self._new_weights = False
|
| 98 |
+
|
| 99 |
+
def restart_process(self, name: str):
|
| 100 |
+
p = Process(target=self._run_env, args=self._p_args[name], name=name)
|
| 101 |
+
p.start()
|
| 102 |
+
return p
|
| 103 |
+
|
| 104 |
+
def spin_up_envs(self, name: str, num_envs: int, eval: bool):
|
| 105 |
+
|
| 106 |
+
ps = []
|
| 107 |
+
for i in range(num_envs):
|
| 108 |
+
n = name + str(i)
|
| 109 |
+
self._p_args[n] = (n, eval)
|
| 110 |
+
self.p_failures[n] = 0
|
| 111 |
+
p = Process(target=self._run_env, args=self._p_args[n], name=n)
|
| 112 |
+
p.start()
|
| 113 |
+
ps.append(p)
|
| 114 |
+
return ps
|
| 115 |
+
|
| 116 |
+
def _load_save(self):
|
| 117 |
+
if self._weightsdir is None:
|
| 118 |
+
logging.info("'weightsdir' was None, so not loading weights.")
|
| 119 |
+
return
|
| 120 |
+
while True:
|
| 121 |
+
weight_folders = []
|
| 122 |
+
with self._save_load_lock:
|
| 123 |
+
if os.path.exists(self._weightsdir):
|
| 124 |
+
weight_folders = os.listdir(self._weightsdir)
|
| 125 |
+
if len(weight_folders) > 0:
|
| 126 |
+
weight_folders = sorted(map(int, weight_folders))
|
| 127 |
+
# Only load if there has been a new weight saving
|
| 128 |
+
if self._previous_loaded_weight_folder != weight_folders[-1]:
|
| 129 |
+
self._previous_loaded_weight_folder = weight_folders[-1]
|
| 130 |
+
d = os.path.join(self._weightsdir, str(weight_folders[-1]))
|
| 131 |
+
try:
|
| 132 |
+
self._agent.load_weights(d)
|
| 133 |
+
except FileNotFoundError:
|
| 134 |
+
# Rare case when agent hasn't finished writing.
|
| 135 |
+
time.sleep(1)
|
| 136 |
+
self._agent.load_weights(d)
|
| 137 |
+
print('Agent %s: Loaded weights: %s' % (self._name, d))
|
| 138 |
+
self._new_weights = True
|
| 139 |
+
else:
|
| 140 |
+
self._new_weights = False
|
| 141 |
+
break
|
| 142 |
+
print('Waiting for weights to become available.')
|
| 143 |
+
time.sleep(1)
|
| 144 |
+
|
| 145 |
+
def _get_type(self, x):
|
| 146 |
+
if x.dtype == np.float64:
|
| 147 |
+
return np.float32
|
| 148 |
+
return x.dtype
|
| 149 |
+
|
| 150 |
+
def _get_task_name(self):
|
| 151 |
+
if hasattr(self._eval_env, '_task_class'):
|
| 152 |
+
eval_task_name = change_case(self._eval_env._task_class.__name__)
|
| 153 |
+
multi_task = False
|
| 154 |
+
elif hasattr(self._eval_env, '_task_classes'):
|
| 155 |
+
if self._eval_env.active_task_id != -1:
|
| 156 |
+
task_id = (self._eval_env.active_task_id) % len(self._eval_env._task_classes)
|
| 157 |
+
eval_task_name = change_case(self._eval_env._task_classes[task_id].__name__)
|
| 158 |
+
else:
|
| 159 |
+
eval_task_name = ''
|
| 160 |
+
multi_task = True
|
| 161 |
+
else:
|
| 162 |
+
raise Exception('Neither task_class nor task_classes found in eval env')
|
| 163 |
+
return eval_task_name, multi_task
|
| 164 |
+
|
| 165 |
+
def _run_env(self, name: str, eval: bool):
|
| 166 |
+
|
| 167 |
+
self._name = name
|
| 168 |
+
|
| 169 |
+
self._agent = copy.deepcopy(self._agent)
|
| 170 |
+
|
| 171 |
+
self._agent.build(training=False, device=self._env_device)
|
| 172 |
+
|
| 173 |
+
logging.info('%s: Launching env.' % name)
|
| 174 |
+
np.random.seed()
|
| 175 |
+
|
| 176 |
+
logging.info('Agent information:')
|
| 177 |
+
logging.info(self._agent)
|
| 178 |
+
|
| 179 |
+
env = self._train_env
|
| 180 |
+
if eval:
|
| 181 |
+
env = self._eval_env
|
| 182 |
+
env.eval = eval
|
| 183 |
+
env.launch()
|
| 184 |
+
for ep in range(self._rollout_episodes):
|
| 185 |
+
self._load_save()
|
| 186 |
+
logging.debug('%s: Starting episode %d.' % (name, ep))
|
| 187 |
+
episode_rollout = []
|
| 188 |
+
generator = self._rollout_generator.generator(
|
| 189 |
+
self._step_signal, env, self._agent,
|
| 190 |
+
self._episode_length, self._timesteps,
|
| 191 |
+
eval, eval_demo_seed=eval_demo_seed,
|
| 192 |
+
record_enabled=rec_cfg.enabled)
|
| 193 |
+
try:
|
| 194 |
+
for replay_transition in generator:
|
| 195 |
+
while True:
|
| 196 |
+
if self._kill_signal.value:
|
| 197 |
+
env.shutdown()
|
| 198 |
+
return
|
| 199 |
+
if (eval or self._target_replay_ratio is None or
|
| 200 |
+
self._step_signal.value <= 0 or (
|
| 201 |
+
self._current_replay_ratio.value >
|
| 202 |
+
self._target_replay_ratio)):
|
| 203 |
+
break
|
| 204 |
+
time.sleep(1)
|
| 205 |
+
logging.debug(
|
| 206 |
+
'Agent. Waiting for replay_ratio %f to be more than %f' %
|
| 207 |
+
(self._current_replay_ratio.value, self._target_replay_ratio))
|
| 208 |
+
|
| 209 |
+
with self.write_lock:
|
| 210 |
+
if len(self.agent_summaries) == 0:
|
| 211 |
+
# Only store new summaries if the previous ones
|
| 212 |
+
# have been popped by the main env runner.
|
| 213 |
+
for s in self._agent.act_summaries():
|
| 214 |
+
self.agent_summaries.append(s)
|
| 215 |
+
episode_rollout.append(replay_transition)
|
| 216 |
+
except StopIteration as e:
|
| 217 |
+
continue
|
| 218 |
+
except Exception as e:
|
| 219 |
+
env.shutdown()
|
| 220 |
+
raise e
|
| 221 |
+
|
| 222 |
+
with self.write_lock:
|
| 223 |
+
for transition in episode_rollout:
|
| 224 |
+
self.stored_transitions.append((name, transition, eval))
|
| 225 |
+
env.shutdown()
|
| 226 |
+
|
| 227 |
+
def kill(self):
|
| 228 |
+
self._kill_signal.value = True
|
external/yarr/yarr/runners/_independent_env_runner.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
from multiprocessing import Process, Manager
|
| 8 |
+
from multiprocessing import get_start_method, set_start_method
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from yarr.agents.agent import Agent
|
| 14 |
+
from yarr.agents.agent import ScalarSummary
|
| 15 |
+
from yarr.agents.agent import Summary
|
| 16 |
+
from yarr.envs.env import Env
|
| 17 |
+
from yarr.utils.rollout_generator import RolloutGenerator
|
| 18 |
+
from yarr.utils.log_writer import LogWriter
|
| 19 |
+
from yarr.utils.process_str import change_case
|
| 20 |
+
from yarr.utils.video_utils import CircleCameraMotion, TaskRecorder
|
| 21 |
+
|
| 22 |
+
from pyrep.objects.dummy import Dummy
|
| 23 |
+
from pyrep.objects.vision_sensor import VisionSensor
|
| 24 |
+
|
| 25 |
+
from yarr.runners._env_runner import _EnvRunner
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class _IndependentEnvRunner(_EnvRunner):
|
| 29 |
+
|
| 30 |
+
def __init__(self,
|
| 31 |
+
train_env: Env,
|
| 32 |
+
eval_env: Env,
|
| 33 |
+
agent: Agent,
|
| 34 |
+
timesteps: int,
|
| 35 |
+
train_envs: int,
|
| 36 |
+
eval_envs: int,
|
| 37 |
+
rollout_episodes: int,
|
| 38 |
+
eval_episodes: int,
|
| 39 |
+
training_iterations: int,
|
| 40 |
+
eval_from_eps_number: int,
|
| 41 |
+
episode_length: int,
|
| 42 |
+
kill_signal: Any,
|
| 43 |
+
step_signal: Any,
|
| 44 |
+
num_eval_episodes_signal: Any,
|
| 45 |
+
eval_epochs_signal: Any,
|
| 46 |
+
eval_report_signal: Any,
|
| 47 |
+
log_freq: int,
|
| 48 |
+
rollout_generator: RolloutGenerator,
|
| 49 |
+
save_load_lock,
|
| 50 |
+
current_replay_ratio,
|
| 51 |
+
target_replay_ratio,
|
| 52 |
+
weightsdir: str = None,
|
| 53 |
+
logdir: str = None,
|
| 54 |
+
env_device: torch.device = None,
|
| 55 |
+
previous_loaded_weight_folder: str = '',
|
| 56 |
+
num_eval_runs: int = 1,
|
| 57 |
+
):
|
| 58 |
+
|
| 59 |
+
super().__init__(train_env, eval_env, agent, timesteps,
|
| 60 |
+
train_envs, eval_envs, rollout_episodes, eval_episodes,
|
| 61 |
+
training_iterations, eval_from_eps_number, episode_length,
|
| 62 |
+
kill_signal, step_signal, num_eval_episodes_signal,
|
| 63 |
+
eval_epochs_signal, eval_report_signal, log_freq,
|
| 64 |
+
rollout_generator, save_load_lock, current_replay_ratio,
|
| 65 |
+
target_replay_ratio, weightsdir, logdir, env_device,
|
| 66 |
+
previous_loaded_weight_folder, num_eval_runs)
|
| 67 |
+
|
| 68 |
+
def _load_save(self):
|
| 69 |
+
if self._weightsdir is None:
|
| 70 |
+
logging.info("'weightsdir' was None, so not loading weights.")
|
| 71 |
+
return
|
| 72 |
+
while True:
|
| 73 |
+
weight_folders = []
|
| 74 |
+
with self._save_load_lock:
|
| 75 |
+
if os.path.exists(self._weightsdir):
|
| 76 |
+
weight_folders = os.listdir(self._weightsdir)
|
| 77 |
+
if len(weight_folders) > 0:
|
| 78 |
+
weight_folders = sorted(map(int, weight_folders))
|
| 79 |
+
# only load if there has been a new weight saving
|
| 80 |
+
if self._previous_loaded_weight_folder != weight_folders[-1]:
|
| 81 |
+
self._previous_loaded_weight_folder = weight_folders[-1]
|
| 82 |
+
d = os.path.join(self._weightsdir, str(weight_folders[-1]))
|
| 83 |
+
try:
|
| 84 |
+
self._agent.load_weights(d)
|
| 85 |
+
except FileNotFoundError:
|
| 86 |
+
# rare case when agent hasn't finished writing.
|
| 87 |
+
time.sleep(1)
|
| 88 |
+
self._agent.load_weights(d)
|
| 89 |
+
logging.info('Agent %s: Loaded weights: %s' % (self._name, d))
|
| 90 |
+
self._new_weights = True
|
| 91 |
+
else:
|
| 92 |
+
self._new_weights = False
|
| 93 |
+
break
|
| 94 |
+
logging.info('Waiting for weights to become available.')
|
| 95 |
+
time.sleep(1)
|
| 96 |
+
|
| 97 |
+
def _get_task_name(self):
|
| 98 |
+
if hasattr(self._eval_env, '_task_class'):
|
| 99 |
+
eval_task_name = change_case(self._eval_env._task_class.__name__)
|
| 100 |
+
multi_task = False
|
| 101 |
+
elif hasattr(self._eval_env, '_task_classes'):
|
| 102 |
+
if self._eval_env.active_task_id != -1:
|
| 103 |
+
task_id = (self._eval_env.active_task_id) % len(self._eval_env._task_classes)
|
| 104 |
+
eval_task_name = change_case(self._eval_env._task_classes[task_id].__name__)
|
| 105 |
+
else:
|
| 106 |
+
eval_task_name = ''
|
| 107 |
+
multi_task = True
|
| 108 |
+
else:
|
| 109 |
+
raise Exception('Neither task_class nor task_classes found in eval env')
|
| 110 |
+
return eval_task_name, multi_task
|
| 111 |
+
|
| 112 |
+
def _run_eval_independent(self, name: str,
|
| 113 |
+
stats_accumulator,
|
| 114 |
+
weight,
|
| 115 |
+
writer_lock,
|
| 116 |
+
eval=True,
|
| 117 |
+
device_idx=0,
|
| 118 |
+
save_metrics=True,
|
| 119 |
+
cinematic_recorder_cfg=None):
|
| 120 |
+
|
| 121 |
+
self._name = name
|
| 122 |
+
self._save_metrics = save_metrics
|
| 123 |
+
self._is_test_set = type(weight) == dict
|
| 124 |
+
|
| 125 |
+
self._agent = copy.deepcopy(self._agent)
|
| 126 |
+
|
| 127 |
+
device = torch.device('cuda:%d' % device_idx) if torch.cuda.device_count() > 1 else torch.device('cuda:0')
|
| 128 |
+
with writer_lock: # hack to prevent multiple CLIP downloads ... argh should use a separate lock
|
| 129 |
+
self._agent.build(training=False, device=device)
|
| 130 |
+
|
| 131 |
+
logging.info('%s: Launching env.' % name)
|
| 132 |
+
np.random.seed()
|
| 133 |
+
|
| 134 |
+
logging.info('Agent information:')
|
| 135 |
+
logging.info(self._agent)
|
| 136 |
+
|
| 137 |
+
env = self._eval_env
|
| 138 |
+
env.eval = eval
|
| 139 |
+
env.launch()
|
| 140 |
+
|
| 141 |
+
# initialize cinematic recorder if specified
|
| 142 |
+
rec_cfg = cinematic_recorder_cfg
|
| 143 |
+
if rec_cfg.enabled:
|
| 144 |
+
cam_placeholder = Dummy('cam_cinematic_placeholder')
|
| 145 |
+
cam = VisionSensor.create(rec_cfg.camera_resolution)
|
| 146 |
+
cam.set_pose(cam_placeholder.get_pose())
|
| 147 |
+
cam.set_parent(cam_placeholder)
|
| 148 |
+
|
| 149 |
+
cam_motion = CircleCameraMotion(cam, Dummy('cam_cinematic_base'), rec_cfg.rotate_speed)
|
| 150 |
+
tr = TaskRecorder(env, cam_motion, fps=rec_cfg.fps)
|
| 151 |
+
|
| 152 |
+
env.env._action_mode.arm_action_mode.set_callable_each_step(tr.take_snap)
|
| 153 |
+
|
| 154 |
+
if not os.path.exists(self._weightsdir):
|
| 155 |
+
raise Exception('No weights directory found.')
|
| 156 |
+
|
| 157 |
+
# to save or not to save evaluation metrics (set as False for recording videos)
|
| 158 |
+
if self._save_metrics:
|
| 159 |
+
csv_file = 'eval_data.csv' if not self._is_test_set else 'test_data.csv'
|
| 160 |
+
writer = LogWriter(self._logdir, True, True,
|
| 161 |
+
env_csv=csv_file)
|
| 162 |
+
|
| 163 |
+
# one weight for all tasks (used for validation)
|
| 164 |
+
if type(weight) == int:
|
| 165 |
+
logging.info('Evaluating weight %s' % weight)
|
| 166 |
+
weight_path = os.path.join(self._weightsdir, str(weight))
|
| 167 |
+
seed_path = self._weightsdir.replace('/weights', '')
|
| 168 |
+
self._agent.load_weights(weight_path)
|
| 169 |
+
weight_name = str(weight)
|
| 170 |
+
|
| 171 |
+
new_transitions = {'train_envs': 0, 'eval_envs': 0}
|
| 172 |
+
total_transitions = {'train_envs': 0, 'eval_envs': 0}
|
| 173 |
+
current_task_id = -1
|
| 174 |
+
|
| 175 |
+
for n_eval in range(self._num_eval_runs):
|
| 176 |
+
if rec_cfg.enabled:
|
| 177 |
+
tr._cam_motion.save_pose()
|
| 178 |
+
|
| 179 |
+
# best weight for each task (used for test evaluation)
|
| 180 |
+
if type(weight) == dict:
|
| 181 |
+
task_name = list(weight.keys())[n_eval]
|
| 182 |
+
task_weight = weight[task_name]
|
| 183 |
+
weight_path = os.path.join(self._weightsdir, str(task_weight))
|
| 184 |
+
seed_path = self._weightsdir.replace('/weights', '')
|
| 185 |
+
self._agent.load_weights(weight_path)
|
| 186 |
+
weight_name = str(task_weight)
|
| 187 |
+
print('Evaluating weight %s for %s' % (weight_name, task_name))
|
| 188 |
+
|
| 189 |
+
# evaluate on N tasks * M episodes per task = total eval episodes
|
| 190 |
+
for ep in range(self._eval_episodes):
|
| 191 |
+
eval_demo_seed = ep + self._eval_from_eps_number
|
| 192 |
+
logging.info('%s: Starting episode %d, seed %d.' % (name, ep, eval_demo_seed))
|
| 193 |
+
|
| 194 |
+
# the current task gets reset after every M episodes
|
| 195 |
+
episode_rollout = []
|
| 196 |
+
generator = self._rollout_generator.generator(
|
| 197 |
+
self._step_signal, env, self._agent,
|
| 198 |
+
self._episode_length, self._timesteps,
|
| 199 |
+
eval, eval_demo_seed=eval_demo_seed,
|
| 200 |
+
record_enabled=rec_cfg.enabled)
|
| 201 |
+
try:
|
| 202 |
+
for replay_transition in generator:
|
| 203 |
+
while True:
|
| 204 |
+
if self._kill_signal.value:
|
| 205 |
+
env.shutdown()
|
| 206 |
+
return
|
| 207 |
+
if (eval or self._target_replay_ratio is None or
|
| 208 |
+
self._step_signal.value <= 0 or (
|
| 209 |
+
self._current_replay_ratio.value >
|
| 210 |
+
self._target_replay_ratio)):
|
| 211 |
+
break
|
| 212 |
+
time.sleep(1)
|
| 213 |
+
logging.debug(
|
| 214 |
+
'Agent. Waiting for replay_ratio %f to be more than %f' %
|
| 215 |
+
(self._current_replay_ratio.value, self._target_replay_ratio))
|
| 216 |
+
|
| 217 |
+
with self.write_lock:
|
| 218 |
+
if len(self.agent_summaries) == 0:
|
| 219 |
+
# Only store new summaries if the previous ones
|
| 220 |
+
# have been popped by the main env runner.
|
| 221 |
+
for s in self._agent.act_summaries():
|
| 222 |
+
self.agent_summaries.append(s)
|
| 223 |
+
episode_rollout.append(replay_transition)
|
| 224 |
+
except StopIteration as e:
|
| 225 |
+
continue
|
| 226 |
+
except Exception as e:
|
| 227 |
+
env.shutdown()
|
| 228 |
+
raise e
|
| 229 |
+
|
| 230 |
+
with self.write_lock:
|
| 231 |
+
for transition in episode_rollout:
|
| 232 |
+
self.stored_transitions.append((name, transition, eval))
|
| 233 |
+
|
| 234 |
+
new_transitions['eval_envs'] += 1
|
| 235 |
+
total_transitions['eval_envs'] += 1
|
| 236 |
+
stats_accumulator.step(transition, eval)
|
| 237 |
+
current_task_id = transition.info['active_task_id']
|
| 238 |
+
|
| 239 |
+
self._num_eval_episodes_signal.value += 1
|
| 240 |
+
|
| 241 |
+
task_name, _ = self._get_task_name()
|
| 242 |
+
reward = episode_rollout[-1].reward
|
| 243 |
+
lang_goal = env._lang_goal
|
| 244 |
+
print(f"Evaluating {task_name} | Episode {ep} | Score: {reward} | Lang Goal: {lang_goal}")
|
| 245 |
+
|
| 246 |
+
# save recording
|
| 247 |
+
if rec_cfg.enabled:
|
| 248 |
+
success = reward > 0.99
|
| 249 |
+
record_file = os.path.join(seed_path, 'videos',
|
| 250 |
+
'%s_w%s_s%s_%s.mp4' % (task_name,
|
| 251 |
+
weight_name,
|
| 252 |
+
eval_demo_seed,
|
| 253 |
+
'succ' if success else 'fail'))
|
| 254 |
+
|
| 255 |
+
lang_goal = self._eval_env._lang_goal
|
| 256 |
+
|
| 257 |
+
tr.save(record_file, lang_goal, reward)
|
| 258 |
+
tr._cam_motion.restore_pose()
|
| 259 |
+
|
| 260 |
+
# report summaries
|
| 261 |
+
summaries = []
|
| 262 |
+
summaries.extend(stats_accumulator.pop())
|
| 263 |
+
|
| 264 |
+
eval_task_name, multi_task = self._get_task_name()
|
| 265 |
+
|
| 266 |
+
if eval_task_name and multi_task:
|
| 267 |
+
for s in summaries:
|
| 268 |
+
if 'eval' in s.name:
|
| 269 |
+
s.name = '%s/%s' % (s.name, eval_task_name)
|
| 270 |
+
|
| 271 |
+
if len(summaries) > 0:
|
| 272 |
+
if multi_task:
|
| 273 |
+
task_score = [s.value for s in summaries if f'eval_envs/return/{eval_task_name}' in s.name][0]
|
| 274 |
+
else:
|
| 275 |
+
task_score = [s.value for s in summaries if f'eval_envs/return' in s.name][0]
|
| 276 |
+
else:
|
| 277 |
+
task_score = "unknown"
|
| 278 |
+
|
| 279 |
+
print(f"Finished {eval_task_name} | Final Score: {task_score}\n")
|
| 280 |
+
|
| 281 |
+
if self._save_metrics:
|
| 282 |
+
with writer_lock:
|
| 283 |
+
writer.add_summaries(weight_name, summaries)
|
| 284 |
+
|
| 285 |
+
self._new_transitions = {'train_envs': 0, 'eval_envs': 0}
|
| 286 |
+
self.agent_summaries[:] = []
|
| 287 |
+
self.stored_transitions[:] = []
|
| 288 |
+
|
| 289 |
+
if self._save_metrics:
|
| 290 |
+
with writer_lock:
|
| 291 |
+
writer.end_iteration()
|
| 292 |
+
|
| 293 |
+
logging.info('Finished evaluation.')
|
| 294 |
+
env.shutdown()
|
| 295 |
+
|
| 296 |
+
def kill(self):
|
| 297 |
+
self._kill_signal.value = True
|
external/yarr/yarr/runners/env_runner.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import signal
|
| 5 |
+
import time
|
| 6 |
+
from multiprocessing import Value
|
| 7 |
+
from threading import Thread
|
| 8 |
+
from typing import List
|
| 9 |
+
from typing import Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from yarr.agents.agent import Agent
|
| 14 |
+
from yarr.agents.agent import ScalarSummary
|
| 15 |
+
from yarr.agents.agent import Summary
|
| 16 |
+
from yarr.envs.env import Env
|
| 17 |
+
from yarr.replay_buffer.replay_buffer import ReplayBuffer
|
| 18 |
+
from yarr.runners._env_runner import _EnvRunner
|
| 19 |
+
from yarr.utils.rollout_generator import RolloutGenerator
|
| 20 |
+
from yarr.utils.stat_accumulator import StatAccumulator, SimpleAccumulator
|
| 21 |
+
from yarr.utils.process_str import change_case
|
| 22 |
+
from helpers.custom_rlbench_env import CustomRLBenchEnv, CustomMultiTaskRLBenchEnv
|
| 23 |
+
|
| 24 |
+
class EnvRunner(object):
|
| 25 |
+
|
| 26 |
+
def __init__(self,
|
| 27 |
+
train_env: Env,
|
| 28 |
+
agent: Agent,
|
| 29 |
+
train_replay_buffer: Union[ReplayBuffer, List[ReplayBuffer]],
|
| 30 |
+
num_train_envs: int,
|
| 31 |
+
num_eval_envs: int,
|
| 32 |
+
rollout_episodes: int,
|
| 33 |
+
eval_episodes: int,
|
| 34 |
+
training_iterations: int,
|
| 35 |
+
eval_from_eps_number: int,
|
| 36 |
+
episode_length: int,
|
| 37 |
+
eval_env: Union[Env, None] = None,
|
| 38 |
+
eval_replay_buffer: Union[ReplayBuffer, List[ReplayBuffer], None] = None,
|
| 39 |
+
stat_accumulator: Union[StatAccumulator, None] = None,
|
| 40 |
+
rollout_generator: RolloutGenerator = None,
|
| 41 |
+
weightsdir: str = None,
|
| 42 |
+
logdir: str = None,
|
| 43 |
+
max_fails: int = 10,
|
| 44 |
+
num_eval_runs: int = 1,
|
| 45 |
+
env_device: torch.device = None,
|
| 46 |
+
multi_task: bool = False):
|
| 47 |
+
self._train_env = train_env
|
| 48 |
+
self._eval_env = eval_env if eval_env else train_env
|
| 49 |
+
self._agent = agent
|
| 50 |
+
self._train_envs = num_train_envs
|
| 51 |
+
self._eval_envs = num_eval_envs
|
| 52 |
+
self._train_replay_buffer = train_replay_buffer if isinstance(train_replay_buffer, list) else [train_replay_buffer]
|
| 53 |
+
self._timesteps = self._train_replay_buffer[0].timesteps if self._train_replay_buffer[0] is not None else 1
|
| 54 |
+
|
| 55 |
+
if eval_replay_buffer is not None:
|
| 56 |
+
eval_replay_buffer = eval_replay_buffer if isinstance(eval_replay_buffer, list) else [eval_replay_buffer]
|
| 57 |
+
self._eval_replay_buffer = eval_replay_buffer
|
| 58 |
+
self._rollout_episodes = rollout_episodes
|
| 59 |
+
self._eval_episodes = eval_episodes
|
| 60 |
+
self._num_eval_runs = num_eval_runs
|
| 61 |
+
self._training_iterations = training_iterations
|
| 62 |
+
self._eval_from_eps_number = eval_from_eps_number
|
| 63 |
+
self._episode_length = episode_length
|
| 64 |
+
self._stat_accumulator = stat_accumulator
|
| 65 |
+
self._rollout_generator = (
|
| 66 |
+
RolloutGenerator() if rollout_generator is None
|
| 67 |
+
else rollout_generator)
|
| 68 |
+
self._rollout_generator._env_device = env_device
|
| 69 |
+
self._weightsdir = weightsdir
|
| 70 |
+
self._logdir = logdir
|
| 71 |
+
self._max_fails = max_fails
|
| 72 |
+
self._env_device = env_device
|
| 73 |
+
self._previous_loaded_weight_folder = ''
|
| 74 |
+
self._p = None
|
| 75 |
+
self._kill_signal = Value('b', 0)
|
| 76 |
+
self._step_signal = Value('i', -1)
|
| 77 |
+
self._num_eval_episodes_signal = Value('i', 0)
|
| 78 |
+
self._eval_epochs_signal = Value('i', 0)
|
| 79 |
+
self._eval_report_signal = Value('b', 0)
|
| 80 |
+
self._new_transitions = {'train_envs': 0, 'eval_envs': 0}
|
| 81 |
+
self._total_transitions = {'train_envs': 0, 'eval_envs': 0}
|
| 82 |
+
self.log_freq = 1000 # Will get overridden later
|
| 83 |
+
self.target_replay_ratio = None # Will get overridden later
|
| 84 |
+
self.current_replay_ratio = Value('f', -1)
|
| 85 |
+
self._current_task_id = -1
|
| 86 |
+
self._multi_task = multi_task
|
| 87 |
+
|
| 88 |
+
def summaries(self) -> List[Summary]:
|
| 89 |
+
summaries = []
|
| 90 |
+
if self._stat_accumulator is not None:
|
| 91 |
+
summaries.extend(self._stat_accumulator.pop())
|
| 92 |
+
for key, value in self._new_transitions.items():
|
| 93 |
+
summaries.append(ScalarSummary('%s/new_transitions' % key, value))
|
| 94 |
+
for key, value in self._total_transitions.items():
|
| 95 |
+
summaries.append(ScalarSummary('%s/total_transitions' % key, value))
|
| 96 |
+
self._new_transitions = {'train_envs': 0, 'eval_envs': 0}
|
| 97 |
+
summaries.extend(self._agent_summaries)
|
| 98 |
+
|
| 99 |
+
# add current task_name to eval summaries .... argh this should be inside a helper function
|
| 100 |
+
if hasattr(self._eval_env, '_task_class'):
|
| 101 |
+
eval_task_name = change_case(self._eval_env._task_class.__name__)
|
| 102 |
+
elif hasattr(self._eval_env, '_task_classes'):
|
| 103 |
+
if self._current_task_id != -1:
|
| 104 |
+
task_id = (self._current_task_id) % len(self._eval_env._task_classes)
|
| 105 |
+
eval_task_name = change_case(self._eval_env._task_classes[task_id].__name__)
|
| 106 |
+
else:
|
| 107 |
+
eval_task_name = ''
|
| 108 |
+
else:
|
| 109 |
+
raise Exception('Neither task_class nor task_classes found in eval env')
|
| 110 |
+
|
| 111 |
+
# multi-task summaries
|
| 112 |
+
if eval_task_name and self._multi_task:
|
| 113 |
+
for s in summaries:
|
| 114 |
+
if 'eval' in s.name:
|
| 115 |
+
s.name = '%s/%s' % (s.name, eval_task_name)
|
| 116 |
+
|
| 117 |
+
return summaries
|
| 118 |
+
|
| 119 |
+
def _update(self):
|
| 120 |
+
# Move the stored transitions to the replay and accumulate statistics.
|
| 121 |
+
new_transitions = collections.defaultdict(int)
|
| 122 |
+
with self._internal_env_runner.write_lock:
|
| 123 |
+
self._agent_summaries = list(
|
| 124 |
+
self._internal_env_runner.agent_summaries)
|
| 125 |
+
if self._num_eval_episodes_signal.value % self._eval_episodes == 0 and self._num_eval_episodes_signal.value > 0:
|
| 126 |
+
self._internal_env_runner.agent_summaries[:] = []
|
| 127 |
+
for name, transition, eval in self._internal_env_runner.stored_transitions:
|
| 128 |
+
add_to_buffer = (not eval) or self._eval_replay_buffer is not None
|
| 129 |
+
if add_to_buffer:
|
| 130 |
+
kwargs = dict(transition.observation)
|
| 131 |
+
replay_index = transition.info["active_task_id"]
|
| 132 |
+
rb = self._eval_replay_buffer[replay_index] if eval else self._train_replay_buffer[replay_index]
|
| 133 |
+
rb.add(
|
| 134 |
+
np.array(transition.action), transition.reward,
|
| 135 |
+
transition.terminal,
|
| 136 |
+
transition.timeout, **kwargs)
|
| 137 |
+
if transition.terminal:
|
| 138 |
+
rb.add_final(
|
| 139 |
+
**transition.final_observation)
|
| 140 |
+
new_transitions[name] += 1
|
| 141 |
+
self._new_transitions[
|
| 142 |
+
'eval_envs' if eval else 'train_envs'] += 1
|
| 143 |
+
self._total_transitions[
|
| 144 |
+
'eval_envs' if eval else 'train_envs'] += 1
|
| 145 |
+
if self._stat_accumulator is not None:
|
| 146 |
+
self._stat_accumulator.step(transition, eval)
|
| 147 |
+
self._current_task_id = transition.info["active_task_id"] if eval else -1
|
| 148 |
+
self._internal_env_runner.stored_transitions[:] = [] # Clear list
|
| 149 |
+
return new_transitions
|
| 150 |
+
|
| 151 |
+
def _run(self, save_load_lock):
|
| 152 |
+
self._internal_env_runner = _EnvRunner(
|
| 153 |
+
self._train_env, self._eval_env, self._agent, self._timesteps, self._train_envs,
|
| 154 |
+
self._eval_envs, self._rollout_episodes, self._eval_episodes,
|
| 155 |
+
self._training_iterations, self._eval_from_eps_number, self._episode_length, self._kill_signal,
|
| 156 |
+
self._step_signal, self._num_eval_episodes_signal,
|
| 157 |
+
self._eval_epochs_signal, self._eval_report_signal,
|
| 158 |
+
self.log_freq, self._rollout_generator, save_load_lock,
|
| 159 |
+
self.current_replay_ratio, self.target_replay_ratio,
|
| 160 |
+
self._weightsdir, self._logdir,
|
| 161 |
+
self._env_device, self._previous_loaded_weight_folder,
|
| 162 |
+
num_eval_runs=self._num_eval_runs)
|
| 163 |
+
training_envs = self._internal_env_runner.spin_up_envs('train_env', self._train_envs, False)
|
| 164 |
+
eval_envs = self._internal_env_runner.spin_up_envs('eval_env', self._eval_envs, True)
|
| 165 |
+
envs = training_envs + eval_envs
|
| 166 |
+
no_transitions = {env.name: 0 for env in envs}
|
| 167 |
+
while True:
|
| 168 |
+
for p in envs:
|
| 169 |
+
if p.exitcode is not None:
|
| 170 |
+
envs.remove(p)
|
| 171 |
+
if p.exitcode != 0:
|
| 172 |
+
self._internal_env_runner.p_failures[p.name] += 1
|
| 173 |
+
n_failures = self._internal_env_runner.p_failures[p.name]
|
| 174 |
+
if n_failures > self._max_fails:
|
| 175 |
+
logging.error('Env %s failed too many times (%d times > %d)' %
|
| 176 |
+
(p.name, n_failures, self._max_fails))
|
| 177 |
+
raise RuntimeError('Too many process failures.')
|
| 178 |
+
logging.warning('Env %s failed (%d times <= %d). restarting' %
|
| 179 |
+
(p.name, n_failures, self._max_fails))
|
| 180 |
+
p = self._internal_env_runner.restart_process(p.name)
|
| 181 |
+
envs.append(p)
|
| 182 |
+
|
| 183 |
+
if not self._kill_signal.value:
|
| 184 |
+
new_transitions = self._update()
|
| 185 |
+
for p in envs:
|
| 186 |
+
if new_transitions[p.name] == 0:
|
| 187 |
+
no_transitions[p.name] += 1
|
| 188 |
+
else:
|
| 189 |
+
no_transitions[p.name] = 0
|
| 190 |
+
if no_transitions[p.name] > 1200: #600: # 10min
|
| 191 |
+
logging.warning("Env %s hangs, so restarting" % p.name)
|
| 192 |
+
envs.remove(p)
|
| 193 |
+
os.kill(p.pid, signal.SIGTERM)
|
| 194 |
+
p = self._internal_env_runner.restart_process(p.name)
|
| 195 |
+
envs.append(p)
|
| 196 |
+
no_transitions[p.name] = 0
|
| 197 |
+
|
| 198 |
+
if len(envs) == 0:
|
| 199 |
+
break
|
| 200 |
+
time.sleep(1)
|
| 201 |
+
|
| 202 |
+
def start(self, save_load_lock):
|
| 203 |
+
self._p = Thread(target=self._run, args=(save_load_lock,), daemon=True)
|
| 204 |
+
self._p.name = 'EnvRunnerThread'
|
| 205 |
+
self._p.start()
|
| 206 |
+
|
| 207 |
+
def wait(self):
|
| 208 |
+
if self._p.is_alive():
|
| 209 |
+
self._p.join()
|
| 210 |
+
|
| 211 |
+
def stop(self):
|
| 212 |
+
if self._p.is_alive():
|
| 213 |
+
self._kill_signal.value = True
|
| 214 |
+
self._p.join()
|
| 215 |
+
|
| 216 |
+
def set_step(self, step):
|
| 217 |
+
self._step_signal.value = step
|
| 218 |
+
|
| 219 |
+
def set_eval_report(self, report):
|
| 220 |
+
self._eval_report_signal.value = report
|
| 221 |
+
|
| 222 |
+
def set_eval_epochs(self, epochs):
|
| 223 |
+
self._eval_epochs_signal.value = epochs
|
| 224 |
+
|
external/yarr/yarr/runners/independent_env_runner.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from typing import List
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
from yarr.agents.agent import Agent
|
| 7 |
+
from yarr.envs.env import Env
|
| 8 |
+
from yarr.replay_buffer.replay_buffer import ReplayBuffer
|
| 9 |
+
from yarr.runners._independent_env_runner import _IndependentEnvRunner
|
| 10 |
+
from yarr.utils.rollout_generator import RolloutGenerator
|
| 11 |
+
from yarr.utils.stat_accumulator import StatAccumulator, SimpleAccumulator
|
| 12 |
+
from yarr.agents.agent import Summary
|
| 13 |
+
from helpers.custom_rlbench_env import CustomRLBenchEnv, CustomMultiTaskRLBenchEnv
|
| 14 |
+
|
| 15 |
+
from yarr.runners.env_runner import EnvRunner
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class IndependentEnvRunner(EnvRunner):
|
| 19 |
+
|
| 20 |
+
def __init__(self,
|
| 21 |
+
train_env: Env,
|
| 22 |
+
agent: Agent,
|
| 23 |
+
train_replay_buffer: Union[ReplayBuffer, List[ReplayBuffer]],
|
| 24 |
+
num_train_envs: int,
|
| 25 |
+
num_eval_envs: int,
|
| 26 |
+
rollout_episodes: int,
|
| 27 |
+
eval_episodes: int,
|
| 28 |
+
training_iterations: int,
|
| 29 |
+
eval_from_eps_number: int,
|
| 30 |
+
episode_length: int,
|
| 31 |
+
eval_env: Union[Env, None] = None,
|
| 32 |
+
eval_replay_buffer: Union[ReplayBuffer, List[ReplayBuffer], None] = None,
|
| 33 |
+
stat_accumulator: Union[StatAccumulator, None] = None,
|
| 34 |
+
rollout_generator: RolloutGenerator = None,
|
| 35 |
+
weightsdir: str = None,
|
| 36 |
+
logdir: str = None,
|
| 37 |
+
max_fails: int = 10,
|
| 38 |
+
num_eval_runs: int = 1,
|
| 39 |
+
env_device: torch.device = None,
|
| 40 |
+
multi_task: bool = False):
|
| 41 |
+
super().__init__(train_env, agent, train_replay_buffer, num_train_envs, num_eval_envs,
|
| 42 |
+
rollout_episodes, eval_episodes, training_iterations, eval_from_eps_number,
|
| 43 |
+
episode_length, eval_env, eval_replay_buffer, stat_accumulator,
|
| 44 |
+
rollout_generator, weightsdir, logdir, max_fails, num_eval_runs,
|
| 45 |
+
env_device, multi_task)
|
| 46 |
+
|
| 47 |
+
def summaries(self) -> List[Summary]:
|
| 48 |
+
summaries = []
|
| 49 |
+
if self._stat_accumulator is not None:
|
| 50 |
+
summaries.extend(self._stat_accumulator.pop())
|
| 51 |
+
self._new_transitions = {'train_envs': 0, 'eval_envs': 0}
|
| 52 |
+
summaries.extend(self._agent_summaries)
|
| 53 |
+
|
| 54 |
+
# add current task_name to eval summaries .... argh this should be inside a helper function
|
| 55 |
+
if hasattr(self._eval_env, '_task_class'):
|
| 56 |
+
eval_task_name = change_case(self._eval_env._task_class.__name__)
|
| 57 |
+
elif hasattr(self._eval_env, '_task_classes'):
|
| 58 |
+
if self._current_task_id != -1:
|
| 59 |
+
task_id = (self._current_task_id) % len(self._eval_env._task_classes)
|
| 60 |
+
eval_task_name = change_case(self._eval_env._task_classes[task_id].__name__)
|
| 61 |
+
else:
|
| 62 |
+
eval_task_name = ''
|
| 63 |
+
else:
|
| 64 |
+
raise Exception('Neither task_class nor task_classes found in eval env')
|
| 65 |
+
|
| 66 |
+
# multi-task summaries
|
| 67 |
+
if eval_task_name and self._multi_task:
|
| 68 |
+
for s in summaries:
|
| 69 |
+
if 'eval' in s.name:
|
| 70 |
+
s.name = '%s/%s' % (s.name, eval_task_name)
|
| 71 |
+
|
| 72 |
+
return summaries
|
| 73 |
+
|
| 74 |
+
# serialized evaluator for individual tasks
|
| 75 |
+
def start(self, weight,
|
| 76 |
+
save_load_lock, writer_lock,
|
| 77 |
+
env_config,
|
| 78 |
+
device_idx,
|
| 79 |
+
save_metrics,
|
| 80 |
+
cinematic_recorder_cfg):
|
| 81 |
+
|
| 82 |
+
if hasattr(self, "_on_thread_start"):
|
| 83 |
+
self._on_thread_start()
|
| 84 |
+
|
| 85 |
+
multi_task = isinstance(env_config[0], list)
|
| 86 |
+
if multi_task:
|
| 87 |
+
eval_env = CustomMultiTaskRLBenchEnv(
|
| 88 |
+
task_classes=env_config[0],
|
| 89 |
+
observation_config=env_config[1],
|
| 90 |
+
action_mode=env_config[2],
|
| 91 |
+
dataset_root=env_config[3],
|
| 92 |
+
episode_length=env_config[4],
|
| 93 |
+
headless=env_config[5],
|
| 94 |
+
swap_task_every=env_config[6],
|
| 95 |
+
include_lang_goal_in_obs=env_config[7],
|
| 96 |
+
time_in_state=env_config[8],
|
| 97 |
+
record_every_n=env_config[9])
|
| 98 |
+
else:
|
| 99 |
+
eval_env = CustomRLBenchEnv(
|
| 100 |
+
task_class=env_config[0],
|
| 101 |
+
observation_config=env_config[1],
|
| 102 |
+
action_mode=env_config[2],
|
| 103 |
+
dataset_root=env_config[3],
|
| 104 |
+
episode_length=env_config[4],
|
| 105 |
+
headless=env_config[5],
|
| 106 |
+
include_lang_goal_in_obs=env_config[6],
|
| 107 |
+
time_in_state=env_config[7],
|
| 108 |
+
record_every_n=env_config[8])
|
| 109 |
+
|
| 110 |
+
self._internal_env_runner = _IndependentEnvRunner(
|
| 111 |
+
self._train_env, eval_env, self._agent, self._timesteps, self._train_envs,
|
| 112 |
+
self._eval_envs, self._rollout_episodes, self._eval_episodes,
|
| 113 |
+
self._training_iterations, self._eval_from_eps_number, self._episode_length, self._kill_signal,
|
| 114 |
+
self._step_signal, self._num_eval_episodes_signal,
|
| 115 |
+
self._eval_epochs_signal, self._eval_report_signal,
|
| 116 |
+
self.log_freq, self._rollout_generator, None,
|
| 117 |
+
self.current_replay_ratio, self.target_replay_ratio,
|
| 118 |
+
self._weightsdir, self._logdir,
|
| 119 |
+
self._env_device, self._previous_loaded_weight_folder,
|
| 120 |
+
num_eval_runs=self._num_eval_runs)
|
| 121 |
+
|
| 122 |
+
stat_accumulator = SimpleAccumulator(eval_video_fps=30)
|
| 123 |
+
self._internal_env_runner._run_eval_independent('eval_env',
|
| 124 |
+
stat_accumulator,
|
| 125 |
+
weight,
|
| 126 |
+
writer_lock,
|
| 127 |
+
True,
|
| 128 |
+
device_idx,
|
| 129 |
+
save_metrics,
|
| 130 |
+
cinematic_recorder_cfg)
|
external/yarr/yarr/runners/offline_train_runner.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
import time
|
| 6 |
+
from typing import List
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import psutil
|
| 10 |
+
import torch
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from yarr.agents.agent import Agent
|
| 13 |
+
from yarr.replay_buffer.wrappers.pytorch_replay_buffer import \
|
| 14 |
+
PyTorchReplayBuffer
|
| 15 |
+
from yarr.utils.log_writer import LogWriter
|
| 16 |
+
from yarr.utils.stat_accumulator import StatAccumulator
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class OfflineTrainRunner():
|
| 20 |
+
|
| 21 |
+
def __init__(self,
|
| 22 |
+
agent: Agent,
|
| 23 |
+
wrapped_replay_buffer: PyTorchReplayBuffer,
|
| 24 |
+
train_device: torch.device,
|
| 25 |
+
stat_accumulator: Union[StatAccumulator, None] = None,
|
| 26 |
+
iterations: int = int(6e6),
|
| 27 |
+
logdir: str = '/tmp/yarr/logs',
|
| 28 |
+
logging_level: int = logging.INFO,
|
| 29 |
+
log_freq: int = 10,
|
| 30 |
+
weightsdir: str = '/tmp/yarr/weights',
|
| 31 |
+
num_weights_to_keep: int = 60,
|
| 32 |
+
save_freq: int = 100,
|
| 33 |
+
tensorboard_logging: bool = True,
|
| 34 |
+
csv_logging: bool = False,
|
| 35 |
+
load_existing_weights: bool = True,
|
| 36 |
+
rank: int = None,
|
| 37 |
+
world_size: int = None):
|
| 38 |
+
self._agent = agent
|
| 39 |
+
self._wrapped_buffer = wrapped_replay_buffer
|
| 40 |
+
self._stat_accumulator = stat_accumulator
|
| 41 |
+
self._iterations = iterations
|
| 42 |
+
self._logdir = logdir
|
| 43 |
+
self._logging_level = logging_level
|
| 44 |
+
self._log_freq = log_freq
|
| 45 |
+
self._weightsdir = weightsdir
|
| 46 |
+
self._num_weights_to_keep = num_weights_to_keep
|
| 47 |
+
self._save_freq = save_freq
|
| 48 |
+
|
| 49 |
+
self._wrapped_buffer = wrapped_replay_buffer
|
| 50 |
+
self._train_device = train_device
|
| 51 |
+
self._tensorboard_logging = tensorboard_logging
|
| 52 |
+
self._csv_logging = csv_logging
|
| 53 |
+
self._load_existing_weights = load_existing_weights
|
| 54 |
+
self._rank = rank
|
| 55 |
+
self._world_size = world_size
|
| 56 |
+
|
| 57 |
+
self._writer = None
|
| 58 |
+
if logdir is None:
|
| 59 |
+
logging.info("'logdir' was None. No logging will take place.")
|
| 60 |
+
else:
|
| 61 |
+
self._writer = LogWriter(
|
| 62 |
+
self._logdir, tensorboard_logging, csv_logging)
|
| 63 |
+
|
| 64 |
+
if weightsdir is None:
|
| 65 |
+
logging.info(
|
| 66 |
+
"'weightsdir' was None. No weight saving will take place.")
|
| 67 |
+
else:
|
| 68 |
+
os.makedirs(self._weightsdir, exist_ok=True)
|
| 69 |
+
|
| 70 |
+
def _save_model(self, i):
|
| 71 |
+
d = os.path.join(self._weightsdir, str(i))
|
| 72 |
+
os.makedirs(d, exist_ok=True)
|
| 73 |
+
self._agent.save_weights(d)
|
| 74 |
+
|
| 75 |
+
# remove oldest save
|
| 76 |
+
prev_dir = os.path.join(self._weightsdir, str(
|
| 77 |
+
i - self._save_freq * self._num_weights_to_keep))
|
| 78 |
+
if os.path.exists(prev_dir):
|
| 79 |
+
shutil.rmtree(prev_dir)
|
| 80 |
+
|
| 81 |
+
def _step(self, i, sampled_batch):
|
| 82 |
+
update_dict = self._agent.update(i, sampled_batch)
|
| 83 |
+
total_losses = update_dict['total_losses']
|
| 84 |
+
return total_losses
|
| 85 |
+
|
| 86 |
+
def _get_resume_eval_epoch(self):
|
| 87 |
+
starting_epoch = 0
|
| 88 |
+
eval_csv_file = self._weightsdir.replace('weights', 'eval_data.csv') # TODO(mohit): check if it's supposed be 'env_data.csv'
|
| 89 |
+
if os.path.exists(eval_csv_file):
|
| 90 |
+
eval_dict = pd.read_csv(eval_csv_file).to_dict()
|
| 91 |
+
epochs = list(eval_dict['step'].values())
|
| 92 |
+
return epochs[-1] if len(epochs) > 0 else starting_epoch
|
| 93 |
+
else:
|
| 94 |
+
return starting_epoch
|
| 95 |
+
|
| 96 |
+
def start(self):
|
| 97 |
+
|
| 98 |
+
if hasattr(self, "_on_thread_start"):
|
| 99 |
+
self._on_thread_start()
|
| 100 |
+
else:
|
| 101 |
+
logging.getLogger().setLevel(self._logging_level)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
self._agent = copy.deepcopy(self._agent)
|
| 105 |
+
self._agent.build(training=True, device=self._train_device)
|
| 106 |
+
|
| 107 |
+
if self._weightsdir is not None:
|
| 108 |
+
existing_weights = sorted([int(f) for f in os.listdir(self._weightsdir)])
|
| 109 |
+
if (not self._load_existing_weights) or len(existing_weights) == 0:
|
| 110 |
+
# self._save_model(0)
|
| 111 |
+
start_iter = 0
|
| 112 |
+
else:
|
| 113 |
+
resume_iteration = existing_weights[-1]
|
| 114 |
+
self._agent.load_weights(os.path.join(self._weightsdir, str(resume_iteration)))
|
| 115 |
+
start_iter = resume_iteration + 1
|
| 116 |
+
if self._rank == 0:
|
| 117 |
+
logging.info(f"Resuming training from iteration {resume_iteration} ...")
|
| 118 |
+
|
| 119 |
+
dataset = self._wrapped_buffer.dataset()
|
| 120 |
+
data_iter = iter(dataset)
|
| 121 |
+
|
| 122 |
+
process = psutil.Process(os.getpid())
|
| 123 |
+
num_cpu = psutil.cpu_count()
|
| 124 |
+
|
| 125 |
+
for i in range(start_iter, self._iterations):
|
| 126 |
+
log_iteration = i % self._log_freq == 0 and i > 0
|
| 127 |
+
|
| 128 |
+
if log_iteration:
|
| 129 |
+
process.cpu_percent(interval=None)
|
| 130 |
+
|
| 131 |
+
t = time.time()
|
| 132 |
+
sampled_batch = next(data_iter)
|
| 133 |
+
sample_time = time.time() - t
|
| 134 |
+
|
| 135 |
+
batch = {k: v.to(self._train_device) for k, v in sampled_batch.items() if type(v) == torch.Tensor}
|
| 136 |
+
t = time.time()
|
| 137 |
+
loss = self._step(i, batch)
|
| 138 |
+
step_time = time.time() - t
|
| 139 |
+
|
| 140 |
+
if self._rank == 0:
|
| 141 |
+
if log_iteration and self._writer is not None:
|
| 142 |
+
agent_summaries = self._agent.update_summaries()
|
| 143 |
+
self._writer.add_summaries(i, agent_summaries)
|
| 144 |
+
|
| 145 |
+
self._writer.add_scalar(
|
| 146 |
+
i, 'monitoring/memory_gb',
|
| 147 |
+
process.memory_info().rss * 1e-9)
|
| 148 |
+
self._writer.add_scalar(
|
| 149 |
+
i, 'monitoring/cpu_percent',
|
| 150 |
+
process.cpu_percent(interval=None) / num_cpu)
|
| 151 |
+
|
| 152 |
+
logging.info(f"Train Step {i:06d} | Loss: {loss:0.5f} | Sample time: {sample_time:0.6f} | Step time: {step_time:0.4f}.")
|
| 153 |
+
|
| 154 |
+
self._writer.end_iteration()
|
| 155 |
+
|
| 156 |
+
if i % self._save_freq == 0 and self._weightsdir is not None:
|
| 157 |
+
self._save_model(i)
|
| 158 |
+
|
| 159 |
+
if self._rank == 0 and self._writer is not None:
|
| 160 |
+
self._writer.close()
|
| 161 |
+
logging.info('Stopping envs ...')
|
| 162 |
+
|
| 163 |
+
self._wrapped_buffer.replay_buffer.shutdown()
|
external/yarr/yarr/runners/pytorch_train_runner.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
import signal
|
| 6 |
+
import sys
|
| 7 |
+
import threading
|
| 8 |
+
import time
|
| 9 |
+
from multiprocessing import Lock
|
| 10 |
+
from typing import Optional, List
|
| 11 |
+
from typing import Union
|
| 12 |
+
|
| 13 |
+
import gc
|
| 14 |
+
import numpy as np
|
| 15 |
+
import psutil
|
| 16 |
+
import torch
|
| 17 |
+
import pandas as pd
|
| 18 |
+
from yarr.agents.agent import Agent
|
| 19 |
+
from yarr.replay_buffer.wrappers.pytorch_replay_buffer import \
|
| 20 |
+
PyTorchReplayBuffer
|
| 21 |
+
from yarr.runners.env_runner import EnvRunner
|
| 22 |
+
from yarr.runners.train_runner import TrainRunner
|
| 23 |
+
from yarr.utils.log_writer import LogWriter
|
| 24 |
+
from yarr.utils.stat_accumulator import StatAccumulator
|
| 25 |
+
from yarr.replay_buffer.prioritized_replay_buffer import PrioritizedReplayBuffer
|
| 26 |
+
|
| 27 |
+
NUM_WEIGHTS_TO_KEEP = 60
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class PyTorchTrainRunner(TrainRunner):
|
| 31 |
+
|
| 32 |
+
def __init__(self,
|
| 33 |
+
agent: Agent,
|
| 34 |
+
env_runner: EnvRunner,
|
| 35 |
+
wrapped_replay_buffer: Union[
|
| 36 |
+
PyTorchReplayBuffer, List[PyTorchReplayBuffer]],
|
| 37 |
+
train_device: torch.device,
|
| 38 |
+
replay_buffer_sample_rates: List[float] = None,
|
| 39 |
+
stat_accumulator: Union[StatAccumulator, None] = None,
|
| 40 |
+
iterations: int = int(1e6),
|
| 41 |
+
num_train_envs: int = 1,
|
| 42 |
+
num_eval_envs: int = 1,
|
| 43 |
+
eval_episodes: int = 10,
|
| 44 |
+
logdir: str = '/tmp/yarr/logs',
|
| 45 |
+
log_freq: int = 10,
|
| 46 |
+
transitions_before_train: int = 1000,
|
| 47 |
+
weightsdir: str = '/tmp/yarr/weights',
|
| 48 |
+
save_freq: int = 100,
|
| 49 |
+
replay_ratio: Optional[float] = None,
|
| 50 |
+
tensorboard_logging: bool = True,
|
| 51 |
+
csv_logging: bool = False,
|
| 52 |
+
buffers_per_batch: int = -1, # -1 = all
|
| 53 |
+
load_existing_weights: bool = True):
|
| 54 |
+
super(PyTorchTrainRunner, self).__init__(
|
| 55 |
+
agent, env_runner, wrapped_replay_buffer,
|
| 56 |
+
stat_accumulator,
|
| 57 |
+
iterations, logdir, log_freq, transitions_before_train, weightsdir,
|
| 58 |
+
save_freq)
|
| 59 |
+
|
| 60 |
+
env_runner.log_freq = log_freq
|
| 61 |
+
env_runner.target_replay_ratio = replay_ratio
|
| 62 |
+
self._wrapped_buffer = wrapped_replay_buffer if isinstance(
|
| 63 |
+
wrapped_replay_buffer, list) else [wrapped_replay_buffer]
|
| 64 |
+
self._replay_buffer_sample_rates = (
|
| 65 |
+
[1.0] if replay_buffer_sample_rates is None else
|
| 66 |
+
replay_buffer_sample_rates)
|
| 67 |
+
if len(self._replay_buffer_sample_rates) != len(wrapped_replay_buffer):
|
| 68 |
+
logging.warning(
|
| 69 |
+
'Numbers of replay buffers differs from sampling rates. Setting as uniform sampling.')
|
| 70 |
+
self._replay_buffer_sample_rates = [1.0 / len(self._wrapped_buffer)] * len(self._wrapped_buffer)
|
| 71 |
+
if sum(self._replay_buffer_sample_rates) != 1:
|
| 72 |
+
raise ValueError('Sum of sampling rates should be 1.')
|
| 73 |
+
|
| 74 |
+
self._train_device = train_device
|
| 75 |
+
self._tensorboard_logging = tensorboard_logging
|
| 76 |
+
self._csv_logging = csv_logging
|
| 77 |
+
self._num_train_envs = num_train_envs
|
| 78 |
+
self._num_eval_envs = num_eval_envs
|
| 79 |
+
self._eval_episodes = eval_episodes
|
| 80 |
+
self._load_existing_weights = load_existing_weights
|
| 81 |
+
|
| 82 |
+
if replay_ratio is not None and replay_ratio < 0:
|
| 83 |
+
raise ValueError("max_replay_ratio must be positive.")
|
| 84 |
+
self._target_replay_ratio = replay_ratio
|
| 85 |
+
|
| 86 |
+
self._writer = None
|
| 87 |
+
if logdir is None:
|
| 88 |
+
logging.info("'logdir' was None. No logging will take place.")
|
| 89 |
+
else:
|
| 90 |
+
self._writer = LogWriter(
|
| 91 |
+
self._logdir, tensorboard_logging, csv_logging)
|
| 92 |
+
if weightsdir is None:
|
| 93 |
+
logging.info(
|
| 94 |
+
"'weightsdir' was None. No weight saving will take place.")
|
| 95 |
+
else:
|
| 96 |
+
os.makedirs(self._weightsdir, exist_ok=True)
|
| 97 |
+
self._buffers_per_batch = buffers_per_batch if buffers_per_batch > 0 else len(wrapped_replay_buffer)
|
| 98 |
+
|
| 99 |
+
def _save_model(self, i):
|
| 100 |
+
with self._save_load_lock:
|
| 101 |
+
d = os.path.join(self._weightsdir, str(i))
|
| 102 |
+
os.makedirs(d, exist_ok=True)
|
| 103 |
+
self._agent.save_weights(d)
|
| 104 |
+
# Remove oldest save
|
| 105 |
+
prev_dir = os.path.join(self._weightsdir, str(
|
| 106 |
+
i - self._save_freq * NUM_WEIGHTS_TO_KEEP))
|
| 107 |
+
if os.path.exists(prev_dir):
|
| 108 |
+
shutil.rmtree(prev_dir)
|
| 109 |
+
|
| 110 |
+
def _step(self, i, sampled_batch):
|
| 111 |
+
update_dict = self._agent.update(i, sampled_batch)
|
| 112 |
+
if "priority" in update_dict:
|
| 113 |
+
priority = update_dict['priority'].cpu().detach().numpy() if isinstance(update_dict['priority'], torch.Tensor) else np.numpy(update_dict['priority'])
|
| 114 |
+
else:
|
| 115 |
+
priority = None
|
| 116 |
+
indices = sampled_batch['indices'].cpu().detach().numpy()
|
| 117 |
+
acc_bs = 0
|
| 118 |
+
for wb_idx, wb in enumerate(self._wrapped_buffer):
|
| 119 |
+
bs = wb.replay_buffer.batch_size
|
| 120 |
+
if 'priority' in update_dict:
|
| 121 |
+
indices_ = indices[:, wb_idx]
|
| 122 |
+
if hasattr(wb, "replay_buffer"):
|
| 123 |
+
if len(priority.shape) > 1:
|
| 124 |
+
priority_ = priority[:, wb_idx]
|
| 125 |
+
else:
|
| 126 |
+
# legacy version
|
| 127 |
+
priority_ = priority[acc_bs: acc_bs + bs]
|
| 128 |
+
wb.replay_buffer.set_priority(indices_, priority_)
|
| 129 |
+
acc_bs += bs
|
| 130 |
+
|
| 131 |
+
def _signal_handler(self, sig, frame):
|
| 132 |
+
if threading.current_thread().name != 'MainThread':
|
| 133 |
+
return
|
| 134 |
+
logging.info('SIGINT captured. Shutting down.'
|
| 135 |
+
'This may take a few seconds.')
|
| 136 |
+
self._env_runner.stop()
|
| 137 |
+
[r.replay_buffer.shutdown() for r in self._wrapped_buffer]
|
| 138 |
+
sys.exit(0)
|
| 139 |
+
|
| 140 |
+
def _get_add_counts(self):
|
| 141 |
+
return np.array([
|
| 142 |
+
r.replay_buffer.add_count for r in self._wrapped_buffer])
|
| 143 |
+
|
| 144 |
+
def _get_sum_add_counts(self):
|
| 145 |
+
return sum([
|
| 146 |
+
r.replay_buffer.add_count for r in self._wrapped_buffer])
|
| 147 |
+
|
| 148 |
+
def _get_resume_eval_epoch(self):
|
| 149 |
+
starting_epoch = 0
|
| 150 |
+
eval_csv_file = self._weightsdir.replace('weights', 'eval_data.csv') # TODO(mohit): check if it's supposed be 'env_data.csv'
|
| 151 |
+
if os.path.exists(eval_csv_file):
|
| 152 |
+
eval_dict = pd.read_csv(eval_csv_file).to_dict()
|
| 153 |
+
epochs = list(eval_dict['step'].values())
|
| 154 |
+
return epochs[-1] if len(epochs) > 0 else starting_epoch
|
| 155 |
+
else:
|
| 156 |
+
return starting_epoch
|
| 157 |
+
|
| 158 |
+
def start(self):
|
| 159 |
+
|
| 160 |
+
signal.signal(signal.SIGINT, self._signal_handler)
|
| 161 |
+
|
| 162 |
+
self._save_load_lock = Lock()
|
| 163 |
+
|
| 164 |
+
# Kick off the environments
|
| 165 |
+
self._env_runner.start(self._save_load_lock)
|
| 166 |
+
|
| 167 |
+
self._agent = copy.deepcopy(self._agent)
|
| 168 |
+
self._agent.build(training=True, device=self._train_device)
|
| 169 |
+
|
| 170 |
+
if self._weightsdir is not None:
|
| 171 |
+
existing_weights = sorted([int(f) for f in os.listdir(self._weightsdir)])
|
| 172 |
+
if (not self._load_existing_weights) or len(existing_weights) == 0:
|
| 173 |
+
self._save_model(0)
|
| 174 |
+
start_iter = 0
|
| 175 |
+
else:
|
| 176 |
+
resume_iteration = existing_weights[-1]
|
| 177 |
+
self._agent.load_weights(os.path.join(self._weightsdir, str(resume_iteration)))
|
| 178 |
+
start_iter = resume_iteration + 1
|
| 179 |
+
print(f"Resuming training from iteration {resume_iteration} ...")
|
| 180 |
+
|
| 181 |
+
if self._num_eval_envs > 0:
|
| 182 |
+
eval_epoch = self._get_resume_eval_epoch()
|
| 183 |
+
self._env_runner.set_eval_epochs(eval_epoch)
|
| 184 |
+
self._writer.set_resumed_from_prev_run(True)
|
| 185 |
+
print(f"Resuming evaluation from epoch {eval_epoch} ...")
|
| 186 |
+
|
| 187 |
+
while (np.any(self._get_add_counts() < self._transitions_before_train)):
|
| 188 |
+
time.sleep(1)
|
| 189 |
+
logging.info(
|
| 190 |
+
'Waiting for %d samples before training. Currently have %s.' %
|
| 191 |
+
(self._transitions_before_train, str(self._get_add_counts())))
|
| 192 |
+
|
| 193 |
+
datasets = [r.dataset() for r in self._wrapped_buffer]
|
| 194 |
+
data_iter = [iter(d) for d in datasets]
|
| 195 |
+
|
| 196 |
+
init_replay_size = self._get_sum_add_counts().astype(float)
|
| 197 |
+
batch_times_buffers_per_sample = sum([
|
| 198 |
+
r.replay_buffer.batch_size for r in self._wrapped_buffer[:self._buffers_per_batch]])
|
| 199 |
+
process = psutil.Process(os.getpid())
|
| 200 |
+
num_cpu = psutil.cpu_count()
|
| 201 |
+
|
| 202 |
+
for i in range(start_iter, self._iterations):
|
| 203 |
+
self._env_runner.set_step(i)
|
| 204 |
+
|
| 205 |
+
if self._num_train_envs > 0 or self._num_eval_envs == 0:
|
| 206 |
+
log_iteration = i % self._log_freq == 0 and i > 0
|
| 207 |
+
else:
|
| 208 |
+
num_eval_episodes = self._env_runner._num_eval_episodes_signal.value
|
| 209 |
+
log_iteration = self._env_runner._eval_report_signal.value and num_eval_episodes > 0
|
| 210 |
+
|
| 211 |
+
if log_iteration:
|
| 212 |
+
process.cpu_percent(interval=None)
|
| 213 |
+
|
| 214 |
+
def get_replay_ratio():
|
| 215 |
+
size_used = batch_times_buffers_per_sample * i
|
| 216 |
+
size_added = (
|
| 217 |
+
self._get_sum_add_counts()
|
| 218 |
+
- init_replay_size
|
| 219 |
+
)
|
| 220 |
+
replay_ratio = size_used / (size_added + 1e-6)
|
| 221 |
+
return replay_ratio
|
| 222 |
+
|
| 223 |
+
if self._target_replay_ratio is not None:
|
| 224 |
+
# wait for env_runner collecting enough samples
|
| 225 |
+
while True:
|
| 226 |
+
replay_ratio = get_replay_ratio()
|
| 227 |
+
self._env_runner.current_replay_ratio.value = replay_ratio
|
| 228 |
+
if replay_ratio < self._target_replay_ratio:
|
| 229 |
+
break
|
| 230 |
+
time.sleep(1)
|
| 231 |
+
logging.debug(
|
| 232 |
+
'Waiting for replay_ratio %f to be less than %f.' %
|
| 233 |
+
(replay_ratio, self._target_replay_ratio))
|
| 234 |
+
del replay_ratio
|
| 235 |
+
|
| 236 |
+
t = time.time()
|
| 237 |
+
|
| 238 |
+
sampled_task_ids = np.random.choice(
|
| 239 |
+
range(len(datasets)), self._buffers_per_batch, replace=False)
|
| 240 |
+
sampled_batch = [next(data_iter[j]) for j in sampled_task_ids]
|
| 241 |
+
result = {}
|
| 242 |
+
for key in sampled_batch[0]:
|
| 243 |
+
result[key] = torch.stack([d[key] for d in sampled_batch], 1)
|
| 244 |
+
sampled_batch = result
|
| 245 |
+
sample_time = time.time() - t
|
| 246 |
+
|
| 247 |
+
batch = {k: v.to(self._train_device) for k, v in sampled_batch.items()}
|
| 248 |
+
t = time.time()
|
| 249 |
+
self._step(i, batch)
|
| 250 |
+
step_time = time.time() - t
|
| 251 |
+
|
| 252 |
+
if log_iteration and self._writer is not None:
|
| 253 |
+
replay_ratio = get_replay_ratio()
|
| 254 |
+
logging.info('Train Step %d. Eval Epoch %d. Sample time: %s. Step time: %s. Replay ratio: %s.' % (
|
| 255 |
+
i, self._env_runner._eval_epochs_signal.value, sample_time, step_time, replay_ratio))
|
| 256 |
+
agent_summaries = self._agent.update_summaries()
|
| 257 |
+
env_summaries = self._env_runner.summaries()
|
| 258 |
+
|
| 259 |
+
# agent summaries
|
| 260 |
+
self._writer.add_summaries(i, agent_summaries)
|
| 261 |
+
|
| 262 |
+
# env summaries
|
| 263 |
+
self._writer.add_summaries(self._env_runner._eval_epochs_signal.value, env_summaries)
|
| 264 |
+
|
| 265 |
+
for r_i, wrapped_buffer in enumerate(self._wrapped_buffer):
|
| 266 |
+
self._writer.add_scalar(
|
| 267 |
+
i, 'replay%d/add_count' % r_i,
|
| 268 |
+
wrapped_buffer.replay_buffer.add_count)
|
| 269 |
+
self._writer.add_scalar(
|
| 270 |
+
i, 'replay%d/size' % r_i,
|
| 271 |
+
wrapped_buffer.replay_buffer.replay_capacity
|
| 272 |
+
if wrapped_buffer.replay_buffer.is_full()
|
| 273 |
+
else wrapped_buffer.replay_buffer.add_count)
|
| 274 |
+
|
| 275 |
+
self._writer.add_scalar(
|
| 276 |
+
i, 'replay/replay_ratio', replay_ratio)
|
| 277 |
+
self._writer.add_scalar(
|
| 278 |
+
i, 'replay/update_to_insert_ratio',
|
| 279 |
+
float(i) / float(
|
| 280 |
+
self._get_sum_add_counts() -
|
| 281 |
+
init_replay_size + 1e-6))
|
| 282 |
+
|
| 283 |
+
self._writer.add_scalar(
|
| 284 |
+
i, 'monitoring/sample_time_per_item',
|
| 285 |
+
sample_time / batch_times_buffers_per_sample)
|
| 286 |
+
self._writer.add_scalar(
|
| 287 |
+
i, 'monitoring/train_time_per_item',
|
| 288 |
+
step_time / batch_times_buffers_per_sample)
|
| 289 |
+
self._writer.add_scalar(
|
| 290 |
+
i, 'monitoring/memory_gb',
|
| 291 |
+
process.memory_info().rss * 1e-9)
|
| 292 |
+
self._writer.add_scalar(
|
| 293 |
+
i, 'monitoring/cpu_percent',
|
| 294 |
+
process.cpu_percent(interval=None) / num_cpu)
|
| 295 |
+
|
| 296 |
+
self._env_runner.set_eval_report(False)
|
| 297 |
+
|
| 298 |
+
self._writer.end_iteration()
|
| 299 |
+
|
| 300 |
+
if i % self._save_freq == 0 and self._weightsdir is not None:
|
| 301 |
+
self._save_model(i)
|
| 302 |
+
|
| 303 |
+
if self._writer is not None:
|
| 304 |
+
self._writer.close()
|
| 305 |
+
|
| 306 |
+
logging.info('Stopping envs ...')
|
| 307 |
+
self._env_runner.stop()
|
| 308 |
+
[r.replay_buffer.shutdown() for r in self._wrapped_buffer]
|
external/yarr/yarr/runners/train_runner.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod, ABC
|
| 2 |
+
from typing import Union, List
|
| 3 |
+
|
| 4 |
+
from yarr.agents.agent import Agent
|
| 5 |
+
from yarr.replay_buffer.wrappers import WrappedReplayBuffer
|
| 6 |
+
from yarr.runners.env_runner import EnvRunner
|
| 7 |
+
from yarr.utils.stat_accumulator import StatAccumulator
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TrainRunner(ABC):
|
| 11 |
+
|
| 12 |
+
def __init__(self,
|
| 13 |
+
agent: Agent,
|
| 14 |
+
env_runner: EnvRunner,
|
| 15 |
+
wrapped_replay_buffer: WrappedReplayBuffer,
|
| 16 |
+
stat_accumulator: Union[StatAccumulator, None] = None,
|
| 17 |
+
iterations: int = int(1e6),
|
| 18 |
+
logdir: str = '/tmp/yarr/logs',
|
| 19 |
+
log_freq: int = 500,
|
| 20 |
+
transitions_before_train: int = 1000,
|
| 21 |
+
weightsdir: str = '/tmp/yarr/weights',
|
| 22 |
+
save_freq: int = 100,
|
| 23 |
+
):
|
| 24 |
+
self._agent = agent
|
| 25 |
+
self._env_runner = env_runner
|
| 26 |
+
self._wrapped_buffer = wrapped_replay_buffer
|
| 27 |
+
self._stat_accumulator = stat_accumulator
|
| 28 |
+
self._iterations = iterations
|
| 29 |
+
self._logdir = logdir
|
| 30 |
+
self._log_freq = log_freq
|
| 31 |
+
self._transitions_before_train = transitions_before_train
|
| 32 |
+
self._weightsdir = weightsdir
|
| 33 |
+
self._save_freq = save_freq
|
| 34 |
+
|
| 35 |
+
@abstractmethod
|
| 36 |
+
def start(self):
|
| 37 |
+
pass
|
external/yarr/yarr/utils/__init__.py
ADDED
|
File without changes
|
external/yarr/yarr/utils/log_writer.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from yarr.agents.agent import ScalarSummary, HistogramSummary, ImageSummary, \
|
| 9 |
+
VideoSummary, TextSummary
|
| 10 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LogWriter(object):
|
| 14 |
+
|
| 15 |
+
def __init__(self,
|
| 16 |
+
logdir: str,
|
| 17 |
+
tensorboard_logging: bool,
|
| 18 |
+
csv_logging: bool,
|
| 19 |
+
train_csv: str = 'train_data.csv',
|
| 20 |
+
env_csv: str = 'env_data.csv'):
|
| 21 |
+
self._tensorboard_logging = tensorboard_logging
|
| 22 |
+
self._csv_logging = csv_logging
|
| 23 |
+
os.makedirs(logdir, exist_ok=True)
|
| 24 |
+
if tensorboard_logging:
|
| 25 |
+
self._tf_writer = SummaryWriter(logdir)
|
| 26 |
+
if csv_logging:
|
| 27 |
+
self._train_prev_row_data = self._train_row_data = OrderedDict()
|
| 28 |
+
self._train_csv_file = os.path.join(logdir, train_csv)
|
| 29 |
+
self._env_prev_row_data = self._env_row_data = OrderedDict()
|
| 30 |
+
self._env_csv_file = os.path.join(logdir, env_csv)
|
| 31 |
+
self._train_field_names = None
|
| 32 |
+
self._env_field_names = None
|
| 33 |
+
|
| 34 |
+
def add_scalar(self, i, name, value):
|
| 35 |
+
if self._tensorboard_logging:
|
| 36 |
+
self._tf_writer.add_scalar(name, value, i)
|
| 37 |
+
if self._csv_logging:
|
| 38 |
+
if 'env' in name or 'eval' in name or 'test' in name:
|
| 39 |
+
if len(self._env_row_data) == 0:
|
| 40 |
+
self._env_row_data['step'] = i
|
| 41 |
+
self._env_row_data[name] = value.item() if isinstance(
|
| 42 |
+
value, torch.Tensor) else value
|
| 43 |
+
else:
|
| 44 |
+
if len(self._train_row_data) == 0:
|
| 45 |
+
self._train_row_data['step'] = i
|
| 46 |
+
self._train_row_data[name] = value.item() if isinstance(
|
| 47 |
+
value, torch.Tensor) else value
|
| 48 |
+
|
| 49 |
+
def add_summaries(self, i, summaries):
|
| 50 |
+
for summary in summaries:
|
| 51 |
+
try:
|
| 52 |
+
if isinstance(summary, ScalarSummary):
|
| 53 |
+
self.add_scalar(i, summary.name, summary.value)
|
| 54 |
+
elif self._tensorboard_logging:
|
| 55 |
+
if isinstance(summary, HistogramSummary):
|
| 56 |
+
self._tf_writer.add_histogram(
|
| 57 |
+
summary.name, summary.value, i)
|
| 58 |
+
elif isinstance(summary, ImageSummary):
|
| 59 |
+
# Only grab first item in batch
|
| 60 |
+
v = (summary.value if summary.value.ndim == 3 else
|
| 61 |
+
summary.value[0])
|
| 62 |
+
self._tf_writer.add_image(summary.name, v, i)
|
| 63 |
+
elif isinstance(summary, VideoSummary):
|
| 64 |
+
# Only grab first item in batch
|
| 65 |
+
v = (summary.value if summary.value.ndim == 5 else
|
| 66 |
+
np.array([summary.value]))
|
| 67 |
+
self._tf_writer.add_video(
|
| 68 |
+
summary.name, v, i, fps=summary.fps)
|
| 69 |
+
elif isinstance(summary, TextSummary):
|
| 70 |
+
self._tf_writer.add_text(summary.name, summary.value, i)
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logging.error('Error on summary: %s' % summary.name)
|
| 73 |
+
raise e
|
| 74 |
+
|
| 75 |
+
def end_iteration(self):
|
| 76 |
+
# write train data
|
| 77 |
+
if self._csv_logging and len(self._train_row_data) > 0:
|
| 78 |
+
should_write_train_header = not os.path.exists(self._train_csv_file)
|
| 79 |
+
with open(self._train_csv_file, mode='a+') as csv_f:
|
| 80 |
+
names = self._train_row_data.keys()
|
| 81 |
+
writer = csv.DictWriter(csv_f, fieldnames=names)
|
| 82 |
+
if should_write_train_header:
|
| 83 |
+
if self._train_field_names is None:
|
| 84 |
+
writer.writeheader()
|
| 85 |
+
else:
|
| 86 |
+
if not np.array_equal(self._train_field_names, self._train_row_data.keys()):
|
| 87 |
+
# Special case when we are logging faster than new
|
| 88 |
+
# summaries are coming in.
|
| 89 |
+
missing_keys = list(set(self._train_field_names) - set(
|
| 90 |
+
self._train_row_data.keys()))
|
| 91 |
+
for mk in missing_keys:
|
| 92 |
+
self._train_row_data[mk] = self._train_prev_row_data[mk]
|
| 93 |
+
self._train_field_names = names
|
| 94 |
+
try:
|
| 95 |
+
writer.writerow(self._train_row_data)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(e)
|
| 98 |
+
self._train_prev_row_data = self._train_row_data
|
| 99 |
+
self._train_row_data = OrderedDict()
|
| 100 |
+
|
| 101 |
+
# write env data (also eval or test during evaluation)
|
| 102 |
+
if self._csv_logging and len(self._env_row_data) > 0:
|
| 103 |
+
should_write_env_header = not os.path.exists(self._env_csv_file)
|
| 104 |
+
with open(self._env_csv_file, mode='a+') as csv_f:
|
| 105 |
+
names = self._env_row_data.keys()
|
| 106 |
+
writer = csv.DictWriter(csv_f, fieldnames=names)
|
| 107 |
+
if should_write_env_header:
|
| 108 |
+
if self._env_field_names is None:
|
| 109 |
+
writer.writeheader()
|
| 110 |
+
else:
|
| 111 |
+
if not np.array_equal(self._env_field_names, self._env_row_data.keys()):
|
| 112 |
+
# Special case when we are logging faster than new
|
| 113 |
+
# summaries are coming in.
|
| 114 |
+
missing_keys = list(set(self._env_field_names) - set(
|
| 115 |
+
self._env_row_data.keys()))
|
| 116 |
+
for mk in missing_keys:
|
| 117 |
+
self._env_row_data[mk] = self._env_prev_row_data[mk]
|
| 118 |
+
self._env_field_names = names
|
| 119 |
+
try:
|
| 120 |
+
writer.writerow(self._env_row_data)
|
| 121 |
+
except Exception as e:
|
| 122 |
+
print(e)
|
| 123 |
+
self._env_prev_row_data = self._env_row_data
|
| 124 |
+
self._env_row_data = OrderedDict()
|
| 125 |
+
|
| 126 |
+
def close(self):
|
| 127 |
+
if self._tensorboard_logging:
|
| 128 |
+
self._tf_writer.close()
|
external/yarr/yarr/utils/multi_task_rollout_generator.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from multiprocessing import Value
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from yarr.agents.agent import Agent
|
| 6 |
+
from yarr.envs.env import Env
|
| 7 |
+
from yarr.envs.multi_task_env import MultiTaskEnv
|
| 8 |
+
from yarr.utils.transition import ReplayTransition
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RolloutGenerator(object):
|
| 12 |
+
|
| 13 |
+
def _get_type(self, x):
|
| 14 |
+
if x.dtype == np.float64:
|
| 15 |
+
return np.float32
|
| 16 |
+
return x.dtype
|
| 17 |
+
|
| 18 |
+
def generator(self, step_signal: Value, env: MultiTaskEnv, agent: Agent,
|
| 19 |
+
episode_length: int, timesteps: int, eval: bool):
|
| 20 |
+
obs = env.reset()
|
| 21 |
+
agent.reset()
|
| 22 |
+
obs_history = {k: [np.array(v, dtype=self._get_type(v))] * timesteps for k, v in obs.items()}
|
| 23 |
+
for step in range(episode_length):
|
| 24 |
+
|
| 25 |
+
prepped_data = {k: np.array([v]) for k, v in obs_history.items()}
|
| 26 |
+
|
| 27 |
+
act_result = agent.act(step_signal.value, prepped_data,
|
| 28 |
+
deterministic=eval)
|
| 29 |
+
|
| 30 |
+
# Convert to np if not already
|
| 31 |
+
agent_obs_elems = {k: np.array(v) for k, v in
|
| 32 |
+
act_result.observation_elements.items()}
|
| 33 |
+
agent_extra_elems = {k: np.array(v) for k, v in
|
| 34 |
+
act_result.replay_elements.items()}
|
| 35 |
+
|
| 36 |
+
transition = env.step(act_result)
|
| 37 |
+
timeout = False
|
| 38 |
+
if step == episode_length - 1:
|
| 39 |
+
# If last transition, and not terminal, then we timed out
|
| 40 |
+
timeout = not transition.terminal
|
| 41 |
+
if timeout:
|
| 42 |
+
transition.terminal = True
|
| 43 |
+
if "needs_reset" in transition.info:
|
| 44 |
+
transition.info["needs_reset"] = True
|
| 45 |
+
|
| 46 |
+
obs.update(agent_obs_elems)
|
| 47 |
+
obs_tp1 = dict(transition.observation)
|
| 48 |
+
|
| 49 |
+
for k in obs_history.keys():
|
| 50 |
+
obs_history[k].append(transition.observation[k])
|
| 51 |
+
obs_history[k].pop(0)
|
| 52 |
+
|
| 53 |
+
transition.info["active_task_id"] = env.active_task_id
|
| 54 |
+
|
| 55 |
+
replay_transition = ReplayTransition(
|
| 56 |
+
obs, act_result.action, transition.reward,
|
| 57 |
+
transition.terminal,
|
| 58 |
+
timeout, obs_tp1, agent_extra_elems,
|
| 59 |
+
transition.info)
|
| 60 |
+
|
| 61 |
+
obs = transition.observation
|
| 62 |
+
yield replay_transition
|
| 63 |
+
|
| 64 |
+
if transition.info.get("needs_reset", transition.terminal):
|
| 65 |
+
return
|
external/yarr/yarr/utils/observation_type.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Type
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ObservationElement(object):
|
| 6 |
+
|
| 7 |
+
def __init__(self, name: str, shape: tuple, type: Type[np.dtype]):
|
| 8 |
+
self.name = name
|
| 9 |
+
self.shape = shape
|
| 10 |
+
self.type = type
|
external/yarr/yarr/utils/process_str.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import reduce
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def change_case(str):
|
| 5 |
+
return reduce(lambda x, y: x + ('_' if y.isupper() else '') + y, str).lower()
|
external/yarr/yarr/utils/rollout_generator.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from multiprocessing import Value
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from yarr.agents.agent import Agent
|
| 6 |
+
from yarr.envs.env import Env
|
| 7 |
+
from yarr.utils.transition import ReplayTransition
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class RolloutGenerator(object):
|
| 11 |
+
|
| 12 |
+
def _get_type(self, x):
|
| 13 |
+
if x.dtype == np.float64:
|
| 14 |
+
return np.float32
|
| 15 |
+
return x.dtype
|
| 16 |
+
|
| 17 |
+
def generator(self, step_signal: Value, env: Env, agent: Agent,
|
| 18 |
+
episode_length: int, timesteps: int,
|
| 19 |
+
eval: bool, eval_demo_seed: int = 0,
|
| 20 |
+
record_enabled: bool = False):
|
| 21 |
+
|
| 22 |
+
if eval:
|
| 23 |
+
obs = env.reset_to_demo(eval_demo_seed)
|
| 24 |
+
else:
|
| 25 |
+
obs = env.reset()
|
| 26 |
+
|
| 27 |
+
agent.reset()
|
| 28 |
+
obs_history = {k: [np.array(v, dtype=self._get_type(v))] * timesteps for k, v in obs.items()}
|
| 29 |
+
for step in range(episode_length):
|
| 30 |
+
|
| 31 |
+
prepped_data = {k:torch.tensor(np.array(v)[None], device=self._env_device) for k, v in obs_history.items()}
|
| 32 |
+
|
| 33 |
+
act_result = agent.act(step_signal.value, prepped_data,
|
| 34 |
+
deterministic=eval)
|
| 35 |
+
|
| 36 |
+
# Convert to np if not already
|
| 37 |
+
agent_obs_elems = {k: np.array(v) for k, v in
|
| 38 |
+
act_result.observation_elements.items()}
|
| 39 |
+
extra_replay_elements = {k: np.array(v) for k, v in
|
| 40 |
+
act_result.replay_elements.items()}
|
| 41 |
+
|
| 42 |
+
transition = env.step(act_result)
|
| 43 |
+
obs_tp1 = dict(transition.observation)
|
| 44 |
+
timeout = False
|
| 45 |
+
if step == episode_length - 1:
|
| 46 |
+
# If last transition, and not terminal, then we timed out
|
| 47 |
+
timeout = not transition.terminal
|
| 48 |
+
if timeout:
|
| 49 |
+
transition.terminal = True
|
| 50 |
+
if "needs_reset" in transition.info:
|
| 51 |
+
transition.info["needs_reset"] = True
|
| 52 |
+
|
| 53 |
+
obs_and_replay_elems = {}
|
| 54 |
+
obs_and_replay_elems.update(obs)
|
| 55 |
+
obs_and_replay_elems.update(agent_obs_elems)
|
| 56 |
+
obs_and_replay_elems.update(extra_replay_elements)
|
| 57 |
+
|
| 58 |
+
for k in obs_history.keys():
|
| 59 |
+
obs_history[k].append(transition.observation[k])
|
| 60 |
+
obs_history[k].pop(0)
|
| 61 |
+
|
| 62 |
+
transition.info["active_task_id"] = env.active_task_id
|
| 63 |
+
|
| 64 |
+
replay_transition = ReplayTransition(
|
| 65 |
+
obs_and_replay_elems, act_result.action, transition.reward,
|
| 66 |
+
transition.terminal, timeout, summaries=transition.summaries,
|
| 67 |
+
info=transition.info)
|
| 68 |
+
|
| 69 |
+
if transition.terminal or timeout:
|
| 70 |
+
# If the agent gives us observations then we need to call act
|
| 71 |
+
# one last time (i.e. acting in the terminal state).
|
| 72 |
+
if len(act_result.observation_elements) > 0:
|
| 73 |
+
prepped_data = {k: torch.tensor([v], device=self._env_device) for k, v in obs_history.items()}
|
| 74 |
+
act_result = agent.act(step_signal.value, prepped_data,
|
| 75 |
+
deterministic=eval)
|
| 76 |
+
agent_obs_elems_tp1 = {k: np.array(v) for k, v in
|
| 77 |
+
act_result.observation_elements.items()}
|
| 78 |
+
obs_tp1.update(agent_obs_elems_tp1)
|
| 79 |
+
replay_transition.final_observation = obs_tp1
|
| 80 |
+
|
| 81 |
+
if record_enabled and transition.terminal or timeout or step == episode_length - 1:
|
| 82 |
+
env.env._action_mode.arm_action_mode.record_end(env.env._scene,
|
| 83 |
+
steps=60, step_scene=True)
|
| 84 |
+
|
| 85 |
+
obs = dict(transition.observation)
|
| 86 |
+
yield replay_transition
|
| 87 |
+
|
| 88 |
+
if transition.info.get("needs_reset", transition.terminal):
|
| 89 |
+
return
|
external/yarr/yarr/utils/stat_accumulator.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from multiprocessing import Lock
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from yarr.agents.agent import Summary, ScalarSummary
|
| 6 |
+
from yarr.utils.transition import ReplayTransition
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class StatAccumulator(object):
|
| 10 |
+
|
| 11 |
+
def step(self, transition: ReplayTransition, eval: bool):
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
def pop(self) -> List[Summary]:
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
def peak(self) -> List[Summary]:
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
def reset(self) -> None:
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Metric(object):
|
| 25 |
+
|
| 26 |
+
def __init__(self):
|
| 27 |
+
self._previous = []
|
| 28 |
+
self._current = 0
|
| 29 |
+
|
| 30 |
+
def update(self, value):
|
| 31 |
+
self._current += value
|
| 32 |
+
|
| 33 |
+
def next(self):
|
| 34 |
+
self._previous.append(self._current)
|
| 35 |
+
self._current = 0
|
| 36 |
+
|
| 37 |
+
def reset(self):
|
| 38 |
+
self._previous.clear()
|
| 39 |
+
|
| 40 |
+
def min(self):
|
| 41 |
+
return np.min(self._previous)
|
| 42 |
+
|
| 43 |
+
def max(self):
|
| 44 |
+
return np.max(self._previous)
|
| 45 |
+
|
| 46 |
+
def mean(self):
|
| 47 |
+
return np.mean(self._previous)
|
| 48 |
+
|
| 49 |
+
def median(self):
|
| 50 |
+
return np.median(self._previous)
|
| 51 |
+
|
| 52 |
+
def std(self):
|
| 53 |
+
return np.std(self._previous)
|
| 54 |
+
|
| 55 |
+
def __len__(self):
|
| 56 |
+
return len(self._previous)
|
| 57 |
+
|
| 58 |
+
def __getitem__(self, i):
|
| 59 |
+
return self._previous[i]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class _SimpleAccumulator(StatAccumulator):
|
| 63 |
+
|
| 64 |
+
def __init__(self, prefix, eval_video_fps: int = 30,
|
| 65 |
+
mean_only: bool = True):
|
| 66 |
+
self._prefix = prefix
|
| 67 |
+
self._eval_video_fps = eval_video_fps
|
| 68 |
+
self._mean_only = mean_only
|
| 69 |
+
self._lock = Lock()
|
| 70 |
+
self._episode_returns = Metric()
|
| 71 |
+
self._episode_lengths = Metric()
|
| 72 |
+
self._summaries = []
|
| 73 |
+
self._transitions = 0
|
| 74 |
+
|
| 75 |
+
def _reset_data(self):
|
| 76 |
+
with self._lock:
|
| 77 |
+
self._episode_returns.reset()
|
| 78 |
+
self._episode_lengths.reset()
|
| 79 |
+
self._summaries.clear()
|
| 80 |
+
|
| 81 |
+
def step(self, transition: ReplayTransition, eval: bool):
|
| 82 |
+
with self._lock:
|
| 83 |
+
self._transitions += 1
|
| 84 |
+
self._episode_returns.update(transition.reward)
|
| 85 |
+
self._episode_lengths.update(1)
|
| 86 |
+
if transition.terminal:
|
| 87 |
+
self._episode_returns.next()
|
| 88 |
+
self._episode_lengths.next()
|
| 89 |
+
self._summaries.extend(list(transition.summaries))
|
| 90 |
+
|
| 91 |
+
def _get(self) -> List[Summary]:
|
| 92 |
+
sums = []
|
| 93 |
+
|
| 94 |
+
if self._mean_only:
|
| 95 |
+
stat_keys = ["mean"]
|
| 96 |
+
else:
|
| 97 |
+
stat_keys = ["min", "max", "mean", "median", "std"]
|
| 98 |
+
names = ["return", "length"]
|
| 99 |
+
metrics = [self._episode_returns, self._episode_lengths]
|
| 100 |
+
for name, metric in zip(names, metrics):
|
| 101 |
+
for stat_key in stat_keys:
|
| 102 |
+
if self._mean_only:
|
| 103 |
+
assert stat_key == "mean"
|
| 104 |
+
sum_name = '%s/%s' % (self._prefix, name)
|
| 105 |
+
else:
|
| 106 |
+
sum_name = '%s/%s/%s' % (self._prefix, name, stat_key)
|
| 107 |
+
sums.append(
|
| 108 |
+
ScalarSummary(sum_name, getattr(metric, stat_key)()))
|
| 109 |
+
sums.append(ScalarSummary(
|
| 110 |
+
'%s/total_transitions' % self._prefix, self._transitions))
|
| 111 |
+
sums.extend(self._summaries)
|
| 112 |
+
return sums
|
| 113 |
+
|
| 114 |
+
def pop(self) -> List[Summary]:
|
| 115 |
+
data = []
|
| 116 |
+
if len(self._episode_returns) > 1:
|
| 117 |
+
data = self._get()
|
| 118 |
+
self._reset_data()
|
| 119 |
+
return data
|
| 120 |
+
|
| 121 |
+
def peak(self) -> List[Summary]:
|
| 122 |
+
return self._get()
|
| 123 |
+
|
| 124 |
+
def reset(self):
|
| 125 |
+
self._transitions = 0
|
| 126 |
+
self._reset_data()
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class SimpleAccumulator(StatAccumulator):
|
| 130 |
+
|
| 131 |
+
def __init__(self, eval_video_fps: int = 30, mean_only: bool = True):
|
| 132 |
+
self._train_acc = _SimpleAccumulator(
|
| 133 |
+
'train_envs', eval_video_fps, mean_only=mean_only)
|
| 134 |
+
self._eval_acc = _SimpleAccumulator(
|
| 135 |
+
'eval_envs', eval_video_fps, mean_only=mean_only)
|
| 136 |
+
|
| 137 |
+
def step(self, transition: ReplayTransition, eval: bool):
|
| 138 |
+
if eval:
|
| 139 |
+
self._eval_acc.step(transition, eval)
|
| 140 |
+
else:
|
| 141 |
+
self._train_acc.step(transition, eval)
|
| 142 |
+
|
| 143 |
+
def pop(self) -> List[Summary]:
|
| 144 |
+
return self._train_acc.pop() + self._eval_acc.pop()
|
| 145 |
+
|
| 146 |
+
def peak(self) -> List[Summary]:
|
| 147 |
+
return self._train_acc.peak() + self._eval_acc.peak()
|
| 148 |
+
|
| 149 |
+
def reset(self) -> None:
|
| 150 |
+
self._train_acc.reset()
|
| 151 |
+
self._eval_acc.reset()
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class MultiTaskAccumulator(StatAccumulator):
|
| 155 |
+
|
| 156 |
+
def __init__(self, num_tasks,
|
| 157 |
+
eval_video_fps: int = 30, mean_only: bool = True,
|
| 158 |
+
train_prefix: str = 'train_task',
|
| 159 |
+
eval_prefix: str = 'eval_task'):
|
| 160 |
+
self._train_accs = [_SimpleAccumulator(
|
| 161 |
+
'%s%d/envs' % (train_prefix, i), eval_video_fps, mean_only=mean_only)
|
| 162 |
+
for i in range(num_tasks)]
|
| 163 |
+
self._eval_accs = [_SimpleAccumulator(
|
| 164 |
+
'%s%d/envs' % (eval_prefix, i), eval_video_fps, mean_only=mean_only)
|
| 165 |
+
for i in range(num_tasks)]
|
| 166 |
+
self._train_accs_mean = _SimpleAccumulator(
|
| 167 |
+
'%s_summary/envs' % train_prefix, eval_video_fps,
|
| 168 |
+
mean_only=mean_only)
|
| 169 |
+
|
| 170 |
+
def step(self, transition: ReplayTransition, eval: bool):
|
| 171 |
+
replay_index = transition.info["active_task_id"]
|
| 172 |
+
if eval:
|
| 173 |
+
self._eval_accs[replay_index].step(transition, eval)
|
| 174 |
+
else:
|
| 175 |
+
self._train_accs[replay_index].step(transition, eval)
|
| 176 |
+
self._train_accs_mean.step(transition, eval)
|
| 177 |
+
|
| 178 |
+
def pop(self) -> List[Summary]:
|
| 179 |
+
combined = self._train_accs_mean.pop()
|
| 180 |
+
for acc in self._train_accs + self._eval_accs:
|
| 181 |
+
combined.extend(acc.pop())
|
| 182 |
+
return combined
|
| 183 |
+
|
| 184 |
+
def peak(self) -> List[Summary]:
|
| 185 |
+
combined = self._train_accs_mean.peak()
|
| 186 |
+
for acc in self._train_accs + self._eval_accs:
|
| 187 |
+
combined.extend(acc.peak())
|
| 188 |
+
return combined
|
| 189 |
+
|
| 190 |
+
def reset(self) -> None:
|
| 191 |
+
self._train_accs_mean.reset()
|
| 192 |
+
[acc.reset() for acc in self._train_accs + self._eval_accs]
|
external/yarr/yarr/utils/transition.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from yarr.agents.agent import Summary
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Transition(object):
|
| 8 |
+
|
| 9 |
+
def __init__(self, observation: dict, reward: float, terminal: bool,
|
| 10 |
+
info: dict = None, summaries: List[Summary] = None):
|
| 11 |
+
self.observation = observation
|
| 12 |
+
self.reward = reward
|
| 13 |
+
self.terminal = terminal
|
| 14 |
+
self.info = info or {}
|
| 15 |
+
self.summaries = summaries or []
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ReplayTransition(object):
|
| 19 |
+
|
| 20 |
+
def __init__(self, observation: dict, action: np.ndarray,
|
| 21 |
+
reward: float, terminal: bool, timeout: bool,
|
| 22 |
+
final_observation: dict = None,
|
| 23 |
+
summaries: List[Summary] = None,
|
| 24 |
+
info: dict = None):
|
| 25 |
+
self.observation = observation
|
| 26 |
+
self.action = action
|
| 27 |
+
self.reward = reward
|
| 28 |
+
self.terminal = terminal
|
| 29 |
+
self.timeout = timeout
|
| 30 |
+
# final only populated on last timestep
|
| 31 |
+
self.final_observation = final_observation
|
| 32 |
+
self.summaries = summaries or []
|
| 33 |
+
self.info = info
|
external/yarr/yarr/utils/video_utils.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pyrep.objects.dummy import Dummy
|
| 4 |
+
from pyrep.objects.vision_sensor import VisionSensor
|
| 5 |
+
from rlbench import Environment
|
| 6 |
+
from rlbench.backend.observation import Observation
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CameraMotion(object):
|
| 10 |
+
def __init__(self, cam: VisionSensor):
|
| 11 |
+
self.cam = cam
|
| 12 |
+
|
| 13 |
+
def step(self):
|
| 14 |
+
raise NotImplementedError()
|
| 15 |
+
|
| 16 |
+
def save_pose(self):
|
| 17 |
+
self._prev_pose = self.cam.get_pose()
|
| 18 |
+
|
| 19 |
+
def restore_pose(self):
|
| 20 |
+
self.cam.set_pose(self._prev_pose)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CircleCameraMotion(CameraMotion):
|
| 24 |
+
|
| 25 |
+
def __init__(self, cam: VisionSensor, origin: Dummy,
|
| 26 |
+
speed: float, init_rotation: float = np.deg2rad(180)):
|
| 27 |
+
super().__init__(cam)
|
| 28 |
+
self.origin = origin
|
| 29 |
+
self.speed = speed # in radians
|
| 30 |
+
self.origin.rotate([0, 0, init_rotation])
|
| 31 |
+
|
| 32 |
+
def step(self):
|
| 33 |
+
self.origin.rotate([0, 0, self.speed])
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class TaskRecorder(object):
|
| 37 |
+
|
| 38 |
+
def __init__(self, env: Environment, cam_motion: CameraMotion, fps=30):
|
| 39 |
+
self._env = env
|
| 40 |
+
self._cam_motion = cam_motion
|
| 41 |
+
self._fps = fps
|
| 42 |
+
self._snaps = []
|
| 43 |
+
self._current_snaps = []
|
| 44 |
+
|
| 45 |
+
def take_snap(self, obs: Observation):
|
| 46 |
+
self._cam_motion.step()
|
| 47 |
+
self._current_snaps.append(
|
| 48 |
+
(self._cam_motion.cam.capture_rgb() * 255.).astype(np.uint8))
|
| 49 |
+
|
| 50 |
+
def save(self, path, lang_goal, reward):
|
| 51 |
+
print(f"Converting to video ... {path}")
|
| 52 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 53 |
+
# OpenCV QT version can conflict with PyRep, so import here
|
| 54 |
+
import cv2
|
| 55 |
+
image_size = self._cam_motion.cam.get_resolution()
|
| 56 |
+
video = cv2.VideoWriter(
|
| 57 |
+
path, cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), self._fps,
|
| 58 |
+
tuple(image_size))
|
| 59 |
+
|
| 60 |
+
for image in self._current_snaps:
|
| 61 |
+
frame = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 62 |
+
|
| 63 |
+
font = cv2.FONT_HERSHEY_DUPLEX
|
| 64 |
+
font_scale = (0.45 * image_size[0]) / 640
|
| 65 |
+
font_thickness = 2
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if lang_goal:
|
| 69 |
+
|
| 70 |
+
lang_textsize = cv2.getTextSize(lang_goal, font, font_scale, font_thickness)[0]
|
| 71 |
+
lang_textX = (image_size[0] - lang_textsize[0]) // 2
|
| 72 |
+
|
| 73 |
+
frame = cv2.putText(frame, lang_goal, org=(lang_textX, image_size[1] - 35),
|
| 74 |
+
fontScale=font_scale, fontFace=font, color=(0, 0, 0),
|
| 75 |
+
thickness=font_thickness, lineType=cv2.LINE_AA)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
video.write(frame)
|
| 79 |
+
video.release()
|
| 80 |
+
self._current_snaps = []
|