1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/ops.h"
17 
18 #include <string>
19 #include <vector>
20 
21 #include "absl/types/optional.h"
22 #include "absl/types/span.h"
23 #include "pybind11/attr.h"
24 #include "pybind11/pybind11.h"
25 #include "tensorflow/compiler/xla/client/lib/comparators.h"
26 #include "tensorflow/compiler/xla/client/lib/lu_decomposition.h"
27 #include "tensorflow/compiler/xla/client/lib/math.h"
28 #include "tensorflow/compiler/xla/client/lib/qr.h"
29 #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
30 #include "tensorflow/compiler/xla/client/lib/sorting.h"
31 #include "tensorflow/compiler/xla/client/lib/svd.h"
32 #include "tensorflow/compiler/xla/client/xla_builder.h"
33 #include "tensorflow/compiler/xla/client/xla_computation.h"
34 #include "tensorflow/compiler/xla/python/types.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 
37 namespace xla {
38 
39 namespace py = pybind11;
40 
BuildOpsSubmodule(py::module * m)41 void BuildOpsSubmodule(py::module* m) {
42   // ops submodule, containing free functions that add operators to an
43   // XlaBuilder.
44   py::module ops = m->def_submodule("ops", "XLA operations");
45 
46   py::enum_<TriangularSolveOptions::Transpose>(
47       ops, "TriangularSolveOptions_Transpose")
48       .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID)
49       .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE)
50       .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE)
51       .value("ADJOINT", TriangularSolveOptions::ADJOINT);
52 
53   py::enum_<RandomAlgorithm>(ops, "RandomAlgorithm")
54       .value("RNG_DEFAULT", RandomAlgorithm::RNG_DEFAULT)
55       .value("RNG_THREE_FRY", RandomAlgorithm::RNG_THREE_FRY)
56       .value("RNG_PHILOX", RandomAlgorithm::RNG_PHILOX);
57 
58   ops.def("AfterAll", &AfterAll, py::arg("builder"), py::arg("tokens"));
59   ops.def("AllGather", &AllGather, py::arg("operand"),
60           py::arg("all_gather_dimension"), py::arg("shard_count"),
61           py::arg("replica_groups") = py::list(),
62           py::arg("channel_id") = absl::nullopt,
63           py::arg("shape_with_layout") = absl::nullopt,
64           py::arg("use_global_device_ids") = absl::nullopt);
65   ops.def(
66       "AllReduce",
67       static_cast<XlaOp (*)(
68           XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
69           const absl::optional<ChannelHandle>&, const absl::optional<Shape>&)>(
70           &AllReduce),
71       py::arg("operand"), py::arg("computation"),
72       py::arg("replica_groups") = py::list(),
73       py::arg("channel_id") = absl::nullopt,
74       py::arg("shape_with_layout") = absl::nullopt);
75   ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
76           py::arg("concat_dimension"), py::arg("split_count"),
77           py::arg("replica_groups") = py::list(),
78           py::arg("layout") = absl::nullopt);
79   ops.def("CollectivePermute", &CollectivePermute, py::arg("operand"),
80           py::arg("source_target_pairs"));
81   ops.def("CreateToken", &CreateToken, py::arg("builder"));
82   ops.def("CrossReplicaSum",
83           static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>(
84               &CrossReplicaSum),
85           py::arg("operand"), py::arg("replica_groups") = py::list());
86   ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"),
87           py::arg("new_element_type"));
88   ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes"));
89   ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"),
90           py::arg("shape"), py::arg("broadcast_dimensions"));
91   ops.def("Call", &Call, py::arg("builder"), py::arg("computation"),
92           py::arg("operands"));
93   ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true);
94   ops.def("Clamp", &Clamp, py::arg("min"), py::arg("operand"), py::arg("max"));
95   ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions"));
96   ops.def("ConcatInDim", &ConcatInDim, py::arg("builder"), py::arg("operands"),
97           py::arg("dimension"));
98   ops.def("Conditional",
99           static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaComputation* const>,
100                                 absl::Span<const XlaOp>)>(&Conditional),
101           py::arg("branch_index"), py::arg("branch_computations"),
102           py::arg("branch_operands"));
103   ops.def("Conditional",
104           static_cast<XlaOp (*)(XlaOp, XlaOp, const XlaComputation&, XlaOp,
105                                 const XlaComputation&)>(&Conditional),
106           py::arg("predicate"), py::arg("true_operand"),
107           py::arg("true_computation"), py::arg("false_operand"),
108           py::arg("false_computation"));
109   ops.def("Constant", &ConstantLiteral, py::arg("builder"), py::arg("literal"));
110   ops.def("ConstantLiteral", &ConstantLiteral, py::arg("builder"),
111           py::arg("literal"));
112   ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"),
113           py::arg("rhs"), py::arg("window_strides"), py::arg("padding"),
114           py::arg("lhs_dilation"), py::arg("rhs_dilation"),
115           py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
116           py::arg("batch_group_count") = 1,
117           py::arg("precision_config") = nullptr,
118           py::arg("preferred_element_type") = absl::nullopt);
119   ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
120           py::arg("new_element_type"));
121   ops.def(
122       "CustomCall",
123       [](XlaBuilder* builder, const py::bytes& call_target_name,
124          absl::Span<const XlaOp> operands, const Shape& shape,
125          const py::bytes& opaque, bool has_side_effect) -> XlaOp {
126         return CustomCall(builder, call_target_name, operands, shape, opaque,
127                           has_side_effect);
128       },
129       py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
130       py::arg("shape"), py::arg("opaque") = py::bytes(""),
131       py::arg("has_side_effect") = false);
132   ops.def(
133       "CustomCallWithLayout",
134       [](XlaBuilder* builder, const py::bytes& call_target_name,
135          absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
136          absl::Span<const Shape> operand_shapes_with_layout,
137          const py::bytes& opaque, bool has_side_effect) -> XlaOp {
138         return CustomCallWithLayout(
139             builder, call_target_name, operands, shape_with_layout,
140             operand_shapes_with_layout, opaque, has_side_effect);
141       },
142       py::arg("builder"), py::arg("call_target_name"), py::arg("operands"),
143       py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
144       py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false);
145   ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
146           py::arg("precision_config") = nullptr,
147           py::arg("preferred_element_type") = absl::nullopt);
148   ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
149           py::arg("dimension_numbers"), py::arg("precision_config") = nullptr,
150           py::arg("preferred_element_type") = absl::nullopt);
151   ops.def("DynamicSlice",
152           static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
153                                 absl::Span<const int64>)>(&DynamicSlice),
154           py::arg("operand"), py::arg("start_indices"), py::arg("slice_sizes"));
155   ops.def("DynamicUpdateSlice",
156           static_cast<XlaOp (*)(XlaOp, XlaOp, absl::Span<const XlaOp>)>(
157               &DynamicUpdateSlice),
158           py::arg("operand"), py::arg("update"), py::arg("start_indices"));
159 
160   ops.def("Fft", &Fft, py::arg("operand"), py::arg("fft_type"),
161           py::arg("fft_length"));
162 
163   ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"),
164           py::arg("dimension_numbers"), py::arg("slice_sizes"),
165           py::arg("indices_are_sorted") = false);
166   ops.def("GetTupleElement", &GetTupleElement, py::arg("tuple_data"),
167           py::arg("index"));
168   ops.def("InfeedWithToken", &InfeedWithToken, py::arg("token"),
169           py::arg("shape"), py::arg("config") = "");
170   ops.def("Iota",
171           static_cast<XlaOp (*)(XlaBuilder*, const Shape&, int64)>(&Iota),
172           py::arg("builder"), py::arg("shape"), py::arg("iota_dimension"));
173   ops.def("Iota",
174           static_cast<XlaOp (*)(XlaBuilder*, PrimitiveType, int64)>(&Iota),
175           py::arg("builder"), py::arg("type"), py::arg("size"));
176   ops.def("Map", &Map, py::arg("builder"), py::arg("operands"),
177           py::arg("computation"), py::arg("dimensions"),
178           py::arg("static_operands") = py::list());
179   ops.def("NextAfter", &NextAfter, py::arg("from"), py::arg("to"));
180   ops.def("OutfeedWithToken", &OutfeedWithToken, py::arg("operand"),
181           py::arg("token"), py::arg("shape_with_layout"),
182           py::arg("outfeed_config") = "");
183   ops.def("Pad", &Pad, py::arg("operand"), py::arg("padding_value"),
184           py::arg("padding_config"));
185   ops.def("Parameter",
186           static_cast<XlaOp (*)(XlaBuilder*, int64, const Shape&,
187                                 const std::string&, const std::vector<bool>&)>(
188               &Parameter),
189           py::arg("builder"), py::arg("parameter_number"), py::arg("shape"),
190           py::arg("name") = "",
191           py::arg("replicated_at_leaf_buffers") = std::vector<bool>());
192   ops.def(
193       "QR",
194       [](XlaOp a, bool full_matrices) -> StatusOr<std::pair<XlaOp, XlaOp>> {
195         TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices));
196         return std::make_pair(qr.q, qr.r);
197       },
198       py::arg("operand"), py::arg("full_matrices"));
199   ops.def(
200       "LU",
201       [](XlaOp a) -> StatusOr<std::tuple<XlaOp, XlaOp, XlaOp>> {
202         LuDecompositionResult lu = LuDecomposition(a);
203         return std::make_tuple(lu.lu, lu.pivots, lu.permutation);
204       },
205       py::arg("operand"));
206   ops.def(
207       "Eigh",
208       [](XlaOp a, bool lower, int64 max_iter,
209          float epsilon) -> std::pair<XlaOp, XlaOp> {
210         auto eigh = SelfAdjointEig(a, lower, max_iter, epsilon);
211         return std::make_pair(eigh.v, eigh.w);
212       },
213       py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 100,
214       py::arg("epsilon") = 1e-6);
215   ops.def(
216       "SVD",
217       [](XlaOp a, int64 max_iter,
218          float epsilon) -> std::tuple<XlaOp, XlaOp, XlaOp> {
219         auto svd = SVD(a, max_iter, epsilon);
220         return std::make_tuple(svd.u, svd.d, svd.v);
221       },
222       py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6);
223   ops.def("Reduce",
224           static_cast<XlaOp (*)(XlaBuilder*, absl::Span<const XlaOp>,
225                                 absl::Span<const XlaOp>, const XlaComputation&,
226                                 absl::Span<const int64>)>(&Reduce),
227           py::arg("builder"), py::arg("operands"), py::arg("init_values"),
228           py::arg("computation"), py::arg("dimensions_to_reduce"));
229   ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"),
230           py::arg("exponent_bits"), py::arg("mantissa_bits"));
231   ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding,
232           py::arg("operand"), py::arg("init_value"), py::arg("computation"),
233           py::arg("window_dimensions"), py::arg("window_strides"),
234           py::arg("base_dilations"), py::arg("window_dilations"),
235           py::arg("padding"));
236   ops.def("ReplicaId", &ReplicaId, py::arg("builder"));
237   ops.def("Reshape",
238           static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>,
239                                 absl::Span<const int64>)>(&Reshape),
240           py::arg("operand"), py::arg("dimensions"), py::arg("new_sizes"));
241   ops.def("Reshape",
242           static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>)>(&Reshape),
243           py::arg("operand"), py::arg("new_sizes"));
244   ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions"));
245   ops.def("RngBitGenerator", &RngBitGenerator, py::arg("algorithm"),
246           py::arg("initial_state"), py::arg("shape"));
247   ops.def("RngNormal", &RngNormal, py::arg("mu"), py::arg("sigma"),
248           py::arg("shape"));
249   ops.def("RngUniform", &RngUniform, py::arg("a"), py::arg("b"),
250           py::arg("shape"));
251   ops.def("Scatter", &Scatter, py::arg("input"), py::arg("scatter_indices"),
252           py::arg("updates"), py::arg("update_computation"),
253           py::arg("dimension_numbers"), py::arg("indices_are_sorted") = false,
254           py::arg("unique_indices") = false);
255   ops.def("Select", &Select, py::arg("pred"), py::arg("on_true"),
256           py::arg("on_false"));
257   ops.def("SelectAndScatterWithGeneralPadding",
258           &SelectAndScatterWithGeneralPadding, py::arg("operand"),
259           py::arg("select"), py::arg("window_dimensions"),
260           py::arg("window_strides"), py::arg("padding"), py::arg("source"),
261           py::arg("init_value"), py::arg("scatter"));
262   ops.def("Slice", &Slice, py::arg("operand"), py::arg("start_indices"),
263           py::arg("limit_indices"), py::arg("strides"));
264   ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"),
265           py::arg("limit_index"), py::arg("stride"), py::arg("dimno"));
266   ops.def(
267       "Sort",
268       [](XlaBuilder* builder, absl::Span<const XlaOp> operands,
269          absl::optional<const XlaComputation*> comparator, int64 dimension,
270          bool is_stable) -> XlaOp {
271         return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
272           std::vector<PrimitiveType> operand_types;
273           for (const auto& operand : operands) {
274             TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand));
275             operand_types.push_back(operand_shape.element_type());
276           }
277 
278           if (comparator) {
279             return Sort(operands, **comparator, dimension, is_stable);
280           } else {
281             return Sort(operands,
282                         CreateScalarLtComputation(operand_types, builder),
283                         dimension, is_stable);
284           }
285         });
286       },
287       py::arg("builder"), py::arg("operands"),
288       py::arg("comparator") = absl::nullopt, py::arg("dimension") = -1,
289       py::arg("is_stable") = false);
290   ops.def("TopK", &TopK, py::arg("input"), py::arg("k"));
291   ops.def("Transpose", &Transpose, py::arg("operand"), py::arg("permutation"));
292   ops.def("TriangularSolve", &TriangularSolve, py::arg("a"), py::arg("b"),
293           py::arg("left_side"), py::arg("lower"), py::arg("unit_diagonal"),
294           py::arg("transpose_a"));
295   ops.def("Tuple", &Tuple, py::arg("builder"), py::arg("elements"));
296   ops.def("While", &While, py::arg("condition"), py::arg("body"),
297           py::arg("init"));
298 
299   ops.def("Igamma", &Igamma, py::arg("a"), py::arg("x"));
300   ops.def("Igammac", &Igammac, py::arg("a"), py::arg("x"));
301   ops.def("IgammaGradA", &IgammaGradA, py::arg("a"), py::arg("x"));
302   ops.def("RandomGammaGrad", &RandomGammaGrad, py::arg("a"), py::arg("x"));
303   ops.def("RegularizedIncompleteBeta", &RegularizedIncompleteBeta, py::arg("a"),
304           py::arg("b"), py::arg("x"));
305   ops.def("Zeta", &Zeta, py::arg("x"), py::arg("q"));
306 
307 #define BINARY_OP(op)                                                 \
308   ops.def(                                                            \
309       #op,                                                            \
310       [](XlaOp a, XlaOp b, absl::optional<std::vector<int64>> dims) { \
311         return dims ? op(a, b, *dims) : op(a, b);                     \
312       },                                                              \
313       py::arg("lhs"), py::arg("rhs"),                                 \
314       py::arg("broadcast_dimensions") = absl::nullopt)
315   BINARY_OP(Eq);
316   BINARY_OP(Ne);
317   BINARY_OP(Ge);
318   BINARY_OP(Gt);
319   BINARY_OP(Lt);
320   BINARY_OP(Le);
321   BINARY_OP(Add);
322   BINARY_OP(Sub);
323   BINARY_OP(Mul);
324   BINARY_OP(Div);
325   BINARY_OP(Rem);
326   BINARY_OP(Max);
327   BINARY_OP(Min);
328   BINARY_OP(And);
329   BINARY_OP(Or);
330   BINARY_OP(Xor);
331   BINARY_OP(ShiftLeft);
332   BINARY_OP(ShiftRightArithmetic);
333   BINARY_OP(ShiftRightLogical);
334   BINARY_OP(Atan2);
335   BINARY_OP(Pow);
336   BINARY_OP(Complex);
337 #undef BINARY_OP
338 
339 #define UNARY_OP(op) ops.def(#op, &op)
340   UNARY_OP(Not);
341   UNARY_OP(PopulationCount);
342   UNARY_OP(Clz);
343   UNARY_OP(Abs);
344   UNARY_OP(Exp);
345   UNARY_OP(Expm1);
346   UNARY_OP(Floor);
347   UNARY_OP(Ceil);
348   UNARY_OP(Round);
349   UNARY_OP(Log);
350   UNARY_OP(Log1p);
351   UNARY_OP(Sign);
352   UNARY_OP(Cos);
353   UNARY_OP(Sin);
354   UNARY_OP(Tanh);
355   UNARY_OP(IsFinite);
356   UNARY_OP(Neg);
357   UNARY_OP(Sqrt);
358   UNARY_OP(Rsqrt);
359   UNARY_OP(Square);
360   UNARY_OP(Reciprocal);
361   UNARY_OP(Erfc);
362   UNARY_OP(Erf);
363   UNARY_OP(ErfInv);
364   UNARY_OP(Lgamma);
365   UNARY_OP(Digamma);
366   UNARY_OP(BesselI0e);
367   UNARY_OP(BesselI1e);
368   UNARY_OP(Acos);
369   UNARY_OP(Asin);
370   UNARY_OP(Atan);
371   UNARY_OP(Tan);
372   UNARY_OP(Acosh);
373   UNARY_OP(Asinh);
374   UNARY_OP(Atanh);
375   UNARY_OP(Cosh);
376   UNARY_OP(Sinh);
377   UNARY_OP(Real);
378   UNARY_OP(Imag);
379   UNARY_OP(Conj);
380 #undef UNARY_OP
381 }
382 
383 }  // namespace xla
384