1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Keras TF utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python import keras
22from tensorflow.python.eager import context
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import sparse_tensor
25from tensorflow.python.framework import test_util
26from tensorflow.python.keras.utils import tf_utils
27from tensorflow.python.ops import variables
28from tensorflow.python.platform import test
29
30
31@test_util.run_all_in_graph_and_eager_modes
32class TestIsSymbolicTensor(test.TestCase):
33
34  def test_default_behavior(self):
35    if context.executing_eagerly():
36      self.assertFalse(tf_utils.is_symbolic_tensor(
37          variables.Variable(name='blah', initial_value=0.)))
38      self.assertFalse(tf_utils.is_symbolic_tensor(
39          ops.convert_to_tensor(0.)))
40      self.assertFalse(tf_utils.is_symbolic_tensor(
41          sparse_tensor.SparseTensor(
42              indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
43    else:
44      self.assertTrue(tf_utils.is_symbolic_tensor(
45          variables.Variable(name='blah', initial_value=0.)))
46      self.assertTrue(tf_utils.is_symbolic_tensor(
47          ops.convert_to_tensor(0.)))
48      self.assertTrue(tf_utils.is_symbolic_tensor(
49          sparse_tensor.SparseTensor(
50              indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
51
52  def test_works_with_registered(self):
53
54    class CustomClass(object):
55
56      def value(self):
57        return ops.convert_to_tensor(42.)
58
59    ops.register_tensor_conversion_function(
60        CustomClass, lambda value, **_: value.value())
61
62    tf_utils.register_symbolic_tensor_type(CustomClass)
63
64    if context.executing_eagerly():
65      self.assertFalse(tf_utils.is_symbolic_tensor(
66          variables.Variable(name='blah', initial_value=0.)))
67      self.assertFalse(tf_utils.is_symbolic_tensor(
68          ops.convert_to_tensor(0.)))
69      self.assertFalse(tf_utils.is_symbolic_tensor(
70          sparse_tensor.SparseTensor(
71              indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
72      self.assertFalse(tf_utils.is_symbolic_tensor(CustomClass()))
73    else:
74      self.assertTrue(tf_utils.is_symbolic_tensor(
75          variables.Variable(name='blah', initial_value=0.)))
76      self.assertTrue(tf_utils.is_symbolic_tensor(
77          ops.convert_to_tensor(0.)))
78      self.assertTrue(tf_utils.is_symbolic_tensor(
79          sparse_tensor.SparseTensor(
80              indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
81      self.assertTrue(tf_utils.is_symbolic_tensor(CustomClass()))
82
83  def test_enables_nontensor_plumbing(self):
84    # Setup.
85
86    class Foo(object):
87
88      def __init__(self, input_):
89        self._input = input_
90        self.value = ops.convert_to_tensor(42.)
91
92      @property
93      def dtype(self):
94        return self.value.dtype
95
96    ops.register_tensor_conversion_function(
97        Foo, lambda x, *args, **kwargs: x.value)
98    tf_utils.register_symbolic_tensor_type(Foo)
99
100    class PlumbingLayer(keras.layers.Lambda):
101
102      def __init__(self, fn, **kwargs):
103        def _fn(*fargs, **fkwargs):
104          d = fn(*fargs, **fkwargs)
105          x = ops.convert_to_tensor(d)
106          d.shape = x.shape
107          d.get_shape = x.get_shape
108          return d, x
109        super(PlumbingLayer, self).__init__(_fn, **kwargs)
110        self._enter_dunder_call = False
111
112      def __call__(self, inputs, *args, **kwargs):
113        self._enter_dunder_call = True
114        d, _ = super(PlumbingLayer, self).__call__(inputs, *args, **kwargs)
115        self._enter_dunder_call = False
116        return d
117
118      def call(self, inputs, *args, **kwargs):
119        d, v = super(PlumbingLayer, self).call(inputs, *args, **kwargs)
120        if self._enter_dunder_call:
121          return d, v
122        return d
123
124    # User-land.
125    model = keras.Sequential([
126        keras.layers.InputLayer([]),
127        PlumbingLayer(Foo),  # Makes a `Foo` object.
128    ])
129    # Let's ensure Keras graph history is preserved by composing the models.
130    model = keras.Model(model.inputs, model(model.outputs))
131    # Now we instantiate the model and verify we have a `Foo` object, not a
132    # `Tensor`.
133    y = model(ops.convert_to_tensor(7.))
134    self.assertIsInstance(y, Foo)
135    # Confirm that (custom) loss sees `Foo` instance, not Tensor.
136    obtained_prediction_box = [None]
137    def custom_loss(y_obs, y_pred):
138      del y_obs
139      obtained_prediction_box[0] = y_pred
140      return y_pred
141    # Apparently `compile` calls the loss function enough to trigger the
142    # side-effect.
143    model.compile('SGD', loss=custom_loss)
144    self.assertIsInstance(obtained_prediction_box[0], Foo)
145
146
147class ConvertInnerNodeDataTest(test.TestCase):
148
149  def test_convert_inner_node_data(self):
150    data = tf_utils.convert_inner_node_data((tf_utils.ListWrapper(['l', 2, 3]),
151                                             tf_utils.ListWrapper(['l', 5, 6])))
152    self.assertEqual(data, (['l', 2, 3], ['l', 5, 6]))
153
154    data = tf_utils.convert_inner_node_data(((['l', 2, 3], ['l', 5, 6])),
155                                            wrap=True)
156    self.assertTrue(all(isinstance(ele, tf_utils.ListWrapper) for ele in data))
157
158
159if __name__ == '__main__':
160  test.main()
161