1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <cstdint>
16 #include <memory>
17 #include <string>
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/container/inlined_vector.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/string_view.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/Optional.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
29 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
30 #include "mlir/IR/Builders.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
33 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
34 #include "mlir/IR/Location.h"  // from @llvm-project
35 #include "mlir/IR/Operation.h"  // from @llvm-project
36 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
37 #include "mlir/IR/Types.h"  // from @llvm-project
38 #include "mlir/IR/Value.h"  // from @llvm-project
39 #include "mlir/Pass/Pass.h"  // from @llvm-project
40 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
41 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
45 #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
49 #include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
50 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
51 #include "tensorflow/compiler/tf2xla/xla_context.h"
52 #include "tensorflow/compiler/tf2xla/xla_expression.h"
53 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
54 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
55 #include "tensorflow/compiler/xla/client/xla_builder.h"
56 #include "tensorflow/core/common_runtime/device.h"
57 #include "tensorflow/core/common_runtime/device_factory.h"
58 #include "tensorflow/core/common_runtime/device_mgr.h"
59 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
60 #include "tensorflow/core/framework/allocator.h"
61 #include "tensorflow/core/framework/function.h"
62 #include "tensorflow/core/framework/function.pb.h"
63 #include "tensorflow/core/framework/node_properties.h"
64 #include "tensorflow/core/framework/op.h"
65 #include "tensorflow/core/framework/op_kernel.h"
66 #include "tensorflow/core/framework/resource_mgr.h"
67 #include "tensorflow/core/framework/tensor.h"
68 #include "tensorflow/core/framework/types.h"
69 #include "tensorflow/core/framework/types.pb.h"
70 #include "tensorflow/core/platform/env.h"
71 #include "tensorflow/core/platform/status.h"
72 #include "tensorflow/core/protobuf/config.pb.h"
73 #include "tensorflow/core/public/session_options.h"
74 #include "tensorflow/stream_executor/lib/statusor.h"
75 #include "tensorflow/stream_executor/stream_executor.h"
76 
77 namespace mlir {
78 namespace mhlo {
79 
IsOpAllowedTf2XlaFallback(Operation * op)80 bool IsOpAllowedTf2XlaFallback(Operation* op) {
81   // Allowlisted TensorFlow ops are known to have well behaved tf2xla kernels
82   // building valid MLIR using MlirHloBuilder.
83   // TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for
84   // all tf2xla kernels.
85   // clang-format off
86 
87   static llvm::SmallDenseSet<mlir::TypeID, 512> ops = {
88     TypeID::get<TF::AbsOp>(),
89     TypeID::get<TF::AcoshOp>(),
90     TypeID::get<TF::AcosOp>(),
91     TypeID::get<TF::AddNOp>(),
92     TypeID::get<TF::AddV2Op>(),
93     TypeID::get<TF::AngleOp>(),
94     TypeID::get<TF::AdjustContrastv2Op>(),
95     TypeID::get<TF::AdjustHueOp>(),
96     TypeID::get<TF::AdjustSaturationOp>(),
97     TypeID::get<TF::ApproximateEqualOp>(),
98     TypeID::get<TF::ArgMaxOp>(),
99     TypeID::get<TF::ArgMinOp>(),
100     TypeID::get<TF::AsinhOp>(),
101     TypeID::get<TF::AsinOp>(),
102     TypeID::get<TF::Atan2Op>(),
103     TypeID::get<TF::AtanhOp>(),
104     TypeID::get<TF::AtanOp>(),
105     TypeID::get<TF::BatchMatMulV2Op>(),
106     TypeID::get<TF::BatchToSpaceOp>(),
107     TypeID::get<TF::BesselI0eOp>(),
108     TypeID::get<TF::BesselI1eOp>(),
109     TypeID::get<TF::BetaincOp>(),
110     TypeID::get<TF::BiasAddGradOp>(),
111     TypeID::get<TF::BiasAddOp>(),
112     TypeID::get<TF::BitwiseAndOp>(),
113     TypeID::get<TF::BitwiseOrOp>(),
114     TypeID::get<TF::BitwiseXorOp>(),
115     TypeID::get<TF::BucketizeOp>(),
116     TypeID::get<TF::CastOp>(),
117     TypeID::get<TF::ClipByValueOp>(),
118     TypeID::get<TF::CholeskyOp>(),
119     TypeID::get<TF::ComplexAbsOp>(),
120     TypeID::get<TF::ConjugateTransposeOp>(),
121     TypeID::get<TF::CoshOp>(),
122     TypeID::get<TF::CrossOp>(),
123     TypeID::get<TF::DataFormatDimMapOp>(),
124     TypeID::get<TF::DataFormatVecPermuteOp>(),
125     TypeID::get<TF::DepthToSpaceOp>(),
126     TypeID::get<TF::DepthwiseConv2dNativeBackpropFilterOp>(),
127     TypeID::get<TF::DepthwiseConv2dNativeBackpropInputOp>(),
128     TypeID::get<TF::DiagOp>(),
129     TypeID::get<TF::DigammaOp>(),
130     TypeID::get<TF::DivNoNanOp>(),
131     TypeID::get<TF::EluGradOp>(),
132     TypeID::get<TF::EluOp>(),
133     TypeID::get<TF::EqualOp>(),
134     TypeID::get<TF::ErfcOp>(),
135     TypeID::get<TF::ErfinvOp>(),
136     TypeID::get<TF::ErfOp>(),
137     TypeID::get<TF::ExtractImagePatchesOp>(),
138     TypeID::get<TF::FFT2DOp>(),
139     TypeID::get<TF::FFT3DOp>(),
140     TypeID::get<TF::FFTOp>(),
141     TypeID::get<TF::FakeParamOp>(),
142     TypeID::get<TF::FakeQuantWithMinMaxArgsGradientOp>(),
143     TypeID::get<TF::FakeQuantWithMinMaxVarsGradientOp>(),
144     TypeID::get<TF::FloorDivOp>(),
145     TypeID::get<TF::FloorModOp>(),
146     TypeID::get<TF::GatherNdOp>(),
147     TypeID::get<TF::GreaterEqualOp>(),
148     TypeID::get<TF::GreaterOp>(),
149     TypeID::get<TF::HSVToRGBOp>(),
150     TypeID::get<TF::IFFT2DOp>(),
151     TypeID::get<TF::IFFT3DOp>(),
152     TypeID::get<TF::IRFFT2DOp>(),
153     TypeID::get<TF::IRFFT3DOp>(),
154     TypeID::get<TF::IgammaOp>(),
155     TypeID::get<TF::IgammacOp>(),
156     TypeID::get<TF::IgammaGradAOp>(),
157     TypeID::get<TF::InplaceAddOp>(),
158     TypeID::get<TF::InTopKV2Op>(),
159     TypeID::get<TF::InvertOp>(),
160     TypeID::get<TF::InvOp>(),
161     TypeID::get<TF::KthOrderStatisticOp>(),
162     TypeID::get<TF::LRNOp>(),
163     TypeID::get<TF::LRNGradOp>(),
164     TypeID::get<TF::LeakyReluGradOp>(),
165     TypeID::get<TF::LeakyReluOp>(),
166     TypeID::get<TF::LeftShiftOp>(),
167     TypeID::get<TF::LessEqualOp>(),
168     TypeID::get<TF::LessOp>(),
169     TypeID::get<TF::ListDiffOp>(),
170     TypeID::get<TF::LogicalAndOp>(),
171     TypeID::get<TF::LogicalNotOp>(),
172     TypeID::get<TF::LogicalOrOp>(),
173     TypeID::get<TF::LogOp>(),
174     TypeID::get<TF::LowerBoundOp>(),
175     TypeID::get<TF::MakeUniqueOp>(),
176     TypeID::get<TF::MatMulOp>(),
177     TypeID::get<TF::MatrixDiagV3Op>(),
178     TypeID::get<TF::MatrixInverseOp>(),
179     TypeID::get<TF::MatrixSetDiagV3Op>(),
180     TypeID::get<TF::MatrixSolveOp>(),
181     TypeID::get<TF::MatrixTriangularSolveOp>(),
182     TypeID::get<TF::MaxPool3DGradGradOp>(),
183     TypeID::get<TF::MaxPoolGradGradOp>(),
184     TypeID::get<TF::MirrorPadOp>(),
185     TypeID::get<TF::MirrorPadGradOp>(),
186     TypeID::get<TF::MulOp>(),
187     TypeID::get<TF::MultinomialOp>(),
188     TypeID::get<TF::NdtriOp>(),
189     TypeID::get<TF::NegOp>(),
190     TypeID::get<TF::NextAfterOp>(),
191     TypeID::get<TF::NonMaxSuppressionV4Op>(),
192     TypeID::get<TF::NotEqualOp>(),
193     TypeID::get<TF::PadOp>(),
194     TypeID::get<TF::ParameterizedTruncatedNormalOp>(),
195     TypeID::get<TF::PlaceholderWithDefaultOp>(),
196     TypeID::get<TF::PolygammaOp>(),
197     TypeID::get<TF::PopulationCountOp>(),
198     TypeID::get<TF::PowOp>(),
199     // TODO(hinsu): Canonicalize QuantizeAndDequantize and
200     // QuantizeAndDequantizeV2 to QuantizeAndDequantizeV3 by converting
201     // attributes to operands.
202     TypeID::get<TF::QuantizeAndDequantizeOp>(),
203     TypeID::get<TF::QuantizeAndDequantizeV2Op>(),
204     TypeID::get<TF::QuantizeAndDequantizeV3Op>(),
205     TypeID::get<TF::RFFT2DOp>(),
206     TypeID::get<TF::RFFT3DOp>(),
207     TypeID::get<TF::RGBToHSVOp>(),
208     TypeID::get<TF::RandomUniformIntOp>(),
209     TypeID::get<TF::RealDivOp>(),
210     TypeID::get<TF::ReciprocalGradOp>(),
211     TypeID::get<TF::Relu6GradOp>(),
212     TypeID::get<TF::ResizeBilinearOp>(),
213     TypeID::get<TF::ResizeBilinearGradOp>(),
214     TypeID::get<TF::ResizeNearestNeighborOp>(),
215     TypeID::get<TF::ResizeNearestNeighborGradOp>(),
216     TypeID::get<TF::ReverseSequenceOp>(),
217     TypeID::get<TF::RightShiftOp>(),
218     TypeID::get<TF::RintOp>(),
219     TypeID::get<TF::RollOp>(),
220     TypeID::get<TF::RoundOp>(),
221     TypeID::get<TF::SelectV2Op>(),
222     TypeID::get<TF::SelfAdjointEigV2Op>(),
223     TypeID::get<TF::SeluGradOp>(),
224     TypeID::get<TF::SeluOp>(),
225     TypeID::get<TF::SigmoidGradOp>(),
226     TypeID::get<TF::SinhOp>(),
227     TypeID::get<TF::SinOp>(),
228     TypeID::get<TF::SoftplusGradOp>(),
229     TypeID::get<TF::SoftsignGradOp>(),
230     TypeID::get<TF::SoftsignOp>(),
231     TypeID::get<TF::SpaceToBatchNDOp>(),
232     TypeID::get<TF::SpaceToBatchOp>(),
233     TypeID::get<TF::SpaceToDepthOp>(),
234     TypeID::get<TF::SparseToDenseOp>(),
235     TypeID::get<TF::SquareOp>(),
236     TypeID::get<TF::StatelessMultinomialOp>(),
237     TypeID::get<TF::StatelessRandomGetAlgOp>(),
238     TypeID::get<TF::StatelessRandomGetKeyCounterOp>(),
239     TypeID::get<TF::StatelessRandomGetKeyCounterAlgOp>(),
240     TypeID::get<TF::StatelessRandomNormalOp>(),
241     TypeID::get<TF::StatelessRandomNormalV2Op>(),
242     TypeID::get<TF::StatelessRandomUniformOp>(),
243     TypeID::get<TF::StatelessRandomUniformFullIntOp>(),
244     TypeID::get<TF::StatelessRandomUniformFullIntV2Op>(),
245     TypeID::get<TF::StatelessRandomUniformV2Op>(),
246     TypeID::get<TF::StatelessRandomUniformIntOp>(),
247     TypeID::get<TF::StatelessRandomUniformIntV2Op>(),
248     TypeID::get<TF::StatelessTruncatedNormalOp>(),
249     TypeID::get<TF::StatelessTruncatedNormalV2Op>(),
250     TypeID::get<TF::SubOp>(),
251     TypeID::get<TF::SvdOp>(),
252     TypeID::get<TF::TanOp>(),
253     TypeID::get<TF::TensorScatterAddOp>(),
254     TypeID::get<TF::TensorScatterSubOp>(),
255     TypeID::get<TF::TPUEmbeddingActivationsOp>(),
256     TypeID::get<TF::TopKUniqueOp>(),
257     TypeID::get<TF::TopKWithUniqueOp>(),
258     TypeID::get<TF::TransposeOp>(),
259     TypeID::get<TF::TridiagonalSolveOp>(),
260     TypeID::get<TF::TruncateDivOp>(),
261     TypeID::get<TF::TruncatedNormalOp>(),
262     TypeID::get<TF::TruncateModOp>(),
263     TypeID::get<TF::UnpackOp>(),
264     TypeID::get<TF::UpperBoundOp>(),
265     TypeID::get<TF::XlaBroadcastHelperOp>(),
266     TypeID::get<TF::XlaConvOp>(),
267     TypeID::get<TF::XlaDotOp>(),
268     TypeID::get<TF::XlaDynamicSliceOp>(),
269     TypeID::get<TF::XlaDynamicUpdateSliceOp>(),
270     TypeID::get<TF::XlaEinsumOp>(),
271     TypeID::get<TF::XlaKeyValueSortOp>(),
272     TypeID::get<TF::XlaPadOp>(),
273     TypeID::get<TF::XlaSetDynamicDimensionSizeOp>(),
274     TypeID::get<TF::XlaSortOp>(),
275     TypeID::get<TF::XlaSvdOp>(),
276     TypeID::get<TF::ZetaOp>()
277   };
278   // clang-format on
279 
280   auto* abstractOp = op->getAbstractOperation();
281   if (!abstractOp) return false;
282   return ops.count(abstractOp->typeID);
283 }
284 
285 namespace {
286 
287 template <typename T, size_t N>
288 using InlinedVector = tensorflow::gtl::InlinedVector<T, N>;  // non-absl ok
289 
CreateDeviceMgr(const std::string & device_type)290 static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
291     const std::string& device_type) {
292   // Register compilation kernels for all registered XLA backends.
293   tensorflow::XlaOpRegistry::RegisterCompilationKernels();
294 
295   auto device = absl::make_unique<tensorflow::XlaCompilationDevice>(
296       tensorflow::SessionOptions(), tensorflow::DeviceType(device_type));
297   return absl::make_unique<tensorflow::StaticDeviceMgr>(std::move(device));
298 }
299 
300 class Tf2XlaRewriter {
301  public:
RewriteOp(Operation * op,PatternRewriter & rewriter,const std::string & device_type)302   static LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter,
303                                  const std::string& device_type) {
304     Tf2XlaRewriter tf2xla_rewriter(op, rewriter, device_type);
305     return tf2xla_rewriter.LegalizeOp();
306   }
307 
308  private:
Tf2XlaRewriter(Operation * op,PatternRewriter & rewriter,const std::string & device_type)309   Tf2XlaRewriter(Operation* op, PatternRewriter& rewriter,
310                  const std::string& device_type)
311       : op_(op),
312         device_type_(device_type),
313         rewriter_(rewriter),
314         hlo_builder_(op->getName().getStringRef().str(), rewriter_,
315                      op->getLoc()),
316         context_(nullptr) {}
317 
~Tf2XlaRewriter()318   ~Tf2XlaRewriter() {
319     if (context_) context_->Unref();
320   }
321 
322   // Prepares OpKernelContext params common to all the ops.
323   // Emits an error on failure.
324   LogicalResult PrepareParams();
325 
326   // Tries to legalize the specified TensorFlow op, if supported.
327   //
328   // Emits an error and returns failure if an error is encountered during
329   // conversion. Note that success return value doesn't mean successful
330   // legalization.
331   LogicalResult LegalizeOp();
332 
333   // Converts the given operand to expression of kind kConstant or kXlaOp.
334   // Emits a remark and returns expression of kind kInvalid on failure.
335   tensorflow::XlaExpression GetExprForOperand(Value operand, Operation* op);
336 
337   Operation* op_;
338   std::string device_type_;
339 
340   PatternRewriter& rewriter_;
341   ::xla::MlirHloBuilder hlo_builder_;
342   tensorflow::OpOrArgLocNameMapper name_mapper_;
343 
344   tensorflow::XlaContext* context_;  // Ref-counted.
345 
346   std::unique_ptr<tensorflow::StaticDeviceMgr> device_mgr_;
347   tensorflow::Device* device_;  // Owned by device_mgr_;
348   std::unique_ptr<tensorflow::ScopedStepContainer> step_container_;
349   std::unique_ptr<tensorflow::FunctionLibraryDefinition> flib_def_;
350   std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr_;
351   tensorflow::OpKernelContext::Params params_;
352 };
353 
PrepareParams()354 LogicalResult Tf2XlaRewriter::PrepareParams() {
355   // XlaCompiler within the context is only used by the functional ops to
356   // compile functions. We are not handling those at the moment so XlaCompiler
357   // is not required.
358   context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_,
359                                         /*graph=*/nullptr);
360   context_->Ref();
361 
362   device_mgr_ = CreateDeviceMgr(device_type_);
363   if (!device_mgr_) return failure();
364 
365   // Type of params_.device is DeviceBase* so store it as Device* to access
366   // derived class method.
367   device_ = device_mgr_->ListDevices().front();
368   params_.device = device_;
369   params_.resource_manager = device_->resource_manager();
370 
371   // Resources are cleared at the time of device manager destruction so pass
372   // no-op cleanup function.
373   auto cleanup = [](const std::string& name) {};
374   // Use step_id zero as we only have a single context concurrently and
375   // concurrently running each of the MLIR functions create a new device.
376   step_container_ = absl::make_unique<tensorflow::ScopedStepContainer>(
377       /*step_id=*/0, cleanup);
378   tensorflow::Status status = step_container_->Create(
379       device_->resource_manager(),
380       tensorflow::XlaContext::kXlaContextResourceName, context_);
381   if (!status.ok()) {
382     return emitError(op_->getLoc())
383            << "failed to create XlaContext resource: " << status.ToString();
384   }
385   params_.step_container = step_container_.get();
386 
387   tensorflow::StatusOr<int64_t> version_or =
388       tensorflow::GetTfGraphProducerVersion(
389           op_->getParentOfType<mlir::ModuleOp>());
390   if (!version_or.ok()) {
391     return emitError(op_->getLoc()) << version_or.status().ToString();
392   }
393 
394   flib_def_ = absl::make_unique<tensorflow::FunctionLibraryDefinition>(
395       tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary());
396   pflr_ = absl::make_unique<tensorflow::ProcessFunctionLibraryRuntime>(
397       device_mgr_.get(), tensorflow::Env::Default(), /*config=*/nullptr,
398       version_or.ValueOrDie(), flib_def_.get(), tensorflow::OptimizerOptions());
399   params_.function_library = pflr_->GetFLR(device_->name());
400   return success();
401 }
402 
LegalizeOp()403 LogicalResult Tf2XlaRewriter::LegalizeOp() {
404   // Only static shaped operands are supported in XLA builders for now.
405   for (Type ty : op_->getOperandTypes()) {
406     auto ranked_ty = ty.dyn_cast<ShapedType>();
407     if (!ranked_ty || !ranked_ty.hasStaticShape()) {
408       return op_->emitRemark()
409              << "lowering requires static shaped tensor operands";
410     }
411   }
412 
413   for (const auto& attr : op_->getAttrs()) {
414     if (attr.second.isa<SymbolRefAttr>()) {
415       return op_->emitRemark()
416              << "ops with symbol references are not supported";
417     }
418   }
419 
420   auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef(
421       op_, name_mapper_.GetUniqueName(op_), /*ignore_unregistered_attrs=*/true);
422   if (!nodedef_or.ok()) {
423     return op_->emitRemark() << "failed to convert op to NodeDef: "
424                              << nodedef_or.status().ToString();
425   }
426 
427   if (failed(PrepareParams())) return failure();
428 
429   std::shared_ptr<const tensorflow::NodeProperties> props;
430   tensorflow::Status status = tensorflow::NodeProperties::CreateFromNodeDef(
431       *nodedef_or.ValueOrDie(),
432       params_.function_library->GetFunctionLibraryDefinition(), &props);
433   if (!status.ok()) {
434     return op_->emitRemark()
435            << "failed to create NodeProperties: " << status.ToString();
436   }
437   tensorflow::OpKernel* op_kernel_raw;
438   status = params_.function_library->CreateKernel(props, &op_kernel_raw);
439   if (!status.ok()) {
440     return op_->emitRemark()
441            << "failed to create tf2xla kernel: " << status.ToString();
442   }
443   // Transfer ownership of the kernel to a local smart pointer.
444   auto op_kernel = absl::WrapUnique(op_kernel_raw);
445 
446   std::vector<int> required_constants;
447   status = tensorflow::XlaOpRegistry::CompileTimeConstantInputs(
448       *op_kernel, &required_constants);
449   if (!status.ok()) {
450     return op_->emitRemark()
451            << "failed to compute required constants: " << status.ToString();
452   }
453   llvm::SmallDenseSet<int, 4> required_consts;
454   required_consts.insert(required_constants.begin(), required_constants.end());
455 
456   // TensorValue in inputs are backed by tensors which in turn depend on
457   // expressions. So, pre-allocate them to the required size.
458   InlinedVector<tensorflow::XlaExpression, 4> expressions;
459   InlinedVector<tensorflow::Tensor, 4> tensors;
460   InlinedVector<tensorflow::TensorValue, 4> inputs;
461   expressions.reserve(op_->getNumOperands());
462   tensors.reserve(op_->getNumOperands());
463   inputs.reserve(op_->getNumOperands());
464 
465   // Prepare the list of Tensor inputs for the kernel.
466   for (auto it : llvm::enumerate(op_->getOperands())) {
467     Value operand = it.value();
468     size_t idx = it.index();
469 
470     tensorflow::XlaExpression expr = GetExprForOperand(operand, op_);
471     tensorflow::XlaExpression::Kind kind = expr.kind();
472     if (kind == tensorflow::XlaExpression::Kind::kInvalid) return failure();
473     if (required_consts.count(idx) &&
474         kind != tensorflow::XlaExpression::Kind::kConstant) {
475       return op_->emitRemark()
476              << "lowering requires operand #" << idx << " to be a constant";
477     }
478     expressions.push_back(expr);
479 
480     if (!tensorflow::DataTypeCanUseMemcpy(expr.dtype())) {
481       return op_->emitRemark()
482              << "skipping legalization due to unsupported type "
483              << operand.getType();
484     }
485 
486     auto shape_or = expr.GetShape();
487     if (!shape_or.ok()) {
488       return op_->emitRemark()
489              << "failed to get shape for expression. " << expr.HumanString();
490     }
491 
492     tensors.emplace_back(
493         device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(),
494         shape_or.ValueOrDie());
495     tensorflow::Tensor& tensor = tensors.back();
496     tensorflow::XlaExpression::AssignExpressionToTensor(expr, &tensor);
497     inputs.emplace_back(&tensor);
498   }
499 
500   params_.inputs = &inputs;
501   params_.op_kernel = op_kernel.get();
502   llvm::SmallVector<tensorflow::AllocatorAttributes, 4> output_attr(
503       op_->getNumResults());
504   params_.output_attr_array = output_attr.data();
505 
506   hlo_builder_.setInsertionPoint(op_);
507   hlo_builder_.SetLocation(op_->getLoc());
508 
509   // Execute the kernel.
510   tensorflow::OpKernelContext op_context(&params_, op_->getNumResults());
511   device_->Compute(params_.op_kernel, &op_context);
512   if (!op_context.status().ok()) {
513     return op_->emitRemark()
514            << "compilation to HLO failed: " << op_context.status().ToString();
515   }
516 
517   // Replace uses of old results using the corresponding value after the
518   // lowering.
519   llvm::SmallVector<Value, 2> values;
520   values.reserve(op_->getNumResults());
521   for (int i = 0, e = op_->getNumResults(); i < e; i++) {
522     tensorflow::Tensor* output = op_context.mutable_output(i);
523     const tensorflow::XlaExpression* expr =
524         tensorflow::XlaExpression::CastExpressionFromTensor(*output);
525     if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp)
526       return op_->emitError(
527           "expects XlaExpression of kind kXlaOp in compiled output");
528     auto value = hlo_builder_.GetValue(expr->handle());
529     mlir::OpResult old_result = op_->getResult(i);
530     if (value.getType() != old_result.getType()) {
531       value = hlo_builder_.create<mlir::tensor::CastOp>(old_result.getType(),
532                                                         value);
533     }
534     values.push_back(value);
535   }
536   rewriter_.replaceOp(op_, values);
537   return success();
538 }
539 
GetExprForOperand(Value operand,Operation * op)540 tensorflow::XlaExpression Tf2XlaRewriter::GetExprForOperand(Value operand,
541                                                             Operation* op) {
542   ElementsAttr const_attr;
543   auto defining_op = operand.getDefiningOp();
544   if (defining_op && matchPattern(defining_op, m_Constant(&const_attr))) {
545     tensorflow::Tensor tensor;
546     auto status = tensorflow::ConvertToTensor(const_attr, &tensor);
547     if (!status.ok()) {
548       op->emitRemark() << "skipping legalization due to failed const conversion"
549                        << status.ToString();
550       return tensorflow::XlaExpression::Invalid();
551     }
552     return tensorflow::XlaExpression::Constant(tensor);
553   }
554 
555   // Skip this op if XLA doesn't support this operand type.
556   auto xla_op_or = hlo_builder_.MakeXlaOp(operand);
557   if (!xla_op_or.ok()) {
558     op->emitRemark() << "skipping legalization due to "
559                      << xla_op_or.status().ToString();
560     return tensorflow::XlaExpression::Invalid();
561   }
562   ::xla::XlaOp xla_op = xla_op_or.ValueOrDie();
563 
564   tensorflow::DataType dtype;
565   auto status = tensorflow::ConvertToDataType(operand.getType(), &dtype);
566   if (!status.ok()) {
567     op->emitRemark() << "skipping legalization due to " << status.ToString();
568     return tensorflow::XlaExpression::Invalid();
569   }
570   return tensorflow::XlaExpression::XlaOp(xla_op, dtype);
571 }
572 
573 class Tf2XlaRewritePattern : public RewritePattern {
574  public:
Tf2XlaRewritePattern(const std::string & device_type)575   explicit Tf2XlaRewritePattern(const std::string& device_type)
576       : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()),
577         device_type_(device_type) {}
578 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const579   LogicalResult matchAndRewrite(Operation* op,
580                                 PatternRewriter& rewriter) const override {
581     if (!IsOpAllowedTf2XlaFallback(op)) return failure();
582     return Tf2XlaRewriter::RewriteOp(op, rewriter, device_type_);
583   }
584 
585  private:
586   std::string device_type_;
587 };
588 
589 class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
590  public:
591   LegalizeTF() = default;
592 
LegalizeTF(llvm::StringRef device_type)593   explicit LegalizeTF(llvm::StringRef device_type) {
594     device_type_ = device_type.str();
595   }
596 
LegalizeTF(const LegalizeTF &)597   LegalizeTF(const LegalizeTF&) {}
598 
runOnFunction()599   void runOnFunction() override {
600     OwningRewritePatternList patterns;
601     patterns.insert<Tf2XlaRewritePattern>(device_type_);
602     if (failed(
603             applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
604       signalPassFailure();
605   }
606 
607  private:
608   // TODO(hinsu): Support finer grained device type assignment instead of a
609   // global device type for all TensorFlow ops.
610   Option<std::string> device_type_{
611       *this, "device-type",
612       llvm::cl::desc("XLA device type for execution of TensorFlow ops.")};
613 };
614 
615 static PassRegistration<LegalizeTF> pass(
616     "xla-legalize-tf-with-tf2xla",
617     "Legalize from TensorFlow to the HLO dialect using tf2xla kernels");
618 
619 }  // end namespace
620 
PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,OwningRewritePatternList & patterns)621 void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
622                                           OwningRewritePatternList& patterns) {
623   patterns.insert<Tf2XlaRewritePattern>(device_type.str());
624 }
625 
createLegalizeTfWithTf2XlaPass(llvm::StringRef device_type)626 std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
627     llvm::StringRef device_type) {
628   return std::make_unique<LegalizeTF>(device_type);
629 }
630 
631 }  // end namespace mhlo
632 }  // end namespace mlir
633