1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ 17 #define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ 18 19 #include <atomic> 20 21 #include "tensorflow/compiler/jit/xla_compilation_cache.h" 22 #include "tensorflow/compiler/jit/xla_device.h" 23 #include "tensorflow/compiler/jit/xla_launch_util.h" 24 #include "tensorflow/core/framework/allocator.h" 25 #include "tensorflow/core/framework/op.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/platform/macros.h" 29 #include "tensorflow/core/util/stream_executor_util.h" 30 31 namespace tensorflow { 32 33 // Holds some information about the platform on which an 34 // XlaLaunch/_XlaCompile/_XlaRun op must run on. 35 class XlaPlatformInfo { 36 public: XlaPlatformInfo()37 XlaPlatformInfo() : device_type_("") {} 38 XlaPlatformInfo(XlaPlatformInfo&&) = default; XlaPlatformInfo(const DeviceType device_type,se::Platform::Id platform_id,const XlaDevice::Metadata * xla_device_metadata,std::unique_ptr<XlaAllocator> xla_allocator,xla::DeviceMemoryAllocator * device_allocator)39 explicit XlaPlatformInfo(const DeviceType device_type, 40 se::Platform::Id platform_id, 41 const XlaDevice::Metadata* xla_device_metadata, 42 std::unique_ptr<XlaAllocator> xla_allocator, 43 xla::DeviceMemoryAllocator* device_allocator) 44 : device_type_(device_type), 45 platform_id_(platform_id), 46 xla_device_metadata_(xla_device_metadata), 47 xla_allocator_(std::move(xla_allocator)), 48 device_allocator_(device_allocator) { 49 CHECK((device_allocator_ != nullptr) ^ (xla_allocator_.get() != nullptr)); 50 } 51 52 XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default; 53 UseMultipleStreams()54 bool UseMultipleStreams() const { 55 return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams(); 56 } 57 allocator()58 xla::DeviceMemoryAllocator* allocator() const { 59 return device_allocator_ ? device_allocator_ : xla_allocator_.get(); 60 } device_type()61 DeviceType device_type() const { return device_type_; } 62 63 // This is equal to xla_device_metadata()->platform()->id() if 64 // xla_device_metadata() is not nullptr. platform_id()65 se::Platform::Id platform_id() const { return platform_id_; } 66 67 // This may be null if the op this XlaPlatformInfo is for was not placed on an 68 // XLA device. xla_device_metadata()69 const XlaDevice::Metadata* xla_device_metadata() const { 70 return xla_device_metadata_; 71 } is_on_xla_device()72 bool is_on_xla_device() const { return xla_device_metadata() != nullptr; } 73 74 private: 75 DeviceType device_type_; 76 se::Platform::Id platform_id_; 77 78 // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the 79 // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the 80 // XlaLaunch/_XlaCompile/_XlaRun OpKernel. 81 const XlaDevice::Metadata* xla_device_metadata_; 82 83 // If the op associated with this XlaPlatformInfo is placed on an XLA device 84 // then device_allocator_ is the xla::Backend's memory allocator and 85 // xla_allocator_ is null. If the op is placed on a regular CPU or GPU device 86 // then device_allocator_ is null and xla_allocator_ points to an appropriate 87 // XlaAllocator instance. 88 std::unique_ptr<XlaAllocator> xla_allocator_; 89 xla::DeviceMemoryAllocator* device_allocator_; 90 91 TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo); 92 }; 93 94 // XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. 95 // The only difference is that it does not require arguments to follow 96 // the "constants, then regular args, then resources" order. 97 // It takes vectors of constant and resource arguments explicitly. 98 // It does not have corresponding OpDef because it is never present 99 // in the GraphDef. 100 // Currently, it is used by eager runtime. FunctionLibraryRuntime creates 101 // this kernel when asked to create a kernel for an XLA-compiled function. 102 class XlaLocalLaunchBase : public OpKernel { 103 public: 104 XlaLocalLaunchBase(OpKernelConstruction* ctx, 105 const std::vector<int>& constants, 106 const std::vector<int>& resources, 107 const NameAttrList& function); 108 XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete; 109 XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete; 110 ~XlaLocalLaunchBase() override = default; 111 112 void Compute(OpKernelContext* ctx) override; 113 114 protected: 115 // Indexes of compile-time constant inputs 116 const std::vector<int> constants_; 117 // Indexes of resource inputs 118 const std::vector<int> resources_; 119 120 const NameAttrList function_; 121 const XlaPlatformInfo platform_info_; 122 }; 123 124 // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph 125 // which will be compiled and executed using XLA. The XlaLocalLaunchOp is 126 // responsible for handling interactions with the TensorFlow executor. 127 // Once all inputs are present, and their shapes are known, the op can 128 // use a 'XlaCompilationCache' to compile and execute code which is specific 129 // to the shapes of input Tensors. 130 // XlaLocalLaunchOp uses xla::LocalClient::Compile() and 131 // xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device 132 // memory. 133 class XlaLocalLaunchOp : public XlaLocalLaunchBase { 134 public: 135 explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); 136 ~XlaLocalLaunchOp() override; 137 138 private: 139 TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); 140 }; 141 142 class XlaCompileOp : public OpKernel { 143 public: 144 explicit XlaCompileOp(OpKernelConstruction* ctx); 145 146 void Compute(OpKernelContext* ctx) override; 147 148 private: 149 // Indexes of compile-time constant inputs 150 const std::vector<int> constants_; 151 // Indexes of resource inputs 152 const std::vector<int> resources_; 153 154 const NameAttrList function_; 155 156 XlaPlatformInfo platform_info_; 157 158 const bool must_compile_; 159 160 // cannot_compile_cluster_ is set to true if XLA returns an Unimplemented 161 // error when compiling the cluster this _XlaCompile is supposed to compile. 162 // If `cannot_compile_cluster_` is true then we avoid compiling this cluster 163 // on any future calls to _XlaCompile. 164 bool cannot_compile_cluster_ GUARDED_BY(cannot_compile_cluster_mu_) = false; 165 166 mutex cannot_compile_cluster_mu_; 167 }; 168 169 class XlaRunOp : public OpKernel { 170 public: 171 explicit XlaRunOp(OpKernelConstruction* ctx); 172 173 void Compute(OpKernelContext* ctx) override; 174 175 private: 176 const XlaPlatformInfo platform_info_; 177 }; 178 179 } // namespace tensorflow 180 181 #endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ 182