1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
18 
19 #include <stack>
20 
21 #include "absl/types/span.h"
22 #include "absl/types/variant.h"
23 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
24 #include "tensorflow/compiler/tf2xla/xla_argument.h"
25 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
26 #include "tensorflow/compiler/tf2xla/xla_expression.h"
27 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
28 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
29 #include "tensorflow/compiler/xla/client/local_client.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/client/xla_computation.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/core/common_runtime/device.h"
34 #include "tensorflow/core/common_runtime/device_mgr.h"
35 #include "tensorflow/core/common_runtime/function.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/platform/env.h"
39 #include "tensorflow/core/platform/mutex.h"
40 #include "tensorflow/core/platform/notification.h"
41 #include "tensorflow/core/platform/thread_annotations.h"
42 #include "tensorflow/core/public/version.h"
43 
44 namespace tensorflow {
45 
46 class XlaContext;
47 
48 // The XlaCompiler class is responsible for compilation of a self-contained
49 // subgraph of a TensorFlow computation using the XLA linear algebra runtime.
50 // It does a symbolic execution of the graph starting from specific input
51 // shapes, using a JIT device to convert operators into XLA computations.
52 //
53 // XlaCompiler is typically invoked from an `XlaLaunch` operator once the
54 // shapes of all input parameters to the computation are known. This is
55 // because the symbolic execution requires known shapes for all operations.
56 //
57 // XlaCompiler compiles Tensorflow graphs that received inputs via _Arg nodes,
58 // and return outputs via _Retval nodes.
59 //
60 // The XlaCompiler requires one Argument struct for each _Arg index, that
61 // describes each argument. Arguments can be compile-time constants
62 // (kind kConstant), run-time parameters (kind kParameter), or resources
63 // (kind kResource).
64 //
65 // Only kParameter and initialized kResource arguments become runtime parameters
66 // to the generated XLA computation.
67 //
68 // The run-time outputs of the XLA computation are arranged in the following
69 // order:
70 //   +------------------+-----------------------------------------+
71 //   |  _Retval values  |  Updated values of kResource arguments  |
72 //   +------------------+-----------------------------------------+
73 // _Retval values are ordered by _Retval index, whereas kResource values are
74 // ordered by the original _Arg position of the variable.
75 //
76 // If a shape representation function is provided as part of
77 // XlaCompiler::CompileOptions, kParameter arguments and return values to an
78 // entry computation will be reshaped in accordance to the shape function.
79 // Arguments and return values to a non-entry computation are not reshaped.
80 // Variable resource arguments are passed and returned in reshaped form, even
81 // for non-entry computations. This feature allows TensorFlow to keep on-device
82 // tensors with a different shape to their representation inside the XLA
83 // computation.
84 //
85 // In computation outputs, updated kResource values are placed the end. When
86 // emitting While loop bodies, we must ensure that the loop body has
87 // identical input and output signatures. By passing variable values
88 // at the end of the argument list and using the
89 // `return_updated_values_for_all_variables` option, we can ensure that the
90 // input and output values of resources appear at the same positions.
91 //
92 // Resources are passed as parameters or returned as resource updates in
93 // "packed" form.
94 // kStack resources are packed as (array, size of stack) XLA tuples.
95 // kTensorArray resources without gradients are packed as the array that
96 // backs the TensorArray. If gradients are present (`tensor_array_gradients`),
97 // the packed representation is a (array, gradient0, gradient1, ...) tuple,
98 // where gradient_k is the value of the k-th gradient in the
99 // `tensor_array_gradients` ordered set.
100 class XlaCompiler {
101  public:
102   using Argument = ::tensorflow::XlaArgument;
103 
104   // Options pertaining to an individual call to CompileGraph() or
105   // CompileFunction().
106   struct CompileOptions {
107     // If `use_tuple_arg` is true, a single tuple parameter will be used for all
108     // arguments; if false, each argument gets its own parameter.
109     bool use_tuple_arg = false;
110 
111     // If 'return_updated_values_for_all_resources' is true, then updated
112     // values of all resource arguments will be included in the
113     // 'resource_updates' of the computation, even if the resource was not
114     // modified by the computation. Used when compiling loop bodies to ensure
115     // the input and output signatures match.
116     bool return_updated_values_for_all_resources = false;
117 
118     // If 'always_return_tuple' is true, then the output of a computation will
119     // always be a tuple. Otherwise, a single-element output will not be wrapped
120     // in a tuple.
121     bool always_return_tuple = true;
122 
123     // True when compiling the entry computation, false for subcomputations
124     // (while, call, etc.)
125     bool is_entry_computation = true;
126 
127     // True when we should add XLA input & output to the graph/function.
128     bool add_token_input_output = false;
129 
130     // Resource updates are converted into input / output of xla. The two
131     // buffers are aliased with other if this option is true.
132     bool alias_resource_update = false;
133   };
134 
135   using OutputDescription = ::tensorflow::XlaOutputDescription;
136 
137   using ResourceUpdate = ::tensorflow::XlaResourceUpdate;
138 
139   using CompilationResult = ::tensorflow::XlaCompilationResult;
140 
141   typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType,
142                                                   bool)>
143       ShapeRepresentationFn;
144   struct Options {
145     // Name of the compilation device to use. It must be set by the caller.
146     // The default empty value is invalid.
147     DeviceType device_type = DeviceType("");
148 
149     // The device to use during compilation to execute instructions on, for
150     // example for auto-tuning.
151     // Valid values are defined by `xla::Backend::devices_ordinal_supported()`.
152     // -1 indicates the default device should be used.
153     int device_ordinal = -1;
154 
155     xla::Client* client = nullptr;
156 
157     // Function library in which to find function definitions. Must be non-null.
158     const FunctionLibraryDefinition* flib_def = nullptr;
159 
160     // The graph def version to be compiled.
161     int graph_def_version = TF_GRAPH_DEF_VERSION;
162 
163     // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall()
164     // for CPU.
165     bool allow_cpu_custom_calls = false;
166 
167     // If both this and 'allow_cpu_custom_calls' are true then tf.fake_quant_*
168     // ops will be emitted as custom calls to a 'fake_quant_with_min_max_vars'
169     // function accepting the input, min, max, num_bits, and narrow_range values
170     // as runtime arguments.
171     bool custom_fake_quant_op_calls = false;
172 
173     // If set, the XLA representation of variables represented to XLA as the
174     // shape given by this shape function. Variables are reshaped to this shape
175     // on write, and reshaped to their original shape on read.
176     ShapeRepresentationFn shape_representation_fn;
177 
178     // If not nullptr, populate_resource_manager is called with the
179     // compilation device's resource manager when the compilation
180     // device is created, and can be used to create metadata objects
181     // that can be accessed by XLA op kernels.
182     std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr;
183 
184     // If not nullptr, this memory allocator can be used by the compiler for
185     // temporary allocations it might want to make during compilation.
186     //
187     // For example, the compiler may want to try out different algorithms and
188     // choose the fastest one, and it might run those algorithms over buffers
189     // created using this allocator.
190     //
191     // The compiler can function correctly without an explicit allocator given
192     // here, but on some devices (notably, GPUs), TensorFlow tends to eagerly
193     // allocate most or all available memory on the device, leaving none for the
194     // compiler to access, unless it can use TensorFlow's allocator.
195     se::DeviceMemoryAllocator* device_allocator = nullptr;
196 
197     // Alias input and output buffers for parameters that are passed-through XLA
198     // modules without being changed.
199     bool alias_passthrough_params = false;
200 
201     // Enable detailed logging of compilation metadata.
202     bool detailed_logging = true;
203   };
204 
205   explicit XlaCompiler(Options options);
206 
207   ~XlaCompiler();
208 
209   // Helper function to populate an XlaCompiler::Argument from XlaResource.
210   static void PopulateArgumentFromResource(const XlaResource& resource,
211                                            Argument* arg);
212 
213   Status CompileFunction(const CompileOptions& options,
214                          const NameAttrList& fn_name_attrs,
215                          absl::Span<const Argument> args,
216                          CompilationResult* result);
217 
218   // Compiles a tensorflow::Graph into an xla::XlaComputation.
219   // Similar to CompileFunction, but takes a Graph as input rather than a
220   // function.
221   Status CompileGraph(
222       const CompileOptions& options, string const& name,
223       std::unique_ptr<Graph> graph, absl::Span<const Argument> args,
224       CompilationResult* result);
225 
226   // Returns the shape of the XLA parameter for an argument 'arg'.
227   // See the class comment for more details about the argument passing
228   // convention.
229   Status XLAShapeForArgument(
230       const Argument& arg, bool is_entry_computation,
231       const absl::optional<xla::HloSharding>& arg_sharding,
232       xla::Shape* xla_shape) const;
233 
234   // Retrieves the channel handle associated with `key`. Allocates
235   // a new channel handle if none exists.
236   // Channel handles can be used to communicate between different
237   // computations. Computations that communicate should be compiled with the
238   // same XlaCompiler.
239   Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
240 
241   // Retrieves the host-to-device channel handle associated with `key`.
242   // Allocates a new channel handle if none exists.
243   Status GetHostToDeviceChannelHandle(const string& key,
244                                       xla::ChannelHandle* channel);
245 
246   // Retrieves the device-to-host channel handle associated with `key`.
247   // Allocates a new channel handle if none exists.
248   Status GetDeviceToHostChannelHandle(const string& key,
249                                       xla::ChannelHandle* channel);
250 
251   // Sets the shapes and types for the device to host transfer associated with
252   // 'key'.
253   Status SetDeviceToHostMetadata(const string& key,
254                                  absl::Span<const DataType> types,
255                                  absl::Span<const TensorShape> shapes);
256 
257   // Gets the shapes the device to host transfer associated with 'key'.
258   Status GetDeviceToHostShapes(const string& key,
259                                std::vector<TensorShape>* shapes) const;
260 
261   // Sets the shapes and types for the host to device transfer associated with
262   // 'key'.
263   Status SetHostToDeviceMetadata(const string& key,
264                                  absl::Span<const DataType> types,
265                                  absl::Span<const TensorShape> shapes);
266 
267   // In order to avoid deadlocks from dependencies in host computations, it can
268   // be necessary to enforce a partial order on the execution of HostCompute
269   // Ops. In particular it may be necessary to constrain the SendToHost for one
270   // HostCompute to run before blocking on the RecvAtHost for another
271   // HostCompute. The compiler maintains a mapping from 'host_compute_name' to
272   // handle, where the handle is an 'output' of the HostCompute Op corresponding
273   // to 'host_compute_name'. Another HostCompute Op that needs to be sequenced
274   // later can add the handle as an 'input' to enforce the constraints.
275   // 'host_compute_name' can be any string the client wishes to use to identify
276   // a given HostCompute Op as long as the names are unique within the
277   // compilation.
278   Status GetHostComputeControlDependency(const string& host_compute_name,
279                                          xla::XlaOp* handle);
280   Status SetHostComputeControlDependency(const string& host_compute_name,
281                                          const xla::XlaOp& handle);
282 
options()283   const Options& options() const { return options_; }
client()284   xla::Client* client() const { return options_.client; }
flib_runtime()285   FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
286 
287   void PushNodeTokenMapping();
288   Status PopNodeTokenMapping();
289   Status SetNodeToken(const string& node_name, const xla::XlaOp& op);
290   xla::StatusOr<xla::XlaOp> GetNodeToken(const string& node_name);
291 
292   // Sets the function body `fbody` to the one registered as `function`.
293   Status FindFunctionBody(const NameAttrList& function,
294                           const FunctionBody** fbody,
295                           const ConfigProto** config_proto = nullptr);
296 
297  private:
298   // Returns the optimized graph object in this function body.
299   std::unique_ptr<Graph> GetGraph(const FunctionBody* fbody);
300 
301   // Builds XLA computations for each of the arguments to the computation.
302   // `args` are the arguments to the computation.
303   Status BuildArguments(const Graph& graph,
304                         const std::vector<XlaCompiler::Argument>& args,
305                         bool use_tuple_arg, xla::XlaBuilder* builder,
306                         XlaContext* context,
307                         const std::map<int, xla::OpSharding>& arg_shardings,
308                         std::vector<XlaExpression>* arg_expressions,
309                         std::vector<int>* input_to_args,
310                         std::vector<xla::Shape>* input_shapes,
311                         bool is_entry_computation);
312 
313   // Graph compiler needs to know how to get an optimized graph from a function
314   // body.
315   friend class GraphCompiler;
316   friend class XlaCompilerTest;
317 
318   Options options_;
319 
320   // Status set to non-OK in the constructor if initialization fails.
321   Status initialization_status_;
322 
323   // Returns the next step sequence number.
324   int64 NextStepId();
325 
326   // Internal sequence number for steps executed on the compilation device.
327   int64 next_step_id_;
328 
329   XlaCompilationDevice* device_;  // Owned by device_mgr_
330   StaticDeviceMgr device_mgr_;
331 
332   // To avoid copying the client's function library, use a local function
333   // library and runtime for functions created as part of the functionalize
334   // control flow transformation.
335   std::unique_ptr<FunctionLibraryDefinition> local_flib_def_;
336   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
337   std::unique_ptr<ProcessFunctionLibraryRuntime> local_pflr_;
338 
339   FunctionLibraryRuntime* local_flib_runtime_;  // owned by local_pflr_.
340   FunctionLibraryRuntime* flib_runtime_;        // owned by pflr_.
341 
342   struct SignatureHash {
343     uint64 operator()(
344         const std::pair<string, std::vector<Argument>>& signature) const;
345   };
346 
347   std::unordered_map<std::pair<string, std::vector<Argument>>,
348                      CompilationResult, SignatureHash>
349       cache_;
350 
351   std::unordered_map<string, xla::ChannelHandle> channels_;
352 
353   std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_sends_;
354   std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_recvs_;
355 
356   std::unordered_map<string, xla::XlaOp> host_compute_control_output_;
357 
358   // This is used to store <node name, token output> mapping. Side-effecting
359   // ops call SetNodeToken() to record its token output, so later side-effecting
360   // ops can use GetNodeToken() to get it and use it as token input.
361   //
362   // It's a stack because we need a mapping like this for each level of nested
363   // CompileGraph() call. In CompileGraph(), we will push a new mapping to the
364   // stack, and pop the mapping before returning.
365   std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_;
366 
367   TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
368 };
369 
370 
371 }  // namespace tensorflow
372 
373 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
374