1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helper classes for tensor shape inference."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.core.framework import tensor_shape_pb2
21from tensorflow.python import tf2
22from tensorflow.python.framework import dtypes
23from tensorflow.python.util import compat
24from tensorflow.python.util.tf_export import tf_export
25
26
27_TENSORSHAPE_V2_OVERRIDE = None
28
29
30@tf_export(v1=["enable_v2_tensorshape"])
31def enable_v2_tensorshape():
32  """In TensorFlow 2.0, iterating over a TensorShape instance returns values.
33
34  This enables the new behavior.
35
36  Concretely, `tensor_shape[i]` returned a Dimension instance in V1, but
37  it V2 it returns either an integer, or None.
38
39  Examples:
40
41  ```
42  #######################
43  # If you had this in V1:
44  value = tensor_shape[i].value
45
46  # Do this in V2 instead:
47  value = tensor_shape[i]
48
49  #######################
50  # If you had this in V1:
51  for dim in tensor_shape:
52    value = dim.value
53    print(value)
54
55  # Do this in V2 instead:
56  for value in tensor_shape:
57    print(value)
58
59  #######################
60  # If you had this in V1:
61  dim = tensor_shape[i]
62  dim.assert_is_compatible_with(other_shape)  # or using any other shape method
63
64  # Do this in V2 instead:
65  if tensor_shape.rank is None:
66    dim = Dimension(None)
67  else:
68    dim = tensor_shape.dims[i]
69  dim.assert_is_compatible_with(other_shape)  # or using any other shape method
70
71  # The V2 suggestion above is more explicit, which will save you from
72  # the following trap (present in V1):
73  # you might do in-place modifications to `dim` and expect them to be reflected
74  # in `tensor_shape[i]`, but they would not be.
75  ```
76  """
77  global _TENSORSHAPE_V2_OVERRIDE  # pylint: disable=invalid-name
78  _TENSORSHAPE_V2_OVERRIDE = True
79
80
81@tf_export(v1=["disable_v2_tensorshape"])
82def disable_v2_tensorshape():
83  """Disables the V2 TensorShape behavior and reverts to V1 behavior.
84
85  See docstring for `enable_v2_tensorshape` for details about the new behavior.
86  """
87  global _TENSORSHAPE_V2_OVERRIDE  # pylint: disable=invalid-name
88  _TENSORSHAPE_V2_OVERRIDE = False
89
90
91@tf_export("compat.dimension_value",
92           v1=["dimension_value", "compat.dimension_value"])
93def dimension_value(dimension):
94  """Compatibility utility required to allow for both V1 and V2 behavior in TF.
95
96  Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to
97  coexist with the new behavior. This utility is a bridge between the two.
98
99  When accessing the value of a TensorShape dimension,
100  use this utility, like this:
101
102  ```
103  # If you had this in your V1 code:
104  value = tensor_shape[i].value
105
106  # Use `dimension_value` as direct replacement compatible with both V1 & V2:
107  value = dimension_value(tensor_shape[i])
108
109  # This would be the V2 equivalent:
110  value = tensor_shape[i]  # Warning: this will return the dim value in V2!
111  ```
112
113  Arguments:
114    dimension: Either a `Dimension` instance, an integer, or None.
115
116  Returns:
117    A plain value, i.e. an integer or None.
118  """
119  if isinstance(dimension, Dimension):
120    return dimension.value
121  return dimension
122
123
124@tf_export("compat.dimension_at_index",
125           v1=["dimension_at_index", "compat.dimension_at_index"])
126def dimension_at_index(shape, index):
127  """Compatibility utility required to allow for both V1 and V2 behavior in TF.
128
129  Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to
130  coexist with the new behavior. This utility is a bridge between the two.
131
132  If you want to retrieve the Dimension instance corresponding to a certain
133  index in a TensorShape instance, use this utility, like this:
134
135  ```
136  # If you had this in your V1 code:
137  dim = tensor_shape[i]
138
139  # Use `dimension_at_index` as direct replacement compatible with both V1 & V2:
140  dim = dimension_at_index(tensor_shape, i)
141
142  # Another possibility would be this, but WARNING: it only works if the
143  # tensor_shape instance has a defined rank.
144  dim = tensor_shape.dims[i]  # `dims` may be None if the rank is undefined!
145
146  # In native V2 code, we recommend instead being more explicit:
147  if tensor_shape.rank is None:
148    dim = Dimension(None)
149  else:
150    dim = tensor_shape.dims[i]
151
152  # Being more explicit will save you from the following trap (present in V1):
153  # you might do in-place modifications to `dim` and expect them to be reflected
154  # in `tensor_shape[i]`, but they would not be (as the Dimension object was
155  # instantiated on the fly.
156  ```
157
158  Arguments:
159    shape: A TensorShape instance.
160    index: An integer index.
161
162  Returns:
163    A dimension object.
164  """
165  assert isinstance(shape, TensorShape)
166  if shape.rank is None:
167    return Dimension(None)
168  else:
169    return shape.dims[index]
170
171
172@tf_export(v1=["Dimension"])
173class Dimension(object):
174  """Represents the value of one dimension in a TensorShape."""
175
176  def __init__(self, value):
177    """Creates a new Dimension with the given value."""
178    if value is None:
179      self._value = None
180    elif isinstance(value, Dimension):
181      self._value = value.value
182    elif isinstance(value, dtypes.DType):
183      raise TypeError("Cannot convert %s to Dimension" % value)
184    else:
185      self._value = int(value)
186      if (not isinstance(value, compat.bytes_or_text_types) and
187          self._value != value):
188        raise ValueError("Ambiguous dimension: %s" % value)
189      if self._value < 0:
190        raise ValueError("Dimension %d must be >= 0" % self._value)
191
192  def __repr__(self):
193    return "Dimension(%s)" % repr(self._value)
194
195  def __str__(self):
196    value = self._value
197    return "?" if value is None else str(value)
198
199  def __eq__(self, other):
200    """Returns true if `other` has the same known value as this Dimension."""
201    try:
202      other = as_dimension(other)
203    except (TypeError, ValueError):
204      return NotImplemented
205    if self._value is None or other.value is None:
206      return None
207    return self._value == other.value
208
209  def __ne__(self, other):
210    """Returns true if `other` has a different known value from `self`."""
211    try:
212      other = as_dimension(other)
213    except (TypeError, ValueError):
214      return NotImplemented
215    if self._value is None or other.value is None:
216      return None
217    return self._value != other.value
218
219  def __int__(self):
220    return self._value
221
222  # This is needed for Windows.
223  # See https://github.com/tensorflow/tensorflow/pull/9780
224  def __long__(self):
225    return self._value
226
227  def __index__(self):
228    # Allow use in Python 3 range
229    return self._value
230
231  @property
232  def value(self):
233    """The value of this dimension, or None if it is unknown."""
234    return self._value
235
236  def is_compatible_with(self, other):
237    """Returns true if `other` is compatible with this Dimension.
238
239    Two known Dimensions are compatible if they have the same value.
240    An unknown Dimension is compatible with all other Dimensions.
241
242    Args:
243      other: Another Dimension.
244
245    Returns:
246      True if this Dimension and `other` are compatible.
247    """
248    other = as_dimension(other)
249    return (self._value is None or other.value is None or
250            self._value == other.value)
251
252  def assert_is_compatible_with(self, other):
253    """Raises an exception if `other` is not compatible with this Dimension.
254
255    Args:
256      other: Another Dimension.
257
258    Raises:
259      ValueError: If `self` and `other` are not compatible (see
260        is_compatible_with).
261    """
262    if not self.is_compatible_with(other):
263      raise ValueError("Dimensions %s and %s are not compatible" % (self,
264                                                                    other))
265
266  def merge_with(self, other):
267    """Returns a Dimension that combines the information in `self` and `other`.
268
269    Dimensions are combined as follows:
270
271    ```python
272    tf.Dimension(n)   .merge_with(tf.Dimension(n))     == tf.Dimension(n)
273    tf.Dimension(n)   .merge_with(tf.Dimension(None))  == tf.Dimension(n)
274    tf.Dimension(None).merge_with(tf.Dimension(n))     == tf.Dimension(n)
275    # equivalent to tf.Dimension(None)
276    tf.Dimension(None).merge_with(tf.Dimension(None))
277
278    # raises ValueError for n != m
279    tf.Dimension(n)   .merge_with(tf.Dimension(m))
280    ```
281
282    Args:
283      other: Another Dimension.
284
285    Returns:
286      A Dimension containing the combined information of `self` and
287      `other`.
288
289    Raises:
290      ValueError: If `self` and `other` are not compatible (see
291        is_compatible_with).
292    """
293    other = as_dimension(other)
294    self.assert_is_compatible_with(other)
295    if self._value is None:
296      return Dimension(other.value)
297    else:
298      return Dimension(self._value)
299
300  def __add__(self, other):
301    """Returns the sum of `self` and `other`.
302
303    Dimensions are summed as follows:
304
305    ```python
306    tf.Dimension(m)    + tf.Dimension(n)     == tf.Dimension(m + n)
307    tf.Dimension(m)    + tf.Dimension(None)  # equiv. to tf.Dimension(None)
308    tf.Dimension(None) + tf.Dimension(n)     # equiv. to tf.Dimension(None)
309    tf.Dimension(None) + tf.Dimension(None)  # equiv. to tf.Dimension(None)
310    ```
311
312    Args:
313      other: Another Dimension, or a value accepted by `as_dimension`.
314
315    Returns:
316      A Dimension whose value is the sum of `self` and `other`.
317    """
318    other = as_dimension(other)
319    if self._value is None or other.value is None:
320      return Dimension(None)
321    else:
322      return Dimension(self._value + other.value)
323
324  def __radd__(self, other):
325    """Returns the sum of `other` and `self`.
326
327    Args:
328      other: Another Dimension, or a value accepted by `as_dimension`.
329
330    Returns:
331      A Dimension whose value is the sum of `self` and `other`.
332    """
333    return self + other
334
335  def __sub__(self, other):
336    """Returns the subtraction of `other` from `self`.
337
338    Dimensions are subtracted as follows:
339
340    ```python
341    tf.Dimension(m)    - tf.Dimension(n)     == tf.Dimension(m - n)
342    tf.Dimension(m)    - tf.Dimension(None)  # equiv. to tf.Dimension(None)
343    tf.Dimension(None) - tf.Dimension(n)     # equiv. to tf.Dimension(None)
344    tf.Dimension(None) - tf.Dimension(None)  # equiv. to tf.Dimension(None)
345    ```
346
347    Args:
348      other: Another Dimension, or a value accepted by `as_dimension`.
349
350    Returns:
351      A Dimension whose value is the subtraction of `other` from `self`.
352    """
353    other = as_dimension(other)
354    if self._value is None or other.value is None:
355      return Dimension(None)
356    else:
357      return Dimension(self._value - other.value)
358
359  def __rsub__(self, other):
360    """Returns the subtraction of `self` from `other`.
361
362    Args:
363      other: Another Dimension, or a value accepted by `as_dimension`.
364
365    Returns:
366      A Dimension whose value is the subtraction of `self` from `other`.
367    """
368    other = as_dimension(other)
369    if self._value is None or other.value is None:
370      return Dimension(None)
371    else:
372      return Dimension(other.value - self._value)
373
374  def __mul__(self, other):
375    """Returns the product of `self` and `other`.
376
377    Dimensions are summed as follows:
378
379    ```python
380    tf.Dimension(m)    * tf.Dimension(n)     == tf.Dimension(m * n)
381    tf.Dimension(m)    * tf.Dimension(None)  # equiv. to tf.Dimension(None)
382    tf.Dimension(None) * tf.Dimension(n)     # equiv. to tf.Dimension(None)
383    tf.Dimension(None) * tf.Dimension(None)  # equiv. to tf.Dimension(None)
384    ```
385
386    Args:
387      other: Another Dimension, or a value accepted by `as_dimension`.
388
389    Returns:
390      A Dimension whose value is the product of `self` and `other`.
391    """
392    try:
393      other = as_dimension(other)
394    except (TypeError, ValueError):
395      return NotImplemented
396
397    if self._value is None or other.value is None:
398      return Dimension(None)
399    else:
400      return Dimension(self._value * other.value)
401
402  def __rmul__(self, other):
403    """Returns the product of `self` and `other`.
404
405    Args:
406      other: Another Dimension, or a value accepted by `as_dimension`.
407
408    Returns:
409      A Dimension whose value is the product of `self` and `other`.
410    """
411    return self * other
412
413  def __floordiv__(self, other):
414    """Returns the quotient of `self` and `other` rounded down.
415
416    Dimensions are divided as follows:
417
418    ```python
419    tf.Dimension(m)    // tf.Dimension(n)     == tf.Dimension(m // n)
420    tf.Dimension(m)    // tf.Dimension(None)  # equiv. to tf.Dimension(None)
421    tf.Dimension(None) // tf.Dimension(n)     # equiv. to tf.Dimension(None)
422    tf.Dimension(None) // tf.Dimension(None)  # equiv. to tf.Dimension(None)
423    ```
424
425    Args:
426      other: Another Dimension, or a value accepted by `as_dimension`.
427
428    Returns:
429      A `Dimension` whose value is the integer quotient of `self` and `other`.
430    """
431    try:
432      other = as_dimension(other)
433    except (TypeError, ValueError):
434      return NotImplemented
435    if self._value is None or other.value is None:
436      return Dimension(None)
437    else:
438      return Dimension(self._value // other.value)
439
440  def __rfloordiv__(self, other):
441    """Returns the quotient of `other` and `self` rounded down.
442
443    Args:
444      other: Another Dimension, or a value accepted by `as_dimension`.
445
446    Returns:
447      A `Dimension` whose value is the integer quotient of `self` and `other`.
448    """
449    other = as_dimension(other)
450    if self._value is None or other.value is None:
451      return Dimension(None)
452    else:
453      return Dimension(other.value // self._value)
454
455  def __div__(self, other):
456    """DEPRECATED: Use `__floordiv__` via `x // y` instead.
457
458    This function exists only for backwards compatibility purposes; new code
459    should use `__floordiv__` via the syntax `x // y`.  Using `x // y`
460    communicates clearly that the result rounds down, and is forward compatible
461    to Python 3.
462
463    Args:
464      other: Another `Dimension`.
465
466    Returns:
467      A `Dimension` whose value is the integer quotient of `self` and `other`.
468    """
469    return self // other
470
471  def __rdiv__(self, other):
472    """Use `__floordiv__` via `x // y` instead.
473
474    This function exists only to have a better error message. Instead of:
475    `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`,
476    this function will explicitly call for usage of `//` instead.
477
478    Args:
479      other: Another `Dimension`.
480
481    Raises:
482      TypeError.
483    """
484    raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', "
485                    "please use // instead".format(type(other).__name__))
486
487  def __truediv__(self, other):
488    """Use `__floordiv__` via `x // y` instead.
489
490    This function exists only to have a better error message. Instead of:
491    `TypeError: unsupported operand type(s) for /: 'Dimension' and 'int'`,
492    this function will explicitly call for usage of `//` instead.
493
494    Args:
495      other: Another `Dimension`.
496
497    Raises:
498      TypeError.
499    """
500    raise TypeError("unsupported operand type(s) for /: 'Dimension' and '{}', "
501                    "please use // instead".format(type(other).__name__))
502
503  def __rtruediv__(self, other):
504    """Use `__floordiv__` via `x // y` instead.
505
506    This function exists only to have a better error message. Instead of:
507    `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`,
508    this function will explicitly call for usage of `//` instead.
509
510    Args:
511      other: Another `Dimension`.
512
513    Raises:
514      TypeError.
515    """
516    raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', "
517                    "please use // instead".format(type(other).__name__))
518
519  def __mod__(self, other):
520    """Returns `self` modulo `other`.
521
522    Dimension moduli are computed as follows:
523
524    ```python
525    tf.Dimension(m)    % tf.Dimension(n)     == tf.Dimension(m % n)
526    tf.Dimension(m)    % tf.Dimension(None)  # equiv. to tf.Dimension(None)
527    tf.Dimension(None) % tf.Dimension(n)     # equiv. to tf.Dimension(None)
528    tf.Dimension(None) % tf.Dimension(None)  # equiv. to tf.Dimension(None)
529    ```
530
531    Args:
532      other: Another Dimension, or a value accepted by `as_dimension`.
533
534    Returns:
535      A Dimension whose value is `self` modulo `other`.
536    """
537    try:
538      other = as_dimension(other)
539    except (TypeError, ValueError):
540      return NotImplemented
541    if self._value is None or other.value is None:
542      return Dimension(None)
543    else:
544      return Dimension(self._value % other.value)
545
546  def __rmod__(self, other):
547    """Returns `other` modulo `self`.
548
549    Args:
550      other: Another Dimension, or a value accepted by `as_dimension`.
551
552    Returns:
553      A Dimension whose value is `other` modulo `self`.
554    """
555    try:
556      other = as_dimension(other)
557    except (TypeError, ValueError):
558      return NotImplemented
559    return other % self
560
561  def __lt__(self, other):
562    """Returns True if `self` is known to be less than `other`.
563
564    Dimensions are compared as follows:
565
566    ```python
567    (tf.Dimension(m)    < tf.Dimension(n))    == (m < n)
568    (tf.Dimension(m)    < tf.Dimension(None)) == None
569    (tf.Dimension(None) < tf.Dimension(n))    == None
570    (tf.Dimension(None) < tf.Dimension(None)) == None
571    ```
572
573    Args:
574      other: Another Dimension.
575
576    Returns:
577      The value of `self.value < other.value` if both are known, otherwise
578      None.
579    """
580    other = as_dimension(other)
581    if self._value is None or other.value is None:
582      return None
583    else:
584      return self._value < other.value
585
586  def __le__(self, other):
587    """Returns True if `self` is known to be less than or equal to `other`.
588
589    Dimensions are compared as follows:
590
591    ```python
592    (tf.Dimension(m)    <= tf.Dimension(n))    == (m <= n)
593    (tf.Dimension(m)    <= tf.Dimension(None)) == None
594    (tf.Dimension(None) <= tf.Dimension(n))    == None
595    (tf.Dimension(None) <= tf.Dimension(None)) == None
596    ```
597
598    Args:
599      other: Another Dimension.
600
601    Returns:
602      The value of `self.value <= other.value` if both are known, otherwise
603      None.
604    """
605    other = as_dimension(other)
606    if self._value is None or other.value is None:
607      return None
608    else:
609      return self._value <= other.value
610
611  def __gt__(self, other):
612    """Returns True if `self` is known to be greater than `other`.
613
614    Dimensions are compared as follows:
615
616    ```python
617    (tf.Dimension(m)    > tf.Dimension(n))    == (m > n)
618    (tf.Dimension(m)    > tf.Dimension(None)) == None
619    (tf.Dimension(None) > tf.Dimension(n))    == None
620    (tf.Dimension(None) > tf.Dimension(None)) == None
621    ```
622
623    Args:
624      other: Another Dimension.
625
626    Returns:
627      The value of `self.value > other.value` if both are known, otherwise
628      None.
629    """
630    other = as_dimension(other)
631    if self._value is None or other.value is None:
632      return None
633    else:
634      return self._value > other.value
635
636  def __ge__(self, other):
637    """Returns True if `self` is known to be greater than or equal to `other`.
638
639    Dimensions are compared as follows:
640
641    ```python
642    (tf.Dimension(m)    >= tf.Dimension(n))    == (m >= n)
643    (tf.Dimension(m)    >= tf.Dimension(None)) == None
644    (tf.Dimension(None) >= tf.Dimension(n))    == None
645    (tf.Dimension(None) >= tf.Dimension(None)) == None
646    ```
647
648    Args:
649      other: Another Dimension.
650
651    Returns:
652      The value of `self.value >= other.value` if both are known, otherwise
653      None.
654    """
655    other = as_dimension(other)
656    if self._value is None or other.value is None:
657      return None
658    else:
659      return self._value >= other.value
660
661  def __reduce__(self):
662    return Dimension, (self._value,)
663
664
665def as_dimension(value):
666  """Converts the given value to a Dimension.
667
668  A Dimension input will be returned unmodified.
669  An input of `None` will be converted to an unknown Dimension.
670  An integer input will be converted to a Dimension with that value.
671
672  Args:
673    value: The value to be converted.
674
675  Returns:
676    A Dimension corresponding to the given value.
677  """
678  if isinstance(value, Dimension):
679    return value
680  else:
681    return Dimension(value)
682
683
684@tf_export("TensorShape")
685class TensorShape(object):
686  """Represents the shape of a `Tensor`.
687
688  A `TensorShape` represents a possibly-partial shape specification for a
689  `Tensor`. It may be one of the following:
690
691  * *Fully-known shape:* has a known number of dimensions and a known size
692    for each dimension. e.g. `TensorShape([16, 256])`
693  * *Partially-known shape:* has a known number of dimensions, and an unknown
694    size for one or more dimension. e.g. `TensorShape([None, 256])`
695  * *Unknown shape:* has an unknown number of dimensions, and an unknown
696    size in all dimensions. e.g. `TensorShape(None)`
697
698  If a tensor is produced by an operation of type `"Foo"`, its shape
699  may be inferred if there is a registered shape function for
700  `"Foo"`. See [Shape
701  functions](https://tensorflow.org/extend/adding_an_op#shape_functions_in_c)
702  for details of shape functions and how to register them. Alternatively,
703  the shape may be set explicitly using `tf.Tensor.set_shape`.
704  """
705
706  def __init__(self, dims):
707    """Creates a new TensorShape with the given dimensions.
708
709    Args:
710      dims: A list of Dimensions, or None if the shape is unspecified.
711
712    Raises:
713      TypeError: If dims cannot be converted to a list of dimensions.
714    """
715    if dims is None:
716      self._dims = None
717    elif isinstance(dims, compat.bytes_or_text_types):
718      raise TypeError("A string has ambiguous TensorShape, please wrap in a "
719                      "list or convert to an int: %s" % dims)
720    elif isinstance(dims, tensor_shape_pb2.TensorShapeProto):
721      if dims.unknown_rank:
722        self._dims = None
723      else:
724        self._dims = [
725            # Protos store variable-size dimensions as -1
726            as_dimension(dim.size if dim.size != -1 else None)
727            for dim in dims.dim
728        ]
729    elif isinstance(dims, TensorShape):
730      self._dims = dims.dims
731    else:
732      try:
733        dims_iter = iter(dims)
734      except TypeError:
735        # Treat as a singleton dimension
736        self._dims = [as_dimension(dims)]
737      else:
738        # Got a list of dimensions
739        self._dims = [as_dimension(d) for d in dims_iter]
740
741  @property
742  def _v2_behavior(self):
743    if _TENSORSHAPE_V2_OVERRIDE is None:
744      return tf2.enabled()
745    return _TENSORSHAPE_V2_OVERRIDE
746
747  def __repr__(self):
748    if self._v2_behavior:
749      if self._dims is not None:
750        return "TensorShape(%r)" % [dim.value for dim in self._dims]
751      else:
752        return "TensorShape(None)"
753    else:
754      return "TensorShape(%r)" % self._dims
755
756  def __str__(self):
757    if self.rank is None:
758      return "<unknown>"
759    elif self.rank == 1:
760      if self._v2_behavior:
761        return "(%s,)" % self._dims[0].value
762      else:
763        return "(%s,)" % self._dims[0]
764    else:
765      if self._v2_behavior:
766        return "(%s)" % ", ".join(str(d.value) for d in self._dims)
767      else:
768        return "(%s)" % ", ".join(str(d) for d in self._dims)
769
770  @property
771  def rank(self):
772    """Returns the rank of this shape, or None if it is unspecified."""
773    if self._dims is not None:
774      return len(self._dims)
775    return None
776
777  @property
778  def dims(self):
779    """Returns a list of Dimensions, or None if the shape is unspecified."""
780    return self._dims
781
782  @dims.setter
783  def dims(self, dims):
784    self._dims = dims
785
786  @property
787  def ndims(self):
788    """Deprecated accessor for `rank`."""
789    return self.rank
790
791  def __len__(self):
792    """Returns the rank of this shape, or raises ValueError if unspecified."""
793    if self._dims is None:
794      raise ValueError("Cannot take the length of shape with unknown rank.")
795    return len(self._dims)
796
797  def __bool__(self):
798    """Returns True if this shape contains non-zero information."""
799    return self._dims is not None
800
801  # Python 3 wants __bool__, Python 2.7 wants __nonzero__
802  __nonzero__ = __bool__
803
804  def __iter__(self):
805    """Returns `self.dims` if the rank is known, otherwise raises ValueError."""
806    if self._dims is None:
807      raise ValueError("Cannot iterate over a shape with unknown rank.")
808    else:
809      if self._v2_behavior:
810        return iter(d.value for d in self._dims)
811      else:
812        return iter(d for d in self._dims)
813
814  def __getitem__(self, key):
815    """Returns the value of a dimension or a shape, depending on the key.
816
817    Args:
818      key: If `key` is an integer, returns the dimension at that index;
819        otherwise if `key` is a slice, returns a TensorShape whose
820        dimensions are those selected by the slice from `self`.
821
822    Returns:
823      An integer if `key` is an integer, or a `TensorShape` if `key` is a
824      slice.
825
826    Raises:
827      ValueError: If `key` is a slice and `self` is completely unknown and
828        the step is set.
829    """
830    if self._dims is not None:
831      if isinstance(key, slice):
832        return TensorShape(self._dims[key])
833      else:
834        if self._v2_behavior:
835          return self._dims[key].value
836        else:
837          return self._dims[key]
838    else:
839      if isinstance(key, slice):
840        start = key.start if key.start is not None else 0
841        stop = key.stop
842
843        if key.step is not None:
844          # TODO(mrry): Handle these maybe.
845          raise ValueError("Steps are not yet handled")
846        if stop is None:
847          # NOTE(mrry): This implies that TensorShape(None) is compatible with
848          # TensorShape(None)[1:], which is obviously not true. It would be
849          # possible to track the number of dimensions symbolically,
850          # and perhaps we should do that.
851          return unknown_shape()
852        elif start < 0 or stop < 0:
853          # TODO(mrry): Handle this better, as it will be useful for handling
854          # suffixes of otherwise unknown shapes.
855          return unknown_shape()
856        else:
857          return unknown_shape(rank=stop - start)
858      else:
859        if self._v2_behavior:
860          return None
861        else:
862          return Dimension(None)
863
864  def num_elements(self):
865    """Returns the total number of elements, or none for incomplete shapes."""
866    if self.is_fully_defined():
867      size = 1
868      for dim in self._dims:
869        size *= dim.value
870      return size
871    else:
872      return None
873
874  def merge_with(self, other):
875    """Returns a `TensorShape` combining the information in `self` and `other`.
876
877    The dimensions in `self` and `other` are merged elementwise,
878    according to the rules defined for `Dimension.merge_with()`.
879
880    Args:
881      other: Another `TensorShape`.
882
883    Returns:
884      A `TensorShape` containing the combined information of `self` and
885      `other`.
886
887    Raises:
888      ValueError: If `self` and `other` are not compatible.
889    """
890    other = as_shape(other)
891    if self._dims is None:
892      return other
893    else:
894      try:
895        self.assert_same_rank(other)
896        new_dims = []
897        for i, dim in enumerate(self._dims):
898          new_dims.append(dim.merge_with(other[i]))
899        return TensorShape(new_dims)
900      except ValueError:
901        raise ValueError("Shapes %s and %s are not compatible" % (self, other))
902
903  def concatenate(self, other):
904    """Returns the concatenation of the dimension in `self` and `other`.
905
906    *N.B.* If either `self` or `other` is completely unknown,
907    concatenation will discard information about the other shape. In
908    future, we might support concatenation that preserves this
909    information for use with slicing.
910
911    Args:
912      other: Another `TensorShape`.
913
914    Returns:
915      A `TensorShape` whose dimensions are the concatenation of the
916      dimensions in `self` and `other`.
917    """
918    # TODO(mrry): Handle the case where we concatenate a known shape with a
919    # completely unknown shape, so that we can use the partial information.
920    other = as_shape(other)
921    if self._dims is None or other.dims is None:
922      return unknown_shape()
923    else:
924      return TensorShape(self._dims + other.dims)
925
926  def assert_same_rank(self, other):
927    """Raises an exception if `self` and `other` do not have compatible ranks.
928
929    Args:
930      other: Another `TensorShape`.
931
932    Raises:
933      ValueError: If `self` and `other` do not represent shapes with the
934        same rank.
935    """
936    other = as_shape(other)
937    if self.rank is not None and other.rank is not None:
938      if self.rank != other.rank:
939        raise ValueError("Shapes %s and %s must have the same rank" % (self,
940                                                                       other))
941
942  def assert_has_rank(self, rank):
943    """Raises an exception if `self` is not compatible with the given `rank`.
944
945    Args:
946      rank: An integer.
947
948    Raises:
949      ValueError: If `self` does not represent a shape with the given `rank`.
950    """
951    if self.rank not in (None, rank):
952      raise ValueError("Shape %s must have rank %d" % (self, rank))
953
954  def with_rank(self, rank):
955    """Returns a shape based on `self` with the given rank.
956
957    This method promotes a completely unknown shape to one with a
958    known rank.
959
960    Args:
961      rank: An integer.
962
963    Returns:
964      A shape that is at least as specific as `self` with the given rank.
965
966    Raises:
967      ValueError: If `self` does not represent a shape with the given `rank`.
968    """
969    try:
970      return self.merge_with(unknown_shape(rank=rank))
971    except ValueError:
972      raise ValueError("Shape %s must have rank %d" % (self, rank))
973
974  def with_rank_at_least(self, rank):
975    """Returns a shape based on `self` with at least the given rank.
976
977    Args:
978      rank: An integer.
979
980    Returns:
981      A shape that is at least as specific as `self` with at least the given
982      rank.
983
984    Raises:
985      ValueError: If `self` does not represent a shape with at least the given
986        `rank`.
987    """
988    if self.rank is not None and self.rank < rank:
989      raise ValueError("Shape %s must have rank at least %d" % (self, rank))
990    else:
991      return self
992
993  def with_rank_at_most(self, rank):
994    """Returns a shape based on `self` with at most the given rank.
995
996    Args:
997      rank: An integer.
998
999    Returns:
1000      A shape that is at least as specific as `self` with at most the given
1001      rank.
1002
1003    Raises:
1004      ValueError: If `self` does not represent a shape with at most the given
1005        `rank`.
1006    """
1007    if self.rank is not None and self.rank > rank:
1008      raise ValueError("Shape %s must have rank at most %d" % (self, rank))
1009    else:
1010      return self
1011
1012  def is_compatible_with(self, other):
1013    """Returns True iff `self` is compatible with `other`.
1014
1015    Two possibly-partially-defined shapes are compatible if there
1016    exists a fully-defined shape that both shapes can represent. Thus,
1017    compatibility allows the shape inference code to reason about
1018    partially-defined shapes. For example:
1019
1020    * TensorShape(None) is compatible with all shapes.
1021
1022    * TensorShape([None, None]) is compatible with all two-dimensional
1023      shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is
1024      not compatible with, for example, TensorShape([None]) or
1025      TensorShape([None, None, None]).
1026
1027    * TensorShape([32, None]) is compatible with all two-dimensional shapes
1028      with size 32 in the 0th dimension, and also TensorShape([None, None])
1029      and TensorShape(None). It is not compatible with, for example,
1030      TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]).
1031
1032    * TensorShape([32, 784]) is compatible with itself, and also
1033      TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None,
1034      None]) and TensorShape(None). It is not compatible with, for example,
1035      TensorShape([32, 1, 784]) or TensorShape([None]).
1036
1037    The compatibility relation is reflexive and symmetric, but not
1038    transitive. For example, TensorShape([32, 784]) is compatible with
1039    TensorShape(None), and TensorShape(None) is compatible with
1040    TensorShape([4, 4]), but TensorShape([32, 784]) is not compatible with
1041    TensorShape([4, 4]).
1042
1043    Args:
1044      other: Another TensorShape.
1045
1046    Returns:
1047      True iff `self` is compatible with `other`.
1048
1049    """
1050    other = as_shape(other)
1051    if self._dims is not None and other.dims is not None:
1052      if self.rank != other.rank:
1053        return False
1054      for x_dim, y_dim in zip(self._dims, other.dims):
1055        if not x_dim.is_compatible_with(y_dim):
1056          return False
1057    return True
1058
1059  def assert_is_compatible_with(self, other):
1060    """Raises exception if `self` and `other` do not represent the same shape.
1061
1062    This method can be used to assert that there exists a shape that both
1063    `self` and `other` represent.
1064
1065    Args:
1066      other: Another TensorShape.
1067
1068    Raises:
1069      ValueError: If `self` and `other` do not represent the same shape.
1070    """
1071    if not self.is_compatible_with(other):
1072      raise ValueError("Shapes %s and %s are incompatible" % (self, other))
1073
1074  def most_specific_compatible_shape(self, other):
1075    """Returns the most specific TensorShape compatible with `self` and `other`.
1076
1077    * TensorShape([None, 1]) is the most specific TensorShape compatible with
1078      both TensorShape([2, 1]) and TensorShape([5, 1]). Note that
1079      TensorShape(None) is also compatible with above mentioned TensorShapes.
1080
1081    * TensorShape([1, 2, 3]) is the most specific TensorShape compatible with
1082      both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more
1083      less specific TensorShapes compatible with above mentioned TensorShapes,
1084      e.g. TensorShape([1, 2, None]), TensorShape(None).
1085
1086    Args:
1087      other: Another `TensorShape`.
1088
1089    Returns:
1090      A `TensorShape` which is the most specific compatible shape of `self`
1091      and `other`.
1092    """
1093
1094    other = as_shape(other)
1095    if self._dims is None or other.dims is None or self.rank != other.rank:
1096      return unknown_shape()
1097
1098    dims = [(Dimension(None))] * self.rank
1099    for i, (d1, d2) in enumerate(zip(self._dims, other.dims)):
1100      if d1 is not None and d2 is not None and d1 == d2:
1101        dims[i] = d1
1102    return TensorShape(dims)
1103
1104  def is_fully_defined(self):
1105    """Returns True iff `self` is fully defined in every dimension."""
1106    return (self._dims is not None and all(dim.value is not None
1107                                           for dim in self._dims))
1108
1109  def assert_is_fully_defined(self):
1110    """Raises an exception if `self` is not fully defined in every dimension.
1111
1112    Raises:
1113      ValueError: If `self` does not have a known value for every dimension.
1114    """
1115    if not self.is_fully_defined():
1116      raise ValueError("Shape %s is not fully defined" % self)
1117
1118  def as_list(self):
1119    """Returns a list of integers or `None` for each dimension.
1120
1121    Returns:
1122      A list of integers or `None` for each dimension.
1123
1124    Raises:
1125      ValueError: If `self` is an unknown shape with an unknown rank.
1126    """
1127    if self._dims is None:
1128      raise ValueError("as_list() is not defined on an unknown TensorShape.")
1129    return [dim.value for dim in self._dims]
1130
1131  def as_proto(self):
1132    """Returns this shape as a `TensorShapeProto`."""
1133    if self._dims is None:
1134      return tensor_shape_pb2.TensorShapeProto(unknown_rank=True)
1135    else:
1136      return tensor_shape_pb2.TensorShapeProto(dim=[
1137          tensor_shape_pb2.TensorShapeProto.Dim(size=-1
1138                                                if d.value is None else d.value)
1139          for d in self._dims
1140      ])
1141
1142  def __eq__(self, other):
1143    """Returns True if `self` is equivalent to `other`."""
1144    try:
1145      other = as_shape(other)
1146    except TypeError:
1147      return NotImplemented
1148    return self._dims == other.dims
1149
1150  def __ne__(self, other):
1151    """Returns True if `self` is known to be different from `other`."""
1152    try:
1153      other = as_shape(other)
1154    except TypeError:
1155      return NotImplemented
1156    if self.rank is None or other.rank is None:
1157      raise ValueError("The inequality of unknown TensorShapes is undefined.")
1158    if self.rank != other.rank:
1159      return True
1160    return self._dims != other.dims
1161
1162  def __reduce__(self):
1163    return TensorShape, (self._dims,)
1164
1165  def __concat__(self, other):
1166    return self.concatenate(other)
1167
1168
1169def as_shape(shape):
1170  """Converts the given object to a TensorShape."""
1171  if isinstance(shape, TensorShape):
1172    return shape
1173  else:
1174    return TensorShape(shape)
1175
1176
1177def unknown_shape(rank=None, **kwargs):
1178  """Returns an unknown TensorShape, optionally with a known rank.
1179
1180  Args:
1181    rank: (Optional) If specified, the number of dimensions in the shape.
1182    **kwargs: For backwards compatibility.
1183
1184  Returns:
1185    An unknown TensorShape.
1186
1187  Raises:
1188    TypeError: In case of invalid arguments.
1189  """
1190  if rank is None and "ndims" in kwargs:
1191    rank = kwargs.pop("ndims")
1192  if kwargs:
1193    raise TypeError("Unknown argument: %s" % kwargs)
1194  if rank is None:
1195    return TensorShape(None)
1196  else:
1197    return TensorShape([Dimension(None)] * rank)
1198
1199
1200def scalar():
1201  """Returns a shape representing a scalar."""
1202  return TensorShape([])
1203
1204
1205def vector(length):
1206  """Returns a shape representing a vector.
1207
1208  Args:
1209    length: The length of the vector, which may be None if unknown.
1210
1211  Returns:
1212    A TensorShape representing a vector of the given length.
1213  """
1214  return TensorShape([length])
1215
1216
1217def matrix(rows, cols):
1218  """Returns a shape representing a matrix.
1219
1220  Args:
1221    rows: The number of rows in the matrix, which may be None if unknown.
1222    cols: The number of columns in the matrix, which may be None if unknown.
1223
1224  Returns:
1225    A TensorShape representing a matrix of the given size.
1226  """
1227  return TensorShape([rows, cols])
1228