Buckets:

|
download
raw
86.6 kB

Optimizing Neural Networks with Kronecker-factored Approximate Curvature

James Martens* and Roger Grosse†

Department of Computer Science, University of Toronto

Abstract

We propose an efficient method for approximating natural gradient descent in neural networks which we call Kronecker-factored Approximate Curvature (K-FAC). K-FAC is based on an efficiently invertible approximation of a neural network’s Fisher information matrix which is neither diagonal nor low-rank, and in some cases is completely non-sparse. It is derived by approximating various large blocks of the Fisher (corresponding to entire layers) as being the Kronecker product of two much smaller matrices. While only several times more expensive to compute than the plain stochastic gradient, the updates produced by K-FAC make much more progress optimizing the objective, which results in an algorithm that can be much faster than stochastic gradient descent with momentum in practice. And unlike some previously proposed approximate natural-gradient/Newton methods which use high-quality non-diagonal curvature matrices (such as Hessian-free optimization), K-FAC works very well in highly stochastic optimization regimes. This is because the cost of storing and inverting K-FAC’s approximation to the curvature matrix does not depend on the amount of data used to estimate it, which is a feature typically associated only with diagonal or low-rank approximations to the curvature matrix.

1 Introduction

The problem of training neural networks is one of the most important and highly investigated ones in machine learning. Despite work on layer-wise pretraining schemes, and various sophisticated optimization methods which try to approximate Newton-Raphson updates or natural gradient updates, stochastic gradient descent (SGD), possibly augmented with momentum, remains the method of choice for large-scale neural network training (Sutskever et al., 2013).

From the work on Hessian-free optimization (HF) (Martens, 2010) and related methods (e.g.


*jmartens@cs.toronto.edu

rgrosse@cs.toronto.eduVinyals and Povey, 2012) we know that updates computed using local curvature information can make much more progress per iteration than the scaled gradient. The reason that HF sees fewer practical applications than SGD are twofold. Firstly, its updates are much more expensive to compute, as they involve running linear conjugate gradient (CG) for potentially hundreds of iterations, each of which requires a matrix-vector product with the curvature matrix (which are as expensive to compute as the stochastic gradient on the current mini-batch). Secondly, HF's estimate of the curvature matrix must remain fixed while CG iterates, and thus the method is able to go through much less data than SGD can in a comparable amount of time, making it less well suited to stochastic optimizations.

As discussed in Martens and Sutskever (2012) and Sutskever et al. (2013), CG has the potential to be much faster at local optimization than gradient descent, when applied to quadratic objective functions. Thus, insofar as the objective can be locally approximated by a quadratic, each step of CG could potentially be doing a lot more work than each iteration of SGD, which would result in HF being much faster overall than SGD. However, there are examples of quadratic functions (e.g. Li, 2005), characterized by curvature matrices with highly spread-out eigenvalue distributions, where CG will have no distinct advantage over well-tuned gradient descent with momentum. Thus, insofar as the quadratic functions being optimized by CG within HF are of this character, HF shouldn't in principle be faster than well-tuned SGD with momentum. The extent to which neural network objective functions give rise to such quadratics is unclear, although Sutskever et al. (2013) provides some preliminary evidence that they do.

CG falls victim to this worst-case analysis because it is a first-order method. This motivates us to consider methods which don't rely on first-order methods like CG as their primary engines of optimization. One such class of methods which have been widely studied are those which work by directly inverting a diagonal, block-diagonal, or low-rank approximation to the curvature matrix (e.g. Becker and LeCun, 1989; Schaul et al., 2013; Zeiler, 2013; Le Roux et al., 2008; Ollivier, 2013). In fact, a diagonal approximation of the Fisher information matrix is used within HF as a preconditioner for CG. However, these methods provide only a limited performance improvement in practice, especially compared to SGD with momentum (see for example Schraudolph et al., 2007; Zeiler, 2013), and many practitioners tend to forgo them in favor of SGD or SGD with momentum.

We know that the curvature associated with neural network objective functions is highly non-diagonal, and that updates which properly respect and account for this non-diagonal curvature, such as those generated by HF, can make much more progress minimizing the objective than the plain gradient or updates computed from diagonal approximations of the curvature (usually $\sim 10^2$ HF updates are required to adequately minimize most objectives, compared to the $\sim 10^4 - 10^5$ required by methods that use diagonal approximations). Thus, if we had an efficient and direct way to compute the inverse of a high-quality non-diagonal approximation to the curvature matrix (i.e. without relying on first-order methods like CG) this could potentially yield an optimization method whose updates would be large and powerful like HF's, while being (almost) as cheap to compute as the stochastic gradient.In this work we develop such a method, which we call Kronecker-factored Approximate Curvature (K-FAC). We show that our method can be much faster in practice than even highly tuned implementations of SGD with momentum on certain standard neural network optimization benchmarks.

The main ingredient in K-FAC is a sophisticated approximation to the Fisher information matrix, which despite being neither diagonal nor low-rank, nor even block-diagonal with small blocks, can be inverted very efficiently, and can be estimated in an online fashion using arbitrarily large subsets of the training data (without increasing the cost of inversion).

This approximation is built in two stages. In the first, the rows and columns of the Fisher are divided into groups, each of which corresponds to all the weights in a given layer, and this gives rise to a block-partitioning of the matrix (where the blocks are much larger than those used by Le Roux et al. (2008) or Ollivier (2013)). These blocks are then approximated as Kronecker products between much smaller matrices, which we show is equivalent to making certain approximating assumptions regarding the statistics of the network’s gradients.

In the second stage, this matrix is further approximated as having an inverse which is either block-diagonal or block-tridiagonal. We justify this approximation through a careful examination of the relationships between inverse covariances, tree-structured graphical models, and linear regression. Notably, this justification doesn’t apply to the Fisher itself, and our experiments confirm that while the inverse Fisher does indeed possess this structure (approximately), the Fisher itself does not.

The rest of this paper is organized as follows. Section 2 gives basic background and notation for neural networks and the natural gradient. Section 3 describes our initial Kronecker product approximation to the Fisher. Section 4 describes our further block-diagonal and block-tridiagonal approximations of the inverse Fisher, and how these can be used to derive an efficient inversion algorithm. Section 5 describes how we compute online estimates of the quantities required by our inverse Fisher approximation over a large “window” of previously processed mini-batches (which makes K-FAC very different from methods like HF or KSD, which base their estimates of the curvature on a single mini-batch). Section 6 describes how we use our approximate Fisher to obtain a practical and robust optimization algorithm which requires very little manual tuning, through the careful application of various theoretically well-founded “damping” techniques that are standard in the optimization literature. Note that damping techniques compensate both for the local quadratic approximation being implicitly made to the objective, and for our further approximation of the Fisher, and are non-optional for essentially any 2nd-order method like K-FAC to work properly, as is well established by both theory and practice within the optimization literature (Nocedal and Wright, 2006). Section 7 describes a simple and effective way of adding a type of “momentum” to K-FAC, which we have found works very well in practice. Section 8 describes the computational costs associated with K-FAC, and various ways to reduce them to the point where each update is at most only several times more expensive to compute than the stochastic gradient. Section 9 gives complete high-level pseudocode for K-FAC. Section 10 characterizes a broad class of network transformations and reparameterizations to which K-FAC is essentially invariant. Section11 considers some related prior methods for neural network optimization. Proofs of formal results are located in the appendix.

2 Background and notation

2.1 Neural Networks

In this section we will define the basic notation for feed-forward neural networks which we will use throughout this paper. Note that this presentation closely follows the one from Martens (2014).

A neural network transforms its input $a_0 = x$ to an output $f(x, \theta) = a_\ell$ through a series of $\ell$ layers, each of which consists of a bank of units/neurons. The units each receive as input a weighted sum of the outputs of units from the previous layer and compute their output via a nonlinear “activation” function. We denote by $s_i$ the vector of these weighted sums for the $i$ -th layer, and by $a_i$ the vector of unit outputs (aka “activities”). The precise computation performed at each layer $i \in {1, \dots, \ell}$ is given as follows:

si=Wiaˉi1ai=ϕi(si)\begin{aligned} s_i &= W_i \bar{a}_{i-1} \\ a_i &= \phi_i(s_i) \end{aligned}

where $\phi_i$ is an element-wise nonlinear function, $W_i$ is a weight matrix, and $\bar{a}_i$ is defined as the vector formed by appending to $a_i$ an additional homogeneous coordinate with value 1. Note that we do not include explicit bias parameters here as these are captured implicitly through our use of homogeneous coordinates. In particular, the last column of each weight matrix $W_i$ corresponds to what is usually thought of as the “bias vector”. Figure 1 illustrates our definition for $\ell = 2$ .

We will define $\theta = [\text{vec}(W_1)^\top \text{vec}(W_2)^\top \dots \text{vec}(W_\ell)^\top]^\top$ , which is the vector consisting of all of the network’s parameters concatenated together, where $\text{vec}$ is the operator which vectorizes matrices by stacking their columns together.

We let $L(y, z)$ denote the loss function which measures the disagreement between a prediction $z$ made by the network, and a target $y$ . The training objective function $h(\theta)$ is the average (or expectation) of losses $L(y, f(x, \theta))$ with respect to a training distribution $\hat{Q}{x,y}$ over input-target pairs $(x, y)$ . $h(\theta)$ is a proxy for the objective which we actually care about but don’t have access to, which is the expectation of the loss taken with respect to the true data distribution $Q{x,y}$ .

We will assume that the loss is given by the negative log probability associated with a simple predictive distribution $R_{y|z}$ for $y$ parameterized by $z$ , i.e. that we have

L(y,z)=logr(yz)L(y, z) = -\log r(y|z)

where $r$ is $R_{y|z}$ ’s density function. This is the case for both the standard least-squares and cross-entropy objective functions, where the predictive distributions are multivariate normal and multinomial, respectively.We will let $P_{y|x}(\theta) = R_{y|f(x,\theta)}$ denote the conditional distribution defined by the neural network, as parameterized by $\theta$ , and $p(y|x, \theta) = r(y|f(x, \theta))$ its density function. Note that minimizing the objective function $h(\theta)$ can be seen as maximum likelihood learning of the model $P_{y|x}(\theta)$ .

For convenience we will define the following additional notation:

Dv=dL(y,f(x,θ))dv=dlogp(yx,θ)dvandgi=Dsi\mathcal{D}v = \frac{dL(y, f(x, \theta))}{dv} = -\frac{d \log p(y|x, \theta)}{dv} \quad \text{and} \quad g_i = \mathcal{D}s_i

Algorithm 1 shows how to compute the gradient $\mathcal{D}\theta$ of the loss function of a neural network using standard backpropagation.


Algorithm 1 An algorithm for computing the gradient of the loss $L(y, f(x, \theta))$ for a given $(x, y)$ . Note that we are assuming here for simplicity that the $\phi_i$ are defined as coordinate-wise functions.


input: $a_0 = x; \theta$ mapped to $(W_1, W_2, \dots, W_\ell)$ .

