1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import functools
21
22from absl.testing import parameterized
23import numpy as np
24import tensorflow as tf
25
26
27def _jvp(f, primals, tangents):
28  """Compute the jacobian of `f` at `primals` multiplied by `tangents`."""
29  with tf.autodiff.ForwardAccumulator(primals, tangents) as acc:
30    primals_out = f(*primals)
31  return primals_out, acc.jvp(
32      primals_out, unconnected_gradients=tf.UnconnectedGradients.ZERO)
33
34
35def _jacfwd(f, primals):
36  """Compute the jacobian of `f` at `primals` using forward-mode autodiff."""
37  jac_flat = []
38  flat_primals = tf.nest.flatten(primals)
39  tangent_mask = [tf.zeros_like(primal) for primal in flat_primals]
40  for primal_index, primal in enumerate(flat_primals):
41    primal_vector = tf.reshape(primal, [-1])
42    primal_vector_length = tf.size(primal_vector)
43    jac_columns = []
44    for element_index in tf.range(primal_vector_length):
45      mask = tf.one_hot(element_index, primal_vector_length)
46      tangent_mask[primal_index] = tf.reshape(mask, tf.shape(primal))
47      jac_columns.append(
48          tf.nest.map_structure(
49              functools.partial(tf.reshape, shape=[-1]),
50              _jvp(f, primals, tf.nest.pack_sequence_as(primals,
51                                                        tangent_mask))[1]))
52    jac_flat.append(tf.stack(jac_columns, axis=1))
53    tangent_mask[primal_index] = tf.zeros_like(primal)
54  return tf.nest.pack_sequence_as(primals, jac_flat)
55
56
57def _grad(f, argnums=0):
58  """Return a function which computes the gradient of `f`."""
59
60  def _f(*params):
61    with tf.GradientTape() as tape:
62      tape.watch(params)
63      primals_out = f(*params)
64    return tape.gradient(
65        primals_out,
66        params[argnums],
67        unconnected_gradients=tf.UnconnectedGradients.ZERO)
68
69  return _f
70
71
72def _hvp(f, primals, tangents):
73  """Compute a forward-over-back Hessian-vector product."""
74  with tf.autodiff.ForwardAccumulator(primals, tangents) as acc:
75    with tf.GradientTape() as tape:
76      tape.watch(primals)
77      f_out = f(*primals)
78      f_out.shape.assert_is_compatible_with([])
79    return acc.jvp(tape.gradient(f_out, primals))
80
81
82def _vectorize_parameters(f, params, use_pfor, dtype):
83  """Loop over `params`, providing a one-hot mask to `f` for each."""
84  parameter_sizes = [tf.size(param) for param in params]
85  total_size = tf.math.add_n(parameter_sizes)
86
87  def _wrapper(index):
88    full_onehot = tf.one_hot(index, total_size)
89    split_onehot = tf.split(full_onehot, parameter_sizes)
90    tangents = [
91        tf.reshape(v, tf.shape(param))
92        for param, v in zip(params, split_onehot)
93    ]
94    return f(tangents)
95
96  if use_pfor:
97    return tf.vectorized_map(_wrapper, tf.range(total_size))
98  else:
99    return tf.map_fn(_wrapper, tf.range(total_size), dtype)
100
101
102def _forward_over_back_hessian(f, params, use_pfor, dtype=None):
103  """Computes the full Hessian matrix for the scalar-valued f(*params).
104
105  Args:
106    f: A function taking `params` and returning a scalar.
107    params: A possibly nested structure of tensors.
108    use_pfor: If true, uses `tf.vectorized_map` calls instead of looping.
109    dtype: Required if `use_pfor=False`. A possibly nested structure of dtypes
110      (e.g. `tf.float32`) matching the structure of `f`'s returns.
111
112  Returns:
113    A possibly nested structure of matrix slices corresponding to `params`. Each
114    slice has shape [P, p_s] where `p_s` is the number of parameters (`tf.size`)
115    in the corresponding element of `params` and `P` is the total number of
116    parameters (`sum_s(p_s)`). The full matrix can be obtained by concatenating
117    along the second axis.
118  """
119  return _vectorize_parameters(
120      functools.partial(_hvp, f, params),
121      params,
122      use_pfor=use_pfor,
123      dtype=dtype)
124
125
126def _test_gradients(testcase,
127                    f,
128                    primals,
129                    order,
130                    delta=1e-3,
131                    rtol=1e-2,
132                    atol=1e-6):
133  """Tests forward/backward jacobians of `f`'s [0, `order`)-order gradients."""
134  if order < 1:
135    raise ValueError(
136        "`order` should be a positive integer, got '{}'.".format(order))
137  if order > 1:
138    _test_gradients(
139        testcase=testcase,
140        f=_grad(f),
141        primals=primals,
142        order=order - 1,
143        delta=delta,
144        rtol=rtol,
145        atol=atol)
146  sym_jac_back, num_jac = tf.test.compute_gradient(f, primals, delta=delta)
147  testcase.assertAllClose(num_jac, sym_jac_back, rtol=rtol, atol=atol)
148  sym_jac_fwd = _jacfwd(f, primals)
149  testcase.assertAllClose(num_jac, sym_jac_fwd, rtol=rtol, atol=atol)
150  # And the symbolic computations should be much closer.
151  testcase.assertAllClose(sym_jac_back, sym_jac_fwd)
152
153
154class ForwardpropTest(tf.test.TestCase, parameterized.TestCase):
155
156  @parameterized.named_parameters([
157      ("Dense", [[0.1]], functools.partial(tf.keras.layers.Dense, 5)),
158      ("Conv2D",
159       np.reshape(
160           np.arange(start=-1., stop=1., step=2. / (1 * 2 * 4 * 4)),
161           [1, 2, 4, 4]), functools.partial(tf.keras.layers.Conv2D, 2, 2), 1e-3)
162  ])
163  def testKerasLayers(self, value, op_fn, atol=1e-6):
164    layer = op_fn()
165    input_value = tf.constant(value, dtype=tf.float32)
166    layer.build(input_value.shape)
167    # Make sure the test is deterministic by avoiding random variable
168    # initialization.
169    for v in layer.trainable_variables:
170      v.assign(
171          tf.reshape(
172              tf.range(
173                  -1.,
174                  1.,
175                  2. / tf.size(v, out_type=tf.float32),
176                  dtype=tf.float32), v.shape))
177    _test_gradients(
178        self,
179        layer,
180        [input_value],
181        atol=atol,
182        # These are linear, so second-order is pretty boring.
183        order=2)
184
185  @parameterized.named_parameters([
186      ("NonFused", [[0.1], [0.2], [-0.3]],
187       functools.partial(tf.keras.layers.BatchNormalization, fused=False)),
188      ("Fused", [[[[0.1, 2.]]], [[[0.2, -3.]]], [[[-0.3, 4.]]]],
189       functools.partial(tf.keras.layers.BatchNormalization, fused=True))
190  ])
191  def testBatchNorm(self, value, op_fn):
192    for training in [True, False]:
193      layer = op_fn()
194      input_value = tf.constant(value, dtype=tf.float32)
195      layer.build(input_value.shape)
196      _test_gradients(
197          self,
198          functools.partial(layer, training=training), [input_value],
199          order=2,
200          atol=1e-3)
201
202  @parameterized.named_parameters([
203      ("NonFused", [[0.1], [0.2], [-0.3]],
204       functools.partial(tf.keras.layers.BatchNormalization, fused=False)),
205      ("Fused", [[[[0.1, 2.]]], [[[0.2, -3.]]], [[[-0.3, 4.]]]],
206       functools.partial(tf.keras.layers.BatchNormalization, fused=True))
207  ])
208  def testBatchNormLayerParamGrads(self, value, op_fn):
209    for training in [True, False]:
210      layer = op_fn()
211      with tf.GradientTape() as tape:
212        input_value = tf.constant(value, dtype=tf.float32)
213        tape.watch(input_value)
214        output = layer(input_value, training=training)
215      jac_back = tape.jacobian(output,
216                               [input_value] + layer.trainable_variables)
217      jac_forward = _jacfwd(
218          lambda *args: layer(args[0], training=training),  # pylint:disable=cell-var-from-loop
219          [input_value] + layer.trainable_variables)
220      for backward, forward in zip(jac_back, jac_forward):
221        forward = tf.reshape(forward, tf.shape(backward))
222        self.assertAllClose(backward, forward)
223
224  @parameterized.named_parameters([("Function", tf.function),
225                                   ("NoFunction", lambda f: f)])
226  def testVariablesHVP(self, decorator):
227
228    class _Model(tf.Module):
229
230      def __init__(self):
231        self._first_dense = tf.keras.layers.Dense(18)
232        self._conv = tf.keras.layers.Conv2D(2, 2)
233        self._norm = tf.keras.layers.BatchNormalization()
234        self._second_dense = tf.keras.layers.Dense(1)
235
236      def __call__(self, x):
237        x = self._first_dense(x)
238        x = tf.nn.relu(x)
239        x = self._norm(x)
240        x = tf.nn.relu(self._conv(tf.reshape(x, [-1, 2, 3, 3])))
241        return self._second_dense(x)
242
243    model = _Model()
244
245    def _loss():
246      input_value = tf.constant([[-0.5, 1.], [0.5, -1.]])
247      target = tf.constant([[-1.], [2.]])
248      return tf.math.reduce_sum((model(input_value) - target)**2.)
249
250    @decorator
251    def _compute_hvps():
252      with tf.GradientTape() as tape:
253        loss = _loss()
254      vector = tape.gradient(loss, model.trainable_variables)
255      variable_input_fn = lambda unused_variables: _loss()
256      forward_over_back_hvp, = _hvp(variable_input_fn,
257                                    [model.trainable_variables], [vector])
258      with tf.GradientTape(persistent=True) as tape:
259        tape.watch(model.trainable_variables)
260        loss = _loss()
261        first_grads = tape.gradient(loss, model.trainable_variables)
262      back_over_back_hvp = tape.gradient(
263          first_grads, model.trainable_variables, output_gradients=vector)
264      return forward_over_back_hvp, back_over_back_hvp
265
266    self.assertAllClose(*_compute_hvps(), rtol=1e-5, atol=1e-5)
267
268  def testEmbeddingLayerInFunction(self):
269
270    class M(tf.keras.Model):
271
272      def __init__(self):
273        super(M, self).__init__()
274        self.embed = tf.keras.layers.Embedding(5, 1)
275        self.proj = tf.keras.layers.Dense(1)
276
277      @tf.function
278      def call(self, x):
279        return self.proj(self.embed(x))
280
281    model = M()
282    model(tf.zeros([3, 3], dtype=tf.int32))  # pylint: disable=not-callable
283    parameters = model.embed.variables
284    tangents = [tf.ones_like(v) for v in parameters]
285    with tf.autodiff.ForwardAccumulator(parameters, tangents):
286      # Note that forwardprop runs alongside the original computation. This test
287      # is just checking that it doesn't crash; correctness is tested in core
288      # TF.
289      model(tf.zeros([3, 3], dtype=tf.int32))  # pylint: disable=not-callable
290
291
292class HessianTests(tf.test.TestCase, parameterized.TestCase):
293
294  @parameterized.named_parameters([("PFor", True), ("MapFn", False)])
295  def testHessianOfVariables(self, use_pfor):
296    model = tf.keras.layers.Dense(1)
297    model.build([None, 2])
298
299    def _loss(*unused_args):
300      input_value = tf.constant([[-0.5, 1.], [0.5, -1.]])
301      target = tf.constant([[-1.], [2.]])
302      return tf.math.reduce_sum((model(input_value) - target)**2.)
303
304    kernel_hess, bias_hess = _forward_over_back_hessian(
305        _loss, [model.kernel, model.bias],
306        use_pfor=use_pfor,
307        dtype=[tf.float32, tf.float32])
308    # 3 total parameters, the whole hessian is the 3x3 concatenation
309    self.assertEqual([3, 2, 1], kernel_hess.shape)
310    self.assertEqual([3, 1], bias_hess.shape)
311    full_hessian = tf.concat([tf.reshape(kernel_hess, [3, 2]), bias_hess],
312                             axis=1)
313    # The full Hessian should be symmetric.
314    self.assertAllClose(full_hessian, tf.transpose(full_hessian))
315
316
317if __name__ == "__main__":
318  tf.test.main()
319