1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21from six.moves import range  # pylint: disable=redefined-builtin
22
23from tensorflow.contrib.labeled_tensor.python.ops import core
24from tensorflow.contrib.labeled_tensor.python.ops import ops
25from tensorflow.contrib.labeled_tensor.python.ops import test_util
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors_impl
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import string_ops
32from tensorflow.python.platform import test as test_lib
33
34
35class Base(test_util.Base):
36
37  def setUp(self):
38    super(Base, self).setUp()
39
40    self.x_size = 7
41    self.channel_size = 3
42    self.z_size = 4
43    self.probs_size = 11
44
45    tensor = math_ops.range(0, self.x_size * self.channel_size * self.z_size *
46                            self.probs_size)
47    tensor = array_ops.reshape(
48        tensor, [self.x_size, self.channel_size, self.z_size, self.probs_size])
49    a0 = ('x', range(self.x_size))
50    a1 = ('channel', ['red', 'green', 'blue'])
51    a2 = 'z'
52    a3 = ('probs', np.linspace(0.0, 1.0, self.probs_size))
53
54    self.tensor = tensor
55    self.a0 = a0
56    self.a1 = a1
57    self.a2 = a2
58    self.a2_resolved = ('z', self.z_size)
59    self.a3 = a3
60    self.original_lt = core.LabeledTensor(tensor, [a0, a1, a2, a3])
61
62    self.x_probs_lt = core.slice_function(self.original_lt, {'z': 0})
63    self.x_probs_lt = ops.select(self.x_probs_lt, {'channel': 'red'})
64    self.channel_probs_lt = core.slice_function(self.original_lt,
65                                                {'x': 3,
66                                                 'z': 0})
67
68
69class SelectTest(Base):
70
71  def test_name(self):
72    select_lt = ops.select(self.original_lt, {'channel': 'green'})
73    self.assertIn('lt_select', select_lt.name)
74
75  def test_scalar(self):
76    select_lt = ops.select(self.original_lt, {'channel': 'green'})
77    golden_lt = core.LabeledTensor(self.tensor[:, 1, :, :],
78                                   [self.a0, self.a2, self.a3])
79    self.assertLabeledTensorsEqual(select_lt, golden_lt)
80
81  def test_slice(self):
82    select_lt = ops.select(self.original_lt, {'channel': slice('red', 'green')})
83    a1_sliced = ('channel', ['red', 'green'])
84    golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :],
85                                   [self.a0, a1_sliced, self.a2, self.a3])
86    self.assertLabeledTensorsEqual(select_lt, golden_lt)
87
88  def test_slices(self):
89    select_lt = ops.select(self.original_lt,
90                           {'x': slice(1, 4),
91                            'channel': slice('green', None)})
92
93    a0_sliced = ('x', range(1, 5))
94    a1_sliced = ('channel', ['green', 'blue'])
95    golden_lt = core.LabeledTensor(self.tensor[1:5, 1:, :, :],
96                                   [a0_sliced, a1_sliced, self.a2, self.a3])
97    self.assertLabeledTensorsEqual(select_lt, golden_lt)
98
99  def test_list(self):
100    select_lt = ops.select(self.original_lt, {'channel': ['red', 'green']})
101    a1_sliced = ('channel', ['red', 'green'])
102    golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :],
103                                   [self.a0, a1_sliced, self.a2, self.a3])
104    self.assertLabeledTensorsEqual(select_lt, golden_lt)
105
106  def test_list_one_item(self):
107    select_lt = ops.select(self.original_lt, {'channel': ['red']})
108    a1_sliced = ('channel', ['red'])
109    golden_lt = core.LabeledTensor(self.tensor[:, :1, :, :],
110                                   [self.a0, a1_sliced, self.a2, self.a3])
111    self.assertLabeledTensorsEqual(select_lt, golden_lt)
112
113  def test_list_zero_items(self):
114    select_lt = ops.select(self.original_lt, {'channel': []})
115    golden_lt = core.LabeledTensor(self.tensor[:, :0, :, :],
116                                   [self.a0, 'channel', self.a2, self.a3])
117    self.assertLabeledTensorsEqual(select_lt, golden_lt)
118
119  def test_scalars(self):
120    select_lt = ops.select(self.original_lt, {'x': 1, 'channel': 'green'})
121    golden_lt = core.LabeledTensor(self.tensor[1, 1, :, :], [self.a2, self.a3])
122    self.assertLabeledTensorsEqual(select_lt, golden_lt)
123
124  def test_tuple(self):
125    original_lt = core.LabeledTensor(constant_op.constant([5, 6]),
126                                     [('x', [(1, 2), (3, 4)])])
127    select_lt = ops.select(original_lt, {'x': (1, 2)})
128    golden_lt = core.LabeledTensor(constant_op.constant(5), [])
129    self.assertLabeledTensorsEqual(select_lt, golden_lt)
130
131  def test_invalid_input(self):
132    with self.assertRaises(ValueError):
133      ops.select(self.original_lt, {'foo': 1})
134    with self.assertRaises(ValueError):
135      ops.select(self.original_lt, {'z': 1})
136    with self.assertRaises(KeyError):
137      ops.select(self.original_lt, {'channel': 'purple'})
138    with self.assertRaises(KeyError):
139      ops.select(self.original_lt, {'channel': ['red', 'purple']})
140    with self.assertRaises(NotImplementedError):
141      ops.select(self.original_lt, {'channel': ['red'], 'x': [1]})
142    with self.assertRaises(NotImplementedError):
143      ops.select(self.original_lt, {'channel': ['red'], 'x': 1})
144    with self.assertRaises(NotImplementedError):
145      ops.select(self.original_lt, {'channel': slice('red', 'green', 2)})
146
147
148class ConcatTest(Base):
149
150  def setUp(self):
151    super(ConcatTest, self).setUp()
152
153    self.red_lt = ops.select(self.original_lt, {'channel': ['red']})
154    self.green_lt = ops.select(self.original_lt, {'channel': ['green']})
155    self.blue_lt = ops.select(self.original_lt, {'channel': ['blue']})
156
157  def test_name(self):
158    concat_lt = ops.concat([self.red_lt, self.blue_lt], 'channel')
159    self.assertIn('lt_concat', concat_lt.name)
160
161  def test(self):
162    concat_lt = ops.concat([self.red_lt, self.green_lt], 'channel')
163    golden_lt = ops.select(self.original_lt, {'channel': ['red', 'green']})
164
165    self.assertLabeledTensorsEqual(concat_lt, golden_lt)
166
167  def test_transposed(self):
168    green_transposed = core.transpose(self.green_lt,
169                                      ['probs', 'channel', 'z', 'x'])
170    with self.assertRaises(ValueError):
171      ops.concat([self.red_lt, green_transposed], 'channel')
172
173  def test_invalid_input(self):
174    with self.assertRaises(ValueError):
175      ops.concat([], 'channel')
176    with self.assertRaises(ValueError):
177      ops.concat([self.red_lt, self.red_lt], 'channel')
178    with self.assertRaises(ValueError):
179      ops.concat([self.red_lt, self.red_lt], 'foo')
180
181
182class PackTest(Base):
183
184  def test_name(self):
185    pack_lt = ops.pack([self.original_lt, self.original_lt], 'batch')
186    self.assertIn('lt_pack', pack_lt.name)
187
188  def test(self):
189    pack_lt = ops.pack([self.original_lt, self.original_lt], 'batch')
190    golden_lt = core.LabeledTensor(
191        array_ops.stack([self.original_lt.tensor, self.original_lt.tensor]),
192        ['batch', self.a0, self.a1, self.a2, self.a3])
193
194    self.assertLabeledTensorsEqual(pack_lt, golden_lt)
195
196  def test_axis(self):
197    pack_lt = ops.pack(
198        [self.original_lt, self.original_lt], new_axis='batch', axis_position=4)
199    golden_lt = core.LabeledTensor(
200        array_ops.stack(
201            [self.original_lt.tensor, self.original_lt.tensor], axis=4),
202        [self.a0, self.a1, self.a2, self.a3, 'batch'])
203
204    self.assertLabeledTensorsEqual(pack_lt, golden_lt)
205
206  def test_invalid_input(self):
207    with self.assertRaises(ValueError):
208      ops.pack([self.original_lt, self.original_lt], 'channel')
209
210
211class UnpackTest(Base):
212
213  def test_name(self):
214    unpack_lts = ops.unpack(self.original_lt)
215    for t in unpack_lts:
216      self.assertIn('lt_unpack', t.name)
217
218  def test(self):
219    unpack_lt = ops.unpack(self.original_lt)[0]
220    golden_lt = core.LabeledTensor(
221        array_ops.unstack(self.original_lt.tensor)[0],
222        [self.a1, self.a2, self.a3])
223
224    self.assertLabeledTensorsEqual(unpack_lt, golden_lt)
225
226  def test_axis(self):
227    unpack_lt = ops.unpack(self.original_lt, axis_name='z')[0]
228    golden_lt = core.LabeledTensor(
229        array_ops.unstack(
230            self.original_lt.tensor, axis=2)[0], [self.a0, self.a1, self.a3])
231
232    self.assertLabeledTensorsEqual(unpack_lt, golden_lt)
233
234  def test_invalid_input(self):
235    with self.assertRaises(ValueError):
236      ops.unpack(self.original_lt, axis_name='not_found')
237
238
239class ReshapeTest(Base):
240
241  def test_name(self):
242    reshape_lt = ops.reshape(self.original_lt, ['channel'], ['foo'])
243    self.assertIn('lt_reshape', reshape_lt.name)
244
245  def test_identity(self):
246    reshape_lt = ops.reshape(self.original_lt,
247                             self.original_lt.axes.keys(),
248                             self.original_lt.axes.values())
249    self.assertLabeledTensorsEqual(reshape_lt, self.original_lt)
250
251  def test_known_size(self):
252    new_dim_size = self.channel_size * self.z_size * self.probs_size
253    reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
254                             [('new_dim', new_dim_size)])
255    golden_lt = core.LabeledTensor(
256        array_ops.reshape(self.original_lt.tensor, [self.x_size, -1]),
257        [self.original_lt.axes['x'], 'new_dim'])
258    self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
259
260  def test_unknown_size(self):
261    reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
262                             ['new_dim'])
263    golden_lt = core.LabeledTensor(
264        array_ops.reshape(self.original_lt.tensor, [self.x_size, -1]),
265        [self.original_lt.axes['x'], 'new_dim'])
266    self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
267
268  def test_unknown_dimension(self):
269    orig_lt = core.LabeledTensor(
270        array_ops.placeholder(dtypes.float32, [None]), ['x'])
271    reshape_lt = ops.reshape(orig_lt, ['x'], ['y', ('z', 1)])
272    self.assertEqual(reshape_lt.axes, core.Axes([('y', None), ('z', 1)]))
273    with self.cached_session() as sess:
274      result = sess.run(reshape_lt, feed_dict={orig_lt.tensor: [1, 2]})
275      np.testing.assert_array_equal(result, [[1], [2]])
276
277  def test_with_labels(self):
278    new_dim_size = self.channel_size * self.z_size * self.probs_size
279    reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'],
280                             [('new_dim', range(new_dim_size))])
281    golden_lt = core.LabeledTensor(
282        array_ops.reshape(self.original_lt.tensor, [self.x_size, -1]),
283        [self.original_lt.axes['x'], ('new_dim', range(new_dim_size))])
284    self.assertLabeledTensorsEqual(reshape_lt, golden_lt)
285
286  def test_invalid_input(self):
287    with self.assertRaisesRegexp(ValueError, 'not contained in the set'):
288      ops.reshape(self.original_lt, ['foo'], ['bar'])
289    with self.assertRaisesRegexp(core.AxisOrderError,
290                                 'not a slice of axis names'):
291      ops.reshape(self.original_lt, ['probs', 'z'], ['bar'])
292    with self.assertRaisesRegexp(ValueError, 'at most one axis in new_axes'):
293      ops.reshape(self.original_lt, ['probs'], ['foo', 'bar'])
294
295
296class RenameAxisTest(Base):
297
298  def test_name(self):
299    rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'foo')
300    self.assertIn('lt_rename_axis', rename_axis_lt.name)
301
302  def test_identity(self):
303    rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'channel')
304    self.assertLabeledTensorsEqual(rename_axis_lt, self.original_lt)
305
306  def test_new_name(self):
307    rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'foo')
308    expected_axes = [(name if name != 'channel' else 'foo', axis.value)
309                     for name, axis in self.original_lt.axes.items()]
310    expected_lt = core.LabeledTensor(self.original_lt.tensor, expected_axes)
311    self.assertLabeledTensorsEqual(rename_axis_lt, expected_lt)
312
313  def test_invalid_input(self):
314    with self.assertRaisesRegexp(ValueError, 'not contained in the set'):
315      ops.rename_axis(self.original_lt, 'foo', 'bar')
316
317
318class BatchTest(Base):
319
320  def setUp(self):
321    super(BatchTest, self).setUp()
322
323    tensors = []
324    for i in range(10):
325      offset_lt = core.LabeledTensor(constant_op.constant(i), [])
326      tensors.append(core.add(self.original_lt, offset_lt))
327    self.pack_lt = ops.pack(tensors, 'batch')
328
329  def test_name(self):
330    batch_ops = ops.batch(
331        [self.pack_lt, self.pack_lt], batch_size=2, enqueue_many=True)
332    for bo in batch_ops:
333      self.assertIn('lt_batch', bo.name)
334
335  def test_enqueue_many(self):
336    [batch_2_op] = ops.batch([self.pack_lt], batch_size=2, enqueue_many=True)
337    self.assertEqual(len(batch_2_op.axes['batch']), 2)
338
339    [batch_10_op] = ops.batch([batch_2_op], batch_size=10, enqueue_many=True)
340
341    self.assertLabeledTensorsEqual(self.pack_lt, batch_10_op)
342
343  def test_no_enqueue_many(self):
344    [batch_2_op] = ops.batch([self.original_lt], batch_size=2)
345    self.assertEqual(len(batch_2_op.axes['batch']), 2)
346
347    [batch_10_op] = ops.batch([batch_2_op], batch_size=10, enqueue_many=True)
348
349    self.assertLabeledTensorsEqual(
350        ops.pack(10 * [self.original_lt], 'batch'), batch_10_op)
351
352  def test_invalid_input(self):
353    with self.assertRaises(ValueError):
354      ops.batch([self.original_lt], 3, enqueue_many=True)
355
356  def test_allow_smaller_final_batch(self):
357    [batch_2_op] = ops.batch(
358        [self.original_lt], batch_size=2, allow_smaller_final_batch=True)
359    self.assertEqual(batch_2_op.axes['batch'].size, None)
360
361
362class ShuffleBatchTest(Base):
363
364  def setUp(self):
365    super(ShuffleBatchTest, self).setUp()
366
367    tensors = []
368    for i in range(10):
369      offset_lt = core.LabeledTensor(constant_op.constant(i), [])
370      tensors.append(core.add(self.original_lt, offset_lt))
371    self.pack_lt = ops.pack(tensors, 'batch')
372
373  def test_name(self):
374    batch_lts = ops.shuffle_batch(
375        [self.pack_lt, self.pack_lt], batch_size=2, enqueue_many=True)
376    for blt in batch_lts:
377      self.assertIn('lt_shuffle_batch', blt.name)
378
379  def test_enqueue_many(self):
380    [batch_2_lt] = ops.shuffle_batch(
381        [self.pack_lt],
382        batch_size=2,
383        enqueue_many=True,
384        min_after_dequeue=8,
385        seed=0)
386    self.assertEqual(len(batch_2_lt.axes['batch']), 2)
387
388    [batch_10_lt] = ops.batch([batch_2_lt], batch_size=10, enqueue_many=True)
389
390    self.assertEqual(batch_10_lt.axes, self.pack_lt.axes)
391    [batch_10, pack] = self.eval([batch_10_lt.tensor, self.pack_lt.tensor])
392    self.assertFalse((batch_10 == pack).all())
393
394  def test_allow_smaller_final_batch(self):
395    [batch_2_op] = ops.shuffle_batch(
396        [self.original_lt], batch_size=2, allow_smaller_final_batch=True)
397    self.assertEqual(batch_2_op.axes['batch'].size, None)
398
399
400class RandomCropTest(Base):
401
402  def test_name(self):
403    crop_lt = ops.random_crop(self.original_lt, {'probs': 3})
404    self.assertIn('lt_random_crop', crop_lt.name)
405
406  def test_single(self):
407    crop_lt = ops.random_crop(self.original_lt, {'probs': 3})
408
409    self.assertEqual(
410        core.Axes([self.a0, self.a1, self.a2_resolved, ('probs', 3)]),
411        crop_lt.axes)
412
413  def test_double(self):
414    crop_lt = ops.random_crop(self.original_lt, {'probs': 3, 'channel': 2})
415
416    self.assertEqual(
417        core.Axes([self.a0, ('channel', 2), self.a2_resolved, ('probs', 3)]),
418        crop_lt.axes)
419
420  def test_size1(self):
421    crop_lt = ops.random_crop(self.original_lt, {'probs': 1})
422
423    self.assertEqual(
424        core.Axes([self.a0, self.a1, self.a2_resolved, ('probs', 1)]),
425        crop_lt.axes)
426
427  def test_different_seeds(self):
428    crop_0_lt = ops.random_crop(
429        self.original_lt, {'probs': 3,
430                           'channel': 2}, seed=0)
431    crop_1_lt = ops.random_crop(
432        self.original_lt, {'probs': 3,
433                           'channel': 2}, seed=1)
434
435    self.assertEqual(crop_0_lt.axes, crop_1_lt.axes)
436    [crop_0, crop_1] = self.eval([crop_0_lt.tensor, crop_1_lt.tensor])
437    self.assertFalse((crop_0 == crop_1).all())
438
439  def test_identical_seeds(self):
440    crop_0_lt = ops.random_crop(
441        self.original_lt, {'probs': 3,
442                           'channel': 2}, seed=0)
443    crop_1_lt = ops.random_crop(
444        self.original_lt, {'probs': 3,
445                           'channel': 2}, seed=0)
446
447    self.assertLabeledTensorsEqual(crop_0_lt, crop_1_lt)
448
449  def test_crop_idempotent(self):
450    crop_0_lt = ops.random_crop(
451        self.original_lt, {'probs': 3,
452                           'channel': 2}, seed=0)
453    crop_1_lt = ops.random_crop(crop_0_lt, {'probs': 3, 'channel': 2}, seed=1)
454
455    self.assertLabeledTensorsEqual(crop_0_lt, crop_1_lt)
456
457  def test_invalid_input(self):
458    with self.assertRaises(ValueError):
459      ops.random_crop(self.original_lt, {'foobar': 2})
460
461
462class MapFnTest(Base):
463
464  def test_name(self):
465    map_lt = ops.map_fn(core.identity, self.original_lt)
466    self.assertIn('lt_map_fn', map_lt.name)
467
468  def test_identity(self):
469    map_lt = ops.map_fn(core.identity, self.original_lt)
470    self.assertLabeledTensorsEqual(map_lt, self.original_lt)
471
472  def test_callable_object(self):
473
474    class Identity(object):
475
476      def __call__(self, other):
477        return other
478
479    map_lt = ops.map_fn(Identity(), self.original_lt)
480    self.assertLabeledTensorsEqual(map_lt, self.original_lt)
481
482  def test_slice(self):
483    map_lt = ops.map_fn(lambda t: core.slice_function(t, {'channel': 1}),
484                        self.original_lt)
485    slice_lt = core.slice_function(self.original_lt, {'channel': 1})
486    self.assertLabeledTensorsEqual(map_lt, slice_lt)
487
488  def test_string(self):
489
490    def fn(entry_lt):
491      op = string_ops.string_join([entry_lt, 'world'])
492      return core.LabeledTensor(op, [])
493
494    tensor_lt = ops.constant(['hi', 'bye'], axes=['batch'])
495    map_lt = ops.map_fn(fn, tensor_lt)
496    golden_lt = ops.constant(['hiworld', 'byeworld'], axes=['batch'])
497
498    self.assertLabeledTensorsEqual(map_lt, golden_lt)
499
500
501class FoldlTest(Base):
502
503  def test_name(self):
504    foldl_lt = ops.foldl(core.add, self.original_lt,
505                         core.slice_function(self.original_lt, {'x': 0}))
506    self.assertIn('lt_foldl', foldl_lt.name)
507
508  def test_sum(self):
509    initializer_lt = ops.constant([0, 10], axes=['y'])
510    tensor_lt = ops.constant([[1, 2], [3, 4], [5, 6]], axes=['x', 'y'])
511    foldl_lt = ops.foldl(core.add, tensor_lt, initializer_lt)
512    golden_lt = ops.constant([9, 22], axes=['y'])
513    self.assertLabeledTensorsEqual(foldl_lt, golden_lt)
514
515
516class SqueezeTest(Base):
517
518  def setUp(self):
519    super(SqueezeTest, self).setUp()
520
521    self.squeezable_lt = core.slice_function(
522        self.original_lt, {'channel': slice(0, 1),
523                           'probs': slice(0, 1)})
524
525  def test_name(self):
526    squeeze_lt = ops.squeeze(self.squeezable_lt)
527    self.assertIn('lt_squeeze', squeeze_lt.name)
528
529  def test_none(self):
530    none_lt = ops.squeeze(self.squeezable_lt, None)
531    axes_lt = ops.squeeze(self.squeezable_lt, ['channel', 'probs'])
532    self.assertLabeledTensorsEqual(none_lt, axes_lt)
533
534  def test(self):
535    squeeze_lt = ops.squeeze(self.squeezable_lt, ['probs'])
536    golden_lt = core.slice_function(self.squeezable_lt, {'probs': 0})
537    self.assertLabeledTensorsEqual(squeeze_lt, golden_lt)
538
539  def test_invalid_input(self):
540    with self.assertRaises(ValueError):
541      ops.squeeze(self.original_lt, ['channel'])
542    with self.assertRaises(ValueError):
543      ops.squeeze(self.squeezable_lt, ['foo'])
544
545
546class MatMulTest(Base):
547
548  def test_name(self):
549    x_lt = core.LabeledTensor(array_ops.ones((3,)), ['x'])
550    matmul_lt = ops.matmul(x_lt, x_lt)
551    self.assertIn('lt_matmul', matmul_lt.name)
552
553  def test_vector_vector(self):
554    x_lt = core.LabeledTensor(math_ops.range(3), ['x'])
555    matmul_lt = ops.matmul(x_lt, x_lt)
556    golden_lt = core.convert_to_labeled_tensor(5)
557    self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
558
559  def test_matrix_vector(self):
560    xy_lt = core.LabeledTensor(
561        array_ops.reshape(math_ops.range(6), (2, 3)), ['x', 'y'])
562    y_lt = core.LabeledTensor(math_ops.range(3), ['y'])
563
564    matmul_lt = ops.matmul(xy_lt, y_lt)
565    golden_lt = core.LabeledTensor(
566        math_ops.matmul(xy_lt.tensor, array_ops.reshape(y_lt.tensor,
567                                                        (-1, 1)))[:, 0], ['x'])
568    self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
569
570    matmul_lt = ops.matmul(y_lt, xy_lt)
571    self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
572
573  def test_matrix_matrix(self):
574    xy_lt = core.LabeledTensor(
575        array_ops.reshape(math_ops.range(6), (2, 3)), ['x', 'y'])
576    yz_lt = core.LabeledTensor(
577        array_ops.reshape(math_ops.range(12), (3, 4)), ['y', 'z'])
578
579    matmul_lt = ops.matmul(xy_lt, yz_lt)
580    golden_lt = core.LabeledTensor(
581        math_ops.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z'])
582    self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
583
584    transpose = lambda x: core.transpose(x, list(x.axes.keys())[::-1])
585
586    matmul_lt = ops.matmul(xy_lt, transpose(yz_lt))
587    self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
588
589    matmul_lt = ops.matmul(transpose(xy_lt), yz_lt)
590    self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
591
592    matmul_lt = ops.matmul(transpose(xy_lt), transpose(yz_lt))
593    self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
594
595    matmul_lt = ops.matmul(yz_lt, xy_lt)
596    self.assertLabeledTensorsEqual(matmul_lt, transpose(golden_lt))
597
598  def test_matrix_matrix_axis_order(self):
599    xy_lt = core.LabeledTensor(
600        array_ops.reshape(math_ops.range(6), (2, 3)), ['x', 'y'])
601    yz_lt = core.LabeledTensor(
602        array_ops.reshape(math_ops.range(12), (3, 4)), ['y', 'z'])
603
604    golden_lt = core.LabeledTensor(
605        math_ops.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z'])
606
607    with core.axis_order_scope(['x', 'y', 'z']):
608
609      matmul_lt = ops.matmul(xy_lt, yz_lt)
610      self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
611
612      matmul_lt = ops.matmul(yz_lt, xy_lt)
613      self.assertLabeledTensorsEqual(matmul_lt, golden_lt)
614
615  def test_invalid(self):
616    scalar_lt = core.LabeledTensor(array_ops.ones(()), [])
617    x_lt = core.LabeledTensor(array_ops.ones((2,)), ['x'])
618    x2_lt = core.LabeledTensor(array_ops.ones((3,)), ['x'])
619    y_lt = core.LabeledTensor(array_ops.ones((3,)), ['y'])
620    xy_lt = core.LabeledTensor(array_ops.ones((2, 3)), ['x', 'y'])
621    xyz_lt = core.LabeledTensor(array_ops.ones((2, 3, 1)), ['x', 'y', 'z'])
622
623    with self.assertRaisesRegexp(ValueError, 'inputs with at least rank'):
624      ops.matmul(x_lt, scalar_lt)
625
626    with self.assertRaises(NotImplementedError):
627      ops.matmul(x_lt, xyz_lt)
628
629    with self.assertRaisesRegexp(ValueError, 'exactly one axis in common'):
630      ops.matmul(x_lt, y_lt)
631
632    with self.assertRaises(NotImplementedError):
633      ops.matmul(xy_lt, xy_lt)
634
635    with self.assertRaisesRegexp(ValueError, 'does not match'):
636      ops.matmul(x_lt, x2_lt)
637
638
639class ReduceSumTest(Base):
640
641  def test_name(self):
642    sum_lt = ops.reduce_sum(self.original_lt, {'channel'})
643    self.assertIn('lt_reduce_sum', sum_lt.name)
644
645  def test_drop_axis(self):
646    sum_lt = ops.reduce_sum(self.original_lt, {'channel'})
647    golden_lt = core.LabeledTensor(
648        math_ops.reduce_sum(self.original_lt.tensor, 1),
649        [self.a0, self.a2, self.a3])
650    self.assertLabeledTensorsEqual(sum_lt, golden_lt)
651
652  def test_drop_scalar_axis(self):
653    sum_lt = ops.reduce_sum(self.original_lt, 'channel')
654    golden_lt = core.LabeledTensor(
655        math_ops.reduce_sum(self.original_lt.tensor, 1),
656        [self.a0, self.a2, self.a3])
657    self.assertLabeledTensorsEqual(sum_lt, golden_lt)
658
659  def test_keep_axis(self):
660    sum_lt = ops.reduce_sum(self.original_lt, {('channel', 'hihowareyou')})
661    golden_lt = core.LabeledTensor(
662        math_ops.reduce_sum(
663            self.original_lt.tensor, 1, keepdims=True),
664        [self.a0, ('channel', ['hihowareyou']), self.a2, self.a3])
665    self.assertLabeledTensorsEqual(sum_lt, golden_lt)
666
667  def test_keep_scalar_axis(self):
668    sum_lt = ops.reduce_sum(self.original_lt, ('channel', 'hihowareyou'))
669    golden_lt = core.LabeledTensor(
670        math_ops.reduce_sum(
671            self.original_lt.tensor, 1, keepdims=True),
672        [self.a0, ('channel', ['hihowareyou']), self.a2, self.a3])
673    self.assertLabeledTensorsEqual(sum_lt, golden_lt)
674
675  def test_scalar(self):
676    scalar_lt = core.LabeledTensor(constant_op.constant(42), [])
677    reduce_lt = ops.reduce_sum(scalar_lt, [])
678    self.assertLabeledTensorsEqual(reduce_lt, scalar_lt)
679
680  def test_empty_list(self):
681    reduce_lt = ops.reduce_sum(self.original_lt, [])
682    self.assertLabeledTensorsEqual(reduce_lt, self.original_lt)
683
684  def test_none(self):
685    sum_lt = ops.reduce_sum(self.original_lt)
686    golden_lt = core.LabeledTensor(
687        math_ops.reduce_sum(self.original_lt.tensor), [])
688    self.assertLabeledTensorsEqual(sum_lt, golden_lt)
689
690  def test_function_docstring_and_name(self):
691    self.assertIn('tf.reduce_sum', ops.reduce_sum.__doc__)
692    self.assertEqual('reduce_sum', ops.reduce_sum.__name__)
693
694
695class ReduceMeanTest(Base):
696
697  def test_name(self):
698    actual_lt = ops.reduce_mean(self.original_lt, {'channel'})
699    self.assertIn('lt_reduce_mean', actual_lt.name)
700
701  def test(self):
702    actual_lt = ops.reduce_mean(self.original_lt, {'channel'})
703    golden_lt = core.LabeledTensor(
704        math_ops.reduce_mean(self.original_lt.tensor, 1),
705        [self.a0, self.a2, self.a3])
706    self.assertLabeledTensorsEqual(actual_lt, golden_lt)
707
708
709class ReduceProdTest(Base):
710
711  def test_name(self):
712    result_lt = ops.reduce_prod(self.original_lt, {'channel'})
713    self.assertIn('lt_reduce_prod', result_lt.name)
714
715  def test(self):
716    result_lt = ops.reduce_prod(self.original_lt, {'channel'})
717    golden_lt = core.LabeledTensor(
718        math_ops.reduce_prod(self.original_lt.tensor, 1),
719        [self.a0, self.a2, self.a3])
720    self.assertLabeledTensorsEqual(result_lt, golden_lt)
721
722
723class ReduceMinTest(Base):
724
725  def test_name(self):
726    result_lt = ops.reduce_min(self.original_lt, {'channel'})
727    self.assertIn('lt_reduce_min', result_lt.name)
728
729  def test(self):
730    result_lt = ops.reduce_min(self.original_lt, {'channel'})
731    golden_lt = core.LabeledTensor(
732        math_ops.reduce_min(self.original_lt.tensor, 1),
733        [self.a0, self.a2, self.a3])
734    self.assertLabeledTensorsEqual(result_lt, golden_lt)
735
736
737class ReduceMaxTest(Base):
738
739  def test_name(self):
740    result_lt = ops.reduce_max(self.original_lt, {'channel'})
741    self.assertIn('lt_reduce_max', result_lt.name)
742
743  def test(self):
744    result_lt = ops.reduce_max(self.original_lt, {'channel'})
745    golden_lt = core.LabeledTensor(
746        math_ops.reduce_max(self.original_lt.tensor, 1),
747        [self.a0, self.a2, self.a3])
748    self.assertLabeledTensorsEqual(result_lt, golden_lt)
749
750
751class BaseReduceBoolean(Base):
752
753  def setUp(self):
754    super(BaseReduceBoolean, self).setUp()
755    self.bool_tensor = math_ops.cast(self.original_lt.tensor > 5, dtypes.bool)
756    self.bool_lt = core.LabeledTensor(self.bool_tensor, self.original_lt.axes)
757
758
759class ReduceAllTest(BaseReduceBoolean):
760
761  def test_name(self):
762    result_lt = ops.reduce_all(self.bool_lt, {'channel'})
763    self.assertIn('lt_reduce_all', result_lt.name)
764
765  def test(self):
766    result_lt = ops.reduce_all(self.bool_lt, {'channel'})
767    golden_lt = core.LabeledTensor(
768        math_ops.reduce_all(self.bool_tensor, 1), [self.a0, self.a2, self.a3])
769    self.assertLabeledTensorsEqual(result_lt, golden_lt)
770
771
772class ReduceAnyTest(BaseReduceBoolean):
773
774  def test_name(self):
775    result_lt = ops.reduce_any(self.bool_lt, {'channel'})
776    self.assertIn('lt_reduce_any', result_lt.name)
777
778  def test(self):
779    result_lt = ops.reduce_any(self.bool_lt, {'channel'})
780    golden_lt = core.LabeledTensor(
781        math_ops.reduce_any(self.bool_tensor, 1), [self.a0, self.a2, self.a3])
782    self.assertLabeledTensorsEqual(result_lt, golden_lt)
783
784
785class TileTest(Base):
786
787  def test_name(self):
788    tile_lt = ops.tile(self.original_lt, {'z': 2})
789    self.assertIn('lt_tile', tile_lt.name)
790
791  def test(self):
792    for multiple in [2, constant_op.constant(2)]:
793      tile_lt = ops.tile(self.original_lt, {'z': multiple})
794      golden_op = array_ops.tile(self.original_lt.tensor, [1, 1, multiple, 1])
795      golden_axes = [
796          'z' if axis.name == 'z' else axis
797          for axis in self.original_lt.axes.values()
798      ]
799      golden_lt = core.LabeledTensor(golden_op, golden_axes)
800      self.assertLabeledTensorsEqual(tile_lt, golden_lt)
801
802  def test_invalid_input(self):
803    with self.assertRaisesRegexp(ValueError, 'are not contained in the set'):
804      ops.tile(self.original_lt, {'foo': 5})
805    with self.assertRaisesRegexp(ValueError, 'axes with tick labels'):
806      ops.tile(self.original_lt, {'x': 5})
807
808
809class PadTest(Base):
810
811  def test_name(self):
812    pad_lt = ops.pad(self.original_lt,
813                     {'x': (1, 1),
814                      'channel': ([], ['alpha'])})
815    self.assertIn('lt_pad', pad_lt.name)
816
817  def test(self):
818    pad_lt = ops.pad(self.original_lt,
819                     {'x': (1, 1),
820                      'channel': ([], ['alpha'])})
821
822    golden_op = array_ops.pad(self.original_lt.tensor, [[1, 1], [0, 1], [0, 0],
823                                                        [0, 0]])
824    golden_axes = [('x', self.x_size + 2),
825                   ('channel', ['red', 'green', 'blue', 'alpha']), self.a2,
826                   self.a3]
827    golden_lt = core.LabeledTensor(golden_op, golden_axes)
828    self.assertLabeledTensorsEqual(pad_lt, golden_lt)
829
830  def test_invalid_input(self):
831    with self.assertRaisesRegexp(ValueError, 'are not contained in the set'):
832      ops.pad(self.original_lt, {'foo': (1, 1), 'channel': ([], ['alpha'])})
833
834
835class ConstantTest(Base):
836
837  def test_name(self):
838    constant_lt = ops.constant(1)
839    self.assertIn('lt_constant', constant_lt.name)
840
841  def test_scalar(self):
842    constant_lt = ops.constant(1)
843    golden_lt = core.LabeledTensor(constant_op.constant(1), [])
844    self.assertLabeledTensorsEqual(constant_lt, golden_lt)
845
846  def test_infer_shape(self):
847    constant_lt = ops.constant([1, 2], axes=['x'])
848    golden_lt = core.LabeledTensor(constant_op.constant([1, 2]), ['x'])
849    self.assertLabeledTensorsEqual(constant_lt, golden_lt)
850
851  def test_specify_shape(self):
852    constant_lt = ops.constant(1, axes=[('x', 3)])
853    golden_lt = core.LabeledTensor(constant_op.constant(1, shape=(3,)), ['x'])
854    self.assertLabeledTensorsEqual(constant_lt, golden_lt)
855
856  def test_existing_axes(self):
857    golden_lt = core.LabeledTensor(constant_op.constant([1, 2]), ['x'])
858    constant_lt = ops.constant([1, 2], axes=golden_lt.axes)
859    self.assertLabeledTensorsEqual(constant_lt, golden_lt)
860
861
862class ZerosLikeTest(Base):
863
864  def test_name(self):
865    like_lt = ops.zeros_like(self.original_lt)
866    self.assertIn('lt_zeros_like', like_lt.name)
867
868  def test(self):
869    like_lt = ops.zeros_like(self.original_lt)
870    golden_lt = core.LabeledTensor(
871        array_ops.zeros_like(self.original_lt.tensor), self.original_lt.axes)
872    self.assertLabeledTensorsEqual(like_lt, golden_lt)
873
874
875class OnesLikeTest(Base):
876
877  def test_name(self):
878    like_lt = ops.ones_like(self.original_lt)
879    self.assertIn('lt_ones_like', like_lt.name)
880
881  def test(self):
882    like_lt = ops.ones_like(self.original_lt)
883    golden_lt = core.LabeledTensor(
884        array_ops.ones_like(self.original_lt.tensor), self.original_lt.axes)
885    self.assertLabeledTensorsEqual(like_lt, golden_lt)
886
887
888class CastTest(Base):
889
890  def test_name(self):
891    cast_lt = ops.cast(self.original_lt, dtypes.float16)
892    self.assertIn('lt_cast', cast_lt.name)
893
894  def test(self):
895    cast_lt = ops.cast(self.original_lt, dtypes.float16)
896    golden_lt = core.LabeledTensor(
897        math_ops.cast(self.original_lt.tensor, dtypes.float16),
898        self.original_lt.axes)
899    self.assertLabeledTensorsEqual(cast_lt, golden_lt)
900
901
902class VerifyTensorAllFiniteTest(Base):
903
904  def setUp(self):
905    super(VerifyTensorAllFiniteTest, self).setUp()
906
907    self.finite_lt = core.LabeledTensor(constant_op.constant(42.0), [])
908    self.nan_lt = core.LabeledTensor(constant_op.constant(np.nan), [])
909
910    self.checked_finite_lt = ops.verify_tensor_all_finite(self.finite_lt, '')
911    self.checked_nan_lt = ops.verify_tensor_all_finite(self.nan_lt, '')
912
913  def test_name(self):
914    self.assertIn('lt_verify_tensor_all_finite', self.checked_finite_lt.name)
915    self.assertIn('lt_verify_tensor_all_finite', self.checked_nan_lt.name)
916
917  def test_finite(self):
918    self.assertLabeledTensorsEqual(self.finite_lt, self.checked_finite_lt)
919
920  def test_nan(self):
921    with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
922                                 'Tensor had NaN values'):
923      self.eval([self.checked_nan_lt])
924
925
926class BooleanMaskTest(Base):
927
928  def test_name(self):
929    mask = core.LabeledTensor(math_ops.range(7) > 3, [self.a0])
930    masked_lt = ops.boolean_mask(self.original_lt, mask)
931    self.assertIn('lt_boolean_mask', masked_lt.name)
932
933  def test(self):
934    mask = core.LabeledTensor(math_ops.range(7) > 3, [self.a0])
935    masked_lt = ops.boolean_mask(self.original_lt, mask)
936    golden_lt = core.LabeledTensor(
937        array_ops.boolean_mask(self.original_lt.tensor, mask.tensor),
938        ['x', self.a1, self.a2, self.a3])
939    self.assertLabeledTensorsEqual(masked_lt, golden_lt)
940
941  def test_invalid_rank(self):
942    mask = core.LabeledTensor(array_ops.ones((7, 3)) > 3, [self.a0, self.a1])
943    with self.assertRaises(NotImplementedError):
944      ops.boolean_mask(self.original_lt, mask)
945
946  def test_mismatched_axis(self):
947    mask = core.LabeledTensor(math_ops.range(7) > 3, ['foo'])
948    with self.assertRaisesRegexp(ValueError, 'not equal'):
949      ops.boolean_mask(self.original_lt, mask)
950
951
952class WhereTest(Base):
953
954  def test_name(self):
955    condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])
956    where_lt = ops.where(condition, condition, condition)
957    self.assertIn('lt_where', where_lt.name)
958
959  def test(self):
960    condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])
961    x = core.LabeledTensor(array_ops.ones(5), ['x'])
962    y = core.LabeledTensor(array_ops.zeros(5), ['x'])
963    where_lt = ops.where(condition, x, y)
964
965    golden_lt = core.LabeledTensor(
966        array_ops.concat([array_ops.ones(3), array_ops.zeros(2)], 0), ['x'])
967    self.assertLabeledTensorsEqual(where_lt, golden_lt)
968
969  def test_mismatched_axes(self):
970    condition = core.LabeledTensor(math_ops.range(5) < 3, ['x'])
971    with self.assertRaisesRegexp(ValueError, 'equal axes'):
972      ops.where(condition, condition[:3], condition)
973    with self.assertRaisesRegexp(ValueError, 'equal axes'):
974      ops.where(condition, condition, condition[:3])
975
976
977if __name__ == '__main__':
978  test_lib.main()
979