1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Non-core ops for LabeledTensor."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import types
22
23import numpy as np
24from six import string_types
25
26from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc
27from tensorflow.contrib.labeled_tensor.python.ops import core
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import functional_ops
32from tensorflow.python.ops import map_fn as map_fn_lib
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import numerics
35from tensorflow.python.ops import random_ops
36from tensorflow.python.training import input  # pylint: disable=redefined-builtin
37
38
39@tc.returns(core.LabeledTensor)
40@tc.accepts(core.LabeledTensor, ops.Tensor, core.Axis,
41            tc.Optional(string_types))
42def _gather_1d_on_axis(labeled_tensor, indexer, axis, name=None):
43  with ops.name_scope(name, 'lt_take', [labeled_tensor]) as scope:
44    temp_axes = core.Axes([axis] + list(
45        labeled_tensor.axes.remove(axis.name).values()))
46    transposed = core.transpose(labeled_tensor, temp_axes.keys())
47    indexed = core.LabeledTensor(
48        array_ops.gather(transposed.tensor, indexer), temp_axes)
49    return core.transpose(indexed, labeled_tensor.axes.keys(), name=scope)
50
51
52@tc.returns(core.LabeledTensor)
53@tc.accepts(core.LabeledTensorLike,
54            tc.Mapping(string_types,
55                       tc.Union(slice, collections.Hashable, list)),
56            tc.Optional(string_types))
57def select(labeled_tensor, selection, name=None):
58  """Slice out a subset of the tensor.
59
60  Args:
61    labeled_tensor: The input tensor.
62    selection: A dictionary mapping an axis name to a scalar, slice or list of
63      values to select. Currently supports two types of selections:
64        (a) Any number of scalar and/or slice selections.
65        (b) Exactly one list selection, without any scalars or slices.
66    name: Optional op name.
67
68  Returns:
69    The selection as a `LabeledTensor`.
70
71  Raises:
72    ValueError: If the tensor doesn't have an axis in the selection or if
73      that axis lacks labels.
74    KeyError: If any labels in a selection are not found in the original axis.
75    NotImplementedError: If you attempt to combine a list selection with
76      scalar selection or another list selection.
77  """
78  with ops.name_scope(name, 'lt_select', [labeled_tensor]) as scope:
79    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
80
81    slices = {}
82    indexers = {}
83    for axis_name, value in selection.items():
84      if axis_name not in labeled_tensor.axes:
85        raise ValueError(
86            'The tensor does not have an axis named %s. Its axes are: %r' %
87            (axis_name, labeled_tensor.axes.keys()))
88      axis = labeled_tensor.axes[axis_name]
89      if axis.labels is None:
90        raise ValueError(
91            'The axis named %s does not have labels. The axis is: %r' %
92            (axis_name, axis))
93
94      if isinstance(value, slice):
95        # TODO(shoyer): consider deprecating using slices in favor of lists
96        if value.start is None:
97          start = None
98        else:
99          start = axis.index(value.start)
100
101        if value.stop is None:
102          stop = None
103        else:
104          # For now, follow the pandas convention of making labeled slices
105          # inclusive of both bounds.
106          stop = axis.index(value.stop) + 1
107
108        if value.step is not None:
109          raise NotImplementedError('slicing with a step is not yet supported')
110
111        slices[axis_name] = slice(start, stop)
112
113      # Needs to be after checking for slices, since slice objects claim to be
114      # instances of collections.Hashable but hash() on them fails.
115      elif isinstance(value, collections.Hashable):
116        slices[axis_name] = axis.index(value)
117
118      elif isinstance(value, list):
119        if indexers:
120          raise NotImplementedError(
121              'select does not yet support more than one list selection at '
122              'the same time')
123        indexer = [axis.index(v) for v in value]
124        indexers[axis_name] = ops.convert_to_tensor(indexer, dtype=dtypes.int64)
125
126      else:
127        # If type checking is working properly, this shouldn't be possible.
128        raise TypeError('cannot handle arbitrary types')
129
130    if indexers and slices:
131      raise NotImplementedError(
132          'select does not yet support combined scalar and list selection')
133
134    # For now, handle array selection separately, because tf.gather_nd does
135    # not support gradients yet. Later, using gather_nd will let us combine
136    # these paths.
137    if indexers:
138      (axis_name, indexer), = indexers.items()
139      axis = core.Axis(axis_name, selection[axis_name])
140      return _gather_1d_on_axis(labeled_tensor, indexer, axis, name=scope)
141    else:
142      return core.slice_function(labeled_tensor, slices, name=scope)
143
144
145@tc.returns(core.LabeledTensor)
146@tc.accepts(
147    tc.Collection(core.LabeledTensorLike), string_types,
148    tc.Optional(string_types))
149def concat(labeled_tensors, axis_name, name=None):
150  """Concatenate tensors along a dimension.
151
152  See tf.concat.
153
154  Args:
155    labeled_tensors: A list of input LabeledTensors.
156    axis_name: The name of the axis along which to concatenate.
157    name: Optional op name.
158
159  Returns:
160    The concatenated tensor.
161    The coordinate labels for the concatenation dimension are also concatenated,
162    if they are available for every tensor.
163
164  Raises:
165    ValueError: If fewer than one tensor inputs is provided, if the tensors
166      have incompatible axes, or if `axis_name` isn't the name of an axis.
167  """
168  with ops.name_scope(name, 'lt_concat', labeled_tensors) as scope:
169    labeled_tensors = [
170        core.convert_to_labeled_tensor(lt) for lt in labeled_tensors
171    ]
172
173    if len(labeled_tensors) < 1:
174      raise ValueError('concat expects at least 1 tensor, but received %s' %
175                       labeled_tensors)
176
177    # All tensors must have these axes.
178    axes_0 = labeled_tensors[0].axes
179    axis_names = list(axes_0.keys())
180
181    if axis_name not in axis_names:
182      raise ValueError('%s not in %s' % (axis_name, axis_names))
183
184    shared_axes = axes_0.remove(axis_name)
185
186    tensors = [labeled_tensors[0].tensor]
187    concat_axis_list = [axes_0[axis_name]]
188    for labeled_tensor in labeled_tensors[1:]:
189      current_shared_axes = labeled_tensor.axes.remove(axis_name)
190      if current_shared_axes != shared_axes:
191        # TODO(shoyer): add more specific checks about what went wrong,
192        # including raising AxisOrderError when appropriate
193        raise ValueError('Mismatched shared axes: the first tensor '
194                         'had axes %r but this tensor has axes %r.' %
195                         (shared_axes, current_shared_axes))
196
197      # Accumulate the axis labels, if they're available.
198      concat_axis_list.append(labeled_tensor.axes[axis_name])
199      tensors.append(labeled_tensor.tensor)
200
201    concat_axis = core.concat_axes(concat_axis_list)
202    concat_dimension = axis_names.index(axis_name)
203    concat_tensor = array_ops.concat(tensors, concat_dimension, name=scope)
204    values = list(axes_0.values())
205    concat_axes = (values[:concat_dimension] + [concat_axis] +
206                   values[concat_dimension + 1:])
207
208    return core.LabeledTensor(concat_tensor, concat_axes)
209
210
211# TODO(shoyer): rename pack/unpack to stack/unstack
212
213
214@tc.returns(core.LabeledTensor)
215@tc.accepts(
216    tc.Collection(core.LabeledTensorLike),
217    tc.Union(string_types, core.AxisLike), int, tc.Optional(string_types))
218def pack(labeled_tensors, new_axis, axis_position=0, name=None):
219  """Pack tensors along a new axis.
220
221  See tf.pack.
222
223  Args:
224    labeled_tensors: The input tensors, which must have identical axes.
225    new_axis: The name of the new axis, or a tuple containing the name
226      and coordinate labels.
227    axis_position: Optional integer position at which to insert the new axis.
228    name: Optional op name.
229
230  Returns:
231    The packed tensors as a single LabeledTensor, with `new_axis` in the given
232    `axis_position`.
233
234  Raises:
235    ValueError: If fewer than one input tensors is provided, or if the tensors
236      don't have identical axes.
237  """
238  with ops.name_scope(name, 'lt_pack', labeled_tensors) as scope:
239    labeled_tensors = [
240        core.convert_to_labeled_tensor(lt) for lt in labeled_tensors
241    ]
242
243    if len(labeled_tensors) < 1:
244      raise ValueError('pack expects at least 1 tensors, but received %s' %
245                       labeled_tensors)
246
247    axes_0 = labeled_tensors[0].axes
248    for t in labeled_tensors:
249      if t.axes != axes_0:
250        raise ValueError('Non-identical axes. Expected %s but got %s' %
251                         (axes_0, t.axes))
252
253    pack_op = array_ops.stack(
254        [t.tensor for t in labeled_tensors], axis=axis_position, name=scope)
255    axes = list(axes_0.values())
256    axes.insert(axis_position, new_axis)
257    return core.LabeledTensor(pack_op, axes)
258
259
260@tc.returns(tc.List(core.LabeledTensor))
261@tc.accepts(core.LabeledTensorLike,
262            tc.Optional(string_types), tc.Optional(string_types))
263def unpack(labeled_tensor, axis_name=None, name=None):
264  """Unpack the tensor.
265
266  See tf.unpack.
267
268  Args:
269    labeled_tensor: The input tensor.
270    axis_name: Optional name of axis to unpack. By default, the first axis is
271      used.
272    name: Optional op name.
273
274  Returns:
275    The list of unpacked LabeledTensors.
276
277  Raises:
278    ValueError: If `axis_name` is not an axis on the input.
279  """
280  with ops.name_scope(name, 'lt_unpack', [labeled_tensor]) as scope:
281    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
282
283    axis_names = list(labeled_tensor.axes.keys())
284    if axis_name is None:
285      axis_name = axis_names[0]
286
287    if axis_name not in axis_names:
288      raise ValueError('%s not in %s' % (axis_name, axis_names))
289    axis = axis_names.index(axis_name)
290
291    unpack_ops = array_ops.unstack(labeled_tensor.tensor, axis=axis, name=scope)
292    axes = [a for i, a in enumerate(labeled_tensor.axes.values()) if i != axis]
293    return [core.LabeledTensor(t, axes) for t in unpack_ops]
294
295
296@tc.returns(core.LabeledTensor)
297@tc.accepts(core.LabeledTensorLike,
298            tc.Collection(string_types),
299            tc.Collection(tc.Union(string_types, core.AxisLike)),
300            tc.Optional(string_types))
301def reshape(labeled_tensor, existing_axes, new_axes, name=None):
302  """Reshape specific axes of a LabeledTensor.
303
304  Non-indicated axes remain in their original locations.
305
306  Args:
307    labeled_tensor: The input tensor.
308    existing_axes: List of axis names found on the input tensor. These must
309      appear sequentially in the list of axis names on the input. In other
310      words, they must be a valid slice of `list(labeled_tensor.axes.keys())`.
311    new_axes: List of strings, tuples of (axis_name, axis_value) or Axis objects
312      providing new axes with which to replace `existing_axes` in the reshaped
313      result. At most one element of `new_axes` may be a string, indicating an
314      axis with unknown size.
315    name: Optional op name.
316
317  Returns:
318    The reshaped LabeledTensor.
319
320  Raises:
321    ValueError: If `existing_axes` are not all axes on the input, or if more
322     than one of `new_axes` has unknown size.
323    AxisOrderError: If `existing_axes` are not a slice of axis names on the
324      input.
325  """
326  with ops.name_scope(name, 'lt_reshape', [labeled_tensor]) as scope:
327    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
328
329    original_axis_names = list(labeled_tensor.axes.keys())
330    existing_axes = list(existing_axes)
331    if not set(existing_axes) <= set(original_axis_names):
332      raise ValueError('existing_axes %r are not contained in the set of axis '
333                       'names %r on the input labeled tensor' %
334                       (existing_axes, original_axis_names))
335
336    start = original_axis_names.index(existing_axes[0])
337    stop = original_axis_names.index(existing_axes[-1]) + 1
338
339    if existing_axes != original_axis_names[start:stop]:
340      # We could support existing_axes that aren't a slice by using transpose,
341      # but that could lead to unpredictable performance consequences because
342      # transposes are not free in TensorFlow. If we did transpose
343      # automatically, the user might never realize that their data is being
344      # produced with the wrong order. (The later will occur with some frequency
345      # because of how broadcasting automatically choose axis order.)
346      # So for now we've taken the strict approach.
347      raise core.AxisOrderError(
348          'existing_axes %r are not a slice of axis names %r on the input '
349          'labeled tensor. Use `transpose` or `impose_axis_order` to reorder '
350          'axes on the input explicitly.' %
351          (existing_axes, original_axis_names))
352
353    if sum(isinstance(axis, string_types) for axis in new_axes) > 1:
354      raise ValueError(
355          'at most one axis in new_axes can have unknown size. All other '
356          'axes must have an indicated integer size or labels: %r' % new_axes)
357
358    original_values = list(labeled_tensor.axes.values())
359    axis_size = lambda axis: -1 if axis.size is None else axis.size
360    shape = [axis_size(axis) for axis in original_values[:start]]
361    for axis_ref in new_axes:
362      if isinstance(axis_ref, string_types):
363        shape.append(-1)
364      else:
365        axis = core.as_axis(axis_ref)
366        shape.append(axis_size(axis))
367    shape.extend(axis_size(axis) for axis in original_values[stop:])
368
369    reshaped_tensor = array_ops.reshape(
370        labeled_tensor.tensor, shape, name=scope)
371    axes = original_values[:start] + list(new_axes) + original_values[stop:]
372    return core.LabeledTensor(reshaped_tensor, axes)
373
374
375@tc.returns(core.LabeledTensor)
376@tc.accepts(core.LabeledTensorLike, string_types, string_types,
377            tc.Optional(string_types))
378def rename_axis(labeled_tensor, existing_name, new_name, name=None):
379  """Rename an axis of LabeledTensor.
380
381  Args:
382    labeled_tensor: The input tensor.
383    existing_name: Name for an existing axis on the input.
384    new_name: Desired replacement name.
385    name: Optional op name.
386
387  Returns:
388    LabeledTensor with renamed axis.
389
390  Raises:
391    ValueError: If `existing_name` is not an axis on the input.
392  """
393  with ops.name_scope(name, 'lt_rename_axis', [labeled_tensor]) as scope:
394    if existing_name not in labeled_tensor.axes:
395      raise ValueError('existing_name %r are not contained in the set of axis '
396                       'names %r on the input labeled tensor' %
397                       (existing_name, labeled_tensor.axes.keys()))
398    new_axis = core.Axis(new_name, labeled_tensor.axes[existing_name].value)
399    return reshape(labeled_tensor, [existing_name], [new_axis], name=scope)
400
401
402@tc.returns(tc.List(core.LabeledTensor))
403@tc.accepts(string_types, collections.Callable, int, bool,
404            tc.Collection(core.LabeledTensorLike), bool,
405            tc.Optional(string_types))
406def _batch_helper(default_name,
407                  batch_fn,
408                  batch_size,
409                  enqueue_many,
410                  labeled_tensors,
411                  allow_smaller_final_batch,
412                  name=None):
413  with ops.name_scope(name, default_name, labeled_tensors) as scope:
414    labeled_tensors = [
415        core.convert_to_labeled_tensor(lt) for lt in labeled_tensors
416    ]
417
418    batch_ops = batch_fn([t.tensor for t in labeled_tensors], scope)
419    # TODO(shoyer): Remove this when they sanitize the TF API.
420    if not isinstance(batch_ops, list):
421      assert isinstance(batch_ops, ops.Tensor)
422      batch_ops = [batch_ops]
423
424    if allow_smaller_final_batch:
425      batch_size = None
426
427    @tc.returns(core.Axes)
428    @tc.accepts(core.Axes)
429    def output_axes(axes):
430      if enqueue_many:
431        if 'batch' not in axes or list(axes.keys()).index('batch') != 0:
432          raise ValueError(
433              'When enqueue_many is True, input tensors must have an axis '
434              'called "batch" as their first dimension, '
435              'but axes were %s' % axes)
436        culled_axes = axes.remove('batch')
437        return core.Axes([('batch', batch_size)] + list(culled_axes.values()))
438      else:
439        return core.Axes([('batch', batch_size)] + list(axes.values()))
440
441    output_labeled_tensors = []
442    for i, tensor in enumerate(batch_ops):
443      axes = output_axes(labeled_tensors[i].axes)
444      output_labeled_tensors.append(core.LabeledTensor(tensor, axes))
445
446    return output_labeled_tensors
447
448
449@tc.returns(tc.List(core.LabeledTensor))
450@tc.accepts(
451    tc.Collection(core.LabeledTensorLike), int, int, int, bool, bool,
452    tc.Optional(string_types))
453def batch(labeled_tensors,
454          batch_size,
455          num_threads=1,
456          capacity=32,
457          enqueue_many=False,
458          allow_smaller_final_batch=False,
459          name=None):
460  """Rebatch a tensor.
461
462  See tf.batch.
463
464  Args:
465    labeled_tensors: The input tensors.
466    batch_size: The output batch size.
467    num_threads: See tf.batch.
468    capacity: See tf.batch.
469    enqueue_many: If true, the input tensors must contain a 'batch' axis as
470      their first axis.
471      If false, the input tensors must not contain a 'batch' axis.
472      See tf.batch.
473    allow_smaller_final_batch: See tf.batch.
474    name: Optional op name.
475
476  Returns:
477    The rebatched tensors.
478    If enqueue_many is false, the output tensors will have a new 'batch' axis
479      as their first axis.
480
481  Raises:
482    ValueError: If enqueue_many is True and the first axis of the tensors
483      isn't "batch".
484  """
485
486  def fn(tensors, scope):
487    return input.batch(
488        tensors,
489        batch_size=batch_size,
490        num_threads=num_threads,
491        capacity=capacity,
492        enqueue_many=enqueue_many,
493        allow_smaller_final_batch=allow_smaller_final_batch,
494        name=scope)
495
496  return _batch_helper('lt_batch', fn, batch_size, enqueue_many,
497                       labeled_tensors, allow_smaller_final_batch, name)
498
499
500@tc.returns(tc.List(core.LabeledTensor))
501@tc.accepts(
502    tc.Collection(core.LabeledTensorLike), int, int, int, bool, int,
503    tc.Optional(int), bool, tc.Optional(string_types))
504def shuffle_batch(labeled_tensors,
505                  batch_size,
506                  num_threads=1,
507                  capacity=32,
508                  enqueue_many=False,
509                  min_after_dequeue=0,
510                  seed=None,
511                  allow_smaller_final_batch=False,
512                  name=None):
513  """Rebatch a tensor, with shuffling.
514
515  See tf.batch.
516
517  Args:
518    labeled_tensors: The input tensors.
519    batch_size: The output batch size.
520    num_threads: See tf.batch.
521    capacity: See tf.batch.
522    enqueue_many: If true, the input tensors must contain a 'batch' axis as
523      their first axis.
524      If false, the input tensors must not contain a 'batch' axis.
525      See tf.batch.
526    min_after_dequeue: Minimum number of elements in the queue after a dequeue,
527      used to ensure mixing.
528    seed: Optional random seed.
529    allow_smaller_final_batch: See tf.batch.
530    name: Optional op name.
531
532  Returns:
533    The rebatched tensors.
534    If enqueue_many is false, the output tensors will have a new 'batch' axis
535      as their first axis.
536
537  Raises:
538    ValueError: If enqueue_many is True and the first axis of the tensors
539      isn't "batch".
540  """
541
542  def fn(tensors, scope):
543    return input.shuffle_batch(
544        tensors,
545        batch_size=batch_size,
546        num_threads=num_threads,
547        capacity=capacity,
548        enqueue_many=enqueue_many,
549        min_after_dequeue=min_after_dequeue,
550        seed=seed,
551        allow_smaller_final_batch=allow_smaller_final_batch,
552        name=scope)
553
554  return _batch_helper('lt_shuffle_batch', fn, batch_size, enqueue_many,
555                       labeled_tensors, allow_smaller_final_batch, name)
556
557
558@tc.returns(core.LabeledTensor)
559@tc.accepts(core.LabeledTensorLike,
560            tc.Mapping(string_types, int),
561            tc.Optional(int), tc.Optional(string_types))
562def random_crop(labeled_tensor, shape_map, seed=None, name=None):
563  """Randomly crops a tensor to a given size.
564
565  See tf.random_crop.
566
567  Args:
568    labeled_tensor: The input tensor.
569    shape_map: A dictionary mapping axis names to the size of the random crop
570      for that dimension.
571    seed: An optional random seed.
572    name: An optional op name.
573
574  Returns:
575    A tensor of the same rank as `labeled_tensor`, cropped randomly in the
576    selected dimensions.
577
578  Raises:
579    ValueError: If the shape map contains an axis name not in the input tensor.
580  """
581  with ops.name_scope(name, 'lt_random_crop', [labeled_tensor]) as scope:
582    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
583
584    for axis_name in shape_map:
585      if axis_name not in labeled_tensor.axes:
586        raise ValueError('Selection axis %s not in axes %s' %
587                         (axis_name, labeled_tensor.axes))
588
589    shape = []
590    axes = []
591    for axis in labeled_tensor.axes.values():
592      if axis.name in shape_map:
593        size = shape_map[axis.name]
594        shape.append(size)
595        # We lose labels for the axes we crop, leaving just the size.
596        axes.append((axis.name, size))
597      else:
598        shape.append(len(axis))
599        axes.append(axis)
600
601    crop_op = random_ops.random_crop(
602        labeled_tensor.tensor, shape, seed=seed, name=scope)
603
604    return core.LabeledTensor(crop_op, axes)
605
606
607# TODO(shoyer): Allow the user to select the axis over which to map.
608@tc.returns(core.LabeledTensor)
609@tc.accepts(collections.Callable, core.LabeledTensorLike,
610            tc.Optional(string_types))
611def map_fn(fn, labeled_tensor, name=None):
612  """Map on the list of tensors unpacked from labeled_tensor.
613
614  See tf.map_fn.
615
616  Args:
617    fn: The function to apply to each unpacked LabeledTensor.
618      It should have type LabeledTensor -> LabeledTensor.
619    labeled_tensor: The input tensor.
620    name: Optional op name.
621
622  Returns:
623    A tensor that packs the results of applying fn to the list of tensors
624    unpacked from labeled_tensor.
625  """
626  with ops.name_scope(name, 'lt_map_fn', [labeled_tensor]) as scope:
627    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
628
629    unpack_lts = unpack(labeled_tensor)
630
631    # TODO(ericmc): Fix this upstream.
632    if labeled_tensor.dtype == dtypes.string:
633      # We must construct the full graph here, because map_fn_lib.map_fn
634      # doesn't work for string-valued tensors.
635      # Constructing the full graph may be slow.
636      map_lts = [fn(t) for t in unpack_lts]
637      return pack(map_lts, list(labeled_tensor.axes.values())[0], name=scope)
638    else:
639      # Figure out what the axis labels should be, but use tf.map_fn to
640      # construct the graph because it's efficient.
641      # It may be slow to construct the full graph, so we infer the labels from
642      # the first element.
643      # TODO(ericmc): This builds a subgraph which then gets thrown away.
644      # Find a more elegant solution.
645      first_map_lt = fn(unpack_lts[0])
646      final_axes = list(labeled_tensor.axes.values())[:1] + list(
647          first_map_lt.axes.values())
648
649      @tc.returns(ops.Tensor)
650      @tc.accepts(ops.Tensor)
651      def tf_fn(tensor):
652        original_axes = list(labeled_tensor.axes.values())[1:]
653        tensor_lt = core.LabeledTensor(tensor, original_axes)
654        return fn(tensor_lt).tensor
655
656      map_op = map_fn_lib.map_fn(
657          tf_fn, labeled_tensor.tensor, dtype=first_map_lt.dtype)
658      map_lt = core.LabeledTensor(map_op, final_axes)
659
660      return core.identity(map_lt, name=scope)
661
662
663@tc.returns(core.LabeledTensor)
664@tc.accepts(collections.Callable, core.LabeledTensorLike,
665            core.LabeledTensorLike, tc.Optional(string_types))
666def foldl(fn, labeled_tensor, initial_value, name=None):
667  """Left fold on the list of tensors unpacked from labeled_tensor.
668
669  See tf.foldl.
670
671  Args:
672    fn: The function to apply to each unpacked LabeledTensor.
673      It should have type (LabeledTensor, LabeledTensor) -> LabeledTensor.
674      Its arguments are (accumulated_value, next_value).
675    labeled_tensor: The input tensor.
676    initial_value: The initial value of the accumulator.
677    name: Optional op name.
678
679  Returns:
680    The accumulated value.
681  """
682  with ops.name_scope(name, 'lt_foldl',
683                      [labeled_tensor, initial_value]) as scope:
684    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
685    initial_value = core.convert_to_labeled_tensor(initial_value)
686
687    @tc.returns(ops.Tensor)
688    @tc.accepts(ops.Tensor, ops.Tensor)
689    def tf_fn(accumulator, next_element):
690      accumulator_lt = core.LabeledTensor(accumulator, initial_value.axes)
691      next_element_lt = core.LabeledTensor(
692          next_element, list(labeled_tensor.axes.values())[1:])
693      return fn(accumulator_lt, next_element_lt).tensor
694
695    foldl_op = functional_ops.foldl(
696        tf_fn, labeled_tensor.tensor, initializer=initial_value.tensor)
697    foldl_lt = core.LabeledTensor(foldl_op, initial_value.axes)
698
699    return core.identity(foldl_lt, name=scope)
700
701
702@tc.returns(core.LabeledTensor)
703@tc.accepts(core.LabeledTensorLike,
704            tc.Optional(tc.Collection(string_types)), tc.Optional(string_types))
705def squeeze(labeled_tensor, axis_names=None, name=None):
706  """Remove size-1 dimensions.
707
708  See tf.squeeze.
709
710  Args:
711    labeled_tensor: The input tensor.
712    axis_names: The names of the dimensions to remove, or None to remove
713      all size-1 dimensions.
714    name: Optional op name.
715
716  Returns:
717    A tensor with the specified dimensions removed.
718
719  Raises:
720    ValueError: If the named axes are not in the tensor, or if they are
721      not size-1.
722  """
723  with ops.name_scope(name, 'lt_squeeze', [labeled_tensor]) as scope:
724    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
725
726    if axis_names is None:
727      axis_names = [a.name for a in labeled_tensor.axes.values() if len(a) == 1]
728
729    for axis_name in axis_names:
730      if axis_name not in labeled_tensor.axes:
731        raise ValueError('axis %s is not in tensor axes %s' %
732                         (axis_name, labeled_tensor.axes))
733      elif len(labeled_tensor.axes[axis_name]) != 1:
734        raise ValueError(
735            'cannot squeeze axis with size greater than 1: (%s, %s)' %
736            (axis_name, labeled_tensor.axes[axis_name]))
737
738    squeeze_dimensions = []
739    axes = []
740    for i, axis in enumerate(labeled_tensor.axes.values()):
741      if axis.name in axis_names:
742        squeeze_dimensions.append(i)
743      else:
744        axes.append(axis)
745
746    if squeeze_dimensions:
747      squeeze_op = array_ops.squeeze(
748          labeled_tensor.tensor, squeeze_dimensions, name=scope)
749    else:
750      squeeze_op = array_ops.identity(labeled_tensor.tensor, name=scope)
751
752    return core.LabeledTensor(squeeze_op, axes)
753
754
755# pylint: disable=invalid-name
756ReduceAxis = tc.Union(string_types,
757                      tc.Tuple(string_types, collections.Hashable))
758ReduceAxes = tc.Optional(tc.Union(ReduceAxis, tc.Collection(ReduceAxis)))
759# pylint: enable=invalid-name
760
761
762@tc.returns(core.LabeledTensor)
763@tc.accepts(core.LabeledTensorLike, core.LabeledTensorLike,
764            tc.Optional(string_types))
765def matmul(a, b, name=None):
766  """Matrix multiply two tensors with rank 1 or 2.
767
768  If both tensors have rank 2, a matrix-matrix product is performed.
769  If one tensor has rank 1 and the other has rank 2, then a matrix-vector
770  product is performed.
771  If both tensors have rank 1, then a vector dot-product is performed.
772  (This behavior matches that of `numpy.dot`.)
773
774  Both tensors must share exactly one dimension in common, which is the
775  dimension the operation is summed along. The inputs will be automatically
776  transposed if necessary as part of the matmul op.
777
778  We intend to eventually support `matmul` on higher rank input, and also
779  eventually support summing over any number shared dimensions (via an `axis`
780  argument), but neither of these features has been implemented yet.
781
782  Args:
783    a: First LabeledTensor.
784    b: Second LabeledTensor.
785    name: Optional op name.
786
787  Returns:
788    LabeledTensor with the result of matrix multiplication. Axes are ordered by
789    the current axis_order_scope, if set, or in or order of appearance on the
790    inputs.
791
792  Raises:
793    NotImplementedError: If inputs have rank >2 or share multiple axes.
794    ValueError: If the inputs have rank 0 or do not share any axes.
795  """
796  with ops.name_scope(name, 'lt_matmul', [a, b]) as scope:
797
798    a = core.convert_to_labeled_tensor(a)
799    b = core.convert_to_labeled_tensor(b)
800
801    if len(a.axes) > 2 or len(b.axes) > 2:
802      # We could pass batched inputs to tf.matmul to make this work, but we
803      # would also need to use tf.tile and/or tf.transpose. These are more
804      # expensive than doing reshapes, so it's not clear if it's a good idea to
805      # do this automatically.
806      raise NotImplementedError(
807          'matmul currently requires inputs with rank 2 or less, but '
808          'inputs have ranks %r and %r' % (len(a.axes), len(b.axes)))
809
810    if not a.axes or not b.axes:
811      raise ValueError(
812          'matmul currently requires inputs with at least rank 1, but '
813          'inputs have ranks %r and %r' % (len(a.axes), len(b.axes)))
814
815    shared_axes = set(a.axes) & set(b.axes)
816    if len(shared_axes) > 1:
817      raise NotImplementedError(
818          'matmul does not yet support summing over multiple shared axes: %r. '
819          'Use transpose and reshape to create a single shared axis to sum '
820          'over.' % shared_axes)
821    if not shared_axes:
822      raise ValueError('there must have exactly one axis in common between '
823                       'input to matmul: %r, %r' %
824                       (a.axes.keys(), b.axes.keys()))
825    shared_axis, = shared_axes
826
827    if a.axes[shared_axis] != b.axes[shared_axis]:
828      raise ValueError('axis %r does not match on input arguments: %r vs %r' %
829                       (shared_axis, a.axes[shared_axis].value,
830                        b.axes[shared_axis].value))
831
832    result_axes = []
833    for axes in [a.axes, b.axes]:
834      for axis in axes.values():
835        if axis.name != shared_axis:
836          result_axes.append(axis)
837
838    axis_scope_order = core.get_axis_order()
839    if axis_scope_order is not None:
840      result_axis_names = [axis.name for axis in result_axes]
841      new_axis_names = [
842          name for name in axis_scope_order if name in result_axis_names
843      ]
844      if new_axis_names != result_axis_names:
845        # switch a and b
846        b, a = a, b
847        # result_axes is a list of length 1 or 2
848        result_axes = result_axes[::-1]
849
850    squeeze_dims = []
851
852    if len(a.axes) == 1:
853      a_tensor = array_ops.reshape(a.tensor, (1, -1))
854      squeeze_dims.append(0)
855      transpose_a = False
856    else:
857      a_tensor = a.tensor
858      transpose_a = list(a.axes.keys()).index(shared_axis) == 0
859
860    if len(b.axes) == 1:
861      b_tensor = array_ops.reshape(b.tensor, (-1, 1))
862      squeeze_dims.append(1)
863      transpose_b = False
864    else:
865      b_tensor = b.tensor
866      transpose_b = list(b.axes.keys()).index(shared_axis) == 1
867
868    result_op = math_ops.matmul(
869        a_tensor, b_tensor, transpose_a=transpose_a, transpose_b=transpose_b)
870
871    if squeeze_dims:
872      result_op = array_ops.squeeze(result_op, squeeze_dims)
873    result_op = array_ops.identity(result_op, name=scope)
874
875    return core.LabeledTensor(result_op, result_axes)
876
877
878@tc.returns(types.FunctionType)
879@tc.accepts(string_types, collections.Callable)
880def define_reduce_op(op_name, reduce_fn):
881  """Define a reduction op for labeled tensors.
882
883  Args:
884    op_name: string name of the TensorFlow op.
885    reduce_fn: function to call to evaluate the op on a tf.Tensor.
886
887  Returns:
888    Function defining the given reduction op that acts on a LabeledTensor.
889  """
890
891  default_name = 'lt_%s' % op_name
892
893  @tc.returns(core.LabeledTensor)
894  @tc.accepts(core.LabeledTensorLike, ReduceAxes, tc.Optional(string_types))
895  def op(labeled_tensor, axes=None, name=None):
896    """Computes the given reduction across the given axes of a LabeledTensor.
897
898    See `tf.{op_name}` for full details.
899
900    Args:
901      labeled_tensor: The input tensor.
902      axes: A set of axes or None.
903        If None, all axes will be reduced.
904        Axes must all be strings, in which case those dimensions will be
905        removed, or pairs of (name, None) or (name, label), in which case those
906        dimensions will be kept.
907      name: Optional op name.
908
909    Returns:
910      The reduced LabeledTensor.
911
912    Raises:
913      ValueError: if any of the axes to reduce over are not found on
914        `labeled_tensor`.
915    """
916    with ops.name_scope(name, default_name, [labeled_tensor]) as scope:
917      labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
918
919      if axes is None:
920        axes = labeled_tensor.axes.keys()
921
922      if isinstance(axes, (string_types, tuple)):
923        axes = [axes]
924
925      reduction_axes = {}
926      axes_to_squeeze = []
927      for a in axes:
928        if isinstance(a, string_types):
929          # We squeeze out this axis.
930          reduction_axes[a] = a
931          axes_to_squeeze.append(a)
932        else:
933          # We keep this axis, with the user-provided labels.
934          (axis_name, label) = a
935          if label is not None:
936            # The input was a single label, so make it a list so it can be
937            # turned into an Axis.
938            label = [label]
939          reduction_axes[axis_name] = (axis_name, label)
940
941      for axis_name in reduction_axes:
942        if axis_name not in labeled_tensor.axes:
943          raise ValueError('Axis %s not in axes %s' %
944                           (axis_name, labeled_tensor.axes))
945
946      intermediate_axes = []
947      reduction_dimensions = []
948      for i, axis in enumerate(labeled_tensor.axes.values()):
949        if axis.name in reduction_axes:
950          intermediate_axes.append(reduction_axes[axis.name])
951          reduction_dimensions.append(i)
952        else:
953          intermediate_axes.append(axis)
954
955      reduce_op = reduce_fn(
956          labeled_tensor.tensor, reduction_dimensions, keepdims=True)
957      reduce_lt = core.LabeledTensor(reduce_op, intermediate_axes)
958
959      return squeeze(reduce_lt, axes_to_squeeze, name=scope)
960
961  op.__doc__ = op.__doc__.format(op_name=op_name)
962  op.__name__ = op_name
963
964  return op
965
966
967reduce_all = define_reduce_op('reduce_all', math_ops.reduce_all)
968reduce_any = define_reduce_op('reduce_any', math_ops.reduce_any)
969reduce_logsumexp = define_reduce_op('reduce_logsumexp',
970                                    math_ops.reduce_logsumexp)
971reduce_max = define_reduce_op('reduce_max', math_ops.reduce_max)
972reduce_mean = define_reduce_op('reduce_mean', math_ops.reduce_mean)
973reduce_min = define_reduce_op('reduce_min', math_ops.reduce_min)
974reduce_prod = define_reduce_op('reduce_prod', math_ops.reduce_prod)
975reduce_sum = define_reduce_op('reduce_sum', math_ops.reduce_sum)
976
977
978@tc.returns(core.LabeledTensor)
979@tc.accepts(core.LabeledTensorLike,
980            tc.Mapping(str, tc.Union(int, ops.Tensor)),
981            tc.Optional(string_types))
982def tile(labeled_tensor, multiples, name=None):
983  """Constructs a tensor by tiling a given tensor.
984
985  Only axes without tick-labels can be tiled. (Otherwise, axis labels on tiled
986  tensors would no longer be unique.)
987
988  See lt.tile.
989
990  Args:
991    labeled_tensor: The input tensor.
992    multiples: A mapping where the keys are axis names and the values are the
993      integer number of times to tile along that axis. Only axes with a multiple
994      different than 1 need be included.
995    name: Optional op name.
996
997  Returns:
998    A tensor with the indicated axes tiled.
999
1000  Raises:
1001    ValueError: If the tiled axes are not axes in the input tensor, or if any
1002      axes in multiples have tick labels.
1003  """
1004  with ops.name_scope(name, 'lt_tile', [labeled_tensor]) as scope:
1005    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
1006
1007    if not set(multiples.keys()) <= set(labeled_tensor.axes.keys()):
1008      raise ValueError('tile axes %r are not contained in the set of axis '
1009                       'names %r on the input labeled tensor' %
1010                       (multiples.keys(), labeled_tensor.axes))
1011
1012    labeled_axes = [
1013        name for name in multiples
1014        if labeled_tensor.axes[name].labels is not None
1015    ]
1016    if labeled_axes:
1017      raise ValueError('cannot tile axes with tick labels: %r' % labeled_axes)
1018
1019    multiples_list = [multiples.get(name, 1) for name in labeled_tensor.axes]
1020    tile_op = array_ops.tile(labeled_tensor.tensor, multiples_list, name=scope)
1021
1022    new_axes = [
1023        axis.name if axis.labels is None else axis
1024        for axis in labeled_tensor.axes.values()
1025    ]
1026    return core.LabeledTensor(tile_op, new_axes)
1027
1028
1029@tc.returns(core.LabeledTensor)
1030@tc.accepts(core.LabeledTensorLike,
1031            tc.Mapping(str, tc.Tuple(core.AxisValue, core.AxisValue)),
1032            string_types, tc.Optional(string_types))
1033def pad(labeled_tensor, paddings, mode='CONSTANT', name=None):
1034  """Pads a tensor.
1035
1036  See tf.pad.
1037
1038  Args:
1039    labeled_tensor: The input tensor.
1040    paddings: A mapping where the keys are axis names and the values are
1041      tuples where the first element is the padding to insert at the beginning
1042      of the axis and the second is the padding to insert at the end of the
1043      axis.
1044    mode: One of "CONSTANT", "REFLECT", or "SYMMETRIC".
1045    name: Optional op name.
1046
1047  Returns:
1048    A tensor with the indicated axes padded, optionally with those axes extended
1049    with the provided labels.
1050
1051  Raises:
1052    ValueError: If the padded axes are not axes in the input tensor.
1053  """
1054  with ops.name_scope(name, 'lt_pad', [labeled_tensor]) as scope:
1055    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
1056
1057    if not set(paddings.keys()) <= set(labeled_tensor.axes.keys()):
1058      raise ValueError('pad axes %r are not contained in the set of axis '
1059                       'names %r on the input labeled tensor' %
1060                       (paddings.keys(), labeled_tensor.axes))
1061
1062    new_axes = []
1063    padding_pairs = []
1064    for name, axis in labeled_tensor.axes.items():
1065      if name in paddings:
1066        padding_before, padding_after = paddings[name]
1067        axis_before = core.Axis(name, padding_before)
1068        axis_after = core.Axis(name, padding_after)
1069        new_axes.append(core.concat_axes([axis_before, axis, axis_after]))
1070        padding_pairs.append((len(axis_before), len(axis_after)))
1071      else:
1072        new_axes.append(axis)
1073        padding_pairs.append((0, 0))
1074
1075    pad_op = array_ops.pad(labeled_tensor.tensor,
1076                           padding_pairs,
1077                           mode,
1078                           name=scope)
1079
1080    return core.LabeledTensor(pad_op, new_axes)
1081
1082
1083@tc.returns(core.LabeledTensor)
1084@tc.accepts(
1085    tc.Union(np.ndarray, list, tuple, core.Scalar),
1086    tc.Optional(dtypes.DType),
1087    tc.Optional(
1088        tc.Union(core.Axes, tc.Collection(
1089            tc.Union(string_types, core.AxisLike)))), tc.Optional(string_types))
1090def constant(value, dtype=None, axes=None, name=None):
1091  """Creates a constant tensor.
1092
1093  If `axes` includes any strings, shape is inferred from `value`. Otherwise,
1094  the sizes of the given `axes` are used to set `shape` for `tf.constant`.
1095
1096  See tf.constant for more details.
1097
1098  Args:
1099    value: The input tensor.
1100    dtype: The type of the returned tensor.
1101    axes: Optional Axes, list of strings or list of objects coercible to Axis
1102      objects. By default, axes are assumed to be an empty list (i.e., `value`
1103      is treated as a scalar).
1104    name: Optional op name.
1105
1106  Returns:
1107    The tensor with elements set to zero.
1108  """
1109  with ops.name_scope(name, 'lt_constant', [value]) as scope:
1110
1111    if axes is None:
1112      axes = []
1113
1114    if isinstance(axes, core.Axes):
1115      axes = axes.values()
1116
1117    if any(isinstance(ax, string_types) for ax in axes):
1118      # need to infer shape
1119      shape = None
1120    else:
1121      # axes already indicate shape
1122      axes = [core.as_axis(a) for a in axes]
1123      shape = [a.size for a in axes]
1124
1125    op = array_ops.constant(value, dtype=dtype, shape=shape, name=scope)
1126    return core.LabeledTensor(op, axes)
1127
1128
1129@tc.returns(core.LabeledTensor)
1130@tc.accepts(core.LabeledTensorLike,
1131            tc.Optional(dtypes.DType), tc.Optional(string_types))
1132def zeros_like(labeled_tensor, dtype=None, name=None):
1133  """Creates an identical tensor with all elements set to zero.
1134
1135  Args:
1136    labeled_tensor: The input tensor.
1137    dtype: The type of the returned tensor.
1138    name: Optional op name.
1139
1140  Returns:
1141    The tensor with elements set to zero.
1142  """
1143  with ops.name_scope(name, 'lt_zeros_like', [labeled_tensor]) as scope:
1144    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
1145    op = array_ops.zeros_like(labeled_tensor.tensor, dtype=dtype, name=scope)
1146    return core.LabeledTensor(op, labeled_tensor.axes)
1147
1148
1149@tc.returns(core.LabeledTensor)
1150@tc.accepts(core.LabeledTensorLike,
1151            tc.Optional(dtypes.DType), tc.Optional(string_types))
1152def ones_like(labeled_tensor, dtype=None, name=None):
1153  """Creates an identical tensor with all elements set to one.
1154
1155  Args:
1156    labeled_tensor: The input tensor.
1157    dtype: The type of the returned tensor.
1158    name: Optional op name.
1159
1160  Returns:
1161    The tensor with elements set to one.
1162  """
1163  with ops.name_scope(name, 'lt_ones_like', [labeled_tensor]) as scope:
1164    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
1165    op = array_ops.ones_like(labeled_tensor.tensor, dtype=dtype, name=scope)
1166    return core.LabeledTensor(op, labeled_tensor.axes)
1167
1168
1169@tc.returns(core.LabeledTensor)
1170@tc.accepts(core.LabeledTensorLike,
1171            tc.Optional(dtypes.DType), tc.Optional(string_types))
1172def cast(labeled_tensor, dtype=None, name=None):
1173  """Casts a labeled tensor to a new type.
1174
1175  Args:
1176    labeled_tensor: The input tensor.
1177    dtype: The type of the returned tensor.
1178    name: Optional op name.
1179
1180  Returns:
1181    A labeled tensor with the new dtype.
1182  """
1183  with ops.name_scope(name, 'lt_cast', [labeled_tensor]) as scope:
1184    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
1185    op = math_ops.cast(labeled_tensor.tensor, dtype=dtype, name=scope)
1186    return core.LabeledTensor(op, labeled_tensor.axes)
1187
1188
1189@tc.returns(core.LabeledTensor)
1190@tc.accepts(core.LabeledTensorLike, string_types, tc.Optional(string_types))
1191def verify_tensor_all_finite(labeled_tensor, message, name=None):
1192  """Asserts a tensor doesn't contain NaNs or Infs.
1193
1194  See tf.verify_tensor_all_finite.
1195
1196  Args:
1197    labeled_tensor: The input tensor.
1198    message: Message to log on failure.
1199    name: Optional op name.
1200
1201  Returns:
1202    The input tensor.
1203  """
1204  with ops.name_scope(name, 'lt_verify_tensor_all_finite',
1205                      [labeled_tensor]) as scope:
1206    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
1207    op = numerics.verify_tensor_all_finite(
1208        labeled_tensor.tensor, msg=message, name=scope)
1209    return core.LabeledTensor(op, labeled_tensor.axes)
1210
1211
1212@tc.returns(core.LabeledTensor)
1213@tc.accepts(core.LabeledTensorLike, core.LabeledTensorLike,
1214            tc.Optional(string_types))
1215def boolean_mask(labeled_tensor, mask, name=None):
1216  """Apply a boolean mask to a labeled tensor.
1217
1218  Unlike `tf.boolean_mask`, this currently only works on 1-dimensional masks.
1219  The mask is applied to the first axis of `labeled_tensor`. Labels on the first
1220  axis are removed, because True indices in `mask` may not be known dynamically.
1221
1222  Args:
1223    labeled_tensor: The input tensor.
1224    mask: The type of the returned tensor.
1225    name: Optional op name.
1226
1227  Returns:
1228    The masked labeled tensor.
1229
1230  Raises:
1231    ValueError: if the first axis of the mask
1232  """
1233  with ops.name_scope(name, 'lt_boolean_mask', [labeled_tensor, mask]) as scope:
1234    labeled_tensor = core.convert_to_labeled_tensor(labeled_tensor)
1235    mask = core.convert_to_labeled_tensor(mask)
1236
1237    if len(mask.axes) > 1:
1238      raise NotImplementedError(
1239          "LabeledTensor's boolean_mask currently only supports 1D masks")
1240    mask_axis = list(mask.axes.values())[0]
1241    lt_axis = list(labeled_tensor.axes.values())[0]
1242    if mask_axis != lt_axis:
1243      raise ValueError('the first axis of the labeled tensor and the mask '
1244                       'are not equal:\n%r\n%r' % (lt_axis, mask_axis))
1245    op = array_ops.boolean_mask(labeled_tensor.tensor, mask.tensor, name=scope)
1246    # TODO(shoyer): attempt to infer labels for the masked values, by calling
1247    # tf.contrib.util.constant_value on the mask?
1248    axes = [lt_axis.name] + list(labeled_tensor.axes.values())[1:]
1249    return core.LabeledTensor(op, axes)
1250
1251
1252@tc.returns(core.LabeledTensor)
1253@tc.accepts(core.LabeledTensorLike, core.LabeledTensorLike,
1254            core.LabeledTensorLike, tc.Optional(string_types))
1255def where(condition, x, y, name=None):
1256  """Return elements from x or y depending on condition.
1257
1258  See `tf.where` for more details. This function currently only implements the
1259  three argument version of where.
1260
1261  Args:
1262    condition: LabeledTensor of type `bool`.
1263    x: LabeledTensor for values where condition is true.
1264    y: LabeledTensor for values where condition is false.
1265    name: Optional op name.
1266
1267  Returns:
1268    The labeled tensor with values according to condition.
1269
1270  Raises:
1271    ValueError: if `x` and `y` have different axes, or if the axes of `x` do not
1272      start with the axes of `condition`.
1273  """
1274  with ops.name_scope(name, 'lt_where', [condition, x, y]) as scope:
1275    condition = core.convert_to_labeled_tensor(condition)
1276    x = core.convert_to_labeled_tensor(x)
1277    y = core.convert_to_labeled_tensor(y)
1278
1279    if not condition.axes == x.axes == y.axes:
1280      raise ValueError('all inputs to `where` must have equal axes')
1281
1282    op = array_ops.where(condition.tensor, x.tensor, y.tensor, name=scope)
1283    return core.LabeledTensor(op, x.axes)
1284