1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Core classes and core ops for LabeledTensor.
16
17Core ops are ops which will eventually be called by LabeledTensor methods,
18and ops which a core op depends upon.
19For example, `add` is a core op because we'll eventually support the `+`
20operator.
21Non-core ops should go in `ops.py`.
22"""
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import collections
28import contextlib
29import numbers
30import types
31
32import numpy as np
33from six import binary_type
34from six import string_types
35from six import text_type
36from six.moves import range  # pylint: disable=redefined-builtin
37
38from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc
39from tensorflow.python.framework import dtypes
40from tensorflow.python.framework import ops
41from tensorflow.python.framework import tensor_shape
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import math_ops
44
45# pylint: disable=invalid-name
46
47# Types coercible to Axis.labels
48# We use this instead of collections.Sequence to exclude strings.
49LabelsLike = tc.Union(np.ndarray, range, list, tuple)
50
51# Types coercible to a tf.Dimension
52DimensionLike = tc.Optional(tc.Union(tensor_shape.Dimension, int))
53
54# Types usable for axis values
55AxisValue = tc.Union(LabelsLike, DimensionLike)
56
57# Valid scalar values for TensorFlow
58Scalar = tc.Union(numbers.Number, bool, binary_type, text_type)
59
60# pylint: enable=invalid-name
61
62
63class Axis(object):
64  """Size and label information for an axis.
65
66  Axis contains either a tf.Dimension indicating the size of an axis,
67  or a tuple of tick labels for the axis.
68
69  If tick labels are provided, they must be unique.
70  """
71
72  @tc.accepts(object, string_types, AxisValue)
73  def __init__(self, name, value):
74    """Construct an Axis.
75
76    Args:
77      name: Name of the axis.
78      value: Either None, an int or tf.Dimension giving the size of the axis,
79        or a sequence that is not a string additionally providing coordinate
80        (tick) labels.
81
82    Raises:
83      ValueError: If the user provides labels with duplicate values.
84    """
85    if isinstance(value, tensor_shape.Dimension):
86      dimension = value
87      labels = None
88    elif isinstance(value, int) or value is None:
89      dimension = tensor_shape.Dimension(value)
90      labels = None
91    else:
92      dimension = tensor_shape.Dimension(len(value))
93      labels = tuple(value)
94
95    if dimension.value == 0:
96      # Treat a zero-length axis as if it has labels.
97      labels = ()
98
99    if labels is not None:
100      index = dict(zip(labels, range(len(labels))))
101      if len(index) != len(labels):
102        raise ValueError('Tick labels must be unique, but got {}'
103                         .format(labels))
104    else:
105      index = None
106
107    self._name = name  # type: string_types
108    self._dimension = dimension  # type: tensor_shape.Dimension
109    self._labels = labels  # type: Optional[tuple]
110    self._index = index  # type: Optional[Dict[Any, int]]
111
112  @property
113  @tc.returns(string_types)
114  def name(self):
115    return self._name
116
117  @tc.returns(string_types)
118  def __repr__(self):
119    # Axis('x', Dimension(2))
120    # TODO(shoyer): make very long reprs more succint?
121    return "%s('%s', %r)" % (type(self).__name__, self.name, self.value)
122
123  @tc.returns(bool)
124  def __eq__(self, other):
125    return (isinstance(other, Axis) and self.name == other.name and
126            self.size == other.size and self.labels == other.labels)
127
128  def __hash__(self):
129    return hash((self.name, self.size, self.labels))
130
131  @tc.returns(bool)
132  def __ne__(self, other):
133    return not self == other
134
135  @tc.returns(int)
136  def __len__(self):
137    size = self.size
138    if size is None:
139      raise ValueError('axis %r has unknown length' % self.name)
140    return size
141
142  @property
143  @tc.returns(tc.Optional(tensor_shape.Dimension))
144  def dimension(self):
145    return self._dimension
146
147  @property
148  @tc.returns(tc.Optional(int))
149  def size(self):
150    return self._dimension.value
151
152  @property
153  @tc.returns(tc.Union(tuple, tensor_shape.Dimension))
154  def value(self):
155    """Returns the tf.Dimension or tuple specifying axis ticks."""
156    if self.labels is None:
157      return self.dimension
158    else:
159      return self.labels
160
161  @property
162  @tc.returns(tc.Optional(tuple))
163  def labels(self):
164    """Returns the tuple containing coordinate labels, else None."""
165    return self._labels
166
167  def index(self, value):
168    """Returns the integer position of the given tick label."""
169    if self._index is None:
170      raise ValueError('Axis does not have tick labels')
171    return self._index[value]
172
173
174# tc class for anything that can be coerced into an Axis
175# pylint: disable=invalid-name
176AxisLike = tc.Union(Axis, tc.Tuple(string_types, AxisValue))
177# pylint: enable=invalid-name
178
179
180@tc.returns(Axis)
181@tc.accepts(AxisLike)
182def as_axis(axis_data):
183  """Convert an AxisLike object into an Axis.
184
185  Args:
186    axis_data: Axis object or tuple (axis_name, axis_value) describing an axis.
187
188  Returns:
189    Axis object. This may be the original object if axis_data is an Axis.
190  """
191  if isinstance(axis_data, Axis):
192    axis = axis_data
193  else:
194    axis = Axis(*axis_data)
195  return axis
196
197
198class Axes(collections.Mapping):
199  """Axis names and indices for a tensor.
200
201  It is an ordered mapping, with keys given by axis name and values given
202  by Axis objects. Duplicate axis names are not allowed.
203  """
204
205  @tc.accepts(object, tc.List(AxisLike))
206  def __init__(self, axes):
207    """Construct an Axes.
208
209    Args:
210      axes: A list of Axis objects or (axis_name, axis_value) tuples.
211
212    Raises:
213      ValueError: If the user provides empty or duplicate axis names.
214    """
215    self._axes = collections.OrderedDict()
216
217    for axis_data in axes:
218      axis = as_axis(axis_data)
219
220      name = axis.name
221      if name in self._axes:
222        raise ValueError('Duplicate axis name: %s' % name)
223
224      self._axes[name] = axis
225
226  def __iter__(self):
227    return iter(self._axes)
228
229  @tc.returns(string_types)
230  def __repr__(self):
231    # Axes([('x', Dimension(2)),
232    #       ('y', ['a', 'b', 'c']),
233    #       ('z', Dimension(4))])
234    cls_name = type(self).__name__
235    values = ["('%s', %r)" % (v.name, v.value) for v in self._axes.values()]
236    values_repr = (',\n' + ' ' * len(cls_name + '([')).join(values)
237    return '%s([%s])' % (cls_name, values_repr)
238
239  @tc.returns(Axis)
240  @tc.accepts(object, string_types)
241  def __getitem__(self, name):
242    return self._axes[name]
243
244  @tc.returns(bool)
245  def __contains__(self, name):
246    return name in self._axes
247
248  @tc.returns(int)
249  def __len__(self):
250    return len(self._axes)
251
252  def __hash__(self):
253    return hash(tuple(self.items()))
254
255  @tc.accepts(object, string_types)
256  def remove(self, axis_name):
257    """Creates a new Axes object without the given axis."""
258    if axis_name not in self:
259      raise KeyError(axis_name)
260    remaining_axes = [axis for axis in self.values() if axis.name != axis_name]
261    return Axes(remaining_axes)
262
263
264class LabeledTensor(object):
265  """A tensor with annotated axes.
266
267  It has the following invariants:
268    1) The dimensionality of the tensor is equal to the number of elements
269    in axes.
270    2) The number of coordinate values in the ith dimension is equal to the
271    size of the tensor in the ith dimension.
272
273  Attributes:
274    tensor: tf.Tensor containing the data.
275    axes: lt.Axes containing axis names and coordinate labels.
276  """
277
278  @tc.accepts(object, ops.Tensor,
279              tc.Union(Axes, tc.Collection(tc.Union(string_types, AxisLike))))
280  def __init__(self, tensor, axes):
281    """Construct a LabeledTensor.
282
283    Args:
284      tensor: The underlying tensor containing the data.
285      axes: An Axes object, or a collection of strings, Axis objects or tuples
286        of (name, value) pairs indicating the axes.
287
288    Raises:
289      ValueError: If the provided axes do not satisfy the class invariants.
290    """
291    self._tensor = tensor
292    shape = tensor.get_shape()
293
294    if isinstance(axes, Axes):
295      unvalidated_axes = axes
296    else:
297      mutable_axes = []
298
299      for position, axis_like in enumerate(axes):
300        if isinstance(axis_like, string_types):
301          # The coordinates for this axes are unlabeled.
302          # Infer the size of the axis.
303          value = shape[position]
304          axis_like = (axis_like, value)
305
306        mutable_axes.append(axis_like)
307
308      # Construct the Axis object, which will additionally validate the contents
309      # of the object.
310      unvalidated_axes = Axes(mutable_axes)
311
312    # Check our invariants.
313
314    # First, the rank of the tensor must be equal to the number of axes.
315    if len(shape) != len(unvalidated_axes):
316      raise ValueError('Tensor rank was not equal to the number of axes: %r, %r'
317                       % (shape, unvalidated_axes))
318
319    # Second, the size of each tensor dimension must match the size of the
320    # corresponding indices.
321    for (d, axis) in zip(shape, unvalidated_axes.values()):
322      if d != axis.size:
323        raise ValueError(
324            'Provided axis size %d does not match tensor dimension size %d'
325            'in tensor %r' % (axis.size, d, tensor))
326
327    self._axes = unvalidated_axes
328
329  def __repr__(self):
330    # <LabeledTensor 'foo' shape=(2, 3, 4) dtype=float32
331    #  axes=[('x', Dimension(2)),
332    #        ('y', ('a', 'b', 'c'),
333    #        ('z', Dimension(4))]>
334    axes = ["('%s', %r)" % (v.name, v.value) for v in self.axes.values()]
335    axes_repr = (',\n' + ' ' * len(' axes=[')).join(axes)
336    return ("<%s '%s' shape=%s dtype=%s\n axes=[%s]>" %
337            (type(self).__name__, self.tensor.name, self.tensor.get_shape(),
338             self.tensor.dtype.name, axes_repr))
339
340  @property
341  def tensor(self):
342    return self._tensor
343
344  def _as_graph_element(self):
345    """Support tf.Graph.as_graph_element on LabeledTensor objects.
346
347    This allows operations such as tf.name_scope to take labeled tensors.
348
349    Returns:
350      self.tensor
351    """
352    return self.tensor
353
354  @property
355  def axes(self):
356    return self._axes
357
358  # properties/methods directly borrowed from tf.Tensor:
359
360  @property
361  def dtype(self):
362    return self._tensor.dtype
363
364  @property
365  def shape(self):
366    return self._tensor.shape
367
368  @property
369  def name(self):
370    return self._tensor.name
371
372  def get_shape(self):
373    """Returns the TensorShape that represents the shape of this tensor.
374
375    See tf.Tensor.get_shape().
376
377    Returns:
378      A TensorShape representing the shape of this tensor.
379    """
380    return self._tensor.get_shape()
381
382  # TODO(shoyer): consider how/if to implement .eval(). Maybe it should return
383  # an xarray.DataArray?
384
385  def __getitem__(self, key):
386    # This should work exactly like tf.Tensor.__getitem__, except it preserves
387    # labels.
388    if not isinstance(key, tuple):
389      key = (key,)
390    if len(key) != len(self.axes):
391      raise ValueError('indexer %r must have the same length as the Tensor '
392                       'rank (%r)' % (key, len(self.axes)))
393    selection = {a: k for a, k in zip(self.axes.keys(), key)}
394    return slice_function(self, selection)
395
396  # special methods for overloading arithmetic operations:
397
398  def __abs__(self):
399    return abs_function(self)
400
401  def __neg__(self):
402    return neg(self)
403
404  def __pos__(self):
405    return self
406
407  def __add__(self, other):
408    return add(self, other)
409
410  def __radd__(self, other):
411    return add(other, self)
412
413  def __sub__(self, other):
414    return sub(self, other)
415
416  def __rsub__(self, other):
417    return sub(other, self)
418
419  def __mul__(self, other):
420    return mul(self, other)
421
422  def __rmul__(self, other):
423    return mul(other, self)
424
425  def __truediv__(self, other):
426    return div(self, other)
427
428  __div__ = __truediv__
429
430  def __rtruediv__(self, other):
431    return div(other, self)
432
433  __rdiv__ = __rtruediv__
434
435  def __mod__(self, other):
436    return mod(self, other)
437
438  def __rmod__(self, other):
439    return mod(other, self)
440
441  def __pow__(self, other):
442    return pow_function(self, other)
443
444  def __rpow__(self, other):
445    return pow_function(other, self)
446
447  # logical operations:
448
449  def __invert__(self):
450    return logical_not(self)
451
452  def __and__(self, other):
453    return logical_and(self, other)
454
455  def __or__(self, other):
456    return logical_or(self, other)
457
458  def __xor__(self, other):
459    return logical_xor(self, other)
460
461  # boolean operations:
462
463  def __lt__(self, other):
464    return less(self, other)
465
466  def __le__(self, other):
467    return less_equal(self, other)
468
469  def __gt__(self, other):
470    return greater(self, other)
471
472  def __ge__(self, other):
473    return greater_equal(self, other)
474
475  def __eq__(self, other):
476    # for consistency with tf.Tensor
477    if not isinstance(other, LabeledTensor):
478      return False
479
480    return self.tensor == other.tensor and self.axes == other.axes
481
482  def __ne__(self, other):
483    return not self == other
484
485  def __hash__(self):
486    return hash((self.tensor, self.axes))
487
488
489# typecheck type abbreviations:
490# abbreviations for third-party types with very long reprs
491tc.register_type_abbreviation(tensor_shape.Dimension, 'tensorflow.Dimension')
492tc.register_type_abbreviation(ops.Tensor, 'tensorflow.Tensor')
493tc.register_type_abbreviation(dtypes.DType, 'tensorflow.DType')
494# core LabeledTensor types
495tc.register_type_abbreviation(Axis, 'labeled_tensor.Axis')
496tc.register_type_abbreviation(Axes, 'labeled_tensor.Axes')
497tc.register_type_abbreviation(LabeledTensor, 'labeled_tensor.LabeledTensor')
498
499
500@tc.returns(ops.Tensor)
501@tc.accepts(LabeledTensor)
502def _convert_labeled_tensor_to_tensor(value, *args, **kwargs):
503  # call ops.convert_to_tensor to handle optional arguments appropriately
504  return ops.internal_convert_to_tensor(value.tensor, *args, **kwargs)
505
506
507ops.register_tensor_conversion_function(LabeledTensor,
508                                        _convert_labeled_tensor_to_tensor)
509
510# tc class for anything that can be coerced into a LabeledTensor
511# pylint: disable=invalid-name
512LabeledTensorLike = tc.Union(LabeledTensor, ops.Tensor, np.ndarray, Scalar)
513# pylint: enable=invalid-name
514
515
516@tc.returns(LabeledTensor)
517@tc.accepts(LabeledTensorLike, object, tc.Optional(string_types))
518def convert_to_labeled_tensor(value, dtype=None, name=None):
519  """Converts the given `value` to a `LabeledTensor`.
520
521  This function accepts `LabeledTensor` objects, 0-dimensional `Tensor` objects
522  and numpy arrays, and Python scalars. Higher dimensional unlabeled tensors
523  must use the `LabeledTensor` constructor explicitly.
524
525  Args:
526    value: Object to convert.
527    dtype: Optional element type for the returned tensor. If missing, the type
528      is inferred from the type of value.
529    name: Optional name to use if a new Tensor is created.
530
531  Returns:
532    `value` converted into a `LabeledTensor` object.
533
534  Raises:
535    ValueError: If the output would have rank>0 but the input was not already a
536      `LabeledTensor`.
537  """
538  # TODO(shoyer): consider extending to accept xarray.DataArray as input.
539  if isinstance(value, LabeledTensor):
540    axes = value.axes.values()
541    value = value.tensor
542  else:
543    axes = []
544
545  # We call convert_to_tensor even for LabeledTensor input because it also
546  # checks to make sure the dtype argument is compatible.
547  tensor = ops.convert_to_tensor(value, dtype=dtype, name=name)
548  if len(tensor.get_shape()) != len(axes):
549    raise ValueError('cannot automatically convert unlabeled arrays or tensors '
550                     'with rank>0 into LabeledTensors: %r' % value)
551  return LabeledTensor(tensor, axes)
552
553
554@tc.returns(Axis)
555@tc.accepts(tc.Collection(Axis))
556def concat_axes(axes):
557  """Concatenate a list of Axes.
558
559  Args:
560    axes: A collection of Axis objects.
561
562  Returns:
563    The concatenation of the axes.
564    If all axes have labels, the result has the concatenation of the labels.
565    Else, the result has no labels, and its size is the sum of the sizes
566    of the axes.
567
568  Raises:
569    ValueError: If `others` is not a collection of Axes or if it is empty.
570  """
571  if not axes:
572    raise ValueError('axes must not be empty')
573  for a in axes:
574    if not isinstance(a, Axis):
575      raise ValueError('Expected an Axis, but got %r of type %r' % (a, type(a)))
576
577  names = set(a.name for a in axes)
578  if len(names) > 1:
579    raise ValueError('axes do not all have the same name: %r' % names)
580  name, = names
581
582  all_have_labels = all(a.labels is not None for a in axes)
583  any_has_unknown_size = any(a.size is None for a in axes)
584
585  if all_have_labels:
586    value = tuple(label for a in axes for label in a.labels)
587  elif any_has_unknown_size:
588    value = None
589  else:
590    value = sum(len(a) for a in axes)
591  return Axis(name, value)
592
593
594@tc.returns(LabeledTensor)
595@tc.accepts(LabeledTensorLike, tc.Optional(string_types))
596def identity(labeled_tensor, name=None):
597  """The identity op.
598
599  See tf.identity.
600
601  Args:
602    labeled_tensor: The input tensor.
603    name: Optional op name.
604
605  Returns:
606    The tensor.
607  """
608  with ops.name_scope(name, 'lt_identity', [labeled_tensor]) as scope:
609    labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
610    return LabeledTensor(
611        array_ops.identity(
612            labeled_tensor.tensor, name=scope),
613        labeled_tensor.axes)
614
615
616# We don't call this slice because that shadows a built-in. Instead, we alias
617# this to lt.slice in __init__.py.
618@tc.returns(LabeledTensor)
619@tc.accepts(LabeledTensorLike,
620            tc.Mapping(string_types, tc.Union(int, slice)),
621            tc.Optional(string_types))
622def slice_function(labeled_tensor, selection, name=None):
623  """Slice out a subset of the tensor.
624
625  This is an analog of tf.slice.
626  For example:
627  >>> tensor = tf.reshape(tf.range(0, 6), [3, 2])
628  >>> labeled_tensor = lt.LabeledTensor(tensor, ['a', ('b', ['foo', 'bar'])])
629  >>> lt.slice(labeled_tensor, {'a': slice(0, 2), 'b': 1})
630  <LabeledTensor 'lt_slice:...' shape=(2,) dtype=int32
631   axes=[('a', Dimension(2))]>
632
633  Args:
634    labeled_tensor: The input tensor.
635    selection: A dictionary of type str -> Union(int, slice of int) mapping
636      axis names to sub-selections.
637    name: Optional op name.
638
639  Returns:
640    The slice as a `LabeledTensor`.
641  """
642  with ops.name_scope(name, 'lt_slice', [labeled_tensor]) as scope:
643    labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
644
645    slices = []
646
647    for axis_name in labeled_tensor.axes:
648      if axis_name not in selection:
649        # We're not sub-selecting this axis, so use the full slice.
650        slices.append(slice(None))
651      else:
652        slices.append(selection[axis_name])
653
654    sliced_tensor = labeled_tensor.tensor[tuple(slices)]
655
656    sliced_axes = []
657    for axis, s in zip(labeled_tensor.axes.values(), slices):
658      # We sub-select this axis's index with the slice s.
659
660      # `s` is either an int or a proper slice.
661      if isinstance(s, slice):
662        if axis.labels is None:
663          # We're not tracking coordinate names for this axis.
664          sliced_axes.append(axis.name)
665        else:
666          sliced_axes.append((axis.name, axis.labels[s]))
667      else:
668        # If the slice is an int this dimension now has size 1, so we remove it.
669        assert isinstance(s, int)
670
671    return LabeledTensor(
672        array_ops.identity(
673            sliced_tensor, name=scope), sliced_axes)
674
675
676@tc.returns(LabeledTensor)
677@tc.accepts(LabeledTensorLike,
678            tc.Optional(tc.Collection(string_types)), tc.Optional(string_types))
679def transpose(labeled_tensor, axis_order=None, name=None):
680  """Permute a tensor's axes.
681
682  See tf.transpose.
683
684  Args:
685    labeled_tensor: The input tensor.
686    axis_order: Optional desired axis order, as a list of names. By default, the
687      order of axes is reversed.
688    name: Optional op name.
689
690  Returns:
691    The permuted tensor.
692
693  Raises:
694    ValueError: If axis_order isn't a permutation of the existing axes.
695  """
696  with ops.name_scope(name, 'lt_transpose', [labeled_tensor]) as scope:
697    labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
698
699    original_order = list(labeled_tensor.axes.keys())
700    if axis_order is None:
701      axis_order = list(reversed(original_order))
702    elif sorted(axis_order) != sorted(original_order):
703      raise ValueError(
704          'The new axis order must have the same names as the original axes, '
705          'but the new order is %r while the original order is %r' %
706          (axis_order, original_order))
707
708    axis_names = list(labeled_tensor.axes.keys())
709    permutation = [axis_names.index(n) for n in axis_order]
710
711    # Note: TensorFlow doesn't copy data for the identity transpose.
712    transpose_tensor = array_ops.transpose(
713        labeled_tensor.tensor, permutation, name=scope)
714
715    permuted_axes = [labeled_tensor.axes[n] for n in axis_order]
716
717    return LabeledTensor(transpose_tensor, permuted_axes)
718
719
720@tc.returns(LabeledTensor)
721@tc.accepts(
722    LabeledTensorLike,
723    tc.Collection(
724        tc.Union(string_types, tc.Tuple(string_types, collections.Hashable))),
725    tc.Optional(string_types))
726def expand_dims(labeled_tensor, axes, name=None):
727  """Insert dimensions of size 1.
728
729  See tf.expand_dims.
730
731  Args:
732    labeled_tensor: The input tensor.
733    axes: The desired axis names as strings or tuples of (name, label),
734      where `label` is the coordinate name for the new dimension `name`.
735      These must include the existing axis names, and the existing names must
736      appear in the same order in this list as they do in the input tensor.
737    name: Optional op name.
738
739  Returns:
740    A tensor with an axis for each axis in axes.
741    New axes are created with size 1 and do not have labeled coordinates.
742
743  Raises:
744    AxisOrderError: If axis names don't appear in the same order in axes
745      and the labeled tensor.
746  """
747  with ops.name_scope(name, 'lt_expand_dims', [labeled_tensor]) as scope:
748    labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
749
750    axis_names = [a if isinstance(a, string_types) else a[0] for a in axes]
751    check_axis_order(labeled_tensor, axis_names)
752
753    reshaped_axes = []
754    shape = []
755    for axis_spec in axes:
756      if axis_spec in labeled_tensor.axes:
757        axis = labeled_tensor.axes[axis_spec]
758        reshaped_axes.append(axis)
759        shape.append(-1 if axis.size is None else axis.size)
760      else:
761        if isinstance(axis_spec, string_types):
762          reshaped_axes.append((axis_spec, 1))
763        else:
764          (name, label) = axis_spec
765          reshaped_axes.append((name, (label,)))
766
767        shape.append(1)
768
769    reshaped_tensor = array_ops.reshape(
770        labeled_tensor.tensor, shape, name=scope)
771
772    return LabeledTensor(reshaped_tensor, reshaped_axes)
773
774
775# This should only be added to a graph collection once.
776_AXIS_ORDER_KEY = ('__axis_order',)
777
778
779@tc.returns(tc.Optional(tc.List(string_types)))
780def get_axis_order():
781  """Get the axis_order set by any containing axis_order_scope.
782
783  Returns:
784    List of strings giving an order to use for axis names, or None, if no axis
785    order is set.
786  """
787  # By storing axis_order in the graph, we can ensure that axis_order_scope is
788  # thread-safe.
789  axis_order_list = ops.get_collection(_AXIS_ORDER_KEY)
790  if axis_order_list:
791    axis_order, = axis_order_list
792  else:
793    axis_order = None
794  return axis_order
795
796
797@tc.accepts(tc.Optional(tc.List(string_types)))
798def _set_axis_order(axis_order):
799  axis_order_list = ops.get_collection_ref(_AXIS_ORDER_KEY)
800  if axis_order_list:
801    axis_order_list[0] = axis_order
802  else:
803    axis_order_list.append(axis_order)
804
805
806@contextlib.contextmanager
807@tc.accepts(tc.Optional(tc.List(string_types)))
808def axis_order_scope(axis_order=None):
809  """Set axis order for the result of broadcasting operations within a scope.
810
811  This allows you to ensure that tensors resulting from arithmetic have a
812  predictable axis order.
813
814  Example usage:
815
816    with lt.axis_order_scope(['x', 'y', 'z']):
817      # result is guaranteed to have the correct axis order
818      result = w + b
819
820  You can nest scopes, in which case only the inner-most scope applies, e.g.,
821
822    with lt.axis_order(['x', 'y', 'z']):
823      with lt.axis_order():
824        result = w + b  # uses the default (left-most) axis ordering
825
826  Args:
827    axis_order: optional list of strings providing axis names. By default,
828      creates a scope without axis order.
829
830  Yields:
831    The provided axis_order or `None`.
832  """
833  original_axis_order = get_axis_order()
834  _set_axis_order(axis_order)
835  try:
836    yield axis_order
837  finally:
838    _set_axis_order(original_axis_order)
839
840
841@tc.returns(tc.List(string_types))
842def _get_valid_axis_order():
843  axis_order = get_axis_order()
844  if axis_order is None:
845    raise AxisOrderError('an explicit axis order must be provided with the '
846                         'axis_order argument or by using an axis_order_scope')
847  return axis_order
848
849
850class AxisOrderError(ValueError):
851  """Error class for cases where there is no valid axis order."""
852
853
854# TODO(shoyer): should this function accept a list of labeled tensors instead?
855@tc.returns(type(None))
856@tc.accepts(LabeledTensorLike, tc.Optional(tc.Collection(string_types)))
857def check_axis_order(labeled_tensor, axis_order=None):
858  """Verify that the given tensor has a consistent axis order.
859
860  Args:
861    labeled_tensor: The input tensor. All axes on this tensor must appear in
862      axis_order.
863    axis_order: Optional desired axis order, as a list of names. If not
864      provided, defaults to the current axis_order_scope (if set).
865
866  Raises:
867    AxisOrderError: If the axis_order is unavailable, inconsistent or does not
868      include all existing axes.
869  """
870  labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
871
872  if axis_order is None:
873    axis_order = _get_valid_axis_order()
874
875  relevant_axis_order = [a for a in axis_order if a in labeled_tensor.axes]
876
877  if len(relevant_axis_order) < len(labeled_tensor.axes):
878    raise AxisOrderError(
879        'not all axis names appear in the required axis order %r: %r' %
880        (axis_order, labeled_tensor))
881
882  if relevant_axis_order != list(labeled_tensor.axes):
883    raise AxisOrderError(
884        'axes on a labeled tensor do not appear in the same order as the '
885        'required axis order %r: %r' % (axis_order, labeled_tensor))
886
887
888@tc.returns(LabeledTensor)
889@tc.accepts(LabeledTensorLike,
890            tc.Optional(tc.Collection(string_types)), tc.Optional(string_types))
891def impose_axis_order(labeled_tensor, axis_order=None, name=None):
892  """Impose desired axis order on a labeled tensor.
893
894  Args:
895    labeled_tensor: The input tensor.
896    axis_order: Optional desired axis order, as a list of names. If not
897      provided, defaults to the current axis_order_scope (if set).
898    name: Optional op name.
899
900  Returns:
901    Labeled tensor with possibly transposed axes.
902
903  Raises:
904    AxisOrderError: If no axis_order is provided or axis_order does not contain
905      all axes on the input tensor.
906  """
907  with ops.name_scope(name, 'lt_impose_axis_order', [labeled_tensor]) as scope:
908    labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
909
910    if axis_order is None:
911      axis_order = _get_valid_axis_order()
912
913    relevant_axis_order = [a for a in axis_order if a in labeled_tensor.axes]
914
915    return transpose(labeled_tensor, relevant_axis_order, name=scope)
916
917
918@tc.returns(tc.Optional(list))
919@tc.accepts(list, list)
920def _find_consistent_ordering(a, b):
921  """Find the left-most consistent ordering between two lists of unique items.
922
923  A consistent ordering combines all elements in both a and b while keeping all
924  elements in their original order in both inputs. The left-most consistent
925  ordering orders elements from `a` not found in `b` before elements in `b` not
926  found in `a`.
927
928  For example, given ['x', 'z'] and ['y', 'z'], both ['x', 'y', 'z'] and ['y',
929  'x', 'z'] are consistent orderings because each of the inputs appears in
930  each consistent ordering in the same order, and ['x', 'y', 'z'] is the
931  left-most, because 'x' appears only in `a` and 'y' appears only in `b`. In
932  contrast, there is no consistent ordering between ['x', 'y'] and ['y', 'x'].
933
934  Args:
935    a: list with unique elements.
936    b: list with unique elements.
937
938  Returns:
939    List containing all elements in either a or b, or None, if no consistent
940    ordering exists.
941  """
942  a_set = set(a)
943  b_set = set(b)
944  i = 0
945  j = 0
946  ordering = []
947  while i < len(a) and j < len(b):
948    if a[i] not in b_set:
949      ordering.append(a[i])
950      i += 1
951    elif b[j] not in a_set:
952      ordering.append(b[j])
953      j += 1
954    elif a[i] == b[j]:
955      ordering.append(a[i])
956      i += 1
957      j += 1
958    else:
959      return None
960
961  ordering.extend(a[i:])
962  ordering.extend(b[j:])
963
964  return ordering
965
966
967@tc.returns(LabeledTensor, LabeledTensor, Axes)
968@tc.accepts(LabeledTensorLike, LabeledTensorLike, tc.Optional(string_types))
969def align(labeled_tensor_0, labeled_tensor_1, name=None):
970  """Align the axes of two tensors so they may be broadcast to each other.
971
972  Axes are ordered by the current axis order scope, if present, or by the left-
973  most consistent ordering. An exception is raised if it is impossible to align
974  the tensors without a transpose (align never copies the input data).
975
976  Example usage:
977
978    >>> a = lt.LabeledTensor(tf.ones((2, 4)), ['x', 'z'])
979    >>> b = lt.LabeledTensor(tf.ones((3, 4)), ['y', 'z'])
980    >>> a2, b2, axes = lt.align(a, b)
981    >>> a2
982    <LabeledTensor 'lt_align_1/lt_align_1/0:...' shape=(2, 1, 4) dtype=float32
983     axes=[('x', Dimension(2)),
984           ('y', Dimension(1)),
985           ('z', Dimension(4))]>
986    >>> b2
987    <LabeledTensor 'lt_align_1/lt_align_1/1:...' shape=(1, 3, 4) dtype=float32
988     axes=[('x', Dimension(1)),
989           ('y', Dimension(3)),
990           ('z', Dimension(4))]>
991    >>> axes
992    Axes([('x', Dimension(2)),
993          ('y', Dimension(3)),
994          ('z', Dimension(4))])
995
996  Args:
997    labeled_tensor_0: An input tensor.
998    labeled_tensor_1: An input tensor.
999    name: Optional op name.
1000
1001  Returns:
1002    The aligned tensors and the axes the resulting tensor would have if the two
1003    aligned tensors were broadcast to each other. The aligned tensors have the
1004    same rank but not necessarily the same shape, with axes in the same order.
1005
1006  Raises:
1007    ValueError: If axes with the same name on the inputs are not equal.
1008    AxisOrderError: If there is no way to reshape the input tensors into the
1009      output without a transpose.
1010  """
1011  with ops.name_scope(name, 'lt_align',
1012                      [labeled_tensor_0, labeled_tensor_1]) as scope:
1013
1014    labeled_tensor_0 = convert_to_labeled_tensor(labeled_tensor_0)
1015    labeled_tensor_1 = convert_to_labeled_tensor(labeled_tensor_1)
1016
1017    axes_0 = labeled_tensor_0.axes
1018    axes_1 = labeled_tensor_1.axes
1019    for axis_name in axes_0:
1020      if axis_name in axes_1:
1021        if axes_0[axis_name] != axes_1[axis_name]:
1022          raise ValueError('Mismatched %r axis on input tensors: %r and %r' %
1023                           (axis_name, axes_0[axis_name], axes_1[axis_name]))
1024
1025    axis_scope_order = get_axis_order()
1026    if axis_scope_order is not None:
1027      # we are in an axis_order_scope
1028      axis_names_set = set(axes_0) | set(axes_1)
1029      new_axis_names = [a for a in axis_scope_order if a in axis_names_set]
1030
1031      check_axis_order(labeled_tensor_0, axis_scope_order)
1032      check_axis_order(labeled_tensor_1, axis_scope_order)
1033
1034    else:
1035      # attempt to find a consistent ordering
1036      new_axis_names = _find_consistent_ordering(list(axes_0), list(axes_1))
1037      if new_axis_names is None:
1038        raise AxisOrderError(
1039            'No consistent axis order allows for aligning tensors with axis '
1040            'orders %r and %r without copying data. Use transpose or '
1041            'impose_axis_order to reorder axes on one of more of the inputs.' %
1042            (axes_0.keys(), axes_1.keys()))
1043
1044    labeled_tensor_0 = expand_dims(
1045        labeled_tensor_0, new_axis_names, name=scope + '0')
1046    labeled_tensor_1 = expand_dims(
1047        labeled_tensor_1, new_axis_names, name=scope + '1')
1048
1049    broadcast_axes = []
1050    for axis_name in new_axis_names:
1051      if axis_name in axes_0:
1052        broadcast_axes.append(axes_0[axis_name])
1053      else:
1054        broadcast_axes.append(axes_1[axis_name])
1055
1056    return labeled_tensor_0, labeled_tensor_1, Axes(broadcast_axes)
1057
1058
1059@tc.returns(types.FunctionType)
1060@tc.accepts(string_types, collections.Callable)
1061def define_unary_op(op_name, elementwise_function):
1062  """Define a unary operation for labeled tensors.
1063
1064  Args:
1065    op_name: string name of the TensorFlow op.
1066    elementwise_function: function to call to evaluate the op on a single
1067      tf.Tensor object. This function must accept two arguments: a tf.Tensor
1068      object, and an optional `name`.
1069
1070  Returns:
1071    Function defining the given op that acts on LabeledTensors.
1072  """
1073
1074  default_name = 'lt_%s' % op_name
1075
1076  @tc.returns(LabeledTensor)
1077  @tc.accepts(LabeledTensorLike, tc.Optional(string_types))
1078  def op(labeled_tensor, name=None):
1079    """LabeledTensor version of `tf.{op_name}`.
1080
1081    See `tf.{op_name}` for full details.
1082
1083    Args:
1084      labeled_tensor: Input tensor.
1085      name: Optional op name.
1086
1087    Returns:
1088      A LabeledTensor with result of applying `tf.{op_name}` elementwise.
1089    """
1090    with ops.name_scope(name, default_name, [labeled_tensor]) as scope:
1091      labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
1092      result_tensor = elementwise_function(labeled_tensor.tensor, name=scope)
1093      return LabeledTensor(result_tensor, labeled_tensor.axes)
1094
1095  op.__doc__ = op.__doc__.format(op_name=op_name)
1096  op.__name__ = op_name
1097
1098  return op
1099
1100
1101abs_function = define_unary_op('abs', math_ops.abs)
1102neg = define_unary_op('neg', math_ops.negative)
1103sign = define_unary_op('sign', math_ops.sign)
1104reciprocal = define_unary_op('reciprocal', math_ops.reciprocal)
1105square = define_unary_op('square', math_ops.square)
1106round_function = define_unary_op('round', math_ops.round)
1107sqrt = define_unary_op('sqrt', math_ops.sqrt)
1108rsqrt = define_unary_op('rsqrt', math_ops.rsqrt)
1109exp = define_unary_op('exp', math_ops.exp)
1110log = define_unary_op('log', math_ops.log)
1111ceil = define_unary_op('ceil', math_ops.ceil)
1112floor = define_unary_op('floor', math_ops.floor)
1113cos = define_unary_op('cos', math_ops.cos)
1114sin = define_unary_op('sin', math_ops.sin)
1115tan = define_unary_op('tan', math_ops.tan)
1116acos = define_unary_op('acos', math_ops.acos)
1117asin = define_unary_op('asin', math_ops.asin)
1118atan = define_unary_op('atan', math_ops.atan)
1119lgamma = define_unary_op('lgamma', math_ops.lgamma)
1120digamma = define_unary_op('digamma', math_ops.digamma)
1121erf = define_unary_op('erf', math_ops.erf)
1122erfc = define_unary_op('erfc', math_ops.erfc)
1123logical_not = define_unary_op('logical_not', math_ops.logical_not)
1124tanh = define_unary_op('tanh', math_ops.tanh)
1125sigmoid = define_unary_op('sigmoid', math_ops.sigmoid)
1126
1127
1128@tc.returns(types.FunctionType)
1129@tc.accepts(string_types, collections.Callable)
1130def define_binary_op(op_name, elementwise_function):
1131  """Define a binary operation that broadcasts labeled tensors.
1132
1133  Args:
1134    op_name: string name of the TensorFlow op.
1135    elementwise_function: function to call to evaluate the op on tf.Tensor
1136      objects. This function must accept three arguments: two tf.Tensor objects,
1137      and an optional `name`.
1138
1139  Returns:
1140    Function defining the given op that acts on LabeledTensors.
1141  """
1142
1143  default_name = 'lt_%s' % op_name
1144
1145  @tc.returns(LabeledTensor)
1146  @tc.accepts(LabeledTensorLike, LabeledTensorLike, tc.Optional(string_types))
1147  def op(labeled_tensor_0, labeled_tensor_1, name=None):
1148    """LabeledTensor version of `tf.{op_name}` with label based alignment.
1149
1150    See `tf.{op_name}` for full details.
1151
1152    Args:
1153      labeled_tensor_0: Input tensor.
1154      labeled_tensor_1: Input tensor.
1155      name: Optional op name.
1156
1157    Returns:
1158      A LabeledTensor with result of applying `tf.{op_name}` elementwise.
1159    """
1160    with ops.name_scope(name, default_name,
1161                        [labeled_tensor_0, labeled_tensor_1]) as scope:
1162
1163      align_0, align_1, broadcast_axes = align(labeled_tensor_0,
1164                                               labeled_tensor_1)
1165
1166      tensor = elementwise_function(align_0.tensor, align_1.tensor, name=scope)
1167
1168      return LabeledTensor(tensor, broadcast_axes)
1169
1170  op.__doc__ = op.__doc__.format(op_name=op_name)
1171  op.__name__ = op_name
1172
1173  return op
1174
1175
1176add = define_binary_op('add', math_ops.add)
1177sub = define_binary_op('sub', math_ops.subtract)
1178mul = define_binary_op('mul', math_ops.multiply)
1179div = define_binary_op('div', math_ops.div)
1180mod = define_binary_op('mod', math_ops.mod)
1181pow_function = define_binary_op('pow', math_ops.pow)
1182
1183equal = define_binary_op('equal', math_ops.equal)
1184greater = define_binary_op('greater', math_ops.greater)
1185greater_equal = define_binary_op('greater_equal', math_ops.greater_equal)
1186not_equal = define_binary_op('not_equal', math_ops.not_equal)
1187less = define_binary_op('less', math_ops.less)
1188less_equal = define_binary_op('less_equal', math_ops.less_equal)
1189logical_and = define_binary_op('logical_and', math_ops.logical_and)
1190logical_or = define_binary_op('logical_or', math_ops.logical_or)
1191logical_xor = define_binary_op('logical_xor', math_ops.logical_xor)
1192
1193maximum = define_binary_op('maximum', math_ops.maximum)
1194minimum = define_binary_op('minimum', math_ops.minimum)
1195squared_difference = define_binary_op('squared_difference',
1196                                      math_ops.squared_difference)
1197igamma = define_binary_op('igamma', math_ops.igamma)
1198igammac = define_binary_op('igammac', math_ops.igammac)
1199zeta = define_binary_op('zeta', math_ops.zeta)
1200polygamma = define_binary_op('polygamma', math_ops.polygamma)
1201