1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Declares the XlaInterpreterExecutor class, which is a CPU-only implementation
17 // of the StreamExecutor interface. For now, this is used for testing and to
18 // examine the performance of host-based StreamExecutor code.
19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_
20 #define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_
21 
22 #include <functional>
23 #include <memory>
24 
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/stream_executor/blas.h"
30 #include "tensorflow/stream_executor/device_description.h"
31 #include "tensorflow/stream_executor/device_memory.h"
32 #include "tensorflow/stream_executor/device_options.h"
33 #include "tensorflow/stream_executor/event.h"
34 #include "tensorflow/stream_executor/host/host_stream.h"
35 #include "tensorflow/stream_executor/host/host_timer.h"
36 #include "tensorflow/stream_executor/kernel.h"
37 #include "tensorflow/stream_executor/kernel_spec.h"
38 #include "tensorflow/stream_executor/launch_dim.h"
39 #include "tensorflow/stream_executor/plugin.h"
40 #include "tensorflow/stream_executor/rng.h"
41 #include "tensorflow/stream_executor/shared_memory_config.h"
42 #include "tensorflow/stream_executor/stream.h"
43 #include "tensorflow/stream_executor/stream_executor.h"
44 #include "tensorflow/stream_executor/stream_executor_internal.h"
45 #include "tensorflow/stream_executor/timer.h"
46 
47 namespace stream_executor {
48 namespace interpreter {
49 
50 using Args = absl::Span<const DeviceMemoryBase>;
51 
52 class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
53  public:
54   explicit XlaInterpreterExecutor(const PluginConfig &plugin_config);
55   ~XlaInterpreterExecutor() override;
56 
Init(int device_ordinal,DeviceOptions device_options)57   port::Status Init(int device_ordinal, DeviceOptions device_options) override {
58     return port::Status::OK();
59   }
60 
GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)61   bool GetKernel(const MultiKernelLoaderSpec &spec,
62                  KernelBase *kernel) override {
63     return false;
64   }
Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & kernel,const KernelArgsArrayBase & args)65   bool Launch(Stream *stream, const ThreadDim &thread_dims,
66               const BlockDim &block_dims, const KernelBase &kernel,
67               const KernelArgsArrayBase &args) override {
68     return false;
69   }
70 
71   void *Allocate(uint64 size) override;
72   void *AllocateSubBuffer(DeviceMemoryBase *mem, uint64 offset_bytes,
73                           uint64 size_bytes) override;
74   void Deallocate(DeviceMemoryBase *mem) override;
75 
HostMemoryAllocate(uint64 size)76   void *HostMemoryAllocate(uint64 size) override { return new char[size]; }
HostMemoryDeallocate(void * mem)77   void HostMemoryDeallocate(void *mem) override {
78     delete[] static_cast<char *>(mem);
79   }
HostMemoryRegister(void * mem,uint64 size)80   bool HostMemoryRegister(void *mem, uint64 size) override { return true; }
HostMemoryUnregister(void * mem)81   bool HostMemoryUnregister(void *mem) override { return true; }
82 
83   bool Memcpy(Stream *stream, void *host_dst, const DeviceMemoryBase &pop_src,
84               uint64 size) override;
85   bool Memcpy(Stream *stream, DeviceMemoryBase *pop_dst, const void *host_src,
86               uint64 size) override;
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * pop_dst,const DeviceMemoryBase & host_src,uint64 size)87   bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *pop_dst,
88                             const DeviceMemoryBase &host_src,
89                             uint64 size) override {
90     return false;
91   }
92 
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)93   bool MemZero(Stream *stream, DeviceMemoryBase *location,
94                uint64 size) override {
95     return false;
96   }
Memset(Stream * stream,DeviceMemoryBase * location,uint8 pattern,uint64 size)97   bool Memset(Stream *stream, DeviceMemoryBase *location, uint8 pattern,
98               uint64 size) override {
99     return false;
100   }
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)101   bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern,
102                 uint64 size) override {
103     return false;
104   }
105 
106   // No "synchronize all activity" implemented for this platform at the moment.
SynchronizeAllActivity()107   bool SynchronizeAllActivity() override { return true; }
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)108   bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override {
109     return false;
110   }
111 
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)112   bool SynchronousMemSet(DeviceMemoryBase *location, int value,
113                          uint64 size) override {
114     return false;
115   }
116 
117   port::Status SynchronousMemcpy(DeviceMemoryBase *pop_dst,
118                                  const void *host_src, uint64 size) override;
119   port::Status SynchronousMemcpy(void *host_dst,
120                                  const DeviceMemoryBase &pop_src,
121                                  uint64 size) override;
SynchronousMemcpyDeviceToDevice(DeviceMemoryBase * pop_dst,const DeviceMemoryBase & pop_src,uint64 size)122   port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *pop_dst,
123                                                const DeviceMemoryBase &pop_src,
124                                                uint64 size) override {
125     return port::Status{port::error::UNIMPLEMENTED, ""};
126   }
127 
128   bool HostCallback(Stream *stream,
129                     std::function<port::Status()> callback) override;
130 
AllocateEvent(Event * event)131   port::Status AllocateEvent(Event *event) override {
132     return port::Status{port::error::UNIMPLEMENTED, ""};
133   }
134 
DeallocateEvent(Event * event)135   port::Status DeallocateEvent(Event *event) override {
136     return port::Status{port::error::UNIMPLEMENTED, ""};
137   }
138 
RecordEvent(Stream * stream,Event * event)139   port::Status RecordEvent(Stream *stream, Event *event) override {
140     return port::Status{port::error::UNIMPLEMENTED, ""};
141   }
142 
WaitForEvent(Stream * stream,Event * event)143   port::Status WaitForEvent(Stream *stream, Event *event) override {
144     return port::Status{port::error::UNIMPLEMENTED, ""};
145   }
146 
PollForEventStatus(Event * event)147   Event::Status PollForEventStatus(Event *event) override {
148     return Event::Status::kError;
149   }
150 
AllocateStream(Stream * stream)151   bool AllocateStream(Stream *stream) override { return true; }
DeallocateStream(Stream * stream)152   void DeallocateStream(Stream *stream) override {}
153   bool CreateStreamDependency(Stream *dependent, Stream *other) override;
154 
AllocateTimer(Timer * timer)155   bool AllocateTimer(Timer *timer) override { return true; }
DeallocateTimer(Timer * timer)156   void DeallocateTimer(Timer *timer) override {}
157   bool StartTimer(Stream *stream, Timer *timer) override;
158   bool StopTimer(Stream *stream, Timer *timer) override;
159 
160   port::Status BlockHostUntilDone(Stream *stream) override;
161 
PlatformDeviceCount()162   int PlatformDeviceCount() override { return 1; }
163 
DeviceMemoryUsage(int64 * free,int64 * total)164   bool DeviceMemoryUsage(int64 *free, int64 *total) const override {
165     return false;
166   }
167 
168   DeviceDescription *PopulateDeviceDescription() const override;
169 
EnablePeerAccessTo(StreamExecutorInterface * other)170   port::Status EnablePeerAccessTo(StreamExecutorInterface *other) override {
171     return port::Status::OK();
172   }
173 
CanEnablePeerAccessTo(StreamExecutorInterface * other)174   bool CanEnablePeerAccessTo(StreamExecutorInterface *other) override {
175     return true;
176   }
177 
GetDeviceSharedMemoryConfig()178   SharedMemoryConfig GetDeviceSharedMemoryConfig() override {
179     return SharedMemoryConfig::kDefault;
180   }
181 
SetDeviceSharedMemoryConfig(SharedMemoryConfig config)182   port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config) override {
183     return port::Status{port::error::UNIMPLEMENTED,
184                         "Shared memory not supported"};
185   }
186 
CreateEventImplementation()187   std::unique_ptr<internal::EventInterface> CreateEventImplementation()
188       override {
189     return nullptr;
190   }
191 
CreateKernelImplementation()192   std::unique_ptr<internal::KernelInterface> CreateKernelImplementation()
193       override {
194     return nullptr;
195   }
196 
GetStreamImplementation()197   std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
198       override {
199     return std::unique_ptr<internal::StreamInterface>(new host::HostStream());
200   }
201 
GetTimerImplementation()202   std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {
203     return std::unique_ptr<internal::TimerInterface>(new host::HostTimer());
204   }
205 
206  private:
207   DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape);
208 
209   port::StatusOr<DeviceMemoryBase> AllocateOutputBuffer(
210       const xla::Shape &shape);
211 
212   const PluginConfig plugin_config_;
213 };
214 
215 }  // namespace interpreter
216 }  // namespace stream_executor
217 
218 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_
219