1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <string>
17 #include <unordered_map>
18 #include <vector>
19
20 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
21 #include "tensorflow/lite/toco/model.h"
22 #include "tensorflow/lite/toco/tooling_util.h"
23 #include "tensorflow/core/platform/logging.h"
24
25 namespace toco {
26
Run(Model * model,std::size_t op_index,bool * modified)27 ::tensorflow::Status ConvertPureConvToDepthwise::Run(Model* model,
28 std::size_t op_index,
29 bool* modified) {
30 *modified = false;
31 auto conv_it = model->operators.begin() + op_index;
32 if (conv_it->get()->type != OperatorType::kConv) {
33 return ::tensorflow::Status::OK();
34 }
35 const auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
36 if (conv_op->stride_width != conv_op->stride_height) {
37 return ::tensorflow::Status::OK();
38 }
39 if ((conv_op->dilation_width_factor != 1) ||
40 (conv_op->dilation_height_factor != 1)) {
41 // Depthwise conv does not support dilation
42 return ::tensorflow::Status::OK();
43 }
44 auto& input_array = model->GetArray(conv_op->inputs[0]);
45 if (!input_array.has_shape()) {
46 // Shapes not propagated yet
47 return ::tensorflow::Status::OK();
48 }
49 if (input_array.shape().dims(3) != 1) {
50 // Not a pure convolution: Conv does accumulation across the depth
51 // dimension.
52 return ::tensorflow::Status::OK();
53 }
54
55 const auto& weights_name = conv_op->inputs[1];
56 if (CountOpsWithInput(*model, weights_name) > 1) {
57 // TODO(yunluli): Come up with a way to do the weights shuffling only once.
58 AddMessageF(
59 "Not changing %s to DepthwiseConv because the weights is consumed by "
60 "another op.",
61 LogName(*conv_op));
62 return ::tensorflow::Status::OK();
63 }
64 auto& weights_array = model->GetArray(weights_name);
65 if (!weights_array.buffer) {
66 // Yield until the weights are resolved as a constant array.
67 return ::tensorflow::Status::OK();
68 }
69 if (weights_array.data_type != ArrayDataType::kFloat) {
70 return ::tensorflow::Status::OK();
71 }
72 // At this point we know we have a pure conv. Rewrite it as DepthwiseConv.
73 AddMessageF(
74 "%s is purely convolutional (input/weights depth is 1), replacing it by "
75 "a DepthwiseConv.",
76 LogName(*conv_op));
77 auto* depthwiseconv_op = new DepthwiseConvOperator;
78 // Conv and DepthwiseConv take the same inputs
79 depthwiseconv_op->inputs = conv_op->inputs;
80 // Conv may have a 2nd output for im2col
81 depthwiseconv_op->outputs = {conv_op->outputs[0]};
82 if (conv_op->outputs.size() > 1) {
83 // delete the im2col array.
84 model->EraseArray(conv_op->outputs[1]);
85 }
86 depthwiseconv_op->fused_activation_function =
87 conv_op->fused_activation_function;
88 // Let PropagateFixedSizes recompute fixed padding, just in case some day it
89 // may be different for Conv vs DepthwiseConv.
90 depthwiseconv_op->padding.type = conv_op->padding.type;
91 depthwiseconv_op->stride_height = conv_op->stride_height;
92 depthwiseconv_op->stride_width = conv_op->stride_width;
93 depthwiseconv_op->depth_multiplier = weights_array.shape().dims(0);
94 // Replace the operator in the graph.
95 const auto depthwiseconv_it =
96 model->operators.emplace(conv_it, depthwiseconv_op);
97 conv_it = depthwiseconv_it + 1;
98 CHECK_EQ(conv_it->get(), conv_op);
99 model->operators.erase(conv_it);
100 // Shuffle the weights.
101 const auto& weights_shape = weights_array.shape();
102 auto& weights_buffer =
103 weights_array.GetMutableBuffer<ArrayDataType::kFloat>();
104 const std::vector<float>& conv_weights_data = weights_buffer.data;
105 std::vector<float> depthwise_conv_weights_data(conv_weights_data.size());
106 const int depth = weights_shape.dims(0);
107 const int width = weights_shape.dims(1);
108 const int height = weights_shape.dims(2);
109 const int width_height = width * height;
110 for (int c = 0; c < depth; c++) {
111 for (int xy = 0; xy < width_height; xy++) {
112 depthwise_conv_weights_data[c + depth * xy] =
113 conv_weights_data[xy + width_height * c];
114 }
115 }
116 *weights_array.mutable_shape()->mutable_dims() = {1, width, height, depth};
117 weights_buffer.data = depthwise_conv_weights_data;
118 *modified = true;
119 return ::tensorflow::Status::OK();
120 }
121
122 } // namespace toco
123