1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Operations for TPUs.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import array_ops 25# pylint: disable=wildcard-import,unused-import 26from tensorflow.python.ops import gen_tpu_ops 27from tensorflow.python.ops.gen_tpu_ops import * 28# pylint: enable=wildcard-import,unused-import 29from tensorflow.python.platform import tf_logging as logging 30from tensorflow.python.tpu import tpu_function 31 32 33def _create_default_group_assignment(): 34 num_shards = tpu_function.get_tpu_context().number_of_shards 35 if num_shards is None: 36 logging.warning( 37 "cross_replica_sum should be used within a tpu_shard_context, but " 38 "got unset number_of_shards. Assuming 1.") 39 num_shards = 1 40 group_assignment = [list(range(num_shards))] 41 return group_assignment 42 43 44def all_to_all(x, 45 concat_dimension, 46 split_dimension, 47 split_count, 48 group_assignment=None, 49 name=None): 50 """Exchange data across TPU replicas. 51 52 Args: 53 x: The local tensor. 54 concat_dimension: The dimension number to concatenate. 55 split_dimension: The dimension number to split. 56 split_count: The number of splits, this number must equal to the sub-group 57 size(group_assignment.get_shape()[1]) 58 group_assignment: Optional 2d int32 lists with shape [num_groups, 59 num_replicas_per_group]. `group_assignment[i]` represents the replica 60 ids in the ith subgroup. 61 name: Optional op name. 62 63 Returns: 64 A `Tensor` which is concatenated by data from different replicas. 65 """ 66 if group_assignment is None: 67 group_assignment = _create_default_group_assignment() 68 return gen_tpu_ops.all_to_all( 69 x, 70 group_assignment, 71 concat_dimension=concat_dimension, 72 split_dimension=split_dimension, 73 split_count=split_count, 74 name=name) 75 76 77@ops.RegisterGradient("AllToAll") 78def _all_to_all_grad(op, grad): 79 # The gradient of a all-to-all is also a all-to-all but the 80 # split_dimension and concat_dimension is swapped. 81 # The graident with respect to group_assignment is None. 82 return [ 83 gen_tpu_ops.all_to_all( 84 grad, 85 op.inputs[1], 86 concat_dimension=op.get_attr("split_dimension"), 87 split_dimension=op.get_attr("concat_dimension"), 88 split_count=op.get_attr("split_count")), None 89 ] 90 91 92def cross_replica_sum(x, group_assignment=None, name=None): 93 """Sum the input tensor across replicas according to group_assignment. 94 95 Args: 96 x: The local tensor to the sum. 97 group_assignment: Optional 2d int32 lists with shape [num_groups, 98 num_replicas_per_group]. `group_assignment[i]` represents the replica 99 ids in the ith subgroup. 100 name: Optional op name. 101 102 Returns: 103 A `Tensor` which is summed across replicas. 104 """ 105 if group_assignment is None: 106 group_assignment = _create_default_group_assignment() 107 108 return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name) 109 110 111def collective_permute(x, source_target_pairs, name=None): 112 """Permute the input tensor across replicas given source_target_pairs. 113 114 For each source_target_pair <a, b>, we send replica a's input to replica b. 115 Each replica id must only appear once in the source column. Also it must 116 only appear once in the target column. 117 For the replica id not in the target column, this op returns a zero tensor 118 with the same shape and dtype of the input x. 119 120 For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing 121 source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs: 122 `[0, A, B, C]`. 123 124 Args: 125 x: The local tensor to be permuted. 126 source_target_pairs: 2d int lists with shape [num_pairs, 2]. 127 source_target_pairs[i][0] represents the source replica id and 128 source_target_pairs[i][1] represents the target replica id. 129 name: Optional op name. 130 131 Returns: 132 A `Tensor` which is permuted. 133 """ 134 return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name) 135 136 137@ops.RegisterGradient("CollectivePermute") 138def _collective_permute_grad(op, grad): 139 # The gradient of a collective permute operation is also a collective 140 # permute, but with source/target pairs reversed. The gradient with respect 141 # to input argument `source_target_pairs` is `None`. 142 source_target_pairs = op.inputs[1][:, ::-1] 143 return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None] 144 145 146@ops.RegisterGradient("CrossReplicaSum") 147def _cross_replica_sum_grad(op, grad): 148 # The gradient of a cross replica sum is also a cross-replica sum. 149 # The gradient with respect to group_assignment is None. 150 return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None] 151 152 153# This extra type checking exists to give a more helpful error message in 154# the common case that uint8 and int64 values are infed. Remove when both 155# types are supported. 156 157_SUPPORTED_INFEED_DTYPES = set([ 158 dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32, 159 dtypes.complex64, dtypes.uint32 160]) 161 162 163@ops.RegisterGradient("TPUEmbeddingActivations") 164def _embedding_activations_grad(activations_op, grad_wrt_activations): 165 """Saves the gradient of embedding activations ops in a graph collection.""" 166 g = ops.get_default_graph() 167 table_id = activations_op.get_attr("table_id") 168 lookup_id = activations_op.get_attr("lookup_id") 169 table_gradients = g.get_collection_ref( 170 "tpu_embedding_gradients_table_%d" % table_id) 171 172 if not table_gradients: 173 raise RuntimeError( 174 "Gradients for TPUEmbedding have been generated in non-training mode." 175 "This is not expected. Consider putting your Optimizer.minimize code " 176 "behind the training mode condition check. For Estimator, you can " 177 "do \n\n" 178 " if mode == tf.estimator.ModeKeys.TRAIN:\n" 179 " train_op = opt.minimize(loss)\n" 180 "\n") 181 182 table_gradients[lookup_id] = array_ops.identity(grad_wrt_activations) 183 return [ 184 # RegisterGradient requires that value be returned for all inputs. Since 185 # the first argument (tpu_gradient_variable_{table_name}) has shape [1], 186 # we will return zeros(shape=[1]). The actual gradient w.r.t. the 187 # embedding activations (grad_wrt_activations) has the same shape as the 188 # activations returned by embedding_activations. 189 array_ops.zeros(arg.shape, dtype=dtypes.float32) 190 for arg in activations_op.inputs 191 ] 192 193 194def infeed_dequeue(dtype, shape, name=None): 195 """A placeholder op for a value that will be fed into the computation. 196 197 Args: 198 dtype: A `tf.DType`. The type of elements in the tensor. 199 shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor. 200 name: A name for the operation (optional). 201 202 Returns: 203 A `Tensor` of type `dtype`. 204 A tensor that will be provided using the infeed mechanism. 205 206 Raises: 207 TypeError: If 'dtype` is not a supported infeed type. 208 """ 209 if dtype not in _SUPPORTED_INFEED_DTYPES: 210 raise TypeError( 211 "{} is not a supported TPU infeed type. Supported types are: " 212 "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES))) 213 214 return gen_tpu_ops.infeed_dequeue(dtype, shape, name=name) 215 216 217# pylint: disable=redefined-outer-name 218def infeed_dequeue_tuple(dtypes, shapes, name=None): 219 """A placeholder op for values fed into the TPU simultaneously as a tuple. 220 221 Args: 222 dtypes: A list of `tf.DType`s that has length `>= 1`. 223 The element types of each element in `outputs`. 224 shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). 225 The shapes of each tensor in `outputs`. 226 name: A name for the operation (optional). 227 228 Returns: 229 A list of `Tensor` objects of type `dtypes`. 230 A list of tensors that will be provided using the infeed mechanism. 231 232 Raises: 233 TypeError: If a type in 'dtypes` is not a supported infeed type. 234 """ 235 for dtype in dtypes: 236 if dtype not in _SUPPORTED_INFEED_DTYPES: 237 raise TypeError( 238 "{} is not a supported TPU infeed type. Supported types are: " 239 "{}".format(dtype, list(_SUPPORTED_INFEED_DTYPES))) 240 return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name) 241# pylint: enable=redefined-outer-name 242 243 244# pylint: disable=protected-access 245def send_tpu_embedding_gradients(inputs, 246 config, 247 learning_rates=None, 248 name=None): 249 """A placeholder op for feeding per-sample gradients to the embedding layer. 250 251 Args: 252 inputs: A TensorList of gradients with which to update embedding tables. 253 This argument has the same length and shapes as the return value of 254 RecvTPUEmbeddingActivations, but contains gradients of the model's 255 loss with respect to the embedding activations. The embedding tables 256 are updated from these gradients via the optimizers specified in the 257 TPU embedding configuration given to tpu.initialize_system. 258 config: Serialized TPUEmbeddingConfiguration proto. 259 learning_rates: A TensorList of float32 scalars, one for each dynamic 260 learning rate tag: see the comments in 261 //third_party/tensorflow/core/protobuf/tpu/ 262 optimization_parameters.proto. 263 Multiple tables can share the same dynamic learning rate tag as 264 specified in the configuration. If the learning rates for all tables 265 are constant, this list should be empty. 266 name: A name for the operation (optional). 267 268 Returns: 269 A SendTPUEmbeddingGradients operation. 270 """ 271 if learning_rates is None: 272 learning_rates = [] 273 return gen_tpu_ops.send_tpu_embedding_gradients( 274 inputs=inputs, learning_rates=learning_rates, config=config, name=name) 275 276 277send_tpu_embedding_gradients.__doc__ = ( 278 gen_tpu_ops.send_tpu_embedding_gradients.__doc__) 279 280 281# pylint: disable=protected-access 282def enqueue_tpu_embedding_integer_batch(batch, 283 device_ordinal, 284 mode_override=None, 285 name=None): 286 """A placeholder op for enqueueing embedding IDs to the TPU. 287 288 Args: 289 batch: A list of 1D tensors, one for each embedding table, containing the 290 indices into the tables. 291 device_ordinal: The TPU device to use. Should be >= 0 and less than the 292 number of TPU cores in the task on which the node is placed. 293 mode_override: A string input that overrides the mode specified in the 294 TPUEmbeddingConfiguration. Supported values are {'unspecified', 295 'inference', 'training', 'backward_pass_only'}. When set to 296 'unspecified', the mode set in TPUEmbeddingConfiguration is used, 297 otherwise mode_override is used (optional). 298 name: A name for the operation (optional). 299 300 Returns: 301 An EnqueueTPUEmbeddingIntegerBatch operation. 302 """ 303 if mode_override is None: 304 mode_override = "unspecified" 305 return gen_tpu_ops.enqueue_tpu_embedding_integer_batch( 306 batch=batch, 307 device_ordinal=device_ordinal, 308 mode_override=mode_override, 309 name=name) 310 311 312enqueue_tpu_embedding_integer_batch.__doc__ = ( 313 gen_tpu_ops.enqueue_tpu_embedding_integer_batch.__doc__) 314 315 316# pylint: disable=protected-access 317def enqueue_tpu_embedding_sparse_batch(sample_indices, 318 embedding_indices, 319 aggregation_weights, 320 device_ordinal, 321 combiners=None, 322 mode_override=None, 323 name=None): 324 """A placeholder op for enqueueing embedding IDs to the TPU. 325 326 Args: 327 sample_indices: A list of rank 1 Tensors specifying the training example 328 and feature to which the corresponding embedding_indices and 329 aggregation_weights values belong. sample_indices[i] must equal b * nf + 330 f, where nf is the number of features from the corresponding table, f is 331 in [0, nf), and b is in [0, batch size). Both int32 and int64 are allowed, 332 and will be converted to int32 internally. 333 embedding_indices: A list of rank 1 Tensors, indices into the embedding 334 tables. Both int32 and int64 are allowed and will be converted to int32 335 internally. 336 aggregation_weights: A list of rank 1 Tensors containing per sample -- 337 i.e. per (training example, feature) -- aggregation weights. Both float32 338 and float64 are allowed and will be converted to float32 internally. 339 device_ordinal: The TPU device to use. Should be >= 0 and less than the 340 number of TPU cores in the task on which the node is placed. 341 combiners: A list of string scalars, one for each embedding table that 342 specify how to normalize the embedding activations after weighted 343 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is 344 invalid to have the sum of the weights be 0 for 'mean' or the sum of the 345 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default 346 is to use 'sum' for all tables (optional). 347 mode_override: A string input that overrides the mode specified in the 348 TPUEmbeddingConfiguration. Supported values are {'unspecified', 349 'inference', 'training', 'backward_pass_only'}. When set to 350 'unspecified', the mode set in TPUEmbeddingConfiguration is used, 351 otherwise mode_override is used (optional). 352 name: A name for the operation (optional). 353 354 Returns: 355 An EnqueueTPUEmbeddingSparseBatch operation. 356 """ 357 if mode_override is None: 358 mode_override = "unspecified" 359 return gen_tpu_ops.enqueue_tpu_embedding_sparse_batch( 360 sample_indices=sample_indices, 361 embedding_indices=embedding_indices, 362 aggregation_weights=aggregation_weights, 363 device_ordinal=device_ordinal, 364 combiners=combiners, 365 mode_override=mode_override, 366 name=name) 367 368 369enqueue_tpu_embedding_sparse_batch.__doc__ = ( 370 gen_tpu_ops.enqueue_tpu_embedding_sparse_batch.__doc__) 371 372 373# pylint: disable=protected-access 374def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices, 375 embedding_indices, 376 aggregation_weights, 377 table_ids, 378 device_ordinal, 379 combiners=None, 380 mode_override=None, 381 name=None): 382 """A placeholder op for enqueueing embedding IDs to the TPU. 383 384 Args: 385 sample_indices: A list of rank 2 Tensors specifying the training example 386 to which the corresponding embedding_indices and aggregation_weights 387 values belong. It corresponds to sp_ids.indices in 388 embedding_lookup_sparse(). If the size of its first dimension is 0, we 389 assume each embedding_indices belongs to a different sample. Both int32 390 and int64 are allowed and will be converted to int32 internally. 391 embedding_indices: A list of rank 1 Tensors, indices into the embedding 392 tables. It corresponds to sp_ids.values in embedding_lookup_sparse(). Both 393 int32 and int64 are allowed and will be converted to int32 internally. 394 aggregation_weights: A list of rank 1 Tensors containing per training 395 example aggregation weights. It corresponds to sp_weights.values in 396 embedding_lookup_sparse(). If the size of its first dimension is 0, we 397 assume all weights are 1. Both float32 and float64 are allowed and will 398 be converted to float32 internally. 399 table_ids: A list of integers specifying the identifier of the embedding 400 table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to 401 lookup the corresponding input. The ith input is looked up using 402 table_ids[i]. The size of the table_ids list must be equal to that of 403 sample_indices, embedding_indices and aggregation_weights. 404 device_ordinal: The TPU device to use. Should be >= 0 and less than the 405 number of TPU cores in the task on which the node is placed. 406 combiners: A list of string scalars, one for each embedding table that 407 specify how to normalize the embedding activations after weighted 408 summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is 409 invalid to have the sum of the weights be 0 for 'mean' or the sum of the 410 squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default 411 is to use 'sum' for all tables (optional). 412 mode_override: A string input that overrides the mode specified in the 413 TPUEmbeddingConfiguration. Supported values are {'unspecified', 414 'inference', 'training', 'backward_pass_only'}. When set to 415 'unspecified', the mode set in TPUEmbeddingConfiguration is used, 416 otherwise mode_override is used (optional). 417 name: A name for the operation (optional). 418 419 Returns: 420 An EnqueueTPUEmbeddingSparseTensorBatch operation. 421 """ 422 if mode_override is None: 423 mode_override = "unspecified" 424 return gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch( 425 sample_indices=sample_indices, 426 embedding_indices=embedding_indices, 427 aggregation_weights=aggregation_weights, 428 table_ids=table_ids, 429 device_ordinal=device_ordinal, 430 combiners=combiners, 431 mode_override=mode_override, 432 name=name) 433 434 435enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = ( 436 gen_tpu_ops.enqueue_tpu_embedding_sparse_tensor_batch.__doc__) 437