1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_ 16 #define TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_ 17 18 #include <Python.h> 19 20 #include <map> 21 22 #include "tensorflow/c/tf_status.h" 23 #include "tensorflow/core/lib/core/error_codes.pb.h" 24 25 namespace tensorflow { 26 27 // Global registry mapping C API error codes to the corresponding custom Python 28 // exception type. This is used to expose the exception types to C extension 29 // code (i.e. so we can raise custom exceptions via SWIG). 30 // 31 // Init() must be called exactly once at the beginning of the process before 32 // Lookup() can be used. 33 // 34 // Example usage: 35 // TF_Status* status = TF_NewStatus(); 36 // TF_Foo(..., status); 37 // 38 // if (TF_GetCode(status) != TF_OK) { 39 // PyObject* exc_type = PyExceptionRegistry::Lookup(TF_GetCode(status)); 40 // // Arguments to OpError base class. Set `node_def` and `op` to None. 41 // PyObject* args = 42 // Py_BuildValue("sss", nullptr, nullptr, TF_Message(status)); 43 // PyErr_SetObject(exc_type, args); 44 // Py_DECREF(args); 45 // TF_DeleteStatus(status); 46 // return NULL; 47 // } 48 class PyExceptionRegistry { 49 public: 50 // Initializes the process-wide registry. Should be called exactly once near 51 // the beginning of the process. The arguments are the various Python 52 // exception types (e.g. `cancelled_exc` corresponds to 53 // errors.CancelledError). 54 static void Init(PyObject* code_to_exc_type_map); 55 56 // Returns the Python exception type corresponding to `code`. Init() must be 57 // called before using this function. `code` should not be TF_OK. 58 static PyObject* Lookup(TF_Code code); 59 Lookup(error::Code code)60 static inline PyObject* Lookup(error::Code code) { 61 return Lookup(static_cast<TF_Code>(code)); 62 } 63 64 private: 65 static PyExceptionRegistry* singleton_; 66 PyExceptionRegistry() = default; 67 68 // Maps error codes to the corresponding Python exception type. 69 std::map<TF_Code, PyObject*> exc_types_; 70 }; 71 72 } // namespace tensorflow 73 74 #endif // TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_ 75