1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.tools.common.public_api."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.platform import googletest
22from tensorflow.tools.common import public_api
23
24
25class PublicApiTest(googletest.TestCase):
26
27  class TestVisitor(object):
28
29    def __init__(self):
30      self.symbols = set()
31      self.last_parent = None
32      self.last_children = None
33
34    def __call__(self, path, parent, children):
35      self.symbols.add(path)
36      self.last_parent = parent
37      self.last_children = list(children)  # Make a copy to preserve state.
38
39  def test_call_forward(self):
40    visitor = self.TestVisitor()
41    children = [('name1', 'thing1'), ('name2', 'thing2')]
42    public_api.PublicAPIVisitor(visitor)('test', 'dummy', children)
43    self.assertEqual(set(['test']), visitor.symbols)
44    self.assertEqual('dummy', visitor.last_parent)
45    self.assertEqual([('name1', 'thing1'), ('name2', 'thing2')],
46                     visitor.last_children)
47
48  def test_private_child_removal(self):
49    visitor = self.TestVisitor()
50    children = [('name1', 'thing1'), ('_name2', 'thing2')]
51    public_api.PublicAPIVisitor(visitor)('test', 'dummy', children)
52    # Make sure the private symbols are removed before the visitor is called.
53    self.assertEqual([('name1', 'thing1')], visitor.last_children)
54    self.assertEqual([('name1', 'thing1')], children)
55
56  def test_no_descent_child_removal(self):
57    visitor = self.TestVisitor()
58    children = [('name1', 'thing1'), ('mock', 'thing2')]
59    public_api.PublicAPIVisitor(visitor)('test', 'dummy', children)
60    # Make sure not-to-be-descended-into symbols are removed after the visitor
61    # is called.
62    self.assertEqual([('name1', 'thing1'), ('mock', 'thing2')],
63                     visitor.last_children)
64    self.assertEqual([('name1', 'thing1')], children)
65
66
67if __name__ == '__main__':
68  googletest.main()
69