1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
17 
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
25 #include "mlir/IR/Attributes.h"  // from @llvm-project
26 #include "mlir/IR/Builders.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
29 #include "mlir/IR/Location.h"  // from @llvm-project
30 #include "mlir/IR/Operation.h"  // from @llvm-project
31 #include "mlir/Pass/Pass.h"  // from @llvm-project
32 #include "mlir/Support/LLVM.h"  // from @llvm-project
33 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
37 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
38 
39 namespace mlir {
40 namespace TF {
41 namespace collection_ops_util {
42 
CreateScalarConst(int32_t value,OpBuilder builder,Location loc)43 Value CreateScalarConst(int32_t value, OpBuilder builder, Location loc) {
44   auto attr = DenseIntElementsAttr::get(
45       RankedTensorType::get({}, builder.getI32Type()), value);
46   return builder.create<TF::ConstOp>(loc, attr);
47 }
48 
GetR1Const(ArrayRef<int64_t> r1,OpBuilder builder,Location loc,int bitwidth)49 Value GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc,
50                  int bitwidth) {
51   llvm::SmallVector<APInt, 4> values;
52   int64_t rank = r1.size();
53   values.reserve(rank);
54   for (int i = 0; i < rank; ++i) values.push_back(APInt(bitwidth, r1[i]));
55   auto result_type = RankedTensorType::get(
56       {rank}, IntegerType::get(builder.getContext(), bitwidth));
57   return builder.create<TF::ConstOp>(
58       loc, DenseElementsAttr::get(result_type, values));
59 }
60 
GetIndicesForElement(Value index,Value buffer,OpBuilder builder,Location loc)61 Value GetIndicesForElement(Value index, Value buffer, OpBuilder builder,
62                            Location loc) {
63   auto buffer_type = buffer.getType().cast<RankedTensorType>();
64   if (buffer_type.getShape().size() == 1) return index;
65   // Create a concat of index and trailing zeros.
66   llvm::SmallVector<int64_t, 8> zeros(buffer_type.getShape().size() - 1, 0);
67   auto zeros_tensor = GetR1Const(zeros, builder, loc);
68   return builder.create<TF::ConcatV2Op>(
69       loc,
70       ArrayRef<Type>{RankedTensorType::get(
71           {static_cast<int64_t>(buffer_type.getShape().size())},
72           getElementTypeOrSelf(index.getType()))},
73       ArrayRef<Value>{index, zeros_tensor, CreateScalarConst(0, builder, loc)});
74 }
75 
GetElement(Value index,Value buffer,OpBuilder builder,Location loc,bool keep_slice_shape)76 Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc,
77                  bool keep_slice_shape) {
78   auto buffer_type = buffer.getType().cast<RankedTensorType>();
79   // Create a slice then reshape to remove the leading trivial dimension of
80   // size 1.
81   llvm::SmallVector<int64_t, 8> slice_size =
82       llvm::to_vector<8>(buffer_type.getShape());
83   slice_size[0] = 1;
84   auto size_const = GetR1Const(slice_size, builder, loc);
85   auto slice_type =
86       RankedTensorType::get(slice_size, buffer_type.getElementType());
87   auto slice = builder.create<TF::SliceOp>(
88       loc, ArrayRef<Type>{slice_type},
89       ArrayRef<Value>{buffer, GetIndicesForElement(index, buffer, builder, loc),
90                       size_const});
91   if (keep_slice_shape) return slice;
92   auto element_type = RankedTensorType::get(buffer_type.getShape().drop_front(),
93                                             buffer_type.getElementType());
94   auto reshape = builder.create<TF::ReshapeOp>(
95       loc, ArrayRef<Type>{element_type},
96       ArrayRef<Value>{slice,
97                       GetR1Const(element_type.getShape(), builder, loc)});
98   return reshape.output();
99 }
100 
SetElement(Value index,Value buffer,Value element,OpBuilder builder,Location loc)101 Value SetElement(Value index, Value buffer, Value element, OpBuilder builder,
102                  Location loc) {
103   auto buffer_type = buffer.getType().cast<RankedTensorType>();
104   // Reshape the element to add a leading dimension of size 1 if th element does
105   // not have that dimension, then perform a dynamic update slice.
106   auto slice_shape = llvm::to_vector<8>(buffer_type.getShape());
107   slice_shape[0] = 1;
108   auto slice_type =
109       RankedTensorType::get(slice_shape, buffer_type.getElementType());
110   auto update_slice = element;
111   if (element.getType() != slice_type) {
112     update_slice = builder.create<TF::ReshapeOp>(
113         loc, ArrayRef<Type>{slice_type},
114         ArrayRef<Value>{element, GetR1Const(slice_shape, builder, loc)});
115   }
116   return builder
117       .create<TF::XlaDynamicUpdateSliceOp>(
118           loc, ArrayRef<Type>{buffer.getType()},
119           ArrayRef<Value>{buffer, update_slice,
120                           GetIndicesForElement(index, buffer, builder, loc)})
121       .output();
122 }
123 
GetSizeType(OpBuilder builder)124 TensorType GetSizeType(OpBuilder builder) {
125   return RankedTensorType::get({1}, builder.getIntegerType(32));
126 }
127 
ReshapeScalarToSizeType(OpBuilder builder,Value scalar,Location loc)128 Value ReshapeScalarToSizeType(OpBuilder builder, Value scalar, Location loc) {
129   auto size_type = GetSizeType(builder);
130   return builder.create<TF::ReshapeOp>(
131       loc, ArrayRef<Type>{size_type},
132       ArrayRef<Value>{scalar, GetR1Const(size_type.getShape(), builder, loc)});
133 }
134 
CreateInitBufferValue(ArrayRef<int64_t> element_shape,Value max_size,Operation * op,Type element_dtype,OpBuilder builder,Value * buffer)135 LogicalResult CreateInitBufferValue(ArrayRef<int64_t> element_shape,
136                                     Value max_size, Operation* op,
137                                     Type element_dtype, OpBuilder builder,
138                                     Value* buffer) {
139   auto max_count_op = max_size.getDefiningOp();
140   if (!max_count_op) return op->emitOpError("unknown max element count");
141   auto max_count_const_op = llvm::dyn_cast<TF::ConstOp>(max_count_op);
142   if (!max_count_const_op) return op->emitOpError("unknown max element count");
143   int64_t max_size_const =
144       (*max_count_const_op.value().getValues<APInt>().begin()).getSExtValue();
145   return CreateInitBufferValue(element_shape, max_size_const, op, element_dtype,
146                                builder, buffer);
147 }
148 
CreateInitBufferValue(ArrayRef<int64_t> element_shape,int64_t max_size,Operation * op,Type element_dtype,OpBuilder builder,Value * buffer)149 LogicalResult CreateInitBufferValue(ArrayRef<int64_t> element_shape,
150                                     int64_t max_size, Operation* op,
151                                     Type element_dtype, OpBuilder builder,
152                                     Value* buffer) {
153   llvm::SmallVector<int64_t, 8> buffer_shape;
154   buffer_shape.push_back(max_size);
155   for (int64_t dim : element_shape) {
156     buffer_shape.push_back(dim);
157   }
158   auto zero = CreateScalarConst(0, builder, op->getLoc());
159   if (getElementTypeOrSelf(zero.getType()) != element_dtype) {
160     zero = builder.create<TF::CastOp>(
161         op->getLoc(), ArrayRef<Type>{RankedTensorType::get({}, element_dtype)},
162         ArrayRef<Value>{zero});
163   }
164   auto buffer_type = RankedTensorType::get(buffer_shape, element_dtype);
165   auto broadcast = builder.create<TF::BroadcastToOp>(
166       op->getLoc(), ArrayRef<Type>{buffer_type},
167       ArrayRef<Value>{zero, GetR1Const(buffer_shape, builder, op->getLoc())});
168   *buffer = broadcast.output();
169   return success();
170 }
171 
GetElementTypeFromAccess(Value collection,ModuleOp module,llvm::function_ref<llvm::Optional<Type> (Operation *)> infer_from_op)172 llvm::Optional<RankedTensorType> GetElementTypeFromAccess(
173     Value collection, ModuleOp module,
174     llvm::function_ref<llvm::Optional<Type>(Operation*)> infer_from_op) {
175   for (auto& use : collection.getUses()) {
176     if (auto while_op = llvm::dyn_cast<TF::WhileOp>(use.getOwner())) {
177       auto body = while_op.body_function();
178       assert(body);
179       auto type_from_body = GetElementTypeFromAccess(
180           body.getArgument(use.getOperandNumber()), module, infer_from_op);
181       if (type_from_body.hasValue()) return type_from_body;
182     } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(use.getOwner())) {
183       auto then_branch = if_op.then_function();
184       auto else_branch = if_op.else_function();
185       assert(then_branch && else_branch);
186       auto type_from_then = GetElementTypeFromAccess(
187           then_branch.getArgument(use.getOperandNumber() - 1), module,
188           infer_from_op);
189       if (type_from_then.hasValue()) return type_from_then;
190       auto type_from_else = GetElementTypeFromAccess(
191           else_branch.getArgument(use.getOperandNumber() - 1), module,
192           infer_from_op);
193       if (type_from_else.hasValue()) return type_from_else;
194     } else if (auto call = llvm::dyn_cast<CallOpInterface>(use.getOwner())) {
195       auto callee = dyn_cast<FuncOp>(call.resolveCallable());
196       auto type_from_callee = GetElementTypeFromAccess(
197           callee.getArgument(use.getOperandNumber()), module, infer_from_op);
198       if (type_from_callee.hasValue()) return type_from_callee;
199     } else if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(use.getOwner())) {
200       auto type_from_alias = GetElementTypeFromAccess(
201           use.getOwner()->getResult(use.getOperandNumber()), module,
202           infer_from_op);
203       if (type_from_alias.hasValue()) return type_from_alias;
204     } else if (auto type = infer_from_op(use.getOwner())) {
205       if (!type) continue;
206       auto elem_type = type->dyn_cast<RankedTensorType>();
207       if (elem_type && elem_type.hasStaticShape()) return elem_type;
208     }
209   }
210   return llvm::None;
211 }
212 
213 // Creates a ReadVariableOp on a local variable.
ReadLocalVariable(Value local_var,OpBuilder builder,Location loc)214 Value ReadLocalVariable(Value local_var, OpBuilder builder, Location loc) {
215   return builder
216       .create<TF::ReadVariableOp>(
217           loc,
218           ArrayRef<Type>{getElementTypeOrSelf(local_var.getType())
219                              .cast<TF::ResourceType>()
220                              .getSubtypes()[0]},
221           ArrayRef<Value>{local_var})
222       .value();
223 }
224 
225 // Creates an AssignVariableOp on a local variable.
WriteLocalVariable(Value local_var,Value value,OpBuilder builder,Location loc)226 TF::AssignVariableOp WriteLocalVariable(Value local_var, Value value,
227                                         OpBuilder builder, Location loc) {
228   return builder.create<TF::AssignVariableOp>(
229       loc, ArrayRef<Type>{}, ArrayRef<Value>{local_var, value});
230 }
231 
AccumulateBuffers(Value a,Value b,OpBuilder builder,Location loc)232 Value AccumulateBuffers(Value a, Value b, OpBuilder builder, Location loc) {
233   if (getElementTypeOrSelf(a.getType()) == builder.getI1Type()) {
234     return builder.create<TF::LogicalOrOp>(loc, ArrayRef<Type>{a.getType()},
235                                            ArrayRef<Value>{a, b});
236   }
237   return builder.create<TF::AddV2Op>(loc, ArrayRef<Type>{a.getType()},
238                                      ArrayRef<Value>{a, b});
239 }
240 
241 namespace {
242 
GetFirstIfIndicesAreContiguous(Value indices)243 int64_t GetFirstIfIndicesAreContiguous(Value indices) {
244   auto type = indices.getType().dyn_cast<RankedTensorType>();
245   if (!type) return -1;
246   auto indices_op = indices.getDefiningOp();
247   if (!indices_op) return -1;
248   auto const_op = llvm::dyn_cast<TF::ConstOp>(indices_op);
249   if (!const_op) return -1;
250   int64_t last_index = -1;
251   int64_t first_index = -1;
252   for (const auto& ind : const_op.value().getValues<APInt>()) {
253     if (last_index == -1) {
254       last_index = ind.getSExtValue();
255       first_index = last_index;
256       continue;
257     }
258     if (last_index + 1 != ind.getSExtValue()) return -1;
259     last_index++;
260   }
261   return first_index;
262 }
263 
264 }  // namespace
265 
GatherElements(Value indices,Value buffer,OpBuilder builder,Location loc)266 Value GatherElements(Value indices, Value buffer, OpBuilder builder,
267                      Location loc) {
268   auto buffer_type = buffer.getType().cast<RankedTensorType>();
269   auto result_shape = llvm::to_vector<8>(buffer_type.getShape());
270   result_shape[0] = indices.getType().cast<RankedTensorType>().getDimSize(0);
271   int64_t maybe_contiguous_start = GetFirstIfIndicesAreContiguous(indices);
272   if (maybe_contiguous_start >= 0) {
273     llvm::SmallVector<int64_t, 8> slice_starts(result_shape.size(), 0);
274     slice_starts[0] = maybe_contiguous_start;
275     auto slice_type =
276         RankedTensorType::get(result_shape, buffer_type.getElementType());
277     return builder.create<TF::SliceOp>(
278         loc, ArrayRef<Type>{slice_type},
279         ArrayRef<Value>{buffer, GetR1Const(slice_starts, builder, loc),
280                         GetR1Const(result_shape, builder, loc)});
281   }
282   auto result_type =
283       RankedTensorType::get(result_shape, buffer_type.getElementType());
284   return builder.create<TF::GatherV2Op>(
285       loc, ArrayRef<Type>{result_type},
286       ArrayRef<Value>{buffer, indices, CreateScalarConst(0, builder, loc)});
287 }
288 
ScatterAccumulateElements(Value indices,Value updates,Value buffer,OpBuilder builder,Location loc)289 Value ScatterAccumulateElements(Value indices, Value updates, Value buffer,
290                                 OpBuilder builder, Location loc) {
291   auto buffer_type = buffer.getType().cast<RankedTensorType>();
292   auto updates_type = updates.getType().cast<RankedTensorType>();
293   int64_t maybe_contiguous_start = GetFirstIfIndicesAreContiguous(indices);
294   if (maybe_contiguous_start == 0 && buffer_type == updates_type) {
295     return AccumulateBuffers(buffer, updates, builder, loc);
296   }
297   // We cannot simply use a TensorScatterUpdate, as it does not accumulate with
298   // the old data; it is tricky to manually add the old data either, since there
299   // could be duplicates in the index. We follow the old bridge's approach by
300   // iterating through the indices.
301   auto per_slice_shape = llvm::to_vector<8>(buffer_type.getShape());
302   per_slice_shape[0] = 1;
303   auto slice_sizes = GetR1Const(per_slice_shape, builder, loc);
304   llvm::SmallVector<int64_t, 8> starts_in_update(buffer_type.getRank(), 0);
305   for (int64_t i = 0; i < updates_type.getDimSize(0); ++i) {
306     auto index = builder.create<TF::SliceOp>(
307         loc, ArrayRef<Type>{GetSizeType(builder)},
308         ArrayRef<Value>{indices, GetR1Const({i}, builder, loc),
309                         GetR1Const({1}, builder, loc)});
310     auto old_slice =
311         GetElement(index, buffer, builder, loc, /*keep_slice_shape=*/true);
312     starts_in_update[0] = i;
313     auto update_slice_starts = GetR1Const(starts_in_update, builder, loc);
314     auto slice =
315         builder
316             .create<TF::SliceOp>(
317                 loc, ArrayRef<Type>{old_slice.getType()},
318                 ArrayRef<Value>{updates, update_slice_starts, slice_sizes})
319             .output();
320     slice = AccumulateBuffers(old_slice, slice, builder, loc);
321     buffer = SetElement(index, buffer, slice, builder, loc);
322   }
323   return buffer;
324 }
325 
326 }  // namespace collection_ops_util
327 }  // namespace TF
328 }  // namespace mlir
329