Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- conditional-flow-matching/.github/workflows/code-quality.yaml +26 -0
- conditional-flow-matching/.github/workflows/python-publish.yml +38 -0
- conditional-flow-matching/.github/workflows/test.yaml +71 -0
- conditional-flow-matching/.github/workflows/test_runner.yaml +75 -0
- conditional-flow-matching/runner/configs/callbacks/default.yaml +22 -0
- conditional-flow-matching/runner/configs/callbacks/early_stopping.yaml +17 -0
- conditional-flow-matching/runner/configs/callbacks/model_checkpoint.yaml +19 -0
- conditional-flow-matching/runner/configs/callbacks/model_summary.yaml +7 -0
- conditional-flow-matching/runner/configs/callbacks/no_stopping.yaml +15 -0
- conditional-flow-matching/runner/configs/callbacks/none.yaml +0 -0
- conditional-flow-matching/runner/configs/callbacks/rich_progress_bar.yaml +6 -0
- conditional-flow-matching/runner/configs/datamodule/cifar.yaml +11 -0
- conditional-flow-matching/runner/configs/datamodule/custom_dist.yaml +17 -0
- conditional-flow-matching/runner/configs/datamodule/eb_full.yaml +14 -0
- conditional-flow-matching/runner/configs/datamodule/funnel.yaml +12 -0
- conditional-flow-matching/runner/configs/datamodule/gaussians.yaml +13 -0
- conditional-flow-matching/runner/configs/datamodule/moons.yaml +10 -0
- conditional-flow-matching/runner/configs/datamodule/scurve.yaml +10 -0
- conditional-flow-matching/runner/configs/datamodule/sklearn.yaml +10 -0
- conditional-flow-matching/runner/configs/datamodule/time_dist.yaml +12 -0
- conditional-flow-matching/runner/configs/datamodule/torchdyn.yaml +13 -0
- conditional-flow-matching/runner/configs/datamodule/tree.yaml +8 -0
- conditional-flow-matching/runner/configs/datamodule/twodim.yaml +8 -0
- conditional-flow-matching/runner/configs/debug/default.yaml +35 -0
- conditional-flow-matching/runner/configs/debug/fdr.yaml +9 -0
- conditional-flow-matching/runner/configs/debug/limit.yaml +12 -0
- conditional-flow-matching/runner/configs/debug/overfit.yaml +13 -0
- conditional-flow-matching/runner/configs/debug/profiler.yaml +12 -0
- conditional-flow-matching/runner/configs/eval.yaml +18 -0
- conditional-flow-matching/runner/configs/experiment/cfm.yaml +22 -0
- conditional-flow-matching/runner/configs/experiment/cnf.yaml +22 -0
- conditional-flow-matching/runner/configs/experiment/icnn.yaml +18 -0
- conditional-flow-matching/runner/configs/experiment/image_cfm.yaml +33 -0
- conditional-flow-matching/runner/configs/experiment/image_fm.yaml +33 -0
- conditional-flow-matching/runner/configs/experiment/image_otcfm.yaml +34 -0
- conditional-flow-matching/runner/configs/experiment/trajectorynet.yaml +22 -0
- conditional-flow-matching/runner/configs/extras/default.yaml +8 -0
- conditional-flow-matching/runner/configs/hparams_search/optuna.yaml +49 -0
- conditional-flow-matching/runner/configs/hydra/default.yaml +15 -0
- conditional-flow-matching/runner/configs/launcher/mila_cluster.yaml +18 -0
- conditional-flow-matching/runner/configs/launcher/mila_cpu_cluster.yaml +16 -0
- conditional-flow-matching/runner/configs/local/.gitkeep +0 -0
- conditional-flow-matching/runner/configs/local/default.yaml +12 -0
- conditional-flow-matching/runner/configs/logger/comet.yaml +12 -0
- conditional-flow-matching/runner/configs/logger/csv.yaml +7 -0
- conditional-flow-matching/runner/configs/logger/many_loggers.yaml +9 -0
- conditional-flow-matching/runner/configs/logger/mlflow.yaml +12 -0
- conditional-flow-matching/runner/configs/logger/neptune.yaml +9 -0
- conditional-flow-matching/runner/configs/logger/tensorboard.yaml +10 -0
- conditional-flow-matching/runner/configs/logger/wandb.yaml +16 -0
conditional-flow-matching/.github/workflows/code-quality.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Same as `code-quality-pr.yaml` but triggered on commit to main branch
|
| 2 |
+
# and runs on all files (instead of only the changed ones)
|
| 3 |
+
|
| 4 |
+
name: Code Quality Main
|
| 5 |
+
|
| 6 |
+
on:
|
| 7 |
+
push:
|
| 8 |
+
branches: [main]
|
| 9 |
+
pull_request:
|
| 10 |
+
branches: [main, "release/*"]
|
| 11 |
+
|
| 12 |
+
jobs:
|
| 13 |
+
code-quality:
|
| 14 |
+
runs-on: ubuntu-latest
|
| 15 |
+
|
| 16 |
+
steps:
|
| 17 |
+
- name: Checkout
|
| 18 |
+
uses: actions/checkout@v5
|
| 19 |
+
|
| 20 |
+
- name: Set up Python
|
| 21 |
+
uses: actions/setup-python@v6
|
| 22 |
+
with:
|
| 23 |
+
python-version: "3.13"
|
| 24 |
+
|
| 25 |
+
- name: Run pre-commits
|
| 26 |
+
uses: pre-commit/action@v3.0.1
|
conditional-flow-matching/.github/workflows/python-publish.yml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This workflow will upload a Python Package using Twine when a release is created
|
| 2 |
+
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
| 3 |
+
|
| 4 |
+
# This workflow uses actions that are not certified by GitHub.
|
| 5 |
+
# They are provided by a third-party and are governed by
|
| 6 |
+
# separate terms of service, privacy policy, and support
|
| 7 |
+
# documentation.
|
| 8 |
+
|
| 9 |
+
name: Upload Python Package
|
| 10 |
+
|
| 11 |
+
on:
|
| 12 |
+
release:
|
| 13 |
+
types: [published]
|
| 14 |
+
|
| 15 |
+
permissions:
|
| 16 |
+
contents: read
|
| 17 |
+
|
| 18 |
+
jobs:
|
| 19 |
+
deploy:
|
| 20 |
+
runs-on: ubuntu-latest
|
| 21 |
+
|
| 22 |
+
steps:
|
| 23 |
+
- uses: actions/checkout@v3
|
| 24 |
+
- name: Set up Python
|
| 25 |
+
uses: actions/setup-python@v3
|
| 26 |
+
with:
|
| 27 |
+
python-version: "3.x"
|
| 28 |
+
- name: Install dependencies
|
| 29 |
+
run: |
|
| 30 |
+
python -m pip install --upgrade pip
|
| 31 |
+
pip install build
|
| 32 |
+
- name: Build package
|
| 33 |
+
run: python -m build
|
| 34 |
+
- name: Publish package
|
| 35 |
+
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
| 36 |
+
with:
|
| 37 |
+
user: __token__
|
| 38 |
+
password: ${{ secrets.PYPI_API_TOKEN }}
|
conditional-flow-matching/.github/workflows/test.yaml
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: TorchCFM Tests
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [main]
|
| 6 |
+
pull_request:
|
| 7 |
+
branches: [main, "release/*"]
|
| 8 |
+
|
| 9 |
+
jobs:
|
| 10 |
+
run_tests_ubuntu:
|
| 11 |
+
runs-on: ${{ matrix.os }}
|
| 12 |
+
|
| 13 |
+
strategy:
|
| 14 |
+
fail-fast: false
|
| 15 |
+
matrix:
|
| 16 |
+
os: [ubuntu-latest, macos-latest, windows-latest]
|
| 17 |
+
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
| 18 |
+
|
| 19 |
+
steps:
|
| 20 |
+
- name: Checkout
|
| 21 |
+
uses: actions/checkout@v5
|
| 22 |
+
|
| 23 |
+
- name: Set up Python ${{ matrix.python-version }}
|
| 24 |
+
uses: actions/setup-python@v6
|
| 25 |
+
with:
|
| 26 |
+
python-version: ${{ matrix.python-version }}
|
| 27 |
+
|
| 28 |
+
- name: Install dependencies
|
| 29 |
+
run: |
|
| 30 |
+
python -m pip install --upgrade pip
|
| 31 |
+
pip install pytest
|
| 32 |
+
pip install sh
|
| 33 |
+
pip install -e .
|
| 34 |
+
|
| 35 |
+
- name: List dependencies
|
| 36 |
+
run: |
|
| 37 |
+
python -m pip list
|
| 38 |
+
|
| 39 |
+
- name: Run pytest
|
| 40 |
+
run: |
|
| 41 |
+
pytest -v --ignore=examples --ignore=runner
|
| 42 |
+
|
| 43 |
+
# upload code coverage report
|
| 44 |
+
code-coverage-torchcfm:
|
| 45 |
+
runs-on: ubuntu-latest
|
| 46 |
+
|
| 47 |
+
steps:
|
| 48 |
+
- name: Checkout
|
| 49 |
+
uses: actions/checkout@v5
|
| 50 |
+
|
| 51 |
+
- name: Set up Python 3.10
|
| 52 |
+
uses: actions/setup-python@v6
|
| 53 |
+
with:
|
| 54 |
+
python-version: "3.10"
|
| 55 |
+
|
| 56 |
+
- name: Install dependencies
|
| 57 |
+
run: |
|
| 58 |
+
python -m pip install --upgrade pip
|
| 59 |
+
pip install pytest
|
| 60 |
+
pip install pytest-cov[toml]
|
| 61 |
+
pip install sh
|
| 62 |
+
pip install -e .
|
| 63 |
+
|
| 64 |
+
- name: Run tests and collect coverage
|
| 65 |
+
run: pytest . --cov torchcfm --ignore=runner --ignore=examples --ignore=torchcfm/models/ --cov-fail-under=30
|
| 66 |
+
|
| 67 |
+
- name: Upload coverage to Codecov
|
| 68 |
+
uses: codecov/codecov-action@v3
|
| 69 |
+
with:
|
| 70 |
+
name: codecov-torchcfm
|
| 71 |
+
verbose: true
|
conditional-flow-matching/.github/workflows/test_runner.yaml
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Runner Tests
|
| 2 |
+
|
| 3 |
+
#on:
|
| 4 |
+
# push:
|
| 5 |
+
# branches: [main]
|
| 6 |
+
# pull_request:
|
| 7 |
+
# branches: [main, "release/*"]
|
| 8 |
+
|
| 9 |
+
jobs:
|
| 10 |
+
run_tests_ubuntu:
|
| 11 |
+
runs-on: ${{ matrix.os }}
|
| 12 |
+
|
| 13 |
+
strategy:
|
| 14 |
+
fail-fast: false
|
| 15 |
+
matrix:
|
| 16 |
+
os: [ubuntu-latest, ubuntu-20.04, macos-latest, windows-latest]
|
| 17 |
+
python-version: ["3.9", "3.10"]
|
| 18 |
+
|
| 19 |
+
steps:
|
| 20 |
+
- name: Checkout
|
| 21 |
+
uses: actions/checkout@v5
|
| 22 |
+
|
| 23 |
+
- name: Set up Python ${{ matrix.python-version }}
|
| 24 |
+
uses: actions/setup-python@v6
|
| 25 |
+
with:
|
| 26 |
+
python-version: ${{ matrix.python-version }}
|
| 27 |
+
|
| 28 |
+
- name: Install dependencies
|
| 29 |
+
run: |
|
| 30 |
+
# Fix pip version < 24.1 due to lightning incomaptibility
|
| 31 |
+
python -m pip install pip==23.2.1
|
| 32 |
+
pip install -r runner-requirements.txt
|
| 33 |
+
pip install pytest
|
| 34 |
+
pip install sh
|
| 35 |
+
pip install -e .
|
| 36 |
+
|
| 37 |
+
- name: List dependencies
|
| 38 |
+
run: |
|
| 39 |
+
python -m pip list
|
| 40 |
+
|
| 41 |
+
- name: Run pytest
|
| 42 |
+
run: |
|
| 43 |
+
pytest -v runner
|
| 44 |
+
|
| 45 |
+
# upload code coverage report
|
| 46 |
+
code-coverage-runner:
|
| 47 |
+
runs-on: ubuntu-latest
|
| 48 |
+
|
| 49 |
+
steps:
|
| 50 |
+
- name: Checkout
|
| 51 |
+
uses: actions/checkout@v5
|
| 52 |
+
|
| 53 |
+
- name: Set up Python 3.10
|
| 54 |
+
uses: actions/setup-python@v6
|
| 55 |
+
with:
|
| 56 |
+
python-version: "3.10"
|
| 57 |
+
|
| 58 |
+
- name: Install dependencies
|
| 59 |
+
run: |
|
| 60 |
+
# Fix pip version < 24.1 due to lightning incomaptibility
|
| 61 |
+
python -m pip install pip==23.2.1
|
| 62 |
+
pip install -r runner-requirements.txt
|
| 63 |
+
pip install pytest
|
| 64 |
+
pip install pytest-cov[toml]
|
| 65 |
+
pip install sh
|
| 66 |
+
pip install -e .
|
| 67 |
+
|
| 68 |
+
- name: Run tests and collect coverage
|
| 69 |
+
run: pytest runner --cov runner --cov-fail-under=30 # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER
|
| 70 |
+
|
| 71 |
+
- name: Upload coverage to Codecov
|
| 72 |
+
uses: codecov/codecov-action@v3
|
| 73 |
+
with:
|
| 74 |
+
name: codecov-runner
|
| 75 |
+
verbose: true
|
conditional-flow-matching/runner/configs/callbacks/default.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- model_checkpoint.yaml
|
| 3 |
+
- early_stopping.yaml
|
| 4 |
+
- model_summary.yaml
|
| 5 |
+
- rich_progress_bar.yaml
|
| 6 |
+
- _self_
|
| 7 |
+
|
| 8 |
+
model_checkpoint:
|
| 9 |
+
dirpath: ${paths.output_dir}/checkpoints
|
| 10 |
+
filename: "epoch_{epoch:04d}"
|
| 11 |
+
monitor: "val/loss"
|
| 12 |
+
mode: "min"
|
| 13 |
+
save_last: True
|
| 14 |
+
auto_insert_metric_name: False
|
| 15 |
+
|
| 16 |
+
early_stopping:
|
| 17 |
+
monitor: "val/loss"
|
| 18 |
+
patience: 100
|
| 19 |
+
mode: "min"
|
| 20 |
+
|
| 21 |
+
model_summary:
|
| 22 |
+
max_depth: -1
|
conditional-flow-matching/runner/configs/callbacks/early_stopping.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.EarlyStopping.html
|
| 2 |
+
|
| 3 |
+
# Monitor a metric and stop training when it stops improving.
|
| 4 |
+
# Look at the above link for more detailed information.
|
| 5 |
+
early_stopping:
|
| 6 |
+
_target_: pytorch_lightning.callbacks.EarlyStopping
|
| 7 |
+
monitor: ??? # quantity to be monitored, must be specified !!!
|
| 8 |
+
min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
|
| 9 |
+
patience: 3 # number of checks with no improvement after which training will be stopped
|
| 10 |
+
verbose: False # verbosity mode
|
| 11 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 12 |
+
strict: True # whether to crash the training if monitor is not found in the validation metrics
|
| 13 |
+
check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
|
| 14 |
+
stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
|
| 15 |
+
divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
|
| 16 |
+
check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
|
| 17 |
+
# log_rank_zero_only: False # this keyword argument isn't available in stable version
|
conditional-flow-matching/runner/configs/callbacks/model_checkpoint.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ModelCheckpoint.html
|
| 2 |
+
|
| 3 |
+
# Save the model periodically by monitoring a quantity.
|
| 4 |
+
# Look at the above link for more detailed information.
|
| 5 |
+
model_checkpoint:
|
| 6 |
+
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
| 7 |
+
dirpath: null # directory to save the model file
|
| 8 |
+
filename: null # checkpoint filename
|
| 9 |
+
monitor: null # name of the logged metric which determines when model is improving
|
| 10 |
+
verbose: False # verbosity mode
|
| 11 |
+
save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
|
| 12 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
| 13 |
+
mode: "min" # "max" means higher metric value is better, can be also "min"
|
| 14 |
+
auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
|
| 15 |
+
save_weights_only: False # if True, then only the model’s weights will be saved
|
| 16 |
+
every_n_train_steps: null # number of training steps between checkpoints
|
| 17 |
+
train_time_interval: null # checkpoints are monitored at the specified time interval
|
| 18 |
+
every_n_epochs: null # number of epochs between checkpoints
|
| 19 |
+
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
|
conditional-flow-matching/runner/configs/callbacks/model_summary.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichModelSummary.html
|
| 2 |
+
|
| 3 |
+
# Generates a summary of all layers in a LightningModule with rich text formatting.
|
| 4 |
+
# Look at the above link for more detailed information.
|
| 5 |
+
model_summary:
|
| 6 |
+
_target_: pytorch_lightning.callbacks.RichModelSummary
|
| 7 |
+
max_depth: 1 # the maximum depth of layer nesting that the summary will include
|
conditional-flow-matching/runner/configs/callbacks/no_stopping.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- model_checkpoint.yaml
|
| 3 |
+
- model_summary.yaml
|
| 4 |
+
- rich_progress_bar.yaml
|
| 5 |
+
- _self_
|
| 6 |
+
|
| 7 |
+
model_checkpoint:
|
| 8 |
+
dirpath: ${paths.output_dir}/checkpoints
|
| 9 |
+
filename: "epoch_{epoch:04d}"
|
| 10 |
+
save_last: True
|
| 11 |
+
every_n_epochs: 100 # number of epochs between checkpoints
|
| 12 |
+
auto_insert_metric_name: False
|
| 13 |
+
|
| 14 |
+
model_summary:
|
| 15 |
+
max_depth: 3
|
conditional-flow-matching/runner/configs/callbacks/none.yaml
ADDED
|
File without changes
|
conditional-flow-matching/runner/configs/callbacks/rich_progress_bar.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichProgressBar.html
|
| 2 |
+
|
| 3 |
+
# Create a progress bar with rich text formatting.
|
| 4 |
+
# Look at the above link for more detailed information.
|
| 5 |
+
rich_progress_bar:
|
| 6 |
+
_target_: pytorch_lightning.callbacks.RichProgressBar
|
conditional-flow-matching/runner/configs/datamodule/cifar.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: src.datamodules.cifar10_datamodule.CIFAR10DataModule
|
| 2 |
+
#_target_: pl_bolts.datamodules.cifar10_datamodule.CIFAR10DataModule
|
| 3 |
+
data_dir: ${paths.data_dir}
|
| 4 |
+
batch_size: 128
|
| 5 |
+
val_split: 0.0
|
| 6 |
+
num_workers: 0
|
| 7 |
+
normalize: True
|
| 8 |
+
seed: 42
|
| 9 |
+
shuffle: True
|
| 10 |
+
pin_memory: True
|
| 11 |
+
drop_last: False
|
conditional-flow-matching/runner/configs/datamodule/custom_dist.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: src.datamodules.distribution_datamodule.TrajectoryNetDistributionTrajectoryDataModule
|
| 2 |
+
#_target_: src.datamodules.distribution_datamodule.DistributionDataModule
|
| 3 |
+
|
| 4 |
+
data_dir: ${paths.data_dir} # data_dir is specified in config.yaml
|
| 5 |
+
train_val_test_split: 1000
|
| 6 |
+
batch_size: 100
|
| 7 |
+
num_workers: 0
|
| 8 |
+
pin_memory: False
|
| 9 |
+
|
| 10 |
+
system: ${paths.data_dir}/embryoid_anndata_small_v2.h5ad
|
| 11 |
+
|
| 12 |
+
system_kwargs:
|
| 13 |
+
max_dim: 1e10
|
| 14 |
+
embedding_name: "phate"
|
| 15 |
+
#embedding_name: "highly_variable"
|
| 16 |
+
whiten: True
|
| 17 |
+
#whiten: False
|
conditional-flow-matching/runner/configs/datamodule/eb_full.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py
|
| 2 |
+
_target_: src.datamodules.distribution_datamodule.TorchDynDataModule
|
| 3 |
+
#_target_: src.datamodules.distribution_datamodule.DistributionDataModule
|
| 4 |
+
|
| 5 |
+
train_val_test_split: [0.8, 0.1, 0.1]
|
| 6 |
+
batch_size: 128
|
| 7 |
+
num_workers: 0
|
| 8 |
+
pin_memory: False
|
| 9 |
+
|
| 10 |
+
system: ${paths.data_dir}/eb_velocity_v5.npz
|
| 11 |
+
|
| 12 |
+
system_kwargs:
|
| 13 |
+
max_dim: 100
|
| 14 |
+
whiten: False
|
conditional-flow-matching/runner/configs/datamodule/funnel.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: src.datamodules.distribution_datamodule.TorchDynDataModule
|
| 2 |
+
#_target_: src.datamodules.distribution_datamodule.DistributionDataModule
|
| 3 |
+
|
| 4 |
+
train_val_test_split: [10000, 1000, 1000]
|
| 5 |
+
batch_size: 128
|
| 6 |
+
num_workers: 0
|
| 7 |
+
pin_memory: False
|
| 8 |
+
|
| 9 |
+
system: "funnel"
|
| 10 |
+
|
| 11 |
+
system_kwargs:
|
| 12 |
+
dim: 10
|
conditional-flow-matching/runner/configs/datamodule/gaussians.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py
|
| 2 |
+
_target_: src.datamodules.distribution_datamodule.TorchDynDataModule
|
| 3 |
+
#_target_: src.datamodules.distribution_datamodule.DistributionDataModule
|
| 4 |
+
|
| 5 |
+
train_val_test_split: [10000, 1000, 1000]
|
| 6 |
+
batch_size: 128
|
| 7 |
+
num_workers: 0
|
| 8 |
+
pin_memory: False
|
| 9 |
+
|
| 10 |
+
system: "gaussians"
|
| 11 |
+
|
| 12 |
+
system_kwargs:
|
| 13 |
+
noise: 1e-4
|
conditional-flow-matching/runner/configs/datamodule/moons.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py
|
| 2 |
+
_target_: src.datamodules.distribution_datamodule.SKLearnDataModule
|
| 3 |
+
#_target_: src.datamodules.distribution_datamodule.DistributionDataModule
|
| 4 |
+
|
| 5 |
+
train_val_test_split: [10000, 1000, 1000]
|
| 6 |
+
batch_size: 128
|
| 7 |
+
num_workers: 0
|
| 8 |
+
pin_memory: False
|
| 9 |
+
|
| 10 |
+
system: "moons"
|
conditional-flow-matching/runner/configs/datamodule/scurve.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py
|
| 2 |
+
_target_: src.datamodules.distribution_datamodule.SKLearnDataModule
|
| 3 |
+
#_target_: src.datamodules.distribution_datamodule.DistributionDataModule
|
| 4 |
+
|
| 5 |
+
train_val_test_split: [10000, 1000, 1000]
|
| 6 |
+
batch_size: 128
|
| 7 |
+
num_workers: 0
|
| 8 |
+
pin_memory: False
|
| 9 |
+
|
| 10 |
+
system: "scurve"
|
conditional-flow-matching/runner/configs/datamodule/sklearn.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py
|
| 2 |
+
_target_: src.datamodules.distribution_datamodule.SKLearnDataModule
|
| 3 |
+
#_target_: src.datamodules.distribution_datamodule.DistributionDataModule
|
| 4 |
+
|
| 5 |
+
train_val_test_split: [10000, 1000, 1000]
|
| 6 |
+
batch_size: 128
|
| 7 |
+
num_workers: 0
|
| 8 |
+
pin_memory: False
|
| 9 |
+
|
| 10 |
+
system: "scurve"
|
conditional-flow-matching/runner/configs/datamodule/time_dist.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: src.datamodules.distribution_datamodule.CustomTrajectoryDataModule
|
| 2 |
+
#_target_: src.datamodules.distribution_datamodule.DistributionDataModule
|
| 3 |
+
|
| 4 |
+
data_dir: ${paths.data_dir} # data_dir is specified in config.yaml
|
| 5 |
+
train_val_test_split: [0.8, 0.1, 0.1]
|
| 6 |
+
batch_size: 128
|
| 7 |
+
num_workers: 0
|
| 8 |
+
pin_memory: False
|
| 9 |
+
max_dim: 5
|
| 10 |
+
whiten: True
|
| 11 |
+
|
| 12 |
+
system: ${paths.data_dir}/eb_velocity_v5.npz
|
conditional-flow-matching/runner/configs/datamodule/torchdyn.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/datasets/static_datasets.py
|
| 2 |
+
_target_: src.datamodules.distribution_datamodule.TorchDynDataModule
|
| 3 |
+
#_target_: src.datamodules.distribution_datamodule.DistributionDataModule
|
| 4 |
+
|
| 5 |
+
train_val_test_split: [10000, 1000, 1000]
|
| 6 |
+
batch_size: 128
|
| 7 |
+
num_workers: 0
|
| 8 |
+
pin_memory: False
|
| 9 |
+
|
| 10 |
+
system: "moons"
|
| 11 |
+
|
| 12 |
+
system_kwargs:
|
| 13 |
+
noise: 1e-4
|
conditional-flow-matching/runner/configs/datamodule/tree.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: src.datamodules.distribution_datamodule.DistributionDataModule
|
| 2 |
+
|
| 3 |
+
data_dir: ${data_dir} # data_dir is specified in config.yaml
|
| 4 |
+
train_val_test_split: 1000
|
| 5 |
+
batch_size: 100
|
| 6 |
+
num_workers: 0
|
| 7 |
+
pin_memory: False
|
| 8 |
+
p: 2
|
conditional-flow-matching/runner/configs/datamodule/twodim.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: src.datamodules.distribution_datamodule.TwoDimDataModule
|
| 2 |
+
|
| 3 |
+
train_val_test_split: [10000, 1000, 1000]
|
| 4 |
+
batch_size: 128
|
| 5 |
+
num_workers: 0
|
| 6 |
+
pin_memory: False
|
| 7 |
+
|
| 8 |
+
system: "moon-8gaussians"
|
conditional-flow-matching/runner/configs/debug/default.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# default debugging setup, runs 1 full epoch
|
| 4 |
+
# other debugging configs can inherit from this one
|
| 5 |
+
|
| 6 |
+
# overwrite task name so debugging logs are stored in separate folder
|
| 7 |
+
task_name: "debug"
|
| 8 |
+
|
| 9 |
+
# disable callbacks and loggers during debugging
|
| 10 |
+
callbacks: null
|
| 11 |
+
logger: null
|
| 12 |
+
|
| 13 |
+
extras:
|
| 14 |
+
ignore_warnings: False
|
| 15 |
+
enforce_tags: False
|
| 16 |
+
|
| 17 |
+
# sets level of all command line loggers to 'DEBUG'
|
| 18 |
+
# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
|
| 19 |
+
hydra:
|
| 20 |
+
job_logging:
|
| 21 |
+
root:
|
| 22 |
+
level: DEBUG
|
| 23 |
+
|
| 24 |
+
# use this to also set hydra loggers to 'DEBUG'
|
| 25 |
+
# verbose: True
|
| 26 |
+
|
| 27 |
+
trainer:
|
| 28 |
+
max_epochs: 1
|
| 29 |
+
accelerator: cpu # debuggers don't like gpus
|
| 30 |
+
devices: 1 # debuggers don't like multiprocessing
|
| 31 |
+
detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
|
| 32 |
+
|
| 33 |
+
datamodule:
|
| 34 |
+
num_workers: 0 # debuggers don't like multiprocessing
|
| 35 |
+
pin_memory: False # disable gpu memory pin
|
conditional-flow-matching/runner/configs/debug/fdr.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# runs 1 train, 1 validation and 1 test step
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- default.yaml
|
| 7 |
+
|
| 8 |
+
trainer:
|
| 9 |
+
fast_dev_run: true
|
conditional-flow-matching/runner/configs/debug/limit.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# uses only 1% of the training data and 5% of validation/test data
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- default.yaml
|
| 7 |
+
|
| 8 |
+
trainer:
|
| 9 |
+
max_epochs: 3
|
| 10 |
+
limit_train_batches: 0.01
|
| 11 |
+
limit_val_batches: 0.05
|
| 12 |
+
limit_test_batches: 0.05
|
conditional-flow-matching/runner/configs/debug/overfit.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# overfits to 3 batches
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- default.yaml
|
| 7 |
+
|
| 8 |
+
trainer:
|
| 9 |
+
max_epochs: 20
|
| 10 |
+
overfit_batches: 3
|
| 11 |
+
|
| 12 |
+
# model ckpt and early stopping need to be disabled during overfitting
|
| 13 |
+
callbacks: null
|
conditional-flow-matching/runner/configs/debug/profiler.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# runs with execution time profiling
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- default.yaml
|
| 7 |
+
|
| 8 |
+
trainer:
|
| 9 |
+
max_epochs: 1
|
| 10 |
+
profiler: "simple"
|
| 11 |
+
# profiler: "advanced"
|
| 12 |
+
# profiler: "pytorch"
|
conditional-flow-matching/runner/configs/eval.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- _self_
|
| 5 |
+
- datamodule: sklearn # choose datamodule with `test_dataloader()` for evaluation
|
| 6 |
+
- model: cfm
|
| 7 |
+
- logger: null
|
| 8 |
+
- trainer: default.yaml
|
| 9 |
+
- paths: default.yaml
|
| 10 |
+
- extras: default.yaml
|
| 11 |
+
- hydra: default.yaml
|
| 12 |
+
|
| 13 |
+
task_name: "eval"
|
| 14 |
+
|
| 15 |
+
tags: ["dev"]
|
| 16 |
+
|
| 17 |
+
# passing checkpoint path is necessary for evaluation
|
| 18 |
+
ckpt_path: ???
|
conditional-flow-matching/runner/configs/experiment/cfm.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /model: cfm.yaml
|
| 5 |
+
- override /logger:
|
| 6 |
+
- csv.yaml
|
| 7 |
+
- wandb.yaml
|
| 8 |
+
- override /datamodule: sklearn.yaml
|
| 9 |
+
|
| 10 |
+
name: "cfm"
|
| 11 |
+
seed: 42
|
| 12 |
+
|
| 13 |
+
datamodule:
|
| 14 |
+
batch_size: 512
|
| 15 |
+
|
| 16 |
+
model:
|
| 17 |
+
optimizer:
|
| 18 |
+
weight_decay: 1e-5
|
| 19 |
+
|
| 20 |
+
trainer:
|
| 21 |
+
max_epochs: 1000
|
| 22 |
+
check_val_every_n_epoch: 10
|
conditional-flow-matching/runner/configs/experiment/cnf.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /model: cnf.yaml
|
| 5 |
+
- override /logger:
|
| 6 |
+
- csv.yaml
|
| 7 |
+
- wandb.yaml
|
| 8 |
+
- override /datamodule: sklearn.yaml
|
| 9 |
+
|
| 10 |
+
name: "cnf"
|
| 11 |
+
seed: 42
|
| 12 |
+
|
| 13 |
+
datamodule:
|
| 14 |
+
batch_size: 1024
|
| 15 |
+
|
| 16 |
+
model:
|
| 17 |
+
optimizer:
|
| 18 |
+
weight_decay: 1e-5
|
| 19 |
+
|
| 20 |
+
trainer:
|
| 21 |
+
max_epochs: 1000
|
| 22 |
+
check_val_every_n_epoch: 10
|
conditional-flow-matching/runner/configs/experiment/icnn.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /model: icnn
|
| 5 |
+
- override /logger:
|
| 6 |
+
- csv
|
| 7 |
+
- wandb
|
| 8 |
+
- override /datamodule: sklearn
|
| 9 |
+
|
| 10 |
+
name: "icnn"
|
| 11 |
+
seed: 42
|
| 12 |
+
|
| 13 |
+
datamodule:
|
| 14 |
+
batch_size: 256
|
| 15 |
+
|
| 16 |
+
trainer:
|
| 17 |
+
max_epochs: 10000
|
| 18 |
+
check_val_every_n_epoch: 100
|
conditional-flow-matching/runner/configs/experiment/image_cfm.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /model: image_cfm.yaml
|
| 5 |
+
- override /callbacks: no_stopping
|
| 6 |
+
- override /logger:
|
| 7 |
+
- csv.yaml
|
| 8 |
+
- wandb.yaml
|
| 9 |
+
- override /datamodule: cifar.yaml
|
| 10 |
+
- override /trainer: ddp.yaml
|
| 11 |
+
|
| 12 |
+
name: "cfm"
|
| 13 |
+
seed: 42
|
| 14 |
+
|
| 15 |
+
datamodule:
|
| 16 |
+
batch_size: 128
|
| 17 |
+
|
| 18 |
+
model:
|
| 19 |
+
_target_: src.models.cfm_module.CFMLitModule
|
| 20 |
+
sigma_min: 1e-4
|
| 21 |
+
|
| 22 |
+
scheduler:
|
| 23 |
+
_target_: timm.scheduler.PolyLRScheduler
|
| 24 |
+
_partial_: True
|
| 25 |
+
warmup_t: 200
|
| 26 |
+
warmup_lr_init: 1e-8
|
| 27 |
+
t_initial: 2000
|
| 28 |
+
|
| 29 |
+
trainer:
|
| 30 |
+
devices: 2
|
| 31 |
+
max_epochs: 2000
|
| 32 |
+
check_val_every_n_epoch: 10
|
| 33 |
+
limit_val_batches: 0.01
|
conditional-flow-matching/runner/configs/experiment/image_fm.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /model: image_cfm.yaml
|
| 5 |
+
- override /callbacks: no_stopping
|
| 6 |
+
- override /logger:
|
| 7 |
+
- csv.yaml
|
| 8 |
+
- wandb.yaml
|
| 9 |
+
- override /datamodule: cifar.yaml
|
| 10 |
+
- override /trainer: ddp.yaml
|
| 11 |
+
|
| 12 |
+
name: "cfm"
|
| 13 |
+
seed: 42
|
| 14 |
+
|
| 15 |
+
datamodule:
|
| 16 |
+
batch_size: 128
|
| 17 |
+
|
| 18 |
+
model:
|
| 19 |
+
_target_: src.models.cfm_module.FMLitModule
|
| 20 |
+
sigma_min: 1e-4
|
| 21 |
+
|
| 22 |
+
scheduler:
|
| 23 |
+
_target_: timm.scheduler.PolyLRScheduler
|
| 24 |
+
_partial_: True
|
| 25 |
+
warmup_t: 200
|
| 26 |
+
warmup_lr_init: 1e-8
|
| 27 |
+
t_initial: 2000
|
| 28 |
+
|
| 29 |
+
trainer:
|
| 30 |
+
devices: 2
|
| 31 |
+
max_epochs: 2000
|
| 32 |
+
check_val_every_n_epoch: 10
|
| 33 |
+
limit_val_batches: 0.01
|
conditional-flow-matching/runner/configs/experiment/image_otcfm.yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /model: image_cfm.yaml
|
| 5 |
+
- override /callbacks: no_stopping
|
| 6 |
+
- override /logger:
|
| 7 |
+
- csv.yaml
|
| 8 |
+
- wandb.yaml
|
| 9 |
+
- override /datamodule: cifar.yaml
|
| 10 |
+
- override /trainer: ddp.yaml
|
| 11 |
+
|
| 12 |
+
name: "cfm"
|
| 13 |
+
seed: 42
|
| 14 |
+
|
| 15 |
+
datamodule:
|
| 16 |
+
batch_size: 128
|
| 17 |
+
|
| 18 |
+
model:
|
| 19 |
+
_target_: src.models.cfm_module.CFMLitModule
|
| 20 |
+
sigma_min: 1e-4
|
| 21 |
+
|
| 22 |
+
scheduler:
|
| 23 |
+
_target_: timm.scheduler.PolyLRScheduler
|
| 24 |
+
_partial_: True
|
| 25 |
+
warmup_t: 200
|
| 26 |
+
warmup_lr_init: 1e-8
|
| 27 |
+
t_initial: 2000
|
| 28 |
+
ot_sampler: "exact"
|
| 29 |
+
|
| 30 |
+
trainer:
|
| 31 |
+
devices: 2
|
| 32 |
+
max_epochs: 2000
|
| 33 |
+
check_val_every_n_epoch: 10
|
| 34 |
+
limit_val_batches: 0.01
|
conditional-flow-matching/runner/configs/experiment/trajectorynet.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- override /model: trajectorynet.yaml
|
| 5 |
+
- override /logger:
|
| 6 |
+
- csv.yaml
|
| 7 |
+
- wandb.yaml
|
| 8 |
+
- override /datamodule: twodim.yaml
|
| 9 |
+
|
| 10 |
+
name: "cnf"
|
| 11 |
+
seed: 42
|
| 12 |
+
|
| 13 |
+
datamodule:
|
| 14 |
+
batch_size: 1024
|
| 15 |
+
|
| 16 |
+
model:
|
| 17 |
+
optimizer:
|
| 18 |
+
weight_decay: 1e-5
|
| 19 |
+
|
| 20 |
+
trainer:
|
| 21 |
+
max_epochs: 1000
|
| 22 |
+
check_val_every_n_epoch: 10
|
conditional-flow-matching/runner/configs/extras/default.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# disable python warnings if they annoy you
|
| 2 |
+
ignore_warnings: False
|
| 3 |
+
|
| 4 |
+
# ask user for tags if none are provided in the config
|
| 5 |
+
enforce_tags: True
|
| 6 |
+
|
| 7 |
+
# pretty print config tree at the start of the run using Rich library
|
| 8 |
+
print_config: True
|
conditional-flow-matching/runner/configs/hparams_search/optuna.yaml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# example hyperparameter optimization of some experiment with Optuna:
|
| 4 |
+
# python train.py -m hparams_search=mnist_optuna experiment=example
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /hydra/sweeper: optuna
|
| 8 |
+
|
| 9 |
+
# choose metric which will be optimized by Optuna
|
| 10 |
+
# make sure this is the correct name of some metric logged in lightning module!
|
| 11 |
+
optimized_metric: "val/2-Wasserstein"
|
| 12 |
+
|
| 13 |
+
# here we define Optuna hyperparameter search
|
| 14 |
+
# it optimizes for value returned from function with @hydra.main decorator
|
| 15 |
+
# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
|
| 16 |
+
hydra:
|
| 17 |
+
mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
|
| 18 |
+
|
| 19 |
+
sweeper:
|
| 20 |
+
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
|
| 21 |
+
|
| 22 |
+
# storage URL to persist optimization results
|
| 23 |
+
# for example, you can use SQLite if you set 'sqlite:///example.db'
|
| 24 |
+
storage: null
|
| 25 |
+
|
| 26 |
+
# name of the study to persist optimization results
|
| 27 |
+
study_name: null
|
| 28 |
+
|
| 29 |
+
# number of parallel workers
|
| 30 |
+
n_jobs: 1
|
| 31 |
+
|
| 32 |
+
# 'minimize' or 'maximize' the objective
|
| 33 |
+
direction: maximize
|
| 34 |
+
|
| 35 |
+
# total number of runs that will be executed
|
| 36 |
+
n_trials: 20
|
| 37 |
+
|
| 38 |
+
# choose Optuna hyperparameter sampler
|
| 39 |
+
# you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
|
| 40 |
+
# docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
|
| 41 |
+
sampler:
|
| 42 |
+
_target_: optuna.samplers.TPESampler
|
| 43 |
+
seed: 1234
|
| 44 |
+
n_startup_trials: 10 # number of random sampling runs before optimization starts
|
| 45 |
+
|
| 46 |
+
# define hyperparameter search space
|
| 47 |
+
params:
|
| 48 |
+
model.optimizer.lr: interval(0.0001, 0.1)
|
| 49 |
+
datamodule.batch_size: choice(32, 64, 128, 256)
|
conditional-flow-matching/runner/configs/hydra/default.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://hydra.cc/docs/configure_hydra/intro/
|
| 2 |
+
|
| 3 |
+
# enable color logging
|
| 4 |
+
defaults:
|
| 5 |
+
- override hydra_logging: colorlog
|
| 6 |
+
- override job_logging: colorlog
|
| 7 |
+
|
| 8 |
+
# output directory, generated dynamically on each run
|
| 9 |
+
run:
|
| 10 |
+
dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
| 11 |
+
sweep:
|
| 12 |
+
dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
|
| 13 |
+
subdir: ${hydra.job.num}
|
| 14 |
+
job:
|
| 15 |
+
chdir: true
|
conditional-flow-matching/runner/configs/launcher/mila_cluster.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
#
|
| 3 |
+
defaults:
|
| 4 |
+
- override /hydra/launcher: submitit_slurm
|
| 5 |
+
|
| 6 |
+
hydra:
|
| 7 |
+
launcher:
|
| 8 |
+
partition: long
|
| 9 |
+
cpus_per_task: 2
|
| 10 |
+
mem_gb: 20
|
| 11 |
+
gres: gpu:1
|
| 12 |
+
timeout_min: 1440
|
| 13 |
+
array_parallelism: 10 # max num of tasks to run in parallel (via job array)
|
| 14 |
+
setup:
|
| 15 |
+
- "module purge"
|
| 16 |
+
- "module load miniconda/3"
|
| 17 |
+
- "conda activate myenv"
|
| 18 |
+
- "unset CUDA_VISIBLE_DEVICES"
|
conditional-flow-matching/runner/configs/launcher/mila_cpu_cluster.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
#
|
| 3 |
+
defaults:
|
| 4 |
+
- override /hydra/launcher: submitit_slurm
|
| 5 |
+
|
| 6 |
+
hydra:
|
| 7 |
+
launcher:
|
| 8 |
+
partition: long-cpu
|
| 9 |
+
cpus_per_task: 1
|
| 10 |
+
mem_gb: 5
|
| 11 |
+
timeout_min: 100
|
| 12 |
+
array_parallelism: 64
|
| 13 |
+
setup:
|
| 14 |
+
- "module purge"
|
| 15 |
+
- "module load miniconda/3"
|
| 16 |
+
- "conda activate myenv"
|
conditional-flow-matching/runner/configs/local/.gitkeep
ADDED
|
File without changes
|
conditional-flow-matching/runner/configs/local/default.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# path to root directory
|
| 2 |
+
# this requires PROJECT_ROOT environment variable to exist
|
| 3 |
+
# PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py`
|
| 4 |
+
root_dir: ${oc.env:PROJECT_ROOT}
|
| 5 |
+
|
| 6 |
+
scratch_dir: ${oc.env:PROJECT_ROOT}
|
| 7 |
+
|
| 8 |
+
# path to data directory
|
| 9 |
+
data_dir: ${local.scratch_dir}/data/
|
| 10 |
+
|
| 11 |
+
# path to logging directory
|
| 12 |
+
log_dir: ${local.scratch_dir}/logs/
|
conditional-flow-matching/runner/configs/logger/comet.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://www.comet.ml
|
| 2 |
+
|
| 3 |
+
comet:
|
| 4 |
+
_target_: pytorch_lightning.loggers.comet.CometLogger
|
| 5 |
+
api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
|
| 6 |
+
save_dir: "${paths.output_dir}"
|
| 7 |
+
project_name: "lightning-hydra-template"
|
| 8 |
+
rest_api_key: null
|
| 9 |
+
# experiment_name: ""
|
| 10 |
+
experiment_key: null # set to resume experiment
|
| 11 |
+
offline: False
|
| 12 |
+
prefix: ""
|
conditional-flow-matching/runner/configs/logger/csv.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# csv logger built in lightning
|
| 2 |
+
|
| 3 |
+
csv:
|
| 4 |
+
_target_: pytorch_lightning.loggers.csv_logs.CSVLogger
|
| 5 |
+
save_dir: "${paths.output_dir}"
|
| 6 |
+
name: "csv/"
|
| 7 |
+
prefix: ""
|
conditional-flow-matching/runner/configs/logger/many_loggers.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# train with many loggers at once
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
# - comet.yaml
|
| 5 |
+
- csv.yaml
|
| 6 |
+
# - mlflow.yaml
|
| 7 |
+
# - neptune.yaml
|
| 8 |
+
- tensorboard.yaml
|
| 9 |
+
- wandb.yaml
|
conditional-flow-matching/runner/configs/logger/mlflow.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://mlflow.org
|
| 2 |
+
|
| 3 |
+
mlflow:
|
| 4 |
+
_target_: pytorch_lightning.loggers.mlflow.MLFlowLogger
|
| 5 |
+
# experiment_name: ""
|
| 6 |
+
# run_name: ""
|
| 7 |
+
tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
|
| 8 |
+
tags: null
|
| 9 |
+
# save_dir: "./mlruns"
|
| 10 |
+
prefix: ""
|
| 11 |
+
artifact_location: null
|
| 12 |
+
# run_id: ""
|
conditional-flow-matching/runner/configs/logger/neptune.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://neptune.ai
|
| 2 |
+
|
| 3 |
+
neptune:
|
| 4 |
+
_target_: pytorch_lightning.loggers.neptune.NeptuneLogger
|
| 5 |
+
api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
|
| 6 |
+
project: username/lightning-hydra-template
|
| 7 |
+
# name: ""
|
| 8 |
+
log_model_checkpoints: True
|
| 9 |
+
prefix: ""
|
conditional-flow-matching/runner/configs/logger/tensorboard.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://www.tensorflow.org/tensorboard/
|
| 2 |
+
|
| 3 |
+
tensorboard:
|
| 4 |
+
_target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
|
| 5 |
+
save_dir: "${paths.output_dir}/tensorboard/"
|
| 6 |
+
name: null
|
| 7 |
+
log_graph: False
|
| 8 |
+
default_hp_metric: True
|
| 9 |
+
prefix: ""
|
| 10 |
+
# version: ""
|
conditional-flow-matching/runner/configs/logger/wandb.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://wandb.ai
|
| 2 |
+
|
| 3 |
+
wandb:
|
| 4 |
+
_target_: pytorch_lightning.loggers.wandb.WandbLogger
|
| 5 |
+
# name: "" # name of the run (normally generated by wandb)
|
| 6 |
+
save_dir: "${paths.output_dir}"
|
| 7 |
+
offline: False
|
| 8 |
+
id: null # pass correct id to resume experiment!
|
| 9 |
+
anonymous: null # enable anonymous logging
|
| 10 |
+
project: "conditional-flow-model"
|
| 11 |
+
log_model: False # upload lightning ckpts
|
| 12 |
+
prefix: "" # a string to put at the beginning of metric keys
|
| 13 |
+
# entity: "" # set to name of your wandb team
|
| 14 |
+
group: ""
|
| 15 |
+
tags: []
|
| 16 |
+
job_type: ""
|