1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities related to layer/model functionality."""
16
17# TODO(b/110718070): Move these functions back to tensorflow/python/keras/utils
18# once __init__ files no longer require all of tf.keras to be imported together.
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import collections
25import functools
26import weakref
27
28from tensorflow.python.util import object_identity
29
30try:
31  # typing module is only used for comment type annotations.
32  import typing  # pylint: disable=g-import-not-at-top, unused-import
33except ImportError:
34  pass
35
36
37def is_layer(obj):
38  """Implicit check for Layer-like objects."""
39  # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer).
40  return hasattr(obj, "_is_layer") and not isinstance(obj, type)
41
42
43def has_weights(obj):
44  """Implicit check for Layer-like objects."""
45  # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer).
46  has_weight = (hasattr(type(obj), "trainable_weights")
47                and hasattr(type(obj), "non_trainable_weights"))
48
49  return has_weight and not isinstance(obj, type)
50
51
52def invalidate_recursive_cache(key):
53  """Convenience decorator to invalidate the cache when setting attributes."""
54  def outer(f):
55    @functools.wraps(f)
56    def wrapped(self, value):
57      sentinel = getattr(self, "_attribute_sentinel")  # type: AttributeSentinel
58      sentinel.invalidate(key)
59      return f(self, value)
60    return wrapped
61  return outer
62
63
64class MutationSentinel(object):
65  """Container for tracking whether a property is in a cached state."""
66  _in_cached_state = False
67
68  def mark_as(self, value):  # type: (MutationSentinel, bool) -> bool
69    may_affect_upstream = (value != self._in_cached_state)
70    self._in_cached_state = value
71    return may_affect_upstream
72
73  @property
74  def in_cached_state(self):
75    return self._in_cached_state
76
77
78class AttributeSentinel(object):
79  """Container for managing attribute cache state within a Layer.
80
81  The cache can be invalidated either on an individual basis (for instance when
82  an attribute is mutated) or a layer-wide basis (such as when a new dependency
83  is added).
84  """
85
86  def __init__(self, always_propagate=False):
87    self._parents = weakref.WeakSet()
88    self.attributes = collections.defaultdict(MutationSentinel)
89
90    # The trackable data structure containers are simple pass throughs. They
91    # don't know or care about particular attributes. As a result, they will
92    # consider themselves to be in a cached state, so it's up to the Layer
93    # which contains them to terminate propagation.
94    self.always_propagate = always_propagate
95
96  def __repr__(self):
97    return "{}\n  {}".format(
98        super(AttributeSentinel, self).__repr__(),
99        {k: v.in_cached_state for k, v in self.attributes.items()})
100
101  def add_parent(self, node):
102    # type: (AttributeSentinel, AttributeSentinel) -> None
103
104    # Properly tracking removal is quite challenging; however since this is only
105    # used to invalidate a cache it's alright to be overly conservative. We need
106    # to invalidate the cache of `node` (since it has implicitly gained a child)
107    # but we don't need to invalidate self since attributes should not depend on
108    # parent Layers.
109    self._parents.add(node)
110    node.invalidate_all()
111
112  def get(self, key):
113    # type: (AttributeSentinel, str) -> bool
114    return self.attributes[key].in_cached_state
115
116  def _set(self, key, value):
117    # type: (AttributeSentinel, str, bool) -> None
118    may_affect_upstream = self.attributes[key].mark_as(value)
119    if may_affect_upstream or self.always_propagate:
120      for node in self._parents:  # type: AttributeSentinel
121        node.invalidate(key)
122
123  def mark_cached(self, key):
124    # type: (AttributeSentinel, str) -> None
125    self._set(key, True)
126
127  def invalidate(self, key):
128    # type: (AttributeSentinel, str) -> None
129    self._set(key, False)
130
131  def invalidate_all(self):
132    # Parents may have different keys than their children, so we locally
133    # invalidate but use the `invalidate_all` method of parents.
134    for key in self.attributes.keys():
135      self.attributes[key].mark_as(False)
136
137    for node in self._parents:
138      node.invalidate_all()
139
140
141def filter_empty_layer_containers(layer_list):
142  """Filter out empty Layer-like containers and uniquify."""
143  # TODO(b/130381733): Make this an attribute in base_layer.Layer.
144  existing = object_identity.ObjectIdentitySet()
145  to_visit = layer_list[::-1]
146  while to_visit:
147    obj = to_visit.pop()
148    if obj in existing:
149      continue
150    existing.add(obj)
151    if is_layer(obj):
152      yield obj
153    else:
154      sub_layers = getattr(obj, "layers", None) or []
155
156      # Trackable data structures will not show up in ".layers" lists, but
157      # the layers they contain will.
158      to_visit.extend(sub_layers[::-1])
159