1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/tensor_shape.h"
17 
18 #include "tensorflow/core/framework/bounds_check.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 #include "tensorflow/core/platform/errors.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/macros.h"
26 #include "tensorflow/core/util/overflow.h"
27 
28 namespace tensorflow {
29 
30 // TensorShape and PartialTensorShape should have no fields beyond
31 // TensorShapeRep.  In particular, their sizes should be the same.
32 static_assert(sizeof(TensorShapeRep) == sizeof(TensorShape),
33               "TensorShape must have no fields beyond TensorShapeRep");
34 static_assert(sizeof(TensorShapeRep) == sizeof(PartialTensorShape),
35               "PartialTensorShape must have no fields beyond TensorShapeRep");
36 
37 template <class Shape>
AppendTo(const TensorShapeBase<Shape> & s,gtl::InlinedVector<int64,8> * vals)38 static void AppendTo(const TensorShapeBase<Shape>& s,
39                      gtl::InlinedVector<int64, 8>* vals) {
40   for (auto dim : s) {
41     vals->push_back(dim.size);
42   }
43 }
44 
CheckDimsEqual(int NDIMS) const45 void TensorShape::CheckDimsEqual(int NDIMS) const {
46   CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS << " dimensions"
47                           << " from a tensor of " << dims() << " dimensions";
48 }
49 
CheckDimsAtLeast(int NDIMS) const50 void TensorShape::CheckDimsAtLeast(int NDIMS) const {
51   CHECK_GE(NDIMS, dims()) << "Asking for tensor of at least " << NDIMS
52                           << " dimensions from a tensor of " << dims()
53                           << " dimensions";
54 }
55 
56 // TODO(slebedev): Consider merging IsValid implementations.
57 template <class Shape>
IsValid()58 bool TensorShapeBase<Shape>::IsValid() {
59   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
60   // unknown_shape() set, and it seems hard to remove this without backwards
61   // compatibility issues.
62   if (kIsPartial && unknown_rank()) return dims() == 0;
63   int64 num_elements = 1;
64   if (dims() > MaxDimensions()) return false;
65   for (auto d : dim_sizes()) {
66     if (d < (kIsPartial ? -1 : 0)) return false;
67     if (d == -1) {
68       num_elements = -1;
69     } else if (!kIsPartial || num_elements >= 0) {
70       num_elements = MultiplyWithoutOverflow(num_elements, d);
71       if (num_elements < 0) return false;
72     }
73   }
74   return true;
75 }
76 
77 template <class Shape>
IsValid(const TensorShapeProto & proto)78 bool TensorShapeBase<Shape>::IsValid(const TensorShapeProto& proto) {
79   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
80   // unknown_shape() set, and it seems hard to remove this without backwards
81   // compatibility issues.
82   if (kIsPartial && proto.unknown_rank()) return proto.dim_size() == 0;
83   int64 num_elements = 1;
84   if (proto.dim().size() > MaxDimensions()) return false;
85   for (const auto& d : proto.dim()) {
86     if (d.size() < (kIsPartial ? -1 : 0)) return false;
87     if (d.size() == -1) {
88       num_elements = -1;
89     } else if (!kIsPartial || num_elements >= 0) {
90       num_elements = MultiplyWithoutOverflow(num_elements, d.size());
91       if (num_elements < 0) return false;
92     }
93   }
94   return true;
95 }
96 
97 template <class Shape>
IsValidShape(const TensorShapeProto & proto)98 Status TensorShapeBase<Shape>::IsValidShape(const TensorShapeProto& proto) {
99   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
100   // unknown_shape() set, and it seems hard to remove this without backwards
101   // compatibility issues.
102   if (kIsPartial && proto.unknown_rank()) {
103     if (proto.dim_size() > 0) {
104       return errors::InvalidArgument(
105           "An unknown shape must not have any dimensions set.");
106     }
107     return Status::OK();
108   }
109   int64 num_elements = 1;
110   if (proto.dim().size() > MaxDimensions()) {
111     return errors::InvalidArgument("Shape ", DebugString(proto),
112                                    " has too many dimensions");
113   }
114   for (const auto& d : proto.dim()) {
115     if (d.size() < (kIsPartial ? -1 : 0)) {
116       if (kIsPartial) {
117         return errors::InvalidArgument(
118             "Shape ", DebugString(proto),
119             " has dimensions with values below -1 (where -1 means unknown)");
120       } else {
121         return errors::InvalidArgument("Shape ", DebugString(proto),
122                                        " is not fully defined");
123       }
124     }
125     if (d.size() == -1) {
126       num_elements = -1;
127     } else if (!kIsPartial || num_elements >= 0) {
128       num_elements = MultiplyWithoutOverflow(num_elements, d.size());
129       if (num_elements < 0) {
130         return errors::InvalidArgument(
131             "Shape ", DebugString(proto),
132             " is too large (more than 2**63 - 1 entries)");
133       }
134     }
135   }
136   return Status::OK();
137 }
138 
139 template <class Shape>
TensorShapeBase(const TensorShapeProto & proto)140 TensorShapeBase<Shape>::TensorShapeBase(const TensorShapeProto& proto) {
141   set_tag(REP16);
142   set_data_type(DT_INVALID);
143   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
144   // unknown_shape() set, and it seems hard to remove this without backwards
145   // compatibility issues.
146   if (kIsPartial && proto.unknown_rank()) {
147     set_ndims_byte(kUnknownRank);
148     set_num_elements(-1);
149   } else {
150     set_ndims_byte(0);
151     set_num_elements(1);
152     for (const auto& d : proto.dim()) {
153       AddDim(d.size());
154     }
155   }
156 }
157 
158 template <class Shape>
BuildTensorShapeBase(const TensorShapeProto & proto,TensorShapeBase * out)159 Status TensorShapeBase<Shape>::BuildTensorShapeBase(
160     const TensorShapeProto& proto, TensorShapeBase* out) {
161   out->set_tag(REP16);
162   out->set_data_type(DT_INVALID);
163   // NOTE(irving): Unfortunately, TensorShape allows parsing protos with
164   // unknown_shape() set, and it seems hard to remove this without backwards
165   // compatibility issues.
166   if (kIsPartial && proto.unknown_rank()) {
167     out->set_ndims_byte(kUnknownRank);
168     out->set_num_elements(-1);
169   } else {
170     out->set_ndims_byte(0);
171     out->set_num_elements(1);
172     Status s = Status::OK();
173     for (const auto& d : proto.dim()) {
174       s = out->AddDimWithStatus(d.size());
175       if (!s.ok()) {
176         return s;
177       }
178     }
179   }
180   return Status::OK();
181 }
182 
183 template <class Shape>
TensorShapeBase(gtl::ArraySlice<int64> dim_sizes)184 TensorShapeBase<Shape>::TensorShapeBase(gtl::ArraySlice<int64> dim_sizes) {
185   set_tag(REP16);
186   set_data_type(DT_INVALID);
187   TF_CHECK_OK(InitDims(dim_sizes));
188 }
189 
190 template <class Shape>
BuildTensorShapeBase(gtl::ArraySlice<int64> dim_sizes,TensorShapeBase * out)191 Status TensorShapeBase<Shape>::BuildTensorShapeBase(
192     gtl::ArraySlice<int64> dim_sizes, TensorShapeBase* out) {
193   out->set_tag(REP16);
194   out->set_data_type(DT_INVALID);
195   return out->InitDims(dim_sizes);
196 }
197 
198 // Returns true iff partial is true and val is < 0.
199 // REQUIRES: val < kMaxRep16
200 // REQUIRES: partial || val >= 0
Set16(bool partial,uint16 * dst,int dim,int64 val)201 static inline bool Set16(bool partial, uint16* dst, int dim, int64 val) {
202   if (partial) {
203     if (val < 0) {
204       dst[dim] = std::numeric_limits<uint16>::max();
205       return true;
206     }
207   }
208   dst[dim] = val;
209   return false;
210 }
211 
212 template <class Shape>
InitDims(gtl::ArraySlice<int64> dim_sizes)213 Status TensorShapeBase<Shape>::InitDims(gtl::ArraySlice<int64> dim_sizes) {
214   DCHECK_EQ(tag(), REP16);
215 
216   // Allow sizes that are under kint64max^0.25 so that 4-way multiplication
217   // below cannot overflow.
218   static const int64 kMaxSmall = 0xd744;
219   static_assert(kMaxSmall * kMaxSmall * kMaxSmall * kMaxSmall <= kint64max,
220                 "bad overflow check");
221   bool large_size = false;
222   for (auto s : dim_sizes) {
223     if (s > kMaxSmall) {
224       large_size = true;
225       break;
226     }
227   }
228 
229   if (!kIsPartial && !large_size) {
230     for (auto s : dim_sizes) {
231       if (TF_PREDICT_FALSE(s < 0)) {
232         return errors::Internal(
233             "Expected shape dimensions to be non-negative, got ", s);
234       }
235     }
236   }
237 
238   if (!large_size) {
239     // Every size fits in 16 bits; use fast-paths for dims in {1,2,3,4}.
240     uint16* dst = as16()->dims_;
241     switch (dim_sizes.size()) {
242       case 1: {
243         set_ndims_byte(1);
244         const int64 size = dim_sizes[0];
245         const bool neg = Set16(kIsPartial, dst, 0, size);
246         set_num_elements(neg ? -1 : size);
247         return Status::OK();
248       }
249       case 2: {
250         set_ndims_byte(2);
251         const int64 size0 = dim_sizes[0];
252         const int64 size1 = dim_sizes[1];
253         bool neg = Set16(kIsPartial, dst, 0, size0);
254         neg |= Set16(kIsPartial, dst, 1, size1);
255         set_num_elements(neg ? -1 : (size0 * size1));
256         return Status::OK();
257       }
258       case 3: {
259         set_ndims_byte(3);
260         const int64 size0 = dim_sizes[0];
261         const int64 size1 = dim_sizes[1];
262         const int64 size2 = dim_sizes[2];
263         bool neg = Set16(kIsPartial, dst, 0, size0);
264         neg |= Set16(kIsPartial, dst, 1, size1);
265         neg |= Set16(kIsPartial, dst, 2, size2);
266         set_num_elements(neg ? -1 : (size0 * size1 * size2));
267         return Status::OK();
268       }
269       case 4: {
270         set_ndims_byte(4);
271         const int64 size0 = dim_sizes[0];
272         const int64 size1 = dim_sizes[1];
273         const int64 size2 = dim_sizes[2];
274         const int64 size3 = dim_sizes[3];
275         bool neg = Set16(kIsPartial, dst, 0, size0);
276         neg |= Set16(kIsPartial, dst, 1, size1);
277         neg |= Set16(kIsPartial, dst, 2, size2);
278         neg |= Set16(kIsPartial, dst, 3, size3);
279         set_num_elements(neg ? -1 : (size0 * size1 * size2 * size3));
280         return Status::OK();
281       }
282     }
283   }
284 
285   set_ndims_byte(0);
286   set_num_elements(1);
287   Status status = Status::OK();
288   for (int64 s : dim_sizes) {
289     status.Update(AddDimWithStatus(internal::SubtleMustCopy(s)));
290     if (!status.ok()) {
291       return status;
292     }
293   }
294 
295   return status;
296 }
297 
298 template <class Shape>
TensorShapeBase()299 TensorShapeBase<Shape>::TensorShapeBase() {
300   set_tag(REP16);
301   set_data_type(DT_INVALID);
302   if (kIsPartial) {
303     set_ndims_byte(kUnknownRank);
304     set_num_elements(-1);
305   } else {
306     set_ndims_byte(0);
307     set_num_elements(1);
308   }
309 }
310 
DestructorOutOfLine()311 void TensorShapeRep::DestructorOutOfLine() {
312   DCHECK(tag() == REP_OUT_OF_LINE);
313   delete as64()->dims_;
314 }
315 
SlowCopyFrom(const TensorShapeRep & b)316 void TensorShapeRep::SlowCopyFrom(const TensorShapeRep& b) {
317   if (b.tag() != REP_OUT_OF_LINE) {
318     if (tag() == REP_OUT_OF_LINE) {
319       delete as64()->dims_;
320     }
321     memcpy(buf(), b.buf(), sizeof(u_.buf));
322     // memcpy above implicitly also does:
323     //   set_tag(b.tag());
324     //   set_ndims_byte(b.ndims_byte());
325     //   set_data_type(b.data_type());
326   } else {
327     set_ndims_byte(b.ndims_byte());
328     set_data_type(b.data_type());
329     if (tag() == REP_OUT_OF_LINE) {
330       // vector already allocated
331       *(as64()->dims_) = *(b.as64()->dims_);
332     } else {
333       set_tag(REP_OUT_OF_LINE);
334       as64()->dims_ = new gtl::InlinedVector<int64, 4>(*(b.as64()->dims_));
335     }
336   }
337 }
338 
339 template <class Shape>
dim_size(int d) const340 int64 TensorShapeBase<Shape>::dim_size(int d) const {
341   if (unknown_rank()) return -1;
342   DCHECK_GE(d, 0);
343   DCHECK_LT(d, dims());
344   if (tag() == REP16) {
345     uint16 dim = as16()->dims_[d];
346     if (kIsPartial && dim == kUnknownRep16) return -1;
347     return dim;
348   } else if (tag() == REP32) {
349     uint32 dim = as32()->dims_[d];
350     if (kIsPartial && dim == kUnknownRep32) return -1;
351     return dim;
352   } else {
353     return (*as64()->dims_)[d];
354   }
355 }
356 
Clear()357 void TensorShapeRep::Clear() {
358   ClearAllButDataType();
359   set_data_type(DT_INVALID);
360 }
361 
ClearAllButDataType()362 void TensorShapeRep::ClearAllButDataType() {
363   if (tag() == REP_OUT_OF_LINE) {
364     delete as64()->dims_;
365   }
366   set_tag(REP16);
367   set_ndims_byte(0);
368   // Leaves data_type alone
369   set_num_elements(1);
370 }
371 
372 template <class Shape>
RecomputeNumElements()373 Status TensorShapeBase<Shape>::RecomputeNumElements() {
374   if (unknown_rank()) {
375     set_num_elements(-1);
376     return Status::OK();
377   }
378   int64 n = 1;
379   for (auto dim : *this) {
380     if (kIsPartial && dim.size < 0) {
381       n = -1;
382       break;
383     }
384     n = MultiplyWithoutOverflow(n, dim.size);
385     if (TF_PREDICT_FALSE(n < 0)) {
386       return errors::InvalidArgument(
387           "Shape ", this->DebugString(),
388           " results in overflow when computing number of elements");
389     }
390   }
391   set_num_elements(n);
392   return Status::OK();
393 }
394 
395 template <class Shape>
AddDim(int64 size)396 void TensorShapeBase<Shape>::AddDim(int64 size) {
397   if (!kIsPartial) CHECK_GE(size, 0);
398   if (unknown_rank()) return;
399   CHECK_LT(ndims_byte(), MaxDimensions()) << "Too many dimensions in tensor";
400   int64 new_num_elements;
401   if (kIsPartial && (num_elements() < 0 || size < 0)) {
402     new_num_elements = -1;
403   } else {
404     new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
405     CHECK_LE(0, new_num_elements);
406   }
407   UnsafeAddDim(size, new_num_elements);
408 }
409 
410 template <class Shape>
AddDimWithStatus(int64 size)411 Status TensorShapeBase<Shape>::AddDimWithStatus(int64 size) {
412   if (!kIsPartial) {
413     if (TF_PREDICT_FALSE(size < 0)) {
414       return errors::Internal("Expected a non-negative size, got ", size);
415     }
416   }
417 
418   if (unknown_rank()) {
419     return Status::OK();
420   }
421 
422   if (TF_PREDICT_FALSE(ndims_byte() >= MaxDimensions())) {
423     return errors::Internal("Too many dimensions in tensor");
424   }
425 
426   int64 new_num_elements;
427   if (kIsPartial && (num_elements() < 0 || size < 0)) {
428     new_num_elements = -1;
429   } else {
430     new_num_elements = MultiplyWithoutOverflow(num_elements(), size);
431     if (TF_PREDICT_FALSE(new_num_elements < 0)) {
432       return errors::Internal("Encountered overflow when multiplying ",
433                               num_elements(), " with ", size,
434                               ", result: ", new_num_elements);
435     }
436   }
437 
438   UnsafeAddDim(size, new_num_elements);
439   return Status::OK();
440 }
441 
442 template <class Shape>
UnsafeAddDim(int64 size,int64 new_num_elements)443 void TensorShapeBase<Shape>::UnsafeAddDim(int64 size, int64 new_num_elements) {
444   const int nd = ndims_byte();
445   if (tag() == REP16 && nd < 6 && size < kMaxRep16) {
446     as16()->dims_[nd] =
447         kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
448   } else if (tag() == REP32 && nd < 3 && size < kMaxRep32) {
449     as32()->dims_[nd] =
450         kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
451   } else if (tag() == REP_OUT_OF_LINE) {
452     as64()->dims_->push_back(size);
453   } else {
454     // Need to change representation
455     gtl::InlinedVector<int64, 8> vals;
456     AppendTo(*this, &vals);
457     vals.push_back(size);
458     // We know we can't be REP16.  See if we have a small enough
459     // number of dimensions and each dimension's size is small enough
460     // to allow REP32.
461     bool can_be_rep32 = (vals.size() <= 3);
462     if (can_be_rep32) {
463       for (size_t i = 0; i < vals.size(); i++) {
464         if (vals[i] >= kMaxRep32) {
465           can_be_rep32 = false;
466           break;
467         }
468       }
469     }
470     if (can_be_rep32) {
471       set_tag(REP32);
472       for (size_t d = 0; d < vals.size(); d++) {
473         as32()->dims_[d] = kIsPartial && vals[d] < 0
474                                ? kUnknownRep32
475                                : static_cast<uint32>(vals[d]);
476       }
477     } else {
478       set_tag(REP_OUT_OF_LINE);
479       as64()->dims_ =
480           new gtl::InlinedVector<int64, 4>(vals.begin(), vals.end());
481     }
482   }
483   set_ndims_byte(nd + 1);
484   set_num_elements(new_num_elements);
485 }
486 
487 template <class Shape>
AppendShape(const TensorShapeBase & shape)488 void TensorShapeBase<Shape>::AppendShape(const TensorShapeBase& shape) {
489   for (auto d : shape) AddDim(d.size);
490 }
491 
492 template <class Shape>
AppendShapeWithStatus(const TensorShapeBase & shape)493 Status TensorShapeBase<Shape>::AppendShapeWithStatus(
494     const TensorShapeBase& shape) {
495   Status s = Status::OK();
496   for (auto d : shape) {
497     s.Update(AddDimWithStatus(d.size));
498     if (!s.ok()) {
499       return s;
500     }
501   }
502   return s;
503 }
504 
505 template <class Shape>
InsertDim(int d,int64 size)506 void TensorShapeBase<Shape>::InsertDim(int d, int64 size) {
507   CHECK_GE(d, 0);
508   CHECK_LE(d, dims());
509   if (!kIsPartial) CHECK_GE(size, 0);
510   CHECK_LT(dims(), MaxDimensions());
511   gtl::InlinedVector<int64, 8> vals;
512   AppendTo(*this, &vals);
513   vals.insert(vals.begin() + d, size);
514   ClearAllButDataType();
515   for (auto dval : vals) {
516     AddDim(dval);
517   }
518 }
519 
520 template <class Shape>
InsertDimWithStatus(int d,int64 size)521 Status TensorShapeBase<Shape>::InsertDimWithStatus(int d, int64 size) {
522   if (!kIsPartial) {
523     if (TF_PREDICT_FALSE(size < 0)) {
524       return errors::Internal("Expected a non-negative size, got ", size);
525     }
526   }
527 
528   if (TF_PREDICT_FALSE(d < 0)) {
529     return errors::Internal("The insertion index must be non-negative, got ",
530                             d);
531   }
532   if (TF_PREDICT_FALSE(d > dims())) {
533     return errors::Internal("The insertion index must be at most ", dims(),
534                             " got ", d);
535   }
536   if (TF_PREDICT_FALSE(dims() >= MaxDimensions())) {
537     return errors::Internal("Shape has ", dims(),
538                             " dimensions which is the maximum allowed");
539   }
540 
541   gtl::InlinedVector<int64, 8> vals;
542   AppendTo(*this, &vals);
543   vals.insert(vals.begin() + d, size);
544   ClearAllButDataType();
545 
546   Status s = Status::OK();
547   for (auto dval : vals) {
548     s.Update(AddDimWithStatus(dval));
549     if (!s.ok()) {
550       return s;
551     }
552   }
553   return s;
554 }
555 
556 template <class Shape>
dim_sizes() const557 gtl::InlinedVector<int64, 4> TensorShapeBase<Shape>::dim_sizes() const {
558   gtl::InlinedVector<int64, 4> result;
559   for (auto dim : *this) {
560     result.push_back(dim.size);
561   }
562   return result;
563 }
564 
565 template <class Shape>
set_dim(int d,int64 size)566 void TensorShapeBase<Shape>::set_dim(int d, int64 size) {
567   CHECK_GE(d, 0);
568   CHECK_LT(d, dims());
569   CHECK_GE(size, 0);
570   if (tag() == REP16 && size < kMaxRep16) {
571     as16()->dims_[d] =
572         kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
573   } else if (tag() == REP32 && size < kMaxRep32) {
574     as32()->dims_[d] =
575         kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
576   } else if (tag() == REP_OUT_OF_LINE) {
577     (*as64()->dims_)[d] = size;
578   } else {
579     // Must upgrade
580     gtl::InlinedVector<int64, 8> vals;
581     AppendTo(*this, &vals);
582     vals[d] = size;
583     ClearAllButDataType();
584     for (auto dval : vals) {
585       AddDim(dval);
586     }
587   }
588   TF_CHECK_OK(RecomputeNumElements());
589 }
590 
591 template <class Shape>
SetDimWithStatus(int d,int64 size)592 Status TensorShapeBase<Shape>::SetDimWithStatus(int d, int64 size) {
593   if (TF_PREDICT_FALSE(d < 0)) {
594     return errors::Internal("Index must be non-negative, got ", d);
595   }
596   if (TF_PREDICT_FALSE(d >= dims())) {
597     return errors::Internal("Index must be less than ", dims(), ", got ", d);
598   }
599   if (TF_PREDICT_FALSE(size < 0)) {
600     return errors::Internal("Expected a non-negative size, got ", size);
601   }
602 
603   if (tag() == REP16 && size < kMaxRep16) {
604     as16()->dims_[d] =
605         kIsPartial && size < 0 ? kUnknownRep16 : static_cast<uint16>(size);
606   } else if (tag() == REP32 && size < kMaxRep32) {
607     as32()->dims_[d] =
608         kIsPartial && size < 0 ? kUnknownRep32 : static_cast<uint32>(size);
609   } else if (tag() == REP_OUT_OF_LINE) {
610     (*as64()->dims_)[d] = size;
611   } else {
612     // Must upgrade
613     gtl::InlinedVector<int64, 8> vals;
614     AppendTo(*this, &vals);
615     vals[d] = size;
616     ClearAllButDataType();
617 
618     Status s = Status::OK();
619     for (auto dval : vals) {
620       s.Update(AddDimWithStatus(dval));
621       if (!s.ok()) {
622         return s;
623       }
624     }
625   }
626 
627   return RecomputeNumElements();
628 }
629 
630 template <class Shape>
RemoveDimRange(int begin,int end)631 void TensorShapeBase<Shape>::RemoveDimRange(int begin, int end) {
632   if (unknown_rank()) return;
633   begin = begin < 0 ? dims() + begin + 1 : begin;
634   end = end < 0 ? dims() + end + 1 : end;
635   CHECK_GE(begin, 0);
636   CHECK_LE(begin, dims());
637   CHECK_GE(end, 0);
638   CHECK_LE(end, dims());
639   if (begin >= end) return;
640   gtl::InlinedVector<int64, 8> vals;
641   AppendTo(*this, &vals);
642   vals.erase(vals.begin() + begin, vals.begin() + end);
643   ClearAllButDataType();
644   for (auto dval : vals) {
645     AddDim(dval);
646   }
647   TF_CHECK_OK(RecomputeNumElements());
648 }
649 
650 template <class Shape>
RemoveDimRangeWithStatus(int begin,int end)651 Status TensorShapeBase<Shape>::RemoveDimRangeWithStatus(int begin, int end) {
652   if (unknown_rank()) {
653     return Status::OK();
654   }
655 
656   begin = begin < 0 ? dims() + begin + 1 : begin;
657   end = end < 0 ? dims() + end + 1 : end;
658 
659   if (TF_PREDICT_FALSE(begin < 0)) {
660     return errors::Internal("Start index must be non-negative, got ", begin);
661   }
662   if (TF_PREDICT_FALSE(begin > dims())) {
663     return errors::Internal("Start index must be less than ", dims(), ", got ",
664                             begin);
665   }
666   if (TF_PREDICT_FALSE(end < 0)) {
667     return errors::Internal("End index must be non-negative, got ", end);
668   }
669   if (TF_PREDICT_FALSE(end > dims())) {
670     return errors::Internal("End index must be less than ", dims(), ", got ",
671                             end);
672   }
673 
674   if (begin >= end) {
675     return Status::OK();
676   }
677 
678   gtl::InlinedVector<int64, 8> vals;
679   AppendTo(*this, &vals);
680   vals.erase(vals.begin() + begin, vals.begin() + end);
681   ClearAllButDataType();
682 
683   Status s = Status::OK();
684   for (auto dval : vals) {
685     s.Update(AddDimWithStatus(dval));
686     if (!s.ok()) {
687       return s;
688     }
689   }
690 
691   return RecomputeNumElements();
692 }
693 
IsSameSize(const TensorShape & b) const694 bool TensorShape::IsSameSize(const TensorShape& b) const {
695   if (b.dims() != dims()) return false;
696   for (int d = 0; d < dims(); d++) {
697     if (dim_size(d) != b.dim_size(d)) return false;
698   }
699   return true;
700 }
701 
702 template <class Shape>
AsProto(TensorShapeProto * proto) const703 void TensorShapeBase<Shape>::AsProto(TensorShapeProto* proto) const {
704   proto->Clear();
705   if (unknown_rank()) {
706     proto->set_unknown_rank(true);
707   } else {
708     for (int i = 0; i < dims(); i++) {
709       proto->add_dim()->set_size(dim_size(i));
710     }
711   }
712 }
713 
714 template <class Shape>
begin() const715 TensorShapeIter<Shape> TensorShapeBase<Shape>::begin() const {
716   return TensorShapeIter<Shape>(static_cast<const Shape*>(this), 0);
717 }
718 
719 template <class Shape>
end() const720 TensorShapeIter<Shape> TensorShapeBase<Shape>::end() const {
721   const int max_dim = unknown_rank() ? -1 : dims();
722   return TensorShapeIter<Shape>(static_cast<const Shape*>(this), max_dim);
723 }
724 
DebugString() const725 string TensorShapeRep::DebugString() const {
726   const auto& shape = *static_cast<const PartialTensorShape*>(this);
727   if (shape.unknown_rank()) return "<unknown>";
728   string s = "[";
729   for (int i = 0; i < shape.dims(); i++) {
730     if (i > 0) strings::StrAppend(&s, ",");
731     int64 dim = shape.dim_size(i);
732     if (dim < 0) {
733       strings::StrAppend(&s, "?");
734     } else {
735       strings::StrAppend(&s, dim);
736     }
737   }
738   strings::StrAppend(&s, "]");
739   return s;
740 }
741 
DebugString(const TensorShapeProto & proto)742 string TensorShapeRep::DebugString(const TensorShapeProto& proto) {
743   string s;
744   if (proto.unknown_rank()) {
745     strings::StrAppend(&s, "<unknown>");
746     if (proto.dim_size() == 0) return s;
747   }
748   strings::StrAppend(&s, "[");
749   bool first = true;
750   for (const auto& d : proto.dim()) {
751     if (!first) strings::StrAppend(&s, ",");
752     if (d.size() == -1) {
753       strings::StrAppend(&s, "?");
754     } else {
755       strings::StrAppend(&s, d.size());
756     }
757     first = false;
758   }
759   strings::StrAppend(&s, "]");
760   return s;
761 }
762 
StartsWith(const TensorShape & shape,const TensorShape & prefix)763 bool TensorShapeUtils::StartsWith(const TensorShape& shape,
764                                   const TensorShape& prefix) {
765   if (shape.dims() < prefix.dims()) return false;
766   for (int i = 0; i < prefix.dims(); ++i) {
767     if (shape.dim_size(i) != prefix.dim_size(i)) return false;
768   }
769   return true;
770 }
771 
EndsWith(const TensorShape & shape,const TensorShape & suffix)772 bool TensorShapeUtils::EndsWith(const TensorShape& shape,
773                                 const TensorShape& suffix) {
774   const int suffix_size = suffix.dims();
775   if (shape.dims() < suffix_size) return false;
776   for (int i = 0; i < suffix_size; ++i) {
777     if (shape.dim_size(shape.dims() - suffix_size + i) != suffix.dim_size(i)) {
778       return false;
779     }
780   }
781   return true;
782 }
783 
784 template <typename T, class Shape>
MakeShapeHelper(const T * dims,int64 n,Shape * out)785 Status MakeShapeHelper(const T* dims, int64 n, Shape* out) {
786   out->Clear();
787   if (n > TensorShape::MaxDimensions()) {
788     return errors::InvalidArgument("Too many dimensions");
789   }
790   if (n < 0) {
791     return errors::InvalidArgument("Negative number of dimensions ", n);
792   }
793   for (int64 i = 0; i < n; ++i) {
794     T dim = internal::SubtleMustCopy(dims[i]);
795     int64 new_num_elements;
796     if (dim < 0) {
797       if (!out->kIsPartial) {
798         return errors::InvalidArgument("Dimension ", dim, " must be >= 0");
799       }
800       if (dim < -1) {
801         return errors::InvalidArgument("Dimension ", dim, " must be >= -1");
802       }
803       dim = -1;
804       new_num_elements = -1;
805     } else if (out->num_elements() < 0) {
806       new_num_elements = -1;
807     } else {
808       new_num_elements = MultiplyWithoutOverflow(out->num_elements(), dim);
809       if (TF_PREDICT_FALSE(new_num_elements < 0)) {
810         TensorShapeProto proto;
811         for (int64 j = 0; j < n; ++j) {
812           proto.add_dim()->set_size(internal::SubtleMustCopy(dims[j]));
813         }
814         return errors::InvalidArgument(
815             "Shape ", TensorShape::DebugString(proto),
816             " would have more than 2**63 - 1 elements");
817       }
818     }
819     out->UnsafeAddDim(dim, new_num_elements);
820   }
821   return Status::OK();
822 }
823 
824 #define MAKE_SHAPE(T, Shape)                                                 \
825   Status TensorShapeUtils::MakeShape(const T* dims, int64 n, Shape* out) {   \
826     return MakeShapeHelper(dims, n, out);                                    \
827   }                                                                          \
828   Status TensorShapeUtils::MakeShape(gtl::ArraySlice<T> shape, Shape* out) { \
829     return MakeShapeHelper(shape.data(), shape.size(), out);                 \
830   }
MAKE_SHAPE(int32,TensorShape)831 MAKE_SHAPE(int32, TensorShape)
832 MAKE_SHAPE(int64, TensorShape)
833 MAKE_SHAPE(int32, PartialTensorShape)
834 MAKE_SHAPE(int64, PartialTensorShape)
835 #undef MAKE_SHAPE
836 
837 string TensorShapeUtils::ShapeListString(
838     const gtl::ArraySlice<TensorShape>& shapes) {
839   string result = "[";
840   bool first = true;
841   for (const TensorShape& shape : shapes) {
842     strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
843     first = false;
844   }
845   strings::StrAppend(&result, "]");
846   return result;
847 }
848 
Concatenate(int64 size) const849 PartialTensorShape PartialTensorShape::Concatenate(int64 size) const {
850   PartialTensorShape out = *this;
851   out.AddDim(size);
852   return out;
853 }
854 
ConcatenateWithStatus(int64 size,PartialTensorShape * out) const855 Status PartialTensorShape::ConcatenateWithStatus(
856     int64 size, PartialTensorShape* out) const {
857   out = const_cast<PartialTensorShape*>(this);
858   return out->AddDimWithStatus(size);
859 }
860 
Concatenate(const PartialTensorShape & shape) const861 PartialTensorShape PartialTensorShape::Concatenate(
862     const PartialTensorShape& shape) const {
863   if (unknown_rank() || shape.unknown_rank()) {
864     return PartialTensorShape();
865   }
866   PartialTensorShape out = *this;
867   for (auto dim : shape) out.AddDim(dim.size);
868   return out;
869 }
870 
ConcatenateWithStatus(const PartialTensorShape & shape,PartialTensorShape * out) const871 Status PartialTensorShape::ConcatenateWithStatus(
872     const PartialTensorShape& shape, PartialTensorShape* out) const {
873   if (unknown_rank() || shape.unknown_rank()) {
874     *out = PartialTensorShape();
875     return Status::OK();
876   }
877   out = const_cast<PartialTensorShape*>(this);
878   for (auto dim : shape) {
879     Status s = out->AddDimWithStatus(dim.size);
880     if (!s.ok()) return s;
881   }
882 
883   return Status::OK();
884 }
885 
MergeWith(const PartialTensorShape & shape,PartialTensorShape * result) const886 Status PartialTensorShape::MergeWith(const PartialTensorShape& shape,
887                                      PartialTensorShape* result) const {
888   if (unknown_rank()) {
889     *result = shape;
890     return Status::OK();
891   }
892   if (shape.unknown_rank()) {
893     *result = *this;
894     return Status::OK();
895   }
896   const int dims_ = dims();
897   if (dims_ != shape.dims()) {
898     return errors::InvalidArgument(
899         "PartialTensorShape: Incompatible ranks during merge: ", dims_, " vs. ",
900         shape.dims());
901   }
902 
903   if (result == this) {
904     return errors::Internal(
905         "PartialTensorShape::MergeWith: cannot merge shape with itself");
906   }
907 
908   result->Clear();
909   Status s = Status::OK();
910   for (int i = 0; i < dims_; ++i) {
911     const int64 dim0 = dim_size(i);
912     const int64 dim1 = shape.dim_size(i);
913     if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) {
914       return errors::InvalidArgument(
915           "PartialTensorShape: Incompatible shapes during merge: ",
916           DebugString(), " vs. ", shape.DebugString());
917     }
918     s.Update(result->AddDimWithStatus(dim0 >= 0 ? dim0 : dim1));
919     if (!s.ok()) {
920       return s;
921     }
922   }
923   return Status::OK();
924 }
925 
AsTensorShape(TensorShape * shape) const926 bool PartialTensorShape::AsTensorShape(TensorShape* shape) const {
927   if (IsFullyDefined()) {
928     const TensorShapeRep* rep = this;
929     *shape = *static_cast<const TensorShape*>(rep);
930     return true;
931   }
932   return false;
933 }
934 
IsIdenticalTo(const PartialTensorShape & shape) const935 bool PartialTensorShape::IsIdenticalTo(const PartialTensorShape& shape) const {
936   if (unknown_rank() || shape.unknown_rank()) {
937     return unknown_rank() == shape.unknown_rank();
938   }
939   if (dims() != shape.dims()) return false;
940   for (int i = 0; i < dims(); i++) {
941     if (dim_size(i) != shape.dim_size(i)) return false;
942   }
943   return true;
944 }
945 
IsCompatibleWith(const PartialTensorShape & shape) const946 bool PartialTensorShape::IsCompatibleWith(
947     const PartialTensorShape& shape) const {
948   if (unknown_rank() || shape.unknown_rank()) return true;
949   if (dims() != shape.dims()) return false;
950   for (int i = 0; i < dims(); i++) {
951     const int64 dim0 = dim_size(i);
952     const int64 dim1 = shape.dim_size(i);
953     if (dim0 >= 0 && dim1 >= 0 && dim0 != dim1) return false;
954   }
955   return true;
956 }
957 
PartialShapeListString(const gtl::ArraySlice<PartialTensorShape> & shapes)958 string PartialTensorShapeUtils::PartialShapeListString(
959     const gtl::ArraySlice<PartialTensorShape>& shapes) {
960   string result = "[";
961   bool first = true;
962   for (const PartialTensorShape& shape : shapes) {
963     strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
964     first = false;
965   }
966   strings::StrAppend(&result, "]");
967   return result;
968 }
969 
AreCompatible(const gtl::ArraySlice<PartialTensorShape> & shapes0,const gtl::ArraySlice<PartialTensorShape> & shapes1)970 bool PartialTensorShapeUtils::AreCompatible(
971     const gtl::ArraySlice<PartialTensorShape>& shapes0,
972     const gtl::ArraySlice<PartialTensorShape>& shapes1) {
973   if (shapes0.size() == shapes1.size()) {
974     for (size_t i = 0; i < shapes0.size(); ++i) {
975       if (!shapes0[i].IsCompatibleWith(shapes1[i])) {
976         return false;
977       }
978     }
979     return true;
980   } else {
981     return false;
982   }
983 }
984 
AreIdentical(const gtl::ArraySlice<PartialTensorShape> & shapes0,const gtl::ArraySlice<PartialTensorShape> & shapes1)985 bool PartialTensorShapeUtils::AreIdentical(
986     const gtl::ArraySlice<PartialTensorShape>& shapes0,
987     const gtl::ArraySlice<PartialTensorShape>& shapes1) {
988   if (shapes0.size() == shapes1.size()) {
989     for (size_t i = 0; i < shapes0.size(); ++i) {
990       if (!shapes0[i].IsIdenticalTo(shapes1[i])) {
991         return false;
992       }
993     }
994     return true;
995   } else {
996     return false;
997   }
998 }
999 
NumElements(gtl::ArraySlice<int64> shape,int64 * num_elements)1000 Status TensorShapeUtils::NumElements(gtl::ArraySlice<int64> shape,
1001                                      int64* num_elements) {
1002   int64 n = 1;
1003   for (auto dim : shape) {
1004     n = MultiplyWithoutOverflow(n, dim);
1005     if (n < 0) {
1006       return errors::InvalidArgument("Can't compute total size of shape [",
1007                                      absl::StrJoin(shape, ","),
1008                                      "]; product would overflow int64");
1009     }
1010   }
1011   *num_elements = n;
1012   return Status::OK();
1013 }
1014 
1015 template class TensorShapeBase<TensorShape>;
1016 template class TensorShapeBase<PartialTensorShape>;
1017 
1018 }  // namespace tensorflow
1019