1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-import-not-at-top 16"""Callbacks: utilities called at certain points during model training. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import os 23 24import numpy as np 25 26from tensorflow.python.eager import context 27from tensorflow.python.eager import profiler 28from tensorflow.python.framework import dtypes 29from tensorflow.python.keras import backend as K 30from tensorflow.python.keras import callbacks 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import state_ops 33from tensorflow.python.ops import summary_ops_v2 34from tensorflow.python.ops import variables 35from tensorflow.python.platform import tf_logging as logging 36from tensorflow.python.summary import summary as tf_summary 37from tensorflow.python.training import saver 38from tensorflow.python.util.tf_export import keras_export 39 40 41@keras_export(v1=['keras.callbacks.TensorBoard']) 42class TensorBoard(callbacks.Callback): 43 # pylint: disable=line-too-long 44 """Enable visualizations for TensorBoard. 45 46 TensorBoard is a visualization tool provided with TensorFlow. 47 48 This callback logs events for TensorBoard, including: 49 * Metrics summary plots 50 * Training graph visualization 51 * Activation histograms 52 * Sampled profiling 53 54 If you have installed TensorFlow with pip, you should be able 55 to launch TensorBoard from the command line: 56 57 ```sh 58 tensorboard --logdir=path_to_your_logs 59 ``` 60 61 You can find more information about TensorBoard 62 [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard). 63 64 Arguments: 65 log_dir: the path of the directory where to save the log files to be 66 parsed by TensorBoard. 67 histogram_freq: frequency (in epochs) at which to compute activation and 68 weight histograms for the layers of the model. If set to 0, histograms 69 won't be computed. Validation data (or split) must be specified for 70 histogram visualizations. 71 write_graph: whether to visualize the graph in TensorBoard. The log file 72 can become quite large when write_graph is set to True. 73 write_grads: whether to visualize gradient histograms in TensorBoard. 74 `histogram_freq` must be greater than 0. 75 batch_size: size of batch of inputs to feed to the network for histograms 76 computation. 77 write_images: whether to write model weights to visualize as image in 78 TensorBoard. 79 embeddings_freq: frequency (in epochs) at which selected embedding layers 80 will be saved. If set to 0, embeddings won't be computed. Data to be 81 visualized in TensorBoard's Embedding tab must be passed as 82 `embeddings_data`. 83 embeddings_layer_names: a list of names of layers to keep eye on. If None 84 or empty list all the embedding layer will be watched. 85 embeddings_metadata: a dictionary which maps layer name to a file name in 86 which metadata for this embedding layer is saved. See the 87 [details](https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional) 88 about metadata files format. In case if the same metadata file is 89 used for all embedding layers, string can be passed. 90 embeddings_data: data to be embedded at layers specified in 91 `embeddings_layer_names`. Numpy array (if the model has a single input) 92 or list of Numpy arrays (if the model has multiple inputs). Learn [more 93 about 94 embeddings](https://www.tensorflow.org/programmers_guide/embedding) 95 update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`, 96 writes the losses and metrics to TensorBoard after each batch. The same 97 applies for `'epoch'`. If using an integer, let's say `1000`, the 98 callback will write the metrics and losses to TensorBoard every 1000 99 samples. Note that writing too frequently to TensorBoard can slow down 100 your training. 101 profile_batch: Profile the batch to sample compute characteristics. By 102 default, it will profile the second batch. Set profile_batch=0 to 103 disable profiling. 104 105 Raises: 106 ValueError: If histogram_freq is set and no validation data is provided. 107 108 @compatibility(eager) 109 Using the `TensorBoard` callback will work when eager execution is enabled, 110 with the restriction that outputting histogram summaries of weights and 111 gradients is not supported. Consequently, `histogram_freq` will be ignored. 112 @end_compatibility 113 """ 114 115 # pylint: enable=line-too-long 116 117 def __init__(self, 118 log_dir='./logs', 119 histogram_freq=0, 120 batch_size=32, 121 write_graph=True, 122 write_grads=False, 123 write_images=False, 124 embeddings_freq=0, 125 embeddings_layer_names=None, 126 embeddings_metadata=None, 127 embeddings_data=None, 128 update_freq='epoch', 129 profile_batch=2): 130 super(TensorBoard, self).__init__() 131 self.log_dir = log_dir 132 self.histogram_freq = histogram_freq 133 if self.histogram_freq and context.executing_eagerly(): 134 logging.warning( 135 UserWarning('Weight and gradient histograms not supported for eager' 136 'execution, setting `histogram_freq` to `0`.')) 137 self.histogram_freq = 0 138 self.merged = None 139 self.write_graph = write_graph 140 self.write_grads = write_grads 141 self.write_images = write_images 142 self.batch_size = batch_size 143 self._current_batch = 0 144 self._total_batches_seen = 0 145 self._total_val_batches_seen = 0 146 self.embeddings_freq = embeddings_freq 147 self.embeddings_layer_names = embeddings_layer_names 148 self.embeddings_metadata = embeddings_metadata 149 self.embeddings_data = embeddings_data 150 if update_freq == 'batch': 151 self.update_freq = 1 152 else: 153 self.update_freq = update_freq 154 self._samples_seen = 0 155 self._samples_seen_at_last_write = 0 156 # TODO(fishx): Add a link to the full profiler tutorial. 157 self._profile_batch = profile_batch 158 # One profiler session is running if it is True. 159 self._is_profiling = False 160 161 # TensorBoard should only write summaries on the chief when in a 162 # Multi-Worker setting. 163 self._chief_worker_only = True 164 165 def _init_writer(self, model): 166 """Sets file writer.""" 167 if context.executing_eagerly(): 168 self.writer = summary_ops_v2.create_file_writer(self.log_dir) 169 if not model.run_eagerly and self.write_graph: 170 with self.writer.as_default(): 171 summary_ops_v2.graph(K.get_graph()) 172 elif self.write_graph: 173 self.writer = tf_summary.FileWriter(self.log_dir, K.get_graph()) 174 else: 175 self.writer = tf_summary.FileWriter(self.log_dir) 176 177 def _make_histogram_ops(self, model): 178 """Defines histogram ops when histogram_freq > 0.""" 179 # only make histogram summary op if it hasn't already been made 180 if self.histogram_freq and self.merged is None: 181 for layer in self.model.layers: 182 for weight in layer.weights: 183 mapped_weight_name = weight.name.replace(':', '_') 184 tf_summary.histogram(mapped_weight_name, weight) 185 if self.write_images: 186 w_img = array_ops.squeeze(weight) 187 shape = K.int_shape(w_img) 188 if len(shape) == 2: # dense layer kernel case 189 if shape[0] > shape[1]: 190 w_img = array_ops.transpose(w_img) 191 shape = K.int_shape(w_img) 192 w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1]) 193 elif len(shape) == 3: # convnet case 194 if K.image_data_format() == 'channels_last': 195 # switch to channels_first to display 196 # every kernel as a separate image 197 w_img = array_ops.transpose(w_img, perm=[2, 0, 1]) 198 shape = K.int_shape(w_img) 199 w_img = array_ops.reshape(w_img, 200 [shape[0], shape[1], shape[2], 1]) 201 elif len(shape) == 1: # bias case 202 w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1]) 203 else: 204 # not possible to handle 3D convnets etc. 205 continue 206 207 shape = K.int_shape(w_img) 208 assert len(shape) == 4 and shape[-1] in [1, 3, 4] 209 tf_summary.image(mapped_weight_name, w_img) 210 211 if self.write_grads: 212 for weight in layer.trainable_weights: 213 mapped_weight_name = weight.name.replace(':', '_') 214 grads = model.optimizer.get_gradients(model.total_loss, weight) 215 216 def is_indexed_slices(grad): 217 return type(grad).__name__ == 'IndexedSlices' 218 219 grads = [ 220 grad.values if is_indexed_slices(grad) else grad 221 for grad in grads 222 ] 223 tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads) 224 225 if hasattr(layer, 'output'): 226 if isinstance(layer.output, list): 227 for i, output in enumerate(layer.output): 228 tf_summary.histogram('{}_out_{}'.format(layer.name, i), output) 229 else: 230 tf_summary.histogram('{}_out'.format(layer.name), layer.output) 231 232 def set_model(self, model): 233 """Sets Keras model and creates summary ops.""" 234 235 self.model = model 236 self._init_writer(model) 237 # histogram summaries only enabled in graph mode 238 if not context.executing_eagerly(): 239 self._make_histogram_ops(model) 240 self.merged = tf_summary.merge_all() 241 242 # If both embedding_freq and embeddings_data are available, we will 243 # visualize embeddings. 244 if self.embeddings_freq and self.embeddings_data is not None: 245 # Avoid circular dependency. 246 from tensorflow.python.keras.engine import training_utils # pylint: disable=g-import-not-at-top 247 self.embeddings_data = training_utils.standardize_input_data( 248 self.embeddings_data, model.input_names) 249 250 # If embedding_layer_names are not provided, get all of the embedding 251 # layers from the model. 252 embeddings_layer_names = self.embeddings_layer_names 253 if not embeddings_layer_names: 254 embeddings_layer_names = [ 255 layer.name 256 for layer in self.model.layers 257 if type(layer).__name__ == 'Embedding' 258 ] 259 260 self.assign_embeddings = [] 261 embeddings_vars = {} 262 263 self.batch_id = batch_id = array_ops.placeholder(dtypes.int32) 264 self.step = step = array_ops.placeholder(dtypes.int32) 265 266 for layer in self.model.layers: 267 if layer.name in embeddings_layer_names: 268 embedding_input = self.model.get_layer(layer.name).output 269 embedding_size = np.prod(embedding_input.shape[1:]) 270 embedding_input = array_ops.reshape(embedding_input, 271 (step, int(embedding_size))) 272 shape = (self.embeddings_data[0].shape[0], int(embedding_size)) 273 embedding = variables.Variable( 274 array_ops.zeros(shape), name=layer.name + '_embedding') 275 embeddings_vars[layer.name] = embedding 276 batch = state_ops.assign(embedding[batch_id:batch_id + step], 277 embedding_input) 278 self.assign_embeddings.append(batch) 279 280 self.saver = saver.Saver(list(embeddings_vars.values())) 281 282 # Create embeddings_metadata dictionary 283 if isinstance(self.embeddings_metadata, str): 284 embeddings_metadata = { 285 layer_name: self.embeddings_metadata 286 for layer_name in embeddings_vars.keys() 287 } 288 else: 289 # If embedding_metadata is already a dictionary 290 embeddings_metadata = self.embeddings_metadata 291 292 try: 293 from tensorboard.plugins import projector 294 except ImportError: 295 raise ImportError('Failed to import TensorBoard. Please make sure that ' 296 'TensorBoard integration is complete."') 297 298 # TODO(psv): Add integration tests to test embedding visualization 299 # with TensorBoard callback. We are unable to write a unit test for this 300 # because TensorBoard dependency assumes TensorFlow package is installed. 301 config = projector.ProjectorConfig() 302 for layer_name, tensor in embeddings_vars.items(): 303 embedding = config.embeddings.add() 304 embedding.tensor_name = tensor.name 305 306 if (embeddings_metadata is not None and 307 layer_name in embeddings_metadata): 308 embedding.metadata_path = embeddings_metadata[layer_name] 309 310 projector.visualize_embeddings(self.writer, config) 311 312 def _fetch_callback(self, summary): 313 self.writer.add_summary(summary, self._total_val_batches_seen) 314 self._total_val_batches_seen += 1 315 316 def _write_custom_summaries(self, step, logs=None): 317 """Writes metrics out as custom scalar summaries. 318 319 Arguments: 320 step: the global step to use for TensorBoard. 321 logs: dict. Keys are scalar summary names, values are 322 NumPy scalars. 323 324 """ 325 logs = logs or {} 326 if context.executing_eagerly(): 327 # use v2 summary ops 328 with self.writer.as_default(), summary_ops_v2.always_record_summaries(): 329 for name, value in logs.items(): 330 if isinstance(value, np.ndarray): 331 value = value.item() 332 summary_ops_v2.scalar(name, value, step=step) 333 else: 334 # use FileWriter from v1 summary 335 for name, value in logs.items(): 336 if isinstance(value, np.ndarray): 337 value = value.item() 338 summary = tf_summary.Summary() 339 summary_value = summary.value.add() 340 summary_value.simple_value = value 341 summary_value.tag = name 342 self.writer.add_summary(summary, step) 343 self.writer.flush() 344 345 def on_batch_end(self, batch, logs=None): 346 """Writes scalar summaries for metrics on every training batch. 347 348 Performs profiling if current batch is in profiler_batches. 349 """ 350 # Don't output batch_size and batch number as TensorBoard summaries 351 logs = logs or {} 352 self._samples_seen += logs.get('size', 1) 353 samples_seen_since = self._samples_seen - self._samples_seen_at_last_write 354 if self.update_freq != 'epoch' and samples_seen_since >= self.update_freq: 355 batch_logs = {('batch_' + k): v 356 for k, v in logs.items() 357 if k not in ['batch', 'size', 'num_steps']} 358 self._write_custom_summaries(self._total_batches_seen, batch_logs) 359 self._samples_seen_at_last_write = self._samples_seen 360 self._total_batches_seen += 1 361 if self._is_profiling: 362 profiler.save(self.log_dir, profiler.stop()) 363 self._is_profiling = False 364 elif (not self._is_profiling and 365 self._total_batches_seen == self._profile_batch - 1): 366 profiler.start() 367 self._is_profiling = True 368 369 def on_train_begin(self, logs=None): 370 if self._profile_batch == 1: 371 profiler.start() 372 self._is_profiling = True 373 374 def on_epoch_begin(self, epoch, logs=None): 375 """Add histogram op to Model eval_function callbacks, reset batch count.""" 376 377 # check if histogram summary should be run for this epoch 378 if self.histogram_freq and epoch % self.histogram_freq == 0: 379 self._epoch = epoch 380 # pylint: disable=protected-access 381 # add the histogram summary op if it should run this epoch 382 self.model._make_test_function() 383 if self.merged not in self.model.test_function.fetches: 384 self.model.test_function.fetches.append(self.merged) 385 self.model.test_function.fetch_callbacks[ 386 self.merged] = self._fetch_callback 387 # pylint: enable=protected-access 388 389 def on_epoch_end(self, epoch, logs=None): 390 """Checks if summary ops should run next epoch, logs scalar summaries.""" 391 392 # don't output batch_size and 393 # batch number as TensorBoard summaries 394 logs = {('epoch_' + k): v 395 for k, v in logs.items() 396 if k not in ['batch', 'size', 'num_steps']} 397 if self.update_freq == 'epoch': 398 step = epoch 399 else: 400 step = self._samples_seen 401 self._write_custom_summaries(step, logs) 402 403 # pop the histogram summary op after each epoch 404 if self.histogram_freq: 405 # pylint: disable=protected-access 406 if self.merged in self.model.test_function.fetches: 407 self.model.test_function.fetches.remove(self.merged) 408 if self.merged in self.model.test_function.fetch_callbacks: 409 self.model.test_function.fetch_callbacks.pop(self.merged) 410 # pylint: enable=protected-access 411 412 if self.embeddings_data is None and self.embeddings_freq: 413 raise ValueError('To visualize embeddings, embeddings_data must ' 414 'be provided.') 415 416 if self.embeddings_freq and self.embeddings_data is not None: 417 if epoch % self.embeddings_freq == 0: 418 # We need a second forward-pass here because we're passing 419 # the `embeddings_data` explicitly. This design allows to pass 420 # arbitrary data as `embeddings_data` and results from the fact 421 # that we need to know the size of the `tf.Variable`s which 422 # hold the embeddings in `set_model`. At this point, however, 423 # the `validation_data` is not yet set. 424 425 embeddings_data = self.embeddings_data 426 n_samples = embeddings_data[0].shape[0] 427 i = 0 428 sess = K.get_session() 429 while i < n_samples: 430 step = min(self.batch_size, n_samples - i) 431 batch = slice(i, i + step) 432 433 if isinstance(self.model.input, list): 434 feed_dict = { 435 model_input: embeddings_data[idx][batch] 436 for idx, model_input in enumerate(self.model.input) 437 } 438 else: 439 feed_dict = {self.model.input: embeddings_data[0][batch]} 440 441 feed_dict.update({self.batch_id: i, self.step: step}) 442 443 if not isinstance(K.learning_phase(), int): 444 feed_dict[K.learning_phase()] = False 445 446 sess.run(self.assign_embeddings, feed_dict=feed_dict) 447 self.saver.save(sess, 448 os.path.join(self.log_dir, 'keras_embedding.ckpt'), 449 epoch) 450 451 i += self.batch_size 452 453 def on_train_end(self, logs=None): 454 if self._is_profiling: 455 profiler.save(self.log_dir, profiler.stop()) 456 self._is_profiling = False 457 self.writer.close() 458