1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests Policies.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import test_util 23from tensorflow.python.keras.mixed_precision.experimental import policy as mp_policy 24from tensorflow.python.platform import test 25 26 27@test_util.run_all_in_graph_and_eager_modes 28class PolicyTest(test.TestCase): 29 """Tests Policies.""" 30 31 def test_infer(self): 32 policy = mp_policy.Policy('infer') 33 self.assertEqual(policy.name, 'infer') 34 self.assertEqual(policy.default_variable_dtype, None) 35 36 def test_infer_float32_vars(self): 37 policy = mp_policy.Policy('infer_float32_vars') 38 self.assertEqual(policy.name, 'infer_float32_vars') 39 self.assertEqual(policy.default_variable_dtype, 'float32') 40 41 def test_global_policy(self): 42 self.assertEqual(mp_policy.global_policy().name, 'infer') 43 default_policy = mp_policy.global_policy() 44 try: 45 mp_policy.set_policy('infer_float32_vars') 46 self.assertEqual(mp_policy.global_policy().name, 'infer_float32_vars') 47 self.assertEqual(mp_policy.global_policy().default_variable_dtype, 48 'float32') 49 with ops.Graph().as_default(): # Policies are not associated with a graph 50 self.assertEqual(mp_policy.global_policy().name, 'infer_float32_vars') 51 mp_policy.set_policy('infer') 52 self.assertEqual(mp_policy.global_policy().name, 'infer') 53 self.assertEqual(mp_policy.global_policy().default_variable_dtype, None) 54 policy = mp_policy.Policy('infer_float32_vars') 55 mp_policy.set_policy(policy) 56 self.assertIs(mp_policy.global_policy(), policy) 57 finally: 58 mp_policy.set_policy(default_policy) 59 60 def test_policy_scope(self): 61 with mp_policy.policy_scope('infer_float32_vars'): 62 self.assertEqual(mp_policy.global_policy().name, 'infer_float32_vars') 63 with mp_policy.policy_scope('infer'): 64 self.assertEqual(mp_policy.global_policy().name, 'infer') 65 self.assertEqual(mp_policy.global_policy().name, 'infer_float32_vars') 66 self.assertEqual(mp_policy.global_policy().name, 'infer') 67 68if __name__ == '__main__': 69 test.main() 70