1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for automatic outside compilation for TF 2.0/Keras."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23from absl import flags
24import numpy as np
25
26from tensorboard.plugins.histogram import summary_v2 as histogram_summary_v2
27from tensorboard.plugins.image import summary_v2 as image_summary_v2
28from tensorboard.plugins.scalar import summary_v2 as scalar_summary_v2
29from tensorflow.python.compat import v2_compat
30from tensorflow.python.data.ops import dataset_ops
31from tensorflow.python.distribute import tpu_strategy as tpu_strategy_lib
32from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
33from tensorflow.python.eager import def_function
34from tensorflow.python.eager import remote
35from tensorflow.python.eager.context import set_soft_device_placement
36from tensorflow.python.framework import ops
37from tensorflow.python.keras import callbacks
38from tensorflow.python.keras import initializers
39from tensorflow.python.keras.distribute import distribute_strategy_test
40from tensorflow.python.keras.engine import base_layer
41from tensorflow.python.keras.engine import sequential as sequential_model_lib
42from tensorflow.python.keras.engine import training
43from tensorflow.python.keras.layers import convolutional as conv_layer_lib
44from tensorflow.python.keras.layers import core as layer_lib
45from tensorflow.python.keras.layers import pooling as pool_layer_lib
46from tensorflow.python.lib.io import file_io
47from tensorflow.python.ops import math_ops
48from tensorflow.python.ops import summary_ops_v2
49# from tensorflow.python.platform import flags
50from tensorflow.python.platform import test
51from tensorflow.python.summary import summary_iterator
52from tensorflow.python.tpu import tpu_strategy_util
53
54NUM_CLASSES = 4
55
56FLAGS = flags.FLAGS
57flags.DEFINE_string('tpu', '', 'Name of TPU to connect to.')
58flags.DEFINE_string('project', None, 'Name of GCP project with TPU.')
59flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.')
60
61
62def get_tpu_cluster_resolver():
63  resolver = tpu_cluster_resolver.TPUClusterResolver(
64      tpu=FLAGS.tpu,
65      zone=FLAGS.zone,
66      project=FLAGS.project,
67  )
68  return resolver
69
70
71def get_tpu_strategy():
72  resolver = get_tpu_cluster_resolver()
73  remote.connect_to_cluster(resolver)
74  tpu_strategy_util.initialize_tpu_system(resolver)
75  return tpu_strategy_lib.TPUStrategy(resolver)
76
77
78class LayerForScalarSummary(base_layer.Layer):
79  """A pass-through layer that only records scalar values to summary."""
80
81  def call(self, x):
82    # Add summary scalar using compat v2 implementation.
83    scalar_summary_v2.scalar('custom_scalar_summary_v2', math_ops.reduce_sum(x))
84    return x
85
86
87class LayerForImageSummary(base_layer.Layer):
88  """A pass-through layer that only records image values to summary."""
89
90  def call(self, x):
91    # Add summary image using compat v2 implementation.
92    image_summary_v2.image('custom_image_summary_v2', x)
93
94    return x
95
96
97class LayerForHistogramSummary(base_layer.Layer):
98  """A pass-through layer that records histogram values to summary."""
99
100  def call(self, x):
101    # Add summary histogram using compat v2 implementation.
102    histogram_summary_v2.histogram('custom_histogram_summary_v2', x)
103
104    return x
105
106
107class CustomModel(training.Model):
108  """Custom model with summary ops in model call definition."""
109
110  def __init__(self, name=None):
111    super(CustomModel, self).__init__()
112    self._my_layers = [
113        layer_lib.Dense(
114            4096,
115            name='dense1',
116            kernel_initializer=initializers.glorot_normal(seed=0),
117            use_bias=False),
118        layer_lib.Dense(
119            4,
120            name='dense2',
121            kernel_initializer=initializers.glorot_normal(seed=0),
122            use_bias=False),
123    ]
124    self.histogram_summary_layer = LayerForHistogramSummary()
125    self.scalar_summary_layer = LayerForScalarSummary()
126
127  def call(self, x):
128    for layer in self._my_layers:
129      x = layer(x)
130    x = self.scalar_summary_layer(x)
131    return self.histogram_summary_layer(x)
132
133
134def get_image_dataset():
135  inputs = np.zeros((10, 28, 28, 3), dtype=np.float32)
136  targets = np.zeros((10, NUM_CLASSES), dtype=np.float32)
137  dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
138  dataset = dataset.repeat(100)
139  dataset = dataset.batch(10, drop_remainder=True)
140  return dataset
141
142
143def mnist_model(input_shape):
144  """Creates a MNIST model."""
145  model = sequential_model_lib.Sequential()
146
147  # Adding custom pass-through layer to visualize input images.
148  model.add(LayerForImageSummary())
149
150  model.add(
151      conv_layer_lib.Conv2D(
152          32, kernel_size=(3, 3), activation='relu', input_shape=input_shape))
153  model.add(conv_layer_lib.Conv2D(64, (3, 3), activation='relu'))
154  model.add(pool_layer_lib.MaxPooling2D(pool_size=(2, 2)))
155  model.add(layer_lib.Dropout(0.25))
156  model.add(layer_lib.Flatten())
157  model.add(layer_lib.Dense(128, activation='relu'))
158  model.add(layer_lib.Dropout(0.5))
159  model.add(layer_lib.Dense(NUM_CLASSES, activation='softmax'))
160
161  # Adding custom pass-through layer for summary recording.
162  model.add(LayerForHistogramSummary())
163  return model
164
165
166class AutoOutsideCompilationWithKerasTest(test.TestCase):
167
168  def setUp(self):
169    super(AutoOutsideCompilationWithKerasTest, self).setUp()
170    v2_compat.enable_v2_behavior()
171    set_soft_device_placement(True)
172    self.summary_dir = self.get_temp_dir()
173
174  def validate_recorded_sumary_file(self, event_files, summary_dict,
175                                    expected_count):
176    for event_file in event_files:
177      for e in summary_iterator.summary_iterator(event_file):
178        for v in e.summary.value:
179          if v.tag in summary_dict:
180            summary_dict[v.tag] += 1
181
182    for key in summary_dict:
183      self.assertEqual(summary_dict[key], expected_count)
184
185  def testV2SummaryWithKerasSequentialModel(self):
186    strategy = get_tpu_strategy()
187
188    with strategy.scope():
189      model = mnist_model((28, 28, 3))
190      model.compile('sgd', 'mse')
191
192      dataset = get_image_dataset()
193      tensorboard_callback = callbacks.TensorBoard(
194          self.summary_dir, update_freq=2)
195      model.fit(
196          dataset,
197          steps_per_epoch=10,
198          epochs=1,
199          callbacks=[tensorboard_callback])
200
201      events_count_dictionary = {
202          'sequential/layer_for_histogram_summary/custom_histogram_summary_v2':
203              0,
204          'sequential/layer_for_image_summary/custom_image_summary_v2':
205              0,
206      }
207
208      event_files = file_io.get_matching_files_v2(
209          os.path.join(self.summary_dir, 'train', 'event*'))
210      # Since total of 10 steps are ran and summary ops should be invoked
211      # every 2 batches, we should see total of 5 event logs.
212      self.validate_recorded_sumary_file(event_files, events_count_dictionary,
213                                         5)
214
215  def testV2SummaryWithKerasSubclassedModel(self):
216    strategy = get_tpu_strategy()
217
218    with strategy.scope():
219      model = CustomModel()
220      model.compile('sgd', 'mse')
221
222      dataset = distribute_strategy_test.get_dataset(strategy)
223      tensorboard_callback = callbacks.TensorBoard(
224          self.summary_dir, update_freq=2)
225      model.fit(
226          dataset,
227          steps_per_epoch=10,
228          epochs=1,
229          callbacks=[tensorboard_callback])
230
231      event_files = file_io.get_matching_files_v2(
232          os.path.join(self.summary_dir, 'train', 'event*'))
233      events_count_dictionary = {
234          ('custom_model/layer_for_scalar_summary/'
235           'custom_scalar_summary_v2'):
236              0,
237          ('custom_model/layer_for_histogram_summary/'
238           'custom_histogram_summary_v2'):
239              0
240      }
241
242      # Since total of 10 steps are ran and summary ops should be invoked
243      # every 2 batches, we should see total of 5 event logs.
244      self.validate_recorded_sumary_file(event_files, events_count_dictionary,
245                                         5)
246
247  def testSummaryWithCustomTrainingLoop(self):
248    strategy = get_tpu_strategy()
249
250    writer = summary_ops_v2.create_file_writer_v2(self.summary_dir)
251    with strategy.scope():
252      model = distribute_strategy_test.get_model()
253      model.compile('sgd', 'mse')
254
255    @def_function.function
256    def custom_function(dataset):
257
258      def _custom_step(features, labels):
259        del labels
260        logits = model(features)
261        with summary_ops_v2.record_if(True), writer.as_default():
262          scalar_summary_v2.scalar(
263              'logits',
264              math_ops.reduce_sum(logits),
265              step=model.optimizer.iterations)
266        return logits
267
268      iterator = iter(dataset)
269      output = strategy.unwrap(
270          strategy.run(_custom_step, args=(next(iterator))))
271      return output
272
273    dataset = strategy.experimental_distribute_dataset(
274        distribute_strategy_test.get_dataset(strategy))
275
276    custom_function(dataset)
277    writer.close()
278
279    event_files = file_io.get_matching_files_v2(
280        os.path.join(self.summary_dir, 'event*'))
281    events_count_dictionary = {
282        ('logits'): 0,
283    }
284    self.validate_recorded_sumary_file(event_files, events_count_dictionary,
285                                       1)
286
287
288if __name__ == '__main__':
289  ops.enable_eager_execution()
290  test.main()
291