1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines helper routines for XLA compilation.
17 
18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
19 #include "tensorflow/compiler/tf2xla/lib/util.h"
20 
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/tf2xla/literal_util.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/type_util.h"
25 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
26 #include "tensorflow/compiler/xla/client/lib/constants.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/client/xla_computation.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/lib/core/status.h"
32 
33 namespace tensorflow {
34 
Zero(xla::XlaBuilder * b,DataType data_type)35 xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) {
36   xla::PrimitiveType type;
37   TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
38   return xla::ConstantLiteral(b, xla::LiteralUtil::Zero(type));
39 }
40 
One(xla::XlaBuilder * b,DataType data_type)41 xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) {
42   xla::PrimitiveType type;
43   TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
44   return xla::ConstantLiteral(b, xla::LiteralUtil::One(type));
45 }
46 
IntegerLiteral(xla::XlaBuilder * b,DataType data_type,int64 value)47 xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type,
48                                       int64 value) {
49   xla::PrimitiveType type;
50   TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
51   return ::tensorflow::IntegerLiteral(b, type, value);
52 }
53 
FloatLiteral(xla::XlaBuilder * b,DataType data_type,double value)54 xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type,
55                                     double value) {
56   xla::PrimitiveType type;
57   TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
58   return ::tensorflow::FloatLiteral(b, type, value);
59 }
60 
ReshapeLiteral(const xla::Literal & input,absl::Span<const int64> dimensions,xla::Literal * output)61 /* static */ Status XlaHelpers::ReshapeLiteral(
62     const xla::Literal& input, absl::Span<const int64> dimensions,
63     xla::Literal* output) {
64   if (input.shape().IsTuple()) {
65     return errors::InvalidArgument("ReshapeLiteral does not support tuples.");
66   }
67   xla::Shape shape =
68       xla::ShapeUtil::MakeShape(input.shape().element_type(), dimensions);
69   int64 elements_before = xla::ShapeUtil::ElementsIn(input.shape());
70   int64 elements_after = xla::ShapeUtil::ElementsIn(shape);
71   if (elements_before != elements_after) {
72     return errors::InvalidArgument(
73         "Shapes before and after ReshapeLiteral have different numbers of "
74         "elements.");
75   }
76 
77   *output = input.Clone();
78   output->mutable_shape_do_not_use()->Swap(&shape);
79   return Status::OK();
80 }
81 
OneHot(xla::XlaBuilder * builder,int64 depth,int axis,DataType index_type,const TensorShape & indices_shape,const xla::XlaOp & indices,const xla::XlaOp & on_value,const xla::XlaOp & off_value,xla::XlaOp * one_hot)82 Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
83                           DataType index_type, const TensorShape& indices_shape,
84                           const xla::XlaOp& indices, const xla::XlaOp& on_value,
85                           const xla::XlaOp& off_value, xla::XlaOp* one_hot) {
86   // Broadcast the linspace constant across the indices along the new axis,
87   // and test equality at each position.
88   std::vector<int64> broadcast_dims(indices_shape.dims());
89   std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
90   std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
91 
92   TensorShape output_shape = indices_shape;
93   output_shape.InsertDim(axis, depth);
94   xla::Shape iota_shape;
95   TF_RETURN_IF_ERROR(
96       TensorShapeToXLAShape(index_type, output_shape, &iota_shape));
97 
98   // Selects the user-provided off_value and on_value values.
99   *one_hot = xla::Select(
100       xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims),
101       xla::Broadcast(on_value, output_shape.dim_sizes()),
102       xla::Broadcast(off_value, output_shape.dim_sizes()));
103   return Status::OK();
104 }
105 
SumAccumulationType(const DataType & dtype)106 DataType XlaHelpers::SumAccumulationType(const DataType& dtype) {
107   // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
108   // repeated floating point additions.
109   if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
110     return DT_FLOAT;
111   }
112   // Upcast small integer types to 32 bit to avoid overflow.
113   if (dtype == DT_INT8 || dtype == DT_INT16) {
114     return DT_INT32;
115   }
116   if (dtype == DT_UINT8 || dtype == DT_UINT16) {
117     return DT_UINT32;
118   }
119   return dtype;
120 }
121 
ConvertElementType(const xla::XlaOp & operand,const DataType new_element_type)122 xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp& operand,
123                                           const DataType new_element_type) {
124   xla::PrimitiveType convert_to;
125   TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to));
126   return xla::ConvertElementType(operand, convert_to);
127 }
128 
IdentityShapeRepresentationFn()129 XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn() {
130   return [](const TensorShape& shape, DataType dtype,
131             bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
132     xla::Shape xla_shape;
133     TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
134     return xla_shape;
135   };
136 }
137 
138 // Rewrites the layout of xla_shape if there is tiled sharding.
RewriteLayoutWithShardedShape(const absl::optional<xla::HloSharding> & sharding,bool use_fast_memory,XlaHelpers::ShapeRepresentationFn shape_representation_fn,xla::Shape * xla_shape)139 Status RewriteLayoutWithShardedShape(
140     const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
141     XlaHelpers::ShapeRepresentationFn shape_representation_fn,
142     xla::Shape* xla_shape) {
143   if (sharding && !sharding->IsTileMaximal() && !sharding->IsManual()) {
144     // After sharding, per core shape might have different layout. For example,
145     // before sharding, a shape [128, 128] will be assigned default
146     // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
147     // the sharded shapes will have minor-to-major {0, 1}.
148     //
149     // As a result, for sharded shapes, we set their layout to per core shape's
150     // layout.
151     //
152     // TODO(endlessroad): for variable input & update, we might have
153     // different layouts which will prevent input output aliasing and
154     // increase memory usage. Investigate such cases.
155     int64 device = *sharding->tile_assignment().begin();
156     std::vector<int64> offset =
157         sharding->TileOffsetForDevice(*xla_shape, device);
158     std::vector<int64> limit = sharding->TileLimitForDevice(*xla_shape, device);
159     std::vector<int64> dimensions(xla_shape->rank());
160     for (int64 i = 0; i < xla_shape->rank(); ++i) {
161       dimensions[i] = limit[i] - offset[i];
162     }
163     xla::Shape per_device_xla_shape =
164         xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
165     TensorShape per_device_tensor_shape;
166     TF_RETURN_IF_ERROR(
167         XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape));
168     TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
169                                             xla_shape->element_type()));
170     TF_ASSIGN_OR_RETURN(per_device_xla_shape,
171                         shape_representation_fn(per_device_tensor_shape, dtype,
172                                                 use_fast_memory));
173     *xla_shape->mutable_layout() = per_device_xla_shape.layout();
174   }
175   return Status::OK();
176 }
177 
178 // There is a shape_representation_fn or sharding for an output, this function
179 // uses a reshape to fix the layout.
ReshapeWithCorrectRepresentationAndSharding(xla::XlaBuilder * builder,xla::XlaOp original,xla::Shape original_shape,XlaHelpers::ShapeRepresentationFn shape_representation_fn,absl::optional<xla::OpSharding> sharding,bool fast_mem)180 xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
181     xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
182     XlaHelpers::ShapeRepresentationFn shape_representation_fn,
183     absl::optional<xla::OpSharding> sharding, bool fast_mem) {
184   if (original_shape.IsTuple()) {
185     std::vector<xla::XlaOp> elements;
186     for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) {
187       auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
188       TF_ASSIGN_OR_RETURN(auto element,
189                           ReshapeWithCorrectRepresentationAndSharding(
190                               builder, xla::GetTupleElement(original, i),
191                               original_shape.tuple_shapes(i),
192                               shape_representation_fn, subsharding, fast_mem));
193       elements.push_back(element);
194     }
195     return xla::Tuple(builder, elements);
196   }
197   if (!original_shape.IsArray()) return original;
198   TensorShape shape;
199   TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
200   TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
201                                           original_shape.element_type()));
202   TF_ASSIGN_OR_RETURN(auto to_shape,
203                       shape_representation_fn(shape, dtype, fast_mem));
204   if (sharding) {
205     TF_ASSIGN_OR_RETURN(auto hlo_sharding,
206                         xla::HloSharding::FromProto(*sharding));
207     TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
208         hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
209   }
210   if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
211     for (int64 i = 0; i < original_shape.rank(); ++i) {
212       to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
213     }
214   }
215   return xla::Reshape(to_shape, original);
216 }
217 
218 }  // end namespace tensorflow
219