1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Optimizer ops for use in layers and tf.learn."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import six
22
23from tensorflow.contrib import framework as contrib_framework
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import clip_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import init_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import random_ops
32from tensorflow.python.ops import variable_scope as vs
33from tensorflow.python.ops import variables as vars_
34from tensorflow.python.summary import summary
35from tensorflow.python.training import moving_averages
36from tensorflow.python.training import optimizer as optimizer_
37from tensorflow.python.training import training as train
38
39OPTIMIZER_CLS_NAMES = {
40    "Adagrad": train.AdagradOptimizer,
41    "Adam": train.AdamOptimizer,
42    "Ftrl": train.FtrlOptimizer,
43    "Momentum": lambda learning_rate: train.MomentumOptimizer(learning_rate, momentum=0.9),  # pylint: disable=line-too-long
44    "RMSProp": train.RMSPropOptimizer,
45    "SGD": train.GradientDescentOptimizer,
46}
47
48OPTIMIZER_SUMMARIES = [
49    "learning_rate",
50    "loss",
51    "gradients",
52    "gradient_norm",
53    "global_gradient_norm",
54]
55
56
57def optimize_loss(loss,
58                  global_step,
59                  learning_rate,
60                  optimizer,
61                  gradient_noise_scale=None,
62                  gradient_multipliers=None,
63                  clip_gradients=None,
64                  learning_rate_decay_fn=None,
65                  update_ops=None,
66                  variables=None,
67                  name=None,
68                  summaries=None,
69                  colocate_gradients_with_ops=False,
70                  increment_global_step=True):
71  """Given loss and parameters for optimizer, returns a training op.
72
73  Various ways of passing optimizers include:
74
75  - by string specifying the name of the optimizer. See OPTIMIZER_CLS_NAMES
76      for full list. E.g. `optimize_loss(..., optimizer='Adam')`.
77  - by function taking learning rate `Tensor` as argument and returning an
78      `Optimizer` instance. E.g. `optimize_loss(...,
79      optimizer=lambda lr: tf.train.MomentumOptimizer(lr, momentum=0.5))`.
80    Alternatively, if `learning_rate` is `None`, the function takes no
81    arguments. E.g. `optimize_loss(..., learning_rate=None,
82      optimizer=lambda: tf.train.MomentumOptimizer(0.5, momentum=0.5))`.
83  - by a subclass of `Optimizer` having a single-argument constructor
84      (the argument is the learning rate), such as AdamOptimizer or
85      AdagradOptimizer. E.g. `optimize_loss(...,
86      optimizer=tf.train.AdagradOptimizer)`.
87  - by an instance of a subclass of `Optimizer`.
88      E.g., `optimize_loss(..., optimizer=tf.train.AdagradOptimizer(0.5))`.
89
90  Args:
91    loss: Scalar `Tensor`.
92    global_step: Scalar int `Tensor`, step counter to update on each step
93                 unless `increment_global_step` is `False`. If not supplied,
94                 it will be fetched from the default graph (see
95                 `tf.train.get_global_step` for details). If it has
96                 not been created, no step will be incremented with each weight
97                 update. `learning_rate_decay_fn` requires `global_step`.
98    learning_rate: float or `Tensor`, magnitude of update per each training
99                   step. Can be `None`.
100    optimizer: string, class or optimizer instance, used as trainer.
101               string should be name of optimizer, like 'SGD',
102                 'Adam', 'Adagrad'. Full list in OPTIMIZER_CLS_NAMES constant.
103               class should be sub-class of `tf.Optimizer` that implements
104                 `compute_gradients` and `apply_gradients` functions.
105               optimizer instance should be instantiation of `tf.Optimizer`
106                 sub-class and have `compute_gradients` and `apply_gradients`
107                 functions.
108    gradient_noise_scale: float or None, adds 0-mean normal noise scaled by this
109                          value.
110    gradient_multipliers: dict of variables or variable names to floats.
111                          If present, gradients for specified
112                          variables will be multiplied by given constant.
113    clip_gradients: float, callable or `None`. If a float is provided, a global
114      clipping is applied to prevent the norm of the gradient from exceeding
115      this value. Alternatively, a callable can be provided, e.g.,
116      `adaptive_clipping_fn()`.  This callable takes a list of
117      `(gradients, variables)` tuples and returns the same thing with the
118      gradients modified.
119    learning_rate_decay_fn: function, takes `learning_rate` and `global_step`
120                            `Tensor`s, returns `Tensor`.
121                            Can be used to implement any learning rate decay
122                            functions.
123                            For example: `tf.train.exponential_decay`.
124                            Ignored if `learning_rate` is not supplied.
125    update_ops: list of update `Operation`s to execute at each step. If `None`,
126                uses elements of UPDATE_OPS collection. The order of execution
127                between `update_ops` and `loss` is non-deterministic.
128    variables: list of variables to optimize or
129               `None` to use all trainable variables.
130    name: The name for this operation is used to scope operations and summaries.
131    summaries: List of internal quantities to visualize on tensorboard. If not
132               set, the loss, the learning rate, and the global norm of the
133               gradients will be reported. The complete list of possible values
134               is in OPTIMIZER_SUMMARIES.
135    colocate_gradients_with_ops: If True, try colocating gradients with the
136                                 corresponding op.
137    increment_global_step: Whether to increment `global_step`. If your model
138      calls `optimize_loss` multiple times per training step (e.g. to optimize
139      different parts of the model), use this arg to avoid incrementing
140      `global_step` more times than necessary.
141
142  Returns:
143    Training op.
144
145  Raises:
146    ValueError: if:
147        * `loss` is an invalid type or shape.
148        * `global_step` is an invalid type or shape.
149        * `learning_rate` is an invalid type or value.
150        * `optimizer` has the wrong type.
151        * `clip_gradients` is neither float nor callable.
152        * `learning_rate` and `learning_rate_decay_fn` are supplied, but no
153          `global_step` is available.
154        * `gradients` is empty.
155  """
156  loss = ops.convert_to_tensor(loss)
157  contrib_framework.assert_scalar(loss)
158  if global_step is None:
159    global_step = train.get_global_step()
160  else:
161    train.assert_global_step(global_step)
162  with vs.variable_scope(name, "OptimizeLoss", [loss, global_step]):
163    # Update ops take UPDATE_OPS collection if not provided.
164    if update_ops is None:
165      update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
166    # Make sure update ops are ran before computing loss.
167    if update_ops:
168      loss = control_flow_ops.with_dependencies(list(update_ops), loss)
169
170    # Learning rate variable, with possible decay.
171    lr = None
172    if learning_rate is not None:
173      if (isinstance(learning_rate, ops.Tensor) and
174          learning_rate.get_shape().ndims == 0):
175        lr = learning_rate
176      elif isinstance(learning_rate, float):
177        if learning_rate < 0.0:
178          raise ValueError("Invalid learning_rate %s.", learning_rate)
179        lr = vs.get_variable(
180            "learning_rate", [],
181            trainable=False,
182            initializer=init_ops.constant_initializer(learning_rate))
183      else:
184        raise ValueError("Learning rate should be 0d Tensor or float. "
185                         "Got %s of type %s" % (str(learning_rate),
186                                                str(type(learning_rate))))
187    if summaries is None:
188      summaries = ["loss", "learning_rate", "global_gradient_norm"]
189    else:
190      for summ in summaries:
191        if summ not in OPTIMIZER_SUMMARIES:
192          raise ValueError("Summaries should be one of [%s], you provided %s." %
193                           (", ".join(OPTIMIZER_SUMMARIES), summ))
194    if learning_rate is not None and learning_rate_decay_fn is not None:
195      if global_step is None:
196        raise ValueError("global_step is required for learning_rate_decay_fn.")
197      lr = learning_rate_decay_fn(lr, global_step)
198      if "learning_rate" in summaries:
199        summary.scalar("learning_rate", lr)
200
201    # Create optimizer, given specified parameters.
202    if isinstance(optimizer, six.string_types):
203      if lr is None:
204        raise ValueError("Learning rate is None, but should be specified if "
205                         "optimizer is string (%s)." % optimizer)
206      if optimizer not in OPTIMIZER_CLS_NAMES:
207        raise ValueError(
208            "Optimizer name should be one of [%s], you provided %s." %
209            (", ".join(OPTIMIZER_CLS_NAMES), optimizer))
210      opt = OPTIMIZER_CLS_NAMES[optimizer](learning_rate=lr)
211    elif (isinstance(optimizer, type) and
212          issubclass(optimizer, optimizer_.Optimizer)):
213      if lr is None:
214        raise ValueError("Learning rate is None, but should be specified if "
215                         "optimizer is class (%s)." % optimizer)
216      opt = optimizer(learning_rate=lr)
217    elif isinstance(optimizer, optimizer_.Optimizer):
218      opt = optimizer
219    elif callable(optimizer):
220      if learning_rate is not None:
221        opt = optimizer(lr)
222      else:
223        opt = optimizer()
224      if not isinstance(opt, optimizer_.Optimizer):
225        raise ValueError("Unrecognized optimizer: function should return "
226                         "subclass of Optimizer. Got %s." % str(opt))
227    else:
228      raise ValueError("Unrecognized optimizer: should be string, "
229                       "subclass of Optimizer, instance of "
230                       "subclass of Optimizer or function with one argument. "
231                       "Got %s." % str(optimizer))
232
233    # All trainable variables, if specific variables are not specified.
234    if variables is None:
235      variables = vars_.trainable_variables()
236
237    # Compute gradients.
238    gradients = opt.compute_gradients(
239        loss,
240        variables,
241        colocate_gradients_with_ops=colocate_gradients_with_ops)
242
243    # Optionally add gradient noise.
244    if gradient_noise_scale is not None:
245      gradients = _add_scaled_noise_to_gradients(gradients,
246                                                 gradient_noise_scale)
247
248    # Multiply some gradients.
249    if gradient_multipliers is not None:
250      gradients = _multiply_gradients(gradients, gradient_multipliers)
251      if not gradients:
252        raise ValueError(
253            "Empty list of (gradient, var) pairs encountered. This is most "
254            "likely to be caused by an improper value of gradient_multipliers.")
255
256    if "global_gradient_norm" in summaries or "gradient_norm" in summaries:
257      summary.scalar("global_norm/gradient_norm",
258                     clip_ops.global_norm(list(zip(*gradients))[0]))
259
260    # Optionally clip gradients by global norm.
261    if isinstance(clip_gradients, float):
262      gradients = _clip_gradients_by_norm(gradients, clip_gradients)
263    elif callable(clip_gradients):
264      gradients = clip_gradients(gradients)
265    elif clip_gradients is not None:
266      raise ValueError(
267          "Unknown type %s for clip_gradients" % type(clip_gradients))
268
269    # Add scalar summary for loss.
270    if "loss" in summaries:
271      summary.scalar("loss", loss)
272
273    # Add histograms for variables, gradients and gradient norms.
274    for gradient, variable in gradients:
275      if isinstance(gradient, ops.IndexedSlices):
276        grad_values = gradient.values
277      else:
278        grad_values = gradient
279
280      if grad_values is not None:
281        var_name = variable.name.replace(":", "_")
282        if "gradients" in summaries:
283          summary.histogram("gradients/%s" % var_name, grad_values)
284        if "gradient_norm" in summaries:
285          summary.scalar("gradient_norm/%s" % var_name,
286                         clip_ops.global_norm([grad_values]))
287
288    if clip_gradients is not None and ("global_gradient_norm" in summaries or
289                                       "gradient_norm" in summaries):
290      summary.scalar("global_norm/clipped_gradient_norm",
291                     clip_ops.global_norm(list(zip(*gradients))[0]))
292
293    # Create gradient updates.
294    grad_updates = opt.apply_gradients(
295        gradients,
296        global_step=global_step if increment_global_step else None,
297        name="train")
298
299    # Ensure the train_tensor computes grad_updates.
300    train_tensor = control_flow_ops.with_dependencies([grad_updates], loss)
301
302    return train_tensor
303
304
305def _clip_gradients_by_norm(grads_and_vars, clip_gradients):
306  """Clips gradients by global norm."""
307  gradients, variables = zip(*grads_and_vars)
308  clipped_gradients, _ = clip_ops.clip_by_global_norm(gradients, clip_gradients)
309  return list(zip(clipped_gradients, variables))
310
311
312def _adaptive_max_norm(norm, std_factor, decay, global_step, epsilon, name):
313  """Find max_norm given norm and previous average."""
314  with vs.variable_scope(name, "AdaptiveMaxNorm", [norm]):
315    log_norm = math_ops.log(norm + epsilon)
316
317    def moving_average(name, value, decay):
318      moving_average_variable = vs.get_variable(
319          name,
320          shape=value.get_shape(),
321          dtype=value.dtype,
322          initializer=init_ops.zeros_initializer(),
323          trainable=False)
324      return moving_averages.assign_moving_average(
325          moving_average_variable, value, decay, zero_debias=False)
326
327    # quicker adaptation at the beginning
328    if global_step is not None:
329      n = math_ops.cast(global_step, dtypes.float32)
330      decay = math_ops.minimum(decay, n / (n + 1.))
331
332    # update averages
333    mean = moving_average("mean", log_norm, decay)
334    sq_mean = moving_average("sq_mean", math_ops.square(log_norm), decay)
335
336    variance = sq_mean - math_ops.square(mean)
337    std = math_ops.sqrt(math_ops.maximum(epsilon, variance))
338    max_norms = math_ops.exp(mean + std_factor * std)
339    return max_norms, mean
340
341
342def adaptive_clipping_fn(std_factor=2.,
343                         decay=0.95,
344                         static_max_norm=None,
345                         global_step=None,
346                         report_summary=False,
347                         epsilon=1e-8,
348                         name=None):
349  """Adapt the clipping value using statistics on the norms.
350
351  Implement adaptive gradient as presented in section 3.2.1 of
352  https://arxiv.org/abs/1412.1602.
353
354  Keeps a moving average of the mean and std of the log(norm) of the gradient.
355  If the norm exceeds `exp(mean + std_factor*std)` then all gradients will be
356  rescaled such that the global norm becomes `exp(mean)`.
357
358  Args:
359    std_factor: Python scaler (or tensor).
360      `max_norm = exp(mean + std_factor*std)`
361    decay: The smoothing factor of the moving averages.
362    static_max_norm: If provided, will threshold the norm to this value as an
363      extra safety.
364    global_step: Optional global_step. If provided, `decay = decay*n/(n+1)`.
365      This provides a quicker adaptation of the mean for the first steps.
366    report_summary: If `True`, will add histogram summaries of the `max_norm`.
367    epsilon: Small value chosen to avoid zero variance.
368    name: The name for this operation is used to scope operations and summaries.
369
370  Returns:
371    A function for applying gradient clipping.
372  """
373
374  def gradient_clipping(grads_and_vars):
375    """Internal function for adaptive clipping."""
376    grads, variables = zip(*grads_and_vars)
377
378    norm = clip_ops.global_norm(grads)
379
380    max_norm, log_mean = _adaptive_max_norm(norm, std_factor, decay,
381                                            global_step, epsilon, name)
382
383    # reports the max gradient norm for debugging
384    if report_summary:
385      summary.scalar("global_norm/adaptive_max_gradient_norm", max_norm)
386
387    # factor will be 1. if norm is smaller than max_norm
388    factor = array_ops.where(norm < max_norm,
389                             array_ops.ones_like(norm),
390                             math_ops.exp(log_mean) / norm)
391
392    if static_max_norm is not None:
393      factor = math_ops.minimum(static_max_norm / norm, factor)
394
395    # apply factor
396    clipped_grads = []
397    for grad in grads:
398      if grad is None:
399        clipped_grads.append(None)
400      elif isinstance(grad, ops.IndexedSlices):
401        clipped_grads.append(
402            ops.IndexedSlices(grad.values * factor, grad.indices,
403                              grad.dense_shape))
404      else:
405        clipped_grads.append(grad * factor)
406
407    return list(zip(clipped_grads, variables))
408
409  return gradient_clipping
410
411
412def _add_scaled_noise_to_gradients(grads_and_vars, gradient_noise_scale):
413  """Adds scaled noise from a 0-mean normal distribution to gradients."""
414  gradients, variables = zip(*grads_and_vars)
415  noisy_gradients = []
416  for gradient in gradients:
417    if gradient is None:
418      noisy_gradients.append(None)
419      continue
420    if isinstance(gradient, ops.IndexedSlices):
421      gradient_shape = gradient.dense_shape
422    else:
423      gradient_shape = gradient.get_shape()
424    noise = random_ops.truncated_normal(gradient_shape) * gradient_noise_scale
425    noisy_gradients.append(gradient + noise)
426  return list(zip(noisy_gradients, variables))
427
428
429def _multiply_gradients(grads_and_vars, gradient_multipliers):
430  """Multiply specified gradients."""
431  multiplied_grads_and_vars = []
432  for grad, var in grads_and_vars:
433    if (grad is not None and
434        (var in gradient_multipliers or var.name in gradient_multipliers)):
435      key = var if var in gradient_multipliers else var.name
436      multiplier = gradient_multipliers[key]
437      if isinstance(grad, ops.IndexedSlices):
438        grad_values = grad.values * multiplier
439        grad = ops.IndexedSlices(grad_values, grad.indices, grad.dense_shape)
440      else:
441        grad *= math_ops.cast(multiplier, grad.dtype)
442    multiplied_grads_and_vars.append((grad, var))
443  return multiplied_grads_and_vars
444