1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/tpu_client.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/container/inlined_vector.h"
22 #include "absl/memory/memory.h"
23 #include "absl/status/status.h"
24 #include "tensorflow/compiler/xla/client/client_library.h"
25 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
26 #include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
27 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
28 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
29 #include "tensorflow/compiler/xla/service/tpu_computation_placer.h"
30 #include "tensorflow/compiler/xla/shape.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/platform/casts.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/stream_executor/device_memory.h"
37 #include "tensorflow/stream_executor/lib/statusor.h"
38 #include "tensorflow/stream_executor/stream.h"
39 #include "tensorflow/stream_executor/tpu/tpu_executable_interface.h"
40 #include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
41 #include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
42 #include "tensorflow/stream_executor/tpu/tpu_stream.h"
43 
44 namespace tf_tpu = tensorflow::tpu;
45 
46 namespace xla {
47 namespace {
48 
49 class TpuDeviceState : public LocalDeviceState {
50  public:
51   TpuDeviceState(se::StreamExecutor* executor, LocalClient* client,
52                  bool asynchronous);
53 
54   Status ThenMemcpyDeviceToDevice(se::Stream* transfer_stream,
55                                   se::Stream* dst_stream,
56                                   se::DeviceMemoryBase src_buffer,
57                                   se::DeviceMemoryBase dst_buffer) override;
58 };
59 
TpuDeviceState(se::StreamExecutor * executor,LocalClient * client,bool asynchronous)60 TpuDeviceState::TpuDeviceState(se::StreamExecutor* executor,
61                                LocalClient* client, bool asynchronous)
62     : LocalDeviceState(executor, client, LocalDeviceState::kAsynchronous,
63                        asynchronous,
64                        /*allow_event_reuse=*/false) {}
65 
ThenMemcpyDeviceToDevice(se::Stream * transfer_stream,se::Stream * dst_stream,se::DeviceMemoryBase src_buffer,se::DeviceMemoryBase dst_buffer)66 Status TpuDeviceState::ThenMemcpyDeviceToDevice(
67     se::Stream* transfer_stream, se::Stream* dst_stream,
68     se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) {
69   auto* transfer_tpu_stream = tensorflow::down_cast<tf_tpu::TpuStream*>(
70       transfer_stream->implementation());
71   TF_RETURN_IF_ERROR(transfer_tpu_stream->EnqueueOnTpuDeviceSendRecvLocal(
72       src_buffer, dst_buffer));
73   return Status::OK();
74 }
75 
76 class PjRtTpuClient : public PjRtStreamExecutorClient {
77  public:
78   PjRtTpuClient(LocalClient* client,
79                 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,
80                 int task_id);
81 
82   StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
83       int num_replicas, int num_partitions) const override;
84 
EnqueueD2DTransfersOnSrcStream() const85   bool EnqueueD2DTransfersOnSrcStream() const override { return false; }
86 
87   StatusOr<absl::optional<std::string>> ExecutableFingerprint(
88       const PjRtExecutable& executable) const override;
89 };
90 
PjRtTpuClient(LocalClient * client,std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices,int task_id)91 PjRtTpuClient::PjRtTpuClient(
92     LocalClient* client,
93     std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int task_id)
94     : PjRtStreamExecutorClient(kTpuName, client, std::move(devices), task_id,
95                                /*allocator=*/nullptr,
96                                /*host_memory_allocator=*/nullptr,
97                                /*should_stage_host_to_device_transfers=*/false,
98                                /*gpu_run_options=*/nullptr) {}
99 
GetDefaultDeviceAssignment(int num_replicas,int num_partitions) const100 StatusOr<DeviceAssignment> PjRtTpuClient::GetDefaultDeviceAssignment(
101     int num_replicas, int num_partitions) const {
102   tf_tpu::TpuPlatformInterface* platform =
103       tf_tpu::TpuPlatformInterface::GetRegisteredPlatform();
104   tf_tpu::TpuHostLocationExternal host = platform->GetTpuHostLocation();
105   int num_local_devices = host.Cores(kTensorCore).size();
106   if (num_replicas * num_partitions <= num_local_devices) {
107     return tf_tpu::TpuComputationPlacer::AssignLocalDevices(host, num_replicas,
108                                                             num_partitions);
109   }
110   // Fallback to default global device assignment if we can't run locally.
111   return PjRtStreamExecutorClient::GetDefaultDeviceAssignment(num_replicas,
112                                                               num_partitions);
113 }
114 
ExecutableFingerprint(const PjRtExecutable & executable) const115 StatusOr<absl::optional<std::string>> PjRtTpuClient::ExecutableFingerprint(
116     const PjRtExecutable& executable) const {
117   if (executable.client() != this) {
118     return InvalidArgument(
119         "Passed executable from different client (platform '%s') to "
120         "PjRtTpuClient::ExecutableFingerprint",
121         executable.client()->platform_name());
122   }
123   if (executable.num_partitions() > 1) {
124     LOG(INFO) << "ExecutableFingerprint not fully implemented for MPMD "
125                  "executables, fingerprint may not be unique.";
126   }
127   xla::TpuExecutableInterface* tpu_executable =
128       tensorflow::down_cast<xla::TpuExecutableInterface*>(
129           tensorflow::down_cast<const PjRtStreamExecutorExecutable*>(
130               &executable)
131               ->executables()[0]
132               ->executable());
133   return absl::optional<std::string>(tpu_executable->fingerprint());
134 }
135 
GetTpuDevices(LocalClient * client,std::vector<std::unique_ptr<LocalDeviceState>> local_device_states)136 StatusOr<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>> GetTpuDevices(
137     LocalClient* client,
138     std::vector<std::unique_ptr<LocalDeviceState>> local_device_states) {
139   std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
140   tf_tpu::TpuTopologyExternal topology =
141       tf_tpu::TpuPlatformInterface::GetRegisteredPlatform()->topology();
142 
143   std::map<int, int> core_id_to_device_ordinal;
144   for (int i = 0; i < client->device_count(); ++i) {
145     se::StreamExecutor* executor =
146         client->backend().stream_executor(i).ValueOrDie();
147     tf_tpu::TpuExecutorInterface* tpu_executor =
148         tensorflow::down_cast<tf_tpu::TpuExecutorInterface*>(
149             executor->implementation());
150     core_id_to_device_ordinal[tpu_executor->GetCoreLocationExternal().Id()] = i;
151   }
152 
153   for (const tf_tpu::TpuCoreLocationExternal& core :
154        topology.cores(TpuCoreTypeEnum::kTensorCore)) {
155     auto it = core_id_to_device_ordinal.find(core.Id());
156     int device_ordinal =
157         (it != core_id_to_device_ordinal.end()) ? it->second : -1;
158     int task_id = topology.IdForHost(core.host_coordinates());
159     const tf_tpu::TpuDimensionsExternal coords = core.chip_coordinates();
160     std::array<int, 3> coords_array = {coords.x, coords.y, coords.z};
161     std::unique_ptr<LocalDeviceState> local_device_state;
162     if (device_ordinal >= 0) {
163       local_device_state = std::move(local_device_states[device_ordinal]);
164     }
165     auto device = absl::make_unique<PjRtTpuDevice>(
166         core, std::move(local_device_state), task_id, coords_array,
167         std::string(tf_tpu::TpuVersionEnumToString(topology.version())));
168     devices.push_back(std::move(device));
169   }
170   return devices;
171 }
172 
173 }  // namespace
174 
GetTpuClient(bool asynchronous,absl::Duration init_retry_timeout)175 StatusOr<std::shared_ptr<PjRtClient>> GetTpuClient(
176     bool asynchronous, absl::Duration init_retry_timeout) {
177   tf_tpu::TpuPlatformInterface* platform =
178       tf_tpu::TpuPlatformInterface::GetRegisteredPlatform(
179           /*initialize_platform=*/true, /*num_tries=*/1);
180   if (platform == nullptr) {
181     return InvalidArgument("TpuPlatform is not available.");
182   }
183   // NOTE: We retry in a loop since some pod failures are transient (e.g. some
184   // RPCs may timeout waiting for other hosts to come up, but will succeed
185   // at a later point if retried).
186   auto start = absl::Now();
187   // TODO(b/165870356): TpuPlatform::Initialized() always returns true!
188   auto status = platform->Initialize({});
189   while (!platform->Initialized()) {
190     status = platform->Initialize({});
191     if (!status.ok()) {
192       LOG(ERROR) << "Platform initialization failed: " << status;
193       if ((absl::Now() - start) >= init_retry_timeout) {
194         return status;
195       }
196     }
197   }
198   if (platform->VisibleDeviceCount() <= 0) {
199     return InvalidArgument("No TPU devices found.");
200   }
201   LocalClientOptions options;
202   options.set_platform(platform);
203   TF_ASSIGN_OR_RETURN(LocalClient * client,
204                       ClientLibrary::GetOrCreateLocalClient(options));
205 
206   std::vector<std::unique_ptr<LocalDeviceState>> local_device_states;
207   local_device_states.reserve(client->device_count());
208   for (int i = 0; i < client->device_count(); ++i) {
209     se::StreamExecutor* executor =
210         client->backend().stream_executor(i).ValueOrDie();
211     local_device_states.push_back(
212         absl::make_unique<TpuDeviceState>(executor, client, asynchronous));
213   }
214 
215   TF_ASSIGN_OR_RETURN(auto devices,
216                       GetTpuDevices(client, std::move(local_device_states)));
217   int task_id = platform->GetTpuHostLocation().Id();
218 
219   return std::shared_ptr<PjRtClient>(
220       absl::make_unique<PjRtTpuClient>(client, std::move(devices), task_id));
221 }
222 
223 }  // namespace xla
224