1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference.h"
16 
17 #include "tensorflow/core/framework/bounds_check.h"
18 #include "tensorflow/core/framework/node_def.pb_text.h"
19 #include "tensorflow/core/framework/partial_tensor_shape.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/strings/numbers.h"
23 #include "tensorflow/core/lib/strings/scanner.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 
26 namespace tensorflow {
27 namespace shape_inference {
28 
29 constexpr int32 InferenceContext::kUnknownRank;
30 constexpr int64 InferenceContext::kUnknownDim;
31 
InferenceContext(int graph_def_version,const NodeDef * node_def,const OpDef & op_def,const std::vector<TensorShapeProto> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<TensorShapeProto> & input_tensors_as_shapes,const std::vector<std::unique_ptr<std::vector<std::pair<TensorShapeProto,DataType>>>> & input_handle_shapes_and_types)32 InferenceContext::InferenceContext(
33     int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
34     const std::vector<TensorShapeProto>& input_shapes,
35     const std::vector<const Tensor*>& input_tensors,
36     const std::vector<TensorShapeProto>& input_tensors_as_shapes,
37     const std::vector<
38         std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
39         input_handle_shapes_and_types)
40     : graph_def_version_(graph_def_version),
41       node_def_(CHECK_NOTNULL(node_def)) {
42   std::vector<ShapeHandle> input_tensors_as_shape_handles;
43   input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size());
44   for (const TensorShapeProto& p : input_tensors_as_shapes) {
45     ShapeHandle shape;
46     construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
47     if (!construction_status_.ok()) {
48       return;
49     }
50     input_tensors_as_shape_handles.push_back(shape);
51   }
52   PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
53   if (!construction_status_.ok()) return;
54   inputs_.reserve(input_shapes.size());
55   for (const TensorShapeProto& p : input_shapes) {
56     ShapeHandle shape;
57     construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
58     if (!construction_status_.ok()) {
59       return;
60     }
61     inputs_.push_back(shape);
62   }
63 
64   std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
65       input_shapes.size());
66   for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) {
67     const auto& v = input_handle_shapes_and_types[i];
68     if (v == nullptr) {
69       continue;
70     }
71     handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
72     auto& new_v = *handle_data[i];
73     for (int j = 0; j < v->size(); ++j) {
74       const auto& p = (*v)[j];
75       construction_status_.Update(
76           MakeShapeFromShapeProto(p.first, &new_v[j].shape));
77       if (!construction_status_.ok()) {
78         return;
79       }
80       new_v[j].dtype = p.second;
81     }
82   }
83   PostInputInit(std::move(handle_data));
84 }
85 
86 // Same as above, but with PartialTensorShape instead of TensorShapeProto
InferenceContext(int graph_def_version,const NodeDef * node_def,const OpDef & op_def,const std::vector<PartialTensorShape> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<PartialTensorShape> & input_tensors_as_shapes,const std::vector<std::unique_ptr<std::vector<std::pair<PartialTensorShape,DataType>>>> & input_handle_shapes_and_types)87 InferenceContext::InferenceContext(
88     int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
89     const std::vector<PartialTensorShape>& input_shapes,
90     const std::vector<const Tensor*>& input_tensors,
91     const std::vector<PartialTensorShape>& input_tensors_as_shapes,
92     const std::vector<
93         std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>&
94         input_handle_shapes_and_types)
95     : graph_def_version_(graph_def_version),
96       node_def_(CHECK_NOTNULL(node_def)) {
97   std::vector<ShapeHandle> input_tensors_as_shape_handles;
98   input_tensors_as_shape_handles.reserve(input_tensors_as_shapes.size());
99   for (const PartialTensorShape& p : input_tensors_as_shapes) {
100     ShapeHandle shape;
101     construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
102     if (!construction_status_.ok()) {
103       return;
104     }
105     input_tensors_as_shape_handles.push_back(shape);
106   }
107   PreInputInit(op_def, input_tensors, input_tensors_as_shape_handles);
108   if (!construction_status_.ok()) return;
109   inputs_.reserve(input_shapes.size());
110   for (const PartialTensorShape& p : input_shapes) {
111     ShapeHandle shape;
112     construction_status_.Update(MakeShapeFromPartialTensorShape(p, &shape));
113     if (!construction_status_.ok()) {
114       return;
115     }
116     inputs_.push_back(shape);
117   }
118   std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
119       input_shapes.size());
120   for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) {
121     const auto& v = input_handle_shapes_and_types[i];
122     if (v == nullptr) {
123       continue;
124     }
125     handle_data[i].reset(new std::vector<ShapeAndType>(v->size()));
126     auto& new_v = *handle_data[i];
127     for (int j = 0; j < v->size(); ++j) {
128       const auto& p = (*v)[j];
129       construction_status_.Update(
130           MakeShapeFromPartialTensorShape(p.first, &new_v[j].shape));
131       if (!construction_status_.ok()) {
132         return;
133       }
134       new_v[j].dtype = p.second;
135     }
136   }
137   PostInputInit(std::move(handle_data));
138 }
139 
InferenceContext(int graph_def_version,const NodeDef * node_def,const OpDef & op_def,const std::vector<ShapeHandle> & input_shapes,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes,std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_shapes_and_types)140 InferenceContext::InferenceContext(
141     int graph_def_version, const NodeDef* node_def, const OpDef& op_def,
142     const std::vector<ShapeHandle>& input_shapes,
143     const std::vector<const Tensor*>& input_tensors,
144     const std::vector<ShapeHandle>& input_tensors_as_shapes,
145     std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
146         input_handle_shapes_and_types)
147     : graph_def_version_(graph_def_version),
148       node_def_(CHECK_NOTNULL(node_def)) {
149   PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
150   if (!construction_status_.ok()) return;
151   inputs_ = input_shapes;
152 
153   PostInputInit(std::move(input_handle_shapes_and_types));
154 }
155 
~InferenceContext()156 InferenceContext::~InferenceContext() {}
157 
Run(const std::function<Status (shape_inference::InferenceContext * c)> & fn)158 Status InferenceContext::Run(
159     const std::function<Status(shape_inference::InferenceContext* c)>& fn) {
160   ForgetMerges();
161   Status s = fn(this);
162   if (!s.ok()) {
163     ForgetMerges();
164     return AttachContext(s);
165   }
166 #ifndef NDEBUG
167   for (int i = 0; i < num_outputs(); ++i) {
168     DCHECK(output(i).IsSet())
169         << i << " for " << node_def_->name() << " of type " << node_def_->op();
170   }
171 #endif  // NDEBUG
172   return s;
173 }
174 
set_output(StringPiece output_name,const std::vector<ShapeHandle> & shapes)175 Status InferenceContext::set_output(StringPiece output_name,
176                                     const std::vector<ShapeHandle>& shapes) {
177   auto result = output_name_map_.find(output_name);
178   if (result == output_name_map_.end()) {
179     return errors::InvalidArgument("Unknown output name: ", output_name);
180   } else {
181     const int start = result->second.first;
182     const int size = result->second.second - start;
183     if (size != shapes.size()) {
184       return errors::InvalidArgument("Must have exactly ", shapes.size(),
185                                      " shapes.");
186     }
187     for (int i = 0; i < size; ++i) {
188       outputs_[i + start] = shapes[i];
189     }
190   }
191   return Status::OK();
192 }
193 
input(StringPiece input_name,std::vector<ShapeHandle> * output) const194 Status InferenceContext::input(StringPiece input_name,
195                                std::vector<ShapeHandle>* output) const {
196   const auto result = input_name_map_.find(input_name);
197   if (result == input_name_map_.end()) {
198     return errors::InvalidArgument("Unknown input name: ", input_name);
199   } else {
200     output->clear();
201     for (int i = result->second.first; i < result->second.second; ++i) {
202       output->push_back(inputs_[i]);
203     }
204   }
205   return Status::OK();
206 }
207 
output(StringPiece output_name,std::vector<ShapeHandle> * output) const208 Status InferenceContext::output(StringPiece output_name,
209                                 std::vector<ShapeHandle>* output) const {
210   const auto result = output_name_map_.find(output_name);
211   if (result == output_name_map_.end()) {
212     return errors::InvalidArgument("Unknown output name: ", output_name);
213   } else {
214     output->clear();
215     for (int i = result->second.first; i < result->second.second; ++i) {
216       output->push_back(outputs_[i]);
217     }
218   }
219   return Status::OK();
220 }
221 
op() const222 string InferenceContext::op() const { return node_def_->op(); }
223 
PreInputInit(const OpDef & op_def,const std::vector<const Tensor * > & input_tensors,const std::vector<ShapeHandle> & input_tensors_as_shapes)224 void InferenceContext::PreInputInit(
225     const OpDef& op_def, const std::vector<const Tensor*>& input_tensors,
226     const std::vector<ShapeHandle>& input_tensors_as_shapes) {
227   input_tensors_ = input_tensors;
228   input_tensors_as_shapes_ = input_tensors_as_shapes;
229 
230   construction_status_ = NameRangesForNode(*node_def_, op_def, &input_name_map_,
231                                            &output_name_map_);
232   if (!construction_status_.ok()) return;
233 
234   int num_outputs = 0;
235   for (const auto& e : output_name_map_) {
236     num_outputs = std::max(num_outputs, e.second.second);
237   }
238   outputs_.assign(num_outputs, nullptr);
239   output_handle_shapes_and_types_.resize(num_outputs);
240 }
241 
ExpandOutputs(int new_output_size)242 Status InferenceContext::ExpandOutputs(int new_output_size) {
243   if (new_output_size < outputs_.size()) {
244     return errors::InvalidArgument("Trying to reduce number of outputs of op.");
245   }
246   outputs_.resize(new_output_size, nullptr);
247   output_handle_shapes_and_types_.resize(new_output_size);
248   return Status::OK();
249 }
250 
PostInputInit(std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data)251 void InferenceContext::PostInputInit(
252     std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
253   int num_inputs_from_node_def = 0;
254   for (const auto& e : input_name_map_) {
255     num_inputs_from_node_def =
256         std::max(num_inputs_from_node_def, e.second.second);
257   }
258 
259   // Allow passing empty shapes/dtypes to avoid changing every single test.
260   if (input_handle_data.empty()) {
261     input_handle_shapes_and_types_.resize(inputs_.size());
262   } else {
263     if (input_handle_data.size() != inputs_.size()) {
264       construction_status_ = errors::InvalidArgument(
265           "Wrong number of handle shapes passed; expected ", inputs_.size(),
266           " got ", input_handle_data.size());
267       return;
268     }
269     input_handle_shapes_and_types_ = std::move(input_handle_data);
270   }
271 
272   if (inputs_.size() != num_inputs_from_node_def) {
273     construction_status_ = errors::InvalidArgument(
274         "Wrong number of inputs passed: ", inputs_.size(), " while ",
275         num_inputs_from_node_def, " expected based on NodeDef");
276     return;
277   }
278 
279   CHECK_LE(input_tensors_.size(), inputs_.size());
280   input_tensors_.resize(inputs_.size());
281   requested_input_tensor_.resize(inputs_.size());
282   requested_input_tensor_as_partial_shape_.resize(inputs_.size());
283 }
284 
ShapeHandleToProto(ShapeHandle handle,TensorShapeProto * proto)285 void InferenceContext::ShapeHandleToProto(ShapeHandle handle,
286                                           TensorShapeProto* proto) {
287   if (!RankKnown(handle)) {
288     proto->set_unknown_rank(true);
289     return;
290   }
291 
292   for (int32 i = 0; i < Rank(handle); ++i) {
293     DimensionHandle dim = Dim(handle, i);
294     auto* dim_shape = proto->add_dim();
295     if (ValueKnown(dim)) {
296       dim_shape->set_size(Value(dim));
297     } else {
298       dim_shape->set_size(-1);
299     }
300   }
301 }
302 
FullyDefined(ShapeHandle s)303 bool InferenceContext::FullyDefined(ShapeHandle s) {
304   if (!RankKnown(s)) return false;
305   for (int i = 0; i < Rank(s); ++i) {
306     if (!ValueKnown(Dim(s, i))) return false;
307   }
308   return true;
309 }
310 
NumElements(ShapeHandle s)311 DimensionHandle InferenceContext::NumElements(ShapeHandle s) {
312   const auto rank = Rank(s);
313   if (rank == kUnknownRank) return UnknownDim();
314   bool found_unknown = false;
315   int64 size = 1;
316   for (int i = 0; i < rank; ++i) {
317     int64 dim_val = Value(Dim(s, i));
318     if (dim_val == kUnknownDim) {
319       found_unknown = true;
320     } else if (dim_val == 0) {
321       return MakeDim(0);
322     } else {
323       size *= dim_val;
324     }
325   }
326   if (found_unknown) {
327     return UnknownDim();
328   } else {
329     return MakeDim(size);
330   }
331 }
332 
DebugString(ShapeHandle s)333 string InferenceContext::DebugString(ShapeHandle s) {
334   if (RankKnown(s)) {
335     std::vector<string> vals;
336     for (auto d : s->dims_) vals.push_back(DebugString(d));
337     return strings::StrCat("[", str_util::Join(vals, ","), "]");
338   } else {
339     return "?";
340   }
341 }
342 
DebugString(DimensionHandle d)343 string InferenceContext::DebugString(DimensionHandle d) {
344   return ValueKnown(d) ? strings::StrCat(Value(d)) : "?";
345 }
346 
DebugString() const347 string InferenceContext::DebugString() const {
348   return strings::StrCat("InferenceContext for node: ",
349                          ProtoDebugString(*node_def_));
350 }
351 
DebugString(const ShapeAndType & shape_and_type)352 string InferenceContext::DebugString(const ShapeAndType& shape_and_type) {
353   return strings::StrCat(DebugString(shape_and_type.shape), ":",
354                          DataTypeString(shape_and_type.dtype));
355 }
356 
DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types)357 string InferenceContext::DebugString(
358     gtl::ArraySlice<ShapeAndType> shape_and_types) {
359   std::vector<string> pieces;
360   for (const ShapeAndType& s : shape_and_types) {
361     pieces.push_back(DebugString(s));
362   }
363   return strings::StrCat("[", str_util::Join(pieces, ","), "]");
364 }
365 
WithRank(ShapeHandle shape,int64 rank,ShapeHandle * out)366 Status InferenceContext::WithRank(ShapeHandle shape, int64 rank,
367                                   ShapeHandle* out) {
368   if (rank > kint32max) {
369     return errors::InvalidArgument("Rank cannot exceed kint32max");
370   }
371   const int32 existing = Rank(shape);
372   if (existing == rank) {
373     *out = shape;
374     return Status::OK();
375   }
376   if (existing == kUnknownRank) {
377     std::vector<DimensionHandle> dims;
378     dims.reserve(rank);
379     for (int i = 0; i < rank; ++i) {
380       dims.push_back(UnknownDim());
381     }
382     ShapeHandle shp = shape_manager_.MakeShape(dims);
383     return Merge(shape, shp, out);
384   }
385   *out = nullptr;
386 
387   return errors::InvalidArgument("Shape must be rank ", rank, " but is rank ",
388                                  existing);
389 }
390 
WithRankAtLeast(ShapeHandle shape,int64 rank,ShapeHandle * out)391 Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64 rank,
392                                          ShapeHandle* out) {
393   if (rank > kint32max) {
394     return errors::InvalidArgument("Rank cannot exceed kint32max");
395   }
396   const int32 existing = Rank(shape);
397   if (existing >= rank || existing == kUnknownRank) {
398     *out = shape;
399     return Status::OK();
400   }
401   *out = nullptr;
402   return errors::InvalidArgument("Shape must be at least rank ", rank,
403                                  " but is rank ", existing);
404 }
405 
WithRankAtMost(ShapeHandle shape,int64 rank,ShapeHandle * out)406 Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64 rank,
407                                         ShapeHandle* out) {
408   if (rank > kint32max) {
409     return errors::InvalidArgument("Rank cannot exceed kint32max");
410   }
411   const int32 existing = Rank(shape);
412   if (existing <= rank || existing == kUnknownRank) {
413     *out = shape;
414     return Status::OK();
415   }
416   *out = nullptr;
417   return errors::InvalidArgument("Shape must be at most rank ", rank,
418                                  " but is rank ", existing);
419 }
420 
WithValue(DimensionHandle dim,int64 value,DimensionHandle * out)421 Status InferenceContext::WithValue(DimensionHandle dim, int64 value,
422                                    DimensionHandle* out) {
423   const int64 existing = Value(dim);
424   if (existing == value) {
425     *out = dim;
426     return Status::OK();
427   }
428   if (existing == kUnknownDim) {
429     DimensionHandle d = MakeDim(value);
430     return Merge(dim, d, out);
431   }
432   *out = nullptr;
433   return errors::InvalidArgument("Dimension must be ", value, " but is ",
434                                  existing);
435 }
436 
Relax(DimensionHandle d_old,DimensionHandle d_new,DimensionHandle * out)437 void InferenceContext::Relax(DimensionHandle d_old, DimensionHandle d_new,
438                              DimensionHandle* out) {
439   if (d_old.SameHandle(d_new)) {
440     *out = d_old;
441   } else if (!ValueKnown(d_old) && !ValueKnown(d_new)) {
442     // The node will be fed by the dimension d_new instead of d_old: any
443     // equality assertion between d_old and other input dimension on this node
444     // may not be true anymore, so forget them all.
445     ForgetMerges();
446     // Return the new shape handle to force the relaxation to propagate to the
447     // fanout of the context.
448     *out = d_new;
449   } else if (!ValueKnown(d_new)) {
450     ForgetMerges();
451     *out = d_new;
452   } else if (Value(d_old) == Value(d_new)) {
453     // Return the old shape handle. This will stop the relaxation in the fanout
454     // of the context.
455     *out = d_old;
456   } else {
457     // Return a new handle that encodes a different unknown dim.
458     ForgetMerges();
459     *out = UnknownDim();
460   }
461 }
462 
Merge(DimensionHandle d0,DimensionHandle d1,DimensionHandle * out)463 Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1,
464                                DimensionHandle* out) {
465   if (d0.SameHandle(d1)) {
466     *out = d0;
467     return Status::OK();
468   } else if (!ValueKnown(d1)) {
469     *out = d0;
470     merged_dims_.emplace_back(d0, d1);
471     return Status::OK();
472   } else if (!ValueKnown(d0)) {
473     *out = d1;
474     merged_dims_.emplace_back(d0, d1);
475     return Status::OK();
476   } else if (Value(d0) == Value(d1)) {
477     *out = d0;
478     return Status::OK();
479   } else {
480     *out = nullptr;
481     return errors::InvalidArgument("Dimensions must be equal, but are ",
482                                    Value(d0), " and ", Value(d1));
483   }
484 }
485 
MergePrefix(ShapeHandle s,ShapeHandle prefix,ShapeHandle * s_out,ShapeHandle * prefix_out)486 Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix,
487                                      ShapeHandle* s_out,
488                                      ShapeHandle* prefix_out) {
489   *s_out = *prefix_out = nullptr;
490   if (!RankKnown(prefix) || !RankKnown(s)) {
491     *s_out = s;
492     *prefix_out = prefix;
493     return Status::OK();
494   }
495   const int32 rank = Rank(prefix);
496   TF_RETURN_IF_ERROR(WithRankAtLeast(s, rank, &s));
497 
498   // Merge the prefix dims and create the new output shapes.
499   const int32 rank_s = Rank(s);
500   std::vector<DimensionHandle> dims;
501   dims.reserve(std::max(rank, rank_s));
502   dims.resize(rank);
503   for (int i = 0; i < rank; ++i) {
504     TF_RETURN_IF_ERROR(Merge(Dim(s, i), Dim(prefix, i), &dims[i]));
505   }
506   *prefix_out = MakeShape(dims);
507   for (int i = rank; i < rank_s; ++i) dims.push_back(Dim(s, i));
508   *s_out = MakeShape(dims);
509   return Status::OK();
510 }
511 
Relax(ShapeHandle s_old,ShapeHandle s_new,ShapeHandle * out)512 void InferenceContext::Relax(ShapeHandle s_old, ShapeHandle s_new,
513                              ShapeHandle* out) {
514   if (s_old.SameHandle(s_new)) {
515     *out = s_old;
516     return;
517   } else if (!RankKnown(s_new) || !s_old.IsSet()) {
518     ForgetMerges();
519     *out = s_new;
520     return;
521   }
522 
523   const int32 rank = Rank(s_old);
524   if (rank != Rank(s_new)) {
525     ForgetMerges();
526     *out = UnknownShape();
527     return;
528   }
529 
530   bool return_s_old = true;
531   for (int i = 0; i < rank; ++i) {
532     auto d0 = Dim(s_old, i);
533     auto d1 = Dim(s_new, i);
534     if (d0.SameHandle(d1)) continue;
535 
536     auto v0 = Value(d0);
537     auto v1 = Value(d1);
538     if (v0 == kUnknownDim || v1 == kUnknownDim || v0 != v1) {
539       return_s_old = false;
540       break;
541     }
542   }
543   if (return_s_old) {
544     *out = s_old;
545     return;
546   }
547 
548   // Relax dims.
549   std::vector<DimensionHandle> dims(rank);
550   for (int i = 0; i < rank; ++i) {
551     Relax(Dim(s_old, i), Dim(s_new, i), &dims[i]);
552   }
553   ForgetMerges();
554   *out = MakeShape(dims);
555 }
556 
Merge(ShapeHandle s0,ShapeHandle s1,ShapeHandle * out)557 Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
558                                ShapeHandle* out) {
559   if (s0.SameHandle(s1)) {
560     *out = s0;
561     return Status::OK();
562   } else if (!RankKnown(s1)) {
563     *out = s0;
564     merged_shapes_.emplace_back(s0, s1);
565     return Status::OK();
566   } else if (!RankKnown(s0)) {
567     *out = s1;
568     merged_shapes_.emplace_back(s0, s1);
569     return Status::OK();
570   }
571 
572   const int32 rank = Rank(s0);
573   if (rank != Rank(s1)) {
574     *out = nullptr;
575     return errors::InvalidArgument("Shapes must be equal rank, but are ", rank,
576                                    " and ", Rank(s1));
577   }
578 
579   bool return_s0 = true;
580   bool return_s1 = true;
581   for (int i = 0; i < rank; ++i) {
582     auto d0 = Dim(s0, i);
583     auto d1 = Dim(s1, i);
584     if (d0.SameHandle(d1)) continue;
585 
586     auto v0 = Value(d0);
587     auto v1 = Value(d1);
588     if (v0 == kUnknownDim) {
589       if (v1 != kUnknownDim) {
590         return_s0 = false;
591       }
592     } else if (v1 == kUnknownDim) {
593       return_s1 = false;
594     } else if (v0 != v1) {
595       *out = nullptr;
596       return errors::InvalidArgument(
597           "Dimension ", i, " in both shapes must be equal, but are ", Value(d0),
598           " and ", Value(d1), ". Shapes are ", DebugString(s0), " and ",
599           DebugString(s1), ".");
600     }
601   }
602 
603   merged_shapes_.emplace_back(s0, s1);
604 
605   if (return_s0 || return_s1) {
606     *out = return_s0 ? s0 : s1;
607     return Status::OK();
608   }
609 
610   // Merge dims.
611   std::vector<DimensionHandle> dims(rank, nullptr);
612   for (int i = 0; i < rank; ++i) {
613     // Invariant for merge was checked earlier, so CHECK is ok.
614     TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i]));
615   }
616 
617   Status s = ReturnCreatedShape(dims, out);
618   if (s.ok()) {
619     // Merge the new shape with s0. Since s0 and s1 are merged, this implies
620     // that s1 and out are also merged.
621     merged_shapes_.emplace_back(s0, *out);
622   }
623   return s;
624 }
625 
Subshape(ShapeHandle s,int64 start,ShapeHandle * out)626 Status InferenceContext::Subshape(ShapeHandle s, int64 start,
627                                   ShapeHandle* out) {
628   return Subshape(s, start, std::numeric_limits<int64>::max() /* end */, out);
629 }
630 
Subshape(ShapeHandle s,int64 start,int64 end,ShapeHandle * out)631 Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
632                                   ShapeHandle* out) {
633   return Subshape(s, start, end, 1 /* stride */, out);
634 }
635 
Subshape(ShapeHandle s,int64 start,int64 end,int64 stride,ShapeHandle * out)636 Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
637                                   int64 stride, ShapeHandle* out) {
638   int64 start_in = start;
639   int64 end_in = end;
640 
641   const int32 rank = Rank(s);
642   if (start == 0 && stride == 1 &&
643       ((RankKnown(s) && end >= rank) ||
644        end == std::numeric_limits<int64>::max())) {
645     *out = s;
646     return Status::OK();
647   }
648   if (!RankKnown(s)) {
649     return ReturnUnknownShape(out);
650   }
651 
652   if (start > rank) start = rank;
653   if (end > rank) end = rank;
654 
655   if (stride < 0 && start == rank) --start;
656 
657   if (start < 0) {
658     start = rank + start;
659     if (start < 0) {
660       *out = nullptr;
661       return errors::InvalidArgument("Subshape start out of bounds: ", start_in,
662                                      ", for shape with rank ", rank);
663     }
664   }
665 
666   if (end < 0) {
667     end = rank + end;
668     if (end < 0) {
669       *out = nullptr;
670       return errors::InvalidArgument("Subshape end out of bounds: ", end_in,
671                                      ", for shape with rank ", rank);
672     }
673   }
674   if (stride > 0 && start > end) {
675     *out = nullptr;
676     return errors::InvalidArgument(
677         "Subshape must have computed start <= end, but is ", start, " and ",
678         end, " (computed from start ", start_in, " and end ", end_in,
679         " over shape with rank ", rank, ")");
680   } else if (stride < 0 && start < end) {
681     *out = nullptr;
682     return errors::InvalidArgument(
683         "Subshape must have computed start >= end since stride is negative, "
684         "but is ",
685         start, " and ", end, " (computed from start ", start_in, " and end ",
686         end_in, " over shape with rank ", rank, " and stride", stride, ")");
687   }
688 
689   std::vector<DimensionHandle> dims;
690   for (int i = start; stride > 0 ? i < end : i > end; i += stride) {
691     dims.push_back(Dim(s, i));
692   }
693   return ReturnCreatedShape(dims, out);
694 }
695 
Concatenate(ShapeHandle s1,ShapeHandle s2,ShapeHandle * out)696 Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2,
697                                      ShapeHandle* out) {
698   if (!RankKnown(s1) || !RankKnown(s2)) {
699     return ReturnUnknownShape(out);
700   }
701   const int32 s1_rank = Rank(s1);
702   const int32 s2_rank = Rank(s2);
703   const int32 rank = s1_rank + s2_rank;
704   std::vector<DimensionHandle> dims;
705   dims.reserve(rank);
706   for (int i = 0; i < s1_rank; ++i) dims.push_back(Dim(s1, i));
707   for (int i = 0; i < s2_rank; ++i) dims.push_back(Dim(s2, i));
708   return ReturnCreatedShape(dims, out);
709 }
710 
ReplaceDim(ShapeHandle s,int64 dim_index_in,DimensionHandle new_dim,ShapeHandle * out)711 Status InferenceContext::ReplaceDim(ShapeHandle s, int64 dim_index_in,
712                                     DimensionHandle new_dim, ShapeHandle* out) {
713   if (!RankKnown(s)) {
714     return ReturnUnknownShape(out);
715   }
716   int64 dim_index = dim_index_in;
717   if (dim_index < 0) {
718     dim_index = s->dims_.size() + dim_index;
719   }
720   if (!FastBoundsCheck(dim_index, s->dims_.size())) {
721     *out = nullptr;
722     return errors::InvalidArgument("Out of range dim_index ", dim_index_in,
723                                    " for shape with ", s->dims_.size(),
724                                    " dimensions");
725   }
726   std::vector<DimensionHandle> dims(s->dims_);
727   dims[dim_index] = new_dim;
728   return ReturnCreatedShape(dims, out);
729 }
730 
MakeShape(const std::vector<DimensionHandle> & dims)731 ShapeHandle InferenceContext::MakeShape(
732     const std::vector<DimensionHandle>& dims) {
733   return shape_manager_.MakeShape(dims);
734 }
735 
MakeShape(std::initializer_list<DimensionOrConstant> dims)736 ShapeHandle InferenceContext::MakeShape(
737     std::initializer_list<DimensionOrConstant> dims) {
738   std::vector<DimensionHandle> dims_actual;
739   dims_actual.reserve(dims.size());
740   for (const DimensionOrConstant& d : dims) {
741     dims_actual.push_back(MakeDim(d));
742   }
743 
744   return shape_manager_.MakeShape(dims_actual);
745 }
746 
UnknownShape()747 ShapeHandle InferenceContext::UnknownShape() {
748   return shape_manager_.UnknownShape();
749 }
750 
UnknownShapeOfRank(int64 rank)751 ShapeHandle InferenceContext::UnknownShapeOfRank(int64 rank) {
752   CHECK_LE(rank, kint32max) << "rank must be less than kint32max";
753   if (rank == kUnknownRank) {
754     return UnknownShape();
755   }
756   CHECK_GE(rank, 0) << "rank must not be negative";
757   std::vector<DimensionHandle> dims(rank);
758   for (int32 i = 0; i < rank; ++i) {
759     dims[i] = UnknownDim();
760   }
761   return MakeShape(dims);
762 }
763 
Scalar()764 ShapeHandle InferenceContext::Scalar() { return MakeShape({}); }
765 
Vector(DimensionOrConstant dim)766 ShapeHandle InferenceContext::Vector(DimensionOrConstant dim) {
767   return MakeShape({dim});
768 }
769 
Matrix(DimensionOrConstant dim1,DimensionOrConstant dim2)770 ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1,
771                                      DimensionOrConstant dim2) {
772   return MakeShape({dim1, dim2});
773 }
774 
MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,ShapeHandle * out)775 Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
776     int input_idx, ShapeHandle* out) {
777   ShapeHandle input_shape;
778   TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape));
779 
780   requested_input_tensor_as_partial_shape_[input_idx] = true;
781   if (input_idx < input_tensors_as_shapes_.size() &&
782       input_tensors_as_shapes_[input_idx].IsSet() &&
783       RankKnown(input_tensors_as_shapes_[input_idx])) {
784     *out = input_tensors_as_shapes_[input_idx];
785     return Status::OK();
786   }
787 
788   return InternalMakeShapeFromTensor(
789       true /* treat_unknown_scalar_tensor_as_unknown_shape */,
790       input_tensor(input_idx), input_shape, out);
791 }
792 
MakeShapeFromShapeTensor(int input_idx,ShapeHandle * out)793 Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
794                                                   ShapeHandle* out) {
795   ShapeHandle input_shape;
796   TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape));
797 
798   requested_input_tensor_as_partial_shape_[input_idx] = true;
799   if (input_idx < input_tensors_as_shapes_.size() &&
800       input_tensors_as_shapes_[input_idx].IsSet() &&
801       RankKnown(input_tensors_as_shapes_[input_idx])) {
802     *out = input_tensors_as_shapes_[input_idx];
803     return Status::OK();
804   }
805 
806   return InternalMakeShapeFromTensor(
807       false /* treat_unknown_scalar_tensor_as_unknown_shape */,
808       input_tensor(input_idx), input_shape, out);
809 }
810 
MakeShapeFromTensor(const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)811 Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
812                                              ShapeHandle tensor_shape,
813                                              ShapeHandle* out) {
814   return InternalMakeShapeFromTensor(
815       false /* treat_unknown_scalar_tensor_as_unknown_shape */, t, tensor_shape,
816       out);
817 }
818 
InternalMakeShapeFromTensor(bool treat_unknown_scalar_tensor_as_unknown_shape,const Tensor * t,ShapeHandle tensor_shape,ShapeHandle * out)819 Status InferenceContext::InternalMakeShapeFromTensor(
820     bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
821     ShapeHandle tensor_shape, ShapeHandle* out) {
822   // Only callers who have set
823   if (!treat_unknown_scalar_tensor_as_unknown_shape) {
824     TF_RETURN_IF_ERROR(WithRank(tensor_shape, 1, &tensor_shape));
825   }
826   if (t == nullptr) {
827     // This is guarded by the check above.
828     if (Rank(tensor_shape) == 0) {
829       return ReturnUnknownShape(out);
830     }
831     // Shape tensor is not known, but if the shape of the shape tensor is then
832     // the right number of unknown dims can be created.
833     DimensionHandle shape_dim = Dim(tensor_shape, 0);
834     if (!ValueKnown(shape_dim)) {
835       return ReturnUnknownShape(out);
836     }
837     const auto num_dims = Value(shape_dim);
838     std::vector<DimensionHandle> dims;
839     dims.reserve(num_dims);
840     for (int i = 0; i < num_dims; i++) dims.push_back(UnknownDim());
841     return ReturnCreatedShape(dims, out);
842   }
843 
844   if (t->shape().dims() == 0) {
845     if (t->dtype() == DataType::DT_INT32) {
846       auto flat_t = t->scalar<int32>();
847       if (flat_t() != -1) {
848         *out = nullptr;
849         return errors::InvalidArgument(
850             "Input tensor must be rank 1, or if its rank 0 it must have value "
851             "-1 "
852             "(representing an unknown shape).  Saw value: ",
853             flat_t());
854       }
855       return ReturnUnknownShape(out);
856     } else if (t->dtype() == DataType::DT_INT64) {
857       auto flat_t = t->scalar<int64>();
858       if (flat_t() != -1) {
859         *out = nullptr;
860         return errors::InvalidArgument(
861             "Input tensor must be rank 1, or if its rank 0 it must have value "
862             "-1 "
863             "(representing an unknown shape).  Saw value: ",
864             flat_t());
865       }
866       return ReturnUnknownShape(out);
867     } else {
868       *out = nullptr;
869       return errors::InvalidArgument(
870           "Input tensor must be int32 or int64, but was ",
871           DataTypeString(t->dtype()));
872     }
873   }
874 
875   if (t->shape().dims() != 1) {
876     *out = nullptr;
877     return errors::InvalidArgument(
878         "Input tensor must be rank 1, but was rank ", t->shape().dims(), ".",
879         ((t->shape().dims() == 0)
880              ? "If it is rank 0 rank 0 it must have statically known value -1 "
881                "(representing an unknown shape). "
882              : " "),
883         "Saw tensor shape ", t->shape().DebugString());
884   }
885   std::vector<DimensionHandle> dims;
886   if (t->dtype() == DataType::DT_INT32) {
887     auto flat_t = t->flat<int32>();
888     for (int i = 0; i < flat_t.size(); ++i) {
889       const int32 val = flat_t(i);
890       if (val < -1) {
891         return errors::InvalidArgument(
892             "Invalid value in tensor used for shape: ", val);
893       }
894       // -1 will become an unknown dim.
895       dims.push_back(MakeDim(val));
896     }
897   } else if (t->dtype() == DataType::DT_INT64) {
898     auto flat_t = t->flat<int64>();
899     for (int i = 0; i < flat_t.size(); ++i) {
900       const int64 val = flat_t(i);
901       if (val < -1) {
902         return errors::InvalidArgument(
903             "Invalid value in tensor used for shape: ", val);
904       }
905       // -1 will become an unknown dim.
906       dims.push_back(MakeDim(val));
907     }
908   } else {
909     *out = nullptr;
910     return errors::InvalidArgument(
911         "Input tensor must be int32 or int64, but was ",
912         DataTypeString(t->dtype()));
913   }
914 
915   return ReturnCreatedShape(dims, out);
916 }
917 
MakeShapeFromPartialTensorShape(const PartialTensorShape & partial_shape,ShapeHandle * out)918 Status InferenceContext::MakeShapeFromPartialTensorShape(
919     const PartialTensorShape& partial_shape, ShapeHandle* out) {
920   *out = nullptr;
921   if (partial_shape.dims() == -1) {
922     return ReturnUnknownShape(out);
923   }
924   const int num_dims = partial_shape.dims();
925   std::vector<DimensionHandle> dims(num_dims);
926   for (int i = 0; i < num_dims; ++i) {
927     // -1 is unknown in PartialTensorShape and in InferenceContext, so this size
928     // can be passed directly to MakeDim.
929     dims[i] = MakeDim(partial_shape.dim_size(i));
930   }
931   return ReturnCreatedShape(dims, out);
932 }
933 
MakeShapeFromTensorShape(const TensorShape & shape,ShapeHandle * out)934 Status InferenceContext::MakeShapeFromTensorShape(const TensorShape& shape,
935                                                   ShapeHandle* out) {
936   return MakeShapeFromPartialTensorShape(PartialTensorShape(shape.dim_sizes()),
937                                          out);
938 }
939 
MakeShapeFromShapeProto(const TensorShapeProto & proto,ShapeHandle * out)940 Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto,
941                                                  ShapeHandle* out) {
942   *out = nullptr;
943   TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto));
944   PartialTensorShape partial_shape(proto);
945   return MakeShapeFromPartialTensorShape(partial_shape, out);
946 }
947 
GetScalarFromTensor(const Tensor * t,int64 * val)948 Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
949   // Caller must ensure that <t> is not NULL.
950   const int rank = t->dims();
951   if (rank != 0) {
952     return errors::InvalidArgument("Input must be scalar but has rank ", rank);
953   }
954 
955   if (t->dtype() == DT_INT32) {
956     *val = t->scalar<int32>()();
957     return Status::OK();
958   } else if (t->dtype() == DT_INT64) {
959     *val = t->scalar<int64>()();
960     return Status::OK();
961   } else {
962     return errors::InvalidArgument("Scalar input must be int32 or int64.");
963   }
964 }
965 
966 // Returns a new dimension whose value is given by a scalar input tensor.
MakeDimForScalarInput(int idx,DimensionHandle * out)967 Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) {
968   int64 val;
969   const Tensor* t = input_tensor(idx);
970   if (t == nullptr) {
971     *out = UnknownDim();
972     return Status::OK();
973   }
974   TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
975   if (val < 0) {
976     return errors::InvalidArgument("Dimension size, given by scalar input ",
977                                    idx, ", must be non-negative but is ", val);
978   }
979   *out = MakeDim(val);
980   return Status::OK();
981 }
982 
MakeDimForScalarInputWithNegativeIndexing(int idx,int input_rank,DimensionHandle * out)983 Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing(
984     int idx, int input_rank, DimensionHandle* out) {
985   int64 val;
986   const Tensor* t = input_tensor(idx);
987   if (t == nullptr) {
988     *out = UnknownDim();
989     return Status::OK();
990   }
991   TF_RETURN_IF_ERROR(GetScalarFromTensor(t, &val));
992   if (val < 0) {
993     if (input_rank < 0) {
994       *out = UnknownDim();
995       return Status::OK();
996     } else if (val + input_rank < 0) {
997       return errors::InvalidArgument("Dimension size, given by scalar input ",
998                                      val, " must be in range [-", input_rank,
999                                      ", ", input_rank, ")");
1000     } else {
1001       val += input_rank;
1002     }
1003   } else if (input_rank >= 0 && val >= input_rank) {
1004     return errors::InvalidArgument("Dimension size, given by scalar input ",
1005                                    val, " must be in range [-", input_rank,
1006                                    ", ", input_rank, ")");
1007   }
1008   *out = MakeDim(val);
1009   return Status::OK();
1010 }
1011 
Divide(DimensionHandle dividend,DimensionOrConstant divisor,bool evenly_divisible,DimensionHandle * out)1012 Status InferenceContext::Divide(DimensionHandle dividend,
1013                                 DimensionOrConstant divisor,
1014                                 bool evenly_divisible, DimensionHandle* out) {
1015   const int64 divisor_value = Value(divisor);
1016   if (divisor_value == 1) {
1017     *out = dividend;
1018   } else if (!ValueKnown(dividend) ||
1019              (divisor.dim.IsSet() && !ValueKnown(divisor.dim))) {
1020     *out = UnknownDim();
1021   } else {
1022     const int64 v = Value(dividend);
1023     if (divisor_value <= 0) {
1024       return errors::InvalidArgument("Divisor must be positive but is ",
1025                                      divisor_value);
1026     }
1027     if (evenly_divisible && (v % divisor_value) != 0) {
1028       return errors::InvalidArgument(
1029           "Dimension size must be evenly divisible by ", divisor_value,
1030           " but is ", v);
1031     }
1032     *out = MakeDim(v / divisor_value);
1033   }
1034   return Status::OK();
1035 }
1036 
Add(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1037 Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second,
1038                              DimensionHandle* out) {
1039   const int64 first_value = Value(first);
1040   const int64 second_value = Value(second);
1041   // Special cases.
1042   if (first_value == 0) {
1043     *out = MakeDim(second);
1044   } else if (second_value == 0) {
1045     *out = first;
1046   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1047     *out = UnknownDim();
1048   } else {
1049     // Invariant: Both values are known and positive. Still in run-time we can
1050     // get pair of values which cannot be store in output. Check below will
1051     // report error. We still need to avoid undefined behavior of signed
1052     // overflow and use unsigned addition.
1053     const int64 sum = static_cast<uint64>(first_value) + second_value;
1054     if (sum < 0) {
1055       return errors::InvalidArgument("Dimension size overflow from adding ",
1056                                      first_value, " and ", second_value);
1057     }
1058     *out = MakeDim(sum);
1059   }
1060   return Status::OK();
1061 }
1062 
Subtract(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1063 Status InferenceContext::Subtract(DimensionHandle first,
1064                                   DimensionOrConstant second,
1065                                   DimensionHandle* out) {
1066   const int64 first_value = Value(first);
1067   const int64 second_value = Value(second);
1068   // Special cases.
1069   if (second_value == 0) {
1070     *out = first;
1071   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1072     *out = UnknownDim();
1073   } else {
1074     // Invariant: Both values are known, first_value is non-negative, and
1075     // second_value is positive.
1076     if (first_value < second_value) {
1077       return errors::InvalidArgument(
1078           "Negative dimension size caused by subtracting ", second_value,
1079           " from ", first_value);
1080     }
1081     *out = MakeDim(first_value - second_value);
1082   }
1083   return Status::OK();
1084 }
1085 
Multiply(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1086 Status InferenceContext::Multiply(DimensionHandle first,
1087                                   DimensionOrConstant second,
1088                                   DimensionHandle* out) {
1089   const int64 first_value = Value(first);
1090   const int64 second_value = Value(second);
1091   // Special cases.
1092   if (first_value == 0) {
1093     *out = first;
1094   } else if (second_value == 0) {
1095     *out = MakeDim(second);
1096   } else if (first_value == 1) {
1097     *out = MakeDim(second);
1098   } else if (second_value == 1) {
1099     *out = first;
1100   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1101     *out = UnknownDim();
1102   } else {
1103     // Invariant: Both values are known and greater than 1.
1104     const int64 product = first_value * second_value;
1105     if (product < 0) {
1106       return errors::InvalidArgument(
1107           "Negative dimension size caused by overflow when multiplying ",
1108           first_value, " and ", second_value);
1109     }
1110     *out = MakeDim(product);
1111   }
1112   return Status::OK();
1113 }
1114 
Min(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1115 Status InferenceContext::Min(DimensionHandle first, DimensionOrConstant second,
1116                              DimensionHandle* out) {
1117   const int64 first_value = Value(first);
1118   const int64 second_value = Value(second);
1119   if (first_value == 0) {
1120     *out = first;
1121   } else if (second_value == 0) {
1122     *out = MakeDim(second);
1123   } else if (first_value == kUnknownDim || second_value == kUnknownDim) {
1124     *out = UnknownDim();
1125   } else {
1126     if (first_value <= second_value) {
1127       *out = first;
1128     } else {
1129       *out = MakeDim(second);
1130     }
1131   }
1132   return Status::OK();
1133 }
1134 
Max(DimensionHandle first,DimensionOrConstant second,DimensionHandle * out)1135 Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second,
1136                              DimensionHandle* out) {
1137   const int64 first_value = Value(first);
1138   const int64 second_value = Value(second);
1139   if (first_value == kUnknownDim || second_value == kUnknownDim) {
1140     *out = UnknownDim();
1141   } else {
1142     if (first_value >= second_value) {
1143       *out = first;
1144     } else {
1145       *out = MakeDim(second);
1146     }
1147   }
1148   return Status::OK();
1149 }
1150 
AttachContext(const Status & status)1151 Status InferenceContext::AttachContext(const Status& status) {
1152   std::vector<string> input_shapes;
1153   input_shapes.reserve(inputs_.size());
1154   for (const ShapeHandle& input_shape : inputs_) {
1155     input_shapes.emplace_back(DebugString(input_shape));
1156   }
1157 
1158   // Add information about the input tensors and partial tensor shapes used.
1159   std::vector<string> input_from_tensors_str;
1160   std::vector<string> input_from_tensors_as_shape_str;
1161   input_from_tensors_as_shape_str.reserve(inputs_.size());
1162   for (int i = 0; i < inputs_.size(); ++i) {
1163     if (requested_input_tensor_as_partial_shape_[i] &&
1164         i < input_tensors_as_shapes_.size() &&
1165         input_tensors_as_shapes_[i].IsSet() &&
1166         RankKnown(input_tensors_as_shapes_[i])) {
1167       input_from_tensors_as_shape_str.push_back(strings::StrCat(
1168           "input[", i, "] = ", DebugString(input_tensors_as_shapes_[i])));
1169     } else if (requested_input_tensor_[i] && i < input_tensors_.size() &&
1170                input_tensors_[i] != nullptr) {
1171       input_from_tensors_str.push_back(strings::StrCat(
1172           "input[", i, "] = <",
1173           input_tensors_[i]->SummarizeValue(256 /* max_values */), ">"));
1174     }
1175   }
1176 
1177   string error_context = strings::StrCat(
1178       " for '", node_def_->name(), "' (op: '", node_def_->op(),
1179       "') with input shapes: ", str_util::Join(input_shapes, ", "));
1180   if (!input_from_tensors_str.empty()) {
1181     strings::StrAppend(&error_context, " and with computed input tensors: ",
1182                        str_util::Join(input_from_tensors_str, ", "));
1183   }
1184   if (!input_from_tensors_as_shape_str.empty()) {
1185     strings::StrAppend(&error_context,
1186                        " and with input tensors computed as partial shapes: ",
1187                        str_util::Join(input_from_tensors_as_shape_str, ","));
1188   }
1189 
1190   strings::StrAppend(&error_context, ".");
1191   return Status(status.code(),
1192                 strings::StrCat(status.error_message(), error_context));
1193 }
1194 
MergeHandleShapesAndTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1195 bool InferenceContext::MergeHandleShapesAndTypes(
1196     const std::vector<ShapeAndType>& shapes_and_types,
1197     std::vector<ShapeAndType>* to_update) {
1198   if (shapes_and_types.size() != to_update->size()) {
1199     return false;
1200   }
1201   std::vector<ShapeAndType> new_values(shapes_and_types.size());
1202   bool refined = false;
1203   for (int i = 0; i < shapes_and_types.size(); ++i) {
1204     const ShapeAndType& existing = (*to_update)[i];
1205     if (shapes_and_types[i].dtype == existing.dtype) {
1206       new_values[i].dtype = existing.dtype;
1207     } else {
1208       if (existing.dtype != DT_INVALID) {
1209         return false;
1210       } else {
1211         new_values[i].dtype = shapes_and_types[i].dtype;
1212         refined = true;
1213       }
1214     }
1215     if (!Merge(existing.shape, shapes_and_types[i].shape, &new_values[i].shape)
1216              .ok()) {
1217       // merge failed, ignore the new value.
1218       new_values[i].shape = existing.shape;
1219     }
1220     if (!existing.shape.SameHandle(new_values[i].shape)) {
1221       refined = true;
1222     }
1223   }
1224   if (!refined) {
1225     return false;
1226   }
1227   for (int i = 0; i < new_values.size(); ++i) {
1228     (*to_update)[i] = new_values[i];
1229   }
1230   return true;
1231 }
1232 
MergeOutputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1233 bool InferenceContext::MergeOutputHandleShapesAndTypes(
1234     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1235   if (output_handle_shapes_and_types_[idx] == nullptr) {
1236     output_handle_shapes_and_types_[idx].reset(
1237         new std::vector<ShapeAndType>(shapes_and_types));
1238     return true;
1239   }
1240   return MergeHandleShapesAndTypes(shapes_and_types,
1241                                    output_handle_shapes_and_types_[idx].get());
1242 }
1243 
MergeInputHandleShapesAndTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1244 bool InferenceContext::MergeInputHandleShapesAndTypes(
1245     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1246   if (input_handle_shapes_and_types_[idx] == nullptr) {
1247     input_handle_shapes_and_types_[idx].reset(
1248         new std::vector<ShapeAndType>(shapes_and_types));
1249     return true;
1250   }
1251   return MergeHandleShapesAndTypes(shapes_and_types,
1252                                    input_handle_shapes_and_types_[idx].get());
1253 }
1254 
RelaxHandleShapesAndMergeTypes(const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * to_update)1255 bool InferenceContext::RelaxHandleShapesAndMergeTypes(
1256     const std::vector<ShapeAndType>& shapes_and_types,
1257     std::vector<ShapeAndType>* to_update) {
1258   if (shapes_and_types.size() != to_update->size()) {
1259     return false;
1260   }
1261   std::vector<ShapeAndType> new_values(shapes_and_types.size());
1262   for (int i = 0; i < shapes_and_types.size(); ++i) {
1263     const ShapeAndType& existing = (*to_update)[i];
1264     if (shapes_and_types[i].dtype == existing.dtype) {
1265       new_values[i].dtype = existing.dtype;
1266     } else {
1267       if (existing.dtype != DT_INVALID) {
1268         return false;
1269       } else {
1270         new_values[i].dtype = shapes_and_types[i].dtype;
1271       }
1272     }
1273     Relax(existing.shape, shapes_and_types[i].shape, &new_values[i].shape);
1274   }
1275   to_update->swap(new_values);
1276   return true;
1277 }
1278 
RelaxOutputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1279 bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes(
1280     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1281   if (output_handle_shapes_and_types_[idx] == nullptr) {
1282     output_handle_shapes_and_types_[idx].reset(
1283         new std::vector<ShapeAndType>(shapes_and_types));
1284     return true;
1285   }
1286   return RelaxHandleShapesAndMergeTypes(
1287       shapes_and_types, output_handle_shapes_and_types_[idx].get());
1288 }
1289 
RelaxInputHandleShapesAndMergeTypes(int idx,const std::vector<ShapeAndType> & shapes_and_types)1290 bool InferenceContext::RelaxInputHandleShapesAndMergeTypes(
1291     int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1292   if (input_handle_shapes_and_types_[idx] == nullptr) {
1293     input_handle_shapes_and_types_[idx].reset(
1294         new std::vector<ShapeAndType>(shapes_and_types));
1295     return true;
1296   }
1297   return RelaxHandleShapesAndMergeTypes(
1298       shapes_and_types, input_handle_shapes_and_types_[idx].get());
1299 }
1300 
1301 // -----------------------------------------------------------------------------
1302 // ShapeManager
1303 // -----------------------------------------------------------------------------
ShapeManager()1304 InferenceContext::ShapeManager::ShapeManager() {}
~ShapeManager()1305 InferenceContext::ShapeManager::~ShapeManager() {
1306   for (auto* s : all_shapes_) delete s;
1307   for (auto* d : all_dims_) delete d;
1308 }
1309 
MakeShape(const std::vector<DimensionHandle> & dims)1310 ShapeHandle InferenceContext::ShapeManager::MakeShape(
1311     const std::vector<DimensionHandle>& dims) {
1312   all_shapes_.push_back(new Shape(dims));
1313   return all_shapes_.back();
1314 }
1315 
UnknownShape()1316 ShapeHandle InferenceContext::ShapeManager::UnknownShape() {
1317   all_shapes_.push_back(new Shape());
1318   return all_shapes_.back();
1319 }
1320 
1321 }  // namespace shape_inference
1322 }  // namespace tensorflow
1323