1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Core classes and core ops for LabeledTensor. 16 17Core ops are ops which will eventually be called by LabeledTensor methods, 18and ops which a core op depends upon. 19For example, `add` is a core op because we'll eventually support the `+` 20operator. 21Non-core ops should go in `ops.py`. 22""" 23from __future__ import absolute_import 24from __future__ import division 25from __future__ import print_function 26 27import collections 28import contextlib 29import numbers 30import types 31 32import numpy as np 33from six import binary_type 34from six import string_types 35from six import text_type 36from six.moves import range # pylint: disable=redefined-builtin 37 38from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc 39from tensorflow.python.framework import dtypes 40from tensorflow.python.framework import ops 41from tensorflow.python.framework import tensor_shape 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import math_ops 44 45# pylint: disable=invalid-name 46 47# Types coercible to Axis.labels 48# We use this instead of collections.Sequence to exclude strings. 49LabelsLike = tc.Union(np.ndarray, range, list, tuple) 50 51# Types coercible to a tf.Dimension 52DimensionLike = tc.Optional(tc.Union(tensor_shape.Dimension, int)) 53 54# Types usable for axis values 55AxisValue = tc.Union(LabelsLike, DimensionLike) 56 57# Valid scalar values for TensorFlow 58Scalar = tc.Union(numbers.Number, bool, binary_type, text_type) 59 60# pylint: enable=invalid-name 61 62 63class Axis(object): 64 """Size and label information for an axis. 65 66 Axis contains either a tf.Dimension indicating the size of an axis, 67 or a tuple of tick labels for the axis. 68 69 If tick labels are provided, they must be unique. 70 """ 71 72 @tc.accepts(object, string_types, AxisValue) 73 def __init__(self, name, value): 74 """Construct an Axis. 75 76 Args: 77 name: Name of the axis. 78 value: Either None, an int or tf.Dimension giving the size of the axis, 79 or a sequence that is not a string additionally providing coordinate 80 (tick) labels. 81 82 Raises: 83 ValueError: If the user provides labels with duplicate values. 84 """ 85 if isinstance(value, tensor_shape.Dimension): 86 dimension = value 87 labels = None 88 elif isinstance(value, int) or value is None: 89 dimension = tensor_shape.Dimension(value) 90 labels = None 91 else: 92 dimension = tensor_shape.Dimension(len(value)) 93 labels = tuple(value) 94 95 if dimension.value == 0: 96 # Treat a zero-length axis as if it has labels. 97 labels = () 98 99 if labels is not None: 100 index = dict(zip(labels, range(len(labels)))) 101 if len(index) != len(labels): 102 raise ValueError('Tick labels must be unique, but got {}' 103 .format(labels)) 104 else: 105 index = None 106 107 self._name = name # type: string_types 108 self._dimension = dimension # type: tensor_shape.Dimension 109 self._labels = labels # type: Optional[tuple] 110 self._index = index # type: Optional[Dict[Any, int]] 111 112 @property 113 @tc.returns(string_types) 114 def name(self): 115 return self._name 116 117 @tc.returns(string_types) 118 def __repr__(self): 119 # Axis('x', Dimension(2)) 120 # TODO(shoyer): make very long reprs more succint? 121 return "%s('%s', %r)" % (type(self).__name__, self.name, self.value) 122 123 @tc.returns(bool) 124 def __eq__(self, other): 125 return (isinstance(other, Axis) and self.name == other.name and 126 self.size == other.size and self.labels == other.labels) 127 128 def __hash__(self): 129 return hash((self.name, self.size, self.labels)) 130 131 @tc.returns(bool) 132 def __ne__(self, other): 133 return not self == other 134 135 @tc.returns(int) 136 def __len__(self): 137 size = self.size 138 if size is None: 139 raise ValueError('axis %r has unknown length' % self.name) 140 return size 141 142 @property 143 @tc.returns(tc.Optional(tensor_shape.Dimension)) 144 def dimension(self): 145 return self._dimension 146 147 @property 148 @tc.returns(tc.Optional(int)) 149 def size(self): 150 return self._dimension.value 151 152 @property 153 @tc.returns(tc.Union(tuple, tensor_shape.Dimension)) 154 def value(self): 155 """Returns the tf.Dimension or tuple specifying axis ticks.""" 156 if self.labels is None: 157 return self.dimension 158 else: 159 return self.labels 160 161 @property 162 @tc.returns(tc.Optional(tuple)) 163 def labels(self): 164 """Returns the tuple containing coordinate labels, else None.""" 165 return self._labels 166 167 def index(self, value): 168 """Returns the integer position of the given tick label.""" 169 if self._index is None: 170 raise ValueError('Axis does not have tick labels') 171 return self._index[value] 172 173 174# tc class for anything that can be coerced into an Axis 175# pylint: disable=invalid-name 176AxisLike = tc.Union(Axis, tc.Tuple(string_types, AxisValue)) 177# pylint: enable=invalid-name 178 179 180@tc.returns(Axis) 181@tc.accepts(AxisLike) 182def as_axis(axis_data): 183 """Convert an AxisLike object into an Axis. 184 185 Args: 186 axis_data: Axis object or tuple (axis_name, axis_value) describing an axis. 187 188 Returns: 189 Axis object. This may be the original object if axis_data is an Axis. 190 """ 191 if isinstance(axis_data, Axis): 192 axis = axis_data 193 else: 194 axis = Axis(*axis_data) 195 return axis 196 197 198class Axes(collections.Mapping): 199 """Axis names and indices for a tensor. 200 201 It is an ordered mapping, with keys given by axis name and values given 202 by Axis objects. Duplicate axis names are not allowed. 203 """ 204 205 @tc.accepts(object, tc.List(AxisLike)) 206 def __init__(self, axes): 207 """Construct an Axes. 208 209 Args: 210 axes: A list of Axis objects or (axis_name, axis_value) tuples. 211 212 Raises: 213 ValueError: If the user provides empty or duplicate axis names. 214 """ 215 self._axes = collections.OrderedDict() 216 217 for axis_data in axes: 218 axis = as_axis(axis_data) 219 220 name = axis.name 221 if name in self._axes: 222 raise ValueError('Duplicate axis name: %s' % name) 223 224 self._axes[name] = axis 225 226 def __iter__(self): 227 return iter(self._axes) 228 229 @tc.returns(string_types) 230 def __repr__(self): 231 # Axes([('x', Dimension(2)), 232 # ('y', ['a', 'b', 'c']), 233 # ('z', Dimension(4))]) 234 cls_name = type(self).__name__ 235 values = ["('%s', %r)" % (v.name, v.value) for v in self._axes.values()] 236 values_repr = (',\n' + ' ' * len(cls_name + '([')).join(values) 237 return '%s([%s])' % (cls_name, values_repr) 238 239 @tc.returns(Axis) 240 @tc.accepts(object, string_types) 241 def __getitem__(self, name): 242 return self._axes[name] 243 244 @tc.returns(bool) 245 def __contains__(self, name): 246 return name in self._axes 247 248 @tc.returns(int) 249 def __len__(self): 250 return len(self._axes) 251 252 def __hash__(self): 253 return hash(tuple(self.items())) 254 255 @tc.accepts(object, string_types) 256 def remove(self, axis_name): 257 """Creates a new Axes object without the given axis.""" 258 if axis_name not in self: 259 raise KeyError(axis_name) 260 remaining_axes = [axis for axis in self.values() if axis.name != axis_name] 261 return Axes(remaining_axes) 262 263 264class LabeledTensor(object): 265 """A tensor with annotated axes. 266 267 It has the following invariants: 268 1) The dimensionality of the tensor is equal to the number of elements 269 in axes. 270 2) The number of coordinate values in the ith dimension is equal to the 271 size of the tensor in the ith dimension. 272 273 Attributes: 274 tensor: tf.Tensor containing the data. 275 axes: lt.Axes containing axis names and coordinate labels. 276 """ 277 278 @tc.accepts(object, ops.Tensor, 279 tc.Union(Axes, tc.Collection(tc.Union(string_types, AxisLike)))) 280 def __init__(self, tensor, axes): 281 """Construct a LabeledTensor. 282 283 Args: 284 tensor: The underlying tensor containing the data. 285 axes: An Axes object, or a collection of strings, Axis objects or tuples 286 of (name, value) pairs indicating the axes. 287 288 Raises: 289 ValueError: If the provided axes do not satisfy the class invariants. 290 """ 291 self._tensor = tensor 292 shape = tensor.get_shape() 293 294 if isinstance(axes, Axes): 295 unvalidated_axes = axes 296 else: 297 mutable_axes = [] 298 299 for position, axis_like in enumerate(axes): 300 if isinstance(axis_like, string_types): 301 # The coordinates for this axes are unlabeled. 302 # Infer the size of the axis. 303 value = shape[position] 304 axis_like = (axis_like, value) 305 306 mutable_axes.append(axis_like) 307 308 # Construct the Axis object, which will additionally validate the contents 309 # of the object. 310 unvalidated_axes = Axes(mutable_axes) 311 312 # Check our invariants. 313 314 # First, the rank of the tensor must be equal to the number of axes. 315 if len(shape) != len(unvalidated_axes): 316 raise ValueError('Tensor rank was not equal to the number of axes: %r, %r' 317 % (shape, unvalidated_axes)) 318 319 # Second, the size of each tensor dimension must match the size of the 320 # corresponding indices. 321 for (d, axis) in zip(shape, unvalidated_axes.values()): 322 if d != axis.size: 323 raise ValueError( 324 'Provided axis size %d does not match tensor dimension size %d' 325 'in tensor %r' % (axis.size, d, tensor)) 326 327 self._axes = unvalidated_axes 328 329 def __repr__(self): 330 # <LabeledTensor 'foo' shape=(2, 3, 4) dtype=float32 331 # axes=[('x', Dimension(2)), 332 # ('y', ('a', 'b', 'c'), 333 # ('z', Dimension(4))]> 334 axes = ["('%s', %r)" % (v.name, v.value) for v in self.axes.values()] 335 axes_repr = (',\n' + ' ' * len(' axes=[')).join(axes) 336 return ("<%s '%s' shape=%s dtype=%s\n axes=[%s]>" % 337 (type(self).__name__, self.tensor.name, self.tensor.get_shape(), 338 self.tensor.dtype.name, axes_repr)) 339 340 @property 341 def tensor(self): 342 return self._tensor 343 344 def _as_graph_element(self): 345 """Support tf.Graph.as_graph_element on LabeledTensor objects. 346 347 This allows operations such as tf.name_scope to take labeled tensors. 348 349 Returns: 350 self.tensor 351 """ 352 return self.tensor 353 354 @property 355 def axes(self): 356 return self._axes 357 358 # properties/methods directly borrowed from tf.Tensor: 359 360 @property 361 def dtype(self): 362 return self._tensor.dtype 363 364 @property 365 def shape(self): 366 return self._tensor.shape 367 368 @property 369 def name(self): 370 return self._tensor.name 371 372 def get_shape(self): 373 """Returns the TensorShape that represents the shape of this tensor. 374 375 See tf.Tensor.get_shape(). 376 377 Returns: 378 A TensorShape representing the shape of this tensor. 379 """ 380 return self._tensor.get_shape() 381 382 # TODO(shoyer): consider how/if to implement .eval(). Maybe it should return 383 # an xarray.DataArray? 384 385 def __getitem__(self, key): 386 # This should work exactly like tf.Tensor.__getitem__, except it preserves 387 # labels. 388 if not isinstance(key, tuple): 389 key = (key,) 390 if len(key) != len(self.axes): 391 raise ValueError('indexer %r must have the same length as the Tensor ' 392 'rank (%r)' % (key, len(self.axes))) 393 selection = {a: k for a, k in zip(self.axes.keys(), key)} 394 return slice_function(self, selection) 395 396 # special methods for overloading arithmetic operations: 397 398 def __abs__(self): 399 return abs_function(self) 400 401 def __neg__(self): 402 return neg(self) 403 404 def __pos__(self): 405 return self 406 407 def __add__(self, other): 408 return add(self, other) 409 410 def __radd__(self, other): 411 return add(other, self) 412 413 def __sub__(self, other): 414 return sub(self, other) 415 416 def __rsub__(self, other): 417 return sub(other, self) 418 419 def __mul__(self, other): 420 return mul(self, other) 421 422 def __rmul__(self, other): 423 return mul(other, self) 424 425 def __truediv__(self, other): 426 return div(self, other) 427 428 __div__ = __truediv__ 429 430 def __rtruediv__(self, other): 431 return div(other, self) 432 433 __rdiv__ = __rtruediv__ 434 435 def __mod__(self, other): 436 return mod(self, other) 437 438 def __rmod__(self, other): 439 return mod(other, self) 440 441 def __pow__(self, other): 442 return pow_function(self, other) 443 444 def __rpow__(self, other): 445 return pow_function(other, self) 446 447 # logical operations: 448 449 def __invert__(self): 450 return logical_not(self) 451 452 def __and__(self, other): 453 return logical_and(self, other) 454 455 def __or__(self, other): 456 return logical_or(self, other) 457 458 def __xor__(self, other): 459 return logical_xor(self, other) 460 461 # boolean operations: 462 463 def __lt__(self, other): 464 return less(self, other) 465 466 def __le__(self, other): 467 return less_equal(self, other) 468 469 def __gt__(self, other): 470 return greater(self, other) 471 472 def __ge__(self, other): 473 return greater_equal(self, other) 474 475 def __eq__(self, other): 476 # for consistency with tf.Tensor 477 if not isinstance(other, LabeledTensor): 478 return False 479 480 return self.tensor == other.tensor and self.axes == other.axes 481 482 def __ne__(self, other): 483 return not self == other 484 485 def __hash__(self): 486 return hash((self.tensor, self.axes)) 487 488 489# typecheck type abbreviations: 490# abbreviations for third-party types with very long reprs 491tc.register_type_abbreviation(tensor_shape.Dimension, 'tensorflow.Dimension') 492tc.register_type_abbreviation(ops.Tensor, 'tensorflow.Tensor') 493tc.register_type_abbreviation(dtypes.DType, 'tensorflow.DType') 494# core LabeledTensor types 495tc.register_type_abbreviation(Axis, 'labeled_tensor.Axis') 496tc.register_type_abbreviation(Axes, 'labeled_tensor.Axes') 497tc.register_type_abbreviation(LabeledTensor, 'labeled_tensor.LabeledTensor') 498 499 500@tc.returns(ops.Tensor) 501@tc.accepts(LabeledTensor) 502def _convert_labeled_tensor_to_tensor(value, *args, **kwargs): 503 # call ops.convert_to_tensor to handle optional arguments appropriately 504 return ops.internal_convert_to_tensor(value.tensor, *args, **kwargs) 505 506 507ops.register_tensor_conversion_function(LabeledTensor, 508 _convert_labeled_tensor_to_tensor) 509 510# tc class for anything that can be coerced into a LabeledTensor 511# pylint: disable=invalid-name 512LabeledTensorLike = tc.Union(LabeledTensor, ops.Tensor, np.ndarray, Scalar) 513# pylint: enable=invalid-name 514 515 516@tc.returns(LabeledTensor) 517@tc.accepts(LabeledTensorLike, object, tc.Optional(string_types)) 518def convert_to_labeled_tensor(value, dtype=None, name=None): 519 """Converts the given `value` to a `LabeledTensor`. 520 521 This function accepts `LabeledTensor` objects, 0-dimensional `Tensor` objects 522 and numpy arrays, and Python scalars. Higher dimensional unlabeled tensors 523 must use the `LabeledTensor` constructor explicitly. 524 525 Args: 526 value: Object to convert. 527 dtype: Optional element type for the returned tensor. If missing, the type 528 is inferred from the type of value. 529 name: Optional name to use if a new Tensor is created. 530 531 Returns: 532 `value` converted into a `LabeledTensor` object. 533 534 Raises: 535 ValueError: If the output would have rank>0 but the input was not already a 536 `LabeledTensor`. 537 """ 538 # TODO(shoyer): consider extending to accept xarray.DataArray as input. 539 if isinstance(value, LabeledTensor): 540 axes = value.axes.values() 541 value = value.tensor 542 else: 543 axes = [] 544 545 # We call convert_to_tensor even for LabeledTensor input because it also 546 # checks to make sure the dtype argument is compatible. 547 tensor = ops.convert_to_tensor(value, dtype=dtype, name=name) 548 if len(tensor.get_shape()) != len(axes): 549 raise ValueError('cannot automatically convert unlabeled arrays or tensors ' 550 'with rank>0 into LabeledTensors: %r' % value) 551 return LabeledTensor(tensor, axes) 552 553 554@tc.returns(Axis) 555@tc.accepts(tc.Collection(Axis)) 556def concat_axes(axes): 557 """Concatenate a list of Axes. 558 559 Args: 560 axes: A collection of Axis objects. 561 562 Returns: 563 The concatenation of the axes. 564 If all axes have labels, the result has the concatenation of the labels. 565 Else, the result has no labels, and its size is the sum of the sizes 566 of the axes. 567 568 Raises: 569 ValueError: If `others` is not a collection of Axes or if it is empty. 570 """ 571 if not axes: 572 raise ValueError('axes must not be empty') 573 for a in axes: 574 if not isinstance(a, Axis): 575 raise ValueError('Expected an Axis, but got %r of type %r' % (a, type(a))) 576 577 names = set(a.name for a in axes) 578 if len(names) > 1: 579 raise ValueError('axes do not all have the same name: %r' % names) 580 name, = names 581 582 all_have_labels = all(a.labels is not None for a in axes) 583 any_has_unknown_size = any(a.size is None for a in axes) 584 585 if all_have_labels: 586 value = tuple(label for a in axes for label in a.labels) 587 elif any_has_unknown_size: 588 value = None 589 else: 590 value = sum(len(a) for a in axes) 591 return Axis(name, value) 592 593 594@tc.returns(LabeledTensor) 595@tc.accepts(LabeledTensorLike, tc.Optional(string_types)) 596def identity(labeled_tensor, name=None): 597 """The identity op. 598 599 See tf.identity. 600 601 Args: 602 labeled_tensor: The input tensor. 603 name: Optional op name. 604 605 Returns: 606 The tensor. 607 """ 608 with ops.name_scope(name, 'lt_identity', [labeled_tensor]) as scope: 609 labeled_tensor = convert_to_labeled_tensor(labeled_tensor) 610 return LabeledTensor( 611 array_ops.identity( 612 labeled_tensor.tensor, name=scope), 613 labeled_tensor.axes) 614 615 616# We don't call this slice because that shadows a built-in. Instead, we alias 617# this to lt.slice in __init__.py. 618@tc.returns(LabeledTensor) 619@tc.accepts(LabeledTensorLike, 620 tc.Mapping(string_types, tc.Union(int, slice)), 621 tc.Optional(string_types)) 622def slice_function(labeled_tensor, selection, name=None): 623 """Slice out a subset of the tensor. 624 625 This is an analog of tf.slice. 626 For example: 627 >>> tensor = tf.reshape(tf.range(0, 6), [3, 2]) 628 >>> labeled_tensor = lt.LabeledTensor(tensor, ['a', ('b', ['foo', 'bar'])]) 629 >>> lt.slice(labeled_tensor, {'a': slice(0, 2), 'b': 1}) 630 <LabeledTensor 'lt_slice:...' shape=(2,) dtype=int32 631 axes=[('a', Dimension(2))]> 632 633 Args: 634 labeled_tensor: The input tensor. 635 selection: A dictionary of type str -> Union(int, slice of int) mapping 636 axis names to sub-selections. 637 name: Optional op name. 638 639 Returns: 640 The slice as a `LabeledTensor`. 641 """ 642 with ops.name_scope(name, 'lt_slice', [labeled_tensor]) as scope: 643 labeled_tensor = convert_to_labeled_tensor(labeled_tensor) 644 645 slices = [] 646 647 for axis_name in labeled_tensor.axes: 648 if axis_name not in selection: 649 # We're not sub-selecting this axis, so use the full slice. 650 slices.append(slice(None)) 651 else: 652 slices.append(selection[axis_name]) 653 654 sliced_tensor = labeled_tensor.tensor[tuple(slices)] 655 656 sliced_axes = [] 657 for axis, s in zip(labeled_tensor.axes.values(), slices): 658 # We sub-select this axis's index with the slice s. 659 660 # `s` is either an int or a proper slice. 661 if isinstance(s, slice): 662 if axis.labels is None: 663 # We're not tracking coordinate names for this axis. 664 sliced_axes.append(axis.name) 665 else: 666 sliced_axes.append((axis.name, axis.labels[s])) 667 else: 668 # If the slice is an int this dimension now has size 1, so we remove it. 669 assert isinstance(s, int) 670 671 return LabeledTensor( 672 array_ops.identity( 673 sliced_tensor, name=scope), sliced_axes) 674 675 676@tc.returns(LabeledTensor) 677@tc.accepts(LabeledTensorLike, 678 tc.Optional(tc.Collection(string_types)), tc.Optional(string_types)) 679def transpose(labeled_tensor, axis_order=None, name=None): 680 """Permute a tensor's axes. 681 682 See tf.transpose. 683 684 Args: 685 labeled_tensor: The input tensor. 686 axis_order: Optional desired axis order, as a list of names. By default, the 687 order of axes is reversed. 688 name: Optional op name. 689 690 Returns: 691 The permuted tensor. 692 693 Raises: 694 ValueError: If axis_order isn't a permutation of the existing axes. 695 """ 696 with ops.name_scope(name, 'lt_transpose', [labeled_tensor]) as scope: 697 labeled_tensor = convert_to_labeled_tensor(labeled_tensor) 698 699 original_order = list(labeled_tensor.axes.keys()) 700 if axis_order is None: 701 axis_order = list(reversed(original_order)) 702 elif sorted(axis_order) != sorted(original_order): 703 raise ValueError( 704 'The new axis order must have the same names as the original axes, ' 705 'but the new order is %r while the original order is %r' % 706 (axis_order, original_order)) 707 708 axis_names = list(labeled_tensor.axes.keys()) 709 permutation = [axis_names.index(n) for n in axis_order] 710 711 # Note: TensorFlow doesn't copy data for the identity transpose. 712 transpose_tensor = array_ops.transpose( 713 labeled_tensor.tensor, permutation, name=scope) 714 715 permuted_axes = [labeled_tensor.axes[n] for n in axis_order] 716 717 return LabeledTensor(transpose_tensor, permuted_axes) 718 719 720@tc.returns(LabeledTensor) 721@tc.accepts( 722 LabeledTensorLike, 723 tc.Collection( 724 tc.Union(string_types, tc.Tuple(string_types, collections.Hashable))), 725 tc.Optional(string_types)) 726def expand_dims(labeled_tensor, axes, name=None): 727 """Insert dimensions of size 1. 728 729 See tf.expand_dims. 730 731 Args: 732 labeled_tensor: The input tensor. 733 axes: The desired axis names as strings or tuples of (name, label), 734 where `label` is the coordinate name for the new dimension `name`. 735 These must include the existing axis names, and the existing names must 736 appear in the same order in this list as they do in the input tensor. 737 name: Optional op name. 738 739 Returns: 740 A tensor with an axis for each axis in axes. 741 New axes are created with size 1 and do not have labeled coordinates. 742 743 Raises: 744 AxisOrderError: If axis names don't appear in the same order in axes 745 and the labeled tensor. 746 """ 747 with ops.name_scope(name, 'lt_expand_dims', [labeled_tensor]) as scope: 748 labeled_tensor = convert_to_labeled_tensor(labeled_tensor) 749 750 axis_names = [a if isinstance(a, string_types) else a[0] for a in axes] 751 check_axis_order(labeled_tensor, axis_names) 752 753 reshaped_axes = [] 754 shape = [] 755 for axis_spec in axes: 756 if axis_spec in labeled_tensor.axes: 757 axis = labeled_tensor.axes[axis_spec] 758 reshaped_axes.append(axis) 759 shape.append(-1 if axis.size is None else axis.size) 760 else: 761 if isinstance(axis_spec, string_types): 762 reshaped_axes.append((axis_spec, 1)) 763 else: 764 (name, label) = axis_spec 765 reshaped_axes.append((name, (label,))) 766 767 shape.append(1) 768 769 reshaped_tensor = array_ops.reshape( 770 labeled_tensor.tensor, shape, name=scope) 771 772 return LabeledTensor(reshaped_tensor, reshaped_axes) 773 774 775# This should only be added to a graph collection once. 776_AXIS_ORDER_KEY = ('__axis_order',) 777 778 779@tc.returns(tc.Optional(tc.List(string_types))) 780def get_axis_order(): 781 """Get the axis_order set by any containing axis_order_scope. 782 783 Returns: 784 List of strings giving an order to use for axis names, or None, if no axis 785 order is set. 786 """ 787 # By storing axis_order in the graph, we can ensure that axis_order_scope is 788 # thread-safe. 789 axis_order_list = ops.get_collection(_AXIS_ORDER_KEY) 790 if axis_order_list: 791 axis_order, = axis_order_list 792 else: 793 axis_order = None 794 return axis_order 795 796 797@tc.accepts(tc.Optional(tc.List(string_types))) 798def _set_axis_order(axis_order): 799 axis_order_list = ops.get_collection_ref(_AXIS_ORDER_KEY) 800 if axis_order_list: 801 axis_order_list[0] = axis_order 802 else: 803 axis_order_list.append(axis_order) 804 805 806@contextlib.contextmanager 807@tc.accepts(tc.Optional(tc.List(string_types))) 808def axis_order_scope(axis_order=None): 809 """Set axis order for the result of broadcasting operations within a scope. 810 811 This allows you to ensure that tensors resulting from arithmetic have a 812 predictable axis order. 813 814 Example usage: 815 816 with lt.axis_order_scope(['x', 'y', 'z']): 817 # result is guaranteed to have the correct axis order 818 result = w + b 819 820 You can nest scopes, in which case only the inner-most scope applies, e.g., 821 822 with lt.axis_order(['x', 'y', 'z']): 823 with lt.axis_order(): 824 result = w + b # uses the default (left-most) axis ordering 825 826 Args: 827 axis_order: optional list of strings providing axis names. By default, 828 creates a scope without axis order. 829 830 Yields: 831 The provided axis_order or `None`. 832 """ 833 original_axis_order = get_axis_order() 834 _set_axis_order(axis_order) 835 try: 836 yield axis_order 837 finally: 838 _set_axis_order(original_axis_order) 839 840 841@tc.returns(tc.List(string_types)) 842def _get_valid_axis_order(): 843 axis_order = get_axis_order() 844 if axis_order is None: 845 raise AxisOrderError('an explicit axis order must be provided with the ' 846 'axis_order argument or by using an axis_order_scope') 847 return axis_order 848 849 850class AxisOrderError(ValueError): 851 """Error class for cases where there is no valid axis order.""" 852 853 854# TODO(shoyer): should this function accept a list of labeled tensors instead? 855@tc.returns(type(None)) 856@tc.accepts(LabeledTensorLike, tc.Optional(tc.Collection(string_types))) 857def check_axis_order(labeled_tensor, axis_order=None): 858 """Verify that the given tensor has a consistent axis order. 859 860 Args: 861 labeled_tensor: The input tensor. All axes on this tensor must appear in 862 axis_order. 863 axis_order: Optional desired axis order, as a list of names. If not 864 provided, defaults to the current axis_order_scope (if set). 865 866 Raises: 867 AxisOrderError: If the axis_order is unavailable, inconsistent or does not 868 include all existing axes. 869 """ 870 labeled_tensor = convert_to_labeled_tensor(labeled_tensor) 871 872 if axis_order is None: 873 axis_order = _get_valid_axis_order() 874 875 relevant_axis_order = [a for a in axis_order if a in labeled_tensor.axes] 876 877 if len(relevant_axis_order) < len(labeled_tensor.axes): 878 raise AxisOrderError( 879 'not all axis names appear in the required axis order %r: %r' % 880 (axis_order, labeled_tensor)) 881 882 if relevant_axis_order != list(labeled_tensor.axes): 883 raise AxisOrderError( 884 'axes on a labeled tensor do not appear in the same order as the ' 885 'required axis order %r: %r' % (axis_order, labeled_tensor)) 886 887 888@tc.returns(LabeledTensor) 889@tc.accepts(LabeledTensorLike, 890 tc.Optional(tc.Collection(string_types)), tc.Optional(string_types)) 891def impose_axis_order(labeled_tensor, axis_order=None, name=None): 892 """Impose desired axis order on a labeled tensor. 893 894 Args: 895 labeled_tensor: The input tensor. 896 axis_order: Optional desired axis order, as a list of names. If not 897 provided, defaults to the current axis_order_scope (if set). 898 name: Optional op name. 899 900 Returns: 901 Labeled tensor with possibly transposed axes. 902 903 Raises: 904 AxisOrderError: If no axis_order is provided or axis_order does not contain 905 all axes on the input tensor. 906 """ 907 with ops.name_scope(name, 'lt_impose_axis_order', [labeled_tensor]) as scope: 908 labeled_tensor = convert_to_labeled_tensor(labeled_tensor) 909 910 if axis_order is None: 911 axis_order = _get_valid_axis_order() 912 913 relevant_axis_order = [a for a in axis_order if a in labeled_tensor.axes] 914 915 return transpose(labeled_tensor, relevant_axis_order, name=scope) 916 917 918@tc.returns(tc.Optional(list)) 919@tc.accepts(list, list) 920def _find_consistent_ordering(a, b): 921 """Find the left-most consistent ordering between two lists of unique items. 922 923 A consistent ordering combines all elements in both a and b while keeping all 924 elements in their original order in both inputs. The left-most consistent 925 ordering orders elements from `a` not found in `b` before elements in `b` not 926 found in `a`. 927 928 For example, given ['x', 'z'] and ['y', 'z'], both ['x', 'y', 'z'] and ['y', 929 'x', 'z'] are consistent orderings because each of the inputs appears in 930 each consistent ordering in the same order, and ['x', 'y', 'z'] is the 931 left-most, because 'x' appears only in `a` and 'y' appears only in `b`. In 932 contrast, there is no consistent ordering between ['x', 'y'] and ['y', 'x']. 933 934 Args: 935 a: list with unique elements. 936 b: list with unique elements. 937 938 Returns: 939 List containing all elements in either a or b, or None, if no consistent 940 ordering exists. 941 """ 942 a_set = set(a) 943 b_set = set(b) 944 i = 0 945 j = 0 946 ordering = [] 947 while i < len(a) and j < len(b): 948 if a[i] not in b_set: 949 ordering.append(a[i]) 950 i += 1 951 elif b[j] not in a_set: 952 ordering.append(b[j]) 953 j += 1 954 elif a[i] == b[j]: 955 ordering.append(a[i]) 956 i += 1 957 j += 1 958 else: 959 return None 960 961 ordering.extend(a[i:]) 962 ordering.extend(b[j:]) 963 964 return ordering 965 966 967@tc.returns(LabeledTensor, LabeledTensor, Axes) 968@tc.accepts(LabeledTensorLike, LabeledTensorLike, tc.Optional(string_types)) 969def align(labeled_tensor_0, labeled_tensor_1, name=None): 970 """Align the axes of two tensors so they may be broadcast to each other. 971 972 Axes are ordered by the current axis order scope, if present, or by the left- 973 most consistent ordering. An exception is raised if it is impossible to align 974 the tensors without a transpose (align never copies the input data). 975 976 Example usage: 977 978 >>> a = lt.LabeledTensor(tf.ones((2, 4)), ['x', 'z']) 979 >>> b = lt.LabeledTensor(tf.ones((3, 4)), ['y', 'z']) 980 >>> a2, b2, axes = lt.align(a, b) 981 >>> a2 982 <LabeledTensor 'lt_align_1/lt_align_1/0:...' shape=(2, 1, 4) dtype=float32 983 axes=[('x', Dimension(2)), 984 ('y', Dimension(1)), 985 ('z', Dimension(4))]> 986 >>> b2 987 <LabeledTensor 'lt_align_1/lt_align_1/1:...' shape=(1, 3, 4) dtype=float32 988 axes=[('x', Dimension(1)), 989 ('y', Dimension(3)), 990 ('z', Dimension(4))]> 991 >>> axes 992 Axes([('x', Dimension(2)), 993 ('y', Dimension(3)), 994 ('z', Dimension(4))]) 995 996 Args: 997 labeled_tensor_0: An input tensor. 998 labeled_tensor_1: An input tensor. 999 name: Optional op name. 1000 1001 Returns: 1002 The aligned tensors and the axes the resulting tensor would have if the two 1003 aligned tensors were broadcast to each other. The aligned tensors have the 1004 same rank but not necessarily the same shape, with axes in the same order. 1005 1006 Raises: 1007 ValueError: If axes with the same name on the inputs are not equal. 1008 AxisOrderError: If there is no way to reshape the input tensors into the 1009 output without a transpose. 1010 """ 1011 with ops.name_scope(name, 'lt_align', 1012 [labeled_tensor_0, labeled_tensor_1]) as scope: 1013 1014 labeled_tensor_0 = convert_to_labeled_tensor(labeled_tensor_0) 1015 labeled_tensor_1 = convert_to_labeled_tensor(labeled_tensor_1) 1016 1017 axes_0 = labeled_tensor_0.axes 1018 axes_1 = labeled_tensor_1.axes 1019 for axis_name in axes_0: 1020 if axis_name in axes_1: 1021 if axes_0[axis_name] != axes_1[axis_name]: 1022 raise ValueError('Mismatched %r axis on input tensors: %r and %r' % 1023 (axis_name, axes_0[axis_name], axes_1[axis_name])) 1024 1025 axis_scope_order = get_axis_order() 1026 if axis_scope_order is not None: 1027 # we are in an axis_order_scope 1028 axis_names_set = set(axes_0) | set(axes_1) 1029 new_axis_names = [a for a in axis_scope_order if a in axis_names_set] 1030 1031 check_axis_order(labeled_tensor_0, axis_scope_order) 1032 check_axis_order(labeled_tensor_1, axis_scope_order) 1033 1034 else: 1035 # attempt to find a consistent ordering 1036 new_axis_names = _find_consistent_ordering(list(axes_0), list(axes_1)) 1037 if new_axis_names is None: 1038 raise AxisOrderError( 1039 'No consistent axis order allows for aligning tensors with axis ' 1040 'orders %r and %r without copying data. Use transpose or ' 1041 'impose_axis_order to reorder axes on one of more of the inputs.' % 1042 (axes_0.keys(), axes_1.keys())) 1043 1044 labeled_tensor_0 = expand_dims( 1045 labeled_tensor_0, new_axis_names, name=scope + '0') 1046 labeled_tensor_1 = expand_dims( 1047 labeled_tensor_1, new_axis_names, name=scope + '1') 1048 1049 broadcast_axes = [] 1050 for axis_name in new_axis_names: 1051 if axis_name in axes_0: 1052 broadcast_axes.append(axes_0[axis_name]) 1053 else: 1054 broadcast_axes.append(axes_1[axis_name]) 1055 1056 return labeled_tensor_0, labeled_tensor_1, Axes(broadcast_axes) 1057 1058 1059@tc.returns(types.FunctionType) 1060@tc.accepts(string_types, collections.Callable) 1061def define_unary_op(op_name, elementwise_function): 1062 """Define a unary operation for labeled tensors. 1063 1064 Args: 1065 op_name: string name of the TensorFlow op. 1066 elementwise_function: function to call to evaluate the op on a single 1067 tf.Tensor object. This function must accept two arguments: a tf.Tensor 1068 object, and an optional `name`. 1069 1070 Returns: 1071 Function defining the given op that acts on LabeledTensors. 1072 """ 1073 1074 default_name = 'lt_%s' % op_name 1075 1076 @tc.returns(LabeledTensor) 1077 @tc.accepts(LabeledTensorLike, tc.Optional(string_types)) 1078 def op(labeled_tensor, name=None): 1079 """LabeledTensor version of `tf.{op_name}`. 1080 1081 See `tf.{op_name}` for full details. 1082 1083 Args: 1084 labeled_tensor: Input tensor. 1085 name: Optional op name. 1086 1087 Returns: 1088 A LabeledTensor with result of applying `tf.{op_name}` elementwise. 1089 """ 1090 with ops.name_scope(name, default_name, [labeled_tensor]) as scope: 1091 labeled_tensor = convert_to_labeled_tensor(labeled_tensor) 1092 result_tensor = elementwise_function(labeled_tensor.tensor, name=scope) 1093 return LabeledTensor(result_tensor, labeled_tensor.axes) 1094 1095 op.__doc__ = op.__doc__.format(op_name=op_name) 1096 op.__name__ = op_name 1097 1098 return op 1099 1100 1101abs_function = define_unary_op('abs', math_ops.abs) 1102neg = define_unary_op('neg', math_ops.negative) 1103sign = define_unary_op('sign', math_ops.sign) 1104reciprocal = define_unary_op('reciprocal', math_ops.reciprocal) 1105square = define_unary_op('square', math_ops.square) 1106round_function = define_unary_op('round', math_ops.round) 1107sqrt = define_unary_op('sqrt', math_ops.sqrt) 1108rsqrt = define_unary_op('rsqrt', math_ops.rsqrt) 1109exp = define_unary_op('exp', math_ops.exp) 1110log = define_unary_op('log', math_ops.log) 1111ceil = define_unary_op('ceil', math_ops.ceil) 1112floor = define_unary_op('floor', math_ops.floor) 1113cos = define_unary_op('cos', math_ops.cos) 1114sin = define_unary_op('sin', math_ops.sin) 1115tan = define_unary_op('tan', math_ops.tan) 1116acos = define_unary_op('acos', math_ops.acos) 1117asin = define_unary_op('asin', math_ops.asin) 1118atan = define_unary_op('atan', math_ops.atan) 1119lgamma = define_unary_op('lgamma', math_ops.lgamma) 1120digamma = define_unary_op('digamma', math_ops.digamma) 1121erf = define_unary_op('erf', math_ops.erf) 1122erfc = define_unary_op('erfc', math_ops.erfc) 1123logical_not = define_unary_op('logical_not', math_ops.logical_not) 1124tanh = define_unary_op('tanh', math_ops.tanh) 1125sigmoid = define_unary_op('sigmoid', math_ops.sigmoid) 1126 1127 1128@tc.returns(types.FunctionType) 1129@tc.accepts(string_types, collections.Callable) 1130def define_binary_op(op_name, elementwise_function): 1131 """Define a binary operation that broadcasts labeled tensors. 1132 1133 Args: 1134 op_name: string name of the TensorFlow op. 1135 elementwise_function: function to call to evaluate the op on tf.Tensor 1136 objects. This function must accept three arguments: two tf.Tensor objects, 1137 and an optional `name`. 1138 1139 Returns: 1140 Function defining the given op that acts on LabeledTensors. 1141 """ 1142 1143 default_name = 'lt_%s' % op_name 1144 1145 @tc.returns(LabeledTensor) 1146 @tc.accepts(LabeledTensorLike, LabeledTensorLike, tc.Optional(string_types)) 1147 def op(labeled_tensor_0, labeled_tensor_1, name=None): 1148 """LabeledTensor version of `tf.{op_name}` with label based alignment. 1149 1150 See `tf.{op_name}` for full details. 1151 1152 Args: 1153 labeled_tensor_0: Input tensor. 1154 labeled_tensor_1: Input tensor. 1155 name: Optional op name. 1156 1157 Returns: 1158 A LabeledTensor with result of applying `tf.{op_name}` elementwise. 1159 """ 1160 with ops.name_scope(name, default_name, 1161 [labeled_tensor_0, labeled_tensor_1]) as scope: 1162 1163 align_0, align_1, broadcast_axes = align(labeled_tensor_0, 1164 labeled_tensor_1) 1165 1166 tensor = elementwise_function(align_0.tensor, align_1.tensor, name=scope) 1167 1168 return LabeledTensor(tensor, broadcast_axes) 1169 1170 op.__doc__ = op.__doc__.format(op_name=op_name) 1171 op.__name__ = op_name 1172 1173 return op 1174 1175 1176add = define_binary_op('add', math_ops.add) 1177sub = define_binary_op('sub', math_ops.subtract) 1178mul = define_binary_op('mul', math_ops.multiply) 1179div = define_binary_op('div', math_ops.div) 1180mod = define_binary_op('mod', math_ops.mod) 1181pow_function = define_binary_op('pow', math_ops.pow) 1182 1183equal = define_binary_op('equal', math_ops.equal) 1184greater = define_binary_op('greater', math_ops.greater) 1185greater_equal = define_binary_op('greater_equal', math_ops.greater_equal) 1186not_equal = define_binary_op('not_equal', math_ops.not_equal) 1187less = define_binary_op('less', math_ops.less) 1188less_equal = define_binary_op('less_equal', math_ops.less_equal) 1189logical_and = define_binary_op('logical_and', math_ops.logical_and) 1190logical_or = define_binary_op('logical_or', math_ops.logical_or) 1191logical_xor = define_binary_op('logical_xor', math_ops.logical_xor) 1192 1193maximum = define_binary_op('maximum', math_ops.maximum) 1194minimum = define_binary_op('minimum', math_ops.minimum) 1195squared_difference = define_binary_op('squared_difference', 1196 math_ops.squared_difference) 1197igamma = define_binary_op('igamma', math_ops.igamma) 1198igammac = define_binary_op('igammac', math_ops.igammac) 1199zeta = define_binary_op('zeta', math_ops.zeta) 1200polygamma = define_binary_op('polygamma', math_ops.polygamma) 1201