1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utilities for using the TensorFlow C API.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.python import pywrap_tensorflow as c_api 23from tensorflow.python.util import compat 24from tensorflow.python.util import tf_contextlib 25 26 27class ScopedTFStatus(object): 28 """Wrapper around TF_Status that handles deletion.""" 29 30 def __init__(self): 31 self.status = c_api.TF_NewStatus() 32 33 def __del__(self): 34 # Note: when we're destructing the global context (i.e when the process is 35 # terminating) we can have already deleted other modules. 36 if c_api.TF_DeleteStatus is not None: 37 c_api.TF_DeleteStatus(self.status) 38 39 40class ScopedTFGraph(object): 41 """Wrapper around TF_Graph that handles deletion.""" 42 43 def __init__(self): 44 self.graph = c_api.TF_NewGraph() 45 46 def __del__(self): 47 # Note: when we're destructing the global context (i.e when the process is 48 # terminating) we can have already deleted other modules. 49 if c_api.TF_DeleteGraph is not None: 50 c_api.TF_DeleteGraph(self.graph) 51 52 53class ScopedTFImportGraphDefOptions(object): 54 """Wrapper around TF_ImportGraphDefOptions that handles deletion.""" 55 56 def __init__(self): 57 self.options = c_api.TF_NewImportGraphDefOptions() 58 59 def __del__(self): 60 # Note: when we're destructing the global context (i.e when the process is 61 # terminating) we can have already deleted other modules. 62 if c_api.TF_DeleteImportGraphDefOptions is not None: 63 c_api.TF_DeleteImportGraphDefOptions(self.options) 64 65 66@tf_contextlib.contextmanager 67def tf_buffer(data=None): 68 """Context manager that creates and deletes TF_Buffer. 69 70 Example usage: 71 with tf_buffer() as buf: 72 # get serialized graph def into buf 73 ... 74 proto_data = c_api.TF_GetBuffer(buf) 75 graph_def.ParseFromString(compat.as_bytes(proto_data)) 76 # buf has been deleted 77 78 with tf_buffer(some_string) as buf: 79 c_api.TF_SomeFunction(buf) 80 # buf has been deleted 81 82 Args: 83 data: An optional `bytes`, `str`, or `unicode` object. If not None, the 84 yielded buffer will contain this data. 85 86 Yields: 87 Created TF_Buffer 88 """ 89 if data: 90 buf = c_api.TF_NewBufferFromString(compat.as_bytes(data)) 91 else: 92 buf = c_api.TF_NewBuffer() 93 try: 94 yield buf 95 finally: 96 c_api.TF_DeleteBuffer(buf) 97 98 99def tf_output(c_op, index): 100 """Returns a wrapped TF_Output with specified operation and index. 101 102 Args: 103 c_op: wrapped TF_Operation 104 index: integer 105 106 Returns: 107 Wrapped TF_Output 108 """ 109 ret = c_api.TF_Output() 110 ret.oper = c_op 111 ret.index = index 112 return ret 113 114 115def tf_operations(graph): 116 """Generator that yields every TF_Operation in `graph`. 117 118 Args: 119 graph: Graph 120 121 Yields: 122 wrapped TF_Operation 123 """ 124 # pylint: disable=protected-access 125 pos = 0 126 c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos) 127 while c_op is not None: 128 yield c_op 129 c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos) 130 # pylint: enable=protected-access 131 132 133def new_tf_operations(graph): 134 """Generator that yields newly-added TF_Operations in `graph`. 135 136 Specifically, yields TF_Operations that don't have associated Operations in 137 `graph`. This is useful for processing nodes added by the C API. 138 139 Args: 140 graph: Graph 141 142 Yields: 143 wrapped TF_Operation 144 """ 145 # TODO(b/69679162): do this more efficiently 146 for c_op in tf_operations(graph): 147 try: 148 graph._get_operation_by_tf_operation(c_op) # pylint: disable=protected-access 149 except KeyError: 150 yield c_op 151