1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
18 
19 #include <stack>
20 
21 #include "absl/types/span.h"
22 #include "absl/types/variant.h"
23 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
24 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
25 #include "tensorflow/compiler/tf2xla/xla_expression.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
27 #include "tensorflow/compiler/xla/client/local_client.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/core/common_runtime/device.h"
32 #include "tensorflow/core/common_runtime/device_mgr.h"
33 #include "tensorflow/core/common_runtime/function.h"
34 #include "tensorflow/core/framework/function.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/platform/env.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/notification.h"
39 #include "tensorflow/core/platform/thread_annotations.h"
40 #include "tensorflow/core/public/version.h"
41 
42 namespace tensorflow {
43 
44 class XlaContext;
45 
46 // The XlaCompiler class is responsible for compilation of a self-contained
47 // subgraph of a TensorFlow computation using the XLA linear algebra runtime.
48 // It does a symbolic execution of the graph starting from specific input
49 // shapes, using a JIT device to convert operators into XLA computations.
50 //
51 // XlaCompiler is typically invoked from an `XlaLaunch` operator once the
52 // shapes of all input parameters to the computation are known. This is
53 // because the symbolic execution requires known shapes for all operations.
54 //
55 // XlaCompiler compiles Tensorflow graphs that received inputs via _Arg nodes,
56 // and return outputs via _Retval nodes.
57 //
58 // The XlaCompiler requires one Argument struct for each _Arg index, that
59 // describes each argument. Arguments can be compile-time constants
60 // (kind kConstant), run-time parameters (kind kParameter), or resources
61 // (kind kResource).
62 //
63 // Only kParameter and initialized kResource arguments become runtime parameters
64 // to the generated XLA computation.
65 //
66 // The run-time outputs of the XLA computation are arranged in the following
67 // order:
68 //   +------------------+-----------------------------------------+
69 //   |  _Retval values  |  Updated values of kResource arguments  |
70 //   +------------------+-----------------------------------------+
71 // _Retval values are ordered by _Retval index, whereas kResource values are
72 // ordered by the original _Arg position of the variable.
73 //
74 // If a shape representation function is provided as part of
75 // XlaCompiler::CompileOptions, kParameter arguments and return values to an
76 // entry computation will be reshaped in accordance to the shape function.
77 // Arguments and return values to a non-entry computation are not reshaped.
78 // Variable resource arguments are passed and returned in reshaped form, even
79 // for non-entry computations. This feature allows TensorFlow to keep on-device
80 // tensors with a different shape to their representation inside the XLA
81 // computation.
82 //
83 // In computation outputs, updated kResource values are placed the end. When
84 // emitting While loop bodies, we must ensure that the loop body has
85 // identical input and output signatures. By passing variable values
86 // at the end of the argument list and using the
87 // `return_updated_values_for_all_variables` option, we can ensure that the
88 // input and output values of resources appear at the same positions.
89 //
90 // Resources are passed as parameters or returned as resource updates in
91 // "packed" form.
92 // kStack resources are packed as (array, size of stack) XLA tuples.
93 // kTensorArray resources without gradients are packed as the array that
94 // backs the TensorArray. If gradients are present (`tensor_array_gradients`),
95 // the packed representation is a (array, gradient0, gradient1, ...) tuple,
96 // where gradient_k is the value of the k-th gradient in the
97 // `tensor_array_gradients` ordered set.
98 class XlaCompiler {
99  public:
100   // Describes how to derive the value of each _Arg node in the graph/function
101   // being compiled. There must be one Argument for each _Arg index.
102   struct Argument {
103     enum Kind {
104       // Default value; not a valid kind.
105       kInvalid,
106 
107       // Argument is a compile-time constant. No associated runtime parameter.
108       kConstant,
109 
110       // Argument is a Variable, TensorArray, or Stack resource. Has an
111       // associated runtime parameter iff `initialized` is true.
112       kResource,
113 
114       // Argument is a run-time parameter.
115       kParameter,
116 
117       // Argument is an XLA token.
118       kToken,
119     };
120 
121     Kind kind = kInvalid;
122 
123     // The type of the argument. If the argument is a resource, this
124     // is the type of the variable's value, not DT_RESOURCE.
125     DataType type = DT_INVALID;
126 
127     // The shape of the argument. For:
128     // * a parameter: the shape of the parameter. We allow setting the xla shape
129     //   if known. This helps avoid conversions to and from TensorShape.
130     // * a constant: ignored; the shape given by constant_value is used
131     //     instead.
132     // * an uninitialized resource: ignored. We don't yet know the shape of an
133     //     uninitialized resource (otherwise we would have initialized it!)
134     // * an initialized variable: the shape of the variable's value.
135     // * an initialized TensorArray or Stack resource: the shape of an entry in
136     //   the TensorArray/Stack. Note this is the size of a single entry, not the
137     //   XLA data structure that represents the complete stack/array.
138     absl::variant<TensorShape, xla::Shape> shape;
139 
140     // The value of the argument, if it is a compile-time constant. Must be a
141     // host-memory tensor.
142     Tensor constant_value;
143 
144     // The name of this argument, used for debugging.
145     string name;
146 
147     // For a kResource, what kind of resource is it?
148     XlaResource::Kind resource_kind = XlaResource::kInvalid;
149 
150     // For a kResource, has this resource been initialized?
151     bool initialized = false;
152 
153     // For a TensorArray or Stack resource, what is the array's declared size?
154     // (Used for lazy initialization.)
155     int64 max_array_size = -1;
156 
157     // TensorArray resource parameters are passed as (array, gradient array 0,
158     // ..., gradient array k), where the gradient arrays are in the same order
159     // as `tensor_array_gradients`.
160     std::set<string> tensor_array_gradients;
161 
162     // dynamic dims to arg number map. Empty if no dynamic shapes.
163     std::map<int32, int32> dynamic_dim_to_arg_num_map;
164     bool is_pad_arg = false;
165 
166     bool operator==(const Argument& other) const;
167 
168     // Returns a human-readable summary of the argument.
169     string HumanString() const;
170 
171     // Returns the dimension sizes for either TensorShape or xla::Shape.
172     std::vector<int64> DimensionSizes() const;
173 
174     // Returns the human-readable string for either TensorShape or xla::Shape.
175     string ShapeHumanString() const;
176   };
177 
178   // Options pertaining to an individual call to CompileGraph() or
179   // CompileFunction().
180   struct CompileOptions {
181     // If `use_tuple_arg` is true, a single tuple parameter will be used for all
182     // arguments; if false, each argument gets its own parameter.
183     bool use_tuple_arg = false;
184 
185     // If 'return_updated_values_for_all_resources' is true, then updated
186     // values of all resource arguments will be included in the
187     // 'resource_updates' of the computation, even if the resource was not
188     // modified by the computation. Used when compiling loop bodies to ensure
189     // the input and output signatures match.
190     bool return_updated_values_for_all_resources = false;
191 
192     // If 'resolve_compile_time_constants' is true, then outputs of a
193     // computation that are known to be compile-time constants will be returned
194     // as Tensors at compile-time, rather than as run-time outputs of the
195     // computation.
196     bool resolve_compile_time_constants = true;
197 
198     // If 'always_return_tuple' is true, then the output of a computation will
199     // always be a tuple. Otherwise, a single-element output will not be wrapped
200     // in a tuple.
201     bool always_return_tuple = true;
202 
203     // True when compiling the entry computation, false for subcomputations
204     // (while, call, etc.)
205     bool is_entry_computation = true;
206 
207     // True when we should add XLA input & output to the graph/function.
208     bool add_token_input_output = false;
209   };
210 
211   struct OutputDescription {
212     // Type and shape of the output. The shape is the unflattened shape.
213     // When `type` is DT_RESOURCE, `shape` is the shape of the resource
214     // variable's value.
215     DataType type;
216     TensorShape shape;
217 
218     // Constant output value, if known to be constant at JIT compilation time.
219     // 'Tensor' is in host memory.
220     bool is_constant = false;
221     Tensor constant_value;
222 
223     // When this output is a resource, i.e. `type == DT_RESOURCE`, this is
224     // the index of the input that contains the resource.
225     int input_index;
226   };
227 
228   // Describes a variable write side effect of the computation.
229   struct ResourceUpdate {
230     // Index of the input that contains the variable resource to write to.
231     int input_index;
232 
233     // Type and shape of the tensor to be written back.
234     // The `shape` field has the same meaning as the Argument::shape field.
235     DataType type;
236     TensorShape shape;
237 
238     // Was the value of the variable modified by the computation?
239     // (Always true, unless `return_updated_values_for_all_resources` is true.)
240     bool modified;
241 
242     // If the resource is a TensorArray, the set of gradients read or written.
243     std::set<string> tensor_array_gradients_accessed;
244   };
245 
246   struct CompilationResult {
247     // Vector that maps from the parameters of the XLA computation to their
248     // original argument positions. To handle compile-time constant inputs, the
249     // parameters to the XLA computation may be a subset of the original
250     // arguments. The relative ordering of parameters are maintained.
251     std::vector<int> input_mapping;
252 
253     // Input shapes of the computation. If we are flattening inputs, these are
254     // the flattened shapes.
255     std::vector<xla::Shape> xla_input_shapes;
256 
257     // Output shape in XLA format. The output shape is always a tuple. If we
258     // are flattening outputs, these are the flattened shapes.
259     xla::Shape xla_output_shape;
260 
261     // TensorFlow shapes of outputs, together with the values of any
262     // constant arguments. Vector indexed by Tensorflow _Retval number,
263     // containing both constant and non-constant results.
264     std::vector<OutputDescription> outputs;
265 
266     // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
267     // matching RecvAtHost/SendFromHost Ops in the outer graph.
268     tf2xla::HostComputeMetadata host_compute_metadata;
269 
270     // Resources whose values were updated by the computation, ordered
271     // by return value position (which is the same as the order the resources
272     // were passed as arguments). Resource updates follow the non-constant
273     // results in the outputs of XLA computation.
274     std::vector<ResourceUpdate> resource_updates;
275 
276     // The XLA computation built from the tensorflow subgraph.
277     std::shared_ptr<xla::XlaComputation> computation;
278   };
279 
280   typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType)>
281       ShapeRepresentationFn;
282   struct Options {
283     // Name of the compilation device to use. It must be set by the caller.
284     // The default empty value is invalid.
285     DeviceType device_type = DeviceType("");
286 
287     // The device to use during compilation to execute instructions on, for
288     // example for auto-tuning.
289     // Valid values are defined by `xla::Backend::devices_ordinal_supported()`.
290     // -1 indicates the default device should be used.
291     int device_ordinal = -1;
292 
293     xla::Client* client = nullptr;
294 
295     // Function library in which to find function definitions. Must be non-null.
296     const FunctionLibraryDefinition* flib_def = nullptr;
297 
298     // The graph def version to be compiled.
299     int graph_def_version = TF_GRAPH_DEF_VERSION;
300 
301     // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall()
302     // for CPU.
303     bool allow_cpu_custom_calls = false;
304 
305     // If set, the XLA representation of variables represented to XLA as the
306     // shape given by this shape function. Variables are reshaped to this shape
307     // on write, and reshaped to their original shape on read.
308     ShapeRepresentationFn shape_representation_fn;
309 
310     // If not nullptr, populate_resource_manager is called with the
311     // compilation device's resource manager when the compilation
312     // device is created, and can be used to create metadata objects
313     // that can be accessed by XLA op kernels.
314     std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr;
315 
316     // If not nullptr, this memory allocator can be used by the compiler for
317     // temporary allocations it might want to make during compilation.
318     //
319     // For example, the compiler may want to try out different algorithms and
320     // choose the fastest one, and it might run those algorithms over buffers
321     // created using this allocator.
322     //
323     // The compiler can function correctly without an explicit allocator given
324     // here, but on some devices (notably, GPUs), TensorFlow tends to eagerly
325     // allocate most or all available memory on the device, leaving none for the
326     // compiler to access, unless it can use TensorFlow's allocator.
327     xla::DeviceMemoryAllocator* device_allocator = nullptr;
328   };
329 
330   explicit XlaCompiler(Options options);
331 
332   ~XlaCompiler();
333 
334   Status CompileFunction(const CompileOptions& options,
335                          const NameAttrList& fn_name_attrs,
336                          absl::Span<const Argument> args,
337                          CompilationResult* result);
338 
339   // Compiles a tensorflow::Graph into an xla::XlaComputation.
340   // Similar to CompileFunction, but takes a Graph as input rather than a
341   // function.
342   Status CompileGraph(
343       const CompileOptions& options, string const& name,
344       std::unique_ptr<Graph> graph, absl::Span<const Argument> args,
345       absl::Span<const xla::XlaBuilder::InputOutputAlias> user_aliases,
346       CompilationResult* result);
347 
348   // Compiles a single Op, given by `node_def`, into an
349   // xla::XlaComputation. Similar to CompileFunction but takes a single Op as
350   // input.
351   Status CompileSingleOp(const CompileOptions& options, const NodeDef& node_def,
352                          absl::Span<const Argument> args,
353                          absl::Span<const DataType> result_types,
354                          CompilationResult* result);
355 
356   // Returns the shape of the XLA parameter for an argument 'arg'.
357   // See the class comment for more details about the argument passing
358   // convention.
359   Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation,
360                              xla::Shape* xla_shape) const;
361 
362   // Retrieves the channel handle associated with `key`. Allocates
363   // a new channel handle if none exists.
364   // Channel handles can be used to communicate between different
365   // computations. Computations that communicate should be compiled with the
366   // same XlaCompiler.
367   Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
368 
369   // Retrieves the host-to-device channel handle associated with `key`.
370   // Allocates a new channel handle if none exists.
371   Status GetHostToDeviceChannelHandle(const string& key,
372                                       xla::ChannelHandle* channel);
373 
374   // Retrieves the device-to-host channel handle associated with `key`.
375   // Allocates a new channel handle if none exists.
376   Status GetDeviceToHostChannelHandle(const string& key,
377                                       xla::ChannelHandle* channel);
378 
379   // Sets the shapes and types for the device to host transfer associated with
380   // 'key'.
381   Status SetDeviceToHostMetadata(const string& key,
382                                  absl::Span<const DataType> types,
383                                  absl::Span<const TensorShape> shapes);
384 
385   // Gets the shapes the device to host transfer associated with 'key'.
386   Status GetDeviceToHostShapes(const string& key,
387                                std::vector<TensorShape>* shapes) const;
388 
389   // Sets the shapes and types for the host to device transfer associated with
390   // 'key'.
391   Status SetHostToDeviceMetadata(const string& key,
392                                  absl::Span<const DataType> types,
393                                  absl::Span<const TensorShape> shapes);
394 
395   // In order to avoid deadlocks from dependencies in host computations, it can
396   // be necessary to enforce a partial order on the execution of HostCompute
397   // Ops. In particular it may be necessary to constrain the SendToHost for one
398   // HostCompute to run before blocking on the RecvAtHost for another
399   // HostCompute. The compiler maintains a mapping from 'host_compute_name' to
400   // handle, where the handle is an 'output' of the HostCompute Op corresponding
401   // to 'host_compute_name'. Another HostCompute Op that needs to be sequenced
402   // later can add the handle as an 'input' to enforce the constraints.
403   // 'host_compute_name' can be any string the client wishes to use to identify
404   // a given HostCompute Op as long as the names are unique within the
405   // compilation.
406   Status GetHostComputeControlDependency(const string& host_compute_name,
407                                          xla::XlaOp* handle);
408   Status SetHostComputeControlDependency(const string& host_compute_name,
409                                          const xla::XlaOp& handle);
410 
options()411   const Options& options() const { return options_; }
client()412   xla::Client* client() const { return options_.client; }
flib_runtime()413   FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
414 
415   void PushNodeTokenMapping();
416   Status PopNodeTokenMapping();
417   Status SetNodeToken(const string& node_name, const xla::XlaOp& op);
418   xla::StatusOr<xla::XlaOp> GetNodeToken(const string& node_name);
419 
420   // Sets the function body `fbody` to the one registered as `function`.
421   Status FindFunctionBody(const NameAttrList& function,
422                           const FunctionBody** fbody);
423 
424  private:
425   // Returns the optimized graph object in this function body.
426   std::unique_ptr<Graph> GetGraph(const FunctionBody* fbody);
427 
428   // Builds XLA computations for each of the arguments to the computation.
429   // `args` are the arguments to the computation.
430   Status BuildArguments(const Graph& graph,
431                         const std::vector<XlaCompiler::Argument>& args,
432                         bool use_tuple_arg, xla::XlaBuilder* builder,
433                         XlaContext* context,
434                         const std::map<int, int>& arg_cores,
435                         std::vector<XlaExpression>* arg_expressions,
436                         std::vector<int>* input_to_args,
437                         std::vector<xla::Shape>* input_shapes,
438                         bool is_entry_computation);
439 
440   // Graph compiler needs to know how to get an optimized graph from a function
441   // body.
442   friend class GraphCompiler;
443   friend class XlaCompilerTest;
444 
445   Options options_;
446 
447   // Status set to non-OK in the constructor if initialization fails.
448   Status initialization_status_;
449 
450   // Returns the next step sequence number.
451   int64 NextStepId();
452 
453   // Internal sequence number for steps executed on the compilation device.
454   int64 next_step_id_;
455 
456   XlaCompilationDevice* device_;  // Owned by device_mgr_
457   DeviceMgr device_mgr_;
458 
459   // To avoid copying the client's function library, use a local function
460   // library and runtime for functions created as part of the functionalize
461   // control flow transformation.
462   std::unique_ptr<FunctionLibraryDefinition> local_flib_def_;
463   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
464   std::unique_ptr<ProcessFunctionLibraryRuntime> local_pflr_;
465 
466   FunctionLibraryRuntime* local_flib_runtime_;  // owned by local_pflr_.
467   FunctionLibraryRuntime* flib_runtime_;        // owned by pflr_.
468 
469   struct SignatureHash {
470     uint64 operator()(
471         const std::pair<string, std::vector<Argument>>& signature) const;
472   };
473 
474   std::unordered_map<std::pair<string, std::vector<Argument>>,
475                      CompilationResult, SignatureHash>
476       cache_;
477 
478   std::unordered_map<string, xla::ChannelHandle> channels_;
479 
480   std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_sends_;
481   std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_recvs_;
482 
483   std::unordered_map<string, xla::XlaOp> host_compute_control_output_;
484 
485   // This is used to store <node name, token output> mapping. Side-effecting
486   // ops call SetNodeToken() to record its token output, so later side-effecting
487   // ops can use GetNodeToken() to get it and use it as token input.
488   //
489   // It's a stack because we need a mapping like this for each level of nested
490   // CompileGraph() call. In CompileGraph(), we will push a new mapping to the
491   // stack, and pop the mapping before returning.
492   std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_;
493 
494   TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
495 };
496 
497 }  // namespace tensorflow
498 
499 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
500