1 // Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // ==============================================================================
15 
16 #include <complex>
17 #include <cstddef>
18 #include <functional>
19 #include <memory>
20 
21 #include "grpcpp/grpcpp.h"
22 #include "absl/base/thread_annotations.h"
23 #include "absl/strings/strip.h"
24 #include "absl/synchronization/mutex.h"
25 #include "absl/time/clock.h"
26 #include "absl/time/time.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/python/tpu_driver/event_id.h"
29 #include "tensorflow/compiler/xla/python/tpu_driver/platform/external/compat.h"
30 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
31 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
32 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_service.grpc.pb.h"
33 #include "tensorflow/compiler/xla/util.h"
34 
35 namespace tpu_driver {
36 namespace {
37 
38 using xla::Status;
39 
40 const int64_t kMaxStreamWriteSize = 10 * 1000 * 1000;
41 const absl::Duration kWriteEpochDuration = absl::Microseconds(10);
42 
43 constexpr char kGrpcProtocol[] = "grpc://";
44 
45 class GrpcTpuStream;
46 class GrpcTpuDriver;
47 
48 class GrpcEvent : public Event {
49  public:
GrpcEvent(EventId id,GrpcTpuStream * stream)50   explicit GrpcEvent(EventId id, GrpcTpuStream* stream)
51       : id_(id), stream_(stream) {}
52   ~GrpcEvent() override;
53 
54   xla::Status Await() override;
55   absl::optional<xla::Status> AwaitWithTimeout(
56       absl::Duration duration) override;
57   void AddCallback(std::function<void(Status)> callback) override;
58 
id() const59   EventId id() const { return id_; }
stream() const60   GrpcTpuStream* stream() const { return stream_; }
61 
62  private:
63   const EventId id_;
64   GrpcTpuStream* stream_;
65 };
66 
67 class ErrorEvent : public GrpcEvent {
68  public:
ErrorEvent(Status status)69   explicit ErrorEvent(Status status) : GrpcEvent(EventId{0, 0}, nullptr) {
70     status_ = status;
71   }
72 
Await()73   xla::Status Await() override { return status_; }
AwaitWithTimeout(absl::Duration duration)74   absl::optional<xla::Status> AwaitWithTimeout(
75       absl::Duration duration) override {
76     return status_;
77   }
AddCallback(std::function<void (Status)> callback)78   void AddCallback(std::function<void(Status)> callback) override {
79     callback(status_);
80   }
81 
82  private:
83   Status status_;
84 };
85 
86 class GrpcBufferHandle : public BufferHandle {
87  public:
GrpcBufferHandle(EventId id,std::shared_ptr<GrpcEvent> event,int64_t bytes,absl::optional<xla::ShapeProto> shape=absl::nullopt)88   explicit GrpcBufferHandle(
89       EventId id, std::shared_ptr<GrpcEvent> event, int64_t bytes,
90       absl::optional<xla::ShapeProto> shape = absl::nullopt)
91       : id_(id),
92         stream_(event->stream()),
93         event_(std::move(event)),
94         bytes_(bytes),
95         shape_(shape) {}
96 
OnReady()97   std::shared_ptr<Event> OnReady() override { return event_; }
size_in_bytes()98   int64_t size_in_bytes() override { return bytes_; }
99 
id() const100   EventId id() const { return id_; }
stream() const101   GrpcTpuStream* stream() const { return stream_; }
102 
shape()103   absl::optional<xla::ShapeProto> shape() override { return shape_; }
104 
105  private:
106   const EventId id_;
107   GrpcTpuStream* stream_;
108   std::shared_ptr<GrpcEvent> event_;
109   int64_t bytes_;
110   absl::optional<xla::ShapeProto> shape_;
111 };
112 
113 class GrpcCompiledProgramHandle : public CompiledProgramHandle {
114  public:
GrpcCompiledProgramHandle(EventId id,std::shared_ptr<GrpcEvent> event)115   explicit GrpcCompiledProgramHandle(EventId id,
116                                      std::shared_ptr<GrpcEvent> event)
117       : id_(id),
118         stream_(event->stream()),
119         event_(std::move(event)),
120         metadata_(std::make_shared<CompiledProgramMetadata>()) {}
121 
OnReady()122   std::shared_ptr<Event> OnReady() override { return event_; }
123 
id() const124   EventId id() const { return id_; }
stream() const125   GrpcTpuStream* stream() const { return stream_; }
126 
program_shape(xla::ProgramShapeProto * program_shape)127   Status program_shape(xla::ProgramShapeProto* program_shape) override {
128     auto opt_status = OnReady()->AwaitWithTimeout(absl::Hours(1));
129     if (!opt_status.has_value()) {
130       return xla::InternalError("Compile failed to finish within 1 hour.");
131     }
132 
133     Status status = opt_status.value();
134     if (!status.ok()) {
135       return status;
136     }
137     *program_shape = metadata_->program_shape();
138     return Status::OK();
139   }
140 
metadata()141   std::shared_ptr<CompiledProgramMetadata> metadata() { return metadata_; }
142 
143  private:
144   const EventId id_;
145   GrpcTpuStream* stream_;
146   std::shared_ptr<GrpcEvent> event_;
147 
148   // Using a shared pointer here because the program handle can go out of scope
149   // before we get a response back, but we want a valid location to write things
150   // into regardless.
151   std::shared_ptr<CompiledProgramMetadata> metadata_;
152 };
153 
154 class GrpcLoadedProgramHandle : public LoadedProgramHandle {
155  public:
GrpcLoadedProgramHandle(EventId id,std::shared_ptr<GrpcEvent> event)156   explicit GrpcLoadedProgramHandle(EventId id, std::shared_ptr<GrpcEvent> event)
157       : id_(id), stream_(event->stream()), event_(std::move(event)) {}
158 
OnReady()159   std::shared_ptr<Event> OnReady() override { return event_; }
160 
id() const161   EventId id() const { return id_; }
stream() const162   GrpcTpuStream* stream() const { return stream_; }
163 
164  private:
165   const EventId id_;
166   GrpcTpuStream* stream_;
167   std::shared_ptr<GrpcEvent> event_;
168 };
169 
170 class GrpcTpuStream {
171  public:
172   explicit GrpcTpuStream(int32_t id, GrpcTpuDriver* driver,
173                          std::unique_ptr<grpc::CloudTpuDriver::Stub> stub);
174   virtual ~GrpcTpuStream();
175 
176   std::unique_ptr<BufferHandle> Allocate(int32_t core_id, MemoryRegion region,
177                                          int64_t num_bytes,
178                                          absl::Span<Event* const> wait_for);
179   std::unique_ptr<BufferHandle> Allocate(int32_t core_id, MemoryRegion region,
180                                          const xla::ShapeProto& shape,
181                                          absl::Span<Event* const> wait_for);
182   std::unique_ptr<BufferHandle> AllocateTuple(
183       int32_t core_id, MemoryRegion region,
184       absl::Span<BufferHandle* const> children,
185       absl::Span<Event* const> wait_for);
186   std::shared_ptr<Event> Deallocate(std::unique_ptr<BufferHandle> handle,
187                                     absl::Span<Event* const> wait_for);
188 
189   std::shared_ptr<Event> TransferToDevice(const void* src, BufferHandle* dst,
190                                           absl::Span<Event* const> wait_for);
191   std::shared_ptr<Event> TransferFromDevice(const BufferHandle* src, void* dst,
192                                             absl::Span<Event* const> wait_for);
193 
194   std::shared_ptr<Event> TransferFromDeviceToDevice(
195       const BufferHandle* src, BufferHandle* dst,
196       absl::Span<Event* const> wait_for);
197 
198   std::unique_ptr<CompiledProgramHandle> CompileProgram(
199       const xla::HloProto& source, int32_t num_replicas,
200       absl::Span<Event* const> wait_for);
201   std::unique_ptr<LoadedProgramHandle> LoadProgram(
202       int32_t core_id, const CompiledProgramHandle* handle,
203       absl::Span<Event* const> wait_for);
204   std::shared_ptr<Event> UnloadProgram(
205       std::unique_ptr<LoadedProgramHandle> handle,
206       absl::Span<Event* const> wait_for);
207   std::shared_ptr<Event> ExecuteProgram(
208       LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
209       absl::Span<BufferHandle* const> outputs,
210       const xla::DeviceAssignmentProto& device_assignment,
211       absl::Span<Event* const> wait_for);
212 
213  private:
214   friend class GrpcEvent;
215   friend class GrpcTpuDriver;
216 
217   struct EventInfo {
218     bool all_deps_done = false;
219     bool done = false;     // response received
220     bool deleted = false;  // deleted by the user
221     Status status;
222     absl::InlinedVector<std::function<void(Status)>, 1> callbacks;
223     // Most events should have <= 2 requirement events.
224     absl::InlinedVector<EventId, 2> deps;
225   };
226 
227   struct TransferInfo {
TransferInfotpu_driver::__anond0091eda0111::GrpcTpuStream::TransferInfo228     explicit TransferInfo(void* dst, int64_t num_bytes)
229         : dst(dst), num_bytes(num_bytes) {}
230 
231     void* const dst;
232     const uint64_t num_bytes;
233   };
234 
235   struct CompileMetadataInfo {
CompileMetadataInfotpu_driver::__anond0091eda0111::GrpcTpuStream::CompileMetadataInfo236     explicit CompileMetadataInfo(
237         std::shared_ptr<CompiledProgramMetadata> metadata) {
238       compiled_metadata = metadata;
239     }
240     std::shared_ptr<CompiledProgramMetadata> compiled_metadata;
241   };
242 
243   // Every public method above should call this first.
244   void InitializeRequest(StreamRequest::Entry* req,
245                          absl::Span<Event* const> wait_for)
246       ABSL_LOCKS_EXCLUDED(events_mutex_);
247 
248   // The first update to an event marks it done and calls registered callbacks.
249   // All subsequent updates must have the same OK-ness as the first update.
250   // Among non-OK updates, only the first error status is remembered.
251   void UpdateEventStatus(EventId id, Status status)
252       ABSL_EXCLUSIVE_LOCKS_REQUIRED(events_mutex_);
253 
254   // To ensure callbacks are still triggered, after this is called, we do not
255   // remove the event from the event mapping until a response is received from
256   // the server.
257   void DeleteEvent(EventId id) ABSL_LOCKS_EXCLUDED(events_mutex_);
258 
259   // Wait at most `duration` for event `id` to complete. Returns the event
260   // status or an empty optional if the event does not complete in time.
261   absl::optional<Status> WaitForEvent(EventId id, absl::Duration duration)
262       ABSL_LOCKS_EXCLUDED(events_mutex_);
263 
264   void AddEventCallback(EventId id, std::function<void(Status)> callback)
265       ABSL_LOCKS_EXCLUDED(events_mutex_);
266 
AddWriteRequest(std::unique_ptr<StreamRequest::Entry> req)267   void AddWriteRequest(std::unique_ptr<StreamRequest::Entry> req) {
268     absl::MutexLock m(&request_lock_);
269     VLOG(2) << "Adding request: " << req->DebugString();
270     requests_.push_back(std::move(req));
271   }
272 
273   // Unique identifier for this stream.
274   int32_t id_;
275   // The parent driver that created this stream.
276   GrpcTpuDriver* driver_;
277 
278   std::unique_ptr<grpc::CloudTpuDriver::Stub> stub_;
279   ::grpc::ClientContext ctx_;
280   std::unique_ptr<
281       ::grpc::ClientReaderWriterInterface<StreamRequest, StreamResponse>>
282       stream_;
283 
284   absl::Mutex request_lock_;
285   std::deque<std::unique_ptr<StreamRequest::Entry>> requests_
286       ABSL_GUARDED_BY(request_lock_);
287   int64_t num_pending_requests_ ABSL_GUARDED_BY(request_lock_) = 0;
288 
289   bool shutting_down_ ABSL_GUARDED_BY(request_lock_) = false;
290 
291   void StreamWriterFn();
292   Thread writer_thread_;
293 
294   void StreamReaderFn();
295   Thread reader_thread_;
296 
297   // Map from operation ID to event information.
298   absl::Mutex events_mutex_;
299   absl::flat_hash_map<EventId, EventInfo> events_
300       ABSL_GUARDED_BY(events_mutex_);
301 
302   // Map from operation ID to transfer information.
303   // When a D2H transfer completes, received data is copied into the `dst`
304   // pointer in `TransferInfo`.
305   absl::Mutex transfers_mutex_;
306   absl::flat_hash_map<EventId, TransferInfo> transfers_
307       ABSL_GUARDED_BY(transfers_mutex_);
308 
309   absl::Mutex compiles_mutex_;
310   absl::flat_hash_map<EventId, CompileMetadataInfo> compiles_
311       ABSL_GUARDED_BY(compiles_mutex_);
312 };
313 
314 class GrpcTpuDriver : public TpuDriver {
315  public:
GrpcTpuDriver(const TpuDriverConfig & config,std::shared_ptr<::grpc::ChannelCredentials> creds,int32_t client_id)316   explicit GrpcTpuDriver(const TpuDriverConfig& config,
317                          std::shared_ptr<::grpc::ChannelCredentials> creds,
318                          int32_t client_id)
319       : config_(config), creds_(creds), client_id_(client_id) {
320     SystemInfo system_info;
321     QuerySystemInfo(&system_info);
322     for (auto& chip_info : system_info.tpu_chip()) {
323       for (auto& core_info : chip_info.core()) {
324         int32_t core_id = core_info.id();
325         // We have one stream per core, so use core ID as stream ID.
326         streams_[core_id] = AllocateStream(core_id);
327       }
328     }
329     CHECK_GT(streams_.size(), 0) << "Can't find any TPU chip in the system.";
330 
331     host_stream_ = AllocateStream(-1);
332   }
333 
~GrpcTpuDriver()334   ~GrpcTpuDriver() override {
335     if (closed_) {
336       return;
337     }
338     auto status = Close();
339     if (!status.ok()) {
340       LOG(ERROR) << status;
341     }
342   }
343 
344   void QuerySystemInfo(SystemInfo* system_info) override;
345   Status Reset() override;
346 
Allocate(int32_t core_id,MemoryRegion region,int64_t num_bytes,absl::Span<Event * const> wait_for)347   std::unique_ptr<BufferHandle> Allocate(
348       int32_t core_id, MemoryRegion region, int64_t num_bytes,
349       absl::Span<Event* const> wait_for) override {
350     return streams_[core_id]->Allocate(core_id, region, num_bytes, wait_for);
351   }
Allocate(int32_t core_id,MemoryRegion region,const xla::ShapeProto & shape,absl::Span<Event * const> wait_for)352   std::unique_ptr<BufferHandle> Allocate(
353       int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
354       absl::Span<Event* const> wait_for) override {
355     return streams_[core_id]->Allocate(core_id, region, shape, wait_for);
356   }
AllocateTuple(int32_t core_id,MemoryRegion region,absl::Span<BufferHandle * const> children,absl::Span<Event * const> wait_for)357   std::unique_ptr<BufferHandle> AllocateTuple(
358       int32_t core_id, MemoryRegion region,
359       absl::Span<BufferHandle* const> children,
360       absl::Span<Event* const> wait_for) override {
361     return streams_[core_id]->AllocateTuple(core_id, region, children,
362                                             wait_for);
363   }
Deallocate(std::unique_ptr<BufferHandle> handle,absl::Span<Event * const> wait_for)364   std::shared_ptr<Event> Deallocate(
365       std::unique_ptr<BufferHandle> handle,
366       absl::Span<Event* const> wait_for) override {
367     auto* stream = static_cast<GrpcBufferHandle*>(handle.get())->stream();
368     return stream->Deallocate(std::move(handle), wait_for);
369   }
370 
TransferToDevice(const void * src,BufferHandle * dst,absl::Span<Event * const> wait_for)371   std::shared_ptr<Event> TransferToDevice(
372       const void* src, BufferHandle* dst,
373       absl::Span<Event* const> wait_for) override {
374     auto* stream = static_cast<GrpcBufferHandle*>(dst)->stream();
375     return stream->TransferToDevice(src, dst, wait_for);
376   }
TransferFromDevice(const BufferHandle * src,void * dst,absl::Span<Event * const> wait_for)377   std::shared_ptr<Event> TransferFromDevice(
378       const BufferHandle* src, void* dst,
379       absl::Span<Event* const> wait_for) override {
380     auto* stream = static_cast<const GrpcBufferHandle*>(src)->stream();
381     return stream->TransferFromDevice(src, dst, wait_for);
382   }
383 
TransferFromDeviceToDevice(const BufferHandle * src,BufferHandle * dst,absl::Span<Event * const> wait_for)384   std::shared_ptr<Event> TransferFromDeviceToDevice(
385       const BufferHandle* src, BufferHandle* dst,
386       absl::Span<Event* const> wait_for) override {
387     auto* stream = static_cast<const GrpcBufferHandle*>(src)->stream();
388     return stream->TransferFromDeviceToDevice(src, dst, wait_for);
389   }
390 
CompileProgram(const xla::HloProto & source,int32_t num_replicas,absl::Span<Event * const> wait_for)391   std::unique_ptr<CompiledProgramHandle> CompileProgram(
392       const xla::HloProto& source, int32_t num_replicas,
393       absl::Span<Event* const> wait_for) override {
394     // Always compile using the first/default core's stream.
395     return streams_[0]->CompileProgram(source, num_replicas, wait_for);
396   }
LoadProgram(int32_t core_id,const CompiledProgramHandle * handle,absl::Span<Event * const> wait_for)397   std::unique_ptr<LoadedProgramHandle> LoadProgram(
398       int32_t core_id, const CompiledProgramHandle* handle,
399       absl::Span<Event* const> wait_for) override {
400     return streams_[core_id]->LoadProgram(core_id, handle, wait_for);
401   }
UnloadProgram(std::unique_ptr<LoadedProgramHandle> handle,absl::Span<Event * const> wait_for)402   std::shared_ptr<Event> UnloadProgram(
403       std::unique_ptr<LoadedProgramHandle> handle,
404       absl::Span<Event* const> wait_for) override {
405     auto* stream =
406         static_cast<const GrpcLoadedProgramHandle*>(handle.get())->stream();
407     return stream->UnloadProgram(std::move(handle), wait_for);
408   }
ExecuteProgram(LoadedProgramHandle * program,absl::Span<BufferHandle * const> inputs,absl::Span<BufferHandle * const> outputs,const xla::DeviceAssignmentProto & device_assignment,absl::Span<Event * const> wait_for)409   std::shared_ptr<Event> ExecuteProgram(
410       LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
411       absl::Span<BufferHandle* const> outputs,
412       const xla::DeviceAssignmentProto& device_assignment,
413       absl::Span<Event* const> wait_for) override {
414     auto* stream =
415         static_cast<const GrpcLoadedProgramHandle*>(program)->stream();
416     return stream->ExecuteProgram(program, inputs, outputs, device_assignment,
417                                   wait_for);
418   }
419 
NewOperationId()420   EventId NewOperationId() { return EventId{client_id_, ++operation_id_}; }
421 
422   static std::unique_ptr<grpc::CloudTpuDriver::Stub> CreateTpuDriverStub(
423       const TpuDriverConfig& config,
424       std::shared_ptr<::grpc::ChannelCredentials> creds);
425 
client_id() const426   uint32_t client_id() const { return client_id_; }
427 
428  private:
429   Status Close();
430   std::unique_ptr<GrpcTpuStream> AllocateStream(int32_t core_id);
431 
432   const TpuDriverConfig config_;
433   std::shared_ptr<::grpc::ChannelCredentials> creds_;
434   const uint32_t client_id_;
435   // Map from stream IDs to streams.
436   absl::flat_hash_map<int32_t, std::unique_ptr<GrpcTpuStream>> streams_;
437   std::unique_ptr<GrpcTpuStream> host_stream_;
438   // Shared by all streams.
439   std::atomic<uint64_t> operation_id_{0};
440   std::atomic<bool> closed_{false};
441 };  // namespace
442 
~GrpcEvent()443 GrpcEvent::~GrpcEvent() { stream_->DeleteEvent(id_); }
444 
Await()445 Status GrpcEvent::Await() {
446   auto opt_status = stream_->WaitForEvent(id_, absl::InfiniteDuration());
447   return opt_status.value();
448 }
449 
AwaitWithTimeout(absl::Duration duration)450 absl::optional<Status> GrpcEvent::AwaitWithTimeout(absl::Duration duration) {
451   return stream_->WaitForEvent(id_, duration);
452 }
453 
AddCallback(std::function<void (Status)> callback)454 void GrpcEvent::AddCallback(std::function<void(Status)> callback) {
455   stream_->AddEventCallback(id_, std::move(callback));
456 }
457 
GrpcTpuStream(int32_t id,GrpcTpuDriver * driver,std::unique_ptr<grpc::CloudTpuDriver::Stub> stub)458 GrpcTpuStream::GrpcTpuStream(int32_t id, GrpcTpuDriver* driver,
459                              std::unique_ptr<grpc::CloudTpuDriver::Stub> stub)
460     : id_(id),
461       driver_(driver),
462       stub_(std::move(stub)),
463       stream_(stub_->StreamExecute(&ctx_)),
464       writer_thread_(&GrpcTpuStream::StreamWriterFn, this),
465       reader_thread_(&GrpcTpuStream::StreamReaderFn, this) {}
466 
~GrpcTpuStream()467 GrpcTpuStream::~GrpcTpuStream() {
468   {
469     absl::MutexLock lock(&request_lock_);
470     shutting_down_ = true;
471   }
472 
473   VLOG(1) << "Shutting down stream.";
474   {
475     // Mark all remaining events invalid.
476     absl::MutexLock lock(&events_mutex_);
477     for (const auto& e : events_) {
478       if (!e.second.done) {
479         LOG(ERROR) << "Resetting: " << e.first;
480         UpdateEventStatus(e.first, xla::Status(tensorflow::error::Code::ABORTED,
481                                                "Driver was closed."));
482       }
483     }
484   }
485   VLOG(1) << "Closing stream.";
486   stream_->WritesDone();
487   stream_->Finish().IgnoreError();
488   VLOG(1) << "Waiting for writer.";
489   writer_thread_.join();
490   VLOG(1) << "Waiting for reader.";
491   reader_thread_.join();
492 }
493 
InitializeRequest(StreamRequest::Entry * req,absl::Span<Event * const> wait_for)494 void GrpcTpuStream::InitializeRequest(StreamRequest::Entry* req,
495                                       absl::Span<Event* const> wait_for) {
496   auto operation_id = driver_->NewOperationId();
497   EventInfo event_info;
498 
499   req->set_operation_id(operation_id.AsInt());
500   if (wait_for.empty()) {
501     event_info.all_deps_done = true;
502   } else {
503     event_info.deps.reserve(wait_for.size());
504     for (auto* event : wait_for) {
505       auto grpc_event = static_cast<const GrpcEvent*>(event);
506       req->add_wait_for_id(grpc_event->id().AsInt());
507       event_info.deps.push_back(grpc_event->id());
508     }
509   }
510 
511   absl::MutexLock lock(&events_mutex_);
512   events_[operation_id] = event_info;
513 }
514 
UpdateEventStatus(EventId id,Status status)515 void GrpcTpuStream::UpdateEventStatus(EventId id, Status status) {
516   auto it = events_.find(id);
517 
518   // These should only happen when the server shuts down, and our local event
519   // cancellation interleaves with server responses. It should be safe to ignore
520   // the second updates in these situations.
521   if (it == events_.end()) {
522     VLOG(1) << "Received a status update: " << status
523             << ", but cannot find GrpcEvent " << id;
524     return;
525   }
526   if (it->second.done) {
527     // Done and deleted events must have already been removed.
528     CHECK(!it->second.deleted);
529     VLOG(1) << "Received a second status update: " << status.error_message()
530             << ", for GrpcEvent " << id << " already done with status: "
531             << it->second.status.error_message();
532     return;
533   }
534 
535   // This is the first time this event finishes. Remember the results and call
536   // the callbacks.
537   VLOG(1) << "Response received for GrpcEvent " << id << ". "
538           << status.ToString() << ". Firing " << it->second.callbacks.size()
539           << " callbacks.";
540   it->second.done = true;
541   it->second.status = status;
542   for (const auto& callback : it->second.callbacks) {
543     callback(status);
544   }
545 
546   // Truly remove the event if it's both done and deleted.
547   if (it->second.deleted) {
548     events_.erase(it);
549   }
550 }
551 
DeleteEvent(EventId id)552 void GrpcTpuStream::DeleteEvent(EventId id) {
553   absl::MutexLock lock(&events_mutex_);
554   auto it = events_.find(id);
555   CHECK(it != events_.end());
556   CHECK(!it->second.deleted);
557   it->second.deleted = true;
558   // Truly remove the event if it's both done and deleted.
559   if (it->second.done) {
560     events_.erase(it);
561   }
562 }
563 
WaitForEvent(EventId id,absl::Duration duration)564 absl::optional<Status> GrpcTpuStream::WaitForEvent(EventId id,
565                                                    absl::Duration duration) {
566   events_mutex_.Lock();
567   auto it = events_.find(id);
568 
569   if (it == events_.end()) {
570     // This event has already been marked as done and deleted. Assume success.
571     events_mutex_.Unlock();
572     return Status::OK();
573   }
574 
575   if (!it->second.all_deps_done) {
576     absl::InlinedVector<EventId, 2> deps = it->second.deps;
577     events_mutex_.Unlock();
578     for (auto dep : deps) {
579       // If a requirement event timed out, no point in any further waiting.
580       if (!WaitForEvent(dep, duration)) {
581         return absl::nullopt;
582       }
583     }
584     events_mutex_.Lock();
585   }
586 
587   // Set the flag here, as we're guaranteed they have all completed at this
588   // point. This helps terminate recursion on a chain of completed events as
589   // soon as possible, at this event.
590   it = events_.find(id);
591   if (it != events_.end()) {
592     it->second.all_deps_done = true;
593   }
594 
595   auto done = [this, id]() {
596     events_mutex_.AssertHeld();
597     return !events_.contains(id) || events_[id].done;
598   };
599   if (events_mutex_.AwaitWithTimeout(absl::Condition(&done), duration)) {
600     auto status = events_.contains(id) ? events_[id].status : Status::OK();
601     events_mutex_.Unlock();
602     return status;
603   }
604   events_mutex_.Unlock();
605   return absl::nullopt;
606 }
607 
AddEventCallback(EventId id,std::function<void (Status)> callback)608 void GrpcTpuStream::AddEventCallback(EventId id,
609                                      std::function<void(Status)> callback) {
610   absl::MutexLock lock(&events_mutex_);
611   auto it = events_.find(id);
612   if (it == events_.end()) {
613     callback(Status());
614     return;
615   }
616   if (it->second.done) {
617     callback(it->second.status);
618     return;
619   }
620   it->second.callbacks.push_back(std::move(callback));
621 }
622 
ShouldBeginWriting(int64_t * pending_requests)623 static bool ShouldBeginWriting(int64_t* pending_requests) {
624   return *pending_requests > 32;
625 }
626 
StreamWriterFn()627 void GrpcTpuStream::StreamWriterFn() {
628   while (true) {
629     request_lock_.LockWhenWithTimeout(
630         absl::Condition(&ShouldBeginWriting, &num_pending_requests_),
631         kWriteEpochDuration);
632     if (shutting_down_) {
633       request_lock_.Unlock();
634       return;
635     }
636 
637     if (requests_.empty()) {
638       request_lock_.Unlock();
639       continue;
640     }
641 
642     std::vector<StreamRequest> reqs;
643     int64_t request_bytes = 0;
644     while (!requests_.empty()) {
645       StreamRequest::Entry* e = requests_.front().release();
646       requests_.pop_front();
647       const int64_t entry_bytes = e->ByteSizeLong();
648       if (reqs.empty() || request_bytes + entry_bytes > kMaxStreamWriteSize) {
649         reqs.push_back(StreamRequest());
650         request_bytes = 0;
651       }
652       VLOG(1) << "Sending request: " << EventId::FromInt(e->operation_id());
653       VLOG(2) << "Sending request: " << e->DebugString();
654       reqs.back().mutable_entry()->AddAllocated(e);
655     }
656     num_pending_requests_ = 0;
657     request_lock_.Unlock();
658 
659     for (const auto& r : reqs) {
660       TraceMe activity("GrpcTpuStream::Send ");
661       ::grpc::WriteOptions opts;
662       opts.set_no_compression().clear_buffer_hint();
663       stream_->Write(r, opts);
664     }
665   }
666 }
667 
StreamReaderFn()668 void GrpcTpuStream::StreamReaderFn() {
669   StreamResponse resp;
670   while (stream_->Read(&resp)) {
671     VLOG(2) << "Received response: " << resp.DebugString();
672     for (const StreamResponse::Entry& entry : resp.entry()) {
673       EventId event_id = EventId::FromInt(entry.operation_id());
674       VLOG(1) << "Received response for: " << event_id;
675 
676       TraceMe activity("GrpcTpuStream::RequestComplete");
677       if (entry.has_transfer_from()) {
678         TraceMe activity("GrpcTpuStream::TransferFromComplete");
679         absl::MutexLock lock(&transfers_mutex_);
680         auto it = transfers_.find(event_id);
681         CHECK(it != transfers_.end());
682         VLOG(1) << "Copying: " << it->second.num_bytes << " to position "
683                 << it->second.dst;
684         if (entry.transfer_from().data().size() != it->second.num_bytes) {
685           absl::MutexLock lock(&events_mutex_);
686           UpdateEventStatus(
687               event_id,
688               Status(
689                   tensorflow::error::Code::DATA_LOSS,
690                   absl::StrCat("Expected ", it->second.num_bytes, " received ",
691                                entry.transfer_from().data().size())));
692           continue;
693         }
694         memcpy(it->second.dst, entry.transfer_from().data().data(),
695                it->second.num_bytes);
696       }
697 
698       if (entry.has_compile()) {
699         TraceMe activity("GrpcTpuStream::CompileComplete");
700         absl::MutexLock lock(&compiles_mutex_);
701         auto it = compiles_.find(event_id);
702         CHECK(it != compiles_.end());
703         *it->second.compiled_metadata = entry.compile().metadata();
704       }
705 
706       absl::MutexLock lock(&events_mutex_);
707       if (entry.status().code() != tensorflow::error::Code::OK) {
708         UpdateEventStatus(
709             event_id,
710             Status(static_cast<tensorflow::error::Code>(entry.status().code()),
711                    entry.status().message()));
712       } else {
713         UpdateEventStatus(event_id, Status::OK());
714       }
715     }
716   }
717 }
718 
Allocate(int32_t core_id,MemoryRegion region,int64_t num_bytes,absl::Span<Event * const> wait_for)719 std::unique_ptr<BufferHandle> GrpcTpuStream::Allocate(
720     int32_t core_id, MemoryRegion region, int64_t num_bytes,
721     absl::Span<Event* const> wait_for) {
722   auto req = absl::make_unique<StreamRequest::Entry>();
723   InitializeRequest(req.get(), wait_for);
724   TraceMe activity("GrpcTpuStream::Allocate(num_bytes)");
725   req->mutable_alloc()->set_core_id(core_id);
726   req->mutable_alloc()->set_region(region);
727   req->mutable_alloc()->set_num_bytes(num_bytes);
728   auto event =
729       std::make_shared<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
730   AddWriteRequest(std::move(req));
731   return absl::make_unique<GrpcBufferHandle>(event->id(), std::move(event),
732                                              num_bytes);
733 }
734 
Allocate(int32_t core_id,MemoryRegion region,const xla::ShapeProto & shape,absl::Span<Event * const> wait_for)735 std::unique_ptr<BufferHandle> GrpcTpuStream::Allocate(
736     int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
737     absl::Span<Event* const> wait_for) {
738   auto req = absl::make_unique<StreamRequest::Entry>();
739   InitializeRequest(req.get(), wait_for);
740   TraceMe activity("GrpcTpuStream::Allocate(shape)");
741   req->mutable_alloc()->set_core_id(core_id);
742   req->mutable_alloc()->set_region(region);
743   *req->mutable_alloc()->mutable_shape() = shape;
744   auto event =
745       std::make_shared<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
746   AddWriteRequest(std::move(req));
747   return absl::make_unique<GrpcBufferHandle>(
748       event->id(), std::move(event), ComputeBytesFromShape(shape), shape);
749 }
750 
AllocateTuple(int32_t core_id,MemoryRegion region,absl::Span<BufferHandle * const> children,absl::Span<Event * const> wait_for)751 std::unique_ptr<BufferHandle> GrpcTpuStream::AllocateTuple(
752     int32_t core_id, MemoryRegion region,
753     absl::Span<BufferHandle* const> children,
754     absl::Span<Event* const> wait_for) {
755   auto req = absl::make_unique<StreamRequest::Entry>();
756   InitializeRequest(req.get(), wait_for);
757   TraceMe activity("GrpcTpuStream::AllocateTuple");
758   req->mutable_alloc_tuple()->set_core_id(core_id);
759   req->mutable_alloc_tuple()->set_region(region);
760   for (auto child : children) {
761     auto grpc_child = static_cast<GrpcBufferHandle*>(child);
762     req->mutable_alloc_tuple()->add_children(grpc_child->id().AsInt());
763   }
764   auto event =
765       std::make_shared<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
766   AddWriteRequest(std::move(req));
767   return absl::make_unique<GrpcBufferHandle>(event->id(), std::move(event), 0);
768 }
769 
Deallocate(std::unique_ptr<BufferHandle> handle,absl::Span<Event * const> wait_for)770 std::shared_ptr<Event> GrpcTpuStream::Deallocate(
771     std::unique_ptr<BufferHandle> handle, absl::Span<Event* const> wait_for) {
772   auto req = absl::make_unique<StreamRequest::Entry>();
773   InitializeRequest(req.get(), wait_for);
774   TraceMe activity("GrpcTpuStream::Deallocate");
775   auto grpc_handle = static_cast<GrpcBufferHandle*>(handle.get());
776   req->mutable_dealloc()->set_handle(grpc_handle->id().AsInt());
777   auto event =
778       std::make_shared<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
779   AddWriteRequest(std::move(req));
780   return event;
781 }
782 
TransferToDevice(const void * src,BufferHandle * dst,absl::Span<Event * const> wait_for)783 std::shared_ptr<Event> GrpcTpuStream::TransferToDevice(
784     const void* src, BufferHandle* dst, absl::Span<Event* const> wait_for) {
785   auto req = absl::make_unique<StreamRequest::Entry>();
786   InitializeRequest(req.get(), wait_for);
787   TraceMe activity("GrpcTpuStream::TransferToDevice");
788   req->mutable_transfer_to()->mutable_data()->assign(
789       static_cast<const char*>(src), dst->size_in_bytes());
790   req->mutable_transfer_to()->set_target_handle(
791       static_cast<GrpcBufferHandle*>(dst)->id().AsInt());
792   auto event =
793       std::make_shared<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
794   AddWriteRequest(std::move(req));
795   return event;
796 }
797 
TransferFromDevice(const BufferHandle * src,void * dst,absl::Span<Event * const> wait_for)798 std::shared_ptr<Event> GrpcTpuStream::TransferFromDevice(
799     const BufferHandle* src, void* dst, absl::Span<Event* const> wait_for) {
800   auto req = absl::make_unique<StreamRequest::Entry>();
801   InitializeRequest(req.get(), wait_for);
802   TraceMe activity("GrpcTpuStream::TransferFromDevice");
803   req->mutable_transfer_from()->set_source_handle(
804       static_cast<const GrpcBufferHandle*>(src)->id().AsInt());
805   EventId event_id = EventId::FromInt(req->operation_id());
806   {
807     absl::MutexLock lock(&transfers_mutex_);
808     TransferInfo info(dst, const_cast<BufferHandle*>(src)->size_in_bytes());
809     transfers_.insert(std::make_pair(event_id, info));
810   }
811   auto event = std::make_shared<GrpcEvent>(event_id, this);
812   AddWriteRequest(std::move(req));
813   return event;
814 }
815 
TransferFromDeviceToDevice(const BufferHandle * src,BufferHandle * dst,absl::Span<Event * const> wait_for)816 std::shared_ptr<Event> GrpcTpuStream::TransferFromDeviceToDevice(
817     const BufferHandle* src, BufferHandle* dst,
818     absl::Span<Event* const> wait_for) {
819   auto req = absl::make_unique<StreamRequest::Entry>();
820   InitializeRequest(req.get(), wait_for);
821   TraceMe activity([&req] {
822     return absl::StrCat("GrpcTpuStream::TransferFromDeviceToDevice",
823                         req->operation_id());
824   });
825 
826   req->mutable_transfer_from_to()->set_source_handle(
827       static_cast<const GrpcBufferHandle*>(src)->id().AsInt());
828   req->mutable_transfer_from_to()->set_target_handle(
829       static_cast<const GrpcBufferHandle*>(dst)->id().AsInt());
830   EventId event_id = EventId::FromInt(req->operation_id());
831   auto event = std::make_shared<GrpcEvent>(event_id, this);
832   AddWriteRequest(std::move(req));
833   return event;
834 }
835 
CompileProgram(const xla::HloProto & source,int32_t num_replicas,absl::Span<Event * const> wait_for)836 std::unique_ptr<CompiledProgramHandle> GrpcTpuStream::CompileProgram(
837     const xla::HloProto& source, int32_t num_replicas,
838     absl::Span<Event* const> wait_for) {
839   auto req = absl::make_unique<StreamRequest::Entry>();
840   InitializeRequest(req.get(), wait_for);
841   TraceMe activity("GrpcTpuStream::CompileProgram");
842   *req->mutable_compile()->mutable_hlo_program() = source;
843   req->mutable_compile()->set_num_replicas(num_replicas);
844   EventId event_id = EventId::FromInt(req->operation_id());
845 
846   auto event =
847       std::make_shared<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
848 
849   auto handle = absl::make_unique<GrpcCompiledProgramHandle>(event->id(),
850                                                              std::move(event));
851   {
852     absl::MutexLock lock(&compiles_mutex_);
853     CompileMetadataInfo info(handle->metadata());
854     compiles_.insert(std::make_pair(event_id, info));
855   }
856 
857   AddWriteRequest(std::move(req));
858   return std::move(handle);
859 }
860 
LoadProgram(int32_t core_id,const CompiledProgramHandle * handle,absl::Span<Event * const> wait_for)861 std::unique_ptr<LoadedProgramHandle> GrpcTpuStream::LoadProgram(
862     int32_t core_id, const CompiledProgramHandle* handle,
863     absl::Span<Event* const> wait_for) {
864   auto req = absl::make_unique<StreamRequest::Entry>();
865   InitializeRequest(req.get(), wait_for);
866   TraceMe activity("GrpcTpuStream::LoadProgram");
867   req->mutable_load()->set_core_id(core_id);
868   auto grpc_handle = static_cast<const GrpcCompiledProgramHandle*>(handle);
869   if (grpc_handle->id().client_id != driver_->client_id()) {
870     auto event = std::make_shared<ErrorEvent>(
871         xla::InvalidArgument("Invalid program handle (wrong client id). Did "
872                              "you restart the server or use a stale handle?"));
873     return absl::make_unique<GrpcLoadedProgramHandle>(event->id(),
874                                                       std::move(event));
875   }
876   req->mutable_load()->set_compiled_program_handle(grpc_handle->id().AsInt());
877   auto event =
878       std::make_shared<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
879   AddWriteRequest(std::move(req));
880   return absl::make_unique<GrpcLoadedProgramHandle>(event->id(),
881                                                     std::move(event));
882 }
883 
UnloadProgram(std::unique_ptr<LoadedProgramHandle> handle,absl::Span<Event * const> wait_for)884 std::shared_ptr<Event> GrpcTpuStream::UnloadProgram(
885     std::unique_ptr<LoadedProgramHandle> handle,
886     absl::Span<Event* const> wait_for) {
887   auto req = absl::make_unique<StreamRequest::Entry>();
888   InitializeRequest(req.get(), wait_for);
889   TraceMe activity("GrpcTpuStream::UnloadProgram");
890   req->mutable_unload()->set_loaded_program_handle(
891       static_cast<GrpcLoadedProgramHandle*>(handle.get())->id().AsInt());
892   auto event =
893       std::make_shared<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
894   AddWriteRequest(std::move(req));
895   return event;
896 }
897 
ExecuteProgram(LoadedProgramHandle * program,absl::Span<BufferHandle * const> inputs,absl::Span<BufferHandle * const> outputs,const xla::DeviceAssignmentProto & device_assignment,absl::Span<Event * const> wait_for)898 std::shared_ptr<Event> GrpcTpuStream::ExecuteProgram(
899     LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
900     absl::Span<BufferHandle* const> outputs,
901     const xla::DeviceAssignmentProto& device_assignment,
902     absl::Span<Event* const> wait_for) {
903   auto req = absl::make_unique<StreamRequest::Entry>();
904   InitializeRequest(req.get(), wait_for);
905   auto program_handle = static_cast<GrpcLoadedProgramHandle*>(program);
906   if (program_handle->id().client_id != driver_->client_id()) {
907     return std::make_shared<ErrorEvent>(
908         xla::InvalidArgument("Invalid program handle (wrong client id). Did "
909                              "you restart the server or use a stale handle?"));
910   }
911 
912   req->mutable_execute()->set_loaded_program_handle(
913       program_handle->id().AsInt());
914 
915   for (BufferHandle* input : inputs) {
916     auto* grpc_handle = static_cast<GrpcBufferHandle*>(input);
917     if (grpc_handle->id().client_id != driver_->client_id()) {
918       return std::make_shared<ErrorEvent>(xla::InvalidArgument(
919           "Invalid input buffer (wrong client id). Did you restart the server "
920           "or use a stale handle?"));
921     }
922     req->mutable_execute()->add_input_handle(grpc_handle->id().AsInt());
923   }
924 
925   for (BufferHandle* output : outputs) {
926     auto* grpc_handle = static_cast<GrpcBufferHandle*>(output);
927     if (grpc_handle->id().client_id != driver_->client_id()) {
928       return std::make_shared<ErrorEvent>(xla::InvalidArgument(
929           "Invalid output buffer (wrong client id). Did you restart the server "
930           "or use a stale handle?"));
931     }
932     req->mutable_execute()->add_output_handle(
933         static_cast<GrpcBufferHandle*>(output)->id().AsInt());
934   }
935   // Only pass along device_assignment if it's not default constructed.
936   if (!(device_assignment.replica_count() == 0 &&
937         device_assignment.computation_count() == 0)) {
938     *req->mutable_execute()->mutable_device_assignment() = device_assignment;
939   }
940   auto event =
941       std::make_shared<GrpcEvent>(EventId::FromInt(req->operation_id()), this);
942   AddWriteRequest(std::move(req));
943   return event;
944 }
945 
946 /*static*/ std::unique_ptr<grpc::CloudTpuDriver::Stub>
CreateTpuDriverStub(const TpuDriverConfig & config,std::shared_ptr<::grpc::ChannelCredentials> creds)947 GrpcTpuDriver::CreateTpuDriverStub(
948     const TpuDriverConfig& config,
949     std::shared_ptr<::grpc::ChannelCredentials> creds) {
950   ::grpc::ChannelArguments args;
951   args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
952   args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
953 
954   // Send at least 20 keep-alives before giving up.
955   int keepalive_timeout_ms = config.grpc().keepalive_timeout_secs() * 1000;
956   int keepalive_interval_ms = keepalive_timeout_ms / 20;
957 
958   grpc_arg client_arg_vals[] = {
959       {.type = GRPC_ARG_INTEGER,
960        .key = const_cast<char*>(
961            GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS),
962        .value = {.integer = keepalive_interval_ms}},
963       {.type = GRPC_ARG_INTEGER,
964        .key = const_cast<char*>(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA),
965        .value = {.integer = 0}},  // unlimited
966       {.type = GRPC_ARG_INTEGER,
967        .key = const_cast<char*>(GRPC_ARG_KEEPALIVE_TIME_MS),
968        .value = {.integer = keepalive_interval_ms}},
969       {.type = GRPC_ARG_INTEGER,
970        .key = const_cast<char*>(GRPC_ARG_KEEPALIVE_TIMEOUT_MS),
971        .value = {.integer = keepalive_timeout_ms}},
972       {.type = GRPC_ARG_INTEGER,
973        .key = const_cast<char*>(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS),
974        .value = {.integer = 1}},
975       {.type = GRPC_ARG_INTEGER,
976        .key = const_cast<char*>(GRPC_ARG_HTTP2_WRITE_BUFFER_SIZE),
977        .value = {.integer = 64 * 1000 * 1000}}};
978 
979   grpc_channel_args client_args = {.num_args = 6, .args = client_arg_vals};
980   args.SetChannelArgs(&client_args);
981 
982   // strips out 'grpc://'
983   auto worker_addr = absl::StripPrefix(config.worker(), kGrpcProtocol);
984   std::shared_ptr<::grpc::Channel> channel =
985       ::grpc::CreateCustomChannel(std::string(worker_addr), creds, args);
986   return grpc::CloudTpuDriver::NewStub(channel);
987 }
988 
AllocateStream(int32_t id)989 std::unique_ptr<GrpcTpuStream> GrpcTpuDriver::AllocateStream(int32_t id) {
990   auto stub = CreateTpuDriverStub(config_, creds_);
991   ::grpc::ClientContext ctx;
992   ctx.set_fail_fast(false);
993   ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10));
994   return absl::make_unique<GrpcTpuStream>(id, this, std::move(stub));
995 }
996 
QuerySystemInfo(SystemInfo * system_info)997 void GrpcTpuDriver::QuerySystemInfo(SystemInfo* system_info) {
998   auto stub = CreateTpuDriverStub(config_, creds_);
999   ::grpc::ClientContext ctx;
1000   ctx.set_fail_fast(false);
1001   ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10));
1002 
1003   QuerySystemInfoRequest req;
1004   QuerySystemInfoResponse resp;
1005   ::grpc::Status status = stub->QuerySystemInfo(&ctx, req, &resp);
1006   if (!status.ok()) {
1007     LOG(ERROR) << "QuerySystemInfo request failed: " << status.error_code()
1008                << ": " << status.error_message() << ": "
1009                << status.error_details();
1010     return;
1011   }
1012   *system_info = resp.system_info();
1013 }
1014 
Reset()1015 Status GrpcTpuDriver::Reset() {
1016   auto stub = CreateTpuDriverStub(config_, creds_);
1017   ::grpc::ClientContext ctx;
1018   ctx.set_fail_fast(false);
1019   ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10));
1020   ResetRequest req;
1021   ResetResponse resp;
1022   ::grpc::Status status = stub->Reset(&ctx, req, &resp);
1023   if (!status.ok()) {
1024     LOG(ERROR) << "Failed to reset the gRPC driver: " << status.error_code()
1025                << ": " << status.error_message() << ": "
1026                << status.error_details();
1027     return xla::Status(tensorflow::error::Code(status.error_code()),
1028                        absl::StrCat("Failed to reset TPU driver. Error was: ",
1029                                     status.error_message(),
1030                                     ". Details: ", status.error_details()));
1031   }
1032   streams_.clear();
1033   host_stream_.reset();
1034   return Close();
1035 }
1036 
Close()1037 Status GrpcTpuDriver::Close() {
1038   auto stub = CreateTpuDriverStub(config_, creds_);
1039   ::grpc::ClientContext ctx;
1040   ctx.set_fail_fast(false);
1041   ctx.set_deadline(std::chrono::system_clock::now() + std::chrono::seconds(10));
1042   CloseRequest req;
1043   req.set_client_id(client_id_);
1044   CloseResponse resp;
1045   ::grpc::Status status = stub->Close(&ctx, req, &resp);
1046   if (!status.ok()) {
1047     return xla::Status(tensorflow::error::Code(status.error_code()),
1048                        absl::StrCat("Failed to close TPU driver. Error was: ",
1049                                     status.error_message(),
1050                                     ". Details: ", status.error_details()));
1051   }
1052   closed_ = true;
1053   return Status::OK();
1054 }
1055 }  // namespace
1056 
CreateGrpcTpuDriver(const TpuDriverConfig & config,std::shared_ptr<::grpc::ChannelCredentials> creds)1057 xla::StatusOr<std::unique_ptr<TpuDriver>> CreateGrpcTpuDriver(
1058     const TpuDriverConfig& config,
1059     std::shared_ptr<::grpc::ChannelCredentials> creds) {
1060   auto stub = GrpcTpuDriver::CreateTpuDriverStub(config, creds);
1061   ::grpc::ClientContext ctx;
1062   ctx.set_fail_fast(false);
1063   ctx.set_deadline(
1064       std::chrono::system_clock::now() +
1065       std::chrono::seconds(config.grpc().connection_timeout_secs()));
1066   OpenRequest req;
1067   OpenResponse resp;
1068   ::grpc::Status status = stub->Open(&ctx, req, &resp);
1069   if (!status.ok()) {
1070     LOG(ERROR) << "Failed to open the gRPC driver: " << status.error_code()
1071                << ": " << status.error_message() << ": "
1072                << status.error_details();
1073     return xla::Status(
1074         tensorflow::error::Code(status.error_code()),
1075         absl::StrCat(
1076             "Failed to connect to remote server at address: ", config.worker(),
1077             ". Error from gRPC: ", status.error_message(),
1078             ". Details: ", status.error_details()));
1079   }
1080   return std::unique_ptr<TpuDriver>(
1081       new GrpcTpuDriver(config, creds, resp.client_id()));
1082 }
1083 
1084 REGISTER_TPU_DRIVER(
1085     "grpc://",
1086     [](const TpuDriverConfig& config)
__anond0091eda0402(const TpuDriverConfig& config) 1087         -> xla::StatusOr<std::unique_ptr<TpuDriver>> {
1088       if (absl::StartsWith(config.worker(), "grpc://localhost")) {
1089         LOG(INFO) << "Using local credentials for localhost: connection.";
1090         return CreateGrpcTpuDriver(
1091             config, ::grpc::experimental::LocalCredentials(LOCAL_TCP));
1092       } else {
1093         return CreateGrpcTpuDriver(config,
1094                                    ::grpc::InsecureChannelCredentials());
1095       }
1096     });
1097 
1098 }  // namespace tpu_driver
1099