1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Declares the XlaInterpreterExecutor class, which is a CPU-only implementation 17 // of the StreamExecutor interface. For now, this is used for testing and to 18 // examine the performance of host-based StreamExecutor code. 19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_ 20 #define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_ 21 22 #include <functional> 23 #include <memory> 24 25 #include "absl/types/span.h" 26 #include "tensorflow/compiler/xla/shape_util.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 #include "tensorflow/core/platform/types.h" 29 #include "tensorflow/stream_executor/blas.h" 30 #include "tensorflow/stream_executor/device_description.h" 31 #include "tensorflow/stream_executor/device_memory.h" 32 #include "tensorflow/stream_executor/device_options.h" 33 #include "tensorflow/stream_executor/event.h" 34 #include "tensorflow/stream_executor/host/host_stream.h" 35 #include "tensorflow/stream_executor/host/host_timer.h" 36 #include "tensorflow/stream_executor/kernel.h" 37 #include "tensorflow/stream_executor/kernel_spec.h" 38 #include "tensorflow/stream_executor/launch_dim.h" 39 #include "tensorflow/stream_executor/plugin.h" 40 #include "tensorflow/stream_executor/rng.h" 41 #include "tensorflow/stream_executor/shared_memory_config.h" 42 #include "tensorflow/stream_executor/stream.h" 43 #include "tensorflow/stream_executor/stream_executor.h" 44 #include "tensorflow/stream_executor/stream_executor_internal.h" 45 #include "tensorflow/stream_executor/timer.h" 46 47 namespace stream_executor { 48 namespace interpreter { 49 50 using Args = absl::Span<const DeviceMemoryBase>; 51 52 class XlaInterpreterExecutor : public internal::StreamExecutorInterface { 53 public: 54 explicit XlaInterpreterExecutor(const PluginConfig &plugin_config); 55 ~XlaInterpreterExecutor() override; 56 Init(int device_ordinal,DeviceOptions device_options)57 port::Status Init(int device_ordinal, DeviceOptions device_options) override { 58 return port::Status::OK(); 59 } 60 GetKernel(const MultiKernelLoaderSpec & spec,KernelBase * kernel)61 bool GetKernel(const MultiKernelLoaderSpec &spec, 62 KernelBase *kernel) override { 63 return false; 64 } Launch(Stream * stream,const ThreadDim & thread_dims,const BlockDim & block_dims,const KernelBase & kernel,const KernelArgsArrayBase & args)65 bool Launch(Stream *stream, const ThreadDim &thread_dims, 66 const BlockDim &block_dims, const KernelBase &kernel, 67 const KernelArgsArrayBase &args) override { 68 return false; 69 } 70 71 void *Allocate(uint64 size) override; 72 void *AllocateSubBuffer(DeviceMemoryBase *mem, uint64 offset_bytes, 73 uint64 size_bytes) override; 74 void Deallocate(DeviceMemoryBase *mem) override; 75 HostMemoryAllocate(uint64 size)76 void *HostMemoryAllocate(uint64 size) override { return new char[size]; } HostMemoryDeallocate(void * mem)77 void HostMemoryDeallocate(void *mem) override { 78 delete[] static_cast<char *>(mem); 79 } HostMemoryRegister(void * mem,uint64 size)80 bool HostMemoryRegister(void *mem, uint64 size) override { return true; } HostMemoryUnregister(void * mem)81 bool HostMemoryUnregister(void *mem) override { return true; } 82 83 bool Memcpy(Stream *stream, void *host_dst, const DeviceMemoryBase &pop_src, 84 uint64 size) override; 85 bool Memcpy(Stream *stream, DeviceMemoryBase *pop_dst, const void *host_src, 86 uint64 size) override; MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * pop_dst,const DeviceMemoryBase & host_src,uint64 size)87 bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *pop_dst, 88 const DeviceMemoryBase &host_src, 89 uint64 size) override { 90 return false; 91 } 92 MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)93 bool MemZero(Stream *stream, DeviceMemoryBase *location, 94 uint64 size) override { 95 return false; 96 } Memset(Stream * stream,DeviceMemoryBase * location,uint8 pattern,uint64 size)97 bool Memset(Stream *stream, DeviceMemoryBase *location, uint8 pattern, 98 uint64 size) override { 99 return false; 100 } Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)101 bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern, 102 uint64 size) override { 103 return false; 104 } 105 106 // No "synchronize all activity" implemented for this platform at the moment. SynchronizeAllActivity()107 bool SynchronizeAllActivity() override { return true; } SynchronousMemZero(DeviceMemoryBase * location,uint64 size)108 bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override { 109 return false; 110 } 111 SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)112 bool SynchronousMemSet(DeviceMemoryBase *location, int value, 113 uint64 size) override { 114 return false; 115 } 116 117 port::Status SynchronousMemcpy(DeviceMemoryBase *pop_dst, 118 const void *host_src, uint64 size) override; 119 port::Status SynchronousMemcpy(void *host_dst, 120 const DeviceMemoryBase &pop_src, 121 uint64 size) override; SynchronousMemcpyDeviceToDevice(DeviceMemoryBase * pop_dst,const DeviceMemoryBase & pop_src,uint64 size)122 port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *pop_dst, 123 const DeviceMemoryBase &pop_src, 124 uint64 size) override { 125 return port::Status{port::error::UNIMPLEMENTED, ""}; 126 } 127 128 bool HostCallback(Stream *stream, 129 std::function<port::Status()> callback) override; 130 AllocateEvent(Event * event)131 port::Status AllocateEvent(Event *event) override { 132 return port::Status{port::error::UNIMPLEMENTED, ""}; 133 } 134 DeallocateEvent(Event * event)135 port::Status DeallocateEvent(Event *event) override { 136 return port::Status{port::error::UNIMPLEMENTED, ""}; 137 } 138 RecordEvent(Stream * stream,Event * event)139 port::Status RecordEvent(Stream *stream, Event *event) override { 140 return port::Status{port::error::UNIMPLEMENTED, ""}; 141 } 142 WaitForEvent(Stream * stream,Event * event)143 port::Status WaitForEvent(Stream *stream, Event *event) override { 144 return port::Status{port::error::UNIMPLEMENTED, ""}; 145 } 146 PollForEventStatus(Event * event)147 Event::Status PollForEventStatus(Event *event) override { 148 return Event::Status::kError; 149 } 150 AllocateStream(Stream * stream)151 bool AllocateStream(Stream *stream) override { return true; } DeallocateStream(Stream * stream)152 void DeallocateStream(Stream *stream) override {} 153 bool CreateStreamDependency(Stream *dependent, Stream *other) override; 154 AllocateTimer(Timer * timer)155 bool AllocateTimer(Timer *timer) override { return true; } DeallocateTimer(Timer * timer)156 void DeallocateTimer(Timer *timer) override {} 157 bool StartTimer(Stream *stream, Timer *timer) override; 158 bool StopTimer(Stream *stream, Timer *timer) override; 159 160 port::Status BlockHostUntilDone(Stream *stream) override; 161 PlatformDeviceCount()162 int PlatformDeviceCount() override { return 1; } 163 DeviceMemoryUsage(int64 * free,int64 * total)164 bool DeviceMemoryUsage(int64 *free, int64 *total) const override { 165 return false; 166 } 167 168 DeviceDescription *PopulateDeviceDescription() const override; 169 EnablePeerAccessTo(StreamExecutorInterface * other)170 port::Status EnablePeerAccessTo(StreamExecutorInterface *other) override { 171 return port::Status::OK(); 172 } 173 CanEnablePeerAccessTo(StreamExecutorInterface * other)174 bool CanEnablePeerAccessTo(StreamExecutorInterface *other) override { 175 return true; 176 } 177 GetDeviceSharedMemoryConfig()178 SharedMemoryConfig GetDeviceSharedMemoryConfig() override { 179 return SharedMemoryConfig::kDefault; 180 } 181 SetDeviceSharedMemoryConfig(SharedMemoryConfig config)182 port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config) override { 183 return port::Status{port::error::UNIMPLEMENTED, 184 "Shared memory not supported"}; 185 } 186 CreateEventImplementation()187 std::unique_ptr<internal::EventInterface> CreateEventImplementation() 188 override { 189 return nullptr; 190 } 191 CreateKernelImplementation()192 std::unique_ptr<internal::KernelInterface> CreateKernelImplementation() 193 override { 194 return nullptr; 195 } 196 GetStreamImplementation()197 std::unique_ptr<internal::StreamInterface> GetStreamImplementation() 198 override { 199 return std::unique_ptr<internal::StreamInterface>(new host::HostStream()); 200 } 201 GetTimerImplementation()202 std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override { 203 return std::unique_ptr<internal::TimerInterface>(new host::HostTimer()); 204 } 205 206 private: 207 DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape); 208 209 port::StatusOr<DeviceMemoryBase> AllocateOutputBuffer( 210 const xla::Shape &shape); 211 212 const PluginConfig plugin_config_; 213 }; 214 215 } // namespace interpreter 216 } // namespace stream_executor 217 218 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_ 219