1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Support for scaled softplus, a smoothed version of ReLU.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import function 22from tensorflow.python.framework import ops 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import gen_array_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops import nn 27 28 29def _reduce_and_reshape_grad(g, t): 30 """Returns the gradient, sum-reduced and reshaped to `t`'s shape.""" 31 shape = array_ops.shape(t) 32 g_shape = array_ops.shape(g) 33 bcast_dims, _ = gen_array_ops.broadcast_gradient_args(shape, g_shape) 34 return array_ops.reshape(math_ops.reduce_sum(g, bcast_dims), shape) 35 36 37def scaled_softplus(x, alpha, clip=None, name=None): 38 """Returns `y = alpha * ln(1 + exp(x / alpha))` or `min(y, clip)`. 39 40 This can be seen as a softplus applied to the scaled input, with the output 41 appropriately scaled. As `alpha` tends to 0, `scaled_softplus(x, alpha)` tends 42 to `relu(x)`. The clipping is optional. As alpha->0, scaled_softplus(x, alpha) 43 tends to relu(x), and scaled_softplus(x, alpha, clip=6) tends to relu6(x). 44 45 Note: the gradient for this operation is defined to depend on the backprop 46 inputs as well as the outputs of this operation. 47 48 Args: 49 x: A `Tensor` of inputs. 50 alpha: A `Tensor`, indicating the amount of smoothness. The caller 51 must ensure that `alpha > 0`. 52 clip: (optional) A `Tensor`, the upper bound to clip the values. 53 name: A name for the scope of the operations (optional). 54 55 Returns: 56 A tensor of the size and type determined by broadcasting of the inputs. 57 58 """ 59 clipping = clip is not None 60 with ops.name_scope(name, 'scaled_softplus', 61 [x, alpha] + ([clip] if clipping else [])): 62 x = ops.convert_to_tensor(x, name='x') 63 dtype = x.dtype 64 alpha = ops.convert_to_tensor(alpha, dtype=dtype, name='alpha') 65 # Compute the forward value. 66 y = alpha * nn.softplus(x / alpha) 67 if clipping: 68 clip = ops.convert_to_tensor(clip, dtype=dtype, name='clip') 69 y = math_ops.minimum(y, clip) 70 71 def _grad(op, g): 72 """Backprop for scaled softplus, with optional clipping.""" 73 y, x, alpha = op.inputs[:3] 74 # Prevent the memory-expensive computations from happening before g is 75 # available. 76 with ops.control_dependencies([g]): 77 y = array_ops.identity(y) 78 clip_grad = [] 79 if clipping: 80 clip = op.inputs[3] 81 unclipped = math_ops.cast(y < clip, g.dtype) 82 clip_grad = [_reduce_and_reshape_grad(g * (1. - unclipped), clip)] 83 g *= unclipped 84 y /= alpha 85 emy = math_ops.exp(-y) 86 dy_dx = 1. - emy 87 # The eps below avoids log(0). Note that t*log(t) -> 0 as t->0. 88 eps = 1e-8 89 dy_dalpha = y * emy - dy_dx * math_ops.log(dy_dx + eps) 90 # Backprop to the actual inputs, but not to the output. 91 return [None, 92 _reduce_and_reshape_grad(g * dy_dx, x), 93 _reduce_and_reshape_grad(g * dy_dalpha, alpha)] + clip_grad 94 95 if clipping: 96 @function.Defun(dtype, dtype, dtype, dtype, 97 func_name='ScaledSoftplusHelper_clip_%s' % dtype.name, 98 shape_func=lambda op: [op.inputs[0].shape], 99 python_grad_func=_grad) 100 def _forward_helper_clip(y, x, alpha, clip): 101 del x, alpha, clip # Unused. 102 return y 103 return _forward_helper_clip(y, x, alpha, clip) 104 # No clipping. 105 @function.Defun(dtype, dtype, dtype, 106 func_name='ScaledSoftplusHelper_%s' % dtype.name, 107 shape_func=lambda op: [op.inputs[0].shape], 108 python_grad_func=_grad) 109 def _forward_helper(y, x, alpha): 110 del x, alpha # Unused. 111 return y 112 return _forward_helper(y, x, alpha) 113 114