1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // SWIG typemaps and declarations for building, compiling, and
17 // executing XLA computations, wrapping most of what is declared in
18 // local_computation_builder.h.
19 
20 %module(threads="1") local_computation_builder
21 
22 // Keep the GIL except where explicitly specified.
23 %nothread;
24 
25 %include "tensorflow/python/platform/base.i"
26 %include "tensorflow/compiler/xla/python/xla_data.i"
27 
28 %{
29 // Must be included first
30 #include "tensorflow/python/lib/core/numpy.h"
31 
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_format.h"
34 #include "tensorflow/compiler/xla/literal.h"
35 #include "tensorflow/compiler/xla/shape_util.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "absl/types/span.h"
38 #include "tensorflow/compiler/xla/python/numpy_bridge.h"
39 #include "tensorflow/compiler/xla/python/local_computation_builder.h"
40 
41 using namespace xla;
42 using namespace xla::swig;
43 
44 %}
45 
46 // Required to use PyArray_* functions.
47 %init %{
48 tensorflow::ImportNumpy();
49 %}
50 
51 // Computation builder types
52 
53 %typemap(in) absl::Span<const xla::swig::LocalOp>(
54       std::vector<LocalOp> temps) {
55   if (!PySequence_Check($input)) {
56     PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
57     SWIG_fail;
58   }
59   const int size = PySequence_Size($input);
60   for (int i = 0; i < size; ++i) {
61     PyObject* o = PySequence_GetItem($input, i);
62     LocalOp* op;
63     if ((SWIG_ConvertPtr(o, (void**)&op, $descriptor(xla::swig::LocalOp*),
64                          SWIG_POINTER_EXCEPTION)) == -1) {
65       SWIG_fail;
66     }
67     temps.push_back(*op);
68     Py_DECREF(o);
69   }
70   $1 = temps;
71 }
72 
73 // Computation and buffer/allocation types
74 
75 %typemap(out) StatusOr<xla::swig::LocalClient> {
76   if ($1.ok()) {
77     xla::swig::LocalClient value = $1.ValueOrDie();
78     {
79       auto $1 = value;
80       $typemap(out, xla::swig::LocalClient)
81     }
82   } else {
83     PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
84     SWIG_fail;
85   }
86 }
87 
88 %typemap(out) StatusOr<xla::swig::LocalExecutable*> {
89   if ($1.ok()) {
90     auto* value = $1.ValueOrDie();
91     {
92       auto* $1 = value;
93       $typemap(out, xla::swig::LocalExecutable*)
94     }
95   } else {
96     PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
97     SWIG_fail;
98   }
99 }
100 
101 %typemap(out) StatusOr<xla::swig::LocalShapedBuffer*> {
102   if ($1.ok()) {
103     auto* value = $1.ValueOrDie();
104     {
105       auto* $1 = value;
106       $typemap(out, xla::swig::LocalShapedBuffer*)
107     }
108   } else {
109     PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
110     SWIG_fail;
111   }
112 }
113 
114 %typemap(out) StatusOr<xla::swig::LocalShapedBufferTuple*> {
115   if ($1.ok()) {
116     auto* value = $1.ValueOrDie();
117     {
118       auto* $1 = value;
119       $typemap(out, xla::swig::LocalShapedBufferTuple*)
120     }
121   } else {
122     PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
123     SWIG_fail;
124   }
125 }
126 
127 %typemap(out) StatusOr<xla::swig::Computation*> {
128   if ($1.ok()) {
129     auto* value = $1.ValueOrDie();
130     {
131       auto* $1 = value;
132       $typemap(out, xla::swig::Computation*)
133     }
134   } else {
135     PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
136     SWIG_fail;
137   }
138 }
139 
140 %typemap(in) absl::Span<xla::swig::LocalShapedBuffer* const>
141     (std::vector<LocalShapedBuffer*> temps) {
142   if (!PySequence_Check($input)) {
143     PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
144     SWIG_fail;
145   }
146   const int size = PySequence_Size($input);
147   temps.reserve(size);
148   for (int i = 0; i < size; ++i) {
149     PyObject* o = PySequence_GetItem($input, i);
150     LocalShapedBuffer* lsbp;
151     if ((SWIG_ConvertPtr(o, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*),
152                          SWIG_POINTER_EXCEPTION)) == -1) {
153       SWIG_fail;
154     }
155     temps.push_back(lsbp);
156     Py_DECREF(o);
157   }
158   $1 = temps;
159 }
160 
161 %typemap(in) absl::Span<const std::vector<xla::swig::LocalShapedBuffer*> >
162     (std::vector<std::vector<LocalShapedBuffer*> > temps) {
163   if (!PySequence_Check($input)) {
164     PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
165     SWIG_fail;
166   }
167   const int size = PySequence_Size($input);
168   temps.reserve(size);
169   for (int i = 0; i < size; ++i) {
170     PyObject* o = PySequence_GetItem($input, i);
171     std::vector<LocalShapedBuffer*> vec;
172     const int vec_size = PySequence_Size(o);
173     vec.reserve(vec_size);
174     for (int j = 0; j < vec_size; ++j) {
175       PyObject* vec_elt = PySequence_GetItem(o, j);
176       LocalShapedBuffer* lsbp;
177       if ((SWIG_ConvertPtr(vec_elt, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*),
178                            SWIG_POINTER_EXCEPTION)) == -1) {
179         Py_DECREF(vec_elt);
180         Py_DECREF(o);
181         SWIG_fail;
182       }
183       vec.push_back(lsbp);
184       Py_DECREF(vec_elt);
185     }
186     temps.push_back(vec);
187     Py_DECREF(o);
188   }
189   $1 = temps;
190 }
191 
192 // ExecutableBuildOptions
193 
194 %typemap(in) const ExecutableBuildOptions*
195     (ExecutableBuildOptions build_options) {
196   if ($input == Py_None) {
197     $1 = NULL;
198   } else {
199     if (!HandleStringAttribute($input, "dump_to", [&](string s) {
200       build_options.mutable_debug_options()->set_xla_dump_to(std::move(s));
201     })) {
202       return nullptr;
203     }
204     if (!HandleStringAttribute($input, "dump_hlo_pass_re", [&](string s) {
205       build_options.mutable_debug_options()->set_xla_dump_hlo_pass_re(std::move(s));
206     })) {
207       return nullptr;
208     }
209     if (!HandleStringAttribute($input, "dump_hlo_module_re", [&](string s) {
210       build_options.mutable_debug_options()->set_xla_dump_hlo_module_re(std::move(s));
211     })) {
212       return nullptr;
213     }
214     if (!HandleBoolAttribute($input, "dump_hlo_as_text", [&](bool b) {
215       build_options.mutable_debug_options()->set_xla_dump_hlo_as_text(b);
216     })) {
217       return nullptr;
218     }
219     if (!HandleBoolAttribute($input, "dump_hlo_as_proto", [&](bool b) {
220       build_options.mutable_debug_options()->set_xla_dump_hlo_as_proto(b);
221     })) {
222       return nullptr;
223     }
224     if (!HandleBoolAttribute($input, "hlo_profile", [&](bool b) {
225       build_options.mutable_debug_options()->set_xla_hlo_profile(b);
226     })) {
227       return nullptr;
228     }
229 
230     PyObject* o = PyObject_GetAttrString($input, "result_shape");
231     if (o == nullptr) {
232       return nullptr;
233     }
234     if (o != Py_None) {
235       StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
236       if (!statusor.ok()) {
237         PyErr_SetString(PyExc_TypeError, absl::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str());
238         Py_DECREF(o);
239         SWIG_fail;
240       }
241       build_options.set_result_layout(statusor.ValueOrDie());
242     }
243     Py_DECREF(o);
244 
245     int64 num_replicas;
246     if (!GetIntAttr($input, "num_replicas", &num_replicas)) {
247       SWIG_fail;
248     }
249     build_options.set_num_replicas(num_replicas);
250 
251     $1 = &build_options;
252   }
253 }
254 
255 %ignoreall
256 %unignore xla;
257 %unignore xla::swig;
258 %unignore xla::swig::RegisterCpuCustomCallTarget;
259 %unignore xla::swig::LocalClient;
260 %unignore xla::swig::LocalClient::Get;
261 %unignore xla::swig::LocalClient::DeviceCount;
262 %unignore xla::swig::LocalClient::TransferToInfeed;
263 %unignore xla::swig::LocalClient::TransferFromOutfeed;
264 %unignore xla::swig::LocalShapedBuffer;
265 %unignore xla::swig::LocalShapedBuffer::FromLiteral;
266 %unignore xla::swig::LocalShapedBuffer::ToLiteral;
267 %unignore xla::swig::LocalShapedBuffer::shape;
268 %unignore xla::swig::LocalShapedBuffer::DestructureTuple;
269 %unignore xla::swig::LocalShapedBufferTuple;
270 %unignore xla::swig::LocalShapedBufferTuple::Release;
271 %unignore xla::swig::LocalShapedBufferTuple::size;
272 %unignore xla::swig::LocalExecutable;
273 %unignore xla::swig::LocalExecutable::DeviceOrdinals;
274 %unignore xla::swig::LocalExecutable::Execute;
275 %unignore xla::swig::LocalExecutable::ExecutePerReplica;
276 %unignore xla::swig::Computation;
277 %unignore xla::swig::Computation::Compile;
278 %unignore xla::swig::Computation::GetProgramShape;
279 %unignore xla::swig::Computation::GetReturnValueShape;
280 %unignore xla::swig::Computation::GetSerializedProto;
281 %unignore xla::swig::Computation::GetHloText;
282 %unignore xla::swig::Computation::GetHloDotGraph;
283 %unignore xla::swig::LocalOp;
284 %unignore xla::swig::ComputationBuilder;
285 %unignore xla::swig::ComputationBuilder::ComputationBuilder;
286 %unignore xla::swig::ComputationBuilder::Build;
287 %unignore xla::swig::ComputationBuilder::BuildWithRoot;
288 %unignore xla::swig::ComputationBuilder::SetOpMetadata;
289 %unignore xla::swig::ComputationBuilder::ClearOpMetadata;
290 %unignore xla::swig::ComputationBuilder::Parameter;
291 %unignore xla::swig::ComputationBuilder::GetShape;
292 %unignore xla::swig::ComputationBuilder::GetReturnValueShape;
293 %unignore xla::swig::ComputationBuilder::ReplicaId;
294 %unignore xla::swig::ComputationBuilder::Infeed;
295 %unignore xla::swig::ComputationBuilder::Outfeed;
296 %unignore xla::swig::ComputationBuilder::ConstantLiteral;
297 %unignore xla::swig::ComputationBuilder::ConstantR0;
298 %unignore xla::swig::ComputationBuilder::Iota;
299 %unignore xla::swig::ComputationBuilder::BroadcastedIota;
300 %unignore xla::swig::ComputationBuilder::Broadcast;
301 %unignore xla::swig::ComputationBuilder::BroadcastInDim;
302 %unignore xla::swig::ComputationBuilder::Pad;
303 %unignore xla::swig::ComputationBuilder::Reshape;
304 %unignore xla::swig::ComputationBuilder::Collapse;
305 %unignore xla::swig::ComputationBuilder::AllToAll;
306 %unignore xla::swig::ComputationBuilder::CrossReplicaSum;
307 %unignore xla::swig::ComputationBuilder::Slice;
308 %unignore xla::swig::ComputationBuilder::SliceInDim;
309 %unignore xla::swig::ComputationBuilder::DynamicSlice;
310 %unignore xla::swig::ComputationBuilder::DynamicUpdateSlice;
311 %unignore xla::swig::ComputationBuilder::ConcatInDim;
312 %unignore xla::swig::ComputationBuilder::SelectAndScatterWithGeneralPadding;
313 %unignore xla::swig::ComputationBuilder::Select;
314 %unignore xla::swig::ComputationBuilder::Tuple;
315 %unignore xla::swig::ComputationBuilder::GetTupleElement;
316 %unignore xla::swig::ComputationBuilder::ConvertElementType;
317 %unignore xla::swig::ComputationBuilder::BitcastConvertType;
318 %unignore xla::swig::ComputationBuilder::Call;
319 %unignore xla::swig::ComputationBuilder::Transpose;
320 %unignore xla::swig::ComputationBuilder::Rev;
321 %unignore xla::swig::ComputationBuilder::Clamp;
322 %unignore xla::swig::ComputationBuilder::Map;
323 %unignore xla::swig::ComputationBuilder::Reduce;
324 %unignore xla::swig::ComputationBuilder::ReduceWindowWithGeneralPadding;
325 %unignore xla::swig::ComputationBuilder::RngNormal;
326 %unignore xla::swig::ComputationBuilder::RngUniform;
327 %unignore xla::swig::ComputationBuilder::RngBernoulli;
328 %unignore xla::swig::ComputationBuilder::While;
329 %unignore xla::swig::ComputationBuilder::Conditional;
330 %unignore xla::swig::ComputationBuilder::IsConstant;
331 %unignore xla::swig::ComputationBuilder::Eq;
332 %unignore xla::swig::ComputationBuilder::Ne;
333 %unignore xla::swig::ComputationBuilder::Ge;
334 %unignore xla::swig::ComputationBuilder::Gt;
335 %unignore xla::swig::ComputationBuilder::Lt;
336 %unignore xla::swig::ComputationBuilder::Le;
337 %unignore xla::swig::ComputationBuilder::Dot;
338 %unignore xla::swig::ComputationBuilder::DotGeneral;
339 %unignore xla::swig::ComputationBuilder::ConvGeneralDilated;
340 %unignore xla::swig::ComputationBuilder::Add;
341 %unignore xla::swig::ComputationBuilder::Sub;
342 %unignore xla::swig::ComputationBuilder::Mul;
343 %unignore xla::swig::ComputationBuilder::Div;
344 %unignore xla::swig::ComputationBuilder::Rem;
345 %unignore xla::swig::ComputationBuilder::Max;
346 %unignore xla::swig::ComputationBuilder::Min;
347 %unignore xla::swig::ComputationBuilder::And;
348 %unignore xla::swig::ComputationBuilder::Or;
349 %unignore xla::swig::ComputationBuilder::Xor;
350 %unignore xla::swig::ComputationBuilder::ShiftLeft;
351 %unignore xla::swig::ComputationBuilder::ShiftRightArithmetic;
352 %unignore xla::swig::ComputationBuilder::ShiftRightLogical;
353 %unignore xla::swig::ComputationBuilder::Not;
354 %unignore xla::swig::ComputationBuilder::Clz;
355 %unignore xla::swig::ComputationBuilder::Abs;
356 %unignore xla::swig::ComputationBuilder::Exp;
357 %unignore xla::swig::ComputationBuilder::Expm1;
358 %unignore xla::swig::ComputationBuilder::Floor;
359 %unignore xla::swig::ComputationBuilder::Ceil;
360 %unignore xla::swig::ComputationBuilder::Round;
361 %unignore xla::swig::ComputationBuilder::Log;
362 %unignore xla::swig::ComputationBuilder::Log1p;
363 %unignore xla::swig::ComputationBuilder::Sign;
364 %unignore xla::swig::ComputationBuilder::Cos;
365 %unignore xla::swig::ComputationBuilder::Sin;
366 %unignore xla::swig::ComputationBuilder::Tanh;
367 %unignore xla::swig::ComputationBuilder::Atan2;
368 %unignore xla::swig::ComputationBuilder::IsFinite;
369 %unignore xla::swig::ComputationBuilder::Pow;
370 %unignore xla::swig::ComputationBuilder::Neg;
371 %unignore xla::swig::ComputationBuilder::Sort;
372 %unignore xla::swig::ComputationBuilder::SortKeyVal;
373 %unignore xla::swig::ComputationBuilder::Sqrt;
374 %unignore xla::swig::ComputationBuilder::Rsqrt;
375 %unignore xla::swig::ComputationBuilder::Square;
376 %unignore xla::swig::ComputationBuilder::Reciprocal;
377 %unignore xla::swig::ComputationBuilder::Erfc;
378 %unignore xla::swig::ComputationBuilder::Erf;
379 %unignore xla::swig::ComputationBuilder::ErfInv;
380 %unignore xla::swig::ComputationBuilder::Lgamma;
381 %unignore xla::swig::ComputationBuilder::Digamma;
382 %unignore xla::swig::ComputationBuilder::Acos;
383 %unignore xla::swig::ComputationBuilder::Asin;
384 %unignore xla::swig::ComputationBuilder::Atan;
385 %unignore xla::swig::ComputationBuilder::Tan;
386 %unignore xla::swig::ComputationBuilder::Acosh;
387 %unignore xla::swig::ComputationBuilder::Asinh;
388 %unignore xla::swig::ComputationBuilder::Atanh;
389 %unignore xla::swig::ComputationBuilder::Cosh;
390 %unignore xla::swig::ComputationBuilder::Sinh;
391 %unignore xla::swig::ComputationBuilder::Real;
392 %unignore xla::swig::ComputationBuilder::Imag;
393 %unignore xla::swig::ComputationBuilder::Conj;
394 %unignore xla::swig::ComputationBuilder::Complex;
395 %unignore xla::swig::ComputationBuilder::Cholesky;
396 %unignore xla::swig::ComputationBuilder::QR;
397 %unignore xla::swig::ComputationBuilder::Eigh;
398 %unignore xla::swig::ComputationBuilder::SVD;
399 %unignore xla::swig::ComputationBuilder::TriangularSolve;
400 %unignore xla::swig::ComputationBuilder::CustomCall;
401 %unignore xla::swig::ComputationBuilder::Gather;
402 %unignore xla::swig::ComputationBuilder::Scatter;
403 %unignore xla::swig::DeleteComputation;
404 %unignore xla::swig::DeleteLocalShapedBuffer;
405 %unignore xla::swig::DeleteLocalExecutable;
406 
407 %thread;
408 %include "tensorflow/compiler/xla/python/local_computation_builder.h"
409 %nothread;
410 
411 %unignoreall
412