1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_ 18 19 // Support for eager execution of TensorFlow kernels. 20 21 #include <memory> 22 #include <unordered_map> 23 24 // clang-format off 25 // Required for IS_MOBILE_PLATFORM 26 #include "absl/memory/memory.h" 27 #include "tensorflow/core/platform/platform.h" 28 // clang-format on 29 30 #include "absl/container/flat_hash_map.h" 31 #include "absl/types/optional.h" 32 #include "tensorflow/core/common_runtime/device.h" 33 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 34 #include "tensorflow/core/framework/cancellation.h" 35 #include "tensorflow/core/framework/collective.h" 36 #include "tensorflow/core/framework/node_def.pb.h" 37 #include "tensorflow/core/framework/op_kernel.h" 38 #include "tensorflow/core/framework/types.h" 39 #include "tensorflow/core/lib/core/errors.h" 40 #include "tensorflow/core/lib/core/status.h" 41 #include "tensorflow/core/lib/gtl/inlined_vector.h" 42 #include "tensorflow/core/platform/fingerprint.h" 43 #include "tensorflow/core/util/managed_stack_trace.h" 44 #include "tensorflow/core/util/tensor_slice_reader_cache.h" 45 #if !defined(IS_MOBILE_PLATFORM) 46 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" 47 #endif // IS_MOBILE_PLATFORM 48 49 namespace tensorflow { 50 51 static constexpr const char* const kOutputsOnOpDevice = "_OutputsOnOpDevice"; 52 53 class ProcessFunctionLibraryRuntime; 54 class FunctionLibraryRuntime; 55 56 struct EagerRemoteFunctionParams { 57 int64 op_id; 58 // Set when this function is a component function. 59 absl::optional<int64> step_id = absl::nullopt; 60 }; 61 62 class EagerKernelArgs : public FunctionArgsInterface { 63 public: EagerKernelArgs()64 EagerKernelArgs() {} 65 EagerKernelArgs(int count)66 explicit EagerKernelArgs(int count) : tensor_args_(count) {} 67 EagerKernelArgs(gtl::InlinedVector<TensorValue,4> && tensor_args)68 explicit EagerKernelArgs(gtl::InlinedVector<TensorValue, 4>&& tensor_args) 69 : tensor_args_(std::move(tensor_args)) {} 70 ~EagerKernelArgs()71 ~EagerKernelArgs() override{}; 72 HasRemoteOrPackedInputs()73 bool HasRemoteOrPackedInputs() const override { return false; }; MutableInput(int i)74 TensorValue* MutableInput(int i) { return &tensor_args_[i]; } 75 76 Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override; 77 78 std::vector<Tensor> GetLocalTensors() const override; 79 GetTensorValues()80 const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const { 81 return &tensor_args_; 82 } 83 84 protected: 85 gtl::InlinedVector<TensorValue, 4> tensor_args_; 86 }; 87 88 typedef absl::variant<Tensor, TensorShape> EagerKernelRet; 89 90 // KernelAndDevice encapsulates the logic needed to run a computation eagerly. 91 // The computation can be a single instantiated kernel (implemented by 92 // KernelAndDeviceOp below) or a multi-device function (implemented by 93 // KernelAndDeviceFunc below). 94 // 95 // Also see: 96 // https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h 97 // and 98 // https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h 99 class KernelAndDevice : public core::RefCounted { 100 public: 101 // Populates this with a kernel appropriate for 'ndef'. 102 // 103 // The provided FunctionLibraryRuntime MUST outlive all calls to 104 // Run() on the returned KernelAndDevice. 105 virtual Status Init(const bool log_device_placement, const NodeDef& ndef, 106 GraphCollector* graph_collector) = 0; 107 108 // Non-multi-device functions are run using regular CallOp and look like 109 // primitive operations from KernelAndDevice perspective. 110 // `flr` can be nullptr if the operation is not run on any specific device 111 // (currently can happen only for multi-device functions). KernelAndDevice(FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)112 KernelAndDevice( 113 FunctionLibraryRuntime* flr, 114 std::function<void(std::function<void()>)>* runner, 115 std::unique_ptr<CollectiveExecutor::Handle> collective_executor, 116 Device* host_cpu_device) 117 : device_(flr == nullptr ? nullptr : flr->device()), 118 host_cpu_device_(host_cpu_device), 119 flr_(flr), 120 collective_executor_(std::move(collective_executor)), 121 runner_(runner) {} 122 123 // Not thread safe. ~KernelAndDevice()124 ~KernelAndDevice() override {} 125 IsFunction()126 virtual bool IsFunction() { return false; } 127 IsCrossProcess()128 virtual bool IsCrossProcess() { return false; } 129 130 // TODO(ashankar): Handle list-valued inputs. 131 virtual Status Run( 132 ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 133 std::vector<EagerKernelRet>* outputs, 134 CancellationManager* cancellation_manager, 135 const absl::optional<EagerRemoteFunctionParams>& remote_func_params, 136 const absl::optional<ManagedStackTrace>& stack_trace) = 0; 137 138 // Execute kernel asynchronously when applicable. Different from `Run` which 139 // blocks the caller thread and waits for the execution of the op/function, 140 // `RunAsync` could return before finishing the execution. The `done` callback 141 // will be triggered once the op/function execution finishes. 142 // Currently, calling RunAsync on ops might not honor the asynchronicity when 143 // it is called on an instance with only sync implementation, execute the 144 // kernel synchronously and then call the callback with the return status 145 // from sync execution. 146 virtual void RunAsync( 147 ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 148 std::vector<EagerKernelRet>* outputs, 149 CancellationManager* cancellation_manager, 150 const absl::optional<EagerRemoteFunctionParams>& remote_func_params, 151 StatusCallback done) = 0; 152 153 virtual Device* InputDevice(int i) const = 0; 154 virtual Device* OutputDevice(int idx) const = 0; 155 // If idx'th output is a resource, returns the device backing the resource. 156 // Else, returns nullptr. 157 virtual Device* OutputResourceDevice(int idx) const = 0; 158 159 // Returns the kernel that will be used to run this. 160 // Returns nullptr if this will be run using function library runtime. 161 virtual const OpKernel* kernel() const = 0; 162 163 // Returns the device on which this kernel will run. In the case of 164 // multi-device functions, this is the default device that is passed to the 165 // placer but actual computation can happen on a different set of devices. 166 // Also, outputs can be produced on devices different from what this method 167 // returns. device()168 Device* device() const { return device_; } 169 170 virtual const DataTypeVector& input_dtypes() const = 0; 171 virtual const DataTypeVector& output_dtypes() const = 0; 172 173 virtual int num_inputs() const = 0; 174 virtual int num_outputs() const = 0; 175 virtual const string& name() const = 0; 176 177 protected: 178 std::function<void(std::function<void()>)>* get_runner() const; 179 180 Device* const device_; // can be null 181 Device* const host_cpu_device_; // non-null 182 FunctionLibraryRuntime* const flr_; // can be null 183 const std::unique_ptr<CollectiveExecutor::Handle> collective_executor_; 184 185 private: 186 std::function<void(std::function<void()>)>* const runner_; // can be null 187 }; 188 189 // Represents an op kernel and the device it will be run on. 190 class KernelAndDeviceOp final : public KernelAndDevice { 191 public: KernelAndDeviceOp(tensorflow::Rendezvous * rendezvous,bool log_memory,FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)192 KernelAndDeviceOp( 193 tensorflow::Rendezvous* rendezvous, bool log_memory, 194 FunctionLibraryRuntime* flr, 195 std::function<void(std::function<void()>)>* runner, 196 std::unique_ptr<CollectiveExecutor::Handle> collective_executor, 197 Device* host_cpu_device) 198 : KernelAndDevice(flr, runner, std::move(collective_executor), 199 host_cpu_device), 200 rendezvous_(rendezvous), 201 log_memory_(log_memory) {} 202 ~KernelAndDeviceOp()203 ~KernelAndDeviceOp() override {} 204 205 Status Init(const bool log_device_placement, const NodeDef& ndef, 206 GraphCollector* graph_collector) override; 207 208 Status Run( 209 ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 210 std::vector<EagerKernelRet>* outputs, 211 CancellationManager* cancellation_manager, 212 const absl::optional<EagerRemoteFunctionParams>& remote_func_params, 213 const absl::optional<ManagedStackTrace>& stack_trace) override; 214 RunAsync(ScopedStepContainer * step_container,const EagerKernelArgs & inputs,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,StatusCallback done)215 void RunAsync( 216 ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 217 std::vector<EagerKernelRet>* outputs, 218 CancellationManager* cancellation_manager, 219 const absl::optional<EagerRemoteFunctionParams>& remote_func_params, 220 StatusCallback done) override { 221 // Trivial async implementation on top of the sync version 222 done(Run(step_container, inputs, outputs, cancellation_manager, 223 remote_func_params, {})); 224 } 225 kernel()226 const OpKernel* kernel() const override { return kernel_.get(); } 227 228 Device* InputDevice(int i) const override; 229 Device* OutputDevice(int idx) const override; 230 Device* OutputResourceDevice(int idx) const override; 231 input_dtypes()232 const DataTypeVector& input_dtypes() const override { 233 return kernel_->input_types(); 234 } output_dtypes()235 const DataTypeVector& output_dtypes() const override { 236 return kernel_->output_types(); 237 } num_inputs()238 int num_inputs() const override { return kernel_->num_inputs(); } num_outputs()239 int num_outputs() const override { return kernel_->num_outputs(); } name()240 const string& name() const override { return kernel_->name(); } 241 242 private: 243 std::unique_ptr<OpKernel> kernel_; 244 gtl::InlinedVector<AllocatorAttributes, 4> input_alloc_attrs_; 245 std::vector<Device*> input_devices_; 246 gtl::InlinedVector<AllocatorAttributes, 1> output_alloc_attrs_; 247 Rendezvous* const rendezvous_; 248 checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; 249 const bool log_memory_; 250 }; 251 252 // Represents a multi-device function. Functions can also be run using 253 // various function-calling kernels including CallOp and PartitionedCallOp. 254 // In such cases, KernelAndDeviceOp is used. 255 class KernelAndDeviceFunc : public KernelAndDevice { 256 public: 257 // `flr` can be nullptr. 258 // `pflr` must not be nullptr. 259 // `host_cpu_device` must not be nullptr. KernelAndDeviceFunc(FunctionLibraryRuntime * flr,ProcessFunctionLibraryRuntime * pflr,std::vector<Device * > input_devices,absl::flat_hash_map<string,const std::vector<string> * > composite_devices,std::unordered_map<int,DtypeAndPartialTensorShape> input_resource_dtypes_and_shapes,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device,const string & name,const bool outputs_on_op_device,std::function<Rendezvous * (const int64)> rendezvous_creator,std::function<int64 ()> get_op_id)260 KernelAndDeviceFunc( 261 FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr, 262 std::vector<Device*> input_devices, 263 absl::flat_hash_map<string, const std::vector<string>*> composite_devices, 264 std::unordered_map<int, DtypeAndPartialTensorShape> 265 input_resource_dtypes_and_shapes, 266 std::function<void(std::function<void()>)>* runner, 267 std::unique_ptr<CollectiveExecutor::Handle> collective_executor, 268 Device* host_cpu_device, const string& name, 269 const bool outputs_on_op_device, 270 std::function<Rendezvous*(const int64)> rendezvous_creator, 271 std::function<int64()> get_op_id) 272 : KernelAndDevice(flr, runner, std::move(collective_executor), 273 host_cpu_device), 274 pflr_(pflr), 275 handle_(kInvalidHandle), 276 outputs_on_op_device_(outputs_on_op_device), 277 input_devices_(std::move(input_devices)), 278 composite_devices_(std::move(composite_devices)), 279 input_resource_dtypes_and_shapes_( 280 std::move(input_resource_dtypes_and_shapes)), 281 name_(name), 282 rendezvous_creator_(std::move(rendezvous_creator)), 283 get_op_id_(std::move(get_op_id)) {} 284 285 ~KernelAndDeviceFunc() override; 286 IsFunction()287 bool IsFunction() override { return true; }; 288 IsCrossProcess()289 bool IsCrossProcess() override { return is_cross_process_; } 290 291 Status InstantiateFunc(const bool log_device_placement, const NodeDef& ndef, 292 GraphCollector* graph_collector); 293 294 Status Init(const bool log_device_placement, const NodeDef& ndef, 295 GraphCollector* graph_collector) override; 296 297 Status Run( 298 ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 299 std::vector<EagerKernelRet>* outputs, 300 CancellationManager* cancellation_manager, 301 const absl::optional<EagerRemoteFunctionParams>& remote_func_params, 302 const absl::optional<ManagedStackTrace>& stack_trace) override; 303 304 void RunAsync( 305 ScopedStepContainer* step_container, const EagerKernelArgs& inputs, 306 std::vector<EagerKernelRet>* outputs, 307 CancellationManager* cancellation_manager, 308 const absl::optional<EagerRemoteFunctionParams>& remote_func_params, 309 StatusCallback done) override; 310 kernel()311 const OpKernel* kernel() const override { return nullptr; } 312 313 Device* InputDevice(int i) const override; 314 Device* OutputDevice(int idx) const override; 315 Device* OutputResourceDevice(int idx) const override; 316 input_dtypes()317 const DataTypeVector& input_dtypes() const override { return input_dtypes_; } output_dtypes()318 const DataTypeVector& output_dtypes() const override { 319 return output_dtypes_; 320 } num_inputs()321 int num_inputs() const override { return input_dtypes_.size(); } num_outputs()322 int num_outputs() const override { return output_dtypes_.size(); } name()323 const string& name() const override { return name_; }; 324 325 private: 326 ProcessFunctionLibraryRuntime* const pflr_; // non-null 327 FunctionLibraryRuntime::Handle handle_; 328 // Indicates whether the function needs to execute cross process. 329 bool is_cross_process_; 330 331 // If true, function outputs are explicitly assigned to the default device; 332 // if false, the output devices are inferred by pflr_. 333 bool outputs_on_op_device_; 334 335 // CPU devices are null. Resource handles' devices are actual backing 336 // devices. 337 std::vector<Device*> output_devices_; 338 // CPU devices are not null. Resource handles' devices are actual backing 339 // devices. 340 std::vector<Device*> input_devices_; 341 // Maps from a CompositeDevice name to a list of physical device names. 342 absl::flat_hash_map<string, const std::vector<string>*> composite_devices_; 343 std::unordered_map<int, DtypeAndPartialTensorShape> 344 input_resource_dtypes_and_shapes_; 345 346 DataTypeVector input_dtypes_; 347 DataTypeVector output_dtypes_; 348 string name_; 349 350 std::function<Rendezvous*(const int64)> rendezvous_creator_; 351 std::function<int64()> get_op_id_; 352 }; 353 354 } // namespace tensorflow 355 356 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_ 357