/* Forward pass */

for all $i$ from 1 to $\ell$ do

siWiaˉi1s_i \leftarrow W_i \bar{a}_{i-1}

aiϕi(si)a_i \leftarrow \phi_i(s_i)

end for

/* Loss derivative computation */

DaL(y,z)zz=a\mathcal{D}a_\ell \leftarrow \frac{\partial L(y, z)}{\partial z} \bigg|_{z=a_\ell}

/* Backwards pass */

for all $i$ from $\ell$ down to 1 do

giDaiϕi(si)g_i \leftarrow \mathcal{D}a_i \odot \phi'_i(s_i)

DWigiaˉi1\mathcal{D}W_i \leftarrow g_i \bar{a}_{i-1}^\top

Dai1Wigi\mathcal{D}a_{i-1} \leftarrow W_i^\top g_i

end for

output: $\mathcal{D}\theta = [\text{vec}(\mathcal{D}W_1)^\top \text{vec}(\mathcal{D}W_2)^\top \dots \text{vec}(\mathcal{D}W_\ell)^\top]^\top$

---The diagram illustrates a feed-forward neural network with two layers. At the bottom, an input layer consists of 7 nodes in a row, labeled $a_0 = x$ . An upward arrow labeled $W_1$ connects this layer to a hidden layer above it. The hidden layer also has 7 nodes in a row, labeled $a_1$ . Above the hidden layer, an upward arrow labeled $W_2$ connects it to an output layer at the top. The output layer has 3 nodes in a row, labeled $a_2$ . The hidden layer nodes are also labeled $s_1$ and the output layer nodes are labeled $s_2$ .

Figure 1: A depiction of a standard feed-forward neural network for $\ell = 2$ .

2.2 The Natural Gradient

Because our network defines a conditional model $P_{y|x}(\theta)$ , it has an associated Fisher information matrix (which we will simply call “the Fisher”) which is given by

F=E[dlogp(yx,θ)dθdlogp(yx,θ)dθ]=E[DθDθ]F = \mathbb{E} \left[ \frac{d \log p(y|x, \theta)}{d\theta} \frac{d \log p(y|x, \theta)}{d\theta}^\top \right] = \mathbb{E}[\mathcal{D}\theta\mathcal{D}\theta^\top]

Here, the expectation is taken with respect to the data distribution $Q_x$ over inputs $x$ , and the model’s predictive distribution $P_{y|x}(\theta)$ over $y$ . Since we usually don’t have access to $Q_x$ , and the above expectation would likely be intractable even if we did, we will instead compute $F$ using the training distribution $\hat{Q}_x$ over inputs $x$ .

The well-known natural gradient (Amari, 1998) is defined as $F^{-1}\nabla h(\theta)$ . Motivated from the perspective of information geometry (Amari and Nagaoka, 2000), the natural gradient defines the direction in parameter space which gives the largest change in the objective per unit of change in the model, as measured by the KL-divergence. This is to be contrasted with the standard gradient, which can be defined as the direction in parameter space which gives the largest change in the objective per unit of change in the parameters, as measured by the standard Euclidean metric.

The natural gradient also has links to several classical ideas from optimization. It can be shown (Martens, 2014; Pascanu and Bengio, 2014) that the Fisher is equivalent to the Generalized Gauss-Newton matrix (GGN) (Schraudolph, 2002; Martens and Sutskever, 2012) in certain important cases, which is a well-known positive semi-definite approximation to the Hessian of the objective function. In particular, (Martens, 2014) showed that when the GGN is defined so that thenetwork is linearized up to the loss function, and the loss function corresponds to the negative log probability of observations under an exponential family model $R_{y|z}$ with $z$ representing the natural parameters, then the Fisher corresponds exactly to the GGN.1

The GGN has served as the curvature matrix of choice in HF and related methods, and so in light of its equivalence to the Fisher, these 2nd-order methods can be seen as approximate natural gradient methods. And perhaps more importantly from a practical perspective, natural gradient-based optimization methods can conversely be viewed as 2nd-order optimization methods, which as pointed out by Martens (2014), brings to bare the vast wisdom that has accumulated about how to make such methods work well in both theory and practice (e.g Nocedal and Wright, 2006). In Section 6 we productively make use of these connections in order to design a robust and highly effective optimization method using our approximation to the natural gradient/Fisher (which is developed in Sections 3 and 4).

For some good recent discussion and analysis of the natural gradient, see Arnold et al. (2011); Martens (2014); Pascanu and Bengio (2014).

3 A block-wise Kronecker-factored Fisher approximation

The main computational challenge associated with using the natural gradient is computing $F^{-1}$ (or its product with $\nabla h$ ). For large networks, with potentially millions of parameters, computing this inverse naively is computationally impractical. In this section we develop an initial approximation of $F$ which will be a key ingredient in deriving our efficiently computable approximation to $F^{-1}$ and the natural gradient.

Note that $\mathcal{D}\theta = [\text{vec}(\mathcal{D}W_1)^\top \text{vec}(\mathcal{D}W_2)^\top \cdots \text{vec}(\mathcal{D}W_\ell)^\top]^\top$ and so $F$ can be expressed as

F=E[DθDθ]=E[vec(DW1)vec(DW2)vec(DW)][vec(DW1)vec(DW2)vec(DW)]=[E[vec(DW1)vec(DW1)]E[vec(DW1)vec(DW2)]E[vec(DW1)vec(DW)]E[vec(DW2)vec(DW1)]E[vec(DW2)vec(DW2)]E[vec(DW2)vec(DW)]E[vec(DW)vec(DW1)]E[vec(DW)vec(DW2)]E[vec(DW)vec(DW)]]\begin{aligned} F &= \mathbb{E} [\mathcal{D}\theta \mathcal{D}\theta^\top] \\ &= \mathbb{E} [\text{vec}(\mathcal{D}W_1)^\top \text{vec}(\mathcal{D}W_2)^\top \cdots \text{vec}(\mathcal{D}W_\ell)^\top]^\top [\text{vec}(\mathcal{D}W_1)^\top \text{vec}(\mathcal{D}W_2)^\top \cdots \text{vec}(\mathcal{D}W_\ell)^\top] \\ &= \begin{bmatrix} \mathbb{E} [\text{vec}(\mathcal{D}W_1) \text{vec}(\mathcal{D}W_1)^\top] & \mathbb{E} [\text{vec}(\mathcal{D}W_1) \text{vec}(\mathcal{D}W_2)^\top] & \cdots & \mathbb{E} [\text{vec}(\mathcal{D}W_1) \text{vec}(\mathcal{D}W_\ell)^\top] \\ \mathbb{E} [\text{vec}(\mathcal{D}W_2) \text{vec}(\mathcal{D}W_1)^\top] & \mathbb{E} [\text{vec}(\mathcal{D}W_2) \text{vec}(\mathcal{D}W_2)^\top] & \cdots & \mathbb{E} [\text{vec}(\mathcal{D}W_2) \text{vec}(\mathcal{D}W_\ell)^\top] \\ \vdots & \vdots & \ddots & \vdots \\ \mathbb{E} [\text{vec}(\mathcal{D}W_\ell) \text{vec}(\mathcal{D}W_1)^\top] & \mathbb{E} [\text{vec}(\mathcal{D}W_\ell) \text{vec}(\mathcal{D}W_2)^\top] & \cdots & \mathbb{E} [\text{vec}(\mathcal{D}W_\ell) \text{vec}(\mathcal{D}W_\ell)^\top] \end{bmatrix} \end{aligned}

Thus, we see that $F$ can be viewed as an $\ell$ by $\ell$ block matrix, with the $(i, j)$ -th block $F_{i,j}$ given by $F_{i,j} = \mathbb{E} [\text{vec}(\mathcal{D}W_i) \text{vec}(\mathcal{D}W_j)^\top]$ .


1Note that the condition that $z$ represents the natural parameters might require one to formally include the nonlinear transformation usually performed by the final nonlinearity $\phi_\ell$ of the network (such as the logistic-sigmoid transform before a cross-entropy error) as part of the loss function $L$ instead. Equivalently, one could linearize the network only up to the input $s_\ell$ to $\phi_\ell$ when computing the GGN (see Martens and Sutskever (2012)).Noting that $\mathcal{D}W_i = g_i \bar{a}{i-1}^\top$ and that $\text{vec}(uv^\top) = v \otimes u$ we have $\text{vec}(\mathcal{D}W_i) = \text{vec}(g_i \bar{a}{i-1}^\top) = \bar{a}{i-1} \otimes g_i$ , and thus we can rewrite $F{i,j}$ as

Fi,j=E[vec(DWi)vec(DWj)]=E[(aˉi1gi)(aˉj1gj)]=E[(aˉi1gi)(aˉj1gj)]=E[aˉi1aˉj1gigj]\begin{aligned} F_{i,j} &= \mathbb{E} [\text{vec}(\mathcal{D}W_i) \text{vec}(\mathcal{D}W_j)^\top] = \mathbb{E} [(\bar{a}_{i-1} \otimes g_i)(\bar{a}_{j-1} \otimes g_j)^\top] = \mathbb{E} [(\bar{a}_{i-1} \otimes g_i)(\bar{a}_{j-1}^\top \otimes g_j^\top)] \\ &= \mathbb{E} [\bar{a}_{i-1} \bar{a}_{j-1}^\top \otimes g_i g_j^\top] \end{aligned}

where $A \otimes B$ denotes the Kronecker product between $A \in \mathbb{R}^{m \times n}$ and $B$ , and is given by

AB[[A]1,1B[A]1,nB[A]m,1B[A]m,nB]A \otimes B \equiv \begin{bmatrix} [A]_{1,1}B & \cdots & [A]_{1,n}B \\ \vdots & \ddots & \vdots \\ [A]_{m,1}B & \cdots & [A]_{m,n}B \end{bmatrix}

Note that the Kronecker product satisfies many convenient properties that we will make use of in this paper, especially the identity $(A \otimes B)^{-1} = A^{-1} \otimes B^{-1}$ . See Van Loan (2000) for a good discussion of the Kronecker product.

Our initial approximation $\tilde{F}$ to $F$ will be defined by the following block-wise approximation:

Fi,j=E[aˉi1aˉj1gigj]E[aˉi1aˉj1]E[gigj]=Aˉi1,j1Gi,j=F~i,j(1)F_{i,j} = \mathbb{E} [\bar{a}_{i-1} \bar{a}_{j-1}^\top \otimes g_i g_j^\top] \approx \mathbb{E} [\bar{a}_{i-1} \bar{a}_{j-1}^\top] \otimes \mathbb{E} [g_i g_j^\top] = \bar{A}_{i-1,j-1} \otimes G_{i,j} = \tilde{F}_{i,j} \quad (1)

where $\bar{A}_{i,j} = \mathbb{E} [\bar{a}_i \bar{a}j^\top]$ and $G{i,j} = \mathbb{E} [g_i g_j^\top]$ .

This gives

