1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/tpu/tpu_executor.h"
17 
18 #include "tensorflow/c/tf_status.h"
19 #include "tensorflow/core/lib/gtl/cleanup.h"
20 #include "tensorflow/core/tpu/tpu_api.h"
21 #include "tensorflow/stream_executor/tpu/status_helper.h"
22 #include "tensorflow/stream_executor/tpu/tpu_event.h"
23 #include "tensorflow/stream_executor/tpu/tpu_stream.h"
24 #include "tensorflow/stream_executor/tpu/tpu_timer.h"
25 
26 using stream_executor::DeviceMemoryBase;
27 
28 namespace tensorflow {
29 namespace tpu {
30 
31 namespace {
32 using ::stream_executor::port::Status;
33 }  // namespace
34 
~TpuExecutor()35 TpuExecutor::~TpuExecutor() {
36   tpu::ExecutorApiFn()->TpuExecutor_FreeFn(executor_);
37 }
38 
Init(int device_ordinal,::stream_executor::DeviceOptions device_options)39 Status TpuExecutor::Init(int device_ordinal,
40                          ::stream_executor::DeviceOptions device_options) {
41   StatusHelper status;
42   SE_DeviceOptions* options =
43       tpu::ExecutorApiFn()->TpuExecutor_NewDeviceOptionsFn(
44           device_options.flags());
45   tpu::ExecutorApiFn()->TpuExecutor_InitFn(executor_, device_ordinal, options,
46                                            status.c_status);
47   tpu::ExecutorApiFn()->TpuExecutor_FreeDeviceOptionsFn(options);
48   return status.status();
49 }
50 
PlatformDeviceCount()51 int TpuExecutor::PlatformDeviceCount() {
52   return tpu::ExecutorApiFn()->TpuExecutor_PlatformDeviceCountFn(executor_);
53 }
54 
SyncAndForgetFailedStreams()55 void TpuExecutor::SyncAndForgetFailedStreams() {
56   tpu::ExecutorApiFn()->TpuExecutor_SyncAndForgetFailedStreamsFn(executor_);
57 }
58 
SynchronizeAllActivity()59 bool TpuExecutor::SynchronizeAllActivity() {
60   return tpu::ExecutorApiFn()->TpuExecutor_SynchronizeAllActivityFn(executor_);
61 }
62 
BlockHostUntilDone(Stream * stream)63 Status TpuExecutor::BlockHostUntilDone(Stream* stream) {
64   StatusHelper status;
65   tpu::ExecutorApiFn()->TpuExecutor_BlockHostUntilDoneFn(
66       executor_, get_stream(stream->implementation()), status.c_status);
67   return status.status();
68 }
69 
BlockUntilDoneOrFailed()70 Status TpuExecutor::BlockUntilDoneOrFailed() {
71   StatusHelper status;
72   tpu::ExecutorApiFn()->TpuExecutor_BlockUntilDoneOrFailedFn(executor_,
73                                                              status.c_status);
74   return status.status();
75 }
76 
GetStatus(Stream * stream)77 Status TpuExecutor::GetStatus(Stream* stream) {
78   StatusHelper status;
79   tpu::ExecutorApiFn()->TpuExecutor_GetStatusFn(
80       executor_, get_stream(stream->implementation()), status.c_status);
81   return status.status();
82 }
83 
GetCoreLocationExternal() const84 tpu::TpuCoreLocationExternal TpuExecutor::GetCoreLocationExternal() const {
85   return tpu::TpuCoreLocationExternal(
86       tpu::ExecutorApiFn()->TpuExecutor_GetCoreLocationFn(executor_));
87 }
88 
AllocateStream(Stream * stream)89 bool TpuExecutor::AllocateStream(Stream* stream) {
90   return tpu::ExecutorApiFn()->TpuExecutor_AllocateStreamFn(
91       executor_, get_stream(stream->implementation()));
92 }
93 
DeallocateStream(Stream * stream)94 void TpuExecutor::DeallocateStream(Stream* stream) {
95   tpu::ExecutorApiFn()->TpuExecutor_DeallocateStreamFn(
96       executor_, get_stream(stream->implementation()));
97   tpu_platform().mutex().lock();
98   stream_map().erase(stream->implementation());
99   tpu_platform().mutex().unlock();
100 }
101 
CreateStreamDependency(Stream * dependent,Stream * other)102 bool TpuExecutor::CreateStreamDependency(Stream* dependent, Stream* other) {
103   return tpu::ExecutorApiFn()->TpuExecutor_CreateStreamDependencyFn(
104       executor_, get_stream(dependent->implementation()),
105       get_stream(other->implementation()));
106 }
107 
AllocateEvent(Event * event)108 Status TpuExecutor::AllocateEvent(Event* event) { return Status::OK(); }
109 
DeallocateEvent(Event * event)110 Status TpuExecutor::DeallocateEvent(Event* event) {
111   tpu_platform().EraseEvent(event->implementation());
112   return Status::OK();
113 }
114 
115 // AllocateTimer/DeallocateTimer have no specialization.
AllocateTimer(Timer * timer)116 bool TpuExecutor::AllocateTimer(Timer* timer) { return true; }
117 
DeallocateTimer(Timer * timer)118 void TpuExecutor::DeallocateTimer(Timer* timer) {}
119 
StartTimer(Stream * stream,::stream_executor::Timer * timer)120 bool TpuExecutor::StartTimer(Stream* stream, ::stream_executor::Timer* timer) {
121   return tpu::ExecutorApiFn()->TpuExecutor_StartTimerFn(
122       executor_, get_stream(stream->implementation()),
123       timer_map_.at(timer->implementation()));
124 }
125 
StopTimer(Stream * stream,::stream_executor::Timer * timer)126 bool TpuExecutor::StopTimer(Stream* stream, ::stream_executor::Timer* timer) {
127   return tpu::ExecutorApiFn()->TpuExecutor_StopTimerFn(
128       executor_, get_stream(stream->implementation()),
129       timer_map_.at(timer->implementation()));
130 }
131 
PollForEventStatus(stream_executor::Event * event)132 stream_executor::Event::Status TpuExecutor::PollForEventStatus(
133     stream_executor::Event* event) {
134   auto se_event = tpu_platform().LookupEvent(event->implementation());
135   return stream_executor::Event::Status(
136       tpu::ExecutorApiFn()->TpuExecutor_PollForEventStatusFn(executor_,
137                                                              se_event));
138 }
139 
RecordEvent(Stream * stream,::stream_executor::Event * event)140 Status TpuExecutor::RecordEvent(Stream* stream,
141                                 ::stream_executor::Event* event) {
142   StatusHelper status;
143   auto se_event = tpu_platform().LookupEvent(event->implementation());
144   tpu::ExecutorApiFn()->TpuExecutor_RecordEventFn(
145       executor_, get_stream(stream->implementation()), se_event,
146       status.c_status);
147   return status.status();
148 }
149 
WaitForEvent(Stream * stream,::stream_executor::Event * event)150 Status TpuExecutor::WaitForEvent(Stream* stream,
151                                  ::stream_executor::Event* event) {
152   StatusHelper status;
153   auto se_event = tpu_platform().LookupEvent(event->implementation());
154   tpu::ExecutorApiFn()->TpuExecutor_WaitForEventFn(
155       executor_, get_stream(stream->implementation()), se_event,
156       status.c_status);
157   return status.status();
158 }
159 
160 // Implementations for Timer, Stream, Event
161 // We need to map these implementations to internal equivalents -- thus we
162 // allocate the internal Timer, Stream and Event operations here, and map
163 // the implementations to the internal values. The "wrapper" interfaces are
164 // responsible for deallocating the internal value when they are destroyed.
165 
166 // Called by Timer::Timer
167 std::unique_ptr<::stream_executor::internal::TimerInterface>
GetTimerImplementation()168 TpuExecutor::GetTimerImplementation() {
169   SE_Timer* tpu_timer = tpu::ExecutorApiFn()->TpuTimer_NewFn(executor_);
170   auto ptr = absl::make_unique<TpuTimer>(tpu_timer);
171   timer_map_[ptr.get()] = tpu_timer;
172   return ptr;
173 }
174 
175 // Called by Stream::Stream
176 std::unique_ptr<::stream_executor::internal::StreamInterface>
GetStreamImplementation()177 TpuExecutor::GetStreamImplementation() {
178   SE_Stream* tpu_stream = tpu::ExecutorApiFn()->TpuStream_NewFn(executor_);
179   auto ptr = absl::make_unique<tpu::TpuStream>(tpu_stream);
180   tpu_platform().mutex().lock();
181   stream_map()[ptr.get()] = tpu_stream;
182   tpu_platform().mutex().unlock();
183   return ptr;
184 }
185 
186 // Called by Event::Event
187 std::unique_ptr<::stream_executor::internal::EventInterface>
CreateEventImplementation()188 TpuExecutor::CreateEventImplementation() {
189   SE_Event* tpu_event = tpu::ExecutorApiFn()->TpuEvent_NewFn(executor_);
190   auto ptr = absl::make_unique<TpuEvent>(tpu_event);
191   tpu_platform().InsertEvent(ptr.get(), tpu_event);
192   return ptr;
193 }
194 
Allocate(uint64 size,int64 memory_space)195 DeviceMemoryBase TpuExecutor::Allocate(uint64 size, int64 memory_space) {
196   SE_DeviceMemoryBase se_base = tpu::ExecutorApiFn()->TpuExecutor_AllocateFn(
197       executor_, size, memory_space);
198   return ApiConverter::FromC(se_base);
199 }
200 
Deallocate(const DeviceMemoryBase & memory)201 void TpuExecutor::Deallocate(const DeviceMemoryBase& memory) {
202   SE_DeviceMemoryBase se_base = ApiConverter::ToC(memory);
203   tpu::ExecutorApiFn()->TpuExecutor_DeallocateFn(executor_, &se_base);
204 }
205 
Deallocate(DeviceMemoryBase * memory)206 void TpuExecutor::Deallocate(DeviceMemoryBase* memory) {
207   SE_DeviceMemoryBase se_base = ApiConverter::ToC(*memory);
208   tpu::ExecutorApiFn()->TpuExecutor_DeallocateFn(executor_, &se_base);
209 }
210 
DeviceMemoryUsage(int64 * free,int64 * total) const211 bool TpuExecutor::DeviceMemoryUsage(int64* free, int64* total) const {
212   int64_t _free;
213   int64_t _total;
214   if (tpu::ExecutorApiFn()->TpuExecutor_DeviceMemoryUsageFn(executor_, &_free,
215                                                             &_total)) {
216     *free = _free;
217     *total = _total;
218     return true;
219   }
220   return false;
221 }
222 
223 absl::optional<stream_executor::AllocatorStats>
GetAllocatorStats()224 TpuExecutor::GetAllocatorStats() {
225   SE_AllocatorStats c_stats;
226   if (tpu::ExecutorApiFn()->TpuExecutor_GetAllocatorStatsFn(executor_,
227                                                             &c_stats)) {
228     ::stream_executor::AllocatorStats stats;
229     stats.num_allocs = c_stats.num_allocs;
230     stats.bytes_in_use = c_stats.bytes_in_use;
231     stats.peak_bytes_in_use = c_stats.peak_bytes_in_use;
232     stats.largest_alloc_size = c_stats.largest_alloc_size;
233     if (c_stats.has_bytes_limit) {
234       stats.bytes_limit = c_stats.bytes_limit;
235     }
236     stats.bytes_reserved = c_stats.bytes_reserved;
237     stats.peak_bytes_reserved = c_stats.peak_bytes_reserved;
238     if (c_stats.has_bytes_reservable_limit) {
239       stats.bytes_reservable_limit = c_stats.bytes_reservable_limit;
240     }
241     stats.largest_free_block_bytes = c_stats.largest_free_block_bytes;
242     return stats;
243   }
244   return {};
245 }
246 
WaitForInfeedReady(int32 infeed_queue_index)247 Status TpuExecutor::WaitForInfeedReady(int32 infeed_queue_index) {
248   StatusHelper status;
249   tpu::ExecutorApiFn()->TpuExecutor_WaitForInfeedReadyFn(
250       executor_, infeed_queue_index, status.c_status);
251   return status.status();
252 }
253 
WaitForOutfeedReady(int32 outfeed_queue_index)254 Status TpuExecutor::WaitForOutfeedReady(int32 outfeed_queue_index) {
255   StatusHelper status;
256   tpu::ExecutorApiFn()->TpuExecutor_WaitForOutfeedReadyFn(
257       executor_, outfeed_queue_index, status.c_status);
258   return status.status();
259 }
260 
DequeueOutfeed(int32 outfeed_queue_index,absl::Span<uint8> bytes,StatusCallback done)261 void TpuExecutor::DequeueOutfeed(int32 outfeed_queue_index,
262                                  absl::Span<uint8> bytes, StatusCallback done) {
263   StatusHelper status;
264   tpu::ExecutorApiFn()->TpuExecutor_DequeueOutfeedFn(
265       executor_, outfeed_queue_index, bytes.data(), bytes.size(),
266       status.c_status);
267   done(status.status());
268 }
269 
EnqueueInfeed(int32 infeed_queue_index,absl::Span<const uint8> bytes)270 Status TpuExecutor::EnqueueInfeed(int32 infeed_queue_index,
271                                   absl::Span<const uint8> bytes) {
272   StatusHelper status;
273   tpu::ExecutorApiFn()->TpuExecutor_EnqueueInfeedFn(
274       executor_, infeed_queue_index, bytes.data(), bytes.size(),
275       status.c_status);
276   return status.status();
277 }
278 
Memcpy(Stream * stream,void * host_dst,const::stream_executor::DeviceMemoryBase & device_src,uint64 size)279 bool TpuExecutor::Memcpy(Stream* stream, void* host_dst,
280                          const ::stream_executor::DeviceMemoryBase& device_src,
281                          uint64 size) {
282   SE_DeviceMemoryBase se_base = ApiConverter::ToC(device_src);
283   return tpu::ExecutorApiFn()->TpuExecutor_MemcpyToHostFn(
284       executor_, get_stream(stream->implementation()), host_dst, &se_base,
285       size);
286 }
287 
Memcpy(Stream * stream,::stream_executor::DeviceMemoryBase * device_dst,const void * host_src,uint64 size)288 bool TpuExecutor::Memcpy(Stream* stream,
289                          ::stream_executor::DeviceMemoryBase* device_dst,
290                          const void* host_src, uint64 size) {
291   SE_DeviceMemoryBase se_base = ApiConverter::ToC(*device_dst);
292   return tpu::ExecutorApiFn()->TpuExecutor_MemcpyFromHostFn(
293       executor_, get_stream(stream->implementation()), &se_base, host_src,
294       size);
295 }
296 
SynchronousMemcpy(::stream_executor::DeviceMemoryBase * device_dst,const void * host_src,uint64 size)297 Status TpuExecutor::SynchronousMemcpy(
298     ::stream_executor::DeviceMemoryBase* device_dst, const void* host_src,
299     uint64 size) {
300   StatusHelper status;
301   SE_DeviceMemoryBase se_base = ApiConverter::ToC(*device_dst);
302   tpu::ExecutorApiFn()->TpuExecutor_SynchronousMemcpyFromHostFn(
303       executor_, &se_base, host_src, size, status.c_status);
304   return status.status();
305 }
306 
SynchronousMemcpy(void * host_dst,const::stream_executor::DeviceMemoryBase & device_src,uint64 size)307 Status TpuExecutor::SynchronousMemcpy(
308     void* host_dst, const ::stream_executor::DeviceMemoryBase& device_src,
309     uint64 size) {
310   StatusHelper status;
311   SE_DeviceMemoryBase se_base = ApiConverter::ToC(device_src);
312   tpu::ExecutorApiFn()->TpuExecutor_SynchronousMemcpyToHostFn(
313       executor_, host_dst, &se_base, size, status.c_status);
314   return status.status();
315 }
316 
SynchronousMemcpyDeviceToDevice(::stream_executor::DeviceMemoryBase * device_dst,const::stream_executor::DeviceMemoryBase & device_src,uint64 size)317 Status TpuExecutor::SynchronousMemcpyDeviceToDevice(
318     ::stream_executor::DeviceMemoryBase* device_dst,
319     const ::stream_executor::DeviceMemoryBase& device_src, uint64 size) {
320   return ::stream_executor::port::UnimplementedError(
321       "This operation not supported on TPU");
322 }
323 
MemcpyDeviceToDevice(Stream * stream,::stream_executor::DeviceMemoryBase * gpu_dst,const::stream_executor::DeviceMemoryBase & host_src,uint64 size)324 bool TpuExecutor::MemcpyDeviceToDevice(
325     Stream* stream, ::stream_executor::DeviceMemoryBase* gpu_dst,
326     const ::stream_executor::DeviceMemoryBase& host_src, uint64 size) {
327   LOG(FATAL) << __func__ << " not supported on TpuExecutor";
328 }
329 
330 struct HostCallbackContext {
331   std::function<Status()> callback;
332 };
333 
HostCallbackTrampoline(void * ctx)334 TF_Status* HostCallbackTrampoline(void* ctx) {
335   HostCallbackContext* host_ctx = reinterpret_cast<HostCallbackContext*>(ctx);
336   Status status = host_ctx->callback();
337   TF_Status* c_status = tpu::ExecutorApiFn()->TpuStatus_CreateFn(
338       status.code(), status.error_message().c_str());
339   delete host_ctx;
340   return c_status;
341 }
342 
HostCallback(Stream * stream,std::function<Status ()> callback)343 bool TpuExecutor::HostCallback(Stream* stream,
344                                std::function<Status()> callback) {
345   HostCallbackContext* ctx = new HostCallbackContext{callback};
346   return tpu::ExecutorApiFn()->TpuExecutor_HostCallbackFn(
347       executor_, get_stream(stream->implementation()), &HostCallbackTrampoline,
348       ctx);
349 }
350 
351 TpuExecutor::StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
CreateDeviceDescription() const352 TpuExecutor::CreateDeviceDescription() const {
353   StatusHelper status;
354   SE_DeviceDescription* description =
355       tpu::ExecutorApiFn()->TpuDeviceDescription_NewFn();
356   auto cleanup = tensorflow::gtl::MakeCleanup([description]() {
357     tpu::ExecutorApiFn()->TpuDeviceDescription_FreeFn(description);
358   });
359   tpu::ExecutorApiFn()->TpuExecutor_CreateDeviceDescriptionFn(
360       executor_, description, status.c_status);
361   if (status.status().ok()) {
362     stream_executor::internal::DeviceDescriptionBuilder builder;
363     CHECK_NE(description->device_vendor, nullptr);
364     builder.set_device_vendor(description->device_vendor);
365     builder.set_name(description->name);
366     builder.set_clock_rate_ghz(description->clock_rate_ghz);
367     builder.set_core_count(description->core_count);
368     builder.set_ecc_enabled(description->ecc_enabled);
369     builder.set_device_memory_size(description->device_memory_size);
370     builder.set_platform_version(description->platform_version);
371     return builder.Build();
372   }
373   return status.status();
374 }
375 
376 }  // namespace tpu
377 }  // namespace tensorflow
378