1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for python.util.protobuf.compare."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import re
23import textwrap
24
25import six
26
27from google.protobuf import text_format
28
29from tensorflow.python.platform import googletest
30from tensorflow.python.util.protobuf import compare
31from tensorflow.python.util.protobuf import compare_test_pb2
32
33
34def LargePbs(*args):
35  """Converts ASCII string Large PBs to messages."""
36  pbs = []
37  for arg in args:
38    pb = compare_test_pb2.Large()
39    text_format.Merge(arg, pb)
40    pbs.append(pb)
41
42  return pbs
43
44
45class ProtoEqTest(googletest.TestCase):
46
47  def assertNotEquals(self, a, b):
48    """Asserts that ProtoEq says a != b."""
49    a, b = LargePbs(a, b)
50    googletest.TestCase.assertEqual(self, compare.ProtoEq(a, b), False)
51
52  def assertEqual(self, a, b):
53    """Asserts that ProtoEq says a == b."""
54    a, b = LargePbs(a, b)
55    googletest.TestCase.assertEqual(self, compare.ProtoEq(a, b), True)
56
57  def testPrimitives(self):
58    googletest.TestCase.assertEqual(self, True, compare.ProtoEq('a', 'a'))
59    googletest.TestCase.assertEqual(self, False, compare.ProtoEq('b', 'a'))
60
61  def testEmpty(self):
62    self.assertEqual('', '')
63
64  def testPrimitiveFields(self):
65    self.assertNotEquals('string_: "a"', '')
66    self.assertEqual('string_: "a"', 'string_: "a"')
67    self.assertNotEquals('string_: "b"', 'string_: "a"')
68    self.assertNotEquals('string_: "ab"', 'string_: "aa"')
69
70    self.assertNotEquals('int64_: 0', '')
71    self.assertEqual('int64_: 0', 'int64_: 0')
72    self.assertNotEquals('int64_: -1', '')
73    self.assertNotEquals('int64_: 1', 'int64_: 0')
74    self.assertNotEquals('int64_: 0', 'int64_: -1')
75
76    self.assertNotEquals('float_: 0.0', '')
77    self.assertEqual('float_: 0.0', 'float_: 0.0')
78    self.assertNotEquals('float_: -0.1', '')
79    self.assertNotEquals('float_: 3.14', 'float_: 0')
80    self.assertNotEquals('float_: 0', 'float_: -0.1')
81    self.assertEqual('float_: -0.1', 'float_: -0.1')
82
83    self.assertNotEquals('bool_: true', '')
84    self.assertNotEquals('bool_: false', '')
85    self.assertNotEquals('bool_: true', 'bool_: false')
86    self.assertEqual('bool_: false', 'bool_: false')
87    self.assertEqual('bool_: true', 'bool_: true')
88
89    self.assertNotEquals('enum_: A', '')
90    self.assertNotEquals('enum_: B', 'enum_: A')
91    self.assertNotEquals('enum_: C', 'enum_: B')
92    self.assertEqual('enum_: C', 'enum_: C')
93
94  def testRepeatedPrimitives(self):
95    self.assertNotEquals('int64s: 0', '')
96    self.assertEqual('int64s: 0', 'int64s: 0')
97    self.assertNotEquals('int64s: 1', 'int64s: 0')
98    self.assertNotEquals('int64s: 0 int64s: 0', '')
99    self.assertNotEquals('int64s: 0 int64s: 0', 'int64s: 0')
100    self.assertNotEquals('int64s: 1 int64s: 0', 'int64s: 0')
101    self.assertNotEquals('int64s: 0 int64s: 1', 'int64s: 0')
102    self.assertNotEquals('int64s: 1', 'int64s: 0 int64s: 2')
103    self.assertNotEquals('int64s: 2 int64s: 0', 'int64s: 1')
104    self.assertEqual('int64s: 0 int64s: 0', 'int64s: 0 int64s: 0')
105    self.assertEqual('int64s: 0 int64s: 1', 'int64s: 0 int64s: 1')
106    self.assertNotEquals('int64s: 1 int64s: 0', 'int64s: 0 int64s: 0')
107    self.assertNotEquals('int64s: 1 int64s: 0', 'int64s: 0 int64s: 1')
108    self.assertNotEquals('int64s: 1 int64s: 0', 'int64s: 0 int64s: 2')
109    self.assertNotEquals('int64s: 1 int64s: 1', 'int64s: 1 int64s: 0')
110    self.assertNotEquals('int64s: 1 int64s: 1', 'int64s: 1 int64s: 0 int64s: 2')
111
112  def testMessage(self):
113    self.assertNotEquals('small <>', '')
114    self.assertEqual('small <>', 'small <>')
115    self.assertNotEquals('small < strings: "a" >', '')
116    self.assertNotEquals('small < strings: "a" >', 'small <>')
117    self.assertEqual('small < strings: "a" >', 'small < strings: "a" >')
118    self.assertNotEquals('small < strings: "b" >', 'small < strings: "a" >')
119    self.assertNotEquals('small < strings: "a" strings: "b" >',
120                         'small < strings: "a" >')
121
122    self.assertNotEquals('string_: "a"', 'small <>')
123    self.assertNotEquals('string_: "a"', 'small < strings: "b" >')
124    self.assertNotEquals('string_: "a"', 'small < strings: "b" strings: "c" >')
125    self.assertNotEquals('string_: "a" small <>', 'small <>')
126    self.assertNotEquals('string_: "a" small <>', 'small < strings: "b" >')
127    self.assertEqual('string_: "a" small <>', 'string_: "a" small <>')
128    self.assertNotEquals('string_: "a" small < strings: "a" >',
129                         'string_: "a" small <>')
130    self.assertEqual('string_: "a" small < strings: "a" >',
131                     'string_: "a" small < strings: "a" >')
132    self.assertNotEquals('string_: "a" small < strings: "a" >',
133                         'int64_: 1 small < strings: "a" >')
134    self.assertNotEquals('string_: "a" small < strings: "a" >', 'int64_: 1')
135    self.assertNotEquals('string_: "a"', 'int64_: 1 small < strings: "a" >')
136    self.assertNotEquals('string_: "a" int64_: 0 small < strings: "a" >',
137                         'int64_: 1 small < strings: "a" >')
138    self.assertNotEquals('string_: "a" int64_: 1 small < strings: "a" >',
139                         'string_: "a" int64_: 0 small < strings: "a" >')
140    self.assertEqual('string_: "a" int64_: 0 small < strings: "a" >',
141                     'string_: "a" int64_: 0 small < strings: "a" >')
142
143  def testNestedMessage(self):
144    self.assertNotEquals('medium <>', '')
145    self.assertEqual('medium <>', 'medium <>')
146    self.assertNotEquals('medium < smalls <> >', 'medium <>')
147    self.assertEqual('medium < smalls <> >', 'medium < smalls <> >')
148    self.assertNotEquals('medium < smalls <> smalls <> >',
149                         'medium < smalls <> >')
150    self.assertEqual('medium < smalls <> smalls <> >',
151                     'medium < smalls <> smalls <> >')
152
153    self.assertNotEquals('medium < int32s: 0 >', 'medium < smalls <> >')
154
155    self.assertNotEquals('medium < smalls < strings: "a"> >',
156                         'medium < smalls <> >')
157
158  def testTagOrder(self):
159    """Tests that different fields are ordered by tag number.
160
161    For reference, here are the relevant tag numbers from compare_test.proto:
162      optional string string_ = 1;
163      optional int64 int64_ = 2;
164      optional float float_ = 3;
165      optional Small small = 8;
166      optional Medium medium = 7;
167      optional Small small = 8;
168    """
169    self.assertNotEquals('string_: "a"                      ',
170                         '             int64_: 1            ')
171    self.assertNotEquals('string_: "a" int64_: 2            ',
172                         '             int64_: 1            ')
173    self.assertNotEquals('string_: "b" int64_: 1            ',
174                         'string_: "a" int64_: 2            ')
175    self.assertEqual('string_: "a" int64_: 1            ',
176                     'string_: "a" int64_: 1            ')
177    self.assertNotEquals('string_: "a" int64_: 1 float_: 0.0',
178                         'string_: "a" int64_: 1            ')
179    self.assertEqual('string_: "a" int64_: 1 float_: 0.0',
180                     'string_: "a" int64_: 1 float_: 0.0')
181    self.assertNotEquals('string_: "a" int64_: 1 float_: 0.1',
182                         'string_: "a" int64_: 1 float_: 0.0')
183    self.assertNotEquals('string_: "a" int64_: 2 float_: 0.0',
184                         'string_: "a" int64_: 1 float_: 0.1')
185    self.assertNotEquals('string_: "a"                      ',
186                         '             int64_: 1 float_: 0.1')
187    self.assertNotEquals('string_: "a"           float_: 0.0',
188                         '             int64_: 1            ')
189    self.assertNotEquals('string_: "b"           float_: 0.0',
190                         'string_: "a" int64_: 1            ')
191
192    self.assertNotEquals('string_: "a"', 'small < strings: "a" >')
193    self.assertNotEquals('string_: "a" small < strings: "a" >',
194                         'small < strings: "b" >')
195    self.assertNotEquals('string_: "a" small < strings: "b" >',
196                         'string_: "a" small < strings: "a" >')
197    self.assertEqual('string_: "a" small < strings: "a" >',
198                     'string_: "a" small < strings: "a" >')
199
200    self.assertNotEquals('string_: "a" medium <>',
201                         'string_: "a" small < strings: "a" >')
202    self.assertNotEquals('string_: "a" medium < smalls <> >',
203                         'string_: "a" small < strings: "a" >')
204    self.assertNotEquals('medium <>', 'small < strings: "a" >')
205    self.assertNotEquals('medium <> small <>', 'small < strings: "a" >')
206    self.assertNotEquals('medium < smalls <> >', 'small < strings: "a" >')
207    self.assertNotEquals('medium < smalls < strings: "a" > >',
208                         'small < strings: "b" >')
209
210
211class NormalizeNumbersTest(googletest.TestCase):
212  """Tests for NormalizeNumberFields()."""
213
214  def testNormalizesInts(self):
215    pb = compare_test_pb2.Large()
216    pb.int64_ = 4
217    compare.NormalizeNumberFields(pb)
218    self.assertTrue(isinstance(pb.int64_, six.integer_types))
219
220    pb.int64_ = 4
221    compare.NormalizeNumberFields(pb)
222    self.assertTrue(isinstance(pb.int64_, six.integer_types))
223
224    pb.int64_ = 9999999999999999
225    compare.NormalizeNumberFields(pb)
226    self.assertTrue(isinstance(pb.int64_, six.integer_types))
227
228  def testNormalizesRepeatedInts(self):
229    pb = compare_test_pb2.Large()
230    pb.int64s.extend([1, 400, 999999999999999])
231    compare.NormalizeNumberFields(pb)
232    self.assertTrue(isinstance(pb.int64s[0], six.integer_types))
233    self.assertTrue(isinstance(pb.int64s[1], six.integer_types))
234    self.assertTrue(isinstance(pb.int64s[2], six.integer_types))
235
236  def testNormalizesFloats(self):
237    pb1 = compare_test_pb2.Large()
238    pb1.float_ = 1.2314352351231
239    pb2 = compare_test_pb2.Large()
240    pb2.float_ = 1.231435
241    self.assertNotEqual(pb1.float_, pb2.float_)
242    compare.NormalizeNumberFields(pb1)
243    compare.NormalizeNumberFields(pb2)
244    self.assertEqual(pb1.float_, pb2.float_)
245
246  def testNormalizesRepeatedFloats(self):
247    pb = compare_test_pb2.Large()
248    pb.medium.floats.extend([0.111111111, 0.111111])
249    compare.NormalizeNumberFields(pb)
250    for value in pb.medium.floats:
251      self.assertAlmostEqual(0.111111, value)
252
253  def testNormalizesDoubles(self):
254    pb1 = compare_test_pb2.Large()
255    pb1.double_ = 1.2314352351231
256    pb2 = compare_test_pb2.Large()
257    pb2.double_ = 1.2314352
258    self.assertNotEqual(pb1.double_, pb2.double_)
259    compare.NormalizeNumberFields(pb1)
260    compare.NormalizeNumberFields(pb2)
261    self.assertEqual(pb1.double_, pb2.double_)
262
263  def testNormalizesMaps(self):
264    pb = compare_test_pb2.WithMap()
265    pb.value_message[4].strings.extend(['a', 'b', 'c'])
266    pb.value_string['d'] = 'e'
267    compare.NormalizeNumberFields(pb)
268
269
270class AssertTest(googletest.TestCase):
271  """Tests assertProtoEqual()."""
272
273  def assertProtoEqual(self, a, b, **kwargs):
274    if isinstance(a, six.string_types) and isinstance(b, six.string_types):
275      a, b = LargePbs(a, b)
276    compare.assertProtoEqual(self, a, b, **kwargs)
277
278  def assertAll(self, a, **kwargs):
279    """Checks that all possible asserts pass."""
280    self.assertProtoEqual(a, a, **kwargs)
281
282  def assertSameNotEqual(self, a, b):
283    """Checks that assertProtoEqual() fails."""
284    self.assertRaises(AssertionError, self.assertProtoEqual, a, b)
285
286  def assertNone(self, a, b, message, **kwargs):
287    """Checks that all possible asserts fail with the given message."""
288    message = re.escape(textwrap.dedent(message))
289    self.assertRaisesRegex(AssertionError, message, self.assertProtoEqual, a, b,
290                           **kwargs)
291
292  def testCheckInitialized(self):
293    # neither is initialized
294    a = compare_test_pb2.Labeled()
295    a.optional = 1
296    self.assertNone(a, a, 'Initialization errors: ', check_initialized=True)
297    self.assertAll(a, check_initialized=False)
298
299    # a is initialized, b isn't
300    b = copy.deepcopy(a)
301    a.required = 2
302    self.assertNone(a, b, 'Initialization errors: ', check_initialized=True)
303    self.assertNone(
304        a,
305        b,
306        """
307                    - required: 2
308                      optional: 1
309                    """,
310        check_initialized=False)
311
312    # both are initialized
313    a = compare_test_pb2.Labeled()
314    a.required = 2
315    self.assertAll(a, check_initialized=True)
316    self.assertAll(a, check_initialized=False)
317
318    b = copy.deepcopy(a)
319    b.required = 3
320    message = """
321              - required: 2
322              ?           ^
323              + required: 3
324              ?           ^
325              """
326    self.assertNone(a, b, message, check_initialized=True)
327    self.assertNone(a, b, message, check_initialized=False)
328
329  def testAssertEqualWithStringArg(self):
330    pb = compare_test_pb2.Large()
331    pb.string_ = 'abc'
332    pb.float_ = 1.234
333    compare.assertProtoEqual(self, """
334          string_: 'abc'
335          float_: 1.234
336        """, pb)
337
338  def testNormalizesNumbers(self):
339    pb1 = compare_test_pb2.Large()
340    pb1.int64_ = 4
341    pb2 = compare_test_pb2.Large()
342    pb2.int64_ = 4
343    compare.assertProtoEqual(self, pb1, pb2)
344
345  def testNormalizesFloat(self):
346    pb1 = compare_test_pb2.Large()
347    pb1.double_ = 4.0
348    pb2 = compare_test_pb2.Large()
349    pb2.double_ = 4
350    compare.assertProtoEqual(self, pb1, pb2, normalize_numbers=True)
351
352  def testPrimitives(self):
353    self.assertAll('string_: "x"')
354    self.assertNone('string_: "x"', 'string_: "y"', """
355                    - string_: "x"
356                    ?           ^
357                    + string_: "y"
358                    ?           ^
359                    """)
360
361  def testRepeatedPrimitives(self):
362    self.assertAll('int64s: 0 int64s: 1')
363
364    self.assertSameNotEqual('int64s: 0 int64s: 1', 'int64s: 1 int64s: 0')
365    self.assertSameNotEqual('int64s: 0 int64s: 1 int64s: 2',
366                            'int64s: 2 int64s: 1 int64s: 0')
367
368    self.assertSameNotEqual('int64s: 0', 'int64s: 0 int64s: 0')
369    self.assertSameNotEqual('int64s: 0 int64s: 1',
370                            'int64s: 1 int64s: 0 int64s: 1')
371
372    self.assertNone('int64s: 0', 'int64s: 0 int64s: 2', """
373                      int64s: 0
374                    + int64s: 2
375                    """)
376    self.assertNone('int64s: 0 int64s: 1', 'int64s: 0 int64s: 2', """
377                      int64s: 0
378                    - int64s: 1
379                    ?         ^
380                    + int64s: 2
381                    ?         ^
382                    """)
383
384  def testMessage(self):
385    self.assertAll('medium: {}')
386    self.assertAll('medium: { smalls: {} }')
387    self.assertAll('medium: { int32s: 1 smalls: {} }')
388    self.assertAll('medium: { smalls: { strings: "x" } }')
389    self.assertAll(
390        'medium: { smalls: { strings: "x" } } small: { strings: "y" }')
391
392    self.assertSameNotEqual('medium: { smalls: { strings: "x" strings: "y" } }',
393                            'medium: { smalls: { strings: "y" strings: "x" } }')
394    self.assertSameNotEqual(
395        'medium: { smalls: { strings: "x" } smalls: { strings: "y" } }',
396        'medium: { smalls: { strings: "y" } smalls: { strings: "x" } }')
397
398    self.assertSameNotEqual(
399        'medium: { smalls: { strings: "x" strings: "y" strings: "x" } }',
400        'medium: { smalls: { strings: "y" strings: "x" } }')
401    self.assertSameNotEqual(
402        'medium: { smalls: { strings: "x" } int32s: 0 }',
403        'medium: { int32s: 0 smalls: { strings: "x" } int32s: 0 }')
404
405    self.assertNone('medium: {}', 'medium: { smalls: { strings: "x" } }', """
406                      medium {
407                    +   smalls {
408                    +     strings: "x"
409                    +   }
410                      }
411                    """)
412    self.assertNone('medium: { smalls: { strings: "x" } }',
413                    'medium: { smalls: {} }', """
414                      medium {
415                        smalls {
416                    -     strings: "x"
417                        }
418                      }
419                    """)
420    self.assertNone('medium: { int32s: 0 }', 'medium: { int32s: 1 }', """
421                      medium {
422                    -   int32s: 0
423                    ?           ^
424                    +   int32s: 1
425                    ?           ^
426                      }
427                    """)
428
429  def testMsgPassdown(self):
430    self.assertRaisesRegex(
431        AssertionError,
432        'test message passed down',
433        self.assertProtoEqual,
434        'medium: {}',
435        'medium: { smalls: { strings: "x" } }',
436        msg='test message passed down')
437
438  def testRepeatedMessage(self):
439    self.assertAll('medium: { smalls: {} smalls: {} }')
440    self.assertAll('medium: { smalls: { strings: "x" } } medium: {}')
441    self.assertAll('medium: { smalls: { strings: "x" } } medium: { int32s: 0 }')
442    self.assertAll('medium: { smalls: {} smalls: { strings: "x" } } small: {}')
443
444    self.assertSameNotEqual('medium: { smalls: { strings: "x" } smalls: {} }',
445                            'medium: { smalls: {} smalls: { strings: "x" } }')
446
447    self.assertSameNotEqual('medium: { smalls: {} }',
448                            'medium: { smalls: {} smalls: {} }')
449    self.assertSameNotEqual('medium: { smalls: {} smalls: {} } medium: {}',
450                            'medium: {} medium: {} medium: { smalls: {} }')
451    self.assertSameNotEqual(
452        'medium: { smalls: { strings: "x" } smalls: {} }',
453        'medium: { smalls: {} smalls: { strings: "x" } smalls: {} }')
454
455    self.assertNone('medium: {}', 'medium: {} medium { smalls: {} }', """
456                      medium {
457                    +   smalls {
458                    +   }
459                      }
460                    """)
461    self.assertNone('medium: { smalls: {} smalls: { strings: "x" } }',
462                    'medium: { smalls: {} smalls: { strings: "y" } }', """
463                      medium {
464                        smalls {
465                        }
466                        smalls {
467                    -     strings: "x"
468                    ?               ^
469                    +     strings: "y"
470                    ?               ^
471                        }
472                      }
473                    """)
474
475
476class MixinTests(compare.ProtoAssertions, googletest.TestCase):
477
478  def testAssertEqualWithStringArg(self):
479    pb = compare_test_pb2.Large()
480    pb.string_ = 'abc'
481    pb.float_ = 1.234
482    self.assertProtoEqual("""
483          string_: 'abc'
484          float_: 1.234
485        """, pb)
486
487
488if __name__ == '__main__':
489  googletest.main()
490