1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Test utils for composite op definition."""
15from __future__ import absolute_import
16from __future__ import division
17from __future__ import print_function
18
19from tensorflow.python.eager import backprop
20from tensorflow.python.platform import test
21
22
23class OpsDefsTest(test.TestCase):
24  """Test utils."""
25
26  def _assertOpAndComposite(self, vars_, compute_op, compute_composite, kwargs,
27                            op_kwargs=None):
28    if op_kwargs is None:
29      op_kwargs = kwargs
30
31    # compute with op.
32    with backprop.GradientTape() as gt:
33      for var_ in vars_:
34        gt.watch(var_)
35      y = compute_op(**op_kwargs)  # uses op and decomposites by the graph pass.
36      grads = gt.gradient(y, vars_)  # uses registered gradient function.
37
38    # compute with composition
39    with backprop.GradientTape() as gt:
40      for var_ in vars_:
41        gt.watch(var_)
42      re_y = compute_composite(**kwargs)  # uses composite function.
43      re_grads = gt.gradient(re_y, vars_)  # uses gradients compposite function.
44
45    for v, re_v in zip(y, re_y):
46      self.assertAllClose(v, re_v)
47    for g, re_g in zip(grads, re_grads):
48      self.assertAllClose(g, re_g)
49