1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operations for embeddings.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from six.moves import xrange # pylint: disable=redefined-builtin 21 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import sparse_tensor 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import clip_ops 28# Imports gradient definitions. 29from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import 30from tensorflow.python.ops import data_flow_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import resource_variable_ops 33from tensorflow.python.ops import variables 34from tensorflow.python.platform import tf_logging as logging 35from tensorflow.python.util.tf_export import tf_export 36 37 38def _gather(params, ids, name=None): 39 """Helper function for _embedding_lookup_and_transform. 40 41 This function gathers embeddings from a single tensor. The gather deals with 42 resource variables specially. 43 44 Args: 45 params: A `Tensor` of embeddings. 46 ids: A `Tensor` indexing the embeddings to be retrieved from `params`. 47 name: A name for the operation (optional). 48 49 Returns: 50 A `Tensor` with the same type as `params`. 51 """ 52 if isinstance(params, resource_variable_ops.ResourceVariable): 53 return params.sparse_read(ids, name=name) 54 else: 55 return array_ops.gather(params, ids, name=name) 56 57 58def _clip(params, ids, max_norm): 59 """Helper function for _embedding_lookup_and_transform. 60 61 This function optionally clips embeddings to an l2-norm of max_norm. 62 63 Args: 64 params: A `Tensor` of embeddings retrieved by `_gather`. 65 ids: The `ids` argument that was passed to `_gather`. 66 max_norm: If provided, the embeddings are l2-normalized to the value of 67 max_norm. 68 69 Returns: 70 A `Tensor` with the same type as `params`. 71 """ 72 73 def _rank(x): 74 """Helper function to retrieve the rank of a tensor. 75 76 Args: 77 x: Something convertible to `Tensor`. 78 79 Returns: 80 Either a pair `(rank, True)` where `rank` is an integer or a pair 81 `(rank, False)` where `rank` is an integer `Tensor`. In either case, 82 `rank` is the rank of `x`. 83 """ 84 rank = ops.convert_to_tensor(x).get_shape().ndims 85 if rank: 86 return rank, True 87 else: 88 return array_ops.rank(x), False 89 90 if max_norm is None: 91 return params 92 ids_rank, ids_static = _rank(ids) 93 params_rank, params_static = _rank(params) 94 return clip_ops.clip_by_norm( 95 params, 96 max_norm, 97 axes=(list(range(ids_rank, params_rank)) 98 if ids_static and params_static 99 else math_ops.range(ids_rank, params_rank))) 100 101 102def _embedding_lookup_and_transform(params, 103 ids, 104 partition_strategy="mod", 105 name=None, 106 max_norm=None, 107 transform_fn=None): 108 """Helper function for embedding_lookup and _compute_sampled_logits. 109 110 This function is a generalization of embedding_lookup that optionally 111 applies a caller-specified transformation to each embedding. This is 112 done through the `transform_fn` argument. If provided, the function is 113 applied to each partitioned tensor of retrieved embeddings, colocated 114 with the embeddings. This function will be called with a single `Tensor` 115 argument of the same type as the `params` tensor and should return a 116 `Tensor`. The shape of the argument will be the same as `params` except 117 for the size of the first dimension. The first dimension of the result's 118 shape must be the same size as the argument's. 119 120 Args: 121 params: See embedding_lookup. 122 ids: See embedding_lookup. 123 partition_strategy: See embedding_lookup. 124 name: See embedding_lookup. 125 max_norm: See embedding_lookup. 126 transform_fn: An optional function to apply to each retrieved embedding. 127 If max_norm is provided, transform_fn is applied to the norm-limited 128 embeddings. 129 130 Returns: 131 See embedding_lookup for details. 132 Raises: 133 ValueError: If `params` is empty. 134 """ 135 if params is None or params in ((), []): 136 raise ValueError("Need at least one param") 137 if isinstance(params, variables.PartitionedVariable): 138 params = list(params) # Iterate to get the underlying Variables. 139 if not isinstance(params, list): 140 params = [params] 141 142 with ops.name_scope(name, "embedding_lookup", params + [ids]) as name: 143 np = len(params) # Number of partitions 144 # Preserve the resource variable status to avoid accidental dense reads. 145 if not any( 146 isinstance(p, resource_variable_ops.ResourceVariable) for p in params): 147 params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params") 148 ids = ops.convert_to_tensor(ids, name="ids") 149 if np == 1 and (not transform_fn or ids.get_shape().ndims == 1): 150 with ops.colocate_with(params[0]): 151 result = _clip(_gather(params[0], ids, name=name), ids, max_norm) 152 if transform_fn: 153 result = transform_fn(result) 154 return result 155 else: 156 # Flatten the ids. There are two cases where we need to do this. 157 # - There is more than one params tensor. 158 # - There is a transform_fn and ids is not statically known to be 1-D. 159 # We must flatten in this case because transform_fn expects a flat 160 # tensor of embeddings. 161 flat_ids = array_ops.reshape(ids, [-1]) 162 original_indices = math_ops.range(array_ops.size(flat_ids)) 163 164 # Create p_assignments and set new_ids depending on the strategy. 165 if partition_strategy == "mod": 166 p_assignments = flat_ids % np 167 new_ids = flat_ids // np 168 elif partition_strategy == "div": 169 # Compute num_total_ids as the sum of dim-0 of params, then assign to 170 # partitions based on a constant number of ids per partition. Optimize 171 # if we already know the full shape statically. 172 dim_0_size = params[0].get_shape()[0] 173 for p in xrange(1, np): 174 dim_0_size += params[p].get_shape()[0] 175 if dim_0_size.value: 176 num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype) 177 else: 178 dim_0_sizes = [] 179 for p in xrange(np): 180 if params[p].get_shape()[0].value is not None: 181 dim_0_sizes.append(params[p].get_shape()[0].value) 182 else: 183 with ops.colocate_with(params[p]): 184 dim_0_sizes.append(array_ops.shape(params[p])[0]) 185 num_total_ids = math_ops.reduce_sum( 186 math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) 187 ids_per_partition = num_total_ids // np 188 extras = num_total_ids % np 189 190 p_assignments = math_ops.maximum( 191 flat_ids // (ids_per_partition + 1), 192 (flat_ids - extras) // ids_per_partition) 193 194 # Emulate a conditional using a boolean indicator tensor 195 new_ids = array_ops.where(p_assignments < extras, 196 flat_ids % (ids_per_partition + 1), 197 (flat_ids - extras) % ids_per_partition) 198 else: 199 raise ValueError("Unrecognized partition strategy: " + 200 partition_strategy) 201 202 # Cast partition assignments to int32 for use in dynamic_partition. 203 # There really should not be more than 2^32 partitions. 204 p_assignments = math_ops.cast(p_assignments, dtypes.int32) 205 # Partition list of ids based on assignments into np separate lists 206 gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np) 207 # Similarly, partition the original indices. 208 pindices = data_flow_ops.dynamic_partition(original_indices, 209 p_assignments, np) 210 # Do np separate lookups, finding embeddings for plist[p] in params[p] 211 partitioned_result = [] 212 for p in xrange(np): 213 pids = gather_ids[p] 214 with ops.colocate_with(params[p]): 215 result = _gather(params[p], pids) 216 if transform_fn: 217 # If transform_fn is provided, the clip_by_norm precedes 218 # the transform and hence must be co-located. See below 219 # for the counterpart if transform_fn is not proveded. 220 result = transform_fn(_clip(result, pids, max_norm)) 221 partitioned_result.append(result) 222 # Stitch these back together 223 ret = data_flow_ops.parallel_dynamic_stitch( 224 pindices, partitioned_result, name=name) 225 226 # Determine the static element shape. 227 if transform_fn is None: 228 element_shape_s = params[0].get_shape()[1:] 229 for p in params[1:]: 230 element_shape_s = element_shape_s.merge_with(p.get_shape()[1:]) 231 else: 232 element_shape_s = ret.get_shape()[1:] 233 234 # Compute the dynamic element shape. 235 if element_shape_s.is_fully_defined(): 236 element_shape_d = element_shape_s 237 elif transform_fn is None: 238 # It's important that we compute params[0].shape on the right device 239 # to avoid data motion. 240 with ops.colocate_with(params[0]): 241 params_shape = array_ops.shape(params[0]) 242 element_shape_d = params_shape[1:] 243 else: 244 element_shape_d = array_ops.shape(ret)[1:] 245 246 # Reshape to reverse the flattening of ids. 247 ret = array_ops.reshape(ret, 248 array_ops.concat( 249 [array_ops.shape(ids), element_shape_d], 0)) 250 251 # Normally the reshape is sufficient, but setting shape explicitly 252 # teaches shape inference that params[1:].get_shape() matters 253 # (in the case that transform_fn is None). 254 ret.set_shape(ids.get_shape().concatenate(element_shape_s)) 255 if not transform_fn: 256 # If transform_fn was provided, the clip_by_norm was done above. 257 ret = _clip(ret, ids, max_norm) 258 return ret 259 260 261@tf_export("nn.embedding_lookup") 262def embedding_lookup( 263 params, 264 ids, 265 partition_strategy="mod", 266 name=None, 267 validate_indices=True, # pylint: disable=unused-argument 268 max_norm=None): 269 """Looks up `ids` in a list of embedding tensors. 270 271 This function is used to perform parallel lookups on the list of 272 tensors in `params`. It is a generalization of 273 @{tf.gather}, where `params` is 274 interpreted as a partitioning of a large embedding tensor. `params` may be 275 a `PartitionedVariable` as returned by using `tf.get_variable()` with a 276 partitioner. 277 278 If `len(params) > 1`, each element `id` of `ids` is partitioned between 279 the elements of `params` according to the `partition_strategy`. 280 In all strategies, if the id space does not evenly divide the number of 281 partitions, each of the first `(max_id + 1) % len(params)` partitions will 282 be assigned one more id. 283 284 If `partition_strategy` is `"mod"`, we assign each id to partition 285 `p = id % len(params)`. For instance, 286 13 ids are split across 5 partitions as: 287 `[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]` 288 289 If `partition_strategy` is `"div"`, we assign ids to partitions in a 290 contiguous manner. In this case, 13 ids are split across 5 partitions as: 291 `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]` 292 293 The results of the lookup are concatenated into a dense 294 tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. 295 296 Args: 297 params: A single tensor representing the complete embedding tensor, 298 or a list of P tensors all of same shape except for the first dimension, 299 representing sharded embedding tensors. Alternatively, a 300 `PartitionedVariable`, created by partitioning along dimension 0. Each 301 element must be appropriately sized for the given `partition_strategy`. 302 ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked 303 up in `params`. 304 partition_strategy: A string specifying the partitioning strategy, relevant 305 if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default 306 is `"mod"`. 307 name: A name for the operation (optional). 308 validate_indices: DEPRECATED. If this operation is assigned to CPU, values 309 in `indices` are always validated to be within range. If assigned to GPU, 310 out-of-bound indices result in safe but unspecified behavior, which may 311 include raising an error. 312 max_norm: If provided, embedding values are l2-normalized to the value of 313 max_norm. 314 315 Returns: 316 A `Tensor` with the same type as the tensors in `params`. 317 318 Raises: 319 ValueError: If `params` is empty. 320 """ 321 return _embedding_lookup_and_transform( 322 params=params, 323 ids=ids, 324 partition_strategy=partition_strategy, 325 name=name, 326 max_norm=max_norm, 327 transform_fn=None) 328 329 330@tf_export("nn.embedding_lookup_sparse") 331def embedding_lookup_sparse(params, 332 sp_ids, 333 sp_weights, 334 partition_strategy="mod", 335 name=None, 336 combiner=None, 337 max_norm=None): 338 """Computes embeddings for the given ids and weights. 339 340 This op assumes that there is at least one id for each row in the dense tensor 341 represented by sp_ids (i.e. there are no rows with empty features), and that 342 all the indices of sp_ids are in canonical row-major order. 343 344 It also assumes that all id values lie in the range [0, p0), where p0 345 is the sum of the size of params along dimension 0. 346 347 Args: 348 params: A single tensor representing the complete embedding tensor, 349 or a list of P tensors all of same shape except for the first dimension, 350 representing sharded embedding tensors. Alternatively, a 351 `PartitionedVariable`, created by partitioning along dimension 0. Each 352 element must be appropriately sized for the given `partition_strategy`. 353 sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId), 354 where N is typically batch size and M is arbitrary. 355 sp_weights: either a SparseTensor of float / double weights, or None to 356 indicate all weights should be taken to be 1. If specified, sp_weights 357 must have exactly the same shape and indices as sp_ids. 358 partition_strategy: A string specifying the partitioning strategy, relevant 359 if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default 360 is `"mod"`. See `tf.nn.embedding_lookup` for more details. 361 name: Optional name for the op. 362 combiner: A string specifying the reduction op. Currently "mean", "sqrtn" 363 and "sum" are supported. 364 "sum" computes the weighted sum of the embedding results for each row. 365 "mean" is the weighted sum divided by the total weight. 366 "sqrtn" is the weighted sum divided by the square root of the sum of the 367 squares of the weights. 368 max_norm: If provided, each embedding is normalized to have l2 norm equal 369 to max_norm before combining. 370 371 Returns: 372 A dense tensor representing the combined embeddings for the 373 sparse ids. For each row in the dense tensor represented by sp_ids, the op 374 looks up the embeddings for all ids in that row, multiplies them by the 375 corresponding weight, and combines these embeddings as specified. 376 377 In other words, if 378 379 shape(combined params) = [p0, p1, ..., pm] 380 381 and 382 383 shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn] 384 385 then 386 387 shape(output) = [d0, d1, ..., dn-1, p1, ..., pm]. 388 389 For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are 390 391 [0, 0]: id 1, weight 2.0 392 [0, 1]: id 3, weight 0.5 393 [1, 0]: id 0, weight 1.0 394 [2, 3]: id 1, weight 3.0 395 396 with `combiner`="mean", then the output will be a 3x20 matrix where 397 398 output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) 399 output[1, :] = params[0, :] * 1.0 400 output[2, :] = params[1, :] * 3.0 401 402 Raises: 403 TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither 404 None nor SparseTensor. 405 ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}. 406 """ 407 if combiner is None: 408 logging.warn("The default value of combiner will change from \"mean\" " 409 "to \"sqrtn\" after 2016/11/01.") 410 combiner = "mean" 411 if combiner not in ("mean", "sqrtn", "sum"): 412 raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'") 413 if isinstance(params, variables.PartitionedVariable): 414 params = list(params) # Iterate to get the underlying Variables. 415 if not isinstance(params, list): 416 params = [params] 417 if not isinstance(sp_ids, sparse_tensor.SparseTensor): 418 raise TypeError("sp_ids must be SparseTensor") 419 ignore_weights = sp_weights is None 420 if not ignore_weights: 421 if not isinstance(sp_weights, sparse_tensor.SparseTensor): 422 raise TypeError("sp_weights must be either None or SparseTensor") 423 sp_ids.values.get_shape().assert_is_compatible_with( 424 sp_weights.values.get_shape()) 425 sp_ids.indices.get_shape().assert_is_compatible_with( 426 sp_weights.indices.get_shape()) 427 sp_ids.dense_shape.get_shape().assert_is_compatible_with( 428 sp_weights.dense_shape.get_shape()) 429 # TODO(yleon): Add enhanced node assertions to verify that sp_ids and 430 # sp_weights have equal indices and shapes. 431 432 with ops.name_scope(name, "embedding_lookup_sparse", 433 params + [sp_ids]) as name: 434 segment_ids = sp_ids.indices[:, 0] 435 if segment_ids.dtype != dtypes.int32: 436 segment_ids = math_ops.cast(segment_ids, dtypes.int32) 437 438 ids = sp_ids.values 439 if ignore_weights: 440 ids, idx = array_ops.unique(ids) 441 else: 442 idx = None 443 444 embeddings = embedding_lookup( 445 params, ids, partition_strategy=partition_strategy, max_norm=max_norm) 446 if not ignore_weights: 447 weights = sp_weights.values 448 if weights.dtype != embeddings.dtype: 449 weights = math_ops.cast(weights, embeddings.dtype) 450 451 # Reshape weights to allow broadcast 452 ones = array_ops.fill( 453 array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1) 454 bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 455 0) 456 457 orig_weights_shape = weights.get_shape() 458 weights = array_ops.reshape(weights, bcast_weights_shape) 459 460 # Set the weight shape, since after reshaping to bcast_weights_shape, 461 # the shape becomes None. 462 if embeddings.get_shape().ndims is not None: 463 weights.set_shape( 464 orig_weights_shape.concatenate( 465 [1 for _ in range(embeddings.get_shape().ndims - 1)])) 466 467 embeddings *= weights 468 469 if combiner == "sum": 470 embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name) 471 elif combiner == "mean": 472 embeddings = math_ops.segment_sum(embeddings, segment_ids) 473 weight_sum = math_ops.segment_sum(weights, segment_ids) 474 embeddings = math_ops.div(embeddings, weight_sum, name=name) 475 elif combiner == "sqrtn": 476 embeddings = math_ops.segment_sum(embeddings, segment_ids) 477 weights_squared = math_ops.pow(weights, 2) 478 weight_sum = math_ops.segment_sum(weights_squared, segment_ids) 479 weight_sum_sqrt = math_ops.sqrt(weight_sum) 480 embeddings = math_ops.div(embeddings, weight_sum_sqrt, name=name) 481 else: 482 assert False, "Unrecognized combiner" 483 else: 484 assert idx is not None 485 if combiner == "sum": 486 embeddings = math_ops.sparse_segment_sum( 487 embeddings, idx, segment_ids, name=name) 488 elif combiner == "mean": 489 embeddings = math_ops.sparse_segment_mean( 490 embeddings, idx, segment_ids, name=name) 491 elif combiner == "sqrtn": 492 embeddings = math_ops.sparse_segment_sqrt_n( 493 embeddings, idx, segment_ids, name=name) 494 else: 495 assert False, "Unrecognized combiner" 496 497 return embeddings 498