1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Testing.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22 23# pylint: disable=g-bad-import-order 24from tensorflow.python.framework import test_util as _test_util 25from tensorflow.python.platform import googletest as _googletest 26 27# pylint: disable=unused-import 28from tensorflow.python.framework.test_util import assert_equal_graph_def 29from tensorflow.python.framework.test_util import create_local_cluster 30from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase 31from tensorflow.python.framework.test_util import gpu_device_name 32from tensorflow.python.framework.test_util import is_gpu_available 33 34from tensorflow.python.ops.gradient_checker import compute_gradient_error 35from tensorflow.python.ops.gradient_checker import compute_gradient 36# pylint: enable=unused-import,g-bad-import-order 37 38import sys 39from tensorflow.python.util.tf_export import tf_export 40if sys.version_info.major == 2: 41 import mock # pylint: disable=g-import-not-at-top,unused-import 42else: 43 from unittest import mock # pylint: disable=g-import-not-at-top,g-importing-member 44 45tf_export(v1=['test.mock'])(mock) 46 47# Import Benchmark class 48Benchmark = _googletest.Benchmark # pylint: disable=invalid-name 49 50# Import StubOutForTesting class 51StubOutForTesting = _googletest.StubOutForTesting # pylint: disable=invalid-name 52 53 54@tf_export('test.main') 55def main(argv=None): 56 """Runs all unit tests.""" 57 _test_util.InstallStackTraceHandler() 58 return _googletest.main(argv) 59 60 61@tf_export(v1=['test.get_temp_dir']) 62def get_temp_dir(): 63 """Returns a temporary directory for use during tests. 64 65 There is no need to delete the directory after the test. 66 67 Returns: 68 The temporary directory. 69 """ 70 return _googletest.GetTempDir() 71 72 73@tf_export(v1=['test.test_src_dir_path']) 74def test_src_dir_path(relative_path): 75 """Creates an absolute test srcdir path given a relative path. 76 77 Args: 78 relative_path: a path relative to tensorflow root. 79 e.g. "core/platform". 80 81 Returns: 82 An absolute path to the linked in runfiles. 83 """ 84 return _googletest.test_src_dir_path(relative_path) 85 86 87@tf_export('test.is_built_with_cuda') 88def is_built_with_cuda(): 89 """Returns whether TensorFlow was built with CUDA (GPU) support.""" 90 return _test_util.IsGoogleCudaEnabled() 91 92 93@tf_export('test.is_built_with_rocm') 94def is_built_with_rocm(): 95 """Returns whether TensorFlow was built with ROCm (GPU) support.""" 96 return _test_util.IsBuiltWithROCm() 97 98 99@tf_export('test.is_built_with_gpu_support') 100def is_built_with_gpu_support(): 101 """Returns whether TensorFlow was built with GPU (i.e. CUDA or ROCm) support.""" 102 return is_built_with_cuda() or is_built_with_rocm() 103 104 105@tf_export('test.is_built_with_xla') 106def is_built_with_xla(): 107 """Returns whether TensorFlow was built with XLA support.""" 108 return _test_util.IsBuiltWithXLA() 109