1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests that the system configuration methods work properly.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.eager import context 22from tensorflow.python.eager import def_function 23from tensorflow.python.framework import config 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import errors 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import test_util 29from tensorflow.python.ops import math_ops 30from tensorflow.python.platform import test 31 32 33def reset_eager(fn): 34 def wrapper(*args, **kwargs): 35 try: 36 return fn(*args, **kwargs) 37 finally: 38 del context._context 39 context._context = context.Context() 40 ops.enable_eager_execution() 41 42 return wrapper 43 44 45class ConfigTest(test.TestCase): 46 47 @test_util.run_gpu_only 48 @reset_eager 49 def testDevicePolicy(self): 50 self.assertEqual(context.DEVICE_PLACEMENT_SILENT, 51 context.context().device_policy) 52 53 # If no op has been executed we should be able to set the device policy as 54 # well as any init-time configs. 55 config.set_intra_op_parallelism_threads(1) 56 config.set_device_policy('silent') 57 config.set_intra_op_parallelism_threads(2) 58 59 # Excute a dummy op to ensure that the context has been initialized 60 constant_op.constant(1) 61 62 def copy_tensor(dtype=dtypes.int32): 63 cpu_tensor = constant_op.constant(1, dtype=dtype) 64 gpu_tensor = cpu_tensor.gpu() 65 self.assertAllEqual(cpu_tensor + gpu_tensor, 2.0) 66 67 config.set_device_policy('silent') 68 self.assertEqual(config.get_device_policy(), 'silent') 69 self.assertEqual(context.DEVICE_PLACEMENT_SILENT, 70 context.context().device_policy) 71 copy_tensor() 72 73 config.set_device_policy('silent_for_int32') 74 self.assertEqual(config.get_device_policy(), 'silent_for_int32') 75 self.assertEqual(context.DEVICE_PLACEMENT_SILENT_FOR_INT32, 76 context.context().device_policy) 77 with self.assertRaisesRegexp(errors.InvalidArgumentError, 78 'Tensors on conflicting devices'): 79 copy_tensor(dtypes.float32) 80 copy_tensor() 81 82 config.set_device_policy('warn') 83 self.assertEqual(config.get_device_policy(), 'warn') 84 self.assertEqual(context.DEVICE_PLACEMENT_WARN, 85 context.context().device_policy) 86 copy_tensor() 87 88 config.set_device_policy('explicit') 89 self.assertEqual(config.get_device_policy(), 'explicit') 90 self.assertEqual(context.DEVICE_PLACEMENT_EXPLICIT, 91 context.context().device_policy) 92 with self.assertRaisesRegexp(errors.InvalidArgumentError, 93 'Tensors on conflicting devices'): 94 copy_tensor() 95 96 config.set_device_policy(None) 97 self.assertEqual(config.get_device_policy(), 'silent') 98 99 @reset_eager 100 def testExecutionMode(self): 101 self.assertTrue(config.get_synchronous_execution()) 102 self.assertEqual(context.SYNC, context.context().execution_mode) 103 104 # If no op has been executed we should be able to set the execution mode as 105 # well as any init-time configs. 106 config.set_intra_op_parallelism_threads(1) 107 config.set_synchronous_execution(False) 108 config.set_intra_op_parallelism_threads(2) 109 110 config.set_synchronous_execution(True) 111 self.assertTrue(config.get_synchronous_execution()) 112 self.assertEqual(context.SYNC, context.context().execution_mode) 113 config.set_synchronous_execution(False) 114 self.assertFalse(config.get_synchronous_execution()) 115 self.assertEqual(context.ASYNC, context.context().execution_mode) 116 117 @reset_eager 118 def testGpuPerProcessMemoryFraction(self): 119 config.set_gpu_per_process_memory_fraction(0.5) 120 self.assertEqual( 121 config.get_gpu_per_process_memory_fraction(), 122 context.context().gpu_per_process_memory_fraction) 123 124 constant_op.constant(1) 125 with self.assertRaises(RuntimeError): 126 config.set_gpu_per_process_memory_fraction(0.5) 127 128 @reset_eager 129 def testGpuPerProcessMemoryGrowth(self): 130 self.assertFalse(config.get_gpu_per_process_memory_growth()) 131 132 config.set_gpu_per_process_memory_growth(True) 133 self.assertTrue(config.get_gpu_per_process_memory_growth()) 134 self.assertEqual( 135 config.get_gpu_per_process_memory_growth(), 136 context.context().gpu_per_process_memory_growth) 137 138 config.set_gpu_per_process_memory_growth(False) 139 self.assertFalse(config.get_gpu_per_process_memory_growth()) 140 self.assertEqual( 141 config.get_gpu_per_process_memory_growth(), 142 context.context().gpu_per_process_memory_growth) 143 144 constant_op.constant(1) 145 with self.assertRaises(RuntimeError): 146 config.set_gpu_per_process_memory_growth(True) 147 148 @reset_eager 149 def testIntraOpParallelismThreads(self): 150 config.set_intra_op_parallelism_threads(10) 151 self.assertEqual( 152 config.get_intra_op_parallelism_threads(), 153 context.context().intra_op_parallelism_threads) 154 155 constant_op.constant(1) 156 with self.assertRaises(RuntimeError): 157 config.set_intra_op_parallelism_threads(1) 158 159 @reset_eager 160 def testInterOpParallelismThreads(self): 161 config.set_inter_op_parallelism_threads(10) 162 self.assertEqual( 163 config.get_inter_op_parallelism_threads(), 164 context.context().inter_op_parallelism_threads) 165 166 constant_op.constant(1) 167 with self.assertRaises(RuntimeError): 168 config.set_inter_op_parallelism_threads(1) 169 170 @test_util.run_gpu_only 171 @reset_eager 172 def testSoftPlacement(self): 173 self.assertEqual(config.get_soft_device_placement(), True) 174 175 @def_function.function 176 def mod(): 177 with ops.device('/device:GPU:0'): 178 a = constant_op.constant(1.0) 179 b = constant_op.constant(1.0) 180 return math_ops.mod(a, b) 181 182 # Since soft placement is enabled, the mod operation should work with CPU 183 mod() 184 185 config.set_soft_device_placement(False) 186 self.assertEqual(config.get_soft_device_placement(), False) 187 self.assertEqual( 188 config.get_soft_device_placement(), 189 context.context().soft_device_placement) 190 191 # Since soft placement is disabled, the mod operation should fail on GPU 192 with self.assertRaises(errors.InvalidArgumentError): 193 mod() 194 195 config.set_soft_device_placement(True) 196 self.assertEqual(config.get_soft_device_placement(), True) 197 self.assertEqual( 198 config.get_soft_device_placement(), 199 context.context().soft_device_placement) 200 201 # Since soft placement is re-enabled, the mod operation should work with CPU 202 mod() 203 204 @reset_eager 205 def testLogDevicePlacement(self): 206 self.assertEqual(context.get_log_device_placement(), False) 207 208 context.set_log_device_placement(True) 209 self.assertEqual(context.get_log_device_placement(), True) 210 self.assertEqual( 211 context.get_log_device_placement(), 212 context.context().log_device_placement) 213 214 context.set_log_device_placement(False) 215 self.assertEqual(context.get_log_device_placement(), False) 216 self.assertEqual( 217 context.get_log_device_placement(), 218 context.context().log_device_placement) 219 220 constant_op.constant(1) 221 with self.assertRaises(RuntimeError): 222 context.set_log_device_placement(True) 223 with self.assertRaises(RuntimeError): 224 context.set_log_device_placement(False) 225 226 227if __name__ == '__main__': 228 ops.enable_eager_execution() 229 test.main() 230