1 // Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // ==============================================================================
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_TPU_DRIVER_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_TPU_DRIVER_H_
18 
19 #include <complex>
20 #include <cstdint>
21 #include <functional>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/synchronization/mutex.h"
29 #include "absl/types/optional.h"
30 #include "absl/types/span.h"
31 #include "tensorflow/compiler/xla/python/tpu_driver/platform/external/compat.h"
32 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
33 #include "tensorflow/compiler/xla/service/hlo.pb.h"
34 #include "tensorflow/compiler/xla/status.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/platform/logging.h"
38 
39 // This API is EXPERIMENTAL and under active development. It is subject to
40 // change without notice.
41 
42 namespace tpu_driver {
43 
44 int64_t ComputeBytesFromShape(const xla::ShapeProto& shape);
45 
46 // Represents the deferred completion of a scheduled operation.
47 //
48 // Events may be blocked on, or used as `wait_for` arguments to enforce
49 // inter-operation dependencies.
50 class Event {
51  public:
~Event()52   virtual ~Event() {}
53 
54   // Blocks until the event completes and returns the result status.
55   virtual xla::Status Await() = 0;
56   // Returns an empty result if the wait times out.
57   virtual absl::optional<xla::Status> AwaitWithTimeout(
58       absl::Duration duration) = 0;
59 
60   // If the event is already done, the callback is called immediately.
61   virtual void AddCallback(std::function<void(xla::Status)> callback) = 0;
62 };
63 
64 // Represents a device memory allocation.
65 class BufferHandle {
66  public:
~BufferHandle()67   virtual ~BufferHandle() {}
68 
69   // This event completes after the device memory is actually allocated.
70   //
71   // Methods that take a buffer handle, such as ExecuteProgram and Transfer*,
72   // automatically add this event as a dependency.
73   virtual std::shared_ptr<Event> OnReady() = 0;
74 
75   virtual int64_t size_in_bytes() = 0;
76   virtual absl::optional<xla::ShapeProto> shape() = 0;
77 };
78 
79 // Represents a compiled program on the host.
80 class CompiledProgramHandle {
81  public:
~CompiledProgramHandle()82   virtual ~CompiledProgramHandle() {}
83 
84   // This Event completes after the program is actually compiled on the host.
85   //
86   // Methods that take a compiled program handle, including LoadProgram,
87   // automatically add this event as a dependency.
88   virtual std::shared_ptr<Event> OnReady() = 0;
89 
size_in_bytes()90   virtual int64_t size_in_bytes() {
91     LOG(FATAL) << "Unimplemented.";
92     return 0;
93   }
94 
95   // Returns the shape of the compiled program. Blocks until compile completes.
96   virtual xla::Status program_shape(xla::ProgramShapeProto* program_shape) = 0;
97 };
98 
99 // Represents a program loaded on the device.
100 class LoadedProgramHandle {
101  public:
~LoadedProgramHandle()102   virtual ~LoadedProgramHandle() {}
103 
104   // This Event completes after the program is actually loaded on the device.
105   //
106   // Methods that take a loaded program handle, including ExecuteProgram and
107   // UnloadProgram, automatically add this event as a dependency.
108   virtual std::shared_ptr<Event> OnReady() = 0;
109 
size_in_bytes()110   virtual int64_t size_in_bytes() {
111     LOG(FATAL) << "Unimplemented.";
112     return 0;
113   }
114 };
115 
116 // A TpuLinearizer manages the linearization and delinearization of user buffers
117 // in the TPU driver. This interface is not yet implemented.
118 class TpuLinearizer {
119  public:
~TpuLinearizer()120   virtual ~TpuLinearizer() {}
121 
ComputeBytesFromShape(const xla::ShapeProto & shape)122   int64_t ComputeBytesFromShape(const xla::ShapeProto& shape) {
123     return ::tpu_driver::ComputeBytesFromShape(shape);
124   }
125   virtual int64_t ComputeLinearizedBytesFromShape(
126       const xla::ShapeProto& shape) = 0;
127 
128   virtual xla::Status LinearizeShape(void* dst, const void* src,
129                                      const xla::ShapeProto& shape) = 0;
130   virtual xla::Status DelinearizeShape(void* dst, const void* src,
131                                        const xla::ShapeProto& shape) = 0;
132 };
133 
134 // A TpuDriver manages a set of operations scheduled to run on a TPU system.
135 //
136 // By default, two independently scheduled operations may execute in any order.
137 // Ordering can be imposed in one of two ways:
138 //
139 // 1. Users can specify event dependencies via the `wait_for` argument.
140 // 2. Operations using buffer or program handles implicitly wait for the handles
141 //    to become ready before executing.
142 //
143 // For returned handle objects, the user is responsible for calling the release
144 // methods (Deallocate, UnloadProgram, etc.) that consume the given unique_ptr
145 // arguments and free up device resources. For returned event objects, there is
146 // no release method; the user can let them go out of scope naturally. As soon
147 // as those methods accepting plain-pointer arguments return, the user can let
148 // the corresponding smart-pointer objects be released or go out of scope,
149 // regardless of whether the scheduled device operations have started execution.
150 class TpuDriver {
151  public:
~TpuDriver()152   virtual ~TpuDriver() {}
153 
154   virtual void QuerySystemInfo(SystemInfo* system_info) = 0;
155   // Synchronous. Reset the state of the TPU driver. After Reset(), this TPU
156   // driver object is no longer usable. Users must destroy this object and
157   // create a new one.
158   //
159   // All running programs will be terminated and all allocations reset. All
160   // events and buffer handles created prior to Reset() will be invalid, and any
161   // use will result in undefined behavior.
162   virtual xla::Status Reset() = 0;
163 
164   virtual std::unique_ptr<BufferHandle> Allocate(
165       int32_t core_id, MemoryRegion region, int64_t num_bytes,
166       absl::Span<Event* const> wait_for) = 0;
167   virtual std::unique_ptr<BufferHandle> Allocate(
168       int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
169       absl::Span<Event* const> wait_for) = 0;
170 
171   // Allocate a buffer representing a tuple of `children` buffers.
172   //
173   // The returned tuple buffer handle does not manage the memory of `children`:
174   // all `children` buffer handles must outlive the last usage of this tuple
175   // buffer handle. One way to guarantee that is to deallocate the tuple buffer
176   // handle before deallocating any buffer handle in `children`.
177   //
178   // All `children` buffers must exist in the same `core_id` and `region`.
179   // If `children` is empty, a zero-sized tuple will be allocated in `region`.
180   virtual std::unique_ptr<BufferHandle> AllocateTuple(
181       int32_t core_id, MemoryRegion region,
182       absl::Span<BufferHandle* const> children,
183       absl::Span<Event* const> wait_for) = 0;
184   virtual std::shared_ptr<Event> Deallocate(
185       std::unique_ptr<BufferHandle> handle,
186       absl::Span<Event* const> wait_for) = 0;
187 
188   /* For buffers declared with an xla::ShapeProto rather than a raw size,
189    * `src` must be laid out in consecutive row-major format for ingestion, and
190    * each element must take up the number of bytes specified by the type.
191    *
192    * For example, for a [3,3,3] tensor with a Float32 type, the memory layout
193    * would be as follows:
194    *
195    * [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], ..., [0,2,2], [1,0,0], ...
196    * [1,2,2], [2,0,0], ..., [2,2,2],
197    *
198    * and the entire buffer will be 108 bytes (27 elements x 4 bytes).
199    *
200    * See
201    * https://eli.thegreenplace.net/2015/memory-layout-of-multi-dimensional-arrays
202    * for a more detailed description.
203    *
204    * `TransferFromDevice` will write out the shape back in this order as well.
205    */
206   virtual std::shared_ptr<Event> TransferToDevice(
207       const void* src, BufferHandle* dst,
208       absl::Span<Event* const> wait_for) = 0;
209   virtual std::shared_ptr<Event> TransferFromDevice(
210       const BufferHandle* src, void* dst,
211       absl::Span<Event* const> wait_for) = 0;
212 
213   virtual std::shared_ptr<Event> TransferFromDeviceToDevice(
214       const BufferHandle* src, BufferHandle* dst,
215       absl::Span<Event* const> wait_for) = 0;
216 
217   virtual std::unique_ptr<CompiledProgramHandle> CompileProgram(
218       const xla::HloProto& source, int32_t num_replicas,
219       absl::Span<Event* const> wait_for) = 0;
220   virtual std::unique_ptr<LoadedProgramHandle> LoadProgram(
221       int32_t core_id, const CompiledProgramHandle* handle,
222       absl::Span<Event* const> wait_for) = 0;
223   virtual std::shared_ptr<Event> UnloadProgram(
224       std::unique_ptr<LoadedProgramHandle> handle,
225       absl::Span<Event* const> wait_for) = 0;
226   virtual std::shared_ptr<Event> ExecuteProgram(
227       LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
228       absl::Span<BufferHandle* const> outputs,
229       const xla::DeviceAssignmentProto& device_assignment,
230       absl::Span<Event* const> wait_for) = 0;
231 
GetLinearizer()232   virtual std::unique_ptr<TpuLinearizer> GetLinearizer() { return nullptr; }
233 };
234 
235 class TpuDriverRegistry {
236  public:
237   static xla::StatusOr<std::unique_ptr<TpuDriver>> Open(
238       const TpuDriverConfig& config);
239   static int RegisterDriver(
240       const std::string& prefix,
241       const std::function<xla::StatusOr<std::unique_ptr<TpuDriver>>(
242           const TpuDriverConfig&)>& creator);
243 };
244 
245 #define REGISTER_TPU_DRIVER(prefix, fn) \
246   REGISTER_TPU_DRIVER_HELPER(__COUNTER__, prefix, fn)
247 #define REGISTER_TPU_DRIVER_HELPER(ctr, prefix, fn)   \
248   static int register_tpu_driver_count_unused_##ctr = \
249       ::tpu_driver::TpuDriverRegistry::RegisterDriver(prefix, fn);
250 
251 }  // namespace tpu_driver
252 
253 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_TPU_DRIVER_H_
254