1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-import-not-at-top
16"""Callbacks: utilities called at certain points during model training.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import os
23
24import numpy as np
25
26from tensorflow.python.eager import context
27from tensorflow.python.eager import profiler
28from tensorflow.python.framework import dtypes
29from tensorflow.python.keras import backend as K
30from tensorflow.python.keras import callbacks
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import state_ops
33from tensorflow.python.ops import summary_ops_v2
34from tensorflow.python.ops import variables
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.summary import summary as tf_summary
37from tensorflow.python.training import saver
38from tensorflow.python.util.tf_export import keras_export
39
40
41@keras_export(v1=['keras.callbacks.TensorBoard'])
42class TensorBoard(callbacks.Callback):
43  # pylint: disable=line-too-long
44  """Enable visualizations for TensorBoard.
45
46  TensorBoard is a visualization tool provided with TensorFlow.
47
48  This callback logs events for TensorBoard, including:
49  * Metrics summary plots
50  * Training graph visualization
51  * Activation histograms
52  * Sampled profiling
53
54  If you have installed TensorFlow with pip, you should be able
55  to launch TensorBoard from the command line:
56
57  ```sh
58  tensorboard --logdir=path_to_your_logs
59  ```
60
61  You can find more information about TensorBoard
62  [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
63
64  Arguments:
65      log_dir: the path of the directory where to save the log files to be
66        parsed by TensorBoard.
67      histogram_freq: frequency (in epochs) at which to compute activation and
68        weight histograms for the layers of the model. If set to 0, histograms
69        won't be computed. Validation data (or split) must be specified for
70        histogram visualizations.
71      write_graph: whether to visualize the graph in TensorBoard. The log file
72        can become quite large when write_graph is set to True.
73      write_grads: whether to visualize gradient histograms in TensorBoard.
74        `histogram_freq` must be greater than 0.
75      batch_size: size of batch of inputs to feed to the network for histograms
76        computation.
77      write_images: whether to write model weights to visualize as image in
78        TensorBoard.
79      embeddings_freq: frequency (in epochs) at which selected embedding layers
80        will be saved. If set to 0, embeddings won't be computed. Data to be
81        visualized in TensorBoard's Embedding tab must be passed as
82        `embeddings_data`.
83      embeddings_layer_names: a list of names of layers to keep eye on. If None
84        or empty list all the embedding layer will be watched.
85      embeddings_metadata: a dictionary which maps layer name to a file name in
86        which metadata for this embedding layer is saved. See the
87          [details](https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
88            about metadata files format. In case if the same metadata file is
89            used for all embedding layers, string can be passed.
90      embeddings_data: data to be embedded at layers specified in
91        `embeddings_layer_names`. Numpy array (if the model has a single input)
92        or list of Numpy arrays (if the model has multiple inputs). Learn [more
93        about
94            embeddings](https://www.tensorflow.org/programmers_guide/embedding)
95      update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`,
96        writes the losses and metrics to TensorBoard after each batch. The same
97        applies for `'epoch'`. If using an integer, let's say `1000`, the
98        callback will write the metrics and losses to TensorBoard every 1000
99        samples. Note that writing too frequently to TensorBoard can slow down
100        your training.
101      profile_batch: Profile the batch to sample compute characteristics. By
102        default, it will profile the second batch. Set profile_batch=0 to
103        disable profiling.
104
105  Raises:
106      ValueError: If histogram_freq is set and no validation data is provided.
107
108  @compatibility(eager)
109  Using the `TensorBoard` callback will work when eager execution is enabled,
110  with the restriction that outputting histogram summaries of weights and
111  gradients is not supported. Consequently, `histogram_freq` will be ignored.
112  @end_compatibility
113  """
114
115  # pylint: enable=line-too-long
116
117  def __init__(self,
118               log_dir='./logs',
119               histogram_freq=0,
120               batch_size=32,
121               write_graph=True,
122               write_grads=False,
123               write_images=False,
124               embeddings_freq=0,
125               embeddings_layer_names=None,
126               embeddings_metadata=None,
127               embeddings_data=None,
128               update_freq='epoch',
129               profile_batch=2):
130    super(TensorBoard, self).__init__()
131    self.log_dir = log_dir
132    self.histogram_freq = histogram_freq
133    if self.histogram_freq and context.executing_eagerly():
134      logging.warning(
135          UserWarning('Weight and gradient histograms not supported for eager'
136                      'execution, setting `histogram_freq` to `0`.'))
137      self.histogram_freq = 0
138    self.merged = None
139    self.write_graph = write_graph
140    self.write_grads = write_grads
141    self.write_images = write_images
142    self.batch_size = batch_size
143    self._current_batch = 0
144    self._total_batches_seen = 0
145    self._total_val_batches_seen = 0
146    self.embeddings_freq = embeddings_freq
147    self.embeddings_layer_names = embeddings_layer_names
148    self.embeddings_metadata = embeddings_metadata
149    self.embeddings_data = embeddings_data
150    if update_freq == 'batch':
151      self.update_freq = 1
152    else:
153      self.update_freq = update_freq
154    self._samples_seen = 0
155    self._samples_seen_at_last_write = 0
156    # TODO(fishx): Add a link to the full profiler tutorial.
157    self._profile_batch = profile_batch
158    # One profiler session is running if it is True.
159    self._is_profiling = False
160
161    # TensorBoard should only write summaries on the chief when in a
162    # Multi-Worker setting.
163    self._chief_worker_only = True
164
165  def _init_writer(self, model):
166    """Sets file writer."""
167    if context.executing_eagerly():
168      self.writer = summary_ops_v2.create_file_writer(self.log_dir)
169      if not model.run_eagerly and self.write_graph:
170        with self.writer.as_default():
171          summary_ops_v2.graph(K.get_graph())
172    elif self.write_graph:
173      self.writer = tf_summary.FileWriter(self.log_dir, K.get_graph())
174    else:
175      self.writer = tf_summary.FileWriter(self.log_dir)
176
177  def _make_histogram_ops(self, model):
178    """Defines histogram ops when histogram_freq > 0."""
179    # only make histogram summary op if it hasn't already been made
180    if self.histogram_freq and self.merged is None:
181      for layer in self.model.layers:
182        for weight in layer.weights:
183          mapped_weight_name = weight.name.replace(':', '_')
184          tf_summary.histogram(mapped_weight_name, weight)
185          if self.write_images:
186            w_img = array_ops.squeeze(weight)
187            shape = K.int_shape(w_img)
188            if len(shape) == 2:  # dense layer kernel case
189              if shape[0] > shape[1]:
190                w_img = array_ops.transpose(w_img)
191                shape = K.int_shape(w_img)
192              w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1])
193            elif len(shape) == 3:  # convnet case
194              if K.image_data_format() == 'channels_last':
195                # switch to channels_first to display
196                # every kernel as a separate image
197                w_img = array_ops.transpose(w_img, perm=[2, 0, 1])
198                shape = K.int_shape(w_img)
199              w_img = array_ops.reshape(w_img,
200                                        [shape[0], shape[1], shape[2], 1])
201            elif len(shape) == 1:  # bias case
202              w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1])
203            else:
204              # not possible to handle 3D convnets etc.
205              continue
206
207            shape = K.int_shape(w_img)
208            assert len(shape) == 4 and shape[-1] in [1, 3, 4]
209            tf_summary.image(mapped_weight_name, w_img)
210
211        if self.write_grads:
212          for weight in layer.trainable_weights:
213            mapped_weight_name = weight.name.replace(':', '_')
214            grads = model.optimizer.get_gradients(model.total_loss, weight)
215
216            def is_indexed_slices(grad):
217              return type(grad).__name__ == 'IndexedSlices'
218
219            grads = [
220                grad.values if is_indexed_slices(grad) else grad
221                for grad in grads
222            ]
223            tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads)
224
225        if hasattr(layer, 'output'):
226          if isinstance(layer.output, list):
227            for i, output in enumerate(layer.output):
228              tf_summary.histogram('{}_out_{}'.format(layer.name, i), output)
229          else:
230            tf_summary.histogram('{}_out'.format(layer.name), layer.output)
231
232  def set_model(self, model):
233    """Sets Keras model and creates summary ops."""
234
235    self.model = model
236    self._init_writer(model)
237    # histogram summaries only enabled in graph mode
238    if not context.executing_eagerly():
239      self._make_histogram_ops(model)
240      self.merged = tf_summary.merge_all()
241
242    # If both embedding_freq and embeddings_data are available, we will
243    # visualize embeddings.
244    if self.embeddings_freq and self.embeddings_data is not None:
245      # Avoid circular dependency.
246      from tensorflow.python.keras.engine import training_utils  # pylint: disable=g-import-not-at-top
247      self.embeddings_data = training_utils.standardize_input_data(
248          self.embeddings_data, model.input_names)
249
250      # If embedding_layer_names are not provided, get all of the embedding
251      # layers from the model.
252      embeddings_layer_names = self.embeddings_layer_names
253      if not embeddings_layer_names:
254        embeddings_layer_names = [
255            layer.name
256            for layer in self.model.layers
257            if type(layer).__name__ == 'Embedding'
258        ]
259
260      self.assign_embeddings = []
261      embeddings_vars = {}
262
263      self.batch_id = batch_id = array_ops.placeholder(dtypes.int32)
264      self.step = step = array_ops.placeholder(dtypes.int32)
265
266      for layer in self.model.layers:
267        if layer.name in embeddings_layer_names:
268          embedding_input = self.model.get_layer(layer.name).output
269          embedding_size = np.prod(embedding_input.shape[1:])
270          embedding_input = array_ops.reshape(embedding_input,
271                                              (step, int(embedding_size)))
272          shape = (self.embeddings_data[0].shape[0], int(embedding_size))
273          embedding = variables.Variable(
274              array_ops.zeros(shape), name=layer.name + '_embedding')
275          embeddings_vars[layer.name] = embedding
276          batch = state_ops.assign(embedding[batch_id:batch_id + step],
277                                   embedding_input)
278          self.assign_embeddings.append(batch)
279
280      self.saver = saver.Saver(list(embeddings_vars.values()))
281
282      # Create embeddings_metadata dictionary
283      if isinstance(self.embeddings_metadata, str):
284        embeddings_metadata = {
285            layer_name: self.embeddings_metadata
286            for layer_name in embeddings_vars.keys()
287        }
288      else:
289        # If embedding_metadata is already a dictionary
290        embeddings_metadata = self.embeddings_metadata
291
292      try:
293        from tensorboard.plugins import projector
294      except ImportError:
295        raise ImportError('Failed to import TensorBoard. Please make sure that '
296                          'TensorBoard integration is complete."')
297
298      # TODO(psv): Add integration tests to test embedding visualization
299      # with TensorBoard callback. We are unable to write a unit test for this
300      # because TensorBoard dependency assumes TensorFlow package is installed.
301      config = projector.ProjectorConfig()
302      for layer_name, tensor in embeddings_vars.items():
303        embedding = config.embeddings.add()
304        embedding.tensor_name = tensor.name
305
306        if (embeddings_metadata is not None and
307            layer_name in embeddings_metadata):
308          embedding.metadata_path = embeddings_metadata[layer_name]
309
310      projector.visualize_embeddings(self.writer, config)
311
312  def _fetch_callback(self, summary):
313    self.writer.add_summary(summary, self._total_val_batches_seen)
314    self._total_val_batches_seen += 1
315
316  def _write_custom_summaries(self, step, logs=None):
317    """Writes metrics out as custom scalar summaries.
318
319    Arguments:
320        step: the global step to use for TensorBoard.
321        logs: dict. Keys are scalar summary names, values are
322            NumPy scalars.
323
324    """
325    logs = logs or {}
326    if context.executing_eagerly():
327      # use v2 summary ops
328      with self.writer.as_default(), summary_ops_v2.always_record_summaries():
329        for name, value in logs.items():
330          if isinstance(value, np.ndarray):
331            value = value.item()
332          summary_ops_v2.scalar(name, value, step=step)
333    else:
334      # use FileWriter from v1 summary
335      for name, value in logs.items():
336        if isinstance(value, np.ndarray):
337          value = value.item()
338        summary = tf_summary.Summary()
339        summary_value = summary.value.add()
340        summary_value.simple_value = value
341        summary_value.tag = name
342        self.writer.add_summary(summary, step)
343    self.writer.flush()
344
345  def on_batch_end(self, batch, logs=None):
346    """Writes scalar summaries for metrics on every training batch.
347
348    Performs profiling if current batch is in profiler_batches.
349    """
350    # Don't output batch_size and batch number as TensorBoard summaries
351    logs = logs or {}
352    self._samples_seen += logs.get('size', 1)
353    samples_seen_since = self._samples_seen - self._samples_seen_at_last_write
354    if self.update_freq != 'epoch' and samples_seen_since >= self.update_freq:
355      batch_logs = {('batch_' + k): v
356                    for k, v in logs.items()
357                    if k not in ['batch', 'size', 'num_steps']}
358      self._write_custom_summaries(self._total_batches_seen, batch_logs)
359      self._samples_seen_at_last_write = self._samples_seen
360    self._total_batches_seen += 1
361    if self._is_profiling:
362      profiler.save(self.log_dir, profiler.stop())
363      self._is_profiling = False
364    elif (not self._is_profiling and
365          self._total_batches_seen == self._profile_batch - 1):
366      profiler.start()
367      self._is_profiling = True
368
369  def on_train_begin(self, logs=None):
370    if self._profile_batch == 1:
371      profiler.start()
372      self._is_profiling = True
373
374  def on_epoch_begin(self, epoch, logs=None):
375    """Add histogram op to Model eval_function callbacks, reset batch count."""
376
377    # check if histogram summary should be run for this epoch
378    if self.histogram_freq and epoch % self.histogram_freq == 0:
379      self._epoch = epoch
380      # pylint: disable=protected-access
381      # add the histogram summary op if it should run this epoch
382      self.model._make_test_function()
383      if self.merged not in self.model.test_function.fetches:
384        self.model.test_function.fetches.append(self.merged)
385        self.model.test_function.fetch_callbacks[
386            self.merged] = self._fetch_callback
387      # pylint: enable=protected-access
388
389  def on_epoch_end(self, epoch, logs=None):
390    """Checks if summary ops should run next epoch, logs scalar summaries."""
391
392    # don't output batch_size and
393    # batch number as TensorBoard summaries
394    logs = {('epoch_' + k): v
395            for k, v in logs.items()
396            if k not in ['batch', 'size', 'num_steps']}
397    if self.update_freq == 'epoch':
398      step = epoch
399    else:
400      step = self._samples_seen
401    self._write_custom_summaries(step, logs)
402
403    # pop the histogram summary op after each epoch
404    if self.histogram_freq:
405      # pylint: disable=protected-access
406      if self.merged in self.model.test_function.fetches:
407        self.model.test_function.fetches.remove(self.merged)
408      if self.merged in self.model.test_function.fetch_callbacks:
409        self.model.test_function.fetch_callbacks.pop(self.merged)
410      # pylint: enable=protected-access
411
412    if self.embeddings_data is None and self.embeddings_freq:
413      raise ValueError('To visualize embeddings, embeddings_data must '
414                       'be provided.')
415
416    if self.embeddings_freq and self.embeddings_data is not None:
417      if epoch % self.embeddings_freq == 0:
418        # We need a second forward-pass here because we're passing
419        # the `embeddings_data` explicitly. This design allows to pass
420        # arbitrary data as `embeddings_data` and results from the fact
421        # that we need to know the size of the `tf.Variable`s which
422        # hold the embeddings in `set_model`. At this point, however,
423        # the `validation_data` is not yet set.
424
425        embeddings_data = self.embeddings_data
426        n_samples = embeddings_data[0].shape[0]
427        i = 0
428        sess = K.get_session()
429        while i < n_samples:
430          step = min(self.batch_size, n_samples - i)
431          batch = slice(i, i + step)
432
433          if isinstance(self.model.input, list):
434            feed_dict = {
435                model_input: embeddings_data[idx][batch]
436                for idx, model_input in enumerate(self.model.input)
437            }
438          else:
439            feed_dict = {self.model.input: embeddings_data[0][batch]}
440
441          feed_dict.update({self.batch_id: i, self.step: step})
442
443          if not isinstance(K.learning_phase(), int):
444            feed_dict[K.learning_phase()] = False
445
446          sess.run(self.assign_embeddings, feed_dict=feed_dict)
447          self.saver.save(sess,
448                          os.path.join(self.log_dir, 'keras_embedding.ckpt'),
449                          epoch)
450
451          i += self.batch_size
452
453  def on_train_end(self, logs=None):
454    if self._is_profiling:
455      profiler.save(self.log_dir, profiler.stop())
456      self._is_profiling = False
457    self.writer.close()
458