1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
17
18 #include <string>
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_computation.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28
29 namespace xla {
30
CreateScalarComputation(const string & name,PrimitiveType type,XlaBuilder * builder,XlaOpGenerator generator)31 XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
32 XlaBuilder* builder,
33 XlaOpGenerator generator) {
34 std::unique_ptr<XlaBuilder> b;
35 if (type == PRED) {
36 b = builder->CreateSubBuilder(name);
37 } else {
38 b = builder->CreateSubBuilder(
39 absl::StrCat(name, "_", PrimitiveType_Name(type)));
40 }
41
42 const Shape scalar = ShapeUtil::MakeShape(type, {});
43 auto lhs = Parameter(b.get(), 0, scalar, "lhs");
44 auto rhs = Parameter(b.get(), 1, scalar, "rhs");
45 generator(lhs, rhs);
46 return b->BuildAndNoteError();
47 }
48
CreateScalarAddComputation(PrimitiveType type,XlaBuilder * builder)49 XlaComputation CreateScalarAddComputation(PrimitiveType type,
50 XlaBuilder* builder) {
51 return CreateScalarComputation(
52 "add", type, builder, [](XlaOp lhs, XlaOp rhs) { return Add(lhs, rhs); });
53 }
54
CreateScalarMultiplyComputation(PrimitiveType type,XlaBuilder * builder)55 XlaComputation CreateScalarMultiplyComputation(PrimitiveType type,
56 XlaBuilder* builder) {
57 return CreateScalarComputation(
58 "mul", type, builder, [](XlaOp lhs, XlaOp rhs) { return Mul(lhs, rhs); });
59 }
60
CreateScalarGeComputation(PrimitiveType type,XlaBuilder * builder)61 XlaComputation CreateScalarGeComputation(PrimitiveType type,
62 XlaBuilder* builder) {
63 return CreateScalarComputation(
64 "ge", type, builder, [](XlaOp lhs, XlaOp rhs) { return Ge(lhs, rhs); });
65 }
66
CreateScalarMaxComputation(PrimitiveType type,XlaBuilder * builder)67 XlaComputation CreateScalarMaxComputation(PrimitiveType type,
68 XlaBuilder* builder) {
69 return CreateScalarComputation(
70 "max", type, builder, [](XlaOp lhs, XlaOp rhs) { return Max(lhs, rhs); });
71 }
72
CreateScalarMinComputation(PrimitiveType type,XlaBuilder * builder)73 XlaComputation CreateScalarMinComputation(PrimitiveType type,
74 XlaBuilder* builder) {
75 return CreateScalarComputation(
76 "min", type, builder, [](XlaOp lhs, XlaOp rhs) { return Min(lhs, rhs); });
77 }
78
CreateScalarAndComputation(PrimitiveType type,XlaBuilder * builder)79 XlaComputation CreateScalarAndComputation(PrimitiveType type,
80 XlaBuilder* builder) {
81 return CreateScalarComputation(
82 "and", type, builder, [](XlaOp lhs, XlaOp rhs) { return And(lhs, rhs); });
83 }
84
CreateScalarOrComputation(PrimitiveType type,XlaBuilder * builder)85 XlaComputation CreateScalarOrComputation(PrimitiveType type,
86 XlaBuilder* builder) {
87 return CreateScalarComputation(
88 "or", type, builder, [](XlaOp lhs, XlaOp rhs) { return Or(lhs, rhs); });
89 }
90
CreateScalarIdentityWithZeroComputation(PrimitiveType type,XlaBuilder * builder)91 XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type,
92 XlaBuilder* builder) {
93 XlaComputation reducer =
94 (primitive_util::IsIntegralType(type) || type == PRED)
95 ? CreateScalarOrComputation(type, builder)
96 : CreateScalarAddComputation(type, builder);
97 return reducer;
98 }
99
Any(XlaOp predicates)100 XlaOp Any(XlaOp predicates) {
101 XlaBuilder* builder = predicates.builder();
102 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
103 auto f = ConstantR0<bool>(builder, false);
104 XlaComputation logical_or = CreateScalarOrComputation(PRED, builder);
105 TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
106 builder->GetShape(predicates));
107 std::vector<int64> all_dimensions(predicates_shape.rank());
108 std::iota(all_dimensions.begin(), all_dimensions.end(), 0);
109 return Reduce(predicates, f, logical_or, all_dimensions);
110 });
111 }
112
113 namespace {
114
CreateMinMaxComputation(XlaBuilder * outer_builder,PrimitiveType value_type,PrimitiveType index_type,bool is_min,bool stable,bool tie_low)115 XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder,
116 PrimitiveType value_type,
117 PrimitiveType index_type, bool is_min,
118 bool stable, bool tie_low) {
119 auto sub_builder = outer_builder->CreateSubBuilder("minmax_func");
120 XlaBuilder* b = sub_builder.get();
121 XlaOp lhs_value =
122 Parameter(b, 0, ShapeUtil::MakeShape(value_type, {}), "lhs_value");
123 XlaOp lhs_index =
124 Parameter(b, 1, ShapeUtil::MakeShape(index_type, {}), "lhs_index");
125 XlaOp rhs_value =
126 Parameter(b, 2, ShapeUtil::MakeShape(value_type, {}), "rhs_value");
127 XlaOp rhs_index =
128 Parameter(b, 3, ShapeUtil::MakeShape(index_type, {}), "rhs_index");
129
130 XlaOp cmp = is_min ? Le(lhs_value, rhs_value) : Ge(lhs_value, rhs_value);
131 XlaOp max = Select(cmp, lhs_value, rhs_value);
132 XlaOp arg_max = Select(cmp, lhs_index, rhs_index);
133 if (stable) {
134 XlaOp eq = Eq(lhs_value, rhs_value);
135 XlaOp tie_id =
136 tie_low ? Min(lhs_index, rhs_index) : Max(lhs_index, rhs_index);
137 arg_max = Select(eq, tie_id, arg_max);
138 }
139 Tuple(b, {max, arg_max});
140 return b->BuildAndNoteError();
141 }
142
ArgMinMax(XlaOp input,PrimitiveType output_type,int axis,bool is_min,bool stable,bool tie_low)143 XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min,
144 bool stable, bool tie_low) {
145 XlaBuilder* builder = input.builder();
146 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
147 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
148 XlaOp value_init_value;
149 if (is_min) {
150 value_init_value = MaxValue(builder, input_shape.element_type());
151 } else {
152 value_init_value = MinValue(builder, input_shape.element_type());
153 }
154 int64 dimension_size = input_shape.dimensions(axis);
155 auto index_type = dimension_size <= INT32_MAX ? S32 : output_type;
156 XlaOp index_init_value = Zero(builder, index_type);
157 auto iota_shape = input_shape;
158 iota_shape.set_element_type(index_type);
159 XlaOp iota = Iota(builder, iota_shape, axis);
160
161 XlaComputation reducer =
162 CreateMinMaxComputation(builder, input_shape.element_type(), index_type,
163 is_min, stable, tie_low);
164 XlaOp max_argmax = Reduce(builder, {input, iota},
165 {value_init_value, index_init_value}, reducer,
166 /*dimensions_to_reduce=*/{axis});
167 XlaOp argmax = GetTupleElement(max_argmax, 1);
168 if (index_type != output_type) {
169 argmax = ConvertElementType(argmax, output_type);
170 }
171 return argmax;
172 });
173 }
174
ArgMinMaxTwoPass(XlaOp input,PrimitiveType output_type,int axis,bool is_min,bool tie_low)175 XlaOp ArgMinMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
176 bool is_min, bool tie_low) {
177 XlaBuilder* builder = input.builder();
178 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
179 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
180 XlaOp init_value;
181 XlaComputation reducer;
182 if (is_min) {
183 init_value = MaxValue(builder, input_shape.element_type());
184 reducer = CreateScalarMinComputation(input_shape.element_type(), builder);
185 } else {
186 init_value = MinValue(builder, input_shape.element_type());
187 reducer = CreateScalarMaxComputation(input_shape.element_type(), builder);
188 }
189
190 XlaOp iota = Iota(
191 builder, ShapeUtil::ChangeElementType(input_shape, output_type), axis);
192 XlaOp reduced_input = Reduce(input, init_value, reducer,
193 /*dimensions_to_reduce=*/{axis});
194 std::vector<int64> broadcast_dims(input_shape.rank() - 1);
195 std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
196 std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
197 if (tie_low) {
198 XlaOp max_idx = MaxValue(builder, output_type);
199 XlaOp select_mask = Select(Eq(input, reduced_input, broadcast_dims),
200 /*on_true=*/iota,
201 /*on_false=*/
202 max_idx);
203 return Reduce(select_mask, max_idx,
204 CreateScalarMinComputation(output_type, builder),
205 /*dimensions_to_reduce=*/{axis});
206 } else {
207 XlaOp min_idx = MinValue(builder, output_type);
208 XlaOp select_mask = Select(Eq(input, reduced_input, broadcast_dims),
209 /*on_true=*/iota,
210 /*on_false=*/
211 min_idx);
212 return Reduce(select_mask, min_idx,
213 CreateScalarMaxComputation(output_type, builder),
214 /*dimensions_to_reduce=*/{axis});
215 }
216 });
217 }
218 } // namespace
219
ArgMax(XlaOp input,PrimitiveType output_type,int axis,bool stable,bool tie_low)220 XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis, bool stable,
221 bool tie_low) {
222 return ArgMinMax(input, output_type, axis, /*is_min=*/false, stable, tie_low);
223 }
224
ArgMin(XlaOp input,PrimitiveType output_type,int axis,bool stable,bool tie_low)225 XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis, bool stable,
226 bool tie_low) {
227 return ArgMinMax(input, output_type, axis, /*is_min=*/true, stable, tie_low);
228 }
229
ArgMaxTwoPass(XlaOp input,PrimitiveType output_type,int axis,bool tie_low)230 XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
231 bool tie_low) {
232 return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/false, tie_low);
233 }
234
ArgMinTwoPass(XlaOp input,PrimitiveType output_type,int axis,bool tie_low)235 XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis,
236 bool tie_low) {
237 return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/true, tie_low);
238 }
239 } // namespace xla
240