File size: 9,634 Bytes
fc0f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j_LlXHYcmRaC"
      },
      "source": [
        "# MLP MNIST\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/google-deepmind/optax/blob/main/examples/mlp_mnist.ipynb)\n",
        "\n",
        "This notebook trains a simple Multilayer Perceptron (MLP) classifier for hand-written digit recognition (MNIST dataset)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9cu0kFNrnJj7"
      },
      "outputs": [],
      "source": [
        "from typing import Sequence\n",
        "\n",
        "from flax import linen as nn\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import optax\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2Adl_l_uZs1d"
      },
      "outputs": [],
      "source": [
        "# @markdown The learning rate for the optimizer:\n",
        "LEARNING_RATE = 0.002 # @param{type:\"number\"}\n",
        "# @markdown Number of samples in each batch:\n",
        "BATCH_SIZE = 128 # @param{type:\"integer\"}\n",
        "# @markdown Total number of epochs to train for:\n",
        "N_EPOCHS = 1 # @param{type:\"integer\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZZej3FcOhuRE"
      },
      "source": [
        "MNIST is a dataset of 28x28 images with 1 channel. We now load the dataset using `tensorflow_datasets`, apply min-max normalization to the images, shuffle the data in the train set and create batches of size `BATCH_SIZE`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xPZ0paOehHWg"
      },
      "outputs": [],
      "source": [
        "(train_loader, test_loader), info = tfds.load(\n",
        "    \"mnist\", split=[\"train\", \"test\"], as_supervised=True, with_info=True\n",
        ")\n",
        "\n",
        "min_max_rgb = lambda image, label: (tf.cast(image, tf.float32) / 255., label)\n",
        "train_loader = train_loader.map(min_max_rgb)\n",
        "test_loader = test_loader.map(min_max_rgb)\n",
        "\n",
        "NUM_CLASSES = info.features[\"label\"].num_classes\n",
        "IMG_SIZE = info.features[\"image\"].shape\n",
        "\n",
        "train_loader_batched = train_loader.shuffle(\n",
        "    buffer_size=10_000, reshuffle_each_iteration=True\n",
        ").batch(BATCH_SIZE, drop_remainder=True)\n",
        "\n",
        "test_loader_batched = test_loader.batch(BATCH_SIZE, drop_remainder=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XkLaC2MlbAqa"
      },
      "source": [
        "The data is ready! Next let's define a model. Optax is agnostic to which (if any) neural network library is used. Here we use Flax to implement a simple MLP."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RppusWrcaXzX"
      },
      "outputs": [],
      "source": [
        "class MLP(nn.Module):\n",
        "  \"\"\"A simple multilayer perceptron model for image classification.\"\"\"\n",
        "  hidden_sizes: Sequence[int] = (1000, 1000)\n",
        "\n",
        "  @nn.compact\n",
        "  def __call__(self, x):\n",
        "    # Flattens images in the batch.\n",
        "    x = x.reshape((x.shape[0], -1))\n",
        "    x = nn.Dense(features=self.hidden_sizes[0])(x)\n",
        "    x = nn.relu(x)\n",
        "    x = nn.Dense(features=self.hidden_sizes[1])(x)\n",
        "    x = nn.relu(x)\n",
        "    x = nn.Dense(features=NUM_CLASSES)(x)\n",
        "    return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DKOi55MgdPyp"
      },
      "outputs": [],
      "source": [
        "net = MLP()\n",
        "\n",
        "@jax.jit\n",
        "def predict(params, inputs):\n",
        "  return net.apply({\"params\": params}, inputs)\n",
        "\n",
        "\n",
        "@jax.jit\n",
        "def loss_accuracy(params, data):\n",
        "  \"\"\"Computes loss and accuracy over a mini-batch.\n",
        "\n",
        "  Args:\n",
        "    params: parameters of the model.\n",
        "    bn_params: state of the model.\n",
        "    data: tuple of (inputs, labels).\n",
        "    is_training: if true, uses train mode, otherwise uses eval mode.\n",
        "\n",
        "  Returns:\n",
        "    loss: float\n",
        "  \"\"\"\n",
        "  inputs, labels = data\n",
        "  logits = predict(params, inputs)\n",
        "  loss = optax.softmax_cross_entropy_with_integer_labels(\n",
        "      logits=logits, labels=labels\n",
        "  ).mean()\n",
        "  accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)\n",
        "  return loss, {\"accuracy\": accuracy}\n",
        "\n",
        "@jax.jit\n",
        "def update_model(state, grads):\n",
        "  return state.apply_gradients(grads=grads)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0eB2dhIpjTIi"
      },
      "source": [
        "Next we need to initialize network parameters and solver state. We also define a convenience function `dataset_stats` that we'll call once per epoch to collect the loss and accuracy of our solver over the test set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PBnbq7gui34L"
      },
      "outputs": [],
      "source": [
        "solver = optax.adam(LEARNING_RATE)\n",
        "rng = jax.random.PRNGKey(0)\n",
        "dummy_data = jnp.ones((1,) + IMG_SIZE, dtype=jnp.float32)\n",
        "\n",
        "params = net.init({\"params\": rng}, dummy_data)[\"params\"]\n",
        "\n",
        "solver_state = solver.init(params)\n",
        "\n",
        "def dataset_stats(params, data_loader):\n",
        "  \"\"\"Computes loss and accuracy over the dataset `data_loader`.\"\"\"\n",
        "  all_accuracy = []\n",
        "  all_loss = []\n",
        "  for batch in data_loader.as_numpy_iterator():\n",
        "    batch_loss, batch_aux = loss_accuracy(params, batch)\n",
        "    all_loss.append(batch_loss)\n",
        "    all_accuracy.append(batch_aux[\"accuracy\"])\n",
        "  return {\"loss\": np.mean(all_loss), \"accuracy\": np.mean(all_accuracy)}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4H6GWNJf0XTY"
      },
      "source": [
        "Finally, we do the actual training. The next cell train the model for  `N_EPOCHS`. Within each epoch we iterate over the batched loader `train_loader_batched`, and once per epoch we also compute the test set accuracy and loss."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DeQr0urBjoDj"
      },
      "outputs": [],
      "source": [
        "train_accuracy = []\n",
        "train_losses = []\n",
        "\n",
        "# Computes test set accuracy at initialization.\n",
        "test_stats = dataset_stats(params, test_loader_batched)\n",
        "test_accuracy = [test_stats[\"accuracy\"]]\n",
        "test_losses = [test_stats[\"loss\"]]\n",
        "\n",
        "\n",
        "@jax.jit\n",
        "def train_step(params, solver_state, batch):\n",
        "  # Performs a one step update.\n",
        "  (loss, aux), grad = jax.value_and_grad(loss_accuracy, has_aux=True)(\n",
        "      params, batch\n",
        "  )\n",
        "  updates, solver_state = solver.update(grad, solver_state, params)\n",
        "  params = optax.apply_updates(params, updates)\n",
        "  return params, solver_state, loss, aux\n",
        "\n",
        "\n",
        "for epoch in range(N_EPOCHS):\n",
        "  train_accuracy_epoch = []\n",
        "  train_losses_epoch = []\n",
        "\n",
        "  for step, train_batch in enumerate(train_loader_batched.as_numpy_iterator()):\n",
        "    params, solver_state, train_loss, train_aux = train_step(\n",
        "        params, solver_state, train_batch\n",
        "    )\n",
        "    train_accuracy_epoch.append(train_aux[\"accuracy\"])\n",
        "    train_losses_epoch.append(train_loss)\n",
        "    if step % 20 == 0:\n",
        "      print(\n",
        "          f\"step {step}, train loss: {train_loss:.2e}, train accuracy:\"\n",
        "          f\" {train_aux['accuracy']:.2f}\"\n",
        "      )\n",
        "\n",
        "  test_stats = dataset_stats(params, test_loader_batched)\n",
        "  test_accuracy.append(test_stats[\"accuracy\"])\n",
        "  test_losses.append(test_stats[\"loss\"])\n",
        "  train_accuracy.append(np.mean(train_accuracy_epoch))\n",
        "  train_losses.append(np.mean(train_losses_epoch))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yyS1oRZBtytP"
      },
      "outputs": [],
      "source": [
        "f\"Improved accuracy on test DS from {test_accuracy[0]} to {test_accuracy[-1]}\""
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}