1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""The Shampoo Optimizer.
17
18Variant of Adagrad using one preconditioner matrix per variable dimension.
19For details, see https://arxiv.org/abs/1802.09568
20"""
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import numpy as np
26from tensorflow.contrib.opt.python.training import matrix_functions
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import linalg_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import state_ops
34from tensorflow.python.platform import tf_logging
35from tensorflow.python.training import optimizer
36
37
38def GetParam(var, timestep):
39  if callable(var):
40    return var(timestep)
41  else:
42    return var
43
44
45class ShampooOptimizer(optimizer.Optimizer):
46  """The Shampoo Optimizer
47
48  Variant of Adagrad using one preconditioner matrix per variable dimension.
49  For details, see https://arxiv.org/abs/1802.09568
50
51  gbar is time-weighted accumulated gradient:
52  gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t]
53
54  mat_gbar is time-weighted accumulated gradient square:
55  mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1]
56                  + mat_gbar_weight[t] * gg_j[t]
57  where if g[t] = g_abcd then gg_a[t] = g_abcd g_a'bcd (Einstein notation)
58
59  Update rule:
60  w[t+1] = w[t] - learning_rate[t] * Prod_j mat_gbar_j[t]^(-alpha/n) gbar[t]
61     Again, mat_gbar_j[t]^(-alpha) gbar[t] is a tensor contraction along the
62     j'th dimension of gbar[t] with the first dimension of
63     mat_gbar_j[t]^(-alpha/n), where alpha is a hyperparameter,
64     and n = rank of the variable.
65     Prod_j represents doing this contraction for all j in 0..n-1.
66
67  Typically learning_rate is constant, but could be time dependent by passing
68  a lambda function that depends on step.
69  """
70
71  def __init__(self,
72               global_step=0,
73               max_matrix_size=768,
74               gbar_decay=0.0,
75               gbar_weight=1.0,
76               mat_gbar_decay=1.0,
77               mat_gbar_weight=1.0,
78               learning_rate=1.0,
79               svd_interval=1,
80               precond_update_interval=1,
81               epsilon=1e-4,
82               alpha=0.5,
83               use_iterative_root=False,
84               use_locking=False,
85               name="Shampoo"):
86    """Default values of the various hyper-parameters.
87
88    gbar_decay, gbar_weight etc. can be a float or a time varying parameter.
89    For time-varying parameters use e.g. "lambda T: T / (T + 1.0)"
90    where the expression in the lambda is a tensorflow expression
91
92    Args:
93      global_step: tensorflow variable indicating the step.
94      max_matrix_size: We do not perform SVD for matrices larger than this.
95      gbar_decay:
96      gbar_weight:  Used to update gbar:
97            gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t]
98      mat_gbar_decay:
99      mat_gbar_weight:  Used to update mat_gbar:
100           mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1]
101                           + mat_gbar_weight[t] * gg_j[t]
102      learning_rate: Similar to SGD
103      svd_interval: We should do SVD after this many steps. Default = 1, i.e.
104                    every step. Usually 20 leads to no loss of accuracy, and
105                    50 or 100 is also OK. May also want more often early,
106                    and less often later - set in caller as for example:
107                    "svd_interval = lambda(T): tf.cond(
108                        T < 2000, lambda: 20.0, lambda: 1000.0)"
109      precond_update_interval: We should update the preconditioners after
110                               this many steps. Default = 1. Usually less than
111                               svd_interval.
112      epsilon:  epsilon * I_n is added to each mat_gbar_j for stability for
113                non-diagonal version of shampoo.
114      alpha:  total power of the preconditioners.
115      use_iterative_root: should the optimizer use SVD (faster) or the
116                          iterative root method (for TPU) for finding the
117                          roots of PSD matrices.
118      use_locking:
119      name: name of optimizer.
120    """
121
122    super(ShampooOptimizer, self).__init__(use_locking, name)
123
124    self._global_step = math_ops.cast(global_step, dtypes.float32)
125    self._max_matrix_size = max_matrix_size
126    self._gbar_decay = gbar_decay
127    self._gbar_weight = gbar_weight
128    self._mat_gbar_decay = mat_gbar_decay
129    self._mat_gbar_weight = mat_gbar_weight
130    self._learning_rate = learning_rate
131    self._svd_interval = svd_interval
132    self._precond_update_interval = precond_update_interval
133    self._epsilon = epsilon
134    self._alpha = alpha
135    self._use_iterative_root = use_iterative_root
136    self._name = name
137
138  def _create_slots(self, var_list):
139    for v in var_list:
140      with ops.colocate_with(v):
141        _ = self._zeros_slot(v, "gbar", self._name)
142        shape = np.array(v.get_shape())
143        for i, d in enumerate(shape):
144          d_tensor = ops.convert_to_tensor(d)
145          if d <= self._max_matrix_size:
146            mat_g_init = array_ops.zeros_like(linalg_ops.eye(d_tensor))
147            if self._svd_interval > 1:
148              _ = self._get_or_make_slot(v, linalg_ops.eye(d_tensor),
149                                         "H_" + str(i), self._name)
150          else:
151            mat_g_init = array_ops.zeros([d_tensor])
152
153          _ = self._get_or_make_slot(v, mat_g_init, "Gbar_" + str(i),
154                                     self._name)
155
156  def _resource_apply_dense(self, grad, var):
157    return self._apply_dense(grad, var)
158
159  def _apply_dense(self, grad, var):
160    return self._apply_gradient(grad, var)
161
162  def _resource_apply_sparse(self, grad_values, var, grad_indices):
163    return self._apply_sparse_shared(grad_values, grad_indices, var)
164
165  def _apply_sparse(self, grad, var):
166    return self._apply_sparse_shared(grad.values, grad.indices, var)
167
168  def _apply_sparse_shared(self, grad_values, grad_indices, var):
169    if var.get_shape()[0] <= self._max_matrix_size or self._gbar_decay != 0.0:
170      # The dimension is small enough, we can make the variable dense and
171      # do a dense update
172      dense_grad = array_ops.scatter_nd(
173          array_ops.expand_dims(grad_indices, axis=1), grad_values,
174          array_ops.shape(var, out_type=grad_indices.dtype))
175      return self._apply_gradient(dense_grad, var)
176    return self._apply_gradient(grad_values, var, grad_indices)
177
178  def _weighted_average(self, var, weight, weight_t, rest):
179    """Computes exponential weighted average: var = weight_t * var + rest.
180
181    Important to ensure that var does not occur in rest, otherwise
182    we can get race conditions in a distributed setting.
183
184    Args:
185      var: variable to be updated
186      weight: parameter to be checked. If it is a constant, we can optimize.
187      weight_t: current value of parameter, used for weighting
188      rest: the remaining tensor to be added
189
190    Returns:
191      updated variable.
192    """
193    if weight == 0.0:
194      return rest       # no need to update var, we will never use it.
195    if weight == 1.0:   # common case
196      return state_ops.assign_add(var, rest)
197    # The op below can cause race conditions in a distributed setting,
198    # since computing weight_t * var + rest can take some time, during
199    # which var may be set by another worker. To prevent this, it should
200    # be implemented as a C++ op.
201    return var.assign_add((weight_t - 1) * var + rest)
202
203  def _update_mat_g(self, mat_g, grad, axes, mat_gbar_decay,
204                    mat_gbar_weight, i):
205    """Updates the cumulative outer products of the gradients.
206
207    Args:
208      mat_g: the matrix to be updated
209      grad: the gradient of the variable
210      axes: a list of k-1 integers 0 to k-1, except i
211      mat_gbar_decay: constant for weighted average:
212          mat_g = mat_g * decay + grad * weight
213      mat_gbar_weight: constant for weighted average
214      i: index of dimension to be updated.
215
216    Returns:
217      updated mat_g = mat_g * mat_gbar_decay + grad_outer * mat_gbar_weight
218
219    In Einstein notation if i = 0: grad_outer_aa'= g_abcd g_a'bcd
220    thus grad_outer is a matrix d_i x d_i, where d_i is the size of the
221    i'th dimension of g.
222    Alternate view: If mat_i(grad) is the flattening of grad to a
223    d_i x (d_1d_2...d_{i-1}d_{i+1}...d_k) matrix, then
224         grad_outer = mat_i(grad) mat_i(grad).transpose
225    """
226    grad_outer = math_ops.tensordot(grad, grad, axes=(axes, axes),
227                                    name="grad_outer_" + str(i))
228    return self._weighted_average(mat_g, self._mat_gbar_decay, mat_gbar_decay,
229                                  mat_gbar_weight * grad_outer)
230
231  def _compute_power_svd(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name):
232    """Computes mat_h = mat_g^alpha using svd. mat_g is a symmetric PSD matrix.
233
234    Args:
235      var: the variable we are updating.
236      mat_g: the symmetric PSD matrix whose power it to be computed
237      mat_g_size: size of mat_g
238      alpha: a real number
239      mat_h_slot_name: name of slot to store the power, if needed.
240
241    Returns:
242      mat_h = mat_g^alpha
243
244    Stores mat_h in the appropriate slot, if it exists.
245    Note that mat_g is PSD. So we could use linalg_ops.self_adjoint_eig.
246    """
247    if mat_g_size == 1:
248      mat_h = math_ops.pow(mat_g + self._epsilon, alpha)
249    else:
250      damping = self._epsilon * linalg_ops.eye(
251          math_ops.cast(mat_g_size, dtypes.int32))
252      diag_d, mat_u, mat_v = linalg_ops.svd(mat_g + damping, full_matrices=True)
253      mat_h = math_ops.matmul(
254          mat_v * math_ops.pow(math_ops.maximum(diag_d, self._epsilon), alpha),
255          array_ops.transpose(mat_u))
256    if mat_h_slot_name is not None:
257      return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h)
258    return mat_h
259
260  def _compute_power_iter(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name,
261                          iter_count=100, epsilon=1e-6):
262    """Computes mat_g^alpha, where alpha = -1/p, p a positive integer."""
263
264    mat_g_sqrt = matrix_functions.matrix_square_root(mat_g, mat_g_size,
265                                                     iter_count, self._epsilon)
266    mat_h = matrix_functions.matrix_inverse_pth_root(
267        mat_g_sqrt,
268        mat_g_size,
269        2 * alpha,
270        iter_count,
271        epsilon,
272        ridge_epsilon=0.0)
273
274    if mat_h_slot_name is not None:
275      return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h)
276    return mat_h
277
278  def _compute_power(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name=None):
279    """Just a switch between the iterative power vs svd."""
280    with ops.name_scope("matrix_iterative_power"):
281      if self._use_iterative_root:
282        return self._compute_power_iter(var, mat_g, mat_g_size, alpha,
283                                        mat_h_slot_name)
284      else:
285        return self._compute_power_svd(var, mat_g, mat_g_size, alpha,
286                                       mat_h_slot_name)
287
288  def _apply_gradient(self, grad, var, indices=None):
289    """The main function to update a variable.
290
291    Args:
292      grad: A Tensor containing gradient to apply.
293      var: A Tensor containing the variable to update.
294      indices: An array of integers, for sparse update.
295
296    Returns:
297      Updated variable var = var - learning_rate * preconditioner * grad
298
299    If the gradient is dense, var and grad have the same shape.
300    If the update is sparse, then the first dimension of the gradient and var
301    may differ, others are all the same. In this case the indices array
302    provides the set of indices of the variable which are to be updated with
303    each row of the gradient.
304    """
305    global_step = self._global_step + 1
306
307    # Update accumulated weighted average of gradients
308    gbar = self.get_slot(var, "gbar")
309    gbar_decay_t = GetParam(self._gbar_decay, global_step)
310    gbar_weight_t = GetParam(self._gbar_weight, global_step)
311    if indices is not None:
312      # Note - the sparse update is not easily implemented, since the
313      # algorithm needs all indices of gbar to be updated
314      # if mat_gbar_decay != 1 or mat_gbar_decay != 0.
315      # One way to make mat_gbar_decay = 1 is by rescaling.
316      # If we want the update:
317      #         G_{t+1} = a_{t+1} G_t + b_{t+1} w_t
318      # define:
319      #         r_{t+1} = a_{t+1} * r_t
320      #         h_t = G_t / r_t
321      # Then:
322      #         h_{t+1} = h_t + (b_{t+1} / r_{t+1}) * w_t
323      # So we get the mat_gbar_decay = 1 as desired.
324      # We can implement this in a future version as needed.
325      # However we still need gbar_decay = 0, otherwise all indices
326      # of the variable will need to be updated.
327      if self._gbar_decay != 0.0:
328        tf_logging.warning("Not applying momentum for variable: %s" % var.name)
329      gbar_updated = grad
330    else:
331      gbar_updated = self._weighted_average(gbar, self._gbar_decay,
332                                            gbar_decay_t,
333                                            gbar_weight_t * grad)
334
335    # Update the preconditioners and compute the preconditioned gradient
336    shape = var.get_shape()
337    mat_g_list = []
338    for i in range(len(shape)):
339      mat_g_list.append(self.get_slot(var, "Gbar_" + str(i)))
340    mat_gbar_decay_t = GetParam(self._mat_gbar_decay, global_step)
341    mat_gbar_weight_t = GetParam(self._mat_gbar_weight, global_step)
342
343    preconditioned_grad = gbar_updated
344    v_rank = len(mat_g_list)
345    neg_alpha = - GetParam(self._alpha, global_step) / v_rank
346    svd_interval = GetParam(self._svd_interval, global_step)
347    precond_update_interval = GetParam(self._precond_update_interval,
348                                       global_step)
349    for i, mat_g in enumerate(mat_g_list):
350      # axes is the list of indices to reduce - everything but the current i.
351      axes = list(range(i)) + list(range(i+1, v_rank))
352      if shape[i] <= self._max_matrix_size:
353        # If the tensor size is sufficiently small perform full Shampoo update
354        # Note if precond_update_interval > 1 and mat_gbar_decay_t != 1, this
355        # is not strictly correct. However we will use it for now, and
356        # fix if needed. (G_1 = aG + bg ==> G_n = a^n G + (1+a+..+a^{n-1})bg)
357
358        # pylint: disable=g-long-lambda,cell-var-from-loop
359        mat_g_updated = control_flow_ops.cond(
360            math_ops.mod(global_step, precond_update_interval) < 1,
361            lambda: self._update_mat_g(
362                mat_g, grad, axes, mat_gbar_decay_t,
363                mat_gbar_weight_t * precond_update_interval, i),
364            lambda: mat_g)
365
366        mat_g_updated = mat_g_updated / float(shape[i].value)
367
368        if self._svd_interval == 1:
369          mat_h = self._compute_power(var, mat_g_updated, shape[i], neg_alpha)
370        else:
371          mat_h = control_flow_ops.cond(
372              math_ops.mod(global_step, svd_interval) < 1,
373              lambda: self._compute_power(var, mat_g_updated, shape[i],
374                                          neg_alpha, "H_" + str(i)),
375              lambda: self.get_slot(var, "H_" + str(i)))
376
377        # mat_h is a square matrix of size d_i x d_i
378        # preconditioned_grad is a d_i x ... x d_n x d_0 x ... d_{i-1} tensor
379        # After contraction with a d_i x d_i tensor
380        # it becomes a d_{i+1} x ... x d_n x d_0 x ... d_i tensor
381        # (the first dimension is contracted out, and the second dimension of
382        # mat_h is appended).  After going through all the indices, it becomes
383        # a d_0 x ... x d_n tensor again.
384        preconditioned_grad = math_ops.tensordot(preconditioned_grad, mat_h,
385                                                 axes=([0], [0]),
386                                                 name="precond_" + str(i))
387      else:
388        # Tensor size is too large -- perform diagonal Shampoo update
389        # Only normalize non-vector cases.
390        if axes:
391          normalizer = 1.0 if indices is not None else float(shape[i].value)
392          grad_outer = math_ops.reduce_sum(grad * grad, axis=axes) / normalizer
393        else:
394          grad_outer = grad * grad
395
396        if i == 0 and indices is not None:
397          assert self._mat_gbar_decay == 1.0
398          mat_g_updated = state_ops.scatter_add(mat_g, indices,
399                                                mat_gbar_weight_t * grad_outer)
400          mat_g_updated_slice = array_ops.gather(mat_g_updated, indices)
401          mat_h = array_ops.where(
402              math_ops.greater(mat_g_updated_slice, 0),
403              math_ops.pow(mat_g_updated_slice, neg_alpha),
404              array_ops.zeros_like(mat_g_updated_slice))
405        else:
406          mat_g_updated = self._weighted_average(mat_g,
407                                                 self._mat_gbar_decay,
408                                                 mat_gbar_decay_t,
409                                                 mat_gbar_weight_t * grad_outer)
410          mat_h = array_ops.where(
411              math_ops.greater(mat_g_updated, 0),
412              math_ops.pow(mat_g_updated, neg_alpha),
413              array_ops.zeros_like(mat_g_updated))
414
415        # Need to do the transpose to ensure that the tensor becomes
416        # a d_{i+1} x ... x d_n x d_0 x ... d_i tensor as described above.
417        preconditioned_grad = array_ops.transpose(
418            preconditioned_grad, perm=list(range(1, v_rank)) + [0]) * mat_h
419
420    # Update the variable based on the Shampoo update
421    learning_rate_t = GetParam(self._learning_rate, global_step)
422    if indices is not None:
423      var_updated = state_ops.scatter_add(
424          var, indices, -learning_rate_t * preconditioned_grad)
425    else:
426      var_updated = state_ops.assign_sub(var,
427                                         learning_rate_t * preconditioned_grad)
428    return var_updated
429