1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
17 
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/None.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Support/raw_ostream.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
25 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
26 #include "mlir/IR/Attributes.h"  // from @llvm-project
27 #include "mlir/IR/Builders.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/Identifier.h"  // from @llvm-project
31 #include "mlir/IR/Location.h"  // from @llvm-project
32 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
33 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
34 #include "mlir/IR/Operation.h"  // from @llvm-project
35 #include "mlir/IR/Types.h"  // from @llvm-project
36 #include "mlir/IR/Value.h"  // from @llvm-project
37 #include "mlir/Support/LLVM.h"  // from @llvm-project
38 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
39 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 
42 namespace mlir {
43 namespace TFL {
44 
45 namespace {
46 
CreateI32SplatConst(OpBuilder * builder,ArrayRef<int64_t> shape,int32_t val,mlir::Location location)47 Value CreateI32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
48                           int32_t val, mlir::Location location) {
49   auto type = RankedTensorType::get(shape, builder->getIntegerType(32));
50   auto attr = DenseElementsAttr::get(type, val);
51   return builder->create<ConstantOp>(location, type, attr);
52 }
53 
CreateF32SplatConst(OpBuilder * builder,ArrayRef<int64_t> shape,float val,mlir::Location location)54 Value CreateF32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
55                           float val, mlir::Location location) {
56   auto type = RankedTensorType::get(shape, builder->getF32Type());
57   auto attr = DenseElementsAttr::get(type, val);
58   return builder->create<ConstantOp>(location, type, attr);
59 }
60 
CreatTfF32ConstOp(OpBuilder * builder,ArrayRef<int64_t> shape,float val,mlir::Location location)61 Value CreatTfF32ConstOp(OpBuilder* builder, ArrayRef<int64_t> shape, float val,
62                         mlir::Location location) {
63   auto type = RankedTensorType::get(shape, builder->getF32Type());
64   auto ele_type = RankedTensorType::get({1}, builder->getF32Type());
65   auto attr = DenseElementsAttr::get(ele_type, val);
66   return builder->create<TF::ConstOp>(location, type, attr);
67 }
68 
CreateI64DenseConst(OpBuilder * builder,ArrayRef<int64_t> shape,ArrayRef<int64_t> values,mlir::Location location)69 Value CreateI64DenseConst(OpBuilder* builder, ArrayRef<int64_t> shape,
70                           ArrayRef<int64_t> values, mlir::Location location) {
71   auto type = RankedTensorType::get(static_cast<int>(shape.size()),
72                                     builder->getIntegerType(64));
73   auto attr = DenseElementsAttr::get(type, values);
74   return builder->create<ConstantOp>(location, type, attr);
75 }
76 
CreateI32DenseConst(OpBuilder * builder,ArrayRef<int32_t> values,mlir::Location location)77 Value CreateI32DenseConst(OpBuilder* builder, ArrayRef<int32_t> values,
78                           mlir::Location location) {
79   auto type = RankedTensorType::get(static_cast<int>(values.size()),
80                                     builder->getIntegerType(32));
81   auto attr = DenseElementsAttr::get(type, values);
82   return builder->create<ConstantOp>(location, type, attr);
83 }
84 
CreateNoneValue(OpBuilder * builder,mlir::Location location)85 Value CreateNoneValue(OpBuilder* builder, mlir::Location location) {
86   return builder->create<mlir::ConstantOp>(location, builder->getNoneType(),
87                                            builder->getUnitAttr());
88 }
89 
Transpose(OpBuilder * builder,Value value_to_transpose,SmallVector<int32_t,4> perm,RankedTensorType original_type,mlir::Location location)90 Value Transpose(OpBuilder* builder, Value value_to_transpose,
91                 SmallVector<int32_t, 4> perm, RankedTensorType original_type,
92                 mlir::Location location) {
93   // Create a constant op for transpose permutation.
94   auto perm_op = CreateI32DenseConst(builder, perm, location);
95 
96   // Create tensor type for the transpose result.
97   auto transpose_type = original_type;
98   auto transpose_shape =
99       llvm::to_vector<8>(llvm::map_range(perm, [transpose_type](int32_t dim) {
100         return transpose_type.getDimSize(dim);
101       }));
102   auto elem_type = transpose_type.getElementType();
103   auto result_type = RankedTensorType::get(transpose_shape, elem_type);
104 
105   return builder->create<TF::TransposeOp>(location, result_type,
106                                           value_to_transpose, perm_op);
107 }
108 
Transpose2D(OpBuilder * builder,Value value_to_transpose,RankedTensorType type,mlir::Location location)109 Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
110                   RankedTensorType type, mlir::Location location) {
111   // Create a constant op for transpose permutation.
112   SmallVector<int32_t, 4> perm = {1, 0};
113   return Transpose(builder, value_to_transpose, perm, type, location);
114 }
115 
Reverse(OpBuilder * builder,Value value_to_reverse,int axis,RankedTensorType type,mlir::Location location)116 Value Reverse(OpBuilder* builder, Value value_to_reverse, int axis,
117               RankedTensorType type, mlir::Location location) {
118   auto axis_op = CreateI32SplatConst(builder, {1}, axis, location);
119   // The result type will be the same as the input.
120   return builder->create<TF::ReverseV2Op>(location, type, value_to_reverse,
121                                           axis_op);
122 }
123 
GetRankedTensorShape(Value value)124 ArrayRef<int64_t> GetRankedTensorShape(Value value) {
125   return value.getType().cast<RankedTensorType>().getShape();
126 }
127 
SliceRankedTensor(OpBuilder * builder,Value input,ArrayRef<int64_t> begin_shape,ArrayRef<int64_t> begin_values,ArrayRef<int64_t> size_shape,ArrayRef<int64_t> size_values,mlir::Location location)128 Value SliceRankedTensor(OpBuilder* builder, Value input,
129                         ArrayRef<int64_t> begin_shape,
130                         ArrayRef<int64_t> begin_values,
131                         ArrayRef<int64_t> size_shape,
132                         ArrayRef<int64_t> size_values,
133                         mlir::Location location) {
134   // If the size of the tensor to be sliced from the input overflows
135   // the input tensor's dimensions, return 0-valued tensor of the requested
136   // shape.
137   ArrayRef<int64_t> input_shape = GetRankedTensorShape(input);
138   for (int i = 0, end = input_shape.size(); i < end; i++) {
139     if (begin_values[i] < 0 ||
140         (begin_values[i] + size_values[i] > input_shape[i])) {
141       return CreateF32SplatConst(builder, size_shape, 0, location);
142     }
143   }
144 
145   // Create a dense constant op for slice's begin
146   auto slice_i2c_begin =
147       CreateI64DenseConst(builder, begin_shape, begin_values, location);
148 
149   // Create a dense constant op for slice's size
150   auto slice_i2c_size =
151       CreateI64DenseConst(builder, size_shape, size_values, location);
152 
153   return builder->create<TF::SliceOp>(
154       location,
155       RankedTensorType::get(
156           size_values,
157           input.getType().cast<RankedTensorType>().getElementType()),
158       input, slice_i2c_begin, slice_i2c_size);
159 }
160 
CreateStridedSliceOp(mlir::Location loc,ArrayRef<int64_t> output_shape,Value input,ArrayRef<int32_t> begin,ArrayRef<int32_t> end,ArrayRef<int32_t> strides,int64_t begin_mask,int64_t end_mask,int64_t ellipsis_mask,int64_t new_axis_mask,int64_t shrink_axis_mask,OpBuilder * builder)161 Value CreateStridedSliceOp(mlir::Location loc, ArrayRef<int64_t> output_shape,
162                            Value input, ArrayRef<int32_t> begin,
163                            ArrayRef<int32_t> end, ArrayRef<int32_t> strides,
164                            int64_t begin_mask, int64_t end_mask,
165                            int64_t ellipsis_mask, int64_t new_axis_mask,
166                            int64_t shrink_axis_mask, OpBuilder* builder) {
167   auto output_type = RankedTensorType::get(
168       output_shape, input.getType().cast<RankedTensorType>().getElementType());
169   auto begin_tensor = CreateI32DenseConst(builder, begin, loc);
170   auto end_tensor = CreateI32DenseConst(builder, end, loc);
171   auto strides_tensor = CreateI32DenseConst(builder, strides, loc);
172 
173   return builder->create<TF::StridedSliceOp>(
174       loc, output_type, input, begin_tensor, end_tensor, strides_tensor,
175       builder->getI64IntegerAttr(begin_mask),
176       builder->getI64IntegerAttr(end_mask),
177       builder->getI64IntegerAttr(ellipsis_mask),
178       builder->getI64IntegerAttr(new_axis_mask),
179       builder->getI64IntegerAttr(shrink_axis_mask));
180 }
181 
182 }  // namespace
183 
SetWeightForInputToCellGate()184 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToCellGate() {
185   SmallVector<int64_t, 2> begin_i2c_values = {0, 0};
186   input2cell_ = SliceRankedTensor(
187       &builder_, weight_transposed_, weight_slice_shape_, begin_i2c_values,
188       weight_slice_shape_, weight_slice_size_input_values_,
189       fused_func_op_.getLoc());
190 }
191 
SetWeightForInputToInputGate()192 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToInputGate() {
193   SmallVector<int64_t, 2> begin_i2i_values = {n_cell_, 0};
194   input2input_ = couple_input_forget_gates_
195                      ? none_
196                      : SliceRankedTensor(&builder_, weight_transposed_,
197                                          weight_slice_shape_, begin_i2i_values,
198                                          weight_slice_shape_,
199                                          weight_slice_size_input_values_,
200                                          fused_func_op_.getLoc());
201 }
202 
SetWeightForInputToForgetGate()203 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToForgetGate() {
204   int input_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
205   SmallVector<int64_t, 2> begin_i2f_values = {input_forget_start, 0};
206   input2forget_ = SliceRankedTensor(
207       &builder_, weight_transposed_, weight_slice_shape_, begin_i2f_values,
208       weight_slice_shape_, weight_slice_size_input_values_,
209       fused_func_op_.getLoc());
210 }
211 
SetWeightForInputToOutputGate()212 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToOutputGate() {
213   int input_output_start =
214       couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
215   SmallVector<int64_t, 2> begin_i2o_values = {input_output_start, 0};
216   input2output_ = SliceRankedTensor(
217       &builder_, weight_transposed_, weight_slice_shape_, begin_i2o_values,
218       weight_slice_shape_, weight_slice_size_input_values_,
219       fused_func_op_.getLoc());
220 }
221 
SetWeightForRecurrentToCellGate()222 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToCellGate() {
223   SmallVector<int64_t, 2> begin_rec2c_values = {0, n_input_};
224   rec2cell_ = SliceRankedTensor(
225       &builder_, weight_transposed_, weight_slice_shape_, begin_rec2c_values,
226       weight_slice_shape_, weight_slice_size_recurrent_values_,
227       fused_func_op_.getLoc());
228 }
229 
SetWeightForRecurrentToInputGate()230 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToInputGate() {
231   SmallVector<int64_t, 2> begin_rec2i_values = {n_cell_, n_input_};
232   rec2input_ = couple_input_forget_gates_
233                    ? none_
234                    : SliceRankedTensor(&builder_, weight_transposed_,
235                                        weight_slice_shape_, begin_rec2i_values,
236                                        weight_slice_shape_,
237                                        weight_slice_size_recurrent_values_,
238                                        fused_func_op_.getLoc());
239 }
240 
SetWeightForRecurrentToForgetGate()241 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToForgetGate() {
242   int rec_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
243   SmallVector<int64_t, 2> begin_rec2f_values = {rec_forget_start, n_input_};
244   rec2forget_ = SliceRankedTensor(
245       &builder_, weight_transposed_, weight_slice_shape_, begin_rec2f_values,
246       weight_slice_shape_, weight_slice_size_recurrent_values_,
247       fused_func_op_.getLoc());
248 }
249 
SetWeightForRecurrentToOutputGate()250 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToOutputGate() {
251   int rec_output_start = couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
252   SmallVector<int64_t, 2> begin_rec2o_values = {rec_output_start, n_input_};
253   rec2output_ = SliceRankedTensor(
254       &builder_, weight_transposed_, weight_slice_shape_, begin_rec2o_values,
255       weight_slice_shape_, weight_slice_size_recurrent_values_,
256       fused_func_op_.getLoc());
257 }
258 
SetBiasToCellGate()259 void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToCellGate() {
260   SmallVector<int64_t, 1> begin_bias2c_values = {0};
261   bias2cell_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
262                                  begin_bias2c_values, bias_slice_shape_,
263                                  bias_size_values_, fused_func_op_.getLoc());
264 }
265 
SetBiasToInputGate()266 void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToInputGate() {
267   SmallVector<int64_t, 1> begin_bias2i_values = {n_cell_};
268   bias2input_ =
269       couple_input_forget_gates_
270           ? none_
271           : SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
272                               begin_bias2i_values, bias_slice_shape_,
273                               bias_size_values_, fused_func_op_.getLoc());
274 }
275 
SetBiasToForgetGate()276 void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToForgetGate() {
277   int bias_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
278   SmallVector<int64_t, 1> begin_bias2f_values = {bias_forget_start};
279   bias2forget_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
280                                    begin_bias2f_values, bias_slice_shape_,
281                                    bias_size_values_, fused_func_op_.getLoc());
282 }
283 
SetBiasToOutputGate()284 void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToOutputGate() {
285   int bias_output_start =
286       couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
287   SmallVector<int64_t, 1> begin_bias2o_values = {bias_output_start};
288   bias2output_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
289                                    begin_bias2o_values, bias_slice_shape_,
290                                    bias_size_values_, fused_func_op_.getLoc());
291 }
292 
SetProjection()293 void ConvertLSTMCellSimpleToFusedLSTM::SetProjection() {
294   SmallVector<int64_t, 2> projection_slice_shape = {
295       1, num_cols_projection_transposed_};
296   SmallVector<int64_t, 2> projection_slice_size_values = {n_output_, n_cell_};
297   SmallVector<int64_t, 2> projection_slice_begin_values = {0, 0};
298   proj_weight_ =
299       !projection_
300           ? none_
301           : SliceRankedTensor(
302                 &builder_, projection_transposed_, projection_slice_shape,
303                 projection_slice_begin_values, projection_slice_shape,
304                 projection_slice_size_values, fused_func_op_.getLoc());
305 }
306 
SetProjectionBias()307 void ConvertLSTMCellSimpleToFusedLSTM::SetProjectionBias() {
308   proj_bias_ = !projection_type_
309                    ? none_
310                    : CreateF32SplatConst(&builder_, {n_output_}, 0,
311                                          fused_func_op_.getLoc());
312 }
313 
SetInputActivationState()314 void ConvertLSTMCellSimpleToFusedLSTM::SetInputActivationState() {
315   input_activation_state_ = CreateF32SplatConst(&builder_, {1, n_output_}, 0,
316                                                 fused_func_op_.getLoc());
317 }
318 
SetInputCellState()319 void ConvertLSTMCellSimpleToFusedLSTM::SetInputCellState() {
320   input_cell_state_ =
321       CreateF32SplatConst(&builder_, {1, n_cell_}, 0, fused_func_op_.getLoc());
322 }
323 
SetCellLayerNormCoefficients()324 void ConvertLSTMCellSimpleToFusedLSTM::SetCellLayerNormCoefficients() {
325   cell_layer_norm_coefficients_ = none_;
326 }
327 
SetInputLayerNormCoefficients()328 void ConvertLSTMCellSimpleToFusedLSTM::SetInputLayerNormCoefficients() {
329   input_layer_norm_coefficients_ = none_;
330 }
331 
SetForgetLayerNormCoefficients()332 void ConvertLSTMCellSimpleToFusedLSTM::SetForgetLayerNormCoefficients() {
333   forget_layer_norm_coefficients_ = none_;
334 }
SetOutputLayerNormCoefficients()335 void ConvertLSTMCellSimpleToFusedLSTM::SetOutputLayerNormCoefficients() {
336   output_layer_norm_coefficients_ = none_;
337 }
338 
GenerateFusedOpOperands()339 void ConvertLSTMCellSimpleToFusedLSTM::GenerateFusedOpOperands() {
340   // Transpose both weight and projection.
341   weight_transposed_ =
342       Transpose2D(&builder_, weight_, weight_type_, fused_func_op_.getLoc());
343   projection_transposed_ = Transpose2D(&builder_, projection_, projection_type_,
344                                        fused_func_op_.getLoc());
345 
346   none_ = CreateNoneValue(&builder_, fused_func_op_.getLoc());
347   // Extract input to cifg gates via slicing the weight tensor
348   SetWeightForInputToCellGate();
349   SetWeightForInputToInputGate();
350   SetWeightForInputToForgetGate();
351   SetWeightForInputToOutputGate();
352 
353   // Extract recurrent to cifg gates via slicing the weight tensor
354   SetWeightForRecurrentToCellGate();
355   SetWeightForRecurrentToInputGate();
356   SetWeightForRecurrentToForgetGate();
357   SetWeightForRecurrentToOutputGate();
358 
359   // Extract bias to cifg gates via slicing the bias tensor
360   SetBiasToCellGate();
361   SetBiasToInputGate();
362   SetBiasToForgetGate();
363   SetBiasToOutputGate();
364 
365   // Extract projection and set an empty projection bias
366   SetProjection();
367   SetProjectionBias();
368 
369   // Set the variable tensors
370   SetInputActivationState();
371   SetInputCellState();
372 
373   // Extract the layer norm coefficients
374   SetCellLayerNormCoefficients();
375   SetInputLayerNormCoefficients();
376   SetForgetLayerNormCoefficients();
377   SetOutputLayerNormCoefficients();
378 }
379 
UpdateFuncSignature()380 void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() {
381   // https://github.com/tensorflow/community/pull/113
382   SmallVector<int64_t, 2> output_shape{1, -1};
383   auto input_types = fused_func_op_.getType().getInputs();
384   auto output_type = mlir::RankedTensorType::get(
385       output_shape, input_.getType().cast<RankedTensorType>().getElementType());
386   fused_func_op_.setType(mlir::FunctionType::get(fused_func_op_.getContext(),
387                                                  input_types, output_type));
388 }
389 
RewriteFunc()390 LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
391   LogicalResult result = Initialize();
392   if (failed(result)) {
393     return result;
394   }
395 
396   // Update the func signature, based on output shape.
397   // The func will ultimately return the output of the fused
398   // LSTM op.
399   UpdateFuncSignature();
400 
401   // Transform the weights, projection, bias and layer norm coefficients
402   // to generate operands for the TFL fused LSTM op.
403   GenerateFusedOpOperands();
404 
405   // Create the fused LSTM op.
406   SmallVector<int64_t, 2> output_shape = {1, n_output_};
407   auto result_type = mlir::RankedTensorType::get(
408       output_shape, input_.getType().cast<RankedTensorType>().getElementType());
409   lstm_ = builder_.create<mlir::TFL::LSTMOp>(
410       fused_func_op_.getLoc(), result_type, input_, input2input_, input2forget_,
411       input2cell_, input2output_, rec2input_, rec2forget_, rec2cell_,
412       rec2output_, /*cell_to_input_weights*/ none_,
413       /*cell_to_forget_weights*/ none_,
414       /*cell_to_output_weights*/ none_, bias2input_, bias2forget_, bias2cell_,
415       bias2output_, proj_weight_, proj_bias_, input_activation_state_,
416       input_cell_state_, input_layer_norm_coefficients_,
417       forget_layer_norm_coefficients_, cell_layer_norm_coefficients_,
418       output_layer_norm_coefficients_, builder_.getStringAttr("TANH"),
419       builder_.getF32FloatAttr(10.0), builder_.getF32FloatAttr(0.0),
420       builder_.getStringAttr("FULL"),
421       /*input_to_input_intermediate=*/mlir::TypeAttr(),
422       /*input_to_forget_intermediate=*/mlir::TypeAttr(),
423       /*input_to_cell_intermediate=*/mlir::TypeAttr(),
424       /*input_to_output_intermediate=*/mlir::TypeAttr(),
425       /*effective_hidden_scale_intermediate=*/mlir::TypeAttr());
426 
427   // Cast the static shaped lstm result to FuncOp's signature -
428   // Ranked but unknown 2nd dimension to support stacking these.
429   SmallVector<int64_t, 2> func_output_shape = {1, -1};
430   auto func_result_type = mlir::RankedTensorType::get(
431       func_output_shape,
432       input_.getType().cast<RankedTensorType>().getElementType());
433 
434   auto tensor_cast = builder_.create<mlir::tensor::CastOp>(
435       fused_func_op_.getLoc(), func_result_type, lstm_.getResult());
436   builder_.create<mlir::ReturnOp>(fused_func_op_.getLoc(),
437                                   tensor_cast.getResult());
438   return success();
439 }
440 
InitializeFromFuncAttributes()441 LogicalResult ConvertLSTMCellSimpleToFusedLSTM::InitializeFromFuncAttributes() {
442   auto attr = fused_func_op_->getAttrOfType<StringAttr>(kTFImplements);
443   if (!attr) {
444     return fused_func_op_.emitError()
445            << "Invalid function attribute, expected " << kTFImplements
446            << " attribute "
447               "not found";
448   }
449 
450   // TODO(ashwinm, b/144775479): Make these NamedAttribute on TF import
451   // once tf.function can support this.
452   llvm::SmallVector<llvm::StringRef, 4> attr_tokens;
453   attr.getValue().split(attr_tokens, ",");
454   if (attr_tokens.empty()) {
455     return fused_func_op_.emitError()
456            << kTFImplements << " attribute should be set";
457   }
458 
459   // Check if the interface matches.
460   if (GetCompositeOpName().str() != attr_tokens[0]) {
461     return fused_func_op_.emitError()
462            << "Unexpected interface for the composite op. Expected: "
463            << GetCompositeOpName() << " Actual: " << attr_tokens[0];
464   }
465 
466   // Extract other interface attributes, for now cifg.
467   couple_input_forget_gates_ =
468       std::find(attr_tokens.begin() + 1, attr_tokens.end(),
469                 kCoupleInputForgetGates) != attr_tokens.end();
470 
471   return success();
472 }
473 
Initialize()474 LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() {
475   if (failed(InitializeFromFuncAttributes())) {
476     return fused_func_op_.emitError()
477            << "Expected function attributes were not set on the function "
478               "encapsulating the composite op";
479   }
480 
481   num_gates_ = couple_input_forget_gates_ ? 3 : 4;
482 
483   input_ = fused_func_op_.getArgument(0);
484   bias_ = fused_func_op_.getArgument(2);
485 
486   weight_ = fused_func_op_.getArgument(1);
487   weight_type_ = weight_.getType().cast<RankedTensorType>();
488 
489   if (weight_type_.getRank() != 2) {
490     return fused_func_op_.emitError() << "The weight tensor was not of rank 2";
491   }
492 
493   if (weight_type_.getDimSize(1) % num_gates_ != 0) {
494     return fused_func_op_.emitError()
495            << "Invalid dimension 1 of weight tensor, "
496               "should be divisible by the number of gates";
497   }
498   n_cell_ = weight_type_.getDimSize(1) / num_gates_;
499 
500   projection_ = fused_func_op_.getArgument(3);
501   projection_type_ = projection_.getType().cast<RankedTensorType>();
502   if (projection_type_.getRank() != 2) {
503     n_output_ = n_cell_;
504   } else {
505     n_output_ = projection_type_.getDimSize(1);
506   }
507   n_input_ = weight_type_.getDimSize(0) - n_output_;
508   num_cols_weight_transposed_ = weight_type_.getDimSize(0);
509   num_cols_projection_transposed_ = projection_type_.getDimSize(0);
510 
511   bias_slice_shape_ = {n_cell_};
512   bias_size_values_ = {n_cell_};
513   weight_slice_shape_ = {1, num_cols_weight_transposed_};
514   weight_slice_size_input_values_ = {n_cell_, n_input_};
515   weight_slice_size_recurrent_values_ = {n_cell_, n_output_};
516 
517   return success();
518 }
519 
Initialize()520 LogicalResult ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::Initialize() {
521   if (failed(ConvertLSTMCellSimpleToFusedLSTM::Initialize())) {
522     return fused_func_op_.emitError()
523            << "Specified LayerNormalizedLSTMCellSimple was not of the expected "
524               "interface and cannot not be converted to the fused LSTM op";
525   }
526 
527   layer_norm_scale_ = fused_func_op_.getArgument(4);
528   layer_norm_scale_type_ = layer_norm_scale_.getType().cast<RankedTensorType>();
529   if (layer_norm_scale_type_.getRank() != 1) {
530     return fused_func_op_.emitError()
531            << "The layer_norm_scale tensor was not of rank 1";
532   }
533   layer_norm_slice_shape_ = {n_cell_};
534   layer_norm_size_values_ = {n_cell_};
535 
536   return success();
537 }
538 
539 void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetCellLayerNormCoefficients()540     SetCellLayerNormCoefficients() {
541   SmallVector<int64_t, 1> begin_cell_layer_norm_values = {0};
542   cell_layer_norm_coefficients_ =
543       SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
544                         begin_cell_layer_norm_values, layer_norm_slice_shape_,
545                         layer_norm_size_values_, fused_func_op_.getLoc());
546 }
547 
548 void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetInputLayerNormCoefficients()549     SetInputLayerNormCoefficients() {
550   SmallVector<int64_t, 1> begin_input_layer_norm_values = {n_cell_};
551   input_layer_norm_coefficients_ =
552       couple_input_forget_gates_
553           ? none_
554           : SliceRankedTensor(
555                 &builder_, layer_norm_scale_, layer_norm_slice_shape_,
556                 begin_input_layer_norm_values, layer_norm_slice_shape_,
557                 layer_norm_size_values_, fused_func_op_.getLoc());
558 }
559 
560 void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetForgetLayerNormCoefficients()561     SetForgetLayerNormCoefficients() {
562   SmallVector<int64_t, 1> begin_forget_layer_norm_values = {2 * n_cell_};
563   forget_layer_norm_coefficients_ =
564       SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
565                         begin_forget_layer_norm_values, layer_norm_slice_shape_,
566                         layer_norm_size_values_, fused_func_op_.getLoc());
567 }
568 
569 void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetOutputLayerNormCoefficients()570     SetOutputLayerNormCoefficients() {
571   SmallVector<int64_t, 1> begin_output_layer_norm_values = {3 * n_cell_};
572   output_layer_norm_coefficients_ =
573       SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
574                         begin_output_layer_norm_values, layer_norm_slice_shape_,
575                         layer_norm_size_values_, fused_func_op_.getLoc());
576 }
577 
Create1DConstantOp(const std::vector<int> & value,Location loc,OpBuilder * builder)578 TF::ConstOp Create1DConstantOp(const std::vector<int>& value, Location loc,
579                                OpBuilder* builder) {
580   auto type =
581       mlir::RankedTensorType::get(value.size(), builder->getIntegerType(32));
582   auto dense_values = mlir::DenseIntElementsAttr::get(type, value);
583   return builder->create<TF::ConstOp>(loc, dense_values);
584 }
585 
CreateScalarConstantOp(int value,Location loc,OpBuilder * builder)586 TF::ConstOp CreateScalarConstantOp(int value, Location loc,
587                                    OpBuilder* builder) {
588   return builder->create<TF::ConstOp>(loc, builder->getI32IntegerAttr(value));
589 }
590 
CreateEqualSizeSplitVOp(Value input,int axis,int splits,Location loc,OpBuilder * builder,Operation ** result)591 LogicalResult CreateEqualSizeSplitVOp(Value input, int axis, int splits,
592                                       Location loc, OpBuilder* builder,
593                                       Operation** result) {
594   auto input_type = input.getType().cast<RankedTensorType>();
595   SmallVector<int64_t, 4> output_shape;
596   int size_of_splits;
597   if (input_type.getRank() < axis || axis < 0) return failure();
598   for (int i = 0; i < input_type.getRank(); ++i) {
599     int dim = input_type.getDimSize(i);
600     if (i == axis) {
601       if (dim % splits != 0) {
602         return failure();
603       }
604       size_of_splits = dim / splits;
605       output_shape.push_back(size_of_splits);
606     } else {
607       output_shape.push_back(dim);
608     }
609   }
610 
611   SmallVector<mlir::Type, 4> output_types;
612   for (int i = 0; i < splits; ++i) {
613     output_types.push_back(
614         mlir::RankedTensorType::get(output_shape, input_type.getElementType()));
615   }
616   auto size_of_splits_op = Create1DConstantOp(
617       {size_of_splits, size_of_splits, size_of_splits, size_of_splits}, loc,
618       builder);
619 
620   auto axis_op = CreateScalarConstantOp(axis, loc, builder);
621   *result = builder->create<TF::SplitVOp>(loc, output_types, input,
622                                           size_of_splits_op.getResult(),
623                                           axis_op.getResult());
624   return success();
625 }
626 
627 // TODO(b/147436982): Consider refactor this to be more general.
ConvertKerasLSTMLayer(mlir::FuncOp func_op,OpBuilder * builder)628 LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
629   // For argument order, please check out standard_lstm under
630   // tensorflow/python/keras/layers/recurrent_v2.py
631   Value input = func_op.getArgument(0);
632   Value output_init_state = func_op.getArgument(1);
633   Value hidden_init_state = func_op.getArgument(2);
634   Value weight_kernel = func_op.getArgument(3);
635   Value recurrent_kernel = func_op.getArgument(4);
636   Value bias = func_op.getArgument(5);
637 
638   // The func op should have 5 outputs.
639   if (func_op.getNumResults() != 5) return failure();
640 
641   // TFL lstm only supports time-majored inputs, so if it's not time-majored,
642   // we will transpose the inputs and outputs.
643   auto time_major_attr = func_op->getAttrOfType<BoolAttr>("tf.time_major");
644   if (time_major_attr == nullptr) return failure();
645 
646   bool time_majored = time_major_attr.getValue();
647   auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
648   if (!input_type) {
649     func_op.emitError() << "Input type is not a ranked tensor type";
650     return failure();
651   }
652 
653   auto final_inputs = input;
654   auto final_input_type = input_type;
655 
656   // Handle go_backwards:
657   // LSTM in Keras semantic will reverse the input sequence if it's go_backwards
658   auto go_backwards_attr = func_op->getAttrOfType<BoolAttr>("tf.go_backwards");
659 
660   if (go_backwards_attr != nullptr && go_backwards_attr.getValue()) {
661     int time_dim = time_majored ? 0 : 1;
662     final_inputs = Reverse(builder, final_inputs, time_dim, final_input_type,
663                            func_op.getLoc());
664   }
665 
666   int batch = time_majored ? final_input_type.getDimSize(1)
667                            : final_input_type.getDimSize(0);
668   int time = time_majored ? final_input_type.getDimSize(0)
669                           : final_input_type.getDimSize(1);
670 
671   // Setup correct weights.
672   RankedTensorType weight_type =
673       weight_kernel.getType().cast<RankedTensorType>();
674   if (weight_type.getRank() != 2)
675     return func_op.emitError() << "The weight should be rank of 2";
676 
677   Value transposed_weight_kernel =
678       Transpose2D(builder, weight_kernel, weight_type, func_op.getLoc());
679 
680   RankedTensorType recurrent_kernel_type =
681       recurrent_kernel.getType().cast<RankedTensorType>();
682   const int n_output = recurrent_kernel_type.getDimSize(0);
683 
684   Value transpose_recurrent_kernel = Transpose2D(
685       builder, recurrent_kernel, recurrent_kernel_type, func_op.getLoc());
686 
687   // Splits the weights into 4: i, f, c, o.
688   const int splits = 4;
689 
690   Operation* weights_array;
691   if (failed(CreateEqualSizeSplitVOp(transposed_weight_kernel, 0, splits,
692                                      func_op.getLoc(), builder,
693                                      &weights_array)))
694     return failure();
695 
696   // Splits the recurrent_weights into 4:
697   Operation* recurrent_weights_array;
698   if (failed(CreateEqualSizeSplitVOp(transpose_recurrent_kernel, 0, splits,
699                                      func_op.getLoc(), builder,
700                                      &recurrent_weights_array)))
701     return failure();
702 
703   // Splits the bias into 4:
704   Operation* bias_array;
705   if (failed(CreateEqualSizeSplitVOp(bias, 0, splits, func_op.getLoc(), builder,
706                                      &bias_array)))
707     return failure();
708 
709   // Build the lstm op.
710   SmallVector<int64_t, 3> output_shape;
711   if (time_majored) {
712     output_shape = {time, batch, n_output};
713   } else {
714     output_shape = {batch, time, n_output};
715   }
716   auto result_type = mlir::RankedTensorType::get(
717       output_shape,
718       final_inputs.getType().cast<RankedTensorType>().getElementType());
719 
720   Value none = builder->create<mlir::ConstantOp>(
721       func_op.getLoc(), builder->getNoneType(), builder->getUnitAttr());
722   auto lstm = builder->create<mlir::TFL::UnidirectionalSequenceLSTMOp>(
723       func_op.getLoc(), result_type, /*input=*/final_inputs,
724       /*input_to_input_weights=*/weights_array->getResult(0),
725       /*input_to_forget_weights=*/weights_array->getResult(1),
726       /*input_to_cell_weights=*/weights_array->getResult(2),
727       /*input_to_output_weights=*/weights_array->getResult(3),
728       /*recurrent_to_input_weights=*/recurrent_weights_array->getResult(0),
729       /*recurrent_to_forget_weights=*/recurrent_weights_array->getResult(1),
730       /*recurrent_to_cell_weights=*/recurrent_weights_array->getResult(2),
731       /*recurrent_to_output_weights=*/recurrent_weights_array->getResult(3),
732       /*cell_to_input_weights=*/none,
733       /*cell_to_forget_weights=*/none,
734       /*cell_to_output_weights=*/none,
735       /*input_gate_bias=*/bias_array->getResult(0),
736       /*forget_gate_bias=*/bias_array->getResult(1),
737       /*cell_bias=*/bias_array->getResult(2),
738       /*output_gate_bias=*/bias_array->getResult(3),
739       /*projection_weights=*/none,
740       /*projection_bias=*/none,
741       /*input_activation_state=*/output_init_state,
742       /*input_cell_state=*/hidden_init_state,
743       /*input_layer_norm_coefficients=*/none,
744       /*forget_layer_norm_coefficients=*/none,
745       /*cell_layer_norm_coefficients=*/none,
746       /*output_layer_norm_coefficients=*/none, builder->getStringAttr("TANH"),
747       builder->getF32FloatAttr(10.0), builder->getF32FloatAttr(0.0),
748       builder->getBoolAttr(time_majored),
749       /*input_to_input_intermediate=*/mlir::TypeAttr(),
750       /*input_to_forget_intermediate=*/mlir::TypeAttr(),
751       /*input_to_cell_intermediate=*/mlir::TypeAttr(),
752       /*input_to_output_intermediate=*/mlir::TypeAttr(),
753       /*effective_hidden_scale_intermediate=*/mlir::TypeAttr());
754 
755   auto final_output_full_sequences = lstm.getResult();
756 
757   // Populate the last output: last output is sliced from the full sequences.
758   // If time_major: last_output = outputs[-1, :, :]
759   // else: last_output = outputs[:, -1, :]
760   //
761   // As we are creating the strided_slice op, we need to populate the following
762   // fields:
763   // end: should always be (0, 0, 0)
764   // strides: should always be (1, 1, 1)
765   // begin: should be (0, -1, 0) or (-1, 0, 0) if it's time-majored.
766   // new_axis_mask: should always be 0.
767   // ellipsis_mask: should always be 0.
768   // begin_mask & end_mask: should be 0b101 = 5 or 0b110 = 4 if it's
769   // time-majored. shrink_axis_mask: should be 0b010 = 2 or 0b001 = 1 if it's
770   // time-majored.
771   SmallVector<int64_t, 2> last_output_shape({batch, n_output});
772 
773   SmallVector<int32_t, 3> end({0, 0, 0});
774   SmallVector<int32_t, 3> strides({1, 1, 1});
775   SmallVector<int32_t, 3> begin;
776 
777   int64_t new_axis_mask = 0;
778   int64_t ellipsis_mask = 0;
779   int64_t begin_mask;
780   int64_t end_mask;
781   int64_t shrink_axis_mask;
782   if (time_majored) {
783     begin_mask = 6;
784     end_mask = 6;
785     shrink_axis_mask = 1;
786     begin = {-1, 0, 0};
787   } else {
788     begin_mask = 5;
789     end_mask = 5;
790     shrink_axis_mask = 2;
791     begin = {0, -1, 0};
792   }
793 
794   auto last_output = CreateStridedSliceOp(
795       func_op.getLoc(), last_output_shape, final_output_full_sequences, begin,
796       end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask,
797       shrink_axis_mask, builder);
798 
799   SmallVector<Value, 5> outputs;
800   SmallVector<Type, 5> output_types;
801 
802   // Due to the existence of the while loop, the timestamp may be unknown
803   // for the signature, for us, since we know the inputs, we can infer the time
804   // steps.
805 
806   // Last output.
807   outputs.push_back(last_output);
808   output_types.push_back(last_output.getType());
809 
810   // Full sequences.
811   outputs.push_back(final_output_full_sequences);
812   output_types.push_back(final_output_full_sequences.getType());
813 
814   // All the rest: states, device.
815   for (int i = 2; i < 5; ++i) {
816     auto result_type =
817         func_op.getCallableResults()[i].dyn_cast<RankedTensorType>();
818     outputs.push_back(CreatTfF32ConstOp(builder, result_type.getShape(), 0.0f,
819                                         func_op.getLoc()));
820     output_types.push_back(result_type);
821   }
822 
823   // Update function signatures.
824   func_op.setType(mlir::FunctionType::get(
825       func_op.getContext(), func_op.getType().getInputs(), output_types));
826 
827   builder->create<mlir::ReturnOp>(func_op.getLoc(), outputs);
828   return success();
829 }
830 
831 }  // namespace TFL
832 }  // namespace mlir
833