1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Support for scaled softplus, a smoothed version of ReLU."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import function
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import gen_array_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import nn
27
28
29def _reduce_and_reshape_grad(g, t):
30  """Returns the gradient, sum-reduced and reshaped to `t`'s shape."""
31  shape = array_ops.shape(t)
32  g_shape = array_ops.shape(g)
33  bcast_dims, _ = gen_array_ops.broadcast_gradient_args(shape, g_shape)
34  return array_ops.reshape(math_ops.reduce_sum(g, bcast_dims), shape)
35
36
37def scaled_softplus(x, alpha, clip=None, name=None):
38  """Returns `y = alpha * ln(1 + exp(x / alpha))` or `min(y, clip)`.
39
40  This can be seen as a softplus applied to the scaled input, with the output
41  appropriately scaled. As `alpha` tends to 0, `scaled_softplus(x, alpha)` tends
42  to `relu(x)`. The clipping is optional. As alpha->0, scaled_softplus(x, alpha)
43  tends to relu(x), and scaled_softplus(x, alpha, clip=6) tends to relu6(x).
44
45  Note: the gradient for this operation is defined to depend on the backprop
46  inputs as well as the outputs of this operation.
47
48  Args:
49    x: A `Tensor` of inputs.
50    alpha: A `Tensor`, indicating the amount of smoothness. The caller
51        must ensure that `alpha > 0`.
52    clip: (optional) A `Tensor`, the upper bound to clip the values.
53    name: A name for the scope of the operations (optional).
54
55  Returns:
56    A tensor of the size and type determined by broadcasting of the inputs.
57
58  """
59  clipping = clip is not None
60  with ops.name_scope(name, 'scaled_softplus',
61                      [x, alpha] + ([clip] if clipping else [])):
62    x = ops.convert_to_tensor(x, name='x')
63    dtype = x.dtype
64    alpha = ops.convert_to_tensor(alpha, dtype=dtype, name='alpha')
65    # Compute the forward value.
66    y = alpha * nn.softplus(x / alpha)
67    if clipping:
68      clip = ops.convert_to_tensor(clip, dtype=dtype, name='clip')
69      y = math_ops.minimum(y, clip)
70
71    def _grad(op, g):
72      """Backprop for scaled softplus, with optional clipping."""
73      y, x, alpha = op.inputs[:3]
74      # Prevent the memory-expensive computations from happening before g is
75      # available.
76      with ops.control_dependencies([g]):
77        y = array_ops.identity(y)
78      clip_grad = []
79      if clipping:
80        clip = op.inputs[3]
81        unclipped = math_ops.cast(y < clip, g.dtype)
82        clip_grad = [_reduce_and_reshape_grad(g * (1. - unclipped), clip)]
83        g *= unclipped
84      y /= alpha
85      emy = math_ops.exp(-y)
86      dy_dx = 1. - emy
87      # The eps below avoids log(0). Note that t*log(t) -> 0 as t->0.
88      eps = 1e-8
89      dy_dalpha = y * emy - dy_dx * math_ops.log(dy_dx + eps)
90      # Backprop to the actual inputs, but not to the output.
91      return [None,
92              _reduce_and_reshape_grad(g * dy_dx, x),
93              _reduce_and_reshape_grad(g * dy_dalpha, alpha)] + clip_grad
94
95    if clipping:
96      @function.Defun(dtype, dtype, dtype, dtype,
97                      func_name='ScaledSoftplusHelper_clip_%s' % dtype.name,
98                      shape_func=lambda op: [op.inputs[0].shape],
99                      python_grad_func=_grad)
100      def _forward_helper_clip(y, x, alpha, clip):
101        del x, alpha, clip  # Unused.
102        return y
103      return _forward_helper_clip(y, x, alpha, clip)
104    # No clipping.
105    @function.Defun(dtype, dtype, dtype,
106                    func_name='ScaledSoftplusHelper_%s' % dtype.name,
107                    shape_func=lambda op: [op.inputs[0].shape],
108                    python_grad_func=_grad)
109    def _forward_helper(y, x, alpha):
110      del x, alpha  # Unused.
111      return y
112    return _forward_helper(y, x, alpha)
113
114