1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <numeric>
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20 
21 #include "absl/container/inlined_vector.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/util/matmul_bcast.h"
26 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
27 #include "tensorflow/lite/toco/model.h"
28 #include "tensorflow/lite/toco/tooling_util.h"
29 
30 namespace toco {
31 namespace {
32 
ToInlinedVector(const std::vector<int> & vec)33 absl::InlinedVector<int64, 4> ToInlinedVector(const std::vector<int>& vec) {
34   return absl::InlinedVector<int64, 4>(vec.begin(), vec.end());
35 }
36 
SliceInput(const std::string & input,const std::string & base_name,const std::string & input_name,const int batch_size,const Array & input_array,Model * model,std::vector<std::unique_ptr<Operator>>::iterator * tail_it)37 std::vector<std::string> SliceInput(
38     const std::string& input, const std::string& base_name,
39     const std::string& input_name, const int batch_size,
40     const Array& input_array, Model* model,
41     std::vector<std::unique_ptr<Operator>>::iterator* tail_it) {
42   int rank = input_array.shape().dimensions_count();
43   int num_rows = input_array.shape().dims(rank - 2);
44   int num_cols = input_array.shape().dims(rank - 1);
45   // Reshape to rank-3 Tensor with first dimension as the batch size.
46   auto* reshape_op = new TensorFlowReshapeOperator;
47   reshape_op->inputs = {
48       input,
49       CreateInt32Array(model, absl::StrCat(base_name, "/reshape_a/shape"),
50                        {batch_size, num_rows, num_cols})};
51   reshape_op->outputs = {AvailableArrayName(
52       *model, absl::StrCat(base_name, "/reshape_", input_name, "/reshape"))};
53   auto& reshape_op_output = model->GetOrCreateArray(reshape_op->outputs[0]);
54   reshape_op_output.data_type = input_array.data_type;
55   *tail_it = model->operators.emplace(*tail_it, reshape_op) + 1;
56 
57   // Slice along each batch index and remember the slice output for future use.
58   std::vector<std::string> slice_outputs;
59   for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
60     std::string batch_name =
61         absl::StrCat(base_name, "_b", batch_idx, "/slice_", input_name);
62     auto* slice_op = new SliceOperator;
63     slice_op->inputs = {
64         reshape_op->outputs[0],
65         CreateInt32Array(model, absl::StrCat(batch_name, "/slice/begin"),
66                          {batch_idx, 0, 0}),
67         CreateInt32Array(model, absl::StrCat(batch_name, "/slice/size"),
68                          {1, num_rows, num_cols})};
69     slice_op->outputs = {
70         AvailableArrayName(*model, absl::StrCat(batch_name, "/slice"))};
71     auto& slice_op_output = model->GetOrCreateArray(slice_op->outputs[0]);
72     slice_op_output.data_type = input_array.data_type;
73     *tail_it = model->operators.emplace(*tail_it, slice_op) + 1;
74 
75     // Reshape to rank-2: [1, num_rows, num_cols] -> [num_rows, num_cols].
76     auto* slice_reshape_op = new TensorFlowReshapeOperator;
77     slice_reshape_op->inputs = {
78         slice_op->outputs[0],
79         CreateInt32Array(model, absl::StrCat(batch_name, "/reshape/shape"),
80                          {num_rows, num_cols})};
81     slice_reshape_op->outputs = {
82         AvailableArrayName(*model, absl::StrCat(batch_name, "/reshape"))};
83     auto& slice_reshape_op_output =
84         model->GetOrCreateArray(slice_reshape_op->outputs[0]);
85     slice_reshape_op_output.data_type = input_array.data_type;
86     *tail_it = model->operators.emplace(*tail_it, slice_reshape_op) + 1;
87 
88     slice_outputs.push_back(slice_reshape_op->outputs[0]);
89   }
90   return slice_outputs;
91 }
92 
GetTransposePerm(const Array & input_array)93 std::vector<int32> GetTransposePerm(const Array& input_array) {
94   const int32 dims = input_array.shape().dimensions_count();
95   std::vector<int32> perm_array_val(dims);
96   for (int i = 0; i < dims; ++i) {
97     perm_array_val[i] = i;
98   }
99   perm_array_val[dims - 2] = dims - 1;
100   perm_array_val[dims - 1] = dims - 2;
101   return perm_array_val;
102 }
103 
GetTransposeShape(const Shape & input_shape,const std::vector<int32> & perm_array_val)104 std::vector<int32> GetTransposeShape(const Shape& input_shape,
105                                      const std::vector<int32>& perm_array_val) {
106   const int32 dims = input_shape.dimensions_count();
107   std::vector<int32> output_shape(dims);
108   for (int i = 0; i < dims; ++i) {
109     output_shape[i] = input_shape.dims(perm_array_val[i]);
110   }
111   return output_shape;
112 }
113 
TransposeInput(const std::string & input,Model * model)114 TransposeOperator* TransposeInput(const std::string& input, Model* model) {
115   const auto& input_array = model->GetArray(input);
116   const auto perm_array = GetTransposePerm(input_array);
117   const std::string perm_array_name = CreateInt32Array(
118       model, AvailableArrayName(*model, input + "/transpose/perm"), perm_array);
119   auto* transpose_op = new TransposeOperator;
120   transpose_op->inputs = {input, perm_array_name};
121   transpose_op->outputs = {AvailableArrayName(*model, input + "/transpose")};
122   auto& transpose_array = model->GetOrCreateArray(transpose_op->outputs[0]);
123   *transpose_array.mutable_shape()->mutable_dims() =
124       GetTransposeShape(input_array.shape(), perm_array);
125   model->GetOrCreateArray(transpose_op->outputs[0]);
126   return transpose_op;
127 }
128 
129 }  // namespace
130 
131 // Unrolls a BatchMatMul on the batch dimension.
132 // We need to slice each batch out of the inputs, matmul them individually, then
133 // stack them all back together at the end.
Run(Model * model,std::size_t op_index,bool * modified)134 ::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index,
135                                             bool* modified) {
136   *modified = false;
137   auto batch_op_it = model->operators.begin() + op_index;
138   if (batch_op_it->get()->type != OperatorType::kBatchMatMul) {
139     return ::tensorflow::Status::OK();
140   }
141   const auto* batch_op =
142       static_cast<const BatchMatMulOperator*>(batch_op_it->get());
143   auto& tail_it = batch_op_it;
144 
145   std::string input_lhs = batch_op->inputs[0];
146   std::string input_rhs = batch_op->inputs[1];
147   const auto& input_lhs_array = model->GetArray(input_lhs);
148   const auto& input_rhs_array = model->GetArray(input_rhs);
149   if (!input_lhs_array.has_shape() || !input_rhs_array.has_shape())
150     return ::tensorflow::Status::OK();
151 
152   // Transpose LHS input if necessary.
153   if (batch_op->adj_x) {
154     TransposeOperator* transpose_op = TransposeInput(input_lhs, model);
155     tail_it = model->operators.emplace(tail_it, transpose_op) + 1;
156     input_lhs = transpose_op->outputs[0];
157   }
158   const auto& input_array_a = model->GetArray(input_lhs);
159 
160   // Transpose RHS input if necessary.
161   if (batch_op->adj_y) {
162     TransposeOperator* transpose_op = TransposeInput(input_rhs, model);
163     tail_it = model->operators.emplace(tail_it, transpose_op) + 1;
164     input_rhs = transpose_op->outputs[0];
165   }
166   const auto& input_array_b = model->GetArray(input_rhs);
167 
168   // Ensure that input ranks are at least 2 and batch shapes are broadcastable.
169   const int dims_a = input_array_a.shape().dimensions_count();
170   const int dims_b = input_array_b.shape().dimensions_count();
171   CHECK_GE(dims_a, 2) << "First input must have rank >= 2";
172   CHECK_GE(dims_b, 2) << "Second input must have rank >= 2";
173 
174   ::tensorflow::MatMulBCast bcast(
175       ToInlinedVector(input_array_a.shape().dims()),
176       ToInlinedVector(input_array_b.shape().dims()));
177   CHECK(bcast.IsValid()) << "Input batch dimensions must be broadcastable";
178 
179   CHECK_EQ(input_array_a.shape().dims(dims_a - 1),
180            input_array_b.shape().dims(dims_b - 2))
181       << "Input dimensions must be compatible for multiplication. shape a = ["
182       << absl::StrJoin(input_array_a.shape().dims(), ", ") << "], shape b = ["
183       << absl::StrJoin(input_array_b.shape().dims(), ", ") << "]";
184 
185   if (dims_a == 2 && dims_b == 2) {
186     // This is really just a MatMul.
187     AddMessageF("Replacing non-batch BatchMatMul %s by a MatMul operator",
188                 LogName(*batch_op));
189     auto* matmul_op = new TensorFlowMatMulOperator;
190     matmul_op->inputs = {input_lhs, input_rhs};
191     matmul_op->outputs = batch_op->outputs;
192     model->operators.emplace(tail_it, matmul_op);
193     DeleteOpAndArrays(model, batch_op);
194     *modified = true;
195     return ::tensorflow::Status::OK();
196   }
197   AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op),
198               bcast.output_batch_size());
199   std::string base_name = std::string(batch_op->outputs[0]);
200 
201   // Compute slices for each batch in the LHS and RHS.
202   std::vector<std::string> slice_a_outputs =
203       SliceInput(input_lhs, base_name, "a", bcast.x_batch_size(), input_array_a,
204                  model, &tail_it);
205   std::vector<std::string> slice_b_outputs =
206       SliceInput(input_rhs, base_name, "b", bcast.y_batch_size(), input_array_b,
207                  model, &tail_it);
208 
209   // Compute (single batch) MatMul for each output batch. The MatMul outputs are
210   // then packed together into one output Tensor.
211   std::vector<std::string> pack_inputs;
212   for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) {
213     std::string batch_name =
214         absl::StrCat(batch_op->outputs[0], "_b", batch_idx);
215     const int a_batch_idx = bcast.IsBroadcastingRequired()
216                                 ? bcast.x_batch_indices()[batch_idx]
217                                 : batch_idx;
218     const int b_batch_idx = bcast.IsBroadcastingRequired()
219                                 ? bcast.y_batch_indices()[batch_idx]
220                                 : batch_idx;
221     auto* matmul_op = new TensorFlowMatMulOperator;
222     matmul_op->inputs = {slice_a_outputs[a_batch_idx],
223                          slice_b_outputs[b_batch_idx]};
224     matmul_op->outputs = {AvailableArrayName(*model, batch_name)};
225     auto& matmul_op_output = model->GetOrCreateArray(matmul_op->outputs[0]);
226     matmul_op_output.data_type = input_array_a.data_type;
227     tail_it = model->operators.emplace(tail_it, matmul_op) + 1;
228 
229     // Add to stack.
230     pack_inputs.push_back(matmul_op->outputs[0]);
231   }
232 
233   // Combine the result of each individual MatMul into a rank-3 Tensor.
234   auto* pack_op = new PackOperator;
235   pack_op->inputs = pack_inputs;
236   pack_op->outputs = {AvailableArrayName(*model, base_name + "/pack")};
237   auto& pack_op_output = model->GetOrCreateArray(pack_op->outputs[0]);
238   pack_op_output.data_type = input_array_a.data_type;
239   pack_op->axis = 0;
240   pack_op->values_count = pack_inputs.size();
241   tail_it = model->operators.emplace(tail_it, pack_op) + 1;
242 
243   // Reshape the rank-3 Tensor into the correct output shape.
244   const auto& result_batch_shape = bcast.output_batch_shape().dim_sizes();
245   std::vector<int> result_shape;
246   // Explicitly cast 64-bit sizes to int in order to avoid MSVC warnings.
247   std::transform(result_batch_shape.begin(), result_batch_shape.end(),
248                  std::back_inserter(result_shape),
249                  [](const int64 dim) { return static_cast<int>(dim); });
250   result_shape.push_back(input_array_a.shape().dims(dims_a - 2));
251   result_shape.push_back(input_array_b.shape().dims(dims_b - 1));
252 
253   auto* reshape_result_op = new TensorFlowReshapeOperator;
254   reshape_result_op->inputs = {
255       pack_op->outputs[0],
256       CreateInt32Array(model, base_name + "/reshape_out/shape", result_shape)};
257   reshape_result_op->outputs = {batch_op->outputs[0]};
258   model->operators.emplace(tail_it, reshape_result_op);
259 
260   DeleteOpAndArrays(model, batch_op);
261   *modified = true;
262   return ::tensorflow::Status::OK();
263 }
264 
265 }  // namespace toco
266