1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <string>
17 #include <unordered_map>
18 #include <vector>
19 
20 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
21 #include "tensorflow/lite/toco/model.h"
22 #include "tensorflow/lite/toco/tooling_util.h"
23 #include "tensorflow/core/platform/logging.h"
24 
25 namespace toco {
26 
Run(Model * model,std::size_t op_index,bool * modified)27 ::tensorflow::Status ConvertPureConvToDepthwise::Run(Model* model,
28                                                      std::size_t op_index,
29                                                      bool* modified) {
30   *modified = false;
31   auto conv_it = model->operators.begin() + op_index;
32   if (conv_it->get()->type != OperatorType::kConv) {
33     return ::tensorflow::Status::OK();
34   }
35   const auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
36   if (conv_op->stride_width != conv_op->stride_height) {
37     return ::tensorflow::Status::OK();
38   }
39   if ((conv_op->dilation_width_factor != 1) ||
40       (conv_op->dilation_height_factor != 1)) {
41     // Depthwise conv does not support dilation
42     return ::tensorflow::Status::OK();
43   }
44   auto& input_array = model->GetArray(conv_op->inputs[0]);
45   if (!input_array.has_shape()) {
46     // Shapes not propagated yet
47     return ::tensorflow::Status::OK();
48   }
49   if (input_array.shape().dims(3) != 1) {
50     // Not a pure convolution: Conv does accumulation across the depth
51     // dimension.
52     return ::tensorflow::Status::OK();
53   }
54 
55   const auto& weights_name = conv_op->inputs[1];
56   if (CountOpsWithInput(*model, weights_name) > 1) {
57     // TODO(yunluli): Come up with a way to do the weights shuffling only once.
58     AddMessageF(
59         "Not changing %s to DepthwiseConv because the weights is consumed by "
60         "another op.",
61         LogName(*conv_op));
62     return ::tensorflow::Status::OK();
63   }
64   auto& weights_array = model->GetArray(weights_name);
65   if (!weights_array.buffer) {
66     // Yield until the weights are resolved as a constant array.
67     return ::tensorflow::Status::OK();
68   }
69   if (weights_array.data_type != ArrayDataType::kFloat) {
70     return ::tensorflow::Status::OK();
71   }
72   // At this point we know we have a pure conv. Rewrite it as DepthwiseConv.
73   AddMessageF(
74       "%s is purely convolutional (input/weights depth is 1), replacing it by "
75       "a DepthwiseConv.",
76       LogName(*conv_op));
77   auto* depthwiseconv_op = new DepthwiseConvOperator;
78   // Conv and DepthwiseConv take the same inputs
79   depthwiseconv_op->inputs = conv_op->inputs;
80   // Conv may have a 2nd output for im2col
81   depthwiseconv_op->outputs = {conv_op->outputs[0]};
82   if (conv_op->outputs.size() > 1) {
83     // delete the im2col array.
84     model->EraseArray(conv_op->outputs[1]);
85   }
86   depthwiseconv_op->fused_activation_function =
87       conv_op->fused_activation_function;
88   // Let PropagateFixedSizes recompute fixed padding, just in case some day it
89   // may be different for Conv vs DepthwiseConv.
90   depthwiseconv_op->padding.type = conv_op->padding.type;
91   depthwiseconv_op->stride_height = conv_op->stride_height;
92   depthwiseconv_op->stride_width = conv_op->stride_width;
93   depthwiseconv_op->depth_multiplier = weights_array.shape().dims(0);
94   // Replace the operator in the graph.
95   const auto depthwiseconv_it =
96       model->operators.emplace(conv_it, depthwiseconv_op);
97   conv_it = depthwiseconv_it + 1;
98   CHECK_EQ(conv_it->get(), conv_op);
99   model->operators.erase(conv_it);
100   // Shuffle the weights.
101   const auto& weights_shape = weights_array.shape();
102   auto& weights_buffer =
103       weights_array.GetMutableBuffer<ArrayDataType::kFloat>();
104   const std::vector<float>& conv_weights_data = weights_buffer.data;
105   std::vector<float> depthwise_conv_weights_data(conv_weights_data.size());
106   const int depth = weights_shape.dims(0);
107   const int width = weights_shape.dims(1);
108   const int height = weights_shape.dims(2);
109   const int width_height = width * height;
110   for (int c = 0; c < depth; c++) {
111     for (int xy = 0; xy < width_height; xy++) {
112       depthwise_conv_weights_data[c + depth * xy] =
113           conv_weights_data[xy + width_height * c];
114     }
115   }
116   *weights_array.mutable_shape()->mutable_dims() = {1, width, height, depth};
117   weights_buffer.data = depthwise_conv_weights_data;
118   *modified = true;
119   return ::tensorflow::Status::OK();
120 }
121 
122 }  // namespace toco
123