1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional tests for ops used with embeddings."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import itertools
22import math
23
24import numpy as np
25from six.moves import xrange  # pylint: disable=redefined-builtin
26
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import sparse_tensor
31from tensorflow.python.framework import test_util
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import data_flow_ops
34from tensorflow.python.ops import embedding_ops
35from tensorflow.python.ops import gradient_checker
36from tensorflow.python.ops import init_ops
37from tensorflow.python.ops import linalg_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import partitioned_variables
40from tensorflow.python.ops import state_ops
41from tensorflow.python.ops import variable_scope
42from tensorflow.python.ops import variables
43from tensorflow.python.platform import test
44from tensorflow.python.platform import tf_logging
45from tensorflow.python.util import compat
46
47
48def _AsLong(array):
49  """Casts arrays elements to long type. Used to convert from numpy tf."""
50  return [int(x) for x in array]
51
52
53class ScatterAddSubTest(test.TestCase):
54
55  def _TestCase(self, shape, indices, scatter_op=state_ops.scatter_add):
56    """Run a random test case with the given shape and indices.
57
58    Args:
59      shape: Shape of the parameters array.
60      indices: One-dimensional array of ints, the indices of the last dimension
61               of the parameters to update.
62      scatter_op: ScatterAdd or ScatterSub.
63    """
64    super(ScatterAddSubTest, self).setUp()
65    with self.cached_session(use_gpu=False):
66      # Create a random parameter array of given shape
67      p_init = np.random.rand(*shape).astype("f")
68      # Create the shape of the update array. All dimensions except the last
69      # match the parameter array, the last dimension equals the # of indices.
70      vals_shape = [len(indices)] + shape[1:]
71      vals_init = np.random.rand(*vals_shape).astype("f")
72      v_i = [float(x) for x in vals_init.ravel()]
73      p = variables.Variable(p_init)
74      vals = constant_op.constant(v_i, shape=vals_shape, name="vals")
75      ind = constant_op.constant(indices, dtype=dtypes.int32)
76      p2 = scatter_op(p, ind, vals, name="updated_p")
77      # p = init
78      variables.global_variables_initializer().run()
79      # p += vals
80      result = self.evaluate(p2)
81    # Compute the expected 'p' using numpy operations.
82    for i, ind in enumerate(indices):
83      if scatter_op == state_ops.scatter_add:
84        p_init.reshape(shape[0], -1)[ind, :] += (vals_init.reshape(
85            vals_shape[0], -1)[i, :])
86      else:
87        p_init.reshape(shape[0], -1)[ind, :] -= (vals_init.reshape(
88            vals_shape[0], -1)[i, :])
89    self.assertTrue(all((p_init == result).ravel()))
90
91  @test_util.run_deprecated_v1
92  def testNoRepetitions(self):
93    self._TestCase([2, 2], [1])
94    self._TestCase([4, 4, 4], [2, 0])
95    self._TestCase([43, 20, 10, 10], [42, 5, 6, 1, 3, 5, 7, 9])
96
97  @test_util.run_deprecated_v1
98  def testWithRepetitions(self):
99    self._TestCase([2, 2], [1, 1])
100    self._TestCase([5, 3, 9, 5], [2, 0, 4, 1, 3, 1, 4, 0, 4, 3])
101    self._TestCase([32, 4, 4], [31] * 8)
102
103  @test_util.run_deprecated_v1
104  def testRandom(self):
105    # Random shapes of rank 4, random indices
106    for _ in range(5):
107      shape = np.random.randint(1, 20, size=4)
108      indices = np.random.randint(shape[0], size=2 * shape[0])
109      self._TestCase(_AsLong(list(shape)), list(indices))
110
111  @test_util.run_deprecated_v1
112  def testSubRandom(self):
113    # Random shapes of rank 4, random indices
114    for _ in range(5):
115      shape = np.random.randint(1, 20, size=4)
116      indices = np.random.randint(shape[0], size=2 * shape[0])
117      self._TestCase(_AsLong(list(shape)), list(indices), state_ops.scatter_sub)
118
119  @test_util.run_deprecated_v1
120  def testWrongShape(self):
121    # Indices and values mismatch.
122    var = variables.Variable(
123        array_ops.zeros(shape=[1024, 64, 64], dtype=dtypes.float32))
124    indices = array_ops.placeholder(dtypes.int32, shape=[32])
125    values = array_ops.placeholder(dtypes.float32, shape=[33, 64, 64])
126    with self.assertRaises(ValueError):
127      state_ops.scatter_add(var, indices, values)
128
129    # Var and values mismatch.
130    values = array_ops.placeholder(dtypes.float32, shape=[32, 64, 63])
131    with self.assertRaises(ValueError):
132      state_ops.scatter_add(var, indices, values)
133
134
135def _PName(param_id):
136  return "p" + str(param_id)
137
138
139def _EmbeddingParams(num_shards,
140                     vocab_size,
141                     dtype=dtypes.float32,
142                     shape=None,
143                     use_shapeless_placeholder=False):
144  p = []
145  params = {}
146  feed_dict = {}
147  if not shape:
148    shape = [10]
149  for i in range(num_shards):
150    shard_shape = [vocab_size // num_shards] + shape
151    if i < vocab_size % num_shards:  # Excess goes evenly on the first shards
152      shard_shape[0] += 1
153
154    param_name = _PName(i)
155
156    if use_shapeless_placeholder:
157      param = array_ops.placeholder(dtype, shape=None, name=param_name)
158    else:
159      param = constant_op.constant(
160          1.0, shape=shard_shape, dtype=dtype, name=param_name)
161    p.append(param)
162    np_type = "f" if dtype == dtypes.float32 else "d"
163    val = (np.random.rand(*shard_shape).astype(np_type)) + 1
164    params[param_name + ":0"] = val
165    feed_dict[param.name] = val
166  return p, params, feed_dict
167
168
169def _EmbeddingParamsAsPartitionedVariable(num_shards,
170                                          vocab_size,
171                                          dtype=dtypes.float32,
172                                          shape=None,
173                                          use_resource=False):
174  p, params, feed_dict = _EmbeddingParams(
175      num_shards, vocab_size, dtype=dtype, shape=shape)
176  shape = shape or [10]
177  partitioned_variable = variable_scope.get_variable(
178      "p",
179      shape=[vocab_size] + shape,
180      initializer=array_ops.concat([params[p_i.name] for p_i in p], 0),
181      partitioner=partitioned_variables.min_max_variable_partitioner(
182          max_partitions=num_shards, min_slice_size=1),
183      use_resource=use_resource)
184  return p, partitioned_variable, params, feed_dict
185
186
187def _EmbeddingResult(params,
188                     id_vals,
189                     num_shards,
190                     vocab_size,
191                     partition_strategy="mod",
192                     weight_vals=None):
193  if weight_vals is None:
194    weight_vals = np.copy(id_vals)
195    weight_vals.fill(1)
196  values = []
197  weights = []
198  weights_squared = []
199  for ids, wts in zip(id_vals, weight_vals):
200    value_aggregation = None
201    weight_aggregation = None
202    squared_weight_aggregation = None
203    if isinstance(ids, compat.integral_types):
204      ids = [ids]
205      wts = [wts]
206    for i, weight_value in zip(ids, wts):
207      if partition_strategy == "mod":
208        val = np.copy(params[_PName(i % num_shards) + ":0"][
209            i // num_shards, :]) * weight_value
210      elif partition_strategy == "div":
211        ids_per_partition, extras = divmod(vocab_size, num_shards)
212        threshold = extras * (ids_per_partition + 1)
213        if i < threshold:
214          partition = i // (ids_per_partition + 1)
215          offset = i % (ids_per_partition + 1)
216        else:
217          partition = extras + (i - threshold) // ids_per_partition
218          offset = (i - threshold) % ids_per_partition
219        val = np.copy(
220            params[_PName(partition) + ":0"][offset, :]) * weight_value
221      else:
222        assert False
223      if value_aggregation is None:
224        assert weight_aggregation is None
225        assert squared_weight_aggregation is None
226        value_aggregation = val
227        weight_aggregation = weight_value
228        squared_weight_aggregation = weight_value * weight_value
229      else:
230        assert weight_aggregation is not None
231        assert squared_weight_aggregation is not None
232        value_aggregation += val
233        weight_aggregation += weight_value
234        squared_weight_aggregation += weight_value * weight_value
235    values.append(value_aggregation)
236    weights.append(weight_aggregation)
237    weights_squared.append(squared_weight_aggregation)
238  values = np.array(values).astype(np.float32)
239  weights = np.array(weights).astype(np.float32)
240  weights_squared = np.array(weights_squared).astype(np.float32)
241  return values, weights, weights_squared
242
243
244class EmbeddingLookupTest(test.TestCase):
245
246  # This test looks up [0, 0] in a parameter matrix sharded 2 ways. Since
247  # both the ids are in the first shard, one of the resulting lookup
248  # vector is going to be empty. The subsequent DivOp fails because of that.
249  # TODO(keveman): Disabling the test until the underlying problem is fixed.
250  @test_util.run_deprecated_v1
251  def testSimpleSharded(self):
252    with self.cached_session():
253      num_shards = 2
254      vocab_size = 4
255      p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
256
257      id_vals = np.array([0, 0])
258      ids = constant_op.constant(list(id_vals), dtype=dtypes.int32)
259      print("Construct ids", ids.get_shape())
260      embedding = embedding_ops.embedding_lookup(p, ids)
261
262      tf_result = embedding.eval(feed_dict=feed_dict)
263    np_result, _, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size)
264    self.assertAllEqual(np_result, tf_result)
265    self.assertShapeEqual(np_result, embedding)
266
267  @test_util.run_deprecated_v1
268  def testMaxNorm(self):
269    with self.cached_session():
270      embeddings = constant_op.constant([[2.0]])
271
272      ids = constant_op.constant([0], dtype=dtypes.int32)
273      embedding = embedding_ops.embedding_lookup(
274          [embeddings], ids, max_norm=1.0)
275
276      self.assertAllEqual(embedding.eval(), [[1.0]])
277
278  @test_util.run_deprecated_v1
279  def testMaxNormNontrivial(self):
280    with self.cached_session():
281      embeddings = constant_op.constant([[2.0, 4.0], [3.0, 1.0]])
282
283      ids = constant_op.constant([0, 1], dtype=dtypes.int32)
284      embedding = embedding_ops.embedding_lookup(
285          [embeddings], ids, max_norm=2.0)
286
287      norms = math_ops.sqrt(
288          math_ops.reduce_sum(embeddings * embeddings, axis=1))
289      normalized = embeddings / array_ops.stack([norms, norms], axis=1)
290      self.assertAllEqual(embedding.eval(), 2 * self.evaluate(normalized))
291
292  @test_util.run_deprecated_v1
293  def testSimpleShardedPartitionedVariable(self):
294    with self.cached_session() as sess:
295      num_shards = 2
296      vocab_size = 4
297      p, p_variable, params, feed_dict = _EmbeddingParamsAsPartitionedVariable(
298          num_shards, vocab_size)
299
300      id_vals = np.array([0, 0])
301      ids = constant_op.constant(list(id_vals), dtype=dtypes.int32)
302      print("Construct ids", ids.get_shape())
303      embedding = embedding_ops.embedding_lookup(p_variable, ids)
304      variables.global_variables_initializer().run()
305      params_values = [params[p_i.name] for p_i in p]
306      # Test that the PartitionedVariable components equal the list in p
307      p_var_val = self.evaluate(list(p_variable))
308      # Actual test
309      tf_result = embedding.eval(feed_dict=feed_dict)
310    np_result, _, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size)
311    self.assertAllEqual(params_values, p_var_val)
312    self.assertAllEqual(np_result, tf_result)
313    self.assertShapeEqual(np_result, embedding)
314
315  @test_util.run_deprecated_v1
316  def testSimpleShardedPartitionedResourceVariable(self):
317    with self.cached_session() as sess:
318      num_shards = 2
319      vocab_size = 4
320      p, p_variable, params, _ = _EmbeddingParamsAsPartitionedVariable(
321          num_shards, vocab_size, use_resource=True)
322
323      id_vals = np.array([0, 0])
324      ids = constant_op.constant(list(id_vals), dtype=dtypes.int32)
325      print("Construct ids", ids.get_shape())
326      embedding = embedding_ops.embedding_lookup(p_variable, ids)
327      variables.global_variables_initializer().run()
328      params_values = [params[p_i.name] for p_i in p]
329      # Test that the PartitionedVariable components equal the list in p
330      p_var_val = self.evaluate(list(p_variable))
331      # Actual test
332      print(ops.get_default_graph().as_graph_def())
333      tf_result = self.evaluate(embedding)
334    np_result, _, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size)
335    self.assertAllEqual(params_values, p_var_val)
336    self.assertAllEqual(np_result, tf_result)
337    self.assertShapeEqual(np_result, embedding)
338
339  @test_util.run_deprecated_v1
340  def testShardedModPartitioningInt32Ids(self):
341    with self.cached_session():
342      num_shards = 5
343      vocab_size = 13
344      # Embedding dimensions is 10. The vocab_size x 10 embedding
345      # parameters are spread in num_shards matrices, so the first
346      # 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
347      p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
348
349      num_vals = 30
350      # Fetch num_vals embeddings for random word ids. Since
351      # num_vals > vocab_size, this ought to have repetitions, so
352      # will test that aspect.
353      id_vals = np.random.randint(vocab_size, size=num_vals)
354      ids = constant_op.constant(list(id_vals), dtype=dtypes.int32)
355
356      embedding = embedding_ops.embedding_lookup(p, ids)
357      tf_result = embedding.eval(feed_dict=feed_dict)
358    np_result, _, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size)
359    self.assertAllEqual(np_result, tf_result)
360    self.assertShapeEqual(np_result, embedding)
361
362  @test_util.run_deprecated_v1
363  def testShardedModPartitioningInt64Ids(self):
364    with self.cached_session():
365      num_shards = 5
366      vocab_size = 13
367      # Embedding dimensions is 10. The vocab_size x 10 embedding
368      # parameters are spread in num_shards matrices, so the first
369      # 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
370      p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
371
372      num_vals = 30
373      # Fetch num_vals embeddings for random word ids. Since
374      # num_vals > vocab_size, this ought to have repetitions, so
375      # will test that aspect.
376      id_vals = np.random.randint(vocab_size, size=num_vals)
377      ids = constant_op.constant(list(id_vals), dtype=dtypes.int64)
378
379      embedding = embedding_ops.embedding_lookup(p, ids)
380      tf_result = embedding.eval(feed_dict=feed_dict)
381    np_result, _, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size)
382    self.assertAllEqual(np_result, tf_result)
383    self.assertShapeEqual(np_result, embedding)
384
385  @test_util.run_deprecated_v1
386  def testShardedDivPartitioningInt32Ids(self):
387    with self.cached_session():
388      num_shards = 5
389      vocab_size = 13
390      # Embedding dimensions is 10. The vocab_size x 10 embedding
391      # parameters are spread in num_shards matrices, so the first
392      # 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
393      p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
394
395      num_vals = 30
396      # Fetch num_vals embeddings for random word ids. Since
397      # num_vals > vocab_size, this ought to have repetitions, so
398      # will test that aspect.
399      id_vals = np.random.randint(vocab_size, size=num_vals)
400      ids = constant_op.constant(list(id_vals), dtype=dtypes.int32)
401
402      embedding = embedding_ops.embedding_lookup(
403          p, ids, partition_strategy="div")
404      tf_result = embedding.eval(feed_dict=feed_dict)
405    np_result, _, _ = _EmbeddingResult(
406        params, id_vals, num_shards, vocab_size, partition_strategy="div")
407    self.assertAllEqual(np_result, tf_result)
408    self.assertShapeEqual(np_result, embedding)
409
410  @test_util.run_deprecated_v1
411  def testShardedDivPartitioningInt32IdsPartitionedVariable(self):
412    with self.cached_session():
413      num_shards = 5
414      vocab_size = 13
415      # Embedding dimensions is 10. The vocab_size x 10 embedding
416      # parameters are spread in num_shards matrices, so the first
417      # 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
418      _, p_variable, params, feed_dict = _EmbeddingParamsAsPartitionedVariable(
419          num_shards, vocab_size)
420
421      num_vals = 30
422      # Fetch num_vals embeddings for random word ids. Since
423      # num_vals > vocab_size, this ought to have repetitions, so
424      # will test that aspect.
425      id_vals = np.random.randint(vocab_size, size=num_vals)
426      ids = constant_op.constant(list(id_vals), dtype=dtypes.int32)
427      variables.global_variables_initializer().run()
428      embedding = embedding_ops.embedding_lookup(
429          p_variable, ids, partition_strategy="div")
430      tf_result = embedding.eval(feed_dict=feed_dict)
431    np_result, _, _ = _EmbeddingResult(
432        params, id_vals, num_shards, vocab_size, partition_strategy="div")
433    self.assertAllEqual(np_result, tf_result)
434    self.assertShapeEqual(np_result, embedding)
435
436  @test_util.run_deprecated_v1
437  def testShardedDivPartitioningInt64Ids(self):
438    with self.cached_session():
439      num_shards = 5
440      vocab_size = 13
441      # Embedding dimensions is 10. The vocab_size x 10 embedding
442      # parameters are spread in num_shards matrices, so the first
443      # 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
444      p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
445
446      num_vals = 30
447      # Fetch num_vals embeddings for random word ids. Since
448      # num_vals > vocab_size, this ought to have repetitions, so
449      # will test that aspect.
450      id_vals = np.random.randint(vocab_size, size=num_vals)
451      ids = constant_op.constant(list(id_vals), dtype=dtypes.int64)
452
453      embedding = embedding_ops.embedding_lookup(
454          p, ids, partition_strategy="div")
455      tf_result = embedding.eval(feed_dict=feed_dict)
456    np_result, _, _ = _EmbeddingResult(
457        params, id_vals, num_shards, vocab_size, partition_strategy="div")
458    self.assertAllEqual(np_result, tf_result)
459    self.assertShapeEqual(np_result, embedding)
460
461  @test_util.run_deprecated_v1
462  def testShardedDivPartitioningUnknownParamShape(self):
463    with self.cached_session():
464      num_shards = 5
465      vocab_size = 13
466      # Embedding dimensions is 10. The vocab_size x 10 embedding
467      # parameters are spread in num_shards matrices, so the first
468      # 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
469
470      # We clear parameter shapes, to test when shape is not statically known.
471      p, params, feed_dict = _EmbeddingParams(
472          num_shards, vocab_size, use_shapeless_placeholder=True)
473
474      num_vals = 30
475      # Fetch num_vals embeddings for random word ids. Since
476      # num_vals > vocab_size, this ought to have repetitions, so
477      # will test that aspect.
478      id_vals = np.random.randint(vocab_size, size=num_vals)
479      ids = constant_op.constant(list(id_vals), dtype=dtypes.int64)
480
481      embedding = embedding_ops.embedding_lookup(
482          p, ids, partition_strategy="div")
483      tf_result = embedding.eval(feed_dict=feed_dict)
484    np_result, _, _ = _EmbeddingResult(
485        params, id_vals, num_shards, vocab_size, partition_strategy="div")
486    self.assertAllEqual(np_result, tf_result)
487
488  @test_util.run_deprecated_v1
489  def testGradientsEmbeddingLookup(self):
490    vocab_size = 9
491    num_ids = 10
492    id_vals = list(np.random.randint(vocab_size, size=num_ids))
493    tf_logging.vlog(1, id_vals)
494    for ids_shape in [(10,), (2, 5)]:
495      for num_shards in [1, 3]:
496        with self.cached_session():
497          ids = constant_op.constant(
498              id_vals, shape=ids_shape, dtype=dtypes.int32)
499          x, params, _ = _EmbeddingParams(num_shards, vocab_size, shape=[2])
500          y = embedding_ops.embedding_lookup(x, ids)
501          y_shape = ids_shape + tuple(params[_PName(0) + ":0"].shape[1:])
502          x_name = [_PName(i) for i in range(num_shards)]
503          x_init_value = [params[x_n + ":0"] for x_n in x_name]
504          x_shape = [i.shape for i in x_init_value]
505          err = gradient_checker.compute_gradient_error(
506              x, x_shape, y, y_shape, x_init_value=x_init_value)
507        self.assertLess(err, 1e-4)
508
509  @test_util.run_deprecated_v1
510  def testGradientsEmbeddingLookupWithComputedParams(self):
511    vocab_size = 9
512    num_ids = 5
513    id_vals = list(np.random.randint(vocab_size, size=num_ids))
514    tf_logging.vlog(1, id_vals)
515    for num_shards in [1, 3]:
516      with self.cached_session():
517        ids = constant_op.constant(id_vals, dtype=dtypes.int32)
518        x, params, _ = _EmbeddingParams(num_shards, vocab_size, shape=[2])
519        # This will force a conversion from IndexedSlices to Tensor.
520        x_squared = [math_ops.square(elem) for elem in x]
521        y = embedding_ops.embedding_lookup(x_squared, ids)
522        y_shape = [num_ids] + list(params[_PName(0) + ":0"].shape[1:])
523        x_name = [_PName(i) for i in range(num_shards)]
524        x_init_value = [params[x_n + ":0"] for x_n in x_name]
525        x_shape = [i.shape for i in x_init_value]
526        err = gradient_checker.compute_gradient_error(
527            x, x_shape, y, y_shape, x_init_value=x_init_value)
528      self.assertLess(err, 1e-3)
529
530  def testConstructionNonSharded(self):
531    with ops.Graph().as_default():
532      p = variables.Variable(
533          array_ops.zeros(shape=[100, 100], dtype=dtypes.float32))
534      ids = constant_op.constant([0, 1, 1, 7], dtype=dtypes.int32)
535      embedding_ops.embedding_lookup([p], ids)
536
537  def testConstructionSharded(self):
538    with ops.Graph().as_default():
539      p = []
540      for _ in range(2):
541        p += [
542            variables.Variable(
543                array_ops.zeros(shape=[100, 100], dtype=dtypes.float32))
544        ]
545        ids = constant_op.constant([0, 1, 1, 17], dtype=dtypes.int32)
546      embedding_ops.embedding_lookup(p, ids)
547
548  @test_util.run_deprecated_v1
549  def testHigherRank(self):
550    np.random.seed(8)
551    with self.cached_session():
552      for params_shape in (12,), (6, 3):
553        params = np.random.randn(*params_shape)
554        for ids_shape in (3, 2), (4, 3):
555          ids = np.random.randint(
556              params.shape[0], size=np.prod(ids_shape)).reshape(ids_shape)
557          # Compare nonsharded to gather
558          simple = embedding_ops.embedding_lookup(params, ids).eval()
559          self.assertAllEqual(simple, array_ops.gather(params, ids).eval())
560          # Run a few random sharded versions
561          for procs in 1, 2, 3:
562            stride = procs * math_ops.range(params.shape[0] // procs)
563            split_params = [
564                array_ops.gather(params, stride + p) for p in xrange(procs)
565            ]
566            sharded = embedding_ops.embedding_lookup(split_params, ids).eval()
567            self.assertAllEqual(simple, sharded)
568
569  @test_util.run_deprecated_v1
570  def testHigherRankMaxNorm(self):
571    np.random.seed(8)
572    with self.cached_session():
573      for params_shape in (12,), (6, 3), (6, 2, 3):
574        # Test embedding rank 0, 1, 2.
575        # Note: the first dimension must be a common multiple of procs below.
576        params = 2 * np.ones(params_shape)
577        params_norm = params / np.sqrt(
578            np.sum(
579                params * params, tuple(range(params.ndim)[1:]), keepdims=True))
580        for ids_shape in (), (3), (4, 3), (2, 3, 4):
581          ids = np.random.randint(
582              params.shape[0], size=np.prod(ids_shape,
583                                            dtype=np.int64)).reshape(ids_shape)
584          # Compare nonsharded to gather
585          simple = embedding_ops.embedding_lookup(
586              params, ids, max_norm=1.0).eval()
587          # assertAllClose is used here as different implementations of sqrt may
588          # be used to compute each of the values being compared.  For example,
589          # on AVX512 builds the embedding operation makes use of Eigen's fast
590          # vectorized square root algorithm for doubles.  These different
591          # implementations of sqrt are not guaranteed to produce exactly the
592          # same results. Therefore, an exact comparison cannot be made.
593          self.assertAllClose(simple, array_ops.gather(params_norm, ids).eval())
594          # Run a few different sharded versions.
595          for procs in 1, 2, 3:
596            stride = procs * math_ops.range(params.shape[0] // procs)
597            split_params = [
598                array_ops.gather(params, stride + p) for p in xrange(procs)
599            ]
600            sharded = embedding_ops.embedding_lookup(
601                split_params, ids, max_norm=1.0).eval()
602            self.assertAllEqual(simple, sharded)
603
604  @test_util.run_deprecated_v1
605  def testTransform(self):
606    # This tests all combinations of:
607    #   - ids rank 0, 1, >1
608    #   - params sharded/unsharded
609    # It always applies max_norm.
610    np.random.seed(8)
611    l2_norm = 2.
612    with self.cached_session():
613      # Param values are in [l2_norm, l2_norm+1) so it will always clip.
614      params = np.random.rand(6, 3) + l2_norm
615      params_norm = l2_norm * params / np.sqrt(
616          np.sum(params * params, axis=1, keepdims=True))
617      # Compute the norm of each embedding. This will change the embedding
618      # rank to 0.
619      params_norm = np.linalg.norm(params_norm, axis=1)
620      transform = lambda x: linalg_ops.norm(x, axis=1)
621      for ids_shape in (), (3), (4, 3), (2, 3, 4):
622        # Test ids rank 0, 1, 2, 3.
623        ids = np.random.randint(
624            params.shape[0], size=np.prod(ids_shape,
625                                          dtype=np.int64)).reshape(ids_shape)
626        # Compare nonsharded to gather.
627        simple = embedding_ops._embedding_lookup_and_transform(
628            params, ids, max_norm=l2_norm, transform_fn=transform).eval()
629        self.assertAllClose(simple, array_ops.gather(params_norm, ids).eval())
630        # Run a few different sharded versions.
631        for procs in 1, 2, 3:
632          stride = procs * math_ops.range(params.shape[0] // procs)
633          split_params = [
634              array_ops.gather(params, stride + p) for p in xrange(procs)
635          ]
636          sharded = embedding_ops._embedding_lookup_and_transform(
637              split_params, ids, max_norm=l2_norm,
638              transform_fn=transform).eval()
639          # assertAllClose is used here as different implementations of sqrt may
640          # be used to compute each of the values being compared.  For example,
641          # on AVX512 builds the embedding operation makes use of Eigen's fast
642          # vectorized square root algorithm for doubles.  These different
643          # implementations of sqrt are not guaranteed to produce exactly the
644          # same results. Therefore, an exact comparison cannot be made.
645          self.assertAllClose(simple, sharded)
646
647
648class EmbeddingLookupSparseTest(test.TestCase):
649
650  def _RandomIdsAndWeights(self, batch_size, vocab_size):
651    max_val_per_entry = 6
652    vals_per_batch_entry = np.random.randint(
653        1, max_val_per_entry, size=batch_size)
654    num_vals = np.sum(vals_per_batch_entry)
655
656    ids = np.random.randint(vocab_size, size=num_vals)
657    weights = 1 + np.random.rand(num_vals)
658
659    indices = []
660    for batch_entry, num_val in enumerate(vals_per_batch_entry):
661      for val_index in range(num_val):
662        indices.append([batch_entry, val_index])
663
664    shape = [batch_size, max_val_per_entry]
665
666    sp_ids = sparse_tensor.SparseTensor(
667        constant_op.constant(indices, dtypes.int64),
668        constant_op.constant(ids, dtypes.int32),
669        constant_op.constant(shape, dtypes.int64))
670    sp_weights = sparse_tensor.SparseTensor(
671        constant_op.constant(indices, dtypes.int64),
672        constant_op.constant(weights, dtypes.float32),
673        constant_op.constant(shape, dtypes.int64))
674
675    return sp_ids, sp_weights, ids, weights, vals_per_batch_entry
676
677  def _GroupByBatchEntry(self, vals, vals_per_batch_entry):
678    grouped_vals = []
679    index = 0
680    for num_val in vals_per_batch_entry:
681      grouped_vals.append(list(vals[index:(index + num_val)]))
682      index += num_val
683    return grouped_vals
684
685  @test_util.run_deprecated_v1
686  def testEmbeddingLookupSparse(self):
687    vocab_size = 13
688    batch_size = 10
689    param_shape = [2, 5]
690    expected_lookup_result_shape = [None] + param_shape
691
692    sp_ids, sp_weights, ids, weights, vals_per_batch_entry = (
693        self._RandomIdsAndWeights(batch_size, vocab_size))
694
695    grouped_ids = self._GroupByBatchEntry(ids, vals_per_batch_entry)
696    grouped_weights = self._GroupByBatchEntry(weights, vals_per_batch_entry)
697    grouped_ignored_weights = self._GroupByBatchEntry(
698        np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry)
699
700    for num_shards, combiner, dtype, ignore_weights in itertools.product(
701        [1, 5], ["sum", "mean", "sqrtn"],
702        [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64],
703        [True, False]):
704
705      with self.cached_session():
706        p, params, feed_dict = _EmbeddingParams(
707            num_shards, vocab_size, shape=param_shape, dtype=dtype)
708        embedding_sum = embedding_ops.embedding_lookup_sparse(
709            p,
710            sp_ids,
711            None if ignore_weights else sp_weights,
712            combiner=combiner)
713
714        self.assertEqual(embedding_sum.get_shape().as_list(),
715                         expected_lookup_result_shape)
716        if dtype in (dtypes.float16, dtypes.bfloat16):
717          self.assertEqual(embedding_sum.dtype, dtypes.float32)
718        else:
719          self.assertEqual(embedding_sum.dtype, dtype)
720
721        tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict)
722
723        np_embedding_sum, np_weight_sum, np_weight_sq_sum = _EmbeddingResult(
724            params,
725            grouped_ids,
726            num_shards,
727            vocab_size,
728            weight_vals=grouped_ignored_weights
729            if ignore_weights else grouped_weights)
730        if combiner == "mean":
731          np_embedding_sum /= np.reshape(np_weight_sum, (batch_size, 1, 1))
732        if combiner == "sqrtn":
733          np_embedding_sum /= np.reshape(
734              np.sqrt(np_weight_sq_sum), (batch_size, 1, 1))
735
736        rtol = 1e-6
737        if dtype == dtypes.bfloat16:
738          rtol = 1e-2
739        elif dtype == dtypes.float16:
740          rtol = 1e-3
741        atol = rtol
742        self.assertAllClose(np_embedding_sum, tf_embedding_sum, rtol, atol)
743
744  @test_util.run_deprecated_v1
745  def testGradientsEmbeddingLookupSparse(self):
746    vocab_size = 12
747    batch_size = 4
748    param_shape = [2, 3]
749    sp_ids, sp_weights, _, _, _ = (self._RandomIdsAndWeights(
750        batch_size, vocab_size))
751
752    for num_shards, combiner, dtype, ignore_weights in itertools.product(
753        [1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32,
754                                           dtypes.float64], [True, False]):
755      with self.cached_session():
756        x, params, _ = _EmbeddingParams(
757            num_shards, vocab_size, shape=param_shape, dtype=dtype)
758
759        y = embedding_ops.embedding_lookup_sparse(
760            x,
761            sp_ids,
762            None if ignore_weights else sp_weights,
763            combiner=combiner)
764        x_name = [_PName(i) for i in range(num_shards)]
765        x_init_value = [params[x_n + ":0"] for x_n in x_name]
766        x_shape = [i.shape for i in x_init_value]
767        y_shape = [batch_size] + list(params[_PName(0) + ":0"].shape[1:])
768        err = gradient_checker.compute_gradient_error(
769            x, x_shape, y, y_shape, x_init_value=x_init_value)
770      self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3)
771
772  @test_util.run_deprecated_v1
773  def testIncompatibleShapes(self):
774    with self.cached_session():
775      x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32)
776      sp_ids = sparse_tensor.SparseTensor(
777          constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64),
778          constant_op.constant([0, 1, 2], dtypes.int32),
779          constant_op.constant([2, 2], dtypes.int64))
780      sp_weights = sparse_tensor.SparseTensor(
781          constant_op.constant([[0, 0], [0, 1]], dtypes.int64),
782          constant_op.constant([12.0, 5.0], dtypes.float32),
783          constant_op.constant([1, 2], dtypes.int64))
784
785      with self.assertRaises(ValueError):
786        embedding_ops.embedding_lookup_sparse(
787            x, sp_ids, sp_weights, combiner="mean")
788
789
790class SafeEmbeddingLookupSparseTest(test.TestCase):
791
792  def _random_weights(self, vocab_size=4, embed_dim=4, num_shards=1):
793    assert vocab_size > 0
794    assert embed_dim > 0
795    assert num_shards > 0
796    assert num_shards <= vocab_size
797
798    initializer = init_ops.truncated_normal_initializer(
799        mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32)
800    embedding_weights = list(variable_scope.get_variable(
801        name="embedding_weights",
802        shape=[vocab_size, embed_dim],
803        partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
804        initializer=initializer))
805    for w in embedding_weights:
806      w.initializer.run()
807    embedding_weights = [w.eval() for w in embedding_weights]
808    return embedding_weights
809
810  def _ids_and_weights_2d(self):
811    # Each row demonstrates a test case:
812    #   Row 0: multiple valid ids, 1 invalid id, weighted mean
813    #   Row 1: all ids are invalid (leaving no valid ids after pruning)
814    #   Row 2: no ids to begin with
815    #   Row 3: single id
816    #   Row 4: all ids have <=0 weight
817    indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0], [4, 0], [4, 1]]
818    ids = [0, 1, -1, -1, 2, 0, 1]
819    weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5]
820    shape = [5, 4]
821
822    sparse_ids = sparse_tensor.SparseTensor(
823        constant_op.constant(indices, dtypes.int64),
824        constant_op.constant(ids, dtypes.int64),
825        constant_op.constant(shape, dtypes.int64))
826
827    sparse_weights = sparse_tensor.SparseTensor(
828        constant_op.constant(indices, dtypes.int64),
829        constant_op.constant(weights, dtypes.float32),
830        constant_op.constant(shape, dtypes.int64))
831
832    return sparse_ids, sparse_weights
833
834  def _ids_and_weights_3d(self):
835    # Each (2-D) index demonstrates a test case:
836    #   Index 0, 0: multiple valid ids, 1 invalid id, weighted mean
837    #   Index 0, 1: all ids are invalid (leaving no valid ids after pruning)
838    #   Index 0, 2: no ids to begin with
839    #   Index 1, 0: single id
840    #   Index 1, 1: all ids have <=0 weight
841    #   Index 1, 2: no ids to begin with
842    indices = [[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 1, 0], [1, 0, 0], [1, 1, 0],
843               [1, 1, 1]]
844    ids = [0, 1, -1, -1, 2, 0, 1]
845    weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5]
846    shape = [2, 3, 4]
847
848    sparse_ids = sparse_tensor.SparseTensor(
849        constant_op.constant(indices, dtypes.int64),
850        constant_op.constant(ids, dtypes.int64),
851        constant_op.constant(shape, dtypes.int64))
852
853    sparse_weights = sparse_tensor.SparseTensor(
854        constant_op.constant(indices, dtypes.int64),
855        constant_op.constant(weights, dtypes.float32),
856        constant_op.constant(shape, dtypes.int64))
857
858    return sparse_ids, sparse_weights
859
860  @test_util.run_deprecated_v1
861  def test_safe_embedding_lookup_sparse_return_zero_vector(self):
862    with self.cached_session():
863      embedding_weights = self._random_weights()
864      sparse_ids, sparse_weights = self._ids_and_weights_2d()
865
866      embedding_lookup_result = (
867          embedding_ops.safe_embedding_lookup_sparse_v2(
868              embedding_weights, sparse_ids, sparse_weights).eval())
869
870      self.assertAllClose(
871          embedding_lookup_result,
872          [(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) /
873           3.0, [0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4])
874
875  @test_util.run_deprecated_v1
876  def test_safe_embedding_lookup_sparse_return_special_vector(self):
877    with self.cached_session():
878      embedding_weights = self._random_weights()
879      sparse_ids, sparse_weights = self._ids_and_weights_2d()
880
881      embedding_lookup_result = (
882          embedding_ops.safe_embedding_lookup_sparse_v2(
883              embedding_weights, sparse_ids, sparse_weights,
884              default_id=3).eval())
885
886      self.assertAllClose(
887          embedding_lookup_result,
888          [(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) /
889           3.0, embedding_weights[0][3], embedding_weights[0][3],
890           embedding_weights[0][2], embedding_weights[0][3]])
891
892  @test_util.run_deprecated_v1
893  def test_safe_embedding_lookup_sparse_no_weights(self):
894    with self.cached_session():
895      embedding_weights = self._random_weights()
896      sparse_ids, _ = self._ids_and_weights_2d()
897
898      embedding_lookup_result = (
899          embedding_ops.safe_embedding_lookup_sparse_v2(
900              embedding_weights, sparse_ids, None).eval())
901
902      self.assertAllClose(
903          embedding_lookup_result,
904          [(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4,
905           [0] * 4, embedding_weights[0][2], (
906               embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
907
908  @test_util.run_deprecated_v1
909  def test_safe_embedding_lookup_sparse_partitioned(self):
910    with self.cached_session():
911      embedding_weights = self._random_weights(num_shards=3)
912      sparse_ids, _ = self._ids_and_weights_2d()
913
914      embedding_lookup_result = (
915          embedding_ops.safe_embedding_lookup_sparse_v2(
916              embedding_weights, sparse_ids, None).eval())
917
918      embedding_weights = list(itertools.chain(*embedding_weights))
919      self.assertAllClose(embedding_lookup_result,
920                          [(embedding_weights[0] + embedding_weights[1]) / 2.0,
921                           [0] * 4, [0] * 4, embedding_weights[2],
922                           (embedding_weights[0] + embedding_weights[1]) / 2.0])
923
924  @test_util.run_deprecated_v1
925  def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self):
926    with self.cached_session():
927      embedding_weights = self._random_weights(num_shards=3)
928      sparse_ids, sparse_weights = self._ids_and_weights_2d()
929
930      embedding_weights[1] = embedding_weights[1].astype(np.float64)
931      self.assertRaises(TypeError, embedding_ops.safe_embedding_lookup_sparse,
932                        embedding_weights, sparse_ids)
933      embedding_weights = [
934          constant_op.constant(w, dtype=dtypes.float64)
935          for w in embedding_weights
936      ]
937      self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse,
938                        embedding_weights, sparse_ids, sparse_weights)
939
940  @test_util.run_deprecated_v1
941  def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
942    with self.cached_session():
943      embedding_weights = self._random_weights()
944      sparse_ids, sparse_weights = self._ids_and_weights_3d()
945
946      embedding_lookup_result = (
947          embedding_ops.safe_embedding_lookup_sparse_v2(
948              embedding_weights, sparse_ids, sparse_weights).eval())
949
950      self.assertAllClose(embedding_lookup_result, [[
951          (1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0,
952          [0] * 4, [0] * 4
953      ], [embedding_weights[0][2], [0] * 4, [0] * 4]])
954
955  @test_util.run_deprecated_v1
956  def test_safe_embedding_lookup_sparse_3d_return_special_vector(self):
957    with self.cached_session():
958      embedding_weights = self._random_weights()
959      sparse_ids, sparse_weights = self._ids_and_weights_3d()
960
961      embedding_lookup_result = (
962          embedding_ops.safe_embedding_lookup_sparse_v2(
963              embedding_weights, sparse_ids, sparse_weights,
964              default_id=3).eval())
965
966      self.assertAllClose(
967          embedding_lookup_result,
968          [[(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) /
969            3.0, embedding_weights[0][3], embedding_weights[0][3]], [
970                embedding_weights[0][2], embedding_weights[0][3],
971                embedding_weights[0][3]
972            ]])
973
974  @test_util.run_deprecated_v1
975  def test_safe_embedding_lookup_sparse_3d_no_weights(self):
976    with self.cached_session():
977      embedding_weights = self._random_weights()
978      sparse_ids, _ = self._ids_and_weights_3d()
979
980      embedding_lookup_result = (
981          embedding_ops.safe_embedding_lookup_sparse_v2(
982              embedding_weights, sparse_ids, None).eval())
983
984      self.assertAllClose(embedding_lookup_result, [[(
985          embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, [
986              0
987          ] * 4], [
988              embedding_weights[0][2],
989              (embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4
990          ]])
991
992  @test_util.run_deprecated_v1
993  def test_safe_embedding_lookup_sparse_3d_partitioned(self):
994    with self.cached_session():
995      embedding_weights = self._random_weights(num_shards=3)
996      sparse_ids, _ = self._ids_and_weights_3d()
997
998      embedding_lookup_result = (
999          embedding_ops.safe_embedding_lookup_sparse_v2(
1000              embedding_weights, sparse_ids, None).eval())
1001
1002      embedding_weights = list(itertools.chain(*embedding_weights))
1003      self.assertAllClose(embedding_lookup_result, [[
1004          (embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4, [0] * 4
1005      ], [
1006          embedding_weights[2],
1007          (embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4
1008      ]])
1009
1010  @test_util.run_deprecated_v1
1011  def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights(
1012      self):
1013    with self.cached_session():
1014      embedding_weights = self._random_weights(num_shards=3)
1015      sparse_ids, sparse_weights = self._ids_and_weights_3d()
1016
1017      embedding_weights[1] = embedding_weights[1].astype(np.float64)
1018      self.assertRaises(TypeError, embedding_ops.safe_embedding_lookup_sparse,
1019                        embedding_weights, sparse_ids)
1020      embedding_weights = [
1021          constant_op.constant(w, dtype=dtypes.float64)
1022          for w in embedding_weights
1023      ]
1024      self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse,
1025                        embedding_weights, sparse_ids, sparse_weights)
1026
1027
1028class DynamicStitchOpTest(test.TestCase):
1029
1030  @test_util.run_deprecated_v1
1031  def testCint32Cpu(self):
1032    with self.session(use_gpu=False):
1033      indices = [
1034          ops.convert_to_tensor([0, 1, 2]),
1035          ops.convert_to_tensor([2, 3])
1036      ]
1037      values = [
1038          ops.convert_to_tensor([12, 23, 34]),
1039          ops.convert_to_tensor([1, 2])
1040      ]
1041      self.assertAllEqual(
1042          data_flow_ops.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2])
1043
1044  @test_util.run_deprecated_v1
1045  def testCint32Gpu(self):
1046    with self.session(use_gpu=True):
1047      indices = [
1048          ops.convert_to_tensor([0, 1, 2]),
1049          ops.convert_to_tensor([2, 3])
1050      ]
1051      values = [
1052          ops.convert_to_tensor([12, 23, 34]),
1053          ops.convert_to_tensor([1, 2])
1054      ]
1055      self.assertAllEqual(
1056          data_flow_ops.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2])
1057
1058  @test_util.run_deprecated_v1
1059  def testInt32Cpu(self):
1060    with self.session(use_gpu=False):
1061      indices = [
1062          ops.convert_to_tensor([0, 1, 2]),
1063          ops.convert_to_tensor([2, 3])
1064      ]
1065      values = [
1066          ops.convert_to_tensor([12, 23, 34]),
1067          ops.convert_to_tensor([1, 2])
1068      ]
1069      self.assertAllEqual(
1070          data_flow_ops.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2])
1071
1072  @test_util.run_deprecated_v1
1073  def testInt32Gpu(self):
1074    with self.session(use_gpu=True):
1075      indices = [
1076          ops.convert_to_tensor([0, 1, 2]),
1077          ops.convert_to_tensor([2, 3])
1078      ]
1079      values = [
1080          ops.convert_to_tensor([12, 23, 34]),
1081          ops.convert_to_tensor([1, 2])
1082      ]
1083      self.assertAllEqual(
1084          data_flow_ops.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2])
1085
1086  @test_util.run_deprecated_v1
1087  def testSumGradArgs(self):
1088    with self.session(use_gpu=False):
1089      indices = [
1090          ops.convert_to_tensor([0, 1, 2, 3]),
1091          ops.convert_to_tensor([2, 3])
1092      ]
1093      values = [
1094          ops.convert_to_tensor([2, 3, 5, 7]),
1095          ops.convert_to_tensor([1, 1])
1096      ]
1097      self.assertAllEqual(
1098          data_flow_ops.dynamic_stitch(indices, values).eval(), [2, 3, 1, 1])
1099
1100  # We expect that the values are merged in order.
1101  @test_util.run_deprecated_v1
1102  def testStitchOrder(self):
1103    with self.cached_session():
1104      indices = []
1105      np_values = []
1106      values = []
1107      for _ in range(10):
1108        indices.extend([ops.convert_to_tensor(np.arange(100).astype(np.int32))])
1109        np_values.extend([np.random.uniform(size=100)])
1110        values.extend([ops.convert_to_tensor(np_values[-1])])
1111      stitched = data_flow_ops.dynamic_stitch(indices, values).eval()
1112    self.assertAllEqual(np_values[-1], stitched)
1113
1114
1115class ParallelDynamicStitchOpTest(test.TestCase):
1116
1117  @test_util.run_deprecated_v1
1118  def testCint32Cpu(self):
1119    with self.session(use_gpu=False):
1120      indices = [
1121          ops.convert_to_tensor([0, 1, 4, 6]),
1122          ops.convert_to_tensor([2, 3, 5])
1123      ]
1124      values = [
1125          ops.convert_to_tensor([12, 23, 34, 45]),
1126          ops.convert_to_tensor([1, 2, 3])
1127      ]
1128      self.assertAllEqual(
1129          data_flow_ops.parallel_dynamic_stitch(indices, values).eval(),
1130          [12, 23, 1, 2, 34, 3, 45])
1131
1132  @test_util.run_deprecated_v1
1133  def testInt32Cpu(self):
1134    with self.session(use_gpu=False):
1135      indices = [
1136          ops.convert_to_tensor([0, 1, 5, 6, 7]),
1137          ops.convert_to_tensor([2, 4, 3])
1138      ]
1139      values = [
1140          ops.convert_to_tensor([12, 23, 34, 45, 56]),
1141          ops.convert_to_tensor([1, 3, 2])
1142      ]
1143      self.assertAllEqual(
1144          data_flow_ops.parallel_dynamic_stitch(indices, values).eval(),
1145          [12, 23, 1, 2, 3, 34, 45, 56])
1146
1147  @test_util.run_deprecated_v1
1148  def testSimple(self):
1149    with self.session(use_gpu=False):
1150      indices = [ops.convert_to_tensor([0, 1]), ops.convert_to_tensor([2, 3])]
1151      values = [ops.convert_to_tensor([2, 3]), ops.convert_to_tensor([1, 1])]
1152      self.assertAllEqual(
1153          data_flow_ops.parallel_dynamic_stitch(indices, values).eval(),
1154          [2, 3, 1, 1])
1155
1156
1157if __name__ == "__main__":
1158  test.main()
1159