1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Non-core ops for LabeledTensor.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import types 22 23import numpy as np 24from six import string_types 25 26from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc 27from tensorflow.contrib.labeled_tensor.python.ops import core 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import functional_ops 32from tensorflow.python.ops import map_fn as map_fn_lib 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import numerics 35from tensorflow.python.ops import random_ops 36from tensorflow.python.training import input # pylint: disable=redefined-builtin 37 38 39@tc.returns(core.LabeledTensor) 40@tc.accepts(core.LabeledTensor, ops.Tensor, core.Axis, 41 tc.Optional(string_types)) 42def _gather_1d_on_axis(labeled_tensor, indexer, axis, name=None): 43 with ops.name_scope(name, 'lt_take', [labeled_tensor]) as scope: 44 temp_axes = core.Axes([axis] + list( 45 labeled_tensor.axes.remove(axis.name).values())) 46 transposed = core.transpose(labeled_tensor, temp_axes.keys()) 47 indexed = core.LabeledTensor( 48 array_ops.gather(transposed.tensor, indexer), temp_axes) 49 return core.transpose(indexed, labeled_tensor.axes.keys(), name=scope) 50 51 52@tc.returns(core.LabeledTensor) 53@tc.accepts(core.LabeledTensorLike, 54 tc.Mapping(string_types, 55 tc.Union(slice, collections.Hashable, list)), 56 tc.Optional(string_types)) 57def select(labeled_tensor, selection, name=None): 58 """Slice out a subset of the tensor. 59 60 Args: 61 labeled_tensor: The input tensor. 62 selection: A dictionary mapping an axis name to a scalar, slice or list of 63 values to select. Currently supports two types of selections: 64 (a) Any number of scalar and/or slice selections. 65 (b) Exactly one list selection, without any scalars or slices. 66 name: Optional op name. 67 68 Returns: 69 The selection as a `LabeledTensor`. 70 71 Raises: 72 ValueError: If the tensor doesn't have an axis in the selection or if 73 that axis lacks labels. 74 KeyError: If any labels in a selection are not found in the original axis. 75 NotImplementedError: If you attempt to combine a list selection with 76 scalar selection or another list selection. 77 """ 78 with ops.name_scope(name, 'lt_select', [labeled_tensor]) as scope: 79 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 80 81 slices = {} 82 indexers = {} 83 for axis_name, value in selection.items(): 84 if axis_name not in labeled_tensor.axes: 85 raise ValueError( 86 'The tensor does not have an axis named %s. Its axes are: %r' % 87 (axis_name, labeled_tensor.axes.keys())) 88 axis = labeled_tensor.axes[axis_name] 89 if axis.labels is None: 90 raise ValueError( 91 'The axis named %s does not have labels. The axis is: %r' % 92 (axis_name, axis)) 93 94 if isinstance(value, slice): 95 # TODO(shoyer): consider deprecating using slices in favor of lists 96 if value.start is None: 97 start = None 98 else: 99 start = axis.index(value.start) 100 101 if value.stop is None: 102 stop = None 103 else: 104 # For now, follow the pandas convention of making labeled slices 105 # inclusive of both bounds. 106 stop = axis.index(value.stop) + 1 107 108 if value.step is not None: 109 raise NotImplementedError('slicing with a step is not yet supported') 110 111 slices[axis_name] = slice(start, stop) 112 113 # Needs to be after checking for slices, since slice objects claim to be 114 # instances of collections.Hashable but hash() on them fails. 115 elif isinstance(value, collections.Hashable): 116 slices[axis_name] = axis.index(value) 117 118 elif isinstance(value, list): 119 if indexers: 120 raise NotImplementedError( 121 'select does not yet support more than one list selection at ' 122 'the same time') 123 indexer = [axis.index(v) for v in value] 124 indexers[axis_name] = ops.convert_to_tensor(indexer, dtype=dtypes.int64) 125 126 else: 127 # If type checking is working properly, this shouldn't be possible. 128 raise TypeError('cannot handle arbitrary types') 129 130 if indexers and slices: 131 raise NotImplementedError( 132 'select does not yet support combined scalar and list selection') 133 134 # For now, handle array selection separately, because tf.gather_nd does 135 # not support gradients yet. Later, using gather_nd will let us combine 136 # these paths. 137 if indexers: 138 (axis_name, indexer), = indexers.items() 139 axis = core.Axis(axis_name, selection[axis_name]) 140 return _gather_1d_on_axis(labeled_tensor, indexer, axis, name=scope) 141 else: 142 return core.slice_function(labeled_tensor, slices, name=scope) 143 144 145@tc.returns(core.LabeledTensor) 146@tc.accepts( 147 tc.Collection(core.LabeledTensorLike), string_types, 148 tc.Optional(string_types)) 149def concat(labeled_tensors, axis_name, name=None): 150 """Concatenate tensors along a dimension. 151 152 See tf.concat. 153 154 Args: 155 labeled_tensors: A list of input LabeledTensors. 156 axis_name: The name of the axis along which to concatenate. 157 name: Optional op name. 158 159 Returns: 160 The concatenated tensor. 161 The coordinate labels for the concatenation dimension are also concatenated, 162 if they are available for every tensor. 163 164 Raises: 165 ValueError: If fewer than one tensor inputs is provided, if the tensors 166 have incompatible axes, or if `axis_name` isn't the name of an axis. 167 """ 168 with ops.name_scope(name, 'lt_concat', labeled_tensors) as scope: 169 labeled_tensors = [ 170 core.convert_to_labeled_tensor(lt) for lt in labeled_tensors 171 ] 172 173 if len(labeled_tensors) < 1: 174 raise ValueError('concat expects at least 1 tensor, but received %s' % 175 labeled_tensors) 176 177 # All tensors must have these axes. 178 axes_0 = labeled_tensors[0].axes 179 axis_names = list(axes_0.keys()) 180 181 if axis_name not in axis_names: 182 raise ValueError('%s not in %s' % (axis_name, axis_names)) 183 184 shared_axes = axes_0.remove(axis_name) 185 186 tensors = [labeled_tensors[0].tensor] 187 concat_axis_list = [axes_0[axis_name]] 188 for labeled_tensor in labeled_tensors[1:]: 189 current_shared_axes = labeled_tensor.axes.remove(axis_name) 190 if current_shared_axes != shared_axes: 191 # TODO(shoyer): add more specific checks about what went wrong, 192 # including raising AxisOrderError when appropriate 193 raise ValueError('Mismatched shared axes: the first tensor ' 194 'had axes %r but this tensor has axes %r.' % 195 (shared_axes, current_shared_axes)) 196 197 # Accumulate the axis labels, if they're available. 198 concat_axis_list.append(labeled_tensor.axes[axis_name]) 199 tensors.append(labeled_tensor.tensor) 200 201 concat_axis = core.concat_axes(concat_axis_list) 202 concat_dimension = axis_names.index(axis_name) 203 concat_tensor = array_ops.concat(tensors, concat_dimension, name=scope) 204 values = list(axes_0.values()) 205 concat_axes = (values[:concat_dimension] + [concat_axis] + 206 values[concat_dimension + 1:]) 207 208 return core.LabeledTensor(concat_tensor, concat_axes) 209 210 211# TODO(shoyer): rename pack/unpack to stack/unstack 212 213 214@tc.returns(core.LabeledTensor) 215@tc.accepts( 216 tc.Collection(core.LabeledTensorLike), 217 tc.Union(string_types, core.AxisLike), int, tc.Optional(string_types)) 218def pack(labeled_tensors, new_axis, axis_position=0, name=None): 219 """Pack tensors along a new axis. 220 221 See tf.pack. 222 223 Args: 224 labeled_tensors: The input tensors, which must have identical axes. 225 new_axis: The name of the new axis, or a tuple containing the name 226 and coordinate labels. 227 axis_position: Optional integer position at which to insert the new axis. 228 name: Optional op name. 229 230 Returns: 231 The packed tensors as a single LabeledTensor, with `new_axis` in the given 232 `axis_position`. 233 234 Raises: 235 ValueError: If fewer than one input tensors is provided, or if the tensors 236 don't have identical axes. 237 """ 238 with ops.name_scope(name, 'lt_pack', labeled_tensors) as scope: 239 labeled_tensors = [ 240 core.convert_to_labeled_tensor(lt) for lt in labeled_tensors 241 ] 242 243 if len(labeled_tensors) < 1: 244 raise ValueError('pack expects at least 1 tensors, but received %s' % 245 labeled_tensors) 246 247 axes_0 = labeled_tensors[0].axes 248 for t in labeled_tensors: 249 if t.axes != axes_0: 250 raise ValueError('Non-identical axes. Expected %s but got %s' % 251 (axes_0, t.axes)) 252 253 pack_op = array_ops.stack( 254 [t.tensor for t in labeled_tensors], axis=axis_position, name=scope) 255 axes = list(axes_0.values()) 256 axes.insert(axis_position, new_axis) 257 return core.LabeledTensor(pack_op, axes) 258 259 260@tc.returns(tc.List(core.LabeledTensor)) 261@tc.accepts(core.LabeledTensorLike, 262 tc.Optional(string_types), tc.Optional(string_types)) 263def unpack(labeled_tensor, axis_name=None, name=None): 264 """Unpack the tensor. 265 266 See tf.unpack. 267 268 Args: 269 labeled_tensor: The input tensor. 270 axis_name: Optional name of axis to unpack. By default, the first axis is 271 used. 272 name: Optional op name. 273 274 Returns: 275 The list of unpacked LabeledTensors. 276 277 Raises: 278 ValueError: If `axis_name` is not an axis on the input. 279 """ 280 with ops.name_scope(name, 'lt_unpack', [labeled_tensor]) as scope: 281 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 282 283 axis_names = list(labeled_tensor.axes.keys()) 284 if axis_name is None: 285 axis_name = axis_names[0] 286 287 if axis_name not in axis_names: 288 raise ValueError('%s not in %s' % (axis_name, axis_names)) 289 axis = axis_names.index(axis_name) 290 291 unpack_ops = array_ops.unstack(labeled_tensor.tensor, axis=axis, name=scope) 292 axes = [a for i, a in enumerate(labeled_tensor.axes.values()) if i != axis] 293 return [core.LabeledTensor(t, axes) for t in unpack_ops] 294 295 296@tc.returns(core.LabeledTensor) 297@tc.accepts(core.LabeledTensorLike, 298 tc.Collection(string_types), 299 tc.Collection(tc.Union(string_types, core.AxisLike)), 300 tc.Optional(string_types)) 301def reshape(labeled_tensor, existing_axes, new_axes, name=None): 302 """Reshape specific axes of a LabeledTensor. 303 304 Non-indicated axes remain in their original locations. 305 306 Args: 307 labeled_tensor: The input tensor. 308 existing_axes: List of axis names found on the input tensor. These must 309 appear sequentially in the list of axis names on the input. In other 310 words, they must be a valid slice of `list(labeled_tensor.axes.keys())`. 311 new_axes: List of strings, tuples of (axis_name, axis_value) or Axis objects 312 providing new axes with which to replace `existing_axes` in the reshaped 313 result. At most one element of `new_axes` may be a string, indicating an 314 axis with unknown size. 315 name: Optional op name. 316 317 Returns: 318 The reshaped LabeledTensor. 319 320 Raises: 321 ValueError: If `existing_axes` are not all axes on the input, or if more 322 than one of `new_axes` has unknown size. 323 AxisOrderError: If `existing_axes` are not a slice of axis names on the 324 input. 325 """ 326 with ops.name_scope(name, 'lt_reshape', [labeled_tensor]) as scope: 327 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 328 329 original_axis_names = list(labeled_tensor.axes.keys()) 330 existing_axes = list(existing_axes) 331 if not set(existing_axes) <= set(original_axis_names): 332 raise ValueError('existing_axes %r are not contained in the set of axis ' 333 'names %r on the input labeled tensor' % 334 (existing_axes, original_axis_names)) 335 336 start = original_axis_names.index(existing_axes[0]) 337 stop = original_axis_names.index(existing_axes[-1]) + 1 338 339 if existing_axes != original_axis_names[start:stop]: 340 # We could support existing_axes that aren't a slice by using transpose, 341 # but that could lead to unpredictable performance consequences because 342 # transposes are not free in TensorFlow. If we did transpose 343 # automatically, the user might never realize that their data is being 344 # produced with the wrong order. (The later will occur with some frequency 345 # because of how broadcasting automatically choose axis order.) 346 # So for now we've taken the strict approach. 347 raise core.AxisOrderError( 348 'existing_axes %r are not a slice of axis names %r on the input ' 349 'labeled tensor. Use `transpose` or `impose_axis_order` to reorder ' 350 'axes on the input explicitly.' % 351 (existing_axes, original_axis_names)) 352 353 if sum(isinstance(axis, string_types) for axis in new_axes) > 1: 354 raise ValueError( 355 'at most one axis in new_axes can have unknown size. All other ' 356 'axes must have an indicated integer size or labels: %r' % new_axes) 357 358 original_values = list(labeled_tensor.axes.values()) 359 axis_size = lambda axis: -1 if axis.size is None else axis.size 360 shape = [axis_size(axis) for axis in original_values[:start]] 361 for axis_ref in new_axes: 362 if isinstance(axis_ref, string_types): 363 shape.append(-1) 364 else: 365 axis = core.as_axis(axis_ref) 366 shape.append(axis_size(axis)) 367 shape.extend(axis_size(axis) for axis in original_values[stop:]) 368 369 reshaped_tensor = array_ops.reshape( 370 labeled_tensor.tensor, shape, name=scope) 371 axes = original_values[:start] + list(new_axes) + original_values[stop:] 372 return core.LabeledTensor(reshaped_tensor, axes) 373 374 375@tc.returns(core.LabeledTensor) 376@tc.accepts(core.LabeledTensorLike, string_types, string_types, 377 tc.Optional(string_types)) 378def rename_axis(labeled_tensor, existing_name, new_name, name=None): 379 """Rename an axis of LabeledTensor. 380 381 Args: 382 labeled_tensor: The input tensor. 383 existing_name: Name for an existing axis on the input. 384 new_name: Desired replacement name. 385 name: Optional op name. 386 387 Returns: 388 LabeledTensor with renamed axis. 389 390 Raises: 391 ValueError: If `existing_name` is not an axis on the input. 392 """ 393 with ops.name_scope(name, 'lt_rename_axis', [labeled_tensor]) as scope: 394 if existing_name not in labeled_tensor.axes: 395 raise ValueError('existing_name %r are not contained in the set of axis ' 396 'names %r on the input labeled tensor' % 397 (existing_name, labeled_tensor.axes.keys())) 398 new_axis = core.Axis(new_name, labeled_tensor.axes[existing_name].value) 399 return reshape(labeled_tensor, [existing_name], [new_axis], name=scope) 400 401 402@tc.returns(tc.List(core.LabeledTensor)) 403@tc.accepts(string_types, collections.Callable, int, bool, 404 tc.Collection(core.LabeledTensorLike), bool, 405 tc.Optional(string_types)) 406def _batch_helper(default_name, 407 batch_fn, 408 batch_size, 409 enqueue_many, 410 labeled_tensors, 411 allow_smaller_final_batch, 412 name=None): 413 with ops.name_scope(name, default_name, labeled_tensors) as scope: 414 labeled_tensors = [ 415 core.convert_to_labeled_tensor(lt) for lt in labeled_tensors 416 ] 417 418 batch_ops = batch_fn([t.tensor for t in labeled_tensors], scope) 419 # TODO(shoyer): Remove this when they sanitize the TF API. 420 if not isinstance(batch_ops, list): 421 assert isinstance(batch_ops, ops.Tensor) 422 batch_ops = [batch_ops] 423 424 if allow_smaller_final_batch: 425 batch_size = None 426 427 @tc.returns(core.Axes) 428 @tc.accepts(core.Axes) 429 def output_axes(axes): 430 if enqueue_many: 431 if 'batch' not in axes or list(axes.keys()).index('batch') != 0: 432 raise ValueError( 433 'When enqueue_many is True, input tensors must have an axis ' 434 'called "batch" as their first dimension, ' 435 'but axes were %s' % axes) 436 culled_axes = axes.remove('batch') 437 return core.Axes([('batch', batch_size)] + list(culled_axes.values())) 438 else: 439 return core.Axes([('batch', batch_size)] + list(axes.values())) 440 441 output_labeled_tensors = [] 442 for i, tensor in enumerate(batch_ops): 443 axes = output_axes(labeled_tensors[i].axes) 444 output_labeled_tensors.append(core.LabeledTensor(tensor, axes)) 445 446 return output_labeled_tensors 447 448 449@tc.returns(tc.List(core.LabeledTensor)) 450@tc.accepts( 451 tc.Collection(core.LabeledTensorLike), int, int, int, bool, bool, 452 tc.Optional(string_types)) 453def batch(labeled_tensors, 454 batch_size, 455 num_threads=1, 456 capacity=32, 457 enqueue_many=False, 458 allow_smaller_final_batch=False, 459 name=None): 460 """Rebatch a tensor. 461 462 See tf.batch. 463 464 Args: 465 labeled_tensors: The input tensors. 466 batch_size: The output batch size. 467 num_threads: See tf.batch. 468 capacity: See tf.batch. 469 enqueue_many: If true, the input tensors must contain a 'batch' axis as 470 their first axis. 471 If false, the input tensors must not contain a 'batch' axis. 472 See tf.batch. 473 allow_smaller_final_batch: See tf.batch. 474 name: Optional op name. 475 476 Returns: 477 The rebatched tensors. 478 If enqueue_many is false, the output tensors will have a new 'batch' axis 479 as their first axis. 480 481 Raises: 482 ValueError: If enqueue_many is True and the first axis of the tensors 483 isn't "batch". 484 """ 485 486 def fn(tensors, scope): 487 return input.batch( 488 tensors, 489 batch_size=batch_size, 490 num_threads=num_threads, 491 capacity=capacity, 492 enqueue_many=enqueue_many, 493 allow_smaller_final_batch=allow_smaller_final_batch, 494 name=scope) 495 496 return _batch_helper('lt_batch', fn, batch_size, enqueue_many, 497 labeled_tensors, allow_smaller_final_batch, name) 498 499 500@tc.returns(tc.List(core.LabeledTensor)) 501@tc.accepts( 502 tc.Collection(core.LabeledTensorLike), int, int, int, bool, int, 503 tc.Optional(int), bool, tc.Optional(string_types)) 504def shuffle_batch(labeled_tensors, 505 batch_size, 506 num_threads=1, 507 capacity=32, 508 enqueue_many=False, 509 min_after_dequeue=0, 510 seed=None, 511 allow_smaller_final_batch=False, 512 name=None): 513 """Rebatch a tensor, with shuffling. 514 515 See tf.batch. 516 517 Args: 518 labeled_tensors: The input tensors. 519 batch_size: The output batch size. 520 num_threads: See tf.batch. 521 capacity: See tf.batch. 522 enqueue_many: If true, the input tensors must contain a 'batch' axis as 523 their first axis. 524 If false, the input tensors must not contain a 'batch' axis. 525 See tf.batch. 526 min_after_dequeue: Minimum number of elements in the queue after a dequeue, 527 used to ensure mixing. 528 seed: Optional random seed. 529 allow_smaller_final_batch: See tf.batch. 530 name: Optional op name. 531 532 Returns: 533 The rebatched tensors. 534 If enqueue_many is false, the output tensors will have a new 'batch' axis 535 as their first axis. 536 537 Raises: 538 ValueError: If enqueue_many is True and the first axis of the tensors 539 isn't "batch". 540 """ 541 542 def fn(tensors, scope): 543 return input.shuffle_batch( 544 tensors, 545 batch_size=batch_size, 546 num_threads=num_threads, 547 capacity=capacity, 548 enqueue_many=enqueue_many, 549 min_after_dequeue=min_after_dequeue, 550 seed=seed, 551 allow_smaller_final_batch=allow_smaller_final_batch, 552 name=scope) 553 554 return _batch_helper('lt_shuffle_batch', fn, batch_size, enqueue_many, 555 labeled_tensors, allow_smaller_final_batch, name) 556 557 558@tc.returns(core.LabeledTensor) 559@tc.accepts(core.LabeledTensorLike, 560 tc.Mapping(string_types, int), 561 tc.Optional(int), tc.Optional(string_types)) 562def random_crop(labeled_tensor, shape_map, seed=None, name=None): 563 """Randomly crops a tensor to a given size. 564 565 See tf.random_crop. 566 567 Args: 568 labeled_tensor: The input tensor. 569 shape_map: A dictionary mapping axis names to the size of the random crop 570 for that dimension. 571 seed: An optional random seed. 572 name: An optional op name. 573 574 Returns: 575 A tensor of the same rank as `labeled_tensor`, cropped randomly in the 576 selected dimensions. 577 578 Raises: 579 ValueError: If the shape map contains an axis name not in the input tensor. 580 """ 581 with ops.name_scope(name, 'lt_random_crop', [labeled_tensor]) as scope: 582 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 583 584 for axis_name in shape_map: 585 if axis_name not in labeled_tensor.axes: 586 raise ValueError('Selection axis %s not in axes %s' % 587 (axis_name, labeled_tensor.axes)) 588 589 shape = [] 590 axes = [] 591 for axis in labeled_tensor.axes.values(): 592 if axis.name in shape_map: 593 size = shape_map[axis.name] 594 shape.append(size) 595 # We lose labels for the axes we crop, leaving just the size. 596 axes.append((axis.name, size)) 597 else: 598 shape.append(len(axis)) 599 axes.append(axis) 600 601 crop_op = random_ops.random_crop( 602 labeled_tensor.tensor, shape, seed=seed, name=scope) 603 604 return core.LabeledTensor(crop_op, axes) 605 606 607# TODO(shoyer): Allow the user to select the axis over which to map. 608@tc.returns(core.LabeledTensor) 609@tc.accepts(collections.Callable, core.LabeledTensorLike, 610 tc.Optional(string_types)) 611def map_fn(fn, labeled_tensor, name=None): 612 """Map on the list of tensors unpacked from labeled_tensor. 613 614 See tf.map_fn. 615 616 Args: 617 fn: The function to apply to each unpacked LabeledTensor. 618 It should have type LabeledTensor -> LabeledTensor. 619 labeled_tensor: The input tensor. 620 name: Optional op name. 621 622 Returns: 623 A tensor that packs the results of applying fn to the list of tensors 624 unpacked from labeled_tensor. 625 """ 626 with ops.name_scope(name, 'lt_map_fn', [labeled_tensor]) as scope: 627 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 628 629 unpack_lts = unpack(labeled_tensor) 630 631 # TODO(ericmc): Fix this upstream. 632 if labeled_tensor.dtype == dtypes.string: 633 # We must construct the full graph here, because map_fn_lib.map_fn 634 # doesn't work for string-valued tensors. 635 # Constructing the full graph may be slow. 636 map_lts = [fn(t) for t in unpack_lts] 637 return pack(map_lts, list(labeled_tensor.axes.values())[0], name=scope) 638 else: 639 # Figure out what the axis labels should be, but use tf.map_fn to 640 # construct the graph because it's efficient. 641 # It may be slow to construct the full graph, so we infer the labels from 642 # the first element. 643 # TODO(ericmc): This builds a subgraph which then gets thrown away. 644 # Find a more elegant solution. 645 first_map_lt = fn(unpack_lts[0]) 646 final_axes = list(labeled_tensor.axes.values())[:1] + list( 647 first_map_lt.axes.values()) 648 649 @tc.returns(ops.Tensor) 650 @tc.accepts(ops.Tensor) 651 def tf_fn(tensor): 652 original_axes = list(labeled_tensor.axes.values())[1:] 653 tensor_lt = core.LabeledTensor(tensor, original_axes) 654 return fn(tensor_lt).tensor 655 656 map_op = map_fn_lib.map_fn( 657 tf_fn, labeled_tensor.tensor, dtype=first_map_lt.dtype) 658 map_lt = core.LabeledTensor(map_op, final_axes) 659 660 return core.identity(map_lt, name=scope) 661 662 663@tc.returns(core.LabeledTensor) 664@tc.accepts(collections.Callable, core.LabeledTensorLike, 665 core.LabeledTensorLike, tc.Optional(string_types)) 666def foldl(fn, labeled_tensor, initial_value, name=None): 667 """Left fold on the list of tensors unpacked from labeled_tensor. 668 669 See tf.foldl. 670 671 Args: 672 fn: The function to apply to each unpacked LabeledTensor. 673 It should have type (LabeledTensor, LabeledTensor) -> LabeledTensor. 674 Its arguments are (accumulated_value, next_value). 675 labeled_tensor: The input tensor. 676 initial_value: The initial value of the accumulator. 677 name: Optional op name. 678 679 Returns: 680 The accumulated value. 681 """ 682 with ops.name_scope(name, 'lt_foldl', 683 [labeled_tensor, initial_value]) as scope: 684 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 685 initial_value = core.convert_to_labeled_tensor(initial_value) 686 687 @tc.returns(ops.Tensor) 688 @tc.accepts(ops.Tensor, ops.Tensor) 689 def tf_fn(accumulator, next_element): 690 accumulator_lt = core.LabeledTensor(accumulator, initial_value.axes) 691 next_element_lt = core.LabeledTensor( 692 next_element, list(labeled_tensor.axes.values())[1:]) 693 return fn(accumulator_lt, next_element_lt).tensor 694 695 foldl_op = functional_ops.foldl( 696 tf_fn, labeled_tensor.tensor, initializer=initial_value.tensor) 697 foldl_lt = core.LabeledTensor(foldl_op, initial_value.axes) 698 699 return core.identity(foldl_lt, name=scope) 700 701 702@tc.returns(core.LabeledTensor) 703@tc.accepts(core.LabeledTensorLike, 704 tc.Optional(tc.Collection(string_types)), tc.Optional(string_types)) 705def squeeze(labeled_tensor, axis_names=None, name=None): 706 """Remove size-1 dimensions. 707 708 See tf.squeeze. 709 710 Args: 711 labeled_tensor: The input tensor. 712 axis_names: The names of the dimensions to remove, or None to remove 713 all size-1 dimensions. 714 name: Optional op name. 715 716 Returns: 717 A tensor with the specified dimensions removed. 718 719 Raises: 720 ValueError: If the named axes are not in the tensor, or if they are 721 not size-1. 722 """ 723 with ops.name_scope(name, 'lt_squeeze', [labeled_tensor]) as scope: 724 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 725 726 if axis_names is None: 727 axis_names = [a.name for a in labeled_tensor.axes.values() if len(a) == 1] 728 729 for axis_name in axis_names: 730 if axis_name not in labeled_tensor.axes: 731 raise ValueError('axis %s is not in tensor axes %s' % 732 (axis_name, labeled_tensor.axes)) 733 elif len(labeled_tensor.axes[axis_name]) != 1: 734 raise ValueError( 735 'cannot squeeze axis with size greater than 1: (%s, %s)' % 736 (axis_name, labeled_tensor.axes[axis_name])) 737 738 squeeze_dimensions = [] 739 axes = [] 740 for i, axis in enumerate(labeled_tensor.axes.values()): 741 if axis.name in axis_names: 742 squeeze_dimensions.append(i) 743 else: 744 axes.append(axis) 745 746 if squeeze_dimensions: 747 squeeze_op = array_ops.squeeze( 748 labeled_tensor.tensor, squeeze_dimensions, name=scope) 749 else: 750 squeeze_op = array_ops.identity(labeled_tensor.tensor, name=scope) 751 752 return core.LabeledTensor(squeeze_op, axes) 753 754 755# pylint: disable=invalid-name 756ReduceAxis = tc.Union(string_types, 757 tc.Tuple(string_types, collections.Hashable)) 758ReduceAxes = tc.Optional(tc.Union(ReduceAxis, tc.Collection(ReduceAxis))) 759# pylint: enable=invalid-name 760 761 762@tc.returns(core.LabeledTensor) 763@tc.accepts(core.LabeledTensorLike, core.LabeledTensorLike, 764 tc.Optional(string_types)) 765def matmul(a, b, name=None): 766 """Matrix multiply two tensors with rank 1 or 2. 767 768 If both tensors have rank 2, a matrix-matrix product is performed. 769 If one tensor has rank 1 and the other has rank 2, then a matrix-vector 770 product is performed. 771 If both tensors have rank 1, then a vector dot-product is performed. 772 (This behavior matches that of `numpy.dot`.) 773 774 Both tensors must share exactly one dimension in common, which is the 775 dimension the operation is summed along. The inputs will be automatically 776 transposed if necessary as part of the matmul op. 777 778 We intend to eventually support `matmul` on higher rank input, and also 779 eventually support summing over any number shared dimensions (via an `axis` 780 argument), but neither of these features has been implemented yet. 781 782 Args: 783 a: First LabeledTensor. 784 b: Second LabeledTensor. 785 name: Optional op name. 786 787 Returns: 788 LabeledTensor with the result of matrix multiplication. Axes are ordered by 789 the current axis_order_scope, if set, or in or order of appearance on the 790 inputs. 791 792 Raises: 793 NotImplementedError: If inputs have rank >2 or share multiple axes. 794 ValueError: If the inputs have rank 0 or do not share any axes. 795 """ 796 with ops.name_scope(name, 'lt_matmul', [a, b]) as scope: 797 798 a = core.convert_to_labeled_tensor(a) 799 b = core.convert_to_labeled_tensor(b) 800 801 if len(a.axes) > 2 or len(b.axes) > 2: 802 # We could pass batched inputs to tf.matmul to make this work, but we 803 # would also need to use tf.tile and/or tf.transpose. These are more 804 # expensive than doing reshapes, so it's not clear if it's a good idea to 805 # do this automatically. 806 raise NotImplementedError( 807 'matmul currently requires inputs with rank 2 or less, but ' 808 'inputs have ranks %r and %r' % (len(a.axes), len(b.axes))) 809 810 if not a.axes or not b.axes: 811 raise ValueError( 812 'matmul currently requires inputs with at least rank 1, but ' 813 'inputs have ranks %r and %r' % (len(a.axes), len(b.axes))) 814 815 shared_axes = set(a.axes) & set(b.axes) 816 if len(shared_axes) > 1: 817 raise NotImplementedError( 818 'matmul does not yet support summing over multiple shared axes: %r. ' 819 'Use transpose and reshape to create a single shared axis to sum ' 820 'over.' % shared_axes) 821 if not shared_axes: 822 raise ValueError('there must have exactly one axis in common between ' 823 'input to matmul: %r, %r' % 824 (a.axes.keys(), b.axes.keys())) 825 shared_axis, = shared_axes 826 827 if a.axes[shared_axis] != b.axes[shared_axis]: 828 raise ValueError('axis %r does not match on input arguments: %r vs %r' % 829 (shared_axis, a.axes[shared_axis].value, 830 b.axes[shared_axis].value)) 831 832 result_axes = [] 833 for axes in [a.axes, b.axes]: 834 for axis in axes.values(): 835 if axis.name != shared_axis: 836 result_axes.append(axis) 837 838 axis_scope_order = core.get_axis_order() 839 if axis_scope_order is not None: 840 result_axis_names = [axis.name for axis in result_axes] 841 new_axis_names = [ 842 name for name in axis_scope_order if name in result_axis_names 843 ] 844 if new_axis_names != result_axis_names: 845 # switch a and b 846 b, a = a, b 847 # result_axes is a list of length 1 or 2 848 result_axes = result_axes[::-1] 849 850 squeeze_dims = [] 851 852 if len(a.axes) == 1: 853 a_tensor = array_ops.reshape(a.tensor, (1, -1)) 854 squeeze_dims.append(0) 855 transpose_a = False 856 else: 857 a_tensor = a.tensor 858 transpose_a = list(a.axes.keys()).index(shared_axis) == 0 859 860 if len(b.axes) == 1: 861 b_tensor = array_ops.reshape(b.tensor, (-1, 1)) 862 squeeze_dims.append(1) 863 transpose_b = False 864 else: 865 b_tensor = b.tensor 866 transpose_b = list(b.axes.keys()).index(shared_axis) == 1 867 868 result_op = math_ops.matmul( 869 a_tensor, b_tensor, transpose_a=transpose_a, transpose_b=transpose_b) 870 871 if squeeze_dims: 872 result_op = array_ops.squeeze(result_op, squeeze_dims) 873 result_op = array_ops.identity(result_op, name=scope) 874 875 return core.LabeledTensor(result_op, result_axes) 876 877 878@tc.returns(types.FunctionType) 879@tc.accepts(string_types, collections.Callable) 880def define_reduce_op(op_name, reduce_fn): 881 """Define a reduction op for labeled tensors. 882 883 Args: 884 op_name: string name of the TensorFlow op. 885 reduce_fn: function to call to evaluate the op on a tf.Tensor. 886 887 Returns: 888 Function defining the given reduction op that acts on a LabeledTensor. 889 """ 890 891 default_name = 'lt_%s' % op_name 892 893 @tc.returns(core.LabeledTensor) 894 @tc.accepts(core.LabeledTensorLike, ReduceAxes, tc.Optional(string_types)) 895 def op(labeled_tensor, axes=None, name=None): 896 """Computes the given reduction across the given axes of a LabeledTensor. 897 898 See `tf.{op_name}` for full details. 899 900 Args: 901 labeled_tensor: The input tensor. 902 axes: A set of axes or None. 903 If None, all axes will be reduced. 904 Axes must all be strings, in which case those dimensions will be 905 removed, or pairs of (name, None) or (name, label), in which case those 906 dimensions will be kept. 907 name: Optional op name. 908 909 Returns: 910 The reduced LabeledTensor. 911 912 Raises: 913 ValueError: if any of the axes to reduce over are not found on 914 `labeled_tensor`. 915 """ 916 with ops.name_scope(name, default_name, [labeled_tensor]) as scope: 917 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 918 919 if axes is None: 920 axes = labeled_tensor.axes.keys() 921 922 if isinstance(axes, (string_types, tuple)): 923 axes = [axes] 924 925 reduction_axes = {} 926 axes_to_squeeze = [] 927 for a in axes: 928 if isinstance(a, string_types): 929 # We squeeze out this axis. 930 reduction_axes[a] = a 931 axes_to_squeeze.append(a) 932 else: 933 # We keep this axis, with the user-provided labels. 934 (axis_name, label) = a 935 if label is not None: 936 # The input was a single label, so make it a list so it can be 937 # turned into an Axis. 938 label = [label] 939 reduction_axes[axis_name] = (axis_name, label) 940 941 for axis_name in reduction_axes: 942 if axis_name not in labeled_tensor.axes: 943 raise ValueError('Axis %s not in axes %s' % 944 (axis_name, labeled_tensor.axes)) 945 946 intermediate_axes = [] 947 reduction_dimensions = [] 948 for i, axis in enumerate(labeled_tensor.axes.values()): 949 if axis.name in reduction_axes: 950 intermediate_axes.append(reduction_axes[axis.name]) 951 reduction_dimensions.append(i) 952 else: 953 intermediate_axes.append(axis) 954 955 reduce_op = reduce_fn( 956 labeled_tensor.tensor, reduction_dimensions, keepdims=True) 957 reduce_lt = core.LabeledTensor(reduce_op, intermediate_axes) 958 959 return squeeze(reduce_lt, axes_to_squeeze, name=scope) 960 961 op.__doc__ = op.__doc__.format(op_name=op_name) 962 op.__name__ = op_name 963 964 return op 965 966 967reduce_all = define_reduce_op('reduce_all', math_ops.reduce_all) 968reduce_any = define_reduce_op('reduce_any', math_ops.reduce_any) 969reduce_logsumexp = define_reduce_op('reduce_logsumexp', 970 math_ops.reduce_logsumexp) 971reduce_max = define_reduce_op('reduce_max', math_ops.reduce_max) 972reduce_mean = define_reduce_op('reduce_mean', math_ops.reduce_mean) 973reduce_min = define_reduce_op('reduce_min', math_ops.reduce_min) 974reduce_prod = define_reduce_op('reduce_prod', math_ops.reduce_prod) 975reduce_sum = define_reduce_op('reduce_sum', math_ops.reduce_sum) 976 977 978@tc.returns(core.LabeledTensor) 979@tc.accepts(core.LabeledTensorLike, 980 tc.Mapping(str, tc.Union(int, ops.Tensor)), 981 tc.Optional(string_types)) 982def tile(labeled_tensor, multiples, name=None): 983 """Constructs a tensor by tiling a given tensor. 984 985 Only axes without tick-labels can be tiled. (Otherwise, axis labels on tiled 986 tensors would no longer be unique.) 987 988 See lt.tile. 989 990 Args: 991 labeled_tensor: The input tensor. 992 multiples: A mapping where the keys are axis names and the values are the 993 integer number of times to tile along that axis. Only axes with a multiple 994 different than 1 need be included. 995 name: Optional op name. 996 997 Returns: 998 A tensor with the indicated axes tiled. 999 1000 Raises: 1001 ValueError: If the tiled axes are not axes in the input tensor, or if any 1002 axes in multiples have tick labels. 1003 """ 1004 with ops.name_scope(name, 'lt_tile', [labeled_tensor]) as scope: 1005 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 1006 1007 if not set(multiples.keys()) <= set(labeled_tensor.axes.keys()): 1008 raise ValueError('tile axes %r are not contained in the set of axis ' 1009 'names %r on the input labeled tensor' % 1010 (multiples.keys(), labeled_tensor.axes)) 1011 1012 labeled_axes = [ 1013 name for name in multiples 1014 if labeled_tensor.axes[name].labels is not None 1015 ] 1016 if labeled_axes: 1017 raise ValueError('cannot tile axes with tick labels: %r' % labeled_axes) 1018 1019 multiples_list = [multiples.get(name, 1) for name in labeled_tensor.axes] 1020 tile_op = array_ops.tile(labeled_tensor.tensor, multiples_list, name=scope) 1021 1022 new_axes = [ 1023 axis.name if axis.labels is None else axis 1024 for axis in labeled_tensor.axes.values() 1025 ] 1026 return core.LabeledTensor(tile_op, new_axes) 1027 1028 1029@tc.returns(core.LabeledTensor) 1030@tc.accepts(core.LabeledTensorLike, 1031 tc.Mapping(str, tc.Tuple(core.AxisValue, core.AxisValue)), 1032 string_types, tc.Optional(string_types)) 1033def pad(labeled_tensor, paddings, mode='CONSTANT', name=None): 1034 """Pads a tensor. 1035 1036 See tf.pad. 1037 1038 Args: 1039 labeled_tensor: The input tensor. 1040 paddings: A mapping where the keys are axis names and the values are 1041 tuples where the first element is the padding to insert at the beginning 1042 of the axis and the second is the padding to insert at the end of the 1043 axis. 1044 mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC". 1045 name: Optional op name. 1046 1047 Returns: 1048 A tensor with the indicated axes padded, optionally with those axes extended 1049 with the provided labels. 1050 1051 Raises: 1052 ValueError: If the padded axes are not axes in the input tensor. 1053 """ 1054 with ops.name_scope(name, 'lt_pad', [labeled_tensor]) as scope: 1055 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 1056 1057 if not set(paddings.keys()) <= set(labeled_tensor.axes.keys()): 1058 raise ValueError('pad axes %r are not contained in the set of axis ' 1059 'names %r on the input labeled tensor' % 1060 (paddings.keys(), labeled_tensor.axes)) 1061 1062 new_axes = [] 1063 padding_pairs = [] 1064 for name, axis in labeled_tensor.axes.items(): 1065 if name in paddings: 1066 padding_before, padding_after = paddings[name] 1067 axis_before = core.Axis(name, padding_before) 1068 axis_after = core.Axis(name, padding_after) 1069 new_axes.append(core.concat_axes([axis_before, axis, axis_after])) 1070 padding_pairs.append((len(axis_before), len(axis_after))) 1071 else: 1072 new_axes.append(axis) 1073 padding_pairs.append((0, 0)) 1074 1075 pad_op = array_ops.pad(labeled_tensor.tensor, 1076 padding_pairs, 1077 mode, 1078 name=scope) 1079 1080 return core.LabeledTensor(pad_op, new_axes) 1081 1082 1083@tc.returns(core.LabeledTensor) 1084@tc.accepts( 1085 tc.Union(np.ndarray, list, tuple, core.Scalar), 1086 tc.Optional(dtypes.DType), 1087 tc.Optional( 1088 tc.Union(core.Axes, tc.Collection( 1089 tc.Union(string_types, core.AxisLike)))), tc.Optional(string_types)) 1090def constant(value, dtype=None, axes=None, name=None): 1091 """Creates a constant tensor. 1092 1093 If `axes` includes any strings, shape is inferred from `value`. Otherwise, 1094 the sizes of the given `axes` are used to set `shape` for `tf.constant`. 1095 1096 See tf.constant for more details. 1097 1098 Args: 1099 value: The input tensor. 1100 dtype: The type of the returned tensor. 1101 axes: Optional Axes, list of strings or list of objects coercible to Axis 1102 objects. By default, axes are assumed to be an empty list (i.e., `value` 1103 is treated as a scalar). 1104 name: Optional op name. 1105 1106 Returns: 1107 The tensor with elements set to zero. 1108 """ 1109 with ops.name_scope(name, 'lt_constant', [value]) as scope: 1110 1111 if axes is None: 1112 axes = [] 1113 1114 if isinstance(axes, core.Axes): 1115 axes = axes.values() 1116 1117 if any(isinstance(ax, string_types) for ax in axes): 1118 # need to infer shape 1119 shape = None 1120 else: 1121 # axes already indicate shape 1122 axes = [core.as_axis(a) for a in axes] 1123 shape = [a.size for a in axes] 1124 1125 op = array_ops.constant(value, dtype=dtype, shape=shape, name=scope) 1126 return core.LabeledTensor(op, axes) 1127 1128 1129@tc.returns(core.LabeledTensor) 1130@tc.accepts(core.LabeledTensorLike, 1131 tc.Optional(dtypes.DType), tc.Optional(string_types)) 1132def zeros_like(labeled_tensor, dtype=None, name=None): 1133 """Creates an identical tensor with all elements set to zero. 1134 1135 Args: 1136 labeled_tensor: The input tensor. 1137 dtype: The type of the returned tensor. 1138 name: Optional op name. 1139 1140 Returns: 1141 The tensor with elements set to zero. 1142 """ 1143 with ops.name_scope(name, 'lt_zeros_like', [labeled_tensor]) as scope: 1144 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 1145 op = array_ops.zeros_like(labeled_tensor.tensor, dtype=dtype, name=scope) 1146 return core.LabeledTensor(op, labeled_tensor.axes) 1147 1148 1149@tc.returns(core.LabeledTensor) 1150@tc.accepts(core.LabeledTensorLike, 1151 tc.Optional(dtypes.DType), tc.Optional(string_types)) 1152def ones_like(labeled_tensor, dtype=None, name=None): 1153 """Creates an identical tensor with all elements set to one. 1154 1155 Args: 1156 labeled_tensor: The input tensor. 1157 dtype: The type of the returned tensor. 1158 name: Optional op name. 1159 1160 Returns: 1161 The tensor with elements set to one. 1162 """ 1163 with ops.name_scope(name, 'lt_ones_like', [labeled_tensor]) as scope: 1164 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 1165 op = array_ops.ones_like(labeled_tensor.tensor, dtype=dtype, name=scope) 1166 return core.LabeledTensor(op, labeled_tensor.axes) 1167 1168 1169@tc.returns(core.LabeledTensor) 1170@tc.accepts(core.LabeledTensorLike, 1171 tc.Optional(dtypes.DType), tc.Optional(string_types)) 1172def cast(labeled_tensor, dtype=None, name=None): 1173 """Casts a labeled tensor to a new type. 1174 1175 Args: 1176 labeled_tensor: The input tensor. 1177 dtype: The type of the returned tensor. 1178 name: Optional op name. 1179 1180 Returns: 1181 A labeled tensor with the new dtype. 1182 """ 1183 with ops.name_scope(name, 'lt_cast', [labeled_tensor]) as scope: 1184 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 1185 op = math_ops.cast(labeled_tensor.tensor, dtype=dtype, name=scope) 1186 return core.LabeledTensor(op, labeled_tensor.axes) 1187 1188 1189@tc.returns(core.LabeledTensor) 1190@tc.accepts(core.LabeledTensorLike, string_types, tc.Optional(string_types)) 1191def verify_tensor_all_finite(labeled_tensor, message, name=None): 1192 """Asserts a tensor doesn't contain NaNs or Infs. 1193 1194 See tf.verify_tensor_all_finite. 1195 1196 Args: 1197 labeled_tensor: The input tensor. 1198 message: Message to log on failure. 1199 name: Optional op name. 1200 1201 Returns: 1202 The input tensor. 1203 """ 1204 with ops.name_scope(name, 'lt_verify_tensor_all_finite', 1205 [labeled_tensor]) as scope: 1206 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 1207 op = numerics.verify_tensor_all_finite( 1208 labeled_tensor.tensor, msg=message, name=scope) 1209 return core.LabeledTensor(op, labeled_tensor.axes) 1210 1211 1212@tc.returns(core.LabeledTensor) 1213@tc.accepts(core.LabeledTensorLike, core.LabeledTensorLike, 1214 tc.Optional(string_types)) 1215def boolean_mask(labeled_tensor, mask, name=None): 1216 """Apply a boolean mask to a labeled tensor. 1217 1218 Unlike `tf.boolean_mask`, this currently only works on 1-dimensional masks. 1219 The mask is applied to the first axis of `labeled_tensor`. Labels on the first 1220 axis are removed, because True indices in `mask` may not be known dynamically. 1221 1222 Args: 1223 labeled_tensor: The input tensor. 1224 mask: The type of the returned tensor. 1225 name: Optional op name. 1226 1227 Returns: 1228 The masked labeled tensor. 1229 1230 Raises: 1231 ValueError: if the first axis of the mask 1232 """ 1233 with ops.name_scope(name, 'lt_boolean_mask', [labeled_tensor, mask]) as scope: 1234 labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor) 1235 mask = core.convert_to_labeled_tensor(mask) 1236 1237 if len(mask.axes) > 1: 1238 raise NotImplementedError( 1239 "LabeledTensor's boolean_mask currently only supports 1D masks") 1240 mask_axis = list(mask.axes.values())[0] 1241 lt_axis = list(labeled_tensor.axes.values())[0] 1242 if mask_axis != lt_axis: 1243 raise ValueError('the first axis of the labeled tensor and the mask ' 1244 'are not equal:\n%r\n%r' % (lt_axis, mask_axis)) 1245 op = array_ops.boolean_mask(labeled_tensor.tensor, mask.tensor, name=scope) 1246 # TODO(shoyer): attempt to infer labels for the masked values, by calling 1247 # tf.contrib.util.constant_value on the mask? 1248 axes = [lt_axis.name] + list(labeled_tensor.axes.values())[1:] 1249 return core.LabeledTensor(op, axes) 1250 1251 1252@tc.returns(core.LabeledTensor) 1253@tc.accepts(core.LabeledTensorLike, core.LabeledTensorLike, 1254 core.LabeledTensorLike, tc.Optional(string_types)) 1255def where(condition, x, y, name=None): 1256 """Return elements from x or y depending on condition. 1257 1258 See `tf.where` for more details. This function currently only implements the 1259 three argument version of where. 1260 1261 Args: 1262 condition: LabeledTensor of type `bool`. 1263 x: LabeledTensor for values where condition is true. 1264 y: LabeledTensor for values where condition is false. 1265 name: Optional op name. 1266 1267 Returns: 1268 The labeled tensor with values according to condition. 1269 1270 Raises: 1271 ValueError: if `x` and `y` have different axes, or if the axes of `x` do not 1272 start with the axes of `condition`. 1273 """ 1274 with ops.name_scope(name, 'lt_where', [condition, x, y]) as scope: 1275 condition = core.convert_to_labeled_tensor(condition) 1276 x = core.convert_to_labeled_tensor(x) 1277 y = core.convert_to_labeled_tensor(y) 1278 1279 if not condition.axes == x.axes == y.axes: 1280 raise ValueError('all inputs to `where` must have equal axes') 1281 1282 op = array_ops.where(condition.tensor, x.tensor, y.tensor, name=scope) 1283 return core.LabeledTensor(op, x.axes) 1284