1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU Strategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23
24from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
25from tensorflow.python.distribute import device_util
26from tensorflow.python.distribute import distribute_lib
27from tensorflow.python.distribute import input_lib
28from tensorflow.python.distribute import numpy_dataset
29from tensorflow.python.distribute import reduce_util
30from tensorflow.python.distribute import values
31from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
32from tensorflow.python.eager import context
33from tensorflow.python.eager import tape
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import device as tf_device
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_util
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import control_flow_ops
41from tensorflow.python.ops import math_ops
42from tensorflow.python.ops import variable_scope as vs
43from tensorflow.python.tpu import device_assignment as device_assignment_lib
44from tensorflow.python.tpu import tpu
45from tensorflow.python.tpu import tpu_strategy_util
46from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
47from tensorflow.python.tpu import training_loop
48from tensorflow.python.tpu.ops import tpu_ops
49from tensorflow.python.util import nest
50from tensorflow.python.util.tf_export import tf_export
51
52
53def get_tpu_system_metadata(tpu_cluster_resolver):
54  """Retrieves TPU system metadata given a TPUClusterResolver."""
55  master = tpu_cluster_resolver.master()
56
57  # pylint: disable=protected-access
58  cluster_spec = tpu_cluster_resolver.cluster_spec()
59  cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
60  tpu_system_metadata = (
61      tpu_system_metadata_lib._query_tpu_system_metadata(
62          master,
63          cluster_def=cluster_def,
64          query_topology=False))
65
66  return tpu_system_metadata
67
68
69# TODO(jhseu): Deduplicate with MirroredStrategy?
70def _create_tpu_mirrored_variable(  # pylint: disable=missing-docstring
71    strategy, device_map, logical_device, real_mirrored_creator,
72    *args, **kwargs):
73  # Figure out what collections this variable should be added to.
74  # We'll add the TPUMirroredVariable to those collections instead.
75  var_collections = kwargs.pop("collections", None)
76  if var_collections is None:
77    var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
78  kwargs["collections"] = []
79
80  # TODO(jhseu): Should we have different behavior for different
81  # synchronization settings?
82
83  # Get aggregation value
84  # TODO(jhseu): Support aggregation in a replica context.
85  aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
86  if aggregation not in [
87      vs.VariableAggregation.NONE,
88      vs.VariableAggregation.SUM,
89      vs.VariableAggregation.MEAN,
90      vs.VariableAggregation.ONLY_FIRST_REPLICA,
91  ]:
92    raise ValueError("Invalid variable aggregation mode: {} for variable: {}"
93                     .format(aggregation, kwargs["name"]))
94
95  # Ignore user-specified caching device, not needed for mirrored variables.
96  kwargs.pop("caching_device", None)
97
98  # TODO(josh11b,apassos): It would be better if variable initialization
99  # was never recorded on the tape instead of having to do this manually
100  # here.
101  with tape.stop_recording():
102    devices = device_map.logical_to_actual_devices(logical_device)
103    value_list = real_mirrored_creator(devices, *args, **kwargs)
104    result = values.TPUMirroredVariable(
105        strategy, device_map, value_list, aggregation,
106        logical_device=logical_device)
107
108  if not (context.executing_eagerly() or ops.inside_function()):
109    g = ops.get_default_graph()
110    # If "trainable" is True, next_creator() will add the member variables
111    # to the TRAINABLE_VARIABLES collection, so we manually remove
112    # them and replace with the MirroredVariable. We can't set
113    # "trainable" to False for next_creator() since that causes functions
114    # like implicit_gradients to skip those variables.
115    if kwargs.get("trainable", True):
116      var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
117      l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
118      for v in value_list:
119        l.remove(v)
120    g.add_to_collections(var_collections, result)
121  return result
122
123
124@tf_export("distribute.experimental.TPUStrategy")
125class TPUStrategy(distribute_lib.DistributionStrategy):
126  """TPU distribution strategy implementation."""
127
128  def __init__(self,
129               tpu_cluster_resolver=None,
130               steps_per_run=None,
131               device_assignment=None):
132    """Initializes the TPUStrategy object.
133
134    Args:
135      tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
136          which provides information about the TPU cluster.
137      steps_per_run: Number of steps to run on device before returning to the
138          host. Note that this can have side-effects on performance, hooks,
139          metrics, summaries etc.
140          This parameter is only used when Distribution Strategy is used with
141          estimator or keras.
142      device_assignment: Optional `tf.contrib.tpu.DeviceAssignment` to specify
143          the placement of replicas on the TPU cluster. Currently only supports
144          the usecase of using a single core within a TPU cluster.
145    """
146    super(TPUStrategy, self).__init__(TPUExtended(
147        self, tpu_cluster_resolver, steps_per_run, device_assignment))
148
149  @property
150  def steps_per_run(self):
151    """DEPRECATED: use .extended.steps_per_run instead."""
152    return self._extended.steps_per_run
153
154  # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this
155  # can use the default implementation.
156  # This implementation runs a single step. It does not use infeed or outfeed.
157  def experimental_run_v2(self, fn, args=(), kwargs=None):
158    """See base class."""
159    if context.executing_eagerly() and not ops.inside_function():
160      raise NotImplementedError(
161          "Eager mode not supported in TPUStrategy outside TF functions.")
162
163    if kwargs is None:
164      kwargs = {}
165
166    result = [None]
167    def replicated_fn(replica_id, replica_args, replica_kwargs):
168      """Wraps user function to provide replica ID and `Tensor` inputs."""
169      with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id):
170        result[0] = fn(*replica_args, **replica_kwargs)
171      return result[0]
172
173    replicate_inputs = []  # By replica.
174    for i in range(self.num_replicas_in_sync):
175      replicate_inputs.append(
176          [constant_op.constant(i, dtype=dtypes.int32),
177           values.select_replica(i, args),
178           values.select_replica(i, kwargs)])
179
180    with self.scope():
181      replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs)
182
183    # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
184    replicate_outputs = [
185        nest.pack_sequence_as(result[0], nest.flatten(replica_outputs))
186        for replica_outputs in replicate_outputs]
187
188    device_map = self.extended._device_map  # pylint: disable=protected-access
189    return values.regroup(device_map, replicate_outputs)
190
191
192class TPUExtended(distribute_lib.DistributionStrategyExtended):
193  """Implementation of TPUStrategy."""
194
195  def __init__(self,
196               container_strategy,
197               tpu_cluster_resolver=None,
198               steps_per_run=None,
199               device_assignment=None):
200    super(TPUExtended, self).__init__(container_strategy)
201
202    if tpu_cluster_resolver is None:
203      tpu_cluster_resolver = TPUClusterResolver("")
204
205    if steps_per_run is None:
206      # TODO(frankchn): Warn when we are being used by DS/Keras and this is
207      # not specified.
208      steps_per_run = 1
209
210    self._tpu_cluster_resolver = tpu_cluster_resolver
211    self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver)
212    self._device_assignment = device_assignment
213
214    # Device assignment is currently only supported for 1 core case.
215    if self._device_assignment:
216      assert isinstance(self._device_assignment,
217                        device_assignment_lib.DeviceAssignment)
218      if self._device_assignment.num_replicas != 1:
219        raise ValueError("Device assignment is only supported for a single "
220                         "core single replica case currently.")
221      if self._device_assignment.num_cores_per_replica != 1:
222        raise ValueError("Device assignment is only supported for a single "
223                         "core single replica case currently.")
224      if not all(self._device_assignment.core_assignment[0][0] == [0, 0, 0]):
225        raise ValueError("Device assignment is only supported for a single "
226                         "core single replica case currently.")
227
228    # TODO(jhseu): Switch to DeviceAssignment to support pods and model
229    # parallelism.
230    self._tpu_devices = [d.name for d in self._tpu_metadata.devices
231                         if "device:TPU:" in d.name]
232
233    self._host_device = tpu_strategy_util.get_first_tpu_host_device(
234        self._tpu_cluster_resolver)
235
236    # Only create variables for the number of replicas we're running.
237    self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync]
238    self._device_map = values.ReplicaDeviceMap(self._tpu_devices)
239
240    # Preload the data onto the TPUs.
241    input_worker_devices = collections.OrderedDict()
242    for tpu_device in self._tpu_devices:
243      host_device = _get_host_for_device(tpu_device)
244      input_worker_devices.setdefault(host_device, [])
245      input_worker_devices[host_device].append(tpu_device)
246    self._input_workers = input_lib.InputWorkers(
247        self._device_map, tuple(input_worker_devices.items()))
248
249    # TODO(sourabhbajaj): Remove this once performance of running one step
250    # at a time is comparable to multiple steps.
251    self.steps_per_run = steps_per_run
252    self._require_static_shapes = True
253
254  def _validate_colocate_with_variable(self, colocate_with_variable):
255    values.validate_colocate_tpu_variable(colocate_with_variable, self)
256
257  def _make_dataset_iterator(self, dataset):
258    """Make iterators for each of the TPU hosts."""
259    return input_lib.DatasetIterator(dataset, self._input_workers,
260                                     self._num_replicas_in_sync)
261
262  def _make_input_fn_iterator(
263      self,
264      input_fn,
265      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
266    input_contexts = []
267    num_workers = self._input_workers.num_workers
268    for i in range(num_workers):
269      input_contexts.append(distribute_lib.InputContext(
270          num_input_pipelines=num_workers,
271          input_pipeline_id=i,
272          num_replicas_in_sync=self._num_replicas_in_sync))
273    return input_lib.InputFunctionIterator(
274        input_fn, self._input_workers, input_contexts)
275
276  def _experimental_make_numpy_dataset(self, numpy_input, session):
277    return numpy_dataset.one_host_numpy_dataset(
278        numpy_input, numpy_dataset.SingleDevice(self._host_device),
279        session)
280
281  # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
282  # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
283  # a mechanism to infer the outputs of `fn`. Pending b/110550782.
284  def _experimental_run_steps_on_iterator(
285      self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
286    output_shapes = multi_worker_iterator.output_shapes
287    shapes = nest.flatten(output_shapes)
288    if any(not s.is_fully_defined() for s in shapes):
289      raise ValueError(
290          "TPU currently requires fully defined shapes. Either use "
291          "set_shape() on the input tensors or use "
292          "dataset.batch(..., drop_remainder=True).")
293
294    # Wrap `fn` for repeat.
295    if initial_loop_values is None:
296      initial_loop_values = {}
297    initial_loop_values = nest.flatten(initial_loop_values)
298    ctx = input_lib.MultiStepContext()
299
300    def run_fn(inputs):
301      """Single step on the TPU device."""
302      fn_result = fn(ctx, inputs)
303      flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
304      if flat_last_step_outputs:
305        with ops.control_dependencies([fn_result]):
306          return [array_ops.identity(f) for f in flat_last_step_outputs]
307      else:
308        return fn_result
309
310    # We capture the control_flow_context at this point, before we run `fn`
311    # inside a while_loop and TPU replicate context. This is useful in cases
312    # where we might need to exit these contexts and get back to the outer
313    # context to do some things, for e.g. create an op which should be
314    # evaluated only once at the end of the loop on the host. One such usage
315    # is in creating metrics' value op.
316    self._outer_control_flow_context = (
317        ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access
318
319    def rewrite_fn(*args):
320      """The rewritten step fn running on TPU."""
321      del args
322
323      per_replica_inputs = multi_worker_iterator.get_next()
324      replicate_inputs = []
325      for replica_id in range(self._num_replicas_in_sync):
326        select_replica = lambda x: values.select_replica(replica_id, x)  # pylint: disable=cell-var-from-loop
327        replicate_inputs.append((nest.map_structure(
328            select_replica, per_replica_inputs),))
329
330      replicate_outputs = tpu.replicate(run_fn, replicate_inputs)
331
332      # If run_fn has tensor outputs, tpu.replicate returns a list of list. We
333      # will flatten it in this case. If run_fn has no tensor outputs,
334      # tpu.replicate returns a list of no_ops, we will keep the output as it
335      # is.
336      if isinstance(replicate_outputs[0], list):
337        replicate_outputs = nest.flatten(replicate_outputs)
338
339      return replicate_outputs
340
341    # TODO(sourabhbajaj): The input to while loop should be based on the
342    # output type of the step_fn
343    assert isinstance(initial_loop_values, list)
344    initial_loop_values = initial_loop_values * self._num_replicas_in_sync
345
346    # Put the while loop op on TPU host 0.
347    with ops.device(self._host_device):
348      if self.steps_per_run == 1:
349        replicate_outputs = rewrite_fn()
350      else:
351        replicate_outputs = training_loop.repeat(iterations, rewrite_fn,
352                                                 initial_loop_values)
353
354    del self._outer_control_flow_context
355    ctx.run_op = control_flow_ops.group(replicate_outputs)
356
357    if isinstance(replicate_outputs, list):
358      # Filter out any ops from the outputs, typically this would be the case
359      # when there were no tensor outputs.
360      last_step_tensor_outputs = [
361          x for x in replicate_outputs if not isinstance(x, ops.Operation)
362      ]
363
364      # Outputs are currently of the structure (flattened)
365      # [output0_device0, output1_device0, output2_device0,
366      #  output0_device1, output1_device1, output2_device1,
367      #  ...]
368      # Convert this to the following structure instead: (grouped by output)
369      # [[output0_device0, output0_device1],
370      #  [output1_device0, output1_device1],
371      #  [output2_device0, output2_device1]]
372      output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync
373      last_step_tensor_outputs = [
374          last_step_tensor_outputs[i::output_num] for i in range(output_num)
375      ]
376    else:
377      # no tensors returned.
378      last_step_tensor_outputs = []
379
380    _set_last_step_outputs(ctx, last_step_tensor_outputs)
381    return ctx
382
383  def _call_for_each_replica(self, fn, args, kwargs):
384    # TODO(jhseu): Consider making it so call_for_each_replica implies that
385    # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly.
386    with _TPUReplicaContext(self._container_strategy()):
387      return fn(*args, **kwargs)
388
389  def _experimental_initialize_system(self):
390    """Experimental method added to be used by Estimator.
391
392    This is a private method only to be used by Estimator. Other frameworks
393    should directly be calling `tf.contrib.distribute.initialize_tpu_system`
394    """
395    tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver)
396
397  def _create_variable(self, next_creator, *args, **kwargs):
398    """Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
399    colocate_with = kwargs.pop("colocate_with", None)
400    if colocate_with is None:
401      device_map = self._device_map
402      logical_device = 0  # TODO(josh11b): Get logical device from scope here.
403    elif isinstance(colocate_with, numpy_dataset.SingleDevice):
404      with ops.device(colocate_with.device):
405        return next_creator(*args, **kwargs)
406    else:
407      device_map = colocate_with.device_map
408      logical_device = colocate_with.logical_device
409
410    def _real_mirrored_creator(devices, *args, **kwargs):  # pylint: disable=g-missing-docstring
411      value_list = []
412      for i, d in enumerate(devices):
413        with ops.device(d):
414          if i > 0:
415            # Give replicas meaningful distinct names:
416            var0name = value_list[0].name.split(":")[0]
417            # We append a / to variable names created on replicas with id > 0 to
418            # ensure that we ignore the name scope and instead use the given
419            # name as the absolute name of the variable.
420            kwargs["name"] = "%s/replica_%d/" % (var0name, i)
421            # Initialize replicas with the same value:
422            if context.executing_eagerly() or ops.inside_function():
423              with ops.init_scope():
424                kwargs["initial_value"] = array_ops.identity(
425                    value_list[0].value())
426            else:
427              def initial_value_fn(device=d):
428                with ops.device(device):
429                  return array_ops.identity(value_list[0].initial_value)
430              kwargs["initial_value"] = initial_value_fn
431          with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
432            v = next_creator(*args, **kwargs)
433          assert not isinstance(v, values.TPUMirroredVariable)
434          value_list.append(v)
435      return value_list
436
437    return _create_tpu_mirrored_variable(
438        self._container_strategy(), device_map, logical_device,
439        _real_mirrored_creator, *args, **kwargs)
440
441  def _reduce_to(self, reduce_op, value, destinations):
442    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
443      if reduce_op == reduce_util.ReduceOp.MEAN:
444        # TODO(jhseu):  Revisit once we support model-parallelism.
445        value *= (1. / self._num_replicas_in_sync)
446      elif reduce_op != reduce_util.ReduceOp.SUM:
447        raise NotImplementedError(
448            "Currently only support sum & mean in TPUStrategy.")
449      return tpu_ops.cross_replica_sum(value)
450
451    if not isinstance(value, values.DistributedValues):
452      # This function handles reducing values that are not PerReplica or
453      # Mirrored values. For example, the same value could be present on all
454      # replicas in which case `value` would be a single value or value could
455      # be 0.
456      return cross_device_ops_lib.reduce_non_distributed_value(
457          reduce_op, self._device_map, value, destinations)
458
459    devices = cross_device_ops_lib.get_devices_from(destinations)
460    if len(devices) != 1:
461      raise ValueError("Multiple devices are not supported for TPUStrategy")
462
463    # Always performs the reduction on the TPU host.
464    with ops.device(self._host_device):
465      output = math_ops.add_n(value.values)
466      if reduce_op == reduce_util.ReduceOp.MEAN:
467        output *= (1. / len(value.values))
468
469    # If necessary, copy to requested destination.
470    dest_canonical = device_util.canonicalize(devices[0])
471    host_canonical = device_util.canonicalize(self._host_device)
472
473    if dest_canonical != host_canonical:
474      with ops.device(devices[0]):
475        output = array_ops.identity(output)
476
477    return output
478
479  def _update(self, var, fn, args, kwargs, group):
480    assert isinstance(var, values.TPUMirroredVariable)
481    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
482      if group:
483        return fn(var, *args, **kwargs)
484      else:
485        return (fn(var, *args, **kwargs),)
486
487    # Otherwise, we revert to MirroredStrategy behavior and update each variable
488    # directly.
489    updates = []
490    for i, (d, v) in enumerate(zip(var.devices, var.values)):
491      name = "update_%d" % i
492      with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
493        # If args and kwargs are not mirrored, the value is returned as is.
494        updates.append(fn(v,
495                          *values.select_device_mirrored(d, args),
496                          **values.select_device_mirrored(d, kwargs)))
497    return values.update_regroup(self, self._device_map, updates, group)
498
499  def read_var(self, var):
500    assert isinstance(var, values.TPUMirroredVariable)
501    return var.read_value()
502
503  def _local_results(self, val):
504    if isinstance(val, values.DistributedValues):
505      # Return in a deterministic order.
506      return tuple(val.get(device=d) for d in sorted(val.devices))
507    elif isinstance(val, list):
508      # TODO(josh11b): We need to remove this case; per device values should
509      # be represented using a PerReplica wrapper instead of a list with
510      # one entry per device.
511      return tuple(val)
512    elif isinstance(val, values.TPUMirroredVariable):
513      # pylint: disable=protected-access
514      if values._enclosing_tpu_context() is not None:
515        return (val,)
516      return val.values
517    return (val,)
518
519  def value_container(self, value):
520    return value
521
522  def _broadcast_to(self, tensor, destinations):
523    del destinations
524    return tensor
525
526  @property
527  def num_hosts(self):
528    if self._device_assignment is None:
529      return self._tpu_metadata.num_hosts
530
531    return len(set([self._device_assignment.host_device(r)
532                    for r in range(self._device_assignment.num_replicas)]))
533
534  @property
535  def num_replicas_per_host(self):
536    if self._device_assignment is None:
537      return self._tpu_metadata.num_of_cores_per_host
538
539    # TODO(sourabhbajaj): Remove this method we use inputs and remove infeed
540    # as the computation of num_replicas_per_host is not a constant
541    # when using device_assignment. This is a temporary workaround to support
542    # StatefulRNN as everything is 1 in that case.
543    # This method needs to take host_id as input for correct computation.
544    max_models_per_host = (self._tpu_metadata.num_of_cores_per_host //
545                           self._device_assignment.num_cores_per_replica)
546    models_per_host = min(self._device_assignment.num_replicas,
547                          max_models_per_host)
548    return models_per_host * self._device_assignment.num_cores_per_replica
549
550  @property
551  def _num_replicas_in_sync(self):
552    if self._device_assignment is None:
553      return self._tpu_metadata.num_cores
554    return (self._device_assignment.num_replicas *
555            self._device_assignment.num_cores_per_replica)
556
557  @property
558  def experimental_between_graph(self):
559    return False
560
561  @property
562  def experimental_should_init(self):
563    return True
564
565  @property
566  def should_checkpoint(self):
567    return True
568
569  @property
570  def should_save_summary(self):
571    return True
572
573  @property
574  def worker_devices(self):
575    return self._tpu_devices
576
577  @property
578  def parameter_devices(self):
579    return self._tpu_devices
580
581  def non_slot_devices(self, var_list):
582    return self._host_device
583
584  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
585    del colocate_with
586    with ops.device(self._host_device), distribute_lib.UpdateContext(
587        self._host_device):
588      result = fn(*args, **kwargs)
589      if group:
590        return result
591      else:
592        return nest.map_structure(self._local_results, result)
593
594  def _configure(self,
595                 session_config=None,
596                 cluster_spec=None,
597                 task_type=None,
598                 task_id=None):
599    del cluster_spec, task_type, task_id
600    if session_config:
601      session_config.CopyFrom(self._update_config_proto(session_config))
602
603  def _update_config_proto(self, config_proto):
604    updated_config = copy.deepcopy(config_proto)
605    updated_config.isolate_session_state = True
606    cluster_spec = self._tpu_cluster_resolver.cluster_spec()
607    if cluster_spec:
608      updated_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
609    return updated_config
610
611  # TODO(priyag): Delete this once all strategies use global batch size.
612  @property
613  def _global_batch_size(self):
614    """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
615
616    `make_input_fn_iterator` assumes per-replica batching.
617
618    Returns:
619      Boolean.
620    """
621    return True
622
623
624class _TPUReplicaContext(distribute_lib.ReplicaContext):
625  """Replication Context class for TPU Strategy."""
626
627  # TODO(sourabhbajaj): Call for each replica should be updating this.
628  # TODO(b/118385803): Always properly initialize replica_id.
629  def __init__(self, strategy, replica_id_in_sync_group=None):
630    if replica_id_in_sync_group is None:
631      replica_id_in_sync_group = constant_op.constant(0, dtypes.int32)
632    distribute_lib.ReplicaContext.__init__(
633        self, strategy, replica_id_in_sync_group=replica_id_in_sync_group)
634
635  @property
636  def devices(self):
637    distribute_lib.require_replica_context(self)
638    ds = self._strategy
639    replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)
640
641    if replica_id is None:  # Non-constant `Tensor` inside `tpu.replicate`.
642      # TODO(cjfj): Return other devices when model parallelism is supported.
643      return (tpu.core(0),)
644    else:
645      return (ds.extended.worker_devices[replica_id],)
646
647
648def _get_host_for_device(device):
649  spec = tf_device.DeviceSpec.from_string(device)
650  return tf_device.DeviceSpec(
651      job=spec.job, replica=spec.replica, task=spec.task,
652      device_type="CPU", device_index=0).to_string()
653
654
655def _set_last_step_outputs(ctx, last_step_tensor_outputs):
656  """Sets the last step outputs on the given context."""
657  # Convert replicate_outputs to the original dict structure of
658  # last_step_outputs.
659  last_step_tensor_outputs_dict = nest.pack_sequence_as(
660      ctx.last_step_outputs, last_step_tensor_outputs)
661
662  for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
663    output = last_step_tensor_outputs_dict[name]
664    # For outputs that have already been reduced, take the first value
665    # from the list as each value should be the same. Else return the full
666    # list of values.
667    # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica
668    # value.
669    if reduce_op is not None:
670      # TODO(priyag): Should this return the element or a list with 1 element
671      last_step_tensor_outputs_dict[name] = output[0]
672  ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
673