1# Lint as: python2, python3 2# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Visitor restricting traversal to only the public tensorflow API.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import re 23 24import six 25 26from tensorflow.python.util import tf_inspect 27 28 29class PublicAPIVisitor(object): 30 """Visitor to use with `traverse` to visit exactly the public TF API.""" 31 32 def __init__(self, visitor): 33 """Constructor. 34 35 `visitor` should be a callable suitable as a visitor for `traverse`. It will 36 be called only for members of the public TensorFlow API. 37 38 Args: 39 visitor: A visitor to call for the public API. 40 """ 41 self._visitor = visitor 42 self._root_name = 'tf' 43 44 # Modules/classes we want to suppress entirely. 45 self._private_map = { 46 'tf': [ 47 'compiler', 48 'core', 49 'python', 50 ], 51 # Some implementations have this internal module that we shouldn't 52 # expose. 53 'tf.flags': ['cpp_flags'], 54 } 55 56 # Modules/classes we do not want to descend into if we hit them. Usually, 57 # system modules exposed through platforms for compatibility reasons. 58 # Each entry maps a module path to a name to ignore in traversal. 59 self._do_not_descend_map = { 60 'tf': [ 61 'examples', 62 'flags', # Don't add flags 63 # TODO(drpng): This can be removed once sealed off. 64 'platform', 65 # TODO(drpng): This can be removed once sealed. 66 'pywrap_tensorflow', 67 # TODO(drpng): This can be removed once sealed. 68 'user_ops', 69 'tools', 70 'tensorboard', 71 ], 72 73 ## Everything below here is legitimate. 74 # It'll stay, but it's not officially part of the API. 75 'tf.app': ['flags'], 76 # Imported for compatibility between py2/3. 77 'tf.test': ['mock'], 78 } 79 80 @property 81 def private_map(self): 82 """A map from parents to symbols that should not be included at all. 83 84 This map can be edited, but it should not be edited once traversal has 85 begun. 86 87 Returns: 88 The map marking symbols to not include. 89 """ 90 return self._private_map 91 92 @property 93 def do_not_descend_map(self): 94 """A map from parents to symbols that should not be descended into. 95 96 This map can be edited, but it should not be edited once traversal has 97 begun. 98 99 Returns: 100 The map marking symbols to not explore. 101 """ 102 return self._do_not_descend_map 103 104 def set_root_name(self, root_name): 105 """Override the default root name of 'tf'.""" 106 self._root_name = root_name 107 108 def _is_private(self, path, name, obj=None): 109 """Return whether a name is private.""" 110 # TODO(wicke): Find out what names to exclude. 111 del obj # Unused. 112 return ((path in self._private_map and name in self._private_map[path]) or 113 (six.ensure_str(name).startswith('_') and 114 not re.match('__.*__$', six.ensure_str(name)) or 115 name in ['__base__', '__class__', '__next_in_mro__'])) 116 117 def _do_not_descend(self, path, name): 118 """Safely queries if a specific fully qualified name should be excluded.""" 119 return (path in self._do_not_descend_map and 120 name in self._do_not_descend_map[path]) 121 122 def __call__(self, path, parent, children): 123 """Visitor interface, see `traverse` for details.""" 124 125 # Avoid long waits in cases of pretty unambiguous failure. 126 if tf_inspect.ismodule(parent) and len( 127 six.ensure_str(path).split('.')) > 10: 128 raise RuntimeError('Modules nested too deep:\n%s.%s\n\nThis is likely a ' 129 'problem with an accidental public import.' % 130 (self._root_name, path)) 131 132 # Includes self._root_name 133 full_path = '.'.join([self._root_name, path]) if path else self._root_name 134 135 # Remove things that are not visible. 136 for name, child in list(children): 137 if self._is_private(full_path, name, child): 138 children.remove((name, child)) 139 140 self._visitor(path, parent, children) 141 142 # Remove things that are visible, but which should not be descended into. 143 for name, child in list(children): 144 if self._do_not_descend(full_path, name): 145 children.remove((name, child)) 146