1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Operations for TPUs."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25# pylint: disable=wildcard-import,unused-import
26from tensorflow.python.ops import gen_tpu_ops
27from tensorflow.python.ops.gen_tpu_ops import *
28# pylint: enable=wildcard-import,unused-import
29from tensorflow.python.platform import tf_logging as logging
30from tensorflow.python.tpu import tpu_function
31
32
33def _create_default_group_assignment():
34  num_shards = tpu_function.get_tpu_context().number_of_shards
35  if num_shards is None:
36    logging.warning(
37        "cross_replica_sum should be used within a tpu_shard_context, but "
38        "got unset number_of_shards. Assuming 1.")
39    num_shards = 1
40  group_assignment = [list(range(num_shards))]
41  return group_assignment
42
43
44def all_to_all(x,
45               concat_dimension,
46               split_dimension,
47               split_count,
48               group_assignment=None,
49               name=None):
50  """Exchange data across TPU replicas.
51
52  Args:
53    x: The local tensor.
54    concat_dimension: The dimension number to concatenate.
55    split_dimension: The dimension number to split.
56    split_count: The number of splits, this number must equal to the sub-group
57      size(group_assignment.get_shape()[1])
58    group_assignment: Optional 2d int32 lists with shape [num_groups,
59      num_replicas_per_group]. `group_assignment[i]` represents the replica
60      ids in the ith subgroup.
61    name: Optional op name.
62
63  Returns:
64    A `Tensor` which is concatenated by data from different replicas.
65  """
66  if group_assignment is None:
67    group_assignment = _create_default_group_assignment()
68  return gen_tpu_ops.all_to_all(
69      x,
70      group_assignment,
71      concat_dimension=concat_dimension,
72      split_dimension=split_dimension,
73      split_count=split_count,
74      name=name)
75
76
77@ops.RegisterGradient("AllToAll")
78def _all_to_all_grad(op, grad):
79  # The gradient of a all-to-all is also a all-to-all but the
80  # split_dimension and concat_dimension is swapped.
81  # The graident with respect to group_assignment is None.
82  return [
83      gen_tpu_ops.all_to_all(
84          grad,
85          op.inputs[1],
86          concat_dimension=op.get_attr("split_dimension"),
87          split_dimension=op.get_attr("concat_dimension"),
88          split_count=op.get_attr("split_count")), None
89  ]
90
91
92def cross_replica_sum(x, group_assignment=None, name=None):
93  """Sum the input tensor across replicas according to group_assignment.
94
95  Args:
96    x: The local tensor to the sum.
97    group_assignment: Optional 2d int32 lists with shape [num_groups,
98      num_replicas_per_group]. `group_assignment[i]` represents the replica
99      ids in the ith subgroup.
100    name: Optional op name.
101
102  Returns:
103    A `Tensor` which is summed across replicas.
104  """
105  if group_assignment is None:
106    group_assignment = _create_default_group_assignment()
107
108  return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
109
110
111def collective_permute(x, source_target_pairs, name=None):
112  """Permute the input tensor across replicas given source_target_pairs.
113
114  For each source_target_pair <a, b>, we send replica a's input to replica b.
115  Each replica id must only appear once in the source column. Also it must
116  only appear once in the target column.
117  For the replica id not in the target column, this op returns a zero tensor
118  with the same shape and dtype of the input x.
119
120  For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
121  source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs:
122  `[0, A, B, C]`.
123
124  Args:
125    x: The local tensor to be permuted.
126    source_target_pairs: 2d int lists with shape [num_pairs, 2].
127      source_target_pairs[i][0] represents the source replica id and
128      source_target_pairs[i][1] represents the target replica id.
129    name: Optional op name.
130
131  Returns:
132    A `Tensor` which is permuted.
133  """
134  return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name)
135
136
137@ops.RegisterGradient("CollectivePermute")
138def _collective_permute_grad(op, grad):
139  # The gradient of a collective permute operation is also a collective
140  # permute, but with source/target pairs reversed. The gradient with respect
141  # to input argument `source_target_pairs` is `None`.
142  source_target_pairs = op.inputs[1][:, ::-1]
143  return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None]
144
145
146@ops.RegisterGradient("CrossReplicaSum")
147def _cross_replica_sum_grad(op, grad):
148  # The gradient of a cross replica sum is also a cross-replica sum.
149  # The gradient with respect to group_assignment is None.
150  return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
151
152
153# This extra type checking exists to give a more helpful error message in
154# the common case that uint8 and int64 values are infed. Remove when both
155# types are supported.
156
157_SUPPORTED_INFEED_DTYPES = set([
158    dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32,
159    dtypes.complex64, dtypes.uint32
160])
161
162
163@ops.RegisterGradient("TPUEmbeddingActivations")
164def _embedding_activations_grad(activations_op, grad_wrt_activations):
165  """Saves the gradient of embedding activations ops in a graph collection."""
166  g = ops.get_default_graph()
167  table_id = activations_op.get_attr("table_id")
168  lookup_id = activations_op.get_attr("lookup_id")
169  table_gradients = g.get_collection_ref(
170      "tpu_embedding_gradients_table_%d" % table_id)
171
172  if not table_gradients:
173    raise RuntimeError(
174        "Gradients for TPUEmbedding have been generated in non-training mode."
175        "This is not expected. Consider putting your Optimizer.minimize code "
176        "behind the training mode condition check. For Estimator, you can "
177        "do \n\n"
178        "    if mode == tf.estimator.ModeKeys.TRAIN:\n"
179        "        train_op = opt.minimize(loss)\n"
180        "\n")
181
182  table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations)
183  return [
184      # RegisterGradient requires that value be returned for all inputs. Since
185      # the first argument (tpu_gradient_variable_{table_name}) has shape [1],
186      # we will return zeros(shape=[1]). The actual gradient w.r.t. the
187      # embedding activations (grad_wrt_activations) has the same shape as the
188      # activations returned by  embedding_activations.
189      array_ops.zeros(arg.shape, dtype=dtypes.float32)
190      for arg in activations_op.inputs
191  ]
192
193
194def infeed_dequeue(dtype, shape, name=None):
195  """A placeholder op for a value that will be fed into the computation.
196
197  Args:
198    dtype: A `tf.DType`. The type of elements in the tensor.
199    shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor.
200    name: A name for the operation (optional).
201
202  Returns:
203    A `Tensor` of type `dtype`.
204    A tensor that will be provided using the infeed mechanism.
205
206  Raises:
207    TypeError: If 'dtype` is not a supported infeed type.
208  """
209  if dtype not in _SUPPORTED_INFEED_DTYPES:
210    raise TypeError(
211        "{} is not a supported TPU infeed type. Supported types are: "
212        "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
213
214  return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name)
215
216
217# pylint: disable=redefined-outer-name
218def infeed_dequeue_tuple(dtypes, shapes, name=None):
219  """A placeholder op for values fed into the TPU simultaneously as a tuple.
220
221  Args:
222    dtypes: A list of `tf.DType`s that has length `>= 1`.
223      The element types of each element in `outputs`.
224    shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`).
225      The shapes of each tensor in `outputs`.
226    name: A name for the operation (optional).
227
228  Returns:
229    A list of `Tensor` objects of type `dtypes`.
230    A list of tensors that will be provided using the infeed mechanism.
231
232  Raises:
233    TypeError: If a type in 'dtypes` is not a supported infeed type.
234  """
235  for dtype in dtypes:
236    if dtype not in _SUPPORTED_INFEED_DTYPES:
237      raise TypeError(
238          "{} is not a supported TPU infeed type. Supported types are: "
239          "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES)))
240  return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
241# pylint: enable=redefined-outer-name
242
243
244# pylint: disable=protected-access
245def send_tpu_embedding_gradients(inputs,
246                                 config,
247                                 learning_rates=None,
248                                 name=None):
249  """A placeholder op for feeding per-sample gradients to the embedding layer.
250
251  Args:
252    inputs: A TensorList of gradients with which to update embedding tables.
253        This argument has the same length and shapes as the return value of
254        RecvTPUEmbeddingActivations, but contains gradients of the model's
255        loss with respect to the embedding activations. The embedding tables
256        are updated from these gradients via the optimizers specified in the
257        TPU embedding configuration given to tpu.initialize_system.
258    config: Serialized TPUEmbeddingConfiguration proto.
259    learning_rates: A TensorList of float32 scalars, one for each dynamic
260        learning rate tag: see the comments in
261        //third_party/tensorflow/core/protobuf/tpu/
262                                             optimization_parameters.proto.
263        Multiple tables can share the same dynamic learning rate tag as
264        specified in the configuration. If the learning rates for all tables
265        are constant, this list should be empty.
266    name: A name for the operation (optional).
267
268  Returns:
269    A SendTPUEmbeddingGradients operation.
270  """
271  if learning_rates is None:
272    learning_rates = []
273  return gen_tpu_ops.send_tpu_embedding_gradients(
274      inputs=inputs, learning_rates=learning_rates, config=config, name=name)
275
276
277send_tpu_embedding_gradients.__doc__ = (
278    gen_tpu_ops.send_tpu_embedding_gradients.__doc__)
279
280
281# pylint: disable=protected-access
282def enqueue_tpu_embedding_integer_batch(batch,
283                                        device_ordinal,
284                                        mode_override=None,
285                                        name=None):
286  """A placeholder op for enqueueing embedding IDs to the TPU.
287
288  Args:
289    batch: A list of 1D tensors, one for each embedding table, containing the
290      indices into the tables.
291    device_ordinal: The TPU device to use. Should be >= 0 and less than the
292      number of TPU cores in the task on which the node is placed.
293    mode_override: A string input that overrides the mode specified in the
294      TPUEmbeddingConfiguration. Supported values are {'unspecified',
295      'inference', 'training', 'backward_pass_only'}. When set to
296      'unspecified', the mode set in TPUEmbeddingConfiguration is used,
297      otherwise mode_override is used (optional).
298    name: A name for the operation (optional).
299
300  Returns:
301    An EnqueueTPUEmbeddingIntegerBatch operation.
302  """
303  if mode_override is None:
304    mode_override = "unspecified"
305  return gen_tpu_ops.enqueue_tpu_embedding_integer_batch(
306      batch=batch,
307      device_ordinal=device_ordinal,
308      mode_override=mode_override,
309      name=name)
310
311
312enqueue_tpu_embedding_integer_batch.__doc__ = (
313    gen_tpu_ops.enqueue_tpu_embedding_integer_batch.__doc__)
314
315
316# pylint: disable=protected-access
317def enqueue_tpu_embedding_sparse_batch(sample_indices,
318                                       embedding_indices,
319                                       aggregation_weights,
320                                       device_ordinal,
321                                       combiners=None,
322                                       mode_override=None,
323                                       name=None):
324  """A placeholder op for enqueueing embedding IDs to the TPU.
325
326  Args:
327    sample_indices: A list of rank 1 Tensors specifying the training example
328      and feature to which the corresponding embedding_indices and
329      aggregation_weights values belong. sample_indices[i] must equal b * nf +
330      f, where nf is the number of features from the corresponding table, f is
331      in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed,
332      and will be converted to int32 internally.
333    embedding_indices: A list of rank 1 Tensors, indices into the embedding
334      tables. Both int32 and int64 are allowed and will be converted to int32
335      internally.
336    aggregation_weights: A list of rank 1 Tensors containing per sample --
337      i.e. per (training example, feature) -- aggregation weights. Both float32
338      and float64 are allowed and will be converted to float32 internally.
339    device_ordinal: The TPU device to use. Should be >= 0 and less than the
340      number of TPU cores in the task on which the node is placed.
341    combiners: A list of string scalars, one for each embedding table that
342      specify how to normalize the embedding activations after weighted
343      summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
344      invalid to have the sum of the weights be 0 for 'mean' or the sum of the
345      squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
346      is to use 'sum' for all tables (optional).
347    mode_override: A string input that overrides the mode specified in the
348      TPUEmbeddingConfiguration. Supported values are {'unspecified',
349      'inference', 'training', 'backward_pass_only'}. When set to
350      'unspecified', the mode set in TPUEmbeddingConfiguration is used,
351      otherwise mode_override is used (optional).
352    name: A name for the operation (optional).
353
354  Returns:
355    An EnqueueTPUEmbeddingSparseBatch operation.
356  """
357  if mode_override is None:
358    mode_override = "unspecified"
359  return gen_tpu_ops.enqueue_tpu_embedding_sparse_batch(
360      sample_indices=sample_indices,
361      embedding_indices=embedding_indices,
362      aggregation_weights=aggregation_weights,
363      device_ordinal=device_ordinal,
364      combiners=combiners,
365      mode_override=mode_override,
366      name=name)
367
368
369enqueue_tpu_embedding_sparse_batch.__doc__ = (
370    gen_tpu_ops.enqueue_tpu_embedding_sparse_batch.__doc__)
371
372
373# pylint: disable=protected-access
374def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
375                                              embedding_indices,
376                                              aggregation_weights,
377                                              table_ids,
378                                              device_ordinal,
379                                              combiners=None,
380                                              mode_override=None,
381                                              name=None):
382  """A placeholder op for enqueueing embedding IDs to the TPU.
383
384  Args:
385    sample_indices: A list of rank 2 Tensors specifying the training example
386      to which the corresponding embedding_indices and aggregation_weights
387      values belong. It corresponds to sp_ids.indices in
388      embedding_lookup_sparse(). If the size of its first dimension is 0, we
389      assume each embedding_indices belongs to a different sample. Both int32
390      and int64 are allowed and will be converted to int32 internally.
391    embedding_indices: A list of rank 1 Tensors, indices into the embedding
392      tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both
393      int32 and int64 are allowed and will be converted to int32 internally.
394    aggregation_weights: A list of rank 1 Tensors containing per training
395      example aggregation weights. It corresponds to sp_weights.values in
396      embedding_lookup_sparse(). If the size of its first dimension is 0, we
397      assume all weights are 1. Both float32 and float64 are allowed and will
398      be converted to float32 internally.
399    table_ids: A list of integers specifying the identifier of the embedding
400      table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
401      lookup the corresponding input. The ith input is looked up using
402      table_ids[i]. The size of the table_ids list must be equal to that of
403      sample_indices, embedding_indices and aggregation_weights.
404    device_ordinal: The TPU device to use. Should be >= 0 and less than the
405      number of TPU cores in the task on which the node is placed.
406    combiners: A list of string scalars, one for each embedding table that
407      specify how to normalize the embedding activations after weighted
408      summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
409      invalid to have the sum of the weights be 0 for 'mean' or the sum of the
410      squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
411      is to use 'sum' for all tables (optional).
412    mode_override: A string input that overrides the mode specified in the
413      TPUEmbeddingConfiguration. Supported values are {'unspecified',
414      'inference', 'training', 'backward_pass_only'}. When set to
415      'unspecified', the mode set in TPUEmbeddingConfiguration is used,
416      otherwise mode_override is used (optional).
417    name: A name for the operation (optional).
418
419  Returns:
420    An EnqueueTPUEmbeddingSparseTensorBatch operation.
421  """
422  if mode_override is None:
423    mode_override = "unspecified"
424  return gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch(
425      sample_indices=sample_indices,
426      embedding_indices=embedding_indices,
427      aggregation_weights=aggregation_weights,
428      table_ids=table_ids,
429      device_ordinal=device_ordinal,
430      combiners=combiners,
431      mode_override=mode_override,
432      name=name)
433
434
435enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = (
436    gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__)
437