1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <memory>
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20 
21 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
22 #include "tensorflow/lite/toco/model.h"
23 #include "tensorflow/lite/toco/runtime/types.h"
24 #include "tensorflow/lite/toco/tooling_util.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace toco {
28 
29 namespace {
30 
VectorGreaterThan(const std::vector<int> & a,const std::vector<int> & b)31 std::vector<bool> VectorGreaterThan(const std::vector<int>& a,
32                                     const std::vector<int>& b) {
33   DCHECK_EQ(a.size(), b.size());
34   const int size = a.size();
35   std::vector<bool> result(size);
36   for (int i = 0; i < size; i++) {
37     result[i] = a[i] > b[i];
38   }
39   return result;
40 }
41 
PairwiseVectorSelect(const std::vector<bool> & selector,const std::vector<int> & input_a,const std::vector<int> & input_b,std::vector<int> * output_a,std::vector<int> * output_b)42 void PairwiseVectorSelect(const std::vector<bool>& selector,
43                           const std::vector<int>& input_a,
44                           const std::vector<int>& input_b,
45                           std::vector<int>* output_a,
46                           std::vector<int>* output_b) {
47   DCHECK_EQ(input_a.size(), input_b.size());
48   DCHECK_EQ(output_a->size(), output_b->size());
49   DCHECK_EQ(input_a.size(), output_a->size());
50   DCHECK_EQ(selector.size(), input_a.size());
51   const int size = input_a.size();
52   for (int i = 0; i < size; i++) {
53     if (selector[i]) {
54       (*output_a)[i] = input_a[i];
55       (*output_b)[i] = input_b[i];
56     } else {
57       (*output_a)[i] = input_b[i];
58       (*output_b)[i] = input_a[i];
59     }
60   }
61 }
62 
63 template <ArrayDataType InputsDataType, ArrayDataType OutputDataType>
EvaluateBinaryOperatorOnConstantInputs(Model * model,const Operator * binary_op)64 void EvaluateBinaryOperatorOnConstantInputs(Model* model,
65                                             const Operator* binary_op) {
66   CHECK(IsConstantParameterArray(*model, binary_op->inputs[0]));
67   CHECK(IsConstantParameterArray(*model, binary_op->inputs[1]));
68   CHECK(binary_op->fused_activation_function ==
69         FusedActivationFunctionType::kNone);
70   const auto& input0_array = model->GetArray(binary_op->inputs[0]);
71   const auto& input1_array = model->GetArray(binary_op->inputs[1]);
72   const auto& output_name = binary_op->outputs[0];
73   auto& output_array = model->GetArray(output_name);
74   CHECK(input0_array.data_type == InputsDataType);
75   CHECK(input1_array.data_type == InputsDataType);
76   CHECK(output_array.data_type == OutputDataType);
77 
78   // We have already tested above for existence of input buffers
79   // (synonymous to being a constant param).
80   CHECK(input0_array.buffer);
81   CHECK(input1_array.buffer);
82   // On the other hand, the output should not already have a buffer.
83   CHECK(!output_array.buffer);
84 
85   const auto& input0_data = input0_array.GetBuffer<InputsDataType>().data;
86   const auto& input1_data = input1_array.GetBuffer<InputsDataType>().data;
87   // Create the buffer on the output array, effectively turning it into
88   // a constant parameter
89 
90   const Shape& output_shape = output_array.shape();
91   auto& output_data = output_array.GetMutableBuffer<OutputDataType>().data;
92   const int output_buffer_size = RequiredBufferSizeForShape(output_shape);
93   output_data.resize(output_buffer_size);
94   const int dims_count = output_shape.dimensions_count();
95 
96   // It will be convenient here to have copies of the operands shapes
97   // extended to match the number of dimensions of the output shape.
98   Shape input0_shape = input0_array.shape();
99   Shape input1_shape = input1_array.shape();
100   ExtendShape(&input0_shape, dims_count);
101   ExtendShape(&input1_shape, dims_count);
102   // Now we may still have operands of different sizes, which would indicate
103   // that we have to "broadcast" the smaller dimension.  We do this using a
104   // a vector of Booleans indicating which input is the larger in each
105   // dimension.
106   CHECK_EQ(input0_shape.dimensions_count(), input1_shape.dimensions_count());
107   CHECK_EQ(input0_shape.dimensions_count(), dims_count);
108   const std::vector<bool> input0_larger =
109       VectorGreaterThan(input0_shape.dims(), input1_shape.dims());
110 
111   std::vector<int> big_sizes(dims_count);
112   std::vector<int> small_sizes(dims_count);
113   PairwiseVectorSelect(input0_larger, input0_shape.dims(), input1_shape.dims(),
114                        &big_sizes, &small_sizes);
115 
116   // The output should already be correctly sized to match the big dimensions.
117   for (int i = 0; i < dims_count; i++) {
118     CHECK_EQ(output_shape.dims(i), big_sizes[i]);
119   }
120 
121   std::vector<int> input0_indices(dims_count);
122   std::vector<int> input1_indices(dims_count);
123   std::vector<int> modulo_indices(dims_count);
124 
125   for (int k = 0; k < output_buffer_size; k++) {
126     const std::vector<int> output_indices = ReverseOffset(output_shape, k);
127     for (int i = 0; i < dims_count; i++) {
128       modulo_indices[i] = output_indices[i] % small_sizes[i];
129     }
130     PairwiseVectorSelect(input0_larger, output_indices, modulo_indices,
131                          &input0_indices, &input1_indices);
132     const auto val0 = input0_data[Offset(input0_shape, input0_indices)];
133     const auto val1 = input1_data[Offset(input1_shape, input1_indices)];
134 
135     DataType<OutputDataType> outval;
136     if (binary_op->type == OperatorType::kAdd) {
137       outval = val0 + val1;
138     } else if (binary_op->type == OperatorType::kMul) {
139       outval = val0 * val1;
140     } else if (binary_op->type == OperatorType::kSub) {
141       outval = val0 - val1;
142     } else if (binary_op->type == OperatorType::kDiv) {
143       outval = val0 / val1;
144     } else if (binary_op->type == OperatorType::kFloorDiv) {
145       outval = floor(val0 / val1);
146     } else if (binary_op->type == OperatorType::kFloorMod) {
147       outval = val0 - (floor(val0 / val1) * val1);
148     } else if (binary_op->type == OperatorType::kMinimum) {
149       outval = std::min(val0, val1);
150     } else if (binary_op->type == OperatorType::kMaximum) {
151       outval = std::max(val0, val1);
152     } else if (binary_op->type == OperatorType::kLess) {
153       outval = val0 < val1;
154     } else if (binary_op->type == OperatorType::kLessEqual) {
155       outval = val0 <= val1;
156     } else if (binary_op->type == OperatorType::kGreater) {
157       outval = val0 > val1;
158     } else if (binary_op->type == OperatorType::kGreaterEqual) {
159       outval = val0 >= val1;
160     } else {
161       LOG(FATAL) << "should not get here";
162     }
163     output_data[Offset(output_shape, output_indices)] = outval;
164   }
165 }
166 
EvaluateBinaryOperatorOnConstantInputs(Model * model,const Operator * binary_op)167 void EvaluateBinaryOperatorOnConstantInputs(Model* model,
168                                             const Operator* binary_op) {
169   const auto inputs_data_type = model->GetArray(binary_op->inputs[0]).data_type;
170   const auto output_data_type =
171       model->GetArray(binary_op->outputs[0]).data_type;
172 #define TOCO_HANDLE_CASE(InputsDataType, OutputDataType)                    \
173   if (inputs_data_type == InputsDataType &&                                 \
174       output_data_type == OutputDataType) {                                 \
175     EvaluateBinaryOperatorOnConstantInputs<InputsDataType, OutputDataType>( \
176         model, binary_op);                                                  \
177     return;                                                                 \
178   }
179   TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kFloat)
180   TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kBool)
181   TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kInt32)
182   TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kBool)
183   TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kInt64)
184   TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kBool)
185   LOG(FATAL) << "Unimplemented: don't know how to resolve a constant "
186              << "binary operator for these data types.";
187 #undef TOCO_HANDLE_CASE
188 }
189 }  // namespace
190 
Run(Model * model,std::size_t op_index,bool * modified)191 ::tensorflow::Status ResolveConstantBinaryOperator::Run(Model* model,
192                                                         std::size_t op_index,
193                                                         bool* modified) {
194   *modified = false;
195   const auto binary_it = model->operators.begin() + op_index;
196   const auto* binary_op = binary_it->get();
197   // Test for binary ops of types that we know how to resolve
198   if (binary_op->type != OperatorType::kAdd &&
199       binary_op->type != OperatorType::kMul &&
200       binary_op->type != OperatorType::kSub &&
201       binary_op->type != OperatorType::kDiv &&
202       binary_op->type != OperatorType::kFloorDiv &&
203       binary_op->type != OperatorType::kFloorMod &&
204       binary_op->type != OperatorType::kMinimum &&
205       binary_op->type != OperatorType::kMaximum &&
206       binary_op->type != OperatorType::kLess &&
207       binary_op->type != OperatorType::kLessEqual &&
208       binary_op->type != OperatorType::kGreater &&
209       binary_op->type != OperatorType::kGreaterEqual) {
210     return ::tensorflow::Status::OK();
211   }
212   CHECK_EQ(binary_op->inputs.size(), 2);
213 
214   const auto& input0_array = model->GetArray(binary_op->inputs[0]);
215   const auto& input1_array = model->GetArray(binary_op->inputs[1]);
216   // Check if both inputs are constant parameters.
217   if (!input0_array.buffer || !input1_array.buffer) {
218     return ::tensorflow::Status::OK();
219   }
220 
221   auto& output_array = model->GetArray(binary_op->outputs[0]);
222   // Yield until the output array dims have been resolved.
223   if (!output_array.has_shape()) {
224     return ::tensorflow::Status::OK();
225   }
226 
227   // At the moment we don't want to care about fused activation functions.
228   // The idea is that we should do the present constants-propagation before
229   // activation functions get fused.
230   if (binary_op->fused_activation_function !=
231       FusedActivationFunctionType::kNone) {
232     AddMessageF(
233         "Not resolving constant %s because it has a fused activation function",
234         LogName(*binary_op));
235     return ::tensorflow::Status::OK();
236   }
237 
238   // Check that input data types agree.
239   CHECK(input0_array.data_type == input1_array.data_type)
240       << "Dissimilar data types given to op outputting \""
241       << binary_op->outputs[0] << "\". 0:\"" << binary_op->inputs[0] << "\"("
242       << static_cast<int>(input0_array.data_type) << ")   1:\""
243       << binary_op->inputs[1] << "\"("
244       << static_cast<int>(input1_array.data_type) << ").";
245 
246   // Do the actual constants propagation
247   EvaluateBinaryOperatorOnConstantInputs(model, binary_op);
248 
249   // Remove the binary operator and its inputs
250   if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) {
251     model->EraseArray(binary_op->inputs[0]);
252   }
253   if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) {
254     model->EraseArray(binary_op->inputs[1]);
255   }
256   AddMessageF("Resolved constant %s to the equivalent constant array",
257               LogName(*binary_op));
258   model->operators.erase(binary_it);
259   *modified = true;
260   return ::tensorflow::Status::OK();
261 }
262 
263 }  // namespace toco
264