1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A Device is a something that can perform computations as part of a
17 // model.  Devices can be local (runs computation on this machine), or
18 // remote (contacts a device local to another machine using an RPC to
19 // do the work).  Devices are registered in a DeviceSet, which is also
20 // responsible for the Device <-> id mapping.
21 //
22 // Device names
23 // * Every Device should have a unique name with the format:
24 //     /job:___/replica:___/task:___/(gpu|cpu):___
25 //   An example name would be "/job:train/replica:0/task:3/device:GPU:2".
26 // * Task numbers are within the specified replica, so there are as
27 //   many "task zeros" as replicas.
28 
29 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_
30 #define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_
31 
32 #include <memory>
33 #include <string>
34 
35 #include "tensorflow/core/framework/allocator.h"
36 #include "tensorflow/core/framework/control_flow.h"
37 #include "tensorflow/core/framework/device_attributes.pb_text.h"
38 #include "tensorflow/core/framework/device_attributes.pb.h"
39 #include "tensorflow/core/framework/device_base.h"
40 #include "tensorflow/core/framework/graph.pb.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/op_segment.h"
43 #include "tensorflow/core/framework/resource_mgr.h"
44 #include "tensorflow/core/framework/types.h"
45 #include "tensorflow/core/graph/graph.h"
46 #include "tensorflow/core/graph/types.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/core/status.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/platform/types.h"
51 #include "tensorflow/core/util/device_name_utils.h"
52 
53 namespace tensorflow {
54 
55 class DeviceMgr;
56 
57 class Device : public DeviceBase {
58  public:
59   // Callback type that takes a Status and returns void.
60   typedef std::function<void(const Status&)> DoneCallback;
61 
62   Device(Env* env, const DeviceAttributes& device_attributes);
63   ~Device() override;
64 
65   // Full name of this device (see top comment).
name()66   const string& name() const override { return device_attributes_.name(); }
67 
68   // Parsed name of this device
parsed_name()69   const DeviceNameUtils::ParsedName& parsed_name() const {
70     return parsed_name_;
71   }
72 
73   // Describes what kind of device this is.  This is intended to be
74   // human-readable and not computer-parsed, except that two devices
75   // with the same device_type() are expected to perform similarly
76   // (both from a computation and communication perspective).
device_type()77   const string& device_type() const { return device_attributes_.device_type(); }
78 
79   // Returns an aggregation of device attributes.
attributes()80   const DeviceAttributes& attributes() const override {
81     return device_attributes_;
82   }
83 
84   // Performs the actual compute function.
85   //
86   // Subclasses may override this function if they wish to perform
87   // some initialization before each compute.
Compute(OpKernel * op_kernel,OpKernelContext * context)88   virtual void Compute(OpKernel* op_kernel, OpKernelContext* context) {
89     op_kernel->Compute(context);
90   }
91 
92   // Asynchronous kernel's compute.
ComputeAsync(AsyncOpKernel * op_kernel,OpKernelContext * context,AsyncOpKernel::DoneCallback done)93   virtual void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
94                             AsyncOpKernel::DoneCallback done) {
95     op_kernel->ComputeAsync(context, std::move(done));
96   }
97 
98   // Takes ownership of the references in tensors. If necessary, a
99   // device may override this method to keep a reference to the
100   // accessed tensors until the async computation has completed.
ConsumeListOfAccessedTensors(DeviceContext * context,const TensorReferenceVector & tensors)101   virtual void ConsumeListOfAccessedTensors(
102       DeviceContext* context, const TensorReferenceVector& tensors) {
103     for (const auto& ref : tensors) {
104       ref.Unref();
105     }
106   }
107 
108   // If true, and tracing is enabled, the `tracing::ScopedAnnotation()` tracing
109   // mechanism will be used instead of `tracing::ScopedActivity()`. Some devices
110   // may override this method to use annotations, which enable child activities
111   // (such as GPU kernel launches) to be related to the OpKernel invocation.
TraceUsingAnnotations()112   virtual bool TraceUsingAnnotations() const { return false; }
113 
114   // Blocks until all operations queued on the device at the time of
115   // the call have completed.  Returns any error pending on the device
116   // at completion.
117   virtual Status Sync() = 0;
118 
119   // Calls the given callback when all operations queued on the device at the
120   // time of the call have completed. The callback is passed any error pending
121   // on the device at completion.
122   // TODO(b/112409994): Consolidate these two APIs, removing the synchronous
123   // version.
124   virtual void Sync(const DoneCallback& done);
125 
126   // On session completion, the executor may call Device::Sync() depending on
127   // flag settings. Override this to return false for devices that don't allow
128   // such calls. Instead, these devices must use other mechanisms (such as
129   // num_deferred_ops) to ensure the device has finished processing necessary
130   // work at session completion. In addition, for these devices, RefreshStatus
131   // must be called at session completion to retrieve execution result status.
132   //
133   // Devices that override this function must also implement RefreshStatus.
AllowsSyncOnCompletion()134   virtual bool AllowsSyncOnCompletion() const { return true; }
135 
136   // This is used in conjunction with AllowsSyncOnCompletion to allow the
137   // executor to get execution result status at session completion.
138   //
139   // For supported devices, this call returns the underlying device stream's
140   // current status in a non-blocking way, without using blocking calls such as
141   // Stream::BlockHostUntilDone or Device::Sync. When applicable, the device
142   // status is also updated with the retrieved stream status.
RefreshStatus()143   virtual Status RefreshStatus() {
144     return errors::Unimplemented(
145         "RefreshStatus is not supported on this device.");
146   }
147 
148   // Optionally modify the device's GraphDef before execution.
149   //
150   // This method should be considered experimental and is supplied to enable
151   // prototyping of TensorFlow device implementations that need to modify
152   // the GraphDef before execution.
153   //
154   // 'graph' supplies the partition of the graph assigned to this
155   // device.
MaybeRewriteGraph(std::unique_ptr<Graph> *)156   virtual Status MaybeRewriteGraph(std::unique_ptr<Graph>* /*graph*/) {
157     return Status::OK();
158   }
159 
160   // Fill in the context map for the graph. Default behavior is to do
161   // nothing.
162   //
163   // The caller takes ownership over the DeviceContext objects given
164   // by the device.
FillContextMap(const Graph * graph,DeviceContextMap * device_context_map)165   virtual Status FillContextMap(const Graph* graph,
166                                 DeviceContextMap* device_context_map) {
167     return Status::OK();
168   }
169 
170   // Returns the op segment of this device.  The caller can reuse op
171   // kernels registered for the same session running on this device.
op_segment()172   OpSegment* op_segment() { return &op_seg_; }
173 
174   // Returns the resource manager associated w/ this device.
resource_manager()175   virtual ResourceMgr* resource_manager() { return rmgr_; }
176 
177   // Returns the device manager that owns this device, or nullptr if this Device
178   // is not owned by a device manager.
device_mgr()179   DeviceMgr* device_mgr() const { return device_mgr_; }
180 
181   // Summarizes the status of this Device, for debugging.
DebugString()182   string DebugString() const { return ProtoDebugString(device_attributes_); }
183 
184   // Assembles the parameter components into a complete DeviceAttributes value.
185   static DeviceAttributes BuildDeviceAttributes(
186       const string& name, DeviceType device, Bytes memory_limit,
187       const DeviceLocality& locality, const string& physical_device_desc);
188 
BuildDeviceAttributes(const string & name,DeviceType device,Bytes memory_limit,const DeviceLocality & locality)189   static DeviceAttributes BuildDeviceAttributes(
190       const string& name, DeviceType device, Bytes memory_limit,
191       const DeviceLocality& locality) {
192     // Pass in an empty string as physical device name.
193     return BuildDeviceAttributes(name, device, memory_limit, locality, "");
194   }
195 
196   // Clears the resource manager associated with this device.
ClearResourceMgr()197   void ClearResourceMgr() { rmgr_->Clear(); }
198 
199  protected:
DeleteResourceMgr()200   void DeleteResourceMgr() {
201     delete rmgr_;
202     rmgr_ = nullptr;
203   }
204 
205  private:
206   friend class DeviceMgr;
207 
208   // Pointer to the device manager that owns this device. Not owned.
209   DeviceMgr* device_mgr_ = nullptr;
210 
211   const DeviceAttributes device_attributes_;
212   DeviceNameUtils::ParsedName parsed_name_;
213 
214   // op_seg_ maps session handle and op name to OpKernel objects.
215   OpSegment op_seg_;
216 
217   // Resources associated w/ this device. E.g., shared variables, etc.
218   ResourceMgr* rmgr_ = nullptr;
219 
220   TF_DISALLOW_COPY_AND_ASSIGN(Device);
221 };
222 
223 }  // namespace tensorflow
224 
225 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_
226