1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/tensor_shape.h"
17 
18 #include "tensorflow/core/framework/bounds_check.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/util/overflow.h"
25 
26 namespace tensorflow {
27 
28 // TensorShape and PartialTensorShape should have no fields beyond
29 // TensorShapeRep.  In particular, their sizes should be the same.
30 static_assert(sizeof(TensorShapeRep) == sizeof(TensorShape),
31               "TensorShape must have no fields beyond TensorShapeRep");
32 static_assert(sizeof(TensorShapeRep) == sizeof(PartialTensorShape),
33               "PartialTensorShape must have no fields beyond TensorShapeRep");
34 
35 template <class Shape>
AppendTo(const TensorShapeBase<Shape> & s,gtl::InlinedVector<int64,8> * vals)36 static void AppendTo(const TensorShapeBase<Shape>& s,
37                      gtl::InlinedVector<int64, 8>* vals) {
38   for (auto dim : s) {
39     vals->push_back(dim.size);
40   }
41 }
42 
CheckDimsEqual(int NDIMS) const43 void TensorShape::CheckDimsEqual(int NDIMS) const {
44   CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS << " dimensions"
45                           << " from a tensor of " << dims() << " dimensions";
46 }
47 
CheckDimsAtLeast(int NDIMS) const48 void TensorShape::CheckDimsAtLeast(int NDIMS) const {
49   CHECK_GE(NDIMS, dims()) << "Asking for tensor of at least " << NDIMS
50                           << " dimensions from a tensor of " << dims()
51                           << " dimensions";
52 }
53 
54 template <class Shape>
IsValid(const TensorShapeProto & proto)55 bool TensorShapeBase<Shape>::IsValid(const TensorShapeProto& proto) {
56   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
57   // unknown_shape() set, and it seems hard to remove this without backwards
58   // compatibility issues.
59   if (kIsPartial && proto.unknown_rank()) return proto.dim_size() == 0;
60   int64 num_elements = 1;
61   if (proto.dim().size() > MaxDimensions()) return false;
62   for (const auto& d : proto.dim()) {
63     if (d.size() < (kIsPartial ? -1 : 0)) return false;
64     if (d.size() == -1) {
65       num_elements = -1;
66     } else if (!kIsPartial || num_elements >= 0) {
67       num_elements = MultiplyWithoutOverflow(num_elements, d.size());
68       if (num_elements < 0) return false;
69     }
70   }
71   return true;
72 }
73 
74 template <class Shape>
IsValidShape(const TensorShapeProto & proto)75 Status TensorShapeBase<Shape>::IsValidShape(const TensorShapeProto& proto) {
76   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
77   // unknown_shape() set, and it seems hard to remove this without backwards
78   // compatibility issues.
79   if (kIsPartial && proto.unknown_rank()) {
80     if (proto.dim_size() > 0) {
81       return errors::InvalidArgument(
82           "An unknown shape must not have any dimensions set.");
83     }
84     return Status::OK();
85   }
86   int64 num_elements = 1;
87   if (proto.dim().size() > MaxDimensions()) {
88     return errors::InvalidArgument("Shape ", DebugString(proto),
89                                    " has too many dimensions");
90   }
91   for (const auto& d : proto.dim()) {
92     if (d.size() < (kIsPartial ? -1 : 0)) {
93       if (kIsPartial) {
94         return errors::InvalidArgument(
95             "Shape ", DebugString(proto),
96             " has dimensions with values below -1 (where -1 means unknown)");
97       } else {
98         return errors::InvalidArgument("Shape ", DebugString(proto),
99                                        " is not fully defined");
100       }
101     }
102     if (d.size() == -1) {
103       num_elements = -1;
104     } else if (!kIsPartial || num_elements >= 0) {
105       num_elements = MultiplyWithoutOverflow(num_elements, d.size());
106       if (num_elements < 0) {
107         return errors::InvalidArgument(
108             "Shape ", DebugString(proto),
109             " is too large (more than 2**63 - 1 entries)");
110       }
111     }
112   }
113   return Status::OK();
114 }
115 
116 template <class Shape>
TensorShapeBase(const TensorShapeProto & proto)117 TensorShapeBase<Shape>::TensorShapeBase(const TensorShapeProto& proto) {
118   set_tag(REP16);
119   set_data_type(DT_INVALID);
120   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
121   // unknown_shape() set, and it seems hard to remove this without backwards
122   // compatibility issues.
123   if (kIsPartial && proto.unknown_rank()) {
124     set_ndims_byte(kUnknownRank);
125     set_num_elements(-1);
126   } else {
127     set_ndims_byte(0);
128     set_num_elements(1);
129     for (const auto& d : proto.dim()) {
130       AddDim(d.size());
131     }
132   }
133 }
134 
135 template <class Shape>
TensorShapeBase(gtl::ArraySlice<int64> dim_sizes)136 TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64> dim_sizes) {
137   set_tag(REP16);
138   set_data_type(DT_INVALID);
139   InitDims(dim_sizes);
140 }
141 
142 // Returns true iff partial is true and val is < 0.
143 // REQUIRES: val < kMaxRep16
144 // REQUIRES: partial || val >= 0
Set16(bool partial,uint16 * dst,int dim,int64 val)145 static inline bool Set16(bool partial, uint16* dst, int dim, int64 val) {
146   if (partial) {
147     if (val < 0) {
148       dst[dim] = std::numeric_limits<uint16>::max();
149       return true;
150     }
151   } else {
152     CHECK_GE(val, 0);
153   }
154   dst[dim] = val;
155   return false;
156 }
157 
158 template <class Shape>
InitDims(gtl::ArraySlice<int64> dim_sizes)159 void TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
160   DCHECK_EQ(tag(), REP16);
161 
162   // Allow sizes that are under kint64max^0.25 so that 4-way multiplication
163   // below cannot overflow.
164   static const uint64 kMaxSmall = 0xd744;
165   static_assert(kMaxSmall * kMaxSmall * kMaxSmall * kMaxSmall <= kint64max,
166                 "bad overflow check");
167   bool large_size = false;
168   for (auto s : dim_sizes) {
169     if (s > kMaxSmall) {
170       large_size = true;
171       break;
172     }
173   }
174 
175   if (!large_size) {
176     // Every size fits in 16 bits; use fast-paths for dims in {1,2,3,4}.
177     uint16* dst = as16()->dims_;
178     switch (dim_sizes.size()) {
179       case 1: {
180         set_ndims_byte(1);
181         const int64 size = dim_sizes[0];
182         const bool neg = Set16(kIsPartial, dst, 0, size);
183         set_num_elements(neg ? -1 : size);
184         return;
185       }
186       case 2: {
187         set_ndims_byte(2);
188         const int64 size0 = dim_sizes[0];
189         const int64 size1 = dim_sizes[1];
190         bool neg = Set16(kIsPartial, dst, 0, size0);
191         neg |= Set16(kIsPartial, dst, 1, size1);
192         set_num_elements(neg ? -1 : (size0 * size1));
193         return;
194       }
195       case 3: {
196         set_ndims_byte(3);
197         const int64 size0 = dim_sizes[0];
198         const int64 size1 = dim_sizes[1];
199         const int64 size2 = dim_sizes[2];
200         bool neg = Set16(kIsPartial, dst, 0, size0);
201         neg |= Set16(kIsPartial, dst, 1, size1);
202         neg |= Set16(kIsPartial, dst, 2, size2);
203         set_num_elements(neg ? -1 : (size0 * size1 * size2));
204         return;
205       }
206       case 4: {
207         set_ndims_byte(4);
208         const int64 size0 = dim_sizes[0];
209         const int64 size1 = dim_sizes[1];
210         const int64 size2 = dim_sizes[2];
211         const int64 size3 = dim_sizes[3];
212         bool neg = Set16(kIsPartial, dst, 0, size0);
213         neg |= Set16(kIsPartial, dst, 1, size1);
214         neg |= Set16(kIsPartial, dst, 2, size2);
215         neg |= Set16(kIsPartial, dst, 3, size3);
216         set_num_elements(neg ? -1 : (size0 * size1 * size2 * size3));
217         return;
218       }
219     }
220   }
221 
222   set_ndims_byte(0);
223   set_num_elements(1);
224   for (int64 s : dim_sizes) {
225     AddDim(internal::SubtleMustCopy(s));
226   }
227 }
228 
229 template <class Shape>
TensorShapeBase()230 TensorShapeBase<Shape>::TensorShapeBase() {
231   set_tag(REP16);
232   set_data_type(DT_INVALID);
233   if (kIsPartial) {
234     set_ndims_byte(kUnknownRank);
235     set_num_elements(-1);
236   } else {
237     set_ndims_byte(0);
238     set_num_elements(1);
239   }
240 }
241 
DestructorOutOfLine()242 void TensorShapeRep::DestructorOutOfLine() {
243   DCHECK(tag() == REP_OUT_OF_LINE);
244   delete as64()->dims_;
245 }
246 
SlowCopyFrom(const TensorShapeRep & b)247 void TensorShapeRep::SlowCopyFrom(const TensorShapeRep& b) {
248   if (b.tag() != REP_OUT_OF_LINE) {
249     if (tag() == REP_OUT_OF_LINE) {
250       delete as64()->dims_;
251     }
252     memcpy(buf(), b.buf(), sizeof(u_.buf));
253     // memcpy above implicitly also does:
254     //   set_tag(b.tag());
255     //   set_ndims_byte(b.ndims_byte());
256     //   set_data_type(b.data_type());
257   } else {
258     DCHECK_EQ(b.tag(), REP_OUT_OF_LINE);
259     set_ndims_byte(b.ndims_byte());
260     set_data_type(b.data_type());
261     if (tag() == REP_OUT_OF_LINE) {
262       // vector already allocated
263       *(as64()->dims_) = *(b.as64()->dims_);
264     } else {
265       set_tag(REP_OUT_OF_LINE);
266       as64()->dims_ = new gtl::InlinedVector<int64, 4>(*(b.as64()->dims_));
267     }
268   }
269 }
270 
271 template <class Shape>
dim_size(int d) const272 int64 TensorShapeBase<Shape>::dim_size(int d) const {
273   if (unknown_rank()) return -1;
274   DCHECK_GE(d, 0);
275   DCHECK_LT(d, dims());
276   if (tag() == REP16) {
277     uint16 dim = as16()->dims_[d];
278     if (kIsPartial && dim == kUnknownRep16) return -1;
279     return dim;
280   } else if (tag() == REP32) {
281     uint32 dim = as32()->dims_[d];
282     if (kIsPartial && dim == kUnknownRep32) return -1;
283     return dim;
284   } else {
285     return (*as64()->dims_)[d];
286   }
287 }
288 
Clear()289 void TensorShapeRep::Clear() {
290   ClearAllButDataType();
291   set_data_type(DT_INVALID);
292 }
293 
ClearAllButDataType()294 void TensorShapeRep::ClearAllButDataType() {
295   if (tag() == REP_OUT_OF_LINE) {
296     delete as64()->dims_;
297   }
298   set_tag(REP16);
299   set_ndims_byte(0);
300   // Leaves data_type alone
301   set_num_elements(1);
302 }
303 
304 template <class Shape>
RecomputeNumElements()305 void TensorShapeBase<Shape>::RecomputeNumElements() {
306   if (unknown_rank()) {
307     set_num_elements(-1);
308     return;
309   }
310   int64 n = 1;
311   for (auto dim : *this) {
312     if (kIsPartial && dim.size < 0) {
313       n = -1;
314       break;
315     }
316     n = MultiplyWithoutOverflow(n, dim.size);
317     CHECK_LE(0, n);
318   }
319   set_num_elements(n);
320 }
321 
322 template <class Shape>
AddDim(int64 size)323 void TensorShapeBase<Shape>::AddDim(int64 size) {
324   if (!kIsPartial) CHECK_GE(size, 0);
325   if (unknown_rank()) return;
326   CHECK_LT(ndims_byte(), MaxDimensions()) << "Too many dimensions in tensor";
327   int64 new_num_elements;
328   if (kIsPartial && (num_elements() < 0 || size < 0)) {
329     new_num_elements = -1;
330   } else {
331     new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
332     CHECK_LE(0, new_num_elements);
333   }
334   UnsafeAddDim(size, new_num_elements);
335 }
336 
337 template <class Shape>
UnsafeAddDim(int64 size,int64 new_num_elements)338 void TensorShapeBase<Shape>::UnsafeAddDim(int64 size, int64 new_num_elements) {
339   const int nd = ndims_byte();
340   if (tag() == REP16 && nd < 6 && size < kMaxRep16) {
341     as16()->dims_[nd] =
342         kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
343   } else if (tag() == REP32 && nd < 3 && size < kMaxRep32) {
344     as32()->dims_[nd] =
345         kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
346   } else if (tag() == REP_OUT_OF_LINE) {
347     as64()->dims_->push_back(size);
348   } else {
349     // Need to change representation
350     gtl::InlinedVector<int64, 8> vals;
351     AppendTo(*this, &vals);
352     vals.push_back(size);
353     // We know we can't be REP16.  See if we have a small enough
354     // number of dimensions and each dimension's size is small enough
355     // to allow REP32.
356     bool can_be_rep32 = (vals.size() <= 3);
357     if (can_be_rep32) {
358       for (size_t i = 0; i < vals.size(); i++) {
359         if (vals[i] >= kMaxRep32) {
360           can_be_rep32 = false;
361           break;
362         }
363       }
364     }
365     if (can_be_rep32) {
366       set_tag(REP32);
367       for (size_t d = 0; d < vals.size(); d++) {
368         as32()->dims_[d] = kIsPartial && vals[d] < 0
369                                ? kUnknownRep32
370                                : static_cast<uint32>(vals[d]);
371       }
372     } else {
373       set_tag(REP_OUT_OF_LINE);
374       as64()->dims_ =
375           new gtl::InlinedVector<int64, 4>(vals.begin(), vals.end());
376     }
377   }
378   set_ndims_byte(nd + 1);
379   set_num_elements(new_num_elements);
380 }
381 
382 template <class Shape>
AppendShape(const TensorShapeBase & shape)383 void TensorShapeBase<Shape>::AppendShape(const TensorShapeBase& shape) {
384   for (auto d : shape) AddDim(d.size);
385 }
386 
387 template <class Shape>
InsertDim(int d,int64 size)388 void TensorShapeBase<Shape>::InsertDim(int d, int64 size) {
389   CHECK_GE(d, 0);
390   CHECK_LE(d, dims());
391   if (!kIsPartial) CHECK_GE(size, 0);
392   CHECK_LT(dims(), MaxDimensions());
393   gtl::InlinedVector<int64, 8> vals;
394   AppendTo(*this, &vals);
395   vals.insert(vals.begin() + d, size);
396   ClearAllButDataType();
397   for (auto dval : vals) {
398     AddDim(dval);
399   }
400 }
401 
402 template <class Shape>
dim_sizes() const403 gtl::InlinedVector<int64, 4> TensorShapeBase<Shape>::dim_sizes() const {
404   gtl::InlinedVector<int64, 4> result;
405   for (auto dim : *this) {
406     result.push_back(dim.size);
407   }
408   return result;
409 }
410 
411 template <class Shape>
set_dim(int d,int64 size)412 void TensorShapeBase<Shape>::set_dim(int d, int64 size) {
413   CHECK_GE(d, 0);
414   CHECK_LT(d, dims());
415   CHECK_GE(size, 0);
416   if (tag() == REP16 && size < kMaxRep16) {
417     as16()->dims_[d] =
418         kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
419   } else if (tag() == REP32 && size < kMaxRep32) {
420     as32()->dims_[d] =
421         kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
422   } else if (tag() == REP_OUT_OF_LINE) {
423     (*as64()->dims_)[d] = size;
424   } else {
425     // Must upgrade
426     gtl::InlinedVector<int64, 8> vals;
427     AppendTo(*this, &vals);
428     vals[d] = size;
429     ClearAllButDataType();
430     for (auto dval : vals) {
431       AddDim(dval);
432     }
433   }
434   RecomputeNumElements();
435 }
436 
437 template <class Shape>
RemoveDimRange(int begin,int end)438 void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) {
439   if (unknown_rank()) return;
440   begin = begin < 0 ? dims() + begin + 1 : begin;
441   end = end < 0 ? dims() + end + 1 : end;
442   CHECK_GE(begin, 0);
443   CHECK_LE(begin, dims());
444   CHECK_GE(end, 0);
445   CHECK_LE(end, dims());
446   if (begin >= end) return;
447   gtl::InlinedVector<int64, 8> vals;
448   AppendTo(*this, &vals);
449   vals.erase(vals.begin() + begin, vals.begin() + end);
450   ClearAllButDataType();
451   for (auto dval : vals) {
452     AddDim(dval);
453   }
454   RecomputeNumElements();
455 }
456 
IsSameSize(const TensorShape & b) const457 bool TensorShape::IsSameSize(const TensorShape& b) const {
458   if (b.dims() != dims()) return false;
459   for (int d = 0; d < dims(); d++) {
460     if (dim_size(d) != b.dim_size(d)) return false;
461   }
462   return true;
463 }
464 
465 template <class Shape>
AsProto(TensorShapeProto * proto) const466 void TensorShapeBase<Shape>::AsProto(TensorShapeProto* proto) const {
467   proto->Clear();
468   if (unknown_rank()) {
469     proto->set_unknown_rank(true);
470   } else {
471     for (int i = 0; i < dims(); i++) {
472       proto->add_dim()->set_size(dim_size(i));
473     }
474   }
475 }
476 
DumpRep() const477 void TensorShapeRep::DumpRep() const {
478 #if 0
479   fprintf(stderr, "Rep: %d %d dims\n", tag(), dims());
480   if (tag() == REP16) {
481     fprintf(stderr, "REP16 NDIMS: %d\n", ndims_byte());
482     for (int i = 0; i < ndims_byte(); i++) {
483       fprintf(stderr, "dim %d: %d\n", i, as16()->dims_[i]);
484     }
485   } else if (tag_ == REP32) {
486     fprintf(stderr, "REP32 NDIMS: %d\n", ndims_);
487     for (int i = 0; i < ndims_byte(); i++) {
488       fprintf(stderr, "dim %d: %d\n", i, as32()->dims_[i]);
489     }
490   } else if (tag_ == REP_OUT_OF_LINE) {
491     fprintf(stderr, "REP_OUT_OF_LINE NDIMS: %d %p\n", ndims_, as16()->dims_);
492     for (int i = 0; i < ndims_byte(); i++) {
493       fprintf(stderr, "dim %d: %lld\n", i, (*as64()->dims_)[i]);
494     }
495   }
496 #endif
497 }
498 
499 template <class Shape>
begin() const500 TensorShapeIter<Shape> TensorShapeBase<Shape>::begin() const {
501   return TensorShapeIter<Shape>(static_cast<const Shape*>(this), 0);
502 }
503 
504 template <class Shape>
end() const505 TensorShapeIter<Shape> TensorShapeBase<Shape>::end() const {
506   CHECK(!unknown_rank());
507   return TensorShapeIter<Shape>(static_cast<const Shape*>(this), dims());
508 }
509 
DebugString() const510 string TensorShapeRep::DebugString() const {
511   const auto& shape = *static_cast<const PartialTensorShape*>(this);
512   if (shape.unknown_rank()) return "<unknown>";
513   string s = "[";
514   for (int i = 0; i < shape.dims(); i++) {
515     if (i > 0) strings::StrAppend(&s, ",");
516     int64 dim = shape.dim_size(i);
517     if (dim < 0) {
518       strings::StrAppend(&s, "?");
519     } else {
520       strings::StrAppend(&s, dim);
521     }
522   }
523   strings::StrAppend(&s, "]");
524   return s;
525 }
526 
DebugString(const TensorShapeProto & proto)527 string TensorShapeRep::DebugString(const TensorShapeProto& proto) {
528   string s;
529   if (proto.unknown_rank()) {
530     strings::StrAppend(&s, "<unknown>");
531     if (proto.dim_size() == 0) return s;
532   }
533   strings::StrAppend(&s, "[");
534   bool first = true;
535   for (const auto& d : proto.dim()) {
536     if (!first) strings::StrAppend(&s, ",");
537     if (d.size() == -1) {
538       strings::StrAppend(&s, "?");
539     } else {
540       strings::StrAppend(&s, d.size());
541     }
542     first = false;
543   }
544   strings::StrAppend(&s, "]");
545   return s;
546 }
547 
StartsWith(const TensorShape & shape,const TensorShape & prefix)548 bool TensorShapeUtils::StartsWith(const TensorShape& shape,
549                                   const TensorShape& prefix) {
550   if (shape.dims() < prefix.dims()) return false;
551   for (int i = 0; i < prefix.dims(); ++i) {
552     if (shape.dim_size(i) != prefix.dim_size(i)) return false;
553   }
554   return true;
555 }
556 
EndsWith(const TensorShape & shape,const TensorShape & suffix)557 bool TensorShapeUtils::EndsWith(const TensorShape& shape,
558                                 const TensorShape& suffix) {
559   const int suffix_size = suffix.dims();
560   if (shape.dims() < suffix_size) return false;
561   for (int i = 0; i < suffix_size; ++i) {
562     if (shape.dim_size(shape.dims() - suffix_size + i) != suffix.dim_size(i)) {
563       return false;
564     }
565   }
566   return true;
567 }
568 
569 template <typename T, class Shape>
MakeShapeHelper(const T * dims,int64 n,Shape * out)570 Status MakeShapeHelper(const T* dims, int64 n, Shape* out) {
571   out->Clear();
572   if (n > TensorShape::MaxDimensions()) {
573     return errors::InvalidArgument("Too many dimensions");
574   }
575   if (n < 0) {
576     return errors::InvalidArgument("Negative number of dimensions ", n);
577   }
578   for (int64 i = 0; i < n; ++i) {
579     T dim = internal::SubtleMustCopy(dims[i]);
580     int64 new_num_elements;
581     if (dim < 0) {
582       if (!out->kIsPartial) {
583         return errors::InvalidArgument("Dimension ", dim, " must be >= 0");
584       }
585       if (dim < -1) {
586         return errors::InvalidArgument("Dimension ", dim, " must be >= -1");
587       }
588       dim = -1;
589       new_num_elements = -1;
590     } else if (out->num_elements() < 0) {
591       new_num_elements = -1;
592     } else {
593       new_num_elements = MultiplyWithoutOverflow(out->num_elements(), dim);
594       if (TF_PREDICT_FALSE(new_num_elements < 0)) {
595         TensorShapeProto proto;
596         for (int64 j = 0; j < n; ++j) {
597           proto.add_dim()->set_size(dim);
598         }
599         return errors::InvalidArgument(
600             "Shape ", TensorShape::DebugString(proto),
601             " would have more than 2**63 - 1 elements");
602       }
603     }
604     out->UnsafeAddDim(dim, new_num_elements);
605   }
606   return Status::OK();
607 }
608 
609 #define MAKE_SHAPE(T, Shape)                                                 \
610   Status TensorShapeUtils::MakeShape(const T* dims, int64 n, Shape* out) {   \
611     return MakeShapeHelper(dims, n, out);                                    \
612   }                                                                          \
613   Status TensorShapeUtils::MakeShape(gtl::ArraySlice<T> shape, Shape* out) { \
614     return MakeShapeHelper(shape.data(), shape.size(), out);                 \
615   }
MAKE_SHAPE(int32,TensorShape)616 MAKE_SHAPE(int32, TensorShape)
617 MAKE_SHAPE(int64, TensorShape)
618 MAKE_SHAPE(int32, PartialTensorShape)
619 MAKE_SHAPE(int64, PartialTensorShape)
620 #undef MAKE_SHAPE
621 
622 string TensorShapeUtils::ShapeListString(
623     const gtl::ArraySlice<TensorShape>& shapes) {
624   string result = "[";
625   bool first = true;
626   for (const TensorShape& shape : shapes) {
627     strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
628     first = false;
629   }
630   strings::StrAppend(&result, "]");
631   return result;
632 }
633 
Concatenate(int64 size) const634 PartialTensorShape PartialTensorShape::Concatenate(int64 size) const {
635   PartialTensorShape out = *this;
636   out.AddDim(size);
637   return out;
638 }
639 
Concatenate(const PartialTensorShape & shape) const640 PartialTensorShape PartialTensorShape::Concatenate(
641     const PartialTensorShape& shape) const {
642   if (unknown_rank() || shape.unknown_rank()) {
643     return PartialTensorShape();
644   }
645   PartialTensorShape out = *this;
646   for (auto dim : shape) out.AddDim(dim.size);
647   return out;
648 }
649 
MergeWith(const PartialTensorShape & shape,PartialTensorShape * result) const650 Status PartialTensorShape::MergeWith(const PartialTensorShape& shape,
651                                      PartialTensorShape* result) const {
652   if (unknown_rank()) {
653     *result = shape;
654     return Status::OK();
655   }
656   if (shape.unknown_rank()) {
657     *result = *this;
658     return Status::OK();
659   }
660   const int dims_ = dims();
661   if (dims_ != shape.dims()) {
662     return errors::InvalidArgument(
663         "PartialTensorShape: Incompatible ranks during merge: ", dims_, " vs. ",
664         shape.dims());
665   }
666   CHECK(result != this);
667   result->Clear();
668   for (int i = 0; i < dims_; ++i) {
669     const int64 dim0 = dim_size(i);
670     const int64 dim1 = shape.dim_size(i);
671     if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) {
672       return errors::InvalidArgument(
673           "PartialTensorShape: Incompatible shapes during merge: ",
674           DebugString(), " vs. ", shape.DebugString());
675     }
676     result->AddDim(dim0 >= 0 ? dim0 : dim1);
677   }
678   return Status::OK();
679 }
680 
AsTensorShape(TensorShape * shape) const681 bool PartialTensorShape::AsTensorShape(TensorShape* shape) const {
682   if (IsFullyDefined()) {
683     const TensorShapeRep* rep = this;
684     *shape = *static_cast<const TensorShape*>(rep);
685     return true;
686   }
687   return false;
688 }
689 
IsIdenticalTo(const PartialTensorShape & shape) const690 bool PartialTensorShape::IsIdenticalTo(const PartialTensorShape& shape) const {
691   if (unknown_rank() || shape.unknown_rank()) {
692     return unknown_rank() == shape.unknown_rank();
693   }
694   if (dims() != shape.dims()) return false;
695   for (int i = 0; i < dims(); i++) {
696     if (dim_size(i) != shape.dim_size(i)) return false;
697   }
698   return true;
699 }
700 
IsCompatibleWith(const PartialTensorShape & shape) const701 bool PartialTensorShape::IsCompatibleWith(
702     const PartialTensorShape& shape) const {
703   if (unknown_rank() || shape.unknown_rank()) return true;
704   if (dims() != shape.dims()) return false;
705   for (int i = 0; i < dims(); i++) {
706     const int64 dim0 = dim_size(i);
707     const int64 dim1 = shape.dim_size(i);
708     if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) return false;
709   }
710   return true;
711 }
712 
PartialShapeListString(const gtl::ArraySlice<PartialTensorShape> & shapes)713 string PartialTensorShapeUtils::PartialShapeListString(
714     const gtl::ArraySlice<PartialTensorShape>& shapes) {
715   string result = "[";
716   bool first = true;
717   for (const PartialTensorShape& shape : shapes) {
718     strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
719     first = false;
720   }
721   strings::StrAppend(&result, "]");
722   return result;
723 }
724 
AreCompatible(const gtl::ArraySlice<PartialTensorShape> & shapes0,const gtl::ArraySlice<PartialTensorShape> & shapes1)725 bool PartialTensorShapeUtils::AreCompatible(
726     const gtl::ArraySlice<PartialTensorShape>& shapes0,
727     const gtl::ArraySlice<PartialTensorShape>& shapes1) {
728   if (shapes0.size() == shapes1.size()) {
729     for (size_t i = 0; i < shapes0.size(); ++i) {
730       if (!shapes0[i].IsCompatibleWith(shapes1[i])) {
731         return false;
732       }
733     }
734     return true;
735   } else {
736     return false;
737   }
738 }
739 
AreIdentical(const gtl::ArraySlice<PartialTensorShape> & shapes0,const gtl::ArraySlice<PartialTensorShape> & shapes1)740 bool PartialTensorShapeUtils::AreIdentical(
741     const gtl::ArraySlice<PartialTensorShape>& shapes0,
742     const gtl::ArraySlice<PartialTensorShape>& shapes1) {
743   if (shapes0.size() == shapes1.size()) {
744     for (size_t i = 0; i < shapes0.size(); ++i) {
745       if (!shapes0[i].IsIdenticalTo(shapes1[i])) {
746         return false;
747       }
748     }
749     return true;
750   } else {
751     return false;
752   }
753 }
754 
NumElements(gtl::ArraySlice<int64> shape,int64 * num_elements)755 Status TensorShapeUtils::NumElements(gtl::ArraySlice<int64> shape,
756                                      int64* num_elements) {
757   int64 n = 1;
758   for (auto dim : shape) {
759     n = MultiplyWithoutOverflow(n, dim);
760     if (n < 0) {
761       return errors::InvalidArgument("Can't compute total size of shape [",
762                                      str_util::Join(shape, ","),
763                                      "]; product would overflow int64");
764     }
765   }
766   *num_elements = n;
767   return Status::OK();
768 }
769 
770 template class TensorShapeBase<TensorShape>;
771 template class TensorShapeBase<PartialTensorShape>;
772 
773 }  // namespace tensorflow
774