1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper classes for tensor shape inference."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.core.framework import tensor_shape_pb2
21from tensorflow.python.util import compat
22from tensorflow.python.util.tf_export import tf_export
23
24
25@tf_export("Dimension")
26class Dimension(object):
27  """Represents the value of one dimension in a TensorShape."""
28
29  def __init__(self, value):
30    """Creates a new Dimension with the given value."""
31    if value is None:
32      self._value = None
33    else:
34      self._value = int(value)
35      if (not isinstance(value, compat.bytes_or_text_types) and
36          self._value != value):
37        raise ValueError("Ambiguous dimension: %s" % value)
38      if self._value < 0:
39        raise ValueError("Dimension %d must be >= 0" % self._value)
40
41  def __repr__(self):
42    return "Dimension(%s)" % repr(self._value)
43
44  def __str__(self):
45    value = self._value
46    return "?" if value is None else str(value)
47
48  def __eq__(self, other):
49    """Returns true if `other` has the same known value as this Dimension."""
50    try:
51      other = as_dimension(other)
52    except (TypeError, ValueError):
53      return NotImplemented
54    if self._value is None or other.value is None:
55      return None
56    return self._value == other.value
57
58  def __ne__(self, other):
59    """Returns true if `other` has a different known value from `self`."""
60    try:
61      other = as_dimension(other)
62    except (TypeError, ValueError):
63      return NotImplemented
64    if self._value is None or other.value is None:
65      return None
66    return self._value != other.value
67
68  def __int__(self):
69    return self._value
70
71  # This is needed for Windows.
72  # See https://github.com/tensorflow/tensorflow/pull/9780
73  def __long__(self):
74    return self._value
75
76  def __index__(self):
77    # Allow use in Python 3 range
78    return self._value
79
80  @property
81  def value(self):
82    """The value of this dimension, or None if it is unknown."""
83    return self._value
84
85  def is_compatible_with(self, other):
86    """Returns true if `other` is compatible with this Dimension.
87
88    Two known Dimensions are compatible if they have the same value.
89    An unknown Dimension is compatible with all other Dimensions.
90
91    Args:
92      other: Another Dimension.
93
94    Returns:
95      True if this Dimension and `other` are compatible.
96    """
97    other = as_dimension(other)
98    return (self._value is None or other.value is None or
99            self._value == other.value)
100
101  def assert_is_compatible_with(self, other):
102    """Raises an exception if `other` is not compatible with this Dimension.
103
104    Args:
105      other: Another Dimension.
106
107    Raises:
108      ValueError: If `self` and `other` are not compatible (see
109        is_compatible_with).
110    """
111    if not self.is_compatible_with(other):
112      raise ValueError("Dimensions %s and %s are not compatible" % (self,
113                                                                    other))
114
115  def merge_with(self, other):
116    """Returns a Dimension that combines the information in `self` and `other`.
117
118    Dimensions are combined as follows:
119
120    ```python
121    tf.Dimension(n)   .merge_with(tf.Dimension(n))    == tf.Dimension(n)
122    tf.Dimension(n)   .merge_with(tf.Dimension(None)) == tf.Dimension(n)
123    tf.Dimension(None).merge_with(tf.Dimension(n))    == tf.Dimension(n)
124    tf.Dimension(None).merge_with(tf.Dimension(None)) == tf.Dimension(None)
125    tf.Dimension(n)   .merge_with(tf.Dimension(m))  # raises ValueError for n != m
126    ```
127
128    Args:
129      other: Another Dimension.
130
131    Returns:
132      A Dimension containing the combined information of `self` and
133      `other`.
134
135    Raises:
136      ValueError: If `self` and `other` are not compatible (see
137        is_compatible_with).
138    """
139    other = as_dimension(other)
140    self.assert_is_compatible_with(other)
141    if self._value is None:
142      return Dimension(other.value)
143    else:
144      return Dimension(self._value)
145
146  def __add__(self, other):
147    """Returns the sum of `self` and `other`.
148
149    Dimensions are summed as follows:
150
151    ```python
152    tf.Dimension(m)    + tf.Dimension(n)    == tf.Dimension(m + n)
153    tf.Dimension(m)    + tf.Dimension(None) == tf.Dimension(None)
154    tf.Dimension(None) + tf.Dimension(n)    == tf.Dimension(None)
155    tf.Dimension(None) + tf.Dimension(None) == tf.Dimension(None)
156    ```
157
158    Args:
159      other: Another Dimension.
160
161    Returns:
162      A Dimension whose value is the sum of `self` and `other`.
163    """
164    other = as_dimension(other)
165    if self._value is None or other.value is None:
166      return Dimension(None)
167    else:
168      return Dimension(self._value + other.value)
169
170  def __sub__(self, other):
171    """Returns the subtraction of `other` from `self`.
172
173    Dimensions are subtracted as follows:
174
175    ```python
176    tf.Dimension(m)    - tf.Dimension(n)    == tf.Dimension(m - n)
177    tf.Dimension(m)    - tf.Dimension(None) == tf.Dimension(None)
178    tf.Dimension(None) - tf.Dimension(n)    == tf.Dimension(None)
179    tf.Dimension(None) - tf.Dimension(None) == tf.Dimension(None)
180    ```
181
182    Args:
183      other: Another Dimension.
184
185    Returns:
186      A Dimension whose value is the subtraction of sum of `other` from `self`.
187    """
188    other = as_dimension(other)
189    if self._value is None or other.value is None:
190      return Dimension(None)
191    else:
192      return Dimension(self._value - other.value)
193
194  def __mul__(self, other):
195    """Returns the product of `self` and `other`.
196
197    Dimensions are summed as follows:
198
199    ```python
200    tf.Dimension(m)    * tf.Dimension(n)    == tf.Dimension(m * n)
201    tf.Dimension(m)    * tf.Dimension(None) == tf.Dimension(None)
202    tf.Dimension(None) * tf.Dimension(n)    == tf.Dimension(None)
203    tf.Dimension(None) * tf.Dimension(None) == tf.Dimension(None)
204    ```
205
206    Args:
207      other: Another Dimension.
208
209    Returns:
210      A Dimension whose value is the product of `self` and `other`.
211    """
212    other = as_dimension(other)
213    if self._value is None or other.value is None:
214      return Dimension(None)
215    else:
216      return Dimension(self._value * other.value)
217
218  def __floordiv__(self, other):
219    """Returns the quotient of `self` and `other` rounded down.
220
221    Dimensions are divided as follows:
222
223    ```python
224    tf.Dimension(m)    // tf.Dimension(n)    == tf.Dimension(m // n)
225    tf.Dimension(m)    // tf.Dimension(None) == tf.Dimension(None)
226    tf.Dimension(None) // tf.Dimension(n)    == tf.Dimension(None)
227    tf.Dimension(None) // tf.Dimension(None) == tf.Dimension(None)
228    ```
229
230    Args:
231      other: Another `Dimension`.
232
233    Returns:
234      A `Dimension` whose value is the integer quotient of `self` and `other`.
235    """
236    other = as_dimension(other)
237    if self._value is None or other.value is None:
238      return Dimension(None)
239    else:
240      return Dimension(self._value // other.value)
241
242  def __div__(self, other):
243    """DEPRECATED: Use `__floordiv__` via `x // y` instead.
244
245    This function exists only for backwards compatibility purposes; new code
246    should use `__floordiv__` via the syntax `x // y`.  Using `x // y`
247    communicates clearly that the result rounds down, and is forward compatible
248    to Python 3.
249
250    Args:
251      other: Another `Dimension`.
252
253    Returns:
254      A `Dimension` whose value is the integer quotient of `self` and `other`.
255    """
256    return self // other
257
258  def __mod__(self, other):
259    """Returns `self` modulo `other.
260
261    Dimension moduli are computed as follows:
262
263    ```python
264    tf.Dimension(m)    % tf.Dimension(n)    == tf.Dimension(m % n)
265    tf.Dimension(m)    % tf.Dimension(None) == tf.Dimension(None)
266    tf.Dimension(None) % tf.Dimension(n)    == tf.Dimension(None)
267    tf.Dimension(None) % tf.Dimension(None) == tf.Dimension(None)
268    ```
269
270    Args:
271      other: Another Dimension.
272
273    Returns:
274      A Dimension whose value is `self` modulo `other`.
275    """
276    other = as_dimension(other)
277    if self._value is None or other.value is None:
278      return Dimension(None)
279    else:
280      return Dimension(self._value % other.value)
281
282  def __lt__(self, other):
283    """Returns True if `self` is known to be less than `other`.
284
285    Dimensions are compared as follows:
286
287    ```python
288    (tf.Dimension(m)    < tf.Dimension(n))    == (m < n)
289    (tf.Dimension(m)    < tf.Dimension(None)) == None
290    (tf.Dimension(None) < tf.Dimension(n))    == None
291    (tf.Dimension(None) < tf.Dimension(None)) == None
292    ```
293
294    Args:
295      other: Another Dimension.
296
297    Returns:
298      The value of `self.value < other.value` if both are known, otherwise
299      None.
300    """
301    other = as_dimension(other)
302    if self._value is None or other.value is None:
303      return None
304    else:
305      return self._value < other.value
306
307  def __le__(self, other):
308    """Returns True if `self` is known to be less than or equal to `other`.
309
310    Dimensions are compared as follows:
311
312    ```python
313    (tf.Dimension(m)    <= tf.Dimension(n))    == (m <= n)
314    (tf.Dimension(m)    <= tf.Dimension(None)) == None
315    (tf.Dimension(None) <= tf.Dimension(n))    == None
316    (tf.Dimension(None) <= tf.Dimension(None)) == None
317    ```
318
319    Args:
320      other: Another Dimension.
321
322    Returns:
323      The value of `self.value <= other.value` if both are known, otherwise
324      None.
325    """
326    other = as_dimension(other)
327    if self._value is None or other.value is None:
328      return None
329    else:
330      return self._value <= other.value
331
332  def __gt__(self, other):
333    """Returns True if `self` is known to be greater than `other`.
334
335    Dimensions are compared as follows:
336
337    ```python
338    (tf.Dimension(m)    > tf.Dimension(n))    == (m > n)
339    (tf.Dimension(m)    > tf.Dimension(None)) == None
340    (tf.Dimension(None) > tf.Dimension(n))    == None
341    (tf.Dimension(None) > tf.Dimension(None)) == None
342    ```
343
344    Args:
345      other: Another Dimension.
346
347    Returns:
348      The value of `self.value > other.value` if both are known, otherwise
349      None.
350    """
351    other = as_dimension(other)
352    if self._value is None or other.value is None:
353      return None
354    else:
355      return self._value > other.value
356
357  def __ge__(self, other):
358    """Returns True if `self` is known to be greater than or equal to `other`.
359
360    Dimensions are compared as follows:
361
362    ```python
363    (tf.Dimension(m)    >= tf.Dimension(n))    == (m >= n)
364    (tf.Dimension(m)    >= tf.Dimension(None)) == None
365    (tf.Dimension(None) >= tf.Dimension(n))    == None
366    (tf.Dimension(None) >= tf.Dimension(None)) == None
367    ```
368
369    Args:
370      other: Another Dimension.
371
372    Returns:
373      The value of `self.value >= other.value` if both are known, otherwise
374      None.
375    """
376    other = as_dimension(other)
377    if self._value is None or other.value is None:
378      return None
379    else:
380      return self._value >= other.value
381
382
383def as_dimension(value):
384  """Converts the given value to a Dimension.
385
386  A Dimension input will be returned unmodified.
387  An input of `None` will be converted to an unknown Dimension.
388  An integer input will be converted to a Dimension with that value.
389
390  Args:
391    value: The value to be converted.
392
393  Returns:
394    A Dimension corresponding to the given value.
395  """
396  if isinstance(value, Dimension):
397    return value
398  else:
399    return Dimension(value)
400
401
402@tf_export("TensorShape")
403class TensorShape(object):
404  """Represents the shape of a `Tensor`.
405
406  A `TensorShape` represents a possibly-partial shape specification for a
407  `Tensor`. It may be one of the following:
408
409  * *Fully-known shape:* has a known number of dimensions and a known size
410    for each dimension. e.g. `TensorShape([16, 256])`
411  * *Partially-known shape:* has a known number of dimensions, and an unknown
412    size for one or more dimension. e.g. `TensorShape([None, 256])`
413  * *Unknown shape:* has an unknown number of dimensions, and an unknown
414    size in all dimensions. e.g. `TensorShape(None)`
415
416  If a tensor is produced by an operation of type `"Foo"`, its shape
417  may be inferred if there is a registered shape function for
418  `"Foo"`. See @{$adding_an_op#shape-functions-in-c$`Shape functions in C++`}
419  for details of shape functions and how to register them. Alternatively,
420  the shape may be set explicitly using @{tf.Tensor.set_shape}.
421  """
422
423  def __init__(self, dims):
424    """Creates a new TensorShape with the given dimensions.
425
426    Args:
427      dims: A list of Dimensions, or None if the shape is unspecified.
428        DEPRECATED: A single integer is treated as a singleton list.
429
430    Raises:
431      TypeError: If dims cannot be converted to a list of dimensions.
432    """
433    # TODO(irving): Eliminate the single integer special case.
434    if dims is None:
435      self._dims = None
436    elif isinstance(dims, compat.bytes_or_text_types):
437      raise TypeError("A string has ambiguous TensorShape, please wrap in a "
438                      "list or convert to an int: %s" % dims)
439    elif isinstance(dims, tensor_shape_pb2.TensorShapeProto):
440      if dims.unknown_rank:
441        self._dims = None
442      else:
443        self._dims = [
444            # Protos store variable-size dimensions as -1
445            as_dimension(dim.size if dim.size != -1 else None)
446            for dim in dims.dim
447        ]
448    elif isinstance(dims, TensorShape):
449      self._dims = dims.dims
450    else:
451      try:
452        dims_iter = iter(dims)
453      except TypeError:
454        # Treat as a singleton dimension
455        self._dims = [as_dimension(dims)]
456      else:
457        # Got a list of dimensions
458        self._dims = [as_dimension(d) for d in dims_iter]
459
460  def __repr__(self):
461    return "TensorShape(%r)" % self._dims
462
463  def __str__(self):
464    if self.ndims is None:
465      return "<unknown>"
466    elif self.ndims == 1:
467      return "(%s,)" % self._dims[0]
468    else:
469      return "(%s)" % ", ".join(str(d) for d in self._dims)
470
471  @property
472  def dims(self):
473    """Returns a list of Dimensions, or None if the shape is unspecified."""
474    return self._dims
475
476  @property
477  def ndims(self):
478    """Returns the rank of this shape, or None if it is unspecified."""
479    if self._dims is None:
480      return None
481    else:
482      return len(self._dims)
483
484  def __len__(self):
485    """Returns the rank of this shape, or raises ValueError if unspecified."""
486    if self._dims is None:
487      raise ValueError("Cannot take the length of Shape with unknown rank.")
488    return len(self._dims)
489
490  def __bool__(self):
491    """Returns True if this shape contains non-zero information."""
492    return self._dims is not None
493
494  # Python 3 wants __bool__, Python 2.7 wants __nonzero__
495  __nonzero__ = __bool__
496
497  def __iter__(self):
498    """Returns `self.dims` if the rank is known, otherwise raises ValueError."""
499    if self._dims is None:
500      raise ValueError("Cannot iterate over a shape with unknown rank.")
501    else:
502      return iter(self._dims)
503
504  def __getitem__(self, key):
505    """Returns the value of a dimension or a shape, depending on the key.
506
507    Args:
508      key: If `key` is an integer, returns the dimension at that index;
509        otherwise if `key` is a slice, returns a TensorShape whose
510        dimensions are those selected by the slice from `self`.
511
512    Returns:
513      A dimension if `key` is an integer, or a `TensorShape` if `key` is a
514      slice.
515
516    Raises:
517      ValueError: If `key` is a slice, and any of its elements are negative, or
518        if `self` is completely unknown and the step is set.
519    """
520    if self._dims is not None:
521      if isinstance(key, slice):
522        return TensorShape(self._dims[key])
523      else:
524        return self._dims[key]
525    else:
526      if isinstance(key, slice):
527        start = key.start if key.start is not None else 0
528        stop = key.stop
529
530        if key.step is not None:
531          # TODO(mrry): Handle these maybe.
532          raise ValueError("Steps are not yet handled")
533        if stop is None:
534          # NOTE(mrry): This implies that TensorShape(None) is compatible with
535          # TensorShape(None)[1:], which is obviously not true. It would be
536          # possible to track the number of dimensions symbolically,
537          # and perhaps we should do that.
538          return unknown_shape()
539        elif start < 0 or stop < 0:
540          # TODO(mrry): Handle this better, as it will be useful for handling
541          # suffixes of otherwise unknown shapes.
542          return unknown_shape()
543        else:
544          return unknown_shape(ndims=stop - start)
545      else:
546        return Dimension(None)
547
548  def num_elements(self):
549    """Returns the total number of elements, or none for incomplete shapes."""
550    if self.is_fully_defined():
551      size = 1
552      for dim in self._dims:
553        size *= dim.value
554      return size
555    else:
556      return None
557
558  def merge_with(self, other):
559    """Returns a `TensorShape` combining the information in `self` and `other`.
560
561    The dimensions in `self` and `other` are merged elementwise,
562    according to the rules defined for `Dimension.merge_with()`.
563
564    Args:
565      other: Another `TensorShape`.
566
567    Returns:
568      A `TensorShape` containing the combined information of `self` and
569      `other`.
570
571    Raises:
572      ValueError: If `self` and `other` are not compatible.
573    """
574    other = as_shape(other)
575    if self._dims is None:
576      return other
577    else:
578      try:
579        self.assert_same_rank(other)
580        new_dims = []
581        for i, dim in enumerate(self._dims):
582          new_dims.append(dim.merge_with(other[i]))
583        return TensorShape(new_dims)
584      except ValueError:
585        raise ValueError("Shapes %s and %s are not compatible" % (self, other))
586
587  def concatenate(self, other):
588    """Returns the concatenation of the dimension in `self` and `other`.
589
590    *N.B.* If either `self` or `other` is completely unknown,
591    concatenation will discard information about the other shape. In
592    future, we might support concatenation that preserves this
593    information for use with slicing.
594
595    Args:
596      other: Another `TensorShape`.
597
598    Returns:
599      A `TensorShape` whose dimensions are the concatenation of the
600      dimensions in `self` and `other`.
601    """
602    # TODO(mrry): Handle the case where we concatenate a known shape with a
603    # completely unknown shape, so that we can use the partial information.
604    other = as_shape(other)
605    if self._dims is None or other.dims is None:
606      return unknown_shape()
607    else:
608      return TensorShape(self._dims + other.dims)
609
610  def assert_same_rank(self, other):
611    """Raises an exception if `self` and `other` do not have compatible ranks.
612
613    Args:
614      other: Another `TensorShape`.
615
616    Raises:
617      ValueError: If `self` and `other` do not represent shapes with the
618        same rank.
619    """
620    other = as_shape(other)
621    if self.ndims is not None and other.ndims is not None:
622      if self.ndims != other.ndims:
623        raise ValueError("Shapes %s and %s must have the same rank" % (self,
624                                                                       other))
625
626  def assert_has_rank(self, rank):
627    """Raises an exception if `self` is not compatible with the given `rank`.
628
629    Args:
630      rank: An integer.
631
632    Raises:
633      ValueError: If `self` does not represent a shape with the given `rank`.
634    """
635    if self.ndims not in (None, rank):
636      raise ValueError("Shape %s must have rank %d" % (self, rank))
637
638  def with_rank(self, rank):
639    """Returns a shape based on `self` with the given rank.
640
641    This method promotes a completely unknown shape to one with a
642    known rank.
643
644    Args:
645      rank: An integer.
646
647    Returns:
648      A shape that is at least as specific as `self` with the given rank.
649
650    Raises:
651      ValueError: If `self` does not represent a shape with the given `rank`.
652    """
653    try:
654      return self.merge_with(unknown_shape(ndims=rank))
655    except ValueError:
656      raise ValueError("Shape %s must have rank %d" % (self, rank))
657
658  def with_rank_at_least(self, rank):
659    """Returns a shape based on `self` with at least the given rank.
660
661    Args:
662      rank: An integer.
663
664    Returns:
665      A shape that is at least as specific as `self` with at least the given
666      rank.
667
668    Raises:
669      ValueError: If `self` does not represent a shape with at least the given
670        `rank`.
671    """
672    if self.ndims is not None and self.ndims < rank:
673      raise ValueError("Shape %s must have rank at least %d" % (self, rank))
674    else:
675      return self
676
677  def with_rank_at_most(self, rank):
678    """Returns a shape based on `self` with at most the given rank.
679
680    Args:
681      rank: An integer.
682
683    Returns:
684      A shape that is at least as specific as `self` with at most the given
685      rank.
686
687    Raises:
688      ValueError: If `self` does not represent a shape with at most the given
689        `rank`.
690    """
691    if self.ndims is not None and self.ndims > rank:
692      raise ValueError("Shape %s must have rank at most %d" % (self, rank))
693    else:
694      return self
695
696  def is_compatible_with(self, other):
697    """Returns True iff `self` is compatible with `other`.
698
699    Two possibly-partially-defined shapes are compatible if there
700    exists a fully-defined shape that both shapes can represent. Thus,
701    compatibility allows the shape inference code to reason about
702    partially-defined shapes. For example:
703
704    * TensorShape(None) is compatible with all shapes.
705
706    * TensorShape([None, None]) is compatible with all two-dimensional
707      shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is
708      not compatible with, for example, TensorShape([None]) or
709      TensorShape([None, None, None]).
710
711    * TensorShape([32, None]) is compatible with all two-dimensional shapes
712      with size 32 in the 0th dimension, and also TensorShape([None, None])
713      and TensorShape(None). It is not compatible with, for example,
714      TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]).
715
716    * TensorShape([32, 784]) is compatible with itself, and also
717      TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None,
718      None]) and TensorShape(None). It is not compatible with, for example,
719      TensorShape([32, 1, 784]) or TensorShape([None]).
720
721    The compatibility relation is reflexive and symmetric, but not
722    transitive. For example, TensorShape([32, 784]) is compatible with
723    TensorShape(None), and TensorShape(None) is compatible with
724    TensorShape([4, 4]), but TensorShape([32, 784]) is not compatible with
725    TensorShape([4, 4]).
726
727    Args:
728      other: Another TensorShape.
729
730    Returns:
731      True iff `self` is compatible with `other`.
732
733    """
734    other = as_shape(other)
735    if self._dims is not None and other.dims is not None:
736      if self.ndims != other.ndims:
737        return False
738      for x_dim, y_dim in zip(self._dims, other.dims):
739        if not x_dim.is_compatible_with(y_dim):
740          return False
741    return True
742
743  def assert_is_compatible_with(self, other):
744    """Raises exception if `self` and `other` do not represent the same shape.
745
746    This method can be used to assert that there exists a shape that both
747    `self` and `other` represent.
748
749    Args:
750      other: Another TensorShape.
751
752    Raises:
753      ValueError: If `self` and `other` do not represent the same shape.
754    """
755    if not self.is_compatible_with(other):
756      raise ValueError("Shapes %s and %s are incompatible" % (self, other))
757
758  def most_specific_compatible_shape(self, other):
759    """Returns the most specific TensorShape compatible with `self` and `other`.
760
761    * TensorShape([None, 1]) is the most specific TensorShape compatible with
762      both TensorShape([2, 1]) and TensorShape([5, 1]). Note that
763      TensorShape(None) is also compatible with above mentioned TensorShapes.
764
765    * TensorShape([1, 2, 3]) is the most specific TensorShape compatible with
766      both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more
767      less specific TensorShapes compatible with above mentioned TensorShapes,
768      e.g. TensorShape([1, 2, None]), TensorShape(None).
769
770    Args:
771      other: Another `TensorShape`.
772
773    Returns:
774      A `TensorShape` which is the most specific compatible shape of `self`
775      and `other`.
776    """
777
778    other = as_shape(other)
779    if self._dims is None or other.dims is None or self.ndims != other.ndims:
780      return unknown_shape()
781
782    dims = [(Dimension(None))] * self.ndims
783    for i, (d1, d2) in enumerate(zip(self._dims, other.dims)):
784      if d1 is not None and d2 is not None and d1 == d2:
785        dims[i] = d1
786    return TensorShape(dims)
787
788  def is_fully_defined(self):
789    """Returns True iff `self` is fully defined in every dimension."""
790    return (self._dims is not None and all(dim.value is not None
791                                           for dim in self._dims))
792
793  def assert_is_fully_defined(self):
794    """Raises an exception if `self` is not fully defined in every dimension.
795
796    Raises:
797      ValueError: If `self` does not have a known value for every dimension.
798    """
799    if not self.is_fully_defined():
800      raise ValueError("Shape %s is not fully defined" % self)
801
802  def as_list(self):
803    """Returns a list of integers or `None` for each dimension.
804
805    Returns:
806      A list of integers or `None` for each dimension.
807
808    Raises:
809      ValueError: If `self` is an unknown shape with an unknown rank.
810    """
811    if self._dims is None:
812      raise ValueError("as_list() is not defined on an unknown TensorShape.")
813    return [dim.value for dim in self._dims]
814
815  def as_proto(self):
816    """Returns this shape as a `TensorShapeProto`."""
817    if self._dims is None:
818      return tensor_shape_pb2.TensorShapeProto(unknown_rank=True)
819    else:
820      return tensor_shape_pb2.TensorShapeProto(dim=[
821          tensor_shape_pb2.TensorShapeProto.Dim(size=-1
822                                                if d.value is None else d.value)
823          for d in self._dims
824      ])
825
826  def __eq__(self, other):
827    """Returns True if `self` is equivalent to `other`."""
828    try:
829      other = as_shape(other)
830    except TypeError:
831      return NotImplemented
832    return self._dims == other.dims
833
834  def __ne__(self, other):
835    """Returns True if `self` is known to be different from `other`."""
836    try:
837      other = as_shape(other)
838    except TypeError:
839      return NotImplemented
840    if self.ndims is None or other.ndims is None:
841      raise ValueError("The inequality of unknown TensorShapes is undefined.")
842    if self.ndims != other.ndims:
843      return True
844    return self._dims != other.dims
845
846
847def as_shape(shape):
848  """Converts the given object to a TensorShape."""
849  if isinstance(shape, TensorShape):
850    return shape
851  else:
852    return TensorShape(shape)
853
854
855def unknown_shape(ndims=None):
856  """Returns an unknown TensorShape, optionally with a known rank.
857
858  Args:
859    ndims: (Optional) If specified, the number of dimensions in the shape.
860
861  Returns:
862    An unknown TensorShape.
863  """
864  if ndims is None:
865    return TensorShape(None)
866  else:
867    return TensorShape([Dimension(None)] * ndims)
868
869
870def scalar():
871  """Returns a shape representing a scalar."""
872  return TensorShape([])
873
874
875def vector(length):
876  """Returns a shape representing a vector.
877
878  Args:
879    length: The length of the vector, which may be None if unknown.
880
881  Returns:
882    A TensorShape representing a vector of the given length.
883  """
884  return TensorShape([length])
885
886
887def matrix(rows, cols):
888  """Returns a shape representing a matrix.
889
890  Args:
891    rows: The number of rows in the matrix, which may be None if unknown.
892    cols: The number of columns in the matrix, which may be None if unknown.
893
894  Returns:
895    A TensorShape representing a matrix of the given size.
896  """
897  return TensorShape([rows, cols])
898