1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The compiler API is used by the XLA service to generate executables that
17 // run on a given platform. This is a registry and abstract interface, for
18 // pluggability by the various platforms.
19 
20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
21 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
22 
23 #include <functional>
24 #include <map>
25 #include <memory>
26 #include <string>
27 #include <vector>
28 
29 #include "absl/types/span.h"
30 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
31 #include "tensorflow/compiler/xla/service/buffer_value.h"
32 #include "tensorflow/compiler/xla/service/computation_placer.h"
33 #include "tensorflow/compiler/xla/service/executable.h"
34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
35 #include "tensorflow/compiler/xla/service/hlo_module.h"
36 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
37 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
38 #include "tensorflow/compiler/xla/service/logical_buffer.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/core/platform/mutex.h"
42 #include "tensorflow/core/platform/protobuf.h"
43 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
44 #include "tensorflow/core/platform/thread_annotations.h"
45 #include "tensorflow/core/platform/threadpool.h"
46 
47 namespace xla {
48 
49 // The following types are used for ahead of time compilation.
50 
51 // Contains the object file data created as a result of ahead-of-time
52 // computation.
53 using ObjectFileData = std::vector<char>;
54 
55 // Abstract superclass describing the result of an ahead-of-time compilation.
56 class AotCompilationResult {
57  public:
58   AotCompilationResult(const AotCompilationResult&) = delete;
59   AotCompilationResult& operator=(AotCompilationResult const&) = delete;
60 
61   virtual ~AotCompilationResult() = default;
62 
63  protected:
64   AotCompilationResult() = default;
65 };
66 
67 // Abstract superclass describing options to an ahead-of-time compilation.
68 class AotCompilationOptions {
69  public:
70   AotCompilationOptions(const AotCompilationOptions&) = delete;
71   AotCompilationOptions& operator=(AotCompilationOptions const&) = delete;
72 
73   virtual ~AotCompilationOptions() = default;
74 
75   // Returns the ID of the platform to which these options apply.
76   virtual se::Platform::Id PlatformId() const = 0;
77 
replica_count()78   virtual int64 replica_count() const { return 0; }
num_cores()79   virtual int64 num_cores() const { return 0; }
broadcast_replicated_params()80   virtual bool broadcast_replicated_params() const { return false; }
use_spmd_partitioning()81   virtual bool use_spmd_partitioning() const { return false; }
deduplicate_hlo()82   virtual bool deduplicate_hlo() const { return false; }
83 
84   // Optional allocator that may be used for allocating temp space on the device
85   // during compilation.
device_allocator()86   se::DeviceMemoryAllocator* device_allocator() const {
87     return device_allocator_;
88   }
set_device_allocator(se::DeviceMemoryAllocator * device_allocator)89   void set_device_allocator(se::DeviceMemoryAllocator* device_allocator) {
90     device_allocator_ = device_allocator;
91   }
92 
debug_options()93   const DebugOptions& debug_options() const { return debug_options_; }
mutable_debug_options()94   DebugOptions* mutable_debug_options() { return &debug_options_; }
95 
has_static_device_assignment()96   bool has_static_device_assignment() const {
97     return static_device_assignment_.has_value();
98   }
static_device_assignment()99   const DeviceAssignment& static_device_assignment() const {
100     CHECK(static_device_assignment_.has_value());
101     return *static_device_assignment_;
102   }
set_static_device_assignment(const DeviceAssignment & device_assignment)103   void set_static_device_assignment(const DeviceAssignment& device_assignment) {
104     static_device_assignment_ = device_assignment;
105   }
106 
fusion_config_collection()107   FusionConfigCollection fusion_config_collection() const {
108     return fusion_config_collection_;
109   }
set_fusion_config_collection(FusionConfigCollection fusion_config_collection)110   void set_fusion_config_collection(
111       FusionConfigCollection fusion_config_collection) {
112     fusion_config_collection_ = fusion_config_collection;
113   }
114 
fusion_config()115   const std::vector<std::vector<bool>>& fusion_config() const {
116     return fusion_config_;
117   }
set_fusion_config(const std::vector<std::vector<bool>> & fusion_config)118   void set_fusion_config(const std::vector<std::vector<bool>>& fusion_config) {
119     fusion_config_ = fusion_config;
120   }
121 
122  protected:
123   AotCompilationOptions();
124 
125  private:
126   se::DeviceMemoryAllocator* device_allocator_ = nullptr;
127   DebugOptions debug_options_;
128   absl::optional<DeviceAssignment> static_device_assignment_;
129   std::vector<std::vector<bool>> fusion_config_;
130   FusionConfigCollection fusion_config_collection_ =
131       FusionConfigCollection::kOff;
132 };
133 
134 // Abstract superclass describing metadata produced during ahead-of-time
135 // compilation.
136 class AotCompilationMetadata {
137  public:
138   AotCompilationMetadata(const AotCompilationMetadata&) = delete;
139   AotCompilationMetadata& operator=(AotCompilationMetadata const&) = delete;
ToString()140   virtual std::string ToString() const { return ""; }
141   virtual ~AotCompilationMetadata() = default;
142 
143  protected:
144   AotCompilationMetadata() = default;
145 };
146 
147 // Abstract compiler interface that is subclassed for compilation on a
148 // particular platform.
149 //
150 // The compiler ties together high level optimization (HLO) and low level
151 // optimization (LLO) / codegen (CG) to generate efficient executables for the
152 // target platform.
153 //
154 // The platform-based compiler singletons are registered via module initializers
155 // in their corresponding XLA compiler libraries, and are registered via the
156 // RegisterCompilerFactory API below.
157 //
158 // Thread-safety: subclasses of Compiler must be thread-safe, as multiple
159 // XLA clients may be requesting compilation concurrently for a given
160 // platform.
161 class Compiler {
162  public:
163   struct CompileOptions {
164     // If device_allocator is not null, the compiler may use it to allocate temp
165     // space on the device for use during compilation.  For example, the
166     // compiler may allocate buffers on the device and then run variants of a
167     // given algorithm over those buffers, to see which variant is fastest.  Any
168     // space allocated will be deallocated before the compilation returns.
169     se::DeviceMemoryAllocator* device_allocator = nullptr;
170 
171     // An optional thread pool for parallel compilation.
172     tensorflow::thread::ThreadPool* thread_pool = nullptr;
173   };
174 
~Compiler()175   virtual ~Compiler() {}
176 
177   // Returns the ID of the platform that this compiler targets.
178   virtual se::Platform::Id PlatformId() const = 0;
179 
180   // Runs Hlo passes to optimize the given Hlo module, returns the optimized
181   // module.
182   virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
183       std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
184       const CompileOptions& options) = 0;
RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,se::DeviceMemoryAllocator * device_allocator)185   StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
186       std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
187       se::DeviceMemoryAllocator* device_allocator) {
188     return RunHloPasses(std::move(module), executor,
189                         CompileOptions{device_allocator});
190   }
191 
192   // Runs HLO passes to optimize the given HloModule, perform scheduling and
193   // buffer assignment, returns the optimized module and the buffer assignments.
194   // This interface is intentionally narrow.
195   virtual StatusOr<
196       std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,bool optimize,const CompileOptions & options)197   RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
198                                    se::StreamExecutor* executor, bool optimize,
199                                    const CompileOptions& options) {
200     return Unimplemented("This compiler does not support this method");
201   }
202 
203   // Compiles the HLO module for execution on a device given by the executor,
204   // and returns an executable object or an error status. No HLO passes are
205   // applied to module. Generally a module should be passed through RunHloPasses
206   // prior to calling this method because some HLO passes are required for
207   // correctness. Takes ownership of the HLO module.
208   //
209   // The compiler may optionally specialize to the individual device
210   // (not just type of device) indicated by the executor.
211   virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
212       std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
213       const CompileOptions& options) = 0;
RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,se::DeviceMemoryAllocator * device_allocator)214   StatusOr<std::unique_ptr<Executable>> RunBackend(
215       std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
216       se::DeviceMemoryAllocator* device_allocator) {
217     return RunBackend(std::move(module), executor,
218                       CompileOptions{device_allocator});
219   }
220 
221   // Compiles a set of HLO modules that can run in parallel, potentially
222   // communicating data between the modules, and returns a corresponding
223   // sequence of executable objects.
224   //
225   // TODO(b/68666782): Remove this method after adding support for multiple
226   // modules to RunHloPasses and RunBackends.
227   virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
228       std::unique_ptr<HloModuleGroup> module_group,
229       std::vector<std::vector<se::StreamExecutor*>> stream_exec,
230       const CompileOptions& options) = 0;
Compile(std::unique_ptr<HloModuleGroup> module_group,std::vector<std::vector<se::StreamExecutor * >> stream_exec,se::DeviceMemoryAllocator * device_allocator)231   StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
232       std::unique_ptr<HloModuleGroup> module_group,
233       std::vector<std::vector<se::StreamExecutor*>> stream_exec,
234       se::DeviceMemoryAllocator* device_allocator) {
235     return Compile(std::move(module_group), stream_exec,
236                    CompileOptions{device_allocator});
237   }
238 
239   // Returns the backend configurations that the backend will consider for the
240   // given HLO. Returns no configurations if the backend does not support
241   // configurations for the given HLO.
242   //
243   // The stream executor is passed in to provide information about the hardware
244   // that the backend configurations would be targeting.
245   virtual std::vector<std::unique_ptr<tensorflow::protobuf::Message>>
246   ComputeBackendConfigs(const HloInstruction& hlo,
247                         se::StreamExecutor* executor) const;
248 
249   // Returns the backend configuration that the backend chooses by default for
250   // the given HLO. Returns no configuration if the backend does not support
251   // configurations for the given HLO.
252   //
253   // The stream executor is passed in to provide information about the hardware
254   // that the backend configurations would be targeting.
255   virtual std::unique_ptr<tensorflow::protobuf::Message>
256   ComputeDefaultBackendConfig(const HloInstruction& hlo,
257                               se::StreamExecutor* executor) const;
258 
259   // Compiles the HLO module group for ahead-of-time execution.  This is
260   // intended for use in static compilation.
261   virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
262   CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
263                      const AotCompilationOptions& options) = 0;
264 
265   // Similar to CompileAheadOfTime above but AotCompilationMetadata
266   // has an argument that can be populated during compilation.
267   virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
268   CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
269                      const AotCompilationOptions& options,
270                      std::unique_ptr<AotCompilationMetadata>* metadata);
271 
272   /////
273   // The Compiler class also serves as a point to register compiler objects
274   // for the various platforms.
275 
276   using CompilerFactory = std::function<std::unique_ptr<Compiler>()>;
277 
278   // Registers the compiler singleton for the platform. This is assumed to
279   // be a singleton, so no ownership is transferred.
280   //
281   // Precondition: a platform kind must not be registered more than once.
282   static void RegisterCompilerFactory(se::Platform::Id platform_id,
283                                       CompilerFactory compiler_factory);
284 
285   // Returns the compiler singleton pointer if it is available for the given
286   // platform, or an error status if it is not.
287   static StatusOr<Compiler*> GetForPlatform(const se::Platform* platform);
288 
289   // Returns a function that computes the size in bytes of the logical
290   // buffer that contains a shape.
291   virtual HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const = 0;
292 
293   // Returns a function that computes the size in bytes of a given
294   // logical buffer.
BufferSizeBytesFunction()295   std::function<int64(const BufferValue&)> BufferSizeBytesFunction() {
296     HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction();
297     return [shape_size](const BufferValue& buffer) {
298       return shape_size(buffer.shape());
299     };
300   }
301 
302  private:
303   // Mutex that guards the platform-compiler map.
304   static tensorflow::mutex platform_compiler_mutex_;
305 
306   // Map from platform kind to compiler factory.
307   static std::map<se::Platform::Id, CompilerFactory>*
308   GetPlatformCompilerFactories();
309 
310   // Map from platform kind to compiler instance, if we made one already (based
311   // on the factories above).
312   static std::map<se::Platform::Id, std::unique_ptr<Compiler>>*
313   GetPlatformCompilers();
314 };
315 
316 }  // namespace xla
317 
318 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
319