1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // SWIG typemaps for building, compiling, and executing XLA computations.
17 //
18 // The typemaps below implement/assert the following correspondences
19 // (with elaborations below):
20 //
21 //    C++                                  Python
22 // -------------------------------------+---------------------------------------
23 //  Span<int64>                        <-  sequence of int
24 //  vector<int>                        ->  sequence of int
25 //  Span<LocalOp>                      <-  sequence of LocalOp
26 //  Literal                            <-> (nested tuple of) numpy ndarray
27 //  std::vector<Literal>               <-  sequence of (nested tuple of) ndarray
28 //  Shape                               -> pair holding (dtype, dimensions)
29 //                                     <-  object duck-typed as xla_client.Shape
30 //  ProgramShape                       ->  pair of ([arg_shapes], ret_shape)
31 //  std::vector<Shape>                 <-  sequence of xla_client.Shape objects
32 //  PrimitiveType                      <-  int
33 //  Span<pair<int64, in64>>            <-  sequence of int pairs
34 //  PaddingConfig proto                <-  ducktyped Python proto
35 //  ConvolutionDimensionNumbers proto  <-  ducktyped Python proto
36 //  DotDimensionNumbers proto          <-  ducktyped Python proto
37 //  GatherDimensionNumbers proto       <-  ducktyped Python proto
38 //  ScatterDimensionNumbers proto      <-  ducktyped Python proto
39 //  Span<ReplicaGroup proto>           <-  sequence of ReplicaGroup Python proto
40 //
41 // Arrows indicate whether a conversion only ever occurs in one
42 // direction, or whether it is maintained bidirectionally.
43 //
44 // The Python objects corresponding to C++ Literals have the type:
45 //
46 //   T = ndarray | (T, ...)
47 //
48 // where a terminal numpy ndarray translates to a Literal with a
49 // non-tuple Shape, an XLA primitive element type corresponding to the
50 // ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates
51 // to a tuple-shaped Literal whose tuple components are translated
52 // recursively. For example, if x is a numpy ndarray in Python, with
53 // shape (2, 3) and dtype of dtype('float32'), then x translates to a
54 // Literal with rank 2, dimension 2 and 3, and XLA primitive type
55 // F32. Meanwhile,
56 //
57 //   (x, (x, x), (x,)),
58 //
59 // translates to a tuple-shaped XLA Literal, whose component subshapes
60 // are a 2x3 F32-shaped literal followed by two tuple-shaped literals.
61 //
62 // Shapes output by C++ become Python objects with the type:
63 //
64 //   T            = (dtype, S)
65 //   S            = DIMENSIONS | TUPLE_SHAPES
66 //   DIMENSIONS   = (int, ...)
67 //   TUPLE_SHAPES = (T, ...)
68 //
69 // In the pair described by the T rule, the terminal dtype determines
70 // whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is
71 // dtype('O'), numpy's object dtype, the structure represents a tuple
72 // shape and the expansion of the non-terminal S is
73 // TUPLE_SHAPES. Otherwise, dtype describes a primitive element type
74 // and S expands into DIMENSIONS giving dimension sizes. For example:
75 //
76 //   (dtype('float32'), (3, 5, 7))
77 //
78 // describes a 3x5x7 array of F32s, and
79 //
80 //   (dtype('O'), ((dtype('float32'), (2, 3)),
81 //                 (dtype('float64'), (4, 5))))
82 //
83 // describes a tuple shape with two subshapes: the first a 2x3 F32,
84 // and the other a 4x5 F64.
85 //
86 // The Python int corresponding to a PrimitiveType enum must be valid
87 // per xla_data.proto (e.g. xla_data.PRED, xla_data.F32).
88 //
89 // The SWIG object wrappers generated by this file are not intended
90 // for end use, but rather for internal use in the Python XLA client,
91 // xla_client.py.
92 //
93 // One central reason for the Python-side indirection is that the
94 // Python-side objects produced by the typemaps in this file are
95 // further packaged up by xla_client before being passed on. For
96 // instance, the Python pair produced for a C++ Shape is further
97 // wrapped in a Python class (xla_client.Shape) so as not to expose
98 // the raw pair externally.
99 //
100 // Other SWIG object wrappers (e.g. of Computation) are further
101 // wrapped by xla_client in order to set up a custom destructor that
102 // triggers memory deallocation on the C++ side.
103 //
104 
105 
106 %module(threads="1") xla_data
107 
108 // Keep the GIL except where explicitly specified.
109 %nothread;
110 
111 %include "tensorflow/python/platform/base.i"
112 
113 %{
114 // Must be included first
115 #include "tensorflow/python/lib/core/numpy.h"
116 
117 #include "absl/strings/str_cat.h"
118 #include "absl/strings/str_format.h"
119 #include "tensorflow/compiler/xla/literal.h"
120 #include "tensorflow/compiler/xla/shape_util.h"
121 #include "tensorflow/compiler/xla/xla_data.pb.h"
122 #include "absl/types/span.h"
123 #include "tensorflow/compiler/xla/python/numpy_bridge.h"
124 
125 using namespace xla;
126 using namespace xla::swig;
127 
128 %}
129 
130 // Basic types
131 
132 
133 %typemap(out) std::vector<int> {
134   PyObject* out = PyList_New($1.size());
135   for (int i = 0; i < $1.size(); ++i) {
136     PyList_SET_ITEM(out, i, PyInt_FromLong($1[i]));
137   }
138   $result = out;
139 }
140 
141 %typemap(out) StatusOr<bool> {
142   if ($1.ok()) {
143     $result = PyBool_FromLong($1.ConsumeValueOrDie());
144   } else {
145     PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
146     SWIG_fail;
147   }
148 }
149 
150 %typemap(out) StatusOr<string> {
151   if ($1.ok()) {
152     $result = PyString_FromString($1.ConsumeValueOrDie().c_str());
153   } else {
154     PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
155     SWIG_fail;
156   }
157 }
158 
159 %typemap(out) Status {
160   if (!$1.ok()) {
161     PyErr_SetString(
162         PyExc_RuntimeError, $1.ToString().c_str());
163     SWIG_fail;
164   }
165   Py_INCREF(Py_None);
166   $result = Py_None;
167 }
168 
169 %typemap(in) absl::Span<const int64>
170     (std::vector<int64> temps) {
171   if (!PySequence_Check($input)) {
172     PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
173     SWIG_fail;
174   }
175   const int size = PySequence_Size($input);
176   temps.resize(size);
177   for (int i = 0; i < size; ++i) {
178     PyObject* o = PySequence_GetItem($input, i);
179     PyObject* py_int = numpy::PyNumberToPyInt(o);
180     if (!py_int) {
181       PyErr_SetString(
182           PyExc_TypeError,
183           "Argument sequence element cannot be converted to int");
184       Py_DECREF(o);
185       SWIG_fail;
186     }
187     temps[i] = numpy::PyIntOrPyLongToLong(py_int);
188     if (temps[i] == -1 && PyErr_Occurred()) {
189       Py_DECREF(py_int);
190       Py_DECREF(o);
191       SWIG_fail;
192     }
193     Py_DECREF(py_int);
194     Py_DECREF(o);
195   }
196   $1 = temps;
197 }
198 
199 // Literal
200 
201 %typemap(in) const Literal& (StatusOr<Literal> literal_status) {
202   literal_status = numpy::XlaLiteralFromPyObject($input);
203   if (!literal_status.ok()) {
204     PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
205     SWIG_fail;
206   }
207   $1 = &literal_status.ValueOrDie();
208 }
209 
210 %typemap(out) Literal (StatusOr<numpy::Safe_PyObjectPtr> obj_status) {
211   obj_status = numpy::PyObjectFromXlaLiteral(*$1);
212   if (!obj_status.ok()) {
213     PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str());
214     SWIG_fail;
215   }
216   $result = obj_status.ValueOrDie().release();
217 }
218 
219 %typemap(out) StatusOr<Literal> (StatusOr<numpy::Safe_PyObjectPtr> obj_status) {
220   if (!$1.ok()) {
221     PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
222     SWIG_fail;
223   }
224   obj_status = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
225   if (!obj_status.ok()) {
226     PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str());
227     SWIG_fail;
228   }
229   $result = obj_status.ValueOrDie().release();
230 }
231 
232 %typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
233   if (!PySequence_Check($input)) {
234     PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
235     SWIG_fail;
236   }
237   const int size = PySequence_Size($input);
238   for (int i = 0; i < size; ++i) {
239     PyObject* o = PySequence_GetItem($input, i);
240     StatusOr<Literal> literal_status = numpy::XlaLiteralFromPyObject(o);
241     if (!literal_status.ok()) {
242       PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
243       Py_DECREF(o);
244       SWIG_fail;
245     }
246     temps.push_back(literal_status.ConsumeValueOrDie());
247     Py_DECREF(o);
248   }
249   $1 = &temps;
250 }
251 
252 // OpMetadata
253 
254 %typemap(in) const OpMetadata& (OpMetadata temp) {
255   StatusOr<OpMetadata> statusor = numpy::OpMetadataFromPyObject($input);
256   if (!statusor.ok()) {
257     PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
258     SWIG_fail;
259   }
260   temp = std::move(statusor).ValueOrDie();
261   $1 = &temp;
262 }
263 
264 // Shape
265 
266 %typemap(out) const Shape& {
267   $result = numpy::PyShapeInfoFromXlaShape(*$1).release();
268 }
269 
270 %typemap(out) StatusOr<Shape> {
271   if ($1.ok()) {
272     $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()).release();
273   } else {
274     PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
275     SWIG_fail;
276   }
277 }
278 
279 
280 %typemap(out) StatusOr<ProgramShape> {
281   if ($1.ok()) {
282     $result = numpy::PyProgramShapeInfoFromXlaProgramShape(
283         $1.ConsumeValueOrDie()).release();
284   } else {
285     PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
286     SWIG_fail;
287   }
288 }
289 
290 
291 %typemap(in) const Shape& (Shape temp) {
292   StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
293   if (!statusor.ok()) {
294     PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
295     SWIG_fail;
296   }
297   temp = std::move(statusor).ValueOrDie();
298   $1 = &temp;
299 }
300 
301 %typemap(in) const absl::optional<Shape>& (
302     absl::optional<Shape> temp) {
303   if ($input == Py_None) {
304     temp = absl::nullopt;
305     $1 = &temp;
306   } else {
307     StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
308     if (!statusor.ok()) {
309       PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
310       SWIG_fail;
311     }
312     temp = std::move(statusor).ValueOrDie();
313     $1 = &temp;
314   }
315 }
316 
317 %typemap(out) std::unique_ptr<Shape> {
318   $result = numpy::PyShapeInfoFromXlaShape(*$1).release();
319 }
320 
321 %typemap(in) const std::vector<Shape>& (std::vector<Shape> temps) {
322   if (!PySequence_Check($input)) {
323     PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
324     SWIG_fail;
325   }
326   const int size = PySequence_Size($input);
327   for (int i = 0; i < size; ++i) {
328     PyObject* o = PySequence_GetItem($input, i);
329     StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
330     Py_DECREF(o);
331     if (!statusor.ok()) {
332       PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
333       SWIG_fail;
334     }
335     temps.push_back(statusor.ConsumeValueOrDie());
336   }
337   $1 = &temps;
338 }
339 
340 %typemap(in) const std::vector<absl::optional<Shape> >& (
341     std::vector<absl::optional<Shape> > temps) {
342   if (!PySequence_Check($input)) {
343     PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
344     SWIG_fail;
345   }
346   const int size = PySequence_Size($input);
347   for (int i = 0; i < size; ++i) {
348     PyObject* o = PySequence_GetItem($input, i);
349     if (o == Py_None) {
350       temps.push_back(absl::nullopt);
351     } else {
352       StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
353       Py_DECREF(o);
354       if (!statusor.ok()) {
355         PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
356         SWIG_fail;
357       }
358       temps.push_back(statusor.ConsumeValueOrDie());
359     }
360   }
361   $1 = &temps;
362 }
363 
364 // PrimitiveType
365 
366 %typemap(in) PrimitiveType {
367   PyObject* py_int = numpy::PyNumberToPyInt($input);
368   if (!py_int) {
369     PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int");
370     SWIG_fail;
371   }
372   const long value = numpy::PyIntOrPyLongToLong(py_int);
373   if (value == -1 && PyErr_Occurred()) {
374     Py_DECREF(py_int);
375     SWIG_fail;
376   }
377   if (!PrimitiveType_IsValid(value)) {
378     PyErr_SetString(
379         PyExc_TypeError, "Argument not valid for PrimitiveType enum");
380     Py_DECREF(py_int);
381     SWIG_fail;
382   }
383   $1 = static_cast<PrimitiveType>(value);
384 }
385 
386 // Span<pair<int64, in64>>
387 
388 %typemap(in) absl::Span<const std::pair<int64, int64> >
389     (std::vector<std::pair<int64, int64> > temps) {
390   if (!PySequence_Check($input)) {
391     PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
392     SWIG_fail;
393   }
394   const int size = PySequence_Size($input);
395   temps.reserve(size);
396   for (int i = 0; i < size; ++i) {
397     PyObject* o = PySequence_GetItem($input, i);
398     if (!o) {
399       SWIG_fail;
400     }
401     PyObject* first = PyTuple_GetItem(o, 0);
402     if (!first) {
403       Py_DECREF(o);
404       SWIG_fail;
405     }
406     PyObject* first_pyint = numpy::PyNumberToPyInt(first);
407     if (!first_pyint) {
408       PyErr_SetString(
409           PyExc_TypeError,
410           "First pair item cannot be converted to int");
411       Py_DECREF(o);
412       SWIG_fail;
413     }
414     PyObject* second = PyTuple_GetItem(o, 1);
415     if (!second) {
416       Py_DECREF(o);
417       Py_DECREF(first_pyint);
418       SWIG_fail;
419     }
420     PyObject* second_pyint = numpy::PyNumberToPyInt(second);
421     if (!second_pyint) {
422       PyErr_SetString(
423           PyExc_TypeError,
424           "Second pair item cannot be converted to int");
425       Py_DECREF(o);
426       Py_DECREF(first_pyint);
427       SWIG_fail;
428     }
429     const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint);
430     if (first_value == -1 && PyErr_Occurred()) {
431       Py_DECREF(o);
432       Py_DECREF(first_pyint);
433       Py_DECREF(second_pyint);
434       SWIG_fail;
435     }
436     const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint);
437     if (second_value == -1 && PyErr_Occurred()) {
438       Py_DECREF(o);
439       Py_DECREF(first_pyint);
440       Py_DECREF(second_pyint);
441       SWIG_fail;
442     }
443     temps.push_back(std::make_pair(first_value, second_value));
444     Py_DECREF(o);
445   }
446   $1 = temps;
447 }
448 
449 // DotDimensionNumbers
450 
451 %typemap(in) const DotDimensionNumbers&
452     (DotDimensionNumbers dimension_numbers) {
453   if (!HandleRepeatedInt64Attribute(
454         $input, "lhs_contracting_dimensions",
455         dimension_numbers.mutable_lhs_contracting_dimensions())) {
456     SWIG_fail;
457   }
458   if (!HandleRepeatedInt64Attribute(
459         $input, "rhs_contracting_dimensions",
460         dimension_numbers.mutable_rhs_contracting_dimensions())) {
461     SWIG_fail;
462   }
463   if (!HandleRepeatedInt64Attribute(
464         $input, "lhs_batch_dimensions",
465         dimension_numbers.mutable_lhs_batch_dimensions())) {
466     SWIG_fail;
467   }
468   if (!HandleRepeatedInt64Attribute(
469         $input, "rhs_batch_dimensions",
470         dimension_numbers.mutable_rhs_batch_dimensions())) {
471     SWIG_fail;
472   }
473 
474   $1 = &dimension_numbers;
475 }
476 
477 // PaddingConfig
478 
479 %typemap(in) const PaddingConfig&
480     (PaddingConfig padding_config) {
481   PyObject* dimensions = PyObject_GetAttrString($input, "dimensions");
482   if (!dimensions) {
483     SWIG_fail;
484   }
485 
486   int length = PySequence_Size(dimensions);
487   if (length == -1) {
488     Py_DECREF(dimensions);
489     SWIG_fail;
490   }
491 
492   for (int i = 0; i < length; ++i) {
493     PyObject* item = PySequence_GetItem(dimensions, i);
494     if (!item) {
495       Py_DECREF(dimensions);
496       SWIG_fail;
497     }
498     int64 edge_padding_low, edge_padding_high, interior_padding;
499     if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low)
500         || !GetIntAttr(item, "edge_padding_high", &edge_padding_high)
501         || !GetIntAttr(item, "interior_padding", &interior_padding)) {
502       Py_DECREF(item);
503       Py_DECREF(dimensions);
504       SWIG_fail;
505     }
506     Py_DECREF(item);
507 
508     PaddingConfig::PaddingConfigDimension* dimension =
509         padding_config.add_dimensions();
510     dimension->set_edge_padding_low(edge_padding_low);
511     dimension->set_edge_padding_high(edge_padding_high);
512     dimension->set_interior_padding(interior_padding);
513   }
514   Py_DECREF(dimensions);
515 
516   $1 = &padding_config;
517 }
518 
519 // ConvolutionDimensionNumbers
520 
521 %typemap(in) const ConvolutionDimensionNumbers&
522     (ConvolutionDimensionNumbers dimension_numbers) {
523   int64 value;
524 
525   if (!GetIntAttr($input, "input_batch_dimension", &value)) {
526     SWIG_fail;
527   }
528   dimension_numbers.set_input_batch_dimension(value);
529 
530   if (!GetIntAttr($input, "input_feature_dimension", &value)) {
531     SWIG_fail;
532   }
533   dimension_numbers.set_input_feature_dimension(value);
534 
535   if (!GetIntAttr($input, "output_batch_dimension", &value)) {
536     SWIG_fail;
537   }
538   dimension_numbers.set_output_batch_dimension(value);
539 
540   if (!GetIntAttr($input, "output_feature_dimension", &value)) {
541     SWIG_fail;
542   }
543   dimension_numbers.set_output_feature_dimension(value);
544 
545   if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) {
546     SWIG_fail;
547   }
548   dimension_numbers.set_kernel_output_feature_dimension(value);
549 
550   if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) {
551     SWIG_fail;
552   }
553   dimension_numbers.set_kernel_input_feature_dimension(value);
554 
555   if (!HandleRepeatedInt64Attribute(
556         $input, "input_spatial_dimensions",
557         dimension_numbers.mutable_input_spatial_dimensions())) {
558     SWIG_fail;
559   }
560   if (!HandleRepeatedInt64Attribute(
561         $input, "kernel_spatial_dimensions",
562         dimension_numbers.mutable_kernel_spatial_dimensions())) {
563     SWIG_fail;
564   }
565   if (!HandleRepeatedInt64Attribute(
566         $input, "output_spatial_dimensions",
567         dimension_numbers.mutable_output_spatial_dimensions())) {
568     SWIG_fail;
569   }
570 
571   $1 = &dimension_numbers;
572 }
573 
574 // GatherDimensionNumbers
575 
576 %typemap(in) const GatherDimensionNumbers&
577     (GatherDimensionNumbers dimension_numbers) {
578   if (!HandleRepeatedInt64Attribute(
579         $input, "offset_dims",
580         dimension_numbers.mutable_offset_dims())) {
581     SWIG_fail;
582   }
583   if (!HandleRepeatedInt64Attribute(
584         $input, "collapsed_slice_dims",
585         dimension_numbers.mutable_collapsed_slice_dims())) {
586     SWIG_fail;
587   }
588   if (!HandleRepeatedInt64Attribute(
589         $input, "start_index_map",
590         dimension_numbers.mutable_start_index_map())) {
591     SWIG_fail;
592   }
593 
594   int64 value;
595   if (!GetIntAttr($input, "index_vector_dim", &value)) {
596     SWIG_fail;
597   }
598   dimension_numbers.set_index_vector_dim(value);
599 
600   $1 = &dimension_numbers;
601 }
602 
603 // ScatterDimensionNumbers
604 
605 %typemap(in) const ScatterDimensionNumbers&
606     (ScatterDimensionNumbers dimension_numbers) {
607   if (!HandleRepeatedInt64Attribute(
608         $input, "update_window_dims",
609         dimension_numbers.mutable_update_window_dims())) {
610     SWIG_fail;
611   }
612   if (!HandleRepeatedInt64Attribute(
613         $input, "inserted_window_dims",
614         dimension_numbers.mutable_inserted_window_dims())) {
615     SWIG_fail;
616   }
617   if (!HandleRepeatedInt64Attribute(
618         $input, "scatter_dims_to_operand_dims",
619         dimension_numbers.mutable_scatter_dims_to_operand_dims())) {
620     SWIG_fail;
621   }
622 
623   int64 value;
624   if (!GetIntAttr($input, "index_vector_dim", &value)) {
625     SWIG_fail;
626   }
627   dimension_numbers.set_index_vector_dim(value);
628 
629   $1 = &dimension_numbers;
630 }
631 
632 // Span<const ReplicaGroup>
633 
634 %typemap(in) absl::Span<const ReplicaGroup >
635     (std::vector<ReplicaGroup > temps) {
636   if (!PySequence_Check($input)) {
637     PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
638     SWIG_fail;
639   }
640   const int size = PySequence_Size($input);
641   temps.reserve(size);
642   for (int i = 0; i < size; ++i) {
643     PyObject* o = PySequence_GetItem($input, i);
644     ReplicaGroup rgrp;
645     if (!HandleRepeatedInt64Attribute(
646             o, "replica_ids",
647             rgrp.mutable_replica_ids())) {
648         SWIG_fail;
649     }
650     temps.push_back(rgrp);
651     Py_DECREF(o);
652   }
653   $1 = temps;
654 }
655