1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_
18 
19 // Support for eager execution of TensorFlow kernels.
20 
21 #include <memory>
22 #include <unordered_map>
23 
24 // clang-format off
25 // Required for IS_MOBILE_PLATFORM
26 #include "absl/memory/memory.h"
27 #include "tensorflow/core/platform/platform.h"
28 // clang-format on
29 
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/types/optional.h"
32 #include "tensorflow/core/common_runtime/device.h"
33 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
34 #include "tensorflow/core/framework/cancellation.h"
35 #include "tensorflow/core/framework/collective.h"
36 #include "tensorflow/core/framework/node_def.pb.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/core/status.h"
41 #include "tensorflow/core/lib/gtl/inlined_vector.h"
42 #include "tensorflow/core/platform/fingerprint.h"
43 #include "tensorflow/core/util/managed_stack_trace.h"
44 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
45 #if !defined(IS_MOBILE_PLATFORM)
46 #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
47 #endif  // IS_MOBILE_PLATFORM
48 
49 namespace tensorflow {
50 
51 static constexpr const char* const kOutputsOnOpDevice = "_OutputsOnOpDevice";
52 
53 class ProcessFunctionLibraryRuntime;
54 class FunctionLibraryRuntime;
55 
56 struct EagerRemoteFunctionParams {
57   int64 op_id;
58   // Set when this function is a component function.
59   absl::optional<int64> step_id = absl::nullopt;
60 };
61 
62 class EagerKernelArgs : public FunctionArgsInterface {
63  public:
EagerKernelArgs()64   EagerKernelArgs() {}
65 
EagerKernelArgs(int count)66   explicit EagerKernelArgs(int count) : tensor_args_(count) {}
67 
EagerKernelArgs(gtl::InlinedVector<TensorValue,4> && tensor_args)68   explicit EagerKernelArgs(gtl::InlinedVector<TensorValue, 4>&& tensor_args)
69       : tensor_args_(std::move(tensor_args)) {}
70 
~EagerKernelArgs()71   ~EagerKernelArgs() override{};
72 
HasRemoteOrPackedInputs()73   bool HasRemoteOrPackedInputs() const override { return false; };
MutableInput(int i)74   TensorValue* MutableInput(int i) { return &tensor_args_[i]; }
75 
76   Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override;
77 
78   std::vector<Tensor> GetLocalTensors() const override;
79 
GetTensorValues()80   const gtl::InlinedVector<TensorValue, 4>* GetTensorValues() const {
81     return &tensor_args_;
82   }
83 
84  protected:
85   gtl::InlinedVector<TensorValue, 4> tensor_args_;
86 };
87 
88 typedef absl::variant<Tensor, TensorShape> EagerKernelRet;
89 
90 // KernelAndDevice encapsulates the logic needed to run a computation eagerly.
91 // The computation can be a single instantiated kernel (implemented by
92 // KernelAndDeviceOp below) or a multi-device function (implemented by
93 // KernelAndDeviceFunc below).
94 //
95 // Also see:
96 // https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
97 // and
98 // https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h
99 class KernelAndDevice : public core::RefCounted {
100  public:
101   // Populates this with a kernel appropriate for 'ndef'.
102   //
103   // The provided FunctionLibraryRuntime MUST outlive all calls to
104   // Run() on the returned KernelAndDevice.
105   virtual Status Init(const bool log_device_placement, const NodeDef& ndef,
106                       GraphCollector* graph_collector) = 0;
107 
108   // Non-multi-device functions are run using regular CallOp and look like
109   // primitive operations from KernelAndDevice perspective.
110   // `flr` can be nullptr if the operation is not run on any specific device
111   // (currently can happen only for multi-device functions).
KernelAndDevice(FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)112   KernelAndDevice(
113       FunctionLibraryRuntime* flr,
114       std::function<void(std::function<void()>)>* runner,
115       std::unique_ptr<CollectiveExecutor::Handle> collective_executor,
116       Device* host_cpu_device)
117       : device_(flr == nullptr ? nullptr : flr->device()),
118         host_cpu_device_(host_cpu_device),
119         flr_(flr),
120         collective_executor_(std::move(collective_executor)),
121         runner_(runner) {}
122 
123   // Not thread safe.
~KernelAndDevice()124   ~KernelAndDevice() override {}
125 
IsFunction()126   virtual bool IsFunction() { return false; }
127 
IsCrossProcess()128   virtual bool IsCrossProcess() { return false; }
129 
130   // TODO(ashankar): Handle list-valued inputs.
131   virtual Status Run(
132       ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
133       std::vector<EagerKernelRet>* outputs,
134       CancellationManager* cancellation_manager,
135       const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
136       const absl::optional<ManagedStackTrace>& stack_trace) = 0;
137 
138   // Execute kernel asynchronously when applicable. Different from `Run` which
139   // blocks the caller thread and waits for the execution of the op/function,
140   // `RunAsync` could return before finishing the execution. The `done` callback
141   // will be triggered once the op/function execution finishes.
142   // Currently, calling RunAsync on ops might not honor the asynchronicity when
143   // it is called on an instance with only sync implementation, execute the
144   // kernel synchronously and then call the callback with the return status
145   // from sync execution.
146   virtual void RunAsync(
147       ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
148       std::vector<EagerKernelRet>* outputs,
149       CancellationManager* cancellation_manager,
150       const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
151       StatusCallback done) = 0;
152 
153   virtual Device* InputDevice(int i) const = 0;
154   virtual Device* OutputDevice(int idx) const = 0;
155   // If idx'th output is a resource, returns the device backing the resource.
156   // Else, returns nullptr.
157   virtual Device* OutputResourceDevice(int idx) const = 0;
158 
159   // Returns the kernel that will be used to run this.
160   // Returns nullptr if this will be run using function library runtime.
161   virtual const OpKernel* kernel() const = 0;
162 
163   // Returns the device on which this kernel will run. In the case of
164   // multi-device functions, this is the default device that is passed to the
165   // placer but actual computation can happen on a different set of devices.
166   // Also, outputs can be produced on devices different from what this method
167   // returns.
device()168   Device* device() const { return device_; }
169 
170   virtual const DataTypeVector& input_dtypes() const = 0;
171   virtual const DataTypeVector& output_dtypes() const = 0;
172 
173   virtual int num_inputs() const = 0;
174   virtual int num_outputs() const = 0;
175   virtual const string& name() const = 0;
176 
177  protected:
178   std::function<void(std::function<void()>)>* get_runner() const;
179 
180   Device* const device_;               // can be null
181   Device* const host_cpu_device_;      // non-null
182   FunctionLibraryRuntime* const flr_;  // can be null
183   const std::unique_ptr<CollectiveExecutor::Handle> collective_executor_;
184 
185  private:
186   std::function<void(std::function<void()>)>* const runner_;  // can be null
187 };
188 
189 // Represents an op kernel and the device it will be run on.
190 class KernelAndDeviceOp final : public KernelAndDevice {
191  public:
KernelAndDeviceOp(tensorflow::Rendezvous * rendezvous,bool log_memory,FunctionLibraryRuntime * flr,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device)192   KernelAndDeviceOp(
193       tensorflow::Rendezvous* rendezvous, bool log_memory,
194       FunctionLibraryRuntime* flr,
195       std::function<void(std::function<void()>)>* runner,
196       std::unique_ptr<CollectiveExecutor::Handle> collective_executor,
197       Device* host_cpu_device)
198       : KernelAndDevice(flr, runner, std::move(collective_executor),
199                         host_cpu_device),
200         rendezvous_(rendezvous),
201         log_memory_(log_memory) {}
202 
~KernelAndDeviceOp()203   ~KernelAndDeviceOp() override {}
204 
205   Status Init(const bool log_device_placement, const NodeDef& ndef,
206               GraphCollector* graph_collector) override;
207 
208   Status Run(
209       ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
210       std::vector<EagerKernelRet>* outputs,
211       CancellationManager* cancellation_manager,
212       const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
213       const absl::optional<ManagedStackTrace>& stack_trace) override;
214 
RunAsync(ScopedStepContainer * step_container,const EagerKernelArgs & inputs,std::vector<EagerKernelRet> * outputs,CancellationManager * cancellation_manager,const absl::optional<EagerRemoteFunctionParams> & remote_func_params,StatusCallback done)215   void RunAsync(
216       ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
217       std::vector<EagerKernelRet>* outputs,
218       CancellationManager* cancellation_manager,
219       const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
220       StatusCallback done) override {
221     // Trivial async implementation on top of the sync version
222     done(Run(step_container, inputs, outputs, cancellation_manager,
223              remote_func_params, {}));
224   }
225 
kernel()226   const OpKernel* kernel() const override { return kernel_.get(); }
227 
228   Device* InputDevice(int i) const override;
229   Device* OutputDevice(int idx) const override;
230   Device* OutputResourceDevice(int idx) const override;
231 
input_dtypes()232   const DataTypeVector& input_dtypes() const override {
233     return kernel_->input_types();
234   }
output_dtypes()235   const DataTypeVector& output_dtypes() const override {
236     return kernel_->output_types();
237   }
num_inputs()238   int num_inputs() const override { return kernel_->num_inputs(); }
num_outputs()239   int num_outputs() const override { return kernel_->num_outputs(); }
name()240   const string& name() const override { return kernel_->name(); }
241 
242  private:
243   std::unique_ptr<OpKernel> kernel_;
244   gtl::InlinedVector<AllocatorAttributes, 4> input_alloc_attrs_;
245   std::vector<Device*> input_devices_;
246   gtl::InlinedVector<AllocatorAttributes, 1> output_alloc_attrs_;
247   Rendezvous* const rendezvous_;
248   checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
249   const bool log_memory_;
250 };
251 
252 // Represents a multi-device function. Functions can also be run using
253 // various function-calling kernels including CallOp and PartitionedCallOp.
254 // In such cases, KernelAndDeviceOp is used.
255 class KernelAndDeviceFunc : public KernelAndDevice {
256  public:
257   // `flr` can be nullptr.
258   // `pflr` must not be nullptr.
259   // `host_cpu_device` must not be nullptr.
KernelAndDeviceFunc(FunctionLibraryRuntime * flr,ProcessFunctionLibraryRuntime * pflr,std::vector<Device * > input_devices,absl::flat_hash_map<string,const std::vector<string> * > composite_devices,std::unordered_map<int,DtypeAndPartialTensorShape> input_resource_dtypes_and_shapes,std::function<void (std::function<void ()>)> * runner,std::unique_ptr<CollectiveExecutor::Handle> collective_executor,Device * host_cpu_device,const string & name,const bool outputs_on_op_device,std::function<Rendezvous * (const int64)> rendezvous_creator,std::function<int64 ()> get_op_id)260   KernelAndDeviceFunc(
261       FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr,
262       std::vector<Device*> input_devices,
263       absl::flat_hash_map<string, const std::vector<string>*> composite_devices,
264       std::unordered_map<int, DtypeAndPartialTensorShape>
265           input_resource_dtypes_and_shapes,
266       std::function<void(std::function<void()>)>* runner,
267       std::unique_ptr<CollectiveExecutor::Handle> collective_executor,
268       Device* host_cpu_device, const string& name,
269       const bool outputs_on_op_device,
270       std::function<Rendezvous*(const int64)> rendezvous_creator,
271       std::function<int64()> get_op_id)
272       : KernelAndDevice(flr, runner, std::move(collective_executor),
273                         host_cpu_device),
274         pflr_(pflr),
275         handle_(kInvalidHandle),
276         outputs_on_op_device_(outputs_on_op_device),
277         input_devices_(std::move(input_devices)),
278         composite_devices_(std::move(composite_devices)),
279         input_resource_dtypes_and_shapes_(
280             std::move(input_resource_dtypes_and_shapes)),
281         name_(name),
282         rendezvous_creator_(std::move(rendezvous_creator)),
283         get_op_id_(std::move(get_op_id)) {}
284 
285   ~KernelAndDeviceFunc() override;
286 
IsFunction()287   bool IsFunction() override { return true; };
288 
IsCrossProcess()289   bool IsCrossProcess() override { return is_cross_process_; }
290 
291   Status InstantiateFunc(const bool log_device_placement, const NodeDef& ndef,
292                          GraphCollector* graph_collector);
293 
294   Status Init(const bool log_device_placement, const NodeDef& ndef,
295               GraphCollector* graph_collector) override;
296 
297   Status Run(
298       ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
299       std::vector<EagerKernelRet>* outputs,
300       CancellationManager* cancellation_manager,
301       const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
302       const absl::optional<ManagedStackTrace>& stack_trace) override;
303 
304   void RunAsync(
305       ScopedStepContainer* step_container, const EagerKernelArgs& inputs,
306       std::vector<EagerKernelRet>* outputs,
307       CancellationManager* cancellation_manager,
308       const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
309       StatusCallback done) override;
310 
kernel()311   const OpKernel* kernel() const override { return nullptr; }
312 
313   Device* InputDevice(int i) const override;
314   Device* OutputDevice(int idx) const override;
315   Device* OutputResourceDevice(int idx) const override;
316 
input_dtypes()317   const DataTypeVector& input_dtypes() const override { return input_dtypes_; }
output_dtypes()318   const DataTypeVector& output_dtypes() const override {
319     return output_dtypes_;
320   }
num_inputs()321   int num_inputs() const override { return input_dtypes_.size(); }
num_outputs()322   int num_outputs() const override { return output_dtypes_.size(); }
name()323   const string& name() const override { return name_; };
324 
325  private:
326   ProcessFunctionLibraryRuntime* const pflr_;  // non-null
327   FunctionLibraryRuntime::Handle handle_;
328   // Indicates whether the function needs to execute cross process.
329   bool is_cross_process_;
330 
331   // If true, function outputs are explicitly assigned to the default device;
332   // if false, the output devices are inferred by pflr_.
333   bool outputs_on_op_device_;
334 
335   // CPU devices are null. Resource handles' devices are actual backing
336   // devices.
337   std::vector<Device*> output_devices_;
338   // CPU devices are not null. Resource handles' devices are actual backing
339   // devices.
340   std::vector<Device*> input_devices_;
341   // Maps from a CompositeDevice name to a list of physical device names.
342   absl::flat_hash_map<string, const std::vector<string>*> composite_devices_;
343   std::unordered_map<int, DtypeAndPartialTensorShape>
344       input_resource_dtypes_and_shapes_;
345 
346   DataTypeVector input_dtypes_;
347   DataTypeVector output_dtypes_;
348   string name_;
349 
350   std::function<Rendezvous*(const int64)> rendezvous_creator_;
351   std::function<int64()> get_op_id_;
352 };
353 
354 }  // namespace tensorflow
355 
356 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_
357