1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-import-not-at-top
16"""Callbacks: utilities called at certain points during model training.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import copy
24import csv
25import io
26import json
27import os
28import time
29
30import numpy as np
31import six
32
33from tensorflow.python.data.ops import iterator_ops
34from tensorflow.python.eager import context
35from tensorflow.python.framework import ops
36from tensorflow.python.keras import backend as K
37from tensorflow.python.keras.utils.data_utils import Sequence
38from tensorflow.python.keras.utils.generic_utils import Progbar
39from tensorflow.python.keras.utils.mode_keys import ModeKeys
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import summary_ops_v2
42from tensorflow.python.platform import tf_logging as logging
43from tensorflow.python.util.tf_export import keras_export
44
45try:
46  import requests
47except ImportError:
48  requests = None
49
50
51def configure_callbacks(callbacks,
52                        model,
53                        do_validation=False,
54                        batch_size=None,
55                        epochs=None,
56                        steps_per_epoch=None,
57                        samples=None,
58                        verbose=1,
59                        count_mode='steps',
60                        mode=ModeKeys.TRAIN):
61  """Configures callbacks for use in various training loops.
62
63  Arguments:
64      callbacks: List of Callbacks.
65      model: Model being trained.
66      do_validation: Whether or not validation loop will be run.
67      batch_size: Number of samples per batch.
68      epochs: Number of epoch to train.
69      steps_per_epoch: Number of batches to run per training epoch.
70      samples: Number of training samples.
71      verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger.
72      count_mode: One of 'steps' or 'samples'. Per-batch or per-sample count.
73      mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT.
74        Which loop mode to configure callbacks for.
75
76  Returns:
77      Instance of CallbackList used to control all Callbacks.
78  """
79  # Check if callbacks have already been configured.
80  if isinstance(callbacks, CallbackList):
81    return callbacks
82
83  if not callbacks:
84    callbacks = []
85
86  # Add additional callbacks during training.
87  if mode == ModeKeys.TRAIN:
88    model.history = History()
89    callbacks = [BaseLogger()] + (callbacks or []) + [model.history]
90    if verbose:
91      callbacks.append(ProgbarLogger(count_mode))
92  callback_list = CallbackList(callbacks)
93
94  # Set callback model
95  callback_model = model._get_callback_model()  # pylint: disable=protected-access
96  callback_list.set_model(callback_model)
97
98  set_callback_parameters(
99      callback_list,
100      model,
101      do_validation=do_validation,
102      batch_size=batch_size,
103      epochs=epochs,
104      steps_per_epoch=steps_per_epoch,
105      samples=samples,
106      verbose=verbose,
107      mode=mode)
108
109  callback_list.model.stop_training = False
110  return callback_list
111
112
113def set_callback_parameters(callback_list,
114                            model,
115                            do_validation=False,
116                            batch_size=None,
117                            epochs=None,
118                            steps_per_epoch=None,
119                            samples=None,
120                            verbose=1,
121                            mode=ModeKeys.TRAIN):
122  """Sets callback parameters.
123
124  Arguments:
125      callback_list: CallbackList instance.
126      model: Model being trained.
127      do_validation: Whether or not validation loop will be run.
128      batch_size: Number of samples per batch.
129      epochs: Number of epoch to train.
130      steps_per_epoch: Number of batches to run per training epoch.
131      samples: Number of training samples.
132      verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger.
133      mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT.
134        Which loop mode to configure callbacks for.
135  """
136  for cbk in callback_list:
137    if isinstance(cbk, (BaseLogger, ProgbarLogger)):
138      cbk.stateful_metrics = model.metrics_names[1:]  # Exclude `loss`
139
140  # Set callback parameters
141  callback_metrics = []
142  # When we have deferred build scenario with iterator input, we will compile
143  # when we standardize first batch of data.
144  if mode != ModeKeys.PREDICT and hasattr(model, 'metrics_names'):
145    callback_metrics = copy.copy(model.metrics_names)
146    if do_validation:
147      callback_metrics += ['val_' + n for n in model.metrics_names]
148  callback_params = {
149      'batch_size': batch_size,
150      'epochs': epochs,
151      'steps': steps_per_epoch,
152      'samples': samples,
153      'verbose': verbose,
154      'do_validation': do_validation,
155      'metrics': callback_metrics,
156  }
157  callback_list.set_params(callback_params)
158
159
160def _is_generator_like(data):
161  """Checks if data is a generator, Sequence, or Iterator."""
162  return (hasattr(data, 'next') or hasattr(data, '__next__') or isinstance(
163      data, (Sequence, iterator_ops.Iterator, iterator_ops.EagerIterator)))
164
165
166def make_logs(model, logs, outputs, mode, prefix=''):
167  """Computes logs for sending to `on_batch_end` methods."""
168  if mode in {ModeKeys.TRAIN, ModeKeys.TEST}:
169    if hasattr(model, 'metrics_names'):
170      for label, output in zip(model.metrics_names, outputs):
171        logs[prefix + label] = output
172  else:
173    logs['outputs'] = outputs
174  return logs
175
176
177class CallbackList(object):
178  """Container abstracting a list of callbacks.
179
180  Arguments:
181      callbacks: List of `Callback` instances.
182      queue_length: Queue length for keeping
183          running statistics over callback execution time.
184  """
185
186  def __init__(self, callbacks=None, queue_length=10):
187    callbacks = callbacks or []
188    self.callbacks = [c for c in callbacks]
189    self.queue_length = queue_length
190    self.params = {}
191    self.model = None
192    self._reset_batch_timing()
193
194  def _reset_batch_timing(self):
195    self._delta_t_batch = 0.
196    self._delta_ts = collections.defaultdict(
197        lambda: collections.deque([], maxlen=self.queue_length))
198
199  def append(self, callback):
200    self.callbacks.append(callback)
201
202  def set_params(self, params):
203    self.params = params
204    for callback in self.callbacks:
205      callback.set_params(params)
206
207  def set_model(self, model):
208    self.model = model
209    for callback in self.callbacks:
210      callback.set_model(model)
211
212  def _call_batch_hook(self, mode, hook, batch, logs=None):
213    """Helper function for all batch_{begin | end} methods."""
214    if not self.callbacks:
215      return
216    hook_name = 'on_{mode}_batch_{hook}'.format(mode=mode, hook=hook)
217    if hook == 'begin':
218      self._t_enter_batch = time.time()
219    if hook == 'end':
220      # Batch is ending, calculate batch time.
221      self._delta_t_batch = time.time() - self._t_enter_batch
222
223    logs = logs or {}
224    t_before_callbacks = time.time()
225    for callback in self.callbacks:
226      batch_hook = getattr(callback, hook_name)
227      batch_hook(batch, logs)
228    self._delta_ts[hook_name].append(time.time() - t_before_callbacks)
229
230    delta_t_median = np.median(self._delta_ts[hook_name])
231    if (self._delta_t_batch > 0. and
232        delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1):
233      logging.warning(
234          'Method (%s) is slow compared '
235          'to the batch update (%f). Check your callbacks.', hook_name,
236          delta_t_median)
237
238  def _call_begin_hook(self, mode):
239    """Helper function for on_{train|test|predict}_begin methods."""
240    if mode == ModeKeys.TRAIN:
241      self.on_train_begin()
242    elif mode == ModeKeys.TEST:
243      self.on_test_begin()
244    else:
245      self.on_predict_begin()
246
247  def _call_end_hook(self, mode):
248    """Helper function for on_{train|test|predict}_end methods."""
249    if mode == ModeKeys.TRAIN:
250      self.on_train_end()
251    elif mode == ModeKeys.TEST:
252      self.on_test_end()
253    else:
254      self.on_predict_end()
255
256  def on_batch_begin(self, batch, logs=None):
257    self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
258
259  def on_batch_end(self, batch, logs=None):
260    self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
261
262  def on_epoch_begin(self, epoch, logs=None):
263    """Calls the `on_epoch_begin` methods of its callbacks.
264
265    This function should only be called during TRAIN mode.
266
267    Arguments:
268        epoch: integer, index of epoch.
269        logs: dict. Currently no data is passed to this argument for this method
270          but that may change in the future.
271    """
272    logs = logs or {}
273    for callback in self.callbacks:
274      callback.on_epoch_begin(epoch, logs)
275    self._reset_batch_timing()
276
277  def on_epoch_end(self, epoch, logs=None):
278    """Calls the `on_epoch_end` methods of its callbacks.
279
280    This function should only be called during TRAIN mode.
281
282    Arguments:
283        epoch: integer, index of epoch.
284        logs: dict, metric results for this training epoch, and for the
285          validation epoch if validation is performed. Validation result keys
286          are prefixed with `val_`.
287    """
288    logs = logs or {}
289    for callback in self.callbacks:
290      callback.on_epoch_end(epoch, logs)
291
292  def on_train_batch_begin(self, batch, logs=None):
293    """Calls the `on_train_batch_begin` methods of its callbacks.
294
295    Arguments:
296        batch: integer, index of batch within the current epoch.
297        logs: dict. Has keys `batch` and `size` representing the current batch
298          number and the size of the batch.
299    """
300    self._call_batch_hook(ModeKeys.TRAIN, 'begin', batch, logs=logs)
301
302  def on_train_batch_end(self, batch, logs=None):
303    """Calls the `on_train_batch_end` methods of its callbacks.
304
305    Arguments:
306        batch: integer, index of batch within the current epoch.
307        logs: dict. Metric results for this batch.
308    """
309    self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
310
311  def on_test_batch_begin(self, batch, logs=None):
312    """Calls the `on_test_batch_begin` methods of its callbacks.
313
314    Arguments:
315        batch: integer, index of batch within the current epoch.
316        logs: dict. Has keys `batch` and `size` representing the current batch
317          number and the size of the batch.
318    """
319    self._call_batch_hook(ModeKeys.TEST, 'begin', batch, logs=logs)
320
321  def on_test_batch_end(self, batch, logs=None):
322    """Calls the `on_test_batch_end` methods of its callbacks.
323
324    Arguments:
325        batch: integer, index of batch within the current epoch.
326        logs: dict. Metric results for this batch.
327    """
328    self._call_batch_hook(ModeKeys.TEST, 'end', batch, logs=logs)
329
330  def on_predict_batch_begin(self, batch, logs=None):
331    """Calls the `on_predict_batch_begin` methods of its callbacks.
332
333    Arguments:
334        batch: integer, index of batch within the current epoch.
335        logs: dict. Has keys `batch` and `size` representing the current batch
336          number and the size of the batch.
337    """
338    self._call_batch_hook(ModeKeys.PREDICT, 'begin', batch, logs=logs)
339
340  def on_predict_batch_end(self, batch, logs=None):
341    """Calls the `on_predict_batch_end` methods of its callbacks.
342
343    Arguments:
344        batch: integer, index of batch within the current epoch.
345        logs: dict. Metric results for this batch.
346    """
347    self._call_batch_hook(ModeKeys.PREDICT, 'end', batch, logs=logs)
348
349  def on_train_begin(self, logs=None):
350    """Calls the `on_train_begin` methods of its callbacks.
351
352    Arguments:
353        logs: dict. Currently no data is passed to this argument for this method
354          but that may change in the future.
355    """
356    for callback in self.callbacks:
357      callback.on_train_begin(logs)
358
359  def on_train_end(self, logs=None):
360    """Calls the `on_train_end` methods of its callbacks.
361
362    Arguments:
363        logs: dict. Currently no data is passed to this argument for this method
364          but that may change in the future.
365    """
366    for callback in self.callbacks:
367      callback.on_train_end(logs)
368
369  def on_test_begin(self, logs=None):
370    """Calls the `on_test_begin` methods of its callbacks.
371
372    Arguments:
373        logs: dict. Currently no data is passed to this argument for this method
374          but that may change in the future.
375    """
376    for callback in self.callbacks:
377      callback.on_test_begin(logs)
378
379  def on_test_end(self, logs=None):
380    """Calls the `on_test_end` methods of its callbacks.
381
382    Arguments:
383        logs: dict. Currently no data is passed to this argument for this method
384          but that may change in the future.
385    """
386    for callback in self.callbacks:
387      callback.on_test_end(logs)
388
389  def on_predict_begin(self, logs=None):
390    """Calls the 'on_predict_begin` methods of its callbacks.
391
392    Arguments:
393        logs: dict. Currently no data is passed to this argument for this method
394          but that may change in the future.
395    """
396    for callback in self.callbacks:
397      callback.on_predict_begin(logs)
398
399  def on_predict_end(self, logs=None):
400    """Calls the `on_predict_end` methods of its callbacks.
401
402    Arguments:
403        logs: dict. Currently no data is passed to this argument for this method
404          but that may change in the future.
405    """
406    for callback in self.callbacks:
407      callback.on_predict_end(logs)
408
409  def __iter__(self):
410    return iter(self.callbacks)
411
412
413@keras_export('keras.callbacks.Callback')
414class Callback(object):
415  """Abstract base class used to build new callbacks.
416
417  Attributes:
418      params: dict. Training parameters
419          (eg. verbosity, batch size, number of epochs...).
420      model: instance of `keras.models.Model`.
421          Reference of the model being trained.
422
423  The `logs` dictionary that callback methods
424  take as argument will contain keys for quantities relevant to
425  the current batch or epoch.
426
427  Currently, the `.fit()` method of the `Model` class
428  will include the following quantities in the `logs` that
429  it passes to its callbacks:
430
431      on_epoch_end: logs include `acc` and `loss`, and
432          optionally include `val_loss`
433          (if validation is enabled in `fit`), and `val_acc`
434          (if validation and accuracy monitoring are enabled).
435      on_batch_begin: logs include `size`,
436          the number of samples in the current batch.
437      on_batch_end: logs include `loss`, and optionally `acc`
438          (if accuracy monitoring is enabled).
439  """
440
441  def __init__(self):
442    self.validation_data = None
443    self.model = None
444    # Whether this Callback should only run on the chief worker in a
445    # Multi-Worker setting.
446    # TODO(omalleyt): Make this attr public once solution is stable.
447    self._chief_worker_only = None
448
449  def set_params(self, params):
450    self.params = params
451
452  def set_model(self, model):
453    self.model = model
454
455  def on_batch_begin(self, batch, logs=None):
456    """A backwards compatibility alias for `on_train_batch_begin`."""
457
458  def on_batch_end(self, batch, logs=None):
459    """A backwards compatibility alias for `on_train_batch_end`."""
460
461  def on_epoch_begin(self, epoch, logs=None):
462    """Called at the start of an epoch.
463
464    Subclasses should override for any actions to run. This function should only
465    be called during TRAIN mode.
466
467    Arguments:
468        epoch: integer, index of epoch.
469        logs: dict. Currently no data is passed to this argument for this method
470          but that may change in the future.
471    """
472
473  def on_epoch_end(self, epoch, logs=None):
474    """Called at the end of an epoch.
475
476    Subclasses should override for any actions to run. This function should only
477    be called during TRAIN mode.
478
479    Arguments:
480        epoch: integer, index of epoch.
481        logs: dict, metric results for this training epoch, and for the
482          validation epoch if validation is performed. Validation result keys
483          are prefixed with `val_`.
484    """
485
486  def on_train_batch_begin(self, batch, logs=None):
487    """Called at the beginning of a training batch in `fit` methods.
488
489    Subclasses should override for any actions to run.
490
491    Arguments:
492        batch: integer, index of batch within the current epoch.
493        logs: dict. Has keys `batch` and `size` representing the current batch
494          number and the size of the batch.
495    """
496    # For backwards compatibility.
497    self.on_batch_begin(batch, logs=logs)
498
499  def on_train_batch_end(self, batch, logs=None):
500    """Called at the end of a training batch in `fit` methods.
501
502    Subclasses should override for any actions to run.
503
504    Arguments:
505        batch: integer, index of batch within the current epoch.
506        logs: dict. Metric results for this batch.
507    """
508    # For backwards compatibility.
509    self.on_batch_end(batch, logs=logs)
510
511  def on_test_batch_begin(self, batch, logs=None):
512    """Called at the beginning of a batch in `evaluate` methods.
513
514    Also called at the beginning of a validation batch in the `fit`
515    methods, if validation data is provided.
516
517    Subclasses should override for any actions to run.
518
519    Arguments:
520        batch: integer, index of batch within the current epoch.
521        logs: dict. Has keys `batch` and `size` representing the current batch
522          number and the size of the batch.
523    """
524
525  def on_test_batch_end(self, batch, logs=None):
526    """Called at the end of a batch in `evaluate` methods.
527
528    Also called at the end of a validation batch in the `fit`
529    methods, if validation data is provided.
530
531    Subclasses should override for any actions to run.
532
533    Arguments:
534        batch: integer, index of batch within the current epoch.
535        logs: dict. Metric results for this batch.
536    """
537
538  def on_predict_batch_begin(self, batch, logs=None):
539    """Called at the beginning of a batch in `predict` methods.
540
541    Subclasses should override for any actions to run.
542
543    Arguments:
544        batch: integer, index of batch within the current epoch.
545        logs: dict. Has keys `batch` and `size` representing the current batch
546          number and the size of the batch.
547    """
548
549  def on_predict_batch_end(self, batch, logs=None):
550    """Called at the end of a batch in `predict` methods.
551
552    Subclasses should override for any actions to run.
553
554    Arguments:
555        batch: integer, index of batch within the current epoch.
556        logs: dict. Metric results for this batch.
557    """
558
559  def on_train_begin(self, logs=None):
560    """Called at the beginning of training.
561
562    Subclasses should override for any actions to run.
563
564    Arguments:
565        logs: dict. Currently no data is passed to this argument for this method
566          but that may change in the future.
567    """
568
569  def on_train_end(self, logs=None):
570    """Called at the end of training.
571
572    Subclasses should override for any actions to run.
573
574    Arguments:
575        logs: dict. Currently no data is passed to this argument for this method
576          but that may change in the future.
577    """
578
579  def on_test_begin(self, logs=None):
580    """Called at the beginning of evaluation or validation.
581
582    Subclasses should override for any actions to run.
583
584    Arguments:
585        logs: dict. Currently no data is passed to this argument for this method
586          but that may change in the future.
587    """
588
589  def on_test_end(self, logs=None):
590    """Called at the end of evaluation or validation.
591
592    Subclasses should override for any actions to run.
593
594    Arguments:
595        logs: dict. Currently no data is passed to this argument for this method
596          but that may change in the future.
597    """
598
599  def on_predict_begin(self, logs=None):
600    """Called at the beginning of prediction.
601
602    Subclasses should override for any actions to run.
603
604    Arguments:
605        logs: dict. Currently no data is passed to this argument for this method
606          but that may change in the future.
607    """
608
609  def on_predict_end(self, logs=None):
610    """Called at the end of prediction.
611
612    Subclasses should override for any actions to run.
613
614    Arguments:
615        logs: dict. Currently no data is passed to this argument for this method
616          but that may change in the future.
617    """
618
619
620@keras_export('keras.callbacks.BaseLogger')
621class BaseLogger(Callback):
622  """Callback that accumulates epoch averages of metrics.
623
624  This callback is automatically applied to every Keras model.
625
626  Arguments:
627      stateful_metrics: Iterable of string names of metrics that
628          should *not* be averaged over an epoch.
629          Metrics in this list will be logged as-is in `on_epoch_end`.
630          All others will be averaged in `on_epoch_end`.
631  """
632
633  def __init__(self, stateful_metrics=None):
634    super(BaseLogger, self).__init__()
635    self.stateful_metrics = set(stateful_metrics or [])
636
637  def on_epoch_begin(self, epoch, logs=None):
638    self.seen = 0
639    self.totals = {}
640
641  def on_batch_end(self, batch, logs=None):
642    logs = logs or {}
643    batch_size = logs.get('size', 0)
644    # In case of distribution strategy we can potentially run multiple steps
645    # at the same time, we should account for that in the `seen` calculation.
646    num_steps = logs.get('num_steps', 1)
647    self.seen += batch_size * num_steps
648
649    for k, v in logs.items():
650      if k in self.stateful_metrics:
651        self.totals[k] = v
652      else:
653        if k in self.totals:
654          self.totals[k] += v * batch_size
655        else:
656          self.totals[k] = v * batch_size
657
658  def on_epoch_end(self, epoch, logs=None):
659    if logs is not None:
660      for k in self.params['metrics']:
661        if k in self.totals:
662          # Make value available to next callbacks.
663          if k in self.stateful_metrics:
664            logs[k] = self.totals[k]
665          else:
666            logs[k] = self.totals[k] / self.seen
667
668
669@keras_export('keras.callbacks.TerminateOnNaN')
670class TerminateOnNaN(Callback):
671  """Callback that terminates training when a NaN loss is encountered.
672  """
673
674  def on_batch_end(self, batch, logs=None):
675    logs = logs or {}
676    loss = logs.get('loss')
677    if loss is not None:
678      if np.isnan(loss) or np.isinf(loss):
679        print('Batch %d: Invalid loss, terminating training' % (batch))
680        self.model.stop_training = True
681
682
683@keras_export('keras.callbacks.ProgbarLogger')
684class ProgbarLogger(Callback):
685  """Callback that prints metrics to stdout.
686
687  Arguments:
688      count_mode: One of "steps" or "samples".
689          Whether the progress bar should
690          count samples seen or steps (batches) seen.
691      stateful_metrics: Iterable of string names of metrics that
692          should *not* be averaged over an epoch.
693          Metrics in this list will be logged as-is.
694          All others will be averaged over time (e.g. loss, etc).
695
696  Raises:
697      ValueError: In case of invalid `count_mode`.
698  """
699
700  def __init__(self, count_mode='samples', stateful_metrics=None):
701    super(ProgbarLogger, self).__init__()
702    if count_mode == 'samples':
703      self.use_steps = False
704    elif count_mode == 'steps':
705      self.use_steps = True
706    else:
707      raise ValueError('Unknown `count_mode`: ' + str(count_mode))
708    self.stateful_metrics = set(stateful_metrics or [])
709
710  def on_train_begin(self, logs=None):
711    self.verbose = self.params['verbose']
712    self.epochs = self.params['epochs']
713
714  def on_epoch_begin(self, epoch, logs=None):
715    self.seen = 0
716    if self.use_steps:
717      self.target = self.params['steps']
718    else:
719      self.target = self.params['samples']
720
721    if self.verbose:
722      if self.epochs > 1:
723        print('Epoch %d/%d' % (epoch + 1, self.epochs))
724    self.progbar = Progbar(
725        target=self.target,
726        verbose=self.verbose,
727        stateful_metrics=self.stateful_metrics,
728        unit_name='step' if self.use_steps else 'sample')
729
730  def on_batch_begin(self, batch, logs=None):
731    self.log_values = []
732
733  def on_batch_end(self, batch, logs=None):
734    logs = logs or {}
735    batch_size = logs.get('size', 0)
736    # In case of distribution strategy we can potentially run multiple steps
737    # at the same time, we should account for that in the `seen` calculation.
738    num_steps = logs.get('num_steps', 1)
739    if self.use_steps:
740      self.seen += num_steps
741    else:
742      self.seen += batch_size * num_steps
743
744    for k in self.params['metrics']:
745      if k in logs:
746        self.log_values.append((k, logs[k]))
747
748    # Skip progbar update for the last batch;
749    # will be handled by on_epoch_end.
750    if self.verbose and (self.target is None or self.seen < self.target):
751      self.progbar.update(self.seen, self.log_values)
752
753  def on_epoch_end(self, epoch, logs=None):
754    logs = logs or {}
755    for k in self.params['metrics']:
756      if k in logs:
757        self.log_values.append((k, logs[k]))
758    if self.verbose:
759      self.progbar.update(self.seen, self.log_values)
760
761
762@keras_export('keras.callbacks.History')
763class History(Callback):
764  """Callback that records events into a `History` object.
765
766  This callback is automatically applied to
767  every Keras model. The `History` object
768  gets returned by the `fit` method of models.
769  """
770
771  def on_train_begin(self, logs=None):
772    self.epoch = []
773    self.history = {}
774
775  def on_epoch_end(self, epoch, logs=None):
776    logs = logs or {}
777    self.epoch.append(epoch)
778    for k, v in logs.items():
779      self.history.setdefault(k, []).append(v)
780
781
782@keras_export('keras.callbacks.ModelCheckpoint')
783class ModelCheckpoint(Callback):
784  """Save the model after every epoch.
785
786  `filepath` can contain named formatting options,
787  which will be filled the value of `epoch` and
788  keys in `logs` (passed in `on_epoch_end`).
789
790  For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
791  then the model checkpoints will be saved with the epoch number and
792  the validation loss in the filename.
793
794  Arguments:
795      filepath: string, path to save the model file.
796      monitor: quantity to monitor.
797      verbose: verbosity mode, 0 or 1.
798      save_best_only: if `save_best_only=True`,
799          the latest best model according to
800          the quantity monitored will not be overwritten.
801      mode: one of {auto, min, max}.
802          If `save_best_only=True`, the decision
803          to overwrite the current save file is made
804          based on either the maximization or the
805          minimization of the monitored quantity. For `val_acc`,
806          this should be `max`, for `val_loss` this should
807          be `min`, etc. In `auto` mode, the direction is
808          automatically inferred from the name of the monitored quantity.
809      save_weights_only: if True, then only the model's weights will be
810          saved (`model.save_weights(filepath)`), else the full model
811          is saved (`model.save(filepath)`).
812      period: Interval (number of epochs) between checkpoints.
813  """
814
815  def __init__(self,
816               filepath,
817               monitor='val_loss',
818               verbose=0,
819               save_best_only=False,
820               save_weights_only=False,
821               mode='auto',
822               period=1):
823    super(ModelCheckpoint, self).__init__()
824    self.monitor = monitor
825    self.verbose = verbose
826    self.filepath = filepath
827    self.save_best_only = save_best_only
828    self.save_weights_only = save_weights_only
829    self.period = period
830    self.epochs_since_last_save = 0
831
832    if mode not in ['auto', 'min', 'max']:
833      logging.warning('ModelCheckpoint mode %s is unknown, '
834                      'fallback to auto mode.', mode)
835      mode = 'auto'
836
837    if mode == 'min':
838      self.monitor_op = np.less
839      self.best = np.Inf
840    elif mode == 'max':
841      self.monitor_op = np.greater
842      self.best = -np.Inf
843    else:
844      if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
845        self.monitor_op = np.greater
846        self.best = -np.Inf
847      else:
848        self.monitor_op = np.less
849        self.best = np.Inf
850
851    # Only the chief worker writes model checkpoints.
852    self._chief_worker_only = True
853
854  def set_model(self, model):
855    self.model = model
856    # Use name matching rather than `isinstance` to avoid circular dependencies.
857    if (not self.save_weights_only and
858        not model._is_graph_network and  # pylint: disable=protected-access
859        model.__class__.__name__ != 'Sequential'):
860      self.save_weights_only = True
861
862  def on_epoch_end(self, epoch, logs=None):
863    logs = logs or {}
864    self.epochs_since_last_save += 1
865    if self.epochs_since_last_save >= self.period:
866      self.epochs_since_last_save = 0
867      filepath = self.filepath.format(epoch=epoch + 1, **logs)
868      if self.save_best_only:
869        current = logs.get(self.monitor)
870        if current is None:
871          logging.warning('Can save best model only with %s available, '
872                          'skipping.', self.monitor)
873        else:
874          if self.monitor_op(current, self.best):
875            if self.verbose > 0:
876              print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
877                    ' saving model to %s' % (epoch + 1, self.monitor, self.best,
878                                             current, filepath))
879            self.best = current
880            if self.save_weights_only:
881              self.model.save_weights(filepath, overwrite=True)
882            else:
883              self.model.save(filepath, overwrite=True)
884          else:
885            if self.verbose > 0:
886              print('\nEpoch %05d: %s did not improve from %0.5f' %
887                    (epoch + 1, self.monitor, self.best))
888      else:
889        if self.verbose > 0:
890          print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
891        if self.save_weights_only:
892          self.model.save_weights(filepath, overwrite=True)
893        else:
894          self.model.save(filepath, overwrite=True)
895
896
897@keras_export('keras.callbacks.EarlyStopping')
898class EarlyStopping(Callback):
899  """Stop training when a monitored quantity has stopped improving.
900
901  Arguments:
902      monitor: Quantity to be monitored.
903      min_delta: Minimum change in the monitored quantity
904          to qualify as an improvement, i.e. an absolute
905          change of less than min_delta, will count as no
906          improvement.
907      patience: Number of epochs with no improvement
908          after which training will be stopped.
909      verbose: verbosity mode.
910      mode: One of `{"auto", "min", "max"}`. In `min` mode,
911          training will stop when the quantity
912          monitored has stopped decreasing; in `max`
913          mode it will stop when the quantity
914          monitored has stopped increasing; in `auto`
915          mode, the direction is automatically inferred
916          from the name of the monitored quantity.
917      baseline: Baseline value for the monitored quantity.
918          Training will stop if the model doesn't show improvement over the
919          baseline.
920      restore_best_weights: Whether to restore model weights from
921          the epoch with the best value of the monitored quantity.
922          If False, the model weights obtained at the last step of
923          training are used.
924  """
925
926  def __init__(self,
927               monitor='val_loss',
928               min_delta=0,
929               patience=0,
930               verbose=0,
931               mode='auto',
932               baseline=None,
933               restore_best_weights=False):
934    super(EarlyStopping, self).__init__()
935
936    self.monitor = monitor
937    self.patience = patience
938    self.verbose = verbose
939    self.baseline = baseline
940    self.min_delta = abs(min_delta)
941    self.wait = 0
942    self.stopped_epoch = 0
943    self.restore_best_weights = restore_best_weights
944    self.best_weights = None
945
946    if mode not in ['auto', 'min', 'max']:
947      logging.warning('EarlyStopping mode %s is unknown, '
948                      'fallback to auto mode.', mode)
949      mode = 'auto'
950
951    if mode == 'min':
952      self.monitor_op = np.less
953    elif mode == 'max':
954      self.monitor_op = np.greater
955    else:
956      if 'acc' in self.monitor:
957        self.monitor_op = np.greater
958      else:
959        self.monitor_op = np.less
960
961    if self.monitor_op == np.greater:
962      self.min_delta *= 1
963    else:
964      self.min_delta *= -1
965
966  def on_train_begin(self, logs=None):
967    # Allow instances to be re-used
968    self.wait = 0
969    self.stopped_epoch = 0
970    if self.baseline is not None:
971      self.best = self.baseline
972    else:
973      self.best = np.Inf if self.monitor_op == np.less else -np.Inf
974
975  def on_epoch_end(self, epoch, logs=None):
976    current = self.get_monitor_value(logs)
977    if current is None:
978      return
979    if self.monitor_op(current - self.min_delta, self.best):
980      self.best = current
981      self.wait = 0
982      if self.restore_best_weights:
983        self.best_weights = self.model.get_weights()
984    else:
985      self.wait += 1
986      if self.wait >= self.patience:
987        self.stopped_epoch = epoch
988        self.model.stop_training = True
989        if self.restore_best_weights:
990          if self.verbose > 0:
991            print('Restoring model weights from the end of the best epoch.')
992          self.model.set_weights(self.best_weights)
993
994  def on_train_end(self, logs=None):
995    if self.stopped_epoch > 0 and self.verbose > 0:
996      print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
997
998  def get_monitor_value(self, logs):
999    logs = logs or {}
1000    monitor_value = logs.get(self.monitor)
1001    if monitor_value is None:
1002      logging.warning('Early stopping conditioned on metric `%s` '
1003                      'which is not available. Available metrics are: %s',
1004                      self.monitor, ','.join(list(logs.keys())))
1005    return monitor_value
1006
1007
1008@keras_export('keras.callbacks.RemoteMonitor')
1009class RemoteMonitor(Callback):
1010  """Callback used to stream events to a server.
1011
1012  Requires the `requests` library.
1013  Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
1014  HTTP POST, with a `data` argument which is a
1015  JSON-encoded dictionary of event data.
1016  If send_as_json is set to True, the content type of the request will be
1017  application/json. Otherwise the serialized JSON will be sent within a form.
1018
1019  Arguments:
1020      root: String; root url of the target server.
1021      path: String; path relative to `root` to which the events will be sent.
1022      field: String; JSON field under which the data will be stored.
1023          The field is used only if the payload is sent within a form
1024          (i.e. send_as_json is set to False).
1025      headers: Dictionary; optional custom HTTP headers.
1026      send_as_json: Boolean; whether the request should be
1027          sent as application/json.
1028  """
1029
1030  def __init__(self,
1031               root='http://localhost:9000',
1032               path='/publish/epoch/end/',
1033               field='data',
1034               headers=None,
1035               send_as_json=False):
1036    super(RemoteMonitor, self).__init__()
1037
1038    self.root = root
1039    self.path = path
1040    self.field = field
1041    self.headers = headers
1042    self.send_as_json = send_as_json
1043
1044  def on_epoch_end(self, epoch, logs=None):
1045    if requests is None:
1046      raise ImportError('RemoteMonitor requires the `requests` library.')
1047    logs = logs or {}
1048    send = {}
1049    send['epoch'] = epoch
1050    for k, v in logs.items():
1051      send[k] = v
1052    try:
1053      if self.send_as_json:
1054        requests.post(self.root + self.path, json=send, headers=self.headers)
1055      else:
1056        requests.post(
1057            self.root + self.path, {self.field: json.dumps(send)},
1058            headers=self.headers)
1059    except requests.exceptions.RequestException:
1060      logging.warning('Warning: could not reach RemoteMonitor '
1061                      'root server at ' + str(self.root))
1062
1063
1064@keras_export('keras.callbacks.LearningRateScheduler')
1065class LearningRateScheduler(Callback):
1066  """Learning rate scheduler.
1067
1068  Arguments:
1069      schedule: a function that takes an epoch index as input
1070          (integer, indexed from 0) and returns a new
1071          learning rate as output (float).
1072      verbose: int. 0: quiet, 1: update messages.
1073  """
1074
1075  def __init__(self, schedule, verbose=0):
1076    super(LearningRateScheduler, self).__init__()
1077    self.schedule = schedule
1078    self.verbose = verbose
1079
1080  def on_epoch_begin(self, epoch, logs=None):
1081    if not hasattr(self.model.optimizer, 'lr'):
1082      raise ValueError('Optimizer must have a "lr" attribute.')
1083    try:  # new API
1084      lr = float(K.get_value(self.model.optimizer.lr))
1085      lr = self.schedule(epoch, lr)
1086    except TypeError:  # Support for old API for backward compatibility
1087      lr = self.schedule(epoch)
1088    if not isinstance(lr, (float, np.float32, np.float64)):
1089      raise ValueError('The output of the "schedule" function '
1090                       'should be float.')
1091    K.set_value(self.model.optimizer.lr, lr)
1092    if self.verbose > 0:
1093      print('\nEpoch %05d: LearningRateScheduler reducing learning '
1094            'rate to %s.' % (epoch + 1, lr))
1095
1096  def on_epoch_end(self, epoch, logs=None):
1097    logs = logs or {}
1098    logs['lr'] = K.get_value(self.model.optimizer.lr)
1099
1100
1101@keras_export('keras.callbacks.TensorBoard', v1=[])
1102class TensorBoard(Callback):
1103  # pylint: disable=line-too-long
1104  """Enable visualizations for TensorBoard.
1105
1106  TensorBoard is a visualization tool provided with TensorFlow.
1107
1108  This callback logs events for TensorBoard, including:
1109  * Metrics summary plots
1110  * Training graph visualization
1111  * Activation histograms
1112  * Sampled profiling
1113
1114  If you have installed TensorFlow with pip, you should be able
1115  to launch TensorBoard from the command line:
1116
1117  ```sh
1118  tensorboard --logdir=path_to_your_logs
1119  ```
1120
1121  You can find more information about TensorBoard
1122  [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
1123
1124  Arguments:
1125      log_dir: the path of the directory where to save the log files to be
1126        parsed by TensorBoard.
1127      histogram_freq: frequency (in epochs) at which to compute activation and
1128        weight histograms for the layers of the model. If set to 0, histograms
1129        won't be computed. Validation data (or split) must be specified for
1130        histogram visualizations.
1131      write_graph: whether to visualize the graph in TensorBoard. The log file
1132        can become quite large when write_graph is set to True.
1133      write_images: whether to write model weights to visualize as image in
1134        TensorBoard.
1135      update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`,
1136        writes the losses and metrics to TensorBoard after each batch. The same
1137        applies for `'epoch'`. If using an integer, let's say `1000`, the
1138        callback will write the metrics and losses to TensorBoard every 1000
1139        samples. Note that writing too frequently to TensorBoard can slow down
1140        your training.
1141      profile_batch: Profile the batch to sample compute characteristics. By
1142        default, it will profile the second batch. Set profile_batch=0 to
1143        disable profiling. Must run in TensorFlow eager mode.
1144
1145  Raises:
1146      ValueError: If histogram_freq is set and no validation data is provided.
1147  """
1148
1149  # pylint: enable=line-too-long
1150
1151  def __init__(self,
1152               log_dir='logs',
1153               histogram_freq=0,
1154               write_graph=True,
1155               write_images=False,
1156               update_freq='epoch',
1157               profile_batch=2,
1158               **kwargs):
1159    super(TensorBoard, self).__init__()
1160    self._validate_kwargs(kwargs)
1161
1162    self.log_dir = log_dir
1163    self.histogram_freq = histogram_freq
1164    self.write_graph = write_graph
1165    self.write_images = write_images
1166    if update_freq == 'batch':
1167      self.update_freq = 1
1168    else:
1169      self.update_freq = update_freq
1170
1171    self._samples_seen = 0
1172    self._samples_seen_at_last_write = 0
1173    self._current_batch = 0
1174    self._total_batches_seen = 0
1175    self._total_val_batches_seen = 0
1176
1177    # A collection of file writers currently in use, to be closed when
1178    # training ends for this callback. Writers are keyed by the
1179    # directory name under the root logdir: e.g., "train" or
1180    # "validation".
1181    self._writers = {}
1182    self._train_run_name = 'train'
1183    self._validation_run_name = 'validation'
1184
1185    self._profile_batch = profile_batch
1186    # True when a trace is running.
1187    self._is_tracing = False
1188
1189    # TensorBoard should only write summaries on the chief when in a
1190    # Multi-Worker setting.
1191    self._chief_worker_only = True
1192
1193  def _validate_kwargs(self, kwargs):
1194    """Handle arguments were supported in V1."""
1195    if kwargs.get('write_grads', False):
1196      logging.warning('`write_grads` will be ignored in TensorFlow 2.0 '
1197                      'for the `TensorBoard` Callback.')
1198    if kwargs.get('embeddings_freq', False):
1199      logging.warning('Embeddings will be ignored in TensorFlow 2.0 '
1200                      'for the `TensorBoard` Callback.')
1201    if kwargs.get('batch_size', False):
1202      logging.warning('`batch_size` is no longer needed in the '
1203                      '`TensorBoard` Callback and will be ignored '
1204                      'in TensorFlow 2.0.')
1205
1206    unrecognized_kwargs = set(kwargs.keys()) - {
1207        'write_grads', 'embeddings_freq', 'embeddings_layer_names',
1208        'embeddings_metadata', 'embeddings_data', 'batch_size'
1209    }
1210
1211    # Only allow kwargs that were supported in V1.
1212    if unrecognized_kwargs:
1213      raise ValueError('Unrecognized arguments in `TensorBoard` '
1214                       'Callback: ' + str(unrecognized_kwargs))
1215
1216  def set_model(self, model):
1217    """Sets Keras model and writes graph if specified."""
1218    self.model = model
1219    with context.eager_mode():
1220      self._close_writers()
1221      if self.write_graph:
1222        with self._get_writer(self._train_run_name).as_default():
1223          with summary_ops_v2.always_record_summaries():
1224            if not model.run_eagerly:
1225              summary_ops_v2.graph(K.get_graph())
1226
1227            summary_writable = (
1228                self.model._is_graph_network or  # pylint: disable=protected-access
1229                self.model.__class__.__name__ == 'Sequential')  # pylint: disable=protected-access
1230            if summary_writable:
1231              summary_ops_v2.keras_model('keras', self.model, step=0)
1232
1233  def _close_writers(self):
1234    """Close all remaining open file writers owned by this callback.
1235
1236    If there are no such file writers, this is a no-op.
1237    """
1238    with context.eager_mode():
1239      for writer in six.itervalues(self._writers):
1240        writer.close()
1241      self._writers.clear()
1242
1243  def _get_writer(self, writer_name):
1244    """Get a summary writer for the given subdirectory under the logdir.
1245
1246    A writer will be created if it does not yet exist.
1247
1248    Args:
1249      writer_name: The name of the directory for which to create or
1250        retrieve a writer. Should be either `self._train_run_name` or
1251        `self._validation_run_name`.
1252
1253    Returns:
1254      A `SummaryWriter` object.
1255    """
1256    if writer_name not in self._writers:
1257      path = os.path.join(self.log_dir, writer_name)
1258      writer = summary_ops_v2.create_file_writer_v2(path)
1259      self._writers[writer_name] = writer
1260    return self._writers[writer_name]
1261
1262  def on_train_begin(self, logs=None):
1263    if self._profile_batch == 1:
1264      summary_ops_v2.trace_on(graph=True, profiler=True)
1265      self._is_tracing = True
1266
1267  def on_batch_end(self, batch, logs=None):
1268    """Writes scalar summaries for metrics on every training batch.
1269
1270    Performs profiling if current batch is in profiler_batches.
1271    """
1272    # Don't output batch_size and batch number as TensorBoard summaries
1273    logs = logs or {}
1274    self._samples_seen += logs.get('size', 1)
1275    samples_seen_since = self._samples_seen - self._samples_seen_at_last_write
1276    if self.update_freq != 'epoch' and samples_seen_since >= self.update_freq:
1277      self._log_metrics(logs, prefix='batch_', step=self._total_batches_seen)
1278      self._samples_seen_at_last_write = self._samples_seen
1279    self._total_batches_seen += 1
1280    if self._is_tracing:
1281      self._log_trace()
1282    elif (not self._is_tracing and
1283          self._total_batches_seen == self._profile_batch - 1):
1284      self._enable_trace()
1285
1286  def on_epoch_end(self, epoch, logs=None):
1287    """Runs metrics and histogram summaries at epoch end."""
1288    step = epoch if self.update_freq == 'epoch' else self._samples_seen
1289    self._log_metrics(logs, prefix='epoch_', step=step)
1290
1291    if self.histogram_freq and epoch % self.histogram_freq == 0:
1292      self._log_weights(epoch)
1293
1294  def on_train_end(self, logs=None):
1295    if self._is_tracing:
1296      self._log_trace()
1297    self._close_writers()
1298
1299  def _enable_trace(self):
1300    if context.executing_eagerly():
1301      summary_ops_v2.trace_on(graph=True, profiler=True)
1302      self._is_tracing = True
1303
1304  def _log_trace(self):
1305    if context.executing_eagerly():
1306      with self._get_writer(self._train_run_name).as_default(), \
1307          summary_ops_v2.always_record_summaries():
1308        # TODO(b/126388999): Remove step info in the summary name.
1309        summary_ops_v2.trace_export(
1310            name='batch_%d' % self._total_batches_seen,
1311            step=self._total_batches_seen,
1312            profiler_outdir=os.path.join(self.log_dir, 'train'))
1313      self._is_tracing = False
1314
1315  def _log_metrics(self, logs, prefix, step):
1316    """Writes metrics out as custom scalar summaries.
1317
1318    Arguments:
1319        logs: Dict. Keys are scalar summary names, values are NumPy scalars.
1320        prefix: String. The prefix to apply to the scalar summary names.
1321        step: Int. The global step to use for TensorBoard.
1322    """
1323    if logs is None:
1324      logs = {}
1325
1326    # Group metrics by the name of their associated file writer. Values
1327    # are lists of metrics, as (name, scalar_value) pairs.
1328    logs_by_writer = {
1329        self._train_run_name: [],
1330        self._validation_run_name: [],
1331    }
1332    validation_prefix = 'val_'
1333    for (name, value) in logs.items():
1334      if name in ('batch', 'size', 'num_steps'):
1335        # Scrub non-metric items.
1336        continue
1337      if name.startswith(validation_prefix):
1338        name = name[len(validation_prefix):]
1339        writer_name = self._validation_run_name
1340      else:
1341        writer_name = self._train_run_name
1342      name = prefix + name  # assign batch or epoch prefix
1343      logs_by_writer[writer_name].append((name, value))
1344
1345    with context.eager_mode():
1346      with summary_ops_v2.always_record_summaries():
1347        for writer_name in logs_by_writer:
1348          these_logs = logs_by_writer[writer_name]
1349          if not these_logs:
1350            # Don't create a "validation" events file if we don't
1351            # actually have any validation data.
1352            continue
1353          writer = self._get_writer(writer_name)
1354          with writer.as_default():
1355            for (name, value) in these_logs:
1356              summary_ops_v2.scalar(name, value, step=step)
1357
1358  def _log_weights(self, epoch):
1359    """Logs the weights of the Model to TensorBoard."""
1360    writer = self._get_writer(self._train_run_name)
1361    with context.eager_mode(), \
1362          writer.as_default(), \
1363          summary_ops_v2.always_record_summaries():
1364      for layer in self.model.layers:
1365        for weight in layer.weights:
1366          weight_name = weight.name.replace(':', '_')
1367          with ops.init_scope():
1368            weight = K.get_value(weight)
1369          summary_ops_v2.histogram(weight_name, weight, step=epoch)
1370          if self.write_images:
1371            self._log_weight_as_image(weight, weight_name, epoch)
1372      writer.flush()
1373
1374  def _log_weight_as_image(self, weight, weight_name, epoch):
1375    """Logs a weight as a TensorBoard image."""
1376    w_img = array_ops.squeeze(weight)
1377    shape = K.int_shape(w_img)
1378    if len(shape) == 1:  # Bias case
1379      w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1])
1380    elif len(shape) == 2:  # Dense layer kernel case
1381      if shape[0] > shape[1]:
1382        w_img = array_ops.transpose(w_img)
1383        shape = K.int_shape(w_img)
1384      w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1])
1385    elif len(shape) == 3:  # ConvNet case
1386      if K.image_data_format() == 'channels_last':
1387        # Switch to channels_first to display every kernel as a separate
1388        # image.
1389        w_img = array_ops.transpose(w_img, perm=[2, 0, 1])
1390        shape = K.int_shape(w_img)
1391      w_img = array_ops.reshape(w_img, [shape[0], shape[1], shape[2], 1])
1392
1393    shape = K.int_shape(w_img)
1394    # Not possible to handle 3D convnets etc.
1395    if len(shape) == 4 and shape[-1] in [1, 3, 4]:
1396      summary_ops_v2.image(weight_name, w_img, step=epoch)
1397
1398
1399@keras_export('keras.callbacks.ReduceLROnPlateau')
1400class ReduceLROnPlateau(Callback):
1401  """Reduce learning rate when a metric has stopped improving.
1402
1403  Models often benefit from reducing the learning rate by a factor
1404  of 2-10 once learning stagnates. This callback monitors a
1405  quantity and if no improvement is seen for a 'patience' number
1406  of epochs, the learning rate is reduced.
1407
1408  Example:
1409
1410  ```python
1411  reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
1412                                patience=5, min_lr=0.001)
1413  model.fit(X_train, Y_train, callbacks=[reduce_lr])
1414  ```
1415
1416  Arguments:
1417      monitor: quantity to be monitored.
1418      factor: factor by which the learning rate will
1419          be reduced. new_lr = lr * factor
1420      patience: number of epochs with no improvement
1421          after which learning rate will be reduced.
1422      verbose: int. 0: quiet, 1: update messages.
1423      mode: one of {auto, min, max}. In `min` mode,
1424          lr will be reduced when the quantity
1425          monitored has stopped decreasing; in `max`
1426          mode it will be reduced when the quantity
1427          monitored has stopped increasing; in `auto`
1428          mode, the direction is automatically inferred
1429          from the name of the monitored quantity.
1430      min_delta: threshold for measuring the new optimum,
1431          to only focus on significant changes.
1432      cooldown: number of epochs to wait before resuming
1433          normal operation after lr has been reduced.
1434      min_lr: lower bound on the learning rate.
1435  """
1436
1437  def __init__(self,
1438               monitor='val_loss',
1439               factor=0.1,
1440               patience=10,
1441               verbose=0,
1442               mode='auto',
1443               min_delta=1e-4,
1444               cooldown=0,
1445               min_lr=0,
1446               **kwargs):
1447    super(ReduceLROnPlateau, self).__init__()
1448
1449    self.monitor = monitor
1450    if factor >= 1.0:
1451      raise ValueError('ReduceLROnPlateau ' 'does not support a factor >= 1.0.')
1452    if 'epsilon' in kwargs:
1453      min_delta = kwargs.pop('epsilon')
1454      logging.warning('`epsilon` argument is deprecated and '
1455                      'will be removed, use `min_delta` instead.')
1456    self.factor = factor
1457    self.min_lr = min_lr
1458    self.min_delta = min_delta
1459    self.patience = patience
1460    self.verbose = verbose
1461    self.cooldown = cooldown
1462    self.cooldown_counter = 0  # Cooldown counter.
1463    self.wait = 0
1464    self.best = 0
1465    self.mode = mode
1466    self.monitor_op = None
1467    self._reset()
1468
1469  def _reset(self):
1470    """Resets wait counter and cooldown counter.
1471    """
1472    if self.mode not in ['auto', 'min', 'max']:
1473      logging.warning('Learning Rate Plateau Reducing mode %s is unknown, '
1474                      'fallback to auto mode.', self.mode)
1475      self.mode = 'auto'
1476    if (self.mode == 'min' or
1477        (self.mode == 'auto' and 'acc' not in self.monitor)):
1478      self.monitor_op = lambda a, b: np.less(a, b - self.min_delta)
1479      self.best = np.Inf
1480    else:
1481      self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta)
1482      self.best = -np.Inf
1483    self.cooldown_counter = 0
1484    self.wait = 0
1485
1486  def on_train_begin(self, logs=None):
1487    self._reset()
1488
1489  def on_epoch_end(self, epoch, logs=None):
1490    logs = logs or {}
1491    logs['lr'] = K.get_value(self.model.optimizer.lr)
1492    current = logs.get(self.monitor)
1493    if current is None:
1494      logging.warning('Reduce LR on plateau conditioned on metric `%s` '
1495                      'which is not available. Available metrics are: %s',
1496                      self.monitor, ','.join(list(logs.keys())))
1497
1498    else:
1499      if self.in_cooldown():
1500        self.cooldown_counter -= 1
1501        self.wait = 0
1502
1503      if self.monitor_op(current, self.best):
1504        self.best = current
1505        self.wait = 0
1506      elif not self.in_cooldown():
1507        self.wait += 1
1508        if self.wait >= self.patience:
1509          old_lr = float(K.get_value(self.model.optimizer.lr))
1510          if old_lr > self.min_lr:
1511            new_lr = old_lr * self.factor
1512            new_lr = max(new_lr, self.min_lr)
1513            K.set_value(self.model.optimizer.lr, new_lr)
1514            if self.verbose > 0:
1515              print('\nEpoch %05d: ReduceLROnPlateau reducing learning '
1516                    'rate to %s.' % (epoch + 1, new_lr))
1517            self.cooldown_counter = self.cooldown
1518            self.wait = 0
1519
1520  def in_cooldown(self):
1521    return self.cooldown_counter > 0
1522
1523
1524@keras_export('keras.callbacks.CSVLogger')
1525class CSVLogger(Callback):
1526  """Callback that streams epoch results to a csv file.
1527
1528  Supports all values that can be represented as a string,
1529  including 1D iterables such as np.ndarray.
1530
1531  Example:
1532
1533  ```python
1534  csv_logger = CSVLogger('training.log')
1535  model.fit(X_train, Y_train, callbacks=[csv_logger])
1536  ```
1537
1538  Arguments:
1539      filename: filename of the csv file, e.g. 'run/log.csv'.
1540      separator: string used to separate elements in the csv file.
1541      append: True: append if file exists (useful for continuing
1542          training). False: overwrite existing file,
1543  """
1544
1545  def __init__(self, filename, separator=',', append=False):
1546    self.sep = separator
1547    self.filename = filename
1548    self.append = append
1549    self.writer = None
1550    self.keys = None
1551    self.append_header = True
1552    if six.PY2:
1553      self.file_flags = 'b'
1554      self._open_args = {}
1555    else:
1556      self.file_flags = ''
1557      self._open_args = {'newline': '\n'}
1558    super(CSVLogger, self).__init__()
1559
1560  def on_train_begin(self, logs=None):
1561    if self.append:
1562      if os.path.exists(self.filename):
1563        with open(self.filename, 'r' + self.file_flags) as f:
1564          self.append_header = not bool(len(f.readline()))
1565      mode = 'a'
1566    else:
1567      mode = 'w'
1568    self.csv_file = io.open(self.filename,
1569                            mode + self.file_flags,
1570                            **self._open_args)
1571
1572  def on_epoch_end(self, epoch, logs=None):
1573    logs = logs or {}
1574
1575    def handle_value(k):
1576      is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
1577      if isinstance(k, six.string_types):
1578        return k
1579      elif isinstance(k, collections.Iterable) and not is_zero_dim_ndarray:
1580        return '"[%s]"' % (', '.join(map(str, k)))
1581      else:
1582        return k
1583
1584    if self.keys is None:
1585      self.keys = sorted(logs.keys())
1586
1587    if self.model.stop_training:
1588      # We set NA so that csv parsers do not fail for this last epoch.
1589      logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys])
1590
1591    if not self.writer:
1592
1593      class CustomDialect(csv.excel):
1594        delimiter = self.sep
1595
1596      fieldnames = ['epoch'] + self.keys
1597      if six.PY2:
1598        fieldnames = [unicode(x) for x in fieldnames]
1599
1600      self.writer = csv.DictWriter(
1601          self.csv_file,
1602          fieldnames=fieldnames,
1603          dialect=CustomDialect)
1604      if self.append_header:
1605        self.writer.writeheader()
1606
1607    row_dict = collections.OrderedDict({'epoch': epoch})
1608    row_dict.update((key, handle_value(logs[key])) for key in self.keys)
1609    self.writer.writerow(row_dict)
1610    self.csv_file.flush()
1611
1612  def on_train_end(self, logs=None):
1613    self.csv_file.close()
1614    self.writer = None
1615
1616
1617@keras_export('keras.callbacks.LambdaCallback')
1618class LambdaCallback(Callback):
1619  r"""Callback for creating simple, custom callbacks on-the-fly.
1620
1621  This callback is constructed with anonymous functions that will be called
1622  at the appropriate time. Note that the callbacks expects positional
1623  arguments, as:
1624
1625   - `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
1626      `epoch`, `logs`
1627   - `on_batch_begin` and `on_batch_end` expect two positional arguments:
1628      `batch`, `logs`
1629   - `on_train_begin` and `on_train_end` expect one positional argument:
1630      `logs`
1631
1632  Arguments:
1633      on_epoch_begin: called at the beginning of every epoch.
1634      on_epoch_end: called at the end of every epoch.
1635      on_batch_begin: called at the beginning of every batch.
1636      on_batch_end: called at the end of every batch.
1637      on_train_begin: called at the beginning of model training.
1638      on_train_end: called at the end of model training.
1639
1640  Example:
1641
1642  ```python
1643  # Print the batch number at the beginning of every batch.
1644  batch_print_callback = LambdaCallback(
1645      on_batch_begin=lambda batch,logs: print(batch))
1646
1647  # Stream the epoch loss to a file in JSON format. The file content
1648  # is not well-formed JSON but rather has a JSON object per line.
1649  import json
1650  json_log = open('loss_log.json', mode='wt', buffering=1)
1651  json_logging_callback = LambdaCallback(
1652      on_epoch_end=lambda epoch, logs: json_log.write(
1653          json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),
1654      on_train_end=lambda logs: json_log.close()
1655  )
1656
1657  # Terminate some processes after having finished model training.
1658  processes = ...
1659  cleanup_callback = LambdaCallback(
1660      on_train_end=lambda logs: [
1661          p.terminate() for p in processes if p.is_alive()])
1662
1663  model.fit(...,
1664            callbacks=[batch_print_callback,
1665                       json_logging_callback,
1666                       cleanup_callback])
1667  ```
1668  """
1669
1670  def __init__(self,
1671               on_epoch_begin=None,
1672               on_epoch_end=None,
1673               on_batch_begin=None,
1674               on_batch_end=None,
1675               on_train_begin=None,
1676               on_train_end=None,
1677               **kwargs):
1678    super(LambdaCallback, self).__init__()
1679    self.__dict__.update(kwargs)
1680    if on_epoch_begin is not None:
1681      self.on_epoch_begin = on_epoch_begin
1682    else:
1683      self.on_epoch_begin = lambda epoch, logs: None
1684    if on_epoch_end is not None:
1685      self.on_epoch_end = on_epoch_end
1686    else:
1687      self.on_epoch_end = lambda epoch, logs: None
1688    if on_batch_begin is not None:
1689      self.on_batch_begin = on_batch_begin
1690    else:
1691      self.on_batch_begin = lambda batch, logs: None
1692    if on_batch_end is not None:
1693      self.on_batch_end = on_batch_end
1694    else:
1695      self.on_batch_end = lambda batch, logs: None
1696    if on_train_begin is not None:
1697      self.on_train_begin = on_train_begin
1698    else:
1699      self.on_train_begin = lambda logs: None
1700    if on_train_end is not None:
1701      self.on_train_end = on_train_end
1702    else:
1703      self.on_train_end = lambda logs: None
1704