1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <string>
17 #include <unordered_map>
18 #include <vector>
19
20 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
21 #include "tensorflow/lite/toco/model.h"
22 #include "tensorflow/lite/toco/runtime/types.h"
23 #include "tensorflow/lite/toco/tooling_util.h"
24 #include "tensorflow/core/platform/logging.h"
25
26 namespace toco {
27
Run(Model * model,std::size_t op_index,bool * modified)28 ::tensorflow::Status ResolveBatchNormalization::Run(Model* model,
29 std::size_t op_index,
30 bool* modified) {
31 *modified = false;
32 auto bn_it = model->operators.begin() + op_index;
33 if (bn_it->get()->type != OperatorType::kBatchNormalization) {
34 return ::tensorflow::Status::OK();
35 }
36 const auto* bn_op =
37 static_cast<const BatchNormalizationOperator*>(bn_it->get());
38
39 auto& mean_array = model->GetArray(bn_op->inputs[1]);
40 const auto& multiplier_array = model->GetArray(bn_op->inputs[2]);
41 const auto& offset_array = model->GetArray(bn_op->inputs[3]);
42
43 // This graph transformation needs to address constant buffers below, so
44 // we need to exit early if these buffers don't exist yet (i.e. if the params
45 // haven't yet been resolved as constants) and will process it once they have.
46 if (!mean_array.buffer || !multiplier_array.buffer || !offset_array.buffer) {
47 return ::tensorflow::Status::OK();
48 }
49
50 CHECK(IsConstantParameterArray(*model, bn_op->inputs[1]) &&
51 IsConstantParameterArray(*model, bn_op->inputs[2]) &&
52 IsConstantParameterArray(*model, bn_op->inputs[3]))
53 << "Batch normalization resolution requires that mean, multiplier and "
54 "offset arrays be constant.";
55
56 // We should only have *float* BatchNormalizations... let's guard this
57 // assumption by CHECK's.
58 CHECK(mean_array.data_type == ArrayDataType::kFloat);
59 CHECK(multiplier_array.data_type == ArrayDataType::kFloat);
60 CHECK(offset_array.data_type == ArrayDataType::kFloat);
61
62 // Create the new Mul, Add operators
63 auto* mul_op = new MulOperator;
64 auto* add_op = new AddOperator;
65 const std::string mul_name =
66 AvailableArrayName(*model, bn_op->outputs[0] + "_mul");
67 const std::string add_name =
68 AvailableArrayName(*model, bn_op->outputs[0] + "_add");
69 const std::string mul_param_name =
70 AvailableArrayName(*model, mul_name + "_param");
71 const std::string add_param_name =
72 AvailableArrayName(*model, add_name + "_param");
73 mul_op->inputs = {bn_op->inputs[0], mul_param_name};
74 mul_op->outputs = {mul_name};
75 add_op->inputs = {mul_name, add_param_name};
76 add_op->outputs = {bn_op->outputs[0]};
77 AddMessageF("Splitting %s into %s and %s", LogName(*bn_op), LogName(*mul_op),
78 LogName(*add_op));
79
80 // Create the intermediate activation array (output of mul, input of add)
81 auto& intermediate_array = model->GetOrCreateArray(mul_op->outputs[0]);
82 intermediate_array.data_type = model->GetArray(bn_op->inputs[0]).data_type;
83
84 // Insert the new operators in the graph
85 auto add_it = model->operators.emplace(bn_it, add_op);
86 auto mul_it = model->operators.emplace(add_it, mul_op);
87 // update invalidated iterators.
88 DCHECK_EQ(mul_it->get(), mul_op);
89 add_it = mul_it + 1;
90 DCHECK_EQ(add_it->get(), add_op);
91 bn_it = add_it + 1;
92 DCHECK_EQ(bn_it->get(), bn_op);
93
94 // Create the new param arrays
95 auto& mean_shape = *mean_array.mutable_shape();
96 const auto& multiplier_shape = multiplier_array.shape();
97 const auto& offset_shape = offset_array.shape();
98 if (mean_shape.dims().empty()) {
99 *mean_shape.mutable_dims() = multiplier_shape.dims();
100 auto& data = mean_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
101 CHECK_EQ(data.size(), 1);
102 data.resize(RequiredBufferSizeForShape(mean_shape), data[0]);
103 }
104 CHECK(mean_shape.dims() == multiplier_shape.dims());
105 CHECK(mean_shape.dims() == offset_shape.dims());
106 const auto& param_shape = mean_shape;
107 const int buffer_size = RequiredBufferSizeForShape(param_shape);
108 auto& mul_param_array = model->GetOrCreateArray(mul_param_name);
109 auto& add_param_array = model->GetOrCreateArray(add_param_name);
110 DropMinMax(model, mul_param_name);
111 DropMinMax(model, add_param_name);
112 mul_param_array.copy_shape(param_shape);
113 add_param_array.copy_shape(param_shape);
114 mul_param_array.data_type = ArrayDataType::kFloat;
115 add_param_array.data_type = ArrayDataType::kFloat;
116 auto& mul_float_data =
117 mul_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
118 auto& add_float_data =
119 add_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
120 mul_float_data.resize(buffer_size);
121 add_float_data.resize(buffer_size);
122 const auto& mean_float_data =
123 mean_array.GetBuffer<ArrayDataType::kFloat>().data;
124 const auto& multiplier_float_data =
125 multiplier_array.GetBuffer<ArrayDataType::kFloat>().data;
126 const auto& offset_float_data =
127 offset_array.GetBuffer<ArrayDataType::kFloat>().data;
128 size_t buffer_size_for_compare = buffer_size;
129 CHECK(mul_float_data.size() == buffer_size_for_compare);
130 CHECK(add_float_data.size() == buffer_size_for_compare);
131 CHECK(mean_float_data.size() == buffer_size_for_compare);
132 CHECK(multiplier_float_data.size() == buffer_size_for_compare);
133 CHECK(offset_float_data.size() == buffer_size_for_compare);
134
135 for (int i = 0; i < buffer_size; i++) {
136 mul_float_data[i] = multiplier_float_data[i];
137 add_float_data[i] =
138 offset_float_data[i] - mean_float_data[i] * multiplier_float_data[i];
139 }
140
141 DeleteOpAndArrays(model, bn_op);
142
143 *modified = true;
144 return ::tensorflow::Status::OK();
145 }
146
147 } // namespace toco
148