1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Testing."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22
23# pylint: disable=g-bad-import-order
24from tensorflow.python.framework import test_util as _test_util
25from tensorflow.python.platform import googletest as _googletest
26
27# pylint: disable=unused-import
28from tensorflow.python.framework.test_util import assert_equal_graph_def
29from tensorflow.python.framework.test_util import create_local_cluster
30from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase
31from tensorflow.python.framework.test_util import gpu_device_name
32from tensorflow.python.framework.test_util import is_gpu_available
33
34from tensorflow.python.ops.gradient_checker import compute_gradient_error
35from tensorflow.python.ops.gradient_checker import compute_gradient
36# pylint: enable=unused-import,g-bad-import-order
37
38import sys
39from tensorflow.python.util.tf_export import tf_export
40if sys.version_info.major == 2:
41  import mock                # pylint: disable=g-import-not-at-top,unused-import
42else:
43  from unittest import mock  # pylint: disable=g-import-not-at-top,g-importing-member
44
45tf_export(v1=['test.mock'])(mock)
46
47# Import Benchmark class
48Benchmark = _googletest.Benchmark  # pylint: disable=invalid-name
49
50# Import StubOutForTesting class
51StubOutForTesting = _googletest.StubOutForTesting  # pylint: disable=invalid-name
52
53
54@tf_export('test.main')
55def main(argv=None):
56  """Runs all unit tests."""
57  _test_util.InstallStackTraceHandler()
58  return _googletest.main(argv)
59
60
61@tf_export(v1=['test.get_temp_dir'])
62def get_temp_dir():
63  """Returns a temporary directory for use during tests.
64
65  There is no need to delete the directory after the test.
66
67  Returns:
68    The temporary directory.
69  """
70  return _googletest.GetTempDir()
71
72
73@tf_export(v1=['test.test_src_dir_path'])
74def test_src_dir_path(relative_path):
75  """Creates an absolute test srcdir path given a relative path.
76
77  Args:
78    relative_path: a path relative to tensorflow root.
79      e.g. "core/platform".
80
81  Returns:
82    An absolute path to the linked in runfiles.
83  """
84  return _googletest.test_src_dir_path(relative_path)
85
86
87@tf_export('test.is_built_with_cuda')
88def is_built_with_cuda():
89  """Returns whether TensorFlow was built with CUDA (GPU) support."""
90  return _test_util.IsGoogleCudaEnabled()
91
92
93@tf_export('test.is_built_with_rocm')
94def is_built_with_rocm():
95  """Returns whether TensorFlow was built with ROCm (GPU) support."""
96  return _test_util.IsBuiltWithROCm()
97
98
99@tf_export('test.is_built_with_gpu_support')
100def is_built_with_gpu_support():
101  """Returns whether TensorFlow was built with GPU (i.e. CUDA or ROCm) support."""
102  return is_built_with_cuda() or is_built_with_rocm()
103
104
105@tf_export('test.is_built_with_xla')
106def is_built_with_xla():
107  """Returns whether TensorFlow was built with XLA support."""
108  return _test_util.IsBuiltWithXLA()
109