1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests that the system configuration methods work properly."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.eager import context
22from tensorflow.python.eager import def_function
23from tensorflow.python.framework import config
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import errors
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import test_util
29from tensorflow.python.ops import math_ops
30from tensorflow.python.platform import test
31
32
33def reset_eager(fn):
34  def wrapper(*args, **kwargs):
35    try:
36      return fn(*args, **kwargs)
37    finally:
38      del context._context
39      context._context = context.Context()
40      ops.enable_eager_execution()
41
42  return wrapper
43
44
45class ConfigTest(test.TestCase):
46
47  @test_util.run_gpu_only
48  @reset_eager
49  def testDevicePolicy(self):
50    self.assertEqual(context.DEVICE_PLACEMENT_SILENT,
51                     context.context().device_policy)
52
53    # If no op has been executed we should be able to set the device policy as
54    # well as any init-time configs.
55    config.set_intra_op_parallelism_threads(1)
56    config.set_device_policy('silent')
57    config.set_intra_op_parallelism_threads(2)
58
59    # Excute a dummy op to ensure that the context has been initialized
60    constant_op.constant(1)
61
62    def copy_tensor(dtype=dtypes.int32):
63      cpu_tensor = constant_op.constant(1, dtype=dtype)
64      gpu_tensor = cpu_tensor.gpu()
65      self.assertAllEqual(cpu_tensor + gpu_tensor, 2.0)
66
67    config.set_device_policy('silent')
68    self.assertEqual(config.get_device_policy(), 'silent')
69    self.assertEqual(context.DEVICE_PLACEMENT_SILENT,
70                     context.context().device_policy)
71    copy_tensor()
72
73    config.set_device_policy('silent_for_int32')
74    self.assertEqual(config.get_device_policy(), 'silent_for_int32')
75    self.assertEqual(context.DEVICE_PLACEMENT_SILENT_FOR_INT32,
76                     context.context().device_policy)
77    with self.assertRaisesRegexp(errors.InvalidArgumentError,
78                                 'Tensors on conflicting devices'):
79      copy_tensor(dtypes.float32)
80    copy_tensor()
81
82    config.set_device_policy('warn')
83    self.assertEqual(config.get_device_policy(), 'warn')
84    self.assertEqual(context.DEVICE_PLACEMENT_WARN,
85                     context.context().device_policy)
86    copy_tensor()
87
88    config.set_device_policy('explicit')
89    self.assertEqual(config.get_device_policy(), 'explicit')
90    self.assertEqual(context.DEVICE_PLACEMENT_EXPLICIT,
91                     context.context().device_policy)
92    with self.assertRaisesRegexp(errors.InvalidArgumentError,
93                                 'Tensors on conflicting devices'):
94      copy_tensor()
95
96    config.set_device_policy(None)
97    self.assertEqual(config.get_device_policy(), 'silent')
98
99  @reset_eager
100  def testExecutionMode(self):
101    self.assertTrue(config.get_synchronous_execution())
102    self.assertEqual(context.SYNC, context.context().execution_mode)
103
104    # If no op has been executed we should be able to set the execution mode as
105    # well as any init-time configs.
106    config.set_intra_op_parallelism_threads(1)
107    config.set_synchronous_execution(False)
108    config.set_intra_op_parallelism_threads(2)
109
110    config.set_synchronous_execution(True)
111    self.assertTrue(config.get_synchronous_execution())
112    self.assertEqual(context.SYNC, context.context().execution_mode)
113    config.set_synchronous_execution(False)
114    self.assertFalse(config.get_synchronous_execution())
115    self.assertEqual(context.ASYNC, context.context().execution_mode)
116
117  @reset_eager
118  def testGpuPerProcessMemoryFraction(self):
119    config.set_gpu_per_process_memory_fraction(0.5)
120    self.assertEqual(
121        config.get_gpu_per_process_memory_fraction(),
122        context.context().gpu_per_process_memory_fraction)
123
124    constant_op.constant(1)
125    with self.assertRaises(RuntimeError):
126      config.set_gpu_per_process_memory_fraction(0.5)
127
128  @reset_eager
129  def testGpuPerProcessMemoryGrowth(self):
130    self.assertFalse(config.get_gpu_per_process_memory_growth())
131
132    config.set_gpu_per_process_memory_growth(True)
133    self.assertTrue(config.get_gpu_per_process_memory_growth())
134    self.assertEqual(
135        config.get_gpu_per_process_memory_growth(),
136        context.context().gpu_per_process_memory_growth)
137
138    config.set_gpu_per_process_memory_growth(False)
139    self.assertFalse(config.get_gpu_per_process_memory_growth())
140    self.assertEqual(
141        config.get_gpu_per_process_memory_growth(),
142        context.context().gpu_per_process_memory_growth)
143
144    constant_op.constant(1)
145    with self.assertRaises(RuntimeError):
146      config.set_gpu_per_process_memory_growth(True)
147
148  @reset_eager
149  def testIntraOpParallelismThreads(self):
150    config.set_intra_op_parallelism_threads(10)
151    self.assertEqual(
152        config.get_intra_op_parallelism_threads(),
153        context.context().intra_op_parallelism_threads)
154
155    constant_op.constant(1)
156    with self.assertRaises(RuntimeError):
157      config.set_intra_op_parallelism_threads(1)
158
159  @reset_eager
160  def testInterOpParallelismThreads(self):
161    config.set_inter_op_parallelism_threads(10)
162    self.assertEqual(
163        config.get_inter_op_parallelism_threads(),
164        context.context().inter_op_parallelism_threads)
165
166    constant_op.constant(1)
167    with self.assertRaises(RuntimeError):
168      config.set_inter_op_parallelism_threads(1)
169
170  @test_util.run_gpu_only
171  @reset_eager
172  def testSoftPlacement(self):
173    self.assertEqual(config.get_soft_device_placement(), True)
174
175    @def_function.function
176    def mod():
177      with ops.device('/device:GPU:0'):
178        a = constant_op.constant(1.0)
179        b = constant_op.constant(1.0)
180        return math_ops.mod(a, b)
181
182    # Since soft placement is enabled, the mod operation should work with CPU
183    mod()
184
185    config.set_soft_device_placement(False)
186    self.assertEqual(config.get_soft_device_placement(), False)
187    self.assertEqual(
188        config.get_soft_device_placement(),
189        context.context().soft_device_placement)
190
191    # Since soft placement is disabled, the mod operation should fail on GPU
192    with self.assertRaises(errors.InvalidArgumentError):
193      mod()
194
195    config.set_soft_device_placement(True)
196    self.assertEqual(config.get_soft_device_placement(), True)
197    self.assertEqual(
198        config.get_soft_device_placement(),
199        context.context().soft_device_placement)
200
201    # Since soft placement is re-enabled, the mod operation should work with CPU
202    mod()
203
204  @reset_eager
205  def testLogDevicePlacement(self):
206    self.assertEqual(context.get_log_device_placement(), False)
207
208    context.set_log_device_placement(True)
209    self.assertEqual(context.get_log_device_placement(), True)
210    self.assertEqual(
211        context.get_log_device_placement(),
212        context.context().log_device_placement)
213
214    context.set_log_device_placement(False)
215    self.assertEqual(context.get_log_device_placement(), False)
216    self.assertEqual(
217        context.get_log_device_placement(),
218        context.context().log_device_placement)
219
220    constant_op.constant(1)
221    with self.assertRaises(RuntimeError):
222      context.set_log_device_placement(True)
223    with self.assertRaises(RuntimeError):
224      context.set_log_device_placement(False)
225
226
227if __name__ == '__main__':
228  ops.enable_eager_execution()
229  test.main()
230