1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
17 #define TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
18 
19 #include <Python.h>
20 
21 #include "pybind11/pybind11.h"
22 #include "tensorflow/c/tf_status.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/protobuf/error_codes.pb.h"
25 #include "tensorflow/python/lib/core/py_exception_registry.h"
26 
27 namespace tensorflow {
28 
29 namespace internal {
30 
CodeToPyExc(const int code)31 inline PyObject* CodeToPyExc(const int code) {
32   switch (code) {
33     case error::Code::INVALID_ARGUMENT:
34       return PyExc_ValueError;
35     case error::Code::OUT_OF_RANGE:
36       return PyExc_IndexError;
37     case error::Code::UNIMPLEMENTED:
38       return PyExc_NotImplementedError;
39     default:
40       return PyExc_RuntimeError;
41   }
42 }
43 
StatusToPyExc(const Status & status)44 inline PyObject* StatusToPyExc(const Status& status) {
45   return CodeToPyExc(status.code());
46 }
47 
TFStatusToPyExc(const TF_Status * status)48 inline PyObject* TFStatusToPyExc(const TF_Status* status) {
49   return CodeToPyExc(TF_GetCode(status));
50 }
51 
52 }  // namespace internal
53 
MaybeRaiseFromStatus(const Status & status)54 inline void MaybeRaiseFromStatus(const Status& status) {
55   if (!status.ok()) {
56     PyErr_SetString(internal::StatusToPyExc(status),
57                     status.error_message().c_str());
58     throw pybind11::error_already_set();
59   }
60 }
61 
MaybeRaiseRegisteredFromStatus(const tensorflow::Status & status)62 inline void MaybeRaiseRegisteredFromStatus(const tensorflow::Status& status) {
63   if (!status.ok()) {
64     PyErr_SetObject(PyExceptionRegistry::Lookup(status.code()),
65                     pybind11::make_tuple(pybind11::none(), pybind11::none(),
66                                          status.error_message())
67                         .ptr());
68     throw pybind11::error_already_set();
69   }
70 }
71 
MaybeRaiseRegisteredFromStatusWithGIL(const tensorflow::Status & status)72 inline void MaybeRaiseRegisteredFromStatusWithGIL(
73     const tensorflow::Status& status) {
74   if (!status.ok()) {
75     // Acquire GIL for throwing exception.
76     pybind11::gil_scoped_acquire acquire;
77 
78     PyErr_SetObject(PyExceptionRegistry::Lookup(status.code()),
79                     pybind11::make_tuple(pybind11::none(), pybind11::none(),
80                                          status.error_message())
81                         .ptr());
82     throw pybind11::error_already_set();
83   }
84 }
85 
MaybeRaiseFromTFStatus(TF_Status * status)86 inline void MaybeRaiseFromTFStatus(TF_Status* status) {
87   TF_Code code = TF_GetCode(status);
88   if (code != TF_OK) {
89     PyErr_SetString(internal::TFStatusToPyExc(status), TF_Message(status));
90     throw pybind11::error_already_set();
91   }
92 }
93 
MaybeRaiseRegisteredFromTFStatus(TF_Status * status)94 inline void MaybeRaiseRegisteredFromTFStatus(TF_Status* status) {
95   TF_Code code = TF_GetCode(status);
96   if (code != TF_OK) {
97     PyErr_SetObject(PyExceptionRegistry::Lookup(code),
98                     pybind11::make_tuple(pybind11::none(), pybind11::none(),
99                                          TF_Message(status))
100                         .ptr());
101     throw pybind11::error_already_set();
102   }
103 }
104 
MaybeRaiseRegisteredFromTFStatusWithGIL(TF_Status * status)105 inline void MaybeRaiseRegisteredFromTFStatusWithGIL(TF_Status* status) {
106   TF_Code code = TF_GetCode(status);
107   if (code != TF_OK) {
108     // Acquire GIL for throwing exception.
109     pybind11::gil_scoped_acquire acquire;
110 
111     PyErr_SetObject(PyExceptionRegistry::Lookup(code),
112                     pybind11::make_tuple(pybind11::none(), pybind11::none(),
113                                          TF_Message(status))
114                         .ptr());
115     throw pybind11::error_already_set();
116   }
117 }
118 
119 }  // namespace tensorflow
120 
121 namespace pybind11 {
122 namespace detail {
123 
124 // Raise an exception if a given status is not OK, otherwise return None.
125 //
126 // The correspondence between status codes and exception classes is given
127 // by PyExceptionRegistry. Note that the registry should be initialized
128 // in order to be used, see PyExceptionRegistry::Init.
129 template <>
130 struct type_caster<tensorflow::Status> {
131  public:
132   PYBIND11_TYPE_CASTER(tensorflow::Status, _("Status"));
133   static handle cast(tensorflow::Status status, return_value_policy, handle) {
134     tensorflow::MaybeRaiseFromStatus(status);
135     return none().inc_ref();
136   }
137 };
138 
139 }  // namespace detail
140 }  // namespace pybind11
141 
142 #endif  // TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
143