F~=[Aˉ0,0G1,1Aˉ0,1G1,2Aˉ0,1G1,Aˉ1,0G2,1Aˉ1,1G2,2Aˉ1,1G2,Aˉ1,0G,1Aˉ1,1G,2Aˉ1,1G,]\tilde{F} = \begin{bmatrix} \bar{A}_{0,0} \otimes G_{1,1} & \bar{A}_{0,1} \otimes G_{1,2} & \cdots & \bar{A}_{0,\ell-1} \otimes G_{1,\ell} \\ \bar{A}_{1,0} \otimes G_{2,1} & \bar{A}_{1,1} \otimes G_{2,2} & \cdots & \bar{A}_{1,\ell-1} \otimes G_{2,\ell} \\ \vdots & \vdots & \ddots & \vdots \\ \bar{A}_{\ell-1,0} \otimes G_{\ell,1} & \bar{A}_{\ell-1,1} \otimes G_{\ell,2} & \cdots & \bar{A}_{\ell-1,\ell-1} \otimes G_{\ell,\ell} \end{bmatrix}

which has the form of what is known as a Khatri-Rao product in multivariate statistics.

The expectation of a Kronecker product is, in general, not equal to the Kronecker product of expectations, and so this is indeed a major approximation to make, and one which likely won't become exact under any realistic set of assumptions, or as a limiting case in some kind of asymptotic analysis. Nevertheless, it seems to be fairly accurate in practice, and is able to successfully capture the “coarse structure” of the Fisher, as demonstrated in Figure 2 for an example network.

As we will see in later sections, this approximation leads to significant computational savings in terms of storage and inversion, which we will be able to leverage in order to design an efficient algorithm for computing an approximation to the natural gradient.

3.1 Interpretations of this approximation

Consider an arbitrary pair of weights $[W_i]{k_1,k_2}$ and $[W_j]{k_3,k_4}$ from the network, where $[\cdot]_{i,j}$ denotes the value of the $(i, j)$ -th entry. We have that the corresponding derivatives of these weights areFigure 2: A comparison of the exact Fisher $F$ and our block-wise Kronecker-factored approximation $\tilde{F}$ , for the middle 4 layers of a standard deep neural network partially trained to classify a 16x16 down-scaled version of MNIST. The network was trained with 7 iterations of K-FAC in batch mode, achieving 5% error (the error reached 0% after 22 iterations) . The network architecture is 256-20-20-20-20-20-10 and uses standard tanh units. On the left is the exact Fisher $F$ , in the middle is our approximation $\tilde{F}$ , and on the right is the difference of these. The dashed lines delineate the blocks. Note that for the purposes of visibility we plot the absolute values of the entries, with the white level corresponding linearly to the size of these values (up to some maximum, which is the same in each image).

given by $\mathcal{D}[W_i]{k_1,k_2} = \bar{a}^{(1)} g^{(1)}$ and $\mathcal{D}[W_j]{k_3,k_4} = \bar{a}^{(2)} g^{(2)}$ , where we denote for convenience $\bar{a}^{(1)} = [\bar{a}{i-1}]{k_1}$ , $\bar{a}^{(2)} = [\bar{a}{j-1}]{k_3}$ , $g^{(1)} = [g_i]{k_2}$ , and $g^{(2)} = [g_j]{k_4}$ .

The approximation given by eqn. 1 is equivalent to making the following approximation for each pair of weights:

E[D[Wi]k1,k2D[Wj]k3,k4]=E[(aˉ(1)g(1))(aˉ(2)g(2))]=E[aˉ(1)aˉ(2)g(1)g(2)]E[aˉ(1)aˉ(2)]E[g(1)g(2)](2)\mathbb{E} [\mathcal{D}[W_i]_{k_1,k_2} \mathcal{D}[W_j]_{k_3,k_4}] = \mathbb{E} [(\bar{a}^{(1)} g^{(1)})(\bar{a}^{(2)} g^{(2)})] = \mathbb{E} [\bar{a}^{(1)} \bar{a}^{(2)} g^{(1)} g^{(2)}] \approx \mathbb{E} [\bar{a}^{(1)} \bar{a}^{(2)}] \mathbb{E} [g^{(1)} g^{(2)}] \quad (2)

And thus one way to interpret the approximation in eqn. 1 is that we are assuming statistical independence between products $\bar{a}^{(1)} \bar{a}^{(2)}$ of unit activities and products $g^{(1)} g^{(2)}$ of unit input derivatives.

Another more detailed interpretation of the approximation emerges by considering the following expression for the approximation error $\mathbb{E} [\bar{a}^{(1)} \bar{a}^{(2)} g^{(1)} g^{(2)}] - \mathbb{E} [\bar{a}^{(1)} \bar{a}^{(2)}] \mathbb{E} [g^{(1)} g^{(2)}]$ (which is derived in the appendix):

κ(aˉ(1),aˉ(2),g(1),g(2))+E[aˉ(1)]κ(aˉ(2),g(1),g(2))+E[aˉ(2)]κ(aˉ(1),g(1),g(2))(3)\kappa(\bar{a}^{(1)}, \bar{a}^{(2)}, g^{(1)}, g^{(2)}) + \mathbb{E}[\bar{a}^{(1)}] \kappa(\bar{a}^{(2)}, g^{(1)}, g^{(2)}) + \mathbb{E}[\bar{a}^{(2)}] \kappa(\bar{a}^{(1)}, g^{(1)}, g^{(2)}) \quad (3)

Here $\kappa(\cdot)$ denotes the cumulant of its arguments. Cumulants are a natural generalization of the concept of mean and variance to higher orders, and indeed 1st-order cumulants are means and 2nd-order cumulants are covariances. Intuitively, cumulants of order $k$ measure the degree to which the interaction between variables is intrinsically of order $k$ , as opposed to arising from many lower-order interactions.

A basic upper bound for the approximation error is

κ(aˉ(1),aˉ(2),g(1),g(2))+E[aˉ(1)]κ(aˉ(2),g(1),g(2))+E[aˉ(2)]κ(aˉ(1),g(1),g(2))(4)|\kappa(\bar{a}^{(1)}, \bar{a}^{(2)}, g^{(1)}, g^{(2)})| + |\mathbb{E}[\bar{a}^{(1)}]| |\kappa(\bar{a}^{(2)}, g^{(1)}, g^{(2)})| + |\mathbb{E}[\bar{a}^{(2)}]| |\kappa(\bar{a}^{(1)}, g^{(1)}, g^{(2)})| \quad (4)which will be small if all of the higher-order cumulants are small (i.e. those of order 3 or higher). Note that in principle this upper bound may be loose due to possible cancellations between the terms in eqn. 3.

Because higher-order cumulants are zero for variables jointly distributed according to a multivariate Gaussian, it follows that this upper bound on the approximation error will be small insofar as the joint distribution over $\bar{a}^{(1)}$ , $\bar{a}^{(2)}$ , $g^{(1)}$ , and $g^{(2)}$ is well approximated by a multivariate Gaussian. And while we are not aware of an argument for why this should be the case in practice, it does seem to be the case that for the example network from Figure 2, the size of the error is well predicted by the size of the higher-order cumulants. In particular, the total approximation error, summed over all pairs of weights in the middle 4 layers, is 2894.4, and is of roughly the same size as the corresponding upper bound (4134.6), whose size is tied to that of the higher order cumulants (due to the impossibility of cancellations in eqn. 4).

4 Additional approximations to $\tilde{F}$ and inverse computations

To the best of our knowledge there is no efficient general method for inverting a Khatri-Rao product like $\tilde{F}$ . Thus, we must make further approximations if we hope to obtain an efficiently computable approximation of the inverse Fisher.

In the following subsections we argue that the inverse of $\tilde{F}$ can be reasonably approximated as having one of two special structures, either of which make it efficiently computable. The second of these will be slightly less restrictive than the first (and hence a better approximation) at the cost of some additional complexity. We will then show how matrix-vector products with these approximate inverses can be efficiently computed, which will thus give an efficient algorithm for computing an approximation to the natural gradient.

4.1 Structured inverses and the connection to linear regression

Suppose we are given a multivariate distribution whose associated covariance matrix is $\Sigma$ .

Define the matrix $B$ so that for $i \neq j$ , $[B]{i,j}$ is the coefficient on the $j$ -th variable in the optimal linear predictor of the $i$ -th variable from all the other variables, and for $i = j$ , $[B]{i,j} = 0$ . Then define the matrix $D$ to be the diagonal matrix where $[D]_{i,i}$ is the variance of the error associated with such a predictor of the $i$ -th variable.

Pourahmadi (2011) showed that $B$ and $D$ can be obtained from the inverse covariance $\Sigma^{-1}$ by the formulas

[B]i,j=[Σ1]i,j[Σ1]i,iand[D]i,i=1[Σ1]i,i[B]_{i,j} = -\frac{[\Sigma^{-1}]_{i,j}}{[\Sigma^{-1}]_{i,i}} \quad \text{and} \quad [D]_{i,i} = \frac{1}{[\Sigma^{-1}]_{i,i}}from which it follows that the inverse covariance matrix can be expressed as

Σ1=D1(IB)\Sigma^{-1} = D^{-1}(I - B)

Intuitively, this result says that each row of the inverse covariance $\Sigma^{-1}$ is given by the coefficients of the optimal linear predictor of the $i$ -th variable from the others, up to a scaling factor. So if the $j$ -th variable is much less “useful” than the other variables for predicting the $i$ -th variable, we can expect that the $(i, j)$ -th entry of the inverse covariance will be relatively small.

Note that “usefulness” is a subtle property as we have informally defined it. In particular, it is not equivalent to the degree of correlation between the $j$ -th and $i$ -th variables, or any such simple measure. As a simple example, consider the case where the $j$ -th variable is equal to the $k$ -th variable plus independent Gaussian noise. Since any linear predictor can achieve a lower variance simply by shifting weight from the $j$ -th variable to the $k$ -th variable, we have that the $j$ -th variable is not useful (and its coefficient will thus be zero) in the task of predicting the $i$ -th variable for any setting of $i$ other than $i = j$ or $i = k$ .

Noting that the Fisher $F$ is a covariance matrix over $\mathcal{D}\theta$ w.r.t. the model’s distribution (because $\mathbb{E}[\mathcal{D}\theta] = 0$ by Lemma 4), we can thus apply the above analysis to the distribution over $\mathcal{D}\theta$ to gain insight into the approximate structure of $F^{-1}$ , and by extension its approximation $\tilde{F}^{-1}$ .

Consider the derivative $\mathcal{D}W_i$ of the loss with respect to the weights $W_i$ of layer $i$ . Intuitively, if we are trying to predict one of the entries of $\mathcal{D}W_i$ from the other entries of $\mathcal{D}\theta$ , those entries also in $\mathcal{D}W_i$ will likely be the most useful in this regard. Thus, it stands to reason that the largest entries of $\tilde{F}^{-1}$ will be those on the diagonal blocks, so that $\tilde{F}^{-1}$ will be well approximated as block-diagonal, with each block corresponding to a different $\mathcal{D}W_i$ .

