1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-import-not-at-top
16"""Callbacks: utilities called at certain points during model training.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import os
23
24import numpy as np
25
26from tensorflow.python.eager import context
27from tensorflow.python.framework import dtypes
28from tensorflow.python.keras import backend as K
29from tensorflow.python.keras import callbacks
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import state_ops
32from tensorflow.python.ops import summary_ops_v2
33from tensorflow.python.ops import variables
34from tensorflow.python.platform import tf_logging as logging
35from tensorflow.python.profiler import profiler_v2 as profiler
36from tensorflow.python.summary import summary as tf_summary
37from tensorflow.python.training import saver
38from tensorflow.python.util.tf_export import keras_export
39
40
41@keras_export(v1=['keras.callbacks.TensorBoard'])
42class TensorBoard(callbacks.TensorBoard):
43  # pylint: disable=line-too-long
44  """Enable visualizations for TensorBoard.
45
46  TensorBoard is a visualization tool provided with TensorFlow.
47
48  This callback logs events for TensorBoard, including:
49  * Metrics summary plots
50  * Training graph visualization
51  * Activation histograms
52  * Sampled profiling
53
54  If you have installed TensorFlow with pip, you should be able
55  to launch TensorBoard from the command line:
56
57  ```sh
58  tensorboard --logdir=path_to_your_logs
59  ```
60
61  You can find more information about TensorBoard
62  [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
63
64  Args:
65      log_dir: the path of the directory where to save the log files to be
66        parsed by TensorBoard.
67      histogram_freq: frequency (in epochs) at which to compute activation and
68        weight histograms for the layers of the model. If set to 0, histograms
69        won't be computed. Validation data (or split) must be specified for
70        histogram visualizations.
71      write_graph: whether to visualize the graph in TensorBoard. The log file
72        can become quite large when write_graph is set to True.
73      write_grads: whether to visualize gradient histograms in TensorBoard.
74        `histogram_freq` must be greater than 0.
75      batch_size: size of batch of inputs to feed to the network for histograms
76        computation.
77      write_images: whether to write model weights to visualize as image in
78        TensorBoard.
79      embeddings_freq: frequency (in epochs) at which selected embedding layers
80        will be saved. If set to 0, embeddings won't be computed. Data to be
81        visualized in TensorBoard's Embedding tab must be passed as
82        `embeddings_data`.
83      embeddings_layer_names: a list of names of layers to keep eye on. If None
84        or empty list all the embedding layer will be watched.
85      embeddings_metadata: a dictionary which maps layer name to a file name in
86        which metadata for this embedding layer is saved.
87          [Here are details](
88            https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
89            about metadata files format. In case if the same metadata file is
90            used for all embedding layers, string can be passed.
91      embeddings_data: data to be embedded at layers specified in
92        `embeddings_layer_names`. Numpy array (if the model has a single input)
93        or list of Numpy arrays (if the model has multiple inputs). Learn more
94        about embeddings [in this guide](
95          https://www.tensorflow.org/programmers_guide/embedding).
96      update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`,
97        writes the losses and metrics to TensorBoard after each batch. The same
98        applies for `'epoch'`. If using an integer, let's say `1000`, the
99        callback will write the metrics and losses to TensorBoard every 1000
100        samples. Note that writing too frequently to TensorBoard can slow down
101        your training.
102      profile_batch: Profile the batch to sample compute characteristics. By
103        default, it will profile the second batch. Set profile_batch=0 to
104        disable profiling.
105
106  Raises:
107      ValueError: If histogram_freq is set and no validation data is provided.
108
109  @compatibility(eager)
110  Using the `TensorBoard` callback will work when eager execution is enabled,
111  with the restriction that outputting histogram summaries of weights and
112  gradients is not supported. Consequently, `histogram_freq` will be ignored.
113  @end_compatibility
114  """
115
116  # pylint: enable=line-too-long
117
118  def __init__(self,
119               log_dir='./logs',
120               histogram_freq=0,
121               batch_size=32,
122               write_graph=True,
123               write_grads=False,
124               write_images=False,
125               embeddings_freq=0,
126               embeddings_layer_names=None,
127               embeddings_metadata=None,
128               embeddings_data=None,
129               update_freq='epoch',
130               profile_batch=2):
131    # Don't call super's init since it is an eager-only version.
132    callbacks.Callback.__init__(self)
133    self.log_dir = log_dir
134    self.histogram_freq = histogram_freq
135    if self.histogram_freq and context.executing_eagerly():
136      logging.warning(
137          UserWarning('Weight and gradient histograms not supported for eager'
138                      'execution, setting `histogram_freq` to `0`.'))
139      self.histogram_freq = 0
140    self.merged = None
141    self.write_graph = write_graph
142    self.write_grads = write_grads
143    self.write_images = write_images
144    self.batch_size = batch_size
145    self._current_batch = 0
146    self._total_batches_seen = 0
147    self._total_val_batches_seen = 0
148    self.embeddings_freq = embeddings_freq
149    self.embeddings_layer_names = embeddings_layer_names
150    self.embeddings_metadata = embeddings_metadata
151    self.embeddings_data = embeddings_data
152    if update_freq == 'batch':
153      self.update_freq = 1
154    else:
155      self.update_freq = update_freq
156    self._samples_seen = 0
157    self._samples_seen_at_last_write = 0
158    # TODO(fishx): Add a link to the full profiler tutorial.
159    self._profile_batch = profile_batch
160    # One profiler session is running if it is True.
161    self._is_profiling = False
162
163    # TensorBoard should only write summaries on the chief when in a
164    # Multi-Worker setting.
165    self._chief_worker_only = True
166
167  def _init_writer(self, model):
168    """Sets file writer."""
169    if context.executing_eagerly():
170      self.writer = summary_ops_v2.create_file_writer_v2(self.log_dir)
171      if not model.run_eagerly and self.write_graph:
172        with self.writer.as_default():
173          summary_ops_v2.graph(K.get_graph())
174    elif self.write_graph:
175      self.writer = tf_summary.FileWriter(self.log_dir, K.get_graph())
176    else:
177      self.writer = tf_summary.FileWriter(self.log_dir)
178
179  def _make_histogram_ops(self, model):
180    """Defines histogram ops when histogram_freq > 0."""
181    # only make histogram summary op if it hasn't already been made
182    if self.histogram_freq and self.merged is None:
183      for layer in self.model.layers:
184        for weight in layer.weights:
185          mapped_weight_name = weight.name.replace(':', '_')
186          tf_summary.histogram(mapped_weight_name, weight)
187          if self.write_images:
188            w_img = array_ops.squeeze(weight)
189            shape = K.int_shape(w_img)
190            if len(shape) == 2:  # dense layer kernel case
191              if shape[0] > shape[1]:
192                w_img = array_ops.transpose(w_img)
193                shape = K.int_shape(w_img)
194              w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1])
195            elif len(shape) == 3:  # convnet case
196              if K.image_data_format() == 'channels_last':
197                # switch to channels_first to display
198                # every kernel as a separate image
199                w_img = array_ops.transpose(w_img, perm=[2, 0, 1])
200                shape = K.int_shape(w_img)
201              w_img = array_ops.reshape(w_img,
202                                        [shape[0], shape[1], shape[2], 1])
203            elif len(shape) == 1:  # bias case
204              w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1])
205            else:
206              # not possible to handle 3D convnets etc.
207              continue
208
209            shape = K.int_shape(w_img)
210            assert len(shape) == 4 and shape[-1] in [1, 3, 4]
211            tf_summary.image(mapped_weight_name, w_img)
212
213        if self.write_grads:
214          for weight in layer.trainable_weights:
215            mapped_weight_name = weight.name.replace(':', '_')
216            grads = model.optimizer.get_gradients(model.total_loss, weight)
217
218            def is_indexed_slices(grad):
219              return type(grad).__name__ == 'IndexedSlices'
220
221            grads = [
222                grad.values if is_indexed_slices(grad) else grad
223                for grad in grads
224            ]
225            tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads)
226
227        if hasattr(layer, 'output'):
228          if isinstance(layer.output, list):
229            for i, output in enumerate(layer.output):
230              tf_summary.histogram('{}_out_{}'.format(layer.name, i), output)
231          else:
232            tf_summary.histogram('{}_out'.format(layer.name), layer.output)
233
234  def set_model(self, model):
235    """Sets Keras model and creates summary ops."""
236
237    self.model = model
238    self._init_writer(model)
239    # histogram summaries only enabled in graph mode
240    if not context.executing_eagerly():
241      self._make_histogram_ops(model)
242      self.merged = tf_summary.merge_all()
243
244    # If both embedding_freq and embeddings_data are available, we will
245    # visualize embeddings.
246    if self.embeddings_freq and self.embeddings_data is not None:
247      # Avoid circular dependency.
248      from tensorflow.python.keras.engine import training_utils_v1  # pylint: disable=g-import-not-at-top
249      self.embeddings_data = training_utils_v1.standardize_input_data(
250          self.embeddings_data, model.input_names)
251
252      # If embedding_layer_names are not provided, get all of the embedding
253      # layers from the model.
254      embeddings_layer_names = self.embeddings_layer_names
255      if not embeddings_layer_names:
256        embeddings_layer_names = [
257            layer.name
258            for layer in self.model.layers
259            if type(layer).__name__ == 'Embedding'
260        ]
261
262      self.assign_embeddings = []
263      embeddings_vars = {}
264
265      self.batch_id = batch_id = array_ops.placeholder(dtypes.int32)
266      self.step = step = array_ops.placeholder(dtypes.int32)
267
268      for layer in self.model.layers:
269        if layer.name in embeddings_layer_names:
270          embedding_input = self.model.get_layer(layer.name).output
271          embedding_size = np.prod(embedding_input.shape[1:])
272          embedding_input = array_ops.reshape(embedding_input,
273                                              (step, int(embedding_size)))
274          shape = (self.embeddings_data[0].shape[0], int(embedding_size))
275          embedding = variables.Variable(
276              array_ops.zeros(shape), name=layer.name + '_embedding')
277          embeddings_vars[layer.name] = embedding
278          batch = state_ops.assign(embedding[batch_id:batch_id + step],
279                                   embedding_input)
280          self.assign_embeddings.append(batch)
281
282      self.saver = saver.Saver(list(embeddings_vars.values()))
283
284      # Create embeddings_metadata dictionary
285      if isinstance(self.embeddings_metadata, str):
286        embeddings_metadata = {
287            layer_name: self.embeddings_metadata
288            for layer_name in embeddings_vars.keys()
289        }
290      else:
291        # If embedding_metadata is already a dictionary
292        embeddings_metadata = self.embeddings_metadata
293
294      try:
295        from tensorboard.plugins import projector
296      except ImportError:
297        raise ImportError('Failed to import TensorBoard. Please make sure that '
298                          'TensorBoard integration is complete."')
299
300      # TODO(psv): Add integration tests to test embedding visualization
301      # with TensorBoard callback. We are unable to write a unit test for this
302      # because TensorBoard dependency assumes TensorFlow package is installed.
303      config = projector.ProjectorConfig()
304      for layer_name, tensor in embeddings_vars.items():
305        embedding = config.embeddings.add()
306        embedding.tensor_name = tensor.name
307
308        if (embeddings_metadata is not None and
309            layer_name in embeddings_metadata):
310          embedding.metadata_path = embeddings_metadata[layer_name]
311
312      projector.visualize_embeddings(self.writer, config)
313
314  def _fetch_callback(self, summary):
315    self.writer.add_summary(summary, self._total_val_batches_seen)
316    self._total_val_batches_seen += 1
317
318  def _write_custom_summaries(self, step, logs=None):
319    """Writes metrics out as custom scalar summaries.
320
321    Args:
322        step: the global step to use for TensorBoard.
323        logs: dict. Keys are scalar summary names, values are
324            NumPy scalars.
325
326    """
327    logs = logs or {}
328    if context.executing_eagerly():
329      # use v2 summary ops
330      with self.writer.as_default(), summary_ops_v2.record_if(True):
331        for name, value in logs.items():
332          if isinstance(value, np.ndarray):
333            value = value.item()
334          summary_ops_v2.scalar(name, value, step=step)
335    else:
336      # use FileWriter from v1 summary
337      for name, value in logs.items():
338        if isinstance(value, np.ndarray):
339          value = value.item()
340        summary = tf_summary.Summary()
341        summary_value = summary.value.add()
342        summary_value.simple_value = value
343        summary_value.tag = name
344        self.writer.add_summary(summary, step)
345    self.writer.flush()
346
347  def on_train_batch_begin(self, batch, logs=None):
348    if (not self._is_profiling and
349        self._total_batches_seen == self._profile_batch - 1):
350      profiler.start(self.log_dir)
351      self._is_profiling = True
352
353  def on_train_batch_end(self, batch, logs=None):
354    return self.on_batch_end(batch, logs)
355
356  def on_test_begin(self, logs=None):
357    pass
358
359  def on_test_end(self, logs=None):
360    pass
361
362  def on_batch_end(self, batch, logs=None):
363    """Writes scalar summaries for metrics on every training batch.
364
365    Performs profiling if current batch is in profiler_batches.
366    """
367    # Don't output batch_size and batch number as TensorBoard summaries
368    logs = logs or {}
369    self._samples_seen += logs.get('size', 1)
370    samples_seen_since = self._samples_seen - self._samples_seen_at_last_write
371    if self.update_freq != 'epoch' and samples_seen_since >= self.update_freq:
372      batch_logs = {('batch_' + k): v
373                    for k, v in logs.items()
374                    if k not in ['batch', 'size', 'num_steps']}
375      self._write_custom_summaries(self._total_batches_seen, batch_logs)
376      self._samples_seen_at_last_write = self._samples_seen
377    self._total_batches_seen += 1
378
379    if self._is_profiling:
380      profiler.stop()
381      self._is_profiling = False
382
383  def on_train_begin(self, logs=None):
384    pass
385
386  def on_epoch_begin(self, epoch, logs=None):
387    """Add histogram op to Model eval_function callbacks, reset batch count."""
388
389    # check if histogram summary should be run for this epoch
390    if self.histogram_freq and epoch % self.histogram_freq == 0:
391      self._epoch = epoch
392      # pylint: disable=protected-access
393      # add the histogram summary op if it should run this epoch
394      self.model._make_test_function()
395      if self.merged not in self.model.test_function.fetches:
396        self.model.test_function.fetches.append(self.merged)
397        self.model.test_function.fetch_callbacks[
398            self.merged] = self._fetch_callback
399      # pylint: enable=protected-access
400
401  def on_epoch_end(self, epoch, logs=None):
402    """Checks if summary ops should run next epoch, logs scalar summaries."""
403
404    # don't output batch_size and
405    # batch number as TensorBoard summaries
406    logs = {('epoch_' + k): v
407            for k, v in logs.items()
408            if k not in ['batch', 'size', 'num_steps']}
409    if self.update_freq == 'epoch':
410      step = epoch
411    else:
412      step = self._samples_seen
413    self._write_custom_summaries(step, logs)
414
415    # pop the histogram summary op after each epoch
416    if self.histogram_freq:
417      # pylint: disable=protected-access
418      if self.merged in self.model.test_function.fetches:
419        self.model.test_function.fetches.remove(self.merged)
420      if self.merged in self.model.test_function.fetch_callbacks:
421        self.model.test_function.fetch_callbacks.pop(self.merged)
422      # pylint: enable=protected-access
423
424    if self.embeddings_data is None and self.embeddings_freq:
425      raise ValueError('To visualize embeddings, embeddings_data must '
426                       'be provided.')
427
428    if self.embeddings_freq and self.embeddings_data is not None:
429      if epoch % self.embeddings_freq == 0:
430        # We need a second forward-pass here because we're passing
431        # the `embeddings_data` explicitly. This design allows to pass
432        # arbitrary data as `embeddings_data` and results from the fact
433        # that we need to know the size of the `tf.Variable`s which
434        # hold the embeddings in `set_model`. At this point, however,
435        # the `validation_data` is not yet set.
436
437        embeddings_data = self.embeddings_data
438        n_samples = embeddings_data[0].shape[0]
439        i = 0
440        sess = K.get_session()
441        while i < n_samples:
442          step = min(self.batch_size, n_samples - i)
443          batch = slice(i, i + step)
444
445          if isinstance(self.model.input, list):
446            feed_dict = {
447                model_input: embeddings_data[idx][batch]
448                for idx, model_input in enumerate(self.model.input)
449            }
450          else:
451            feed_dict = {self.model.input: embeddings_data[0][batch]}
452
453          feed_dict.update({self.batch_id: i, self.step: step})
454
455          if not isinstance(K.learning_phase(), int):
456            feed_dict[K.learning_phase()] = False
457
458          sess.run(self.assign_embeddings, feed_dict=feed_dict)
459          self.saver.save(sess,
460                          os.path.join(self.log_dir, 'keras_embedding.ckpt'),
461                          epoch)
462
463          i += self.batch_size
464
465  def on_train_end(self, logs=None):
466    if self._is_profiling:
467      profiler.stop()
468      self._is_profiling = False
469    self.writer.close()
470