1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Operations for TPUs."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25# pylint: disable=wildcard-import,unused-import
26from tensorflow.python.ops import gen_tpu_ops
27from tensorflow.python.ops.gen_tpu_ops import *
28# pylint: enable=wildcard-import,unused-import
29from tensorflow.python.platform import tf_logging as logging
30from tensorflow.python.tpu import tpu_function
31from tensorflow.python.util.tf_export import tf_export
32
33
34def _create_default_group_assignment():
35  num_shards = tpu_function.get_tpu_context().number_of_shards
36  if num_shards is None:
37    logging.warning(
38        "cross_replica_sum should be used within a tpu_shard_context, but "
39        "got unset number_of_shards. Assuming 1.")
40    num_shards = 1
41  group_assignment = [list(range(num_shards))]
42  return group_assignment
43
44
45def all_to_all(x,
46               concat_dimension,
47               split_dimension,
48               split_count,
49               group_assignment=None,
50               name=None):
51  """Exchange data across TPU replicas.
52
53  Args:
54    x: The local tensor.
55    concat_dimension: The dimension number to concatenate.
56    split_dimension: The dimension number to split.
57    split_count: The number of splits, this number must equal to the sub-group
58      size(group_assignment.get_shape()[1])
59    group_assignment: Optional 2d int32 lists with shape [num_groups,
60      num_replicas_per_group]. `group_assignment[i]` represents the replica
61      ids in the ith subgroup.
62    name: Optional op name.
63
64  Returns:
65    A `Tensor` which is concatenated by data from different replicas.
66  """
67  if group_assignment is None:
68    group_assignment = _create_default_group_assignment()
69  return gen_tpu_ops.all_to_all(
70      x,
71      group_assignment,
72      concat_dimension=concat_dimension,
73      split_dimension=split_dimension,
74      split_count=split_count,
75      name=name)
76
77
78@ops.RegisterGradient("AllToAll")
79def _all_to_all_grad(op, grad):
80  # The gradient of a all-to-all is also a all-to-all but the
81  # split_dimension and concat_dimension is swapped.
82  # The gradient with respect to group_assignment is None.
83  return [
84      gen_tpu_ops.all_to_all(
85          grad,
86          op.inputs[1],
87          concat_dimension=op.get_attr("split_dimension"),
88          split_dimension=op.get_attr("concat_dimension"),
89          split_count=op.get_attr("split_count")), None
90  ]
91
92
93@tf_export(v1=["tpu.cross_replica_sum"])
94def cross_replica_sum(x, group_assignment=None, name=None):
95  """Sum the input tensor across replicas according to group_assignment.
96
97  Args:
98    x: The local tensor to the sum.
99    group_assignment: Optional 2d int32 lists with shape [num_groups,
100      num_replicas_per_group]. `group_assignment[i]` represents the replica
101      ids in the ith subgroup.
102    name: Optional op name.
103
104  Returns:
105    A `Tensor` which is summed across replicas.
106  """
107  if group_assignment is None:
108    group_assignment = _create_default_group_assignment()
109
110  return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
111
112
113def collective_permute(x, source_target_pairs, name=None):
114  """Permute the input tensor across replicas given source_target_pairs.
115
116  For each source_target_pair <a, b>, we send replica a's input to replica b.
117  Each replica id must only appear once in the source column. Also it must
118  only appear once in the target column.
119  For the replica id not in the target column, this op returns a zero tensor
120  with the same shape and dtype of the input x.
121
122  For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
123  source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs:
124  `[0, A, B, C]`.
125
126  Args:
127    x: The local tensor to be permuted.
128    source_target_pairs: 2d int lists with shape [num_pairs, 2].
129      source_target_pairs[i][0] represents the source replica id and
130      source_target_pairs[i][1] represents the target replica id.
131    name: Optional op name.
132
133  Returns:
134    A `Tensor` which is permuted.
135  """
136  return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name)
137
138
139@ops.RegisterGradient("CollectivePermute")
140def _collective_permute_grad(op, grad):
141  # The gradient of a collective permute operation is also a collective
142  # permute, but with source/target pairs reversed. The gradient with respect
143  # to input argument `source_target_pairs` is `None`.
144  source_target_pairs = op.inputs[1][:, ::-1]
145  return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None]
146
147
148@ops.RegisterGradient("CrossReplicaSum")
149def _cross_replica_sum_grad(op, grad):
150  # The gradient of a cross replica sum is also a cross-replica sum.
151  # The gradient with respect to group_assignment is None.
152  return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
153
154
155# This extra type checking exists to give a more helpful error message in
156# the common case that uint8 and int64 values are infed. Remove when both
157# types are supported.
158
159_SUPPORTED_INFEED_DTYPES = set([
160    dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32,
161    dtypes.complex64, dtypes.uint32
162])
163
164
165@ops.RegisterGradient("TPUEmbeddingActivations")
166def _embedding_activations_grad(activations_op, grad_wrt_activations):
167  """Saves the gradient of embedding activations ops in a graph collection."""
168  g = ops.get_default_graph()
169  table_id = activations_op.get_attr("table_id")
170  lookup_id = activations_op.get_attr("lookup_id")
171  table_gradients = g.get_collection_ref(
172      "tpu_embedding_gradients_table_%d" % table_id)
173
174  if not table_gradients:
175    raise RuntimeError(
176        "Gradients for TPUEmbedding have been generated in non-training mode."
177        "This is not expected. Consider putting your Optimizer.minimize code "
178        "behind the training mode condition check. For Estimator, you can "
179        "do \n\n"
180        "    if mode == tf.estimator.ModeKeys.TRAIN:\n"
181        "        train_op = opt.minimize(loss)\n"
182        "\n")
183
184  table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations)
185  return [
186      # RegisterGradient requires that value be returned for all inputs. Since
187      # the first argument (tpu_gradient_variable_{table_name}) has shape [1],
188      # we will return zeros(shape=[1]). The actual gradient w.r.t. the
189      # embedding activations (grad_wrt_activations) has the same shape as the
190      # activations returned by  embedding_activations.
191      array_ops.zeros(arg.shape, dtype=dtypes.float32)
192      for arg in activations_op.inputs
193  ]
194
195
196def infeed_dequeue(dtype, shape, name=None):
197  """A placeholder op for a value that will be fed into the computation.
198
199  Args:
200    dtype: A `tf.DType`. The type of elements in the tensor.
201    shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor.
202    name: A name for the operation (optional).
203
204  Returns:
205    A `Tensor` of type `dtype`.
206    A tensor that will be provided using the infeed mechanism.
207
208  Raises:
209    TypeError: If 'dtype` is not a supported infeed type.
210  """
211  if dtype not in _SUPPORTED_INFEED_DTYPES:
212    raise TypeError(
213        "Operation '{}' has type {} which is not a supported TPU infeed type. "
214        "Supported types are: {}".format(name, dtype,
215                                         list(_SUPPORTED_INFEED_DTYPES)))
216
217  return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name)
218
219
220# pylint: disable=redefined-outer-name
221def infeed_dequeue_tuple(dtypes, shapes, name=None):
222  """A placeholder op for values fed into the TPU simultaneously as a tuple.
223
224  Args:
225    dtypes: A list of `tf.DType`s that has length `>= 1`.
226      The element types of each element in `outputs`.
227    shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`).
228      The shapes of each tensor in `outputs`.
229    name: A name for the operation (optional).
230
231  Returns:
232    A list of `Tensor` objects of type `dtypes`.
233    A list of tensors that will be provided using the infeed mechanism.
234
235  Raises:
236    TypeError: If a type in 'dtypes` is not a supported infeed type.
237  """
238  for dtype in dtypes:
239    if dtype not in _SUPPORTED_INFEED_DTYPES:
240      raise TypeError(
241          "{} is not a supported TPU infeed type. Supported types are: "
242          "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
243  return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
244# pylint: enable=redefined-outer-name
245
246
247# pylint: disable=protected-access
248def send_tpu_embedding_gradients(inputs,
249                                 config,
250                                 learning_rates=None,
251                                 name=None):
252  """A placeholder op for feeding per-sample gradients to the embedding layer.
253
254  Args:
255    inputs: A TensorList of gradients with which to update embedding tables.
256        This argument has the same length and shapes as the return value of
257        RecvTPUEmbeddingActivations, but contains gradients of the model's
258        loss with respect to the embedding activations. The embedding tables
259        are updated from these gradients via the optimizers specified in the
260        TPU embedding configuration given to tpu.initialize_system.
261    config: Serialized TPUEmbeddingConfiguration proto.
262    learning_rates: A TensorList of float32 scalars, one for each dynamic
263        learning rate tag: see the comments in
264        //third_party/tensorflow/core/protobuf/tpu/
265                                             optimization_parameters.proto.
266        Multiple tables can share the same dynamic learning rate tag as
267        specified in the configuration. If the learning rates for all tables
268        are constant, this list should be empty.
269    name: A name for the operation (optional).
270
271  Returns:
272    A SendTPUEmbeddingGradients operation.
273  """
274  if learning_rates is None:
275    learning_rates = []
276  return gen_tpu_ops.send_tpu_embedding_gradients(
277      inputs=inputs, learning_rates=learning_rates, config=config, name=name)
278
279
280send_tpu_embedding_gradients.__doc__ = (
281    gen_tpu_ops.send_tpu_embedding_gradients.__doc__)
282
283
284# pylint: disable=protected-access
285def enqueue_tpu_embedding_integer_batch(batch,
286                                        device_ordinal,
287                                        mode_override=None,
288                                        name=None):
289  """A placeholder op for enqueueing embedding IDs to the TPU.
290
291  Args:
292    batch: A list of 1D tensors, one for each embedding table, containing the
293      indices into the tables.
294    device_ordinal: The TPU device to use. Should be >= 0 and less than the
295      number of TPU cores in the task on which the node is placed.
296    mode_override: A string input that overrides the mode specified in the
297      TPUEmbeddingConfiguration. Supported values are {'unspecified',
298      'inference', 'train', 'backward_pass_only'}. When set to 'unspecified',
299      the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
300      is used (optional).
301    name: A name for the operation (optional).
302
303  Returns:
304    An EnqueueTPUEmbeddingIntegerBatch operation.
305  """
306  if mode_override is None:
307    mode_override = "unspecified"
308  return gen_tpu_ops.enqueue_tpu_embedding_integer_batch(
309      batch=batch,
310      device_ordinal=device_ordinal,
311      mode_override=mode_override,
312      name=name)
313
314
315enqueue_tpu_embedding_integer_batch.__doc__ = (
316    gen_tpu_ops.enqueue_tpu_embedding_integer_batch.__doc__)
317
318
319# pylint: disable=protected-access
320def enqueue_tpu_embedding_sparse_batch(sample_indices,
321                                       embedding_indices,
322                                       aggregation_weights,
323                                       device_ordinal,
324                                       combiners=None,
325                                       mode_override=None,
326                                       name=None):
327  """A placeholder op for enqueueing embedding IDs to the TPU.
328
329  Args:
330    sample_indices: A list of rank 1 Tensors specifying the training example
331      and feature to which the corresponding embedding_indices and
332      aggregation_weights values belong. sample_indices[i] must equal b * nf +
333      f, where nf is the number of features from the corresponding table, f is
334      in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed,
335      and will be converted to int32 internally.
336    embedding_indices: A list of rank 1 Tensors, indices into the embedding
337      tables. Both int32 and int64 are allowed and will be converted to int32
338      internally.
339    aggregation_weights: A list of rank 1 Tensors containing per sample --
340      i.e. per (training example, feature) -- aggregation weights. Both float32
341      and float64 are allowed and will be converted to float32 internally.
342    device_ordinal: The TPU device to use. Should be >= 0 and less than the
343      number of TPU cores in the task on which the node is placed.
344    combiners: A list of string scalars, one for each embedding table that
345      specify how to normalize the embedding activations after weighted
346      summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
347      invalid to have the sum of the weights be 0 for 'mean' or the sum of the
348      squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
349      is to use 'sum' for all tables (optional).
350    mode_override: A string input that overrides the mode specified in the
351      TPUEmbeddingConfiguration. Supported values are {'unspecified',
352      'inference', 'train', 'backward_pass_only'}. When set to 'unspecified',
353      the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
354      is used (optional).
355    name: A name for the operation (optional).
356
357  Returns:
358    An EnqueueTPUEmbeddingSparseBatch operation.
359  """
360  if mode_override is None:
361    mode_override = "unspecified"
362  return gen_tpu_ops.enqueue_tpu_embedding_sparse_batch(
363      sample_indices=sample_indices,
364      embedding_indices=embedding_indices,
365      aggregation_weights=aggregation_weights,
366      device_ordinal=device_ordinal,
367      combiners=combiners,
368      mode_override=mode_override,
369      name=name)
370
371
372enqueue_tpu_embedding_sparse_batch.__doc__ = (
373    gen_tpu_ops.enqueue_tpu_embedding_sparse_batch.__doc__)
374
375
376# pylint: disable=protected-access
377def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
378                                              embedding_indices,
379                                              aggregation_weights,
380                                              table_ids,
381                                              device_ordinal,
382                                              max_sequence_lengths=None,
383                                              num_features=None,
384                                              combiners=None,
385                                              mode_override=None,
386                                              name=None):
387  """A placeholder op for enqueueing embedding IDs to the TPU.
388
389  Args:
390    sample_indices: A list of rank 2 Tensors specifying the training example
391      to which the corresponding embedding_indices and aggregation_weights
392      values belong. It corresponds to sp_ids.indices in
393      embedding_lookup_sparse(). If the size of its first dimension is 0, we
394      assume each embedding_indices belongs to a different sample. Both int32
395      and int64 are allowed and will be converted to int32 internally.
396    embedding_indices: A list of rank 1 Tensors, indices into the embedding
397      tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both
398      int32 and int64 are allowed and will be converted to int32 internally.
399    aggregation_weights: A list of rank 1 Tensors containing per training
400      example aggregation weights. It corresponds to sp_weights.values in
401      embedding_lookup_sparse(). If the size of its first dimension is 0, we
402      assume all weights are 1. Both float32 and float64 are allowed and will
403      be converted to float32 internally.
404    table_ids: A list of integers specifying the identifier of the embedding
405      table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
406      lookup the corresponding input. The ith input is looked up using
407      table_ids[i]. The size of the table_ids list must be equal to that of
408      sample_indices, embedding_indices and aggregation_weights.
409    device_ordinal: The TPU device to use. Should be >= 0 and less than the
410      number of TPU cores in the task on which the node is placed.
411    max_sequence_lengths: A list of integers, the size of which is equal to
412      sample_indices. If equal to 0, the corresponding feature is considered to
413      be a non-sequence feature, If greater than 0, the corresponding feature is
414      a sequence feature with the given maximal length. If None, then we assume
415      a list of all zeroes.
416    num_features: A list of integers, the size of which is equal to
417      sample_indices. If non-empty, entries in this list must be at least 1.
418      For each batch element, we will take num_features rows of the input
419      tensor for embedding lookup. E.g., when sample_indices is empty,
420      the embedding indices must be of shape (batch_size*num_features).
421    combiners: A list of string scalars, one for each embedding table that
422      specify how to normalize the embedding activations after weighted
423      summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
424      invalid to have the sum of the weights be 0 for 'mean' or the sum of the
425      squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
426      is to use 'sum' for all tables (optional).
427    mode_override: A string input that overrides the mode specified in the
428      TPUEmbeddingConfiguration. Supported values are {'unspecified',
429      'inference', 'train', 'backward_pass_only'}. When set to 'unspecified',
430      the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
431      is used (optional).
432    name: A name for the operation (optional).
433
434  Returns:
435    An EnqueueTPUEmbeddingSparseTensorBatch operation.
436  """
437  if mode_override is None:
438    mode_override = "unspecified"
439  return gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
440      sample_indices=sample_indices,
441      embedding_indices=embedding_indices,
442      aggregation_weights=aggregation_weights,
443      table_ids=table_ids,
444      device_ordinal=device_ordinal,
445      max_sequence_lengths=max_sequence_lengths,
446      combiners=combiners,
447      mode_override=mode_override,
448      num_features=num_features,
449      name=name)
450
451
452enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = (
453    gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__)
454
455
456# pylint: disable=protected-access
457def enqueue_tpu_embedding_ragged_tensor_batch(sample_splits,
458                                              embedding_indices,
459                                              aggregation_weights,
460                                              table_ids,
461                                              device_ordinal,
462                                              max_sequence_lengths=None,
463                                              num_features=None,
464                                              combiners=None,
465                                              mode_override=None,
466                                              name=None):
467  """A placeholder op for enqueueing embedding IDs to the TPU.
468
469  Args:
470    sample_splits: A list of rank 1 Tensors specifying the break points for
471      splitting embedding_indices and aggregation_weights into rows. It
472      corresponds to ids.row_splits in embedding_lookup(), when ids is a
473      RaggedTensor. Both int32 and int64 are allowed and will be converted to
474      int32 internally.
475    embedding_indices: A list of rank 1 Tensors, indices into the embedding
476      tables. It corresponds to ids.values in embedding_lookup(), when ids is a
477      RaggedTensor. Both int32 and int64 are allowed and will be converted to
478      int32 internally.
479    aggregation_weights: A list of rank 1 Tensors containing per training
480      example aggregation weights. It corresponds to the values field of a
481      RaggedTensor with the same row_splits as ids in embedding_lookup(), when
482      ids is a RaggedTensor. Both float32 and float64 are allowed and will be
483      converted to float32 internally.
484    table_ids: A list of integers specifying the identifier of the embedding
485      table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
486      lookup the corresponding input. The ith input is looked up using
487      table_ids[i]. The size of the table_ids list must be equal to that of
488      sample_indices, embedding_indices and aggregation_weights.
489    device_ordinal: The TPU device to use. Should be >= 0 and less than the
490      number of TPU cores in the task on which the node is placed.
491    max_sequence_lengths: A list of integers, the size of which is equal to
492      sample_indices. If equal to 0, the corresponding feature is considered to
493      be a non-sequence feature, If greater than 0, the corresponding feature is
494      a sequence feature with the given maximal length. If None, then we assume
495      a list of all zeroes.
496    num_features: A list of integers, the size of which must be equal to
497      sample_indices. If non-empty, entries in this list must be at least 1.
498      For each batch element, we will take num_features rows of the input
499      tensor for embedding lookup. E.g., when sample_indices is empty,
500      the embedding indices must be of shape (batch_size*num_features).
501    combiners: A list of string scalars, one for each embedding table that
502      specify how to normalize the embedding activations after weighted
503      summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
504      invalid to have the sum of the weights be 0 for 'mean' or the sum of the
505      squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
506      is to use 'sum' for all tables (optional).
507    mode_override: A string input that overrides the mode specified in the
508      TPUEmbeddingConfiguration. Supported values are {'unspecified',
509      'inference', 'training', 'backward_pass_only'}. When set to 'unspecified',
510      the mode set in TPUEmbeddingConfiguration is used, otherwise mode_override
511      is used (optional).
512    name: A name for the operation (optional).
513
514  Returns:
515    An EnqueueTPUEmbeddingRaggedTensorBatch operation.
516  """
517  if mode_override is None:
518    mode_override = "unspecified"
519  return gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch(
520      sample_splits=sample_splits,
521      embedding_indices=embedding_indices,
522      aggregation_weights=aggregation_weights,
523      table_ids=table_ids,
524      device_ordinal=device_ordinal,
525      max_sequence_lengths=max_sequence_lengths,
526      combiners=combiners,
527      mode_override=mode_override,
528      num_features=num_features,
529      name=name)
530
531
532enqueue_tpu_embedding_ragged_tensor_batch.__doc__ = (
533    gen_tpu_ops.enqueue_tpu_embedding_ragged_tensor_batch.__doc__)
534