1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training-related utilities."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import atexit
22import collections
23import functools
24import multiprocessing.pool
25import threading
26import time
27
28import numpy as np
29import six
30from six.moves import zip  # pylint: disable=redefined-builtin
31
32from tensorflow.core.framework import graph_pb2
33from tensorflow.python import tf2
34from tensorflow.python.data.experimental.ops import cardinality
35from tensorflow.python.data.experimental.ops import distribute_options
36from tensorflow.python.data.ops import dataset_ops
37from tensorflow.python.data.ops import iterator_ops
38from tensorflow.python.eager import context
39from tensorflow.python.framework import composite_tensor
40from tensorflow.python.framework import dtypes
41from tensorflow.python.framework import errors
42from tensorflow.python.framework import ops
43from tensorflow.python.framework import smart_cond
44from tensorflow.python.framework import sparse_tensor
45from tensorflow.python.framework import tensor_spec
46from tensorflow.python.framework import tensor_util
47from tensorflow.python.keras import backend as K
48from tensorflow.python.keras import callbacks as cbks
49from tensorflow.python.keras import losses
50from tensorflow.python.keras import metrics as metrics_module
51from tensorflow.python.keras.utils import data_utils
52from tensorflow.python.keras.utils import generic_utils
53from tensorflow.python.keras.utils import losses_utils
54from tensorflow.python.keras.utils import tf_inspect
55from tensorflow.python.ops import array_ops
56from tensorflow.python.ops import gen_array_ops
57from tensorflow.python.ops import math_ops
58from tensorflow.python.ops import sparse_ops
59from tensorflow.python.ops.ragged import ragged_tensor
60from tensorflow.python.ops.ragged import ragged_tensor_value
61from tensorflow.python.platform import tf_logging as logging
62from tensorflow.python.util import nest
63
64
65def is_composite_or_composite_value(tensor):
66  """Returns true if 'tensor' is a CompositeTensor or a CT Value object."""
67  # TODO(b/125094323): This should be isinstance(CompositeTensor) or
68  # isinstance(CompositeTensorValue) once we support that.
69  return isinstance(
70      tensor,
71      (composite_tensor.CompositeTensor, sparse_tensor.SparseTensorValue,
72       ragged_tensor_value.RaggedTensorValue))
73
74
75@six.add_metaclass(abc.ABCMeta)
76class Aggregator(object):
77  """Abstract base class used to aggregate batch-level outputs of a loop.
78
79  Attributes:
80    use_steps: Whether the loop is using `step` or `batch_size`.
81    num_samples: Total number of samples: `batch_size * num_batches`.
82    steps: Total number of steps.
83    batch_size: Batch size. It is used for validation checks between inputs and
84      outputs.
85    results: What to return at the end of the aggregation loop.
86  """
87
88  def __init__(self, use_steps, num_samples=None, steps=None, batch_size=None):
89    self.use_steps = use_steps
90    self.num_samples = num_samples
91    self.steps = steps
92    self.batch_size = batch_size
93    self.results = []
94
95  @abc.abstractmethod
96  def create(self, batch_outs):
97    """Creates the initial results from the first batch outputs.
98
99    Args:
100      batch_outs: A list of batch-level outputs.
101    """
102    raise NotImplementedError('Must be implemented in subclasses.')
103
104  @abc.abstractmethod
105  def aggregate(self, batch_outs, batch_start=None, batch_end=None):
106    """Aggregates batch-level results into total results.
107
108    Args:
109      batch_outs: A list of batch-level outputs.
110      batch_start: The start index of this batch. Always `None` if `use_steps`
111        is `True`.
112      batch_end: The end index of this batch. Always `None` if `use_steps` is
113        `True`.
114    """
115    raise NotImplementedError('Must be implemented in subclasses.')
116
117  @abc.abstractmethod
118  def finalize(self):
119    """Prepares the total results to be returned."""
120    raise NotImplementedError('Must be implemented in subclasses.')
121
122
123class MetricsAggregator(Aggregator):
124  """Aggregator that calculates loss and metrics info.
125
126  Attributes:
127    use_steps: Whether the loop is using `step` or `batch_size`.
128    num_samples: Total number of samples: `batch_size*num_batches`.
129    steps: Total number of steps, ie number of times to iterate over a dataset
130      to cover all samples.
131  """
132
133  def __init__(self, use_steps, num_samples=None, steps=None):
134    super(MetricsAggregator, self).__init__(
135        use_steps=use_steps,
136        num_samples=num_samples,
137        steps=steps,
138        batch_size=None)
139
140  def create(self, batch_outs):
141    self.results = [0.] * len(batch_outs)
142
143  def aggregate(self, batch_outs, batch_start=None, batch_end=None):
144    # Loss.
145    if self.use_steps:
146      self.results[0] += batch_outs[0]
147    else:
148      self.results[0] += batch_outs[0] * (batch_end - batch_start)
149    # Metrics (always stateful, just grab current values.)
150    self.results[1:] = batch_outs[1:]
151
152  def finalize(self):
153    if not self.results:
154      raise ValueError('Empty training data.')
155    self.results[0] /= (self.num_samples or self.steps)
156
157
158def _append_sparse_tensor_value(target, to_append):
159  """Append sparse tensor value objects."""
160  # Make sure the sparse tensors are of the same size (except for the 0th dim).
161  if len(target.dense_shape) != len(to_append.dense_shape):
162    raise RuntimeError(
163        'Unable to concatenate %s and %s. The inner dense shapes do not '
164        'have the same number of dimensions (%s vs %s)' %
165        (target, to_append, target.dense_shape, to_append.dense_shape))
166
167  if target.dense_shape[1:] != to_append.dense_shape[1:]:
168    raise RuntimeError(
169        'Unable to concatenate %s and %s. The inner dense shapes do not '
170        'match inner dimensions (%s vs %s)' %
171        (target, to_append, target.dense_shape[1:], to_append.dense_shape[1:]))
172
173  # Add the to_append indices to target, updating the 0th value, and keeping
174  # track of the maximum so we know the final dense_shape of this tensor.
175  base_dim0_value = target.dense_shape[0]
176  max_dim0_value = target.dense_shape[0]
177  new_indices = target.indices
178  for index in to_append.indices:
179    # Here, we iterate through the sparse indices of the tensor to append. For
180    # each index, we update its zeroth value (the batch index) by adding the
181    # number of batch items in the tensor we are appending to (so an index
182    # of [0, 0, 1] for a value that is being appended to a tensor with 0th dim
183    # size 3 would become [3, 0, 1].)
184    index[0] += base_dim0_value
185    max_dim0_value = max(max_dim0_value, index[0])
186    new_indices = np.append(new_indices, [index], axis=0)
187
188  # Extend the values array to contain all of the appended values. These will
189  # be in the same order as the indices added above.
190  new_values = np.concatenate((target.values, to_append.values), axis=0)
191
192  # Create a new dense shape by replacing the value for the 0th dimension
193  # with the new max dim0 value.
194  new_dense_shape = list(target.dense_shape)
195  new_dense_shape[0] = max_dim0_value + 1
196  new_dense_shape = tuple(new_dense_shape)
197
198  return sparse_tensor.SparseTensorValue(
199      indices=new_indices, values=new_values, dense_shape=new_dense_shape)
200
201
202def _append_ragged_tensor_value(target, to_append):
203  """Append ragged tensor value objects."""
204  # Make sure the ragged tensors are of the same size (save for the 0th dim).
205  if len(target.shape) != len(to_append.shape):
206    raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append))
207
208  if target.shape[1:] != to_append.shape[1:]:
209    raise RuntimeError('Unable to concatenate %s and %s' % (target, to_append))
210
211  adjusted_row_splits = to_append.row_splits[1:] + target.row_splits[-1]
212  new_row_splits = np.append(target.row_splits, adjusted_row_splits)
213  if isinstance(target.values, ragged_tensor_value.RaggedTensorValue):
214    new_values = _append_ragged_tensor_value(target.values, to_append.values)
215  else:
216    new_values = np.concatenate((target.values, to_append.values), axis=0)
217
218  return ragged_tensor_value.RaggedTensorValue(new_values, new_row_splits)
219
220
221def _append_composite_tensor(target, to_append):
222  """Helper function to append composite tensors to each other in the 0 axis.
223
224  In order to support batching within a fit/evaluate/predict call, we need
225  to be able to aggregate within a CompositeTensor. Unfortunately, the CT
226  API currently does not make this easy - especially in V1 mode, where we're
227  working with CompositeTensor Value objects that have no connection with the
228  CompositeTensors that created them.
229
230  Args:
231    target: CompositeTensor or CompositeTensor value object that will be
232      appended to.
233    to_append: CompositeTensor or CompositeTensor value object to append to.
234      'target'.
235
236  Returns:
237    A CompositeTensor or CompositeTensor value object.
238
239  Raises:
240    RuntimeError: if concatenation is not possible.
241  """
242  if type(target) is not type(to_append):
243    raise RuntimeError('Unable to concatenate %s and %s' %
244                       (type(target), type(to_append)))
245
246  # Perform type-specific concatenation.
247  # TODO(b/125094323): This should be replaced by a simple call to
248  # target.append() that should work on all of the below classes.
249
250  # If we're seeing a CompositeTensor here, we know it's because we're in
251  # Eager mode (or else we'd have evaluated the CT to a CT Value object
252  # already). Therefore, it's safe to call concat() on it without evaluating
253  # the result any further. If not - that is, if we're seeing a
254  # SparseTensorValue or a RaggedTensorValue - we need to hand-update it
255  # since we're outside of the graph anyways.
256  if isinstance(target, sparse_tensor.SparseTensor):
257    # We need to invoke the sparse version of concatenate here - tf.concat
258    # won't work.
259    return sparse_ops.sparse_concat(sp_inputs=[target, to_append], axis=0)
260  elif isinstance(target, ragged_tensor.RaggedTensor):
261    return array_ops.concat([target, to_append], axis=0)
262  elif isinstance(target, sparse_tensor.SparseTensorValue):
263    return _append_sparse_tensor_value(target, to_append)
264  elif isinstance(target, ragged_tensor_value.RaggedTensorValue):
265    return _append_ragged_tensor_value(target, to_append)
266  else:
267    raise RuntimeError('Attempted to concatenate unsupported object %s.' %
268                       type(target))
269
270
271class ConcatAggregator(Aggregator):
272  """Combine tensor-likes which cannot be merged on the fly.
273
274  This class expects to aggregate a single tensor-like rather than a nested
275  structure of tensor-likes.
276  """
277
278  def __init__(self, batch_size):
279    self.composite = None
280    super(ConcatAggregator, self).__init__(
281        use_steps=True, num_samples=None, steps=None, batch_size=batch_size)
282
283  def create(self, batch_element):
284    self.composite = is_composite_or_composite_value(batch_element)
285
286  def aggregate(self, batch_element, batch_start=None, batch_end=None):
287
288    # TODO(psv): Add num_samples check here to detect when output batch
289    # #samples is < batch size and != input batch #samples.
290    if self.batch_size and self.batch_size < batch_element.shape[0]:
291      raise ValueError(
292          'Mismatch between expected batch size and model output batch size. '
293          'Output shape = {}, expected output shape = shape {}'.format(
294              batch_element.shape,
295              (self.batch_size,) + batch_element.shape[1:]))
296    self.results.append(batch_element)
297
298  def finalize(self):
299    # Special case of single batch inference which skips a copy.
300    if len(self.results) == 1:
301      self.results = self.results[0]
302
303    elif self.composite:
304      # TODO(taylorrobie): efficiently concatenate.
305      results = self.results[0]
306      for r in self.results[1:]:
307        results = _append_composite_tensor(results, r)
308      self.results = results
309
310    else:
311      self.results = np.concatenate(self.results, axis=0)
312
313
314_COPY_THREADS = 4
315_COPY_POOL = None
316
317
318def get_copy_pool():
319  """Shared threadpool for copying arrays.
320
321  Pool instantiation takes ~ 2ms, so a singleton pool is used rather than
322  creating a pool per SliceAggregator.
323
324  Returns:
325    The global copy threadpool.
326  """
327  global _COPY_POOL
328  if _COPY_POOL is None:
329    _COPY_POOL = multiprocessing.pool.ThreadPool(_COPY_THREADS)
330    atexit.register(_COPY_POOL.close)
331  return _COPY_POOL
332
333
334class SliceAggregator(Aggregator):
335  """Combine arrays where the final size is known.
336
337  This class expects to aggregate a single tensor-like rather than a nested
338  structure of tensor-likes.
339
340  NumPy copies are an operation that threads handle quite well because all of
341  the heavy lifting is in c and does not need the GIL. Moreover, we can perform
342  lock-free writes to the same buffer in multiple threads because the nature of
343  result aggregation guarantees that either the indices are disjoint or the
344  aggregator will throw an exception in finalize. Moreover, because aggregation
345  is performed on the slowest varying dimension, assignments for a given batch
346  will write to contiguous blocks of memory, further minimizing contention.
347
348  There is, however, some scheduling and context switching overhead which will
349  offset the gains from pipelining the slice assignment. Below a given threshold
350  it is faster to simply assign in the main thread rather than enqueue the
351  assignment in a side thread. The exact threshold will vary from system to
352  system, but the time is not very sensitive to the exact transition so a value
353  of 2 ** 14 was chosen which should be reasonable on most systems.
354  """
355
356  _BINARY_SIZE_THRESHOLD = 2 ** 14
357  _MAX_COPY_SECONDS = 300
358
359  def __init__(self, num_samples, batch_size):
360    self._async_copies = []
361    self._pool = get_copy_pool()
362    self._errors = []
363    super(SliceAggregator, self).__init__(
364        use_steps=False,
365        num_samples=num_samples,
366        steps=None,
367        batch_size=batch_size)
368
369  def create(self, batch_element):
370    # This step does not need to be pipelined because NumPy empty array
371    # initialization is effectively instantaneous.
372    shape = (self.num_samples,) + batch_element.shape[1:]
373    dtype = batch_element.dtype
374
375    self.results = np.empty(shape=shape, dtype=dtype)
376
377  def aggregate(self, batch_element, batch_start, batch_end):
378    # Fail early.
379    if self._errors:
380      six.reraise(type(self._errors[0]), self._errors[0])
381
382    # In the special case of single batch inference, no copy is needed.
383    if batch_end - batch_start == self.num_samples:
384      if self.num_samples != batch_element.shape[0]:
385        raise ValueError(
386            'Mismatch between expected batch size and model output batch size. '
387            'Output shape = {}, expected output shape = shape {}'.format(
388                batch_element.shape, self.results.shape))
389
390      self.results = batch_element
391      return
392
393    # This is an approximate threshold, so we don't need to consider the number
394    # of bytes per element.
395    num_elements = np.prod(batch_element.shape)
396    if num_elements < self._BINARY_SIZE_THRESHOLD:
397      self.results[batch_start:batch_end] = batch_element
398    else:
399      is_finished = threading.Event()
400      self._pool.apply_async(
401          self._slice_assign,
402          args=(batch_element, batch_start, batch_end, is_finished))
403      self._async_copies.append(is_finished)
404
405  def _slice_assign(self, batch_element, batch_start, batch_end, is_finished):
406    """Legacy utility method to slice input arrays."""
407    try:
408      self.results[batch_start:batch_end] = batch_element
409
410    except Exception as e:  # pylint: disable=broad-except
411      # `_slice_assign` should only be called in threads and exceptions raised
412      # in threads do not carry over to the main thread. So instead we perform a
413      # a broad catch in the thread and then store the exception to be re-raised
414      # in the main thread.
415      self._errors.append(e)
416
417    finally:
418      is_finished.set()
419
420  def finalize(self):
421    start_time = time.time()
422    for is_finished in self._async_copies:
423      timeout = max([0., self._MAX_COPY_SECONDS - (time.time() - start_time)])
424      if not is_finished.wait(timeout):
425        raise ValueError('Timed out waiting for copy to complete.')
426
427    if self._errors:
428      six.reraise(self._errors[0].__class__, self._errors[0])
429
430
431class OutputsAggregator(Aggregator):
432  """Aggregator that concatenates outputs."""
433
434  _structure = None
435
436  def create(self, batch_outs):
437    # SparseTensorValue is a named tuple which nest will flatten, so we need
438    # to guard it to properly handle the structure.
439    self._structure = nest.get_traverse_shallow_structure(
440        lambda x: not is_composite_or_composite_value(x), batch_outs)
441    batch_outs = nest.flatten_up_to(self._structure, batch_outs)
442
443    for batch_element in batch_outs:
444      if is_composite_or_composite_value(batch_element):
445        # If the output is not a ndarray, it will be either a composite tensor
446        # or a composite tensor's Value object. In either case, we can't
447        # allocate an array to hold the object - we'll handle it later.
448        self.results.append(ConcatAggregator(self.batch_size))
449      elif isinstance(batch_element, np.ndarray):
450        self.results.append(
451            (ConcatAggregator(self.batch_size) if self.use_steps else
452             SliceAggregator(self.num_samples, self.batch_size)))
453      else:
454        # This is not a ndarray, a CompositeTensor, or a CompositeTensorValue.
455        # Fail fast rather than trying to concatenate it.
456        raise RuntimeError('Attempted to aggregate unsupported object {}.'
457                           .format(batch_element))
458
459      self.results[-1].create(batch_element)
460
461  def aggregate(self, batch_outs, batch_start=None, batch_end=None):
462    batch_outs = nest.flatten_up_to(self._structure, batch_outs)
463    for batch_element, result in zip(batch_outs, self.results):
464      result.aggregate(batch_element, batch_start, batch_end)
465
466  def finalize(self):
467    for result in self.results:
468      result.finalize()
469    self.results = [i.results for i in self.results]
470    self.results = nest.pack_sequence_as(self._structure, self.results)
471
472
473def get_progbar(model, count_mode, include_metrics=True):
474  """Get Progbar."""
475  if include_metrics:
476    stateful_metric_names = getattr(model, 'metrics_names', None)
477    if stateful_metric_names:
478      stateful_metric_names = stateful_metric_names[1:]  # Exclude `loss`
479  else:
480    stateful_metric_names = None
481  return cbks.ProgbarLogger(count_mode, stateful_metrics=stateful_metric_names)
482
483
484def check_num_samples(ins, batch_size=None, steps=None, steps_name='steps'):
485  """Determine the number of samples provided for training and evaluation.
486
487  The number of samples is not defined when running with `steps`,
488  in which case the number of samples is set to `None`.
489
490  Args:
491      ins: List of tensors to be fed to the Keras function.
492      batch_size: Integer batch size or `None` if not defined.
493      steps: Total number of steps (batches of samples) before declaring
494        `_predict_loop` finished. Ignored with the default value of `None`.
495      steps_name: The public API's parameter name for `steps`.
496
497  Raises:
498      ValueError: when `steps` is `None` and the attribute `ins.shape`
499      does not exist. Also raises ValueError when `steps` is not `None`
500      and `batch_size` is not `None` because they are mutually
501      exclusive.
502
503  Returns:
504      When steps is `None`, returns the number of samples to be
505      processed based on the size of the first dimension of the
506      first input numpy array. When steps is not `None` and
507      `batch_size` is `None`, returns `None`.
508  """
509  if steps is not None and batch_size is not None:
510    raise ValueError('If ' + steps_name +
511                     ' is set, the `batch_size` must be None.')
512  if check_steps_argument(ins, steps, steps_name):
513    return None
514
515  if hasattr(ins[0], 'shape'):
516    return int(ins[0].shape[0])
517  return None  # Edge case where ins == [static_learning_phase]
518
519
520def standardize_single_array(x, expected_shape=None):
521  """Expand data of shape (x,) to (x, 1), unless len(expected_shape)==1."""
522  if x is None:
523    return None
524
525  if is_composite_or_composite_value(x):
526    return x
527
528  if isinstance(x, int):
529    raise ValueError(
530        'Expected an array data type but received an integer: {}'.format(x))
531
532  if (x.shape is not None and len(x.shape) == 1 and
533      (expected_shape is None or len(expected_shape) != 1)):
534    if tensor_util.is_tf_type(x):
535      x = array_ops.expand_dims(x, axis=1)
536    else:
537      x = np.expand_dims(x, 1)
538  return x
539
540
541def get_composite_shape(tensor):
542  """Returns the shape of the passed composite tensor."""
543  if isinstance(tensor, sparse_tensor.SparseTensorValue):
544    # SparseTensorValues use a 'dense_shape' attribute
545    return tensor.dense_shape
546  else:
547    return tensor.shape
548
549
550def standardize_input_data(data,
551                           names,
552                           shapes=None,
553                           check_batch_axis=True,
554                           exception_prefix=''):
555  """Normalizes inputs and targets provided by users.
556
557  Users may pass data as a list of arrays, dictionary of arrays,
558  or as a single array. We normalize this to an ordered list of
559  arrays (same order as `names`), while checking that the provided
560  arrays have shapes that match the network's expectations.
561
562  Args:
563      data: User-provided input data (polymorphic).
564      names: List of expected array names.
565      shapes: Optional list of expected array shapes.
566      check_batch_axis: Boolean; whether to check that the batch axis of the
567        arrays matches the expected value found in `shapes`.
568      exception_prefix: String prefix used for exception formatting.
569
570  Returns:
571      List of standardized input arrays (one array per model input).
572
573  Raises:
574      ValueError: in case of improperly formatted user-provided data.
575  """
576  try:
577    data_len = len(data)
578  except TypeError:
579    # For instance if data is `None` or a symbolic Tensor.
580    data_len = None
581
582  if not names:
583    if data_len and not isinstance(data, dict):
584      raise ValueError(
585          'Error when checking model ' + exception_prefix + ': '
586          'expected no data, but got:', data)
587    return []
588  if data is None:
589    return [None for _ in range(len(names))]
590
591  if isinstance(data, dict):
592    try:
593      data = [
594          data[x].values
595          if data[x].__class__.__name__ == 'DataFrame' else data[x]
596          for x in names
597      ]
598    except KeyError as e:
599      raise ValueError('No data provided for "' + e.args[0] + '". Need data '
600                       'for each key in: ' + str(names))
601  elif isinstance(data, (list, tuple)):
602    if isinstance(data[0], (list, tuple)):
603      data = [np.asarray(d) for d in data]
604    elif len(names) == 1 and isinstance(data[0], (float, int)):
605      data = [np.asarray(data)]
606    else:
607      data = [
608          x.values if x.__class__.__name__ == 'DataFrame' else x for x in data
609      ]
610  else:
611    data = data.values if data.__class__.__name__ == 'DataFrame' else data
612    data = [data]
613
614  if shapes is not None:
615    data = [
616        standardize_single_array(x, shape) for (x, shape) in zip(data, shapes)
617    ]
618  else:
619    data = [standardize_single_array(x) for x in data]
620
621  if len(data) != len(names):
622    if data and hasattr(data[0], 'shape'):
623      raise ValueError('Error when checking model ' + exception_prefix +
624                       ': the list of Numpy arrays that you are passing to '
625                       'your model is not the size the model expected. '
626                       'Expected to see ' + str(len(names)) + ' array(s), ' +
627                       'for inputs ' + str(names) + ' but instead got the '
628                       'following list of ' + str(len(data)) + ' arrays: ' +
629                       str(data)[:200] + '...')
630    elif len(names) > 1:
631      raise ValueError('Error when checking model ' + exception_prefix +
632                       ': you are passing a list as input to your model, '
633                       'but the model expects a list of ' + str(len(names)) +
634                       ' Numpy arrays instead. The list you passed was: ' +
635                       str(data)[:200])
636    elif len(data) == 1 and not hasattr(data[0], 'shape'):
637      raise TypeError('Error when checking model ' + exception_prefix +
638                      ': data should be a Numpy array, or list/dict of '
639                      'Numpy arrays. Found: ' + str(data)[:200] + '...')
640    elif len(names) == 1:
641      data = [np.asarray(data)]
642
643  # Check shapes compatibility.
644  if shapes:
645    for i in range(len(names)):
646      if shapes[i] is not None:
647        if tensor_util.is_tf_type(data[i]):
648          tensorshape = data[i].shape
649          if not tensorshape:
650            continue
651          data_shape = tuple(tensorshape.as_list())
652        elif is_composite_or_composite_value(data[i]):
653          tensorshape = get_composite_shape(data[i])
654          data_shape = tuple(tensorshape.as_list())
655        else:
656          data_shape = data[i].shape
657
658        shape = shapes[i]
659        if len(data_shape) != len(shape):
660          raise ValueError('Error when checking ' + exception_prefix +
661                           ': expected ' + names[i] + ' to have ' +
662                           str(len(shape)) + ' dimensions, but got array '
663                           'with shape ' + str(data_shape))
664        if not check_batch_axis:
665          data_shape = data_shape[1:]
666          shape = shape[1:]
667        for dim, ref_dim in zip(data_shape, shape):
668          if ref_dim != dim and ref_dim is not None and dim is not None:
669            raise ValueError('Error when checking ' + exception_prefix +
670                             ': expected ' + names[i] + ' to have shape ' +
671                             str(shape) + ' but got array with shape ' +
672                             str(data_shape))
673  return data
674
675
676def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
677  """Maps `sample_weight` or `class_weight` to model outputs.
678
679  Args:
680      x_weight: User-provided `sample_weight` or `class_weight` argument.
681      output_names: List of output names (strings) in the model.
682      weight_type: A string used purely for exception printing.
683
684  Returns:
685      A list of `sample_weight` or `class_weight` where there are exactly
686          one element per model output.
687
688  Raises:
689      ValueError: In case of invalid user-provided argument.
690  """
691  if x_weight is None or (isinstance(x_weight, (list, tuple)) and
692                          len(x_weight) == 0):  # pylint: disable=g-explicit-length-test
693    return [None for _ in output_names]
694  if len(output_names) == 1:
695    if isinstance(x_weight, (list, tuple)) and len(x_weight) == 1:
696      return x_weight
697    if isinstance(x_weight, dict) and output_names[0] in x_weight:
698      return [x_weight[output_names[0]]]
699    else:
700      return [x_weight]
701  if isinstance(x_weight, (list, tuple)):
702    if len(x_weight) != len(output_names):
703      raise ValueError('Provided `' + weight_type + '` was a list of ' +
704                       str(len(x_weight)) + ' elements, but the model has ' +
705                       str(len(output_names)) + ' outputs. '
706                       'You should provide one `' + weight_type + '`'
707                       'array per model output.')
708    return x_weight
709  if isinstance(x_weight, collections.abc.Mapping):
710    generic_utils.check_for_unexpected_keys(weight_type, x_weight, output_names)
711    x_weights = []
712    for name in output_names:
713      x_weights.append(x_weight.get(name))
714    return x_weights
715  else:
716    raise TypeError('The model has multiple outputs, so `' + weight_type + '` '
717                    'should be either a list or a dict. '
718                    'Provided `' + weight_type + '` type not understood: ' +
719                    str(x_weight))
720
721
722def standardize_class_weights(class_weight, output_names):
723  return standardize_sample_or_class_weights(class_weight, output_names,
724                                             'class_weight')
725
726
727def standardize_sample_weights(sample_weight, output_names):
728  return standardize_sample_or_class_weights(sample_weight, output_names,
729                                             'sample_weight')
730
731
732def check_array_lengths(inputs, targets, weights=None):
733  """Does user input validation for numpy arrays.
734
735  Args:
736      inputs: list of Numpy arrays of inputs.
737      targets: list of Numpy arrays of targets.
738      weights: list of Numpy arrays of sample weights.
739
740  Raises:
741      ValueError: in case of incorrectly formatted data.
742  """
743
744  def is_tensor_or_composite_tensor(x):
745    return tensor_util.is_tf_type(x) or is_composite_or_composite_value(x)
746
747  def set_of_lengths(x):
748    # Returns a set with the variation between
749    # different shapes, with None => 0
750    if x is None:
751      return {}
752    else:
753      return set([
754          y.shape[0]
755          for y in x
756          if y is not None and not is_tensor_or_composite_tensor(y)
757      ])
758
759  set_x = set_of_lengths(inputs)
760  set_y = set_of_lengths(targets)
761  set_w = set_of_lengths(weights)
762  if len(set_x) > 1:
763    raise ValueError('All input arrays (x) should have '
764                     'the same number of samples. Got array shapes: ' +
765                     str([x.shape for x in inputs]))
766  if len(set_y) > 1:
767    raise ValueError('All target arrays (y) should have '
768                     'the same number of samples. Got array shapes: ' +
769                     str([y.shape for y in targets]))
770  if set_x and set_y and list(set_x)[0] != list(set_y)[0]:
771    raise ValueError('Input arrays should have '
772                     'the same number of samples as target arrays. '
773                     'Found ' + str(list(set_x)[0]) + ' input samples '
774                     'and ' + str(list(set_y)[0]) + ' target samples.')
775  if len(set_w) > 1:
776    raise ValueError('All sample_weight arrays should have '
777                     'the same number of samples. Got array shapes: ' +
778                     str([w.shape for w in weights]))
779  if set_y and set_w and list(set_y)[0] != list(set_w)[0]:
780    raise ValueError('Sample_weight arrays should have '
781                     'the same number of samples as target arrays. Got ' +
782                     str(list(set_y)[0]) + ' input samples and ' +
783                     str(list(set_w)[0]) + ' target samples.')
784
785
786def check_loss_and_target_compatibility(targets, loss_fns, output_shapes):
787  """Does validation on the compatibility of targets and loss functions.
788
789  This helps prevent users from using loss functions incorrectly. This check
790  is purely for UX purposes.
791
792  Args:
793      targets: list of Numpy arrays of targets.
794      loss_fns: list of loss functions.
795      output_shapes: list of shapes of model outputs.
796
797  Raises:
798      ValueError: if a loss function or target array
799          is incompatible with an output.
800  """
801  key_loss_fns = {
802      losses.mean_squared_error, losses.binary_crossentropy,
803      losses.categorical_crossentropy
804  }
805  key_loss_classes = (losses.MeanSquaredError, losses.BinaryCrossentropy,
806                      losses.CategoricalCrossentropy)
807  for y, loss, shape in zip(targets, loss_fns, output_shapes):
808    if y is None or loss is None or tensor_util.is_tf_type(y):
809      continue
810    if losses.is_categorical_crossentropy(loss):
811      if y.shape[-1] == 1:
812        raise ValueError('You are passing a target array of shape ' +
813                         str(y.shape) +
814                         ' while using as loss `categorical_crossentropy`. '
815                         '`categorical_crossentropy` expects '
816                         'targets to be binary matrices (1s and 0s) '
817                         'of shape (samples, classes). '
818                         'If your targets are integer classes, '
819                         'you can convert them to the expected format via:\n'
820                         '```\n'
821                         'from keras.utils import to_categorical\n'
822                         'y_binary = to_categorical(y_int)\n'
823                         '```\n'
824                         '\n'
825                         'Alternatively, you can use the loss function '
826                         '`sparse_categorical_crossentropy` instead, '
827                         'which does expect integer targets.')
828
829    is_loss_wrapper = isinstance(loss, losses.LossFunctionWrapper)
830    if (isinstance(loss, key_loss_classes) or (is_loss_wrapper and
831                                               (loss.fn in key_loss_fns))):
832      for target_dim, out_dim in zip(y.shape[1:], shape[1:]):
833        if out_dim is not None and target_dim != out_dim:
834          loss_name = loss.name
835          if loss_name is None:
836            loss_type = loss.fn if is_loss_wrapper else type(loss)
837            loss_name = loss_type.__name__
838          raise ValueError('A target array with shape ' + str(y.shape) +
839                           ' was passed for an output of shape ' + str(shape) +
840                           ' while using as loss `' + loss_name + '`. '
841                           'This loss expects targets to have the same shape '
842                           'as the output.')
843
844
845def collect_per_output_metric_info(metrics,
846                                   output_names,
847                                   output_shapes,
848                                   loss_fns,
849                                   is_weighted=False):
850  """Maps metric names and functions to model outputs.
851
852  Args:
853      metrics: a list or a list of lists or a dict of metric functions.
854      output_names: a list of the names (strings) of model outputs.
855      output_shapes: a list of the shapes (strings) of model outputs.
856      loss_fns: a list of the loss functions corresponding to the model outputs.
857      is_weighted: Boolean indicating whether the given metrics are weighted.
858
859  Returns:
860      A list (one entry per model output) of dicts.
861      For instance, if the model has 2 outputs, and for the first output
862      we want to compute "binary_accuracy" and "binary_crossentropy",
863      and just "binary_accuracy" for the second output,
864      the list would look like: `[{
865          'acc': binary_accuracy(),
866          'ce': binary_crossentropy(),
867        }, {
868          'acc': binary_accuracy(),
869        }]`
870
871  Raises:
872      TypeError: if an incorrect type is passed for the `metrics` argument.
873  """
874  if not metrics:
875    return [{} for _ in output_names]
876
877  if isinstance(metrics, list):
878    any_sub_list = any(isinstance(m, list) for m in metrics)
879    if any_sub_list:
880      if len(metrics) != len(output_names):
881        raise ValueError('When passing a list of lists as `metrics`, '
882                         'it should have one entry per model output. '
883                         'The model has ' + str(len(output_names)) +
884                         ' outputs, but you passed metrics=' + str(metrics))
885      # User has provided a list of len = len(outputs).
886      nested_metrics = [generic_utils.to_list(m) for m in metrics]
887    else:
888      # If it is a single list we then apply all metrics to all outputs.
889      if len(output_names) > 1:
890        nested_metrics = []
891        for _ in output_names:
892          nested_metrics.append(
893              [metrics_module.clone_metric(m) for m in metrics])
894      else:
895        nested_metrics = [metrics]
896  elif isinstance(metrics, collections.abc.Mapping):
897    generic_utils.check_for_unexpected_keys('metrics', metrics, output_names)
898    nested_metrics = []
899    for name in output_names:
900      output_metrics = generic_utils.to_list(metrics.get(name, []))
901      nested_metrics.append(output_metrics)
902  else:
903    raise TypeError('Type of `metrics` argument not understood. '
904                    'Expected a list or dictionary, found: ' + str(metrics))
905
906  per_output_metrics = []
907  for i, metrics in enumerate(nested_metrics):
908    metrics_dict = collections.OrderedDict()
909    for metric in metrics:
910      metric_name = get_metric_name(metric, is_weighted)
911      metric_fn = get_metric_function(
912          metric, output_shape=output_shapes[i], loss_fn=loss_fns[i])
913
914      # If the metric function is not stateful, we create a stateful version.
915      if not isinstance(metric_fn, metrics_module.Metric):
916        metric_fn = metrics_module.MeanMetricWrapper(
917            metric_fn, name=metric_name)
918      metrics_dict[metric_name] = metric_fn
919    per_output_metrics.append(metrics_dict)
920
921  return per_output_metrics
922
923
924def batch_shuffle(index_array, batch_size):
925  """Shuffles an array in a batch-wise fashion.
926
927  Useful for shuffling HDF5 arrays
928  (where one cannot access arbitrary indices).
929
930  Args:
931      index_array: array of indices to be shuffled.
932      batch_size: integer.
933
934  Returns:
935      The `index_array` array, shuffled in a batch-wise fashion.
936  """
937  batch_count = int(len(index_array) / batch_size)
938  # to reshape we need to be cleanly divisible by batch size
939  # we stash extra items and reappend them after shuffling
940  last_batch = index_array[batch_count * batch_size:]
941  index_array = index_array[:batch_count * batch_size]
942  index_array = index_array.reshape((batch_count, batch_size))
943  np.random.shuffle(index_array)
944  index_array = index_array.flatten()
945  return np.append(index_array, last_batch)
946
947
948def standardize_weights(y,
949                        sample_weight=None,
950                        class_weight=None,
951                        sample_weight_mode=None):
952  """Performs sample weight validation and standardization.
953
954  Everything gets normalized to a single sample-wise (or timestep-wise)
955  weight array. If both `sample_weight` and `class_weight` are provided,
956  the weights are multiplied.
957
958  Args:
959      y: Numpy array or Tensor of model targets to be weighted.
960      sample_weight: User-provided `sample_weight` argument.
961      class_weight: User-provided `class_weight` argument.
962      sample_weight_mode: One of `None` or `"temporal"`. `"temporal"` indicated
963        that we expect 2D weight data that will be applied to the last 2
964        dimensions of the targets (i.e. we are weighting timesteps, not
965        samples).
966
967  Returns:
968      A numpy array of target weights, one entry per sample to weight.
969
970  Raises:
971      ValueError: In case of invalid user-provided arguments.
972  """
973  # Iterator may return sample_weight as 1-tuple
974  if isinstance(sample_weight, tuple):
975    sample_weight = sample_weight[0]
976  if sample_weight_mode is not None and sample_weight_mode != 'samplewise':
977    if sample_weight_mode != 'temporal':
978      raise ValueError('"sample_weight_mode '
979                       'should be None or "temporal". '
980                       'Found: ' + str(sample_weight_mode))
981    if len(y.shape) < 3:
982      raise ValueError('Found a sample_weight array for '
983                       'an input with shape ' + str(y.shape) + '. '
984                       'Timestep-wise sample weighting (use of '
985                       'sample_weight_mode="temporal") is restricted to '
986                       'outputs that are at least 3D, i.e. that have '
987                       'a time dimension.')
988    if sample_weight is not None and len(sample_weight.shape) != 2:
989      raise ValueError('Found a sample_weight array with shape ' +
990                       str(sample_weight.shape) + '. '
991                       'In order to use timestep-wise sample weighting, '
992                       'you should pass a 2D sample_weight array.')
993  else:
994    if sample_weight is not None and len(sample_weight.shape) != 1:
995      raise ValueError(
996          'Found a sample_weight array with shape {}. In order to '
997          'use timestep-wise sample weights, you should specify '
998          'sample_weight_mode="temporal" in compile(); founssd "{}" '
999          'instead. If you just mean to use sample-wise weights, '
1000          'make sure your sample_weight array is 1D.'.format(
1001              sample_weight.shape, sample_weight_mode))
1002
1003  if sample_weight is not None:
1004    if len(sample_weight.shape) > len(y.shape):
1005      raise ValueError('Found a sample_weight with shape' +
1006                       str(sample_weight.shape) + '.'
1007                       'Expected sample_weight with rank '
1008                       'less than or equal to ' + str(len(y.shape)))
1009
1010    if (not tensor_util.is_tf_type(sample_weight) and
1011        y.shape[:sample_weight.ndim] != sample_weight.shape):
1012      raise ValueError('Found a sample_weight array with shape ' +
1013                       str(sample_weight.shape) + ' for an input with shape ' +
1014                       str(y.shape) + '. '
1015                       'sample_weight cannot be broadcast.')
1016
1017  # Class weights applied per-sample.
1018  class_sample_weight = None
1019  if isinstance(class_weight, dict):
1020    if len(y.shape) > 2:
1021      raise ValueError('`class_weight` not supported for '
1022                       '3+ dimensional targets.')
1023
1024    if tensor_util.is_tf_type(y):
1025      # Few classes are expected, so densifying is reasonable.
1026      keys = np.array(sorted(class_weight.keys()))
1027      values = np.array([class_weight[i] for i in keys])
1028      weight_vector = np.zeros(np.max(keys) + 1)
1029      weight_vector[:] = np.nan
1030      weight_vector[keys] = values
1031
1032      y_classes = smart_cond.smart_cond(
1033          len(y.shape.as_list()) == 2 and K.shape(y)[1] > 1,
1034          lambda: K.argmax(y, axis=1),
1035          lambda: math_ops.cast(K.reshape(y, (-1,)), dtypes.int64))
1036      class_sample_weight = array_ops.gather(weight_vector, y_classes)
1037      gen_array_ops.check_numerics(
1038          class_sample_weight,
1039          'Invalid classes or class weights detected. NaN values indicate that '
1040          'an appropriate class weight could not be determined.')
1041      class_sample_weight = math_ops.cast(class_sample_weight, K.floatx())
1042      if sample_weight is not None:
1043        sample_weight = math_ops.cast(
1044            ops.convert_to_tensor_v2_with_dispatch(sample_weight), K.floatx())
1045    else:
1046      y_classes = y
1047      if len(y.shape) == 2:
1048        if y.shape[1] > 1:
1049          y_classes = np.argmax(y, axis=1)
1050        elif y.shape[1] == 1:
1051          y_classes = np.reshape(y, y.shape[0])
1052
1053      class_sample_weight = np.asarray(
1054          [class_weight[cls] for cls in y_classes if cls in class_weight])
1055
1056      if len(class_sample_weight) != len(y_classes):
1057        # subtract the sets to pick all missing classes
1058        existing_classes = set(y_classes)
1059        existing_class_weight = set(class_weight.keys())
1060        raise ValueError(
1061            '`class_weight` must contain all classes in the data.'
1062            ' The classes %s exist in the data but not in '
1063            '`class_weight`.' % (existing_classes - existing_class_weight))
1064
1065  if class_sample_weight is not None and sample_weight is not None:
1066    # Multiply weights if both are provided.
1067    return class_sample_weight * sample_weight
1068  if sample_weight is not None:
1069    return sample_weight
1070  if class_sample_weight is not None:
1071    return class_sample_weight
1072  return None
1073
1074
1075def has_symbolic_tensors(ls):
1076  if context.executing_eagerly():
1077    return False
1078  return has_tensors(ls)
1079
1080
1081def has_tensors(ls):
1082  """Returns true if `ls` contains tensors."""
1083  # Note: at some point in time ragged tensors didn't count as tensors, so this
1084  # returned false for ragged tensors. Making this return true fails some tests
1085  # which would then require a steps_per_epoch argument.
1086  if isinstance(ls, (list, tuple)):
1087    return any(
1088        tensor_util.is_tf_type(v) and
1089        not isinstance(v, ragged_tensor.RaggedTensor) for v in ls)
1090  if isinstance(ls, dict):
1091    return any(
1092        tensor_util.is_tf_type(v) and
1093        not isinstance(v, ragged_tensor.RaggedTensor)
1094        for _, v in six.iteritems(ls))
1095  return tensor_util.is_tf_type(ls) and not isinstance(
1096      ls, ragged_tensor.RaggedTensor)
1097
1098
1099def get_metric_name(metric, weighted=False):
1100  """Returns the name corresponding to the given metric input.
1101
1102  Args:
1103    metric: Metric function name or reference.
1104    weighted: Boolean indicating if the given metric is weighted.
1105
1106  Returns:
1107      The metric name.
1108  """
1109  if tf2.enabled():
1110    # We keep the string that the user has set in compile as the metric name.
1111    if isinstance(metric, six.string_types):
1112      return metric
1113
1114    metric = metrics_module.get(metric)
1115    return metric.name if hasattr(metric, 'name') else metric.__name__
1116  else:
1117    metric_name_prefix = 'weighted_' if weighted else ''
1118    if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
1119      if metric in ('accuracy', 'acc'):
1120        suffix = 'acc'
1121      elif metric in ('crossentropy', 'ce'):
1122        suffix = 'ce'
1123    else:
1124      metric_fn = metrics_module.get(metric)
1125      # Get metric name as string
1126      if hasattr(metric_fn, 'name'):
1127        suffix = metric_fn.name
1128      else:
1129        suffix = metric_fn.__name__
1130    metric_name = metric_name_prefix + suffix
1131    return metric_name
1132
1133
1134def get_metric_function(metric, output_shape=None, loss_fn=None):
1135  """Returns the metric function corresponding to the given metric input.
1136
1137  Args:
1138      metric: Metric function name or reference.
1139      output_shape: The shape of the output that this metric will be calculated
1140        for.
1141      loss_fn: The loss function used.
1142
1143  Returns:
1144      The metric function.
1145  """
1146  if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']:
1147    return metrics_module.get(metric)
1148
1149  is_sparse_categorical_crossentropy = (
1150      isinstance(loss_fn, losses.SparseCategoricalCrossentropy) or
1151      (isinstance(loss_fn, losses.LossFunctionWrapper) and
1152       loss_fn.fn == losses.sparse_categorical_crossentropy))
1153
1154  is_binary_crossentropy = (
1155      isinstance(loss_fn, losses.BinaryCrossentropy) or
1156      (isinstance(loss_fn, losses.LossFunctionWrapper) and
1157       loss_fn.fn == losses.binary_crossentropy))
1158
1159  if metric in ['accuracy', 'acc']:
1160    if output_shape[-1] == 1 or is_binary_crossentropy:
1161      return metrics_module.binary_accuracy
1162    elif is_sparse_categorical_crossentropy:
1163      return metrics_module.sparse_categorical_accuracy
1164    # If the output_shape[-1] is not 1, then we know output is `categorical`.
1165    # We assume it is sparse categorical only if loss is explicitly given
1166    # as sparse categorical crossentropy loss.
1167    return metrics_module.categorical_accuracy
1168  else:
1169    if output_shape[-1] == 1 or is_binary_crossentropy:
1170      return metrics_module.binary_crossentropy
1171    elif is_sparse_categorical_crossentropy:
1172      return metrics_module.sparse_categorical_crossentropy
1173    return metrics_module.categorical_crossentropy
1174
1175
1176def call_metric_function(metric_fn,
1177                         y_true,
1178                         y_pred=None,
1179                         weights=None,
1180                         mask=None):
1181  """Invokes metric function and returns the metric result tensor."""
1182  if mask is not None:
1183    mask = math_ops.cast(mask, y_pred.dtype)
1184    if weights is None:
1185      # Use mask as sample weight.
1186      weights = mask
1187    else:
1188      # Update dimensions of weights to match with mask.
1189      weights = math_ops.cast(weights, dtype=y_pred.dtype)
1190      mask, _, weights = losses_utils.squeeze_or_expand_dimensions(
1191          mask, sample_weight=weights)
1192      weights *= mask
1193
1194  if y_pred is not None:
1195    return metric_fn(y_true, y_pred, sample_weight=weights)
1196  # `Mean` metric only takes a single value.
1197  return metric_fn(y_true, sample_weight=weights)
1198
1199
1200def get_loss_function(loss):
1201  """Returns the loss corresponding to the loss input in `compile` API."""
1202  if loss is None or isinstance(loss, losses.Loss):
1203    return loss
1204
1205  if tf_inspect.isclass(loss) and issubclass(loss, losses.Loss):
1206    # It is not safe to assume that the loss takes no constructor arguments.
1207    raise ValueError(
1208        'Received uninstantiated Loss class: {}\nPlease call loss ""classes '
1209        'before passing them to Model.compile.'.format(loss))
1210
1211  # Deserialize loss configuration, if needed.
1212  if isinstance(loss, collections.abc.Mapping):
1213    loss = losses.get(loss)
1214
1215  # Custom callable class.
1216  if callable(loss) and not hasattr(loss, '__name__'):
1217    return loss
1218
1219  # Wrap loss function with signature `(y_true, y_pred, **kwargs)`
1220  # in `LossFunctionWrapper` class.
1221  loss_fn = losses.get(loss)
1222
1223  # For losses which are given as strings/functions in the compile API,
1224  # we always set the loss reduction type to be `SUM_OVER_BATCH_SIZE`
1225  # (both in distribution strategy context and otherwise).
1226  return losses.LossFunctionWrapper(
1227      loss_fn,
1228      name=loss_fn.__name__,
1229      reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE)
1230
1231
1232def validate_dataset_input(x, y, sample_weight, validation_split=None):
1233  """Validates user input arguments when a dataset iterator is passed.
1234
1235  Args:
1236    x: Input data. A `tf.data` dataset or iterator.
1237    y: Target data. It could be either Numpy array(s) or TensorFlow tensor(s).
1238      Expected to be `None` when `x` is a dataset iterator.
1239    sample_weight: An optional sample-weight array passed by the user to weight
1240      the importance of each sample in `x`. Expected to be `None` when `x` is a
1241      dataset iterator
1242    validation_split: Float between 0 and 1. Fraction of the training data to be
1243      used as validation data. Expected to be `None` when `x` is a dataset
1244      iterator.
1245
1246  Raises:
1247    ValueError: if argument `y` or `sample_weight` or `validation_split` are
1248        provided by user.
1249  """
1250  if y is not None:
1251    raise ValueError('You passed a dataset or dataset iterator (%s) as '
1252                     'input `x` to your model. In that case, you should '
1253                     'not specify a target (`y`) argument, since the dataset '
1254                     'or dataset iterator generates both input data and '
1255                     'target data. '
1256                     'Received: %s' % (x, y))
1257  if sample_weight is not None:
1258    raise ValueError('`sample_weight` argument is not supported when input '
1259                     '`x` is a dataset or a dataset iterator. Instead, you'
1260                     'can provide sample_weight as the third element  of your'
1261                     'dataset, i.e. (inputs, targets, sample_weight). '
1262                     'Received: x=%s, sample_weight=%s' % (x, sample_weight))
1263  if validation_split is not None and validation_split != 0.0:
1264    raise ValueError(
1265        '`validation_split` argument is not supported when '
1266        'input `x` is a dataset or a dataset iterator. '
1267        'Received: x=%s, validation_split=%f' % (x, validation_split))
1268
1269
1270def validate_input_types(inp, orig_inp, allow_dict=True, field_name='inputs'):
1271  """Helper function to validate either inputs or targets."""
1272  if isinstance(inp, (list, tuple)):
1273    if not all(isinstance(v, np.ndarray) or
1274               tensor_util.is_tf_type(v) for v in inp):
1275      raise ValueError(
1276          'Please provide as model inputs either a single array or a list of '
1277          'arrays. You passed: {}={}'.format(field_name, str(orig_inp)))
1278  elif isinstance(inp, dict):
1279    if not allow_dict:
1280      raise ValueError(
1281          'You cannot pass a dictionary as model {}.'.format(field_name))
1282  elif not isinstance(inp, np.ndarray) and not tensor_util.is_tf_type(inp):
1283    raise ValueError(
1284        'Please provide as model inputs either a single array or a list of '
1285        'arrays. You passed: {}={}'.format(field_name, orig_inp))
1286
1287
1288def check_generator_arguments(y=None, sample_weight=None,
1289                              validation_split=None):
1290  """Validates arguments passed when using a generator."""
1291  if y is not None:
1292    raise ValueError('`y` argument is not supported when data is'
1293                     'a generator or Sequence instance. Instead pass targets'
1294                     ' as the second element of the generator.')
1295  if sample_weight is not None:
1296    raise ValueError('`sample_weight` argument is not supported when data is'
1297                     'a generator or Sequence instance. Instead pass sample'
1298                     ' weights as the third element of the generator.')
1299  if validation_split:
1300    raise ValueError('If your data is in the form of a Python generator, '
1301                     'you cannot use `validation_split`.')
1302
1303
1304def check_steps_argument(input_data, steps, steps_name):
1305  """Validates `steps` argument based on input data's type.
1306
1307  The cases when `steps` value must be provided are when
1308    1. input data passed is an iterator.
1309    2. model was built on top of symbolic tensors, input data is not
1310       required and is `None`.
1311    3. input data passed is a symbolic tensor.
1312
1313  Args:
1314      input_data: Input data. Can be Numpy array(s) or TensorFlow tensor(s) or
1315        tf.data.Dataset iterator or `None`.
1316      steps: Integer or `None`. Total number of steps (batches of samples) to
1317        execute.
1318      steps_name: The public API's parameter name for `steps`.
1319
1320  Returns:
1321    boolean, True if `steps` argument is required, else False.
1322
1323  Raises:
1324      ValueError: if `steps` argument is required for given input data type
1325        but not provided.
1326  """
1327  is_x_iterator = isinstance(
1328      input_data, (iterator_ops.Iterator, iterator_ops.IteratorBase))
1329  if (input_data is None or is_x_iterator or has_symbolic_tensors(input_data) or
1330      (isinstance(input_data, list) and not input_data)):
1331    if steps is None:
1332      input_type_str = 'a Dataset iterator' if is_x_iterator else 'data tensors'
1333      raise ValueError('When using {input_type} as input to a model, you should'
1334                       ' specify the `{steps_name}` argument.'.format(
1335                           input_type=input_type_str, steps_name=steps_name))
1336    return True
1337
1338  if isinstance(input_data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)):
1339    return True
1340
1341  if steps is not None:
1342    list_types = (np.ndarray, list, tuple)
1343    if (isinstance(input_data, list_types) or
1344        (isinstance(input_data, dict) and
1345         any(isinstance(v, list_types) for v in input_data.values()))):
1346      logging.warning('When passing input data as arrays, do not specify '
1347                      '`steps_per_epoch`/`steps` argument. '
1348                      'Please use `batch_size` instead.')
1349  return False
1350
1351
1352def cast_single_tensor(x, dtype=None):
1353  if isinstance(x, np.ndarray):
1354    x = ops.convert_to_tensor_v2_with_dispatch(x)
1355  dtype = dtype or K.floatx()
1356  if x.dtype.is_floating:
1357    return math_ops.cast(x, dtype=dtype)
1358  return x
1359
1360
1361def cast_if_floating_dtype_and_mismatch(targets, outputs):
1362  """Returns target data tensors using correct datatype.
1363
1364  Checks that each target and output pair are the same datatype. If not, casts
1365  the target to the output's datatype.
1366
1367  Args:
1368    targets: tensor or list of targets.
1369    outputs: tensor or list of outputs.
1370
1371  Returns:
1372    Targets in appropriate datatype.
1373  """
1374  if tensor_util.is_tf_type(targets):
1375    # There is one target, so output[0] should be the only output.
1376    return cast_single_tensor(targets, dtype=outputs[0].dtype)
1377  new_targets = []
1378  for target, out in zip(targets, outputs):
1379    if isinstance(target, np.ndarray):
1380      target = ops.convert_to_tensor_v2_with_dispatch(target)
1381    if target.dtype != out.dtype:
1382      new_targets.append(cast_single_tensor(target, dtype=out.dtype))
1383    else:
1384      new_targets.append(target)
1385  return new_targets
1386
1387
1388def cast_if_floating_dtype(x, dtype=None):
1389  """Casts the given data tensors to the default floating point type.
1390
1391  Casts only if the input is already a floating point type.
1392  Args:
1393    x: tensor or list/tuple of tensors.
1394    dtype: The dtype to which Tensors should be cast.
1395
1396  Returns:
1397    Converted input.
1398  """
1399  return nest.map_structure(functools.partial(cast_single_tensor, dtype=dtype),
1400                            x)
1401
1402
1403def cast_to_model_input_dtypes(x, model):
1404  """Casts the given data tensors to the dtypes of the model inputs.
1405
1406  Args:
1407    x: tensor or list/tuple of tensors.
1408    model: The model.
1409
1410  Returns:
1411    Converted input. Each tensor is casted to the corresponding input in
1412    `model.inputs`.
1413  """
1414  input_dtypes = nest.map_structure(lambda t: t.dtype, model.inputs)
1415  return nest.map_structure(math_ops.cast, x, input_dtypes)
1416
1417
1418def prepare_sample_weight_modes(training_endpoints, sample_weight_mode):
1419  """Prepares sample weight modes for the model.
1420
1421  Args:
1422    training_endpoints: List of model _TrainingEndpoints.
1423    sample_weight_mode: sample weight mode user input passed from compile API.
1424
1425  Raises:
1426    ValueError: In case of invalid `sample_weight_mode` input.
1427  """
1428
1429  if isinstance(sample_weight_mode, collections.abc.Mapping):
1430    generic_utils.check_for_unexpected_keys(
1431        'sample_weight_mode', sample_weight_mode,
1432        [e.output_name for e in training_endpoints])
1433
1434    for end_point in training_endpoints:
1435      if not end_point.should_skip_target_weights():
1436        if end_point.output_name not in sample_weight_mode:
1437          raise ValueError('Output ' + end_point.output_name +
1438                           'missing from `_sample_weight_modes` dictionary')
1439        else:
1440          end_point.sample_weight_mode = sample_weight_mode.get(
1441              end_point.output_name)
1442  elif isinstance(sample_weight_mode, (list, tuple)):
1443    if len(sample_weight_mode) != len(training_endpoints):
1444      raise ValueError('When passing a list as sample_weight_mode, '
1445                       'it should have one entry per model output. '
1446                       'The model has ' + str(len(training_endpoints)) +
1447                       ' outputs, but you passed ' +
1448                       str(len(sample_weight_mode)) + '_sample_weight_modes.')
1449    for mode, endpoint in zip(sample_weight_mode, training_endpoints):
1450      if not endpoint.should_skip_target_weights():
1451        endpoint.sample_weight_mode = mode
1452  else:
1453    for endpoint in training_endpoints:
1454      if not endpoint.should_skip_target_weights():
1455        endpoint.sample_weight_mode = sample_weight_mode
1456
1457
1458def prepare_loss_functions(loss, output_names):
1459  """Converts loss to a list of loss functions.
1460
1461  Args:
1462      loss: String (name of objective function), objective function or
1463        `tf.losses.Loss` instance. See `tf.losses`. If the model has multiple
1464        outputs, you can use a different loss on each output by passing a
1465        dictionary or a list of losses. The loss value that will be minimized by
1466        the model will then be the sum of all individual losses.
1467      output_names: List of model output names.
1468
1469  Returns:
1470      A list of loss objective functions.
1471
1472  Raises:
1473      ValueError: If loss is a dict with keys not in model output names,
1474          or if loss is a list with len not equal to model outputs.
1475  """
1476  if isinstance(loss, collections.abc.Mapping):
1477    generic_utils.check_for_unexpected_keys('loss', loss, output_names)
1478    loss_functions = []
1479    for name in output_names:
1480      if name not in loss:
1481        logging.warning(
1482            'Output {0} missing from loss dictionary. We assume '
1483            'this was done on purpose. The fit and evaluate APIs will not be '
1484            'expecting any data to be passed to {0}.'.format(name))
1485      loss_functions.append(get_loss_function(loss.get(name, None)))
1486  elif isinstance(loss, six.string_types):
1487    loss_functions = [get_loss_function(loss) for _ in output_names]
1488  elif isinstance(loss, collections.abc.Sequence):
1489    if len(loss) != len(output_names):
1490      raise ValueError('When passing a list as loss, it should have one entry '
1491                       'per model outputs. The model has {} outputs, but you '
1492                       'passed loss={}'.format(len(output_names), loss))
1493    loss_functions = nest.map_structure(get_loss_function, loss)
1494  else:
1495    loss_functions = [get_loss_function(loss) for _ in range(len(output_names))]
1496
1497  return loss_functions
1498
1499
1500def prepare_loss_weights(training_endpoints, loss_weights=None):
1501  """Converts loss weights to a list of loss weights.
1502
1503  The result loss weights will be populated on the training endpoint.
1504
1505  Args:
1506      training_endpoints: List of model training endpoints.
1507      loss_weights: Optional list or dictionary specifying scalar coefficients
1508        (Python floats) to weight the loss contributions of different model
1509        outputs. The loss value that will be minimized by the model will then be
1510        the *weighted sum* of all individual losses, weighted by the
1511          `loss_weights` coefficients. If a list, it is expected to have a 1:1
1512            mapping to the model's outputs. If a dict, it is expected to map
1513            output names (strings) to scalar coefficients.
1514
1515  Raises:
1516      ValueError: If loss weight is a dict with key not in model output names,
1517          or if loss is a list with len not equal to model outputs.
1518  """
1519  if loss_weights is None:
1520    for e in training_endpoints:
1521      e.loss_weight = 1.
1522  elif isinstance(loss_weights, collections.abc.Mapping):
1523    generic_utils.check_for_unexpected_keys(
1524        'loss_weights', loss_weights,
1525        [e.output_name for e in training_endpoints])
1526    for e in training_endpoints:
1527      e.loss_weight = loss_weights.get(e.output_name, 1.)
1528  elif isinstance(loss_weights, list):
1529    if len(loss_weights) != len(training_endpoints):
1530      raise ValueError('When passing a list as loss_weights, '
1531                       'it should have one entry per model output. '
1532                       'The model has ' + str(len(training_endpoints)) +
1533                       ' outputs, but you passed loss_weights=' +
1534                       str(loss_weights))
1535    for w, e in zip(loss_weights, training_endpoints):
1536      e.loss_weight = w
1537  else:
1538    raise TypeError('Could not interpret loss_weights argument: ' +
1539                    str(loss_weights) + ' - expected a list of dicts.')
1540
1541
1542# TODO(rohanj): This is a hack to get around not depending on feature_column and
1543# create a cyclical dependency. Figure out a cleaner solution
1544def is_feature_layer(layer):
1545  """Returns whether `layer` is a FeatureLayer or not."""
1546  return getattr(layer, '_is_feature_layer', False)
1547
1548
1549def is_eager_dataset_or_iterator(data):
1550  return context.executing_eagerly() and isinstance(
1551      data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2,
1552             iterator_ops.IteratorBase))
1553
1554
1555# pylint: disable=protected-access
1556def get_dataset_graph_def(dataset):
1557  if context.executing_eagerly():
1558    graph_def_str = dataset._as_serialized_graph().numpy()
1559  else:
1560    graph_def_str = K.get_value(dataset._as_serialized_graph())
1561  return graph_pb2.GraphDef().FromString(graph_def_str)
1562
1563
1564def verify_dataset_shuffled(x):
1565  """Verifies that the dataset is shuffled.
1566
1567  Args:
1568    x: Dataset passed as an input to the model.
1569
1570  Returns:
1571    boolean, whether the input dataset is shuffled or not.
1572  """
1573  assert isinstance(x, dataset_ops.DatasetV2)
1574  graph_def = get_dataset_graph_def(x)
1575  for node in graph_def.node:
1576    if node.op.startswith('ShuffleDataset'):
1577      return True
1578  # Also check graph_def.library.function for ds.interleave or ds.flat_map
1579  for function in graph_def.library.function:
1580    for node in function.node_def:
1581      if node.op.startswith('ShuffleDataset'):
1582        return True
1583  logging.warning('Expected a shuffled dataset but input dataset `x` is '
1584                  'not shuffled. Please invoke `shuffle()` on input dataset.')
1585  return False
1586
1587
1588def is_dataset_or_iterator(data):
1589  return isinstance(data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2,
1590                           iterator_ops.Iterator, iterator_ops.IteratorBase))
1591
1592
1593def get_iterator(dataset):
1594  """Create and initialize an iterator from a dataset."""
1595  if context.executing_eagerly():
1596    iterator = dataset_ops.make_one_shot_iterator(dataset)
1597  else:
1598    iterator = dataset_ops.make_initializable_iterator(dataset)
1599  initialize_iterator(iterator)
1600  return iterator
1601
1602
1603def initialize_iterator(iterator):
1604  if not context.executing_eagerly():
1605    init_op = iterator.initializer
1606    K.get_session((init_op,)).run(init_op)
1607
1608
1609def extract_tensors_from_dataset(dataset):
1610  """Extract a tuple of tensors `inputs, targets, sample_weight` from a dataset.
1611
1612  Args:
1613    dataset: Dataset instance.
1614
1615  Returns:
1616    Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
1617  """
1618  iterator = get_iterator(dataset)
1619  inputs, targets, sample_weight = unpack_iterator_input(iterator)
1620  return inputs, targets, sample_weight
1621
1622
1623def unpack_iterator_input(iterator):
1624  """Convert a dataset iterator to a tuple of tensors `x, y, sample_weights`.
1625
1626  Args:
1627    iterator: Instance of a dataset iterator.
1628
1629  Returns:
1630    Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
1631  """
1632  try:
1633    next_element = iterator.get_next()
1634  except errors.OutOfRangeError:
1635    raise RuntimeError('Your dataset iterator ran out of data; '
1636                       'Make sure that your dataset can generate '
1637                       'required number of samples.')
1638
1639  if isinstance(next_element, (list, tuple)):
1640    if len(next_element) not in [2, 3]:
1641      raise ValueError(
1642          'Please provide model inputs as a list or tuple of 2 or 3 '
1643          'elements: (input, target) or (input, target, sample_weights) '
1644          'Received %s' % next_element)
1645    if len(next_element) == 2:
1646      x, y = next_element
1647      weights = None
1648    else:
1649      x, y, weights = next_element
1650  else:
1651    x = next_element
1652    y = None
1653    weights = None
1654  return x, y, weights
1655
1656
1657def infer_steps_for_dataset(model,
1658                            dataset,
1659                            steps,
1660                            epochs=1,
1661                            steps_name='steps'):
1662  """Infers steps_per_epoch needed to loop through a dataset.
1663
1664  Args:
1665      model: Keras model instance.
1666      dataset: Input data of type tf.data.Dataset.
1667      steps: Number of steps to draw from the dataset (may be None if unknown).
1668      epochs: Number of times to iterate over the dataset.
1669      steps_name: The string name of the steps argument, either `steps`,
1670        `validation_steps`, or `steps_per_epoch`. Only used for error message
1671        formatting.
1672
1673  Returns:
1674    Integer or `None`. Inferred number of steps to loop through the dataset.
1675    `None` is returned if 1) the size of the dataset is unknown and `steps` was
1676    not specified, or 2) this is multi-worker training and auto sharding is
1677    enabled.
1678
1679  Raises:
1680    ValueError: In case of invalid argument values.
1681  """
1682  assert isinstance(dataset, dataset_ops.DatasetV2)
1683  if (model._in_multi_worker_mode() and
1684      (dataset.options().experimental_distribute.auto_shard_policy !=
1685       distribute_options.AutoShardPolicy.OFF)):
1686    # If the dataset would be auto-sharded, we should not infer a local
1687    # steps_per_epoch due to the possible inbalanced sharding between workers.
1688    return None
1689
1690  size = K.get_value(cardinality.cardinality(dataset))
1691  if size == cardinality.INFINITE and steps is None:
1692    raise ValueError('When passing an infinitely repeating dataset, you '
1693                     'must specify the `%s` argument.' % (steps_name,))
1694  if size >= 0:
1695    if steps is not None and steps * epochs > size:
1696      if epochs > 1:
1697        raise ValueError('The dataset you passed contains %s batches, but you '
1698                         'passed `epochs=%s` and `%s=%s`, which is a total of '
1699                         '%s steps. We cannot draw that many steps from this '
1700                         'dataset. We suggest to set `%s=%s`.' %
1701                         (size, epochs, steps_name, steps, steps * epochs,
1702                          steps_name, size // epochs))
1703      else:
1704        raise ValueError('The dataset you passed contains %s batches, but you '
1705                         'passed `%s=%s`. We cannot draw that many steps from '
1706                         'this dataset. We suggest to set `%s=%s`.' %
1707                         (size, steps_name, steps, steps_name, size))
1708  if steps is None:
1709    if size >= 0:
1710      return size
1711    return None
1712  return steps
1713
1714
1715class ModelInputs(object):
1716  """Encapsulates model inputs.
1717
1718  Allows for transforming model inputs while keeping the same structure.
1719  """
1720
1721  def __init__(self, inputs):
1722    self._inputs = inputs
1723    self._is_dict = isinstance(self._inputs, dict)
1724    self._is_single_input = not isinstance(self._inputs, (list, tuple, dict))
1725
1726    self._flattened_inputs = []
1727    self._input_names = []
1728
1729    if self._is_dict:
1730      for k in sorted(self._inputs.keys()):
1731        self._flattened_inputs.append(self._inputs[k])
1732        self._input_names.append(k)
1733    else:
1734      self._flattened_inputs = nest.flatten(self._inputs)
1735      self._input_names = [
1736          'input_%d' % (i + 1) for i in range(len(self._flattened_inputs))
1737      ]
1738
1739  def get_input_names(self):
1740    """Returns keys to name inputs by.
1741
1742    In case inputs provided were a list, tuple or single entry, we make up a
1743    key 'input_%d'. For dictionary case, we return a sorted list of keys.
1744    """
1745    return self._input_names
1746
1747  def get_symbolic_inputs(self, return_single_as_list=False):
1748    """Returns inputs to be set as self.inputs for a model."""
1749    # TODO(karmel): There is a side-effect here where what you get
1750    # with as_list and as_dict depends on whether you have called this
1751    # method first, since it modifies in place.
1752    for i, (k, v) in enumerate(zip(self._input_names, self._flattened_inputs)):
1753      if isinstance(v, (list, float, int)):
1754        v = np.asarray(v)
1755        if v.ndim == 1:
1756          v = np.expand_dims(v, 1)
1757
1758      if isinstance(v, np.ndarray):
1759        # We fix the placeholder shape except the batch size.
1760        # This is suboptimal, but it is the best we can do with the info
1761        # we have. The user should call `model._set_inputs(placeholders)`
1762        # to specify custom placeholders if the need arises.
1763        shape = (None,) + tuple(v.shape[1:])
1764        if shape == (None,):
1765          shape = (None, 1)
1766        dtype = dtypes.as_dtype(v.dtype)
1767        if dtype.is_floating:
1768          dtype = K.floatx()
1769        v = K.placeholder(shape=shape, name=k, dtype=dtype)
1770      elif isinstance(v, tensor_spec.TensorSpec):
1771        shape = (None,) + tuple(v.shape.as_list()[1:])
1772        if shape == (None,):
1773          shape = (None, 1)
1774        v = K.placeholder(shape=shape, name=k, dtype=v.dtype)
1775
1776      self._flattened_inputs[i] = v
1777
1778    if self._is_dict:
1779      return dict(zip(self._input_names, self._flattened_inputs))
1780    if self._is_single_input and not return_single_as_list:
1781      return self._flattened_inputs[0]
1782    return self._flattened_inputs
1783
1784  def as_dict(self):
1785    """An iterable over a dictionary version of inputs."""
1786    for k, v in zip(self._input_names, self._flattened_inputs):
1787      yield k, v
1788
1789  def as_list(self):
1790    """Returning the inputs as a list."""
1791    return self._flattened_inputs
1792
1793
1794# Allow use of methods not exposed to the user.
1795# pylint: disable=protected-access
1796
1797
1798# pylint: enable=protected-access
1799
1800
1801def generic_output_names(outputs_list):
1802  return ['output_%d' % (i + 1) for i in range(len(outputs_list))]
1803
1804
1805def should_run_validation(validation_freq, epoch):
1806  """Checks if validation should be run this epoch.
1807
1808  Args:
1809    validation_freq: Integer or list. If an integer, specifies how many training
1810      epochs to run before a new validation run is performed. If a list,
1811      specifies the epochs on which to run validation.
1812    epoch: Integer, the number of the training epoch just completed.
1813
1814  Returns:
1815    Bool, True if validation should be run.
1816
1817  Raises:
1818    ValueError: if `validation_freq` is an Integer and less than 1, or if
1819    it is neither an Integer nor a Sequence.
1820  """
1821  # `epoch` is 0-indexed internally but 1-indexed in the public API.
1822  one_indexed_epoch = epoch + 1
1823
1824  if isinstance(validation_freq, int):
1825    if validation_freq < 1:
1826      raise ValueError('`validation_freq` can not be less than 1.')
1827    return one_indexed_epoch % validation_freq == 0
1828
1829  if not isinstance(validation_freq, collections.abc.Container):
1830    raise ValueError('`validation_freq` must be an Integer or '
1831                     '`collections.abc.Container` (e.g. list, tuple, etc.)')
1832  return one_indexed_epoch in validation_freq
1833
1834
1835def split_training_and_validation_data(x, y, sample_weights, validation_split):
1836  """Split input data into train/eval section based on validation_split."""
1837  if has_symbolic_tensors(x):
1838    raise ValueError('If your data is in the form of symbolic tensors, '
1839                     'you cannot use `validation_split`.')
1840  if hasattr(x[0], 'shape'):
1841    split_at = int(x[0].shape[0] * (1. - validation_split))
1842  else:
1843    split_at = int(len(x[0]) * (1. - validation_split))
1844  x, val_x = (generic_utils.slice_arrays(x, 0, split_at),
1845              generic_utils.slice_arrays(x, split_at))
1846  y, val_y = (generic_utils.slice_arrays(y, 0, split_at),
1847              generic_utils.slice_arrays(y, split_at))
1848  if sample_weights:
1849    sample_weights, val_sample_weights = (
1850        generic_utils.slice_arrays(sample_weights, 0, split_at),
1851        generic_utils.slice_arrays(sample_weights, split_at),
1852    )
1853  else:
1854    val_sample_weights = None
1855  return x, y, sample_weights, val_x, val_y, val_sample_weights
1856
1857
1858def unpack_validation_data(validation_data, raise_if_ambiguous=True):
1859  """Unpack validation data based input type.
1860
1861  The validation data is not touched if its dataset or dataset iterator.
1862  For other type of input (Numpy or tensor), it will be unpacked into tuple of
1863  3 which is x, y and sample weights.
1864
1865  Args:
1866    validation_data: dataset, dataset iterator, or numpy, tensor tuple.
1867    raise_if_ambiguous: boolean on whether to fail if validation_data cannot be
1868      parsed. Otherwise simply return validation_data, None, None and defer the
1869      decision to the caller.
1870
1871  Returns:
1872    tuple of 3, (x, y, sample_weights) for numpy and tensor input.
1873  """
1874  if (isinstance(validation_data, (iterator_ops.Iterator,
1875                                   iterator_ops.IteratorBase,
1876                                   dataset_ops.DatasetV2,
1877                                   data_utils.Sequence))
1878      or not hasattr(validation_data, '__len__')):
1879    val_x = validation_data
1880    val_y = None
1881    val_sample_weight = None
1882  elif len(validation_data) == 2:
1883    try:
1884      val_x, val_y = validation_data  # pylint: disable=unpacking-non-sequence
1885      val_sample_weight = None
1886    except ValueError:
1887      val_x, val_y, val_sample_weight = validation_data, None, None
1888  elif len(validation_data) == 3:
1889    try:
1890      val_x, val_y, val_sample_weight = validation_data  # pylint: disable=unpacking-non-sequence
1891    except ValueError:
1892      val_x, val_y, val_sample_weight = validation_data, None, None
1893  else:
1894    if raise_if_ambiguous:
1895      raise ValueError(
1896          'When passing a `validation_data` argument, '
1897          'it must contain either 2 items (x_val, y_val), '
1898          'or 3 items (x_val, y_val, val_sample_weights), '
1899          'or alternatively it could be a dataset or a '
1900          'dataset or a dataset iterator. '
1901          'However we received `validation_data=%s`' % validation_data)
1902    val_x, val_y, val_sample_weight = validation_data, None, None
1903  return val_x, val_y, val_sample_weight
1904
1905
1906class TrainingLoop(object):
1907  """TrainingLoop is a wrapper class around the training logic.
1908
1909  This class is trying to encapsulate the different logic of fit/eval/predict
1910  with regard to different data input and model condition.
1911
1912  Note that TrainingLoop is stateless, which means it doesn't contain any
1913  internal field and can be reused with different model and inputs.
1914  """
1915
1916  def fit(self,
1917          model,
1918          x=None,
1919          y=None,
1920          batch_size=None,
1921          epochs=1,
1922          verbose=1,
1923          callbacks=None,
1924          validation_split=0.,
1925          validation_data=None,
1926          shuffle=True,
1927          class_weight=None,
1928          sample_weight=None,
1929          initial_epoch=0,
1930          steps_per_epoch=None,
1931          validation_steps=None,
1932          validation_freq=1,
1933          **kwargs):
1934    """Train the model with the inputs and targets."""
1935    raise NotImplementedError()
1936
1937  def evaluate(self,
1938               model,
1939               x=None,
1940               y=None,
1941               batch_size=None,
1942               verbose=1,
1943               sample_weight=None,
1944               steps=None,
1945               callbacks=None,
1946               **kwargs):
1947    """Returns the loss value & metrics values for the model in test mode."""
1948    raise NotImplementedError()
1949
1950  def predict(self,
1951              model,
1952              x,
1953              batch_size=None,
1954              verbose=0,
1955              steps=None,
1956              callbacks=None,
1957              **kwargs):
1958    raise NotImplementedError()
1959