1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <iterator>
17 #include <memory>
18 #include <numeric>
19 #include <string>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
26 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
27 #include "tensorflow/lite/toco/model.h"
28 #include "tensorflow/lite/toco/tooling_util.h"
29 
30 namespace toco {
31 
32 namespace {
33 
ComputeConvSizes(const Shape & input_shape,int output_depth,int kwidth,int kheight,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,PaddingType padding_type,Shape * output_shape,FixedPadding * fixed_padding)34 void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth,
35                       int kheight, int stride_width, int stride_height,
36                       int dilation_width_factor, int dilation_height_factor,
37                       PaddingType padding_type, Shape* output_shape,
38                       FixedPadding* fixed_padding) {
39   const int input_width = input_shape.dims(2);
40   const int input_height = input_shape.dims(1);
41   const int batch = input_shape.dims(0);
42 
43   CHECK_GE(input_width, 1);
44   CHECK_GE(input_height, 1);
45   CHECK_GE(batch, 1);
46   CHECK_GE(kwidth, 1);
47   CHECK_GE(kheight, 1);
48   CHECK_GE(stride_width, 1);
49   CHECK_GE(stride_height, 1);
50   CHECK_GE(dilation_width_factor, 1);
51   CHECK_GE(dilation_height_factor, 1);
52 
53   int dilated_kwidth = dilation_width_factor * (kwidth - 1) + 1;
54   int dilated_kheight = dilation_height_factor * (kheight - 1) + 1;
55 
56   int output_height = 0;
57   int output_width = 0;
58   if (padding_type == PaddingType::kValid) {
59     output_height =
60         (input_height + stride_height - dilated_kheight) / stride_height;
61     output_width = (input_width + stride_width - dilated_kwidth) / stride_width;
62   } else if (padding_type == PaddingType::kSame) {
63     output_height = (input_height + stride_height - 1) / stride_height;
64     output_width = (input_width + stride_width - 1) / stride_width;
65   } else {
66     LOG(FATAL) << "Only supporting SAME or VALID padding";
67   }
68 
69   fixed_padding->height = std::max(0, ((output_height - 1) * stride_height +
70                                        dilated_kheight - input_height) /
71                                           2);
72   fixed_padding->width = std::max(
73       0,
74       ((output_width - 1) * stride_width + dilated_kwidth - input_width) / 2);
75 
76   // Actually had to debug a situation where those were negative due to bad
77   // propagation of placeholder -1 sizes in TensorFlowReshape.
78   CHECK_GT(output_width, 0);
79   CHECK_GT(output_height, 0);
80   output_shape->ReplaceDims({batch, output_height, output_width, output_depth});
81 }
82 
ComputeBinaryOperatorOutputSize(const Shape & input_shape_x,const Shape & input_shape_y,Array * output_array)83 void ComputeBinaryOperatorOutputSize(const Shape& input_shape_x,
84                                      const Shape& input_shape_y,
85                                      Array* output_array) {
86   // This matches the code in BroadcastBinaryOpShapeFn from tensorflow.
87   // It zips together the two input shapes and pads with 1 to make them the
88   // same length. For each dimension we broadcast if either dimension is 1 and
89   // otherwise expect them to match.
90   int rank_x = input_shape_x.dimensions_count();
91   int rank_y = input_shape_y.dimensions_count();
92   int rank_out = std::max(rank_x, rank_y);
93   std::vector<int>* dims_out = output_array->mutable_shape()->mutable_dims();
94   dims_out->clear();
95   dims_out->reserve(rank_out);
96   for (int i = 0; i < rank_out; ++i) {
97     int dim_x = i < (rank_out - rank_x)
98                     ? 1
99                     : input_shape_x.dims(i - (rank_out - rank_x));
100     bool dim_y_is_one = i < (rank_out - rank_y);
101     int dim_y = dim_y_is_one ? 1 : input_shape_y.dims(i - (rank_out - rank_y));
102     if (dim_x == -1 || dim_y == -1) {
103       // One or both dimensions is unknown.
104       QCHECK(false) << "Shapes must be specified";
105     } else if (dim_x == 1 || dim_y == 1) {
106       // Broadcast one dimension to the other that is 1.
107       if (dim_x == 1 && !dim_y_is_one) {
108         // Broadcast dim_y to dim_x (1).
109         dims_out->push_back(dim_y);
110       } else {
111         // Broadcast dim_x to dim_y (1).
112         DCHECK_EQ(dim_y, 1);
113         dims_out->push_back(dim_x);
114       }
115     } else {
116       // Expect the dimensions to match.
117       CHECK_EQ(dim_x, dim_y) << "Dimensions must match";
118       dims_out->push_back(dim_x);
119     }
120   }
121   CHECK(output_array->has_shape());
122 }
123 
ProcessConvOperator(Model * model,ConvOperator * op)124 void ProcessConvOperator(Model* model, ConvOperator* op) {
125   const auto& input_array = model->GetArray(op->inputs[0]);
126   // Yield until input dims have been resolved.
127   if (!input_array.has_shape()) {
128     return;
129   }
130   const auto& input_shape = input_array.shape();
131   CHECK(input_shape.dimensions_count() == 4)
132       << "Conv ops require 4D inputs. Input array \"" << op->inputs[0]
133       << "\" is " << input_shape.dimensions_count() << "D.";
134 
135   const auto& weights_array = model->GetArray(op->inputs[1]);
136   // Yield until weights dims have been resolved.
137   if (!weights_array.has_shape()) {
138     return;
139   }
140   const auto& weights_shape = weights_array.shape();
141   CHECK_EQ(weights_shape.dimensions_count(), 4);
142 
143   auto& output_array = model->GetArray(op->outputs[0]);
144   const int output_depth = weights_shape.dims(0);
145   const int kheight = weights_shape.dims(1);
146   const int kwidth = weights_shape.dims(2);
147   ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
148                    op->stride_height, op->dilation_width_factor,
149                    op->dilation_height_factor, op->padding.type,
150                    output_array.mutable_shape(),
151                    &op->padding.GetOrCreateFixedPadding());
152   CHECK_EQ(output_array.shape().dimensions_count(), 4);
153 
154   // Set im2col array dimensions if there is one.
155   if (op->outputs.size() == 2) {
156     const auto& output_shape = output_array.shape();
157     const int input_depth = weights_shape.dims(3);
158     auto& im2col_array = model->GetArray(op->outputs[1]);
159     im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1),
160                                   output_shape.dims(2),
161                                   input_depth * kheight * kwidth});
162   }
163 }
164 
ProcessTransposeConvOperator(Model * model,TransposeConvOperator * op)165 void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
166   // TransposeConv is unique in that it is specifically given the output shape
167   // as a 1D array on it's 1st input. Theoretically then, resolving the output
168   // shape is as easy as waiting for this input to be resolved. However, we also
169   // have to calculate the padding which requires the weights shape. So, we
170   // might as well calculate the output shape and ensure it matches the
171   // specified one
172 
173   // SPECIFIED OUTPUT SHAPE
174   // The below is the specified, or prescribed output shape, _given_ to the
175   // operator as an input.
176   auto& specified_output_shape_array =
177       model->GetArray(op->inputs[TransposeConvOperator::OUTPUT_SHAPE]);
178   if (!specified_output_shape_array.has_shape() ||
179       !specified_output_shape_array.buffer) {
180     // Yield until the specified output shape is resolved as a constant
181     return;
182   }
183 
184   CHECK(specified_output_shape_array.data_type == ArrayDataType::kInt32)
185       << "TransposeConv input_dims must be int32";
186 
187   CHECK(specified_output_shape_array.shape().dimensions_count() == 1 &&
188         specified_output_shape_array.shape().dims(0) == 4)
189       << "TransposeConv requires a 1D, 4 element array on it's 0th input "
190          "specifying the output shape. \""
191       << op->inputs[TransposeConvOperator::OUTPUT_SHAPE] << "\" had shape "
192       << toco::ShapeToString(specified_output_shape_array.shape());
193 
194   // COMPUTE PADDING
195   // We require the weights shape to calculate padding.
196   const auto& weights_array =
197       model->GetArray(op->inputs[TransposeConvOperator::WEIGHTS]);
198   if (!weights_array.has_shape()) {
199     // Yield until weights dims have been resolved.
200     return;
201   }
202   const auto& weights_shape = weights_array.shape();
203   CHECK_EQ(weights_shape.dimensions_count(), 4)
204       << "TransposeConv weights must have 4 input dimensions. Input weights \""
205       << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
206       << toco::ShapeToString(weights_shape) << ".";
207 
208   // Compute padding
209   const int kheight = weights_shape.dims(1);
210   const int kwidth = weights_shape.dims(2);
211   op->padding.GetOrCreateFixedPadding();
212   if (op->padding.type == PaddingType::kValid) {
213     op->padding.fixed->height = 0;
214     op->padding.fixed->width = 0;
215   } else if (op->padding.type == PaddingType::kSame) {
216     op->padding.fixed->height = (kheight - 1) / 2;
217     op->padding.fixed->width = (kwidth - 1) / 2;
218   } else {
219     LOG(FATAL) << "TransposeConv only supports SAME or VALID padding";
220   }
221 
222   // VALIDATE some dimensions and set the output shape.
223   const auto& input_array =
224       model->GetArray(op->inputs[TransposeConvOperator::DATA_INPUT]);
225   if (!input_array.has_shape()) {
226     // Yield until input dims have been resolved.
227     return;
228   }
229   const auto& input_shape = input_array.shape();
230   CHECK_EQ(input_shape.dimensions_count(), 4)
231       << "TransposeConv input shape must have 4 dimensions. Input \""
232       << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
233       << toco::ShapeToString(weights_shape) << ".";
234   CHECK_EQ(input_shape.dims(3), weights_shape.dims(3))
235       << "Input shape depth and weight depth do not agree";
236 
237   // Set the output shape according to the specified output shape.
238   std::vector<int32> const& specified_output_shape =
239       specified_output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
240   auto& output_array = model->GetArray(op->outputs[0]);
241   *(output_array.mutable_shape()->mutable_dims()) = specified_output_shape;
242 
243   // Set im2col array dimensions if there is one.
244   if (op->outputs.size() == 2) {
245     const int input_depth = weights_shape.dims(3);
246     auto& im2col_array = model->GetArray(op->outputs[1]);
247     im2col_array.copy_shape(
248         Shape{specified_output_shape[0], specified_output_shape[1],
249               specified_output_shape[2], input_depth * kheight * kwidth});
250   }
251 }
252 
ProcessDepthwiseConvOperator(Model * model,DepthwiseConvOperator * op)253 void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
254   const auto& input_array = model->GetArray(op->inputs[0]);
255   // Yield until input dims have been resolved.
256   if (!input_array.has_shape()) {
257     return;
258   }
259   const auto& input_shape = input_array.shape();
260   CHECK_EQ(input_shape.dimensions_count(), 4);
261 
262   const auto& weights_array = model->GetArray(op->inputs[1]);
263   // Yield until weights dims have been resolved.
264   if (!weights_array.has_shape()) {
265     return;
266   }
267   const auto& weights_shape = weights_array.shape();
268   CHECK_EQ(weights_shape.dimensions_count(), 4);
269 
270   const string& output_name = op->outputs[0];
271   const int input_depth = input_shape.dims(3);
272   const int output_depth = weights_shape.dims(3);
273   // TensorFlow doesn't define the depth_multiplier value on DepthwiseConv ops,
274   // instead it has to be inferred from the weights dims. However, once we are
275   // here, weights dims have already been converted to our own internal format,
276   // where the multiplier is no longer readily apparent. So instead we get it
277   // as the quotient of output and input depths. We only want to do that when
278   // depth_multiplier had the zero value: any other value should be checked
279   // as done by the next if() below.
280   if (!op->depth_multiplier) {
281     op->depth_multiplier = output_depth / input_depth;
282   }
283   CHECK_EQ(output_depth, input_depth * op->depth_multiplier)
284       << "input/output depths and depth_multiplier don't match";
285 
286   const int kheight = weights_shape.dims(1);
287   const int kwidth = weights_shape.dims(2);
288   ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
289                    op->stride_height, op->dilation_width_factor,
290                    op->dilation_height_factor, op->padding.type,
291                    model->GetArray(output_name).mutable_shape(),
292                    &op->padding.GetOrCreateFixedPadding());
293 }
294 
ProcessDepthToSpaceOperator(Model * model,DepthToSpaceOperator * op)295 void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
296   const auto& input_array = model->GetArray(op->inputs[0]);
297   // Yield until input dims have been resolved.
298   if (!input_array.has_shape()) {
299     return;
300   }
301   const auto& input_shape = input_array.shape();
302   CHECK_EQ(input_shape.dimensions_count(), 4);
303 
304   const string& output_name = op->outputs[0];
305   const int block_size = op->block_size;
306   CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
307   const int batch = input_shape.dims(0);
308   const int height = input_shape.dims(1);
309   const int width = input_shape.dims(2);
310   const int depth = input_shape.dims(3);
311   QCHECK_EQ(depth % (block_size * block_size), 0);
312 
313   model->GetArray(output_name)
314       .copy_shape(Shape({batch, height * block_size, width * block_size,
315                          depth / block_size / block_size}));
316 }
317 
ProcessSpaceToDepthOperator(Model * model,SpaceToDepthOperator * op)318 void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
319   const auto& input_array = model->GetArray(op->inputs[0]);
320   // Yield until input dims have been resolved.
321   if (!input_array.has_shape()) {
322     return;
323   }
324   const auto& input_shape = input_array.shape();
325   CHECK_EQ(input_shape.dimensions_count(), 4);
326 
327   const string& output_name = op->outputs[0];
328   const int block_size = op->block_size;
329   CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
330   const int batch = input_shape.dims(0);
331   const int height = input_shape.dims(1);
332   const int width = input_shape.dims(2);
333   const int depth = input_shape.dims(3);
334   QCHECK_EQ(width % block_size, 0);
335   QCHECK_EQ(height % block_size, 0);
336 
337   model->GetArray(output_name)
338       .copy_shape(Shape({batch, height / block_size, width / block_size,
339                          depth * block_size * block_size}));
340 }
341 
ProcessOpWithShapeInput(Model * model,Operator * op)342 void ProcessOpWithShapeInput(Model* model, Operator* op) {
343   CHECK_EQ(op->outputs.size(), 1);
344   auto& output_array = model->GetArray(op->outputs[0]);
345   if (output_array.has_shape()) {
346     // We have already run
347     return;
348   }
349 
350   auto& dims_array = model->GetArray(op->inputs[0]);
351   if (!dims_array.has_shape()) {
352     // Yield until dims shape been resolved.
353     return;
354   }
355   if (!dims_array.buffer) {
356     // Yield until the dims are constant
357     return;
358   }
359   CHECK(dims_array.data_type == ArrayDataType::kInt32) << "dims must be int32";
360   CHECK_LE(RequiredBufferSizeForShape(dims_array.shape()), 4)
361       << "dims vector can be no larger than 4 values";
362 
363   std::vector<int32> const& dims =
364       dims_array.GetBuffer<ArrayDataType::kInt32>().data;
365   *(output_array.mutable_shape()->mutable_dims()) = dims;
366 }
367 
ProcessFullyConnectedOperator(Model * model,FullyConnectedOperator * op)368 void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
369   const auto& input_array = model->GetArray(op->inputs[0]);
370   // Yield until input dims have been resolved.
371   if (!input_array.has_shape()) {
372     return;
373   }
374   const auto& input_shape = input_array.shape();
375   CHECK_GE(input_shape.dimensions_count(), 1);
376 
377   const auto& weights_array = model->GetArray(op->inputs[1]);
378   // Yield until weights dims have been resolved.
379   if (!weights_array.has_shape()) {
380     return;
381   }
382   const auto& weights_shape = weights_array.shape();
383 
384   const int weights_output_depth = weights_shape.dims(0);
385   CHECK_EQ(weights_shape.dimensions_count(), 2);
386 
387   const int input_overall_size = RequiredBufferSizeForShape(input_shape);
388   const int matmul_repeats = input_overall_size / weights_shape.dims(1);
389   CHECK_EQ(matmul_repeats * weights_shape.dims(1), input_overall_size);
390 
391   auto& output_array = model->GetArray(op->outputs[0]);
392   output_array.copy_shape(Shape({matmul_repeats, weights_output_depth}));
393 }
394 
ProcessTensorFlowReshapeOperator(Model * model,TensorFlowReshapeOperator * op)395 void ProcessTensorFlowReshapeOperator(Model* model,
396                                       TensorFlowReshapeOperator* op) {
397   auto& output_array = model->GetArray(op->outputs[0]);
398   if (output_array.has_shape()) {
399     // We have already run
400     return;
401   }
402 
403   const auto& input_array = model->GetArray(op->inputs[0]);
404   if (!input_array.has_shape()) {
405     // Yield until input dims have been resolved.
406     return;
407   }
408   const auto& input_shape = input_array.shape();
409 
410   auto& shape_array = model->GetArray(op->inputs[1]);
411   if (!shape_array.has_shape()) {
412     // Yield until target_shape shape been resolved.
413     return;
414   }
415   if (!shape_array.buffer) {
416     // Yield until the target_shape is constant
417     return;
418   }
419   CHECK(shape_array.data_type == ArrayDataType::kInt32)
420       << "Reshape dims must be int32";
421 
422   // shape_data is the raw array of ints describing the shape
423   // in the TensorFlow node. We intentionally make a copy here, rather than
424   // modify wildcards in-place below, because in some graphs, the same shape
425   // array with a wildcard may be referenced from multiple Reshape nodes, where
426   // the wildcard needs to resolved to distinct values.
427   std::vector<int32> shape_data =
428       shape_array.GetBuffer<ArrayDataType::kInt32>().data;
429   // The Reshape shape may have a wildcard dim, encoded as -1.
430   bool has_wildcard = false;
431   int wildcard_index = 0;
432   int product_non_wildcard_dims = 1;
433   for (int i = 0; i < shape_data.size(); i++) {
434     if (shape_data[i] == -1) {
435       CHECK(!has_wildcard);
436       has_wildcard = true;
437       wildcard_index = i;
438     } else {
439       product_non_wildcard_dims *= shape_data[i];
440     }
441   }
442 
443   const int input_flat_size = RequiredBufferSizeForShape(input_shape);
444   if (has_wildcard) {
445     CHECK_GE(input_flat_size, product_non_wildcard_dims)
446         << "Array not large enough to fill the requested dimensions for "
447            "Reshape op with output \""
448         << op->outputs[0] << "\". Are your input shapes correct?";
449     shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims;
450   }
451 
452   if (shape_data.size() == 1 && shape_data[0] == 0) {
453     // We have reshaped a scalar, so preserve as a scalar.
454     shape_data.clear();
455   }
456 
457   auto& output_shape = *output_array.mutable_shape();
458   *output_shape.mutable_dims() = shape_data;
459   CHECK_EQ(input_flat_size, RequiredBufferSizeForShape(output_shape))
460       << "Input cannot be reshaped to requested dimensions for Reshape op with "
461          "output \""
462       << op->outputs[0] << "\". Are your input shapes correct?";
463 }
464 
ProcessSimpleOperator(Model * model,Operator * op,int input_index)465 void ProcessSimpleOperator(Model* model, Operator* op, int input_index) {
466   const auto& input_array = model->GetArray(op->inputs[input_index]);
467   // Yield until input dims have been resolved.
468   if (!input_array.has_shape()) {
469     return;
470   }
471 
472   const string& output_name = op->outputs[0];
473   auto& output_array = model->GetArray(output_name);
474   if (output_array.has_shape()) {
475     return;
476   }
477 
478   output_array.copy_shape(input_array.shape());
479 }
480 
ProcessSimpleBinaryOperator(Model * model,Operator * op)481 void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
482   CHECK_EQ(op->inputs.size(), 2);
483   const auto& input0_array = model->GetArray(op->inputs[0]);
484   const auto& input1_array = model->GetArray(op->inputs[1]);
485   // Yield until input dims have been resolved.
486   if (!input0_array.has_shape() || !input1_array.has_shape()) {
487     return;
488   }
489   const string& output_name = op->outputs[0];
490   auto& output_array = model->GetArray(output_name);
491   ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(),
492                                   &output_array);
493 }
494 
ProcessSelectOperator(Model * model,SelectOperator * op)495 void ProcessSelectOperator(Model* model, SelectOperator* op) {
496   // Yield until all input dims have been resolved.
497   for (const auto& input : op->inputs) {
498     const auto& input_array = model->GetArray(input);
499     if (!input_array.has_shape()) {
500       return;
501     }
502   }
503 
504   // Select's output matches the second and third output.
505   const auto& input1_array = model->GetArray(op->inputs[1]);
506   auto& output_array = model->GetArray(op->outputs[0]);
507   output_array.copy_shape(input1_array.shape());
508 }
509 
ProcessAddNOperator(Model * model,Operator * op)510 void ProcessAddNOperator(Model* model, Operator* op) {
511   // Yield until all input dims have been resolved.
512   //
513   // TODO(myenik): Since AddN does not support broadcasting, maybe we could
514   // actually use this to improve shape propagation by propagating the shape of
515   // one input to all other inputs once it is resolved instead of just the
516   // output, since all inputs must be the same size and shape for a well-formed
517   // graph.
518   for (const auto& input : op->inputs) {
519     const auto& input_array = model->GetArray(input);
520     if (!input_array.has_shape()) {
521       return;
522     }
523   }
524 
525   // AddN does not support broadcasting, all inputs must be the same shape, so
526   // we just take the first input shape and apply it to the output.
527   const auto& input0_array = model->GetArray(op->inputs[0]);
528   auto& output_array = model->GetArray(op->outputs[0]);
529   output_array.copy_shape(input0_array.shape());
530 }
531 
KeepDims(const Operator & op)532 bool KeepDims(const Operator& op) {
533   switch (op.type) {
534     case OperatorType::kReduceMin:  //  Reduction Min
535       return static_cast<const TensorFlowMinOperator&>(op).keep_dims;
536     case OperatorType::kReduceMax:  //  Reduction Max
537       return static_cast<const TensorFlowMaxOperator&>(op).keep_dims;
538     case OperatorType::kSum:
539       return static_cast<const TensorFlowSumOperator&>(op).keep_dims;
540     case OperatorType::kReduceProd:
541       return static_cast<const TensorFlowProdOperator&>(op).keep_dims;
542     case OperatorType::kMean:
543       return static_cast<const MeanOperator&>(op).keep_dims;
544     case OperatorType::kAny:
545       return static_cast<const TensorFlowAnyOperator&>(op).keep_dims;
546     default:
547       LOG(FATAL) << "Not a reduction operator!";
548       return false;
549   }
550 }
551 
ProcessTensorFlowReductionOperator(Model * model,Operator * op)552 void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
553   CHECK_LE(op->inputs.size(), 2);
554   auto& output_array = model->GetArray(op->outputs[0]);
555   if (output_array.has_shape()) {
556     return;
557   }
558   const auto& input_array = model->GetArray(op->inputs[0]);
559   if (!input_array.has_shape()) {
560     return;
561   }
562   const auto& input_shape = input_array.shape();
563   const bool keep_dims = KeepDims(*op);
564   if (op->inputs.size() == 2) {
565     // There is a reduction_indices input.
566     const auto& reduction_indices_array = model->GetArray(op->inputs[1]);
567     if (!reduction_indices_array.buffer) {
568       return;
569     }
570     CHECK(reduction_indices_array.buffer->type == ArrayDataType::kInt32);
571 
572     int input_rank = input_shape.dimensions_count();
573     std::set<int32> true_indices;
574     const auto& reduction_indices =
575         reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data;
576     for (int i = 0; i < reduction_indices.size(); ++i) {
577       const int32 reduction_index = reduction_indices[i];
578       if (reduction_index < -input_rank || reduction_index >= input_rank) {
579         CHECK(false) << "Invalid reduction dimension " << reduction_index
580                      << " for input with " << input_rank << " dimensions";
581       }
582       int32 wrapped_index = reduction_index;
583       if (wrapped_index < 0) {
584         wrapped_index += input_rank;
585       }
586       true_indices.insert(wrapped_index);
587     }
588 
589     auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
590     mutable_dims->clear();
591     for (int i = 0; i < input_rank; ++i) {
592       if (true_indices.count(i) > 0) {
593         if (keep_dims) {
594           mutable_dims->emplace_back(1);
595         }
596       } else {
597         mutable_dims->emplace_back(input_shape.dims(i));
598       }
599     }
600   } else {
601     // No reduction_indices means complete reduction to a single scalar.
602     if (keep_dims) {
603       output_array.copy_shape(input_shape);
604     } else {
605       output_array.copy_shape(Shape({}));
606     }
607   }
608 }
609 
ProcessSliceOperator(Model * model,SliceOperator * op)610 void ProcessSliceOperator(Model* model, SliceOperator* op) {
611   CHECK_EQ(op->inputs.size(), 3);
612   CHECK_EQ(op->outputs.size(), 1);
613 
614   // Yield until the Slice params have been resolved.
615   if (op->begin.empty()) return;
616 
617   // Yield until input dims have been resolved.
618   const auto& input_array = model->GetArray(op->inputs[0]);
619   if (!input_array.has_shape()) return;
620   const Shape& input_shape = input_array.shape();
621 
622   auto& output_array = model->GetArray(op->outputs[0]);
623   if (output_array.has_shape()) return;
624 
625   CHECK_EQ(input_shape.dims().size(), op->size.size());
626   CHECK_EQ(op->begin.size(), op->size.size());
627 
628   std::vector<int> output_dims;
629   for (int i = 0; i < op->begin.size(); ++i) {
630     int size = op->size[i];
631     if (size == -1) {
632       size = input_array.shape().dims(i) - op->begin[i];
633     }
634     output_dims.push_back(size);
635   }
636 
637   *output_array.mutable_shape()->mutable_dims() = output_dims;
638 }
639 
ProcessReorderAxesOperator(Model * model,ReorderAxesOperator * op)640 void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
641   const string& input_name = op->inputs[0];
642   const auto& input_array = model->GetArray(input_name);
643   // Yield until input dims have been resolved.
644   if (!input_array.has_shape()) {
645     return;
646   }
647   const auto& input_shape = input_array.shape();
648   const string& output_name = op->outputs[0];
649   Shape* output_shape = model->GetArray(output_name).mutable_shape();
650   ShuffleDims(input_shape, op->input_axes_order, op->output_axes_order,
651               output_shape);
652 }
653 
ProcessConcatenationOperator(Model * model,ConcatenationOperator * op)654 void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
655   // Yield until input dims have been resolved.
656   for (const auto& input_name : op->inputs) {
657     auto& input_array = model->GetArray(input_name);
658     if (!input_array.has_shape()) {
659       return;
660     }
661   }
662   auto& output_array = model->GetArray(op->outputs[0]);
663   // Use first non-empty input as basis for output dimensions.
664   for (const auto& input_name : op->inputs) {
665     const auto& input_array = model->GetArray(input_name);
666     if (input_array.shape().dimensions_count() > 0) {
667       output_array.copy_shape(input_array.shape());
668       // Negative axis means the count starts at the back of the dims().
669       if (op->axis < 0) op->axis += input_array.shape().dims().size();
670       break;
671     }
672   }
673   // Determine the concat size, and enfore that all inputs have
674   // the same dimensions count.
675   int concat_size = 0;
676   for (const auto& input_name : op->inputs) {
677     auto& input_array = model->GetArray(input_name);
678     CHECK(input_array.has_shape());
679     if (input_array.shape().dimensions_count() == 0) {
680       continue;
681     }
682     CHECK_EQ(input_array.shape().dimensions_count(),
683              output_array.shape().dimensions_count());
684     const std::vector<int>& input_dims = input_array.shape().dims();
685     CHECK_LT(op->axis, input_dims.size());
686     concat_size += input_dims[op->axis];
687   }
688   // Write out the concat_size on the output array shape.
689   auto& output_shape = *output_array.mutable_shape();
690   auto& output_dims = *output_shape.mutable_dims();
691   CHECK_LT(op->axis, output_shape.dimensions_count());
692   output_dims[op->axis] = concat_size;
693 }
694 
ProcessRangeOperator(Model * model,RangeOperator * op)695 void ProcessRangeOperator(Model* model, RangeOperator* op) {
696   CHECK_EQ(op->inputs.size(), 3);
697   const auto& start_array = model->GetArray(op->inputs[0]);
698   if (!start_array.has_shape()) {
699     // Yield until input dims have been resolved.
700     return;
701   }
702   const auto& limit_array = model->GetArray(op->inputs[1]);
703   if (!limit_array.has_shape()) {
704     return;
705   }
706   const auto& delta_array = model->GetArray(op->inputs[2]);
707   if (!delta_array.has_shape()) {
708     return;
709   }
710 
711   if (!IsConstantParameterArray(*model, op->inputs[0])) {
712     // Yield until inputs are constant.
713     return;
714   }
715   if (!IsConstantParameterArray(*model, op->inputs[1])) {
716     return;
717   }
718   if (!IsConstantParameterArray(*model, op->inputs[2])) {
719     return;
720   }
721 
722   CHECK(start_array.data_type == ArrayDataType::kInt32)
723       << "Range op inputs must be int32.";
724   CHECK(limit_array.data_type == ArrayDataType::kInt32)
725       << "Range op inputs must be int32.";
726   CHECK(delta_array.data_type == ArrayDataType::kInt32)
727       << "Range op inputs must be int32.";
728   CHECK_EQ(RequiredBufferSizeForShape(start_array.shape()), 1)
729       << "Range op inputs must be scalar.";
730   CHECK_EQ(RequiredBufferSizeForShape(limit_array.shape()), 1)
731       << "Range op inputs must be scalar.";
732   CHECK_EQ(RequiredBufferSizeForShape(delta_array.shape()), 1)
733       << "Range op inputs must be scalar.";
734   int size = floor((limit_array.GetBuffer<ArrayDataType::kInt32>().data[0] -
735                     start_array.GetBuffer<ArrayDataType::kInt32>().data[0]) /
736                    delta_array.GetBuffer<ArrayDataType::kInt32>().data[0]);
737 
738   // Only set the output shape. Contents are set by ResolveConstantRange.
739   CHECK_EQ(op->outputs.size(), 1);
740   auto& output_array = model->GetArray(op->outputs[0]);
741   Shape* output_shape = output_array.mutable_shape();
742   output_shape->ReplaceDims({size});
743 }
744 
ProcessTensorFlowSplitOperator(Model * model,TensorFlowSplitOperator * op)745 void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
746   CHECK_EQ(op->inputs.size(), 2);
747   const string& input_name = op->inputs[1];
748   const auto& input_array = model->GetArray(input_name);
749   // Yield until input dims have been resolved.
750   if (!input_array.has_shape()) {
751     return;
752   }
753   const Shape& input_shape = input_array.shape();
754 
755   // Yield until axis is constant.
756   if (!IsConstantParameterArray(*model, op->inputs[0])) {
757     return;
758   }
759 
760   const auto& axis_array = model->GetArray(op->inputs[0]);
761 
762   // Yield until axis dims have been resolved.
763   if (!axis_array.has_shape()) {
764     return;
765   }
766 
767   CHECK(axis_array.data_type == ArrayDataType::kInt32)
768       << "Axis array must be int32.";
769   CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
770       << "Axis array must be scalar.";
771 
772   int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
773   if (axis < 0) {
774     axis += input_shape.dimensions_count();
775   }
776 
777   const int split_dim = input_shape.dims(axis);
778   CHECK_EQ(split_dim % op->num_split, 0);
779   const int split_depth = split_dim / op->num_split;
780 
781   Shape output_shape = input_shape;
782   (*output_shape.mutable_dims())[axis] = split_depth;
783 
784   CHECK_EQ(op->outputs.size(), op->num_split);
785   for (const auto& output : op->outputs) {
786     model->GetArray(output).copy_shape(output_shape);
787   }
788 }
789 
ProcessTensorFlowSplitVOperator(Model * model,TensorFlowSplitVOperator * op)790 void ProcessTensorFlowSplitVOperator(Model* model,
791                                      TensorFlowSplitVOperator* op) {
792   CHECK_EQ(op->inputs.size(), 3);
793 
794   const auto& input_array = model->GetArray(op->inputs[0]);
795   // Yield until input dims have been resolved.
796   if (!input_array.has_shape()) {
797     return;
798   }
799   const Shape& input_shape = input_array.shape();
800 
801   // Yield until size_splits is constant.
802   if (!IsConstantParameterArray(*model, op->inputs[1])) {
803     return;
804   }
805   const auto& size_array = model->GetArray(op->inputs[1]);
806   // Yield until size_splits dims have been resolved.
807   if (!size_array.has_shape()) {
808     return;
809   }
810   const Shape& size_shape = size_array.shape();
811 
812   CHECK(size_array.data_type == ArrayDataType::kInt32 ||
813         size_array.data_type == ArrayDataType::kInt64)
814       << "size_splits must be int32, int64";
815   CHECK_EQ(size_shape.dimensions_count(), 1) << "size_splits must be 1-D";
816 
817   std::vector<int64> size_splits_vector;
818   if (size_array.data_type == ArrayDataType::kInt32) {
819     for (const auto each_size :
820          size_array.GetBuffer<ArrayDataType::kInt32>().data) {
821       size_splits_vector.push_back(each_size);
822     }
823   } else {
824     size_splits_vector = size_array.GetBuffer<ArrayDataType::kInt64>().data;
825   }
826 
827   // Yield until axis is constant.
828   if (!IsConstantParameterArray(*model, op->inputs[2])) {
829     return;
830   }
831   const auto& axis_array = model->GetArray(op->inputs[2]);
832   // Yield until axis dims have been resolved.
833   if (!axis_array.has_shape()) {
834     return;
835   }
836 
837   CHECK(axis_array.data_type == ArrayDataType::kInt32)
838       << "Axis array must be int32.";
839   CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
840       << "Axis array must be scalar.";
841 
842   int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
843   if (axis < 0) {
844     axis += input_shape.dimensions_count();
845   }
846 
847   CHECK_EQ(op->num_split, size_splits_vector.size());
848 
849   int64_t minus_one_count = 0, size_splits_sum = 0;
850   for (auto size : size_splits_vector) {
851     if (size == -1) {
852       ++minus_one_count;
853     } else {
854       size_splits_sum += size;
855     }
856   }
857 
858   const int input_size = input_shape.dims(axis);
859 
860   CHECK_LE(minus_one_count, 1) << "size_splits can contain at most one -1.";
861 
862   if (minus_one_count == 1) {
863     CHECK_LE(size_splits_sum, input_size);
864     auto iter =
865         std::find(size_splits_vector.begin(), size_splits_vector.end(), -1);
866     *iter = input_size - size_splits_sum;
867   } else {
868     CHECK_EQ(size_splits_sum, input_size);
869   }
870 
871   CHECK_EQ(op->outputs.size(), op->num_split);
872 
873   for (int i = 0; i < op->outputs.size(); ++i) {
874     const auto& output = op->outputs[i];
875     Shape output_shape = input_shape;
876     (*output_shape.mutable_dims())[axis] = size_splits_vector.at(i);
877     model->GetArray(output).copy_shape(output_shape);
878   }
879 }
880 
ProcessAveragePoolOperator(Model * model,AveragePoolOperator * op)881 void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
882   const string& input_name = op->inputs[0];
883   const auto& input_array = model->GetArray(input_name);
884   // Yield until input dims have been resolved.
885   if (!input_array.has_shape()) {
886     return;
887   }
888   const auto& input_shape = input_array.shape();
889   CHECK_EQ(input_shape.dimensions_count(), 4);
890   const string& output_name = op->outputs[0];
891   const int output_depth = input_shape.dims(3);
892   ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
893                    op->stride_width, op->stride_height, 1, 1, op->padding.type,
894                    model->GetArray(output_name).mutable_shape(),
895                    &op->padding.GetOrCreateFixedPadding());
896 }
897 
ProcessMaxPoolOperator(Model * model,MaxPoolOperator * op)898 void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
899   const string& input_name = op->inputs[0];
900   const auto& input_array = model->GetArray(input_name);
901   // Yield until input dims have been resolved.
902   if (!input_array.has_shape()) {
903     return;
904   }
905   const auto& input_shape = input_array.shape();
906   CHECK_EQ(input_shape.dimensions_count(), 4);
907   const string& output_name = op->outputs[0];
908   const int output_depth = input_shape.dims(3);
909   ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
910                    op->stride_width, op->stride_height, 1, 1, op->padding.type,
911                    model->GetArray(output_name).mutable_shape(),
912                    &op->padding.GetOrCreateFixedPadding());
913 }
914 
ProcessL2PoolOperator(Model * model,L2PoolOperator * op)915 void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
916   const string& input_name = op->inputs[0];
917   const auto& input_array = model->GetArray(input_name);
918   // Yield until input dims have been resolved.
919   if (!input_array.has_shape()) {
920     return;
921   }
922   const auto& input_shape = input_array.shape();
923   if (input_shape.dimensions_count() < 4) {
924     LOG(FATAL) << "missing dimensions for " << input_name;
925   }
926   const string& output_name = op->outputs[0];
927   const int output_depth = input_shape.dims(3);
928   ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
929                    op->stride_width, op->stride_height, 1, 1, op->padding.type,
930                    model->GetArray(output_name).mutable_shape(),
931                    &op->padding.GetOrCreateFixedPadding());
932 }
933 
ProcessResizeBilinearOperator(Model * model,ResizeBilinearOperator * op)934 void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
935   CHECK_EQ(op->inputs.size(), 2);
936   CHECK_EQ(op->outputs.size(), 1);
937 
938   if (!model->GetArray(op->inputs[0]).has_shape() ||
939       !model->GetArray(op->inputs[1]).has_shape()) {
940     return;
941   }
942   const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
943 
944   const string& output_size_name = op->inputs[1];
945   const auto& output_size_array = model->GetArray(output_size_name);
946   CHECK(output_size_array.data_type == ArrayDataType::kInt32);
947   CHECK(output_size_array.has_shape());
948   const auto& output_size_shape = output_size_array.shape();
949   CHECK_EQ(output_size_shape.dimensions_count(), 1);
950   CHECK_EQ(output_size_shape.dims(0), 2);
951   if (!output_size_array.buffer) {
952     return;
953   }
954   std::vector<int32> output_shape =
955       output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
956   model->GetArray(op->outputs[0])
957       .copy_shape(Shape({input_data_shape.dims(0), output_shape[0],
958                          output_shape[1], input_data_shape.dims(3)}));
959 }
960 
ProcessResizeNearestNeighborOperator(Model * model,ResizeNearestNeighborOperator * op)961 void ProcessResizeNearestNeighborOperator(Model* model,
962                                           ResizeNearestNeighborOperator* op) {
963   CHECK_EQ(op->inputs.size(), 2);
964   CHECK_EQ(op->outputs.size(), 1);
965 
966   if (!model->GetArray(op->inputs[0]).has_shape() ||
967       !model->GetArray(op->inputs[1]).has_shape()) {
968     return;
969   }
970   const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
971 
972   const string& output_size_name = op->inputs[1];
973   const auto& output_size_array = model->GetArray(output_size_name);
974   CHECK(output_size_array.data_type == ArrayDataType::kInt32);
975   CHECK(output_size_array.has_shape());
976   const auto& output_size_shape = output_size_array.shape();
977   CHECK_EQ(output_size_shape.dimensions_count(), 1);
978   CHECK_EQ(output_size_shape.dims(0), 2);
979   if (!output_size_array.buffer) {
980     return;
981   }
982   std::vector<int32> output_shape =
983       output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
984   model->GetArray(op->outputs[0])
985       .copy_shape(Shape({input_data_shape.dims(0), output_shape[0],
986                          output_shape[1], input_data_shape.dims(3)}));
987 }
988 
ProcessLstmCellOperator(Model * model,LstmCellOperator * op)989 void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
990   // Only required for compact LstmCell with default NUM_INPUTS of inputs.
991   if (op->inputs.size() != LstmCellOperator::NUM_INPUTS) return;
992 
993   const auto& input_array =
994       model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]);
995   // Yield until all input dims have been resolved.
996   if (!input_array.has_shape()) {
997     return;
998   }
999   const auto& input_shape = input_array.shape();
1000   CHECK_GE(input_shape.dimensions_count(), 2);
1001 
1002   const auto& prev_activ_array =
1003       model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]);
1004   // Yield until all input dims have been resolved.
1005   if (!prev_activ_array.has_shape()) {
1006     return;
1007   }
1008   const auto& prev_activ_shape = prev_activ_array.shape();
1009   CHECK_GE(prev_activ_shape.dimensions_count(), 2);
1010 
1011   const auto& weights_array =
1012       model->GetArray(op->inputs[LstmCellOperator::WEIGHTS_INPUT]);
1013   // Yield until weights dims have been resolved.
1014   if (!weights_array.has_shape()) {
1015     return;
1016   }
1017   const auto& weights_shape = weights_array.shape();
1018   CHECK_EQ(weights_shape.dimensions_count(), 2);
1019 
1020   const auto& bias_array =
1021       model->GetArray(op->inputs[LstmCellOperator::BIASES_INPUT]);
1022   // Yield until bias dims have been resolved.
1023   if (!bias_array.has_shape()) {
1024     return;
1025   }
1026   const auto& bias_shape = bias_array.shape();
1027   CHECK_GE(bias_shape.dimensions_count(), 1);
1028 
1029   const auto& prev_state_array =
1030       model->GetArray(op->inputs[LstmCellOperator::PREV_STATE_INPUT]);
1031   // Yield until all input dims have been resolved.
1032   if (!prev_state_array.has_shape()) {
1033     return;
1034   }
1035   const auto& prev_state_shape = prev_state_array.shape();
1036   CHECK_GE(prev_state_shape.dimensions_count(), 2);
1037 
1038   const int fc_output_depth = weights_shape.dims(0);
1039   CHECK_EQ(fc_output_depth, bias_shape.dims(0));
1040   CHECK_EQ(fc_output_depth % 4, 0);
1041   const int depth = fc_output_depth / 4;
1042 
1043   const int input_depth = input_shape.dims(input_shape.dimensions_count() - 1);
1044   const int fc_input_depth = weights_shape.dims(1);
1045   CHECK_EQ(input_depth + depth, fc_input_depth);
1046   Shape output_shape(input_shape);
1047   (*output_shape.mutable_dims())[output_shape.dimensions_count() - 1] = depth;
1048 
1049   // Set output dimensions
1050   model->GetArray(op->outputs[LstmCellOperator::STATE_OUTPUT])
1051       .copy_shape(output_shape);
1052   model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT])
1053       .copy_shape(output_shape);
1054 
1055   Shape concat_temp_shape(input_shape);
1056   (*concat_temp_shape
1057         .mutable_dims())[concat_temp_shape.dimensions_count() - 1] =
1058       fc_input_depth;
1059   model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP])
1060       .copy_shape(concat_temp_shape);
1061 
1062   Shape activ_temp_shape(input_shape);
1063   (*activ_temp_shape.mutable_dims())[activ_temp_shape.dimensions_count() - 1] =
1064       fc_output_depth;
1065   model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP])
1066       .copy_shape(activ_temp_shape);
1067 }
1068 
ProcessUnidirectionalSequenceLstmOperator(Model * model,UnidirectionalSequenceLstmOperator * op)1069 void ProcessUnidirectionalSequenceLstmOperator(
1070     Model* model, UnidirectionalSequenceLstmOperator* op) {
1071   auto& output_array = model->GetArray(op->outputs[0]);
1072   if (output_array.has_shape()) {
1073     // Shape already propagated
1074     return;
1075   }
1076 
1077   if (output_array.data_type == ArrayDataType::kNone) {
1078     // Yield until the output type has been set by PropagateArrayDataTypes
1079     return;
1080   }
1081 
1082   // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1083   const auto& input_array = model->GetArray(op->inputs[0]);
1084 
1085   constexpr int kInputActivationStateTensor = 18;
1086   constexpr int kInputCellStateTensor = 19;
1087 
1088   // TFlite intepreter does not support array which is variable and contains a
1089   // buffer (see b/115961645 for more discussion).
1090   // The follow block remove buffer from the array to work around the
1091   // restriction, as a consequence, downstream applications should not
1092   // read lstm state as input to other operations.
1093   model->GetArray(op->inputs[kInputActivationStateTensor]).buffer.reset();
1094   model->GetArray(op->inputs[kInputCellStateTensor]).buffer.reset();
1095 
1096   // Yield until input dims have been resolved.
1097   if (!input_array.has_shape()) {
1098     return;
1099   }
1100   const auto& input_shape = input_array.shape();
1101   const int batch_size = input_shape.dims(1);
1102   const int timestamp = input_shape.dims(0);
1103 
1104   const auto& recurrent_to_output_weights_array =
1105       model->GetArray(op->inputs[8]);
1106   // Yield until input dims have been resolved.
1107   if (!recurrent_to_output_weights_array.has_shape()) {
1108     return;
1109   }
1110 
1111   const auto& output_weights_shape = recurrent_to_output_weights_array.shape();
1112   const int output_size = output_weights_shape.dims(1);
1113 
1114   Shape* output_shape = output_array.mutable_shape();
1115   output_shape->ReplaceDims({timestamp, batch_size, output_size});
1116 }
1117 
ProcessUnidirectionalSequenceRnnOperator(Model * model,UnidirectionalSequenceRnnOperator * op)1118 void ProcessUnidirectionalSequenceRnnOperator(
1119     Model* model, UnidirectionalSequenceRnnOperator* op) {
1120   auto& output_array = model->GetArray(op->outputs[0]);
1121   if (output_array.has_shape()) {
1122     // Shape already propagated.
1123     return;
1124   }
1125 
1126   if (output_array.data_type == ArrayDataType::kNone) {
1127     // Yield until the output type has been set by PropagateArrayDataTypes
1128     return;
1129   }
1130 
1131   constexpr int kHiddenStateTensor = 4;
1132   // TFlite intepreter does not support array which is variable and contains a
1133   // buffer (see b/115961645 for more discussion).
1134   // The follow block remove buffer from the array to work around the
1135   // restriction, as a consequence, downstream applications should not
1136   // read lstm state as input to other operations.
1137   model->GetArray(op->inputs[kHiddenStateTensor]).buffer.reset();
1138 
1139   // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1140   const auto& input_array = model->GetArray(op->inputs[0]);
1141   // Yield until input dims have been resolved.
1142   if (!input_array.has_shape()) {
1143     return;
1144   }
1145   const auto& input_shape = input_array.shape();
1146   const int batch_size = input_shape.dims(1);
1147   const int timestamp = input_shape.dims(0);
1148 
1149   const auto& bias_array = model->GetArray(op->inputs[3]);
1150   // Yield until input dims have been resolved.
1151   if (!bias_array.has_shape()) {
1152     return;
1153   }
1154 
1155   const auto& bias_shape = bias_array.shape();
1156   const int output_size = bias_shape.dims(0);
1157 
1158   Shape* output_shape = output_array.mutable_shape();
1159   output_shape->ReplaceDims({timestamp, batch_size, output_size});
1160 }
1161 
ProcessBidirectionalSequenceLstmOperator(Model * model,BidirectionalSequenceLstmOperator * op)1162 void ProcessBidirectionalSequenceLstmOperator(
1163     Model* model, BidirectionalSequenceLstmOperator* op) {
1164   // We assume time major.
1165   auto& fw_output_array = model->GetArray(op->outputs[0]);
1166   auto& bw_output_array = model->GetArray(op->outputs[1]);
1167   if (fw_output_array.has_shape()) {
1168     // Shape already propagated
1169     return;
1170   }
1171 
1172   if (fw_output_array.data_type == ArrayDataType::kNone) {
1173     // Yield until the output type has been set by PropagateArrayDataTypes
1174     return;
1175   }
1176 
1177   // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1178   const auto& input_array = model->GetArray(op->inputs[0]);
1179   // Yield until input dims have been resolved.
1180   if (!input_array.has_shape()) {
1181     return;
1182   }
1183   const auto& input_shape = input_array.shape();
1184   const int batch_size = input_shape.dims(1);
1185   const int timestamp = input_shape.dims(0);
1186 
1187   constexpr int kBwRecurrentToOutputWeightsTensor = 25;
1188   const auto& recurrent_to_output_weights_array =
1189       model->GetArray(op->inputs[kBwRecurrentToOutputWeightsTensor]);
1190   // Yield until input dims have been resolved.
1191   if (!recurrent_to_output_weights_array.has_shape()) {
1192     return;
1193   }
1194 
1195   constexpr int kFwInputActivationStateTensor = 35;
1196   constexpr int kFwInputCellStateTensor = 36;
1197   constexpr int kBwInputActivationStateTensor = 37;
1198   constexpr int kBwInputCellStateTensor = 38;
1199   // b(115961645): This is a hack to work around.
1200   model->GetArray(op->inputs[kFwInputActivationStateTensor]).buffer.reset();
1201   model->GetArray(op->inputs[kFwInputCellStateTensor]).buffer.reset();
1202   model->GetArray(op->inputs[kBwInputActivationStateTensor]).buffer.reset();
1203   model->GetArray(op->inputs[kBwInputCellStateTensor]).buffer.reset();
1204 
1205   const auto& output_weights_shape = recurrent_to_output_weights_array.shape();
1206   const int output_size = output_weights_shape.dims(1);
1207 
1208   Shape* fw_output_shape = fw_output_array.mutable_shape();
1209   if (op->merge_outputs) {
1210     fw_output_shape->ReplaceDims({timestamp, batch_size, 2 * output_size});
1211   } else {
1212     fw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1213     Shape* bw_output_shape = bw_output_array.mutable_shape();
1214     bw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1215   }
1216 }
1217 
ProcessBidirectionalSequenceRnnOperator(Model * model,BidirectionalSequenceRnnOperator * op)1218 void ProcessBidirectionalSequenceRnnOperator(
1219     Model* model, BidirectionalSequenceRnnOperator* op) {
1220   // We assume time major.
1221   auto& fw_output_array = model->GetArray(op->outputs[0]);
1222   auto& bw_output_array = model->GetArray(op->outputs[1]);
1223   if (fw_output_array.has_shape()) {
1224     // Shape already propagated
1225     return;
1226   }
1227 
1228   if (fw_output_array.data_type == ArrayDataType::kNone) {
1229     // Yield until the output type has been set by PropagateArrayDataTypes
1230     return;
1231   }
1232 
1233   // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1234   const auto& input_array = model->GetArray(op->inputs[0]);
1235   // Yield until input dims have been resolved.
1236   if (!input_array.has_shape()) {
1237     return;
1238   }
1239   const auto& input_shape = input_array.shape();
1240   const int batch_size = input_shape.dims(1);
1241   const int timestamp = input_shape.dims(0);
1242 
1243   constexpr int kFwWeightsTensor = 1;
1244   const auto& forward_weights_array =
1245       model->GetArray(op->inputs[kFwWeightsTensor]);
1246   // Yield until input dims have been resolved.
1247   if (!forward_weights_array.has_shape()) {
1248     return;
1249   }
1250 
1251   constexpr int kFwHiddenStateTensor = 4;
1252   constexpr int kBwHiddenStateTensor = 8;
1253   // b(115961645): This is a hack to work around.
1254   model->GetArray(op->inputs[kFwHiddenStateTensor]).buffer.reset();
1255   model->GetArray(op->inputs[kBwHiddenStateTensor]).buffer.reset();
1256 
1257   const auto& output_weights_shape = forward_weights_array.shape();
1258   const int output_size = output_weights_shape.dims(0);
1259 
1260   Shape* fw_output_shape = fw_output_array.mutable_shape();
1261   if (op->merge_outputs) {
1262     fw_output_shape->ReplaceDims({timestamp, batch_size, 2 * output_size});
1263   } else {
1264     fw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1265     Shape* bw_output_shape = bw_output_array.mutable_shape();
1266     bw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1267   }
1268 }
1269 
ProcessSpaceToBatchNDOperator(Model * model,SpaceToBatchNDOperator * op)1270 void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
1271   const auto& input_array = model->GetArray(op->inputs[0]);
1272   // Yield until input dims have been resolved.
1273   if (!input_array.has_shape()) {
1274     return;
1275   }
1276   const auto& input_shape = input_array.shape();
1277   // This method only handles input dimensions of 4.
1278   if (input_shape.dimensions_count() != 4) {
1279     return;
1280   }
1281   const auto input_height = input_shape.dims(1);
1282   const auto input_width = input_shape.dims(2);
1283 
1284   const auto& block_shape_array = model->GetArray(op->inputs[1]);
1285   const auto& paddings_array = model->GetArray(op->inputs[2]);
1286   const auto& block_shape_array_shape = block_shape_array.shape();
1287   const auto& paddings_array_shape = paddings_array.shape();
1288   QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
1289   QCHECK_EQ(paddings_array_shape.dimensions_count(), 2);
1290 
1291   // We only support two dimensions.
1292   QCHECK_EQ(block_shape_array_shape.dims(0), 2);
1293   if (!block_shape_array.buffer) {
1294     return;
1295   }
1296   QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
1297   const auto& block_shape_data =
1298       block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1299   auto block_height = block_shape_data[0];
1300   auto block_width = block_shape_data[1];
1301 
1302   QCHECK_EQ(paddings_array_shape.dims(0), 2);  // Number of block dimensions
1303   QCHECK_EQ(paddings_array_shape.dims(1), 2);  // Two parameters per dimension.
1304   if (!paddings_array.buffer) {
1305     return;
1306   }
1307   QCHECK(paddings_array.data_type == ArrayDataType::kInt32);
1308   const auto& paddings_data =
1309       paddings_array.GetBuffer<ArrayDataType::kInt32>().data;
1310   int height_with_paddings = input_height + paddings_data[0] + paddings_data[1];
1311   int width_with_paddings = input_width + paddings_data[2] + paddings_data[3];
1312   QCHECK_EQ(height_with_paddings % block_height, 0);
1313   QCHECK_EQ(width_with_paddings % block_width, 0);
1314   int output_height = height_with_paddings / block_height;
1315   int output_width = width_with_paddings / block_width;
1316 
1317   model->GetArray(op->outputs[0])
1318       .copy_shape(Shape({input_shape.dims(0) * block_height * block_width,
1319                          output_height, output_width, input_shape.dims(3)}));
1320 }
1321 
ProcessBatchToSpaceNDOperator(Model * model,BatchToSpaceNDOperator * op)1322 void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
1323   const auto& input_array = model->GetArray(op->inputs[0]);
1324   // Yield until input dims have been resolved.
1325   if (!input_array.has_shape()) {
1326     return;
1327   }
1328   const auto& input_shape = input_array.shape();
1329   CHECK_EQ(input_shape.dimensions_count(), 4);
1330   const auto input_height = input_shape.dims(1);
1331   const auto input_width = input_shape.dims(2);
1332 
1333   const auto& block_shape_array = model->GetArray(op->inputs[1]);
1334   const auto& crops_array = model->GetArray(op->inputs[2]);
1335   const auto& block_shape_array_shape = block_shape_array.shape();
1336   const auto& crops_array_shape = crops_array.shape();
1337   QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
1338   QCHECK_EQ(crops_array_shape.dimensions_count(), 2);
1339 
1340   // We only support two dimensions.
1341   QCHECK_EQ(block_shape_array_shape.dims(0), 2);
1342   if (!block_shape_array.buffer) {
1343     return;
1344   }
1345   QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
1346   const auto& block_shape_data =
1347       block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1348   auto block_height = block_shape_data[0];
1349   auto block_width = block_shape_data[1];
1350 
1351   QCHECK_EQ(crops_array_shape.dims(0), 2);  // Number of block dimensions
1352   QCHECK_EQ(crops_array_shape.dims(1), 2);  // Two parameters per dimension.
1353   if (!crops_array.buffer) {
1354     return;
1355   }
1356   QCHECK(crops_array.data_type == ArrayDataType::kInt32);
1357   const auto& crops_data = crops_array.GetBuffer<ArrayDataType::kInt32>().data;
1358   const int crops_top = crops_data[0];
1359   const int crops_bottom = crops_data[1];
1360   const int crops_left = crops_data[2];
1361   const int crops_right = crops_data[3];
1362   const int output_height =
1363       input_height * block_height - crops_top - crops_bottom;
1364   const int output_width = input_width * block_width - crops_left - crops_right;
1365   QCHECK_EQ(input_shape.dims(0) % (block_height * block_width), 0);
1366 
1367   model->GetArray(op->outputs[0])
1368       .copy_shape(Shape({input_shape.dims(0) / (block_height * block_width),
1369                          output_height, output_width, input_shape.dims(3)}));
1370 }
1371 
ProcessGatherOperator(Model * model,GatherOperator * op)1372 void ProcessGatherOperator(Model* model, GatherOperator* op) {
1373   const auto& input_array = model->GetArray(op->inputs[0]);
1374   const auto& indices_array = model->GetArray(op->inputs[1]);
1375   auto& output_array = model->GetArray(op->outputs[0]);
1376 
1377   // Bail if we already know the output shape.
1378   if (output_array.has_shape()) {
1379     return;
1380   }
1381 
1382   // Yield until input dims have been resolved.
1383   if (!input_array.has_shape() || !indices_array.has_shape()) {
1384     return;
1385   }
1386 
1387   // Yield until the axis has been resolved.
1388   if (!op->axis) {
1389     return;
1390   }
1391   int axis = op->axis.value();
1392 
1393   const auto& input_shape = input_array.shape();
1394   const auto& indices_shape = indices_array.shape();
1395   QCHECK_GE(input_shape.dimensions_count(), 1);
1396   op->input_rank = input_shape.dimensions_count();
1397   QCHECK_LT(axis, op->input_rank);
1398 
1399   // Copy the input dimensions to the output except for the axis dimensions
1400   // where the dimension of indices_shape is used.
1401   auto output_dims = output_array.mutable_shape()->mutable_dims();
1402   for (int dim = 0; dim < axis; ++dim) {
1403     output_dims->push_back(input_shape.dims(dim));
1404   }
1405   for (int dim = 0; dim < indices_shape.dimensions_count(); ++dim) {
1406     output_dims->push_back(indices_shape.dims(dim));
1407   }
1408   for (int dim = axis + 1; dim < input_shape.dimensions_count(); ++dim) {
1409     output_dims->push_back(input_shape.dims(dim));
1410   }
1411 }
1412 
ProcessGatherNdOperator(Model * model,GatherNdOperator * op)1413 void ProcessGatherNdOperator(Model* model, GatherNdOperator* op) {
1414   const auto& input_array = model->GetArray(op->inputs[0]);
1415   const auto& indices_array = model->GetArray(op->inputs[1]);
1416   auto& output_array = model->GetArray(op->outputs[0]);
1417 
1418   // Bail if we already know the output shape.
1419   if (output_array.has_shape()) {
1420     return;
1421   }
1422 
1423   // Yield until input dims have been resolved.
1424   if (!input_array.has_shape() || !indices_array.has_shape()) {
1425     return;
1426   }
1427 
1428   const auto& input_shape = input_array.shape();
1429   const auto& indices_shape = indices_array.shape();
1430   QCHECK_GE(input_shape.dimensions_count(), 1);
1431   QCHECK_GE(indices_shape.dimensions_count(), 1);
1432   const int indices_nd =
1433       indices_shape.dims(indices_shape.dimensions_count() - 1);
1434   QCHECK_LE(indices_nd, input_shape.dimensions_count());
1435 
1436   auto output_dims = output_array.mutable_shape()->mutable_dims();
1437   for (int dim = 0; dim < indices_shape.dimensions_count() - 1; ++dim) {
1438     output_dims->push_back(indices_shape.dims(dim));
1439   }
1440   for (int dim = indices_nd; dim < input_shape.dimensions_count(); ++dim) {
1441     output_dims->push_back(input_shape.dims(dim));
1442   }
1443 }
1444 
ProcessTopkV2Operator(Model * model,TopKV2Operator * op)1445 void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) {
1446   const auto& input_values = model->GetArray(op->inputs[0]);
1447   const auto& input_k = model->GetArray(op->inputs[1]);
1448   auto& output_values = model->GetArray(op->outputs[0]);
1449   auto& output_indexes = model->GetArray(op->outputs[1]);
1450 
1451   // Bail if we already know the output shape.
1452   if (output_indexes.has_shape()) {
1453     QCHECK(output_values.has_shape());
1454     return;
1455   }
1456 
1457   // Yield until input dims have been resolved.
1458   if (!input_values.has_shape() || !input_k.has_shape()) {
1459     return;
1460   }
1461 
1462   // If the value is initialized, we can specify the last dimension, otherwise
1463   // unknown.
1464   if (input_k.buffer) {
1465     const auto& input_values_shape = input_values.shape();
1466     auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims();
1467     auto output_values_dims = output_values.mutable_shape()->mutable_dims();
1468     for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) {
1469       output_indexes_dims->push_back(input_values_shape.dims(dim));
1470       output_values_dims->push_back(input_values_shape.dims(dim));
1471     }
1472     const int32_t k_value = input_k.GetBuffer<ArrayDataType::kInt32>().data[0];
1473     output_indexes_dims->push_back(k_value);
1474     output_values_dims->push_back(k_value);
1475   }
1476 }
1477 
ProcessPadOperator(Model * model,PadOperator * op)1478 void ProcessPadOperator(Model* model, PadOperator* op) {
1479   CHECK_EQ(op->inputs.size(), 2);
1480   CHECK_EQ(op->outputs.size(), 1);
1481 
1482   const auto& input_array = model->GetArray(op->inputs[0]);
1483 
1484   // Yield until input dims have been resolved.
1485   if (!input_array.has_shape()) return;
1486 
1487   if (op->left_padding.empty()) return;
1488   CHECK_EQ(op->left_padding.size(), op->right_padding.size());
1489 
1490   auto& output_array = model->GetArray(op->outputs[0]);
1491   if (output_array.has_shape()) return;
1492 
1493   Shape output_shape = input_array.shape();
1494   std::vector<int>& dims = *output_shape.mutable_dims();
1495   CHECK_EQ(op->left_padding.size(), dims.size());
1496 
1497   for (int i = 0; i < op->left_padding.size(); ++i) {
1498     dims[i] += op->left_padding[i] + op->right_padding[i];
1499   }
1500 
1501   output_array.copy_shape(output_shape);
1502 }
1503 
ProcessPadV2Operator(Model * model,PadV2Operator * op)1504 void ProcessPadV2Operator(Model* model, PadV2Operator* op) {
1505   CHECK_EQ(op->inputs.size(), 3);
1506   CHECK_EQ(op->outputs.size(), 1);
1507 
1508   const auto& input_array = model->GetArray(op->inputs[0]);
1509 
1510   // Yield until input dims have been resolved.
1511   if (!input_array.has_shape()) return;
1512 
1513   if (op->left_padding.empty()) return;
1514   CHECK_EQ(op->left_padding.size(), op->right_padding.size());
1515 
1516   auto& output_array = model->GetArray(op->outputs[0]);
1517   if (output_array.has_shape()) return;
1518 
1519   Shape output_shape = input_array.shape();
1520   std::vector<int>& dims = *output_shape.mutable_dims();
1521   CHECK_EQ(op->left_padding.size(), dims.size());
1522 
1523   for (int i = 0; i < op->left_padding.size(); ++i) {
1524     dims[i] += op->left_padding[i] + op->right_padding[i];
1525   }
1526 
1527   output_array.copy_shape(output_shape);
1528 }
1529 
ProcessRankOperator(Model * model,TensorFlowRankOperator * op)1530 void ProcessRankOperator(Model* model, TensorFlowRankOperator* op) {
1531   CHECK_GE(op->inputs.size(), 1);
1532   CHECK_EQ(op->outputs.size(), 1);
1533   auto& output_array = model->GetArray(op->outputs[0]);
1534   if (output_array.has_shape()) {
1535     // Shape already propagated
1536     return;
1537   }
1538 
1539   if (output_array.data_type == ArrayDataType::kNone) {
1540     // Yield until the output type has been set by PropagateArrayDataTypes
1541     return;
1542   }
1543 
1544   const auto& input_array = model->GetArray(op->inputs[0]);
1545   if (!input_array.has_shape()) {
1546     // Yield until input dims have been resolved.
1547     return;
1548   }
1549 
1550   // Only set the output shape. Array contents are set by
1551   // ResolveConstantShapeOrRank.
1552   Shape* output_shape = output_array.mutable_shape();
1553   output_shape->ReplaceDims({});
1554 }
1555 
ProcessShapeOperator(Model * model,TensorFlowShapeOperator * op)1556 void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
1557   CHECK_GE(op->inputs.size(), 1);
1558   CHECK_EQ(op->outputs.size(), 1);
1559   auto& output_array = model->GetArray(op->outputs[0]);
1560   if (output_array.has_shape()) {
1561     // Shape already propagated
1562     return;
1563   }
1564 
1565   if (output_array.data_type == ArrayDataType::kNone) {
1566     // Yield until the output type has been set by PropagateArrayDataTypes
1567     return;
1568   }
1569 
1570   const auto& input_array = model->GetArray(op->inputs[0]);
1571   if (!input_array.has_shape()) {
1572     // Yield until input dims have been resolved.
1573     return;
1574   }
1575 
1576   // Only set the output shape. Array contents are set by
1577   // ResolveConstantShapeOrRank.
1578   Shape* output_shape = output_array.mutable_shape();
1579   output_shape->ReplaceDims({input_array.shape().dimensions_count()});
1580 }
1581 
ProcessPackOperator(Model * model,PackOperator * op)1582 void ProcessPackOperator(Model* model, PackOperator* op) {
1583   CHECK_GE(op->inputs.size(), 1);
1584   CHECK_EQ(op->outputs.size(), 1);
1585   auto& output_array = model->GetArray(op->outputs[0]);
1586   if (output_array.has_shape()) {
1587     // Shape already propagated
1588     return;
1589   }
1590 
1591   std::unique_ptr<Shape> packed_shape;
1592   for (const auto& input : op->inputs) {
1593     const auto& input_array = model->GetArray(input);
1594     if (!input_array.has_shape()) {
1595       // Yield until all input dims have been resolved.
1596       return;
1597     }
1598 
1599     Shape shape = input_array.shape();
1600     if (!packed_shape) {
1601       packed_shape.reset(new Shape(shape));
1602     } else {
1603       CHECK(*packed_shape == shape) << "All input arrays to Pack operators "
1604                                        "must have the same shape. Input \""
1605                                     << input << "\" is different.";
1606     }
1607   }
1608 
1609   int axis = op->axis;
1610   if (axis < 0) {
1611     // Handle negative axis
1612     axis += packed_shape->dims().size() + 1;
1613   }
1614   packed_shape->mutable_dims()->insert(
1615       packed_shape->mutable_dims()->begin() + axis, op->inputs.size());
1616   output_array.copy_shape(*packed_shape);
1617 }
1618 
ProcessStridedSliceOperator(Model * model,StridedSliceOperator * op)1619 void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
1620   CHECK_GE(op->inputs.size(), 1);
1621   CHECK_EQ(op->outputs.size(), 1);
1622   auto& output_array = model->GetArray(op->outputs[0]);
1623   if (output_array.has_shape()) {
1624     // Shape already propagated
1625     return;
1626   }
1627 
1628   if (op->start_indices.empty() || op->stop_indices.empty() ||
1629       op->strides.empty()) {
1630     // ResolveStridedSliceAttributes has not run yet.
1631     return;
1632   }
1633 
1634   const auto& input_array = model->GetArray(op->inputs[0]);
1635   if (!input_array.has_shape()) {
1636     // Yield until input dims have been resolved.
1637     return;
1638   }
1639 
1640   if (op->ellipsis_mask != 0) {
1641     // Something like LOG_FIRST_N(WARNING, 10) would be prefferable to reduce
1642     // log noise. However, the TensorFlow logging library does not appear to
1643     // support this.
1644     LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0]
1645                  << "\". ellipsis_mask is not supported (mask="
1646                  << op->ellipsis_mask << ")";
1647     return;
1648   }
1649   if (op->new_axis_mask != 0) {
1650     LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0]
1651                  << "\". new_axis_mask is not supported (mask="
1652                  << op->new_axis_mask << ")";
1653     return;
1654   }
1655 
1656   int num_input_axes = input_array.shape().dimensions_count();
1657   CHECK_LE(op->start_indices.size(), num_input_axes)
1658       << "StridedSlice op with output \"" << op->outputs[0]
1659       << "\", requires no more than " << num_input_axes << " start indices";
1660   CHECK_LE(op->stop_indices.size(), num_input_axes)
1661       << "StridedSlice op with output \"" << op->outputs[0]
1662       << "\", requires no more than " << num_input_axes << " stop indices";
1663   CHECK_LE(op->strides.size(), num_input_axes)
1664       << "StridedSlice op with output \"" << op->outputs[0]
1665       << "\", requires no more than " << num_input_axes << " strides";
1666   for (int i = 0; i < op->strides.size(); i++) {
1667     CHECK_NE(op->strides[i], 0) << "Strides must be non-zero. Axis " << i
1668                                 << " has stride=" << op->strides[i] << ".";
1669   }
1670 
1671   // Create output shape
1672   std::vector<int>* dims = output_array.mutable_shape()->mutable_dims();
1673 
1674   // Compute output shape
1675   for (int axis = 0; axis < num_input_axes; ++axis) {
1676     const auto strided_slice_params =
1677         tflite::strided_slice::BuildStridedSliceParams(
1678             op->begin_mask, op->end_mask, op->shrink_axis_mask,
1679             op->start_indices, op->stop_indices, op->strides);
1680     int start_index = tflite::strided_slice::StartForAxis(
1681         strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
1682     int stop_index = tflite::strided_slice::StopForAxis(
1683         strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
1684         start_index);
1685 
1686     int dim_size =
1687         ceil(static_cast<float>(stop_index - start_index) / op->strides[axis]);
1688 
1689     CHECK_GT(dim_size, 0)
1690         << "Output size for an axis must be greater than 0. Axis " << axis
1691         << " computes to size " << dim_size
1692         << " for StridedSlice op with output \"" << op->outputs[0] << "\".";
1693     if (op->shrink_axis_mask & (1 << axis)) {
1694       CHECK_EQ(dim_size, 1)
1695           << "Output size for an axis must compute to 1 when shrinking an "
1696              "axis. Axis "
1697           << axis << " computes to size " << dim_size
1698           << " for StridedSlice op with output \"" << op->outputs[0] << "\".";
1699     } else {
1700       dims->push_back(dim_size);
1701     }
1702   }
1703 }
1704 
ProcessSqueezeOperator(Model * model,SqueezeOperator * op)1705 void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
1706   CHECK_EQ(op->inputs.size(), 1);
1707   CHECK_EQ(op->outputs.size(), 1);
1708 
1709   const auto& input_array = model->GetArray(op->inputs[0]);
1710 
1711   // Yield until input dims have been resolved.
1712   if (!input_array.has_shape()) return;
1713 
1714   auto& output_array = model->GetArray(op->outputs[0]);
1715   if (output_array.has_shape()) return;
1716 
1717   const std::vector<int>& input_dims = input_array.shape().dims();
1718   std::vector<int> output_dims;
1719 
1720   std::vector<int> squeeze_dims;
1721   const int input_num_dims = input_dims.size();
1722   for (int i : op->squeeze_dims) {
1723     squeeze_dims.push_back(i < 0 ? i + input_num_dims : i);
1724   }
1725   for (int i = 0; i < input_num_dims; ++i) {
1726     if (input_dims[i] != 1 ||
1727         (!squeeze_dims.empty() &&
1728          std::find(squeeze_dims.begin(), squeeze_dims.end(), i) ==
1729              squeeze_dims.end())) {
1730       output_dims.push_back(input_dims[i]);
1731     }
1732   }
1733   *output_array.mutable_shape()->mutable_dims() = output_dims;
1734 }
1735 
ProcessSvdfOperator(Model * model,SvdfOperator * op)1736 void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
1737   CHECK(op->inputs.size() == 3 || op->inputs.size() == 4);
1738   const auto& input_array = model->GetArray(op->inputs[0]);
1739   if (!input_array.has_shape()) return;
1740 
1741   auto& weights_feature_array = model->GetArray(op->inputs[1]);
1742   if (!weights_feature_array.has_shape()) return;
1743 
1744   const auto& weights_time_array = model->GetArray(op->inputs[2]);
1745   if (!weights_time_array.has_shape()) return;
1746 
1747   const bool has_bias = (op->inputs.size() == 4);
1748   if (has_bias) {
1749     const auto& bias_array = model->GetArray(op->inputs[3]);
1750     if (!bias_array.has_shape()) return;
1751   }
1752 
1753   const int batch_size = input_array.shape().dims()[0];
1754   const int num_units = weights_feature_array.shape().dims()[0];
1755   const int memory_size = weights_time_array.shape().dims()[1];
1756 
1757   auto& state_array = model->GetArray(op->outputs[0]);
1758   state_array.mutable_shape()->ReplaceDims(
1759       {batch_size, memory_size * num_units});
1760 
1761   auto& output_array = model->GetArray(op->outputs[1]);
1762   output_array.mutable_shape()->ReplaceDims({batch_size, num_units});
1763 }
1764 
ProcessTransposeOperator(Model * model,TransposeOperator * op)1765 void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
1766   auto& output_array = model->GetArray(op->outputs[0]);
1767   if (output_array.has_shape()) {
1768     // We have already run
1769     return;
1770   }
1771 
1772   const auto& input_array = model->GetArray(op->inputs[0]);
1773   if (!input_array.has_shape()) {
1774     // Yield until input dims have been resolved.
1775     return;
1776   }
1777   const auto& input_shape = input_array.shape();
1778 
1779   auto& perm_array = model->GetArray(op->inputs[1]);
1780   if (!perm_array.has_shape()) {
1781     // Yield until permutation shape been resolved.
1782     return;
1783   }
1784   if (!perm_array.buffer) {
1785     // Yield until the permutation is constant
1786     return;
1787   }
1788   CHECK(perm_array.data_type == ArrayDataType::kInt32)
1789       << "Transpose permutation input must be int32";
1790 
1791   std::vector<int32> const& perm =
1792       perm_array.GetBuffer<ArrayDataType::kInt32>().data;
1793   CHECK_EQ(perm.size(), input_shape.dimensions_count())
1794       << "Transpose permutation input " << op->inputs[1]
1795       << " must be same length as input dimensions";
1796   std::vector<int>* output_dims = output_array.mutable_shape()->mutable_dims();
1797   for (int i = 0; i < perm.size(); i++) {
1798     int axis = perm[i];
1799     CHECK_GE(axis, 0);
1800     CHECK_LT(axis, input_shape.dimensions_count());
1801     output_dims->push_back(input_shape.dims(axis));
1802   }
1803 }
1804 
1805 template <typename Op>
ProcessArgMinMaxOperator(Model * model,Op * op)1806 void ProcessArgMinMaxOperator(Model* model, Op* op) {
1807   CHECK_EQ(op->inputs.size(), 2);
1808   const auto& input_array = model->GetArray(op->inputs[0]);
1809   // Yield until input dims have been resolved.
1810   if (!input_array.has_shape()) {
1811     return;
1812   }
1813 
1814   const Array& axis_array = model->GetArray(op->inputs[1]);
1815   // Yield until input axis array shape has been resolved.
1816   if (!axis_array.has_shape()) {
1817     return;
1818   }
1819 
1820   const std::vector<int>& input_dims = input_array.shape().dims();
1821 
1822   CHECK(axis_array.data_type == ArrayDataType::kInt32 ||
1823         axis_array.data_type == ArrayDataType::kInt64)
1824       << "axis_array must be int32, int64";
1825 
1826   CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
1827       << "Axis array must be scalar.";
1828 
1829   int64 axis;
1830   if (axis_array.data_type == ArrayDataType::kInt32) {
1831     axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
1832   } else {
1833     axis = axis_array.GetBuffer<ArrayDataType::kInt64>().data[0];
1834   }
1835 
1836   std::vector<int> output_dims;
1837 
1838   output_dims.reserve(input_dims.size() - 1);
1839   for (int i = 0; i < input_dims.size(); ++i) {
1840     if (i != axis) {
1841       output_dims.push_back(input_dims[i]);
1842     }
1843   }
1844 
1845   const string& output_name = op->outputs[0];
1846   auto& output_array = model->GetArray(output_name);
1847   if (output_array.has_shape()) {
1848     return;
1849   }
1850   *output_array.mutable_shape()->mutable_dims() = output_dims;
1851 }
1852 
ProcessSparseToDenseOperator(Model * model,SparseToDenseOperator * op)1853 void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) {
1854   CHECK_EQ(op->inputs.size(), 4);
1855 
1856   const Array& output_shape_array = model->GetArray(op->inputs[1]);
1857   if (!output_shape_array.has_shape()) return;
1858   CHECK_EQ(output_shape_array.shape().dimensions_count(), 1);
1859 
1860   // Output should not go over four dimensions.
1861   CHECK_LE(output_shape_array.shape().dims(0), 4);
1862 
1863   const string& output_name = op->outputs[0];
1864   Array& output_array = model->GetArray(output_name);
1865   if (output_array.has_shape()) return;
1866 
1867   CHECK(output_shape_array.data_type == ArrayDataType::kInt32 ||
1868         output_shape_array.data_type == ArrayDataType::kInt64);
1869   if (output_shape_array.data_type == ArrayDataType::kInt32) {
1870     *output_array.mutable_shape()->mutable_dims() =
1871         output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1872   } else {
1873     const std::vector<int64>& output_shape_data =
1874         output_shape_array.GetBuffer<ArrayDataType::kInt64>().data;
1875     std::copy(
1876         output_shape_data.begin(), output_shape_data.end(),
1877         std::back_inserter(*output_array.mutable_shape()->mutable_dims()));
1878   }
1879 }
1880 
ProcessTileOperator(Model * model,TensorFlowTileOperator * op)1881 void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) {
1882   CHECK_EQ(op->inputs.size(), 2);
1883   CHECK_EQ(op->outputs.size(), 1);
1884 
1885   auto& output_array = model->GetArray(op->outputs[0]);
1886   if (output_array.has_shape()) {
1887     // We have already run.
1888     return;
1889   }
1890 
1891   const auto& input_array = model->GetArray(op->inputs[0]);
1892   if (!input_array.has_shape()) {
1893     // Yield until input dims have been resolved.
1894     return;
1895   }
1896   const auto& input_shape = input_array.shape();
1897 
1898   auto& multiples_array = model->GetArray(op->inputs[1]);
1899   if (!multiples_array.has_shape()) {
1900     // Yield until multiples shape been resolved.
1901     return;
1902   }
1903   if (!multiples_array.buffer) {
1904     // Yield until the multiples is constant.
1905     return;
1906   }
1907   CHECK(multiples_array.data_type == ArrayDataType::kInt32)
1908       << "Tile multiples input must be int32";
1909 
1910   std::vector<int32> const& multiples =
1911       multiples_array.GetBuffer<ArrayDataType::kInt32>().data;
1912   CHECK_EQ(multiples.size(), input_shape.dimensions_count())
1913       << "Tile multiples input " << op->inputs[1]
1914       << " must be same length as input dimensions";
1915 
1916   auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
1917   mutable_dims->resize(multiples.size());
1918   for (int i = 0; i < mutable_dims->size(); ++i) {
1919     (*mutable_dims)[i] = input_shape.dims(i) * multiples[i];
1920   }
1921 }
1922 
ProcessOneHotOperator(Model * model,OneHotOperator * op)1923 void ProcessOneHotOperator(Model* model, OneHotOperator* op) {
1924   CHECK_EQ(op->inputs.size(), 4);
1925   CHECK_EQ(op->outputs.size(), 1);
1926   auto& output_array = model->GetArray(op->outputs[0]);
1927   if (output_array.has_shape()) {
1928     // Shape already propagated
1929     return;
1930   }
1931 
1932   // Yield until indices dims have been resolved.
1933   const auto& indices_array =
1934       model->GetArray(op->inputs[OneHotOperator::INDICES_INPUT]);
1935   if (!indices_array.has_shape()) {
1936     return;
1937   }
1938 
1939   // Yield until depth is constant and dims have been resolved.
1940   if (!IsConstantParameterArray(*model,
1941                                 op->inputs[OneHotOperator::DEPTH_INPUT])) {
1942     return;
1943   }
1944   const auto& depth_array =
1945       model->GetArray(op->inputs[OneHotOperator::DEPTH_INPUT]);
1946   if (!depth_array.has_shape()) {
1947     return;
1948   }
1949 
1950   CHECK(depth_array.data_type == ArrayDataType::kInt32)
1951       << "Depth array must be int32.";
1952   CHECK_EQ(RequiredBufferSizeForShape(depth_array.shape()), 1)
1953       << "Depth array must be scalar.";
1954 
1955   const int depth = depth_array.GetBuffer<ArrayDataType::kInt32>().data[0];
1956   CHECK_GE(depth, 0) << "Depth must be non-negative.";
1957 
1958   const int indices_dims = indices_array.shape().dimensions_count();
1959   const int output_dims = indices_dims + 1;
1960   const int axis = op->axis == -1 ? indices_dims : op->axis;
1961   CHECK_GE(axis, 0) << "Resolved axis must be non-negative.";
1962 
1963   auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
1964   mutable_dims->resize(output_dims);
1965   for (int i = 0; i < output_dims; ++i) {
1966     int dim = 0;
1967     if (i < axis) {
1968       dim = indices_array.shape().dims(i);
1969     } else if (i == axis) {
1970       dim = depth;
1971     } else {
1972       dim = indices_array.shape().dims(i - 1);
1973     }
1974     (*mutable_dims)[i] = dim;
1975   }
1976 }
1977 
ProcessUnpackOperator(Model * model,UnpackOperator * op)1978 void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
1979   CHECK_EQ(op->inputs.size(), 1);
1980   const auto& input_array = model->GetArray(op->inputs[0]);
1981   // Yield until input dims have been resolved.
1982   if (!input_array.has_shape()) {
1983     return;
1984   }
1985 
1986   const std::vector<int>& input_dims = input_array.shape().dims();
1987   std::vector<int> output_dims;
1988 
1989   output_dims.reserve(input_dims.size() - 1);
1990   for (int i = 0; i < input_dims.size(); ++i) {
1991     if (i != op->axis) {
1992       output_dims.push_back(input_dims[i]);
1993     }
1994   }
1995   for (const string& output_name : op->outputs) {
1996     auto& output_array = model->GetArray(output_name);
1997     if (output_array.has_shape()) {
1998       return;
1999     }
2000     *output_array.mutable_shape()->mutable_dims() = output_dims;
2001   }
2002 }
2003 
ProcessMirrorPadOperator(Model * model,MirrorPadOperator * op)2004 void ProcessMirrorPadOperator(Model* model, MirrorPadOperator* op) {
2005   CHECK_EQ(op->inputs.size(), 2);
2006   const auto& input_array = model->GetArray(op->inputs[0]);
2007   const auto& padding_matrix = model->GetArray(op->inputs[1]);
2008 
2009   // Yield until input dims have been resolved.
2010   if (!input_array.has_shape()) {
2011     return;
2012   }
2013 
2014   auto& output_array = model->GetArray(op->outputs[0]);
2015   // If output already computed or padding matrix is non
2016   // const then return.
2017   if (output_array.has_shape() ||
2018       !IsConstantParameterArray(*model, op->inputs[1])) {
2019     return;
2020   }
2021   Shape output_shape = input_array.shape();
2022   std::vector<int>& dims = *output_shape.mutable_dims();
2023 
2024   std::vector<int64_t> padding;
2025   if (padding_matrix.data_type == ArrayDataType::kInt32) {
2026     const auto& data = padding_matrix.GetBuffer<ArrayDataType::kInt32>().data;
2027     for (auto elem : data) {
2028       padding.push_back(static_cast<int64_t>(elem));
2029     }
2030   } else if (padding_matrix.data_type == ArrayDataType::kInt64) {
2031     const auto& data = padding_matrix.GetBuffer<ArrayDataType::kInt64>().data;
2032     for (auto elem : data) {
2033       padding.push_back(elem);
2034     }
2035   } else {
2036     CHECK(padding_matrix.data_type == ArrayDataType::kInt64 ||
2037           padding_matrix.data_type == ArrayDataType::kInt32);
2038   }
2039   CHECK_EQ(padding_matrix.shape().dimensions_count(), 2);
2040   CHECK_EQ(input_array.shape().dimensions_count(),
2041            padding_matrix.shape().dims(0));
2042   for (int i = 0; i < input_array.shape().dimensions_count(); ++i) {
2043     dims[i] += padding[i * 2] + padding[i * 2 + 1];
2044   }
2045 
2046   output_array.copy_shape(output_shape);
2047 }
2048 
ProcessUniqueOperator(Model * model,UniqueOperator * op)2049 void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
2050   const auto& input_array = model->GetArray(op->inputs[0]);
2051   // We have 2 outputs, the shape of the index tensor, is the same size
2052   // as the input array. The unique values tensor, is unknown until runtime.
2053   CHECK_EQ(op->outputs.size(), 2);
2054   auto& idx_output_array = model->GetArray(op->outputs[1]);
2055 
2056   // Yield until input dims have been resolved, or output already computed
2057   if (!input_array.has_shape() || idx_output_array.has_shape()) {
2058     return;
2059   }
2060   idx_output_array.copy_shape(input_array.shape());
2061 }
2062 
2063 }  // namespace
2064 
Run(Model * model,std::size_t op_index,bool * modified)2065 ::tensorflow::Status PropagateFixedSizes::Run(Model* model,
2066                                               std::size_t op_index,
2067                                               bool* modified) {
2068   *modified = false;
2069   auto it = model->operators.begin() + op_index;
2070   auto* op = it->get();
2071   std::unordered_map<string, std::vector<int>> old_output_dims;
2072   for (const auto& output : op->outputs) {
2073     if (model->GetArray(output).has_shape()) {
2074       old_output_dims[output] = model->GetArray(output).shape().dims();
2075     }
2076   }
2077 
2078   switch (op->type) {
2079     case OperatorType::kAbs:
2080     case OperatorType::kBatchNormalization:
2081     case OperatorType::kL2Normalization:
2082     case OperatorType::kDequantize:
2083     case OperatorType::kElu:
2084     case OperatorType::kRelu:
2085     case OperatorType::kRelu1:
2086     case OperatorType::kRelu6:
2087     case OperatorType::kPRelu:
2088     case OperatorType::kLeakyRelu:
2089     case OperatorType::kSoftmax:
2090     case OperatorType::kLogSoftmax:
2091     case OperatorType::kLog:
2092     case OperatorType::kLogistic:
2093     case OperatorType::kTanh:
2094     case OperatorType::kLocalResponseNormalization:
2095     case OperatorType::kIdentity:
2096     case OperatorType::kFakeQuant:
2097     case OperatorType::kNeg:
2098     case OperatorType::kRsqrt:
2099     case OperatorType::kSqrt:
2100     case OperatorType::kSquare:
2101     case OperatorType::kAll:
2102     case OperatorType::kAssert:
2103     case OperatorType::kCast:
2104     case OperatorType::kFloor:
2105     case OperatorType::kCeil:
2106     case OperatorType::kExp:
2107     case OperatorType::kSin:
2108     case OperatorType::kCos:
2109     case OperatorType::kLogicalAnd:
2110     case OperatorType::kLogicalNot:
2111     case OperatorType::kLogicalOr:
2112     case OperatorType::kZerosLike:
2113     case OperatorType::kReverseV2:
2114     case OperatorType::kReverseSequence:
2115       ProcessSimpleOperator(model, op, 0);
2116       break;
2117     case OperatorType::kGather:
2118       ProcessGatherOperator(model, static_cast<GatherOperator*>(op));
2119       break;
2120     case OperatorType::kGatherNd:
2121       ProcessGatherNdOperator(model, static_cast<GatherNdOperator*>(op));
2122       break;
2123     case OperatorType::kTopK_V2:
2124       ProcessTopkV2Operator(model, static_cast<TopKV2Operator*>(op));
2125       break;
2126     case OperatorType::kAdd:
2127     case OperatorType::kSub:
2128     case OperatorType::kMul:
2129     case OperatorType::kDiv:
2130     case OperatorType::kFloorDiv:
2131     case OperatorType::kFloorMod:
2132     case OperatorType::kLess:
2133     case OperatorType::kLessEqual:
2134     case OperatorType::kGreater:
2135     case OperatorType::kMaximum:  //  Element-wise Maximum
2136     case OperatorType::kMinimum:  //  Element-wise Minimum
2137     case OperatorType::kGreaterEqual:
2138     case OperatorType::kEqual:
2139     case OperatorType::kNotEqual:
2140     case OperatorType::kPow:
2141     case OperatorType::kSquaredDifference:
2142       ProcessSimpleBinaryOperator(model, op);
2143       break;
2144     case OperatorType::kAddN:
2145       ProcessAddNOperator(model, op);
2146       break;
2147     case OperatorType::kConv:
2148       ProcessConvOperator(model, static_cast<ConvOperator*>(op));
2149       break;
2150     case OperatorType::kTransposeConv:
2151       ProcessTransposeConvOperator(model,
2152                                    static_cast<TransposeConvOperator*>(op));
2153       break;
2154     case OperatorType::kDepthwiseConv:
2155       ProcessDepthwiseConvOperator(model,
2156                                    static_cast<DepthwiseConvOperator*>(op));
2157       break;
2158     case OperatorType::kDepthToSpace:
2159       ProcessDepthToSpaceOperator(model,
2160                                   static_cast<DepthToSpaceOperator*>(op));
2161       break;
2162     case OperatorType::kSpaceToDepth:
2163       ProcessSpaceToDepthOperator(model,
2164                                   static_cast<SpaceToDepthOperator*>(op));
2165       break;
2166     case OperatorType::kFill:
2167       CHECK_EQ(op->inputs.size(), 2);
2168       ProcessOpWithShapeInput(model, op);
2169       break;
2170     case OperatorType::kFullyConnected:
2171       ProcessFullyConnectedOperator(model,
2172                                     static_cast<FullyConnectedOperator*>(op));
2173       break;
2174     case OperatorType::kReshape:
2175       ProcessTensorFlowReshapeOperator(
2176           model, static_cast<TensorFlowReshapeOperator*>(op));
2177       break;
2178     case OperatorType::kAveragePool:
2179       ProcessAveragePoolOperator(model, static_cast<AveragePoolOperator*>(op));
2180       break;
2181     case OperatorType::kMaxPool:
2182       ProcessMaxPoolOperator(model, static_cast<MaxPoolOperator*>(op));
2183       break;
2184     case OperatorType::kL2Pool:
2185       ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op));
2186       break;
2187     case OperatorType::kReduceMin:  //  Reduction Min
2188     case OperatorType::kReduceMax:  //  Reduction Max
2189     case OperatorType::kSum:
2190     case OperatorType::kReduceProd:
2191     case OperatorType::kMean:
2192     case OperatorType::kAny:
2193       ProcessTensorFlowReductionOperator(model, op);
2194       break;
2195     case OperatorType::kSelect:
2196       ProcessSelectOperator(model, static_cast<SelectOperator*>(op));
2197       break;
2198     case OperatorType::kSlice:
2199       ProcessSliceOperator(model, static_cast<SliceOperator*>(op));
2200       break;
2201 
2202     case OperatorType::kSwitch:
2203       // We can't know the sizes of the outputs until we have resolved the
2204       // predicate, and once we have resolved the predicate, the whole
2205       // Switch node will get resolved away.
2206       // See ResolveTensorFlowSwitch.
2207       break;
2208     case OperatorType::kMerge:
2209       // No need to bother resolving TensorFlow Merge ops: other graph
2210       // transformations will remove them anyway.
2211       // See ResolveTensorFlowMerge.
2212       break;
2213     case OperatorType::kSplit:
2214       ProcessTensorFlowSplitOperator(model,
2215                                      static_cast<TensorFlowSplitOperator*>(op));
2216       break;
2217     case OperatorType::kSplitV:
2218       ProcessTensorFlowSplitVOperator(
2219           model, static_cast<TensorFlowSplitVOperator*>(op));
2220       break;
2221     case OperatorType::kSqueeze:
2222       ProcessSqueezeOperator(model, static_cast<SqueezeOperator*>(op));
2223       break;
2224     case OperatorType::kConcat:
2225     case OperatorType::kConcatV2:
2226       // Unimplemented, hopefully another graph transformation will
2227       // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat
2228       // will resolve this node to a DepthConcatenation, or else we have
2229       // a more general non-depth concatenation that will hopefully be dropped,
2230       // or else at the moment we will abort.
2231       break;
2232     case OperatorType::kExpandDims:
2233       // Yield until ExpandDims is converted to Reshape
2234       break;
2235     case OperatorType::kRange:
2236       ProcessRangeOperator(model, static_cast<RangeOperator*>(op));
2237       break;
2238     case OperatorType::kRank:
2239       ProcessRankOperator(model, static_cast<TensorFlowRankOperator*>(op));
2240       break;
2241     case OperatorType::kShape:
2242       ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op));
2243       break;
2244     case OperatorType::kPack:
2245       ProcessPackOperator(model, static_cast<PackOperator*>(op));
2246       break;
2247     case OperatorType::kReorderAxes:
2248       ProcessReorderAxesOperator(model, static_cast<ReorderAxesOperator*>(op));
2249       break;
2250     case OperatorType::kConcatenation:
2251       ProcessConcatenationOperator(model,
2252                                    static_cast<ConcatenationOperator*>(op));
2253       break;
2254     case OperatorType::kResizeBilinear:
2255       ProcessResizeBilinearOperator(model,
2256                                     static_cast<ResizeBilinearOperator*>(op));
2257       break;
2258     case OperatorType::kResizeNearestNeighbor:
2259       ProcessResizeNearestNeighborOperator(
2260           model, static_cast<ResizeNearestNeighborOperator*>(op));
2261       break;
2262     case OperatorType::kUnidirectionalSequenceLstm:
2263       ProcessUnidirectionalSequenceLstmOperator(
2264           model, static_cast<UnidirectionalSequenceLstmOperator*>(op));
2265       break;
2266     case OperatorType::kUnidirectionalSequenceRnn:
2267       ProcessUnidirectionalSequenceRnnOperator(
2268           model, static_cast<UnidirectionalSequenceRnnOperator*>(op));
2269       break;
2270     case OperatorType::kBidirectionalSequenceLstm:
2271       ProcessBidirectionalSequenceLstmOperator(
2272           model, static_cast<BidirectionalSequenceLstmOperator*>(op));
2273       break;
2274     case OperatorType::kBidirectionalSequenceRnn:
2275       ProcessBidirectionalSequenceRnnOperator(
2276           model, static_cast<BidirectionalSequenceRnnOperator*>(op));
2277       break;
2278     case OperatorType::kLstmCell:
2279       ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
2280       break;
2281     case OperatorType::kBatchMatMul:
2282     case OperatorType::kMatMul:
2283       // MatMul operators are converted to FullyConnected, after which their
2284       // shapes are propagated.
2285       break;
2286     case OperatorType::kSpaceToBatchND:
2287       ProcessSpaceToBatchNDOperator(model,
2288                                     static_cast<SpaceToBatchNDOperator*>(op));
2289       break;
2290     case OperatorType::kBatchToSpaceND:
2291       ProcessBatchToSpaceNDOperator(model,
2292                                     static_cast<BatchToSpaceNDOperator*>(op));
2293       break;
2294     case OperatorType::kPad:
2295       ProcessPadOperator(model, static_cast<PadOperator*>(op));
2296       break;
2297     case OperatorType::kPadV2:
2298       ProcessPadV2Operator(model, static_cast<PadV2Operator*>(op));
2299       break;
2300     case OperatorType::kStridedSlice:
2301       ProcessStridedSliceOperator(model,
2302                                   static_cast<StridedSliceOperator*>(op));
2303       break;
2304     case OperatorType::kArgMax:
2305       ProcessArgMinMaxOperator<ArgMaxOperator>(
2306           model, static_cast<ArgMaxOperator*>(op));
2307       break;
2308     case OperatorType::kArgMin:
2309       ProcessArgMinMaxOperator<ArgMinOperator>(
2310           model, static_cast<ArgMinOperator*>(op));
2311       break;
2312     case OperatorType::kUnsupported: {
2313       const auto* unsupported_op =
2314           static_cast<TensorFlowUnsupportedOperator*>(op);
2315       // Attribute can be not specified, ignore it.
2316       if (unsupported_op->output_shapes.size() < op->outputs.size()) {
2317         return ::tensorflow::Status::OK();
2318       }
2319       for (int i = 0; i < op->outputs.size(); ++i) {
2320         const string& output = op->outputs[i];
2321         model->GetArray(output).copy_shape(unsupported_op->output_shapes.at(i));
2322       }
2323       break;
2324     }
2325     case OperatorType::kSvdf:
2326       ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op));
2327       break;
2328     case OperatorType::kTranspose:
2329       ProcessTransposeOperator(model, static_cast<TransposeOperator*>(op));
2330       break;
2331     case OperatorType::kDynamicPartition:
2332     case OperatorType::kDynamicStitch:
2333       // DynamicPartition/DynamicStitch are currently only supported for
2334       // transforms that remove them, so we avoid propagating shapes through
2335       // them and let things settle once they've been removed.
2336       break;
2337     case OperatorType::kRandomUniform:
2338       CHECK_EQ(op->inputs.size(), 1);
2339       ProcessOpWithShapeInput(model, op);
2340       break;
2341     case OperatorType::kSparseToDense:
2342       ProcessSparseToDenseOperator(model,
2343                                    static_cast<SparseToDenseOperator*>(op));
2344       break;
2345     case OperatorType::kTile:
2346       ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op));
2347       break;
2348       break;
2349     case OperatorType::kOneHot:
2350       ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op));
2351       break;
2352     case OperatorType::kUnpack:
2353       ProcessUnpackOperator(model, static_cast<UnpackOperator*>(op));
2354       break;
2355     case OperatorType::kMirrorPad:
2356       ProcessMirrorPadOperator(model, static_cast<MirrorPadOperator*>(op));
2357       break;
2358     case OperatorType::kUnique:
2359       ProcessUniqueOperator(model, static_cast<UniqueOperator*>(op));
2360       break;
2361     case OperatorType::kWhere:
2362       // The size of the output can only be known after evaluating the cond
2363       // tensor. Ignore shape propagation here and defer that to the
2364       // interpreter.
2365       break;
2366     default:
2367       // Unimplemented, another graph transformation should drop it.
2368       LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
2369   }
2370 
2371   // Return true if any output dim changed, false if none changed.
2372   // Assumption: no transformation clears an output shape, they only add shapes.
2373   for (const auto& output : op->outputs) {
2374     if (model->GetArray(output).has_shape() &&
2375         (old_output_dims[output] != model->GetArray(output).shape().dims())) {
2376       AddMessageF("Set shape of %s to [%s]", output,
2377                   absl::StrJoin(model->GetArray(output).shape().dims(), ","));
2378       *modified = true;
2379       return ::tensorflow::Status::OK();
2380     }
2381   }
2382   return ::tensorflow::Status::OK();
2383 }
2384 
2385 }  // namespace toco
2386