1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // The compiler API is used by the XLA service to generate executables that 17 // run on a given platform. This is a registry and abstract interface, for 18 // pluggability by the various platforms. 19 20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ 21 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ 22 23 #include <functional> 24 #include <map> 25 #include <memory> 26 #include <string> 27 #include <vector> 28 29 #include "absl/types/span.h" 30 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 31 #include "tensorflow/compiler/xla/service/buffer_value.h" 32 #include "tensorflow/compiler/xla/service/computation_placer.h" 33 #include "tensorflow/compiler/xla/service/executable.h" 34 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 35 #include "tensorflow/compiler/xla/service/hlo_module.h" 36 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 37 #include "tensorflow/compiler/xla/service/hlo_module_group.h" 38 #include "tensorflow/compiler/xla/service/logical_buffer.h" 39 #include "tensorflow/compiler/xla/statusor.h" 40 #include "tensorflow/compiler/xla/types.h" 41 #include "tensorflow/core/platform/mutex.h" 42 #include "tensorflow/core/platform/protobuf.h" 43 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 44 #include "tensorflow/core/platform/thread_annotations.h" 45 #include "tensorflow/core/platform/threadpool.h" 46 47 namespace xla { 48 49 // The following types are used for ahead of time compilation. 50 51 // Contains the object file data created as a result of ahead-of-time 52 // computation. 53 using ObjectFileData = std::vector<char>; 54 55 // Abstract superclass describing the result of an ahead-of-time compilation. 56 class AotCompilationResult { 57 public: 58 AotCompilationResult(const AotCompilationResult&) = delete; 59 AotCompilationResult& operator=(AotCompilationResult const&) = delete; 60 61 virtual ~AotCompilationResult() = default; 62 63 protected: 64 AotCompilationResult() = default; 65 }; 66 67 // Abstract superclass describing options to an ahead-of-time compilation. 68 class AotCompilationOptions { 69 public: 70 AotCompilationOptions(const AotCompilationOptions&) = delete; 71 AotCompilationOptions& operator=(AotCompilationOptions const&) = delete; 72 73 virtual ~AotCompilationOptions() = default; 74 75 // Returns the ID of the platform to which these options apply. 76 virtual se::Platform::Id PlatformId() const = 0; 77 replica_count()78 virtual int64 replica_count() const { return 0; } num_cores()79 virtual int64 num_cores() const { return 0; } broadcast_replicated_params()80 virtual bool broadcast_replicated_params() const { return false; } use_spmd_partitioning()81 virtual bool use_spmd_partitioning() const { return false; } deduplicate_hlo()82 virtual bool deduplicate_hlo() const { return false; } 83 84 // Optional allocator that may be used for allocating temp space on the device 85 // during compilation. device_allocator()86 se::DeviceMemoryAllocator* device_allocator() const { 87 return device_allocator_; 88 } set_device_allocator(se::DeviceMemoryAllocator * device_allocator)89 void set_device_allocator(se::DeviceMemoryAllocator* device_allocator) { 90 device_allocator_ = device_allocator; 91 } 92 debug_options()93 const DebugOptions& debug_options() const { return debug_options_; } mutable_debug_options()94 DebugOptions* mutable_debug_options() { return &debug_options_; } 95 has_static_device_assignment()96 bool has_static_device_assignment() const { 97 return static_device_assignment_.has_value(); 98 } static_device_assignment()99 const DeviceAssignment& static_device_assignment() const { 100 CHECK(static_device_assignment_.has_value()); 101 return *static_device_assignment_; 102 } set_static_device_assignment(const DeviceAssignment & device_assignment)103 void set_static_device_assignment(const DeviceAssignment& device_assignment) { 104 static_device_assignment_ = device_assignment; 105 } 106 fusion_config_collection()107 FusionConfigCollection fusion_config_collection() const { 108 return fusion_config_collection_; 109 } set_fusion_config_collection(FusionConfigCollection fusion_config_collection)110 void set_fusion_config_collection( 111 FusionConfigCollection fusion_config_collection) { 112 fusion_config_collection_ = fusion_config_collection; 113 } 114 fusion_config()115 const std::vector<std::vector<bool>>& fusion_config() const { 116 return fusion_config_; 117 } set_fusion_config(const std::vector<std::vector<bool>> & fusion_config)118 void set_fusion_config(const std::vector<std::vector<bool>>& fusion_config) { 119 fusion_config_ = fusion_config; 120 } 121 122 protected: 123 AotCompilationOptions(); 124 125 private: 126 se::DeviceMemoryAllocator* device_allocator_ = nullptr; 127 DebugOptions debug_options_; 128 absl::optional<DeviceAssignment> static_device_assignment_; 129 std::vector<std::vector<bool>> fusion_config_; 130 FusionConfigCollection fusion_config_collection_ = 131 FusionConfigCollection::kOff; 132 }; 133 134 // Abstract superclass describing metadata produced during ahead-of-time 135 // compilation. 136 class AotCompilationMetadata { 137 public: 138 AotCompilationMetadata(const AotCompilationMetadata&) = delete; 139 AotCompilationMetadata& operator=(AotCompilationMetadata const&) = delete; ToString()140 virtual std::string ToString() const { return ""; } 141 virtual ~AotCompilationMetadata() = default; 142 143 protected: 144 AotCompilationMetadata() = default; 145 }; 146 147 // Abstract compiler interface that is subclassed for compilation on a 148 // particular platform. 149 // 150 // The compiler ties together high level optimization (HLO) and low level 151 // optimization (LLO) / codegen (CG) to generate efficient executables for the 152 // target platform. 153 // 154 // The platform-based compiler singletons are registered via module initializers 155 // in their corresponding XLA compiler libraries, and are registered via the 156 // RegisterCompilerFactory API below. 157 // 158 // Thread-safety: subclasses of Compiler must be thread-safe, as multiple 159 // XLA clients may be requesting compilation concurrently for a given 160 // platform. 161 class Compiler { 162 public: 163 struct CompileOptions { 164 // If device_allocator is not null, the compiler may use it to allocate temp 165 // space on the device for use during compilation. For example, the 166 // compiler may allocate buffers on the device and then run variants of a 167 // given algorithm over those buffers, to see which variant is fastest. Any 168 // space allocated will be deallocated before the compilation returns. 169 se::DeviceMemoryAllocator* device_allocator = nullptr; 170 171 // An optional thread pool for parallel compilation. 172 tensorflow::thread::ThreadPool* thread_pool = nullptr; 173 }; 174 ~Compiler()175 virtual ~Compiler() {} 176 177 // Returns the ID of the platform that this compiler targets. 178 virtual se::Platform::Id PlatformId() const = 0; 179 180 // Runs Hlo passes to optimize the given Hlo module, returns the optimized 181 // module. 182 virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses( 183 std::unique_ptr<HloModule> module, se::StreamExecutor* executor, 184 const CompileOptions& options) = 0; RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,se::DeviceMemoryAllocator * device_allocator)185 StatusOr<std::unique_ptr<HloModule>> RunHloPasses( 186 std::unique_ptr<HloModule> module, se::StreamExecutor* executor, 187 se::DeviceMemoryAllocator* device_allocator) { 188 return RunHloPasses(std::move(module), executor, 189 CompileOptions{device_allocator}); 190 } 191 192 // Runs HLO passes to optimize the given HloModule, perform scheduling and 193 // buffer assignment, returns the optimized module and the buffer assignments. 194 // This interface is intentionally narrow. 195 virtual StatusOr< 196 std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>> RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,bool optimize,const CompileOptions & options)197 RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module, 198 se::StreamExecutor* executor, bool optimize, 199 const CompileOptions& options) { 200 return Unimplemented("This compiler does not support this method"); 201 } 202 203 // Compiles the HLO module for execution on a device given by the executor, 204 // and returns an executable object or an error status. No HLO passes are 205 // applied to module. Generally a module should be passed through RunHloPasses 206 // prior to calling this method because some HLO passes are required for 207 // correctness. Takes ownership of the HLO module. 208 // 209 // The compiler may optionally specialize to the individual device 210 // (not just type of device) indicated by the executor. 211 virtual StatusOr<std::unique_ptr<Executable>> RunBackend( 212 std::unique_ptr<HloModule> module, se::StreamExecutor* executor, 213 const CompileOptions& options) = 0; RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,se::DeviceMemoryAllocator * device_allocator)214 StatusOr<std::unique_ptr<Executable>> RunBackend( 215 std::unique_ptr<HloModule> module, se::StreamExecutor* executor, 216 se::DeviceMemoryAllocator* device_allocator) { 217 return RunBackend(std::move(module), executor, 218 CompileOptions{device_allocator}); 219 } 220 221 // Compiles a set of HLO modules that can run in parallel, potentially 222 // communicating data between the modules, and returns a corresponding 223 // sequence of executable objects. 224 // 225 // TODO(b/68666782): Remove this method after adding support for multiple 226 // modules to RunHloPasses and RunBackends. 227 virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( 228 std::unique_ptr<HloModuleGroup> module_group, 229 std::vector<std::vector<se::StreamExecutor*>> stream_exec, 230 const CompileOptions& options) = 0; Compile(std::unique_ptr<HloModuleGroup> module_group,std::vector<std::vector<se::StreamExecutor * >> stream_exec,se::DeviceMemoryAllocator * device_allocator)231 StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( 232 std::unique_ptr<HloModuleGroup> module_group, 233 std::vector<std::vector<se::StreamExecutor*>> stream_exec, 234 se::DeviceMemoryAllocator* device_allocator) { 235 return Compile(std::move(module_group), stream_exec, 236 CompileOptions{device_allocator}); 237 } 238 239 // Returns the backend configurations that the backend will consider for the 240 // given HLO. Returns no configurations if the backend does not support 241 // configurations for the given HLO. 242 // 243 // The stream executor is passed in to provide information about the hardware 244 // that the backend configurations would be targeting. 245 virtual std::vector<std::unique_ptr<tensorflow::protobuf::Message>> 246 ComputeBackendConfigs(const HloInstruction& hlo, 247 se::StreamExecutor* executor) const; 248 249 // Returns the backend configuration that the backend chooses by default for 250 // the given HLO. Returns no configuration if the backend does not support 251 // configurations for the given HLO. 252 // 253 // The stream executor is passed in to provide information about the hardware 254 // that the backend configurations would be targeting. 255 virtual std::unique_ptr<tensorflow::protobuf::Message> 256 ComputeDefaultBackendConfig(const HloInstruction& hlo, 257 se::StreamExecutor* executor) const; 258 259 // Compiles the HLO module group for ahead-of-time execution. This is 260 // intended for use in static compilation. 261 virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> 262 CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group, 263 const AotCompilationOptions& options) = 0; 264 265 // Similar to CompileAheadOfTime above but AotCompilationMetadata 266 // has an argument that can be populated during compilation. 267 virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> 268 CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group, 269 const AotCompilationOptions& options, 270 std::unique_ptr<AotCompilationMetadata>* metadata); 271 272 ///// 273 // The Compiler class also serves as a point to register compiler objects 274 // for the various platforms. 275 276 using CompilerFactory = std::function<std::unique_ptr<Compiler>()>; 277 278 // Registers the compiler singleton for the platform. This is assumed to 279 // be a singleton, so no ownership is transferred. 280 // 281 // Precondition: a platform kind must not be registered more than once. 282 static void RegisterCompilerFactory(se::Platform::Id platform_id, 283 CompilerFactory compiler_factory); 284 285 // Returns the compiler singleton pointer if it is available for the given 286 // platform, or an error status if it is not. 287 static StatusOr<Compiler*> GetForPlatform(const se::Platform* platform); 288 289 // Returns a function that computes the size in bytes of the logical 290 // buffer that contains a shape. 291 virtual HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const = 0; 292 293 // Returns a function that computes the size in bytes of a given 294 // logical buffer. BufferSizeBytesFunction()295 std::function<int64(const BufferValue&)> BufferSizeBytesFunction() { 296 HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction(); 297 return [shape_size](const BufferValue& buffer) { 298 return shape_size(buffer.shape()); 299 }; 300 } 301 302 private: 303 // Mutex that guards the platform-compiler map. 304 static tensorflow::mutex platform_compiler_mutex_; 305 306 // Map from platform kind to compiler factory. 307 static std::map<se::Platform::Id, CompilerFactory>* 308 GetPlatformCompilerFactories(); 309 310 // Map from platform kind to compiler instance, if we made one already (based 311 // on the factories above). 312 static std::map<se::Platform::Id, std::unique_ptr<Compiler>>* 313 GetPlatformCompilers(); 314 }; 315 316 } // namespace xla 317 318 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ 319