1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python module for Session ops, vars, and functions exported by pybind11."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21# pylint: disable=invalid-import-order,g-bad-import-order, wildcard-import, unused-import
22from tensorflow.python import pywrap_tensorflow
23from tensorflow.python._pywrap_tf_session import *
24from tensorflow.python._pywrap_tf_session import _TF_SetTarget
25from tensorflow.python._pywrap_tf_session import _TF_SetConfig
26from tensorflow.python._pywrap_tf_session import _TF_NewSessionOptions
27
28# Convert versions to strings for Python2 and keep api_compatibility_test green.
29# We can remove this hack once we remove Python2 presubmits. pybind11 can only
30# return unicode for Python2 even with py::str.
31# https://pybind11.readthedocs.io/en/stable/advanced/cast/strings.html#returning-c-strings-to-python
32# pylint: disable=undefined-variable
33__version__ = str(get_version())
34__git_version__ = str(get_git_version())
35__compiler_version__ = str(get_compiler_version())
36__cxx11_abi_flag__ = get_cxx11_abi_flag()
37__monolithic_build__ = get_monolithic_build()
38
39# User getters to hold attributes rather than pybind11's m.attr due to
40# b/145559202.
41GRAPH_DEF_VERSION = get_graph_def_version()
42GRAPH_DEF_VERSION_MIN_CONSUMER = get_graph_def_version_min_consumer()
43GRAPH_DEF_VERSION_MIN_PRODUCER = get_graph_def_version_min_producer()
44TENSOR_HANDLE_KEY = get_tensor_handle_key()
45
46# pylint: enable=undefined-variable
47
48
49# Disable pylint invalid name warnings for legacy functions.
50# pylint: disable=invalid-name
51def TF_NewSessionOptions(target=None, config=None):
52  # NOTE: target and config are validated in the session constructor.
53  opts = _TF_NewSessionOptions()
54  if target is not None:
55    _TF_SetTarget(opts, target)
56  if config is not None:
57    config_str = config.SerializeToString()
58    _TF_SetConfig(opts, config_str)
59  return opts
60
61
62# Disable pylind undefined-variable as the variable is exported in the shared
63# object via pybind11.
64# pylint: disable=undefined-variable
65def TF_Reset(target, containers=None, config=None):
66  opts = TF_NewSessionOptions(target=target, config=config)
67  try:
68    TF_Reset_wrapper(opts, containers)
69  finally:
70    TF_DeleteSessionOptions(opts)
71