1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for backend_config."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.keras import backend
21from tensorflow.python.keras import backend_config
22from tensorflow.python.keras import combinations
23from tensorflow.python.platform import test
24
25
26@combinations.generate(combinations.combine(mode=['graph', 'eager']))
27class BackendConfigTest(test.TestCase):
28
29  def test_backend(self):
30    self.assertEqual(backend.backend(), 'tensorflow')
31
32  def test_epsilon(self):
33    epsilon = 1e-2
34    backend_config.set_epsilon(epsilon)
35    self.assertEqual(backend_config.epsilon(), epsilon)
36    backend_config.set_epsilon(1e-7)
37    self.assertEqual(backend_config.epsilon(), 1e-7)
38
39  def test_floatx(self):
40    floatx = 'float64'
41    backend_config.set_floatx(floatx)
42    self.assertEqual(backend_config.floatx(), floatx)
43    backend_config.set_floatx('float32')
44    self.assertEqual(backend_config.floatx(), 'float32')
45
46  def test_image_data_format(self):
47    image_data_format = 'channels_first'
48    backend_config.set_image_data_format(image_data_format)
49    self.assertEqual(backend_config.image_data_format(), image_data_format)
50    backend_config.set_image_data_format('channels_last')
51    self.assertEqual(backend_config.image_data_format(), 'channels_last')
52
53
54if __name__ == '__main__':
55  test.main()
56