Beyond the other entries of $\mathcal{D}W_i$ , it is the entries of $\mathcal{D}W_{i+1}$ and $\mathcal{D}W_{i-1}$ (i.e. those associated with adjacent layers) that will arguably be the most useful in predicting a given entry of $\mathcal{D}W_i$ . This is because the true process for computing the loss gradient only uses information from the layer below (during the forward pass) and from the layer above (during the backwards pass). Thus, approximating $\tilde{F}^{-1}$ as block-tridiagonal seems like a reasonable and milder alternative than taking it to be block-diagonal. Indeed, this approximation would be exact if the distribution over $\mathcal{D}\theta$ were given by a directed graphical model which generated each of the $\mathcal{D}W_i$ ’s, one layer at a time, from either $\mathcal{D}W_{i+1}$ or $\mathcal{D}W_{i-1}$ . Or equivalently, if $\mathcal{D}W_i$ were distributed according to an undirected Gaussian graphical model with binary potentials only between entries in the same or adjacent layers. Both of these models are depicted in Figure 4.

Now while in reality the $\mathcal{D}W_i$ ’s are generated using information from adjacent layers according to a process that is neither linear nor Gaussian, it nonetheless stands to reason that their joint statistics might be reasonably approximated by such a model. In fact, the idea of approximating the distribution over loss gradients with a directed graphical model forms the basis of the recent FANG method of Grosse and Salakhutdinov (2015).

Figure 3 examines the extent to which the inverse Fisher is well approximated as block-diagonal or block-tridiagonal for an example network.Figure 3: A comparison of our block-wise Kronecker-factored approximation $\tilde{F}$ , and its inverse, using the example neural network from Figure 2. On the left is $\tilde{F}$ , in the middle is its exact inverse, and on the right is a 4x4 matrix containing the averages of the absolute values of the entries in each block of the inverse. As predicted by our theory, the inverse exhibits an approximate block-tridiagonal structure, whereas $\tilde{F}$ itself does not. Note that the corresponding plots for the exact $F$ and its inverse look similar. The very small blocks visible on the diagonal of the inverse each correspond to the weights on the outgoing connections of a particular unit. The inverse was computed subject to the factored Tikhonov damping technique described in Sections 6.3 and 6.6, using the same value of $\gamma$ that was used by K-FAC at the iteration from which this example was taken (see Figure 2). Note that for the purposes of visibility we plot the absolute values of the entries, with the white level corresponding linearly to the size of these values (up to some maximum, which is chosen differently for the Fisher approximation and its inverse, due to the highly differing scales of these matrices).In the following two subsections we show how both the block-diagonal and block-tridiagonal approximations to $\check{F}^{-1}$ give rise to computationally efficient methods for computing matrix-vector products with it. And at the end of Section 4 we present two figures (Figures 5 and 6) which examine the quality of these approximations for an example network.

4.2 Approximating $\check{F}^{-1}$ as block-diagonal

Approximating $\check{F}^{-1}$ as block-diagonal is equivalent to approximating $\check{F}$ as block-diagonal. A natural choice for such an approximation $\check{\check{F}}$ of $\check{F}$ , is to take the block-diagonal of $\check{F}$ to be that of $\check{F}$ . This gives the matrix

Fˇˇ=diag(Fˇ1,1,Fˇ2,2,,Fˇ,)=diag(Aˉ0,0G1,1,Aˉ1,1G2,2,,Aˉ1,1G,)\check{\check{F}} = \text{diag} \left( \check{F}_{1,1}, \check{F}_{2,2}, \dots, \check{F}_{\ell,\ell} \right) = \text{diag} \left( \bar{A}_{0,0} \otimes G_{1,1}, \bar{A}_{1,1} \otimes G_{2,2}, \dots, \bar{A}_{\ell-1,\ell-1} \otimes G_{\ell,\ell} \right)

Using the identity $(A \otimes B)^{-1} = A^{-1} \otimes B^{-1}$ we can easily compute the inverse of $\check{\check{F}}$ as

Fˇˇ1=diag(Aˉ0,01G1,11,Aˉ1,11G2,21,,Aˉ1,11G,1)\check{\check{F}}^{-1} = \text{diag} \left( \bar{A}_{0,0}^{-1} \otimes G_{1,1}^{-1}, \bar{A}_{1,1}^{-1} \otimes G_{2,2}^{-1}, \dots, \bar{A}_{\ell-1,\ell-1}^{-1} \otimes G_{\ell,\ell}^{-1} \right)

Thus, computing $\check{\check{F}}^{-1}$ amounts to computing the inverses of $2\ell$ smaller matrices.

Then to compute $u = \check{\check{F}}^{-1}v$ , we can make use of the well-known identity $(A \otimes B) \text{vec}(X) = \text{vec}(BX A^\top)$ to get

Ui=Gi,i1ViAˉi1,i11U_i = G_{i,i}^{-1} V_i \bar{A}_{i-1,i-1}^{-1}

where $v$ maps to $(V_1, V_2, \dots, V_\ell)$ and $u$ maps to $(U_1, U_2, \dots, U_\ell)$ in an analogous way to how $\theta$ maps to $(W_1, W_2, \dots, W_\ell)$ .

Note that block-diagonal approximations to the Fisher information have been proposed before in TONGA (Le Roux et al., 2008), where each block corresponds to the weights associated with a particular unit. In our block-diagonal approximation, the blocks correspond to all the parameters in a given layer, and are thus much larger. In fact, they are so large that they would be impractical to invert as general matrices.

4.3 Approximating $\check{F}^{-1}$ as block-tridiagonal

Note that unlike in the above block-diagonal case, approximating $\check{F}^{-1}$ as block-tridiagonal is not equivalent to approximating $\check{F}$ as block-tridiagonal. Thus we require a more sophisticated approach to deal with such an approximation. We develop such an approach in this subsection.

To start, we will define $\hat{F}$ to be the matrix which agrees with $\check{F}$ on the tridiagonal blocks, and which satisfies the property that $\hat{F}^{-1}$ is block-tridiagonal. Note that this definition implies certain values for the off-tridiagonal blocks of $\hat{F}$ which will differ from those of $\check{F}$ insofar as $\check{F}^{-1}$ is not actually block-tridiagonal.Figure 4: A diagram depicting the UGGM corresponding to $\hat{F}^{-1}$ and its equivalent DGGM. The UGGM's edges are labeled with the corresponding weights of the model (these are distinct from the network's weights). Here, $(\hat{F}^{-1})_{i,j}$ denotes the $(i,j)$ -th block of $\hat{F}^{-1}$ . The DGGM's edges are labeled with the matrices that specify the linear mapping from the source node to the conditional mean of the destination node (whose conditional covariance is given by its label).

To establish that such a matrix $\hat{F}$ is well defined and can be inverted efficiently, we first observe that assuming that $\hat{F}^{-1}$ is block-tridiagonal is equivalent to assuming that it is the precision matrix of an undirected Gaussian graphical model (UGGM) over $\mathcal{D}\theta$ (as depicted in Figure 4), whose density function is proportional to $\exp(-\mathcal{D}\theta^\top \hat{F}^{-1} \mathcal{D}\theta)$ . As this graphical model has a tree structure, there is an equivalent directed graphical model with the same distribution and the same (undirected) graphical structure (e.g. Bishop, 2006), where the directionality of the edges is given by a directed acyclic graph (DAG). Moreover, this equivalent directed model will also be linear/Gaussian, and hence a directed Gaussian Graphical model (DGGM).

Next we will show how the parameters of such a DGGM corresponding to $\hat{F}$ can be efficiently recovered from the tridiagonal blocks of $\hat{F}$ , so that $\hat{F}$ is uniquely determined by these blocks (and hence well-defined). We will assume here that the direction of the edges is from the higher layers to the lower ones. Note that a different choice for these directions would yield a superficially different algorithm for computing the inverse of $\hat{F}$ that would nonetheless yield the same output.

For each $i$ , we will denote the conditional covariance matrix of $\text{vec}(\mathcal{DW}i)$ on $\text{vec}(\mathcal{DW}{i+1})$ by $\Sigma_{i|i+1}$ and the linear coefficients from $\text{vec}(\mathcal{DW}_{i+1})$ to $\text{vec}(\mathcal{DW}i)$ by the matrix $\Psi{i,i+1}$ , so that the conditional distributions defining the model are

vec(DWi)N(Ψi,i+1vec(DWi+1),Σii+1)andvec(DW)N(0,Σ)\text{vec}(\mathcal{DW}_i) \sim \mathcal{N}(\Psi_{i,i+1} \text{vec}(\mathcal{DW}_{i+1}), \Sigma_{i|i+1}) \quad \text{and} \quad \text{vec}(\mathcal{DW}_\ell) \sim \mathcal{N}(\vec{0}, \Sigma_\ell)

Since $\Sigma_\ell$ is just the covariance of $\text{vec}(\mathcal{DW}\ell)$ , it is given simply by $\hat{F}{\ell,\ell} = \tilde{F}{\ell,\ell}$ . And for $i \leq \ell - 1$ , we can see that $\Psi{i,i+1}$ is given by

Ψi,i+1=F^i,i+1F^i+1,i+11=F~i,i+1F~i+1,i+11=(Aˉi1,iGi,i+1)(Aˉi,iGi+1,i+1)1=Ψi1,iAˉΨi,i+1G\Psi_{i,i+1} = \hat{F}_{i,i+1} \hat{F}_{i+1,i+1}^{-1} = \tilde{F}_{i,i+1} \tilde{F}_{i+1,i+1}^{-1} = (\bar{A}_{i-1,i} \otimes G_{i,i+1}) (\bar{A}_{i,i} \otimes G_{i+1,i+1})^{-1} = \Psi_{i-1,i}^{\bar{A}} \otimes \Psi_{i,i+1}^Gwhere

Ψi1,iAˉ=Aˉi1,iAˉi,i1andΨi,i+1G=Gi,i+1Gi+1,i+11\Psi_{i-1,i}^{\bar{A}} = \bar{A}_{i-1,i} \bar{A}_{i,i}^{-1} \quad \text{and} \quad \Psi_{i,i+1}^G = G_{i,i+1} G_{i+1,i+1}^{-1}

The conditional covariance $\Sigma_{i|i+1}$ is thus given by

Σii+1=F^i,iΨi,i+1F^i+1,i+1Ψi,i+1T=F~i,iΨi,i+1F~i+1,i+1Ψi,i+1T=Aˉi1,i1Gi,iΨi1,iAˉAˉi,iΨi1,iAˉTΨi,i+1GGi+1,i+1Ψi,i+1GT\begin{aligned} \Sigma_{i|i+1} &= \hat{F}_{i,i} - \Psi_{i,i+1} \hat{F}_{i+1,i+1} \Psi_{i,i+1}^T = \tilde{F}_{i,i} - \Psi_{i,i+1} \tilde{F}_{i+1,i+1} \Psi_{i,i+1}^T \\ &= \bar{A}_{i-1,i-1} \otimes G_{i,i} - \Psi_{i-1,i}^{\bar{A}} \bar{A}_{i,i} \Psi_{i-1,i}^{\bar{A}T} \otimes \Psi_{i,i+1}^G G_{i+1,i+1} \Psi_{i,i+1}^{G^T} \end{aligned}

