1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Keras TF utils.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python import keras 22from tensorflow.python.eager import context 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import sparse_tensor 25from tensorflow.python.framework import test_util 26from tensorflow.python.keras.utils import tf_utils 27from tensorflow.python.ops import variables 28from tensorflow.python.platform import test 29 30 31@test_util.run_all_in_graph_and_eager_modes 32class TestIsSymbolicTensor(test.TestCase): 33 34 def test_default_behavior(self): 35 if context.executing_eagerly(): 36 self.assertFalse(tf_utils.is_symbolic_tensor( 37 variables.Variable(name='blah', initial_value=0.))) 38 self.assertFalse(tf_utils.is_symbolic_tensor( 39 ops.convert_to_tensor(0.))) 40 self.assertFalse(tf_utils.is_symbolic_tensor( 41 sparse_tensor.SparseTensor( 42 indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))) 43 else: 44 self.assertTrue(tf_utils.is_symbolic_tensor( 45 variables.Variable(name='blah', initial_value=0.))) 46 self.assertTrue(tf_utils.is_symbolic_tensor( 47 ops.convert_to_tensor(0.))) 48 self.assertTrue(tf_utils.is_symbolic_tensor( 49 sparse_tensor.SparseTensor( 50 indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))) 51 52 def test_works_with_registered(self): 53 54 class CustomClass(object): 55 56 def value(self): 57 return ops.convert_to_tensor(42.) 58 59 ops.register_tensor_conversion_function( 60 CustomClass, lambda value, **_: value.value()) 61 62 tf_utils.register_symbolic_tensor_type(CustomClass) 63 64 if context.executing_eagerly(): 65 self.assertFalse(tf_utils.is_symbolic_tensor( 66 variables.Variable(name='blah', initial_value=0.))) 67 self.assertFalse(tf_utils.is_symbolic_tensor( 68 ops.convert_to_tensor(0.))) 69 self.assertFalse(tf_utils.is_symbolic_tensor( 70 sparse_tensor.SparseTensor( 71 indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))) 72 self.assertFalse(tf_utils.is_symbolic_tensor(CustomClass())) 73 else: 74 self.assertTrue(tf_utils.is_symbolic_tensor( 75 variables.Variable(name='blah', initial_value=0.))) 76 self.assertTrue(tf_utils.is_symbolic_tensor( 77 ops.convert_to_tensor(0.))) 78 self.assertTrue(tf_utils.is_symbolic_tensor( 79 sparse_tensor.SparseTensor( 80 indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))) 81 self.assertTrue(tf_utils.is_symbolic_tensor(CustomClass())) 82 83 def test_enables_nontensor_plumbing(self): 84 # Setup. 85 86 class Foo(object): 87 88 def __init__(self, input_): 89 self._input = input_ 90 self.value = ops.convert_to_tensor(42.) 91 92 @property 93 def dtype(self): 94 return self.value.dtype 95 96 ops.register_tensor_conversion_function( 97 Foo, lambda x, *args, **kwargs: x.value) 98 tf_utils.register_symbolic_tensor_type(Foo) 99 100 class PlumbingLayer(keras.layers.Lambda): 101 102 def __init__(self, fn, **kwargs): 103 def _fn(*fargs, **fkwargs): 104 d = fn(*fargs, **fkwargs) 105 x = ops.convert_to_tensor(d) 106 d.shape = x.shape 107 d.get_shape = x.get_shape 108 return d, x 109 super(PlumbingLayer, self).__init__(_fn, **kwargs) 110 self._enter_dunder_call = False 111 112 def __call__(self, inputs, *args, **kwargs): 113 self._enter_dunder_call = True 114 d, _ = super(PlumbingLayer, self).__call__(inputs, *args, **kwargs) 115 self._enter_dunder_call = False 116 return d 117 118 def call(self, inputs, *args, **kwargs): 119 d, v = super(PlumbingLayer, self).call(inputs, *args, **kwargs) 120 if self._enter_dunder_call: 121 return d, v 122 return d 123 124 # User-land. 125 model = keras.Sequential([ 126 keras.layers.InputLayer([]), 127 PlumbingLayer(Foo), # Makes a `Foo` object. 128 ]) 129 # Let's ensure Keras graph history is preserved by composing the models. 130 model = keras.Model(model.inputs, model(model.outputs)) 131 # Now we instantiate the model and verify we have a `Foo` object, not a 132 # `Tensor`. 133 y = model(ops.convert_to_tensor(7.)) 134 self.assertIsInstance(y, Foo) 135 # Confirm that (custom) loss sees `Foo` instance, not Tensor. 136 obtained_prediction_box = [None] 137 def custom_loss(y_obs, y_pred): 138 del y_obs 139 obtained_prediction_box[0] = y_pred 140 return y_pred 141 # Apparently `compile` calls the loss function enough to trigger the 142 # side-effect. 143 model.compile('SGD', loss=custom_loss) 144 self.assertIsInstance(obtained_prediction_box[0], Foo) 145 146 147class ConvertInnerNodeDataTest(test.TestCase): 148 149 def test_convert_inner_node_data(self): 150 data = tf_utils.convert_inner_node_data((tf_utils.ListWrapper(['l', 2, 3]), 151 tf_utils.ListWrapper(['l', 5, 6]))) 152 self.assertEqual(data, (['l', 2, 3], ['l', 5, 6])) 153 154 data = tf_utils.convert_inner_node_data(((['l', 2, 3], ['l', 5, 6])), 155 wrap=True) 156 self.assertTrue(all(isinstance(ele, tf_utils.ListWrapper) for ele in data)) 157 158 159if __name__ == '__main__': 160 test.main() 161