1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <iterator>
17 #include <memory>
18 #include <numeric>
19 #include <string>
20 #include <unordered_map>
21 #include <vector>
22
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
26 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
27 #include "tensorflow/lite/toco/model.h"
28 #include "tensorflow/lite/toco/tooling_util.h"
29
30 namespace toco {
31
32 namespace {
33
ComputeConvSizes(const Shape & input_shape,int output_depth,int kwidth,int kheight,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,PaddingType padding_type,Shape * output_shape,FixedPadding * fixed_padding)34 void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth,
35 int kheight, int stride_width, int stride_height,
36 int dilation_width_factor, int dilation_height_factor,
37 PaddingType padding_type, Shape* output_shape,
38 FixedPadding* fixed_padding) {
39 const int input_width = input_shape.dims(2);
40 const int input_height = input_shape.dims(1);
41 const int batch = input_shape.dims(0);
42
43 CHECK_GE(input_width, 1);
44 CHECK_GE(input_height, 1);
45 CHECK_GE(batch, 1);
46 CHECK_GE(kwidth, 1);
47 CHECK_GE(kheight, 1);
48 CHECK_GE(stride_width, 1);
49 CHECK_GE(stride_height, 1);
50 CHECK_GE(dilation_width_factor, 1);
51 CHECK_GE(dilation_height_factor, 1);
52
53 int dilated_kwidth = dilation_width_factor * (kwidth - 1) + 1;
54 int dilated_kheight = dilation_height_factor * (kheight - 1) + 1;
55
56 int output_height = 0;
57 int output_width = 0;
58 if (padding_type == PaddingType::kValid) {
59 output_height =
60 (input_height + stride_height - dilated_kheight) / stride_height;
61 output_width = (input_width + stride_width - dilated_kwidth) / stride_width;
62 } else if (padding_type == PaddingType::kSame) {
63 output_height = (input_height + stride_height - 1) / stride_height;
64 output_width = (input_width + stride_width - 1) / stride_width;
65 } else {
66 LOG(FATAL) << "Only supporting SAME or VALID padding";
67 }
68
69 fixed_padding->height = std::max(0, ((output_height - 1) * stride_height +
70 dilated_kheight - input_height) /
71 2);
72 fixed_padding->width = std::max(
73 0,
74 ((output_width - 1) * stride_width + dilated_kwidth - input_width) / 2);
75
76 // Actually had to debug a situation where those were negative due to bad
77 // propagation of placeholder -1 sizes in TensorFlowReshape.
78 CHECK_GT(output_width, 0);
79 CHECK_GT(output_height, 0);
80 output_shape->ReplaceDims({batch, output_height, output_width, output_depth});
81 }
82
ComputeBinaryOperatorOutputSize(const Shape & input_shape_x,const Shape & input_shape_y,Array * output_array)83 void ComputeBinaryOperatorOutputSize(const Shape& input_shape_x,
84 const Shape& input_shape_y,
85 Array* output_array) {
86 // This matches the code in BroadcastBinaryOpShapeFn from tensorflow.
87 // It zips together the two input shapes and pads with 1 to make them the
88 // same length. For each dimension we broadcast if either dimension is 1 and
89 // otherwise expect them to match.
90 int rank_x = input_shape_x.dimensions_count();
91 int rank_y = input_shape_y.dimensions_count();
92 int rank_out = std::max(rank_x, rank_y);
93 std::vector<int>* dims_out = output_array->mutable_shape()->mutable_dims();
94 dims_out->clear();
95 dims_out->reserve(rank_out);
96 for (int i = 0; i < rank_out; ++i) {
97 int dim_x = i < (rank_out - rank_x)
98 ? 1
99 : input_shape_x.dims(i - (rank_out - rank_x));
100 bool dim_y_is_one = i < (rank_out - rank_y);
101 int dim_y = dim_y_is_one ? 1 : input_shape_y.dims(i - (rank_out - rank_y));
102 if (dim_x == -1 || dim_y == -1) {
103 // One or both dimensions is unknown.
104 QCHECK(false) << "Shapes must be specified";
105 } else if (dim_x == 1 || dim_y == 1) {
106 // Broadcast one dimension to the other that is 1.
107 if (dim_x == 1 && !dim_y_is_one) {
108 // Broadcast dim_y to dim_x (1).
109 dims_out->push_back(dim_y);
110 } else {
111 // Broadcast dim_x to dim_y (1).
112 DCHECK_EQ(dim_y, 1);
113 dims_out->push_back(dim_x);
114 }
115 } else {
116 // Expect the dimensions to match.
117 CHECK_EQ(dim_x, dim_y) << "Dimensions must match";
118 dims_out->push_back(dim_x);
119 }
120 }
121 CHECK(output_array->has_shape());
122 }
123
ProcessConvOperator(Model * model,ConvOperator * op)124 void ProcessConvOperator(Model* model, ConvOperator* op) {
125 const auto& input_array = model->GetArray(op->inputs[0]);
126 // Yield until input dims have been resolved.
127 if (!input_array.has_shape()) {
128 return;
129 }
130 const auto& input_shape = input_array.shape();
131 CHECK(input_shape.dimensions_count() == 4)
132 << "Conv ops require 4D inputs. Input array \"" << op->inputs[0]
133 << "\" is " << input_shape.dimensions_count() << "D.";
134
135 const auto& weights_array = model->GetArray(op->inputs[1]);
136 // Yield until weights dims have been resolved.
137 if (!weights_array.has_shape()) {
138 return;
139 }
140 const auto& weights_shape = weights_array.shape();
141 CHECK_EQ(weights_shape.dimensions_count(), 4);
142
143 auto& output_array = model->GetArray(op->outputs[0]);
144 const int output_depth = weights_shape.dims(0);
145 const int kheight = weights_shape.dims(1);
146 const int kwidth = weights_shape.dims(2);
147 ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
148 op->stride_height, op->dilation_width_factor,
149 op->dilation_height_factor, op->padding.type,
150 output_array.mutable_shape(),
151 &op->padding.GetOrCreateFixedPadding());
152 CHECK_EQ(output_array.shape().dimensions_count(), 4);
153
154 // Set im2col array dimensions if there is one.
155 if (op->outputs.size() == 2) {
156 const auto& output_shape = output_array.shape();
157 const int input_depth = weights_shape.dims(3);
158 auto& im2col_array = model->GetArray(op->outputs[1]);
159 im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1),
160 output_shape.dims(2),
161 input_depth * kheight * kwidth});
162 }
163 }
164
ProcessTransposeConvOperator(Model * model,TransposeConvOperator * op)165 void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
166 // TransposeConv is unique in that it is specifically given the output shape
167 // as a 1D array on it's 1st input. Theoretically then, resolving the output
168 // shape is as easy as waiting for this input to be resolved. However, we also
169 // have to calculate the padding which requires the weights shape. So, we
170 // might as well calculate the output shape and ensure it matches the
171 // specified one
172
173 // SPECIFIED OUTPUT SHAPE
174 // The below is the specified, or prescribed output shape, _given_ to the
175 // operator as an input.
176 auto& specified_output_shape_array =
177 model->GetArray(op->inputs[TransposeConvOperator::OUTPUT_SHAPE]);
178 if (!specified_output_shape_array.has_shape() ||
179 !specified_output_shape_array.buffer) {
180 // Yield until the specified output shape is resolved as a constant
181 return;
182 }
183
184 CHECK(specified_output_shape_array.data_type == ArrayDataType::kInt32)
185 << "TransposeConv input_dims must be int32";
186
187 CHECK(specified_output_shape_array.shape().dimensions_count() == 1 &&
188 specified_output_shape_array.shape().dims(0) == 4)
189 << "TransposeConv requires a 1D, 4 element array on it's 0th input "
190 "specifying the output shape. \""
191 << op->inputs[TransposeConvOperator::OUTPUT_SHAPE] << "\" had shape "
192 << toco::ShapeToString(specified_output_shape_array.shape());
193
194 // COMPUTE PADDING
195 // We require the weights shape to calculate padding.
196 const auto& weights_array =
197 model->GetArray(op->inputs[TransposeConvOperator::WEIGHTS]);
198 if (!weights_array.has_shape()) {
199 // Yield until weights dims have been resolved.
200 return;
201 }
202 const auto& weights_shape = weights_array.shape();
203 CHECK_EQ(weights_shape.dimensions_count(), 4)
204 << "TransposeConv weights must have 4 input dimensions. Input weights \""
205 << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
206 << toco::ShapeToString(weights_shape) << ".";
207
208 // Compute padding
209 const int kheight = weights_shape.dims(1);
210 const int kwidth = weights_shape.dims(2);
211 op->padding.GetOrCreateFixedPadding();
212 if (op->padding.type == PaddingType::kValid) {
213 op->padding.fixed->height = 0;
214 op->padding.fixed->width = 0;
215 } else if (op->padding.type == PaddingType::kSame) {
216 op->padding.fixed->height = (kheight - 1) / 2;
217 op->padding.fixed->width = (kwidth - 1) / 2;
218 } else {
219 LOG(FATAL) << "TransposeConv only supports SAME or VALID padding";
220 }
221
222 // VALIDATE some dimensions and set the output shape.
223 const auto& input_array =
224 model->GetArray(op->inputs[TransposeConvOperator::DATA_INPUT]);
225 if (!input_array.has_shape()) {
226 // Yield until input dims have been resolved.
227 return;
228 }
229 const auto& input_shape = input_array.shape();
230 CHECK_EQ(input_shape.dimensions_count(), 4)
231 << "TransposeConv input shape must have 4 dimensions. Input \""
232 << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
233 << toco::ShapeToString(weights_shape) << ".";
234 CHECK_EQ(input_shape.dims(3), weights_shape.dims(3))
235 << "Input shape depth and weight depth do not agree";
236
237 // Set the output shape according to the specified output shape.
238 std::vector<int32> const& specified_output_shape =
239 specified_output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
240 auto& output_array = model->GetArray(op->outputs[0]);
241 *(output_array.mutable_shape()->mutable_dims()) = specified_output_shape;
242
243 // Set im2col array dimensions if there is one.
244 if (op->outputs.size() == 2) {
245 const int input_depth = weights_shape.dims(3);
246 auto& im2col_array = model->GetArray(op->outputs[1]);
247 im2col_array.copy_shape(
248 Shape{specified_output_shape[0], specified_output_shape[1],
249 specified_output_shape[2], input_depth * kheight * kwidth});
250 }
251 }
252
ProcessDepthwiseConvOperator(Model * model,DepthwiseConvOperator * op)253 void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
254 const auto& input_array = model->GetArray(op->inputs[0]);
255 // Yield until input dims have been resolved.
256 if (!input_array.has_shape()) {
257 return;
258 }
259 const auto& input_shape = input_array.shape();
260 CHECK_EQ(input_shape.dimensions_count(), 4);
261
262 const auto& weights_array = model->GetArray(op->inputs[1]);
263 // Yield until weights dims have been resolved.
264 if (!weights_array.has_shape()) {
265 return;
266 }
267 const auto& weights_shape = weights_array.shape();
268 CHECK_EQ(weights_shape.dimensions_count(), 4);
269
270 const string& output_name = op->outputs[0];
271 const int input_depth = input_shape.dims(3);
272 const int output_depth = weights_shape.dims(3);
273 // TensorFlow doesn't define the depth_multiplier value on DepthwiseConv ops,
274 // instead it has to be inferred from the weights dims. However, once we are
275 // here, weights dims have already been converted to our own internal format,
276 // where the multiplier is no longer readily apparent. So instead we get it
277 // as the quotient of output and input depths. We only want to do that when
278 // depth_multiplier had the zero value: any other value should be checked
279 // as done by the next if() below.
280 if (!op->depth_multiplier) {
281 op->depth_multiplier = output_depth / input_depth;
282 }
283 CHECK_EQ(output_depth, input_depth * op->depth_multiplier)
284 << "input/output depths and depth_multiplier don't match";
285
286 const int kheight = weights_shape.dims(1);
287 const int kwidth = weights_shape.dims(2);
288 ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
289 op->stride_height, op->dilation_width_factor,
290 op->dilation_height_factor, op->padding.type,
291 model->GetArray(output_name).mutable_shape(),
292 &op->padding.GetOrCreateFixedPadding());
293 }
294
ProcessDepthToSpaceOperator(Model * model,DepthToSpaceOperator * op)295 void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
296 const auto& input_array = model->GetArray(op->inputs[0]);
297 // Yield until input dims have been resolved.
298 if (!input_array.has_shape()) {
299 return;
300 }
301 const auto& input_shape = input_array.shape();
302 CHECK_EQ(input_shape.dimensions_count(), 4);
303
304 const string& output_name = op->outputs[0];
305 const int block_size = op->block_size;
306 CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
307 const int batch = input_shape.dims(0);
308 const int height = input_shape.dims(1);
309 const int width = input_shape.dims(2);
310 const int depth = input_shape.dims(3);
311 QCHECK_EQ(depth % (block_size * block_size), 0);
312
313 model->GetArray(output_name)
314 .copy_shape(Shape({batch, height * block_size, width * block_size,
315 depth / block_size / block_size}));
316 }
317
ProcessSpaceToDepthOperator(Model * model,SpaceToDepthOperator * op)318 void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
319 const auto& input_array = model->GetArray(op->inputs[0]);
320 // Yield until input dims have been resolved.
321 if (!input_array.has_shape()) {
322 return;
323 }
324 const auto& input_shape = input_array.shape();
325 CHECK_EQ(input_shape.dimensions_count(), 4);
326
327 const string& output_name = op->outputs[0];
328 const int block_size = op->block_size;
329 CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
330 const int batch = input_shape.dims(0);
331 const int height = input_shape.dims(1);
332 const int width = input_shape.dims(2);
333 const int depth = input_shape.dims(3);
334 QCHECK_EQ(width % block_size, 0);
335 QCHECK_EQ(height % block_size, 0);
336
337 model->GetArray(output_name)
338 .copy_shape(Shape({batch, height / block_size, width / block_size,
339 depth * block_size * block_size}));
340 }
341
ProcessOpWithShapeInput(Model * model,Operator * op)342 void ProcessOpWithShapeInput(Model* model, Operator* op) {
343 CHECK_EQ(op->outputs.size(), 1);
344 auto& output_array = model->GetArray(op->outputs[0]);
345 if (output_array.has_shape()) {
346 // We have already run
347 return;
348 }
349
350 auto& dims_array = model->GetArray(op->inputs[0]);
351 if (!dims_array.has_shape()) {
352 // Yield until dims shape been resolved.
353 return;
354 }
355 if (!dims_array.buffer) {
356 // Yield until the dims are constant
357 return;
358 }
359 CHECK(dims_array.data_type == ArrayDataType::kInt32) << "dims must be int32";
360 CHECK_LE(RequiredBufferSizeForShape(dims_array.shape()), 4)
361 << "dims vector can be no larger than 4 values";
362
363 std::vector<int32> const& dims =
364 dims_array.GetBuffer<ArrayDataType::kInt32>().data;
365 *(output_array.mutable_shape()->mutable_dims()) = dims;
366 }
367
ProcessFullyConnectedOperator(Model * model,FullyConnectedOperator * op)368 void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
369 const auto& input_array = model->GetArray(op->inputs[0]);
370 // Yield until input dims have been resolved.
371 if (!input_array.has_shape()) {
372 return;
373 }
374 const auto& input_shape = input_array.shape();
375 CHECK_GE(input_shape.dimensions_count(), 1);
376
377 const auto& weights_array = model->GetArray(op->inputs[1]);
378 // Yield until weights dims have been resolved.
379 if (!weights_array.has_shape()) {
380 return;
381 }
382 const auto& weights_shape = weights_array.shape();
383
384 const int weights_output_depth = weights_shape.dims(0);
385 CHECK_EQ(weights_shape.dimensions_count(), 2);
386
387 const int input_overall_size = RequiredBufferSizeForShape(input_shape);
388 const int matmul_repeats = input_overall_size / weights_shape.dims(1);
389 CHECK_EQ(matmul_repeats * weights_shape.dims(1), input_overall_size);
390
391 auto& output_array = model->GetArray(op->outputs[0]);
392 output_array.copy_shape(Shape({matmul_repeats, weights_output_depth}));
393 }
394
ProcessTensorFlowReshapeOperator(Model * model,TensorFlowReshapeOperator * op)395 void ProcessTensorFlowReshapeOperator(Model* model,
396 TensorFlowReshapeOperator* op) {
397 auto& output_array = model->GetArray(op->outputs[0]);
398 if (output_array.has_shape()) {
399 // We have already run
400 return;
401 }
402
403 const auto& input_array = model->GetArray(op->inputs[0]);
404 if (!input_array.has_shape()) {
405 // Yield until input dims have been resolved.
406 return;
407 }
408 const auto& input_shape = input_array.shape();
409
410 auto& shape_array = model->GetArray(op->inputs[1]);
411 if (!shape_array.has_shape()) {
412 // Yield until target_shape shape been resolved.
413 return;
414 }
415 if (!shape_array.buffer) {
416 // Yield until the target_shape is constant
417 return;
418 }
419 CHECK(shape_array.data_type == ArrayDataType::kInt32)
420 << "Reshape dims must be int32";
421
422 // shape_data is the raw array of ints describing the shape
423 // in the TensorFlow node. We intentionally make a copy here, rather than
424 // modify wildcards in-place below, because in some graphs, the same shape
425 // array with a wildcard may be referenced from multiple Reshape nodes, where
426 // the wildcard needs to resolved to distinct values.
427 std::vector<int32> shape_data =
428 shape_array.GetBuffer<ArrayDataType::kInt32>().data;
429 // The Reshape shape may have a wildcard dim, encoded as -1.
430 bool has_wildcard = false;
431 int wildcard_index = 0;
432 int product_non_wildcard_dims = 1;
433 for (int i = 0; i < shape_data.size(); i++) {
434 if (shape_data[i] == -1) {
435 CHECK(!has_wildcard);
436 has_wildcard = true;
437 wildcard_index = i;
438 } else {
439 product_non_wildcard_dims *= shape_data[i];
440 }
441 }
442
443 const int input_flat_size = RequiredBufferSizeForShape(input_shape);
444 if (has_wildcard) {
445 CHECK_GE(input_flat_size, product_non_wildcard_dims)
446 << "Array not large enough to fill the requested dimensions for "
447 "Reshape op with output \""
448 << op->outputs[0] << "\". Are your input shapes correct?";
449 shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims;
450 }
451
452 if (shape_data.size() == 1 && shape_data[0] == 0) {
453 // We have reshaped a scalar, so preserve as a scalar.
454 shape_data.clear();
455 }
456
457 auto& output_shape = *output_array.mutable_shape();
458 *output_shape.mutable_dims() = shape_data;
459 CHECK_EQ(input_flat_size, RequiredBufferSizeForShape(output_shape))
460 << "Input cannot be reshaped to requested dimensions for Reshape op with "
461 "output \""
462 << op->outputs[0] << "\". Are your input shapes correct?";
463 }
464
ProcessSimpleOperator(Model * model,Operator * op,int input_index)465 void ProcessSimpleOperator(Model* model, Operator* op, int input_index) {
466 const auto& input_array = model->GetArray(op->inputs[input_index]);
467 // Yield until input dims have been resolved.
468 if (!input_array.has_shape()) {
469 return;
470 }
471
472 const string& output_name = op->outputs[0];
473 auto& output_array = model->GetArray(output_name);
474 if (output_array.has_shape()) {
475 return;
476 }
477
478 output_array.copy_shape(input_array.shape());
479 }
480
ProcessSimpleBinaryOperator(Model * model,Operator * op)481 void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
482 CHECK_EQ(op->inputs.size(), 2);
483 const auto& input0_array = model->GetArray(op->inputs[0]);
484 const auto& input1_array = model->GetArray(op->inputs[1]);
485 // Yield until input dims have been resolved.
486 if (!input0_array.has_shape() || !input1_array.has_shape()) {
487 return;
488 }
489 const string& output_name = op->outputs[0];
490 auto& output_array = model->GetArray(output_name);
491 ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(),
492 &output_array);
493 }
494
ProcessSelectOperator(Model * model,SelectOperator * op)495 void ProcessSelectOperator(Model* model, SelectOperator* op) {
496 // Yield until all input dims have been resolved.
497 for (const auto& input : op->inputs) {
498 const auto& input_array = model->GetArray(input);
499 if (!input_array.has_shape()) {
500 return;
501 }
502 }
503
504 // Select's output matches the second and third output.
505 const auto& input1_array = model->GetArray(op->inputs[1]);
506 auto& output_array = model->GetArray(op->outputs[0]);
507 output_array.copy_shape(input1_array.shape());
508 }
509
ProcessAddNOperator(Model * model,Operator * op)510 void ProcessAddNOperator(Model* model, Operator* op) {
511 // Yield until all input dims have been resolved.
512 //
513 // TODO(myenik): Since AddN does not support broadcasting, maybe we could
514 // actually use this to improve shape propagation by propagating the shape of
515 // one input to all other inputs once it is resolved instead of just the
516 // output, since all inputs must be the same size and shape for a well-formed
517 // graph.
518 for (const auto& input : op->inputs) {
519 const auto& input_array = model->GetArray(input);
520 if (!input_array.has_shape()) {
521 return;
522 }
523 }
524
525 // AddN does not support broadcasting, all inputs must be the same shape, so
526 // we just take the first input shape and apply it to the output.
527 const auto& input0_array = model->GetArray(op->inputs[0]);
528 auto& output_array = model->GetArray(op->outputs[0]);
529 output_array.copy_shape(input0_array.shape());
530 }
531
KeepDims(const Operator & op)532 bool KeepDims(const Operator& op) {
533 switch (op.type) {
534 case OperatorType::kReduceMin: // Reduction Min
535 return static_cast<const TensorFlowMinOperator&>(op).keep_dims;
536 case OperatorType::kReduceMax: // Reduction Max
537 return static_cast<const TensorFlowMaxOperator&>(op).keep_dims;
538 case OperatorType::kSum:
539 return static_cast<const TensorFlowSumOperator&>(op).keep_dims;
540 case OperatorType::kReduceProd:
541 return static_cast<const TensorFlowProdOperator&>(op).keep_dims;
542 case OperatorType::kMean:
543 return static_cast<const MeanOperator&>(op).keep_dims;
544 case OperatorType::kAny:
545 return static_cast<const TensorFlowAnyOperator&>(op).keep_dims;
546 default:
547 LOG(FATAL) << "Not a reduction operator!";
548 return false;
549 }
550 }
551
ProcessTensorFlowReductionOperator(Model * model,Operator * op)552 void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
553 CHECK_LE(op->inputs.size(), 2);
554 auto& output_array = model->GetArray(op->outputs[0]);
555 if (output_array.has_shape()) {
556 return;
557 }
558 const auto& input_array = model->GetArray(op->inputs[0]);
559 if (!input_array.has_shape()) {
560 return;
561 }
562 const auto& input_shape = input_array.shape();
563 const bool keep_dims = KeepDims(*op);
564 if (op->inputs.size() == 2) {
565 // There is a reduction_indices input.
566 const auto& reduction_indices_array = model->GetArray(op->inputs[1]);
567 if (!reduction_indices_array.buffer) {
568 return;
569 }
570 CHECK(reduction_indices_array.buffer->type == ArrayDataType::kInt32);
571
572 int input_rank = input_shape.dimensions_count();
573 std::set<int32> true_indices;
574 const auto& reduction_indices =
575 reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data;
576 for (int i = 0; i < reduction_indices.size(); ++i) {
577 const int32 reduction_index = reduction_indices[i];
578 if (reduction_index < -input_rank || reduction_index >= input_rank) {
579 CHECK(false) << "Invalid reduction dimension " << reduction_index
580 << " for input with " << input_rank << " dimensions";
581 }
582 int32 wrapped_index = reduction_index;
583 if (wrapped_index < 0) {
584 wrapped_index += input_rank;
585 }
586 true_indices.insert(wrapped_index);
587 }
588
589 auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
590 mutable_dims->clear();
591 for (int i = 0; i < input_rank; ++i) {
592 if (true_indices.count(i) > 0) {
593 if (keep_dims) {
594 mutable_dims->emplace_back(1);
595 }
596 } else {
597 mutable_dims->emplace_back(input_shape.dims(i));
598 }
599 }
600 } else {
601 // No reduction_indices means complete reduction to a single scalar.
602 if (keep_dims) {
603 output_array.copy_shape(input_shape);
604 } else {
605 output_array.copy_shape(Shape({}));
606 }
607 }
608 }
609
ProcessSliceOperator(Model * model,SliceOperator * op)610 void ProcessSliceOperator(Model* model, SliceOperator* op) {
611 CHECK_EQ(op->inputs.size(), 3);
612 CHECK_EQ(op->outputs.size(), 1);
613
614 // Yield until the Slice params have been resolved.
615 if (op->begin.empty()) return;
616
617 // Yield until input dims have been resolved.
618 const auto& input_array = model->GetArray(op->inputs[0]);
619 if (!input_array.has_shape()) return;
620 const Shape& input_shape = input_array.shape();
621
622 auto& output_array = model->GetArray(op->outputs[0]);
623 if (output_array.has_shape()) return;
624
625 CHECK_EQ(input_shape.dims().size(), op->size.size());
626 CHECK_EQ(op->begin.size(), op->size.size());
627
628 std::vector<int> output_dims;
629 for (int i = 0; i < op->begin.size(); ++i) {
630 int size = op->size[i];
631 if (size == -1) {
632 size = input_array.shape().dims(i) - op->begin[i];
633 }
634 output_dims.push_back(size);
635 }
636
637 *output_array.mutable_shape()->mutable_dims() = output_dims;
638 }
639
ProcessReorderAxesOperator(Model * model,ReorderAxesOperator * op)640 void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
641 const string& input_name = op->inputs[0];
642 const auto& input_array = model->GetArray(input_name);
643 // Yield until input dims have been resolved.
644 if (!input_array.has_shape()) {
645 return;
646 }
647 const auto& input_shape = input_array.shape();
648 const string& output_name = op->outputs[0];
649 Shape* output_shape = model->GetArray(output_name).mutable_shape();
650 ShuffleDims(input_shape, op->input_axes_order, op->output_axes_order,
651 output_shape);
652 }
653
ProcessConcatenationOperator(Model * model,ConcatenationOperator * op)654 void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
655 // Yield until input dims have been resolved.
656 for (const auto& input_name : op->inputs) {
657 auto& input_array = model->GetArray(input_name);
658 if (!input_array.has_shape()) {
659 return;
660 }
661 }
662 auto& output_array = model->GetArray(op->outputs[0]);
663 // Use first non-empty input as basis for output dimensions.
664 for (const auto& input_name : op->inputs) {
665 const auto& input_array = model->GetArray(input_name);
666 if (input_array.shape().dimensions_count() > 0) {
667 output_array.copy_shape(input_array.shape());
668 // Negative axis means the count starts at the back of the dims().
669 if (op->axis < 0) op->axis += input_array.shape().dims().size();
670 break;
671 }
672 }
673 // Determine the concat size, and enfore that all inputs have
674 // the same dimensions count.
675 int concat_size = 0;
676 for (const auto& input_name : op->inputs) {
677 auto& input_array = model->GetArray(input_name);
678 CHECK(input_array.has_shape());
679 if (input_array.shape().dimensions_count() == 0) {
680 continue;
681 }
682 CHECK_EQ(input_array.shape().dimensions_count(),
683 output_array.shape().dimensions_count());
684 const std::vector<int>& input_dims = input_array.shape().dims();
685 CHECK_LT(op->axis, input_dims.size());
686 concat_size += input_dims[op->axis];
687 }
688 // Write out the concat_size on the output array shape.
689 auto& output_shape = *output_array.mutable_shape();
690 auto& output_dims = *output_shape.mutable_dims();
691 CHECK_LT(op->axis, output_shape.dimensions_count());
692 output_dims[op->axis] = concat_size;
693 }
694
ProcessRangeOperator(Model * model,RangeOperator * op)695 void ProcessRangeOperator(Model* model, RangeOperator* op) {
696 CHECK_EQ(op->inputs.size(), 3);
697 const auto& start_array = model->GetArray(op->inputs[0]);
698 if (!start_array.has_shape()) {
699 // Yield until input dims have been resolved.
700 return;
701 }
702 const auto& limit_array = model->GetArray(op->inputs[1]);
703 if (!limit_array.has_shape()) {
704 return;
705 }
706 const auto& delta_array = model->GetArray(op->inputs[2]);
707 if (!delta_array.has_shape()) {
708 return;
709 }
710
711 if (!IsConstantParameterArray(*model, op->inputs[0])) {
712 // Yield until inputs are constant.
713 return;
714 }
715 if (!IsConstantParameterArray(*model, op->inputs[1])) {
716 return;
717 }
718 if (!IsConstantParameterArray(*model, op->inputs[2])) {
719 return;
720 }
721
722 CHECK(start_array.data_type == ArrayDataType::kInt32)
723 << "Range op inputs must be int32.";
724 CHECK(limit_array.data_type == ArrayDataType::kInt32)
725 << "Range op inputs must be int32.";
726 CHECK(delta_array.data_type == ArrayDataType::kInt32)
727 << "Range op inputs must be int32.";
728 CHECK_EQ(RequiredBufferSizeForShape(start_array.shape()), 1)
729 << "Range op inputs must be scalar.";
730 CHECK_EQ(RequiredBufferSizeForShape(limit_array.shape()), 1)
731 << "Range op inputs must be scalar.";
732 CHECK_EQ(RequiredBufferSizeForShape(delta_array.shape()), 1)
733 << "Range op inputs must be scalar.";
734 int size = floor((limit_array.GetBuffer<ArrayDataType::kInt32>().data[0] -
735 start_array.GetBuffer<ArrayDataType::kInt32>().data[0]) /
736 delta_array.GetBuffer<ArrayDataType::kInt32>().data[0]);
737
738 // Only set the output shape. Contents are set by ResolveConstantRange.
739 CHECK_EQ(op->outputs.size(), 1);
740 auto& output_array = model->GetArray(op->outputs[0]);
741 Shape* output_shape = output_array.mutable_shape();
742 output_shape->ReplaceDims({size});
743 }
744
ProcessTensorFlowSplitOperator(Model * model,TensorFlowSplitOperator * op)745 void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
746 CHECK_EQ(op->inputs.size(), 2);
747 const string& input_name = op->inputs[1];
748 const auto& input_array = model->GetArray(input_name);
749 // Yield until input dims have been resolved.
750 if (!input_array.has_shape()) {
751 return;
752 }
753 const Shape& input_shape = input_array.shape();
754
755 // Yield until axis is constant.
756 if (!IsConstantParameterArray(*model, op->inputs[0])) {
757 return;
758 }
759
760 const auto& axis_array = model->GetArray(op->inputs[0]);
761
762 // Yield until axis dims have been resolved.
763 if (!axis_array.has_shape()) {
764 return;
765 }
766
767 CHECK(axis_array.data_type == ArrayDataType::kInt32)
768 << "Axis array must be int32.";
769 CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
770 << "Axis array must be scalar.";
771
772 int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
773 if (axis < 0) {
774 axis += input_shape.dimensions_count();
775 }
776
777 const int split_dim = input_shape.dims(axis);
778 CHECK_EQ(split_dim % op->num_split, 0);
779 const int split_depth = split_dim / op->num_split;
780
781 Shape output_shape = input_shape;
782 (*output_shape.mutable_dims())[axis] = split_depth;
783
784 CHECK_EQ(op->outputs.size(), op->num_split);
785 for (const auto& output : op->outputs) {
786 model->GetArray(output).copy_shape(output_shape);
787 }
788 }
789
ProcessTensorFlowSplitVOperator(Model * model,TensorFlowSplitVOperator * op)790 void ProcessTensorFlowSplitVOperator(Model* model,
791 TensorFlowSplitVOperator* op) {
792 CHECK_EQ(op->inputs.size(), 3);
793
794 const auto& input_array = model->GetArray(op->inputs[0]);
795 // Yield until input dims have been resolved.
796 if (!input_array.has_shape()) {
797 return;
798 }
799 const Shape& input_shape = input_array.shape();
800
801 // Yield until size_splits is constant.
802 if (!IsConstantParameterArray(*model, op->inputs[1])) {
803 return;
804 }
805 const auto& size_array = model->GetArray(op->inputs[1]);
806 // Yield until size_splits dims have been resolved.
807 if (!size_array.has_shape()) {
808 return;
809 }
810 const Shape& size_shape = size_array.shape();
811
812 CHECK(size_array.data_type == ArrayDataType::kInt32 ||
813 size_array.data_type == ArrayDataType::kInt64)
814 << "size_splits must be int32, int64";
815 CHECK_EQ(size_shape.dimensions_count(), 1) << "size_splits must be 1-D";
816
817 std::vector<int64> size_splits_vector;
818 if (size_array.data_type == ArrayDataType::kInt32) {
819 for (const auto each_size :
820 size_array.GetBuffer<ArrayDataType::kInt32>().data) {
821 size_splits_vector.push_back(each_size);
822 }
823 } else {
824 size_splits_vector = size_array.GetBuffer<ArrayDataType::kInt64>().data;
825 }
826
827 // Yield until axis is constant.
828 if (!IsConstantParameterArray(*model, op->inputs[2])) {
829 return;
830 }
831 const auto& axis_array = model->GetArray(op->inputs[2]);
832 // Yield until axis dims have been resolved.
833 if (!axis_array.has_shape()) {
834 return;
835 }
836
837 CHECK(axis_array.data_type == ArrayDataType::kInt32)
838 << "Axis array must be int32.";
839 CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
840 << "Axis array must be scalar.";
841
842 int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
843 if (axis < 0) {
844 axis += input_shape.dimensions_count();
845 }
846
847 CHECK_EQ(op->num_split, size_splits_vector.size());
848
849 int64_t minus_one_count = 0, size_splits_sum = 0;
850 for (auto size : size_splits_vector) {
851 if (size == -1) {
852 ++minus_one_count;
853 } else {
854 size_splits_sum += size;
855 }
856 }
857
858 const int input_size = input_shape.dims(axis);
859
860 CHECK_LE(minus_one_count, 1) << "size_splits can contain at most one -1.";
861
862 if (minus_one_count == 1) {
863 CHECK_LE(size_splits_sum, input_size);
864 auto iter =
865 std::find(size_splits_vector.begin(), size_splits_vector.end(), -1);
866 *iter = input_size - size_splits_sum;
867 } else {
868 CHECK_EQ(size_splits_sum, input_size);
869 }
870
871 CHECK_EQ(op->outputs.size(), op->num_split);
872
873 for (int i = 0; i < op->outputs.size(); ++i) {
874 const auto& output = op->outputs[i];
875 Shape output_shape = input_shape;
876 (*output_shape.mutable_dims())[axis] = size_splits_vector.at(i);
877 model->GetArray(output).copy_shape(output_shape);
878 }
879 }
880
ProcessAveragePoolOperator(Model * model,AveragePoolOperator * op)881 void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
882 const string& input_name = op->inputs[0];
883 const auto& input_array = model->GetArray(input_name);
884 // Yield until input dims have been resolved.
885 if (!input_array.has_shape()) {
886 return;
887 }
888 const auto& input_shape = input_array.shape();
889 CHECK_EQ(input_shape.dimensions_count(), 4);
890 const string& output_name = op->outputs[0];
891 const int output_depth = input_shape.dims(3);
892 ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
893 op->stride_width, op->stride_height, 1, 1, op->padding.type,
894 model->GetArray(output_name).mutable_shape(),
895 &op->padding.GetOrCreateFixedPadding());
896 }
897
ProcessMaxPoolOperator(Model * model,MaxPoolOperator * op)898 void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
899 const string& input_name = op->inputs[0];
900 const auto& input_array = model->GetArray(input_name);
901 // Yield until input dims have been resolved.
902 if (!input_array.has_shape()) {
903 return;
904 }
905 const auto& input_shape = input_array.shape();
906 CHECK_EQ(input_shape.dimensions_count(), 4);
907 const string& output_name = op->outputs[0];
908 const int output_depth = input_shape.dims(3);
909 ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
910 op->stride_width, op->stride_height, 1, 1, op->padding.type,
911 model->GetArray(output_name).mutable_shape(),
912 &op->padding.GetOrCreateFixedPadding());
913 }
914
ProcessL2PoolOperator(Model * model,L2PoolOperator * op)915 void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
916 const string& input_name = op->inputs[0];
917 const auto& input_array = model->GetArray(input_name);
918 // Yield until input dims have been resolved.
919 if (!input_array.has_shape()) {
920 return;
921 }
922 const auto& input_shape = input_array.shape();
923 if (input_shape.dimensions_count() < 4) {
924 LOG(FATAL) << "missing dimensions for " << input_name;
925 }
926 const string& output_name = op->outputs[0];
927 const int output_depth = input_shape.dims(3);
928 ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
929 op->stride_width, op->stride_height, 1, 1, op->padding.type,
930 model->GetArray(output_name).mutable_shape(),
931 &op->padding.GetOrCreateFixedPadding());
932 }
933
ProcessResizeBilinearOperator(Model * model,ResizeBilinearOperator * op)934 void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
935 CHECK_EQ(op->inputs.size(), 2);
936 CHECK_EQ(op->outputs.size(), 1);
937
938 if (!model->GetArray(op->inputs[0]).has_shape() ||
939 !model->GetArray(op->inputs[1]).has_shape()) {
940 return;
941 }
942 const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
943
944 const string& output_size_name = op->inputs[1];
945 const auto& output_size_array = model->GetArray(output_size_name);
946 CHECK(output_size_array.data_type == ArrayDataType::kInt32);
947 CHECK(output_size_array.has_shape());
948 const auto& output_size_shape = output_size_array.shape();
949 CHECK_EQ(output_size_shape.dimensions_count(), 1);
950 CHECK_EQ(output_size_shape.dims(0), 2);
951 if (!output_size_array.buffer) {
952 return;
953 }
954 std::vector<int32> output_shape =
955 output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
956 model->GetArray(op->outputs[0])
957 .copy_shape(Shape({input_data_shape.dims(0), output_shape[0],
958 output_shape[1], input_data_shape.dims(3)}));
959 }
960
ProcessResizeNearestNeighborOperator(Model * model,ResizeNearestNeighborOperator * op)961 void ProcessResizeNearestNeighborOperator(Model* model,
962 ResizeNearestNeighborOperator* op) {
963 CHECK_EQ(op->inputs.size(), 2);
964 CHECK_EQ(op->outputs.size(), 1);
965
966 if (!model->GetArray(op->inputs[0]).has_shape() ||
967 !model->GetArray(op->inputs[1]).has_shape()) {
968 return;
969 }
970 const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
971
972 const string& output_size_name = op->inputs[1];
973 const auto& output_size_array = model->GetArray(output_size_name);
974 CHECK(output_size_array.data_type == ArrayDataType::kInt32);
975 CHECK(output_size_array.has_shape());
976 const auto& output_size_shape = output_size_array.shape();
977 CHECK_EQ(output_size_shape.dimensions_count(), 1);
978 CHECK_EQ(output_size_shape.dims(0), 2);
979 if (!output_size_array.buffer) {
980 return;
981 }
982 std::vector<int32> output_shape =
983 output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
984 model->GetArray(op->outputs[0])
985 .copy_shape(Shape({input_data_shape.dims(0), output_shape[0],
986 output_shape[1], input_data_shape.dims(3)}));
987 }
988
ProcessLstmCellOperator(Model * model,LstmCellOperator * op)989 void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
990 // Only required for compact LstmCell with default NUM_INPUTS of inputs.
991 if (op->inputs.size() != LstmCellOperator::NUM_INPUTS) return;
992
993 const auto& input_array =
994 model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]);
995 // Yield until all input dims have been resolved.
996 if (!input_array.has_shape()) {
997 return;
998 }
999 const auto& input_shape = input_array.shape();
1000 CHECK_GE(input_shape.dimensions_count(), 2);
1001
1002 const auto& prev_activ_array =
1003 model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]);
1004 // Yield until all input dims have been resolved.
1005 if (!prev_activ_array.has_shape()) {
1006 return;
1007 }
1008 const auto& prev_activ_shape = prev_activ_array.shape();
1009 CHECK_GE(prev_activ_shape.dimensions_count(), 2);
1010
1011 const auto& weights_array =
1012 model->GetArray(op->inputs[LstmCellOperator::WEIGHTS_INPUT]);
1013 // Yield until weights dims have been resolved.
1014 if (!weights_array.has_shape()) {
1015 return;
1016 }
1017 const auto& weights_shape = weights_array.shape();
1018 CHECK_EQ(weights_shape.dimensions_count(), 2);
1019
1020 const auto& bias_array =
1021 model->GetArray(op->inputs[LstmCellOperator::BIASES_INPUT]);
1022 // Yield until bias dims have been resolved.
1023 if (!bias_array.has_shape()) {
1024 return;
1025 }
1026 const auto& bias_shape = bias_array.shape();
1027 CHECK_GE(bias_shape.dimensions_count(), 1);
1028
1029 const auto& prev_state_array =
1030 model->GetArray(op->inputs[LstmCellOperator::PREV_STATE_INPUT]);
1031 // Yield until all input dims have been resolved.
1032 if (!prev_state_array.has_shape()) {
1033 return;
1034 }
1035 const auto& prev_state_shape = prev_state_array.shape();
1036 CHECK_GE(prev_state_shape.dimensions_count(), 2);
1037
1038 const int fc_output_depth = weights_shape.dims(0);
1039 CHECK_EQ(fc_output_depth, bias_shape.dims(0));
1040 CHECK_EQ(fc_output_depth % 4, 0);
1041 const int depth = fc_output_depth / 4;
1042
1043 const int input_depth = input_shape.dims(input_shape.dimensions_count() - 1);
1044 const int fc_input_depth = weights_shape.dims(1);
1045 CHECK_EQ(input_depth + depth, fc_input_depth);
1046 Shape output_shape(input_shape);
1047 (*output_shape.mutable_dims())[output_shape.dimensions_count() - 1] = depth;
1048
1049 // Set output dimensions
1050 model->GetArray(op->outputs[LstmCellOperator::STATE_OUTPUT])
1051 .copy_shape(output_shape);
1052 model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT])
1053 .copy_shape(output_shape);
1054
1055 Shape concat_temp_shape(input_shape);
1056 (*concat_temp_shape
1057 .mutable_dims())[concat_temp_shape.dimensions_count() - 1] =
1058 fc_input_depth;
1059 model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP])
1060 .copy_shape(concat_temp_shape);
1061
1062 Shape activ_temp_shape(input_shape);
1063 (*activ_temp_shape.mutable_dims())[activ_temp_shape.dimensions_count() - 1] =
1064 fc_output_depth;
1065 model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP])
1066 .copy_shape(activ_temp_shape);
1067 }
1068
ProcessUnidirectionalSequenceLstmOperator(Model * model,UnidirectionalSequenceLstmOperator * op)1069 void ProcessUnidirectionalSequenceLstmOperator(
1070 Model* model, UnidirectionalSequenceLstmOperator* op) {
1071 auto& output_array = model->GetArray(op->outputs[0]);
1072 if (output_array.has_shape()) {
1073 // Shape already propagated
1074 return;
1075 }
1076
1077 if (output_array.data_type == ArrayDataType::kNone) {
1078 // Yield until the output type has been set by PropagateArrayDataTypes
1079 return;
1080 }
1081
1082 // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1083 const auto& input_array = model->GetArray(op->inputs[0]);
1084
1085 constexpr int kInputActivationStateTensor = 18;
1086 constexpr int kInputCellStateTensor = 19;
1087
1088 // TFlite intepreter does not support array which is variable and contains a
1089 // buffer (see b/115961645 for more discussion).
1090 // The follow block remove buffer from the array to work around the
1091 // restriction, as a consequence, downstream applications should not
1092 // read lstm state as input to other operations.
1093 model->GetArray(op->inputs[kInputActivationStateTensor]).buffer.reset();
1094 model->GetArray(op->inputs[kInputCellStateTensor]).buffer.reset();
1095
1096 // Yield until input dims have been resolved.
1097 if (!input_array.has_shape()) {
1098 return;
1099 }
1100 const auto& input_shape = input_array.shape();
1101 const int batch_size = input_shape.dims(1);
1102 const int timestamp = input_shape.dims(0);
1103
1104 const auto& recurrent_to_output_weights_array =
1105 model->GetArray(op->inputs[8]);
1106 // Yield until input dims have been resolved.
1107 if (!recurrent_to_output_weights_array.has_shape()) {
1108 return;
1109 }
1110
1111 const auto& output_weights_shape = recurrent_to_output_weights_array.shape();
1112 const int output_size = output_weights_shape.dims(1);
1113
1114 Shape* output_shape = output_array.mutable_shape();
1115 output_shape->ReplaceDims({timestamp, batch_size, output_size});
1116 }
1117
ProcessUnidirectionalSequenceRnnOperator(Model * model,UnidirectionalSequenceRnnOperator * op)1118 void ProcessUnidirectionalSequenceRnnOperator(
1119 Model* model, UnidirectionalSequenceRnnOperator* op) {
1120 auto& output_array = model->GetArray(op->outputs[0]);
1121 if (output_array.has_shape()) {
1122 // Shape already propagated.
1123 return;
1124 }
1125
1126 if (output_array.data_type == ArrayDataType::kNone) {
1127 // Yield until the output type has been set by PropagateArrayDataTypes
1128 return;
1129 }
1130
1131 constexpr int kHiddenStateTensor = 4;
1132 // TFlite intepreter does not support array which is variable and contains a
1133 // buffer (see b/115961645 for more discussion).
1134 // The follow block remove buffer from the array to work around the
1135 // restriction, as a consequence, downstream applications should not
1136 // read lstm state as input to other operations.
1137 model->GetArray(op->inputs[kHiddenStateTensor]).buffer.reset();
1138
1139 // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1140 const auto& input_array = model->GetArray(op->inputs[0]);
1141 // Yield until input dims have been resolved.
1142 if (!input_array.has_shape()) {
1143 return;
1144 }
1145 const auto& input_shape = input_array.shape();
1146 const int batch_size = input_shape.dims(1);
1147 const int timestamp = input_shape.dims(0);
1148
1149 const auto& bias_array = model->GetArray(op->inputs[3]);
1150 // Yield until input dims have been resolved.
1151 if (!bias_array.has_shape()) {
1152 return;
1153 }
1154
1155 const auto& bias_shape = bias_array.shape();
1156 const int output_size = bias_shape.dims(0);
1157
1158 Shape* output_shape = output_array.mutable_shape();
1159 output_shape->ReplaceDims({timestamp, batch_size, output_size});
1160 }
1161
ProcessBidirectionalSequenceLstmOperator(Model * model,BidirectionalSequenceLstmOperator * op)1162 void ProcessBidirectionalSequenceLstmOperator(
1163 Model* model, BidirectionalSequenceLstmOperator* op) {
1164 // We assume time major.
1165 auto& fw_output_array = model->GetArray(op->outputs[0]);
1166 auto& bw_output_array = model->GetArray(op->outputs[1]);
1167 if (fw_output_array.has_shape()) {
1168 // Shape already propagated
1169 return;
1170 }
1171
1172 if (fw_output_array.data_type == ArrayDataType::kNone) {
1173 // Yield until the output type has been set by PropagateArrayDataTypes
1174 return;
1175 }
1176
1177 // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1178 const auto& input_array = model->GetArray(op->inputs[0]);
1179 // Yield until input dims have been resolved.
1180 if (!input_array.has_shape()) {
1181 return;
1182 }
1183 const auto& input_shape = input_array.shape();
1184 const int batch_size = input_shape.dims(1);
1185 const int timestamp = input_shape.dims(0);
1186
1187 constexpr int kBwRecurrentToOutputWeightsTensor = 25;
1188 const auto& recurrent_to_output_weights_array =
1189 model->GetArray(op->inputs[kBwRecurrentToOutputWeightsTensor]);
1190 // Yield until input dims have been resolved.
1191 if (!recurrent_to_output_weights_array.has_shape()) {
1192 return;
1193 }
1194
1195 constexpr int kFwInputActivationStateTensor = 35;
1196 constexpr int kFwInputCellStateTensor = 36;
1197 constexpr int kBwInputActivationStateTensor = 37;
1198 constexpr int kBwInputCellStateTensor = 38;
1199 // b(115961645): This is a hack to work around.
1200 model->GetArray(op->inputs[kFwInputActivationStateTensor]).buffer.reset();
1201 model->GetArray(op->inputs[kFwInputCellStateTensor]).buffer.reset();
1202 model->GetArray(op->inputs[kBwInputActivationStateTensor]).buffer.reset();
1203 model->GetArray(op->inputs[kBwInputCellStateTensor]).buffer.reset();
1204
1205 const auto& output_weights_shape = recurrent_to_output_weights_array.shape();
1206 const int output_size = output_weights_shape.dims(1);
1207
1208 Shape* fw_output_shape = fw_output_array.mutable_shape();
1209 if (op->merge_outputs) {
1210 fw_output_shape->ReplaceDims({timestamp, batch_size, 2 * output_size});
1211 } else {
1212 fw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1213 Shape* bw_output_shape = bw_output_array.mutable_shape();
1214 bw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1215 }
1216 }
1217
ProcessBidirectionalSequenceRnnOperator(Model * model,BidirectionalSequenceRnnOperator * op)1218 void ProcessBidirectionalSequenceRnnOperator(
1219 Model* model, BidirectionalSequenceRnnOperator* op) {
1220 // We assume time major.
1221 auto& fw_output_array = model->GetArray(op->outputs[0]);
1222 auto& bw_output_array = model->GetArray(op->outputs[1]);
1223 if (fw_output_array.has_shape()) {
1224 // Shape already propagated
1225 return;
1226 }
1227
1228 if (fw_output_array.data_type == ArrayDataType::kNone) {
1229 // Yield until the output type has been set by PropagateArrayDataTypes
1230 return;
1231 }
1232
1233 // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1234 const auto& input_array = model->GetArray(op->inputs[0]);
1235 // Yield until input dims have been resolved.
1236 if (!input_array.has_shape()) {
1237 return;
1238 }
1239 const auto& input_shape = input_array.shape();
1240 const int batch_size = input_shape.dims(1);
1241 const int timestamp = input_shape.dims(0);
1242
1243 constexpr int kFwWeightsTensor = 1;
1244 const auto& forward_weights_array =
1245 model->GetArray(op->inputs[kFwWeightsTensor]);
1246 // Yield until input dims have been resolved.
1247 if (!forward_weights_array.has_shape()) {
1248 return;
1249 }
1250
1251 constexpr int kFwHiddenStateTensor = 4;
1252 constexpr int kBwHiddenStateTensor = 8;
1253 // b(115961645): This is a hack to work around.
1254 model->GetArray(op->inputs[kFwHiddenStateTensor]).buffer.reset();
1255 model->GetArray(op->inputs[kBwHiddenStateTensor]).buffer.reset();
1256
1257 const auto& output_weights_shape = forward_weights_array.shape();
1258 const int output_size = output_weights_shape.dims(0);
1259
1260 Shape* fw_output_shape = fw_output_array.mutable_shape();
1261 if (op->merge_outputs) {
1262 fw_output_shape->ReplaceDims({timestamp, batch_size, 2 * output_size});
1263 } else {
1264 fw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1265 Shape* bw_output_shape = bw_output_array.mutable_shape();
1266 bw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1267 }
1268 }
1269
ProcessSpaceToBatchNDOperator(Model * model,SpaceToBatchNDOperator * op)1270 void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
1271 const auto& input_array = model->GetArray(op->inputs[0]);
1272 // Yield until input dims have been resolved.
1273 if (!input_array.has_shape()) {
1274 return;
1275 }
1276 const auto& input_shape = input_array.shape();
1277 // This method only handles input dimensions of 4.
1278 if (input_shape.dimensions_count() != 4) {
1279 return;
1280 }
1281 const auto input_height = input_shape.dims(1);
1282 const auto input_width = input_shape.dims(2);
1283
1284 const auto& block_shape_array = model->GetArray(op->inputs[1]);
1285 const auto& paddings_array = model->GetArray(op->inputs[2]);
1286 const auto& block_shape_array_shape = block_shape_array.shape();
1287 const auto& paddings_array_shape = paddings_array.shape();
1288 QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
1289 QCHECK_EQ(paddings_array_shape.dimensions_count(), 2);
1290
1291 // We only support two dimensions.
1292 QCHECK_EQ(block_shape_array_shape.dims(0), 2);
1293 if (!block_shape_array.buffer) {
1294 return;
1295 }
1296 QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
1297 const auto& block_shape_data =
1298 block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1299 auto block_height = block_shape_data[0];
1300 auto block_width = block_shape_data[1];
1301
1302 QCHECK_EQ(paddings_array_shape.dims(0), 2); // Number of block dimensions
1303 QCHECK_EQ(paddings_array_shape.dims(1), 2); // Two parameters per dimension.
1304 if (!paddings_array.buffer) {
1305 return;
1306 }
1307 QCHECK(paddings_array.data_type == ArrayDataType::kInt32);
1308 const auto& paddings_data =
1309 paddings_array.GetBuffer<ArrayDataType::kInt32>().data;
1310 int height_with_paddings = input_height + paddings_data[0] + paddings_data[1];
1311 int width_with_paddings = input_width + paddings_data[2] + paddings_data[3];
1312 QCHECK_EQ(height_with_paddings % block_height, 0);
1313 QCHECK_EQ(width_with_paddings % block_width, 0);
1314 int output_height = height_with_paddings / block_height;
1315 int output_width = width_with_paddings / block_width;
1316
1317 model->GetArray(op->outputs[0])
1318 .copy_shape(Shape({input_shape.dims(0) * block_height * block_width,
1319 output_height, output_width, input_shape.dims(3)}));
1320 }
1321
ProcessBatchToSpaceNDOperator(Model * model,BatchToSpaceNDOperator * op)1322 void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
1323 const auto& input_array = model->GetArray(op->inputs[0]);
1324 // Yield until input dims have been resolved.
1325 if (!input_array.has_shape()) {
1326 return;
1327 }
1328 const auto& input_shape = input_array.shape();
1329 CHECK_EQ(input_shape.dimensions_count(), 4);
1330 const auto input_height = input_shape.dims(1);
1331 const auto input_width = input_shape.dims(2);
1332
1333 const auto& block_shape_array = model->GetArray(op->inputs[1]);
1334 const auto& crops_array = model->GetArray(op->inputs[2]);
1335 const auto& block_shape_array_shape = block_shape_array.shape();
1336 const auto& crops_array_shape = crops_array.shape();
1337 QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
1338 QCHECK_EQ(crops_array_shape.dimensions_count(), 2);
1339
1340 // We only support two dimensions.
1341 QCHECK_EQ(block_shape_array_shape.dims(0), 2);
1342 if (!block_shape_array.buffer) {
1343 return;
1344 }
1345 QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
1346 const auto& block_shape_data =
1347 block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1348 auto block_height = block_shape_data[0];
1349 auto block_width = block_shape_data[1];
1350
1351 QCHECK_EQ(crops_array_shape.dims(0), 2); // Number of block dimensions
1352 QCHECK_EQ(crops_array_shape.dims(1), 2); // Two parameters per dimension.
1353 if (!crops_array.buffer) {
1354 return;
1355 }
1356 QCHECK(crops_array.data_type == ArrayDataType::kInt32);
1357 const auto& crops_data = crops_array.GetBuffer<ArrayDataType::kInt32>().data;
1358 const int crops_top = crops_data[0];
1359 const int crops_bottom = crops_data[1];
1360 const int crops_left = crops_data[2];
1361 const int crops_right = crops_data[3];
1362 const int output_height =
1363 input_height * block_height - crops_top - crops_bottom;
1364 const int output_width = input_width * block_width - crops_left - crops_right;
1365 QCHECK_EQ(input_shape.dims(0) % (block_height * block_width), 0);
1366
1367 model->GetArray(op->outputs[0])
1368 .copy_shape(Shape({input_shape.dims(0) / (block_height * block_width),
1369 output_height, output_width, input_shape.dims(3)}));
1370 }
1371
ProcessGatherOperator(Model * model,GatherOperator * op)1372 void ProcessGatherOperator(Model* model, GatherOperator* op) {
1373 const auto& input_array = model->GetArray(op->inputs[0]);
1374 const auto& indices_array = model->GetArray(op->inputs[1]);
1375 auto& output_array = model->GetArray(op->outputs[0]);
1376
1377 // Bail if we already know the output shape.
1378 if (output_array.has_shape()) {
1379 return;
1380 }
1381
1382 // Yield until input dims have been resolved.
1383 if (!input_array.has_shape() || !indices_array.has_shape()) {
1384 return;
1385 }
1386
1387 // Yield until the axis has been resolved.
1388 if (!op->axis) {
1389 return;
1390 }
1391 int axis = op->axis.value();
1392
1393 const auto& input_shape = input_array.shape();
1394 const auto& indices_shape = indices_array.shape();
1395 QCHECK_GE(input_shape.dimensions_count(), 1);
1396 op->input_rank = input_shape.dimensions_count();
1397 QCHECK_LT(axis, op->input_rank);
1398
1399 // Copy the input dimensions to the output except for the axis dimensions
1400 // where the dimension of indices_shape is used.
1401 auto output_dims = output_array.mutable_shape()->mutable_dims();
1402 for (int dim = 0; dim < axis; ++dim) {
1403 output_dims->push_back(input_shape.dims(dim));
1404 }
1405 for (int dim = 0; dim < indices_shape.dimensions_count(); ++dim) {
1406 output_dims->push_back(indices_shape.dims(dim));
1407 }
1408 for (int dim = axis + 1; dim < input_shape.dimensions_count(); ++dim) {
1409 output_dims->push_back(input_shape.dims(dim));
1410 }
1411 }
1412
ProcessGatherNdOperator(Model * model,GatherNdOperator * op)1413 void ProcessGatherNdOperator(Model* model, GatherNdOperator* op) {
1414 const auto& input_array = model->GetArray(op->inputs[0]);
1415 const auto& indices_array = model->GetArray(op->inputs[1]);
1416 auto& output_array = model->GetArray(op->outputs[0]);
1417
1418 // Bail if we already know the output shape.
1419 if (output_array.has_shape()) {
1420 return;
1421 }
1422
1423 // Yield until input dims have been resolved.
1424 if (!input_array.has_shape() || !indices_array.has_shape()) {
1425 return;
1426 }
1427
1428 const auto& input_shape = input_array.shape();
1429 const auto& indices_shape = indices_array.shape();
1430 QCHECK_GE(input_shape.dimensions_count(), 1);
1431 QCHECK_GE(indices_shape.dimensions_count(), 1);
1432 const int indices_nd =
1433 indices_shape.dims(indices_shape.dimensions_count() - 1);
1434 QCHECK_LE(indices_nd, input_shape.dimensions_count());
1435
1436 auto output_dims = output_array.mutable_shape()->mutable_dims();
1437 for (int dim = 0; dim < indices_shape.dimensions_count() - 1; ++dim) {
1438 output_dims->push_back(indices_shape.dims(dim));
1439 }
1440 for (int dim = indices_nd; dim < input_shape.dimensions_count(); ++dim) {
1441 output_dims->push_back(input_shape.dims(dim));
1442 }
1443 }
1444
ProcessTopkV2Operator(Model * model,TopKV2Operator * op)1445 void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) {
1446 const auto& input_values = model->GetArray(op->inputs[0]);
1447 const auto& input_k = model->GetArray(op->inputs[1]);
1448 auto& output_values = model->GetArray(op->outputs[0]);
1449 auto& output_indexes = model->GetArray(op->outputs[1]);
1450
1451 // Bail if we already know the output shape.
1452 if (output_indexes.has_shape()) {
1453 QCHECK(output_values.has_shape());
1454 return;
1455 }
1456
1457 // Yield until input dims have been resolved.
1458 if (!input_values.has_shape() || !input_k.has_shape()) {
1459 return;
1460 }
1461
1462 // If the value is initialized, we can specify the last dimension, otherwise
1463 // unknown.
1464 if (input_k.buffer) {
1465 const auto& input_values_shape = input_values.shape();
1466 auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims();
1467 auto output_values_dims = output_values.mutable_shape()->mutable_dims();
1468 for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) {
1469 output_indexes_dims->push_back(input_values_shape.dims(dim));
1470 output_values_dims->push_back(input_values_shape.dims(dim));
1471 }
1472 const int32_t k_value = input_k.GetBuffer<ArrayDataType::kInt32>().data[0];
1473 output_indexes_dims->push_back(k_value);
1474 output_values_dims->push_back(k_value);
1475 }
1476 }
1477
ProcessPadOperator(Model * model,PadOperator * op)1478 void ProcessPadOperator(Model* model, PadOperator* op) {
1479 CHECK_EQ(op->inputs.size(), 2);
1480 CHECK_EQ(op->outputs.size(), 1);
1481
1482 const auto& input_array = model->GetArray(op->inputs[0]);
1483
1484 // Yield until input dims have been resolved.
1485 if (!input_array.has_shape()) return;
1486
1487 if (op->left_padding.empty()) return;
1488 CHECK_EQ(op->left_padding.size(), op->right_padding.size());
1489
1490 auto& output_array = model->GetArray(op->outputs[0]);
1491 if (output_array.has_shape()) return;
1492
1493 Shape output_shape = input_array.shape();
1494 std::vector<int>& dims = *output_shape.mutable_dims();
1495 CHECK_EQ(op->left_padding.size(), dims.size());
1496
1497 for (int i = 0; i < op->left_padding.size(); ++i) {
1498 dims[i] += op->left_padding[i] + op->right_padding[i];
1499 }
1500
1501 output_array.copy_shape(output_shape);
1502 }
1503
ProcessPadV2Operator(Model * model,PadV2Operator * op)1504 void ProcessPadV2Operator(Model* model, PadV2Operator* op) {
1505 CHECK_EQ(op->inputs.size(), 3);
1506 CHECK_EQ(op->outputs.size(), 1);
1507
1508 const auto& input_array = model->GetArray(op->inputs[0]);
1509
1510 // Yield until input dims have been resolved.
1511 if (!input_array.has_shape()) return;
1512
1513 if (op->left_padding.empty()) return;
1514 CHECK_EQ(op->left_padding.size(), op->right_padding.size());
1515
1516 auto& output_array = model->GetArray(op->outputs[0]);
1517 if (output_array.has_shape()) return;
1518
1519 Shape output_shape = input_array.shape();
1520 std::vector<int>& dims = *output_shape.mutable_dims();
1521 CHECK_EQ(op->left_padding.size(), dims.size());
1522
1523 for (int i = 0; i < op->left_padding.size(); ++i) {
1524 dims[i] += op->left_padding[i] + op->right_padding[i];
1525 }
1526
1527 output_array.copy_shape(output_shape);
1528 }
1529
ProcessRankOperator(Model * model,TensorFlowRankOperator * op)1530 void ProcessRankOperator(Model* model, TensorFlowRankOperator* op) {
1531 CHECK_GE(op->inputs.size(), 1);
1532 CHECK_EQ(op->outputs.size(), 1);
1533 auto& output_array = model->GetArray(op->outputs[0]);
1534 if (output_array.has_shape()) {
1535 // Shape already propagated
1536 return;
1537 }
1538
1539 if (output_array.data_type == ArrayDataType::kNone) {
1540 // Yield until the output type has been set by PropagateArrayDataTypes
1541 return;
1542 }
1543
1544 const auto& input_array = model->GetArray(op->inputs[0]);
1545 if (!input_array.has_shape()) {
1546 // Yield until input dims have been resolved.
1547 return;
1548 }
1549
1550 // Only set the output shape. Array contents are set by
1551 // ResolveConstantShapeOrRank.
1552 Shape* output_shape = output_array.mutable_shape();
1553 output_shape->ReplaceDims({});
1554 }
1555
ProcessShapeOperator(Model * model,TensorFlowShapeOperator * op)1556 void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
1557 CHECK_GE(op->inputs.size(), 1);
1558 CHECK_EQ(op->outputs.size(), 1);
1559 auto& output_array = model->GetArray(op->outputs[0]);
1560 if (output_array.has_shape()) {
1561 // Shape already propagated
1562 return;
1563 }
1564
1565 if (output_array.data_type == ArrayDataType::kNone) {
1566 // Yield until the output type has been set by PropagateArrayDataTypes
1567 return;
1568 }
1569
1570 const auto& input_array = model->GetArray(op->inputs[0]);
1571 if (!input_array.has_shape()) {
1572 // Yield until input dims have been resolved.
1573 return;
1574 }
1575
1576 // Only set the output shape. Array contents are set by
1577 // ResolveConstantShapeOrRank.
1578 Shape* output_shape = output_array.mutable_shape();
1579 output_shape->ReplaceDims({input_array.shape().dimensions_count()});
1580 }
1581
ProcessPackOperator(Model * model,PackOperator * op)1582 void ProcessPackOperator(Model* model, PackOperator* op) {
1583 CHECK_GE(op->inputs.size(), 1);
1584 CHECK_EQ(op->outputs.size(), 1);
1585 auto& output_array = model->GetArray(op->outputs[0]);
1586 if (output_array.has_shape()) {
1587 // Shape already propagated
1588 return;
1589 }
1590
1591 std::unique_ptr<Shape> packed_shape;
1592 for (const auto& input : op->inputs) {
1593 const auto& input_array = model->GetArray(input);
1594 if (!input_array.has_shape()) {
1595 // Yield until all input dims have been resolved.
1596 return;
1597 }
1598
1599 Shape shape = input_array.shape();
1600 if (!packed_shape) {
1601 packed_shape.reset(new Shape(shape));
1602 } else {
1603 CHECK(*packed_shape == shape) << "All input arrays to Pack operators "
1604 "must have the same shape. Input \""
1605 << input << "\" is different.";
1606 }
1607 }
1608
1609 int axis = op->axis;
1610 if (axis < 0) {
1611 // Handle negative axis
1612 axis += packed_shape->dims().size() + 1;
1613 }
1614 packed_shape->mutable_dims()->insert(
1615 packed_shape->mutable_dims()->begin() + axis, op->inputs.size());
1616 output_array.copy_shape(*packed_shape);
1617 }
1618
ProcessStridedSliceOperator(Model * model,StridedSliceOperator * op)1619 void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
1620 CHECK_GE(op->inputs.size(), 1);
1621 CHECK_EQ(op->outputs.size(), 1);
1622 auto& output_array = model->GetArray(op->outputs[0]);
1623 if (output_array.has_shape()) {
1624 // Shape already propagated
1625 return;
1626 }
1627
1628 if (op->start_indices.empty() || op->stop_indices.empty() ||
1629 op->strides.empty()) {
1630 // ResolveStridedSliceAttributes has not run yet.
1631 return;
1632 }
1633
1634 const auto& input_array = model->GetArray(op->inputs[0]);
1635 if (!input_array.has_shape()) {
1636 // Yield until input dims have been resolved.
1637 return;
1638 }
1639
1640 if (op->ellipsis_mask != 0) {
1641 // Something like LOG_FIRST_N(WARNING, 10) would be prefferable to reduce
1642 // log noise. However, the TensorFlow logging library does not appear to
1643 // support this.
1644 LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0]
1645 << "\". ellipsis_mask is not supported (mask="
1646 << op->ellipsis_mask << ")";
1647 return;
1648 }
1649 if (op->new_axis_mask != 0) {
1650 LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0]
1651 << "\". new_axis_mask is not supported (mask="
1652 << op->new_axis_mask << ")";
1653 return;
1654 }
1655
1656 int num_input_axes = input_array.shape().dimensions_count();
1657 CHECK_LE(op->start_indices.size(), num_input_axes)
1658 << "StridedSlice op with output \"" << op->outputs[0]
1659 << "\", requires no more than " << num_input_axes << " start indices";
1660 CHECK_LE(op->stop_indices.size(), num_input_axes)
1661 << "StridedSlice op with output \"" << op->outputs[0]
1662 << "\", requires no more than " << num_input_axes << " stop indices";
1663 CHECK_LE(op->strides.size(), num_input_axes)
1664 << "StridedSlice op with output \"" << op->outputs[0]
1665 << "\", requires no more than " << num_input_axes << " strides";
1666 for (int i = 0; i < op->strides.size(); i++) {
1667 CHECK_NE(op->strides[i], 0) << "Strides must be non-zero. Axis " << i
1668 << " has stride=" << op->strides[i] << ".";
1669 }
1670
1671 // Create output shape
1672 std::vector<int>* dims = output_array.mutable_shape()->mutable_dims();
1673
1674 // Compute output shape
1675 for (int axis = 0; axis < num_input_axes; ++axis) {
1676 const auto strided_slice_params =
1677 tflite::strided_slice::BuildStridedSliceParams(
1678 op->begin_mask, op->end_mask, op->shrink_axis_mask,
1679 op->start_indices, op->stop_indices, op->strides);
1680 int start_index = tflite::strided_slice::StartForAxis(
1681 strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
1682 int stop_index = tflite::strided_slice::StopForAxis(
1683 strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
1684 start_index);
1685
1686 int dim_size =
1687 ceil(static_cast<float>(stop_index - start_index) / op->strides[axis]);
1688
1689 CHECK_GT(dim_size, 0)
1690 << "Output size for an axis must be greater than 0. Axis " << axis
1691 << " computes to size " << dim_size
1692 << " for StridedSlice op with output \"" << op->outputs[0] << "\".";
1693 if (op->shrink_axis_mask & (1 << axis)) {
1694 CHECK_EQ(dim_size, 1)
1695 << "Output size for an axis must compute to 1 when shrinking an "
1696 "axis. Axis "
1697 << axis << " computes to size " << dim_size
1698 << " for StridedSlice op with output \"" << op->outputs[0] << "\".";
1699 } else {
1700 dims->push_back(dim_size);
1701 }
1702 }
1703 }
1704
ProcessSqueezeOperator(Model * model,SqueezeOperator * op)1705 void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
1706 CHECK_EQ(op->inputs.size(), 1);
1707 CHECK_EQ(op->outputs.size(), 1);
1708
1709 const auto& input_array = model->GetArray(op->inputs[0]);
1710
1711 // Yield until input dims have been resolved.
1712 if (!input_array.has_shape()) return;
1713
1714 auto& output_array = model->GetArray(op->outputs[0]);
1715 if (output_array.has_shape()) return;
1716
1717 const std::vector<int>& input_dims = input_array.shape().dims();
1718 std::vector<int> output_dims;
1719
1720 std::vector<int> squeeze_dims;
1721 const int input_num_dims = input_dims.size();
1722 for (int i : op->squeeze_dims) {
1723 squeeze_dims.push_back(i < 0 ? i + input_num_dims : i);
1724 }
1725 for (int i = 0; i < input_num_dims; ++i) {
1726 if (input_dims[i] != 1 ||
1727 (!squeeze_dims.empty() &&
1728 std::find(squeeze_dims.begin(), squeeze_dims.end(), i) ==
1729 squeeze_dims.end())) {
1730 output_dims.push_back(input_dims[i]);
1731 }
1732 }
1733 *output_array.mutable_shape()->mutable_dims() = output_dims;
1734 }
1735
ProcessSvdfOperator(Model * model,SvdfOperator * op)1736 void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
1737 CHECK(op->inputs.size() == 3 || op->inputs.size() == 4);
1738 const auto& input_array = model->GetArray(op->inputs[0]);
1739 if (!input_array.has_shape()) return;
1740
1741 auto& weights_feature_array = model->GetArray(op->inputs[1]);
1742 if (!weights_feature_array.has_shape()) return;
1743
1744 const auto& weights_time_array = model->GetArray(op->inputs[2]);
1745 if (!weights_time_array.has_shape()) return;
1746
1747 const bool has_bias = (op->inputs.size() == 4);
1748 if (has_bias) {
1749 const auto& bias_array = model->GetArray(op->inputs[3]);
1750 if (!bias_array.has_shape()) return;
1751 }
1752
1753 const int batch_size = input_array.shape().dims()[0];
1754 const int num_units = weights_feature_array.shape().dims()[0];
1755 const int memory_size = weights_time_array.shape().dims()[1];
1756
1757 auto& state_array = model->GetArray(op->outputs[0]);
1758 state_array.mutable_shape()->ReplaceDims(
1759 {batch_size, memory_size * num_units});
1760
1761 auto& output_array = model->GetArray(op->outputs[1]);
1762 output_array.mutable_shape()->ReplaceDims({batch_size, num_units});
1763 }
1764
ProcessTransposeOperator(Model * model,TransposeOperator * op)1765 void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
1766 auto& output_array = model->GetArray(op->outputs[0]);
1767 if (output_array.has_shape()) {
1768 // We have already run
1769 return;
1770 }
1771
1772 const auto& input_array = model->GetArray(op->inputs[0]);
1773 if (!input_array.has_shape()) {
1774 // Yield until input dims have been resolved.
1775 return;
1776 }
1777 const auto& input_shape = input_array.shape();
1778
1779 auto& perm_array = model->GetArray(op->inputs[1]);
1780 if (!perm_array.has_shape()) {
1781 // Yield until permutation shape been resolved.
1782 return;
1783 }
1784 if (!perm_array.buffer) {
1785 // Yield until the permutation is constant
1786 return;
1787 }
1788 CHECK(perm_array.data_type == ArrayDataType::kInt32)
1789 << "Transpose permutation input must be int32";
1790
1791 std::vector<int32> const& perm =
1792 perm_array.GetBuffer<ArrayDataType::kInt32>().data;
1793 CHECK_EQ(perm.size(), input_shape.dimensions_count())
1794 << "Transpose permutation input " << op->inputs[1]
1795 << " must be same length as input dimensions";
1796 std::vector<int>* output_dims = output_array.mutable_shape()->mutable_dims();
1797 for (int i = 0; i < perm.size(); i++) {
1798 int axis = perm[i];
1799 CHECK_GE(axis, 0);
1800 CHECK_LT(axis, input_shape.dimensions_count());
1801 output_dims->push_back(input_shape.dims(axis));
1802 }
1803 }
1804
1805 template <typename Op>
ProcessArgMinMaxOperator(Model * model,Op * op)1806 void ProcessArgMinMaxOperator(Model* model, Op* op) {
1807 CHECK_EQ(op->inputs.size(), 2);
1808 const auto& input_array = model->GetArray(op->inputs[0]);
1809 // Yield until input dims have been resolved.
1810 if (!input_array.has_shape()) {
1811 return;
1812 }
1813
1814 const Array& axis_array = model->GetArray(op->inputs[1]);
1815 // Yield until input axis array shape has been resolved.
1816 if (!axis_array.has_shape()) {
1817 return;
1818 }
1819
1820 const std::vector<int>& input_dims = input_array.shape().dims();
1821
1822 CHECK(axis_array.data_type == ArrayDataType::kInt32 ||
1823 axis_array.data_type == ArrayDataType::kInt64)
1824 << "axis_array must be int32, int64";
1825
1826 CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
1827 << "Axis array must be scalar.";
1828
1829 int64 axis;
1830 if (axis_array.data_type == ArrayDataType::kInt32) {
1831 axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
1832 } else {
1833 axis = axis_array.GetBuffer<ArrayDataType::kInt64>().data[0];
1834 }
1835
1836 std::vector<int> output_dims;
1837
1838 output_dims.reserve(input_dims.size() - 1);
1839 for (int i = 0; i < input_dims.size(); ++i) {
1840 if (i != axis) {
1841 output_dims.push_back(input_dims[i]);
1842 }
1843 }
1844
1845 const string& output_name = op->outputs[0];
1846 auto& output_array = model->GetArray(output_name);
1847 if (output_array.has_shape()) {
1848 return;
1849 }
1850 *output_array.mutable_shape()->mutable_dims() = output_dims;
1851 }
1852
ProcessSparseToDenseOperator(Model * model,SparseToDenseOperator * op)1853 void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) {
1854 CHECK_EQ(op->inputs.size(), 4);
1855
1856 const Array& output_shape_array = model->GetArray(op->inputs[1]);
1857 if (!output_shape_array.has_shape()) return;
1858 CHECK_EQ(output_shape_array.shape().dimensions_count(), 1);
1859
1860 // Output should not go over four dimensions.
1861 CHECK_LE(output_shape_array.shape().dims(0), 4);
1862
1863 const string& output_name = op->outputs[0];
1864 Array& output_array = model->GetArray(output_name);
1865 if (output_array.has_shape()) return;
1866
1867 CHECK(output_shape_array.data_type == ArrayDataType::kInt32 ||
1868 output_shape_array.data_type == ArrayDataType::kInt64);
1869 if (output_shape_array.data_type == ArrayDataType::kInt32) {
1870 *output_array.mutable_shape()->mutable_dims() =
1871 output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1872 } else {
1873 const std::vector<int64>& output_shape_data =
1874 output_shape_array.GetBuffer<ArrayDataType::kInt64>().data;
1875 std::copy(
1876 output_shape_data.begin(), output_shape_data.end(),
1877 std::back_inserter(*output_array.mutable_shape()->mutable_dims()));
1878 }
1879 }
1880
ProcessTileOperator(Model * model,TensorFlowTileOperator * op)1881 void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) {
1882 CHECK_EQ(op->inputs.size(), 2);
1883 CHECK_EQ(op->outputs.size(), 1);
1884
1885 auto& output_array = model->GetArray(op->outputs[0]);
1886 if (output_array.has_shape()) {
1887 // We have already run.
1888 return;
1889 }
1890
1891 const auto& input_array = model->GetArray(op->inputs[0]);
1892 if (!input_array.has_shape()) {
1893 // Yield until input dims have been resolved.
1894 return;
1895 }
1896 const auto& input_shape = input_array.shape();
1897
1898 auto& multiples_array = model->GetArray(op->inputs[1]);
1899 if (!multiples_array.has_shape()) {
1900 // Yield until multiples shape been resolved.
1901 return;
1902 }
1903 if (!multiples_array.buffer) {
1904 // Yield until the multiples is constant.
1905 return;
1906 }
1907 CHECK(multiples_array.data_type == ArrayDataType::kInt32)
1908 << "Tile multiples input must be int32";
1909
1910 std::vector<int32> const& multiples =
1911 multiples_array.GetBuffer<ArrayDataType::kInt32>().data;
1912 CHECK_EQ(multiples.size(), input_shape.dimensions_count())
1913 << "Tile multiples input " << op->inputs[1]
1914 << " must be same length as input dimensions";
1915
1916 auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
1917 mutable_dims->resize(multiples.size());
1918 for (int i = 0; i < mutable_dims->size(); ++i) {
1919 (*mutable_dims)[i] = input_shape.dims(i) * multiples[i];
1920 }
1921 }
1922
ProcessOneHotOperator(Model * model,OneHotOperator * op)1923 void ProcessOneHotOperator(Model* model, OneHotOperator* op) {
1924 CHECK_EQ(op->inputs.size(), 4);
1925 CHECK_EQ(op->outputs.size(), 1);
1926 auto& output_array = model->GetArray(op->outputs[0]);
1927 if (output_array.has_shape()) {
1928 // Shape already propagated
1929 return;
1930 }
1931
1932 // Yield until indices dims have been resolved.
1933 const auto& indices_array =
1934 model->GetArray(op->inputs[OneHotOperator::INDICES_INPUT]);
1935 if (!indices_array.has_shape()) {
1936 return;
1937 }
1938
1939 // Yield until depth is constant and dims have been resolved.
1940 if (!IsConstantParameterArray(*model,
1941 op->inputs[OneHotOperator::DEPTH_INPUT])) {
1942 return;
1943 }
1944 const auto& depth_array =
1945 model->GetArray(op->inputs[OneHotOperator::DEPTH_INPUT]);
1946 if (!depth_array.has_shape()) {
1947 return;
1948 }
1949
1950 CHECK(depth_array.data_type == ArrayDataType::kInt32)
1951 << "Depth array must be int32.";
1952 CHECK_EQ(RequiredBufferSizeForShape(depth_array.shape()), 1)
1953 << "Depth array must be scalar.";
1954
1955 const int depth = depth_array.GetBuffer<ArrayDataType::kInt32>().data[0];
1956 CHECK_GE(depth, 0) << "Depth must be non-negative.";
1957
1958 const int indices_dims = indices_array.shape().dimensions_count();
1959 const int output_dims = indices_dims + 1;
1960 const int axis = op->axis == -1 ? indices_dims : op->axis;
1961 CHECK_GE(axis, 0) << "Resolved axis must be non-negative.";
1962
1963 auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
1964 mutable_dims->resize(output_dims);
1965 for (int i = 0; i < output_dims; ++i) {
1966 int dim = 0;
1967 if (i < axis) {
1968 dim = indices_array.shape().dims(i);
1969 } else if (i == axis) {
1970 dim = depth;
1971 } else {
1972 dim = indices_array.shape().dims(i - 1);
1973 }
1974 (*mutable_dims)[i] = dim;
1975 }
1976 }
1977
ProcessUnpackOperator(Model * model,UnpackOperator * op)1978 void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
1979 CHECK_EQ(op->inputs.size(), 1);
1980 const auto& input_array = model->GetArray(op->inputs[0]);
1981 // Yield until input dims have been resolved.
1982 if (!input_array.has_shape()) {
1983 return;
1984 }
1985
1986 const std::vector<int>& input_dims = input_array.shape().dims();
1987 std::vector<int> output_dims;
1988
1989 output_dims.reserve(input_dims.size() - 1);
1990 for (int i = 0; i < input_dims.size(); ++i) {
1991 if (i != op->axis) {
1992 output_dims.push_back(input_dims[i]);
1993 }
1994 }
1995 for (const string& output_name : op->outputs) {
1996 auto& output_array = model->GetArray(output_name);
1997 if (output_array.has_shape()) {
1998 return;
1999 }
2000 *output_array.mutable_shape()->mutable_dims() = output_dims;
2001 }
2002 }
2003
ProcessMirrorPadOperator(Model * model,MirrorPadOperator * op)2004 void ProcessMirrorPadOperator(Model* model, MirrorPadOperator* op) {
2005 CHECK_EQ(op->inputs.size(), 2);
2006 const auto& input_array = model->GetArray(op->inputs[0]);
2007 const auto& padding_matrix = model->GetArray(op->inputs[1]);
2008
2009 // Yield until input dims have been resolved.
2010 if (!input_array.has_shape()) {
2011 return;
2012 }
2013
2014 auto& output_array = model->GetArray(op->outputs[0]);
2015 // If output already computed or padding matrix is non
2016 // const then return.
2017 if (output_array.has_shape() ||
2018 !IsConstantParameterArray(*model, op->inputs[1])) {
2019 return;
2020 }
2021 Shape output_shape = input_array.shape();
2022 std::vector<int>& dims = *output_shape.mutable_dims();
2023
2024 std::vector<int64_t> padding;
2025 if (padding_matrix.data_type == ArrayDataType::kInt32) {
2026 const auto& data = padding_matrix.GetBuffer<ArrayDataType::kInt32>().data;
2027 for (auto elem : data) {
2028 padding.push_back(static_cast<int64_t>(elem));
2029 }
2030 } else if (padding_matrix.data_type == ArrayDataType::kInt64) {
2031 const auto& data = padding_matrix.GetBuffer<ArrayDataType::kInt64>().data;
2032 for (auto elem : data) {
2033 padding.push_back(elem);
2034 }
2035 } else {
2036 CHECK(padding_matrix.data_type == ArrayDataType::kInt64 ||
2037 padding_matrix.data_type == ArrayDataType::kInt32);
2038 }
2039 CHECK_EQ(padding_matrix.shape().dimensions_count(), 2);
2040 CHECK_EQ(input_array.shape().dimensions_count(),
2041 padding_matrix.shape().dims(0));
2042 for (int i = 0; i < input_array.shape().dimensions_count(); ++i) {
2043 dims[i] += padding[i * 2] + padding[i * 2 + 1];
2044 }
2045
2046 output_array.copy_shape(output_shape);
2047 }
2048
ProcessUniqueOperator(Model * model,UniqueOperator * op)2049 void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
2050 const auto& input_array = model->GetArray(op->inputs[0]);
2051 // We have 2 outputs, the shape of the index tensor, is the same size
2052 // as the input array. The unique values tensor, is unknown until runtime.
2053 CHECK_EQ(op->outputs.size(), 2);
2054 auto& idx_output_array = model->GetArray(op->outputs[1]);
2055
2056 // Yield until input dims have been resolved, or output already computed
2057 if (!input_array.has_shape() || idx_output_array.has_shape()) {
2058 return;
2059 }
2060 idx_output_array.copy_shape(input_array.shape());
2061 }
2062
2063 } // namespace
2064
Run(Model * model,std::size_t op_index,bool * modified)2065 ::tensorflow::Status PropagateFixedSizes::Run(Model* model,
2066 std::size_t op_index,
2067 bool* modified) {
2068 *modified = false;
2069 auto it = model->operators.begin() + op_index;
2070 auto* op = it->get();
2071 std::unordered_map<string, std::vector<int>> old_output_dims;
2072 for (const auto& output : op->outputs) {
2073 if (model->GetArray(output).has_shape()) {
2074 old_output_dims[output] = model->GetArray(output).shape().dims();
2075 }
2076 }
2077
2078 switch (op->type) {
2079 case OperatorType::kAbs:
2080 case OperatorType::kBatchNormalization:
2081 case OperatorType::kL2Normalization:
2082 case OperatorType::kDequantize:
2083 case OperatorType::kElu:
2084 case OperatorType::kRelu:
2085 case OperatorType::kRelu1:
2086 case OperatorType::kRelu6:
2087 case OperatorType::kPRelu:
2088 case OperatorType::kLeakyRelu:
2089 case OperatorType::kSoftmax:
2090 case OperatorType::kLogSoftmax:
2091 case OperatorType::kLog:
2092 case OperatorType::kLogistic:
2093 case OperatorType::kTanh:
2094 case OperatorType::kLocalResponseNormalization:
2095 case OperatorType::kIdentity:
2096 case OperatorType::kFakeQuant:
2097 case OperatorType::kNeg:
2098 case OperatorType::kRsqrt:
2099 case OperatorType::kSqrt:
2100 case OperatorType::kSquare:
2101 case OperatorType::kAll:
2102 case OperatorType::kAssert:
2103 case OperatorType::kCast:
2104 case OperatorType::kFloor:
2105 case OperatorType::kCeil:
2106 case OperatorType::kExp:
2107 case OperatorType::kSin:
2108 case OperatorType::kCos:
2109 case OperatorType::kLogicalAnd:
2110 case OperatorType::kLogicalNot:
2111 case OperatorType::kLogicalOr:
2112 case OperatorType::kZerosLike:
2113 case OperatorType::kReverseV2:
2114 case OperatorType::kReverseSequence:
2115 ProcessSimpleOperator(model, op, 0);
2116 break;
2117 case OperatorType::kGather:
2118 ProcessGatherOperator(model, static_cast<GatherOperator*>(op));
2119 break;
2120 case OperatorType::kGatherNd:
2121 ProcessGatherNdOperator(model, static_cast<GatherNdOperator*>(op));
2122 break;
2123 case OperatorType::kTopK_V2:
2124 ProcessTopkV2Operator(model, static_cast<TopKV2Operator*>(op));
2125 break;
2126 case OperatorType::kAdd:
2127 case OperatorType::kSub:
2128 case OperatorType::kMul:
2129 case OperatorType::kDiv:
2130 case OperatorType::kFloorDiv:
2131 case OperatorType::kFloorMod:
2132 case OperatorType::kLess:
2133 case OperatorType::kLessEqual:
2134 case OperatorType::kGreater:
2135 case OperatorType::kMaximum: // Element-wise Maximum
2136 case OperatorType::kMinimum: // Element-wise Minimum
2137 case OperatorType::kGreaterEqual:
2138 case OperatorType::kEqual:
2139 case OperatorType::kNotEqual:
2140 case OperatorType::kPow:
2141 case OperatorType::kSquaredDifference:
2142 ProcessSimpleBinaryOperator(model, op);
2143 break;
2144 case OperatorType::kAddN:
2145 ProcessAddNOperator(model, op);
2146 break;
2147 case OperatorType::kConv:
2148 ProcessConvOperator(model, static_cast<ConvOperator*>(op));
2149 break;
2150 case OperatorType::kTransposeConv:
2151 ProcessTransposeConvOperator(model,
2152 static_cast<TransposeConvOperator*>(op));
2153 break;
2154 case OperatorType::kDepthwiseConv:
2155 ProcessDepthwiseConvOperator(model,
2156 static_cast<DepthwiseConvOperator*>(op));
2157 break;
2158 case OperatorType::kDepthToSpace:
2159 ProcessDepthToSpaceOperator(model,
2160 static_cast<DepthToSpaceOperator*>(op));
2161 break;
2162 case OperatorType::kSpaceToDepth:
2163 ProcessSpaceToDepthOperator(model,
2164 static_cast<SpaceToDepthOperator*>(op));
2165 break;
2166 case OperatorType::kFill:
2167 CHECK_EQ(op->inputs.size(), 2);
2168 ProcessOpWithShapeInput(model, op);
2169 break;
2170 case OperatorType::kFullyConnected:
2171 ProcessFullyConnectedOperator(model,
2172 static_cast<FullyConnectedOperator*>(op));
2173 break;
2174 case OperatorType::kReshape:
2175 ProcessTensorFlowReshapeOperator(
2176 model, static_cast<TensorFlowReshapeOperator*>(op));
2177 break;
2178 case OperatorType::kAveragePool:
2179 ProcessAveragePoolOperator(model, static_cast<AveragePoolOperator*>(op));
2180 break;
2181 case OperatorType::kMaxPool:
2182 ProcessMaxPoolOperator(model, static_cast<MaxPoolOperator*>(op));
2183 break;
2184 case OperatorType::kL2Pool:
2185 ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op));
2186 break;
2187 case OperatorType::kReduceMin: // Reduction Min
2188 case OperatorType::kReduceMax: // Reduction Max
2189 case OperatorType::kSum:
2190 case OperatorType::kReduceProd:
2191 case OperatorType::kMean:
2192 case OperatorType::kAny:
2193 ProcessTensorFlowReductionOperator(model, op);
2194 break;
2195 case OperatorType::kSelect:
2196 ProcessSelectOperator(model, static_cast<SelectOperator*>(op));
2197 break;
2198 case OperatorType::kSlice:
2199 ProcessSliceOperator(model, static_cast<SliceOperator*>(op));
2200 break;
2201
2202 case OperatorType::kSwitch:
2203 // We can't know the sizes of the outputs until we have resolved the
2204 // predicate, and once we have resolved the predicate, the whole
2205 // Switch node will get resolved away.
2206 // See ResolveTensorFlowSwitch.
2207 break;
2208 case OperatorType::kMerge:
2209 // No need to bother resolving TensorFlow Merge ops: other graph
2210 // transformations will remove them anyway.
2211 // See ResolveTensorFlowMerge.
2212 break;
2213 case OperatorType::kSplit:
2214 ProcessTensorFlowSplitOperator(model,
2215 static_cast<TensorFlowSplitOperator*>(op));
2216 break;
2217 case OperatorType::kSplitV:
2218 ProcessTensorFlowSplitVOperator(
2219 model, static_cast<TensorFlowSplitVOperator*>(op));
2220 break;
2221 case OperatorType::kSqueeze:
2222 ProcessSqueezeOperator(model, static_cast<SqueezeOperator*>(op));
2223 break;
2224 case OperatorType::kConcat:
2225 case OperatorType::kConcatV2:
2226 // Unimplemented, hopefully another graph transformation will
2227 // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat
2228 // will resolve this node to a DepthConcatenation, or else we have
2229 // a more general non-depth concatenation that will hopefully be dropped,
2230 // or else at the moment we will abort.
2231 break;
2232 case OperatorType::kExpandDims:
2233 // Yield until ExpandDims is converted to Reshape
2234 break;
2235 case OperatorType::kRange:
2236 ProcessRangeOperator(model, static_cast<RangeOperator*>(op));
2237 break;
2238 case OperatorType::kRank:
2239 ProcessRankOperator(model, static_cast<TensorFlowRankOperator*>(op));
2240 break;
2241 case OperatorType::kShape:
2242 ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op));
2243 break;
2244 case OperatorType::kPack:
2245 ProcessPackOperator(model, static_cast<PackOperator*>(op));
2246 break;
2247 case OperatorType::kReorderAxes:
2248 ProcessReorderAxesOperator(model, static_cast<ReorderAxesOperator*>(op));
2249 break;
2250 case OperatorType::kConcatenation:
2251 ProcessConcatenationOperator(model,
2252 static_cast<ConcatenationOperator*>(op));
2253 break;
2254 case OperatorType::kResizeBilinear:
2255 ProcessResizeBilinearOperator(model,
2256 static_cast<ResizeBilinearOperator*>(op));
2257 break;
2258 case OperatorType::kResizeNearestNeighbor:
2259 ProcessResizeNearestNeighborOperator(
2260 model, static_cast<ResizeNearestNeighborOperator*>(op));
2261 break;
2262 case OperatorType::kUnidirectionalSequenceLstm:
2263 ProcessUnidirectionalSequenceLstmOperator(
2264 model, static_cast<UnidirectionalSequenceLstmOperator*>(op));
2265 break;
2266 case OperatorType::kUnidirectionalSequenceRnn:
2267 ProcessUnidirectionalSequenceRnnOperator(
2268 model, static_cast<UnidirectionalSequenceRnnOperator*>(op));
2269 break;
2270 case OperatorType::kBidirectionalSequenceLstm:
2271 ProcessBidirectionalSequenceLstmOperator(
2272 model, static_cast<BidirectionalSequenceLstmOperator*>(op));
2273 break;
2274 case OperatorType::kBidirectionalSequenceRnn:
2275 ProcessBidirectionalSequenceRnnOperator(
2276 model, static_cast<BidirectionalSequenceRnnOperator*>(op));
2277 break;
2278 case OperatorType::kLstmCell:
2279 ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
2280 break;
2281 case OperatorType::kBatchMatMul:
2282 case OperatorType::kMatMul:
2283 // MatMul operators are converted to FullyConnected, after which their
2284 // shapes are propagated.
2285 break;
2286 case OperatorType::kSpaceToBatchND:
2287 ProcessSpaceToBatchNDOperator(model,
2288 static_cast<SpaceToBatchNDOperator*>(op));
2289 break;
2290 case OperatorType::kBatchToSpaceND:
2291 ProcessBatchToSpaceNDOperator(model,
2292 static_cast<BatchToSpaceNDOperator*>(op));
2293 break;
2294 case OperatorType::kPad:
2295 ProcessPadOperator(model, static_cast<PadOperator*>(op));
2296 break;
2297 case OperatorType::kPadV2:
2298 ProcessPadV2Operator(model, static_cast<PadV2Operator*>(op));
2299 break;
2300 case OperatorType::kStridedSlice:
2301 ProcessStridedSliceOperator(model,
2302 static_cast<StridedSliceOperator*>(op));
2303 break;
2304 case OperatorType::kArgMax:
2305 ProcessArgMinMaxOperator<ArgMaxOperator>(
2306 model, static_cast<ArgMaxOperator*>(op));
2307 break;
2308 case OperatorType::kArgMin:
2309 ProcessArgMinMaxOperator<ArgMinOperator>(
2310 model, static_cast<ArgMinOperator*>(op));
2311 break;
2312 case OperatorType::kUnsupported: {
2313 const auto* unsupported_op =
2314 static_cast<TensorFlowUnsupportedOperator*>(op);
2315 // Attribute can be not specified, ignore it.
2316 if (unsupported_op->output_shapes.size() < op->outputs.size()) {
2317 return ::tensorflow::Status::OK();
2318 }
2319 for (int i = 0; i < op->outputs.size(); ++i) {
2320 const string& output = op->outputs[i];
2321 model->GetArray(output).copy_shape(unsupported_op->output_shapes.at(i));
2322 }
2323 break;
2324 }
2325 case OperatorType::kSvdf:
2326 ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op));
2327 break;
2328 case OperatorType::kTranspose:
2329 ProcessTransposeOperator(model, static_cast<TransposeOperator*>(op));
2330 break;
2331 case OperatorType::kDynamicPartition:
2332 case OperatorType::kDynamicStitch:
2333 // DynamicPartition/DynamicStitch are currently only supported for
2334 // transforms that remove them, so we avoid propagating shapes through
2335 // them and let things settle once they've been removed.
2336 break;
2337 case OperatorType::kRandomUniform:
2338 CHECK_EQ(op->inputs.size(), 1);
2339 ProcessOpWithShapeInput(model, op);
2340 break;
2341 case OperatorType::kSparseToDense:
2342 ProcessSparseToDenseOperator(model,
2343 static_cast<SparseToDenseOperator*>(op));
2344 break;
2345 case OperatorType::kTile:
2346 ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op));
2347 break;
2348 break;
2349 case OperatorType::kOneHot:
2350 ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op));
2351 break;
2352 case OperatorType::kUnpack:
2353 ProcessUnpackOperator(model, static_cast<UnpackOperator*>(op));
2354 break;
2355 case OperatorType::kMirrorPad:
2356 ProcessMirrorPadOperator(model, static_cast<MirrorPadOperator*>(op));
2357 break;
2358 case OperatorType::kUnique:
2359 ProcessUniqueOperator(model, static_cast<UniqueOperator*>(op));
2360 break;
2361 case OperatorType::kWhere:
2362 // The size of the output can only be known after evaluating the cond
2363 // tensor. Ignore shape propagation here and defer that to the
2364 // interpreter.
2365 break;
2366 default:
2367 // Unimplemented, another graph transformation should drop it.
2368 LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
2369 }
2370
2371 // Return true if any output dim changed, false if none changed.
2372 // Assumption: no transformation clears an output shape, they only add shapes.
2373 for (const auto& output : op->outputs) {
2374 if (model->GetArray(output).has_shape() &&
2375 (old_output_dims[output] != model->GetArray(output).shape().dims())) {
2376 AddMessageF("Set shape of %s to [%s]", output,
2377 absl::StrJoin(model->GetArray(output).shape().dims(), ","));
2378 *modified = true;
2379 return ::tensorflow::Status::OK();
2380 }
2381 }
2382 return ::tensorflow::Status::OK();
2383 }
2384
2385 } // namespace toco
2386