1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test for allowing TF ops to work with Keras Functional API."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import time
21
22from absl.testing import parameterized
23import numpy as np
24
25from tensorflow.python import keras
26from tensorflow.python.framework import ops
27from tensorflow.python.keras import keras_parameterized
28from tensorflow.python.keras import testing_utils
29from tensorflow.python.keras.optimizer_v2 import adam
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import gen_nn_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.platform import test
34from tensorflow.python.util import nest
35
36
37def _single_op_at_end():
38  inputs = keras.Input(shape=(10,))
39  x = keras.layers.Dense(10)(inputs)
40  outputs = gen_nn_ops.relu(x)
41  return inputs, outputs
42
43
44def _single_identity_op_at_end():
45  inputs = keras.Input(shape=(10,))
46  x = keras.layers.Dense(10)(inputs)
47  outputs = array_ops.identity(x)
48  assert 'Identity' in outputs.name
49  return inputs, outputs
50
51
52def _multiple_ops_at_end():
53  inputs = keras.Input(shape=(10,))
54  x = keras.layers.Dense(10)(inputs)
55  x = gen_nn_ops.relu(x)
56  outputs = gen_nn_ops.relu(x)
57  return inputs, outputs
58
59
60def _single_op_in_middle():
61  inputs = keras.Input(shape=(10,))
62  x = keras.layers.Dense(10)(inputs)
63  x = gen_nn_ops.relu(x)
64  outputs = keras.layers.Dense(10)(x)
65  return inputs, outputs
66
67
68def _multiple_ops_in_middle():
69  inputs = keras.Input(shape=(10,))
70  x = keras.layers.Dense(10)(inputs)
71  x = gen_nn_ops.relu(x)
72  x = gen_nn_ops.relu(x)
73  outputs = keras.layers.Dense(10)(x)
74  return inputs, outputs
75
76
77def _single_standalone_branch():
78  inputs = keras.Input(shape=(10,))
79  x = keras.layers.Dense(10)(inputs)
80  outputs = x * 2
81  return inputs, outputs
82
83
84def _single_op_with_attrs():
85  inputs = keras.Input(shape=(10,))
86  x = math_ops.reduce_mean(inputs, axis=1, keepdims=True)
87  outputs = keras.layers.Dense(10)(x)
88  return inputs, outputs
89
90
91def _multiple_uses():
92  inputs = keras.Input(shape=(10,))
93  x = math_ops.reduce_mean(inputs, axis=1, keepdims=True)
94  x1 = keras.layers.Dense(10)(x)
95  x2 = keras.layers.Dense(10)(x)
96  outputs = x1 + x2
97  return inputs, outputs
98
99
100def _op_with_tensor_list():
101  inputs = keras.Input(shape=(10,))
102  x = array_ops.concat([inputs, inputs], axis=1)
103  outputs = keras.layers.Dense(10)(x)
104  return inputs, outputs
105
106
107def _add_n():
108  inputs = keras.Input(shape=(10,))
109  outputs = math_ops.add_n([inputs, inputs, inputs])
110  return inputs, outputs
111
112
113def _reuse_op():
114  inputs = keras.Input(shape=(10,))
115  # This op needs to be checked multiple times.
116  x = gen_nn_ops.relu(inputs)
117  y = keras.layers.Dense(10)(x)
118  x2 = x * 2
119  y2 = keras.layers.Dense(10)(x2)
120  outputs = y + y2
121  return inputs, outputs
122
123
124class LayerWithLayer(keras.layers.Layer):
125
126  def build(self, input_shape):
127    self.bias = self.add_weight(name='bias', dtype='float32')
128    self.layer = keras.layers.Dense(10)
129
130  def call(self, inputs):
131    inputs = inputs * self.bias
132    # Would throw an error if Keras History was created here.
133    return self.layer(inputs)
134
135
136def _inner_layer():
137  inputs = keras.Input(shape=(10,))
138  outputs = LayerWithLayer()(inputs)
139  return inputs, outputs
140
141
142@keras_parameterized.run_all_keras_modes
143class AutoLambdaTest(keras_parameterized.TestCase):
144
145  @parameterized.named_parameters(
146      ('single_op_at_end', _single_op_at_end),
147      ('single_identity_op_at_end', _single_identity_op_at_end),
148      ('multiple_ops_at_end', _multiple_ops_at_end),
149      ('single_op_in_middle', _single_op_in_middle),
150      ('multiple_ops_in_middle', _multiple_ops_in_middle),
151      ('single_standalone_branch', _single_standalone_branch),
152      ('single_op_with_attrs', _single_op_with_attrs),
153      ('multiple_uses', _multiple_uses),
154      ('op_with_tensor_list', _op_with_tensor_list), ('add_n', _add_n),
155      ('_reuse_op', _reuse_op), ('_inner_layer', _inner_layer))
156  def test_autolambda(self, model_fn):
157    inputs, outputs = model_fn()
158    model = keras.Model(inputs, outputs)
159    model.compile(
160        adam.Adam(0.001), 'mse', run_eagerly=testing_utils.should_run_eagerly())
161
162    np_inputs = nest.map_structure(lambda x: np.ones((10, 10), 'float32'),
163                                   inputs)
164    np_outputs = nest.map_structure(lambda x: np.ones((10, 10), 'float32'),
165                                    outputs)
166    model.fit(np_inputs, np_outputs, batch_size=2)
167    model(np_inputs)  # Test calling the model directly on inputs.
168
169    new_model = keras.Model.from_config(
170        model.get_config(), custom_objects={'LayerWithLayer': LayerWithLayer})
171    new_model.compile(
172        adam.Adam(0.001), 'mse', run_eagerly=testing_utils.should_run_eagerly())
173    new_model.fit(np_inputs, np_outputs, batch_size=2)
174    new_model(np_inputs)  # Test calling the new model directly on inputs.
175
176  def test_numerical_correctness_simple(self):
177    x = ops.convert_to_tensor([[-1., 0., -2., 1.]])
178    inputs = keras.Input(shape=(4,))
179    outputs = gen_nn_ops.relu(inputs)
180    model = keras.Model(inputs, outputs)
181    y = self.evaluate(model(x))
182    self.assertAllClose(y, [[0., 0., 0., 1.]])
183
184  def test_numerical_correctness_with_attrs(self):
185    x = ops.convert_to_tensor([[1.5, 1.5], [2.5, 3.5]])
186    inputs = keras.Input(shape=(10,))
187    outputs = math_ops.reduce_mean(inputs, axis=1)
188    model = keras.Model(inputs, outputs)
189    y = self.evaluate(model(x))
190    self.assertAllClose(y, [1.5, 3.])
191
192  def test_numerical_correctness_serialization(self):
193    x = ops.convert_to_tensor([-1., 0., -2., 1.])
194    inputs = keras.Input(shape=(4,))
195    outputs = gen_nn_ops.relu(inputs)
196    model1 = keras.Model(inputs, outputs)
197    y1 = self.evaluate(model1(x))
198    model2 = model1.from_config(model1.get_config())
199    y2 = self.evaluate(model2(x))
200    self.assertAllClose(y1, y2)
201
202  def test_no_tracking(self):
203    x = keras.backend.placeholder((10, 10))
204    keras.layers.Dense(1)(x)
205    self.assertTrue(x._keras_history_checked)
206
207  def test_timing_scales_linearly(self):
208
209    def _construct_graph_of_size(size):
210      start = time.time()
211      x = keras.backend.placeholder(shape=(10, 4))
212
213      for _ in range(size):
214        x = keras.layers.Dense(4)(x)
215        x = gen_nn_ops.relu(x)
216
217      end = time.time()
218      return end - start
219
220    size_50 = _construct_graph_of_size(50)
221    size_500 = _construct_graph_of_size(500)
222
223    # Check construction time grows approx. linearly with size.
224    e = 3  # Fudge factor to prevent flakiness.
225    self.assertLess(size_500, (10 * e) * size_50)
226
227  def test_no_mask_tracking(self):
228    x = keras.backend.placeholder((10, 10))
229    y = keras.layers.Masking(0.)(x)
230    self.assertTrue(y._keras_mask._keras_history_checked)
231
232  def test_built(self):
233    inputs = keras.Input(shape=(10,))
234    outputs = gen_nn_ops.relu(inputs)
235    model = keras.Model(inputs, outputs)
236    model.compile('sgd', 'mse')
237    for layer in model.layers:
238      self.assertTrue(layer.built)
239    # Test something that requires Layers to be built.
240    model.summary()
241
242
243if __name__ == '__main__':
244  test.main()
245