1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for automatic outside compilation for TF 2.0/Keras.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23from absl import flags 24import numpy as np 25 26from tensorboard.plugins.histogram import summary_v2 as histogram_summary_v2 27from tensorboard.plugins.image import summary_v2 as image_summary_v2 28from tensorboard.plugins.scalar import summary_v2 as scalar_summary_v2 29from tensorflow.python.compat import v2_compat 30from tensorflow.python.data.ops import dataset_ops 31from tensorflow.python.distribute import tpu_strategy as tpu_strategy_lib 32from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver 33from tensorflow.python.eager import def_function 34from tensorflow.python.eager import remote 35from tensorflow.python.eager.context import set_soft_device_placement 36from tensorflow.python.framework import ops 37from tensorflow.python.keras import callbacks 38from tensorflow.python.keras import initializers 39from tensorflow.python.keras.distribute import distribute_strategy_test 40from tensorflow.python.keras.engine import base_layer 41from tensorflow.python.keras.engine import sequential as sequential_model_lib 42from tensorflow.python.keras.engine import training 43from tensorflow.python.keras.layers import convolutional as conv_layer_lib 44from tensorflow.python.keras.layers import core as layer_lib 45from tensorflow.python.keras.layers import pooling as pool_layer_lib 46from tensorflow.python.lib.io import file_io 47from tensorflow.python.ops import math_ops 48from tensorflow.python.ops import summary_ops_v2 49# from tensorflow.python.platform import flags 50from tensorflow.python.platform import test 51from tensorflow.python.summary import summary_iterator 52from tensorflow.python.tpu import tpu_strategy_util 53 54NUM_CLASSES = 4 55 56FLAGS = flags.FLAGS 57flags.DEFINE_string('tpu', '', 'Name of TPU to connect to.') 58flags.DEFINE_string('project', None, 'Name of GCP project with TPU.') 59flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.') 60 61 62def get_tpu_cluster_resolver(): 63 resolver = tpu_cluster_resolver.TPUClusterResolver( 64 tpu=FLAGS.tpu, 65 zone=FLAGS.zone, 66 project=FLAGS.project, 67 ) 68 return resolver 69 70 71def get_tpu_strategy(): 72 resolver = get_tpu_cluster_resolver() 73 remote.connect_to_cluster(resolver) 74 tpu_strategy_util.initialize_tpu_system(resolver) 75 return tpu_strategy_lib.TPUStrategy(resolver) 76 77 78class LayerForScalarSummary(base_layer.Layer): 79 """A pass-through layer that only records scalar values to summary.""" 80 81 def call(self, x): 82 # Add summary scalar using compat v2 implementation. 83 scalar_summary_v2.scalar('custom_scalar_summary_v2', math_ops.reduce_sum(x)) 84 return x 85 86 87class LayerForImageSummary(base_layer.Layer): 88 """A pass-through layer that only records image values to summary.""" 89 90 def call(self, x): 91 # Add summary image using compat v2 implementation. 92 image_summary_v2.image('custom_image_summary_v2', x) 93 94 return x 95 96 97class LayerForHistogramSummary(base_layer.Layer): 98 """A pass-through layer that records histogram values to summary.""" 99 100 def call(self, x): 101 # Add summary histogram using compat v2 implementation. 102 histogram_summary_v2.histogram('custom_histogram_summary_v2', x) 103 104 return x 105 106 107class CustomModel(training.Model): 108 """Custom model with summary ops in model call definition.""" 109 110 def __init__(self, name=None): 111 super(CustomModel, self).__init__() 112 self._my_layers = [ 113 layer_lib.Dense( 114 4096, 115 name='dense1', 116 kernel_initializer=initializers.glorot_normal(seed=0), 117 use_bias=False), 118 layer_lib.Dense( 119 4, 120 name='dense2', 121 kernel_initializer=initializers.glorot_normal(seed=0), 122 use_bias=False), 123 ] 124 self.histogram_summary_layer = LayerForHistogramSummary() 125 self.scalar_summary_layer = LayerForScalarSummary() 126 127 def call(self, x): 128 for layer in self._my_layers: 129 x = layer(x) 130 x = self.scalar_summary_layer(x) 131 return self.histogram_summary_layer(x) 132 133 134def get_image_dataset(): 135 inputs = np.zeros((10, 28, 28, 3), dtype=np.float32) 136 targets = np.zeros((10, NUM_CLASSES), dtype=np.float32) 137 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 138 dataset = dataset.repeat(100) 139 dataset = dataset.batch(10, drop_remainder=True) 140 return dataset 141 142 143def mnist_model(input_shape): 144 """Creates a MNIST model.""" 145 model = sequential_model_lib.Sequential() 146 147 # Adding custom pass-through layer to visualize input images. 148 model.add(LayerForImageSummary()) 149 150 model.add( 151 conv_layer_lib.Conv2D( 152 32, kernel_size=(3, 3), activation='relu', input_shape=input_shape)) 153 model.add(conv_layer_lib.Conv2D(64, (3, 3), activation='relu')) 154 model.add(pool_layer_lib.MaxPooling2D(pool_size=(2, 2))) 155 model.add(layer_lib.Dropout(0.25)) 156 model.add(layer_lib.Flatten()) 157 model.add(layer_lib.Dense(128, activation='relu')) 158 model.add(layer_lib.Dropout(0.5)) 159 model.add(layer_lib.Dense(NUM_CLASSES, activation='softmax')) 160 161 # Adding custom pass-through layer for summary recording. 162 model.add(LayerForHistogramSummary()) 163 return model 164 165 166class AutoOutsideCompilationWithKerasTest(test.TestCase): 167 168 def setUp(self): 169 super(AutoOutsideCompilationWithKerasTest, self).setUp() 170 v2_compat.enable_v2_behavior() 171 set_soft_device_placement(True) 172 self.summary_dir = self.get_temp_dir() 173 174 def validate_recorded_sumary_file(self, event_files, summary_dict, 175 expected_count): 176 for event_file in event_files: 177 for e in summary_iterator.summary_iterator(event_file): 178 for v in e.summary.value: 179 if v.tag in summary_dict: 180 summary_dict[v.tag] += 1 181 182 for key in summary_dict: 183 self.assertEqual(summary_dict[key], expected_count) 184 185 def testV2SummaryWithKerasSequentialModel(self): 186 strategy = get_tpu_strategy() 187 188 with strategy.scope(): 189 model = mnist_model((28, 28, 3)) 190 model.compile('sgd', 'mse') 191 192 dataset = get_image_dataset() 193 tensorboard_callback = callbacks.TensorBoard( 194 self.summary_dir, update_freq=2) 195 model.fit( 196 dataset, 197 steps_per_epoch=10, 198 epochs=1, 199 callbacks=[tensorboard_callback]) 200 201 events_count_dictionary = { 202 'sequential/layer_for_histogram_summary/custom_histogram_summary_v2': 203 0, 204 'sequential/layer_for_image_summary/custom_image_summary_v2': 205 0, 206 } 207 208 event_files = file_io.get_matching_files_v2( 209 os.path.join(self.summary_dir, 'train', 'event*')) 210 # Since total of 10 steps are ran and summary ops should be invoked 211 # every 2 batches, we should see total of 5 event logs. 212 self.validate_recorded_sumary_file(event_files, events_count_dictionary, 213 5) 214 215 def testV2SummaryWithKerasSubclassedModel(self): 216 strategy = get_tpu_strategy() 217 218 with strategy.scope(): 219 model = CustomModel() 220 model.compile('sgd', 'mse') 221 222 dataset = distribute_strategy_test.get_dataset(strategy) 223 tensorboard_callback = callbacks.TensorBoard( 224 self.summary_dir, update_freq=2) 225 model.fit( 226 dataset, 227 steps_per_epoch=10, 228 epochs=1, 229 callbacks=[tensorboard_callback]) 230 231 event_files = file_io.get_matching_files_v2( 232 os.path.join(self.summary_dir, 'train', 'event*')) 233 events_count_dictionary = { 234 ('custom_model/layer_for_scalar_summary/' 235 'custom_scalar_summary_v2'): 236 0, 237 ('custom_model/layer_for_histogram_summary/' 238 'custom_histogram_summary_v2'): 239 0 240 } 241 242 # Since total of 10 steps are ran and summary ops should be invoked 243 # every 2 batches, we should see total of 5 event logs. 244 self.validate_recorded_sumary_file(event_files, events_count_dictionary, 245 5) 246 247 def testSummaryWithCustomTrainingLoop(self): 248 strategy = get_tpu_strategy() 249 250 writer = summary_ops_v2.create_file_writer_v2(self.summary_dir) 251 with strategy.scope(): 252 model = distribute_strategy_test.get_model() 253 model.compile('sgd', 'mse') 254 255 @def_function.function 256 def custom_function(dataset): 257 258 def _custom_step(features, labels): 259 del labels 260 logits = model(features) 261 with summary_ops_v2.record_if(True), writer.as_default(): 262 scalar_summary_v2.scalar( 263 'logits', 264 math_ops.reduce_sum(logits), 265 step=model.optimizer.iterations) 266 return logits 267 268 iterator = iter(dataset) 269 output = strategy.unwrap( 270 strategy.run(_custom_step, args=(next(iterator)))) 271 return output 272 273 dataset = strategy.experimental_distribute_dataset( 274 distribute_strategy_test.get_dataset(strategy)) 275 276 custom_function(dataset) 277 writer.close() 278 279 event_files = file_io.get_matching_files_v2( 280 os.path.join(self.summary_dir, 'event*')) 281 events_count_dictionary = { 282 ('logits'): 0, 283 } 284 self.validate_recorded_sumary_file(event_files, events_count_dictionary, 285 1) 286 287 288if __name__ == '__main__': 289 ops.enable_eager_execution() 290 test.main() 291