1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utilities for using the TensorFlow C API."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python import pywrap_tensorflow as c_api
23from tensorflow.python.util import compat
24from tensorflow.python.util import tf_contextlib
25
26
27class ScopedTFStatus(object):
28  """Wrapper around TF_Status that handles deletion."""
29
30  def __init__(self):
31    self.status = c_api.TF_NewStatus()
32
33  def __del__(self):
34    # Note: when we're destructing the global context (i.e when the process is
35    # terminating) we can have already deleted other modules.
36    if c_api.TF_DeleteStatus is not None:
37      c_api.TF_DeleteStatus(self.status)
38
39
40class ScopedTFGraph(object):
41  """Wrapper around TF_Graph that handles deletion."""
42
43  def __init__(self):
44    self.graph = c_api.TF_NewGraph()
45
46  def __del__(self):
47    # Note: when we're destructing the global context (i.e when the process is
48    # terminating) we can have already deleted other modules.
49    if c_api.TF_DeleteGraph is not None:
50      c_api.TF_DeleteGraph(self.graph)
51
52
53class ScopedTFImportGraphDefOptions(object):
54  """Wrapper around TF_ImportGraphDefOptions that handles deletion."""
55
56  def __init__(self):
57    self.options = c_api.TF_NewImportGraphDefOptions()
58
59  def __del__(self):
60    # Note: when we're destructing the global context (i.e when the process is
61    # terminating) we can have already deleted other modules.
62    if c_api.TF_DeleteImportGraphDefOptions is not None:
63      c_api.TF_DeleteImportGraphDefOptions(self.options)
64
65
66@tf_contextlib.contextmanager
67def tf_buffer(data=None):
68  """Context manager that creates and deletes TF_Buffer.
69
70  Example usage:
71    with tf_buffer() as buf:
72      # get serialized graph def into buf
73      ...
74      proto_data = c_api.TF_GetBuffer(buf)
75      graph_def.ParseFromString(compat.as_bytes(proto_data))
76    # buf has been deleted
77
78    with tf_buffer(some_string) as buf:
79      c_api.TF_SomeFunction(buf)
80    # buf has been deleted
81
82  Args:
83    data: An optional `bytes`, `str`, or `unicode` object. If not None, the
84      yielded buffer will contain this data.
85
86  Yields:
87    Created TF_Buffer
88  """
89  if data:
90    buf = c_api.TF_NewBufferFromString(compat.as_bytes(data))
91  else:
92    buf = c_api.TF_NewBuffer()
93  try:
94    yield buf
95  finally:
96    c_api.TF_DeleteBuffer(buf)
97
98
99def tf_output(c_op, index):
100  """Returns a wrapped TF_Output with specified operation and index.
101
102  Args:
103    c_op: wrapped TF_Operation
104    index: integer
105
106  Returns:
107    Wrapped TF_Output
108  """
109  ret = c_api.TF_Output()
110  ret.oper = c_op
111  ret.index = index
112  return ret
113
114
115def tf_operations(graph):
116  """Generator that yields every TF_Operation in `graph`.
117
118  Args:
119    graph: Graph
120
121  Yields:
122    wrapped TF_Operation
123  """
124  # pylint: disable=protected-access
125  pos = 0
126  c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos)
127  while c_op is not None:
128    yield c_op
129    c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos)
130  # pylint: enable=protected-access
131
132
133def new_tf_operations(graph):
134  """Generator that yields newly-added TF_Operations in `graph`.
135
136  Specifically, yields TF_Operations that don't have associated Operations in
137  `graph`. This is useful for processing nodes added by the C API.
138
139  Args:
140    graph: Graph
141
142  Yields:
143    wrapped TF_Operation
144  """
145  # TODO(b/69679162): do this more efficiently
146  for c_op in tf_operations(graph):
147    try:
148      graph._get_operation_by_tf_operation(c_op)  # pylint: disable=protected-access
149    except KeyError:
150      yield c_op
151