1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
17 
18 #include <string>
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_computation.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 
29 namespace xla {
30 namespace {
31 
32 using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&);
33 
CreateScalarComputation(const string & name,PrimitiveType type,XlaBuilder * builder,XlaOpGenerator generator)34 XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
35                                        XlaBuilder* builder,
36                                        XlaOpGenerator generator) {
37   std::unique_ptr<XlaBuilder> b;
38   if (type == PRED) {
39     b = builder->CreateSubBuilder(name);
40   } else {
41     b = builder->CreateSubBuilder(
42         absl::StrCat(name, "_", PrimitiveType_Name(type)));
43   }
44 
45   const Shape scalar = ShapeUtil::MakeShape(type, {});
46   auto lhs = Parameter(b.get(), 0, scalar, "lhs");
47   auto rhs = Parameter(b.get(), 1, scalar, "rhs");
48   generator(b.get(), lhs, rhs);
49   return b->BuildAndNoteError();
50 }
51 
52 }  // namespace
53 
CreateScalarAddComputation(PrimitiveType type,XlaBuilder * builder)54 XlaComputation CreateScalarAddComputation(PrimitiveType type,
55                                           XlaBuilder* builder) {
56   return CreateScalarComputation(
57       "add", type, builder,
58       [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
59         return Add(lhs, rhs);
60       });
61 }
62 
CreateScalarMultiplyComputation(PrimitiveType type,XlaBuilder * builder)63 XlaComputation CreateScalarMultiplyComputation(PrimitiveType type,
64                                                XlaBuilder* builder) {
65   return CreateScalarComputation(
66       "mul", type, builder,
67       [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
68         return Mul(lhs, rhs);
69       });
70 }
71 
CreateScalarGeComputation(PrimitiveType type,XlaBuilder * builder)72 XlaComputation CreateScalarGeComputation(PrimitiveType type,
73                                          XlaBuilder* builder) {
74   return CreateScalarComputation("ge", type, builder,
75                                  [](XlaBuilder* b, const XlaOp& lhs,
76                                     const XlaOp& rhs) { return Ge(lhs, rhs); });
77 }
78 
CreateScalarMaxComputation(PrimitiveType type,XlaBuilder * builder)79 XlaComputation CreateScalarMaxComputation(PrimitiveType type,
80                                           XlaBuilder* builder) {
81   return CreateScalarComputation(
82       "max", type, builder,
83       [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
84         return Max(lhs, rhs);
85       });
86 }
87 
CreateScalarMinComputation(PrimitiveType type,XlaBuilder * builder)88 XlaComputation CreateScalarMinComputation(PrimitiveType type,
89                                           XlaBuilder* builder) {
90   return CreateScalarComputation(
91       "min", type, builder,
92       [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
93         return Min(lhs, rhs);
94       });
95 }
96 
CreateScalarAndComputation(PrimitiveType type,XlaBuilder * builder)97 XlaComputation CreateScalarAndComputation(PrimitiveType type,
98                                           XlaBuilder* builder) {
99   return CreateScalarComputation(
100       "and", type, builder,
101       [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
102         return And(lhs, rhs);
103       });
104 }
105 
CreateScalarOrComputation(PrimitiveType type,XlaBuilder * builder)106 XlaComputation CreateScalarOrComputation(PrimitiveType type,
107                                          XlaBuilder* builder) {
108   return CreateScalarComputation("or", type, builder,
109                                  [](XlaBuilder* b, const XlaOp& lhs,
110                                     const XlaOp& rhs) { return Or(lhs, rhs); });
111 }
112 
Any(XlaOp predicates)113 XlaOp Any(XlaOp predicates) {
114   XlaBuilder* builder = predicates.builder();
115   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
116     auto f = ConstantR0<bool>(builder, false);
117     XlaComputation logical_or = CreateScalarOrComputation(PRED, builder);
118     TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
119                         builder->GetShape(predicates));
120     std::vector<int64> all_dimensions(predicates_shape.rank());
121     std::iota(all_dimensions.begin(), all_dimensions.end(), 0);
122     return Reduce(predicates, f, logical_or, all_dimensions);
123   });
124 }
125 
126 namespace {
127 
ArgMinMax(XlaOp input,PrimitiveType output_type,int axis,bool is_min)128 XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
129   XlaBuilder* builder = input.builder();
130   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
131     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
132     XlaOp init_value;
133     XlaComputation reducer;
134     if (is_min) {
135       init_value = MaxValue(builder, input_shape.element_type());
136       reducer = CreateScalarMinComputation(input_shape.element_type(), builder);
137     } else {
138       init_value = MinValue(builder, input_shape.element_type());
139       reducer = CreateScalarMaxComputation(input_shape.element_type(), builder);
140     }
141 
142     XlaOp input_max = Reduce(input, init_value, reducer,
143                              /*dimensions_to_reduce=*/{axis});
144     std::vector<int64> broadcast_dims(input_shape.rank() - 1);
145     std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
146     std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
147     // Compute a mask that has 1s for elements equal to the maximum.
148     XlaOp partial_mask =
149         ConvertElementType(Eq(input, input_max, broadcast_dims), output_type);
150 
151     // In order to make identity elements for a bitwise And, we:
152     //   Left shift the 1 to the leftmost bit, yielding 0x10...0
153     //   Arithmetic right shift the 1 back to the rightmost bit, yielding
154     //   0xFF...F
155     int32 bits_in_type =
156         ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1;
157     XlaOp shift_amount = ConstantR0WithType(builder, output_type, bits_in_type);
158     XlaOp full_mask = ShiftRightArithmetic(
159         ShiftLeft(partial_mask, shift_amount), shift_amount);
160 
161     // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its
162     // index.
163 
164     const int64 axis_size = ShapeUtil::GetDimension(input_shape, axis);
165     XlaOp iota = Iota(builder, output_type, axis_size);
166     XlaOp product = And(full_mask, iota, /*broadcast_dimensions=*/{axis});
167 
168     // If there are multiple maximum elements, choose the one with the highest
169     // index.
170     return Reduce(product, MinValue(builder, output_type),
171                   CreateScalarMaxComputation(output_type, builder),
172                   /*dimensions_to_reduce=*/{axis});
173   });
174 }
175 
176 }  // namespace
177 
ArgMax(XlaOp input,PrimitiveType output_type,int axis)178 XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) {
179   return ArgMinMax(input, output_type, axis, /*is_min=*/false);
180 }
181 
ArgMin(XlaOp input,PrimitiveType output_type,int axis)182 XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) {
183   return ArgMinMax(input, output_type, axis, /*is_min=*/true);
184 }
185 
186 }  // namespace xla
187