1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utilities for using the TensorFlow C API."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.core.framework import api_def_pb2
23from tensorflow.core.framework import op_def_pb2
24from tensorflow.python import pywrap_tensorflow as c_api
25from tensorflow.python.util import compat
26from tensorflow.python.util import tf_contextlib
27
28
29class ScopedTFStatus(object):
30  """Wrapper around TF_Status that handles deletion."""
31
32  def __init__(self):
33    self.status = c_api.TF_NewStatus()
34
35  def __del__(self):
36    # Note: when we're destructing the global context (i.e when the process is
37    # terminating) we can have already deleted other modules.
38    if c_api is not None and c_api.TF_DeleteStatus is not None:
39      c_api.TF_DeleteStatus(self.status)
40
41
42class ScopedTFGraph(object):
43  """Wrapper around TF_Graph that handles deletion."""
44
45  def __init__(self):
46    self.graph = c_api.TF_NewGraph()
47
48  def __del__(self):
49    # Note: when we're destructing the global context (i.e when the process is
50    # terminating) we can have already deleted other modules.
51    if c_api is not None and c_api.TF_DeleteGraph is not None:
52      c_api.TF_DeleteGraph(self.graph)
53
54
55class ScopedTFImportGraphDefOptions(object):
56  """Wrapper around TF_ImportGraphDefOptions that handles deletion."""
57
58  def __init__(self):
59    self.options = c_api.TF_NewImportGraphDefOptions()
60
61  def __del__(self):
62    # Note: when we're destructing the global context (i.e when the process is
63    # terminating) we can have already deleted other modules.
64    if c_api is not None and c_api.TF_DeleteImportGraphDefOptions is not None:
65      c_api.TF_DeleteImportGraphDefOptions(self.options)
66
67
68class ScopedTFImportGraphDefResults(object):
69  """Wrapper around TF_ImportGraphDefOptions that handles deletion."""
70
71  def __init__(self, results):
72    self.results = results
73
74  def __del__(self):
75    # Note: when we're destructing the global context (i.e when the process is
76    # terminating) we can have already deleted other modules.
77    if c_api is not None and c_api.TF_DeleteImportGraphDefResults is not None:
78      c_api.TF_DeleteImportGraphDefResults(self.results)
79
80
81class ScopedTFFunction(object):
82  """Wrapper around TF_Function that handles deletion."""
83
84  def __init__(self, func):
85    self.func = func
86
87  def __del__(self):
88    # Note: when we're destructing the global context (i.e when the process is
89    # terminating) we can have already deleted other modules.
90    if c_api is not None and c_api.TF_DeleteFunction is not None:
91      c_api.TF_DeleteFunction(self.func)
92
93
94class ApiDefMap(object):
95  """Wrapper around Tf_ApiDefMap that handles querying and deletion.
96
97  The OpDef protos are also stored in this class so that they could
98  be queried by op name.
99  """
100
101  def __init__(self):
102    op_def_proto = op_def_pb2.OpList()
103    buf = c_api.TF_GetAllOpList()
104    try:
105      op_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
106      self._api_def_map = c_api.TF_NewApiDefMap(buf)
107    finally:
108      c_api.TF_DeleteBuffer(buf)
109
110    self._op_per_name = {}
111    for op in op_def_proto.op:
112      self._op_per_name[op.name] = op
113
114  def __del__(self):
115    # Note: when we're destructing the global context (i.e when the process is
116    # terminating) we can have already deleted other modules.
117    if c_api is not None and c_api.TF_DeleteApiDefMap is not None:
118      c_api.TF_DeleteApiDefMap(self._api_def_map)
119
120  def put_api_def(self, text):
121    c_api.TF_ApiDefMapPut(self._api_def_map, text, len(text))
122
123  def get_api_def(self, op_name):
124    api_def_proto = api_def_pb2.ApiDef()
125    buf = c_api.TF_ApiDefMapGet(self._api_def_map, op_name, len(op_name))
126    try:
127      api_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
128    finally:
129      c_api.TF_DeleteBuffer(buf)
130    return api_def_proto
131
132  def get_op_def(self, op_name):
133    if op_name in self._op_per_name:
134      return self._op_per_name[op_name]
135    raise ValueError("No entry found for " + op_name + ".")
136
137  def op_names(self):
138    return self._op_per_name.keys()
139
140
141@tf_contextlib.contextmanager
142def tf_buffer(data=None):
143  """Context manager that creates and deletes TF_Buffer.
144
145  Example usage:
146    with tf_buffer() as buf:
147      # get serialized graph def into buf
148      ...
149      proto_data = c_api.TF_GetBuffer(buf)
150      graph_def.ParseFromString(compat.as_bytes(proto_data))
151    # buf has been deleted
152
153    with tf_buffer(some_string) as buf:
154      c_api.TF_SomeFunction(buf)
155    # buf has been deleted
156
157  Args:
158    data: An optional `bytes`, `str`, or `unicode` object. If not None, the
159      yielded buffer will contain this data.
160
161  Yields:
162    Created TF_Buffer
163  """
164  if data:
165    buf = c_api.TF_NewBufferFromString(compat.as_bytes(data))
166  else:
167    buf = c_api.TF_NewBuffer()
168  try:
169    yield buf
170  finally:
171    c_api.TF_DeleteBuffer(buf)
172
173
174def tf_output(c_op, index):
175  """Returns a wrapped TF_Output with specified operation and index.
176
177  Args:
178    c_op: wrapped TF_Operation
179    index: integer
180
181  Returns:
182    Wrapped TF_Output
183  """
184  ret = c_api.TF_Output()
185  ret.oper = c_op
186  ret.index = index
187  return ret
188
189
190def tf_operations(graph):
191  """Generator that yields every TF_Operation in `graph`.
192
193  Args:
194    graph: Graph
195
196  Yields:
197    wrapped TF_Operation
198  """
199  # pylint: disable=protected-access
200  pos = 0
201  c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos)
202  while c_op is not None:
203    yield c_op
204    c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos)
205  # pylint: enable=protected-access
206
207
208def new_tf_operations(graph):
209  """Generator that yields newly-added TF_Operations in `graph`.
210
211  Specifically, yields TF_Operations that don't have associated Operations in
212  `graph`. This is useful for processing nodes added by the C API.
213
214  Args:
215    graph: Graph
216
217  Yields:
218    wrapped TF_Operation
219  """
220  # TODO(b/69679162): do this more efficiently
221  for c_op in tf_operations(graph):
222    try:
223      graph._get_operation_by_tf_operation(c_op)  # pylint: disable=protected-access
224    except KeyError:
225      yield c_op
226