--- license: mit library_name: optax tags: - optimizer - learned-optimizer - meta-learning - jax --- # Celo2-base

Paper Code License: MIT

Official pretrained weights for **Celo2-base** learned update rule: This variant uses the learned update rule for all parameters without any optimization harness. For better performance, see [celo2](https://huggingface.co/amoudgl/celo2) that uses Newton-Schulz orthogonalization and AdamW for biases/embeddings. ## Quickstart Download checkpoint and install: ```bash pip install git+https://github.com/amoudgl/celo2.git hf download amoudgl/celo2-base --local-dir ./celo2-base ``` Use `load_checkpoint` method to fetch pretrained params from checkpoint path: ```python from celo2_optax import load_checkpoint pretrained_params = load_checkpoint('./celo2-base/theta.state') ``` Standard optax usage with `scale_by_celo2` method that takes pretrained params as input: ```python import optax from celo2_optax import scale_by_celo2 optimizer = optax.chain( scale_by_celo2(pretrained_params, orthogonalize=False), optax.add_decayed_weights(weight_decay), optax.scale_by_learning_rate(lr_schedule), ) ``` ## Loading and inspecting MLP update rule weights ```python from celo2_optax import load_checkpoint import jax pretrained_params = load_checkpoint('./celo2-base/theta.state') # dictionary containing weights print(jax.tree.map(lambda x: x.shape, pretrained_params)) ``` The checkpoint contains a small MLP stored under the `ff_mod_stack` key with weight matrices (`w0__*`, `w1`, `w2`) and biases (`b0`, `b1`, `b2`). Each `w0__*` key contains weights corresponding to particular input feature such as momentum, gradient, parameter, etc. ## Meta-training config | Key | Value | | ----------------------- | ------------------------------------------------------------ | | **Optimizer architecture** | MLP, 2 hidden layers, 8 units each | | **Meta-training tasks** | 4 image classification tasks (MNIST, FMNIST, CIFAR-10, SVHN) | | **Task architecture** | MLP (64-32-10) | | **Meta-trainer** | Persistent Evolution Strategies (PES) | | **Outer iterations** | 100K | | **Truncation length** | 50 | | **Min unroll length** | 100 | | **Max unroll length** | 2000 | For more details, see config JSON included in the repo [here](./config.json). ## Files | File | Description | | ------------- | -------------------------------- | | `theta.state` | Pretrained MLP optimizer weights | | `config.json` | Meta-training configuration | ## Citation ```bibtex @misc{moudgil2026celo2, title={Celo2: Towards Learned Optimization Free Lunch}, author={Abhinav Moudgil and Boris Knyazev and Eugene Belilovsky}, year={2026}, eprint={2602.19142}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={https://arxiv.org/abs/2602.19142}, } ```