1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <numeric>
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20 
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_join.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
25 #include "tensorflow/lite/toco/model.h"
26 #include "tensorflow/lite/toco/tooling_util.h"
27 
28 namespace toco {
29 
30 namespace {
31 
UnrollBatchMatMul3D(const string & input_lhs,const string & input_rhs,const BatchMatMulOperator * batch_op,const std::vector<int> batch,Model * model,std::vector<std::unique_ptr<Operator>>::iterator * tail_it,std::vector<string> * pack_inputs)32 void UnrollBatchMatMul3D(
33     const string& input_lhs, const string& input_rhs,
34     const BatchMatMulOperator* batch_op, const std::vector<int> batch,
35     Model* model, std::vector<std::unique_ptr<Operator>>::iterator* tail_it,
36     std::vector<string>* pack_inputs) {
37   const std::string batch_name =
38       absl::StrCat(batch_op->outputs[0], "_b", absl::StrJoin(batch, "-"));
39   const auto& input_array_a = model->GetArray(input_lhs);
40   const auto& input_array_b = model->GetArray(input_rhs);
41   const int dims_count = input_array_a.shape().dimensions_count();
42 
43   // tf.slice(a, ...).
44   std::vector<int> begin_indices_a = batch;
45   begin_indices_a.resize(dims_count);
46   std::vector<int> slice_size_a = input_array_a.shape().dims();
47   for (int i = 0; i < batch.size(); ++i) {
48     slice_size_a[i] = 1;
49   }
50   auto* slice_a_op = new SliceOperator;
51   slice_a_op->inputs = {
52       input_lhs,
53       CreateInt32Array(model, batch_name + "/slice_a/slice/begin",
54                        begin_indices_a),
55       CreateInt32Array(model, batch_name + "/slice_a/slice/size", slice_size_a),
56   };
57   slice_a_op->outputs = {AvailableArrayName(*model, batch_name + "/slice_a")};
58   auto& slice_a_op_output = model->GetOrCreateArray(slice_a_op->outputs[0]);
59   slice_a_op_output.data_type = input_array_a.data_type;
60   *tail_it = model->operators.emplace(*tail_it, slice_a_op) + 1;
61 
62   // Reshape to remove the first dimension ([1,M,N] -> [M,N]).
63   auto* slice_a_reshape_op = new TensorFlowReshapeOperator;
64   slice_a_reshape_op->inputs = {
65       slice_a_op->outputs[0],
66       CreateInt32Array(model, batch_name + "/slice_a/reshape/shape",
67                        {-1, input_array_a.shape().dims(dims_count - 1)})};
68   slice_a_reshape_op->outputs = {
69       AvailableArrayName(*model, batch_name + "/slice_a/reshape")};
70   auto& slice_a_reshape_op_output =
71       model->GetOrCreateArray(slice_a_reshape_op->outputs[0]);
72   slice_a_reshape_op_output.data_type = input_array_a.data_type;
73   *tail_it = model->operators.emplace(*tail_it, slice_a_reshape_op) + 1;
74 
75   // tf.slice(b, ...).
76   std::vector<int> begin_indices_b = batch;
77   begin_indices_b.resize(dims_count);
78   std::vector<int> slice_size_b = input_array_b.shape().dims();
79   for (int i = 0; i < batch.size(); ++i) {
80     slice_size_b[i] = 1;
81   }
82   auto* slice_b_op = new SliceOperator;
83   slice_b_op->inputs = {
84       input_rhs,
85       CreateInt32Array(model, batch_name + "/slice_b/slice/begin",
86                        begin_indices_b),
87       CreateInt32Array(model, batch_name + "/slice_b/slice/size", slice_size_b),
88   };
89   slice_b_op->outputs = {AvailableArrayName(*model, batch_name + "/slice_b")};
90   auto& slice_b_op_output = model->GetOrCreateArray(slice_b_op->outputs[0]);
91   slice_b_op_output.data_type = input_array_b.data_type;
92   *tail_it = model->operators.emplace(*tail_it, slice_b_op) + 1;
93 
94   // Reshape to remove the first dimension ([1,M,N] -> [M,N]).
95   auto* slice_b_reshape_op = new TensorFlowReshapeOperator;
96   slice_b_reshape_op->inputs = {
97       slice_b_op->outputs[0],
98       CreateInt32Array(model, batch_name + "/slice_b/reshape/shape",
99                        {-1, input_array_b.shape().dims(dims_count - 1)})};
100   slice_b_reshape_op->outputs = {
101       AvailableArrayName(*model, batch_name + "/slice_b/reshape")};
102   auto& slice_b_reshape_op_output =
103       model->GetOrCreateArray(slice_b_reshape_op->outputs[0]);
104   slice_b_reshape_op_output.data_type = input_array_b.data_type;
105   *tail_it = model->operators.emplace(*tail_it, slice_b_reshape_op) + 1;
106 
107   // tf.matmul(slice_a, slice_b).
108   auto* matmul_op = new TensorFlowMatMulOperator;
109   matmul_op->inputs = {slice_a_reshape_op->outputs[0],
110                        slice_b_reshape_op->outputs[0]};
111   matmul_op->outputs = {AvailableArrayName(*model, batch_name)};
112   auto& matmul_op_output = model->GetOrCreateArray(matmul_op->outputs[0]);
113   matmul_op_output.data_type = input_array_a.data_type;
114   *tail_it = model->operators.emplace(*tail_it, matmul_op) + 1;
115 
116   // Add to stack.
117   pack_inputs->push_back(matmul_op->outputs[0]);
118 }
119 
UnrollBatchMatMulRecursion(const string & input_lhs,const string & input_rhs,const BatchMatMulOperator * batch_op,Model * model,std::vector<std::unique_ptr<Operator>>::iterator * tail_it,const std::vector<int> & batch_prefix)120 std::vector<string> UnrollBatchMatMulRecursion(
121     const string& input_lhs, const string& input_rhs,
122     const BatchMatMulOperator* batch_op, Model* model,
123     std::vector<std::unique_ptr<Operator>>::iterator* tail_it,
124     const std::vector<int>& batch_prefix) {
125   const auto& input_array_a = model->GetArray(input_lhs);
126   const auto& dims_vec = input_array_a.shape().dims();
127   const int current_dim_size = dims_vec[batch_prefix.size()];
128   std::vector<string> batch_pack_inputs;
129 
130   if (batch_prefix.size() + 3 == dims_vec.size()) {
131     // Base case
132     for (int batch = 0; batch < current_dim_size; ++batch) {
133       std::vector<int> new_batch_prefix = batch_prefix;
134       new_batch_prefix.emplace_back(batch);
135       UnrollBatchMatMul3D(input_lhs, input_rhs, batch_op, new_batch_prefix,
136                           model, tail_it, &batch_pack_inputs);
137     }
138   } else {
139     // Recursion
140     for (int batch = 0; batch < current_dim_size; ++batch) {
141       std::vector<int> new_batch_prefix = batch_prefix;
142       new_batch_prefix.emplace_back(batch);
143       std::vector<string> pack_inputs = UnrollBatchMatMulRecursion(
144           input_lhs, input_rhs, batch_op, model, tail_it, new_batch_prefix);
145 
146       // The pack that will join all the individual matmul results together.
147       auto* pack_op = new PackOperator;
148       std::string batch_name = absl::StrCat(
149           batch_op->outputs[0], "_b", absl::StrJoin(new_batch_prefix, "-"));
150       pack_op->inputs = pack_inputs;
151       pack_op->outputs = {AvailableArrayName(*model, batch_name + "/pack")};
152       auto& pack_op_output = model->GetOrCreateArray(pack_op->outputs[0]);
153       pack_op_output.data_type = input_array_a.data_type;
154       pack_op->axis = 0;
155       pack_op->values_count = pack_inputs.size();
156       *tail_it = model->operators.emplace(*tail_it, pack_op) + 1;
157 
158       batch_pack_inputs.push_back(pack_op->outputs[0]);
159     }
160   }
161   return batch_pack_inputs;
162 }
163 
GetTransposePerm(const Array & input_array)164 std::vector<int32> GetTransposePerm(const Array& input_array) {
165   const int32 dims = input_array.shape().dimensions_count();
166   std::vector<int32> perm_array_val(dims);
167   for (int i = 0; i < dims; ++i) {
168     perm_array_val[i] = i;
169   }
170   perm_array_val[dims - 2] = dims - 1;
171   perm_array_val[dims - 1] = dims - 2;
172   return perm_array_val;
173 }
174 
GetTransposeShape(const Shape & input_shape,const std::vector<int32> & perm_array_val)175 std::vector<int32> GetTransposeShape(const Shape& input_shape,
176                                      const std::vector<int32>& perm_array_val) {
177   const int32 dims = input_shape.dimensions_count();
178   std::vector<int32> output_shape(dims);
179   for (int i = 0; i < dims; ++i) {
180     output_shape[i] = input_shape.dims(perm_array_val[i]);
181   }
182   return output_shape;
183 }
184 
TransposeInput(const string & input,Model * model)185 TransposeOperator* TransposeInput(const string& input, Model* model) {
186   const auto& input_array = model->GetArray(input);
187   const auto perm_array = GetTransposePerm(input_array);
188   const string perm_array_name = CreateInt32Array(
189       model, AvailableArrayName(*model, input + "/transpose/perm"), perm_array);
190   auto* transpose_op = new TransposeOperator;
191   transpose_op->inputs = {input, perm_array_name};
192   transpose_op->outputs = {AvailableArrayName(*model, input + "/transpose")};
193   auto& transpose_array = model->GetOrCreateArray(transpose_op->outputs[0]);
194   *transpose_array.mutable_shape()->mutable_dims() =
195       GetTransposeShape(input_array.shape(), perm_array);
196   model->GetOrCreateArray(transpose_op->outputs[0]);
197   return transpose_op;
198 }
199 
200 }  // namespace
201 
202 // Unrolls a BatchMatMul on the batch dimension.
203 // We need to slice each batch out of the inputs, matmul them individually, then
204 // stack them all back together at the end.
205 //
206 // This transform effectively looks like:
207 //  result_slices = []
208 //  for bat in B:
209 //    slice_a = tf.reshape(tf.slice(a, [bat, 0, 0], [1, M, N]), [M, N])
210 //    slice_b = tf.reshape(tf.slice(b, [bat, 0, 0], [1, M, N]), [M, N])
211 //    slice_c = tf.matmul(slice_a, slice_b)
212 //    result_slices[bat] = slice_c
213 //  result = tf.stack(result_slices)
Run(Model * model,std::size_t op_index,bool * modified)214 ::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index,
215                                             bool* modified) {
216   *modified = false;
217   auto batch_op_it = model->operators.begin() + op_index;
218   if (batch_op_it->get()->type != OperatorType::kBatchMatMul) {
219     return ::tensorflow::Status::OK();
220   }
221   const auto* batch_op =
222       static_cast<const BatchMatMulOperator*>(batch_op_it->get());
223 
224   auto& tail_it = batch_op_it;
225 
226   string input_lhs = batch_op->inputs[0];
227   string input_rhs = batch_op->inputs[1];
228   const auto& input_lhs_array = model->GetArray(input_lhs);
229   const auto& input_rhs_array = model->GetArray(input_rhs);
230   if (!input_lhs_array.has_shape() || !input_rhs_array.has_shape())
231     return ::tensorflow::Status::OK();
232 
233   // Transpose LHS input if necessary.
234   if (batch_op->adj_x) {
235     TransposeOperator* transpose_op = TransposeInput(input_lhs, model);
236     tail_it = model->operators.emplace(tail_it, transpose_op) + 1;
237     input_lhs = transpose_op->outputs[0];
238   }
239   const auto& input_array_a = model->GetArray(input_lhs);
240 
241   // Transpose RHS input if necessary.
242   if (batch_op->adj_y) {
243     TransposeOperator* transpose_op = TransposeInput(input_rhs, model);
244     tail_it = model->operators.emplace(tail_it, transpose_op) + 1;
245     input_rhs = transpose_op->outputs[0];
246   }
247   const auto& input_array_b = model->GetArray(input_rhs);
248 
249   const int dims = input_array_a.shape().dimensions_count();
250   for (int i = 0; i < dims - 2; ++i) {
251     CHECK_EQ(input_array_a.shape().dims(i), input_array_b.shape().dims(i))
252         << "input array not consistent at index " << i;
253   }
254   CHECK_EQ(input_array_a.shape().dims(dims - 1),
255            input_array_b.shape().dims(dims - 2))
256       << "Input dimensions must be compatible for multipication. shape a = ["
257       << absl::StrJoin(input_array_a.shape().dims(), ", ") << "], shape b = ["
258       << absl::StrJoin(input_array_b.shape().dims(), ", ") << "]";
259 
260   if (dims == 2) {
261     // This is really just a MatMul. This likely means that someone hand-crafted
262     // a graphdef with a BatchMatMul when they really wanted a MatMul.
263     AddMessageF("Replacing non-batch BatchMatMul %s by a MatMul operator",
264                 LogName(*batch_op));
265     auto* matmul_op = new TensorFlowMatMulOperator;
266     matmul_op->inputs = {input_lhs, input_rhs};
267     matmul_op->outputs = batch_op->outputs;
268     tail_it = model->operators.emplace(tail_it, matmul_op) + 1;
269     CHECK_EQ(tail_it->get(), batch_op);
270     model->operators.erase(tail_it);
271     *modified = true;
272     return ::tensorflow::Status::OK();
273   }
274 
275   CHECK_GE(input_array_a.shape().dimensions_count(), 3)
276       << "Input arrays must have rank >= 3";
277 
278   const auto& dims_vec = input_array_a.shape().dims();
279   AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op),
280               std::accumulate(dims_vec.begin(), dims_vec.end() - 2, 1,
281                               std::multiplies<int>()));
282 
283   std::vector<string> pack_inputs = UnrollBatchMatMulRecursion(
284       input_lhs, input_rhs, batch_op, model, &tail_it, {});
285   auto* pack_op = new PackOperator;
286   pack_op->inputs = pack_inputs;
287   pack_op->outputs = {batch_op->outputs[0]};
288   pack_op->axis = 0;
289   pack_op->values_count = pack_inputs.size();
290   model->operators.emplace(tail_it, pack_op);
291 
292   // Remove the old batch matmul now that we've unrolled.
293   batch_op_it = model->operators.begin();
294   for (; batch_op_it != model->operators.end(); ++batch_op_it) {
295     if (batch_op_it->get() == batch_op) {
296       break;
297     }
298   }
299   CHECK(batch_op_it != model->operators.end());
300   CHECK(batch_op_it->get() == batch_op);
301   model->operators.erase(batch_op_it);
302   *modified = true;
303   return ::tensorflow::Status::OK();
304 }
305 
306 }  // namespace toco
307