1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21from six.moves import range # pylint: disable=redefined-builtin 22 23from tensorflow.contrib.labeled_tensor.python.ops import core 24from tensorflow.contrib.labeled_tensor.python.ops import ops 25from tensorflow.contrib.labeled_tensor.python.ops import test_util 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors_impl 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import string_ops 32from tensorflow.python.platform import test as test_lib 33 34 35class Base(test_util.Base): 36 37 def setUp(self): 38 super(Base, self).setUp() 39 40 self.x_size = 7 41 self.channel_size = 3 42 self.z_size = 4 43 self.probs_size = 11 44 45 tensor = math_ops.range(0, self.x_size * self.channel_size * self.z_size * 46 self.probs_size) 47 tensor = array_ops.reshape( 48 tensor, [self.x_size, self.channel_size, self.z_size, self.probs_size]) 49 a0 = ('x', range(self.x_size)) 50 a1 = ('channel', ['red', 'green', 'blue']) 51 a2 = 'z' 52 a3 = ('probs', np.linspace(0.0, 1.0, self.probs_size)) 53 54 self.tensor = tensor 55 self.a0 = a0 56 self.a1 = a1 57 self.a2 = a2 58 self.a2_resolved = ('z', self.z_size) 59 self.a3 = a3 60 self.original_lt = core.LabeledTensor(tensor, [a0, a1, a2, a3]) 61 62 self.x_probs_lt = core.slice_function(self.original_lt, {'z': 0}) 63 self.x_probs_lt = ops.select(self.x_probs_lt, {'channel': 'red'}) 64 self.channel_probs_lt = core.slice_function(self.original_lt, 65 {'x': 3, 66 'z': 0}) 67 68 69class SelectTest(Base): 70 71 def test_name(self): 72 select_lt = ops.select(self.original_lt, {'channel': 'green'}) 73 self.assertIn('lt_select', select_lt.name) 74 75 def test_scalar(self): 76 select_lt = ops.select(self.original_lt, {'channel': 'green'}) 77 golden_lt = core.LabeledTensor(self.tensor[:, 1, :, :], 78 [self.a0, self.a2, self.a3]) 79 self.assertLabeledTensorsEqual(select_lt, golden_lt) 80 81 def test_slice(self): 82 select_lt = ops.select(self.original_lt, {'channel': slice('red', 'green')}) 83 a1_sliced = ('channel', ['red', 'green']) 84 golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :], 85 [self.a0, a1_sliced, self.a2, self.a3]) 86 self.assertLabeledTensorsEqual(select_lt, golden_lt) 87 88 def test_slices(self): 89 select_lt = ops.select(self.original_lt, 90 {'x': slice(1, 4), 91 'channel': slice('green', None)}) 92 93 a0_sliced = ('x', range(1, 5)) 94 a1_sliced = ('channel', ['green', 'blue']) 95 golden_lt = core.LabeledTensor(self.tensor[1:5, 1:, :, :], 96 [a0_sliced, a1_sliced, self.a2, self.a3]) 97 self.assertLabeledTensorsEqual(select_lt, golden_lt) 98 99 def test_list(self): 100 select_lt = ops.select(self.original_lt, {'channel': ['red', 'green']}) 101 a1_sliced = ('channel', ['red', 'green']) 102 golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :], 103 [self.a0, a1_sliced, self.a2, self.a3]) 104 self.assertLabeledTensorsEqual(select_lt, golden_lt) 105 106 def test_list_one_item(self): 107 select_lt = ops.select(self.original_lt, {'channel': ['red']}) 108 a1_sliced = ('channel', ['red']) 109 golden_lt = core.LabeledTensor(self.tensor[:, :1, :, :], 110 [self.a0, a1_sliced, self.a2, self.a3]) 111 self.assertLabeledTensorsEqual(select_lt, golden_lt) 112 113 def test_list_zero_items(self): 114 select_lt = ops.select(self.original_lt, {'channel': []}) 115 golden_lt = core.LabeledTensor(self.tensor[:, :0, :, :], 116 [self.a0, 'channel', self.a2, self.a3]) 117 self.assertLabeledTensorsEqual(select_lt, golden_lt) 118 119 def test_scalars(self): 120 select_lt = ops.select(self.original_lt, {'x': 1, 'channel': 'green'}) 121 golden_lt = core.LabeledTensor(self.tensor[1, 1, :, :], [self.a2, self.a3]) 122 self.assertLabeledTensorsEqual(select_lt, golden_lt) 123 124 def test_tuple(self): 125 original_lt = core.LabeledTensor(constant_op.constant([5, 6]), 126 [('x', [(1, 2), (3, 4)])]) 127 select_lt = ops.select(original_lt, {'x': (1, 2)}) 128 golden_lt = core.LabeledTensor(constant_op.constant(5), []) 129 self.assertLabeledTensorsEqual(select_lt, golden_lt) 130 131 def test_invalid_input(self): 132 with self.assertRaises(ValueError): 133 ops.select(self.original_lt, {'foo': 1}) 134 with self.assertRaises(ValueError): 135 ops.select(self.original_lt, {'z': 1}) 136 with self.assertRaises(KeyError): 137 ops.select(self.original_lt, {'channel': 'purple'}) 138 with self.assertRaises(KeyError): 139 ops.select(self.original_lt, {'channel': ['red', 'purple']}) 140 with self.assertRaises(NotImplementedError): 141 ops.select(self.original_lt, {'channel': ['red'], 'x': [1]}) 142 with self.assertRaises(NotImplementedError): 143 ops.select(self.original_lt, {'channel': ['red'], 'x': 1}) 144 with self.assertRaises(NotImplementedError): 145 ops.select(self.original_lt, {'channel': slice('red', 'green', 2)}) 146 147 148class ConcatTest(Base): 149 150 def setUp(self): 151 super(ConcatTest, self).setUp() 152 153 self.red_lt = ops.select(self.original_lt, {'channel': ['red']}) 154 self.green_lt = ops.select(self.original_lt, {'channel': ['green']}) 155 self.blue_lt = ops.select(self.original_lt, {'channel': ['blue']}) 156 157 def test_name(self): 158 concat_lt = ops.concat([self.red_lt, self.blue_lt], 'channel') 159 self.assertIn('lt_concat', concat_lt.name) 160 161 def test(self): 162 concat_lt = ops.concat([self.red_lt, self.green_lt], 'channel') 163 golden_lt = ops.select(self.original_lt, {'channel': ['red', 'green']}) 164 165 self.assertLabeledTensorsEqual(concat_lt, golden_lt) 166 167 def test_transposed(self): 168 green_transposed = core.transpose(self.green_lt, 169 ['probs', 'channel', 'z', 'x']) 170 with self.assertRaises(ValueError): 171 ops.concat([self.red_lt, green_transposed], 'channel') 172 173 def test_invalid_input(self): 174 with self.assertRaises(ValueError): 175 ops.concat([], 'channel') 176 with self.assertRaises(ValueError): 177 ops.concat([self.red_lt, self.red_lt], 'channel') 178 with self.assertRaises(ValueError): 179 ops.concat([self.red_lt, self.red_lt], 'foo') 180 181 182class PackTest(Base): 183 184 def test_name(self): 185 pack_lt = ops.pack([self.original_lt, self.original_lt], 'batch') 186 self.assertIn('lt_pack', pack_lt.name) 187 188 def test(self): 189 pack_lt = ops.pack([self.original_lt, self.original_lt], 'batch') 190 golden_lt = core.LabeledTensor( 191 array_ops.stack([self.original_lt.tensor, self.original_lt.tensor]), 192 ['batch', self.a0, self.a1, self.a2, self.a3]) 193 194 self.assertLabeledTensorsEqual(pack_lt, golden_lt) 195 196 def test_axis(self): 197 pack_lt = ops.pack( 198 [self.original_lt, self.original_lt], new_axis='batch', axis_position=4) 199 golden_lt = core.LabeledTensor( 200 array_ops.stack( 201 [self.original_lt.tensor, self.original_lt.tensor], axis=4), 202 [self.a0, self.a1, self.a2, self.a3, 'batch']) 203 204 self.assertLabeledTensorsEqual(pack_lt, golden_lt) 205 206 def test_invalid_input(self): 207 with self.assertRaises(ValueError): 208 ops.pack([self.original_lt, self.original_lt], 'channel') 209 210 211class UnpackTest(Base): 212 213 def test_name(self): 214 unpack_lts = ops.unpack(self.original_lt) 215 for t in unpack_lts: 216 self.assertIn('lt_unpack', t.name) 217 218 def test(self): 219 unpack_lt = ops.unpack(self.original_lt)[0] 220 golden_lt = core.LabeledTensor( 221 array_ops.unstack(self.original_lt.tensor)[0], 222 [self.a1, self.a2, self.a3]) 223 224 self.assertLabeledTensorsEqual(unpack_lt, golden_lt) 225 226 def test_axis(self): 227 unpack_lt = ops.unpack(self.original_lt, axis_name='z')[0] 228 golden_lt = core.LabeledTensor( 229 array_ops.unstack( 230 self.original_lt.tensor, axis=2)[0], [self.a0, self.a1, self.a3]) 231 232 self.assertLabeledTensorsEqual(unpack_lt, golden_lt) 233 234 def test_invalid_input(self): 235 with self.assertRaises(ValueError): 236 ops.unpack(self.original_lt, axis_name='not_found') 237 238 239class ReshapeTest(Base): 240 241 def test_name(self): 242 reshape_lt = ops.reshape(self.original_lt, ['channel'], ['foo']) 243 self.assertIn('lt_reshape', reshape_lt.name) 244 245 def test_identity(self): 246 reshape_lt = ops.reshape(self.original_lt, 247 self.original_lt.axes.keys(), 248 self.original_lt.axes.values()) 249 self.assertLabeledTensorsEqual(reshape_lt, self.original_lt) 250 251 def test_known_size(self): 252 new_dim_size = self.channel_size * self.z_size * self.probs_size 253 reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'], 254 [('new_dim', new_dim_size)]) 255 golden_lt = core.LabeledTensor( 256 array_ops.reshape(self.original_lt.tensor, [self.x_size, -1]), 257 [self.original_lt.axes['x'], 'new_dim']) 258 self.assertLabeledTensorsEqual(reshape_lt, golden_lt) 259 260 def test_unknown_size(self): 261 reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'], 262 ['new_dim']) 263 golden_lt = core.LabeledTensor( 264 array_ops.reshape(self.original_lt.tensor, [self.x_size, -1]), 265 [self.original_lt.axes['x'], 'new_dim']) 266 self.assertLabeledTensorsEqual(reshape_lt, golden_lt) 267 268 def test_unknown_dimension(self): 269 orig_lt = core.LabeledTensor( 270 array_ops.placeholder(dtypes.float32, [None]), ['x']) 271 reshape_lt = ops.reshape(orig_lt, ['x'], ['y', ('z', 1)]) 272 self.assertEqual(reshape_lt.axes, core.Axes([('y', None), ('z', 1)])) 273 with self.cached_session() as sess: 274 result = sess.run(reshape_lt, feed_dict={orig_lt.tensor: [1, 2]}) 275 np.testing.assert_array_equal(result, [[1], [2]]) 276 277 def test_with_labels(self): 278 new_dim_size = self.channel_size * self.z_size * self.probs_size 279 reshape_lt = ops.reshape(self.original_lt, ['channel', 'z', 'probs'], 280 [('new_dim', range(new_dim_size))]) 281 golden_lt = core.LabeledTensor( 282 array_ops.reshape(self.original_lt.tensor, [self.x_size, -1]), 283 [self.original_lt.axes['x'], ('new_dim', range(new_dim_size))]) 284 self.assertLabeledTensorsEqual(reshape_lt, golden_lt) 285 286 def test_invalid_input(self): 287 with self.assertRaisesRegexp(ValueError, 'not contained in the set'): 288 ops.reshape(self.original_lt, ['foo'], ['bar']) 289 with self.assertRaisesRegexp(core.AxisOrderError, 290 'not a slice of axis names'): 291 ops.reshape(self.original_lt, ['probs', 'z'], ['bar']) 292 with self.assertRaisesRegexp(ValueError, 'at most one axis in new_axes'): 293 ops.reshape(self.original_lt, ['probs'], ['foo', 'bar']) 294 295 296class RenameAxisTest(Base): 297 298 def test_name(self): 299 rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'foo') 300 self.assertIn('lt_rename_axis', rename_axis_lt.name) 301 302 def test_identity(self): 303 rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'channel') 304 self.assertLabeledTensorsEqual(rename_axis_lt, self.original_lt) 305 306 def test_new_name(self): 307 rename_axis_lt = ops.rename_axis(self.original_lt, 'channel', 'foo') 308 expected_axes = [(name if name != 'channel' else 'foo', axis.value) 309 for name, axis in self.original_lt.axes.items()] 310 expected_lt = core.LabeledTensor(self.original_lt.tensor, expected_axes) 311 self.assertLabeledTensorsEqual(rename_axis_lt, expected_lt) 312 313 def test_invalid_input(self): 314 with self.assertRaisesRegexp(ValueError, 'not contained in the set'): 315 ops.rename_axis(self.original_lt, 'foo', 'bar') 316 317 318class BatchTest(Base): 319 320 def setUp(self): 321 super(BatchTest, self).setUp() 322 323 tensors = [] 324 for i in range(10): 325 offset_lt = core.LabeledTensor(constant_op.constant(i), []) 326 tensors.append(core.add(self.original_lt, offset_lt)) 327 self.pack_lt = ops.pack(tensors, 'batch') 328 329 def test_name(self): 330 batch_ops = ops.batch( 331 [self.pack_lt, self.pack_lt], batch_size=2, enqueue_many=True) 332 for bo in batch_ops: 333 self.assertIn('lt_batch', bo.name) 334 335 def test_enqueue_many(self): 336 [batch_2_op] = ops.batch([self.pack_lt], batch_size=2, enqueue_many=True) 337 self.assertEqual(len(batch_2_op.axes['batch']), 2) 338 339 [batch_10_op] = ops.batch([batch_2_op], batch_size=10, enqueue_many=True) 340 341 self.assertLabeledTensorsEqual(self.pack_lt, batch_10_op) 342 343 def test_no_enqueue_many(self): 344 [batch_2_op] = ops.batch([self.original_lt], batch_size=2) 345 self.assertEqual(len(batch_2_op.axes['batch']), 2) 346 347 [batch_10_op] = ops.batch([batch_2_op], batch_size=10, enqueue_many=True) 348 349 self.assertLabeledTensorsEqual( 350 ops.pack(10 * [self.original_lt], 'batch'), batch_10_op) 351 352 def test_invalid_input(self): 353 with self.assertRaises(ValueError): 354 ops.batch([self.original_lt], 3, enqueue_many=True) 355 356 def test_allow_smaller_final_batch(self): 357 [batch_2_op] = ops.batch( 358 [self.original_lt], batch_size=2, allow_smaller_final_batch=True) 359 self.assertEqual(batch_2_op.axes['batch'].size, None) 360 361 362class ShuffleBatchTest(Base): 363 364 def setUp(self): 365 super(ShuffleBatchTest, self).setUp() 366 367 tensors = [] 368 for i in range(10): 369 offset_lt = core.LabeledTensor(constant_op.constant(i), []) 370 tensors.append(core.add(self.original_lt, offset_lt)) 371 self.pack_lt = ops.pack(tensors, 'batch') 372 373 def test_name(self): 374 batch_lts = ops.shuffle_batch( 375 [self.pack_lt, self.pack_lt], batch_size=2, enqueue_many=True) 376 for blt in batch_lts: 377 self.assertIn('lt_shuffle_batch', blt.name) 378 379 def test_enqueue_many(self): 380 [batch_2_lt] = ops.shuffle_batch( 381 [self.pack_lt], 382 batch_size=2, 383 enqueue_many=True, 384 min_after_dequeue=8, 385 seed=0) 386 self.assertEqual(len(batch_2_lt.axes['batch']), 2) 387 388 [batch_10_lt] = ops.batch([batch_2_lt], batch_size=10, enqueue_many=True) 389 390 self.assertEqual(batch_10_lt.axes, self.pack_lt.axes) 391 [batch_10, pack] = self.eval([batch_10_lt.tensor, self.pack_lt.tensor]) 392 self.assertFalse((batch_10 == pack).all()) 393 394 def test_allow_smaller_final_batch(self): 395 [batch_2_op] = ops.shuffle_batch( 396 [self.original_lt], batch_size=2, allow_smaller_final_batch=True) 397 self.assertEqual(batch_2_op.axes['batch'].size, None) 398 399 400class RandomCropTest(Base): 401 402 def test_name(self): 403 crop_lt = ops.random_crop(self.original_lt, {'probs': 3}) 404 self.assertIn('lt_random_crop', crop_lt.name) 405 406 def test_single(self): 407 crop_lt = ops.random_crop(self.original_lt, {'probs': 3}) 408 409 self.assertEqual( 410 core.Axes([self.a0, self.a1, self.a2_resolved, ('probs', 3)]), 411 crop_lt.axes) 412 413 def test_double(self): 414 crop_lt = ops.random_crop(self.original_lt, {'probs': 3, 'channel': 2}) 415 416 self.assertEqual( 417 core.Axes([self.a0, ('channel', 2), self.a2_resolved, ('probs', 3)]), 418 crop_lt.axes) 419 420 def test_size1(self): 421 crop_lt = ops.random_crop(self.original_lt, {'probs': 1}) 422 423 self.assertEqual( 424 core.Axes([self.a0, self.a1, self.a2_resolved, ('probs', 1)]), 425 crop_lt.axes) 426 427 def test_different_seeds(self): 428 crop_0_lt = ops.random_crop( 429 self.original_lt, {'probs': 3, 430 'channel': 2}, seed=0) 431 crop_1_lt = ops.random_crop( 432 self.original_lt, {'probs': 3, 433 'channel': 2}, seed=1) 434 435 self.assertEqual(crop_0_lt.axes, crop_1_lt.axes) 436 [crop_0, crop_1] = self.eval([crop_0_lt.tensor, crop_1_lt.tensor]) 437 self.assertFalse((crop_0 == crop_1).all()) 438 439 def test_identical_seeds(self): 440 crop_0_lt = ops.random_crop( 441 self.original_lt, {'probs': 3, 442 'channel': 2}, seed=0) 443 crop_1_lt = ops.random_crop( 444 self.original_lt, {'probs': 3, 445 'channel': 2}, seed=0) 446 447 self.assertLabeledTensorsEqual(crop_0_lt, crop_1_lt) 448 449 def test_crop_idempotent(self): 450 crop_0_lt = ops.random_crop( 451 self.original_lt, {'probs': 3, 452 'channel': 2}, seed=0) 453 crop_1_lt = ops.random_crop(crop_0_lt, {'probs': 3, 'channel': 2}, seed=1) 454 455 self.assertLabeledTensorsEqual(crop_0_lt, crop_1_lt) 456 457 def test_invalid_input(self): 458 with self.assertRaises(ValueError): 459 ops.random_crop(self.original_lt, {'foobar': 2}) 460 461 462class MapFnTest(Base): 463 464 def test_name(self): 465 map_lt = ops.map_fn(core.identity, self.original_lt) 466 self.assertIn('lt_map_fn', map_lt.name) 467 468 def test_identity(self): 469 map_lt = ops.map_fn(core.identity, self.original_lt) 470 self.assertLabeledTensorsEqual(map_lt, self.original_lt) 471 472 def test_callable_object(self): 473 474 class Identity(object): 475 476 def __call__(self, other): 477 return other 478 479 map_lt = ops.map_fn(Identity(), self.original_lt) 480 self.assertLabeledTensorsEqual(map_lt, self.original_lt) 481 482 def test_slice(self): 483 map_lt = ops.map_fn(lambda t: core.slice_function(t, {'channel': 1}), 484 self.original_lt) 485 slice_lt = core.slice_function(self.original_lt, {'channel': 1}) 486 self.assertLabeledTensorsEqual(map_lt, slice_lt) 487 488 def test_string(self): 489 490 def fn(entry_lt): 491 op = string_ops.string_join([entry_lt, 'world']) 492 return core.LabeledTensor(op, []) 493 494 tensor_lt = ops.constant(['hi', 'bye'], axes=['batch']) 495 map_lt = ops.map_fn(fn, tensor_lt) 496 golden_lt = ops.constant(['hiworld', 'byeworld'], axes=['batch']) 497 498 self.assertLabeledTensorsEqual(map_lt, golden_lt) 499 500 501class FoldlTest(Base): 502 503 def test_name(self): 504 foldl_lt = ops.foldl(core.add, self.original_lt, 505 core.slice_function(self.original_lt, {'x': 0})) 506 self.assertIn('lt_foldl', foldl_lt.name) 507 508 def test_sum(self): 509 initializer_lt = ops.constant([0, 10], axes=['y']) 510 tensor_lt = ops.constant([[1, 2], [3, 4], [5, 6]], axes=['x', 'y']) 511 foldl_lt = ops.foldl(core.add, tensor_lt, initializer_lt) 512 golden_lt = ops.constant([9, 22], axes=['y']) 513 self.assertLabeledTensorsEqual(foldl_lt, golden_lt) 514 515 516class SqueezeTest(Base): 517 518 def setUp(self): 519 super(SqueezeTest, self).setUp() 520 521 self.squeezable_lt = core.slice_function( 522 self.original_lt, {'channel': slice(0, 1), 523 'probs': slice(0, 1)}) 524 525 def test_name(self): 526 squeeze_lt = ops.squeeze(self.squeezable_lt) 527 self.assertIn('lt_squeeze', squeeze_lt.name) 528 529 def test_none(self): 530 none_lt = ops.squeeze(self.squeezable_lt, None) 531 axes_lt = ops.squeeze(self.squeezable_lt, ['channel', 'probs']) 532 self.assertLabeledTensorsEqual(none_lt, axes_lt) 533 534 def test(self): 535 squeeze_lt = ops.squeeze(self.squeezable_lt, ['probs']) 536 golden_lt = core.slice_function(self.squeezable_lt, {'probs': 0}) 537 self.assertLabeledTensorsEqual(squeeze_lt, golden_lt) 538 539 def test_invalid_input(self): 540 with self.assertRaises(ValueError): 541 ops.squeeze(self.original_lt, ['channel']) 542 with self.assertRaises(ValueError): 543 ops.squeeze(self.squeezable_lt, ['foo']) 544 545 546class MatMulTest(Base): 547 548 def test_name(self): 549 x_lt = core.LabeledTensor(array_ops.ones((3,)), ['x']) 550 matmul_lt = ops.matmul(x_lt, x_lt) 551 self.assertIn('lt_matmul', matmul_lt.name) 552 553 def test_vector_vector(self): 554 x_lt = core.LabeledTensor(math_ops.range(3), ['x']) 555 matmul_lt = ops.matmul(x_lt, x_lt) 556 golden_lt = core.convert_to_labeled_tensor(5) 557 self.assertLabeledTensorsEqual(matmul_lt, golden_lt) 558 559 def test_matrix_vector(self): 560 xy_lt = core.LabeledTensor( 561 array_ops.reshape(math_ops.range(6), (2, 3)), ['x', 'y']) 562 y_lt = core.LabeledTensor(math_ops.range(3), ['y']) 563 564 matmul_lt = ops.matmul(xy_lt, y_lt) 565 golden_lt = core.LabeledTensor( 566 math_ops.matmul(xy_lt.tensor, array_ops.reshape(y_lt.tensor, 567 (-1, 1)))[:, 0], ['x']) 568 self.assertLabeledTensorsEqual(matmul_lt, golden_lt) 569 570 matmul_lt = ops.matmul(y_lt, xy_lt) 571 self.assertLabeledTensorsEqual(matmul_lt, golden_lt) 572 573 def test_matrix_matrix(self): 574 xy_lt = core.LabeledTensor( 575 array_ops.reshape(math_ops.range(6), (2, 3)), ['x', 'y']) 576 yz_lt = core.LabeledTensor( 577 array_ops.reshape(math_ops.range(12), (3, 4)), ['y', 'z']) 578 579 matmul_lt = ops.matmul(xy_lt, yz_lt) 580 golden_lt = core.LabeledTensor( 581 math_ops.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z']) 582 self.assertLabeledTensorsEqual(matmul_lt, golden_lt) 583 584 transpose = lambda x: core.transpose(x, list(x.axes.keys())[::-1]) 585 586 matmul_lt = ops.matmul(xy_lt, transpose(yz_lt)) 587 self.assertLabeledTensorsEqual(matmul_lt, golden_lt) 588 589 matmul_lt = ops.matmul(transpose(xy_lt), yz_lt) 590 self.assertLabeledTensorsEqual(matmul_lt, golden_lt) 591 592 matmul_lt = ops.matmul(transpose(xy_lt), transpose(yz_lt)) 593 self.assertLabeledTensorsEqual(matmul_lt, golden_lt) 594 595 matmul_lt = ops.matmul(yz_lt, xy_lt) 596 self.assertLabeledTensorsEqual(matmul_lt, transpose(golden_lt)) 597 598 def test_matrix_matrix_axis_order(self): 599 xy_lt = core.LabeledTensor( 600 array_ops.reshape(math_ops.range(6), (2, 3)), ['x', 'y']) 601 yz_lt = core.LabeledTensor( 602 array_ops.reshape(math_ops.range(12), (3, 4)), ['y', 'z']) 603 604 golden_lt = core.LabeledTensor( 605 math_ops.matmul(xy_lt.tensor, yz_lt.tensor), ['x', 'z']) 606 607 with core.axis_order_scope(['x', 'y', 'z']): 608 609 matmul_lt = ops.matmul(xy_lt, yz_lt) 610 self.assertLabeledTensorsEqual(matmul_lt, golden_lt) 611 612 matmul_lt = ops.matmul(yz_lt, xy_lt) 613 self.assertLabeledTensorsEqual(matmul_lt, golden_lt) 614 615 def test_invalid(self): 616 scalar_lt = core.LabeledTensor(array_ops.ones(()), []) 617 x_lt = core.LabeledTensor(array_ops.ones((2,)), ['x']) 618 x2_lt = core.LabeledTensor(array_ops.ones((3,)), ['x']) 619 y_lt = core.LabeledTensor(array_ops.ones((3,)), ['y']) 620 xy_lt = core.LabeledTensor(array_ops.ones((2, 3)), ['x', 'y']) 621 xyz_lt = core.LabeledTensor(array_ops.ones((2, 3, 1)), ['x', 'y', 'z']) 622 623 with self.assertRaisesRegexp(ValueError, 'inputs with at least rank'): 624 ops.matmul(x_lt, scalar_lt) 625 626 with self.assertRaises(NotImplementedError): 627 ops.matmul(x_lt, xyz_lt) 628 629 with self.assertRaisesRegexp(ValueError, 'exactly one axis in common'): 630 ops.matmul(x_lt, y_lt) 631 632 with self.assertRaises(NotImplementedError): 633 ops.matmul(xy_lt, xy_lt) 634 635 with self.assertRaisesRegexp(ValueError, 'does not match'): 636 ops.matmul(x_lt, x2_lt) 637 638 639class ReduceSumTest(Base): 640 641 def test_name(self): 642 sum_lt = ops.reduce_sum(self.original_lt, {'channel'}) 643 self.assertIn('lt_reduce_sum', sum_lt.name) 644 645 def test_drop_axis(self): 646 sum_lt = ops.reduce_sum(self.original_lt, {'channel'}) 647 golden_lt = core.LabeledTensor( 648 math_ops.reduce_sum(self.original_lt.tensor, 1), 649 [self.a0, self.a2, self.a3]) 650 self.assertLabeledTensorsEqual(sum_lt, golden_lt) 651 652 def test_drop_scalar_axis(self): 653 sum_lt = ops.reduce_sum(self.original_lt, 'channel') 654 golden_lt = core.LabeledTensor( 655 math_ops.reduce_sum(self.original_lt.tensor, 1), 656 [self.a0, self.a2, self.a3]) 657 self.assertLabeledTensorsEqual(sum_lt, golden_lt) 658 659 def test_keep_axis(self): 660 sum_lt = ops.reduce_sum(self.original_lt, {('channel', 'hihowareyou')}) 661 golden_lt = core.LabeledTensor( 662 math_ops.reduce_sum( 663 self.original_lt.tensor, 1, keepdims=True), 664 [self.a0, ('channel', ['hihowareyou']), self.a2, self.a3]) 665 self.assertLabeledTensorsEqual(sum_lt, golden_lt) 666 667 def test_keep_scalar_axis(self): 668 sum_lt = ops.reduce_sum(self.original_lt, ('channel', 'hihowareyou')) 669 golden_lt = core.LabeledTensor( 670 math_ops.reduce_sum( 671 self.original_lt.tensor, 1, keepdims=True), 672 [self.a0, ('channel', ['hihowareyou']), self.a2, self.a3]) 673 self.assertLabeledTensorsEqual(sum_lt, golden_lt) 674 675 def test_scalar(self): 676 scalar_lt = core.LabeledTensor(constant_op.constant(42), []) 677 reduce_lt = ops.reduce_sum(scalar_lt, []) 678 self.assertLabeledTensorsEqual(reduce_lt, scalar_lt) 679 680 def test_empty_list(self): 681 reduce_lt = ops.reduce_sum(self.original_lt, []) 682 self.assertLabeledTensorsEqual(reduce_lt, self.original_lt) 683 684 def test_none(self): 685 sum_lt = ops.reduce_sum(self.original_lt) 686 golden_lt = core.LabeledTensor( 687 math_ops.reduce_sum(self.original_lt.tensor), []) 688 self.assertLabeledTensorsEqual(sum_lt, golden_lt) 689 690 def test_function_docstring_and_name(self): 691 self.assertIn('tf.reduce_sum', ops.reduce_sum.__doc__) 692 self.assertEqual('reduce_sum', ops.reduce_sum.__name__) 693 694 695class ReduceMeanTest(Base): 696 697 def test_name(self): 698 actual_lt = ops.reduce_mean(self.original_lt, {'channel'}) 699 self.assertIn('lt_reduce_mean', actual_lt.name) 700 701 def test(self): 702 actual_lt = ops.reduce_mean(self.original_lt, {'channel'}) 703 golden_lt = core.LabeledTensor( 704 math_ops.reduce_mean(self.original_lt.tensor, 1), 705 [self.a0, self.a2, self.a3]) 706 self.assertLabeledTensorsEqual(actual_lt, golden_lt) 707 708 709class ReduceProdTest(Base): 710 711 def test_name(self): 712 result_lt = ops.reduce_prod(self.original_lt, {'channel'}) 713 self.assertIn('lt_reduce_prod', result_lt.name) 714 715 def test(self): 716 result_lt = ops.reduce_prod(self.original_lt, {'channel'}) 717 golden_lt = core.LabeledTensor( 718 math_ops.reduce_prod(self.original_lt.tensor, 1), 719 [self.a0, self.a2, self.a3]) 720 self.assertLabeledTensorsEqual(result_lt, golden_lt) 721 722 723class ReduceMinTest(Base): 724 725 def test_name(self): 726 result_lt = ops.reduce_min(self.original_lt, {'channel'}) 727 self.assertIn('lt_reduce_min', result_lt.name) 728 729 def test(self): 730 result_lt = ops.reduce_min(self.original_lt, {'channel'}) 731 golden_lt = core.LabeledTensor( 732 math_ops.reduce_min(self.original_lt.tensor, 1), 733 [self.a0, self.a2, self.a3]) 734 self.assertLabeledTensorsEqual(result_lt, golden_lt) 735 736 737class ReduceMaxTest(Base): 738 739 def test_name(self): 740 result_lt = ops.reduce_max(self.original_lt, {'channel'}) 741 self.assertIn('lt_reduce_max', result_lt.name) 742 743 def test(self): 744 result_lt = ops.reduce_max(self.original_lt, {'channel'}) 745 golden_lt = core.LabeledTensor( 746 math_ops.reduce_max(self.original_lt.tensor, 1), 747 [self.a0, self.a2, self.a3]) 748 self.assertLabeledTensorsEqual(result_lt, golden_lt) 749 750 751class BaseReduceBoolean(Base): 752 753 def setUp(self): 754 super(BaseReduceBoolean, self).setUp() 755 self.bool_tensor = math_ops.cast(self.original_lt.tensor > 5, dtypes.bool) 756 self.bool_lt = core.LabeledTensor(self.bool_tensor, self.original_lt.axes) 757 758 759class ReduceAllTest(BaseReduceBoolean): 760 761 def test_name(self): 762 result_lt = ops.reduce_all(self.bool_lt, {'channel'}) 763 self.assertIn('lt_reduce_all', result_lt.name) 764 765 def test(self): 766 result_lt = ops.reduce_all(self.bool_lt, {'channel'}) 767 golden_lt = core.LabeledTensor( 768 math_ops.reduce_all(self.bool_tensor, 1), [self.a0, self.a2, self.a3]) 769 self.assertLabeledTensorsEqual(result_lt, golden_lt) 770 771 772class ReduceAnyTest(BaseReduceBoolean): 773 774 def test_name(self): 775 result_lt = ops.reduce_any(self.bool_lt, {'channel'}) 776 self.assertIn('lt_reduce_any', result_lt.name) 777 778 def test(self): 779 result_lt = ops.reduce_any(self.bool_lt, {'channel'}) 780 golden_lt = core.LabeledTensor( 781 math_ops.reduce_any(self.bool_tensor, 1), [self.a0, self.a2, self.a3]) 782 self.assertLabeledTensorsEqual(result_lt, golden_lt) 783 784 785class TileTest(Base): 786 787 def test_name(self): 788 tile_lt = ops.tile(self.original_lt, {'z': 2}) 789 self.assertIn('lt_tile', tile_lt.name) 790 791 def test(self): 792 for multiple in [2, constant_op.constant(2)]: 793 tile_lt = ops.tile(self.original_lt, {'z': multiple}) 794 golden_op = array_ops.tile(self.original_lt.tensor, [1, 1, multiple, 1]) 795 golden_axes = [ 796 'z' if axis.name == 'z' else axis 797 for axis in self.original_lt.axes.values() 798 ] 799 golden_lt = core.LabeledTensor(golden_op, golden_axes) 800 self.assertLabeledTensorsEqual(tile_lt, golden_lt) 801 802 def test_invalid_input(self): 803 with self.assertRaisesRegexp(ValueError, 'are not contained in the set'): 804 ops.tile(self.original_lt, {'foo': 5}) 805 with self.assertRaisesRegexp(ValueError, 'axes with tick labels'): 806 ops.tile(self.original_lt, {'x': 5}) 807 808 809class PadTest(Base): 810 811 def test_name(self): 812 pad_lt = ops.pad(self.original_lt, 813 {'x': (1, 1), 814 'channel': ([], ['alpha'])}) 815 self.assertIn('lt_pad', pad_lt.name) 816 817 def test(self): 818 pad_lt = ops.pad(self.original_lt, 819 {'x': (1, 1), 820 'channel': ([], ['alpha'])}) 821 822 golden_op = array_ops.pad(self.original_lt.tensor, [[1, 1], [0, 1], [0, 0], 823 [0, 0]]) 824 golden_axes = [('x', self.x_size + 2), 825 ('channel', ['red', 'green', 'blue', 'alpha']), self.a2, 826 self.a3] 827 golden_lt = core.LabeledTensor(golden_op, golden_axes) 828 self.assertLabeledTensorsEqual(pad_lt, golden_lt) 829 830 def test_invalid_input(self): 831 with self.assertRaisesRegexp(ValueError, 'are not contained in the set'): 832 ops.pad(self.original_lt, {'foo': (1, 1), 'channel': ([], ['alpha'])}) 833 834 835class ConstantTest(Base): 836 837 def test_name(self): 838 constant_lt = ops.constant(1) 839 self.assertIn('lt_constant', constant_lt.name) 840 841 def test_scalar(self): 842 constant_lt = ops.constant(1) 843 golden_lt = core.LabeledTensor(constant_op.constant(1), []) 844 self.assertLabeledTensorsEqual(constant_lt, golden_lt) 845 846 def test_infer_shape(self): 847 constant_lt = ops.constant([1, 2], axes=['x']) 848 golden_lt = core.LabeledTensor(constant_op.constant([1, 2]), ['x']) 849 self.assertLabeledTensorsEqual(constant_lt, golden_lt) 850 851 def test_specify_shape(self): 852 constant_lt = ops.constant(1, axes=[('x', 3)]) 853 golden_lt = core.LabeledTensor(constant_op.constant(1, shape=(3,)), ['x']) 854 self.assertLabeledTensorsEqual(constant_lt, golden_lt) 855 856 def test_existing_axes(self): 857 golden_lt = core.LabeledTensor(constant_op.constant([1, 2]), ['x']) 858 constant_lt = ops.constant([1, 2], axes=golden_lt.axes) 859 self.assertLabeledTensorsEqual(constant_lt, golden_lt) 860 861 862class ZerosLikeTest(Base): 863 864 def test_name(self): 865 like_lt = ops.zeros_like(self.original_lt) 866 self.assertIn('lt_zeros_like', like_lt.name) 867 868 def test(self): 869 like_lt = ops.zeros_like(self.original_lt) 870 golden_lt = core.LabeledTensor( 871 array_ops.zeros_like(self.original_lt.tensor), self.original_lt.axes) 872 self.assertLabeledTensorsEqual(like_lt, golden_lt) 873 874 875class OnesLikeTest(Base): 876 877 def test_name(self): 878 like_lt = ops.ones_like(self.original_lt) 879 self.assertIn('lt_ones_like', like_lt.name) 880 881 def test(self): 882 like_lt = ops.ones_like(self.original_lt) 883 golden_lt = core.LabeledTensor( 884 array_ops.ones_like(self.original_lt.tensor), self.original_lt.axes) 885 self.assertLabeledTensorsEqual(like_lt, golden_lt) 886 887 888class CastTest(Base): 889 890 def test_name(self): 891 cast_lt = ops.cast(self.original_lt, dtypes.float16) 892 self.assertIn('lt_cast', cast_lt.name) 893 894 def test(self): 895 cast_lt = ops.cast(self.original_lt, dtypes.float16) 896 golden_lt = core.LabeledTensor( 897 math_ops.cast(self.original_lt.tensor, dtypes.float16), 898 self.original_lt.axes) 899 self.assertLabeledTensorsEqual(cast_lt, golden_lt) 900 901 902class VerifyTensorAllFiniteTest(Base): 903 904 def setUp(self): 905 super(VerifyTensorAllFiniteTest, self).setUp() 906 907 self.finite_lt = core.LabeledTensor(constant_op.constant(42.0), []) 908 self.nan_lt = core.LabeledTensor(constant_op.constant(np.nan), []) 909 910 self.checked_finite_lt = ops.verify_tensor_all_finite(self.finite_lt, '') 911 self.checked_nan_lt = ops.verify_tensor_all_finite(self.nan_lt, '') 912 913 def test_name(self): 914 self.assertIn('lt_verify_tensor_all_finite', self.checked_finite_lt.name) 915 self.assertIn('lt_verify_tensor_all_finite', self.checked_nan_lt.name) 916 917 def test_finite(self): 918 self.assertLabeledTensorsEqual(self.finite_lt, self.checked_finite_lt) 919 920 def test_nan(self): 921 with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 922 'Tensor had NaN values'): 923 self.eval([self.checked_nan_lt]) 924 925 926class BooleanMaskTest(Base): 927 928 def test_name(self): 929 mask = core.LabeledTensor(math_ops.range(7) > 3, [self.a0]) 930 masked_lt = ops.boolean_mask(self.original_lt, mask) 931 self.assertIn('lt_boolean_mask', masked_lt.name) 932 933 def test(self): 934 mask = core.LabeledTensor(math_ops.range(7) > 3, [self.a0]) 935 masked_lt = ops.boolean_mask(self.original_lt, mask) 936 golden_lt = core.LabeledTensor( 937 array_ops.boolean_mask(self.original_lt.tensor, mask.tensor), 938 ['x', self.a1, self.a2, self.a3]) 939 self.assertLabeledTensorsEqual(masked_lt, golden_lt) 940 941 def test_invalid_rank(self): 942 mask = core.LabeledTensor(array_ops.ones((7, 3)) > 3, [self.a0, self.a1]) 943 with self.assertRaises(NotImplementedError): 944 ops.boolean_mask(self.original_lt, mask) 945 946 def test_mismatched_axis(self): 947 mask = core.LabeledTensor(math_ops.range(7) > 3, ['foo']) 948 with self.assertRaisesRegexp(ValueError, 'not equal'): 949 ops.boolean_mask(self.original_lt, mask) 950 951 952class WhereTest(Base): 953 954 def test_name(self): 955 condition = core.LabeledTensor(math_ops.range(5) < 3, ['x']) 956 where_lt = ops.where(condition, condition, condition) 957 self.assertIn('lt_where', where_lt.name) 958 959 def test(self): 960 condition = core.LabeledTensor(math_ops.range(5) < 3, ['x']) 961 x = core.LabeledTensor(array_ops.ones(5), ['x']) 962 y = core.LabeledTensor(array_ops.zeros(5), ['x']) 963 where_lt = ops.where(condition, x, y) 964 965 golden_lt = core.LabeledTensor( 966 array_ops.concat([array_ops.ones(3), array_ops.zeros(2)], 0), ['x']) 967 self.assertLabeledTensorsEqual(where_lt, golden_lt) 968 969 def test_mismatched_axes(self): 970 condition = core.LabeledTensor(math_ops.range(5) < 3, ['x']) 971 with self.assertRaisesRegexp(ValueError, 'equal axes'): 972 ops.where(condition, condition[:3], condition) 973 with self.assertRaisesRegexp(ValueError, 'equal axes'): 974 ops.where(condition, condition, condition[:3]) 975 976 977if __name__ == '__main__': 978 test_lib.main() 979