1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <string>
17 #include <unordered_map>
18 #include <vector>
19 
20 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
21 #include "tensorflow/lite/toco/model.h"
22 #include "tensorflow/lite/toco/runtime/types.h"
23 #include "tensorflow/lite/toco/tooling_util.h"
24 #include "tensorflow/core/platform/logging.h"
25 
26 namespace toco {
27 
Run(Model * model,std::size_t op_index,bool * modified)28 ::tensorflow::Status ResolveBatchNormalization::Run(Model* model,
29                                                     std::size_t op_index,
30                                                     bool* modified) {
31   *modified = false;
32   auto bn_it = model->operators.begin() + op_index;
33   if (bn_it->get()->type != OperatorType::kBatchNormalization) {
34     return ::tensorflow::Status::OK();
35   }
36   const auto* bn_op =
37       static_cast<const BatchNormalizationOperator*>(bn_it->get());
38 
39   auto& mean_array = model->GetArray(bn_op->inputs[1]);
40   const auto& multiplier_array = model->GetArray(bn_op->inputs[2]);
41   const auto& offset_array = model->GetArray(bn_op->inputs[3]);
42 
43   // This graph transformation needs to address constant buffers below, so
44   // we need to exit early if these buffers don't exist yet (i.e. if the params
45   // haven't yet been resolved as constants) and will process it once they have.
46   if (!mean_array.buffer || !multiplier_array.buffer || !offset_array.buffer) {
47     return ::tensorflow::Status::OK();
48   }
49 
50   CHECK(IsConstantParameterArray(*model, bn_op->inputs[1]) &&
51         IsConstantParameterArray(*model, bn_op->inputs[2]) &&
52         IsConstantParameterArray(*model, bn_op->inputs[3]))
53       << "Batch normalization resolution requires that mean, multiplier and "
54          "offset arrays be constant.";
55 
56   // We should only have *float* BatchNormalizations... let's guard this
57   // assumption by CHECK's.
58   CHECK(mean_array.data_type == ArrayDataType::kFloat);
59   CHECK(multiplier_array.data_type == ArrayDataType::kFloat);
60   CHECK(offset_array.data_type == ArrayDataType::kFloat);
61 
62   // Create the new Mul, Add operators
63   auto* mul_op = new MulOperator;
64   auto* add_op = new AddOperator;
65   const std::string mul_name =
66       AvailableArrayName(*model, bn_op->outputs[0] + "_mul");
67   const std::string add_name =
68       AvailableArrayName(*model, bn_op->outputs[0] + "_add");
69   const std::string mul_param_name =
70       AvailableArrayName(*model, mul_name + "_param");
71   const std::string add_param_name =
72       AvailableArrayName(*model, add_name + "_param");
73   mul_op->inputs = {bn_op->inputs[0], mul_param_name};
74   mul_op->outputs = {mul_name};
75   add_op->inputs = {mul_name, add_param_name};
76   add_op->outputs = {bn_op->outputs[0]};
77   AddMessageF("Splitting %s into %s and %s", LogName(*bn_op), LogName(*mul_op),
78               LogName(*add_op));
79 
80   // Create the intermediate activation array (output of mul, input of add)
81   auto& intermediate_array = model->GetOrCreateArray(mul_op->outputs[0]);
82   intermediate_array.data_type = model->GetArray(bn_op->inputs[0]).data_type;
83 
84   // Insert the new operators in the graph
85   auto add_it = model->operators.emplace(bn_it, add_op);
86   auto mul_it = model->operators.emplace(add_it, mul_op);
87   // update invalidated iterators.
88   DCHECK_EQ(mul_it->get(), mul_op);
89   add_it = mul_it + 1;
90   DCHECK_EQ(add_it->get(), add_op);
91   bn_it = add_it + 1;
92   DCHECK_EQ(bn_it->get(), bn_op);
93 
94   // Create the new param arrays
95   auto& mean_shape = *mean_array.mutable_shape();
96   const auto& multiplier_shape = multiplier_array.shape();
97   const auto& offset_shape = offset_array.shape();
98   if (mean_shape.dims().empty()) {
99     *mean_shape.mutable_dims() = multiplier_shape.dims();
100     auto& data = mean_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
101     CHECK_EQ(data.size(), 1);
102     data.resize(RequiredBufferSizeForShape(mean_shape), data[0]);
103   }
104   CHECK(mean_shape.dims() == multiplier_shape.dims());
105   CHECK(mean_shape.dims() == offset_shape.dims());
106   const auto& param_shape = mean_shape;
107   const int buffer_size = RequiredBufferSizeForShape(param_shape);
108   auto& mul_param_array = model->GetOrCreateArray(mul_param_name);
109   auto& add_param_array = model->GetOrCreateArray(add_param_name);
110   DropMinMax(model, mul_param_name);
111   DropMinMax(model, add_param_name);
112   mul_param_array.copy_shape(param_shape);
113   add_param_array.copy_shape(param_shape);
114   mul_param_array.data_type = ArrayDataType::kFloat;
115   add_param_array.data_type = ArrayDataType::kFloat;
116   auto& mul_float_data =
117       mul_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
118   auto& add_float_data =
119       add_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
120   mul_float_data.resize(buffer_size);
121   add_float_data.resize(buffer_size);
122   const auto& mean_float_data =
123       mean_array.GetBuffer<ArrayDataType::kFloat>().data;
124   const auto& multiplier_float_data =
125       multiplier_array.GetBuffer<ArrayDataType::kFloat>().data;
126   const auto& offset_float_data =
127       offset_array.GetBuffer<ArrayDataType::kFloat>().data;
128   size_t buffer_size_for_compare = buffer_size;
129   CHECK(mul_float_data.size() == buffer_size_for_compare);
130   CHECK(add_float_data.size() == buffer_size_for_compare);
131   CHECK(mean_float_data.size() == buffer_size_for_compare);
132   CHECK(multiplier_float_data.size() == buffer_size_for_compare);
133   CHECK(offset_float_data.size() == buffer_size_for_compare);
134 
135   for (int i = 0; i < buffer_size; i++) {
136     mul_float_data[i] = multiplier_float_data[i];
137     add_float_data[i] =
138         offset_float_data[i] - mean_float_data[i] * multiplier_float_data[i];
139   }
140 
141   DeleteOpAndArrays(model, bn_op);
142 
143   *modified = true;
144   return ::tensorflow::Status::OK();
145 }
146 
147 }  // namespace toco
148