1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 #include <numeric>
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20
21 #include "absl/container/inlined_vector.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/util/matmul_bcast.h"
26 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
27 #include "tensorflow/lite/toco/model.h"
28 #include "tensorflow/lite/toco/tooling_util.h"
29
30 namespace toco {
31 namespace {
32
ToInlinedVector(const std::vector<int> & vec)33 absl::InlinedVector<int64, 4> ToInlinedVector(const std::vector<int>& vec) {
34 return absl::InlinedVector<int64, 4>(vec.begin(), vec.end());
35 }
36
SliceInput(const std::string & input,const std::string & base_name,const std::string & input_name,const int batch_size,const Array & input_array,Model * model,std::vector<std::unique_ptr<Operator>>::iterator * tail_it)37 std::vector<std::string> SliceInput(
38 const std::string& input, const std::string& base_name,
39 const std::string& input_name, const int batch_size,
40 const Array& input_array, Model* model,
41 std::vector<std::unique_ptr<Operator>>::iterator* tail_it) {
42 int rank = input_array.shape().dimensions_count();
43 int num_rows = input_array.shape().dims(rank - 2);
44 int num_cols = input_array.shape().dims(rank - 1);
45 // Reshape to rank-3 Tensor with first dimension as the batch size.
46 auto* reshape_op = new TensorFlowReshapeOperator;
47 reshape_op->inputs = {
48 input,
49 CreateInt32Array(model, absl::StrCat(base_name, "/reshape_a/shape"),
50 {batch_size, num_rows, num_cols})};
51 reshape_op->outputs = {AvailableArrayName(
52 *model, absl::StrCat(base_name, "/reshape_", input_name, "/reshape"))};
53 auto& reshape_op_output = model->GetOrCreateArray(reshape_op->outputs[0]);
54 reshape_op_output.data_type = input_array.data_type;
55 *tail_it = model->operators.emplace(*tail_it, reshape_op) + 1;
56
57 // Slice along each batch index and remember the slice output for future use.
58 std::vector<std::string> slice_outputs;
59 for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
60 std::string batch_name =
61 absl::StrCat(base_name, "_b", batch_idx, "/slice_", input_name);
62 auto* slice_op = new SliceOperator;
63 slice_op->inputs = {
64 reshape_op->outputs[0],
65 CreateInt32Array(model, absl::StrCat(batch_name, "/slice/begin"),
66 {batch_idx, 0, 0}),
67 CreateInt32Array(model, absl::StrCat(batch_name, "/slice/size"),
68 {1, num_rows, num_cols})};
69 slice_op->outputs = {
70 AvailableArrayName(*model, absl::StrCat(batch_name, "/slice"))};
71 auto& slice_op_output = model->GetOrCreateArray(slice_op->outputs[0]);
72 slice_op_output.data_type = input_array.data_type;
73 *tail_it = model->operators.emplace(*tail_it, slice_op) + 1;
74
75 // Reshape to rank-2: [1, num_rows, num_cols] -> [num_rows, num_cols].
76 auto* slice_reshape_op = new TensorFlowReshapeOperator;
77 slice_reshape_op->inputs = {
78 slice_op->outputs[0],
79 CreateInt32Array(model, absl::StrCat(batch_name, "/reshape/shape"),
80 {num_rows, num_cols})};
81 slice_reshape_op->outputs = {
82 AvailableArrayName(*model, absl::StrCat(batch_name, "/reshape"))};
83 auto& slice_reshape_op_output =
84 model->GetOrCreateArray(slice_reshape_op->outputs[0]);
85 slice_reshape_op_output.data_type = input_array.data_type;
86 *tail_it = model->operators.emplace(*tail_it, slice_reshape_op) + 1;
87
88 slice_outputs.push_back(slice_reshape_op->outputs[0]);
89 }
90 return slice_outputs;
91 }
92
GetTransposePerm(const Array & input_array)93 std::vector<int32> GetTransposePerm(const Array& input_array) {
94 const int32 dims = input_array.shape().dimensions_count();
95 std::vector<int32> perm_array_val(dims);
96 for (int i = 0; i < dims; ++i) {
97 perm_array_val[i] = i;
98 }
99 perm_array_val[dims - 2] = dims - 1;
100 perm_array_val[dims - 1] = dims - 2;
101 return perm_array_val;
102 }
103
GetTransposeShape(const Shape & input_shape,const std::vector<int32> & perm_array_val)104 std::vector<int32> GetTransposeShape(const Shape& input_shape,
105 const std::vector<int32>& perm_array_val) {
106 const int32 dims = input_shape.dimensions_count();
107 std::vector<int32> output_shape(dims);
108 for (int i = 0; i < dims; ++i) {
109 output_shape[i] = input_shape.dims(perm_array_val[i]);
110 }
111 return output_shape;
112 }
113
TransposeInput(const std::string & input,Model * model)114 TransposeOperator* TransposeInput(const std::string& input, Model* model) {
115 const auto& input_array = model->GetArray(input);
116 const auto perm_array = GetTransposePerm(input_array);
117 const std::string perm_array_name = CreateInt32Array(
118 model, AvailableArrayName(*model, input + "/transpose/perm"), perm_array);
119 auto* transpose_op = new TransposeOperator;
120 transpose_op->inputs = {input, perm_array_name};
121 transpose_op->outputs = {AvailableArrayName(*model, input + "/transpose")};
122 auto& transpose_array = model->GetOrCreateArray(transpose_op->outputs[0]);
123 *transpose_array.mutable_shape()->mutable_dims() =
124 GetTransposeShape(input_array.shape(), perm_array);
125 model->GetOrCreateArray(transpose_op->outputs[0]);
126 return transpose_op;
127 }
128
129 } // namespace
130
131 // Unrolls a BatchMatMul on the batch dimension.
132 // We need to slice each batch out of the inputs, matmul them individually, then
133 // stack them all back together at the end.
Run(Model * model,std::size_t op_index,bool * modified)134 ::tensorflow::Status UnrollBatchMatMul::Run(Model* model, std::size_t op_index,
135 bool* modified) {
136 *modified = false;
137 auto batch_op_it = model->operators.begin() + op_index;
138 if (batch_op_it->get()->type != OperatorType::kBatchMatMul) {
139 return ::tensorflow::Status::OK();
140 }
141 const auto* batch_op =
142 static_cast<const BatchMatMulOperator*>(batch_op_it->get());
143 auto& tail_it = batch_op_it;
144
145 std::string input_lhs = batch_op->inputs[0];
146 std::string input_rhs = batch_op->inputs[1];
147 const auto& input_lhs_array = model->GetArray(input_lhs);
148 const auto& input_rhs_array = model->GetArray(input_rhs);
149 if (!input_lhs_array.has_shape() || !input_rhs_array.has_shape())
150 return ::tensorflow::Status::OK();
151
152 // Transpose LHS input if necessary.
153 if (batch_op->adj_x) {
154 TransposeOperator* transpose_op = TransposeInput(input_lhs, model);
155 tail_it = model->operators.emplace(tail_it, transpose_op) + 1;
156 input_lhs = transpose_op->outputs[0];
157 }
158 const auto& input_array_a = model->GetArray(input_lhs);
159
160 // Transpose RHS input if necessary.
161 if (batch_op->adj_y) {
162 TransposeOperator* transpose_op = TransposeInput(input_rhs, model);
163 tail_it = model->operators.emplace(tail_it, transpose_op) + 1;
164 input_rhs = transpose_op->outputs[0];
165 }
166 const auto& input_array_b = model->GetArray(input_rhs);
167
168 // Ensure that input ranks are at least 2 and batch shapes are broadcastable.
169 const int dims_a = input_array_a.shape().dimensions_count();
170 const int dims_b = input_array_b.shape().dimensions_count();
171 CHECK_GE(dims_a, 2) << "First input must have rank >= 2";
172 CHECK_GE(dims_b, 2) << "Second input must have rank >= 2";
173
174 ::tensorflow::MatMulBCast bcast(
175 ToInlinedVector(input_array_a.shape().dims()),
176 ToInlinedVector(input_array_b.shape().dims()));
177 CHECK(bcast.IsValid()) << "Input batch dimensions must be broadcastable";
178
179 CHECK_EQ(input_array_a.shape().dims(dims_a - 1),
180 input_array_b.shape().dims(dims_b - 2))
181 << "Input dimensions must be compatible for multiplication. shape a = ["
182 << absl::StrJoin(input_array_a.shape().dims(), ", ") << "], shape b = ["
183 << absl::StrJoin(input_array_b.shape().dims(), ", ") << "]";
184
185 if (dims_a == 2 && dims_b == 2) {
186 // This is really just a MatMul.
187 AddMessageF("Replacing non-batch BatchMatMul %s by a MatMul operator",
188 LogName(*batch_op));
189 auto* matmul_op = new TensorFlowMatMulOperator;
190 matmul_op->inputs = {input_lhs, input_rhs};
191 matmul_op->outputs = batch_op->outputs;
192 model->operators.emplace(tail_it, matmul_op);
193 DeleteOpAndArrays(model, batch_op);
194 *modified = true;
195 return ::tensorflow::Status::OK();
196 }
197 AddMessageF("Unrolling BatchMatMul %s %d times", LogName(*batch_op),
198 bcast.output_batch_size());
199 std::string base_name = std::string(batch_op->outputs[0]);
200
201 // Compute slices for each batch in the LHS and RHS.
202 std::vector<std::string> slice_a_outputs =
203 SliceInput(input_lhs, base_name, "a", bcast.x_batch_size(), input_array_a,
204 model, &tail_it);
205 std::vector<std::string> slice_b_outputs =
206 SliceInput(input_rhs, base_name, "b", bcast.y_batch_size(), input_array_b,
207 model, &tail_it);
208
209 // Compute (single batch) MatMul for each output batch. The MatMul outputs are
210 // then packed together into one output Tensor.
211 std::vector<std::string> pack_inputs;
212 for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) {
213 std::string batch_name =
214 absl::StrCat(batch_op->outputs[0], "_b", batch_idx);
215 const int a_batch_idx = bcast.IsBroadcastingRequired()
216 ? bcast.x_batch_indices()[batch_idx]
217 : batch_idx;
218 const int b_batch_idx = bcast.IsBroadcastingRequired()
219 ? bcast.y_batch_indices()[batch_idx]
220 : batch_idx;
221 auto* matmul_op = new TensorFlowMatMulOperator;
222 matmul_op->inputs = {slice_a_outputs[a_batch_idx],
223 slice_b_outputs[b_batch_idx]};
224 matmul_op->outputs = {AvailableArrayName(*model, batch_name)};
225 auto& matmul_op_output = model->GetOrCreateArray(matmul_op->outputs[0]);
226 matmul_op_output.data_type = input_array_a.data_type;
227 tail_it = model->operators.emplace(tail_it, matmul_op) + 1;
228
229 // Add to stack.
230 pack_inputs.push_back(matmul_op->outputs[0]);
231 }
232
233 // Combine the result of each individual MatMul into a rank-3 Tensor.
234 auto* pack_op = new PackOperator;
235 pack_op->inputs = pack_inputs;
236 pack_op->outputs = {AvailableArrayName(*model, base_name + "/pack")};
237 auto& pack_op_output = model->GetOrCreateArray(pack_op->outputs[0]);
238 pack_op_output.data_type = input_array_a.data_type;
239 pack_op->axis = 0;
240 pack_op->values_count = pack_inputs.size();
241 tail_it = model->operators.emplace(tail_it, pack_op) + 1;
242
243 // Reshape the rank-3 Tensor into the correct output shape.
244 const auto& result_batch_shape = bcast.output_batch_shape().dim_sizes();
245 std::vector<int> result_shape;
246 // Explicitly cast 64-bit sizes to int in order to avoid MSVC warnings.
247 std::transform(result_batch_shape.begin(), result_batch_shape.end(),
248 std::back_inserter(result_shape),
249 [](const int64 dim) { return static_cast<int>(dim); });
250 result_shape.push_back(input_array_a.shape().dims(dims_a - 2));
251 result_shape.push_back(input_array_b.shape().dims(dims_b - 1));
252
253 auto* reshape_result_op = new TensorFlowReshapeOperator;
254 reshape_result_op->inputs = {
255 pack_op->outputs[0],
256 CreateInt32Array(model, base_name + "/reshape_out/shape", result_shape)};
257 reshape_result_op->outputs = {batch_op->outputs[0]};
258 model->operators.emplace(tail_it, reshape_result_op);
259
260 DeleteOpAndArrays(model, batch_op);
261 *modified = true;
262 return ::tensorflow::Status::OK();
263 }
264
265 } // namespace toco
266