Following the work of Grosse and Salakhutdinov (2015), we use the block generalization of well-known “Cholesky” decomposition of the precision matrix of DGGMs (Pourahmadi, 1999), which gives

F^1=ΞTΛΞ\hat{F}^{-1} = \Xi^T \Lambda \Xi

where,

Λ=diag(Σ121,Σ231,,Σ11,Σ1)andΞ=[IΨ1,2IΨ2,3IΨ1,I]\Lambda = \text{diag} \left( \Sigma_{1|2}^{-1}, \Sigma_{2|3}^{-1}, \dots, \Sigma_{\ell-1|\ell}^{-1}, \Sigma_{\ell}^{-1} \right) \quad \text{and} \quad \Xi = \begin{bmatrix} I & -\Psi_{1,2} & & & \\ & I & -\Psi_{2,3} & & \\ & & I & \ddots & \\ & & & \ddots & -\Psi_{\ell-1,\ell} \\ & & & & I \end{bmatrix}

Thus, matrix-vector multiplication with $\hat{F}^{-1}$ amounts to performing matrix-vector multiplication by $\Xi$ , followed by $\Lambda$ , and then by $\Xi^T$ .

As in the block-diagonal case considered in the previous subsection, matrix-vector products with $\Xi$ (and $\Xi^T$ ) can be efficiently computed using the well-known identity $(A \otimes B)^{-1} = A^{-1} \otimes B^{-1}$ . In particular, $u = \Xi^T v$ can be computed as

Ui=ViΨi1,iGTVi1Ψi2,i1AˉandU1=V1U_i = V_i - \Psi_{i-1,i}^{G^T} V_{i-1} \Psi_{i-2,i-1}^{\bar{A}} \quad \text{and} \quad U_1 = V_1

and similarly $u = \Xi v$ can be computed as

Ui=ViΨi,i+1GVi+1Ψi1,iAˉTandU=VU_i = V_i - \Psi_{i,i+1}^G V_{i+1} \Psi_{i-1,i}^{\bar{A}T} \quad \text{and} \quad U_{\ell} = V_{\ell}

where the $U_i$ 's and $V_i$ 's are defined in terms of $u$ and $v$ as in the previous subsection.

Multiplying a vector $v$ by $\Lambda$ amounts to multiplying each $\text{vec}(V_i)$ by the corresponding $\Sigma_{i|i+1}^{-1}$ . This is slightly tricky because $\Sigma_{i|i+1}$ is the difference of Kronecker products, so we cannot use the straightforward identity $(A \otimes B)^{-1} = A^{-1} \otimes B^{-1}$ . Fortunately, there are efficient techniques for inverting such matrices which we discuss in detail in Appendix B.## 4.4 Examining the approximation quality

Figures 5 and 6 examine the quality of the approximations $\check{F}$ and $\hat{F}$ of $\tilde{F}$ , which are derived by approximating $\tilde{F}^{-1}$ as block-diagonal and block-tridiagonal (resp.), for an example network.

From Figure 5, which compares $\check{F}$ and $\hat{F}$ directly to $\tilde{F}$ , we can see that while $\check{F}$ and $\hat{F}$ exactly capture the diagonal and tridiagonal blocks (resp.) of $\tilde{F}$ , as they must by definition, $\hat{F}$ ends up approximating the off-tridiagonal blocks of $\tilde{F}$ very well too. This is likely owed to the fact that the approximating assumption used to derive $\hat{F}$ , that $\tilde{F}^{-1}$ is block-tridiagonal, is a very reasonable one in practice (judging by Figure 3).

Figure 6, which compares $\check{F}^{-1}$ and $\hat{F}^{-1}$ to $\tilde{F}^{-1}$ , paints an arguably more interesting and relevant picture, as the quality of the approximation of the natural gradient will be roughly proportional2 to the quality of approximation of the inverse Fisher. We can see from this figure that due to the approximate block-diagonal structure of $\tilde{F}^{-1}$ , $\check{F}^{-1}$ is actually a reasonably good approximation of $\tilde{F}^{-1}$ , despite $\check{F}$ being a rather poor approximation of $\tilde{F}$ (based on Figure 5). Meanwhile, we can see that by accounting for the tri-diagonal blocks, $\hat{F}^{-1}$ is indeed a significantly better approximation of $\tilde{F}^{-1}$ than $\check{F}^{-1}$ is, even on the diagonal blocks.


2The error in any approximation $F_0^{-1}\nabla h$ of the natural gradient $F^{-1}\nabla h$ will be roughly proportional to the error in the approximation $F_0^{-1}$ of the associated inverse Fisher $F^{-1}$ , since $|F^{-1}\nabla h - F_0^{-1}\nabla h| \leq |\nabla h| |F^{-1} - F_0^{-1}|$ .Figure 5: A comparison of our block-wise Kronecker-factored approximation $\tilde{F}$ , and its approximations $\check{F}$ and $\hat{F}$ (which are based on approximating the inverse $\tilde{F}^{-1}$ as either block-diagonal or block-tridiagonal, respectively), using the example neural network from Figure 2. On the left is $\tilde{F}$ , in the middle its approximation, and on the right is the absolute difference of these. The top row compares to $\check{F}$ and the bottom row compares to $\hat{F}$ . While the diagonal blocks of the top right matrix, and the tridiagonal blocks of the bottom right matrix are exactly zero due to how $\check{F}$ and $\hat{F}$ (resp.) are constructed, the off-tridiagonal blocks of the bottom right matrix, while being very close to zero, are actually non-zero (which is hard to see from the plot). Note that for the purposes of visibility we plot the absolute values of the entries, with the white level corresponding linearly to the size of these values (up to some maximum, which is the same in each image).Figure 6: A comparison of the exact inverse $\tilde{F}^{-1}$ of our block-wise Kronecker-factored approximation $\tilde{F}$ , and its block-diagonal and block-tridiagonal approximations $\check{F}^{-1}$ and $\hat{F}^{-1}$ (resp.), using the example neural network from Figure 2. On the left is $F^{-1}$ , in the middle its approximation, and on the right is the absolute difference of these. The top row compares to $\check{F}^{-1}$ and the bottom row compares to $\hat{F}^{-1}$ . The inverse was computed subject to the factored Tikhonov damping technique described in Sections 6.3 and 6.6, using the same value of $\gamma$ that was used by K-FAC at the iteration from which this example was taken (see Figure 2). Note that for the purposes of visibility we plot the absolute values of the entries, with the white level corresponding linearly to the size of these values (up to some maximum, which is the same in each image).

5 Estimating the required statistics

Recall that $\bar{A}{i,j} = \mathbb{E}[\bar{a}i \bar{a}j^\top]$ and $G{i,j} = \mathbb{E}[g_i g_j^\top]$ . Both approximate Fisher inverses discussed in Section 4 require some subset of these. In particular, the block-diagonal approximation requires them for $i = j$ , while the block-tridiagonal approximation requires them for $j \in {i, i + 1}$ (noting that $\bar{A}{i,j}^\top = \bar{A}{j,i}$ and $G_{i,j}^\top = G_{j,i}$ ).

Since the $\bar{a}_i$ 's don't depend on $y$ , we can take the expectation $\mathbb{E}[\bar{a}_i \bar{a}_j^\top]$ with respect to just the training distribution $\hat{Q}_x$ over the inputs $x$ . On the other hand, the $g_i$ 's do depend on $y$ , and so the expectation3 $\mathbb{E}[g_i g_j^\top]$ must be taken with respect to both $\hat{Q}_x$ and the network's predictive

