1 // Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 
16 #include "absl/container/btree_map.h"
17 #include "absl/container/flat_hash_map.h"
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_split.h"
20 #include "absl/synchronization/mutex.h"
21 #include "tensorflow/compiler/xla/pjrt/semaphore.h"
22 #include "tensorflow/compiler/xla/pjrt/worker_thread.h"
23 #include "tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.h"
24 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
25 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/protobuf/error_codes.pb.h"
29 
30 namespace tpu_driver {
31 namespace {
32 
33 #define CHECK_EXISTS_OR_RETURN(container, target_op_id, operation_id)  \
34   {                                                                    \
35     auto p = CheckHandleExists(container, target_op_id, operation_id); \
36     if (p != nullptr) return p;                                        \
37   }
38 
39 using xla::Status;
40 using xla::WorkerThread;
41 
42 const char kPodTpuDriverPrefix[] = "grpc+pod://";
43 
44 class PodTpuDriver;
45 
46 class PodEvent : public Event {
47  public:
PodEvent(PodTpuDriver * driver,int64_t operation_id)48   explicit PodEvent(PodTpuDriver* driver, int64_t operation_id)
49       : driver_(driver), operation_id_(operation_id) {}
operation_id() const50   int64_t operation_id() const { return operation_id_; }
51 
52   xla::Status Await() override;
53 
54   absl::optional<xla::Status> AwaitWithTimeout(
55       absl::Duration duration) override;
56 
57   void AddCallback(std::function<void(Status)> callback) override;
58 
59  private:
60   PodTpuDriver* driver_;
61   const int64_t operation_id_;
62 };
63 
64 class ErrorEvent : public PodEvent {
65  public:
ErrorEvent(PodTpuDriver * driver,int64_t operation_id,Status status)66   explicit ErrorEvent(PodTpuDriver* driver, int64_t operation_id, Status status)
67       : PodEvent(driver, operation_id) {
68     status_ = status;
69   }
70 
Await()71   xla::Status Await() override { return status_; }
AwaitWithTimeout(absl::Duration duration)72   absl::optional<xla::Status> AwaitWithTimeout(
73       absl::Duration duration) override {
74     return status_;
75   }
AddCallback(std::function<void (Status)> callback)76   void AddCallback(std::function<void(Status)> callback) override {
77     callback(status_);
78   }
79 
80  private:
81   Status status_;
82 };
83 
84 class CombinedEvent : public PodEvent {
85  public:
CombinedEvent(PodTpuDriver * driver,int64_t operation_id,std::vector<std::shared_ptr<Event>> events)86   explicit CombinedEvent(PodTpuDriver* driver, int64_t operation_id,
87                          std::vector<std::shared_ptr<Event>> events)
88       : PodEvent(driver, operation_id), events_(events) {
89     for (auto& event : events_) {
90       event->AddCallback([this](Status s) { IncrementAndCheckComplete(s); });
91     }
92   }
93 
Await()94   xla::Status Await() override {
95     for (auto& event : events_) {
96       TF_RETURN_IF_ERROR(event->Await());
97     }
98     return Status::OK();
99   }
100 
AwaitWithTimeout(absl::Duration duration)101   absl::optional<xla::Status> AwaitWithTimeout(
102       absl::Duration duration) override {
103     for (auto& event : events_) {
104       auto start_time = absl::Now();
105       auto status = event->AwaitWithTimeout(duration);
106       duration -= absl::Now() - start_time;
107       if (status == absl::nullopt) {
108         return absl::nullopt;
109       } else {
110         TF_RETURN_IF_ERROR(status.value());
111       }
112     }
113     return Status::OK();
114   }
115 
AddCallback(std::function<void (Status)> callback)116   void AddCallback(std::function<void(Status)> callback)
117       TF_LOCKS_EXCLUDED(mu_) override {
118     bool all_events_completed = false;
119     {
120       absl::MutexLock l(&mu_);
121       all_events_completed = events_completed_ == events_.size();
122     }
123     if (all_events_completed) {
124       callback(event_status_);
125     } else {
126       absl::MutexLock l(&mu_);
127       callbacks_.push_back(std::move(callback));
128     }
129   }
130 
131  private:
IncrementAndCheckComplete(Status s)132   void IncrementAndCheckComplete(Status s) TF_LOCKS_EXCLUDED(mu_) {
133     std::vector<std::function<void(Status)>> callbacks;
134     {
135       absl::MutexLock l(&mu_);
136 
137       event_status_ = s;
138       events_completed_++;
139       if (events_completed_ == events_.size()) {
140         // Copy callbacks to a temporary to be invoked outside the mutex.
141         callbacks.assign(callbacks_.begin(), callbacks_.end());
142         callbacks_.clear();
143       } else {
144         return;
145       }
146     }
147 
148     for (const auto& callback : callbacks) {
149       callback(event_status_);
150     }
151   }
152 
153   absl::Mutex mu_;
154   std::vector<std::shared_ptr<Event>> events_;
155   std::vector<std::function<void(Status)>> callbacks_ ABSL_GUARDED_BY(mu_);
156   int64_t events_completed_ ABSL_GUARDED_BY(mu_) = 0;
157   Status event_status_;
158 };
159 
160 class PodBufferHandle : public BufferHandle {
161  public:
PodBufferHandle(PodTpuDriver * driver,int64_t operation_id,int64_t size_in_bytes,absl::optional<xla::ShapeProto> shape,int64_t core_id)162   explicit PodBufferHandle(PodTpuDriver* driver, int64_t operation_id,
163                            int64_t size_in_bytes,
164                            absl::optional<xla::ShapeProto> shape,
165                            int64_t core_id)
166       : driver_(driver),
167         operation_id_(operation_id),
168         size_in_bytes_(size_in_bytes),
169         shape_(shape),
170         event_(std::make_shared<PodEvent>(driver_, operation_id_)),
171         core_id_(core_id) {}
172 
OnReady()173   std::shared_ptr<Event> OnReady() override { return event_; }
size_in_bytes()174   int64_t size_in_bytes() override { return size_in_bytes_; }
shape()175   absl::optional<xla::ShapeProto> shape() override { return shape_; }
176 
operation_id() const177   int64_t operation_id() const { return operation_id_; }
core_id() const178   int64_t core_id() const { return core_id_; }
179 
180  private:
181   PodTpuDriver* driver_;
182   const int64_t operation_id_;
183   const int64_t size_in_bytes_;
184   const absl::optional<xla::ShapeProto> shape_;
185   std::shared_ptr<PodEvent> event_;
186   const int64_t core_id_;
187 };
188 
189 class PodCompiledProgramHandle : public CompiledProgramHandle {
190  public:
PodCompiledProgramHandle(PodTpuDriver * driver,int64_t operation_id)191   explicit PodCompiledProgramHandle(PodTpuDriver* driver, int64_t operation_id)
192       : driver_(driver),
193         operation_id_(operation_id),
194         event_(std::make_shared<PodEvent>(driver_, operation_id_)) {}
195 
OnReady()196   std::shared_ptr<Event> OnReady() override { return event_; }
197 
198   xla::Status program_shape(xla::ProgramShapeProto* program_shape) override;
199 
operation_id() const200   int64_t operation_id() const { return operation_id_; }
201 
202  private:
203   PodTpuDriver* driver_;
204   const int64_t operation_id_;
205   std::shared_ptr<PodEvent> event_;
206 };
207 
208 class PodLoadedProgramHandle : public LoadedProgramHandle {
209  public:
PodLoadedProgramHandle(PodTpuDriver * driver,int64_t operation_id,int64_t core_id)210   explicit PodLoadedProgramHandle(PodTpuDriver* driver, int64_t operation_id,
211                                   int64_t core_id)
212       : driver_(driver),
213         operation_id_(operation_id),
214         core_id_(core_id),
215         event_(std::make_shared<PodEvent>(driver_, operation_id_)) {}
216 
OnReady()217   std::shared_ptr<Event> OnReady() override { return event_; }
218 
operation_id() const219   int64_t operation_id() const { return operation_id_; }
core_id() const220   int64_t core_id() const { return core_id_; }
221 
222  private:
223   PodTpuDriver* driver_;
224   const int64_t operation_id_;
225   const int64_t core_id_;
226   std::shared_ptr<PodEvent> event_;
227 };
228 
229 struct EventInFlight {
EventInFlighttpu_driver::__anon2989b0910111::EventInFlight230   EventInFlight()
231       : underlying_event(nullptr),
232         create_fn(nullptr),
233         incomplete_deps(),
234         callbacks() {}
235 
236   std::shared_ptr<Event> underlying_event;
237   std::function<std::shared_ptr<Event>(void)> create_fn;
238 
239   absl::flat_hash_set<int64_t> incomplete_deps;
240   std::vector<std::function<void(Status)>> callbacks;
241 };
242 
243 class PodTpuDriver : public TpuDriver {
244  public:
PodTpuDriver(const TpuDriverConfig & config,std::shared_ptr<::grpc::ChannelCredentials> creds)245   explicit PodTpuDriver(const TpuDriverConfig& config,
246                         std::shared_ptr<::grpc::ChannelCredentials> creds)
247       : config_(config),
248         creds_(creds),
249         event_thread_(tensorflow::Env::Default(), "grpc_pod_event_thread") {
250     std::vector<std::string> workers = absl::StrSplit(
251         absl::StripPrefix(config.worker(), kPodTpuDriverPrefix), ',');
252 
253     int worker_count = 0;
254 
255     // Flag for environments where local core # == all cores in TPU system #,
256     // which means that we are connecting to separate TPU systems or we are in
257     // a test environment.
258     bool in_local_core_environment = false;
259 
260     for (const auto& worker : workers) {
261       TpuDriverConfig worker_config(config_);
262       *(worker_config.mutable_worker()) = absl::StrCat("grpc://", worker);
263       auto tpu_driver =
264           CreateGrpcTpuDriver(worker_config, creds_).ConsumeValueOrDie();
265 
266       SystemInfo driver_info;
267       tpu_driver->QuerySystemInfo(&driver_info);
268 
269       if (driver_info.core_count() == driver_info.local_core_size()) {
270         drivers_.insert({worker_count, std::move(tpu_driver)});
271         in_local_core_environment = true;
272       } else {
273         drivers_.insert({driver_info.host_id(), std::move(tpu_driver)});
274       }
275 
276       worker_count++;
277     }
278 
279     absl::flat_hash_set<std::tuple<int, int, int>> processed_chips;
280 
281     for (int driver_num = 0; driver_num < workers.size(); ++driver_num) {
282       SystemInfo driver_info;
283       drivers_[driver_num]->QuerySystemInfo(&driver_info);
284 
285       for (const auto& tpu_chip : driver_info.tpu_chip()) {
286         std::tuple<int, int, int> coord{tpu_chip.chip_coord().x(),
287                                         tpu_chip.chip_coord().y(),
288                                         tpu_chip.chip_coord().z()};
289         // We only want to add chips that we have not seen before if we are in a
290         // TPU pod slice, or we are only seeing local cores (e.g. we are
291         // connected to individual TPUs or we are in a test environment).
292         if (!processed_chips.contains(coord) ||
293             driver_info.core_count() == driver_info.local_core_size()) {
294           *(pod_info_.add_tpu_chip()) = tpu_chip;
295           processed_chips.insert(coord);
296         }
297       }
298 
299       *(pod_info_.mutable_cpu()) = driver_info.cpu();
300     }
301 
302     // Process all the unique chips that we have seen.
303     int core_count = 0;
304     for (auto& tpu_chip : *pod_info_.mutable_tpu_chip()) {
305       for (auto& tpu_core : *tpu_chip.mutable_core()) {
306         int current_core = tpu_core.id();
307         if (in_local_core_environment) {
308           current_core = core_count;
309         }
310 
311         core_to_driver_.insert(
312             {current_core, drivers_[tpu_chip.host_id()].get()});
313         core_to_driver_id_.insert({current_core, tpu_chip.host_id()});
314         core_to_driver_core_.insert({current_core, tpu_core.id()});
315 
316         tpu_core.set_id(current_core);
317         tpu_core.set_core_on_host_index(current_core);
318         *(pod_info_.add_local_core()) = tpu_core;
319 
320         core_count++;
321       }
322 
323       // We are setting host_id to zero because we want this to look like one
324       // host with many cores from the perspective of tpu_client.cc.
325       tpu_chip.set_host_id(0);
326     }
327 
328     pod_info_.set_chip_count(pod_info_.tpu_chip_size());
329     pod_info_.set_core_count(pod_info_.local_core_size());
330 
331     // We want this to look like one host with many TPU chips/cores connected.
332     pod_info_.set_host_count(1);
333     pod_info_.set_host_id(0);
334   }
335 
~PodTpuDriver()336   ~PodTpuDriver() override {
337     // TODO(frankchn): Unload all handles, and wait for all events to finish.
338   }
339 
QuerySystemInfo(SystemInfo * system_info)340   void QuerySystemInfo(SystemInfo* system_info) override {
341     *system_info = pod_info_;
342   }
343 
Reset()344   xla::Status Reset() override {
345     for (auto& driver : drivers_) {
346       TF_RETURN_IF_ERROR(driver.second->Reset());
347     }
348     return xla::Status::OK();
349   }
350 
Allocate(int32_t core_id,MemoryRegion region,int64_t num_bytes,absl::Span<Event * const> wait_for)351   std::unique_ptr<BufferHandle> Allocate(
352       int32_t core_id, MemoryRegion region, int64_t num_bytes,
353       absl::Span<Event* const> wait_for) override {
354     int64_t operation_id = GetOperationId();
355     auto deps = GetDependencyOperationIds(wait_for);
356 
357     ScheduleRequest(
358         operation_id,
359         [this, core_id, region, num_bytes,
360          operation_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
361           underlying_buffers_.insert(
362               {operation_id,
363                core_to_driver_[core_id]->Allocate(core_to_driver_core_[core_id],
364                                                   region, num_bytes, {})});
365           return underlying_buffers_[operation_id]->OnReady();
366         },
367         deps);
368 
369     return absl::make_unique<PodBufferHandle>(this, operation_id, num_bytes,
370                                               absl::nullopt, core_id);
371   }
372 
Allocate(int32_t core_id,MemoryRegion region,const xla::ShapeProto & shape,absl::Span<Event * const> wait_for)373   std::unique_ptr<BufferHandle> Allocate(
374       int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
375       absl::Span<Event* const> wait_for) override {
376     int64_t operation_id = GetOperationId();
377     auto deps = GetDependencyOperationIds(wait_for);
378 
379     ScheduleRequest(
380         operation_id,
381         [this, core_id, region, shape,
382          operation_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
383           underlying_buffers_.insert(
384               {operation_id,
385                core_to_driver_[core_id]->Allocate(core_to_driver_core_[core_id],
386                                                   region, shape, {})});
387           return underlying_buffers_[operation_id]->OnReady();
388         },
389         deps);
390 
391     return absl::make_unique<PodBufferHandle>(
392         this, operation_id, ComputeBytesFromShape(shape), shape, core_id);
393   }
394 
AllocateTuple(int32_t core_id,MemoryRegion region,absl::Span<BufferHandle * const> children,absl::Span<Event * const> wait_for)395   std::unique_ptr<BufferHandle> AllocateTuple(
396       int32_t core_id, MemoryRegion region,
397       absl::Span<BufferHandle* const> children,
398       absl::Span<Event* const> wait_for) override {
399     int64_t operation_id = GetOperationId();
400     auto deps = GetDependencyOperationIds(wait_for);
401 
402     std::vector<int64_t> children_ids;
403     for (int i = 0; i < children.size(); ++i) {
404       auto child_op_id =
405           static_cast<PodBufferHandle* const>(children[i])->operation_id();
406       deps.insert(child_op_id);
407       children_ids.push_back(child_op_id);
408     }
409 
410     ScheduleRequest(
411         operation_id,
412         [this, core_id, region, children_ids,
413          operation_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_)
414             -> std::shared_ptr<Event> {
415           std::vector<BufferHandle*> child_buffers;
416           child_buffers.reserve(children_ids.size());
417           for (int i = 0; i < children_ids.size(); ++i) {
418             CHECK_EXISTS_OR_RETURN(underlying_buffers_, children_ids[i],
419                                    operation_id);
420             child_buffers.push_back(underlying_buffers_[children_ids[i]].get());
421           }
422 
423           underlying_buffers_.insert(
424               {operation_id,
425                core_to_driver_[core_id]->AllocateTuple(
426                    core_to_driver_core_[core_id], region, child_buffers, {})});
427           return underlying_buffers_[operation_id]->OnReady();
428         },
429         deps);
430 
431     return absl::make_unique<PodBufferHandle>(this, operation_id, 0,
432                                               absl::nullopt, core_id);
433   }
434 
Deallocate(std::unique_ptr<BufferHandle> handle,absl::Span<Event * const> wait_for)435   std::shared_ptr<Event> Deallocate(
436       std::unique_ptr<BufferHandle> handle,
437       absl::Span<Event* const> wait_for) override {
438     int64_t operation_id = GetOperationId();
439     auto deps = GetDependencyOperationIds(wait_for);
440     deps.insert(static_cast<PodBufferHandle*>(handle.get())->operation_id());
441 
442     auto op_id = static_cast<PodBufferHandle*>(handle.get())->operation_id();
443     auto core_id = static_cast<PodBufferHandle*>(handle.get())->core_id();
444 
445     ScheduleRequest(
446         operation_id,
447         [this, operation_id, op_id,
448          core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
449           CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id);
450 
451           auto buf_iter = underlying_buffers_.find(op_id);
452           auto underlying_hn = std::move(buf_iter->second);
453           underlying_buffers_.erase(buf_iter);
454 
455           return core_to_driver_[core_id]->Deallocate(std::move(underlying_hn),
456                                                       {});
457         },
458         deps);
459 
460     return std::make_shared<PodEvent>(this, operation_id);
461   }
462 
TransferToDevice(const void * src,BufferHandle * dst,absl::Span<Event * const> wait_for)463   std::shared_ptr<Event> TransferToDevice(
464       const void* src, BufferHandle* dst,
465       absl::Span<Event* const> wait_for) override {
466     int64_t operation_id = GetOperationId();
467     auto deps = GetDependencyOperationIds(wait_for);
468     deps.insert(static_cast<PodBufferHandle*>(dst)->operation_id());
469 
470     auto op_id = static_cast<PodBufferHandle*>(dst)->operation_id();
471     auto core_id = static_cast<PodBufferHandle*>(dst)->core_id();
472 
473     ScheduleRequest(
474         operation_id,
475         [this, src, operation_id, op_id,
476          core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
477           CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id);
478 
479           auto buf_iter = underlying_buffers_.find(op_id);
480           return core_to_driver_[core_id]->TransferToDevice(
481               src, buf_iter->second.get(), {});
482         },
483         deps);
484 
485     return std::make_shared<PodEvent>(this, operation_id);
486   }
487 
TransferFromDevice(const BufferHandle * src,void * dst,absl::Span<Event * const> wait_for)488   std::shared_ptr<Event> TransferFromDevice(
489       const BufferHandle* src, void* dst,
490       absl::Span<Event* const> wait_for) override {
491     int64_t operation_id = GetOperationId();
492     auto deps = GetDependencyOperationIds(wait_for);
493     deps.insert(static_cast<const PodBufferHandle*>(src)->operation_id());
494 
495     auto op_id = static_cast<const PodBufferHandle*>(src)->operation_id();
496     auto core_id = static_cast<const PodBufferHandle*>(src)->core_id();
497 
498     ScheduleRequest(
499         operation_id,
500         [this, dst, operation_id, op_id,
501          core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
502           CHECK_EXISTS_OR_RETURN(underlying_buffers_, op_id, operation_id);
503           auto buf_iter = underlying_buffers_.find(op_id);
504           return core_to_driver_[core_id]->TransferFromDevice(
505               buf_iter->second.get(), dst, {});
506         },
507         deps);
508 
509     return std::make_shared<PodEvent>(this, operation_id);
510   }
511 
TransferFromDeviceToDevice(const BufferHandle * src,BufferHandle * dst,absl::Span<Event * const> wait_for)512   std::shared_ptr<Event> TransferFromDeviceToDevice(
513       const BufferHandle* src, BufferHandle* dst,
514       absl::Span<Event* const> wait_for) override {
515     auto src_core_id = static_cast<const PodBufferHandle*>(src)->core_id();
516     auto dst_core_id = static_cast<PodBufferHandle*>(dst)->core_id();
517 
518     auto src_driver_id = core_to_driver_id_[src_core_id];
519     auto dst_driver_id = core_to_driver_id_[dst_core_id];
520 
521     if (src_driver_id == dst_driver_id) {
522       // They are in the same host, we can schedule it normally
523       int64_t operation_id = GetOperationId();
524       auto deps = GetDependencyOperationIds(wait_for);
525       deps.insert(static_cast<const PodBufferHandle*>(src)->operation_id());
526       deps.insert(static_cast<PodBufferHandle*>(dst)->operation_id());
527 
528       auto src_op_id = static_cast<const PodBufferHandle*>(src)->operation_id();
529       auto dst_op_id = static_cast<PodBufferHandle*>(dst)->operation_id();
530 
531       ScheduleRequest(
532           operation_id,
533           [this, operation_id, src_op_id, dst_op_id, dst_core_id]()
534               TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
535                 CHECK_EXISTS_OR_RETURN(underlying_buffers_, src_op_id,
536                                        operation_id);
537                 CHECK_EXISTS_OR_RETURN(underlying_buffers_, dst_op_id,
538                                        operation_id);
539 
540                 auto src_iter = underlying_buffers_.find(src_op_id);
541                 auto dst_iter = underlying_buffers_.find(dst_op_id);
542                 return core_to_driver_[dst_core_id]->TransferFromDeviceToDevice(
543                     src_iter->second.get(), dst_iter->second.get(), {});
544               },
545           deps);
546       return std::make_shared<PodEvent>(this, operation_id);
547     } else {
548       // src and dst are on different hosts, we have to bounce through us.
549       auto dst_size = dst->size_in_bytes();
550       char* host_buf = new char[dst_size];
551 
552       auto src_event = TransferFromDevice(src, host_buf, wait_for);
553       auto dst_event = TransferToDevice(host_buf, dst, {src_event.get()});
554       dst_event->AddCallback(
555           [src_event, host_buf](xla::Status status) { delete[] host_buf; });
556       return dst_event;
557     }
558   }
559 
CompileProgram(const xla::HloProto & source,int32_t num_replicas,absl::Span<Event * const> wait_for)560   std::unique_ptr<CompiledProgramHandle> CompileProgram(
561       const xla::HloProto& source, int32_t num_replicas,
562       absl::Span<Event* const> wait_for) override {
563     int64_t operation_id = GetOperationId();
564     auto deps = GetDependencyOperationIds(wait_for);
565 
566     ScheduleRequest(
567         operation_id,
568         [this, operation_id, source,
569          num_replicas]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
570           auto cph_iterator =
571               underlying_cph_
572                   .insert(
573                       {operation_id,
574                        std::vector<std::unique_ptr<CompiledProgramHandle>>()})
575                   .first;
576 
577           std::vector<std::shared_ptr<Event>> collected_events;
578           for (int i = 0; i < drivers_.size(); ++i) {
579             auto current_cph =
580                 drivers_[i]->CompileProgram(source, num_replicas, {});
581             cph_iterator->second.push_back(std::move(current_cph));
582             collected_events.push_back(cph_iterator->second[i]->OnReady());
583           }
584           return std::make_shared<CombinedEvent>(this, operation_id,
585                                                  collected_events);
586         },
587         deps);
588 
589     return absl::make_unique<PodCompiledProgramHandle>(this, operation_id);
590   }
591 
LoadProgram(int32_t core_id,const CompiledProgramHandle * handle,absl::Span<Event * const> wait_for)592   std::unique_ptr<LoadedProgramHandle> LoadProgram(
593       int32_t core_id, const CompiledProgramHandle* handle,
594       absl::Span<Event* const> wait_for) override {
595     int64_t operation_id = GetOperationId();
596     auto deps = GetDependencyOperationIds(wait_for);
597     deps.insert(
598         static_cast<const PodCompiledProgramHandle*>(handle)->operation_id());
599     auto cph_op_id =
600         static_cast<const PodCompiledProgramHandle*>(handle)->operation_id();
601 
602     ScheduleRequest(
603         operation_id,
604         [this, operation_id, cph_op_id,
605          core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
606           CHECK_EXISTS_OR_RETURN(underlying_cph_, cph_op_id, operation_id);
607           auto cph_iter = underlying_cph_.find(cph_op_id);
608 
609           underlying_lph_.insert(
610               {operation_id,
611                core_to_driver_[core_id]->LoadProgram(
612                    core_to_driver_core_[core_id],
613                    cph_iter->second[core_to_driver_id_[core_id]].get(), {})});
614 
615           return underlying_lph_[operation_id]->OnReady();
616         },
617         deps);
618 
619     return absl::make_unique<PodLoadedProgramHandle>(this, operation_id,
620                                                      core_id);
621   }
622 
UnloadProgram(std::unique_ptr<LoadedProgramHandle> handle,absl::Span<Event * const> wait_for)623   std::shared_ptr<Event> UnloadProgram(
624       std::unique_ptr<LoadedProgramHandle> handle,
625       absl::Span<Event* const> wait_for) override {
626     int64_t operation_id = GetOperationId();
627     auto deps = GetDependencyOperationIds(wait_for);
628     deps.insert(
629         static_cast<PodLoadedProgramHandle*>(handle.get())->operation_id());
630     auto op_id =
631         static_cast<PodLoadedProgramHandle*>(handle.get())->operation_id();
632     auto core_id =
633         static_cast<PodLoadedProgramHandle*>(handle.get())->core_id();
634 
635     ScheduleRequest(
636         operation_id,
637         [this, operation_id, op_id,
638          core_id]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) -> std::shared_ptr<Event> {
639           CHECK_EXISTS_OR_RETURN(underlying_lph_, op_id, operation_id);
640           auto lph_iter = underlying_lph_.find(op_id);
641           auto event = core_to_driver_[core_id]->UnloadProgram(
642               std::move(lph_iter->second), {});
643           underlying_lph_.erase(lph_iter);
644 
645           return event;
646         },
647         deps);
648 
649     return std::make_shared<PodEvent>(this, operation_id);
650   }
651 
ExecuteProgram(LoadedProgramHandle * program,absl::Span<BufferHandle * const> inputs,absl::Span<BufferHandle * const> outputs,const xla::DeviceAssignmentProto & device_assignment,absl::Span<Event * const> wait_for)652   std::shared_ptr<Event> ExecuteProgram(
653       LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
654       absl::Span<BufferHandle* const> outputs,
655       const xla::DeviceAssignmentProto& device_assignment,
656       absl::Span<Event* const> wait_for) override {
657     int64_t operation_id = GetOperationId();
658 
659     auto deps = GetDependencyOperationIds(wait_for);
660     deps.insert(static_cast<PodLoadedProgramHandle*>(program)->operation_id());
661 
662     auto op_id = static_cast<PodLoadedProgramHandle*>(program)->operation_id();
663     auto core_id = static_cast<PodLoadedProgramHandle*>(program)->core_id();
664 
665     std::vector<int64_t> input_op_ids;
666     std::vector<int64_t> output_op_ids;
667 
668     for (auto* input : inputs) {
669       auto input_dep =
670           static_cast<PodBufferHandle* const>(input)->operation_id();
671       input_op_ids.push_back(input_dep);
672       deps.insert(input_dep);
673     }
674     for (auto* output : outputs) {
675       auto output_dep =
676           static_cast<PodBufferHandle* const>(output)->operation_id();
677       output_op_ids.push_back(output_dep);
678       deps.insert(output_dep);
679     }
680 
681     ScheduleRequest(
682         operation_id,
683         [this, operation_id, core_id, op_id, input_op_ids, output_op_ids,
684          device_assignment]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_)
685             -> std::shared_ptr<Event> {
686           std::vector<BufferHandle*> underlying_inputs;
687           std::vector<BufferHandle*> underlying_outputs;
688 
689           underlying_inputs.reserve(input_op_ids.size());
690           for (auto input_op_id : input_op_ids) {
691             CHECK_EXISTS_OR_RETURN(underlying_buffers_, input_op_id,
692                                    operation_id);
693             underlying_inputs.push_back(underlying_buffers_[input_op_id].get());
694           }
695           underlying_outputs.reserve(output_op_ids.size());
696           for (auto output_op_id : output_op_ids) {
697             CHECK_EXISTS_OR_RETURN(underlying_buffers_, output_op_id,
698                                    operation_id);
699             underlying_outputs.push_back(
700                 underlying_buffers_[output_op_id].get());
701           }
702 
703           CHECK_EXISTS_OR_RETURN(underlying_lph_, op_id, operation_id);
704           LoadedProgramHandle* handle = underlying_lph_[op_id].get();
705           return core_to_driver_[core_id]->ExecuteProgram(
706               handle, underlying_inputs, underlying_outputs, device_assignment,
707               {});
708         },
709         deps);
710 
711     return std::make_shared<PodEvent>(this, operation_id);
712   }
713 
GetLinearizer()714   std::unique_ptr<TpuLinearizer> GetLinearizer() override {
715     return drivers_[0]->GetLinearizer();
716   }
717 
718   // Helper methods for Event scheduling
719 
WaitForEvent(int64_t event_id,absl::Duration duration)720   absl::optional<Status> WaitForEvent(int64_t event_id, absl::Duration duration)
721       TF_LOCKS_EXCLUDED(mu_) {
722     std::shared_ptr<Event> underlying_event;
723 
724     {
725       absl::MutexLock l(&mu_);
726       auto event = events_.find(event_id);
727 
728       if (event == events_.end()) {
729         auto event_status = abnormal_event_status_.find(event_id);
730         if (event_status == abnormal_event_status_.end()) {
731           return Status::OK();
732         } else {
733           return event_status->second;
734         }
735       }
736 
737       auto done = [this, event_id]() {
738         mu_.AssertHeld();
739         // The event was either completed and erased from the map or we have
740         // an underlying event available to us.
741         return events_.count(event_id) == 0 ||
742                (events_[event_id]->underlying_event != nullptr &&
743                 events_[event_id]->underlying_event.use_count() != 0);
744       };
745 
746       auto status = mu_.AwaitWithTimeout(absl::Condition(&done), duration);
747       if (!status) {
748         return absl::nullopt;
749       }
750 
751       if (events_.count(event_id) > 0) {
752         underlying_event = events_[event_id]->underlying_event;
753       } else {
754         underlying_event = nullptr;
755       }
756     }
757 
758     // Wait for the underlying event without holding on to the event_lock_, or
759     // else incoming events will not be processed.
760     if (underlying_event != nullptr) {
761       return underlying_event->AwaitWithTimeout(duration);
762     } else {
763       absl::MutexLock l(&mu_);
764       auto event_status = abnormal_event_status_.find(event_id);
765       if (event_status == abnormal_event_status_.end()) {
766         return Status::OK();
767       } else {
768         return event_status->second;
769       }
770     }
771   }
772 
AddCallbackForEvent(int64_t event_id,std::function<void (Status)> fn)773   void AddCallbackForEvent(int64_t event_id, std::function<void(Status)> fn)
774       TF_LOCKS_EXCLUDED(mu_) {
775     absl::MutexLock l(&mu_);
776     auto event = events_.find(event_id);
777 
778     if (event == events_.end()) {
779       auto event_status = abnormal_event_status_.find(event_id);
780       if (event_status == abnormal_event_status_.end()) {
781         fn(Status::OK());
782       } else {
783         fn(event_status->second);
784       }
785     } else {
786       if (event->second->underlying_event != nullptr &&
787           event->second->underlying_event.use_count() != 0) {
788         event->second->underlying_event->AddCallback(fn);
789       } else {
790         event->second->callbacks.push_back(std::move(fn));
791       }
792     }
793   }
794 
GetCompiledProgramShape(int64_t op_id,xla::ProgramShapeProto * program_shape)795   xla::Status GetCompiledProgramShape(int64_t op_id,
796                                       xla::ProgramShapeProto* program_shape)
797       TF_LOCKS_EXCLUDED(mu_) {
798     absl::MutexLock l(&mu_);
799 
800     auto done = [this, op_id]() {
801       mu_.AssertHeld();
802       return underlying_cph_.contains(op_id);
803     };
804     mu_.Await(absl::Condition(&done));
805 
806     return underlying_cph_[op_id][0]->program_shape(program_shape);
807   }
808 
809  private:
810   const TpuDriverConfig& config_;
811   std::shared_ptr<::grpc::ChannelCredentials> creds_;
812 
813   absl::flat_hash_map<int32_t, std::unique_ptr<TpuDriver>> drivers_;
814   absl::flat_hash_map<int32_t, int32_t> core_to_driver_id_;
815   absl::flat_hash_map<int32_t, TpuDriver*> core_to_driver_;
816   absl::flat_hash_map<int32_t, int32_t> core_to_driver_core_;
817   SystemInfo pod_info_;
818 
819   absl::Mutex mu_;
820 
821   absl::flat_hash_map<int64_t, std::unique_ptr<BufferHandle>>
822       underlying_buffers_ ABSL_GUARDED_BY(mu_);
823   absl::flat_hash_map<int64_t,
824                       std::vector<std::unique_ptr<CompiledProgramHandle>>>
825       underlying_cph_ ABSL_GUARDED_BY(mu_);
826   absl::flat_hash_map<int64_t, std::unique_ptr<LoadedProgramHandle>>
827       underlying_lph_ ABSL_GUARDED_BY(mu_);
828 
829   absl::btree_map<int64_t, std::unique_ptr<EventInFlight>> events_
830       ABSL_GUARDED_BY(mu_);
831   absl::flat_hash_map<int64_t, Status> abnormal_event_status_
832       ABSL_GUARDED_BY(mu_);
833 
834   std::atomic<int64_t> operation_id_counter_{0};
835 
836   WorkerThread event_thread_;
837 
GetOperationId()838   int64_t GetOperationId() { return operation_id_counter_++; }
839 
GetDependencyOperationIds(absl::Span<Event * const> wait_for)840   absl::flat_hash_set<int64_t> GetDependencyOperationIds(
841       absl::Span<Event* const> wait_for) {
842     absl::flat_hash_set<int64_t> deps;
843     for (auto* event : wait_for) {
844       deps.insert(static_cast<PodEvent* const>(event)->operation_id());
845     }
846     return deps;
847   }
848 
849   // EventCompleted is executed on the event_thread_ worker thread. We want
850   // to propagate the fact that the event is completed to any subsequent events
851   // that might depend on this event.
EventCompleted(int64_t event_id,Status status)852   void EventCompleted(int64_t event_id, Status status) TF_LOCKS_EXCLUDED(mu_) {
853     absl::MutexLock l(&mu_);
854 
855     absl::btree_map<int64_t, std::unique_ptr<EventInFlight>>::iterator
856         curr_event;
857     if (!status.ok()) abnormal_event_status_.insert({event_id, status});
858     curr_event = events_.find(event_id);
859 
860     DCHECK(curr_event->second->callbacks.empty());
861     DCHECK(curr_event->second->incomplete_deps.empty());
862 
863     for (auto& event : events_) {
864       event.second->incomplete_deps.erase(event_id);
865       // The if statement conditions on both
866       //  - all previous events have completed (incomplete_deps.empty())
867       //  - the op creating this event has not been called yet
868       //    (event.second.create_fn != nullptr)
869       // We call the create_fn that creates the event and adds any relevant
870       // callbacks to the actual event, before setting create_fn to nullptr
871       // to indicate that it has already been called
872       if (event.second->incomplete_deps.empty() &&
873           event.second->create_fn != nullptr) {
874         // We were the last unfilled dependency, all other dependencies are
875         // filled. We can now fire the create function.
876         event.second->underlying_event = event.second->create_fn();
877         for (auto& fn : event.second->callbacks) {
878           event.second->underlying_event->AddCallback(std::move(fn));
879         }
880         event.second->callbacks.clear();
881         event.second->create_fn = nullptr;
882       }
883     }
884 
885     // We erase the current event to signal that it has finished.
886     events_.erase(curr_event);
887   }
888 
ScheduleRequest(int64_t operation_id,std::function<std::shared_ptr<Event> (void)> fn,const absl::flat_hash_set<int64_t> & deps)889   void ScheduleRequest(int64_t operation_id,
890                        std::function<std::shared_ptr<Event>(void)> fn,
891                        const absl::flat_hash_set<int64_t>& deps)
892       TF_LOCKS_EXCLUDED(mu_) {
893     absl::MutexLock l(&mu_);
894     absl::btree_map<int64_t, std::unique_ptr<EventInFlight>>::iterator event;
895     absl::flat_hash_set<int64_t> incomplete_deps;
896 
897     event = events_.insert({operation_id, absl::make_unique<EventInFlight>()})
898                 .first;
899     for (const auto& dep : deps) {
900       if (events_.count(dep) > 0) incomplete_deps.insert(dep);
901     }
902 
903     if (incomplete_deps.empty()) {
904       // All dependencies have been fulfilled, we execute the request
905       // immediately and add a callback to inform our event fulfilled thread
906       // when it is done.
907       event->second->create_fn = nullptr;
908       event->second->underlying_event = fn();
909       event->second->underlying_event->AddCallback(
910           [this, operation_id](Status status) {
911             event_thread_.Schedule([this, operation_id, status]() {
912               EventCompleted(operation_id, status);
913             });
914           });
915     } else {
916       // There are some dependencies that are not yet fulfilled. We attach
917       // the request to the event, and will execute it in the EventFulfilled
918       // worker thread when all its dependencies are fulfilled.
919       event->second->create_fn = std::move(fn);
920       event->second->incomplete_deps = std::move(incomplete_deps);
921       event->second->callbacks.push_back([this, operation_id](Status status) {
922         event_thread_.Schedule([this, operation_id, status]() {
923           EventCompleted(operation_id, status);
924         });
925       });
926     }
927   }
928 
929   template <typename T>
CheckHandleExists(absl::flat_hash_map<int64_t,T> & container,int64_t target_op_id,int64_t operation_id)930   std::shared_ptr<Event> CheckHandleExists(
931       absl::flat_hash_map<int64_t, T>& container, int64_t target_op_id,
932       int64_t operation_id) {
933     if (container.count(target_op_id) == 0) {
934       return std::make_shared<ErrorEvent>(
935           this, operation_id,
936           tensorflow::errors::InvalidArgument("Handle ", target_op_id,
937                                               " does not exist."));
938     }
939     return nullptr;
940   }
941 };
942 
Await()943 xla::Status PodEvent::Await() {
944   return driver_->WaitForEvent(operation_id_, absl::InfiniteDuration()).value();
945 }
946 
AwaitWithTimeout(absl::Duration duration)947 absl::optional<xla::Status> PodEvent::AwaitWithTimeout(
948     absl::Duration duration) {
949   return driver_->WaitForEvent(operation_id_, duration);
950 }
951 
AddCallback(std::function<void (Status)> callback)952 void PodEvent::AddCallback(std::function<void(Status)> callback) {
953   driver_->AddCallbackForEvent(operation_id_, std::move(callback));
954 }
955 
CreatePodTpuDriver(const TpuDriverConfig & config,std::shared_ptr<::grpc::ChannelCredentials> creds)956 xla::StatusOr<std::unique_ptr<TpuDriver>> CreatePodTpuDriver(
957     const TpuDriverConfig& config,
958     std::shared_ptr<::grpc::ChannelCredentials> creds) {
959   return std::unique_ptr<TpuDriver>(new PodTpuDriver(config, creds));
960 }
961 
program_shape(xla::ProgramShapeProto * program_shape)962 xla::Status PodCompiledProgramHandle::program_shape(
963     xla::ProgramShapeProto* program_shape) {
964   return driver_->GetCompiledProgramShape(operation_id(), program_shape);
965 }
966 
967 }  // namespace
968 
969 REGISTER_TPU_DRIVER(kPodTpuDriverPrefix,
970                     [](const TpuDriverConfig& config)
__anon2989b0911202(const TpuDriverConfig& config) 971                         -> xla::StatusOr<std::unique_ptr<TpuDriver>> {
972                       return CreatePodTpuDriver(
973                           config,
974                           ::grpc::InsecureChannelCredentials());  // NOLINT
975                     });
976 
977 }  // namespace tpu_driver
978