1"""Utilities for collecting objects based on "is" comparison."""
2# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import weakref
22
23
24class _ObjectIdentityWrapper(object):
25  """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped.
26
27  Since __eq__ is based on object identity, it's safe to also define __hash__
28  based on object ids. This lets us add unhashable types like trackable
29  _ListWrapper objects to object-identity collections.
30  """
31
32  def __init__(self, wrapped):
33    self._wrapped = wrapped
34
35  @property
36  def unwrapped(self):
37    return self._wrapped
38
39  def __eq__(self, other):
40    if isinstance(other, _ObjectIdentityWrapper):
41      return self._wrapped is other._wrapped  # pylint: disable=protected-access
42    return self._wrapped is other
43
44  def __hash__(self):
45    # Wrapper id() is also fine for weakrefs. In fact, we rely on
46    # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is
47    # weakref.ref(a) in _WeakObjectIdentityWrapper.
48    return id(self._wrapped)
49
50
51class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper):
52
53  def __init__(self, wrapped):
54    super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped))
55
56  @property
57  def unwrapped(self):
58    return self._wrapped()
59
60
61class ObjectIdentityDictionary(collections.MutableMapping):
62  """A mutable mapping data structure which compares using "is".
63
64  This is necessary because we have trackable objects (_ListWrapper) which
65  have behavior identical to built-in Python lists (including being unhashable
66  and comparing based on the equality of their contents by default).
67  """
68
69  def __init__(self):
70    self._storage = {}
71
72  def _wrap_key(self, key):
73    return _ObjectIdentityWrapper(key)
74
75  def __getitem__(self, key):
76    return self._storage[self._wrap_key(key)]
77
78  def __setitem__(self, key, value):
79    self._storage[self._wrap_key(key)] = value
80
81  def __delitem__(self, key):
82    del self._storage[self._wrap_key(key)]
83
84  def __len__(self):
85    return len(self._storage)
86
87  def __iter__(self):
88    for key in self._storage:
89      yield key.unwrapped
90
91
92class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary):
93  """Like weakref.WeakKeyDictionary, but compares objects with "is"."""
94
95  def _wrap_key(self, key):
96    return _WeakObjectIdentityWrapper(key)
97
98  def __len__(self):
99    # Iterate, discarding old weak refs
100    return len(list(self._storage))
101
102  def __iter__(self):
103    keys = self._storage.keys()
104    for key in keys:
105      unwrapped = key.unwrapped
106      if unwrapped is None:
107        del self[key]
108      else:
109        yield unwrapped
110
111
112class ObjectIdentitySet(collections.MutableSet):
113  """Like the built-in set, but compares objects with "is"."""
114
115  def __init__(self, *args):
116    self._storage = set([self._wrap_key(obj) for obj in list(*args)])
117
118  def _wrap_key(self, key):
119    return _ObjectIdentityWrapper(key)
120
121  def __contains__(self, key):
122    return self._wrap_key(key) in self._storage
123
124  def discard(self, key):
125    self._storage.discard(self._wrap_key(key))
126
127  def add(self, key):
128    self._storage.add(self._wrap_key(key))
129
130  def __len__(self):
131    return len(self._storage)
132
133  def __iter__(self):
134    keys = list(self._storage)
135    for key in keys:
136      yield key.unwrapped
137
138
139class ObjectIdentityWeakSet(ObjectIdentitySet):
140  """Like weakref.WeakSet, but compares objects with "is"."""
141
142  def _wrap_key(self, key):
143    return _WeakObjectIdentityWrapper(key)
144
145  def __len__(self):
146    # Iterate, discarding old weak refs
147    return len([_ for _ in self])
148
149  def __iter__(self):
150    keys = list(self._storage)
151    for key in keys:
152      unwrapped = key.unwrapped
153      if unwrapped is None:
154        self.discard(key)
155      else:
156        yield unwrapped
157