1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Contains _ExtensionDict class to represent extensions.
32"""
33
34from google.protobuf.internal import type_checkers
35from google.protobuf.descriptor import FieldDescriptor
36
37
38def _VerifyExtensionHandle(message, extension_handle):
39  """Verify that the given extension handle is valid."""
40
41  if not isinstance(extension_handle, FieldDescriptor):
42    raise KeyError('HasExtension() expects an extension handle, got: %s' %
43                   extension_handle)
44
45  if not extension_handle.is_extension:
46    raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
47
48  if not extension_handle.containing_type:
49    raise KeyError('"%s" is missing a containing_type.'
50                   % extension_handle.full_name)
51
52  if extension_handle.containing_type is not message.DESCRIPTOR:
53    raise KeyError('Extension "%s" extends message type "%s", but this '
54                   'message is of type "%s".' %
55                   (extension_handle.full_name,
56                    extension_handle.containing_type.full_name,
57                    message.DESCRIPTOR.full_name))
58
59
60# TODO(robinson): Unify error handling of "unknown extension" crap.
61# TODO(robinson): Support iteritems()-style iteration over all
62# extensions with the "has" bits turned on?
63class _ExtensionDict(object):
64
65  """Dict-like container for Extension fields on proto instances.
66
67  Note that in all cases we expect extension handles to be
68  FieldDescriptors.
69  """
70
71  def __init__(self, extended_message):
72    """
73    Args:
74      extended_message: Message instance for which we are the Extensions dict.
75    """
76    self._extended_message = extended_message
77
78  def __getitem__(self, extension_handle):
79    """Returns the current value of the given extension handle."""
80
81    _VerifyExtensionHandle(self._extended_message, extension_handle)
82
83    result = self._extended_message._fields.get(extension_handle)
84    if result is not None:
85      return result
86
87    if extension_handle.label == FieldDescriptor.LABEL_REPEATED:
88      result = extension_handle._default_constructor(self._extended_message)
89    elif extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
90      assert getattr(extension_handle.message_type, '_concrete_class', None), (
91          'Uninitialized concrete class found for field %r (message type %r)'
92          % (extension_handle.full_name,
93             extension_handle.message_type.full_name))
94      result = extension_handle.message_type._concrete_class()
95      try:
96        result._SetListener(self._extended_message._listener_for_children)
97      except ReferenceError:
98        pass
99    else:
100      # Singular scalar -- just return the default without inserting into the
101      # dict.
102      return extension_handle.default_value
103
104    # Atomically check if another thread has preempted us and, if not, swap
105    # in the new object we just created.  If someone has preempted us, we
106    # take that object and discard ours.
107    # WARNING:  We are relying on setdefault() being atomic.  This is true
108    #   in CPython but we haven't investigated others.  This warning appears
109    #   in several other locations in this file.
110    result = self._extended_message._fields.setdefault(
111        extension_handle, result)
112
113    return result
114
115  def __eq__(self, other):
116    if not isinstance(other, self.__class__):
117      return False
118
119    my_fields = self._extended_message.ListFields()
120    other_fields = other._extended_message.ListFields()
121
122    # Get rid of non-extension fields.
123    my_fields = [field for field in my_fields if field.is_extension]
124    other_fields = [field for field in other_fields if field.is_extension]
125
126    return my_fields == other_fields
127
128  def __ne__(self, other):
129    return not self == other
130
131  def __len__(self):
132    fields = self._extended_message.ListFields()
133    # Get rid of non-extension fields.
134    extension_fields = [field for field in fields if field[0].is_extension]
135    return len(extension_fields)
136
137  def __hash__(self):
138    raise TypeError('unhashable object')
139
140  # Note that this is only meaningful for non-repeated, scalar extension
141  # fields.  Note also that we may have to call _Modified() when we do
142  # successfully set a field this way, to set any necssary "has" bits in the
143  # ancestors of the extended message.
144  def __setitem__(self, extension_handle, value):
145    """If extension_handle specifies a non-repeated, scalar extension
146    field, sets the value of that field.
147    """
148
149    _VerifyExtensionHandle(self._extended_message, extension_handle)
150
151    if (extension_handle.label == FieldDescriptor.LABEL_REPEATED or
152        extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE):
153      raise TypeError(
154          'Cannot assign to extension "%s" because it is a repeated or '
155          'composite type.' % extension_handle.full_name)
156
157    # It's slightly wasteful to lookup the type checker each time,
158    # but we expect this to be a vanishingly uncommon case anyway.
159    type_checker = type_checkers.GetTypeChecker(extension_handle)
160    # pylint: disable=protected-access
161    self._extended_message._fields[extension_handle] = (
162        type_checker.CheckValue(value))
163    self._extended_message._Modified()
164
165  def _FindExtensionByName(self, name):
166    """Tries to find a known extension with the specified name.
167
168    Args:
169      name: Extension full name.
170
171    Returns:
172      Extension field descriptor.
173    """
174    return self._extended_message._extensions_by_name.get(name, None)
175
176  def _FindExtensionByNumber(self, number):
177    """Tries to find a known extension with the field number.
178
179    Args:
180      number: Extension field number.
181
182    Returns:
183      Extension field descriptor.
184    """
185    return self._extended_message._extensions_by_number.get(number, None)
186
187  def __iter__(self):
188    # Return a generator over the populated extension fields
189    return (f[0] for f in self._extended_message.ListFields()
190            if f[0].is_extension)
191
192  def __contains__(self, extension_handle):
193    _VerifyExtensionHandle(self._extended_message, extension_handle)
194
195    if extension_handle not in self._extended_message._fields:
196      return False
197
198    if extension_handle.label == FieldDescriptor.LABEL_REPEATED:
199      return bool(self._extended_message._fields.get(extension_handle))
200
201    if extension_handle.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
202      value = self._extended_message._fields.get(extension_handle)
203      # pylint: disable=protected-access
204      return value is not None and value._is_present_in_parent
205
206    return True
207