File size: 10,800 Bytes
fc0f7bd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 | {
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "j_LlXHYcmRaC"
},
"source": [
"# Gradient Accumulation\n",
"\n",
"[](https://colab.sandbox.google.com/github/google-deepmind/optax/blob/main/examples/gradient_accumulation.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vQro77whXULU"
},
"source": [
"_Gradient accumulation_ is a technique where the gradients for several consecutive optimization steps are combined together, so that they can be applied at regular repeating intervals.\n",
"\n",
"One example where this is useful is to simulate training with a larger batch size than would fit into the available device memory. Another example is in the context of multi-task learning, where batches for different tasks may be visited in a round-robin fashion. Gradient accumulation makes it possible to simulate training on one large batch containing all of the tasks together.\n",
"\n",
"In this example, we give an example of implementing gradient accumulation using `optax.MultiSteps`. We start by bringing in some imports and defining some type annotations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9cu0kFNrnJj7"
},
"outputs": [],
"source": [
"from typing import Iterable\n",
"\n",
"import flax.linen as nn\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import optax\n",
"import chex\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TIKfATeXnW3B"
},
"source": [
"The following implements a network and loss function that could be used in an image classification problem."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bJ1RWa4rnZmR"
},
"outputs": [],
"source": [
"class MLP(nn.Module):\n",
" \"\"\"A simple multilayer perceptron model.\"\"\"\n",
"\n",
" @nn.compact\n",
" def __call__(self, x):\n",
" # Flattens inputs in the batch.\n",
" x = x.reshape((x.shape[0], -1))\n",
" x = nn.Dense(features=512)(x)\n",
" x = nn.relu(x)\n",
" x = nn.Dense(features=512)(x)\n",
" x = nn.relu(x)\n",
" x = nn.Dense(features=10)(x)\n",
" return x\n",
"\n",
"net = MLP()\n",
"\n",
"def loss_fn(params, batch):\n",
" \"\"\"Computes loss over a mini-batch.\n",
" \"\"\"\n",
" logits = net.apply(params, batch['image'])\n",
" loss = optax.softmax_cross_entropy_with_integer_labels(\n",
" logits=logits, labels=batch['label']\n",
" ).mean()\n",
" return loss"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FDq-pRJGnksx"
},
"source": [
"We implement a training loop to perform gradient descent as follows."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uqKt4aBJXiBj"
},
"outputs": [],
"source": [
"def build_train_step(optimizer: optax.GradientTransformation):\n",
" \"\"\"Builds a function for executing a single step in the optimization.\"\"\"\n",
"\n",
" @jax.jit\n",
" def update(params, opt_state, batch):\n",
" grads = jax.grad(loss_fn)(params, batch)\n",
" updates, opt_state = optimizer.update(grads, opt_state)\n",
" params = optax.apply_updates(params, updates)\n",
" return params, opt_state\n",
"\n",
" return update\n",
"\n",
"\n",
"def fit(\n",
" optimizer: optax.GradientTransformation,\n",
" params: optax.Params,\n",
" batches: Iterable[dict[str, jnp.ndarray]],\n",
") -\u003e optax.Params:\n",
" \"\"\"Executes a train loop over the train batches using the given optimizer.\"\"\"\n",
"\n",
" train_step = build_train_step(optimizer)\n",
" opt_state = optimizer.init(params)\n",
"\n",
" for batch in batches:\n",
" params, opt_state = train_step(params, opt_state, batch)\n",
"\n",
" return params"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pTaorGnceOGs"
},
"source": [
"The following generates some random image-like data to test with our networks. The shapes used here correspond to the shapes that might appear in an MNIST classifier.\n",
"\n",
"We also initialize some parameters and a base optimizer to share through the following examples."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yK75QHOML7h9"
},
"outputs": [],
"source": [
"EXAMPLES = jax.random.uniform(jax.random.PRNGKey(0), (9, 28, 28, 1))\n",
"LABELS = jax.random.randint(jax.random.PRNGKey(0), (9,), minval=0, maxval=10)\n",
"\n",
"optimizer = optax.sgd(1e-4)\n",
"params = net.init(jax.random.PRNGKey(0), EXAMPLES)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "luk1TPW6efgQ"
},
"source": [
"## Splitting updates for one batch over multiple steps\n",
"\n",
"\n",
"The following two snippets will compute numerically identical results, but with the difference that the second snippet will use gradient accumulation over three batches to mimic the first snippet, which performs a single step with one large batch.\n",
"\n",
"We start with the snippet that runs a training loop over a single batch containing all examples,"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hyykkSEio0Tx"
},
"outputs": [],
"source": [
"new_params_single_batch = fit(\n",
" optimizer,\n",
" params,\n",
" batches=[dict(image=EXAMPLES, label=LABELS),]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7qIpPp0Jo6WT"
},
"source": [
"In this second snippet, our training loop will execute three training steps that together also contain all of the examples. In this case, the optimizer is wrapped with `optax.MultiSteps`, with `every_k_schedule=3`. This means that instead of applying gradient updates directly, the raw gradients will be combined together until the third step, where the wrapped optimizer will be applied to the average over the raw gradients seen up until now. For the \"interim\" steps, the updates returned by the optimizer will be all-zeros, resulting in no change to the parameters during these steps."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pV1yZRxio2LS"
},
"outputs": [],
"source": [
"new_params_gradient_accumulation = fit(\n",
" optax.MultiSteps(optimizer, every_k_schedule=3),\n",
" params,\n",
" batches=[\n",
" dict(image=EXAMPLES[0:3], label=LABELS[0:3]),\n",
" dict(image=EXAMPLES[3:6], label=LABELS[3:6]),\n",
" dict(image=EXAMPLES[6:9], label=LABELS[6:9]),\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gu8JnqgKo9Jq"
},
"source": [
"We can now verify that both training loops compute identical results as follows."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "X2hWzwFkK43k"
},
"outputs": [],
"source": [
"chex.assert_trees_all_close(\n",
" new_params_single_batch,\n",
" new_params_gradient_accumulation,\n",
" atol=1e-7,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ub0GHPvvhIKI"
},
"source": [
"## Interaction of `optax.MultiStep` with schedules.\n",
"\n",
"The snippet below is identical to the snippet above, except we additionally introduce a learning rate schedule. As above, the second call to `fit` is using gradient accumulation. Similarly to before, we find that both train loops compute compute identical outputs (up to numerical errors).\n",
"\n",
"This happens because the learning rate schedule in `optax.MultiStep` is only updated once for each of the _outer_ steps. In particular, the state of the inner optimizer is only updated each time `every_k_schedule` optimizer steps have been taken."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "o9CS96VjMuON"
},
"outputs": [],
"source": [
"learning_rate_schedule = optax.piecewise_constant_schedule(\n",
" init_value=1.0,\n",
" boundaries_and_scales={\n",
" 0: 1e-4,\n",
" 1: 1e-1,\n",
" },\n",
")\n",
"\n",
"optimizer = optax.sgd(learning_rate_schedule)\n",
"\n",
"new_params_single_batch = fit(\n",
" optimizer,\n",
" params,\n",
" batches=[\n",
" dict(image=EXAMPLES, label=LABELS),\n",
" ],\n",
")\n",
"\n",
"new_params_gradient_accumulation = fit(\n",
" optax.MultiSteps(optimizer, every_k_schedule=3),\n",
" params,\n",
" batches=[\n",
" dict(image=EXAMPLES[0:3], label=LABELS[0:3]),\n",
" dict(image=EXAMPLES[3:6], label=LABELS[3:6]),\n",
" dict(image=EXAMPLES[6:9], label=LABELS[6:9]),\n",
" ],\n",
")\n",
"\n",
"chex.assert_trees_all_close(\n",
" new_params_single_batch,\n",
" new_params_gradient_accumulation,\n",
" atol=1e-7,\n",
")"
]
}
],
"metadata": {
"colab": {
"last_runtime": {
"build_target": "//learning/grp/tools/ml_python:ml_notebook",
"kind": "private"
},
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
|