1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for io_utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import sys
22
23import six
24
25from tensorflow.python.keras import keras_parameterized
26from tensorflow.python.keras.utils import io_utils
27from tensorflow.python.platform import test
28
29
30class TestIOUtils(keras_parameterized.TestCase):
31
32  def test_ask_to_proceed_with_overwrite(self):
33    with test.mock.patch.object(six.moves, 'input') as mock_log:
34      mock_log.return_value = 'y'
35      self.assertTrue(io_utils.ask_to_proceed_with_overwrite('/tmp/not_exists'))
36
37      mock_log.return_value = 'n'
38      self.assertFalse(
39          io_utils.ask_to_proceed_with_overwrite('/tmp/not_exists'))
40
41      mock_log.side_effect = ['m', 'y']
42      self.assertTrue(io_utils.ask_to_proceed_with_overwrite('/tmp/not_exists'))
43
44      mock_log.side_effect = ['m', 'n']
45      self.assertFalse(
46          io_utils.ask_to_proceed_with_overwrite('/tmp/not_exists'))
47
48  def test_path_to_string(self):
49
50    class PathLikeDummy(object):
51
52      def __fspath__(self):
53        return 'dummypath'
54
55    dummy = object()
56    if sys.version_info >= (3, 4):
57      from pathlib import Path  # pylint:disable=g-import-not-at-top
58      # conversion of PathLike
59      self.assertEqual(io_utils.path_to_string(Path('path')), 'path')
60    if sys.version_info >= (3, 6):
61      self.assertEqual(io_utils.path_to_string(PathLikeDummy()), 'dummypath')
62
63    # pass-through, works for all versions of python
64    self.assertEqual(io_utils.path_to_string('path'), 'path')
65    self.assertIs(io_utils.path_to_string(dummy), dummy)
66
67
68if __name__ == '__main__':
69  test.main()
70