1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utility functions for comparing proto2 messages in Python.
17
18ProtoEq() compares two proto2 messages for equality.
19
20ClearDefaultValuedFields() recursively clears the fields that are set to their
21default values. This is useful for comparing protocol buffers where the
22semantics of unset fields and default valued fields are the same.
23
24assertProtoEqual() is useful for unit tests.  It produces much more helpful
25output than assertEqual() for proto2 messages, e.g. this:
26
27  outer {
28    inner {
29-     strings: "x"
30?               ^
31+     strings: "y"
32?               ^
33    }
34  }
35
36...compared to the default output from assertEqual() that looks like this:
37
38AssertionError: <my.Msg object at 0x9fb353c> != <my.Msg object at 0x9fb35cc>
39
40Call it inside your unit test's googletest.TestCase subclasses like this:
41
42  from tensorflow.python.util.protobuf import compare
43
44  class MyTest(googletest.TestCase):
45    ...
46    def testXXX(self):
47      ...
48      compare.assertProtoEqual(self, a, b)
49
50Alternatively:
51
52  from tensorflow.python.util.protobuf import compare
53
54  class MyTest(compare.ProtoAssertions, googletest.TestCase):
55    ...
56    def testXXX(self):
57      ...
58      self.assertProtoEqual(a, b)
59"""
60
61from __future__ import absolute_import
62from __future__ import division
63from __future__ import print_function
64
65import difflib
66
67import six
68
69from google.protobuf import descriptor
70from google.protobuf import descriptor_pool
71from google.protobuf import message
72from google.protobuf import text_format
73
74from ..compat import collections_abc
75
76
77def assertProtoEqual(self, a, b, check_initialized=True,  # pylint: disable=invalid-name
78                     normalize_numbers=False, msg=None):
79  """Fails with a useful error if a and b aren't equal.
80
81  Comparison of repeated fields matches the semantics of
82  unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter.
83
84  Args:
85    self: googletest.TestCase
86    a: proto2 PB instance, or text string representing one.
87    b: proto2 PB instance -- message.Message or subclass thereof.
88    check_initialized: boolean, whether to fail if either a or b isn't
89      initialized.
90    normalize_numbers: boolean, whether to normalize types and precision of
91      numbers before comparison.
92    msg: if specified, is used as the error message on failure.
93  """
94  pool = descriptor_pool.Default()
95  if isinstance(a, six.string_types):
96    a = text_format.Merge(a, b.__class__(), descriptor_pool=pool)
97
98  for pb in a, b:
99    if check_initialized:
100      errors = pb.FindInitializationErrors()
101      if errors:
102        self.fail('Initialization errors: %s\n%s' % (errors, pb))
103    if normalize_numbers:
104      NormalizeNumberFields(pb)
105
106  a_str = text_format.MessageToString(a, descriptor_pool=pool)
107  b_str = text_format.MessageToString(b, descriptor_pool=pool)
108
109  # Some Python versions would perform regular diff instead of multi-line
110  # diff if string is longer than 2**16. We substitute this behavior
111  # with a call to unified_diff instead to have easier-to-read diffs.
112  # For context, see: https://bugs.python.org/issue11763.
113  if len(a_str) < 2**16 and len(b_str) < 2**16:
114    self.assertMultiLineEqual(a_str, b_str, msg=msg)
115  else:
116    diff = '\n' + ''.join(difflib.unified_diff(a_str.splitlines(True),
117                                               b_str.splitlines(True)))
118    self.fail('%s : %s' % (msg, diff))
119
120
121def NormalizeNumberFields(pb):
122  """Normalizes types and precisions of number fields in a protocol buffer.
123
124  Due to subtleties in the python protocol buffer implementation, it is possible
125  for values to have different types and precision depending on whether they
126  were set and retrieved directly or deserialized from a protobuf. This function
127  normalizes integer values to ints and longs based on width, 32-bit floats to
128  five digits of precision to account for python always storing them as 64-bit,
129  and ensures doubles are floating point for when they're set to integers.
130
131  Modifies pb in place. Recurses into nested objects.
132
133  Args:
134    pb: proto2 message.
135
136  Returns:
137    the given pb, modified in place.
138  """
139  for desc, values in pb.ListFields():
140    is_repeated = True
141    if desc.label is not descriptor.FieldDescriptor.LABEL_REPEATED:
142      is_repeated = False
143      values = [values]
144
145    normalized_values = None
146
147    # We force 32-bit values to int and 64-bit values to long to make
148    # alternate implementations where the distinction is more significant
149    # (e.g. the C++ implementation) simpler.
150    if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
151                     descriptor.FieldDescriptor.TYPE_UINT64,
152                     descriptor.FieldDescriptor.TYPE_SINT64):
153      normalized_values = [int(x) for x in values]
154    elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
155                       descriptor.FieldDescriptor.TYPE_UINT32,
156                       descriptor.FieldDescriptor.TYPE_SINT32,
157                       descriptor.FieldDescriptor.TYPE_ENUM):
158      normalized_values = [int(x) for x in values]
159    elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
160      normalized_values = [round(x, 6) for x in values]
161    elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
162      normalized_values = [round(float(x), 7) for x in values]
163
164    if normalized_values is not None:
165      if is_repeated:
166        pb.ClearField(desc.name)
167        getattr(pb, desc.name).extend(normalized_values)
168      else:
169        setattr(pb, desc.name, normalized_values[0])
170
171    if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
172        desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
173      if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
174          desc.message_type.has_options and
175          desc.message_type.GetOptions().map_entry):
176        # This is a map, only recurse if the values have a message type.
177        if (desc.message_type.fields_by_number[2].type ==
178            descriptor.FieldDescriptor.TYPE_MESSAGE):
179          for v in six.itervalues(values):
180            NormalizeNumberFields(v)
181      else:
182        for v in values:
183          # recursive step
184          NormalizeNumberFields(v)
185
186  return pb
187
188
189def _IsMap(value):
190  return isinstance(value, collections_abc.Mapping)
191
192
193def _IsRepeatedContainer(value):
194  if isinstance(value, six.string_types):
195    return False
196  try:
197    iter(value)
198    return True
199  except TypeError:
200    return False
201
202
203def ProtoEq(a, b):
204  """Compares two proto2 objects for equality.
205
206  Recurses into nested messages. Uses list (not set) semantics for comparing
207  repeated fields, ie duplicates and order matter.
208
209  Args:
210    a: A proto2 message or a primitive.
211    b: A proto2 message or a primitive.
212
213  Returns:
214    `True` if the messages are equal.
215  """
216  def Format(pb):
217    """Returns a dictionary or unchanged pb bases on its type.
218
219    Specifically, this function returns a dictionary that maps tag
220    number (for messages) or element index (for repeated fields) to
221    value, or just pb unchanged if it's neither.
222
223    Args:
224      pb: A proto2 message or a primitive.
225    Returns:
226      A dict or unchanged pb.
227    """
228    if isinstance(pb, message.Message):
229      return dict((desc.number, value) for desc, value in pb.ListFields())
230    elif _IsMap(pb):
231      return dict(pb.items())
232    elif _IsRepeatedContainer(pb):
233      return dict(enumerate(list(pb)))
234    else:
235      return pb
236
237  a, b = Format(a), Format(b)
238
239  # Base case
240  if not isinstance(a, dict) or not isinstance(b, dict):
241    return a == b
242
243  # This list performs double duty: it compares two messages by tag value *or*
244  # two repeated fields by element, in order. the magic is in the format()
245  # function, which converts them both to the same easily comparable format.
246  for tag in sorted(set(a.keys()) | set(b.keys())):
247    if tag not in a or tag not in b:
248      return False
249    else:
250      # Recursive step
251      if not ProtoEq(a[tag], b[tag]):
252        return False
253
254  # Didn't find any values that differed, so they're equal!
255  return True
256
257
258class ProtoAssertions(object):
259  """Mix this into a googletest.TestCase class to get proto2 assertions.
260
261  Usage:
262
263  class SomeTestCase(compare.ProtoAssertions, googletest.TestCase):
264    ...
265    def testSomething(self):
266      ...
267      self.assertProtoEqual(a, b)
268
269  See module-level definitions for method documentation.
270  """
271
272  # pylint: disable=invalid-name
273  def assertProtoEqual(self, *args, **kwargs):
274    return assertProtoEqual(self, *args, **kwargs)
275