--- license: mit library_name: optax tags: - optimizer - learned-optimizer - meta-learning - jax --- # Celo2: Towards Learned Optimization Free Lunch

Paper Code License: MIT

Official pretrained weights for **Celo2**: This variant applies (and is meta-trained with) a harness that includes Newton-Schulz orthogonalization on top of the learned update for matrix parameters and uses AdamW for biases/embeddings. For a fully-learned variant without any harness, see [celo2-base](https://huggingface.co/amoudgl/celo2-base). ## Quickstart Download checkpoint and install: ```bash pip install git+https://github.com/amoudgl/celo2.git hf download amoudgl/celo2 --local-dir ./celo2 ``` Use `load_checkpoint` method to fetch pretrained params from checkpoint path: ```python from celo2_optax import load_checkpoint pretrained_params = load_checkpoint('./celo2/theta.state') ``` Standard optax usage with `scale_by_celo2` method that takes pretrained params as input: ```python import optax from celo2_optax import scale_by_celo2 optimizer = optax.multi_transform( transforms={ 'celo2': optax.chain( scale_by_celo2(pretrained_params, orthogonalize=True), optax.add_decayed_weights(weight_decay), optax.scale_by_learning_rate(lr_schedule), ), 'adam': optax.adamw(lr_schedule, 0.9, 0.95, weight_decay=weight_decay), }, param_labels=lambda params: jax.tree.map_with_path( lambda path, val: 'adam' if val.ndim <= 1 or 'embed' in jax.tree_util.keystr(path) else 'celo2', params, ), ) ``` ## Loading and inspecting MLP update rule weights ```python from celo2_optax import load_checkpoint import jax pretrained_params = load_checkpoint('./celo2/theta.state') # dictionary containing weights print(jax.tree.map(lambda x: x.shape, pretrained_params)) ``` The checkpoint contains a small MLP stored under the `ff_mod_stack` key with weight matrices (`w0__*`, `w1`, `w2`) and biases (`b0`, `b1`, `b2`). Each `w0__*` key contains weights corresponding to particular input feature such as momentum, gradient, parameter, etc. ## Meta-training config | Key | Value | | ----------------------- | ------------------------------------------------------------ | | **Optimizer architecture** | MLP, 2 hidden layers, 8 units each | | **Meta-training tasks** | 4 image classification tasks (MNIST, FMNIST, CIFAR-10, SVHN) | | **Task architecture** | MLP (64-32-10) | | **Meta-trainer** | Persistent Evolution Strategies (PES) | | **Outer iterations** | 100K | | **Truncation length** | 50 | | **Min unroll length** | 100 | | **Max unroll length** | 2000 | For more details, see config JSON included in the repo [here](./config.json). ## Files | File | Description | | ------------- | -------------------------------- | | `theta.state` | Pretrained MLP optimizer weights | | `config.json` | Meta-training configuration | ## Citation ```bibtex @misc{moudgil2026celo2, title={Celo2: Towards Learned Optimization Free Lunch}, author={Abhinav Moudgil and Boris Knyazev and Eugene Belilovsky}, year={2026}, eprint={2602.19142}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2602.19142}, } ```