File size: 10,529 Bytes
fc0f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3tN0poN3zZ4L"
      },
      "source": [
        "# Meta-Learning\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/google-deepmind/optax/blob/main/examples/meta_learning.ipynb)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wGHOAQQNWSvl"
      },
      "source": [
        "Here we investigate using optax to meta-learn the learning rate of an optax optimizer. For a concrete example, we define a model where $y$ is linearly related to $x$, with some added noise,\n",
        "\n",
        "$$y = f(x) = 10 \\cdot x + \\mathcal{N}(0, 1).$$\n",
        "\n",
        "We imagine trying to solve the problem where we have access to some data generated by $f(\\cdot)$, but we do not know what $f$ actually is. Thus, we can attempt to  approximate $f(\\cdot)$ using a parametrized function, $f(\\theta, x).$ We can find a good value of $\\theta$ using gradient descent with optax. \n",
        "\n",
        "Gradient descent typically requires hyper-parameters be chosen, for example the learning rate of the optimizer. Using meta-learning, we are able to simultaneously optimize over parameters of our optimizer, whilst also optimizing over the underlying parameters for the function we are fitting.\n",
        "\n",
        "Many meta-learning algorithms now take the following form,\n",
        "\n",
        "1. An inner loop performs $N$ updates to a model's parameters $\\theta$ using the current value of a hyper-parameter $\\eta$.\n",
        "1. An outer loop updates the hyper-parameter $\\eta$ to maximize the quality of the inner updates to $\\theta$ by differentiating through the inner updates.\n",
        "\n",
        "Since in our case we are interested in the learning rate, which is\n",
        "bounded between 0 and 1, we parametrize the learning rate as a sigmoid\n",
        "over the meta parameter $\\eta$.\n",
        "\n",
        "In the following snippts, we will solve the problem using optax. To begin with, we define a generator that samples from the hidden underlying distribution."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Z6PrJuBOebo1"
      },
      "outputs": [],
      "source": [
        "from typing import Callable, Iterator, Tuple\n",
        "import chex\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import optax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NA3VjroNZjlo"
      },
      "outputs": [],
      "source": [
        "def generator() -\u003e Iterator[Tuple[chex.Array, chex.Array]]:\n",
        "  rng = jax.random.PRNGKey(0)\n",
        "\n",
        "  while True:\n",
        "    rng, k1, k2 = jax.random.split(rng, num=3)\n",
        "    x = jax.random.uniform(k1, minval=0.0, maxval=10.0)\n",
        "    y = 10.0 * x + jax.random.normal(k2)\n",
        "    yield x, y"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lUfPZ8FVfqc9"
      },
      "source": [
        "We demonstrate sampling from this as follows,"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tDpR8GFfftPD"
      },
      "outputs": [],
      "source": [
        "g = generator()\n",
        "\n",
        "for _ in range(5):\n",
        "  x, y = next(g)\n",
        "  print(f\"Sampled y = {y:.3f}, x = {x:.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J-L8Vyc9eFTy"
      },
      "source": [
        "We now define our parametrized function $f(\\theta, x)$, and choose a random initial value for the parameter $\\theta$."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N8PNwdASeNZc"
      },
      "outputs": [],
      "source": [
        "def f(theta: chex.Array, x: chex.Array) -\u003e chex.Array:\n",
        "  return x * theta\n",
        "\n",
        "theta = jax.random.normal(jax.random.PRNGKey(42))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wUldd3_2hObs"
      },
      "source": [
        "For the internal optimizer, which will fit $\\theta$, we will use RMSProp. For the external optimizer, which we will use to fit the learning rate, we will use Adam. In optax, we must use `optax.inject_hyperparams` in order to allow the outer optimizer to modify the learning rate of the inner optimizer.\n",
        "\n",
        "The inner optimizer is initialized with a learning rate of 0.1, and the outer optimizer is initialized with a learning rate of 0.03."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cWP2HN_KqQ6o"
      },
      "outputs": [],
      "source": [
        "init_learning_rate = jnp.array(0.1)\n",
        "meta_learning_rate = jnp.array(0.03)\n",
        "\n",
        "opt = optax.inject_hyperparams(optax.rmsprop)(learning_rate=init_learning_rate)\n",
        "meta_opt = optax.adam(learning_rate=meta_learning_rate)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uFG9tczunv3H"
      },
      "source": [
        "In the following code, we implement a step of gradient descent using the computed loss."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "suvw4J9Rn9Ju"
      },
      "outputs": [],
      "source": [
        "def loss(theta, x, y):\n",
        "  return optax.l2_loss(y, f(theta, x))\n",
        "\n",
        "\n",
        "def step(theta, state, x, y):\n",
        "  grad = jax.grad(loss)(theta, x, y)\n",
        "  updates, state = opt.update(grad, state)\n",
        "  theta = optax.apply_updates(theta, updates)\n",
        "  return theta, state"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cTbJC8h1oGe6"
      },
      "source": [
        "For the meta-learning part of the problem, we will use the inner step to compute an _outer_ loss value, and an _outer_ step.\n",
        "\n",
        "These look very similar to the inner step, however, this outer step will inject the value of the learning rate hyperparameter into the state for the inner optimizer.\n",
        "\n",
        "We use `jax.jit` on the outer step of our computation, to use JAX to optimize our computation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2KojrF7FiJYD"
      },
      "outputs": [],
      "source": [
        "@jax.jit\n",
        "def outer_loss(eta, theta, state, samples):\n",
        "  state.hyperparams['learning_rate'] = jax.nn.sigmoid(eta)\n",
        "\n",
        "  for x, y in samples[:-1]:\n",
        "    theta, state = step(theta, state, x, y)\n",
        "\n",
        "  x, y = samples[-1]\n",
        "\n",
        "  return loss(theta, x, y), (theta, state)\n",
        "\n",
        "\n",
        "@jax.jit\n",
        "def outer_step(eta, theta, meta_state, state, samples):\n",
        "  grad, (theta, state) = jax.grad(\n",
        "      outer_loss, has_aux=True)(eta, theta, state, samples)\n",
        "\n",
        "  meta_updates, meta_state = meta_opt.update(grad, meta_state)\n",
        "  eta = optax.apply_updates(eta, meta_updates)\n",
        "\n",
        "  return eta, theta, meta_state, state"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wr7oTKoxoVSE"
      },
      "source": [
        "In the following, we put all of the code above together in order to fit a value for $\\theta$."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N69sYlo_oUTZ"
      },
      "outputs": [],
      "source": [
        "state = opt.init(theta)\n",
        "# inverse sigmoid, to match the value we initialized the inner optimizer with.\n",
        "eta = -np.log(1. / init_learning_rate - 1)\n",
        "meta_state = meta_opt.init(eta)\n",
        "\n",
        "N = 7\n",
        "learning_rates = []\n",
        "thetas = []\n",
        "\n",
        "for i in range(2000):\n",
        "  samples = [next(g) for i in range(N)]\n",
        "  eta, theta, meta_state, state = outer_step(eta, theta, meta_state, state, samples)\n",
        "  learning_rates.append(jax.nn.sigmoid(eta))\n",
        "  thetas.append(theta)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lp0e2sS8pIqL"
      },
      "source": [
        "We can now plot the learning rates and values for $\\theta$ that we computed during our optimization,"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gp45HrNyJuW3"
      },
      "outputs": [],
      "source": [
        "fig, (ax1, ax2) = plt.subplots(2);\n",
        "fig.suptitle('Meta-learning RMSProp\\'s learning rate');\n",
        "plt.xlabel('Step');\n",
        "\n",
        "ax1.semilogy(range(len(learning_rates)), learning_rates);\n",
        "ax1.set(ylabel='Learning rate');\n",
        "ax1.label_outer();\n",
        "\n",
        "plt.xlabel('Number of updates');\n",
        "ax2.semilogy(range(len(thetas)), thetas);\n",
        "\n",
        "ax2.label_outer();\n",
        "ax2.set(ylabel='Theta');"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tNvaCoeYpQRL"
      },
      "source": [
        "The profile of the learning rate over time seems fairly reasonable, it decays towards zero to help convergence."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "Meta-Learning",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1UF4UGTpZvC0AWbko-4IFzsG43RH4b6WC",
          "timestamp": 1637243991785
        },
        {
          "file_id": "1ZRN4T0N8o-OvX4jIN52Beo3kzpNDM-Mx",
          "timestamp": 1637168142510
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}