1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines helper routines for the XLA device.
17 
18 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
19 #define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
20 
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
25 #include "tensorflow/core/framework/tensor.h"
26 
27 namespace tensorflow {
28 
29 // Helper methods for building XLA computations.
30 class XlaHelpers {
31  public:
32   // Returns a handle representing the zero value of a scalar
33   // element of data_type.
34   static xla::XlaOp Zero(xla::XlaBuilder* b, DataType data_type);
35 
36   // Returns a handle representing the one value of a scalar
37   // element of data_type.
38   static xla::XlaOp One(xla::XlaBuilder* b, DataType data_type);
39 
40   // Returns a handle representing the given value of an integer scalar
41   // element of data_type.
42   // Note that unlike One and Zero, does not work on boolean types.
43   static xla::XlaOp IntegerLiteral(xla::XlaBuilder* b, DataType data_type,
44                                    int64 value);
45 
46   // Returns a handle representing the given value of a floating-point scalar
47   // element of data_type.
48   static xla::XlaOp FloatLiteral(xla::XlaBuilder* b, DataType data_type,
49                                  double value);
50 
51   // Reshapes literal 'input' to have 'shape'. Both the original shape and
52   // 'shape' must contain the same number of elements.
53   static Status ReshapeLiteral(const xla::Literal& input,
54                                absl::Span<const int64> shape,
55                                xla::Literal* output);
56 
57   // Converts `indices` into a one-hot representation. `depth` is the size
58   // of the new axis to add. `axis` is the position at which to add the new
59   // axis. `indices_shape` is the shape of `indices`. `on_value` and
60   // `off_value` represent the values to use for the on and off positions,
61   // respectively.
62   static Status OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
63                        DataType index_type, const TensorShape& indices_shape,
64                        const xla::XlaOp& indices, const xla::XlaOp& on_value,
65                        const xla::XlaOp& off_value, xla::XlaOp* one_hot);
66 
67   // Certain DataTypes should use increased precision DataTypes when performing
68   // reductions.  This function remaps a given DataType to a higher precision
69   // DataType if needed.
70   static DataType SumAccumulationType(const DataType& dtype);
71 
72   // A helper for creating a ConvertElementType xla op given a DataType rather
73   // than the xla::PrimitiveType.
74   static xla::XlaOp ConvertElementType(const xla::XlaOp& operand,
75                                        const DataType new_element_type);
76 
77   typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType,
78                                                   bool)>
79       ShapeRepresentationFn;
80 };
81 
82 // Creates an identity shape representation function.
83 XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn();
84 
85 // Rewrites the layout of xla_shape if there is tiled sharding.
86 Status RewriteLayoutWithShardedShape(
87     const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
88     XlaHelpers::ShapeRepresentationFn shape_representation_fn,
89     xla::Shape* xla_shape);
90 
91 // Adds reshapes to fix the layout of an output, if a shape_representation_fn or
92 // sharding is present.
93 xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
94     xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
95     XlaHelpers::ShapeRepresentationFn shape_representation_fn,
96     absl::optional<xla::OpSharding> sharding, bool fast_mem);
97 
98 struct XlaOutputDescription {
99   // Type and shape of the output. The shape is the unflattened shape.
100   // When `type` is DT_RESOURCE, `shape` is the shape of the resource
101   // variable's value.
102   DataType type;
103   TensorShape shape;
104 
105   // Constant output value, if known to be constant at JIT compilation time.
106   // 'Tensor' is in host memory.
107   bool is_constant = false;
108   Tensor constant_value;
109 
110   // When this output is a resource, i.e. `type == DT_RESOURCE`, this is
111   // the index of the input that contains the resource.
112   int input_index;
113 
114   // Whether this output is a TensorList.
115   bool is_tensor_list = false;
116 };
117 
118 // Describes a variable write side effect of the computation.
119 struct XlaResourceUpdate {
120   // Index of the input that contains the variable resource to write to.
121   int input_index;
122 
123   // Type and shape of the tensor to be written back.
124   // The `shape` field has the same meaning as the Argument::shape field.
125   DataType type;
126   TensorShape shape;
127 
128   // Was the value of the variable modified by the computation?
129   // (Always true, unless `return_updated_values_for_all_resources` is true.)
130   bool modified;
131 
132   // If the resource is a TensorArray, the set of gradients read or written.
133   std::set<string> tensor_array_gradients_accessed;
134 };
135 
136 struct XlaCompilationResult {
137   // Vector that maps from the parameters of the XLA computation to their
138   // original argument positions. To handle compile-time constant inputs, the
139   // parameters to the XLA computation may be a subset of the original
140   // arguments. The relative ordering of parameters are maintained.
141   std::vector<int> input_mapping;
142 
143   // Input shapes of the computation. If we are flattening inputs, these are
144   // the flattened shapes.
145   std::vector<xla::Shape> xla_input_shapes;
146 
147   // Output shape in XLA format. The output shape is always a tuple. If we
148   // are flattening outputs, these are the flattened shapes.
149   xla::Shape xla_output_shape;
150 
151   // TensorFlow shapes of outputs, together with the values of any
152   // constant arguments. Vector indexed by Tensorflow _Retval number,
153   // containing both constant and non-constant results.
154   std::vector<XlaOutputDescription> outputs;
155 
156   // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
157   // matching RecvAtHost/SendFromHost Ops in the outer graph.
158   tf2xla::HostComputeMetadata host_compute_metadata;
159 
160   // Resources whose values were updated by the computation, ordered
161   // by return value position (which is the same as the order the resources
162   // were passed as arguments). Resource updates follow the non-constant
163   // results in the outputs of XLA computation.
164   std::vector<XlaResourceUpdate> resource_updates;
165 
166   // The XLA computation built from the tensorflow subgraph.
167   std::shared_ptr<xla::XlaComputation> computation;
168 };
169 
170 }  // end namespace tensorflow
171 
172 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
173