1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // These functions transform Python/Numpy data structures to XLA data
17 // structures and vice versa, performing copies where
18 // appropriate. Python tuples and Numpy ndarrays translate to XLA
19 // tuples and XLA literals, respectively, and Numpy shape/dtype
20 // information is translated to XLA shape information.
21 
22 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_
23 #define TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_
24 
25 #include <algorithm>
26 #include <memory>
27 
28 #include "absl/types/span.h"
29 #include "tensorflow/compiler/xla/literal.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/python/lib/core/numpy.h"
32 
33 namespace xla {
34 
35 namespace swig {
36 
37 namespace numpy {
38 
39 struct PyDecrefDeleter {
operatorPyDecrefDeleter40   void operator()(PyObject* p) const { Py_DECREF(p); }
41 };
42 
43 // Safe container for an owned PyObject. On destruction, the reference count of
44 // the contained object will be decremented.
45 using Safe_PyObjectPtr = std::unique_ptr<PyObject, PyDecrefDeleter>;
46 
47 Safe_PyObjectPtr make_safe(PyObject* object);
48 
49 // Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy
50 // dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and
51 // vice versa.
52 int PrimitiveTypeToNumpyType(PrimitiveType primitive_type);
53 PrimitiveType NumpyTypeToPrimitiveType(int np_type);
54 
55 // Determines whether an integer-encoded Numpy dtype is valid,
56 // i.e. has a supported conversion to an XLA PrimitiveType.
57 bool NumpyTypeIsValid(int np_type);
58 
59 // Converts XLA shape information into a Python pair of the form
60 // (numpy dtype, dimensions). If the XLA shape represents a tuple,
61 // then the numpy dtype is NPY_OBJECT ('O') and `dimensions` is a
62 // Python tuple of shape-description pairs, created
63 // recursively. Otherwise, `dimensions` is a Python tuple-of-integers
64 // providing the array dimensions.
65 //
66 // The return value is a new reference.
67 Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape);
68 
69 // Returns a pair of (arg_shapes, result_shape), where arg_shapes is a tuple
70 // of argument shapes and result_shape is the result shape. Each shape is as
71 // described in in PyShapeInfoFromXlaShape's comment.
72 Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape(
73     const ProgramShape& shape);
74 
75 // Converts a Python object with a method interface mathing that of
76 // xla_client.Shape into an XLA Shape object.
77 //
78 // The return value is a new reference.
79 StatusOr<Shape> XlaShapeFromPyShape(PyObject* o);
80 
81 // Converts a PyObject that represents operation metadata into protocol buffer
82 // form.
83 StatusOr<OpMetadata> OpMetadataFromPyObject(PyObject* o);
84 
85 // Converts an XLA literal to a Python object, either a Numpy ndarray
86 // or a nested Python tuple thereof.
87 //
88 // To avoid transferring ownership of the data buffers that underlie
89 // PyArrays and XLA literals, this function makes deep copies of all
90 // array data.
91 //
92 // The return value is a new reference.
93 StatusOr<Safe_PyObjectPtr> PyObjectFromXlaLiteral(const LiteralSlice& literal);
94 
95 // Converts a Numpy ndarray or a nested Python tuple thereof to a
96 // corresponding XLA literal.
97 //
98 // To avoid transferring ownership of the data buffers that underlie
99 // PyArrays and XLA literals, this function makes deep copies of all
100 // array data.
101 StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o);
102 
103 // The following functions copy array data from the buffers underlying Numpy
104 // ndarrays into those underlying XLA literals, and vice versa.
105 
106 Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
107                                Literal* literal);
108 
109 Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
110                                PyArrayObject* py_array);
111 
112 template <typename NativeT>
CopyNumpyArrayToLiteral(PyArrayObject * py_array,Literal * literal)113 void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) {
114   NativeT* source = static_cast<NativeT*>(PyArray_DATA(py_array));
115   auto dest = literal->data<NativeT>();
116   std::copy(source, source + PyArray_SIZE(py_array), dest.data());
117 }
118 
119 template <typename NativeT>
CopyLiteralToNumpyArray(const LiteralSlice & literal,PyArrayObject * py_array)120 void CopyLiteralToNumpyArray(const LiteralSlice& literal,
121                              PyArrayObject* py_array) {
122   NativeT* dest = static_cast<NativeT*>(PyArray_DATA(py_array));
123   auto source = literal.data<NativeT>();
124   std::copy(source.begin(), source.end(), dest);
125 }
126 
127 // Safely returns a repr of the given Python object o as a C++ string.
128 string PyObjectCppRepr(PyObject* o);
129 
130 // Workarounds for Python 2 and 3 interop
131 
132 PyObject* LongToPyIntOrPyLong(long x);  // NOLINT
133 long PyIntOrPyLongToLong(PyObject* o);  // NOLINT
134 bool CheckPyIntOrLong(PyObject* o);
135 PyObject* PyNumberToPyInt(PyObject* o);
136 
137 }  // namespace numpy
138 
139 // Miscellaneous swig helpers that don't have a better home.
140 
141 bool GetIntAttr(PyObject* o, const char* field, int64* result);
142 
143 // Returns "ok"; true if there is no error, false if there was an error.
144 bool HandleStringAttribute(PyObject* o, const char* attr_name,
145                            std::function<void(string s)> f);
146 bool HandleBoolAttribute(PyObject* o, const char* attr_name,
147                          std::function<void(bool b)> f);
148 
149 bool HandleRepeatedInt64Attribute(
150     PyObject* o, const char* attr_name,
151     tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>* field);
152 
153 }  // namespace swig
154 
155 }  // namespace xla
156 
157 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_
158