1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training-related utilities."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import collections
22from collections import OrderedDict
23
24import numpy as np
25import six
26
27from tensorflow.python import tf2
28from tensorflow.python.data.experimental.ops import cardinality
29from tensorflow.python.data.ops import dataset_ops
30from tensorflow.python.data.ops import iterator_ops
31from tensorflow.python.data.ops import readers
32from tensorflow.python.eager import context
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import errors
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import tensor_shape
38from tensorflow.python.framework import tensor_util
39from tensorflow.python.keras import backend as K
40from tensorflow.python.keras import callbacks as cbks
41from tensorflow.python.keras import losses
42from tensorflow.python.keras import metrics as metrics_module
43from tensorflow.python.keras.utils import generic_utils
44from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
45from tensorflow.python.ops import array_ops
46from tensorflow.python.ops import math_ops
47from tensorflow.python.platform import tf_logging as logging
48from tensorflow.python.util import nest
49
50
51@six.add_metaclass(abc.ABCMeta)
52class Aggregator(object):
53  """Abstract base class used to aggregate batch-level outputs of a loop.
54
55  Attributes:
56    use_steps: Whether the loop is using `step` or `batch_size`.
57    num_samples_or_steps: Either `batch_size*num_batches` or `steps`.
58    results: What to return at the end of the aggregation loop.
59  """
60
61  def __init__(self, use_steps, num_samples_or_steps):
62    self.use_steps = use_steps
63    self.num_samples_or_steps = num_samples_or_steps
64    self.results = []
65
66  @abc.abstractmethod
67  def create(self, batch_outs):
68    """Creates the initial results from the first batch outputs.
69
70    Arguments:
71      batch_outs: A list of batch-level outputs.
72    """
73    raise NotImplementedError('Must be implemented in subclasses.')
74
75  @abc.abstractmethod
76  def aggregate(self, batch_outs, batch_start=None, batch_end=None):
77    """Aggregates batch-level results into total results.
78
79    Arguments:
80      batch_outs: A list of batch-level outputs.
81      batch_start: The start index of this batch. Always `None` if `use_steps`
82        is `True`.
83      batch_end: The end index of this batch. Always `None` if `use_steps` is
84        `True`.
85    """
86    raise NotImplementedError('Must be implemented in subclasses.')
87
88  @abc.abstractmethod
89  def finalize(self):
90    """Prepares the total results to be returned."""
91    raise NotImplementedError('Must be implemented in subclasses.')
92
93
94class MetricsAggregator(Aggregator):
95  """Aggregator that calculates loss and metrics info."""
96
97  def create(self, batch_outs):
98    self.results = [0.] * len(batch_outs)
99
100  def aggregate(self, batch_outs, batch_start=None, batch_end=None):
101    # Loss.
102    if self.use_steps:
103      self.results[0] += batch_outs[0]
104    else:
105      self.results[0] += batch_outs[0] * (batch_end - batch_start)
106    # Metrics (always stateful, just grab current values.)
107    self.results[1:] = batch_outs[1:]
108
109  def finalize(self):
110    if not self.results:
111      raise ValueError('Empty training data.')
112    self.results[0] /= self.num_samples_or_steps
113
114
115class OutputsAggregator(Aggregator):
116  """Aggregator that concatenates outputs."""
117
118  def create(self, batch_outs):
119    if self.use_steps:
120      # Cannot pre-allocate the returned NumPy arrays bc
121      # batch sizes are unknown. Concatenate batches at the end.
122      for _ in batch_outs:
123        self.results.append([])
124    else:
125      # Pre-allocate NumPy arrays.
126      for batch_out in batch_outs:
127        shape = (self.num_samples_or_steps,) + batch_out.shape[1:]
128        self.results.append(np.zeros(shape, dtype=batch_out.dtype))
129
130  def aggregate(self, batch_outs, batch_start=None, batch_end=None):
131    if self.use_steps:
132      for i, batch_out in enumerate(batch_outs):
133        self.results[i].append(batch_out)
134    else:
135      for i, batch_out in enumerate(batch_outs):
136        self.results[i][batch_start:batch_end] = batch_out
137
138  def finalize(self):
139    if self.use_steps:
140      self.results = [np.concatenate(result, axis=0) for result in self.results]
141
142
143def get_progbar(model, count_mode):
144  """Get Progbar."""
145  stateful_metric_names = None
146  if hasattr(model, 'metrics_names'):
147    stateful_metric_names = model.metrics_names[1:]  # Exclude `loss`
148  return cbks.ProgbarLogger(count_mode, stateful_metrics=stateful_metric_names)
149
150
151def slice_arrays(arrays, indices, contiguous=True):
152  """Slices batches out of provided arrays (workaround for eager tensors).
153
154  Unfortunately eager tensors don't have the same slicing behavior as
155  Numpy arrays (they follow the same slicing behavior as symbolic TF tensors),
156  hence we cannot use `generic_utils.slice_arrays` directly
157  and we have to implement this workaround based on `concat`. This has a
158  performance cost.
159
160  Arguments:
161    arrays: Single array or list of arrays.
162    indices: List of indices in the array that should be included in the output
163      batch.
164    contiguous: Boolean flag indicating whether the indices are contiguous.
165
166  Returns:
167    Slice of data (either single array or list of arrays).
168  """
169  converted_to_list = False
170  if not isinstance(arrays, list):
171    converted_to_list = True
172    arrays = [arrays]
173  if any(tensor_util.is_tensor(x) for x in arrays):
174    if not contiguous:
175      entries = [[x[i:i + 1] for i in indices] for x in arrays]
176      slices = [array_ops.concat(x, axis=0) for x in entries]
177    else:
178      slices = [x[indices[0]:indices[-1] + 1] for x in arrays]
179  else:
180    slices = generic_utils.slice_arrays(arrays, indices)
181
182  if converted_to_list:
183    slices = slices[0]
184  return slices
185
186
187def check_num_samples(ins, batch_size=None, steps=None, steps_name='steps'):
188  """Determine the number of samples provided for training and evaluation.
189
190  The number of samples is not defined when running with `steps`,
191  in which case the number of samples is set to `None`.
192
193  Arguments:
194      ins: List of tensors to be fed to the Keras function.
195      batch_size: Integer batch size or `None` if not defined.
196      steps: Total number of steps (batches of samples) before declaring
197        `_predict_loop` finished. Ignored with the default value of `None`.
198      steps_name: The public API's parameter name for `steps`.
199
200  Raises:
201      ValueError: when `steps` is `None` and the attribute `ins.shape`
202      does not exist. Also raises ValueError when `steps` is not `None`
203      and `batch_size` is not `None` because they are mutually
204      exclusive.
205
206  Returns:
207      When steps is `None`, returns the number of samples to be
208      processed based on the size of the first dimension of the
209      first input numpy array. When steps is not `None` and
210      `batch_size` is `None`, returns `None`.
211  """
212  if steps is not None and batch_size is not None:
213    raise ValueError('If ' + steps_name +
214                     ' is set, the `batch_size` must be None.')
215  if check_steps_argument(ins, steps, steps_name):
216    return None
217  if hasattr(ins[0], 'shape'):
218    return int(ins[0].shape[0])
219  return None  # Edge case where ins == [static_learning_phase]
220
221
222def standardize_single_array(x, expected_shape=None):
223  """Expand data of shape (x,) to (x, 1), unless len(expected_shape)==1."""
224  if x is None:
225    return None
226
227  if (x.shape is not None and len(x.shape) == 1 and
228      (expected_shape is None or len(expected_shape) != 1)):
229    if tensor_util.is_tensor(x):
230      x = array_ops.expand_dims(x, axis=1)
231    else:
232      x = np.expand_dims(x, 1)
233  return x
234
235
236def standardize_input_data(data,
237                           names,
238                           shapes=None,
239                           check_batch_axis=True,
240                           exception_prefix=''):
241  """Normalizes inputs and targets provided by users.
242
243  Users may pass data as a list of arrays, dictionary of arrays,
244  or as a single array. We normalize this to an ordered list of
245  arrays (same order as `names`), while checking that the provided
246  arrays have shapes that match the network's expectations.
247
248  Arguments:
249      data: User-provided input data (polymorphic).
250      names: List of expected array names.
251      shapes: Optional list of expected array shapes.
252      check_batch_axis: Boolean; whether to check that the batch axis of the
253        arrays matches the expected value found in `shapes`.
254      exception_prefix: String prefix used for exception formatting.
255
256  Returns:
257      List of standardized input arrays (one array per model input).
258
259  Raises:
260      ValueError: in case of improperly formatted user-provided data.
261  """
262  if not names:
263    if (data is not None and hasattr(data, '__len__') and len(data) and
264        not isinstance(data, dict)):
265      raise ValueError(
266          'Error when checking model ' + exception_prefix + ': '
267          'expected no data, but got:', data)
268    return []
269  if data is None:
270    return [None for _ in range(len(names))]
271
272  if isinstance(data, dict):
273    try:
274      data = [
275          data[x].values
276          if data[x].__class__.__name__ == 'DataFrame' else data[x]
277          for x in names
278      ]
279    except KeyError as e:
280      raise ValueError('No data provided for "' + e.args[0] + '". Need data '
281                       'for each key in: ' + str(names))
282  elif isinstance(data, (list, tuple)):
283    if isinstance(data[0], (list, tuple)):
284      data = [np.asarray(d) for d in data]
285    elif len(names) == 1 and isinstance(data[0], (float, int)):
286      data = [np.asarray(data)]
287    else:
288      data = [
289          x.values if x.__class__.__name__ == 'DataFrame' else x for x in data
290      ]
291  else:
292    data = data.values if data.__class__.__name__ == 'DataFrame' else data
293    data = [data]
294  if shapes is not None:
295    data = [
296        standardize_single_array(x, shape) for (x, shape) in zip(data, shapes)
297    ]
298  else:
299    data = [standardize_single_array(x) for x in data]
300
301  if len(data) != len(names):
302    if data and hasattr(data[0], 'shape'):
303      raise ValueError('Error when checking model ' + exception_prefix +
304                       ': the list of Numpy arrays that you are passing to '
305                       'your model is not the size the model expected. '
306                       'Expected to see ' + str(len(names)) + ' array(s), '
307                       'but instead got the following list of ' +
308                       str(len(data)) + ' arrays: ' + str(data)[:200] + '...')
309    elif len(names) > 1:
310      raise ValueError('Error when checking model ' + exception_prefix +
311                       ': you are passing a list as input to your model, '
312                       'but the model expects a list of ' + str(len(names)) +
313                       ' Numpy arrays instead. The list you passed was: ' +
314                       str(data)[:200])
315    elif len(data) == 1 and not hasattr(data[0], 'shape'):
316      raise TypeError('Error when checking model ' + exception_prefix +
317                      ': data should be a Numpy array, or list/dict of '
318                      'Numpy arrays. Found: ' + str(data)[:200] + '...')
319    elif len(names) == 1:
320      data = [np.asarray(data)]
321
322  # Check shapes compatibility.
323  if shapes:
324    for i in range(len(names)):
325      if shapes[i] is not None:
326        if tensor_util.is_tensor(data[i]):
327          tensorshape = data[i].get_shape()
328          if not tensorshape:
329            continue
330          data_shape = tuple(tensorshape.as_list())
331        else:
332          data_shape = data[i].shape
333        shape = shapes[i]
334        if len(data_shape) != len(shape):
335          raise ValueError('Error when checking ' + exception_prefix +
336                           ': expected ' + names[i] + ' to have ' +
337                           str(len(shape)) + ' dimensions, but got array '
338                           'with shape ' + str(data_shape))
339        if not check_batch_axis:
340          data_shape = data_shape[1:]
341          shape = shape[1:]
342        for dim, ref_dim in zip(data_shape, shape):
343          if ref_dim != dim and ref_dim is not None and dim is not None:
344            raise ValueError('Error when checking ' + exception_prefix +
345                             ': expected ' + names[i] + ' to have shape ' +
346                             str(shape) + ' but got array with shape ' +
347                             str(data_shape))
348  return data
349
350
351def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
352  """Maps `sample_weight` or `class_weight` to model outputs.
353
354  Arguments:
355      x_weight: User-provided `sample_weight` or `class_weight` argument.
356      output_names: List of output names (strings) in the model.
357      weight_type: A string used purely for exception printing.
358
359  Returns:
360      A list of `sample_weight` or `class_weight` where there are exactly
361          one element per model output.
362
363  Raises:
364      ValueError: In case of invalid user-provided argument.
365  """
366  if x_weight is None or (isinstance(x_weight, list) and len(x_weight) == 0):  # pylint: disable=g-explicit-length-test
367    return [None for _ in output_names]
368  if len(output_names) == 1:
369    if isinstance(x_weight, list) and len(x_weight) == 1:
370      return x_weight
371    if isinstance(x_weight, dict) and output_names[0] in x_weight:
372      return [x_weight[output_names[0]]]
373    else:
374      return [x_weight]
375  if isinstance(x_weight, list):
376    if len(x_weight) != len(output_names):
377      raise ValueError('Provided `' + weight_type + '` was a list of ' +
378                       str(len(x_weight)) + ' elements, but the model has ' +
379                       str(len(output_names)) + ' outputs. '
380                       'You should provide one `' + weight_type + '`'
381                       'array per model output.')
382    return x_weight
383  if isinstance(x_weight, dict):
384    x_weights = []
385    for name in output_names:
386      x_weights.append(x_weight.get(name))
387    return x_weights
388  else:
389    raise TypeError('The model has multiple outputs, so `' + weight_type + '` '
390                    'should be either a list or a dict. '
391                    'Provided `' + weight_type + '` type not understood: ' +
392                    str(x_weight))
393
394
395def standardize_class_weights(class_weight, output_names):
396  return standardize_sample_or_class_weights(class_weight, output_names,
397                                             'class_weight')
398
399
400def standardize_sample_weights(sample_weight, output_names):
401  return standardize_sample_or_class_weights(sample_weight, output_names,
402                                             'sample_weight')
403
404
405def check_array_lengths(inputs, targets, weights=None):
406  """Does user input validation for numpy arrays.
407
408  Arguments:
409      inputs: list of Numpy arrays of inputs.
410      targets: list of Numpy arrays of targets.
411      weights: list of Numpy arrays of sample weights.
412
413  Raises:
414      ValueError: in case of incorrectly formatted data.
415  """
416
417  def set_of_lengths(x):
418    # Returns a set with the variation between
419    # different shapes, with None => 0
420    if x is None:
421      return {}
422    else:
423      return set([
424          y.shape[0]
425          for y in x
426          if y is not None and not tensor_util.is_tensor(y)
427      ])
428
429  set_x = set_of_lengths(inputs)
430  set_y = set_of_lengths(targets)
431  set_w = set_of_lengths(weights)
432  if len(set_x) > 1:
433    raise ValueError('All input arrays (x) should have '
434                     'the same number of samples. Got array shapes: ' +
435                     str([x.shape for x in inputs]))
436  if len(set_y) > 1:
437    raise ValueError('All target arrays (y) should have '
438                     'the same number of samples. Got array shapes: ' +
439                     str([y.shape for y in targets]))
440  if set_x and set_y and list(set_x)[0] != list(set_y)[0]:
441    raise ValueError('Input arrays should have '
442                     'the same number of samples as target arrays. '
443                     'Found ' + str(list(set_x)[0]) + ' input samples '
444                     'and ' + str(list(set_y)[0]) + ' target samples.')
445  if len(set_w) > 1:
446    raise ValueError('All sample_weight arrays should have '
447                     'the same number of samples. Got array shapes: ' +
448                     str([w.shape for w in weights]))
449  if set_y and set_w and list(set_y)[0] != list(set_w)[0]:
450    raise ValueError('Sample_weight arrays should have '
451                     'the same number of samples as target arrays. Got ' +
452                     str(list(set_y)[0]) + ' input samples and ' +
453                     str(list(set_w)[0]) + ' target samples.')
454
455
456def check_loss_and_target_compatibility(targets, loss_fns, output_shapes):
457  """Does validation on the compatibility of targets and loss functions.
458
459  This helps prevent users from using loss functions incorrectly. This check
460  is purely for UX purposes.
461
462  Arguments:
463      targets: list of Numpy arrays of targets.
464      loss_fns: list of loss functions.
465      output_shapes: list of shapes of model outputs.
466
467  Raises:
468      ValueError: if a loss function or target array
469          is incompatible with an output.
470  """
471  key_loss_fns = {
472      losses.mean_squared_error, losses.binary_crossentropy,
473      losses.categorical_crossentropy
474  }
475  key_loss_classes = (losses.MeanSquaredError, losses.BinaryCrossentropy,
476                      losses.CategoricalCrossentropy)
477  for y, loss, shape in zip(targets, loss_fns, output_shapes):
478    if y is None or loss is None or tensor_util.is_tensor(y):
479      continue
480    if losses.is_categorical_crossentropy(loss):
481      if y.shape[-1] == 1:
482        raise ValueError('You are passing a target array of shape ' +
483                         str(y.shape) +
484                         ' while using as loss `categorical_crossentropy`. '
485                         '`categorical_crossentropy` expects '
486                         'targets to be binary matrices (1s and 0s) '
487                         'of shape (samples, classes). '
488                         'If your targets are integer classes, '
489                         'you can convert them to the expected format via:\n'
490                         '```\n'
491                         'from keras.utils import to_categorical\n'
492                         'y_binary = to_categorical(y_int)\n'
493                         '```\n'
494                         '\n'
495                         'Alternatively, you can use the loss function '
496                         '`sparse_categorical_crossentropy` instead, '
497                         'which does expect integer targets.')
498
499    is_loss_wrapper = isinstance(loss, losses.LossFunctionWrapper)
500    if (isinstance(loss, key_loss_classes) or (is_loss_wrapper and
501                                               (loss.fn in key_loss_fns))):
502      for target_dim, out_dim in zip(y.shape[1:], shape[1:]):
503        if out_dim is not None and target_dim != out_dim:
504          loss_name = loss.name
505          if loss_name is None:
506            loss_type = loss.fn if is_loss_wrapper else type(loss)
507            loss_name = loss_type.__name__
508          raise ValueError('A target array with shape ' + str(y.shape) +
509                           ' was passed for an output of shape ' + str(shape) +
510                           ' while using as loss `' + loss_name + '`. '
511                           'This loss expects targets to have the same shape '
512                           'as the output.')
513
514
515def collect_per_output_metric_info(metrics,
516                                   output_names,
517                                   output_shapes,
518                                   loss_fns,
519                                   is_weighted=False):
520  """Maps metric names and functions to model outputs.
521
522  Arguments:
523      metrics: a list or a list of lists or a dict of metric functions.
524      output_names: a list of the names (strings) of model outputs.
525      output_shapes: a list of the shapes (strings) of model outputs.
526      loss_fns: a list of the loss functions corresponding to the model outputs.
527      is_weighted: Boolean indicating whether the given metrics are weighted.
528
529  Returns:
530      A list (one entry per model output) of dicts.
531      For instance, if the model has 2 outputs, and for the first output
532      we want to compute "binary_accuracy" and "binary_crossentropy",
533      and just "binary_accuracy" for the second output,
534      the list would look like: `[{
535          'acc': binary_accuracy(),
536          'ce': binary_crossentropy(),
537        }, {
538          'acc': binary_accuracy(),
539        }]`
540
541  Raises:
542      TypeError: if an incorrect type is passed for the `metrics` argument.
543  """
544  if not metrics:
545    return [{} for _ in output_names]
546
547  if isinstance(metrics, list):
548    any_sub_list = any(isinstance(m, list) for m in metrics)
549    if any_sub_list:
550      if len(metrics) != len(output_names):
551        raise ValueError('When passing a list of lists as `metrics`, '
552                         'it should have one entry per model output. '
553                         'The model has ' + str(len(output_names)) +
554                         ' outputs, but you passed metrics=' + str(metrics))
555      # User has provided a list of len = len(outputs).
556      nested_metrics = [generic_utils.to_list(m) for m in metrics]
557    else:
558      # If it is a single list we then apply all metrics to all outputs.
559      if len(output_names) > 1:
560        nested_metrics = []
561        for _ in output_names:
562          nested_metrics.append(
563              [metrics_module.clone_metric(m) for m in metrics])
564      else:
565        nested_metrics = [metrics]
566  elif isinstance(metrics, dict):
567    nested_metrics = []
568    for name in output_names:
569      output_metrics = generic_utils.to_list(metrics.get(name, []))
570      nested_metrics.append(output_metrics)
571  else:
572    raise TypeError('Type of `metrics` argument not understood. '
573                    'Expected a list or dictionary, found: ' + str(metrics))
574
575  per_output_metrics = []
576  for i, metrics in enumerate(nested_metrics):
577    metrics_dict = OrderedDict()
578    for metric in metrics:
579      metric_name = get_metric_name(metric, is_weighted)
580      metric_fn = get_metric_function(
581          metric, output_shape=output_shapes[i], loss_fn=loss_fns[i])
582
583      # If the metric function is not stateful, we create a stateful version.
584      if not isinstance(metric_fn, metrics_module.Metric):
585        metric_fn = metrics_module.MeanMetricWrapper(
586            metric_fn, name=metric_name)
587      metrics_dict[metric_name] = metric_fn
588    per_output_metrics.append(metrics_dict)
589
590  return per_output_metrics
591
592
593def batch_shuffle(index_array, batch_size):
594  """Shuffles an array in a batch-wise fashion.
595
596  Useful for shuffling HDF5 arrays
597  (where one cannot access arbitrary indices).
598
599  Arguments:
600      index_array: array of indices to be shuffled.
601      batch_size: integer.
602
603  Returns:
604      The `index_array` array, shuffled in a batch-wise fashion.
605  """
606  batch_count = int(len(index_array) / batch_size)
607  # to reshape we need to be cleanly divisible by batch size
608  # we stash extra items and reappend them after shuffling
609  last_batch = index_array[batch_count * batch_size:]
610  index_array = index_array[:batch_count * batch_size]
611  index_array = index_array.reshape((batch_count, batch_size))
612  np.random.shuffle(index_array)
613  index_array = index_array.flatten()
614  return np.append(index_array, last_batch)
615
616
617def standardize_weights(y,
618                        sample_weight=None,
619                        class_weight=None,
620                        sample_weight_mode=None):
621  """Performs sample weight validation and standardization.
622
623  Everything gets normalized to a single sample-wise (or timestep-wise)
624  weight array. If both `sample_weight` and `class_weight` are provided,
625  the weights are multiplied.
626
627  Arguments:
628      y: Numpy array of model targets to be weighted.
629      sample_weight: User-provided `sample_weight` argument.
630      class_weight: User-provided `class_weight` argument.
631      sample_weight_mode: One of `None` or `"temporal"`. `"temporal"` indicated
632        that we expect 2D weight data that will be applied to the last 2
633        dimensions of the targets (i.e. we are weighting timesteps, not
634        samples).
635
636  Returns:
637      A numpy array of target weights, one entry per sample to weight.
638
639  Raises:
640      ValueError: In case of invalid user-provided arguments.
641  """
642  # Iterator may return sample_weight as 1-tuple
643  if isinstance(sample_weight, tuple):
644    sample_weight = sample_weight[0]
645  if sample_weight_mode is not None:
646    if sample_weight_mode != 'temporal':
647      raise ValueError('"sample_weight_mode '
648                       'should be None or "temporal". '
649                       'Found: ' + str(sample_weight_mode))
650    if len(y.shape) < 3:
651      raise ValueError('Found a sample_weight array for '
652                       'an input with shape ' + str(y.shape) + '. '
653                       'Timestep-wise sample weighting (use of '
654                       'sample_weight_mode="temporal") is restricted to '
655                       'outputs that are at least 3D, i.e. that have '
656                       'a time dimension.')
657    if sample_weight is not None and len(sample_weight.shape) != 2:
658      raise ValueError('Found a sample_weight array with shape ' +
659                       str(sample_weight.shape) + '. '
660                       'In order to use timestep-wise sample weighting, '
661                       'you should pass a 2D sample_weight array.')
662  else:
663    if sample_weight is not None and len(sample_weight.shape) != 1:
664      raise ValueError('Found a sample_weight array with shape ' +
665                       str(sample_weight.shape) + '. '
666                       'In order to use timestep-wise sample weights, '
667                       'you should specify '
668                       'sample_weight_mode="temporal" '
669                       'in compile(). If you just mean to use '
670                       'sample-wise weights, make sure your '
671                       'sample_weight array is 1D.')
672
673  if sample_weight is not None:
674    if len(sample_weight.shape) > len(y.shape):
675      raise ValueError('Found a sample_weight with shape' +
676                       str(sample_weight.shape) + '.'
677                       'Expected sample_weight with rank '
678                       'less than or equal to ' + str(len(y.shape)))
679
680    if (not tensor_util.is_tensor(sample_weight) and
681        y.shape[:sample_weight.ndim] != sample_weight.shape):
682      raise ValueError('Found a sample_weight array with shape ' +
683                       str(sample_weight.shape) + ' for an input with shape ' +
684                       str(y.shape) + '. '
685                       'sample_weight cannot be broadcast.')
686
687  # Class weights applied per-sample.
688  class_sample_weight = None
689  if isinstance(class_weight, dict):
690    if len(y.shape) > 2:
691      raise ValueError('`class_weight` not supported for '
692                       '3+ dimensional targets.')
693
694    if len(y.shape) == 2:
695      if y.shape[1] > 1:
696        y_classes = np.argmax(y, axis=1)
697      elif y.shape[1] == 1:
698        y_classes = np.reshape(y, y.shape[0])
699    else:
700      y_classes = y
701
702    class_sample_weight = np.asarray(
703        [class_weight[cls] for cls in y_classes if cls in class_weight])
704
705    if len(class_sample_weight) != len(y_classes):
706      # subtract the sets to pick all missing classes
707      existing_classes = set(y_classes)
708      existing_class_weight = set(class_weight.keys())
709      raise ValueError(
710          '`class_weight` must contain all classes in the data.'
711          ' The classes %s exist in the data but not in '
712          '`class_weight`.' % (existing_classes - existing_class_weight))
713
714  if class_sample_weight is not None and sample_weight is not None:
715    # Multiply weights if both are provided.
716    return class_sample_weight * sample_weight
717  if sample_weight is not None:
718    return sample_weight
719  if class_sample_weight is not None:
720    return class_sample_weight
721  return None
722
723
724def has_symbolic_tensors(ls):
725  if context.executing_eagerly():
726    return False
727  return has_tensors(ls)
728
729
730def has_tensors(ls):
731  if isinstance(ls, (list, tuple)):
732    return any(tensor_util.is_tensor(v) for v in ls)
733  if isinstance(ls, dict):
734    return any(tensor_util.is_tensor(v) for _, v in six.iteritems(ls))
735  return tensor_util.is_tensor(ls)
736
737
738def get_metric_name(metric, weighted=False):
739  """Returns the name corresponding to the given metric input.
740
741  Arguments:
742    metric: Metric function name or reference.
743    weighted: Boolean indicating if the given metric is weighted.
744
745  Returns:
746      The metric name.
747  """
748  if tf2.enabled():
749    # We keep the string that the user has set in compile as the metric name.
750    if isinstance(metric, six.string_types):
751      return metric
752
753    metric = metrics_module.get(metric)
754    return metric.name if hasattr(metric, 'name') else metric.__name__
755  else:
756    metric_name_prefix = 'weighted_' if weighted else ''
757    if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
758      if metric in ('accuracy', 'acc'):
759        suffix = 'acc'
760      elif metric in ('crossentropy', 'ce'):
761        suffix = 'ce'
762    else:
763      metric_fn = metrics_module.get(metric)
764      # Get metric name as string
765      if hasattr(metric_fn, 'name'):
766        suffix = metric_fn.name
767      else:
768        suffix = metric_fn.__name__
769    metric_name = metric_name_prefix + suffix
770    return metric_name
771
772
773def get_metric_function(metric, output_shape=None, loss_fn=None):
774  """Returns the metric function corresponding to the given metric input.
775
776  Arguments:
777      metric: Metric function name or reference.
778      output_shape: The shape of the output that this metric will be calculated
779        for.
780      loss_fn: The loss function used.
781
782  Returns:
783      The metric function.
784  """
785  if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']:
786    return metrics_module.get(metric)
787
788  is_sparse_categorical_crossentropy = (
789      isinstance(loss_fn, losses.SparseCategoricalCrossentropy) or
790      (isinstance(loss_fn, losses.LossFunctionWrapper) and
791       loss_fn.fn == losses.sparse_categorical_crossentropy))
792
793  is_binary_crossentropy = (
794      isinstance(loss_fn, losses.BinaryCrossentropy) or
795      (isinstance(loss_fn, losses.LossFunctionWrapper) and
796       loss_fn.fn == losses.binary_crossentropy))
797
798  if metric in ['accuracy', 'acc']:
799    if output_shape[-1] == 1 or is_binary_crossentropy:
800      return metrics_module.binary_accuracy
801    elif is_sparse_categorical_crossentropy:
802      return metrics_module.sparse_categorical_accuracy
803    # If the output_shape[-1] is not 1, then we know output is `categorical`.
804    # We assume it is sparse categorical only if loss is explicitly given
805    # as sparse categorical crossentropy loss.
806    return metrics_module.categorical_accuracy
807  else:
808    if output_shape[-1] == 1 or is_binary_crossentropy:
809      return metrics_module.binary_crossentropy
810    elif is_sparse_categorical_crossentropy:
811      return metrics_module.sparse_categorical_crossentropy
812    return metrics_module.categorical_crossentropy
813
814
815def call_metric_function(metric_fn, y_true, y_pred, weights=None, mask=None):
816  """Invokes metric function and returns the metric result tensor."""
817  if mask is None:
818    return metric_fn(y_true, y_pred, sample_weight=weights)
819
820  mask = math_ops.cast(mask, y_pred.dtype)
821  if weights is None:
822    # Use mask as sample weight.
823    return metric_fn(y_true, y_pred, sample_weight=mask)
824
825  # Update dimensions of weights to match with mask.
826  mask, _, weights = squeeze_or_expand_dimensions(mask, None, weights)
827  weights *= mask
828  return metric_fn(y_true, y_pred, sample_weight=weights)
829
830
831def get_loss_function(loss):
832  """Returns the loss function corresponding to the given loss input."""
833  if loss is None or isinstance(loss, losses.Loss):
834    return loss
835
836  # Deserialize loss configuration, if needed.
837  if isinstance(loss, collections.Mapping):
838    loss = losses.get(loss)
839
840  # Custom callable class.
841  if callable(loss) and not hasattr(loss, '__name__'):
842    return loss
843
844  # Wrap loss function with signature `(y_true, y_pred, **kwargs)`
845  # in `LossFunctionWrapper` class.
846  loss_fn = losses.get(loss)
847  return losses.LossFunctionWrapper(loss_fn, name=loss_fn.__name__)
848
849
850def validate_dataset_input(x, y, sample_weight, validation_split=None):
851  """Validates user input arguments when a dataset iterator is passed.
852
853  Arguments:
854    x: Input data. A `tf.data` dataset or iterator.
855    y: Target data. It could be either Numpy array(s) or TensorFlow tensor(s).
856      Expected to be `None` when `x` is a dataset iterator.
857    sample_weight: An optional sample-weight array passed by the user to weight
858      the importance of each sample in `x`. Expected to be `None` when `x` is a
859      dataset iterator
860    validation_split: Float between 0 and 1. Fraction of the training data to be
861      used as validation data. Expected to be `None` when `x` is a dataset
862      iterator.
863
864  Raises:
865    ValueError: if argument `y` or `sample_weight` or `validation_split` are
866        provided by user.
867  """
868  if y is not None:
869    raise ValueError('You passed a dataset or dataset iterator (%s) as '
870                     'input `x` to your model. In that case, you should '
871                     'not specify a target (`y`) argument, since the dataset '
872                     'or dataset iterator generates both input data and '
873                     'target data. '
874                     'Received: %s' % (x, y))
875  if sample_weight is not None:
876    raise ValueError('`sample_weight` argument is not supported when input '
877                     '`x` is a dataset or a dataset iterator. Instead, you'
878                     'can provide sample_weight as the third element  of your'
879                     'dataset, i.e. (inputs, targets, sample_weight). '
880                     'Received: x=%s, sample_weight=%s' % (x, sample_weight))
881  if validation_split is not None and validation_split != 0.0:
882    raise ValueError(
883        '`validation_split` argument is not supported when '
884        'input `x` is a dataset or a dataset iterator. '
885        'Received: x=%s, validation_split=%f' % (x, validation_split))
886
887
888def check_generator_arguments(y=None, sample_weight=None,
889                              validation_split=None):
890  """Validates arguments passed when using a generator."""
891  if y is not None:
892    raise ValueError('`y` argument is not supported when data is'
893                     'a generator or Sequence instance. Instead pass targets'
894                     ' as the second element of the generator.')
895  if sample_weight is not None:
896    raise ValueError('`sample_weight` argument is not supported when data is'
897                     'a generator or Sequence instance. Instead pass sample'
898                     ' weights as the third element of the generator.')
899  if validation_split:
900    raise ValueError('If your data is in the form of a Python generator, '
901                     'you cannot use `validation_split`.')
902
903
904def check_steps_argument(input_data, steps, steps_name):
905  """Validates `steps` argument based on input data's type.
906
907  The cases when `steps` value must be provided are when
908    1. input data passed is an iterator.
909    2. model was built on top of symbolic tensors, input data is not
910       required and is `None`.
911    3. input data passed is a symbolic tensor.
912
913  Arguments:
914      input_data: Input data. Can be Numpy array(s) or TensorFlow tensor(s) or
915        tf.data.Dataset iterator or `None`.
916      steps: Integer or `None`. Total number of steps (batches of samples) to
917        execute.
918      steps_name: The public API's parameter name for `steps`.
919
920  Returns:
921    boolean, True if `steps` argument is required, else False.
922
923  Raises:
924      ValueError: if `steps` argument is required for given input data type
925        but not provided.
926  """
927  # TODO(fchollet): allow datasets with steps=None if cardinality is known.
928  is_x_iterator = isinstance(
929      input_data, (iterator_ops.Iterator, iterator_ops.EagerIterator))
930  if (input_data is None or is_x_iterator or has_symbolic_tensors(input_data) or
931      (isinstance(input_data, list) and not input_data)):
932    if steps is None:
933      input_type_str = 'a Dataset iterator' if is_x_iterator else 'data tensors'
934      raise ValueError('When using {input_type} as input to a model, you should'
935                       ' specify the `{steps_name}` argument.'.format(
936                           input_type=input_type_str, steps_name=steps_name))
937    return True
938  return False
939
940
941def cast_single_tensor(x):
942  if tensor_util.is_tensor(x) and x.dtype.is_floating:
943    return math_ops.cast(x, dtype=K.floatx())
944  return x
945
946
947def cast_if_floating_dtype(x):
948  """Casts the given data tensors to the default floating point type.
949
950  Casts only if the input is already a floating point type.
951  Args:
952    x: tensor or list/tuple of tensors.
953
954  Returns:
955    Converted input.
956
957  Raises:
958    RuntimeError: if data isn't tensors.
959  """
960  if not has_tensors(x):
961    raise RuntimeError(
962        'Please provide tensors for casting, got: {x}'.format(x=x))
963
964  return nest.map_structure(cast_single_tensor, x)
965
966
967def get_output_sample_weight_and_mode(skip_target_weighing_indices,
968                                      sample_weight_mode, output_name,
969                                      output_index):
970  """Returns the sample weight and weight mode for a single output."""
971  if output_index in skip_target_weighing_indices:
972    return None, None
973
974  if sample_weight_mode == 'temporal':
975    default_value = [[1.]]
976    shape = [None, None]
977    mode = 'temporal'
978  else:
979    default_value = [1.]
980    shape = [None]
981    mode = None
982  if context.executing_eagerly():
983    weight = None
984  else:
985    weight = array_ops.placeholder_with_default(
986        constant_op.constant(default_value, dtype=K.floatx()),
987        shape=shape,
988        name=output_name + '_sample_weights')
989  return weight, mode
990
991
992def prepare_sample_weights(output_names, sample_weight_mode,
993                           skip_target_weighing_indices):
994  """Prepares sample weights for the model.
995
996  Args:
997    output_names: List of model output names.
998    sample_weight_mode: sample weight mode user input passed from compile API.
999    skip_target_weighing_indices: Indices of output for which sample weights
1000      should be skipped.
1001
1002  Returns:
1003    A pair of list of sample weights and sample weight modes
1004      (one for each output).
1005
1006  Raises:
1007    ValueError: In case of invalid `sample_weight_mode` input.
1008  """
1009  sample_weights = []
1010  sample_weight_modes = []
1011  if isinstance(sample_weight_mode, dict):
1012    unknown_output = set(sample_weight_mode.keys()) - set(output_names)
1013    if unknown_output:
1014      raise ValueError('Unknown entry in '
1015                       'sample_weight_mode dictionary: "' + unknown_output +
1016                       '". Only expected the following keys: ' +
1017                       str(output_names))
1018    for i, name in enumerate(output_names):
1019      if (i not in skip_target_weighing_indices and
1020          name not in sample_weight_mode):
1021        raise ValueError('Output missing from sample_weight_modes dictionary')
1022      weight, mode = get_output_sample_weight_and_mode(
1023          skip_target_weighing_indices, sample_weight_mode.get(name), name, i)
1024      sample_weights.append(weight)
1025      sample_weight_modes.append(mode)
1026  elif isinstance(sample_weight_mode, list):
1027    if len(sample_weight_mode) != len(output_names):
1028      raise ValueError('When passing a list as sample_weight_mode, '
1029                       'it should have one entry per model output. '
1030                       'The model has ' + str(len(output_names)) +
1031                       ' outputs, but you passed ' +
1032                       str(len(sample_weight_mode)) + 'sample_weight_modes')
1033    for i, name in enumerate(output_names):
1034      weight, mode = get_output_sample_weight_and_mode(
1035          skip_target_weighing_indices, sample_weight_mode[i], name, i)
1036      sample_weights.append(weight)
1037      sample_weight_modes.append(mode)
1038  else:
1039    for i, name in enumerate(output_names):
1040      weight, mode = get_output_sample_weight_and_mode(
1041          skip_target_weighing_indices, sample_weight_mode, name, i)
1042      sample_weights.append(weight)
1043      sample_weight_modes.append(mode)
1044  return sample_weights, sample_weight_modes
1045
1046
1047def prepare_loss_functions(loss, output_names):
1048  """Converts loss to a list of loss functions.
1049
1050  Arguments:
1051      loss: String (name of objective function), objective function or
1052        `tf.losses.Loss` instance. See `tf.losses`. If the model has multiple
1053        outputs, you can use a different loss on each output by passing a
1054        dictionary or a list of losses. The loss value that will be minimized by
1055        the model will then be the sum of all individual losses.
1056      output_names: List of model output names.
1057
1058  Returns:
1059      A list of loss objective functions.
1060
1061  Raises:
1062      ValueError: If loss is a dict with keys not in model output names,
1063          or if loss is a list with len not equal to model outputs.
1064  """
1065  if isinstance(loss, collections.Mapping):
1066    for name in loss:
1067      if name not in output_names:
1068        raise ValueError('Unknown entry in loss dictionary: {}. Only expected '
1069                         'following keys: {}'.format(name, output_names))
1070    loss_functions = []
1071    for name in output_names:
1072      if name not in loss:
1073        logging.warning(
1074            'Output {0} missing from loss dictionary. We assume '
1075            'this was done on purpose. The fit and evaluate APIs will not be '
1076            'expecting any data to be passed to {0}.'.format(name))
1077      loss_functions.append(get_loss_function(loss.get(name, None)))
1078  elif isinstance(loss, six.string_types):
1079    loss_functions = [get_loss_function(loss) for _ in output_names]
1080  elif isinstance(loss, collections.Sequence):
1081    if len(loss) != len(output_names):
1082      raise ValueError('When passing a list as loss, it should have one entry '
1083                       'per model outputs. The model has {} outputs, but you '
1084                       'passed loss={}'.format(len(output_names), loss))
1085    loss_functions = nest.map_structure(get_loss_function, loss)
1086  else:
1087    loss_functions = [get_loss_function(loss) for _ in range(len(output_names))]
1088
1089  return loss_functions
1090
1091
1092def prepare_loss_weights(output_names, loss_weights=None):
1093  """Converts loss weights to a list of loss weights.
1094
1095  Arguments:
1096      output_names: List of model output names.
1097      loss_weights: Optional list or dictionary specifying scalar coefficients
1098        (Python floats) to weight the loss contributions of different model
1099        outputs. The loss value that will be minimized by the model will then be
1100        the *weighted sum* of all individual losses, weighted by the
1101          `loss_weights` coefficients. If a list, it is expected to have a 1:1
1102            mapping to the model's outputs. If a dict, it is expected to map
1103            output names (strings) to scalar coefficients.
1104
1105  Returns:
1106      A list of loss weights of python floats.
1107
1108  Raises:
1109      ValueError: If loss weight is a dict with key not in model output names,
1110          or if loss is a list with len not equal to model outputs.
1111  """
1112  if loss_weights is None:
1113    weights_list = [1.] * len(output_names)
1114  elif isinstance(loss_weights, dict):
1115    for name in loss_weights:
1116      if name not in output_names:
1117        raise ValueError('Unknown entry in loss_weights dictionary: {}. '
1118                         'Only expected the following keys: {}'.format(
1119                             name, output_names))
1120    weights_list = [loss_weights.get(name, 1.) for name in output_names]
1121  elif isinstance(loss_weights, list):
1122    if len(loss_weights) != len(output_names):
1123      raise ValueError('When passing a list as loss_weights, '
1124                       'it should have one entry per model output. '
1125                       'The model has ' + str(len(output_names)) +
1126                       ' outputs, but you passed loss_weights=' +
1127                       str(loss_weights))
1128    weights_list = loss_weights
1129  else:
1130    raise TypeError('Could not interpret loss_weights argument: ' +
1131                    str(loss_weights) + ' - expected a list of dicts.')
1132
1133  return weights_list
1134
1135
1136# TODO(rohanj): This is a hack to get around not depending on feature_column and
1137# create a cyclical dependency. Figure out a cleaner solution
1138def is_feature_layer(layer):
1139  """Returns whether `layer` is a FeatureLayer or not."""
1140  return getattr(layer, '_is_feature_layer', False)
1141
1142
1143def is_eager_dataset_or_iterator(data):
1144  return context.executing_eagerly() and isinstance(
1145      data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2,
1146             iterator_ops.EagerIterator))
1147
1148
1149# pylint: disable=protected-access
1150def assert_not_batched(dataset):
1151  """Asserts that `dataset` is not batched.
1152
1153  The algorithm used by this method is sound but not complete. In other words,
1154  if the method fails to establish the assertion, it does not mean the dataset
1155  is batched.
1156
1157  Example usage:
1158  ```python
1159  try:
1160    assert_not_batched(dataset)
1161    # safe to assume `dataset` it not batched here
1162  expect ValueError:
1163    # make no assumptions about `dataset`
1164  ```
1165
1166  Args:
1167    dataset: The dataset to analyze.
1168
1169  Raises:
1170    ValueError: If the method cannot establish the assertion.
1171  """
1172  if isinstance(dataset, dataset_ops.DatasetV1Adapter):
1173    return assert_not_batched(dataset._dataset)
1174  else:
1175    whitelisted_types = [
1176        dataset_ops._OptionsDataset,
1177        dataset_ops.ConcatenateDataset,
1178        dataset_ops.CacheDataset,
1179        dataset_ops.FilterDataset,
1180        dataset_ops.MapDataset,
1181        dataset_ops.ParallelMapDataset,
1182        dataset_ops.PrefetchDataset,
1183        dataset_ops.RangeDataset,
1184        dataset_ops.RepeatDataset,
1185        dataset_ops.ShuffleDataset,
1186        dataset_ops.SkipDataset,
1187        dataset_ops.SparseTensorSliceDataset,
1188        dataset_ops.TakeDataset,
1189        dataset_ops.TensorDataset,
1190        dataset_ops.TensorSliceDataset,
1191        dataset_ops.ZipDataset,
1192        readers.FixedLengthRecordDatasetV2,
1193        readers.TextLineDatasetV2,
1194        readers.TFRecordDatasetV2,
1195    ]
1196    for ty in whitelisted_types:
1197      if isinstance(dataset, ty):
1198        for input_dataset in dataset._inputs():
1199          assert_not_batched(input_dataset)
1200        return
1201    raise ValueError('Could not assert that dataset is not batched.')
1202
1203
1204# pylint: disable=protected-access
1205def assert_not_shuffled(dataset):
1206  """Asserts that `dataset` is not shuffled.
1207
1208  The algorithm used by this method is sound but not complete. In other words,
1209  if the method fails to establish the assertion, it does not mean the dataset
1210  is shuffled.
1211
1212  Example usage:
1213  ```python
1214  try:
1215    assert_not_shuffled(dataset)
1216    # safe to assume `dataset` it not shuffled here
1217  expect ValueError:
1218    # make no assumptions about `dataset`
1219  ```
1220
1221  Args:
1222    dataset: The dataset to analyze.
1223
1224  Raises:
1225    ValueError: If the method cannot establish the assertion.
1226  """
1227  if isinstance(dataset, dataset_ops.DatasetV1Adapter):
1228    return assert_not_shuffled(dataset._dataset)
1229  else:
1230    whitelisted_types = [
1231        dataset_ops._OptionsDataset,
1232        dataset_ops.BatchDataset,
1233        dataset_ops.ConcatenateDataset,
1234        dataset_ops.CacheDataset,
1235        dataset_ops.FilterDataset,
1236        dataset_ops.MapDataset,
1237        dataset_ops.PaddedBatchDataset,
1238        dataset_ops.ParallelMapDataset,
1239        dataset_ops.PrefetchDataset,
1240        dataset_ops.RangeDataset,
1241        dataset_ops.RepeatDataset,
1242        dataset_ops.SkipDataset,
1243        dataset_ops.SparseTensorSliceDataset,
1244        dataset_ops.TakeDataset,
1245        dataset_ops.TensorDataset,
1246        dataset_ops.TensorSliceDataset,
1247        dataset_ops.WindowDataset,
1248        dataset_ops.ZipDataset,
1249        readers.FixedLengthRecordDatasetV2,
1250        readers.TextLineDatasetV2,
1251        readers.TFRecordDatasetV2,
1252    ]
1253    for ty in whitelisted_types:
1254      if isinstance(dataset, ty):
1255        for input_dataset in dataset._inputs():
1256          assert_not_shuffled(input_dataset)
1257        return
1258    raise ValueError('Could not assert that dataset is not shuffled.')
1259
1260
1261def verify_dataset_shuffled(x):
1262  """Verifies that the dataset is shuffled.
1263
1264  Args:
1265    x: Dataset passed as an input to the model.
1266
1267  Raises:
1268    ValueError: if the dataset is not already shuffled.
1269  """
1270  assert isinstance(x, dataset_ops.DatasetV2)
1271  try:
1272    assert_not_shuffled(x)
1273  except ValueError:
1274    # Dataset may or may not be shuffled.
1275    return
1276  else:
1277    logging.warning('Expected a shuffled dataset but input dataset `x` is '
1278                    'not shuffled. Please invoke `shuffle()` on input dataset.')
1279
1280
1281def is_dataset_or_iterator(data):
1282  return isinstance(data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2,
1283                           iterator_ops.EagerIterator, iterator_ops.Iterator))
1284
1285
1286def get_iterator(dataset):
1287  """Create and initialize an iterator from a dataset."""
1288  iterator = dataset_ops.make_initializable_iterator(dataset)
1289  initialize_iterator(iterator)
1290  return iterator
1291
1292
1293def initialize_iterator(iterator):
1294  init_op = iterator.initializer
1295  if not context.executing_eagerly():
1296    K.get_session((init_op,)).run(init_op)
1297
1298
1299def extract_tensors_from_dataset(dataset):
1300  """Extract a tuple of tensors `inputs, targets, sample_weight` from a dataset.
1301
1302  Arguments:
1303    dataset: Dataset instance.
1304
1305  Returns:
1306    Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
1307  """
1308  iterator = get_iterator(dataset)
1309  inputs, targets, sample_weight = unpack_iterator_input(iterator)
1310  return inputs, targets, sample_weight
1311
1312
1313def unpack_iterator_input(iterator):
1314  """Convert a dataset iterator to a tuple of tensors `x, y, sample_weights`.
1315
1316  Arguments:
1317    iterator: Instance of a dataset iterator.
1318
1319  Returns:
1320    Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
1321  """
1322  try:
1323    next_element = iterator.get_next()
1324  except errors.OutOfRangeError:
1325    raise RuntimeError('Your dataset iterator ran out of data; '
1326                       'Make sure that your dataset can generate '
1327                       'required number of samples.')
1328
1329  if isinstance(next_element, (list, tuple)):
1330    if len(next_element) not in [2, 3]:
1331      raise ValueError(
1332          'Please provide model inputs as a list or tuple of 2 or 3 '
1333          'elements: (input, target) or (input, target, sample_weights) '
1334          'Received %s' % next_element)
1335    if len(next_element) == 2:
1336      x, y = next_element
1337      weights = None
1338    else:
1339      x, y, weights = next_element
1340  else:
1341    x = next_element
1342    y = None
1343    weights = None
1344  return x, y, weights
1345
1346
1347def infer_steps_for_dataset(dataset, steps, epochs=1, steps_name='steps'):
1348  """Infers steps_per_epoch needed to loop through a dataset.
1349
1350  Arguments:
1351      dataset: Input data of type tf.data.Dataset.
1352      steps: Number of steps to draw from the dataset (may be None if unknown).
1353      epochs: Number of times to iterate over the dataset.
1354      steps_name: The string name of the steps argument, either `steps`,
1355        `validation_steps`, or `steps_per_epoch`. Only used for error message
1356        formatting.
1357
1358  Returns:
1359    Integer or `None`. Inferred number of steps to loop through the dataset.
1360    `None` is returned if the size of the dataset is unknown and `steps` was
1361    not specified.
1362
1363  Raises:
1364    ValueError: In case of invalid argument values.
1365  """
1366  assert isinstance(dataset, dataset_ops.DatasetV2)
1367  size = K.get_value(cardinality.cardinality(dataset))
1368  if size == cardinality.INFINITE and steps is None:
1369    raise ValueError('When passing an infinitely repeating dataset, you '
1370                     'must specify the `%s` argument.' % (steps_name,))
1371  if size >= 0:
1372    if steps is not None and steps * epochs > size:
1373      if epochs > 1:
1374        raise ValueError('The dataset you passed contains %s batches, but you '
1375                         'passed `epochs=%s` and `%s=%s`, which is a total of '
1376                         '%s steps. We cannot draw that many steps from this '
1377                         'dataset. We suggest to set `%s=%s`.' %
1378                         (size, epochs, steps_name, steps, steps * epochs,
1379                          steps_name, size // epochs))
1380      else:
1381        raise ValueError('The dataset you passed contains %s batches, but you '
1382                         'passed `%s=%s`. We cannot draw that many steps from '
1383                         'this dataset. We suggest to set `%s=%s`.' %
1384                         (size, steps_name, steps, steps_name, size))
1385  if steps is None:
1386    if size >= 0:
1387      return size
1388    return None
1389  return steps
1390
1391
1392class ModelInputs(object):
1393  """Encapsulates model inputs.
1394
1395  Allows for transforming model inputs while keeping the same structure.
1396  """
1397
1398  def __init__(self, inputs):
1399    self._inputs = inputs
1400    self._is_dict = isinstance(self._inputs, dict)
1401    self._is_single_input = not isinstance(self._inputs, (list, tuple, dict))
1402
1403    self._flattened_inputs = []
1404    self._input_names = []
1405
1406    if self._is_dict:
1407      for k in sorted(self._inputs.keys()):
1408        self._flattened_inputs.append(self._inputs[k])
1409        self._input_names.append(k)
1410    else:
1411      self._flattened_inputs = nest.flatten(self._inputs)
1412      self._input_names = [
1413          'input_%d' % (i + 1) for i in range(len(self._flattened_inputs))
1414      ]
1415
1416  def get_input_names(self):
1417    """Returns keys to name inputs by.
1418
1419    In case inputs provided were a list, tuple or single entry, we make up a
1420    key 'input_%d'. For dictionary case, we return a sorted list of keys.
1421    """
1422    return self._input_names
1423
1424  def get_symbolic_inputs(self, return_single_as_list=False):
1425    """Returns inputs to be set as self.inputs for a model."""
1426    # TODO(karmel): There is a side-effect here where what you get
1427    # with as_list and as_dict depends on whether you have called this
1428    # method first, since it modifies in place.
1429    for i in range(len(self._flattened_inputs)):
1430      k = self._input_names[i]
1431      v = self._flattened_inputs[i]
1432      if isinstance(v, (list, float, int)):
1433        v = np.asarray(v)
1434        if v.ndim == 1:
1435          v = np.expand_dims(v, 1)
1436
1437      if isinstance(v, (np.ndarray, ops.EagerTensor)):
1438        # We fix the placeholder shape except the batch size.
1439        # This is suboptimal, but it is the best we can do with the info
1440        # we have. The user should call `model._set_inputs(placeholders)`
1441        # to specify custom placeholders if the need arises.
1442        shape = (None,) + tuple(v.shape[1:])
1443        dtype = dtypes.as_dtype(v.dtype)
1444        if dtype.is_floating:
1445          dtype = K.floatx()
1446        v = K.placeholder(shape=shape, name=k, dtype=dtype)
1447      elif isinstance(v, tensor_shape.TensorShape):
1448        shape = (None,) + tuple(v.as_list()[1:])
1449        v = K.placeholder(shape=shape, name=k)
1450
1451      self._flattened_inputs[i] = v
1452
1453    if self._is_dict:
1454      return dict(zip(self._input_names, self._flattened_inputs))
1455    if self._is_single_input and not return_single_as_list:
1456      return self._flattened_inputs[0]
1457    return self._flattened_inputs
1458
1459  def as_dict(self):
1460    """An iterable over a dictionary version of inputs."""
1461    for i in range(len(self._flattened_inputs)):
1462      yield self._input_names[i], self._flattened_inputs[i]
1463
1464  def as_list(self):
1465    """Returning the inputs as a list."""
1466    return self._flattened_inputs
1467
1468
1469# Allow use of methods not exposed to the user.
1470# pylint: disable=protected-access
1471def get_input_shape_and_dtype(layer):
1472  """Retrieves input shape and input dtype of layer if applicable.
1473
1474  Args:
1475    layer: Layer (or model) instance.
1476
1477  Returns:
1478    Tuple (input_shape, input_dtype). Both could be None if the layer
1479      does not have a defined input shape.
1480
1481  Raises:
1482    ValueError: in case an empty Sequential or Functional model is passed.
1483  """
1484
1485  def _is_graph_model(layer):
1486    return ((hasattr(layer, '_is_graph_network') and layer._is_graph_network) or
1487            layer.__class__.__name__ == 'Sequential')
1488
1489  # In case of nested models: recover the first layer
1490  # of the deepest model to infer input shape and dtype.
1491  # Subclassed Models may not have been built so can't be checked.
1492  while _is_graph_model(layer):
1493    if not layer.layers:
1494      raise ValueError('An empty Model cannot be used as a Layer.')
1495    layer = layer.layers[0]
1496
1497  if hasattr(layer, '_batch_input_shape'):
1498    return layer._batch_input_shape, layer.dtype
1499  return None, None
1500
1501
1502# pylint: enable=protected-access
1503
1504
1505def get_static_batch_size(layer):
1506  """Gets the static batch size of a Layer.
1507
1508  Arguments:
1509    layer: a `Layer` instance.
1510
1511  Returns:
1512    The static batch size of a Layer.
1513  """
1514  batch_input_shape, _ = get_input_shape_and_dtype(layer)
1515  if batch_input_shape is not None:
1516    return tensor_shape.as_dimension(batch_input_shape[0]).value
1517  return None
1518
1519
1520def generic_output_names(outputs_list):
1521  return ['output_%d' % (i + 1) for i in range(len(outputs_list))]
1522
1523
1524def convert_eager_tensors_to_numpy(structure):
1525  """Convert every EagerTensor in `structure` to NumPy.
1526
1527  Arguments:
1528    structure: An arbitrary structure of elements to be converted to NumPy
1529      arrays.
1530
1531  Returns:
1532    An identical structure with EagerTensors converted to NumPy arrays.
1533  """
1534
1535  def _convert(element):
1536    if isinstance(element, ops.EagerTensor):
1537      return element.numpy()
1538    return element
1539
1540  return nest.map_structure(_convert, structure)
1541
1542
1543def should_run_validation(validation_freq, epoch):
1544  """Checks if validation should be run this epoch.
1545
1546  Arguments:
1547    validation_freq: Integer or list. If an integer, specifies how many training
1548      epochs to run before a new validation run is performed. If a list,
1549      specifies the epochs on which to run validation.
1550    epoch: Integer, the number of the training epoch just completed.
1551
1552  Returns:
1553    Bool, True if validation should be run.
1554
1555  Raises:
1556    ValueError: if `validation_freq` is an Integer and less than 1, or if
1557    it is neither an Integer nor a Sequence.
1558  """
1559  # `epoch` is 0-indexed internally but 1-indexed in the public API.
1560  one_indexed_epoch = epoch + 1
1561
1562  if isinstance(validation_freq, int):
1563    if validation_freq < 1:
1564      raise ValueError('`validation_freq` can not be less than 1.')
1565    return one_indexed_epoch % validation_freq == 0
1566
1567  if not isinstance(validation_freq, collections.Container):
1568    raise ValueError('`validation_freq` must be an Integer or '
1569                     '`collections.Container` (e.g. list, tuple, etc.)')
1570  return one_indexed_epoch in validation_freq
1571