1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
18 
19 #include "tensorflow/core/common_runtime/device.h"
20 #include "tensorflow/core/util/device_name_utils.h"
21 
22 namespace tensorflow {
23 
24 // Wraps a device with a new name, delegating work to the wrapped device.
25 //
26 // This class is used to wrap local devices when using clusterspec propagation
27 // where the name of a particular device may change in the context of a given
28 // session.
29 class RenamedDevice : public Device {
30  public:
31   static std::unique_ptr<Device> NewRenamedDevice(const string& new_base,
32                                                   Device* underlying,
33                                                   bool owns_underlying,
34                                                   bool isolate_session_state);
35 
36   ~RenamedDevice() override;
37 
38   // Below are virtual methods defined on DeviceBase
RequiresRecordingAccessedTensors()39   bool RequiresRecordingAccessedTensors() const override {
40     return underlying_->RequiresRecordingAccessedTensors();
41   }
42 
UnderlyingDevice()43   const DeviceBase* UnderlyingDevice() const override {
44     return underlying_->UnderlyingDevice();
45   }
UnderlyingDevice()46   DeviceBase* UnderlyingDevice() override {
47     return underlying_->UnderlyingDevice();
48   }
49 
tensorflow_cpu_worker_threads()50   const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override {
51     return underlying_->tensorflow_cpu_worker_threads();
52   }
53 
tensorflow_gpu_device_info()54   const GpuDeviceInfo* tensorflow_gpu_device_info() const override {
55     return underlying_->tensorflow_gpu_device_info();
56   }
57 
GetAllocator(AllocatorAttributes attr)58   Allocator* GetAllocator(AllocatorAttributes attr) override {
59     return underlying_->GetAllocator(attr);
60   }
61 
GetScopedAllocator(AllocatorAttributes attr,int64 step_id)62   Allocator* GetScopedAllocator(AllocatorAttributes attr,
63                                 int64 step_id) override {
64     return underlying_->GetScopedAllocator(attr, step_id);
65   }
66 
GetScopedAllocatorMgr()67   ScopedAllocatorMgr* GetScopedAllocatorMgr() const override {
68     return underlying_->GetScopedAllocatorMgr();
69   }
70 
eigen_cpu_device()71   const Eigen::ThreadPoolDevice* eigen_cpu_device() override {
72     return underlying_->eigen_cpu_device();
73   }
74 
75 #ifdef TENSORFLOW_USE_SYCL
eigen_sycl_device()76   const Eigen::SyclDevice* eigen_sycl_device() const override {
77     return underlying_->eigen_sycl_device();
78   }
79 #endif
80 
MakeGpuDevice()81   PerOpGpuDevice* MakeGpuDevice() override {
82     return underlying_->MakeGpuDevice();
83   }
84 
ReinitializeGpuDevice(OpKernelContext * context,PerOpGpuDevice * device,DeviceContext * dc,Allocator * allocator)85   Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
86                                DeviceContext* dc,
87                                Allocator* allocator) override {
88     return underlying_->ReinitializeGpuDevice(context, device, dc, allocator);
89   }
90 
MakeTensorFromProto(const TensorProto & tensor_proto,const AllocatorAttributes alloc_attrs,Tensor * tensor)91   Status MakeTensorFromProto(const TensorProto& tensor_proto,
92                              const AllocatorAttributes alloc_attrs,
93                              Tensor* tensor) override {
94     return underlying_->MakeTensorFromProto(tensor_proto, alloc_attrs, tensor);
95   }
96 
97   // Below are virtual methods defined on Device
98 
Compute(OpKernel * op_kernel,OpKernelContext * context)99   void Compute(OpKernel* op_kernel, OpKernelContext* context) override {
100     underlying_->Compute(op_kernel, context);
101   }
102 
ComputeAsync(AsyncOpKernel * op_kernel,OpKernelContext * context,AsyncOpKernel::DoneCallback done)103   void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
104                     AsyncOpKernel::DoneCallback done) override {
105     underlying_->ComputeAsync(op_kernel, context, std::move(done));
106   }
107 
ConsumeListOfAccessedTensors(DeviceContext * context,const TensorReferenceVector & tensors)108   void ConsumeListOfAccessedTensors(
109       DeviceContext* context, const TensorReferenceVector& tensors) override {
110     underlying_->ConsumeListOfAccessedTensors(context, tensors);
111   }
112 
Sync()113   Status Sync() override { return underlying_->Sync(); }
114 
MaybeRewriteGraph(std::unique_ptr<Graph> * graph)115   Status MaybeRewriteGraph(std::unique_ptr<Graph>* graph) override {
116     return underlying_->MaybeRewriteGraph(graph);
117   }
118 
FillContextMap(const Graph * graph,DeviceContextMap * device_context_map)119   Status FillContextMap(const Graph* graph,
120                         DeviceContextMap* device_context_map) override {
121     return underlying_->FillContextMap(graph, device_context_map);
122   }
123 
124   // Returns the resource manager associated w/ this device.
resource_manager()125   ResourceMgr* resource_manager() override {
126     if (isolate_session_state_) {
127       return Device::resource_manager();
128     } else {
129       return underlying_->resource_manager();
130     }
131   }
132 
133  private:
134   RenamedDevice(Device* underlying, const DeviceAttributes& attributes,
135                 bool owns_underlying, bool isolate_session_state);
136   Device* const underlying_;
137   const bool owns_underlying_;
138   const bool isolate_session_state_;
139 };
140 
141 }  // namespace tensorflow
142 
143 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
144