1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities related to distributed training."""
16# pylint:disable=protected-access
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22
23import numpy as np
24
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.data.ops import iterator_ops
27from tensorflow.python.distribute import distribute_coordinator_context as dc_context
28from tensorflow.python.distribute import multi_worker_util
29from tensorflow.python.distribute import reduce_util
30from tensorflow.python.eager import context
31from tensorflow.python.eager import def_function
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import sparse_tensor
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.keras import backend as K
37from tensorflow.python.keras import callbacks
38from tensorflow.python.keras import metrics as metrics_module
39from tensorflow.python.keras import optimizers
40from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
41from tensorflow.python.keras.engine import training_utils_v1
42from tensorflow.python.keras.optimizer_v2 import optimizer_v2
43from tensorflow.python.keras.utils import tf_contextlib
44from tensorflow.python.keras.utils.mode_keys import ModeKeys
45from tensorflow.python.ops import array_ops
46from tensorflow.python.ops import control_flow_ops
47from tensorflow.python.ops import math_ops
48from tensorflow.python.ops import sparse_ops
49from tensorflow.python.ops import variables
50from tensorflow.python.ops.ragged import ragged_tensor
51from tensorflow.python.platform import tf_logging as logging
52from tensorflow.python.util import nest
53
54
55def set_weights(distribution_strategy, dist_model, weights):
56  """Sets the weights of the replicated models.
57
58  The weights of the replicated models are set to the weights of the original
59  model. The weights of the replicated model are Mirrored variables and hence
60  we need to use the `update` call within a DistributionStrategy scope.
61
62  Args:
63    distribution_strategy: DistributionStrategy used to distribute training
64        and validation.
65    dist_model: The replicated models on the different devices.
66    weights: The weights of the original model.
67  """
68  assign_ops = []
69  for layer in dist_model.layers:
70    num_param = len(layer.weights)
71    layer_weights = weights[:num_param]
72    for sw, w in zip(layer.weights, layer_weights):
73      if ops.executing_eagerly_outside_functions():
74        sw.assign(w)
75      else:
76        assign_ops.append(distribution_strategy.unwrap(sw.assign(w)))
77    weights = weights[num_param:]
78
79  if not ops.executing_eagerly_outside_functions():
80    K.get_session(assign_ops).run(assign_ops)
81
82
83def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs,
84                  grouped_updates=None, grouped_session_args=None,
85                  with_loss_tensor=False):
86  """Unwrap the list of values contained in the PerReplica parameters.
87
88  This function calls `flatten_per_replica_values` to parse each of the input
89  parameters into a list of values on the different devices. If we set
90  `with_loss_tensor` to be True, we also call `reduce` on the list of losses on
91  the different devices to give us one loss tensor.
92
93  Args:
94    distribution_strategy: DistributionStrategy used to distribute training and
95        validation.
96    grouped_inputs: PerReplica inputs returned from the train or test function
97        that we ran on each device.
98    grouped_outputs: PerReplica outputs returned from the train or test function
99        that we ran on each device.
100    grouped_updates: PerReplica updates returned from the train or test function
101        that we ran on each device.
102    grouped_session_args: PerReplica session args returned from the train or
103        test function that we ran on each device.
104    with_loss_tensor: Boolean that indicates if we need to add the reduced loss
105        tensor as one of the outputs.
106
107  Returns:
108    Values of each of the PerReplica parameters.
109
110  """
111  # Unwrap per device values returned from each model's train function.
112  # This will be used to construct the main train function.
113  all_inputs = flatten_per_replica_values(distribution_strategy,
114                                          grouped_inputs)
115  all_outputs = unwrap_outputs(distribution_strategy, grouped_outputs,
116                               with_loss_tensor)
117
118  if grouped_updates:
119    all_updates = flatten_per_replica_values(distribution_strategy,
120                                             grouped_updates)
121  else:
122    all_updates = None
123
124  all_session_args = {}
125  if grouped_session_args:
126    grouped_feed_dict = grouped_session_args.get('feed_dict')
127    if grouped_feed_dict:
128      all_session_args['feed_dict'] = flatten_per_replica_values(
129          distribution_strategy, grouped_feed_dict)
130
131    grouped_fetches = grouped_session_args.get('fetches')
132    if grouped_fetches:
133      all_session_args['fetches'] = flatten_per_replica_values(
134          distribution_strategy, grouped_fetches)
135
136  # TODO(priyag): Return only non empty/None values
137  return all_inputs, all_outputs, all_updates, all_session_args
138
139
140def unwrap_output_dict(strategy, grouped_outputs, mode):
141  """Unwrap the list of outputs contained in the PerReplica parameters."""
142  if mode == ModeKeys.PREDICT:
143    return flatten_per_replica_values(strategy, grouped_outputs)
144
145  # In the case of fit/eval, the grouped_outputs is a dict, whereas in predict,
146  # the output is as same structure as model output. They need to be treated
147  # differently
148  total_loss = strategy.reduce(reduce_util.ReduceOp.SUM,
149                               grouped_outputs['total_loss'][0], axis=None)
150  output_losses = flatten_per_replica_values(strategy,
151                                             grouped_outputs['output_losses'])
152  metrics = flatten_per_replica_values(strategy,
153                                       grouped_outputs['metrics'])
154  batch_size = strategy.reduce(reduce_util.ReduceOp.SUM,
155                               grouped_outputs['batch_size'], axis=None)
156  if (K.is_tpu_strategy(strategy) and
157      ops.executing_eagerly_outside_functions()):
158    # Choose 1 value per replica in the TPU case since all replicas produce the
159    # same output.
160    # We only do this in eager mode for now since this function is used in
161    # both graph and eager mode and in the graph case we currently don't use
162    # experimental_run so would need to be removed when we converge the graph
163    # code path as well.
164    output_losses = output_losses[::strategy.num_replicas_in_sync]
165    metrics = metrics[::strategy.num_replicas_in_sync]
166  return {'total_loss': [total_loss],
167          'output_losses': output_losses,
168          'metrics': metrics,
169          'batch_size': batch_size}
170
171
172def unwrap_outputs(distribution_strategy, grouped_outputs,
173                   with_loss_tensor=False):
174  """Unwrap the list of outputs contained in the PerReplica parameters.
175
176  This function calls `flatten_per_replica_values` to parse each of the input
177  parameters into a list of outputs on the different devices. If we set
178  `with_loss_tensor` to be True, we also call `reduce` on the list of losses on
179  the different devices to give us one loss tensor.
180
181  Args:
182    distribution_strategy: DistributionStrategy used to distribute training and
183        validation.
184    grouped_outputs: PerReplica outputs returned from the train or test function
185        that we ran on each device.
186    with_loss_tensor: Boolean that indicates if we need to add the reduced loss
187        tensor as one of the outputs.
188
189  Returns:
190    Values of each of the PerReplica outputs.
191
192  """
193  if not with_loss_tensor:
194    return flatten_per_replica_values(distribution_strategy,
195                                      grouped_outputs)
196
197  if not isinstance(grouped_outputs, list):
198    grouped_outputs = [grouped_outputs]
199  # reduce loss tensor before adding it to the list of fetches
200  loss = distribution_strategy.reduce(reduce_util.ReduceOp.SUM,
201                                      grouped_outputs[0], axis=None)
202  all_outputs = flatten_per_replica_values(distribution_strategy,
203                                           grouped_outputs[1:])
204  if (K.is_tpu_strategy(distribution_strategy) and
205      ops.executing_eagerly_outside_functions()):
206    # Choose 1 value per replica in the TPU case since all replicas produce the
207    # same output.
208    # We only do this in eager mode for now since this function is used in
209    # both graph and eager mode and in the graph case we currently don't use
210    # experimental_run so would need to be removed when we converge the graph
211    # code path as well.
212    all_outputs = all_outputs[::distribution_strategy.num_replicas_in_sync]
213  return [loss] + all_outputs
214
215
216def flatten_per_replica_values(distribution_strategy, per_replica_values):
217  """Unwraps and flattens a nest of PerReplica parameters.
218
219  PerReplica values have one value associated with each device. Each entry in
220  the PerReplica dict has a device `key` and the corresponding value on the
221  device as the `value`. In this function we take a PerReplica value or a list
222  of PerReplica values and return all the values in the PerReplica dict.
223
224  Args:
225    distribution_strategy: DistributionStrategy used to distribute training and
226      validation.
227    per_replica_values: List of PerReplica object or a single PerReplica object.
228
229  Returns:
230    List of values of all the PerReplica objects.
231
232  """
233  # pylint: disable=g-complex-comprehension
234  # This function takes a PerReplica object or a list of PerReplica objects and
235  # returns all the values associated with it.
236  return [e for flattened in nest.flatten(per_replica_values)
237          for e in distribution_strategy.unwrap(flattened)]
238
239
240def validate_callbacks(input_callbacks, optimizer):
241  """Validate whether given callbacks are supported by DistributionStrategy.
242
243  Args:
244    input_callbacks: List of callbacks passed by the user to fit.
245    optimizer: Optimizer instance used to train the model.
246
247  Raises:
248    ValueError: If `LearningRateScheduler` or `ReduceLROnPlateau` is one of the
249        callbacks passed.
250    ValueError: If `write_grads` is one of the parameters passed as part of the
251        TensorBoard callback.
252  """
253  if input_callbacks:
254    for callback in input_callbacks:
255      if isinstance(callback, (callbacks.LearningRateScheduler,
256                               callbacks.ReduceLROnPlateau)):
257
258        if not isinstance(optimizer, optimizer_v2.OptimizerV2):
259          raise ValueError('You must specify a Keras Optimizer V2 when using '
260                           '%s callback with DistributionStrategy.' % callback)
261
262      # If users want to use the TensorBoard callback they cannot use certain
263      # features of the callback that involve accessing model attributes and
264      # running ops.
265      if isinstance(callback, callbacks.TensorBoard):
266        if getattr(callback, 'write_grads', False):
267          logging.warning(
268              UserWarning(
269                  '`write_grads` in the TensorBoard callback is not supported '
270                  'when using DistributionStrategy. Setting `write_grads` '
271                  'to `False`.'))
272          callback.write_grads = False
273
274
275def validate_distributed_dataset_inputs(distribution_strategy, x, y,
276                                        sample_weights=None):
277  """Validate all the components of a DistributedValue Dataset input.
278
279  Args:
280    distribution_strategy: The current DistributionStrategy used to call
281        `fit`/`evaluate`.
282    x: Input Dataset DistributedValue object. For example, when we use
283        `MirroredStrategy` this is a PerReplica object with a tensor for each
284        device set in the dict. x can also be a tuple or dict. The keys of the
285        dict should match the names of the input layers of the model.
286    y: Target Dataset DistributedValue object. For example, when we use
287        `MirroredStrategy` this is a PerReplica object with a tensor for each
288        device set in the dict. y can also be a tuple or dict. The keys of the
289        dict should match the names of the output layers of the model.
290    sample_weights: Sample weights Dataset DistributedValue object. For example,
291        when we use `MirroredStrategy` this is a PerReplica object with a tensor
292        for each device set in the dict.
293
294  Returns:
295    The unwrapped values list of the x and y DistributedValues inputs.
296
297  Raises:
298    ValueError: If x and y do not have support for being evaluated as tensors.
299        or if x and y contain elements that are not tensors or if x and y
300        contain elements that have a shape or dtype mismatch.
301  """
302  # If the input and target used to call the model are not dataset tensors,
303  # we need to raise an error. When using a DistributionStrategy, the input
304  # and targets to a model should be from a `tf.data.Dataset`.
305
306  # If each element of x and y are not tensors, we cannot standardize and
307  # validate the input and targets.
308  x_values_list = validate_per_replica_inputs(distribution_strategy, x)
309
310  if y is not None:
311    y_values_list = validate_per_replica_inputs(distribution_strategy, y)
312  else:
313    y_values_list = None
314
315  if sample_weights is not None:
316    sample_weights_list = validate_per_replica_inputs(distribution_strategy,
317                                                      sample_weights)
318  else:
319    sample_weights_list = None
320
321  # Return the unwrapped values to avoid calling `unwrap` a second time.
322  return x_values_list, y_values_list, sample_weights_list
323
324
325def validate_per_replica_inputs(distribution_strategy, x):
326  """Validates PerReplica dataset input list.
327
328  Args:
329    distribution_strategy: The current DistributionStrategy used to call
330      `fit`, `evaluate` and `predict`.
331    x: A list of PerReplica objects that represent the input or
332      target values.
333
334  Returns:
335    List containing the first element of each of the PerReplica objects in
336    the input list.
337
338  Raises:
339    ValueError: If any of the objects in the `per_replica_list` is not a tensor.
340
341  """
342  # Convert the inputs and targets into a list of PerReplica objects.
343  per_replica_list = nest.flatten(x, expand_composites=True)
344  x_values_list = []
345  for x in per_replica_list:
346    # At this point x should contain only tensors.
347    x_values = distribution_strategy.unwrap(x)
348    for value in x_values:
349      if not tensor_util.is_tf_type(value):
350        raise ValueError('Dataset input to the model should be tensors instead '
351                         'they are of type {}'.format(type(value)))
352
353    if not context.executing_eagerly():
354      # Validate that the shape and dtype of all the elements in x are the same.
355      validate_all_tensor_shapes(x, x_values)
356    validate_all_tensor_types(x, x_values)
357
358    x_values_list.append(x_values[0])
359  return x_values_list
360
361
362def validate_all_tensor_types(x, x_values):
363  x_dtype = x_values[0].dtype
364  for i in range(1, len(x_values)):
365    if x_dtype != x_values[i].dtype:
366      raise ValueError('Input tensor dtypes do not match for distributed tensor'
367                       ' inputs {}'.format(x))
368
369
370def validate_all_tensor_shapes(x, x_values):
371  # Validate that the shape of all the elements in x have the same shape
372  x_shape = x_values[0].shape.as_list()
373  for i in range(1, len(x_values)):
374    if x_shape != x_values[i].shape.as_list():
375      raise ValueError('Input tensor shapes do not match for distributed tensor'
376                       ' inputs {}'.format(x))
377
378
379def _wait_for_variable_initialization(session):
380  """Utility to wait for variables to be initialized."""
381  all_variables = K._get_variables(K.get_graph())  # pylint: disable=protected-access
382  candidate_vars = []
383  for v in all_variables:
384    if not getattr(v, '_keras_initialized', False):
385      candidate_vars.append(v)
386
387  if not candidate_vars:
388    return
389
390  while True:
391    is_initialized = session.run(
392        [variables.is_variable_initialized(v) for v in candidate_vars])
393    uninitialized_vars = []
394    for flag, v in zip(is_initialized, candidate_vars):
395      if not flag:
396        uninitialized_vars.append(v)
397      v._keras_initialized = True  # pylint: disable=protected-access
398    if not uninitialized_vars:
399      break
400
401
402def init_restore_or_wait_for_variables():
403  """Initialize or restore variables or wait for variables to be initialized."""
404  session = K._get_session()  # pylint: disable=protected-access
405  if not multi_worker_util.has_worker_context(
406  ) or multi_worker_util.should_load_checkpoint():
407    # TODO(yuefengz): if checkpoints exist, restore from checkpoint.
408    K._initialize_variables(session)  # pylint: disable=protected-access
409  else:
410    _wait_for_variable_initialization(session)
411
412
413def validate_inputs(x, y):
414  """Validate inputs when using DistributionStrategy.
415
416  Args:
417    x: Model Inputs.
418    y: Model Targets.
419
420  Raises:
421    ValueError: if input is not a Dataset or a numpy array(when we use
422      MirroredStrategy).
423  """
424  if (isinstance(x, iterator_ops.Iterator) or
425      isinstance(y, iterator_ops.Iterator)):
426    raise ValueError('`DistributionStrategy` does not support inputs of type '
427                     'Iterator. You must pass a `tf.data.Dataset` object or a '
428                     'numpy array as input.')
429
430
431def is_dataset_shape_fully_defined(dataset):
432  """Returns whether a dataset contains a final partial batch."""
433  shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(dataset))
434  unknown_shapes = [s for s in shapes if not s.is_fully_defined()]
435  return not unknown_shapes
436
437
438def process_batch_and_step_size(strategy,
439                                inputs,
440                                batch_size,
441                                steps_per_epoch,
442                                mode,
443                                validation_split=0.):
444  """Process the batch size and step size based on input and dist strategy."""
445  first_x_value = nest.flatten(inputs)[0]
446  if isinstance(first_x_value, np.ndarray):
447    num_samples = first_x_value.shape[0]
448    if validation_split and 0. < validation_split < 1.:
449      num_samples = int(num_samples * (1 - validation_split))
450    # Until support for partial batch is implemented across all
451    # functions and distribution strategy, we pass `mode` to selectively
452    # relax the constraint to consume all the training samples.
453    steps_per_epoch, batch_size = get_input_params(
454        strategy, num_samples, steps_per_epoch, batch_size, mode=mode)
455  return batch_size, steps_per_epoch
456
457
458def get_input_params(distribution_strategy,
459                     num_samples,
460                     steps,
461                     batch_size,
462                     mode=None):
463  """Calculate the number of batches and steps/steps_per_epoch.
464
465  Args:
466    distribution_strategy: The DistributionStrategy used to compile the model.
467    num_samples: The number of samples from which we determine the batch size
468      and steps.
469    steps:  The specified number of steps.
470    batch_size: The specified batch_size.
471    mode: ModeKey representing whether input will be used for training,
472      evaluation, or prediction. This is used to relax the constraints on
473      consuming all the training samples to keep compatibility till we support
474      partial batches. If none, then partial batches are not allowed.
475
476  Returns:
477    steps: The steps or steps_per_epoch argument depending on if a user is
478        calling `fit`, `evaluate` or `predict`. If the is_training flag is set
479        we don't require the number of samples to be used completely.
480    batch_size: The batch size to be used in model iterations.
481
482  Raises:
483    ValueError: If the number of batches or steps evaluates to 0.
484
485  """
486  # TODO(b/118776054): Use global batch size for Keras/DS support.
487  # Currently this is only supported in TPUStrategy and CoreMirroredStrategy.
488  use_per_replica_batch = not dist_utils.global_batch_size_supported(
489      distribution_strategy)
490
491  # TODO(b/128995245): In eager mode, uneven batch sizes are allowed except for
492  # `fit()` on TPUStrategy.
493  # In graph mode, the zero batch case in batch norm is not handled due to
494  # XLA-GPU regression. Uneven batch sizes are not allowed except
495  # for `test()` and `predict()` on TPUStrategy.
496  if context.executing_eagerly():
497    allow_partial_batch = (
498        mode != ModeKeys.TRAIN or
499        not K.is_tpu_strategy(distribution_strategy))
500  else:
501    allow_partial_batch = (
502        mode == ModeKeys.TRAIN or
503        ((mode == ModeKeys.PREDICT or mode == ModeKeys.TEST) and
504         K.is_tpu_strategy(distribution_strategy)))
505
506  if steps is None:
507    if batch_size is None:
508      # If neither the batch size or number of steps are set. We choose the
509      # global batch size as the minimum of number of samples and 32. 32 is
510      # chosen to provide backward compatibility.
511      global_batch_size = min(num_samples, 32)
512    else:
513      # If the user provided the batch size we need to handle the case
514      # between different strategies that use the global/per-replica batch size
515      global_batch_size = batch_size
516      if use_per_replica_batch:
517        global_batch_size *= distribution_strategy.num_replicas_in_sync
518    if allow_partial_batch:
519      steps = np.ceil(num_samples / global_batch_size).astype(int)
520    else:
521      if num_samples % global_batch_size:
522        raise ValueError('The number of samples %s is not divisible by '
523                         'batch size %s.' % (num_samples, global_batch_size))
524      steps = num_samples // global_batch_size
525  else:
526    if batch_size is None:
527      # We calculate the batch size based on the number of steps specified
528      if num_samples % steps:
529        raise ValueError('The number of samples %s is not divisible by '
530                         'steps %s. Please change the number of steps to a '
531                         'value that can consume all the samples' % (
532                             num_samples, steps))
533      global_batch_size = num_samples // steps
534    else:
535      # If the user provided the batch size we need to handle the case
536      # between different strategies that use the global/per-replica batch size
537      global_batch_size = batch_size
538      if use_per_replica_batch:
539        global_batch_size *= distribution_strategy.num_replicas_in_sync
540
541      min_num_samples = global_batch_size * steps
542      if allow_partial_batch:
543        min_num_samples = global_batch_size * (steps-1) + 1 if steps > 1 else 0
544
545      if num_samples < min_num_samples:
546        raise ValueError('Number of samples %s is less than samples required '
547                         'for specified batch_size %s and steps %s' % (
548                             num_samples, global_batch_size, steps))
549
550  # We need to return the per replica or global batch size based on the strategy
551  if use_per_replica_batch:
552    if global_batch_size % distribution_strategy.num_replicas_in_sync:
553      raise ValueError(
554          'The batch size (%s) could not be sharded evenly across the sync '
555          'replicas (%s) in the distribution strategy.' % (
556              global_batch_size, distribution_strategy.num_replicas_in_sync))
557    batch_size = global_batch_size // distribution_strategy.num_replicas_in_sync
558  else:
559    batch_size = global_batch_size
560
561  return steps, batch_size
562
563
564def get_batch_dimension(iterator):
565  shapes = nest.flatten(dataset_ops.get_legacy_output_shapes(iterator))
566  # Take the batch size from the first element, as it should be the same for
567  # all.
568  dims = shapes[0].dims
569  return dims[0] if dims else None
570
571
572def get_iterator(dataset, distribution_strategy):
573  with distribution_strategy.scope():
574    iterator = distribution_strategy.make_dataset_iterator(dataset)
575  initialize_iterator(iterator, distribution_strategy)
576  return iterator
577
578
579def initialize_iterator(iterator, distribution_strategy):
580  with distribution_strategy.scope():
581    init_op = control_flow_ops.group(iterator.initializer)
582    if not context.executing_eagerly():
583      K.get_session((init_op,)).run(init_op)
584
585
586def _get_input_from_iterator(iterator, model):
587  """Get elements from the iterator and verify the input shape and type."""
588  next_element = iterator.get_next()
589
590  # `len(nest.flatten(x))` is going to not count empty elements such as {}.
591  # len(nest.flatten([[0,1,2], {}])) is 3 and not 4.   The `next_element` is
592  # going to get flattened in `_prepare_feed_values` to work around that. Empty
593  # elements are going to get filtered out as part of the flattening.
594  if len(nest.flatten(next_element)) == len(model.inputs):
595    x = next_element
596    y = None
597    sample_weights = None
598  elif len(nest.flatten(next_element)) == (len(model.inputs) +
599                                           len(model.outputs)):
600    x, y = next_element
601    sample_weights = None
602  else:
603    x, y, sample_weights = next_element
604
605  # Validate that all the elements in x and y are of the same type and shape.
606  validate_distributed_dataset_inputs(
607      model._distribution_strategy, x, y, sample_weights)
608  return x, y, sample_weights
609
610
611def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
612  """Prepare feed values to the model execution function.
613
614  Args:
615    model: Model to prepare feed values for.
616    inputs: List or dict of model inputs.
617    targets: Optional list of model targets.
618    sample_weights: Optional list of sample weight arrays.
619    mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
620
621  Returns:
622    Feed values for the model in the given mode.
623  """
624  strategy = model._distribution_strategy
625  inputs, targets, sample_weights = _get_input_from_iterator(inputs, model)
626  if K.is_tpu_strategy(strategy):
627    if sample_weights is not None:
628      raise ValueError('TPUStrategy does not support sample weights.')
629
630  # When the inputs are dict, then we want to flatten it in the same order as
631  # the input layers, such that the data are fed into the input layers in the
632  # correct order.
633  if isinstance(inputs, dict):
634    inputs = [inputs[key] for key in model._feed_input_names]
635  if is_distributing_by_cloning(model):
636    inputs = flatten_per_replica_values(strategy, inputs)
637    targets = flatten_per_replica_values(strategy, targets)
638    # Expand 1-dimensional inputs.
639    # TODO(b/124535720): Remove once this standarize data logic is shared with
640    # main flow.
641    inputs, targets = nest.map_structure(
642        training_utils_v1.standardize_single_array, (inputs, targets))
643  else:
644    inputs = training_utils_v1.ModelInputs(inputs).as_list()
645
646  if mode == ModeKeys.PREDICT:
647    sample_weights = []
648    targets = []
649  elif sample_weights is not None and is_distributing_by_cloning(model):
650    if context.executing_eagerly() and not model._compile_distribution:
651      raise NotImplementedError('`sample_weight` is not supported when using '
652                                'tf.distribute.Strategy in eager mode and '
653                                'cloning=True.')
654    sample_weights = flatten_per_replica_values(strategy, sample_weights)
655
656  ins = [inputs, targets, sample_weights]
657  return tuple(ins)
658
659
660def is_distributing_by_cloning(model):
661  """Decide whether this model is going to be distributed via cloning.
662
663  We are going to distribute the model by cloning in graph mode.
664
665  Args:
666    model: Keras model to distribute.
667
668  Returns:
669    True if the `model` is going to be distributed using cloning and False
670    otherwise.
671  """
672  if (K.is_tpu_strategy(model._distribution_strategy) and
673      context.executing_eagerly):  # b/137580852
674    return False
675  elif ops.executing_eagerly_outside_functions():
676    return bool(model._compile_distribution)
677  return True
678
679
680def _custom_compile_for_predict(model):
681  """Custom compile for TPU predict mode."""
682  if not model.built:
683    # Model is not compilable because it does not know its number of inputs
684    # and outputs, nor their shapes and names. We will compile after the first
685    # time the model gets called on training data.
686    return
687  model._is_compiled = True
688  model.total_loss = None
689  model.train_function = None
690  model.test_function = None
691  model.predict_function = None
692
693
694def _build_network_on_replica(model, mode, inputs=None, targets=None):
695  """Build an updated model on replicas.
696
697  We create a new Keras model while sharing the variables from the old graph.
698  Building a new sub-graph is required since the original keras model creates
699  placeholders for the input and the output that are not accessible till we
700  call iterator.get_next() inside the step_fn for `fit`/`evaluate`/`predict`.
701
702  The sharing of weights and layers between the old and the new model guarantee
703  that we're using Strategy variables and any updates on either model are
704  reflected correctly in callbacks and loop iterations.
705
706  We need to make sure we share the optimizers between the old and the new model
707  as well so that optimizer state is not lost if the user is running fit
708  multiple times.
709
710  Args:
711    model: Model to be replicated across Replicas
712    mode: Which of fit/eval/predict is building the distributed network
713    inputs: Input variables to be passed to the model
714    targets: Target tensor to be passed to model.compile
715
716  Returns:
717    A new model with shared layers with the old model.
718  """
719  # Need to do imports here since we run into a circular dependency error.
720  from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
721  from tensorflow.python.keras.engine import sequential  # pylint: disable=g-import-not-at-top
722
723  # We rely on the internal methods to avoid having share_weights weights in the
724  # public API.
725  if isinstance(model, sequential.Sequential):
726    updated_model = models._clone_sequential_model(
727        model, input_tensors=inputs, layer_fn=models.share_weights)
728  else:
729    updated_model = models._clone_functional_model(
730        model, input_tensors=inputs, layer_fn=models.share_weights)
731    # Callable losses added directly to a functional Model need to be added
732    # here.
733    updated_model._callable_losses = model._callable_losses
734
735  # Recast all low precision outputs back to float32 since we only casted
736  # the inputs to bfloat16 and not targets. This is done so that we can preserve
737  # precision when calculating the loss value.
738  def _upcast_low_precision_outputs(output):
739    if output.dtype == dtypes.bfloat16:
740      return math_ops.cast(output, dtypes.float32)
741    else:
742      return output
743  updated_model.outputs = [_upcast_low_precision_outputs(o)
744                           for o in updated_model.outputs]
745
746  if isinstance(targets, tuple):
747    targets = nest.flatten(targets)
748
749  if mode == ModeKeys.PREDICT and inputs is not None:  # TPU predict case
750    _custom_compile_for_predict(updated_model)
751  else:
752    updated_model.compile(
753        model.optimizer,
754        model.loss,
755        metrics=metrics_module.clone_metrics(model._compile_metrics),
756        loss_weights=model.loss_weights,
757        sample_weight_mode=model.sample_weight_mode,
758        weighted_metrics=metrics_module.clone_metrics(
759            model._compile_weighted_metrics),
760        target_tensors=targets)
761  return updated_model
762
763
764def _build_distributed_network(model, strategy, mode, inputs=None,
765                               targets=None):
766  """Create a cloned model on each replica."""
767  with K.get_graph().as_default(), strategy.scope():
768    distributed_model = strategy.extended.call_for_each_replica(
769        _build_network_on_replica,
770        args=(model, mode, inputs, targets))
771    set_distributed_model(model, mode, distributed_model)
772
773
774def _clone_and_build_model(model, mode, inputs=None, targets=None):
775  """Clone and build the given keras_model."""
776  # We need to set the import here since we run into a circular dependency
777  # error.
778  from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
779  cloned_model = models.clone_model(model, input_tensors=inputs)
780
781  # Compile and build model.
782  if isinstance(model.optimizer, optimizers.TFOptimizer):
783    optimizer = model.optimizer
784  else:
785    optimizer_config = model.optimizer.get_config()
786    optimizer = model.optimizer.__class__.from_config(optimizer_config)
787
788  # Recast all low precision outputs back to float32 since we only casted
789  # the inputs to bfloat16 and not targets. This is done so that we can preserve
790  # precision when calculating the loss value.
791  def _upcast_low_precision_outputs(output):
792    if output.dtype == dtypes.bfloat16:
793      return math_ops.cast(output, dtypes.float32)
794    else:
795      return output
796  cloned_model.outputs = [_upcast_low_precision_outputs(o)
797                          for o in cloned_model.outputs]
798
799  if isinstance(targets, tuple):
800    targets = nest.flatten(targets)
801  if mode == ModeKeys.PREDICT and inputs is not None:  # TPU predict case
802    _custom_compile_for_predict(cloned_model)
803  else:
804    cloned_model.compile(
805        optimizer,
806        model.loss,
807        metrics=metrics_module.clone_metrics(model._compile_metrics),
808        loss_weights=model.loss_weights,
809        sample_weight_mode=model.sample_weight_mode,
810        weighted_metrics=metrics_module.clone_metrics(
811            model._compile_weighted_metrics),
812        target_tensors=targets)
813  return cloned_model
814
815
816def clone_model_on_replicas(model, strategy, mode, inputs=None, targets=None):
817  """Create a cloned model on each replica."""
818  with K.get_graph().as_default(), strategy.scope():
819    distributed_model = strategy.extended.call_for_each_replica(
820        _clone_and_build_model, args=(model, mode, inputs, targets))
821    set_distributed_model(model, mode, distributed_model)
822  if mode == ModeKeys.TRAIN:
823    model._make_callback_model(distributed_model)
824
825
826def _make_execution_function(model, mode):
827  """Makes or reuses function to run one step of distributed model execution."""
828  if is_distributing_by_cloning(model):
829    return _make_execution_function_with_cloning(model, mode)
830
831  distributed_function = get_distributed_function(model, mode)
832  if distributed_function:
833    return distributed_function
834
835  distribution_function = _make_execution_function_without_cloning(model, mode)
836  set_distributed_function(model, mode, distribution_function)
837  return distribution_function
838
839
840def _make_execution_function_without_cloning(model, mode):
841  """Creates a function to run one step of distributed model execution."""
842  strategy = model._distribution_strategy
843
844  with strategy.scope():
845    per_replica_function = _make_replica_execution_function(model, mode)
846
847    def distributed_function(input_fn):
848      """A single step of the distributed execution across replicas."""
849      x, y, sample_weights = input_fn()
850      # Call `Model.{train,test,predict}_on_batch` on every replica passing
851      # PerReplicas as arguments.  On every replica inside this call, each
852      # PerReplica object will return the value for that replica.  The outputs
853      # are PerReplicas too.
854      outputs = strategy.run(per_replica_function, args=(x, y, sample_weights))
855      # Out of PerReplica outputs reduce or pick values to return.
856      all_outputs = unwrap_outputs(
857          strategy, outputs, with_loss_tensor=(mode != ModeKeys.PREDICT))
858      return all_outputs
859
860    if not model.run_eagerly:
861      distributed_function = def_function.function(distributed_function)
862      def execution_function(input_fn):
863        # `numpy` translates Tensors to values in Eager mode.
864        return [out.numpy() for out in distributed_function(input_fn)]
865    else:
866      execution_function = distributed_function
867
868    return execution_function
869
870
871def _make_replica_execution_function(model, mode):
872  """A single step of the distributed execution on a replica."""
873  if mode == ModeKeys.TRAIN:
874    func = model.train_on_batch
875  elif mode == ModeKeys.TEST:
876    func = model.test_on_batch
877  else:
878
879    def predict_on_batch(x, y=None, sample_weights=None):
880      del y, sample_weights
881      return model.predict_on_batch(x)
882
883    func = predict_on_batch
884
885  if mode != ModeKeys.PREDICT:
886    # `reset_metrics` is set to False to maintain stateful metrics across
887    # batch-level calls.
888    func = functools.partial(func, reset_metrics=False)
889
890  return func
891
892
893def _make_replicated_models_with_cloning(model, mode):
894  """Build models on each replica."""
895  strategy = model._distribution_strategy
896
897  # If distributed_model is not built, create one for `mode`.
898  if model._compile_distribution:
899    clone_model_on_replicas(model, strategy, mode)
900  else:
901    _build_distributed_network(model, strategy, mode)
902
903
904def _make_execution_function_with_cloning(model, mode):
905  """Clones or re-uses models to run one step of distributed model execution."""
906  distributed_model = get_distributed_model(model, mode)
907  # TODO(b/134069401): Create a cache for the distributed model and exec
908  # function that incorporates additional attributes to be part of the cache key
909  # than just the mode.
910  # If distributed model for a particular `mode` is already built, use the
911  # `_distribution_function` on that distributed model.
912  # If you have updated the sample_weight_mode on the model, then you will need
913  # to recompile metrics and recreate the execution function. This is indicated
914  # by the `_recompile_exec_function` property.
915  if (distributed_model and hasattr(distributed_model, '_distribution_function')
916      and not (hasattr(distributed_model, '_recompile_exec_function') and
917               distributed_model._recompile_exec_function)):
918    return distributed_model._distributed_function
919
920  if not distributed_model:
921    _make_replicated_models_with_cloning(model, mode)
922    distributed_model = get_distributed_model(model, mode)
923  assert distributed_model
924
925  # Also create an execution function on that distributed model.
926  if context.executing_eagerly():
927    distributed_function = _make_eager_execution_function(model, mode)
928  else:
929    distributed_function = _make_graph_execution_function(model, mode)
930
931  # We cache the distributed execution function on the model since creating
932  # distributed models and execution functions are expensive.
933  distributed_model._distributed_function = distributed_function
934  distributed_model._recompile_exec_function = False
935  return distributed_function
936
937
938def _make_graph_execution_function(model, mode):
939  """Makes function to run one step of distributed model in graph mode."""
940
941  def _per_replica_function(model):
942    f = model._make_execution_function(mode)
943    return (f.inputs, f.outputs, f.updates_op, f.session_kwargs)
944
945  strategy = model._distribution_strategy
946  with strategy.scope():
947    # Create train ops on each of the devices when we call
948    # `_per_replica_fit_function`.
949    (grouped_inputs, grouped_outputs, grouped_updates,
950     grouped_session_args) = strategy.extended.call_for_each_replica(
951         _per_replica_function, args=(get_distributed_model(model, mode),))
952
953    # Initialize the variables in the replicated model. This is necessary for
954    # multi-worker training because on some workers, initialization is not
955    # needed. This method does initialization or waiting for initialization
956    # according to the context object of distribute coordinator.
957    init_restore_or_wait_for_variables()
958
959    # Unwrap all the per device values returned from `call_for_each_replica`.
960    # Unwrapping per device values gives you a list of values that can be
961    # used to construct a new train function that is composed of update ops on
962    # all the devices over which the model is distributed.
963    (all_inputs, all_outputs, all_updates, all_session_args) = unwrap_values(
964        strategy,
965        grouped_inputs,
966        grouped_outputs,
967        grouped_updates,
968        grouped_session_args,
969        with_loss_tensor=(mode != ModeKeys.PREDICT))
970
971    return K.function(
972        all_inputs,
973        all_outputs,
974        updates=all_updates,
975        name='distributed_{}_function'.format(mode),
976        **all_session_args)
977
978
979def _make_eager_execution_function(model, mode):
980  """Makes function to run one step of distributed model eager execution."""
981  def _per_replica_function(model):
982    f = model._make_execution_function(mode)
983    return (f.inputs, f.outputs)
984
985  # NOTE(priyag): Try creating a new FuncGraph within DS scope instead of using
986  # the global one.
987  strategy = model._distribution_strategy
988  global_graph = K.get_graph()
989
990  with global_graph.as_default(), strategy.scope():
991    # First we gather the relevant portions of the model across all replicas.
992    # `K._scratch_graph(global_graph)` signals to Keras that it should not
993    # lift to a separate graph when creating the per-replica functions.
994    with K._scratch_graph(global_graph):
995      # Create train ops on each of the devices when we call
996      # `_per_replica_fit_function`.
997      grouped = strategy.extended.call_for_each_replica(
998          _per_replica_function, args=(get_distributed_model(model, mode),))
999      grouped_inputs, grouped_outputs = grouped
1000
1001      # Unwrap all the per device values returned from `call_for_each_replica`.
1002      # Unwrapping per device values gives you a list of values that can be
1003      # used to construct a new train function that is composed of
1004      # inputs/outputs on all the devices over which the model is distributed.
1005      (all_inputs, all_outputs, _, _) = unwrap_values(
1006          strategy,
1007          grouped_inputs,
1008          grouped_outputs,
1009          with_loss_tensor=(mode != ModeKeys.PREDICT))
1010
1011    # Finally, a joint Keras function is created; this one will be created in
1012    # a separate FuncGraph.
1013    return K.function(
1014        all_inputs,
1015        all_outputs,
1016        name='eager_distributed_{}_function'.format(mode))
1017
1018
1019def _copy_weights_to_distributed_model(original_model, mode):
1020  """Copies weights from original model to distributed models."""
1021  strategy = original_model._distribution_strategy
1022  distributed_model = get_distributed_model(original_model, mode)
1023  if strategy:
1024    # Copy the weights from the original model to each of the replicated
1025    # models.
1026    orig_model_weights = original_model.get_weights()
1027    first_model = strategy.unwrap(distributed_model)[0]
1028    set_weights(strategy, first_model, orig_model_weights)
1029
1030
1031def _copy_weights_to_original_model(model, mode):
1032  """Copies weights from first distributed model back to original model."""
1033  if model._distribution_strategy and mode == ModeKeys.TRAIN:
1034    distributed_model = get_distributed_model(model, mode)
1035    updated_weights = model._distribution_strategy.unwrap(
1036        distributed_model)[0].get_weights()
1037    model.set_weights(updated_weights)
1038
1039
1040def _per_replica_aggregate_batch(strategy, batch_outs, model, mode):
1041  """Aggregates the per-replica batch-level outputs from a distributed step."""
1042  if strategy is not None and mode == ModeKeys.PREDICT:
1043    total_batch_outs = []
1044    for i in range(len(model.outputs)):
1045      num_replicas = strategy.num_replicas_in_sync
1046      nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
1047      total_batch_outs.append(
1048          concat_along_batch_dimension(nest.flatten(nested_outs)))
1049    return total_batch_outs
1050  return batch_outs
1051
1052
1053def _reset_metrics(model):
1054  if model._distribution_strategy:
1055    for mode in [ModeKeys.TRAIN, ModeKeys.TEST, ModeKeys.PREDICT]:
1056      distributed_model = get_distributed_model(model, mode)
1057      if distributed_model:
1058        first_model = model._distribution_strategy.unwrap(distributed_model)[0]
1059        first_model.reset_metrics()
1060
1061
1062def get_distributed_model(model, mode):
1063  key = _generate_cache_key(mode)
1064  return model._distributed_model_cache.get(key, None)
1065
1066
1067def set_distributed_model(model, mode, distributed_model):
1068  key = _generate_cache_key(mode)
1069  model._distributed_model_cache[key] = distributed_model
1070
1071
1072def get_distributed_function(model, mode):
1073  key = _generate_cache_key(mode)
1074  return model._distributed_function_cache.get(key, None)
1075
1076
1077def set_distributed_function(model, mode, distributed_function):
1078  key = _generate_cache_key(mode)
1079  model._distributed_function_cache[key] = distributed_function
1080
1081
1082def _generate_cache_key(mode):
1083  key = hash(mode)
1084  return key
1085
1086
1087@tf_contextlib.contextmanager
1088def distributed_scope(strategy, learning_phase):
1089  with strategy.scope(), K.learning_phase_scope(learning_phase):
1090    yield
1091
1092
1093def is_current_worker_chief():
1094  return dc_context.get_current_worker_context().is_chief
1095
1096
1097def filter_distributed_callbacks(callbacks_list, model):
1098  """Filter Callbacks based on the worker context when running multi-worker.
1099
1100  Args:
1101    callbacks_list: A list of `Callback` instances.
1102    model: Keras model instance.
1103
1104  Returns:
1105    The list of `Callback` instances that should be run on this worker.
1106  """
1107
1108  if not model._in_multi_worker_mode():
1109    raise ValueError(
1110        'filter_distributed_callbacks() should only be called when Keras '
1111        'is in multi worker mode.')
1112
1113  callbacks_list = callbacks_list or []
1114  if not [
1115      c for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint)
1116  ]:
1117    # TODO(rchao): Consider providing a ModelCheckpoint here if the user
1118    # fails to (possibly with tempfile directory).
1119    logging.warning('ModelCheckpoint callback is not provided. '
1120                    'Workers will need to restart training if any fails.')
1121
1122  if callbacks_list is None or is_current_worker_chief():
1123    return callbacks_list
1124
1125  # Some Callbacks should only run on the chief worker.
1126  return [
1127      callback for callback in callbacks_list if not callback._chief_worker_only
1128  ]  # pylint: disable=protected-access
1129
1130
1131def _update_sample_weight_modes(model, mode, sample_weights):
1132  """Update sample_weight_mode of the distributed model."""
1133  if is_distributing_by_cloning(model):
1134    distributed_model = get_distributed_model(model, mode)
1135    if not distributed_model:
1136      _make_replicated_models_with_cloning(model, mode)
1137      distributed_model = get_distributed_model(model, mode)
1138    distributed_model._recompile_exec_function = any(
1139        [e.sample_weights_mismatch() for e in model._training_endpoints])
1140
1141    if sample_weights:
1142      distributed_models = flatten_per_replica_values(
1143          model._distribution_strategy, distributed_model)
1144      # sample_weights is a tuple of 1 list where the number of elements in the
1145      # list is equal to the number of replicas in sync.
1146      sample_weights = sample_weights[0]
1147      if sample_weights and None not in sample_weights:
1148        for m, sw in zip(distributed_models, sample_weights):
1149          m._update_sample_weight_modes(sample_weights=[sw])
1150
1151
1152def concat_along_batch_dimension(outputs):
1153  """Concats prediction outputs along the batch dimension."""
1154  if isinstance(outputs[0], sparse_tensor.SparseTensor):
1155    return sparse_ops.sparse_concat_v2(axis=0, sp_inputs=outputs)
1156  if isinstance(outputs[0], ragged_tensor.RaggedTensor):
1157    return array_ops.concat(outputs, axis=0)
1158  return np.concatenate(outputs)
1159