1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper classes for tensor shape inference."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import functools
21import operator
22import six
23
24from tensorflow.core.framework import tensor_shape_pb2
25from tensorflow.python import tf2
26from tensorflow.python.eager import monitoring
27from tensorflow.python.util.tf_export import tf_export
28
29_TENSORSHAPE_V2_OVERRIDE = None
30
31_api_usage_gauge = monitoring.BoolGauge(
32    "/tensorflow/api/v2_tensorshape",
33    "Whether tensor_shape.enable_v2_tensorshape() is called.")
34
35
36@tf_export(v1=["enable_v2_tensorshape"])
37def enable_v2_tensorshape():
38  """In TensorFlow 2.0, iterating over a TensorShape instance returns values.
39
40  This enables the new behavior.
41
42  Concretely, `tensor_shape[i]` returned a Dimension instance in V1, but
43  it V2 it returns either an integer, or None.
44
45  Examples:
46
47  ```
48  #######################
49  # If you had this in V1:
50  value = tensor_shape[i].value
51
52  # Do this in V2 instead:
53  value = tensor_shape[i]
54
55  #######################
56  # If you had this in V1:
57  for dim in tensor_shape:
58    value = dim.value
59    print(value)
60
61  # Do this in V2 instead:
62  for value in tensor_shape:
63    print(value)
64
65  #######################
66  # If you had this in V1:
67  dim = tensor_shape[i]
68  dim.assert_is_compatible_with(other_shape)  # or using any other shape method
69
70  # Do this in V2 instead:
71  if tensor_shape.rank is None:
72    dim = Dimension(None)
73  else:
74    dim = tensor_shape.dims[i]
75  dim.assert_is_compatible_with(other_shape)  # or using any other shape method
76
77  # The V2 suggestion above is more explicit, which will save you from
78  # the following trap (present in V1):
79  # you might do in-place modifications to `dim` and expect them to be reflected
80  # in `tensor_shape[i]`, but they would not be.
81  ```
82  """
83  global _TENSORSHAPE_V2_OVERRIDE  # pylint: disable=invalid-name
84  _TENSORSHAPE_V2_OVERRIDE = True
85  _api_usage_gauge.get_cell().set(True)
86
87
88@tf_export(v1=["disable_v2_tensorshape"])
89def disable_v2_tensorshape():
90  """Disables the V2 TensorShape behavior and reverts to V1 behavior.
91
92  See docstring for `enable_v2_tensorshape` for details about the new behavior.
93  """
94  global _TENSORSHAPE_V2_OVERRIDE  # pylint: disable=invalid-name
95  _TENSORSHAPE_V2_OVERRIDE = False
96  _api_usage_gauge.get_cell().set(False)
97
98
99@tf_export(
100    "compat.dimension_value", v1=["dimension_value", "compat.dimension_value"])
101def dimension_value(dimension):
102  """Compatibility utility required to allow for both V1 and V2 behavior in TF.
103
104  Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to
105  coexist with the new behavior. This utility is a bridge between the two.
106
107  When accessing the value of a TensorShape dimension,
108  use this utility, like this:
109
110  ```
111  # If you had this in your V1 code:
112  value = tensor_shape[i].value
113
114  # Use `dimension_value` as direct replacement compatible with both V1 & V2:
115  value = dimension_value(tensor_shape[i])
116
117  # This would be the V2 equivalent:
118  value = tensor_shape[i]  # Warning: this will return the dim value in V2!
119  ```
120
121  Args:
122    dimension: Either a `Dimension` instance, an integer, or None.
123
124  Returns:
125    A plain value, i.e. an integer or None.
126  """
127  if isinstance(dimension, Dimension):
128    return dimension.value
129  return dimension
130
131
132@tf_export(
133    "compat.dimension_at_index",
134    v1=["dimension_at_index", "compat.dimension_at_index"])
135def dimension_at_index(shape, index):
136  """Compatibility utility required to allow for both V1 and V2 behavior in TF.
137
138  Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to
139  coexist with the new behavior. This utility is a bridge between the two.
140
141  If you want to retrieve the Dimension instance corresponding to a certain
142  index in a TensorShape instance, use this utility, like this:
143
144  ```
145  # If you had this in your V1 code:
146  dim = tensor_shape[i]
147
148  # Use `dimension_at_index` as direct replacement compatible with both V1 & V2:
149  dim = dimension_at_index(tensor_shape, i)
150
151  # Another possibility would be this, but WARNING: it only works if the
152  # tensor_shape instance has a defined rank.
153  dim = tensor_shape.dims[i]  # `dims` may be None if the rank is undefined!
154
155  # In native V2 code, we recommend instead being more explicit:
156  if tensor_shape.rank is None:
157    dim = Dimension(None)
158  else:
159    dim = tensor_shape.dims[i]
160
161  # Being more explicit will save you from the following trap (present in V1):
162  # you might do in-place modifications to `dim` and expect them to be reflected
163  # in `tensor_shape[i]`, but they would not be (as the Dimension object was
164  # instantiated on the fly.
165  ```
166
167  Args:
168    shape: A TensorShape instance.
169    index: An integer index.
170
171  Returns:
172    A dimension object.
173  """
174  assert isinstance(shape, TensorShape)
175  if shape.rank is None:
176    return Dimension(None)
177  else:
178    return shape.dims[index]
179
180
181@tf_export(v1=["Dimension"])
182class Dimension(object):
183  """Represents the value of one dimension in a TensorShape."""
184
185  __slots__ = ["_value"]
186
187  def __init__(self, value):
188    """Creates a new Dimension with the given value."""
189    if isinstance(value, int):  # Most common case.
190      if value < 0:
191        raise ValueError("Dimension %d must be >= 0" % value)
192      self._value = value
193    elif value is None:
194      self._value = None
195    elif isinstance(value, Dimension):
196      self._value = value._value
197    else:
198      try:
199        # int(...) compensates for the int/long dichotomy on Python 2.X.
200        # TODO(b/143206389): Remove once we fully migrate to 3.X.
201        self._value = int(value.__index__())
202      except AttributeError:
203        six.raise_from(
204            TypeError("Dimension value must be integer or None or have "
205                      "an __index__ method, got value '{0!r}' with type '{1!r}'"
206                      .format(value, type(value))), None)
207      if self._value < 0:
208        raise ValueError("Dimension %d must be >= 0" % self._value)
209
210  def __repr__(self):
211    return "Dimension(%s)" % repr(self._value)
212
213  def __str__(self):
214    value = self._value
215    return "?" if value is None else str(value)
216
217  def __eq__(self, other):
218    """Returns true if `other` has the same known value as this Dimension."""
219    try:
220      other = as_dimension(other)
221    except (TypeError, ValueError):
222      return NotImplemented
223    if self._value is None or other.value is None:
224      return None
225    return self._value == other.value
226
227  def __ne__(self, other):
228    """Returns true if `other` has a different known value from `self`."""
229    try:
230      other = as_dimension(other)
231    except (TypeError, ValueError):
232      return NotImplemented
233    if self._value is None or other.value is None:
234      return None
235    return self._value != other.value
236
237  def __bool__(self):
238    """Equivalent to `bool(self.value)`."""
239    return bool(self._value)
240
241  def __int__(self):
242    return self._value
243
244  # This is needed for Windows.
245  # See https://github.com/tensorflow/tensorflow/pull/9780
246  def __long__(self):
247    return self._value
248
249  def __index__(self):
250    # Allow use in Python 3 range
251    return self._value
252
253  @property
254  def value(self):
255    """The value of this dimension, or None if it is unknown."""
256    return self._value
257
258  def is_compatible_with(self, other):
259    """Returns true if `other` is compatible with this Dimension.
260
261    Two known Dimensions are compatible if they have the same value.
262    An unknown Dimension is compatible with all other Dimensions.
263
264    Args:
265      other: Another Dimension.
266
267    Returns:
268      True if this Dimension and `other` are compatible.
269    """
270    other = as_dimension(other)
271    return (self._value is None or other.value is None or
272            self._value == other.value)
273
274  def assert_is_compatible_with(self, other):
275    """Raises an exception if `other` is not compatible with this Dimension.
276
277    Args:
278      other: Another Dimension.
279
280    Raises:
281      ValueError: If `self` and `other` are not compatible (see
282        is_compatible_with).
283    """
284    if not self.is_compatible_with(other):
285      raise ValueError("Dimensions %s and %s are not compatible" %
286                       (self, other))
287
288  def merge_with(self, other):
289    """Returns a Dimension that combines the information in `self` and `other`.
290
291    Dimensions are combined as follows:
292
293    ```python
294    tf.compat.v1.Dimension(n)   .merge_with(tf.compat.v1.Dimension(n))     ==
295    tf.compat.v1.Dimension(n)
296    tf.compat.v1.Dimension(n)   .merge_with(tf.compat.v1.Dimension(None))  ==
297    tf.compat.v1.Dimension(n)
298    tf.compat.v1.Dimension(None).merge_with(tf.compat.v1.Dimension(n))     ==
299    tf.compat.v1.Dimension(n)
300    # equivalent to tf.compat.v1.Dimension(None)
301    tf.compat.v1.Dimension(None).merge_with(tf.compat.v1.Dimension(None))
302
303    # raises ValueError for n != m
304    tf.compat.v1.Dimension(n)   .merge_with(tf.compat.v1.Dimension(m))
305    ```
306
307    Args:
308      other: Another Dimension.
309
310    Returns:
311      A Dimension containing the combined information of `self` and
312      `other`.
313
314    Raises:
315      ValueError: If `self` and `other` are not compatible (see
316        is_compatible_with).
317    """
318    other = as_dimension(other)
319    self.assert_is_compatible_with(other)
320    if self._value is None:
321      return Dimension(other.value)
322    else:
323      return Dimension(self._value)
324
325  def __add__(self, other):
326    """Returns the sum of `self` and `other`.
327
328    Dimensions are summed as follows:
329
330    ```python
331    tf.compat.v1.Dimension(m)    + tf.compat.v1.Dimension(n)     ==
332    tf.compat.v1.Dimension(m + n)
333    tf.compat.v1.Dimension(m)    + tf.compat.v1.Dimension(None)  # equiv. to
334    tf.compat.v1.Dimension(None)
335    tf.compat.v1.Dimension(None) + tf.compat.v1.Dimension(n)     # equiv. to
336    tf.compat.v1.Dimension(None)
337    tf.compat.v1.Dimension(None) + tf.compat.v1.Dimension(None)  # equiv. to
338    tf.compat.v1.Dimension(None)
339    ```
340
341    Args:
342      other: Another Dimension, or a value accepted by `as_dimension`.
343
344    Returns:
345      A Dimension whose value is the sum of `self` and `other`.
346    """
347    try:
348      other = as_dimension(other)
349    except (TypeError, ValueError):
350      return NotImplemented
351    if self._value is None or other.value is None:
352      return Dimension(None)
353    else:
354      return Dimension(self._value + other.value)
355
356  def __radd__(self, other):
357    """Returns the sum of `other` and `self`.
358
359    Args:
360      other: Another Dimension, or a value accepted by `as_dimension`.
361
362    Returns:
363      A Dimension whose value is the sum of `self` and `other`.
364    """
365    return self + other
366
367  def __sub__(self, other):
368    """Returns the subtraction of `other` from `self`.
369
370    Dimensions are subtracted as follows:
371
372    ```python
373    tf.compat.v1.Dimension(m)    - tf.compat.v1.Dimension(n)     ==
374    tf.compat.v1.Dimension(m - n)
375    tf.compat.v1.Dimension(m)    - tf.compat.v1.Dimension(None)  # equiv. to
376    tf.compat.v1.Dimension(None)
377    tf.compat.v1.Dimension(None) - tf.compat.v1.Dimension(n)     # equiv. to
378    tf.compat.v1.Dimension(None)
379    tf.compat.v1.Dimension(None) - tf.compat.v1.Dimension(None)  # equiv. to
380    tf.compat.v1.Dimension(None)
381    ```
382
383    Args:
384      other: Another Dimension, or a value accepted by `as_dimension`.
385
386    Returns:
387      A Dimension whose value is the subtraction of `other` from `self`.
388    """
389    try:
390      other = as_dimension(other)
391    except (TypeError, ValueError):
392      return NotImplemented
393    if self._value is None or other.value is None:
394      return Dimension(None)
395    else:
396      return Dimension(self._value - other.value)
397
398  def __rsub__(self, other):
399    """Returns the subtraction of `self` from `other`.
400
401    Args:
402      other: Another Dimension, or a value accepted by `as_dimension`.
403
404    Returns:
405      A Dimension whose value is the subtraction of `self` from `other`.
406    """
407    other = as_dimension(other)
408    if self._value is None or other.value is None:
409      return Dimension(None)
410    else:
411      return Dimension(other.value - self._value)
412
413  def __mul__(self, other):
414    """Returns the product of `self` and `other`.
415
416    Dimensions are summed as follows:
417
418    ```python
419    tf.compat.v1.Dimension(m)    * tf.compat.v1.Dimension(n)     ==
420    tf.compat.v1.Dimension(m * n)
421    tf.compat.v1.Dimension(m)    * tf.compat.v1.Dimension(None)  # equiv. to
422    tf.compat.v1.Dimension(None)
423    tf.compat.v1.Dimension(None) * tf.compat.v1.Dimension(n)     # equiv. to
424    tf.compat.v1.Dimension(None)
425    tf.compat.v1.Dimension(None) * tf.compat.v1.Dimension(None)  # equiv. to
426    tf.compat.v1.Dimension(None)
427    ```
428
429    Args:
430      other: Another Dimension, or a value accepted by `as_dimension`.
431
432    Returns:
433      A Dimension whose value is the product of `self` and `other`.
434    """
435    try:
436      other = as_dimension(other)
437    except (TypeError, ValueError):
438      return NotImplemented
439
440    if self._value is None or other.value is None:
441      return Dimension(None)
442    else:
443      return Dimension(self._value * other.value)
444
445  def __rmul__(self, other):
446    """Returns the product of `self` and `other`.
447
448    Args:
449      other: Another Dimension, or a value accepted by `as_dimension`.
450
451    Returns:
452      A Dimension whose value is the product of `self` and `other`.
453    """
454    return self * other
455
456  def __floordiv__(self, other):
457    """Returns the quotient of `self` and `other` rounded down.
458
459    Dimensions are divided as follows:
460
461    ```python
462    tf.compat.v1.Dimension(m)    // tf.compat.v1.Dimension(n)     ==
463    tf.compat.v1.Dimension(m // n)
464    tf.compat.v1.Dimension(m)    // tf.compat.v1.Dimension(None)  # equiv. to
465    tf.compat.v1.Dimension(None)
466    tf.compat.v1.Dimension(None) // tf.compat.v1.Dimension(n)     # equiv. to
467    tf.compat.v1.Dimension(None)
468    tf.compat.v1.Dimension(None) // tf.compat.v1.Dimension(None)  # equiv. to
469    tf.compat.v1.Dimension(None)
470    ```
471
472    Args:
473      other: Another Dimension, or a value accepted by `as_dimension`.
474
475    Returns:
476      A `Dimension` whose value is the integer quotient of `self` and `other`.
477    """
478    try:
479      other = as_dimension(other)
480    except (TypeError, ValueError):
481      return NotImplemented
482    if self._value is None or other.value is None:
483      return Dimension(None)
484    else:
485      return Dimension(self._value // other.value)
486
487  def __rfloordiv__(self, other):
488    """Returns the quotient of `other` and `self` rounded down.
489
490    Args:
491      other: Another Dimension, or a value accepted by `as_dimension`.
492
493    Returns:
494      A `Dimension` whose value is the integer quotient of `self` and `other`.
495    """
496    other = as_dimension(other)
497    if self._value is None or other.value is None:
498      return Dimension(None)
499    else:
500      return Dimension(other.value // self._value)
501
502  def __div__(self, other):
503    """DEPRECATED: Use `__floordiv__` via `x // y` instead.
504
505    This function exists only for backwards compatibility purposes; new code
506    should use `__floordiv__` via the syntax `x // y`.  Using `x // y`
507    communicates clearly that the result rounds down, and is forward compatible
508    to Python 3.
509
510    Args:
511      other: Another `Dimension`.
512
513    Returns:
514      A `Dimension` whose value is the integer quotient of `self` and `other`.
515    """
516    return self // other
517
518  def __rdiv__(self, other):
519    """Use `__floordiv__` via `x // y` instead.
520
521    This function exists only to have a better error message. Instead of:
522    `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`,
523    this function will explicitly call for usage of `//` instead.
524
525    Args:
526      other: Another `Dimension`.
527
528    Raises:
529      TypeError.
530    """
531    raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', "
532                    "please use // instead".format(type(other).__name__))
533
534  def __truediv__(self, other):
535    """Use `__floordiv__` via `x // y` instead.
536
537    This function exists only to have a better error message. Instead of:
538    `TypeError: unsupported operand type(s) for /: 'Dimension' and 'int'`,
539    this function will explicitly call for usage of `//` instead.
540
541    Args:
542      other: Another `Dimension`.
543
544    Raises:
545      TypeError.
546    """
547    raise TypeError("unsupported operand type(s) for /: 'Dimension' and '{}', "
548                    "please use // instead".format(type(other).__name__))
549
550  def __rtruediv__(self, other):
551    """Use `__floordiv__` via `x // y` instead.
552
553    This function exists only to have a better error message. Instead of:
554    `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`,
555    this function will explicitly call for usage of `//` instead.
556
557    Args:
558      other: Another `Dimension`.
559
560    Raises:
561      TypeError.
562    """
563    raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', "
564                    "please use // instead".format(type(other).__name__))
565
566  def __mod__(self, other):
567    """Returns `self` modulo `other`.
568
569    Dimension modulo are computed as follows:
570
571    ```python
572    tf.compat.v1.Dimension(m)    % tf.compat.v1.Dimension(n)     ==
573    tf.compat.v1.Dimension(m % n)
574    tf.compat.v1.Dimension(m)    % tf.compat.v1.Dimension(None)  # equiv. to
575    tf.compat.v1.Dimension(None)
576    tf.compat.v1.Dimension(None) % tf.compat.v1.Dimension(n)     # equiv. to
577    tf.compat.v1.Dimension(None)
578    tf.compat.v1.Dimension(None) % tf.compat.v1.Dimension(None)  # equiv. to
579    tf.compat.v1.Dimension(None)
580    ```
581
582    Args:
583      other: Another Dimension, or a value accepted by `as_dimension`.
584
585    Returns:
586      A Dimension whose value is `self` modulo `other`.
587    """
588    other = as_dimension(other)
589    if self._value is None or other.value is None:
590      return Dimension(None)
591    else:
592      return Dimension(self._value % other.value)
593
594  def __rmod__(self, other):
595    """Returns `other` modulo `self`.
596
597    Args:
598      other: Another Dimension, or a value accepted by `as_dimension`.
599
600    Returns:
601      A Dimension whose value is `other` modulo `self`.
602    """
603    other = as_dimension(other)
604    return other % self
605
606  def __lt__(self, other):
607    """Returns True if `self` is known to be less than `other`.
608
609    Dimensions are compared as follows:
610
611    ```python
612    (tf.compat.v1.Dimension(m)    < tf.compat.v1.Dimension(n))    == (m < n)
613    (tf.compat.v1.Dimension(m)    < tf.compat.v1.Dimension(None)) == None
614    (tf.compat.v1.Dimension(None) < tf.compat.v1.Dimension(n))    == None
615    (tf.compat.v1.Dimension(None) < tf.compat.v1.Dimension(None)) == None
616    ```
617
618    Args:
619      other: Another Dimension.
620
621    Returns:
622      The value of `self.value < other.value` if both are known, otherwise
623      None.
624    """
625    other = as_dimension(other)
626    if self._value is None or other.value is None:
627      return None
628    else:
629      return self._value < other.value
630
631  def __le__(self, other):
632    """Returns True if `self` is known to be less than or equal to `other`.
633
634    Dimensions are compared as follows:
635
636    ```python
637    (tf.compat.v1.Dimension(m)    <= tf.compat.v1.Dimension(n))    == (m <= n)
638    (tf.compat.v1.Dimension(m)    <= tf.compat.v1.Dimension(None)) == None
639    (tf.compat.v1.Dimension(None) <= tf.compat.v1.Dimension(n))    == None
640    (tf.compat.v1.Dimension(None) <= tf.compat.v1.Dimension(None)) == None
641    ```
642
643    Args:
644      other: Another Dimension.
645
646    Returns:
647      The value of `self.value <= other.value` if both are known, otherwise
648      None.
649    """
650    other = as_dimension(other)
651    if self._value is None or other.value is None:
652      return None
653    else:
654      return self._value <= other.value
655
656  def __gt__(self, other):
657    """Returns True if `self` is known to be greater than `other`.
658
659    Dimensions are compared as follows:
660
661    ```python
662    (tf.compat.v1.Dimension(m)    > tf.compat.v1.Dimension(n))    == (m > n)
663    (tf.compat.v1.Dimension(m)    > tf.compat.v1.Dimension(None)) == None
664    (tf.compat.v1.Dimension(None) > tf.compat.v1.Dimension(n))    == None
665    (tf.compat.v1.Dimension(None) > tf.compat.v1.Dimension(None)) == None
666    ```
667
668    Args:
669      other: Another Dimension.
670
671    Returns:
672      The value of `self.value > other.value` if both are known, otherwise
673      None.
674    """
675    other = as_dimension(other)
676    if self._value is None or other.value is None:
677      return None
678    else:
679      return self._value > other.value
680
681  def __ge__(self, other):
682    """Returns True if `self` is known to be greater than or equal to `other`.
683
684    Dimensions are compared as follows:
685
686    ```python
687    (tf.compat.v1.Dimension(m)    >= tf.compat.v1.Dimension(n))    == (m >= n)
688    (tf.compat.v1.Dimension(m)    >= tf.compat.v1.Dimension(None)) == None
689    (tf.compat.v1.Dimension(None) >= tf.compat.v1.Dimension(n))    == None
690    (tf.compat.v1.Dimension(None) >= tf.compat.v1.Dimension(None)) == None
691    ```
692
693    Args:
694      other: Another Dimension.
695
696    Returns:
697      The value of `self.value >= other.value` if both are known, otherwise
698      None.
699    """
700    other = as_dimension(other)
701    if self._value is None or other.value is None:
702      return None
703    else:
704      return self._value >= other.value
705
706  def __reduce__(self):
707    return Dimension, (self._value,)
708
709
710def as_dimension(value):
711  """Converts the given value to a Dimension.
712
713  A Dimension input will be returned unmodified.
714  An input of `None` will be converted to an unknown Dimension.
715  An integer input will be converted to a Dimension with that value.
716
717  Args:
718    value: The value to be converted.
719
720  Returns:
721    A Dimension corresponding to the given value.
722  """
723  if isinstance(value, Dimension):
724    return value
725  else:
726    return Dimension(value)
727
728
729@tf_export("TensorShape")
730class TensorShape(object):
731  """Represents the shape of a `Tensor`.
732
733  A `TensorShape` represents a possibly-partial shape specification for a
734  `Tensor`. It may be one of the following:
735
736  * *Fully-known shape:* has a known number of dimensions and a known size
737    for each dimension. e.g. `TensorShape([16, 256])`
738  * *Partially-known shape:* has a known number of dimensions, and an unknown
739    size for one or more dimension. e.g. `TensorShape([None, 256])`
740  * *Unknown shape:* has an unknown number of dimensions, and an unknown
741    size in all dimensions. e.g. `TensorShape(None)`
742
743  If a tensor is produced by an operation of type `"Foo"`, its shape
744  may be inferred if there is a registered shape function for
745  `"Foo"`. See [Shape
746  functions](https://tensorflow.org/extend/adding_an_op#shape_functions_in_c)
747  for details of shape functions and how to register them. Alternatively,
748  the shape may be set explicitly using `tf.Tensor.set_shape`.
749  """
750  __slots__ = ["_dims"]
751
752  def __init__(self, dims):
753    """Creates a new TensorShape with the given dimensions.
754
755    Args:
756      dims: A list of Dimensions, or None if the shape is unspecified.
757
758    Raises:
759      TypeError: If dims cannot be converted to a list of dimensions.
760    """
761    if isinstance(dims, (tuple, list)):  # Most common case.
762      self._dims = [Dimension(d) for d in dims]
763    elif dims is None:
764      self._dims = None
765    elif isinstance(dims, tensor_shape_pb2.TensorShapeProto):
766      if dims.unknown_rank:
767        self._dims = None
768      else:
769        self._dims = [
770            # Protos store variable-size dimensions as -1
771            as_dimension(dim.size if dim.size != -1 else None)
772            for dim in dims.dim
773        ]
774    elif isinstance(dims, TensorShape):
775      self._dims = dims.dims
776    else:
777      try:
778        dims_iter = iter(dims)
779      except TypeError:
780        # Treat as a singleton dimension
781        self._dims = [as_dimension(dims)]
782      else:
783        self._dims = []
784        for d in dims_iter:
785          try:
786            self._dims.append(as_dimension(d))
787          except TypeError as e:
788            six.raise_from(
789                TypeError(
790                    "Failed to convert '{0!r}' to a shape: '{1!r}'"
791                    "could not be converted to a dimension. A shape should "
792                    "either be single dimension (e.g. 10), or an iterable of "
793                    "dimensions (e.g. [1, 10, None])."
794                    .format(dims, d)), e)
795
796  @property
797  def _v2_behavior(self):
798    if _TENSORSHAPE_V2_OVERRIDE is None:
799      return tf2.enabled()
800    return _TENSORSHAPE_V2_OVERRIDE
801
802  def __repr__(self):
803    if self._v2_behavior:
804      if self._dims is not None:
805        return "TensorShape(%r)" % [dim.value for dim in self._dims]
806      else:
807        return "TensorShape(None)"
808    else:
809      return "TensorShape(%r)" % self._dims
810
811  def __str__(self):
812    if self.rank is None:
813      return "<unknown>"
814    elif self.rank == 1:
815      if self._v2_behavior:
816        return "(%s,)" % self._dims[0].value
817      else:
818        return "(%s,)" % self._dims[0]
819    else:
820      if self._v2_behavior:
821        return "(%s)" % ", ".join(str(d.value) for d in self._dims)
822      else:
823        return "(%s)" % ", ".join(str(d) for d in self._dims)
824
825  @property
826  def rank(self):
827    """Returns the rank of this shape, or None if it is unspecified."""
828    if self._dims is not None:
829      return len(self._dims)
830    return None
831
832  @property
833  def dims(self):
834    """Deprecated.  Returns list of dimensions for this shape.
835
836    Suggest `TensorShape.as_list` instead.
837
838    Returns:
839      A list containing `tf.compat.v1.Dimension`s, or None if the shape is
840      unspecified.
841    """
842    return self._dims
843
844  @property
845  def ndims(self):
846    """Deprecated accessor for `rank`."""
847    return self.rank
848
849  def __len__(self):
850    """Returns the rank of this shape, or raises ValueError if unspecified."""
851    if self._dims is None:
852      raise ValueError("Cannot take the length of shape with unknown rank.")
853    return len(self._dims)
854
855  def __bool__(self):
856    """Returns True if this shape contains non-zero information."""
857    return self._dims is not None
858
859  # Python 3 wants __bool__, Python 2.7 wants __nonzero__
860  __nonzero__ = __bool__
861
862  def __iter__(self):
863    """Returns `self.dims` if the rank is known, otherwise raises ValueError."""
864    if self._dims is None:
865      raise ValueError("Cannot iterate over a shape with unknown rank.")
866    else:
867      if self._v2_behavior:
868        return iter(d.value for d in self._dims)
869      else:
870        return iter(d for d in self._dims)
871
872  def __getitem__(self, key):
873    """Returns the value of a dimension or a shape, depending on the key.
874
875    Args:
876      key: If `key` is an integer, returns the dimension at that index;
877        otherwise if `key` is a slice, returns a TensorShape whose dimensions
878        are those selected by the slice from `self`.
879
880    Returns:
881      An integer if `key` is an integer, or a `TensorShape` if `key` is a
882      slice.
883
884    Raises:
885      ValueError: If `key` is a slice and `self` is completely unknown and
886        the step is set.
887    """
888    if self._dims is not None:
889      if isinstance(key, slice):
890        return TensorShape(self._dims[key])
891      else:
892        if self._v2_behavior:
893          return self._dims[key].value
894        else:
895          return self._dims[key]
896    else:
897      if isinstance(key, slice):
898        start = key.start if key.start is not None else 0
899        stop = key.stop
900
901        if key.step is not None:
902          # TODO(mrry): Handle these maybe.
903          raise ValueError("Steps are not yet handled")
904        if stop is None:
905          # NOTE(mrry): This implies that TensorShape(None) is compatible with
906          # TensorShape(None)[1:], which is obviously not true. It would be
907          # possible to track the number of dimensions symbolically,
908          # and perhaps we should do that.
909          return unknown_shape()
910        elif start < 0 or stop < 0:
911          # TODO(mrry): Handle this better, as it will be useful for handling
912          # suffixes of otherwise unknown shapes.
913          return unknown_shape()
914        else:
915          return unknown_shape(rank=stop - start)
916      else:
917        if self._v2_behavior:
918          return None
919        else:
920          return Dimension(None)
921
922  def num_elements(self):
923    """Returns the total number of elements, or none for incomplete shapes."""
924    if self.is_fully_defined():
925      return functools.reduce(operator.mul, self.as_list(), 1)
926    else:
927      return None
928
929  def merge_with(self, other):
930    """Returns a `TensorShape` combining the information in `self` and `other`.
931
932    The dimensions in `self` and `other` are merged element-wise,
933    according to the rules below:
934
935    ```python
936    Dimension(n).merge_with(Dimension(None)) == Dimension(n)
937    Dimension(None).merge_with(Dimension(n)) == Dimension(n)
938    Dimension(None).merge_with(Dimension(None)) == Dimension(None)
939    # raises ValueError for n != m
940    Dimension(n).merge_with(Dimension(m))
941    ```
942    >> ts = tf.TensorShape([1,2])
943    >> ot1 = tf.TensorShape([1,2])
944    >> ts.merge_with(ot).as_list()
945    [1,2]
946
947    >> ot2 = tf.TensorShape([1,None])
948    >> ts.merge_with(ot2).as_list()
949    [1,2]
950
951    >> ot3 = tf.TensorShape([None, None])
952    >> ot3.merge_with(ot2).as_list()
953    [1, None]
954
955    Args:
956      other: Another `TensorShape`.
957
958    Returns:
959      A `TensorShape` containing the combined information of `self` and
960      `other`.
961
962    Raises:
963      ValueError: If `self` and `other` are not compatible.
964    """
965    other = as_shape(other)
966    if self._dims is None:
967      return other
968    if other.dims is None:
969      return self
970    else:
971      try:
972        self.assert_same_rank(other)
973        new_dims = [
974            dim.merge_with(other_dim)
975            for dim, other_dim in zip(self._dims, other.dims)
976        ]
977        return TensorShape(new_dims)
978      except ValueError:
979        raise ValueError("Shapes %s and %s are not compatible" % (self, other))
980
981  def __add__(self, other):
982    return self.concatenate(other)
983
984  def __radd__(self, other):
985    if not isinstance(other, TensorShape):
986      other = TensorShape(other)
987    return other.concatenate(self)
988
989  def concatenate(self, other):
990    """Returns the concatenation of the dimension in `self` and `other`.
991
992    *N.B.* If either `self` or `other` is completely unknown,
993    concatenation will discard information about the other shape. In
994    future, we might support concatenation that preserves this
995    information for use with slicing.
996
997    Args:
998      other: Another `TensorShape`.
999
1000    Returns:
1001      A `TensorShape` whose dimensions are the concatenation of the
1002      dimensions in `self` and `other`.
1003    """
1004    # TODO(mrry): Handle the case where we concatenate a known shape with a
1005    # completely unknown shape, so that we can use the partial information.
1006    other = as_shape(other)
1007    if self._dims is None or other.dims is None:
1008      return unknown_shape()
1009    else:
1010      return TensorShape(self._dims + other.dims)
1011
1012  def assert_same_rank(self, other):
1013    """Raises an exception if `self` and `other` do not have compatible ranks.
1014
1015    Args:
1016      other: Another `TensorShape`.
1017
1018    Raises:
1019      ValueError: If `self` and `other` do not represent shapes with the
1020        same rank.
1021    """
1022    other = as_shape(other)
1023    if self.rank is not None and other.rank is not None:
1024      if self.rank != other.rank:
1025        raise ValueError("Shapes %s and %s must have the same rank" %
1026                         (self, other))
1027
1028  def assert_has_rank(self, rank):
1029    """Raises an exception if `self` is not compatible with the given `rank`.
1030
1031    Args:
1032      rank: An integer.
1033
1034    Raises:
1035      ValueError: If `self` does not represent a shape with the given `rank`.
1036    """
1037    if self.rank not in (None, rank):
1038      raise ValueError("Shape %s must have rank %d" % (self, rank))
1039
1040  def with_rank(self, rank):
1041    """Returns a shape based on `self` with the given rank.
1042
1043    This method promotes a completely unknown shape to one with a
1044    known rank.
1045
1046    Args:
1047      rank: An integer.
1048
1049    Returns:
1050      A shape that is at least as specific as `self` with the given rank.
1051
1052    Raises:
1053      ValueError: If `self` does not represent a shape with the given `rank`.
1054    """
1055    try:
1056      return self.merge_with(unknown_shape(rank=rank))
1057    except ValueError:
1058      raise ValueError("Shape %s must have rank %d" % (self, rank))
1059
1060  def with_rank_at_least(self, rank):
1061    """Returns a shape based on `self` with at least the given rank.
1062
1063    Args:
1064      rank: An integer.
1065
1066    Returns:
1067      A shape that is at least as specific as `self` with at least the given
1068      rank.
1069
1070    Raises:
1071      ValueError: If `self` does not represent a shape with at least the given
1072        `rank`.
1073    """
1074    if self.rank is not None and self.rank < rank:
1075      raise ValueError("Shape %s must have rank at least %d" % (self, rank))
1076    else:
1077      return self
1078
1079  def with_rank_at_most(self, rank):
1080    """Returns a shape based on `self` with at most the given rank.
1081
1082    Args:
1083      rank: An integer.
1084
1085    Returns:
1086      A shape that is at least as specific as `self` with at most the given
1087      rank.
1088
1089    Raises:
1090      ValueError: If `self` does not represent a shape with at most the given
1091        `rank`.
1092    """
1093    if self.rank is not None and self.rank > rank:
1094      raise ValueError("Shape %s must have rank at most %d" % (self, rank))
1095    else:
1096      return self
1097
1098  def is_compatible_with(self, other):
1099    """Returns True iff `self` is compatible with `other`.
1100
1101    Two possibly-partially-defined shapes are compatible if there
1102    exists a fully-defined shape that both shapes can represent. Thus,
1103    compatibility allows the shape inference code to reason about
1104    partially-defined shapes. For example:
1105
1106    * TensorShape(None) is compatible with all shapes.
1107
1108    * TensorShape([None, None]) is compatible with all two-dimensional
1109      shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is
1110      not compatible with, for example, TensorShape([None]) or
1111      TensorShape([None, None, None]).
1112
1113    * TensorShape([32, None]) is compatible with all two-dimensional shapes
1114      with size 32 in the 0th dimension, and also TensorShape([None, None])
1115      and TensorShape(None). It is not compatible with, for example,
1116      TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]).
1117
1118    * TensorShape([32, 784]) is compatible with itself, and also
1119      TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None,
1120      None]) and TensorShape(None). It is not compatible with, for example,
1121      TensorShape([32, 1, 784]) or TensorShape([None]).
1122
1123    The compatibility relation is reflexive and symmetric, but not
1124    transitive. For example, TensorShape([32, 784]) is compatible with
1125    TensorShape(None), and TensorShape(None) is compatible with
1126    TensorShape([4, 4]), but TensorShape([32, 784]) is not compatible with
1127    TensorShape([4, 4]).
1128
1129    Args:
1130      other: Another TensorShape.
1131
1132    Returns:
1133      True iff `self` is compatible with `other`.
1134
1135    """
1136    other = as_shape(other)
1137    if self._dims is not None and other.dims is not None:
1138      if self.rank != other.rank:
1139        return False
1140      for x_dim, y_dim in zip(self._dims, other.dims):
1141        if not x_dim.is_compatible_with(y_dim):
1142          return False
1143    return True
1144
1145  def assert_is_compatible_with(self, other):
1146    """Raises exception if `self` and `other` do not represent the same shape.
1147
1148    This method can be used to assert that there exists a shape that both
1149    `self` and `other` represent.
1150
1151    Args:
1152      other: Another TensorShape.
1153
1154    Raises:
1155      ValueError: If `self` and `other` do not represent the same shape.
1156    """
1157    if not self.is_compatible_with(other):
1158      raise ValueError("Shapes %s and %s are incompatible" % (self, other))
1159
1160  def most_specific_compatible_shape(self, other):
1161    """Returns the most specific TensorShape compatible with `self` and `other`.
1162
1163    * TensorShape([None, 1]) is the most specific TensorShape compatible with
1164      both TensorShape([2, 1]) and TensorShape([5, 1]). Note that
1165      TensorShape(None) is also compatible with above mentioned TensorShapes.
1166
1167    * TensorShape([1, 2, 3]) is the most specific TensorShape compatible with
1168      both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more
1169      less specific TensorShapes compatible with above mentioned TensorShapes,
1170      e.g. TensorShape([1, 2, None]), TensorShape(None).
1171
1172    Args:
1173      other: Another `TensorShape`.
1174
1175    Returns:
1176      A `TensorShape` which is the most specific compatible shape of `self`
1177      and `other`.
1178    """
1179
1180    other = as_shape(other)
1181    if self._dims is None or other.dims is None or self.rank != other.rank:
1182      return unknown_shape()
1183
1184    dims = [
1185        d1 if d1 is not None and d2 is not None and d1 == d2 else None
1186        for d1, d2 in zip(self._dims, other.dims)
1187    ]
1188    return TensorShape(dims)
1189
1190  def is_fully_defined(self):
1191    """Returns True iff `self` is fully defined in every dimension."""
1192    return (self._dims is not None and
1193            all(dim.value is not None for dim in self._dims))
1194
1195  def assert_is_fully_defined(self):
1196    """Raises an exception if `self` is not fully defined in every dimension.
1197
1198    Raises:
1199      ValueError: If `self` does not have a known value for every dimension.
1200    """
1201    if not self.is_fully_defined():
1202      raise ValueError("Shape %s is not fully defined" % self)
1203
1204  def as_list(self):
1205    """Returns a list of integers or `None` for each dimension.
1206
1207    Returns:
1208      A list of integers or `None` for each dimension.
1209
1210    Raises:
1211      ValueError: If `self` is an unknown shape with an unknown rank.
1212    """
1213    if self._dims is None:
1214      raise ValueError("as_list() is not defined on an unknown TensorShape.")
1215    return [dim.value for dim in self._dims]
1216
1217  def as_proto(self):
1218    """Returns this shape as a `TensorShapeProto`."""
1219    if self._dims is None:
1220      return tensor_shape_pb2.TensorShapeProto(unknown_rank=True)
1221    else:
1222      return tensor_shape_pb2.TensorShapeProto(dim=[
1223          tensor_shape_pb2.TensorShapeProto.Dim(
1224              size=-1 if d.value is None else d.value) for d in self._dims
1225      ])
1226
1227  def __eq__(self, other):
1228    """Returns True if `self` is equivalent to `other`.
1229
1230    It first tries to convert `other` to `TensorShape`. `TypeError` is thrown
1231    when the conversion fails. Otherwise, it compares each element in the
1232    TensorShape dimensions.
1233
1234    * Two *Fully known* shapes, return True iff each element is equal.
1235    >>> t_a = tf.TensorShape([1,2])
1236    >>> a = [1, 2]
1237    >>> t_b = tf.TensorShape([1,2])
1238    >>> t_c = tf.TensorShape([1,2,3])
1239    >>> t_a.__eq__(a)
1240    True
1241    >>> t_a.__eq__(t_b)
1242    True
1243    >>> t_a.__eq__(t_c)
1244    False
1245
1246    * Two *Partially-known* shapes, return False.
1247    >>> p_a = tf.TensorShape([1,None])
1248    >>> p_b = tf.TensorShape([2,None])
1249    >>> p_a.__eq__(p_b)
1250    False
1251    >>> t_a.__eq__(p_a)
1252    False
1253
1254    * Two *Unknown shape*, return True.
1255    >>> unk_a = tf.TensorShape(None)
1256    >>> unk_b = tf.TensorShape(None)
1257    >>> unk_a.__eq__(unk_b)
1258    True
1259    >>> unk_a.__eq__(t_a)
1260    False
1261
1262    Args:
1263      other: A `TensorShape` or type that can be converted to `TensorShape`.
1264
1265    Returns:
1266      True if the dimensions are all equal.
1267
1268    Raises:
1269      TypeError if `other` can not be converted to `TensorShape`.
1270    """
1271
1272    try:
1273      other = as_shape(other)
1274    except TypeError:
1275      return NotImplemented
1276    return self._dims == other.dims
1277
1278  def __ne__(self, other):
1279    """Returns True if `self` is known to be different from `other`."""
1280    try:
1281      other = as_shape(other)
1282    except TypeError:
1283      return NotImplemented
1284    if self.rank is None or other.rank is None:
1285      raise ValueError("The inequality of unknown TensorShapes is undefined.")
1286    if self.rank != other.rank:
1287      return True
1288    return self._dims != other.dims
1289
1290  def __reduce__(self):
1291    return TensorShape, (self._dims,)
1292
1293  def __concat__(self, other):
1294    return self.concatenate(other)
1295
1296
1297def as_shape(shape):
1298  """Converts the given object to a TensorShape."""
1299  if isinstance(shape, TensorShape):
1300    return shape
1301  else:
1302    return TensorShape(shape)
1303
1304
1305def unknown_shape(rank=None, **kwargs):
1306  """Returns an unknown TensorShape, optionally with a known rank.
1307
1308  Args:
1309    rank: (Optional) If specified, the number of dimensions in the shape.
1310    **kwargs: For backwards compatibility.
1311
1312  Returns:
1313    An unknown TensorShape.
1314
1315  Raises:
1316    TypeError: In case of invalid arguments.
1317  """
1318  if rank is None and "ndims" in kwargs:
1319    rank = kwargs.pop("ndims")
1320  if kwargs:
1321    raise TypeError("Unknown argument: %s" % kwargs)
1322  if rank is None:
1323    return TensorShape(None)
1324  else:
1325    return TensorShape([Dimension(None)] * rank)
1326