3It is important to note this expectation should not be taken with respect to the training/data distribution over $y$ (i.e.distribution $P_{y|x}$ .

While computing matrix-vector products with the $G_{i,j}$ could be done exactly and efficiently for a given input $x$ (or small mini-batch of $x$ 's) by adapting the methods of Schraudolph (2002), there doesn't seem to be a sufficiently efficient method for computing the entire matrix itself. Indeed, the hardness results of Martens et al. (2012) suggest that this would require, for each example $x$ in the mini-batch, work that is asymptotically equivalent to matrix-matrix multiplication involving matrices the same size as $G_{i,j}$ . While a small constant number of such multiplications is arguably an acceptable cost (see Section 8), a number which grows with the size of the mini-batch would not be.

Instead, we will approximate the expectation over $y$ by a standard Monte-Carlo estimate obtained by sampling $y$ 's from the network's predictive distribution and then rerunning the backwards phase of backpropagation (see Algorithm 1) as if these were the training targets.

Note that computing/estimating the required $\bar{A}{i,j}/G{i,j}$ 's involves computing averages over outer products of various $\bar{a}_i$ 's from network's usual forward pass, and $g_i$ 's from the modified backwards pass (with targets sampled as above). Thus we can compute/estimate these quantities on the same input data used to compute the gradient $\nabla h$ , at the cost of one or more additional backwards passes, and a few additional outer-product averages. Fortunately, this turns out to be quite inexpensive, as we have found that just one modified backwards pass is sufficient to obtain a good quality estimate in practice, and the required outer-product averages are similar to those already used to compute the gradient in the usual backpropagation algorithm.

In the case of online/stochastic optimization we have found that the best strategy is to maintain running estimates of the required $\bar{A}{i,j}$ 's and $G{i,j}$ 's using a simple exponentially decaying averaging scheme. In particular, we take the new running estimate to be the old one weighted by $\epsilon$ , plus the estimate on the new mini-batch weighted by $1 - \epsilon$ , for some $0 \leq \epsilon < 1$ . In our experiments we used $\epsilon = \min{1 - 1/k, 0.95}$ , where $k$ is the iteration number.

Note that the more naive averaging scheme where the estimates from each iteration are given equal weight would be inappropriate here. This is because the $\bar{A}{i,j}$ 's and $G{i,j}$ 's depend on the network's parameters $\theta$ , and these will slowly change over time as optimization proceeds, so that estimates computed many iterations ago will become stale.

This kind of exponentially decaying averaging scheme is commonly used in methods involving diagonal or block-diagonal approximations (with much smaller blocks than ours) to the curvature matrix (e.g. LeCun et al., 1998; Park et al., 2000; Schaul et al., 2013). Such schemes have the desirable property that they allow the curvature estimate to depend on much more data than can be


$\hat{Q}{y|x}$ or $Q{y|x}$ ). Using the training/data distribution for $y$ would perhaps give an approximation to a quantity known as the “empirical Fisher information matrix”, which lacks the previously discussed equivalence to the Generalized Gauss-Newton matrix, and would not be compatible with the theoretical analysis performed in Section 3.1 (in particular, Lemma 4 would break down). Moreover, such a choice would not give rise to what is usually thought of as the natural gradient, and based on the findings of Martens (2010), would likely perform worse in practice as part of an optimization algorithm. See Martens (2014) for a more detailed discussion of the empirical Fisher and reasons why it may be a poor choice for a curvature matrix compared to the standard Fisher.reasonably processed in a single mini-batch.

Notably, for methods like HF which deal with the exact Fisher indirectly via matrix-vector products, such a scheme would be impossible to implement efficiently, as the exact Fisher matrix (or GGN) seemingly cannot be summarized using a compact data structure whose size is independent of the amount of data used to estimate it. Indeed, it seems that the only representation of the exact Fisher which would be independent of the amount of data used to estimate it would be an explicit $n \times n$ matrix (which is far too big to be practical). Because of this, HF and related methods must base their curvature estimates only on subsets of data that can be reasonably processed all at once, which limits their effectiveness in the stochastic optimization regime.

6 Update damping

6.1 Background and motivation

The idealized natural gradient approach is to follow the smooth path in the Riemannian manifold (implied by the Fisher information matrix viewed as a metric tensor) that is generated by taking a series of infinitesimally small steps (in the original parameter space) in the direction of the natural gradient (which gets recomputed at each point). While this is clearly impractical as a real optimization method, one can take larger steps and still follow these paths approximately. But in our experience, to obtain an update which satisfies the minimal requirement of not worsening the objective function value, it is often the case that one must make the step size so small that the resulting optimization algorithm performs poorly in practice.

The reason that the natural gradient can only be reliably followed a short distance is that it is defined merely as an optimal direction (which trades off improvement in the objective versus change in the predictive distribution), and not a discrete update.

Fortunately, as observed by Martens (2014), the natural gradient can be understood using a more traditional optimization-theoretic perspective which implies how it can be used to generate updates that will be useful over larger distances. In particular, when $R_{y|z}$ is an exponential family model with $z$ as its natural parameters (as it will be in our experiments), Martens (2014) showed that the Fisher becomes equivalent to the Generalized Gauss-Newton matrix (GGN), which is a positive semi-definite approximation of the Hessian of $h$ . Additionally, there is the well-known fact that when $L(x, f(x, \theta))$ is the negative log-likelihood function associated with a given $(x, y)$ pair (as we are assuming in this work), the Hessian $H$ of $h$ and the Fisher $F$ are closely related in the sense $H$ is the expected Hessian of $L$ under the training distribution $\hat{Q}{x,y}$ , while $F$ is the expected Hessian of $L$ under the model's distribution $P{x,y}$ (defined by the density $p(x, y) = p(y|x)q(x)$ ).

From these observations it follows that

M(δ)=12δFδ+h(θ)δ+h(θ)(5)M(\delta) = \frac{1}{2} \delta^\top F \delta + \nabla h(\theta)^\top \delta + h(\theta) \quad (5)can be viewed as a convex approximation of the 2nd-order Taylor series of expansion of $h(\delta + \theta)$ , whose minimizer $\delta^*$ is the (negative) natural gradient $-F^{-1}\nabla h(\theta)$ . Note that if we add an $\ell_2$ or “weight-decay” regularization term to $h$ of the form $\frac{\eta}{2}|\theta|_2^2$ , then similarly $F + \eta I$ can be viewed as an approximation of the Hessian of $h$ , and replacing $F$ with $F + \eta I$ in $M(\delta)$ yields an approximation of the 2nd-order Taylor series, whose minimizer is a kind of “regularized” (negative) natural gradient $-(F + \eta I)^{-1}\nabla h(\theta)$ , which is what we end up using in practice.

From the interpretation of the natural gradient as the minimizer of $M(\delta)$ , we can see that it fails to be useful as a local update only insofar as $M(\delta)$ fails to be a good local approximation to $h(\delta + \theta)$ . And so as argued by Martens (2014), it is natural to make use of the various “damping” techniques that have been developed in the optimization literature for dealing with the breakdowns in local quadratic approximations that inevitably occur during optimization. Notably, this breakdown usually won’t occur in the final “local convergence” stage of optimization where the function becomes well approximated as a convex quadratic within a sufficiently large neighborhood of the local optimum. This is the phase traditionally analyzed in most theoretical results, and while it is important that an optimizer be able to converge well in this final phase, it is arguably much more important from a practical standpoint that it behaves sensibly before this phase.

This initial “exploration phase” (Darken and Moody, 1990) is where damping techniques help in ways that are not apparent from the asymptotic convergence theorems alone, which is not to say there are not strong mathematical arguments that support their use (see Nocedal and Wright, 2006). In particular, in the exploration phase it will often still be true that $h(\theta + \delta)$ is accurately approximated by a convex quadratic locally within some region around $\delta = 0$ , and that therefore optimization can be most efficiently performed by minimizing a sequence of such convex quadratic approximations within adaptively sized local regions.

Note that well designed damping techniques, such as the ones we will employ, automatically adapt to the local properties of the function, and effectively “turn themselves off” when the quadratic model becomes a sufficiently accurate local approximation of $h$ , allowing the optimizer to achieve the desired asymptotic convergence behavior (Moré, 1978).

Successful and theoretically well-founded damping techniques include Tikhonov damping (aka Tikhonov regularization, which is closely connected to the trust-region method) with Levenberg-Marquardt style adaptation (Moré, 1978), line-searches, and trust regions, truncation, etc., all of which tend to be much more effective in practice than merely applying a learning rate to the update, or adding a fixed multiple of the identity to the curvature matrix. Indeed, a subset of these techniques was exploited in the work of Martens (2010), and primitive versions of them have appeared implicitly in older works such as Becker and LeCun (1989), and also in many recent diagonal methods like that of Zeiler (2013), although often without a good understanding of what they are doing and why they help.

Crucially, more powerful 2nd-order optimizers like HF and K-FAC, which have the capability of taking much larger steps than 1st-order methods (or methods which use diagonal curvature matrices), require more sophisticated damping solutions to work well, and will usually completely fail without them, which is consistent with predictions made in various theoretical analyses (e.g.Nocedal and Wright, 2006). As an analogy one can think of such powerful 2nd-order optimizers as extremely fast racing cars that need more sophisticated control systems than standard cars to prevent them from flying off the road. Arguably one of the reasons why high-powered 2nd-order optimization methods have historically tended to under-perform in machine learning applications, and in neural network training in particular, is that their designers did not understand or take seriously the issue of quadratic model approximation quality, and did not employ the more sophisticated and effective damping techniques that are available to deal with this issue.

For a detailed review and discussion of various damping techniques and their crucial role in practical 2nd-order optimization methods, we refer the reader to Martens and Sutskever (2012).

6.2 A highly effective damping scheme for K-FAC

Methods like HF which use the exact Fisher seem to work reasonably well with an adaptive Tikhonov regularization technique where $\lambda I$ is added to $F + \eta I$ , and where $\lambda$ is adapted according to Levenberg-Marquardt style adjustment rule. This common and well-studied method can be shown to be equivalent to imposing an adaptive spherical region (known as a “trust region”) which constrains the optimization of the quadratic model (e.g Nocedal and Wright, 2006). However, we found that this simple technique is insufficient when used with our approximate natural gradient update proposals. In particular, we have found that there never seems to be a “good” choice for $\lambda$ that gives rise to updates which are of a quality comparable to those produced by methods that use the exact Fisher, such as HF.

One possible explanation for this finding is that, unlike quadratic models based on the exact Fisher (or equivalently, the GGN), the one underlying K-FAC has no guarantee of being accurate up to 2nd-order. Thus, $\lambda$ must remain large in order to compensate for this intrinsic 2nd-order inaccuracy of the model, which has the side effect of “washing out” the small eigenvalues (which represent important low-curvature directions).

Fortunately, through trial and error, we were able to find a relatively simple and highly effective damping scheme, which combines several different techniques, and which works well within K-FAC. Our scheme works by computing an initial update proposal using a version of the above described adaptive Tikhonov damping/regularization method, and then re-scaling this according to quadratic model computed using the exact Fisher. This second step is made practical by the fact that it only requires a single matrix-vector product with the exact Fisher, and this can be computed efficiently using standard methods. We discuss the details of this scheme in the following subsections.

6.3 A factored Tikhonov regularization technique

In the first stage of our damping scheme we generate a candidate update proposal $\Delta$ by applying a slightly modified form of Tikhonov damping to our approximate Fisher, before multiplying $-\nabla h$by its inverse.

In the usual Tikhonov regularization/damping technique, one adds $(\lambda + \eta)I$ to the curvature matrix (where $\eta$ accounts for the $\ell_2$ regularization), which is equivalent to adding a term of the form $\frac{\lambda + \eta}{2} |\delta|_2^2$ to the corresponding quadratic model (given by $M(\delta)$ with $F$ replaced by our approximation). For the block-diagonal approximation $\check{F}$ of $\tilde{F}$ (from Section 4.2) this amounts to adding $(\lambda + \eta)I$ (for a lower dimensional $I$ ) to each of the individual diagonal blocks, which gives modified diagonal blocks of the form

Aˉi1,i1Gi,i+(λ+η)I=Aˉi1,i1Gi,i+(λ+η)II(6)\bar{A}_{i-1,i-1} \otimes G_{i,i} + (\lambda + \eta)I = \bar{A}_{i-1,i-1} \otimes G_{i,i} + (\lambda + \eta)I \otimes I \quad (6)

Because this is the sum of two Kronecker products we cannot use the simple identity $(A \otimes B)^{-1} = A^{-1} \otimes B^{-1}$ anymore. Fortunately however, there are efficient techniques for inverting such matrices, which we discuss in detail in Appendix B.

If we try to apply this same Tikhonov technique to our more sophisticated approximation $\hat{F}$ of $\tilde{F}$ (from Section 4.3) by adding $(\lambda + \eta)I$ to each of the diagonal blocks of $\hat{F}$ , it is no longer clear how to efficiently invert $\hat{F}$ . Instead, a solution which we have found works very well in practice (and which we also use for the block-diagonal approximation $\check{F}$ ), is to add $\pi_i(\sqrt{\lambda + \eta})I$ and $\frac{1}{\pi_i}(\sqrt{\lambda + \eta})I$ for a scalar constant $\pi_i$ to the individual Kronecker factors $\bar{A}{i-1,i-1}$ and $G{i,i}$ (resp.) of each diagonal block, giving

(Aˉi1,i1+πi(λ+η)I)(Gi,i+1πi(λ+η)I)(7)\left( \bar{A}_{i-1,i-1} + \pi_i(\sqrt{\lambda + \eta})I \right) \otimes \left( G_{i,i} + \frac{1}{\pi_i}(\sqrt{\lambda + \eta})I \right) \quad (7)

As this is a single Kronecker product, all of the computations described in Sections 4.2 and 4.3 can still be used here too, simply by replacing each $\bar{A}{i-1,i-1}$ and $G{i,i}$ with their modified versions $\bar{A}{i-1,i-1} + \pi_i(\sqrt{\lambda + \eta})I$ and $G{i,i} + \frac{1}{\pi_i}(\sqrt{\lambda + \eta})I$ .

To see why the expression in eqn. 7 is a reasonable approximation to eqn. 6, note that expanding it gives

Aˉi1,i1Gi,i+πi(λ+η)IGi,i+1πi(λ+η)Aˉi1,i1I+(λ+η)II\bar{A}_{i-1,i-1} \otimes G_{i,i} + \pi_i(\sqrt{\lambda + \eta})I \otimes G_{i,i} + \frac{1}{\pi_i}(\sqrt{\lambda + \eta})\bar{A}_{i-1,i-1} \otimes I + (\lambda + \eta)I \otimes I

which differs from eqn. 6 by the residual error expression

πi(λ+η)IGi,i+1πi(λ+η)Aˉi1,i1I\pi_i(\sqrt{\lambda + \eta})I \otimes G_{i,i} + \frac{1}{\pi_i}(\sqrt{\lambda + \eta})\bar{A}_{i-1,i-1} \otimes I

While the choice of $\pi_i = 1$ is simple and can sometimes work well in practice, a slightly more principled choice can be found by minimizing the obvious upper bound (following from the triangle inequality) on the matrix norm of this residual expression, for some matrix norm $|\cdot|_v$ .This gives

πi=Aˉi1,i1IvIGi,iv\pi_i = \sqrt{\frac{\|\bar{A}_{i-1,i-1} \otimes I\|_v}{\|I \otimes G_{i,i}\|_v}}

Evaluating this expression can be done efficiently for various common choices of the matrix norm $|\cdot|_v$ . For example, for a general $B$ we have $|I \otimes B|_F = |B \otimes I|_F = \sqrt{d}|B|_F$ where $d$ is the height/dimension of $I$ , and also $|I \otimes B|_2 = |B \otimes I|_2 = |B|_2$ .

In our experience, one of the best and most robust choices for the norm $|\cdot|_v$ is the trace-norm, which for PSD matrices is given by the trace. With this choice, the formula for $\pi_i$ has the following simple form:

πi=tr(Aˉi1,i1)/(di1+1)tr(Gi,i)/di\pi_i = \sqrt{\frac{\text{tr}(\bar{A}_{i-1,i-1})/(d_{i-1} + 1)}{\text{tr}(G_{i,i})/d_i}}

where $d_i$ is the dimension (number of units) in layer $i$ . Intuitively, the inner fraction is just the average eigenvalue of $\bar{A}{i-1,i-1}$ divided by the average eigenvalue of $G{i,i}$ .

Interestingly, we have found that this factored approximate Tikhonov approach, which was originally motivated by computational concerns, often works better than the exact version (eqn. 6) in practice. The reasons for this are still somewhat mysterious to us, but it may have to do with the fact that the inverse of the product of two quantities is often most robustly estimated as the inverse of the product of their individually regularized estimates.

6.4 Re-scaling according to the exact $F$

Given an update proposal $\Delta$ produced by multiplying the negative gradient $-\nabla h$ by our approximate Fisher inverse (subject to the Tikhonov technique described in the previous subsection), the second stage of our proposed damping scheme re-scales $\Delta$ according to the quadratic model $M$ as computed with the exact $F$ , to produce a final update $\delta = \alpha\Delta$ .

More precisely, we optimize $\alpha$ according to the value of the quadratic model

M(δ)=M(αΔ)=α22Δ(F+(λ+η)I)Δ+αhΔ+h(θ)M(\delta) = M(\alpha\Delta) = \frac{\alpha^2}{2}\Delta^\top(F + (\lambda + \eta)I)\Delta + \alpha\nabla h^\top\Delta + h(\theta)

as computed using an estimate of the exact Fisher $F$ (to which we add the $\ell_2$ regularization + Tikhonov term $(\lambda + \eta)I$ ). Because this is a 1-dimensional quadratic minimization problem, the formula for the optimal $\alpha$ can be computed very efficiently as

α=hΔΔ(F+(λ+η)I)Δ=hΔΔFΔ+(λ+η)Δ22\alpha^* = \frac{-\nabla h^\top\Delta}{\Delta^\top(F + (\lambda + \eta)I)\Delta} = \frac{-\nabla h^\top\Delta}{\Delta^\top F\Delta + (\lambda + \eta)\|\Delta\|_2^2}To evaluate this formula we use the current stochastic gradient $\nabla h$ (i.e. the same one used to produce $\Delta$ ), and compute matrix-vector products with $F$ using the input data from the same mini-batch. While using a mini-batch to compute $F$ gets away from the idea of basing our estimate of the curvature on a long history of data (as we do with our approximate Fisher), it is made slightly less objectionable by the fact that we are only using it to estimate a single scalar quantity ( $\Delta^\top F \Delta$ ). This is to be contrasted with methods like HF which perform a long and careful optimization of $M(\delta)$ using such an estimate of $F$ .

Because the matrix-vector products with $F$ are only used to compute scalar quantities in K-FAC, we can reduce their computational cost by roughly one half (versus standard matrix-vector products with $F$ ) using a simple trick which is discussed in Appendix C.

Intuitively, this second stage of our damping scheme effectively compensates for the intrinsic inaccuracy of the approximate quadratic model (based on our approximate Fisher) used to generate the initial update proposal $\Delta$ , by essentially falling back on a more accurate quadratic model based on the exact Fisher.

Interestingly, by re-scaling $\Delta$ according to $M(\delta)$ , K-FAC can be viewed as a version of HF which uses our approximate Fisher as a preconditioning matrix (instead of the traditional diagonal preconditioner), and runs CG for only 1 step, initializing it from 0. This observation suggests running CG for longer, thus obtaining an algorithm which is even closer to HF (although using a much better preconditioner for CG). Indeed, this approach works reasonably well in our experience, but suffers from some of the same problems that HF has in the stochastic setting, due its much stronger use of the mini-batch–estimated exact $F$ .

Figure 7 demonstrates the effectiveness of this re-scaling technique versus the simpler method of just using the raw $\Delta$ as an update proposal. We can see that $\Delta$ , without being re-scaled, is a very poor update to $\theta$ , and won’t even give any improvement in the objective function unless the strength of the factored Tikhonov damping terms is made very large. On the other hand, when the update is re-scaled, we can afford to compute $\Delta$ using a much smaller strength for the factored Tikhonov damping terms, and overall this yields a much larger and more effective update to $\theta$ .

6.5 Adapting $\lambda$

Tikhonov damping can be interpreted as implementing a trust-region constraint on the update $\delta$ , so that in particular the constraint $|\delta| \leq r$ is imposed for some $r$ , where $r$ depends on $\lambda$ and the curvature matrix (e.g. Nocedal and Wright, 2006). While some approaches adjust $r$ and then seek to find the matching $\lambda$ , it is often simpler just to adjust $\lambda$ directly, as the precise relationship between $\lambda$ and $r$ is complicated, and the curvature matrix is constantly evolving as optimization takes place.

The theoretically well-founded Levenberg-Marquardt style rule used by HF for doing this, which we will adopt for K-FAC, is given by

if ρ>3/4 then λω1λ\text{if } \rho > 3/4 \text{ then } \lambda \leftarrow \omega_1 \lambdaFigure 7: A comparison of the effectiveness of the proposed damping scheme, with and without the re-scaling techniques described in Section 6.4. The network used for this comparison is the one produced at iteration 500 by K-FAC (with the block-tridiagonal inverse approximation) on the MNIST autoencoder problem described in Section 13. The y-axis is the improvement in the objective function $h$ (i.e. $h(\theta) - h(\theta + \delta)$ ) produced by the update $\delta$ , while the x-axis is the strength constant used in the factored Tikhonov damping technique (which is denoted by “ $\gamma$ ” as described in Section 6.6). In the legend, “no moment.” indicates that the momentum technique developed for K-FAC in Section 7 (which relies on the use of re-scaling) was not used.

if ρ<1/4 then λ1ω1λ\text{if } \rho < 1/4 \text{ then } \lambda \leftarrow \frac{1}{\omega_1} \lambda

where $\rho \equiv \frac{h(\theta + \delta) - h(\theta)}{M(\delta) - M(0)}$ is the “reduction ratio” and $0 < \omega_1 < 1$ is some decay constant, and all quantities are computed on the current mini-batch (and $M$ uses the exact $F$ ).

Intuitively, this rule tries to make $\lambda$ as small as possible (and hence the implicit trust-region as large as possible) while maintaining the property that the quadratic model $M(\delta)$ remains a good local approximation to $h$ (in the sense that it accurately predicts the value of $h(\theta + \delta)$ for the $\delta$ which gets chosen at each iteration). It has the desirable property that as the optimization enters the final convergence stage where $M$ becomes an almost exact approximation in a sufficiently large neighborhood of the local minimum, the value of $\lambda$ will go rapidly enough towards 0 that it doesn’t interfere with the asymptotic local convergence theory enjoyed by 2nd-order methods (Moré, 1978).

In our experiments we applied this rule every $T_1$ iterations of K-FAC, with $\omega_1 = (19/20)^{T_1}$ and $T_1 = 5$ , from a starting value of $\lambda = 150$ . Note that the optimal value of $\omega_1$ and the starting value of $\lambda$ may be application dependent, and setting them inappropriately could significantly slow down K-FAC in practice.

Computing $\rho$ can be done quite efficiently. Note that for the optimal $\delta$ , $M(\delta) - M(0) = \frac{1}{2} \nabla h^\top \delta$ , and $h(\theta)$ is available from the usual forward pass. The only remaining quantity which is needed to evaluate $\rho$ is thus $h(\theta + \delta)$ , which will require an additional forward pass. But fortunately, we only need to perform this once every $T_1$ iterations.## 6.6 Maintaining a separate damping strength for the approximate Fisher

While the scheme described in the previous sections works reasonably well in most situations, we have found that in order to avoid certain failure cases and to be truly robust in a large variety of situations, the Tikhonov damping strength parameter for the factored Tikhonov technique described in Section 6.3 should be maintained and adjusted independently of $\lambda$ . To this end we replace the expression $\sqrt{\lambda + \eta}$ in Section 6.3 with a separate constant $\gamma$ , which we initialize to $\sqrt{\lambda + \eta}$ but which is then adjusted using a different rule, which is described at the end of this section.

The reasoning behind this modification is as follows. The role of $\lambda$ , according to the Levenberg Marquardt theory (Moré, 1978), is to be as small as possible while maintaining the property that the quadratic model $M$ remains a trust-worthy approximation of the true objective. Meanwhile, $\gamma$ 's role is to ensure that the initial update proposal $\Delta$ is as good an approximation as possible to the true optimum of $M$ (as computed using a mini-batch estimate of the exact $F$ ), so that in particular the re-scaling performed in Section 6.4 is as benign as possible. While one might hope that adding the same multiple of the identity to our approximate Fisher as we do to the exact $F$ (as it appears in $M$ ) would produce the best $\Delta$ in this regard, this isn't obviously the case. In particular, using a larger multiple may help compensate for the approximation we are making to the Fisher when computing $\Delta$ , and thus help produce a more “conservative” but ultimately more useful initial update proposal $\Delta$ , which is what we observe happens in practice.

A simple measure of the quality of our choice of $\gamma$ is the (negative) value of the quadratic model $M(\delta) = M(\alpha\Delta)$ for the optimally chosen $\alpha$ . To adjust $\gamma$ based on this measure (or others like it) we use a simple greedy adjustment rule. In particular, every $T_2$ iterations during the optimization we try 3 different values of $\gamma$ ( $\gamma_0$ , $\omega_2\gamma_0$ , and $(1/\omega_2)\gamma_0$ , where $\gamma_0$ is the current value) and choose the new $\gamma$ to be the best of these, as measured by our quality metric. In our experiments we used $T_2 = 20$ (which must be a multiple of the constant $T_3$ as defined in Section 8), and $\omega_2 = (\sqrt{19/20})^{T_2}$ .

We have found that $M(\delta)$ works well in practice as a measure of the quality of $\gamma$ , and has the added bonus that it can be computed at essentially no additional cost from the incidental quantities already computed when solving for the optimal $\alpha$ . In our initial experiments we found that using it gave similar results to those obtained by using other obvious measures for the quality of $\gamma$ , such as $h(\theta + \delta)$ .

7 Momentum

Sutskever et al. (2013) found that momentum (Polyak, 1964; Plaut et al., 1986) was very helpful in the context of stochastic gradient descent optimization of deep neural networks. A version of momentum is also present in the original HF method, and it plays an arguably even more important role in more “stochastic” versions of HF (Martens and Sutskever, 2012; Kiros, 2013).

A natural way of adding momentum to K-FAC, and one which we have found works wellin practice, is to take the update to be $\delta = \alpha\Delta + \mu\delta_0$ , where $\delta_0$ is the final update computed at the previous iteration, and where $\alpha$ and $\mu$ are chosen to minimize $M(\delta)$ . This allows K-FAC to effectively build up a better solution to the local quadratic optimization problem $\min_{\delta} M(\delta)$ (where $M$ uses the exact $F$ ) over many iterations, somewhat similarly to how Matrix Momentum (Scarpetta et al., 1999) and HF do this (see Sutskever et al., 2013).

The optimal solution for $\alpha$ and $\mu$ can be computed as

[αμ]=[ΔFΔ+(λ+η)Δ22ΔFδ0+(λ+η)Δδ0ΔFδ0+(λ+η)Δδ0δ0Fδ0+(λ+η)δ022]1[hΔhδ0]\begin{bmatrix} \alpha^* \\ \mu^* \end{bmatrix} = - \begin{bmatrix} \Delta^\top F \Delta + (\lambda + \eta) \|\Delta\|_2^2 & \Delta^\top F \delta_0 + (\lambda + \eta) \Delta^\top \delta_0 \\ \Delta^\top F \delta_0 + (\lambda + \eta) \Delta^\top \delta_0 & \delta_0^\top F \delta_0 + (\lambda + \eta) \|\delta_0\|_2^2 \end{bmatrix}^{-1} \begin{bmatrix} \nabla h^\top \Delta \\ \nabla h^\top \delta_0 \end{bmatrix}

The main cost in evaluating this formula is computing the two matrix-vector products $F\Delta$ and $F\delta_0$ . Fortunately, the technique discussed in Appendix C can be applied here to compute the 4 required scalars at the cost of only two forwards passes (equivalent to the cost of only one matrix-vector product with $F$ ).

Empirically we have found that this type of momentum provides substantial acceleration in regimes where the gradient signal has a low noise to signal ratio, which is usually the case in the early to mid stages of stochastic optimization, but can also be the case in later stages if the mini-batch size is made sufficiently large. These findings are consistent with predictions made by convex optimization theory, and with older empirical work done on neural network optimization (LeCun et al., 1998).

Notably, because the implicit “momentum decay constant” $\mu$ in our method is being computed on the fly, one doesn’t have to worry about setting schedules for it, or adjusting it via heuristics, as one often does in the context of SGD.

Interestingly, if $h$ is a quadratic function (so the definition of $M(\delta)$ remains fixed at each iteration) and all quantities are computed deterministically (i.e. without noise), then using this type of momentum makes K-FAC equivalent to performing preconditioned linear CG on $M(\delta)$ , with the preconditioner given by our approximate Fisher. This follows from the fact that linear CG can be interpreted as a momentum method where the learning rate $\alpha$ and momentum decay coefficient $\mu$ are chosen to jointly minimize $M(\delta)$ at the current iteration.

8 Computational Costs and Efficiency Improvements

Let $d$ be the typical number of units in each layer and $m$ the mini-batch size. The significant computational tasks required to compute a single update/iteration of K-FAC, and rough estimates of their associated computational costs, are as follows:

    1. standard forwards and backwards pass: $2C_1\ell d^2m$
    1. computation of the gradient $\nabla h$ on the current mini-batch using quantities computed in backwards pass: $C_2\ell d^2m$1. 3. additional backwards pass with random targets (as described in Section 5): $C_1\ell d^2m$
    1. updating the estimates of the required $\bar{A}{i,j}$ 's and $G{i,j}$ 's from quantities computed in the forwards pass and the additional randomized backwards pass: $2C_2\ell d^2m$
    1. matrix inverses (or SVDs for the block-tridiagonal inverse, as described in Appendix B) required to compute the inverse of the approximate Fisher: $C_3\ell d^3$ for the block-diagonal inverse, $C_4\ell d^3$ for the block-tridiagonal inverse
    1. various matrix-matrix products required to compute the matrix-vector product of the approximate inverse with the stochastic gradient: $C_5\ell d^3$ for the block-diagonal inverse, $C_6\ell d^3$ for the block-tridiagonal inverse
    1. matrix-vector products with the exact $F$ on the current mini-batch using the approach in Appendix C: $4C_1\ell d^2m$ with momentum, $2C_1\ell d^2m$ without momentum
    1. additional forward pass required to evaluate the reduction ratio $\rho$ needed to apply the $\lambda$ adjustment rule described in Section 6.5: $C_1\ell d^2m$ every $T_1$ iterations

Here the $C_i$ are various constants that account for implementation details, and we are assuming the use of the naive cubic matrix-matrix multiplication and inversion algorithms when producing the cost estimates. Note that it is hard to assign precise values to the constants, as they very much depend on how these various tasks are implemented.

Note that most of the computations required for these tasks will be sped up greatly by performing them in parallel across units, layers, training cases, or all of these. The above cost estimates however measure sequential operations, and thus may not accurately reflect the true computation times enjoyed by a parallel implementation. In our experiments we used a vectorized implementation that performed the computations in parallel over units and training cases, although not over layers (which is possible for computations that don't involve a sequential forwards or backwards "pass" over the layers).

Tasks 1 and 2 represent the standard stochastic gradient computation.

The costs of tasks 3 and 4 are similar and slightly smaller than those of tasks 1 and 2. One way to significantly reduce them is to use a random subset of the current mini-batch of size $\tau_1 m$ to update the estimates of the required $\bar{A}{i,j}$ 's and $G{i,j}$ 's. One can similarly reduce the cost of task 7 by computing the (factored) matrix-vector product with $F$ using such a subset of size $\tau_2 m$ , although we recommend proceeding with caution when doing this, as using inconsistent sets of data for the quadratic and linear terms in $M(\delta)$ can hypothetically cause instability problems which are avoided by using consistent data (see Martens and Sutskever (2012), Section 13.1). In our experiments in Section 13 we used $\tau_1 = 1/8$ and $\tau_2 = 1/4$ , which seemed to have a negligible effect on the quality of the resultant updates, while significantly reducing per-iteration computation time. In a separate set of unreported experiments we found that in certain situations, such as when $\ell_2$ regularization isn't used and the network starts heavily overfitting the data, or when smaller mini-batches were used, we had to revert to using $\tau_2 = 1$ to prevent significant deterioration in the quality of the updates.The cost of task 8 can be made relatively insignificant by making the adjustment period $T_1$ for $\lambda$ large enough. We used $T_1 = 5$ in our experiments.

The costs of tasks 5 and 6 are hard to compare directly with the costs associated with computing the gradient, as their relative sizes will depend on factors such as the architecture of the neural network being trained, as well as the particulars of the implementation. However, one quick observation we can make is that both tasks 5 and 6 involve computations that be performed in parallel across the different layers, which is to be contrasted with many of the other tasks which require sequential passes over the layers of the network.

Clearly, if $m \gg d$ , then the cost of tasks 5 and 6 becomes negligible in comparison to the others. However, it is more often the case that $m$ is comparable or perhaps smaller than $d$ . Moreover, while algorithms for inverses and SVDs tend to have the same asymptotic cost as matrix-matrix multiplication, they are at least several times more expensive in practice, in addition to being harder to parallelize on modern GPU architectures (indeed, CPU implementations are often faster in our experience). Thus, $C_3$ and $C_4$ will typically be (much) larger than $C_5$ and $C_6$ , and so in a basic/naive implementation of K-FAC, task 5 can dominate the overall per-iteration cost.

Fortunately, there are several possible ways to mitigate the cost of task 5. As mentioned above, one way is to perform the computations for each layer in parallel, and even simultaneously with the gradient computation and other tasks. In the case of our block-tridiagonal approximation to the inverse, one can avoid computing any SVDs or matrix square roots by using an iterative Stein-equation solver (see Appendix B). And there are also ways of reducing matrix-inversion (and even matrix square-root) to a short sequence of matrix-matrix multiplications using iterative methods (Pan and Schreiber, 1991). Furthermore, because the matrices in question only change slowly over time, one can consider hot-starting these iterative inversion methods from previous solutions. In the extreme case where $d$ is very large, one can also consider using low-rank + diagonal approximations of the $\bar{A}{i,j}$ and $G{i,j}$ matrices maintained online (e.g. using a similar strategy as Le Roux et al. (2008)) from which inverses and/or SVDs can be more easily computed. Although based on our experience such approximations can, in some cases, lead to a substantial degradation in the quality of the updates.

While these ideas work reasonably well in practice, perhaps the simplest method, and the one we ended up settling on for our experiments, is to simply recompute the approximate Fisher inverse only every $T_3$ iterations (we used $T_3 = 20$ in our experiments). As it turns out, the curvature of the objective stays relatively stable during optimization, especially in the later stages, and so in our experience this strategy results in only a modest decrease in the quality of the updates.

If $m$ is much smaller than $d$ , the costs associated with task 6 can begin to dominate (provided $T_3$ is sufficiently large so that the cost of task 5 is relatively small). And unlike task 5, task 6 must be performed at every iteration. While the simplest solution is to increase $m$ (while reaping the benefits of a less noisy gradient), in the case of the block-diagonal inverse it turns out that we can change the cost of task 6 from $C_5 \ell d^3$ to $C_5 \ell d^2 m$ by taking advantage of the low-rank structure of the stochastic gradient. The method for doing this is described below.

Xet Storage Details

Size:
86.6 kB
·
Xet hash:
4758e560d22b663fb9b26ba8f8c11f059015f31e2b03f00573d5f17a32377245

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.