1 // Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================== 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_TPU_DRIVER_H_ 17 #define TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_TPU_DRIVER_H_ 18 19 #include <complex> 20 #include <cstdint> 21 #include <functional> 22 #include <memory> 23 #include <string> 24 #include <vector> 25 26 #include "absl/container/flat_hash_map.h" 27 #include "absl/container/inlined_vector.h" 28 #include "absl/synchronization/mutex.h" 29 #include "absl/types/optional.h" 30 #include "absl/types/span.h" 31 #include "tensorflow/compiler/xla/python/tpu_driver/platform/external/compat.h" 32 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h" 33 #include "tensorflow/compiler/xla/service/hlo.pb.h" 34 #include "tensorflow/compiler/xla/status.h" 35 #include "tensorflow/compiler/xla/statusor.h" 36 #include "tensorflow/compiler/xla/xla_data.pb.h" 37 #include "tensorflow/core/platform/logging.h" 38 39 // This API is EXPERIMENTAL and under active development. It is subject to 40 // change without notice. 41 42 namespace tpu_driver { 43 44 int64_t ComputeBytesFromShape(const xla::ShapeProto& shape); 45 46 // Represents the deferred completion of a scheduled operation. 47 // 48 // Events may be blocked on, or used as `wait_for` arguments to enforce 49 // inter-operation dependencies. 50 class Event { 51 public: ~Event()52 virtual ~Event() {} 53 54 // Blocks until the event completes and returns the result status. 55 virtual xla::Status Await() = 0; 56 // Returns an empty result if the wait times out. 57 virtual absl::optional<xla::Status> AwaitWithTimeout( 58 absl::Duration duration) = 0; 59 60 // If the event is already done, the callback is called immediately. 61 virtual void AddCallback(std::function<void(xla::Status)> callback) = 0; 62 }; 63 64 // Represents a device memory allocation. 65 class BufferHandle { 66 public: ~BufferHandle()67 virtual ~BufferHandle() {} 68 69 // This event completes after the device memory is actually allocated. 70 // 71 // Methods that take a buffer handle, such as ExecuteProgram and Transfer*, 72 // automatically add this event as a dependency. 73 virtual std::shared_ptr<Event> OnReady() = 0; 74 75 virtual int64_t size_in_bytes() = 0; 76 virtual absl::optional<xla::ShapeProto> shape() = 0; 77 }; 78 79 // Represents a compiled program on the host. 80 class CompiledProgramHandle { 81 public: ~CompiledProgramHandle()82 virtual ~CompiledProgramHandle() {} 83 84 // This Event completes after the program is actually compiled on the host. 85 // 86 // Methods that take a compiled program handle, including LoadProgram, 87 // automatically add this event as a dependency. 88 virtual std::shared_ptr<Event> OnReady() = 0; 89 size_in_bytes()90 virtual int64_t size_in_bytes() { 91 LOG(FATAL) << "Unimplemented."; 92 return 0; 93 } 94 95 // Returns the shape of the compiled program. Blocks until compile completes. 96 virtual xla::Status program_shape(xla::ProgramShapeProto* program_shape) = 0; 97 }; 98 99 // Represents a program loaded on the device. 100 class LoadedProgramHandle { 101 public: ~LoadedProgramHandle()102 virtual ~LoadedProgramHandle() {} 103 104 // This Event completes after the program is actually loaded on the device. 105 // 106 // Methods that take a loaded program handle, including ExecuteProgram and 107 // UnloadProgram, automatically add this event as a dependency. 108 virtual std::shared_ptr<Event> OnReady() = 0; 109 size_in_bytes()110 virtual int64_t size_in_bytes() { 111 LOG(FATAL) << "Unimplemented."; 112 return 0; 113 } 114 }; 115 116 // A TpuLinearizer manages the linearization and delinearization of user buffers 117 // in the TPU driver. This interface is not yet implemented. 118 class TpuLinearizer { 119 public: ~TpuLinearizer()120 virtual ~TpuLinearizer() {} 121 ComputeBytesFromShape(const xla::ShapeProto & shape)122 int64_t ComputeBytesFromShape(const xla::ShapeProto& shape) { 123 return ::tpu_driver::ComputeBytesFromShape(shape); 124 } 125 virtual int64_t ComputeLinearizedBytesFromShape( 126 const xla::ShapeProto& shape) = 0; 127 128 virtual xla::Status LinearizeShape(void* dst, const void* src, 129 const xla::ShapeProto& shape) = 0; 130 virtual xla::Status DelinearizeShape(void* dst, const void* src, 131 const xla::ShapeProto& shape) = 0; 132 }; 133 134 // A TpuDriver manages a set of operations scheduled to run on a TPU system. 135 // 136 // By default, two independently scheduled operations may execute in any order. 137 // Ordering can be imposed in one of two ways: 138 // 139 // 1. Users can specify event dependencies via the `wait_for` argument. 140 // 2. Operations using buffer or program handles implicitly wait for the handles 141 // to become ready before executing. 142 // 143 // For returned handle objects, the user is responsible for calling the release 144 // methods (Deallocate, UnloadProgram, etc.) that consume the given unique_ptr 145 // arguments and free up device resources. For returned event objects, there is 146 // no release method; the user can let them go out of scope naturally. As soon 147 // as those methods accepting plain-pointer arguments return, the user can let 148 // the corresponding smart-pointer objects be released or go out of scope, 149 // regardless of whether the scheduled device operations have started execution. 150 class TpuDriver { 151 public: ~TpuDriver()152 virtual ~TpuDriver() {} 153 154 virtual void QuerySystemInfo(SystemInfo* system_info) = 0; 155 // Synchronous. Reset the state of the TPU driver. After Reset(), this TPU 156 // driver object is no longer usable. Users must destroy this object and 157 // create a new one. 158 // 159 // All running programs will be terminated and all allocations reset. All 160 // events and buffer handles created prior to Reset() will be invalid, and any 161 // use will result in undefined behavior. 162 virtual xla::Status Reset() = 0; 163 164 virtual std::unique_ptr<BufferHandle> Allocate( 165 int32_t core_id, MemoryRegion region, int64_t num_bytes, 166 absl::Span<Event* const> wait_for) = 0; 167 virtual std::unique_ptr<BufferHandle> Allocate( 168 int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape, 169 absl::Span<Event* const> wait_for) = 0; 170 171 // Allocate a buffer representing a tuple of `children` buffers. 172 // 173 // The returned tuple buffer handle does not manage the memory of `children`: 174 // all `children` buffer handles must outlive the last usage of this tuple 175 // buffer handle. One way to guarantee that is to deallocate the tuple buffer 176 // handle before deallocating any buffer handle in `children`. 177 // 178 // All `children` buffers must exist in the same `core_id` and `region`. 179 // If `children` is empty, a zero-sized tuple will be allocated in `region`. 180 virtual std::unique_ptr<BufferHandle> AllocateTuple( 181 int32_t core_id, MemoryRegion region, 182 absl::Span<BufferHandle* const> children, 183 absl::Span<Event* const> wait_for) = 0; 184 virtual std::shared_ptr<Event> Deallocate( 185 std::unique_ptr<BufferHandle> handle, 186 absl::Span<Event* const> wait_for) = 0; 187 188 /* For buffers declared with an xla::ShapeProto rather than a raw size, 189 * `src` must be laid out in consecutive row-major format for ingestion, and 190 * each element must take up the number of bytes specified by the type. 191 * 192 * For example, for a [3,3,3] tensor with a Float32 type, the memory layout 193 * would be as follows: 194 * 195 * [0,0,0], [0,0,1], [0,0,2], [0,1,0], [0,1,1], ..., [0,2,2], [1,0,0], ... 196 * [1,2,2], [2,0,0], ..., [2,2,2], 197 * 198 * and the entire buffer will be 108 bytes (27 elements x 4 bytes). 199 * 200 * See 201 * https://eli.thegreenplace.net/2015/memory-layout-of-multi-dimensional-arrays 202 * for a more detailed description. 203 * 204 * `TransferFromDevice` will write out the shape back in this order as well. 205 */ 206 virtual std::shared_ptr<Event> TransferToDevice( 207 const void* src, BufferHandle* dst, 208 absl::Span<Event* const> wait_for) = 0; 209 virtual std::shared_ptr<Event> TransferFromDevice( 210 const BufferHandle* src, void* dst, 211 absl::Span<Event* const> wait_for) = 0; 212 213 virtual std::shared_ptr<Event> TransferFromDeviceToDevice( 214 const BufferHandle* src, BufferHandle* dst, 215 absl::Span<Event* const> wait_for) = 0; 216 217 virtual std::unique_ptr<CompiledProgramHandle> CompileProgram( 218 const xla::HloProto& source, int32_t num_replicas, 219 absl::Span<Event* const> wait_for) = 0; 220 virtual std::unique_ptr<LoadedProgramHandle> LoadProgram( 221 int32_t core_id, const CompiledProgramHandle* handle, 222 absl::Span<Event* const> wait_for) = 0; 223 virtual std::shared_ptr<Event> UnloadProgram( 224 std::unique_ptr<LoadedProgramHandle> handle, 225 absl::Span<Event* const> wait_for) = 0; 226 virtual std::shared_ptr<Event> ExecuteProgram( 227 LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs, 228 absl::Span<BufferHandle* const> outputs, 229 const xla::DeviceAssignmentProto& device_assignment, 230 absl::Span<Event* const> wait_for) = 0; 231 GetLinearizer()232 virtual std::unique_ptr<TpuLinearizer> GetLinearizer() { return nullptr; } 233 }; 234 235 class TpuDriverRegistry { 236 public: 237 static xla::StatusOr<std::unique_ptr<TpuDriver>> Open( 238 const TpuDriverConfig& config); 239 static int RegisterDriver( 240 const std::string& prefix, 241 const std::function<xla::StatusOr<std::unique_ptr<TpuDriver>>( 242 const TpuDriverConfig&)>& creator); 243 }; 244 245 #define REGISTER_TPU_DRIVER(prefix, fn) \ 246 REGISTER_TPU_DRIVER_HELPER(__COUNTER__, prefix, fn) 247 #define REGISTER_TPU_DRIVER_HELPER(ctr, prefix, fn) \ 248 static int register_tpu_driver_count_unused_##ctr = \ 249 ::tpu_driver::TpuDriverRegistry::RegisterDriver(prefix, fn); 250 251 } // namespace tpu_driver 252 253 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_TPU_DRIVER_H_ 254