Update README.md
Browse files
README.md
CHANGED
|
@@ -1,4 +1,99 @@
|
|
| 1 |
---
|
| 2 |
language:
|
| 3 |
- en
|
| 4 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
language:
|
| 3 |
- en
|
| 4 |
+
---
|
| 5 |
+
# S5: Simplified State Space Layers for Sequence Modeling
|
| 6 |
+
|
| 7 |
+
This repository provides the implementation for the
|
| 8 |
+
paper: Simplified State Space Layers for Sequence Modeling. The preprint is available [here](https://arxiv.org/abs/2208.04933).
|
| 9 |
+
|
| 10 |
+

|
| 11 |
+
<p style="text-align: center;">
|
| 12 |
+
Figure 1: S5 uses a single multi-input, multi-output linear state-space model, coupled with non-linearities, to define a non-linear sequence-to-sequence transformation. Parallel scans are used for efficient offline processing.
|
| 13 |
+
</p>
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
The S5 layer builds on the prior S4 work ([paper](https://arxiv.org/abs/2111.00396)). While it has departed considerably, this repository originally started off with much of the JAX implementation of S4 from the
|
| 17 |
+
Annotated S4 blog by Rush and Karamcheti (available [here](https://github.com/srush/annotated-s4)).
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
## Requirements & Installation
|
| 21 |
+
To run the code on your own machine, run either `pip install -r requirements_cpu.txt` or `pip install -r requirements_gpu.txt`. The GPU installation of JAX can be tricky, and so we include requirements that should work for most people, although further instructions are available [here](https://github.com/google/jax#installation).
|
| 22 |
+
|
| 23 |
+
Run from within the root directory `pip install -e .` to install the package.
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
## Data Download
|
| 27 |
+
Downloading the raw data is done differently for each dataset. The following datasets require no action:
|
| 28 |
+
- Text (IMDb)
|
| 29 |
+
- Image (Cifar black & white)
|
| 30 |
+
- sMNIST
|
| 31 |
+
- psMNIST
|
| 32 |
+
- Cifar (Color)
|
| 33 |
+
|
| 34 |
+
The remaining datasets need to be manually downloaded. To download _everything_, run `./bin/download_all.sh`. This will download quite a lot of data and will take some time.
|
| 35 |
+
|
| 36 |
+
Below is a summary of the steps for each dataset:
|
| 37 |
+
- ListOps: run `./bin/download_lra.sh` to download the full LRA dataset.
|
| 38 |
+
- Retrieval (AAN): run `./bin/download_aan.sh`
|
| 39 |
+
- Pathfinder: run `./bin/download_lra.sh` to download the full LRA dataset.
|
| 40 |
+
- Path-X: run `./bin/download_lra.sh` to download the full LRA dataset.
|
| 41 |
+
- Speech commands 35: run `./bin/download_sc35.sh` to download the speech commands data.
|
| 42 |
+
|
| 43 |
+
*With the exception of SC35.* When the dataset is used for the first time, a cache is created in `./cache_dir`. Converting the data (e.g. tokenizing) can be quite slow, and so this cache contains the processed dataset. The cache can be moved and specified with the `--dir_name` argument (i.e. the default is `--dir_name=./cache_dir`) to avoid applying this preprocessing every time the code is run somewhere new.
|
| 44 |
+
|
| 45 |
+
SC35 is slightly different. SC35 doesn't use `--dir_name`, and instead requires that the following path exists: `./raw_datasets/speech_commands/0.0.2/SpeechCommands` (i.e. the directory `./raw_datasets/speech_commands/0.0.2/SpeechCommands/zero` must exist). The cache is then stored in `./raw_datasets/speech_commands/0.0.2/SpeechCommands/processed_data`. This directory can then be copied (preserving the directory path) to move the preprocessed dataset to a new location.
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
## Repository Structure
|
| 49 |
+
Directories and files that ship with GitHub repo:
|
| 50 |
+
```
|
| 51 |
+
s5/ Source code for models, datasets, etc.
|
| 52 |
+
dataloading.py Dataloading functions.
|
| 53 |
+
layers.py Defines the S5 layer which wraps the S5 SSM with nonlinearity, norms, dropout, etc.
|
| 54 |
+
seq_model.py Defines deep sequence models that consist of stacks of S5 layers.
|
| 55 |
+
ssm.py S5 SSM implementation.
|
| 56 |
+
ssm_init.py Helper functions for initializing the S5 SSM .
|
| 57 |
+
train.py Training loop code.
|
| 58 |
+
train_helpers.py Functions for optimization, training and evaluation steps.
|
| 59 |
+
dataloaders/ Code mainly derived from S4 processing each dataset.
|
| 60 |
+
utils/ Range of utility functions.
|
| 61 |
+
bin/ Shell scripts for downloading data and running example experiments.
|
| 62 |
+
requirements_cpu.txt Requirements for running in CPU mode (not advised).
|
| 63 |
+
requirements_gpu.txt Requirements for running in GPU mode (installation can be highly system-dependent).
|
| 64 |
+
run_train.py Training loop entrypoint.
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
Directories that may be created on-the-fly:
|
| 68 |
+
```
|
| 69 |
+
raw_datasets/ Raw data as downloaded.
|
| 70 |
+
cache_dir/ Precompiled caches of data. Can be copied to new locations to avoid preprocessing.
|
| 71 |
+
wandb/ Local WandB log files.
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
## Experiments
|
| 75 |
+
|
| 76 |
+
The configurations to run the LRA and 35-way Speech Commands experiments from the paper are located in `bin/run_experiments`. For example,
|
| 77 |
+
to run the LRA text (character level IMDB) experiment, run `./bin/run_experiments/run_lra_imdb.sh`.
|
| 78 |
+
To log with W&B, adjust the default `USE_WANDB, wandb_entity, wandb_project` arguments.
|
| 79 |
+
Note: the pendulum
|
| 80 |
+
regression dataloading and experiments will be added soon.
|
| 81 |
+
|
| 82 |
+
## Citation
|
| 83 |
+
Please use the following when citing our work:
|
| 84 |
+
```
|
| 85 |
+
@misc{smith2022s5,
|
| 86 |
+
doi = {10.48550/ARXIV.2208.04933},
|
| 87 |
+
url = {https://arxiv.org/abs/2208.04933},
|
| 88 |
+
author = {Smith, Jimmy T. H. and Warrington, Andrew and Linderman, Scott W.},
|
| 89 |
+
keywords = {Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences},
|
| 90 |
+
title = {Simplified State Space Layers for Sequence Modeling},
|
| 91 |
+
publisher = {arXiv},
|
| 92 |
+
year = {2022},
|
| 93 |
+
copyright = {Creative Commons Attribution 4.0 International}
|
| 94 |
+
}
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
Please reach out if you have any questions.
|
| 98 |
+
|
| 99 |
+
-- The S5 authors.
|