1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // These functions transform Python/Numpy data structures to XLA data
17 // structures and vice versa, performing copies where
18 // appropriate. Python tuples and Numpy ndarrays translate to XLA
19 // tuples and XLA literals, respectively, and Numpy shape/dtype
20 // information is translated to XLA shape information.
21
22 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_
23 #define TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_
24
25 #include <algorithm>
26 #include <memory>
27
28 #include "absl/types/span.h"
29 #include "tensorflow/compiler/xla/literal.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/python/lib/core/numpy.h"
32
33 namespace xla {
34
35 namespace swig {
36
37 namespace numpy {
38
39 struct PyDecrefDeleter {
operatorPyDecrefDeleter40 void operator()(PyObject* p) const { Py_DECREF(p); }
41 };
42
43 // Safe container for an owned PyObject. On destruction, the reference count of
44 // the contained object will be decremented.
45 using Safe_PyObjectPtr = std::unique_ptr<PyObject, PyDecrefDeleter>;
46
47 Safe_PyObjectPtr make_safe(PyObject* object);
48
49 // Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy
50 // dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and
51 // vice versa.
52 int PrimitiveTypeToNumpyType(PrimitiveType primitive_type);
53 PrimitiveType NumpyTypeToPrimitiveType(int np_type);
54
55 // Determines whether an integer-encoded Numpy dtype is valid,
56 // i.e. has a supported conversion to an XLA PrimitiveType.
57 bool NumpyTypeIsValid(int np_type);
58
59 // Converts XLA shape information into a Python pair of the form
60 // (numpy dtype, dimensions). If the XLA shape represents a tuple,
61 // then the numpy dtype is NPY_OBJECT ('O') and `dimensions` is a
62 // Python tuple of shape-description pairs, created
63 // recursively. Otherwise, `dimensions` is a Python tuple-of-integers
64 // providing the array dimensions.
65 //
66 // The return value is a new reference.
67 Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape);
68
69 // Returns a pair of (arg_shapes, result_shape), where arg_shapes is a tuple
70 // of argument shapes and result_shape is the result shape. Each shape is as
71 // described in in PyShapeInfoFromXlaShape's comment.
72 Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape(
73 const ProgramShape& shape);
74
75 // Converts a Python object with a method interface mathing that of
76 // xla_client.Shape into an XLA Shape object.
77 //
78 // The return value is a new reference.
79 StatusOr<Shape> XlaShapeFromPyShape(PyObject* o);
80
81 // Converts a PyObject that represents operation metadata into protocol buffer
82 // form.
83 StatusOr<OpMetadata> OpMetadataFromPyObject(PyObject* o);
84
85 // Converts an XLA literal to a Python object, either a Numpy ndarray
86 // or a nested Python tuple thereof.
87 //
88 // To avoid transferring ownership of the data buffers that underlie
89 // PyArrays and XLA literals, this function makes deep copies of all
90 // array data.
91 //
92 // The return value is a new reference.
93 StatusOr<Safe_PyObjectPtr> PyObjectFromXlaLiteral(const LiteralSlice& literal);
94
95 // Converts a Numpy ndarray or a nested Python tuple thereof to a
96 // corresponding XLA literal.
97 //
98 // To avoid transferring ownership of the data buffers that underlie
99 // PyArrays and XLA literals, this function makes deep copies of all
100 // array data.
101 StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o);
102
103 // The following functions copy array data from the buffers underlying Numpy
104 // ndarrays into those underlying XLA literals, and vice versa.
105
106 Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
107 Literal* literal);
108
109 Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
110 PyArrayObject* py_array);
111
112 template <typename NativeT>
CopyNumpyArrayToLiteral(PyArrayObject * py_array,Literal * literal)113 void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) {
114 NativeT* source = static_cast<NativeT*>(PyArray_DATA(py_array));
115 auto dest = literal->data<NativeT>();
116 std::copy(source, source + PyArray_SIZE(py_array), dest.data());
117 }
118
119 template <typename NativeT>
CopyLiteralToNumpyArray(const LiteralSlice & literal,PyArrayObject * py_array)120 void CopyLiteralToNumpyArray(const LiteralSlice& literal,
121 PyArrayObject* py_array) {
122 NativeT* dest = static_cast<NativeT*>(PyArray_DATA(py_array));
123 auto source = literal.data<NativeT>();
124 std::copy(source.begin(), source.end(), dest);
125 }
126
127 // Safely returns a repr of the given Python object o as a C++ string.
128 string PyObjectCppRepr(PyObject* o);
129
130 // Workarounds for Python 2 and 3 interop
131
132 PyObject* LongToPyIntOrPyLong(long x); // NOLINT
133 long PyIntOrPyLongToLong(PyObject* o); // NOLINT
134 bool CheckPyIntOrLong(PyObject* o);
135 PyObject* PyNumberToPyInt(PyObject* o);
136
137 } // namespace numpy
138
139 // Miscellaneous swig helpers that don't have a better home.
140
141 bool GetIntAttr(PyObject* o, const char* field, int64* result);
142
143 // Returns "ok"; true if there is no error, false if there was an error.
144 bool HandleStringAttribute(PyObject* o, const char* attr_name,
145 std::function<void(string s)> f);
146 bool HandleBoolAttribute(PyObject* o, const char* attr_name,
147 std::function<void(bool b)> f);
148
149 bool HandleRepeatedInt64Attribute(
150 PyObject* o, const char* attr_name,
151 tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>* field);
152
153 } // namespace swig
154
155 } // namespace xla
156
157 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_
158