1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20
21 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
22 #include "tensorflow/lite/toco/model.h"
23 #include "tensorflow/lite/toco/runtime/types.h"
24 #include "tensorflow/lite/toco/tooling_util.h"
25 #include "tensorflow/core/platform/logging.h"
26
27 namespace toco {
28
29 namespace {
30
VectorGreaterThan(const std::vector<int> & a,const std::vector<int> & b)31 std::vector<bool> VectorGreaterThan(const std::vector<int>& a,
32 const std::vector<int>& b) {
33 DCHECK_EQ(a.size(), b.size());
34 const int size = a.size();
35 std::vector<bool> result(size);
36 for (int i = 0; i < size; i++) {
37 result[i] = a[i] > b[i];
38 }
39 return result;
40 }
41
PairwiseVectorSelect(const std::vector<bool> & selector,const std::vector<int> & input_a,const std::vector<int> & input_b,std::vector<int> * output_a,std::vector<int> * output_b)42 void PairwiseVectorSelect(const std::vector<bool>& selector,
43 const std::vector<int>& input_a,
44 const std::vector<int>& input_b,
45 std::vector<int>* output_a,
46 std::vector<int>* output_b) {
47 DCHECK_EQ(input_a.size(), input_b.size());
48 DCHECK_EQ(output_a->size(), output_b->size());
49 DCHECK_EQ(input_a.size(), output_a->size());
50 DCHECK_EQ(selector.size(), input_a.size());
51 const int size = input_a.size();
52 for (int i = 0; i < size; i++) {
53 if (selector[i]) {
54 (*output_a)[i] = input_a[i];
55 (*output_b)[i] = input_b[i];
56 } else {
57 (*output_a)[i] = input_b[i];
58 (*output_b)[i] = input_a[i];
59 }
60 }
61 }
62
63 template <ArrayDataType InputsDataType, ArrayDataType OutputDataType>
EvaluateBinaryOperatorOnConstantInputs(Model * model,const Operator * binary_op)64 void EvaluateBinaryOperatorOnConstantInputs(Model* model,
65 const Operator* binary_op) {
66 CHECK(IsConstantParameterArray(*model, binary_op->inputs[0]));
67 CHECK(IsConstantParameterArray(*model, binary_op->inputs[1]));
68 CHECK(binary_op->fused_activation_function ==
69 FusedActivationFunctionType::kNone);
70 const auto& input0_array = model->GetArray(binary_op->inputs[0]);
71 const auto& input1_array = model->GetArray(binary_op->inputs[1]);
72 const auto& output_name = binary_op->outputs[0];
73 auto& output_array = model->GetArray(output_name);
74 CHECK(input0_array.data_type == InputsDataType);
75 CHECK(input1_array.data_type == InputsDataType);
76 CHECK(output_array.data_type == OutputDataType);
77
78 // We have already tested above for existence of input buffers
79 // (synonymous to being a constant param).
80 CHECK(input0_array.buffer);
81 CHECK(input1_array.buffer);
82 // On the other hand, the output should not already have a buffer.
83 CHECK(!output_array.buffer);
84
85 const auto& input0_data = input0_array.GetBuffer<InputsDataType>().data;
86 const auto& input1_data = input1_array.GetBuffer<InputsDataType>().data;
87 // Create the buffer on the output array, effectively turning it into
88 // a constant parameter
89
90 const Shape& output_shape = output_array.shape();
91 auto& output_data = output_array.GetMutableBuffer<OutputDataType>().data;
92 const int output_buffer_size = RequiredBufferSizeForShape(output_shape);
93 output_data.resize(output_buffer_size);
94 const int dims_count = output_shape.dimensions_count();
95
96 // It will be convenient here to have copies of the operands shapes
97 // extended to match the number of dimensions of the output shape.
98 Shape input0_shape = input0_array.shape();
99 Shape input1_shape = input1_array.shape();
100 ExtendShape(&input0_shape, dims_count);
101 ExtendShape(&input1_shape, dims_count);
102 // Now we may still have operands of different sizes, which would indicate
103 // that we have to "broadcast" the smaller dimension. We do this using a
104 // a vector of Booleans indicating which input is the larger in each
105 // dimension.
106 CHECK_EQ(input0_shape.dimensions_count(), input1_shape.dimensions_count());
107 CHECK_EQ(input0_shape.dimensions_count(), dims_count);
108 const std::vector<bool> input0_larger =
109 VectorGreaterThan(input0_shape.dims(), input1_shape.dims());
110
111 std::vector<int> big_sizes(dims_count);
112 std::vector<int> small_sizes(dims_count);
113 PairwiseVectorSelect(input0_larger, input0_shape.dims(), input1_shape.dims(),
114 &big_sizes, &small_sizes);
115
116 // The output should already be correctly sized to match the big dimensions.
117 for (int i = 0; i < dims_count; i++) {
118 CHECK_EQ(output_shape.dims(i), big_sizes[i]);
119 }
120
121 std::vector<int> input0_indices(dims_count);
122 std::vector<int> input1_indices(dims_count);
123 std::vector<int> modulo_indices(dims_count);
124
125 for (int k = 0; k < output_buffer_size; k++) {
126 const std::vector<int> output_indices = ReverseOffset(output_shape, k);
127 for (int i = 0; i < dims_count; i++) {
128 modulo_indices[i] = output_indices[i] % small_sizes[i];
129 }
130 PairwiseVectorSelect(input0_larger, output_indices, modulo_indices,
131 &input0_indices, &input1_indices);
132 const auto val0 = input0_data[Offset(input0_shape, input0_indices)];
133 const auto val1 = input1_data[Offset(input1_shape, input1_indices)];
134
135 DataType<OutputDataType> outval;
136 if (binary_op->type == OperatorType::kAdd) {
137 outval = val0 + val1;
138 } else if (binary_op->type == OperatorType::kMul) {
139 outval = val0 * val1;
140 } else if (binary_op->type == OperatorType::kSub) {
141 outval = val0 - val1;
142 } else if (binary_op->type == OperatorType::kDiv) {
143 outval = val0 / val1;
144 } else if (binary_op->type == OperatorType::kFloorDiv) {
145 outval = floor(val0 / val1);
146 } else if (binary_op->type == OperatorType::kFloorMod) {
147 outval = val0 - (floor(val0 / val1) * val1);
148 } else if (binary_op->type == OperatorType::kMinimum) {
149 outval = std::min(val0, val1);
150 } else if (binary_op->type == OperatorType::kMaximum) {
151 outval = std::max(val0, val1);
152 } else if (binary_op->type == OperatorType::kLess) {
153 outval = val0 < val1;
154 } else if (binary_op->type == OperatorType::kLessEqual) {
155 outval = val0 <= val1;
156 } else if (binary_op->type == OperatorType::kGreater) {
157 outval = val0 > val1;
158 } else if (binary_op->type == OperatorType::kGreaterEqual) {
159 outval = val0 >= val1;
160 } else {
161 LOG(FATAL) << "should not get here";
162 }
163 output_data[Offset(output_shape, output_indices)] = outval;
164 }
165 }
166
EvaluateBinaryOperatorOnConstantInputs(Model * model,const Operator * binary_op)167 void EvaluateBinaryOperatorOnConstantInputs(Model* model,
168 const Operator* binary_op) {
169 const auto inputs_data_type = model->GetArray(binary_op->inputs[0]).data_type;
170 const auto output_data_type =
171 model->GetArray(binary_op->outputs[0]).data_type;
172 #define TOCO_HANDLE_CASE(InputsDataType, OutputDataType) \
173 if (inputs_data_type == InputsDataType && \
174 output_data_type == OutputDataType) { \
175 EvaluateBinaryOperatorOnConstantInputs<InputsDataType, OutputDataType>( \
176 model, binary_op); \
177 return; \
178 }
179 TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kFloat)
180 TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kBool)
181 TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kInt32)
182 TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kBool)
183 TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kInt64)
184 TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kBool)
185 LOG(FATAL) << "Unimplemented: don't know how to resolve a constant "
186 << "binary operator for these data types.";
187 #undef TOCO_HANDLE_CASE
188 }
189 } // namespace
190
Run(Model * model,std::size_t op_index,bool * modified)191 ::tensorflow::Status ResolveConstantBinaryOperator::Run(Model* model,
192 std::size_t op_index,
193 bool* modified) {
194 *modified = false;
195 const auto binary_it = model->operators.begin() + op_index;
196 const auto* binary_op = binary_it->get();
197 // Test for binary ops of types that we know how to resolve
198 if (binary_op->type != OperatorType::kAdd &&
199 binary_op->type != OperatorType::kMul &&
200 binary_op->type != OperatorType::kSub &&
201 binary_op->type != OperatorType::kDiv &&
202 binary_op->type != OperatorType::kFloorDiv &&
203 binary_op->type != OperatorType::kFloorMod &&
204 binary_op->type != OperatorType::kMinimum &&
205 binary_op->type != OperatorType::kMaximum &&
206 binary_op->type != OperatorType::kLess &&
207 binary_op->type != OperatorType::kLessEqual &&
208 binary_op->type != OperatorType::kGreater &&
209 binary_op->type != OperatorType::kGreaterEqual) {
210 return ::tensorflow::Status::OK();
211 }
212 CHECK_EQ(binary_op->inputs.size(), 2);
213
214 const auto& input0_array = model->GetArray(binary_op->inputs[0]);
215 const auto& input1_array = model->GetArray(binary_op->inputs[1]);
216 // Check if both inputs are constant parameters.
217 if (!input0_array.buffer || !input1_array.buffer) {
218 return ::tensorflow::Status::OK();
219 }
220
221 auto& output_array = model->GetArray(binary_op->outputs[0]);
222 // Yield until the output array dims have been resolved.
223 if (!output_array.has_shape()) {
224 return ::tensorflow::Status::OK();
225 }
226
227 // At the moment we don't want to care about fused activation functions.
228 // The idea is that we should do the present constants-propagation before
229 // activation functions get fused.
230 if (binary_op->fused_activation_function !=
231 FusedActivationFunctionType::kNone) {
232 AddMessageF(
233 "Not resolving constant %s because it has a fused activation function",
234 LogName(*binary_op));
235 return ::tensorflow::Status::OK();
236 }
237
238 // Check that input data types agree.
239 CHECK(input0_array.data_type == input1_array.data_type)
240 << "Dissimilar data types given to op outputting \""
241 << binary_op->outputs[0] << "\". 0:\"" << binary_op->inputs[0] << "\"("
242 << static_cast<int>(input0_array.data_type) << ") 1:\""
243 << binary_op->inputs[1] << "\"("
244 << static_cast<int>(input1_array.data_type) << ").";
245
246 // Do the actual constants propagation
247 EvaluateBinaryOperatorOnConstantInputs(model, binary_op);
248
249 // Remove the binary operator and its inputs
250 if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) {
251 model->EraseArray(binary_op->inputs[0]);
252 }
253 if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) {
254 model->EraseArray(binary_op->inputs[1]);
255 }
256 AddMessageF("Resolved constant %s to the equivalent constant array",
257 LogName(*binary_op));
258 model->operators.erase(binary_it);
259 *modified = true;
260 return ::tensorflow::Status::OK();
261 }
262
263 } // namespace toco
264