1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Keras callbacks."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import csv
23import os
24import re
25import shutil
26import sys
27import threading
28import unittest
29
30from absl.testing import parameterized
31import numpy as np
32
33from tensorflow.python import keras
34from tensorflow.python.data.ops import dataset_ops
35from tensorflow.python.framework import random_seed
36from tensorflow.python.keras import keras_parameterized
37from tensorflow.python.keras import testing_utils
38from tensorflow.python.ops import array_ops
39from tensorflow.python.platform import test
40from tensorflow.python.platform import tf_logging as logging
41from tensorflow.python.summary import summary_iterator
42from tensorflow.python.training import adam
43
44try:
45  import h5py  # pylint:disable=g-import-not-at-top
46except ImportError:
47  h5py = None
48
49try:
50  import requests  # pylint:disable=g-import-not-at-top
51except ImportError:
52  requests = None
53
54
55TRAIN_SAMPLES = 10
56TEST_SAMPLES = 10
57NUM_CLASSES = 2
58INPUT_DIM = 3
59NUM_HIDDEN = 5
60BATCH_SIZE = 5
61
62
63class Counter(keras.callbacks.Callback):
64  """Counts the number of times each callback method was run.
65
66  Attributes:
67    method_counts: dict. Contains the counts of time  each callback method was
68      run.
69  """
70
71  def __init__(self):
72    self.method_counts = collections.defaultdict(int)
73    methods_to_count = [
74        'on_batch_begin', 'on_batch_end', 'on_epoch_begin', 'on_epoch_end',
75        'on_predict_batch_begin', 'on_predict_batch_end', 'on_predict_begin',
76        'on_predict_end', 'on_test_batch_begin', 'on_test_batch_end',
77        'on_test_begin', 'on_test_end', 'on_train_batch_begin',
78        'on_train_batch_end', 'on_train_begin', 'on_train_end'
79    ]
80    for method_name in methods_to_count:
81      setattr(self, method_name,
82              self.wrap_with_counts(method_name, getattr(self, method_name)))
83
84  def wrap_with_counts(self, method_name, method):
85
86    def _call_and_count(*args, **kwargs):
87      self.method_counts[method_name] += 1
88      return method(*args, **kwargs)
89
90    return _call_and_count
91
92
93def _get_numpy():
94  return np.ones((10, 10)), np.ones((10, 1))
95
96
97def _get_sequence():
98
99  class MySequence(keras.utils.data_utils.Sequence):
100
101    def __getitem__(self, _):
102      return np.ones((2, 10)), np.ones((2, 1))
103
104    def __len__(self):
105      return 5
106
107  return MySequence(), None
108
109
110@keras_parameterized.run_with_all_model_types
111@keras_parameterized.run_all_keras_modes
112class CallbackCountsTest(keras_parameterized.TestCase):
113
114  def _check_counts(self, counter, expected_counts):
115    """Checks that the counts registered by `counter` are those expected."""
116    for method_name, expected_count in expected_counts.items():
117      self.assertEqual(
118          counter.method_counts[method_name],
119          expected_count,
120          msg='For method {}: expected {}, got: {}'.format(
121              method_name, expected_count, counter.method_counts[method_name]))
122
123  def _get_model(self):
124    layers = [
125        keras.layers.Dense(10, activation='relu'),
126        keras.layers.Dense(1, activation='sigmoid')
127    ]
128    model = testing_utils.get_model_from_layers(layers, input_shape=(10,))
129    model.compile(
130        adam.AdamOptimizer(0.001),
131        'binary_crossentropy',
132        run_eagerly=testing_utils.should_run_eagerly())
133    return model
134
135  @parameterized.named_parameters(('with_numpy', _get_numpy()),
136                                  ('with_sequence', _get_sequence()))
137  def test_callback_hooks_are_called_in_fit(self, data):
138    x, y = data
139    val_x, val_y = np.ones((4, 10)), np.ones((4, 1))
140
141    model = self._get_model()
142    counter = Counter()
143    model.fit(
144        x,
145        y,
146        validation_data=(val_x, val_y),
147        batch_size=2,
148        epochs=5,
149        callbacks=[counter])
150
151    self._check_counts(
152        counter, {
153            'on_batch_begin': 25,
154            'on_batch_end': 25,
155            'on_epoch_begin': 5,
156            'on_epoch_end': 5,
157            'on_predict_batch_begin': 0,
158            'on_predict_batch_end': 0,
159            'on_predict_begin': 0,
160            'on_predict_end': 0,
161            'on_test_batch_begin': 10,
162            'on_test_batch_end': 10,
163            'on_test_begin': 5,
164            'on_test_end': 5,
165            'on_train_batch_begin': 25,
166            'on_train_batch_end': 25,
167            'on_train_begin': 1,
168            'on_train_end': 1
169        })
170
171  @parameterized.named_parameters(('with_numpy', _get_numpy()),
172                                  ('with_sequence', _get_sequence()))
173  def test_callback_hooks_are_called_in_evaluate(self, data):
174    x, y = data
175
176    model = self._get_model()
177    counter = Counter()
178    model.evaluate(x, y, batch_size=2, callbacks=[counter])
179    self._check_counts(
180        counter, {
181            'on_test_batch_begin': 5,
182            'on_test_batch_end': 5,
183            'on_test_begin': 1,
184            'on_test_end': 1
185        })
186
187  @parameterized.named_parameters(('with_numpy', _get_numpy()),
188                                  ('with_sequence', _get_sequence()))
189  def test_callback_hooks_are_called_in_predict(self, data):
190    x = data[0]
191
192    model = self._get_model()
193    counter = Counter()
194    model.predict(x, batch_size=2, callbacks=[counter])
195    self._check_counts(
196        counter, {
197            'on_predict_batch_begin': 5,
198            'on_predict_batch_end': 5,
199            'on_predict_begin': 1,
200            'on_predict_end': 1
201        })
202
203  def test_callback_list_methods(self):
204    counter = Counter()
205    callback_list = keras.callbacks.CallbackList([counter])
206
207    batch = 0
208    callback_list.on_test_batch_begin(batch)
209    callback_list.on_test_batch_end(batch)
210    callback_list.on_predict_batch_begin(batch)
211    callback_list.on_predict_batch_end(batch)
212
213    self._check_counts(
214        counter, {
215            'on_test_batch_begin': 1,
216            'on_test_batch_end': 1,
217            'on_predict_batch_begin': 1,
218            'on_predict_batch_end': 1
219        })
220
221
222class KerasCallbacksTest(keras_parameterized.TestCase):
223
224  def _get_model(self, input_shape=None):
225    layers = [
226        keras.layers.Dense(3, activation='relu'),
227        keras.layers.Dense(2, activation='softmax')
228    ]
229    model = testing_utils.get_model_from_layers(layers, input_shape=input_shape)
230    model.compile(
231        loss='mse',
232        optimizer='rmsprop',
233        metrics=[keras.metrics.CategoricalAccuracy(name='my_acc')],
234        run_eagerly=testing_utils.should_run_eagerly())
235    return model
236
237  @keras_parameterized.run_with_all_model_types
238  @keras_parameterized.run_all_keras_modes
239  def test_progbar_logging(self):
240    model = self._get_model(input_shape=(3,))
241
242    x = array_ops.ones((50, 3))
243    y = array_ops.zeros((50, 2))
244    dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(10)
245    expected_log = r'(.*- loss:.*- my_acc:.*)+'
246
247    with self.captureWritesToStream(sys.stdout) as printed:
248      model.fit(dataset, epochs=2, steps_per_epoch=10)
249      self.assertRegexpMatches(printed.contents(), expected_log)
250
251  @keras_parameterized.run_with_all_model_types(exclude_models='functional')
252  @keras_parameterized.run_all_keras_modes
253  def test_progbar_logging_deferred_model_build(self):
254    model = self._get_model()
255    self.assertFalse(model.built)
256
257    x = array_ops.ones((50, 3))
258    y = array_ops.zeros((50, 2))
259    dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(10)
260    expected_log = r'(.*- loss:.*- my_acc:.*)+'
261
262    with self.captureWritesToStream(sys.stdout) as printed:
263      model.fit(dataset, epochs=2, steps_per_epoch=10)
264      self.assertRegexpMatches(printed.contents(), expected_log)
265
266  @keras_parameterized.run_with_all_model_types
267  def test_ModelCheckpoint(self):
268    if h5py is None:
269      return  # Skip test if models cannot be saved.
270
271    layers = [
272        keras.layers.Dense(NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'),
273        keras.layers.Dense(NUM_CLASSES, activation='softmax')
274    ]
275    model = testing_utils.get_model_from_layers(layers, input_shape=(10,))
276    model.compile(
277        loss='categorical_crossentropy', optimizer='rmsprop', metrics=['acc'])
278
279    temp_dir = self.get_temp_dir()
280    self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
281
282    filepath = os.path.join(temp_dir, 'checkpoint')
283    (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
284        train_samples=TRAIN_SAMPLES,
285        test_samples=TEST_SAMPLES,
286        input_shape=(INPUT_DIM,),
287        num_classes=NUM_CLASSES)
288    y_test = keras.utils.to_categorical(y_test)
289    y_train = keras.utils.to_categorical(y_train)
290    # case 1
291    monitor = 'val_loss'
292    save_best_only = False
293    mode = 'auto'
294
295    model = keras.models.Sequential()
296    model.add(
297        keras.layers.Dense(
298            NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
299    model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
300    model.compile(
301        loss='categorical_crossentropy', optimizer='rmsprop', metrics=['acc'])
302
303    cbks = [
304        keras.callbacks.ModelCheckpoint(
305            filepath,
306            monitor=monitor,
307            save_best_only=save_best_only,
308            mode=mode)
309    ]
310    model.fit(
311        x_train,
312        y_train,
313        batch_size=BATCH_SIZE,
314        validation_data=(x_test, y_test),
315        callbacks=cbks,
316        epochs=1,
317        verbose=0)
318    assert os.path.exists(filepath)
319    os.remove(filepath)
320
321    # case 2
322    mode = 'min'
323    cbks = [
324        keras.callbacks.ModelCheckpoint(
325            filepath,
326            monitor=monitor,
327            save_best_only=save_best_only,
328            mode=mode)
329    ]
330    model.fit(
331        x_train,
332        y_train,
333        batch_size=BATCH_SIZE,
334        validation_data=(x_test, y_test),
335        callbacks=cbks,
336        epochs=1,
337        verbose=0)
338    assert os.path.exists(filepath)
339    os.remove(filepath)
340
341    # case 3
342    mode = 'max'
343    monitor = 'val_acc'
344    cbks = [
345        keras.callbacks.ModelCheckpoint(
346            filepath,
347            monitor=monitor,
348            save_best_only=save_best_only,
349            mode=mode)
350    ]
351    model.fit(
352        x_train,
353        y_train,
354        batch_size=BATCH_SIZE,
355        validation_data=(x_test, y_test),
356        callbacks=cbks,
357        epochs=1,
358        verbose=0)
359    assert os.path.exists(filepath)
360    os.remove(filepath)
361
362    # case 4
363    save_best_only = True
364    cbks = [
365        keras.callbacks.ModelCheckpoint(
366            filepath,
367            monitor=monitor,
368            save_best_only=save_best_only,
369            mode=mode)
370    ]
371    model.fit(
372        x_train,
373        y_train,
374        batch_size=BATCH_SIZE,
375        validation_data=(x_test, y_test),
376        callbacks=cbks,
377        epochs=1,
378        verbose=0)
379    assert os.path.exists(filepath)
380    os.remove(filepath)
381
382    # Case: metric not available.
383    cbks = [
384        keras.callbacks.ModelCheckpoint(
385            filepath,
386            monitor='unknown',
387            save_best_only=True)
388    ]
389    model.fit(
390        x_train,
391        y_train,
392        batch_size=BATCH_SIZE,
393        validation_data=(x_test, y_test),
394        callbacks=cbks,
395        epochs=1,
396        verbose=0)
397    # File won't be written.
398    assert not os.path.exists(filepath)
399
400    # case 5
401    save_best_only = False
402    period = 2
403    mode = 'auto'
404
405    filepath = os.path.join(temp_dir, 'checkpoint.{epoch:02d}.h5')
406    cbks = [
407        keras.callbacks.ModelCheckpoint(
408            filepath,
409            monitor=monitor,
410            save_best_only=save_best_only,
411            mode=mode,
412            period=period)
413    ]
414    model.fit(
415        x_train,
416        y_train,
417        batch_size=BATCH_SIZE,
418        validation_data=(x_test, y_test),
419        callbacks=cbks,
420        epochs=4,
421        verbose=1)
422    assert os.path.exists(filepath.format(epoch=2))
423    assert os.path.exists(filepath.format(epoch=4))
424    os.remove(filepath.format(epoch=2))
425    os.remove(filepath.format(epoch=4))
426    assert not os.path.exists(filepath.format(epoch=1))
427    assert not os.path.exists(filepath.format(epoch=3))
428
429    # Invalid use: this will raise a warning but not an Exception.
430    keras.callbacks.ModelCheckpoint(
431        filepath,
432        monitor=monitor,
433        save_best_only=save_best_only,
434        mode='unknown')
435
436  def test_EarlyStopping(self):
437    with self.cached_session():
438      np.random.seed(123)
439      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
440          train_samples=TRAIN_SAMPLES,
441          test_samples=TEST_SAMPLES,
442          input_shape=(INPUT_DIM,),
443          num_classes=NUM_CLASSES)
444      y_test = keras.utils.to_categorical(y_test)
445      y_train = keras.utils.to_categorical(y_train)
446      model = testing_utils.get_small_sequential_mlp(
447          num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
448      model.compile(
449          loss='categorical_crossentropy', optimizer='rmsprop', metrics=['acc'])
450
451      cases = [
452          ('max', 'val_acc'),
453          ('min', 'val_loss'),
454          ('auto', 'val_acc'),
455          ('auto', 'loss'),
456          ('unknown', 'unknown')
457      ]
458      for mode, monitor in cases:
459        patience = 0
460        cbks = [
461            keras.callbacks.EarlyStopping(
462                patience=patience, monitor=monitor, mode=mode)
463        ]
464        model.fit(
465            x_train,
466            y_train,
467            batch_size=BATCH_SIZE,
468            validation_data=(x_test, y_test),
469            callbacks=cbks,
470            epochs=5,
471            verbose=0)
472
473  def test_EarlyStopping_reuse(self):
474    with self.cached_session():
475      np.random.seed(1337)
476      patience = 3
477      data = np.random.random((100, 1))
478      labels = np.where(data > 0.5, 1, 0)
479      model = keras.models.Sequential((keras.layers.Dense(
480          1, input_dim=1, activation='relu'), keras.layers.Dense(
481              1, activation='sigmoid'),))
482      model.compile(
483          optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy'])
484      weights = model.get_weights()
485
486      stopper = keras.callbacks.EarlyStopping(monitor='acc', patience=patience)
487      hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20)
488      assert len(hist.epoch) >= patience
489
490      # This should allow training to go for at least `patience` epochs
491      model.set_weights(weights)
492      hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20)
493      assert len(hist.epoch) >= patience
494
495  def test_EarlyStopping_with_baseline(self):
496    with self.cached_session():
497      np.random.seed(1337)
498      baseline = 0.5
499      (data, labels), _ = testing_utils.get_test_data(
500          train_samples=100,
501          test_samples=50,
502          input_shape=(1,),
503          num_classes=NUM_CLASSES)
504      model = testing_utils.get_small_sequential_mlp(
505          num_hidden=1, num_classes=1, input_dim=1)
506      model.compile(
507          optimizer='sgd', loss='binary_crossentropy', metrics=['acc'])
508
509      stopper = keras.callbacks.EarlyStopping(monitor='acc',
510                                              baseline=baseline)
511      hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20)
512      assert len(hist.epoch) == 1
513
514      patience = 3
515      stopper = keras.callbacks.EarlyStopping(monitor='acc',
516                                              patience=patience,
517                                              baseline=baseline)
518      hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20)
519      assert len(hist.epoch) >= patience
520
521  def test_EarlyStopping_final_weights_when_restoring_model_weights(self):
522
523    class DummyModel(object):
524
525      def __init__(self):
526        self.stop_training = False
527        self.weights = -1
528
529      def get_weights(self):
530        return self.weights
531
532      def set_weights(self, weights):
533        self.weights = weights
534
535      def set_weight_to_epoch(self, epoch):
536        self.weights = epoch
537
538    early_stop = keras.callbacks.EarlyStopping(monitor='val_loss',
539                                               patience=2,
540                                               restore_best_weights=True)
541    early_stop.model = DummyModel()
542    losses = [0.2, 0.15, 0.1, 0.11, 0.12]
543    # The best configuration is in the epoch 2 (loss = 0.1000).
544    epochs_trained = 0
545    early_stop.on_train_begin()
546    for epoch in range(len(losses)):
547      epochs_trained += 1
548      early_stop.model.set_weight_to_epoch(epoch=epoch)
549      early_stop.on_epoch_end(epoch, logs={'val_loss': losses[epoch]})
550      if early_stop.model.stop_training:
551        break
552    # The best configuration is in epoch 2 (loss = 0.1000),
553    # and while patience = 2, we're restoring the best weights,
554    # so we end up at the epoch with the best weights, i.e. epoch 2
555    self.assertEqual(early_stop.model.get_weights(), 2)
556
557  def test_RemoteMonitor(self):
558    if requests is None:
559      return
560
561    monitor = keras.callbacks.RemoteMonitor()
562    # This will raise a warning since the default address in unreachable:
563    monitor.on_epoch_end(0, logs={'loss': 0.})
564
565  def test_LearningRateScheduler(self):
566    with self.cached_session():
567      np.random.seed(1337)
568      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
569          train_samples=TRAIN_SAMPLES,
570          test_samples=TEST_SAMPLES,
571          input_shape=(INPUT_DIM,),
572          num_classes=NUM_CLASSES)
573      y_test = keras.utils.to_categorical(y_test)
574      y_train = keras.utils.to_categorical(y_train)
575      model = testing_utils.get_small_sequential_mlp(
576          num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
577      model.compile(
578          loss='categorical_crossentropy',
579          optimizer='sgd',
580          metrics=['accuracy'])
581
582      cbks = [keras.callbacks.LearningRateScheduler(lambda x: 1. / (1. + x))]
583      model.fit(
584          x_train,
585          y_train,
586          batch_size=BATCH_SIZE,
587          validation_data=(x_test, y_test),
588          callbacks=cbks,
589          epochs=5,
590          verbose=0)
591      assert (
592          float(keras.backend.get_value(
593              model.optimizer.lr)) - 0.2) < keras.backend.epsilon()
594
595      cbks = [keras.callbacks.LearningRateScheduler(lambda x, lr: lr / 2)]
596      model.compile(
597          loss='categorical_crossentropy',
598          optimizer='sgd',
599          metrics=['accuracy'])
600      model.fit(
601          x_train,
602          y_train,
603          batch_size=BATCH_SIZE,
604          validation_data=(x_test, y_test),
605          callbacks=cbks,
606          epochs=2,
607          verbose=0)
608      assert (
609          float(keras.backend.get_value(
610              model.optimizer.lr)) - 0.01 / 4) < keras.backend.epsilon()
611
612  def test_ReduceLROnPlateau(self):
613    with self.cached_session():
614      np.random.seed(1337)
615      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
616          train_samples=TRAIN_SAMPLES,
617          test_samples=TEST_SAMPLES,
618          input_shape=(INPUT_DIM,),
619          num_classes=NUM_CLASSES)
620      y_test = keras.utils.to_categorical(y_test)
621      y_train = keras.utils.to_categorical(y_train)
622
623      def make_model():
624        random_seed.set_random_seed(1234)
625        np.random.seed(1337)
626        model = testing_utils.get_small_sequential_mlp(
627            num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
628        model.compile(
629            loss='categorical_crossentropy',
630            optimizer=keras.optimizers.SGD(lr=0.1))
631        return model
632
633      # TODO(psv): Make sure the callback works correctly when min_delta is
634      # set as 0. Test fails when the order of this callback and assertion is
635      # interchanged.
636      model = make_model()
637      cbks = [
638          keras.callbacks.ReduceLROnPlateau(
639              monitor='val_loss',
640              factor=0.1,
641              min_delta=0,
642              patience=1,
643              cooldown=5)
644      ]
645      model.fit(
646          x_train,
647          y_train,
648          batch_size=BATCH_SIZE,
649          validation_data=(x_test, y_test),
650          callbacks=cbks,
651          epochs=2,
652          verbose=0)
653      self.assertAllClose(
654          float(keras.backend.get_value(model.optimizer.lr)), 0.1, atol=1e-4)
655
656      model = make_model()
657      # This should reduce the LR after the first epoch (due to high epsilon).
658      cbks = [
659          keras.callbacks.ReduceLROnPlateau(
660              monitor='val_loss',
661              factor=0.1,
662              min_delta=10,
663              patience=1,
664              cooldown=5)
665      ]
666      model.fit(
667          x_train,
668          y_train,
669          batch_size=BATCH_SIZE,
670          validation_data=(x_test, y_test),
671          callbacks=cbks,
672          epochs=2,
673          verbose=2)
674      self.assertAllClose(
675          float(keras.backend.get_value(model.optimizer.lr)), 0.01, atol=1e-4)
676
677  def test_ReduceLROnPlateau_patience(self):
678
679    class DummyOptimizer(object):
680
681      def __init__(self):
682        self.lr = keras.backend.variable(1.0)
683
684    class DummyModel(object):
685
686      def __init__(self):
687        self.optimizer = DummyOptimizer()
688
689    reduce_on_plateau = keras.callbacks.ReduceLROnPlateau(
690        monitor='val_loss', patience=2)
691    reduce_on_plateau.model = DummyModel()
692
693    losses = [0.0860, 0.1096, 0.1040]
694    lrs = []
695
696    for epoch in range(len(losses)):
697      reduce_on_plateau.on_epoch_end(epoch, logs={'val_loss': losses[epoch]})
698      lrs.append(keras.backend.get_value(reduce_on_plateau.model.optimizer.lr))
699
700    # The learning rates should be 1.0 except the last one
701    for lr in lrs[:-1]:
702      self.assertEqual(lr, 1.0)
703    self.assertLess(lrs[-1], 1.0)
704
705  def test_ReduceLROnPlateau_backwards_compatibility(self):
706    with test.mock.patch.object(logging, 'warning') as mock_log:
707      reduce_on_plateau = keras.callbacks.ReduceLROnPlateau(epsilon=1e-13)
708      self.assertRegexpMatches(
709          str(mock_log.call_args), '`epsilon` argument is deprecated')
710    self.assertFalse(hasattr(reduce_on_plateau, 'epsilon'))
711    self.assertTrue(hasattr(reduce_on_plateau, 'min_delta'))
712    self.assertEqual(reduce_on_plateau.min_delta, 1e-13)
713
714  def test_CSVLogger(self):
715    with self.cached_session():
716      np.random.seed(1337)
717      temp_dir = self.get_temp_dir()
718      self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
719      filepath = os.path.join(temp_dir, 'log.tsv')
720
721      sep = '\t'
722      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
723          train_samples=TRAIN_SAMPLES,
724          test_samples=TEST_SAMPLES,
725          input_shape=(INPUT_DIM,),
726          num_classes=NUM_CLASSES)
727      y_test = keras.utils.to_categorical(y_test)
728      y_train = keras.utils.to_categorical(y_train)
729
730      def make_model():
731        np.random.seed(1337)
732        model = testing_utils.get_small_sequential_mlp(
733            num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
734        model.compile(
735            loss='categorical_crossentropy',
736            optimizer=keras.optimizers.SGD(lr=0.1),
737            metrics=['accuracy'])
738        return model
739
740      # case 1, create new file with defined separator
741      model = make_model()
742      cbks = [keras.callbacks.CSVLogger(filepath, separator=sep)]
743      model.fit(
744          x_train,
745          y_train,
746          batch_size=BATCH_SIZE,
747          validation_data=(x_test, y_test),
748          callbacks=cbks,
749          epochs=1,
750          verbose=0)
751
752      assert os.path.exists(filepath)
753      with open(filepath) as csvfile:
754        dialect = csv.Sniffer().sniff(csvfile.read())
755      assert dialect.delimiter == sep
756      del model
757      del cbks
758
759      # case 2, append data to existing file, skip header
760      model = make_model()
761      cbks = [keras.callbacks.CSVLogger(filepath, separator=sep, append=True)]
762      model.fit(
763          x_train,
764          y_train,
765          batch_size=BATCH_SIZE,
766          validation_data=(x_test, y_test),
767          callbacks=cbks,
768          epochs=1,
769          verbose=0)
770
771      # case 3, reuse of CSVLogger object
772      model.fit(
773          x_train,
774          y_train,
775          batch_size=BATCH_SIZE,
776          validation_data=(x_test, y_test),
777          callbacks=cbks,
778          epochs=2,
779          verbose=0)
780
781      with open(filepath) as csvfile:
782        list_lines = csvfile.readlines()
783        for line in list_lines:
784          assert line.count(sep) == 4
785        assert len(list_lines) == 5
786        output = ' '.join(list_lines)
787        assert len(re.findall('epoch', output)) == 1
788
789      os.remove(filepath)
790
791  def test_stop_training_csv(self):
792    # Test that using the CSVLogger callback with the TerminateOnNaN callback
793    # does not result in invalid CSVs.
794    np.random.seed(1337)
795    tmpdir = self.get_temp_dir()
796    self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
797
798    with self.cached_session():
799      fp = os.path.join(tmpdir, 'test.csv')
800      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
801          train_samples=TRAIN_SAMPLES,
802          test_samples=TEST_SAMPLES,
803          input_shape=(INPUT_DIM,),
804          num_classes=NUM_CLASSES)
805
806      y_test = keras.utils.to_categorical(y_test)
807      y_train = keras.utils.to_categorical(y_train)
808      cbks = [keras.callbacks.TerminateOnNaN(), keras.callbacks.CSVLogger(fp)]
809      model = keras.models.Sequential()
810      for _ in range(5):
811        model.add(keras.layers.Dense(2, input_dim=INPUT_DIM, activation='relu'))
812      model.add(keras.layers.Dense(NUM_CLASSES, activation='linear'))
813      model.compile(loss='mean_squared_error',
814                    optimizer='rmsprop')
815
816      def data_generator():
817        i = 0
818        max_batch_index = len(x_train) // BATCH_SIZE
819        tot = 0
820        while 1:
821          if tot > 3 * len(x_train):
822            yield (np.ones([BATCH_SIZE, INPUT_DIM]) * np.nan,
823                   np.ones([BATCH_SIZE, NUM_CLASSES]) * np.nan)
824          else:
825            yield (x_train[i * BATCH_SIZE: (i + 1) * BATCH_SIZE],
826                   y_train[i * BATCH_SIZE: (i + 1) * BATCH_SIZE])
827          i += 1
828          tot += 1
829          i %= max_batch_index
830
831      history = model.fit_generator(data_generator(),
832                                    len(x_train) // BATCH_SIZE,
833                                    validation_data=(x_test, y_test),
834                                    callbacks=cbks,
835                                    epochs=20)
836      loss = history.history['loss']
837      assert len(loss) > 1
838      assert loss[-1] == np.inf or np.isnan(loss[-1])
839
840      values = []
841      with open(fp) as f:
842        for x in csv.reader(f):
843          # In windows, due to \r\n line ends we may end up reading empty lines
844          # after each line. Skip empty lines.
845          if x:
846            values.append(x)
847      assert 'nan' in values[-1], 'The last epoch was not logged.'
848
849  def test_TerminateOnNaN(self):
850    with self.cached_session():
851      np.random.seed(1337)
852      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
853          train_samples=TRAIN_SAMPLES,
854          test_samples=TEST_SAMPLES,
855          input_shape=(INPUT_DIM,),
856          num_classes=NUM_CLASSES)
857
858      y_test = keras.utils.to_categorical(y_test)
859      y_train = keras.utils.to_categorical(y_train)
860      cbks = [keras.callbacks.TerminateOnNaN()]
861      model = keras.models.Sequential()
862      initializer = keras.initializers.Constant(value=1e5)
863      for _ in range(5):
864        model.add(
865            keras.layers.Dense(
866                2,
867                input_dim=INPUT_DIM,
868                activation='relu',
869                kernel_initializer=initializer))
870      model.add(keras.layers.Dense(NUM_CLASSES))
871      model.compile(loss='mean_squared_error', optimizer='rmsprop')
872
873      history = model.fit(
874          x_train,
875          y_train,
876          batch_size=BATCH_SIZE,
877          validation_data=(x_test, y_test),
878          callbacks=cbks,
879          epochs=20)
880      loss = history.history['loss']
881      self.assertEqual(len(loss), 1)
882      self.assertEqual(loss[0], np.inf)
883
884  @unittest.skipIf(
885      os.name == 'nt',
886      'use_multiprocessing=True does not work on windows properly.')
887  def test_LambdaCallback(self):
888    with self.cached_session():
889      np.random.seed(1337)
890      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
891          train_samples=TRAIN_SAMPLES,
892          test_samples=TEST_SAMPLES,
893          input_shape=(INPUT_DIM,),
894          num_classes=NUM_CLASSES)
895      y_test = keras.utils.to_categorical(y_test)
896      y_train = keras.utils.to_categorical(y_train)
897      model = keras.models.Sequential()
898      model.add(
899          keras.layers.Dense(
900              NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
901      model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
902      model.compile(
903          loss='categorical_crossentropy',
904          optimizer='sgd',
905          metrics=['accuracy'])
906
907      # Start an arbitrary process that should run during model
908      # training and be terminated after training has completed.
909      e = threading.Event()
910
911      def target():
912        e.wait()
913
914      t = threading.Thread(target=target)
915      t.start()
916      cleanup_callback = keras.callbacks.LambdaCallback(
917          on_train_end=lambda logs: e.set())
918
919      cbks = [cleanup_callback]
920      model.fit(
921          x_train,
922          y_train,
923          batch_size=BATCH_SIZE,
924          validation_data=(x_test, y_test),
925          callbacks=cbks,
926          epochs=5,
927          verbose=0)
928      t.join()
929      assert not t.is_alive()
930
931  def test_RemoteMonitorWithJsonPayload(self):
932    if requests is None:
933      self.skipTest('`requests` required to run this test')
934    with self.cached_session():
935      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
936          train_samples=TRAIN_SAMPLES,
937          test_samples=TEST_SAMPLES,
938          input_shape=(INPUT_DIM,),
939          num_classes=NUM_CLASSES)
940      y_test = keras.utils.np_utils.to_categorical(y_test)
941      y_train = keras.utils.np_utils.to_categorical(y_train)
942      model = keras.models.Sequential()
943      model.add(
944          keras.layers.Dense(
945              NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
946      model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
947      model.compile(
948          loss='categorical_crossentropy',
949          optimizer='rmsprop',
950          metrics=['accuracy'])
951      cbks = [keras.callbacks.RemoteMonitor(send_as_json=True)]
952
953      with test.mock.patch.object(requests, 'post'):
954        model.fit(
955            x_train,
956            y_train,
957            batch_size=BATCH_SIZE,
958            validation_data=(x_test, y_test),
959            callbacks=cbks,
960            epochs=1)
961
962
963# A summary that was emitted during a test. Fields:
964#   logdir: str. The logdir of the FileWriter to which the summary was
965#     written.
966#   tag: str. The name of the summary.
967_ObservedSummary = collections.namedtuple('_ObservedSummary', ('logdir', 'tag'))
968
969
970class _SummaryFile(object):
971  """A record of summary tags and the files to which they were written.
972
973  Fields `scalars`, `images`, `histograms`, and `tensors` are sets
974  containing `_ObservedSummary` values.
975  """
976
977  def __init__(self):
978    self.scalars = set()
979    self.images = set()
980    self.histograms = set()
981    self.tensors = set()
982
983
984def list_summaries(logdir):
985  """Read all summaries under the logdir into a `_SummaryFile`.
986
987  Args:
988    logdir: A path to a directory that contains zero or more event
989      files, either as direct children or in transitive subdirectories.
990      Summaries in these events must only contain old-style scalars,
991      images, and histograms. Non-summary events, like `graph_def`s, are
992      ignored.
993
994  Returns:
995    A `_SummaryFile` object reflecting all summaries written to any
996    event files in the logdir or any of its descendant directories.
997
998  Raises:
999    ValueError: If an event file contains an summary of unexpected kind.
1000  """
1001  result = _SummaryFile()
1002  for (dirpath, dirnames, filenames) in os.walk(logdir):
1003    del dirnames  # unused
1004    for filename in filenames:
1005      if not filename.startswith('events.out.'):
1006        continue
1007      path = os.path.join(dirpath, filename)
1008      for event in summary_iterator.summary_iterator(path):
1009        if not event.summary:  # (e.g., it's a `graph_def` event)
1010          continue
1011        for value in event.summary.value:
1012          tag = value.tag
1013          # Case on the `value` rather than the summary metadata because
1014          # the Keras callback uses `summary_ops_v2` to emit old-style
1015          # summaries. See b/124535134.
1016          kind = value.WhichOneof('value')
1017          container = {
1018              'simple_value': result.scalars,
1019              'image': result.images,
1020              'histo': result.histograms,
1021              'tensor': result.tensors,
1022          }.get(kind)
1023          if container is None:
1024            raise ValueError(
1025                'Unexpected summary kind %r in event file %s:\n%r'
1026                % (kind, path, event))
1027          container.add(_ObservedSummary(logdir=dirpath, tag=tag))
1028  return result
1029
1030
1031@keras_parameterized.run_with_all_model_types
1032@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
1033class TestTensorBoardV2(keras_parameterized.TestCase):
1034
1035  def setUp(self):
1036    super(TestTensorBoardV2, self).setUp()
1037    self.logdir = os.path.join(self.get_temp_dir(), 'tb')
1038    self.train_dir = os.path.join(self.logdir, 'train')
1039    self.validation_dir = os.path.join(self.logdir, 'validation')
1040
1041  def _get_model(self):
1042    layers = [
1043        keras.layers.Conv2D(8, (3, 3)),
1044        keras.layers.Flatten(),
1045        keras.layers.Dense(1)
1046    ]
1047    model = testing_utils.get_model_from_layers(layers, input_shape=(10, 10, 1))
1048    model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
1049    return model
1050
1051  def test_TensorBoard_default_logdir(self):
1052    """Regression test for cross-platform pathsep in default logdir."""
1053    os.chdir(self.get_temp_dir())
1054
1055    model = self._get_model()
1056    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1057    tb_cbk = keras.callbacks.TensorBoard()  # no logdir specified
1058
1059    model.fit(
1060        x,
1061        y,
1062        batch_size=2,
1063        epochs=2,
1064        validation_data=(x, y),
1065        callbacks=[tb_cbk])
1066
1067    summary_file = list_summaries(logdir='.')
1068    train_dir = os.path.join('.', 'logs', 'train')
1069    validation_dir = os.path.join('.', 'logs', 'validation')
1070    self.assertEqual(
1071        summary_file.scalars, {
1072            _ObservedSummary(logdir=train_dir, tag='epoch_loss'),
1073            _ObservedSummary(logdir=validation_dir, tag='epoch_loss'),
1074        })
1075
1076  def test_TensorBoard_basic(self):
1077    model = self._get_model()
1078    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1079    tb_cbk = keras.callbacks.TensorBoard(self.logdir)
1080
1081    model.fit(
1082        x,
1083        y,
1084        batch_size=2,
1085        epochs=2,
1086        validation_data=(x, y),
1087        callbacks=[tb_cbk])
1088
1089    summary_file = list_summaries(self.logdir)
1090    self.assertEqual(
1091        summary_file.scalars, {
1092            _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'),
1093            _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'),
1094        })
1095
1096  def test_TensorBoard_across_invocations(self):
1097    """Regression test for summary writer resource use-after-free.
1098
1099    See: <https://github.com/tensorflow/tensorflow/issues/25707>
1100    """
1101    model = self._get_model()
1102    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1103    tb_cbk = keras.callbacks.TensorBoard(self.logdir)
1104
1105    for _ in (1, 2):
1106      model.fit(
1107          x,
1108          y,
1109          batch_size=2,
1110          epochs=2,
1111          validation_data=(x, y),
1112          callbacks=[tb_cbk])
1113
1114    summary_file = list_summaries(self.logdir)
1115    self.assertEqual(
1116        summary_file.scalars, {
1117            _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'),
1118            _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'),
1119        })
1120
1121  def test_TensorBoard_no_spurious_event_files(self):
1122    model = self._get_model()
1123    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1124    tb_cbk = keras.callbacks.TensorBoard(self.logdir)
1125
1126    model.fit(
1127        x,
1128        y,
1129        batch_size=2,
1130        epochs=2,
1131        callbacks=[tb_cbk])
1132
1133    events_file_run_basenames = set()
1134    for (dirpath, dirnames, filenames) in os.walk(self.logdir):
1135      del dirnames  # unused
1136      if any(fn.startswith('events.out.') for fn in filenames):
1137        events_file_run_basenames.add(os.path.basename(dirpath))
1138    self.assertEqual(events_file_run_basenames, {'train'})
1139
1140  def test_TensorBoard_batch_metrics(self):
1141    model = self._get_model()
1142    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1143    tb_cbk = keras.callbacks.TensorBoard(self.logdir, update_freq=1)
1144
1145    model.fit(
1146        x,
1147        y,
1148        batch_size=2,
1149        epochs=2,
1150        validation_data=(x, y),
1151        callbacks=[tb_cbk])
1152
1153    summary_file = list_summaries(self.logdir)
1154    self.assertEqual(
1155        summary_file.scalars,
1156        {
1157            _ObservedSummary(logdir=self.train_dir, tag='batch_loss'),
1158            _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'),
1159            _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'),
1160        },
1161    )
1162
1163  def test_TensorBoard_weight_histograms(self):
1164    model = self._get_model()
1165    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1166    tb_cbk = keras.callbacks.TensorBoard(self.logdir, histogram_freq=1)
1167    model_type = testing_utils.get_model_type()
1168
1169    model.fit(
1170        x,
1171        y,
1172        batch_size=2,
1173        epochs=2,
1174        validation_data=(x, y),
1175        callbacks=[tb_cbk])
1176    summary_file = list_summaries(self.logdir)
1177
1178    self.assertEqual(
1179        summary_file.scalars,
1180        {
1181            _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'),
1182            _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'),
1183        },
1184    )
1185    self.assertEqual(
1186        self._strip_layer_names(summary_file.histograms, model_type),
1187        {
1188            _ObservedSummary(logdir=self.train_dir, tag='bias_0'),
1189            _ObservedSummary(logdir=self.train_dir, tag='kernel_0'),
1190        },
1191    )
1192
1193  def test_TensorBoard_weight_images(self):
1194    model = self._get_model()
1195    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1196    tb_cbk = keras.callbacks.TensorBoard(
1197        self.logdir, histogram_freq=1, write_images=True)
1198    model_type = testing_utils.get_model_type()
1199
1200    model.fit(
1201        x,
1202        y,
1203        batch_size=2,
1204        epochs=2,
1205        validation_data=(x, y),
1206        callbacks=[tb_cbk])
1207    summary_file = list_summaries(self.logdir)
1208
1209    self.assertEqual(
1210        summary_file.scalars,
1211        {
1212            _ObservedSummary(logdir=self.train_dir, tag='epoch_loss'),
1213            _ObservedSummary(logdir=self.validation_dir, tag='epoch_loss'),
1214        },
1215    )
1216    self.assertEqual(
1217        self._strip_layer_names(summary_file.histograms, model_type),
1218        {
1219            _ObservedSummary(logdir=self.train_dir, tag='bias_0'),
1220            _ObservedSummary(logdir=self.train_dir, tag='kernel_0'),
1221        },
1222    )
1223    self.assertEqual(
1224        self._strip_layer_names(summary_file.images, model_type),
1225        {
1226            _ObservedSummary(logdir=self.train_dir, tag='bias_0/image/0'),
1227            _ObservedSummary(logdir=self.train_dir, tag='kernel_0/image/0'),
1228            _ObservedSummary(logdir=self.train_dir, tag='kernel_0/image/1'),
1229            _ObservedSummary(logdir=self.train_dir, tag='kernel_0/image/2'),
1230        },
1231    )
1232
1233  def _strip_layer_names(self, summaries, model_type):
1234    """Deduplicate summary names modulo layer prefix.
1235
1236    This removes the first slash-component of each tag name: for
1237    instance, "foo/bar/baz" becomes "bar/baz".
1238
1239    Args:
1240      summaries: A `set` of `_ObservedSummary` values.
1241      model_type: The model type currently being tested.
1242
1243    Returns:
1244      A new `set` of `_ObservedSummary` values with layer prefixes
1245      removed.
1246    """
1247    result = set()
1248    for summary in summaries:
1249      if '/' not in summary.tag:
1250        raise ValueError('tag has no layer name: %r' % summary.tag)
1251      start_from = 2 if 'subclass' in model_type else 1
1252      new_tag = '/'.join(summary.tag.split('/')[start_from:])
1253      result.add(summary._replace(tag=new_tag))
1254    return result
1255
1256  def test_TensorBoard_invalid_argument(self):
1257    with self.assertRaisesRegexp(ValueError, 'Unrecognized arguments'):
1258      keras.callbacks.TensorBoard(wwrite_images=True)
1259
1260
1261# Note that this test specifies model_type explicitly.
1262@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
1263class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
1264
1265  def setUp(self):
1266    super(TestTensorBoardV2NonParameterizedTest, self).setUp()
1267    self.logdir = os.path.join(self.get_temp_dir(), 'tb')
1268    self.train_dir = os.path.join(self.logdir, 'train')
1269    self.validation_dir = os.path.join(self.logdir, 'validation')
1270
1271  def _get_seq_model(self):
1272    model = keras.models.Sequential([
1273        keras.layers.Conv2D(8, (3, 3), input_shape=(10, 10, 1)),
1274        keras.layers.Flatten(),
1275        keras.layers.Dense(1),
1276    ])
1277    model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
1278    return model
1279
1280  def fitModelAndAssertKerasModelWritten(self, model):
1281    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1282    tb_cbk = keras.callbacks.TensorBoard(self.logdir,
1283                                         write_graph=True,
1284                                         profile_batch=0)
1285    model.fit(
1286        x,
1287        y,
1288        batch_size=2,
1289        epochs=2,
1290        validation_data=(x, y),
1291        callbacks=[tb_cbk])
1292    summary_file = list_summaries(self.logdir)
1293    self.assertEqual(
1294        summary_file.tensors,
1295        {
1296            _ObservedSummary(logdir=self.train_dir, tag='keras'),
1297        },
1298    )
1299
1300  def test_TensorBoard_writeSequentialModel_noInputShape(self):
1301    model = keras.models.Sequential([
1302        keras.layers.Conv2D(8, (3, 3)),
1303        keras.layers.Flatten(),
1304        keras.layers.Dense(1),
1305    ])
1306    model.compile('sgd', 'mse', run_eagerly=False)
1307    self.fitModelAndAssertKerasModelWritten(model)
1308
1309  def test_TensorBoard_writeSequentialModel_withInputShape(self):
1310    model = keras.models.Sequential([
1311        keras.layers.Conv2D(8, (3, 3), input_shape=(10, 10, 1)),
1312        keras.layers.Flatten(),
1313        keras.layers.Dense(1),
1314    ])
1315    model.compile('sgd', 'mse', run_eagerly=False)
1316    self.fitModelAndAssertKerasModelWritten(model)
1317
1318  def test_TensoriBoard_writeModel(self):
1319    inputs = keras.layers.Input([10, 10, 1])
1320    x = keras.layers.Conv2D(8, (3, 3), activation='relu')(inputs)
1321    x = keras.layers.Flatten()(x)
1322    x = keras.layers.Dense(1)(x)
1323    model = keras.models.Model(inputs=inputs, outputs=[x])
1324    model.compile('sgd', 'mse', run_eagerly=False)
1325    self.fitModelAndAssertKerasModelWritten(model)
1326
1327  def test_TensorBoard_autoTrace(self):
1328    model = self._get_seq_model()
1329    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1330    tb_cbk = keras.callbacks.TensorBoard(
1331        self.logdir, histogram_freq=1, profile_batch=1, write_graph=False)
1332
1333    model.fit(
1334        x,
1335        y,
1336        batch_size=2,
1337        epochs=2,
1338        validation_data=(x, y),
1339        callbacks=[tb_cbk])
1340    summary_file = list_summaries(self.logdir)
1341
1342    self.assertEqual(
1343        summary_file.tensors,
1344        {
1345            _ObservedSummary(logdir=self.train_dir, tag=u'batch_1'),
1346        },
1347    )
1348
1349  def test_TensorBoard_autoTrace_tagNameWithBatchNum(self):
1350    model = self._get_seq_model()
1351    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1352    tb_cbk = keras.callbacks.TensorBoard(
1353        self.logdir, histogram_freq=1, profile_batch=2, write_graph=False)
1354
1355    model.fit(
1356        x,
1357        y,
1358        batch_size=2,
1359        epochs=2,
1360        validation_data=(x, y),
1361        callbacks=[tb_cbk])
1362    summary_file = list_summaries(self.logdir)
1363
1364    self.assertEqual(
1365        summary_file.tensors,
1366        {
1367            _ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
1368        },
1369    )
1370
1371  def test_TensorBoard_autoTrace_profile_batch_largerThanBatchCount(self):
1372    model = self._get_seq_model()
1373    x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
1374    tb_cbk = keras.callbacks.TensorBoard(
1375        self.logdir, histogram_freq=1, profile_batch=10000, write_graph=False)
1376
1377    model.fit(
1378        x,
1379        y,
1380        batch_size=2,
1381        epochs=2,
1382        validation_data=(x, y),
1383        callbacks=[tb_cbk])
1384    summary_file = list_summaries(self.logdir)
1385
1386    # Enabled trace only on the 10000th batch, thus it should be empty.
1387    self.assertEmpty(summary_file.tensors)
1388
1389
1390if __name__ == '__main__':
1391  test.main()
1392