/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/numpy_bridge.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace swig { namespace numpy { Safe_PyObjectPtr make_safe(PyObject* object) { return Safe_PyObjectPtr(object); } int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { switch (primitive_type) { case PRED: return NPY_BOOL; case S8: return NPY_INT8; case S16: return NPY_INT16; case S32: return NPY_INT32; case S64: return NPY_INT64; case U8: return NPY_UINT8; case U16: return NPY_UINT16; case U32: return NPY_UINT32; case U64: return NPY_UINT64; case F16: return NPY_FLOAT16; case F32: return NPY_FLOAT32; case F64: return NPY_FLOAT64; case C64: return NPY_COMPLEX64; case C128: return NPY_COMPLEX128; case TUPLE: return NPY_OBJECT; default: LOG(FATAL) << "No Numpy type for XLA primitive type " << primitive_type; } } PrimitiveType NumpyTypeToPrimitiveType(int np_type) { switch (np_type) { case NPY_BOOL: return PRED; case NPY_INT8: return S8; case NPY_INT16: return S16; case NPY_INT32: return S32; case NPY_INT64: return S64; case NPY_UINT8: return U8; case NPY_UINT16: return U16; case NPY_UINT32: return U32; case NPY_UINT64: return U64; case NPY_FLOAT16: return F16; case NPY_FLOAT32: return F32; case NPY_FLOAT64: return F64; case NPY_COMPLEX64: return C64; case NPY_COMPLEX128: return C128; case NPY_OBJECT: return TUPLE; default: LOG(FATAL) << "No XLA primitive type for Numpy type " << np_type; } } bool NumpyTypeIsValid(int np_type) { switch (np_type) { case NPY_BOOL: case NPY_INT8: case NPY_INT16: case NPY_INT32: case NPY_INT64: case NPY_UINT8: case NPY_UINT16: case NPY_UINT32: case NPY_UINT64: case NPY_FLOAT16: case NPY_FLOAT32: case NPY_FLOAT64: case NPY_COMPLEX64: case NPY_COMPLEX128: case NPY_OBJECT: return true; default: return false; } } Safe_PyObjectPtr PyShapeInfoFromXlaShape(const Shape& shape) { int np_typenum = PrimitiveTypeToNumpyType(shape.element_type()); PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum); Safe_PyObjectPtr dimensions; if (shape.IsTuple()) { int num_elements = ShapeUtil::TupleElementCount(shape); dimensions = make_safe(PyTuple_New(ShapeUtil::TupleElementCount(shape))); for (int i = 0; i < num_elements; ++i) { PyTuple_SET_ITEM( dimensions.get(), i, PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i)) .release()); } } else { int rank = shape.rank(); dimensions = make_safe(PyTuple_New(rank)); for (int i = 0; i < rank; ++i) { PyTuple_SET_ITEM(dimensions.get(), i, LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i))); } } return make_safe(PyTuple_Pack(2, np_dtype, dimensions.release())); } Safe_PyObjectPtr PyProgramShapeInfoFromXlaProgramShape( const ProgramShape& shape) { Safe_PyObjectPtr arg_shapes = make_safe(PyTuple_New(shape.parameters_size())); for (int i = 0; i < shape.parameters_size(); ++i) { PyTuple_SET_ITEM(arg_shapes.get(), i, PyShapeInfoFromXlaShape(shape.parameters(i)).release()); } Safe_PyObjectPtr result_shape = PyShapeInfoFromXlaShape(shape.result()); return make_safe( PyTuple_Pack(2, arg_shapes.release(), result_shape.release())); } // Precondition: o->ob_type == &PyArrayDescr_Type static int NumpyTypenum(PyObject* o) { return reinterpret_cast(o)->type_num; } // Extracts the string held inside r and returns it as a C++ string. // // NOTE: this is an internal helper for conversion to a C++, and so decrefs r. static string ExtractStringAndDecref(PyObject* r) { auto error = [r] { return absl::StrFormat("", r); }; if (r == nullptr) { return error(); } #if PY_MAJOR_VERSION < 3 string result = PyString_AsString(r); #else PyObject* bytes = PyUnicode_AsEncodedString(r, 0, 0); if (bytes == nullptr) { return error(); } CHECK(PyBytes_Check(bytes)); string result = PyBytes_AsString(bytes); Py_DECREF(bytes); #endif Py_DECREF(r); return result; } // Safely returns a str of the given Python object o as a C++ string. static string PyObjectCppStr(PyObject* o) { PyObject* s = PyObject_Str(o); return ExtractStringAndDecref(s); } string PyObjectCppRepr(PyObject* o) { PyObject* r = PyObject_Repr(o); return ExtractStringAndDecref(r); } StatusOr XlaShapeFromPyShape(PyObject* o) { auto error = [o](const string& prefix) { return InvalidArgument("%s; got %s", prefix.c_str(), PyObjectCppRepr(o).c_str()); }; auto call_method = [o, &error](const string& method) -> StatusOr { PyObject* result = PyObject_CallMethod(o, const_cast(method.c_str()), nullptr); if (result == nullptr) { return error( absl::StrCat("Failed to call method of shape object:", method)); } return result; }; PyObject* np_type; TF_ASSIGN_OR_RETURN(np_type, call_method("numpy_dtype")); if (np_type->ob_type != &PyArrayDescr_Type) { return error( "Return value of shape method numpy_dtype " "is not an integer numpy dtype"); } if (!NumpyTypeIsValid(NumpyTypenum(np_type))) { return error( "Return value of shape method numpy_dtype " "is not a valid integer numpy dtype"); } const PrimitiveType element_type = NumpyTypeToPrimitiveType(NumpyTypenum(np_type)); Py_DECREF(np_type); if (element_type == TUPLE) { PyObject* py_subshapes; TF_ASSIGN_OR_RETURN(py_subshapes, call_method("tuple_shapes")); if (!PyTuple_Check(py_subshapes)) { return error( "Return value of Shape method tuple_shapes() is not a tuple"); } const int length = PyTuple_Size(py_subshapes); std::vector subshapes; subshapes.reserve(length); for (int i = 0; i < length; i++) { TF_ASSIGN_OR_RETURN( const Shape& subshape, XlaShapeFromPyShape(PyTuple_GetItem(py_subshapes, i))); subshapes.push_back(subshape); } Py_DECREF(py_subshapes); return ShapeUtil::MakeTupleShape(subshapes); } else { PyObject* py_dimensions; PyObject* py_minor_to_major; TF_ASSIGN_OR_RETURN(py_dimensions, call_method("dimensions")); TF_ASSIGN_OR_RETURN(py_minor_to_major, call_method("minor_to_major")); if (!PyTuple_Check(py_dimensions)) { return error("Return value of Shape method dimensions() is not a tuple"); } if (py_minor_to_major != Py_None && !PyTuple_Check(py_minor_to_major)) { return error( "Return value of Shape method minor_to_major() is neither a tuple " "nor None"); } const int length = PyTuple_Size(py_dimensions); if (py_minor_to_major != Py_None && length != PyTuple_Size(py_minor_to_major)) { return error( "Shape methods dimensions() and minor_to_major() return " "different-length tuples"); } std::vector dimensions(length); std::vector minor_to_major(length); for (int i = 0; i < length; i++) { dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i)); if (dimensions[i] == -1 && PyErr_Occurred()) { return error("Dimension is not an int"); } if (py_minor_to_major != Py_None) { minor_to_major[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_minor_to_major, i)); if (minor_to_major[i] == -1 && PyErr_Occurred()) { return error("Minor-to-major value is not an int"); } } } bool with_layout = py_minor_to_major != Py_None; Py_DECREF(py_dimensions); Py_DECREF(py_minor_to_major); if (with_layout) { return ShapeUtil::MakeShapeWithLayout(element_type, dimensions, minor_to_major); } else { return ShapeUtil::MakeShape(element_type, dimensions); } } } // Helper that retrieves the member with attr_name, stringifies it if is not // None, and returns it as a C++ string. static absl::optional GetAttrAsString(PyObject* o, const string& attr_name) { if (!PyObject_HasAttrString(o, attr_name.c_str())) { return absl::nullopt; } PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str()); if (attr == Py_None) { Py_DECREF(attr); return absl::nullopt; } string result = PyObjectCppStr(attr); Py_DECREF(attr); return result; } // Helper that retrieves the member with attr_name, checks that it is an integer // if it is not None, and returns it as an int32 value. static absl::optional GetAttrAsInt32(PyObject* o, const string& attr_name) { if (!PyObject_HasAttrString(o, attr_name.c_str())) { return absl::nullopt; } PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str()); if (attr == Py_None) { Py_DECREF(attr); return absl::nullopt; } if (!CheckPyIntOrLong(attr)) { Py_DECREF(attr); return absl::nullopt; } long value = PyIntOrPyLongToLong(attr); // NOLINT Py_DECREF(attr); if (value == -1 && PyErr_Occurred() != nullptr) { return absl::nullopt; } if (static_cast(value) != value) { return absl::nullopt; } return value; } StatusOr OpMetadataFromPyObject(PyObject* o) { OpMetadata result; absl::optional op_type = GetAttrAsString(o, "op_type"); if (op_type.has_value()) { result.set_op_type(op_type.value()); } absl::optional op_name = GetAttrAsString(o, "op_name"); if (op_name.has_value()) { result.set_op_name(op_name.value()); } absl::optional source_file = GetAttrAsString(o, "source_file"); if (source_file.has_value()) { result.set_source_file(source_file.value()); } absl::optional source_line = GetAttrAsInt32(o, "source_line"); if (source_line.has_value()) { result.set_source_line(source_line.value()); } return result; } StatusOr PyObjectFromXlaLiteral(const LiteralSlice& literal) { if (literal.shape().IsTuple()) { int num_elements = ShapeUtil::TupleElementCount(literal.shape()); std::vector elems(num_elements); for (int i = 0; i < num_elements; i++) { TF_ASSIGN_OR_RETURN(elems[i], PyObjectFromXlaLiteral(LiteralSlice(literal, {i}))); } Safe_PyObjectPtr tuple = make_safe(PyTuple_New(num_elements)); for (int i = 0; i < num_elements; i++) { PyTuple_SET_ITEM(tuple.get(), i, elems[i].release()); } return tuple; } else { int rank = literal.shape().rank(); std::vector dimensions(rank); // NOLINT - PyArray requires a long* for (int i = 0; i < rank; i++) { dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i); } int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type()); Safe_PyObjectPtr array = make_safe( PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0)); TF_RETURN_IF_ERROR(CopyLiteralToNumpyArray( np_type, literal, reinterpret_cast(array.get()))); return array; } } StatusOr XlaLiteralFromPyObject(PyObject* o) { if (PyTuple_Check(o)) { int num_elements = PyTuple_Size(o); std::vector elements; elements.reserve(num_elements); for (int i = 0; i < num_elements; i++) { PyObject* element = PyTuple_GetItem(o, i); TF_ASSIGN_OR_RETURN(auto literal, XlaLiteralFromPyObject(element)); elements.push_back(std::move(literal)); } return LiteralUtil::MakeTupleOwned(std::move(elements)); } else if (PyArray_Check(o)) { PyArrayObject* py_array = reinterpret_cast(o); int rank = PyArray_NDIM(py_array); std::vector dimensions(rank); for (int i = 0; i < rank; i++) { dimensions[i] = PyArray_DIM(py_array, i); } int np_type = PyArray_TYPE(py_array); auto literal = LiteralUtil::CreateFromDimensions( NumpyTypeToPrimitiveType(np_type), dimensions); TF_RETURN_IF_ERROR(CopyNumpyArrayToLiteral(np_type, py_array, &literal)); return std::move(literal); } else { return InvalidArgument( "Non-tuple or Numpy array encountered in conversion to XLA literal."); } } Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, Literal* literal) { switch (np_type) { case NPY_BOOL: CopyNumpyArrayToLiteral(py_array, literal); break; case NPY_INT8: CopyNumpyArrayToLiteral(py_array, literal); break; case NPY_INT16: CopyNumpyArrayToLiteral(py_array, literal); break; case NPY_INT32: CopyNumpyArrayToLiteral(py_array, literal); break; case NPY_INT64: CopyNumpyArrayToLiteral(py_array, literal); break; case NPY_UINT8: CopyNumpyArrayToLiteral(py_array, literal); break; case NPY_UINT16: CopyNumpyArrayToLiteral(py_array, literal); break; case NPY_UINT32: CopyNumpyArrayToLiteral(py_array, literal); break; case NPY_UINT64: CopyNumpyArrayToLiteral(py_array, literal); break; case NPY_FLOAT16: CopyNumpyArrayToLiteral(py_array, literal); break; case NPY_FLOAT32: CopyNumpyArrayToLiteral(py_array, literal); break; case NPY_FLOAT64: CopyNumpyArrayToLiteral(py_array, literal); break; case NPY_COMPLEX64: CopyNumpyArrayToLiteral(py_array, literal); break; case NPY_COMPLEX128: CopyNumpyArrayToLiteral(py_array, literal); break; default: return InvalidArgument( "No XLA literal container for Numpy type number: %d", np_type); } return Status::OK(); } Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, PyArrayObject* py_array) { switch (np_type) { case NPY_BOOL: CopyLiteralToNumpyArray(literal, py_array); break; case NPY_INT8: CopyLiteralToNumpyArray(literal, py_array); break; case NPY_INT16: CopyLiteralToNumpyArray(literal, py_array); break; case NPY_INT32: CopyLiteralToNumpyArray(literal, py_array); break; case NPY_INT64: CopyLiteralToNumpyArray(literal, py_array); break; case NPY_UINT8: CopyLiteralToNumpyArray(literal, py_array); break; case NPY_UINT16: CopyLiteralToNumpyArray(literal, py_array); break; case NPY_UINT32: CopyLiteralToNumpyArray(literal, py_array); break; case NPY_UINT64: CopyLiteralToNumpyArray(literal, py_array); break; case NPY_FLOAT16: CopyLiteralToNumpyArray(literal, py_array); break; case NPY_FLOAT32: CopyLiteralToNumpyArray(literal, py_array); break; case NPY_FLOAT64: CopyLiteralToNumpyArray(literal, py_array); break; case NPY_COMPLEX64: CopyLiteralToNumpyArray(literal, py_array); break; case NPY_COMPLEX128: CopyLiteralToNumpyArray(literal, py_array); break; default: return InvalidArgument( "No XLA literal container for Numpy type number: %d", np_type); } return Status::OK(); } PyObject* LongToPyIntOrPyLong(long x) { // NOLINT #if PY_MAJOR_VERSION < 3 return PyInt_FromLong(x); #else return PyLong_FromLong(x); #endif } long PyIntOrPyLongToLong(PyObject* o) { // NOLINT #if PY_MAJOR_VERSION < 3 return PyInt_AsLong(o); #else return PyLong_AsLong(o); #endif } bool CheckPyIntOrLong(PyObject* o) { #if PY_MAJOR_VERSION < 3 return PyInt_Check(o); #else if (!PyLong_Check(o)) { return false; } int overflow = 0; PyLong_AsLongAndOverflow(o, &overflow); return (overflow == 0); #endif } PyObject* PyNumberToPyInt(PyObject* o) { #if PY_MAJOR_VERSION < 3 return PyNumber_Int(o); #else return PyNumber_Long(o); #endif } } // namespace numpy bool GetIntAttr(PyObject* o, const char* field, int64* result) { PyObject* fo = PyObject_GetAttrString(o, field); if (!fo) { return false; } const int64 value = numpy::PyIntOrPyLongToLong(fo); if (value == -1 && PyErr_Occurred()) { Py_DECREF(fo); return false; } Py_DECREF(fo); *result = value; return true; } // Returns "ok"; true if there is no error, false if there was an error. bool HandleStringAttribute(PyObject* o, const char* attr_name, std::function f) { if (!PyObject_HasAttrString(o, attr_name)) { return true; // It's ok for the object to not have the attribute. } PyObject* attr = PyObject_GetAttrString(o, attr_name); if (attr == nullptr) { return false; // An error occurred getting the attribute. } if (attr == Py_None) { Py_DECREF(attr); return true; // The attribute is None, which we consider ok. } #if PY_MAJOR_VERSION < 3 if (!PyString_Check(attr)) { string message = absl::StrFormat("%s must be a string or none; got %s", attr_name, numpy::PyObjectCppRepr(attr)); PyErr_SetString(PyExc_TypeError, message.c_str()); Py_DECREF(attr); return false; // Type error, not ok. } f(PyString_AsString(attr)); #else if (!PyBytes_Check(attr)) { string message = absl::StrFormat("%s must be a string or none; got %s", attr_name, numpy::PyObjectCppRepr(attr)); PyErr_SetString(PyExc_TypeError, message.c_str()); Py_DECREF(attr); return false; // Type error, not ok. } f(PyBytes_AsString(attr)); #endif Py_DECREF(attr); return true; // Handled string attribute, ok! } // Returns "ok"; true if there is no error, false if there was an error. bool HandleBoolAttribute(PyObject* o, const char* attr_name, std::function f) { if (!PyObject_HasAttrString(o, attr_name)) { return true; // It's ok for the object to not have the attribute. } PyObject* attr = PyObject_GetAttrString(o, attr_name); if (attr == nullptr) { return false; // An error occurred getting the attribute. } if (attr == Py_None) { Py_DECREF(attr); return true; // The attribute is None, which we consider ok. } if (!PyBool_Check(attr)) { string message = absl::StrFormat("%s must be a boolean or none; got %s", attr_name, numpy::PyObjectCppRepr(attr)); PyErr_SetString(PyExc_TypeError, message.c_str()); Py_DECREF(attr); return false; // Type error, not ok. } f(PyObject_IsTrue(attr)); Py_DECREF(attr); return true; // Handled boolean attribute, ok! } bool HandleRepeatedInt64Attribute( PyObject* o, const char* attr_name, tensorflow::protobuf::RepeatedField* field) { PyObject* seq = PyObject_GetAttrString(o, attr_name); if (!seq) { return false; } int length = PySequence_Size(seq); if (length == -1) { Py_DECREF(seq); return false; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(seq, i); if (!item) { Py_DECREF(seq); return false; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(seq); return false; } *field->Add() = dimension; Py_DECREF(item); } Py_DECREF(seq); return true; } } // namespace swig } // namespace xla