1#!/usr/bin/env python
2#
3# Copyright 2018 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Common code for converting proto to other formats, such as JSON."""
18
19import base64
20import collections
21import datetime
22import json
23
24import six
25
26from apitools.base.protorpclite import message_types
27from apitools.base.protorpclite import messages
28from apitools.base.protorpclite import protojson
29from apitools.base.py import exceptions
30
31
32_Codec = collections.namedtuple('_Codec', ['encoder', 'decoder'])
33CodecResult = collections.namedtuple('CodecResult', ['value', 'complete'])
34
35
36class EdgeType(object):
37    """The type of transition made by an edge."""
38    SCALAR = 1
39    REPEATED = 2
40    MAP = 3
41
42
43class ProtoEdge(collections.namedtuple('ProtoEdge',
44                                       ['type_', 'field', 'index'])):
45    """A description of a one-level transition from a message to a value.
46
47    Protobuf messages can be arbitrarily nested as fields can be defined with
48    any "message" type. This nesting property means that there are often many
49    levels of proto messages within a single message instance. This class can
50    unambiguously describe a single step from a message to some nested value.
51
52    Properties:
53      type_: EdgeType, The type of transition represented by this edge.
54      field: str, The name of the message-typed field.
55      index: Any, Additional data needed to make the transition. The semantics
56          of the "index" property change based on the value of "type_":
57            SCALAR: ignored.
58            REPEATED: a numeric index into "field"'s list.
59            MAP: a key into "field"'s mapping.
60    """
61    __slots__ = ()
62
63    def __str__(self):
64        if self.type_ == EdgeType.SCALAR:
65            return self.field
66        else:
67            return '{}[{}]'.format(self.field, self.index)
68
69
70# TODO(craigcitro): Make these non-global.
71_UNRECOGNIZED_FIELD_MAPPINGS = {}
72_CUSTOM_MESSAGE_CODECS = {}
73_CUSTOM_FIELD_CODECS = {}
74_FIELD_TYPE_CODECS = {}
75
76
77def MapUnrecognizedFields(field_name):
78    """Register field_name as a container for unrecognized fields."""
79    def Register(cls):
80        _UNRECOGNIZED_FIELD_MAPPINGS[cls] = field_name
81        return cls
82    return Register
83
84
85def RegisterCustomMessageCodec(encoder, decoder):
86    """Register a custom encoder/decoder for this message class."""
87    def Register(cls):
88        _CUSTOM_MESSAGE_CODECS[cls] = _Codec(encoder=encoder, decoder=decoder)
89        return cls
90    return Register
91
92
93def RegisterCustomFieldCodec(encoder, decoder):
94    """Register a custom encoder/decoder for this field."""
95    def Register(field):
96        _CUSTOM_FIELD_CODECS[field] = _Codec(encoder=encoder, decoder=decoder)
97        return field
98    return Register
99
100
101def RegisterFieldTypeCodec(encoder, decoder):
102    """Register a custom encoder/decoder for all fields of this type."""
103    def Register(field_type):
104        _FIELD_TYPE_CODECS[field_type] = _Codec(
105            encoder=encoder, decoder=decoder)
106        return field_type
107    return Register
108
109
110def CopyProtoMessage(message):
111    """Make a deep copy of a message."""
112    return JsonToMessage(type(message), MessageToJson(message))
113
114
115def MessageToJson(message, include_fields=None):
116    """Convert the given message to JSON."""
117    result = _ProtoJsonApiTools.Get().encode_message(message)
118    return _IncludeFields(result, message, include_fields)
119
120
121def JsonToMessage(message_type, message):
122    """Convert the given JSON to a message of type message_type."""
123    return _ProtoJsonApiTools.Get().decode_message(message_type, message)
124
125
126# TODO(craigcitro): Do this directly, instead of via JSON.
127def DictToMessage(d, message_type):
128    """Convert the given dictionary to a message of type message_type."""
129    return JsonToMessage(message_type, json.dumps(d))
130
131
132def MessageToDict(message):
133    """Convert the given message to a dictionary."""
134    return json.loads(MessageToJson(message))
135
136
137def DictToAdditionalPropertyMessage(properties, additional_property_type,
138                                    sort_items=False):
139    """Convert the given dictionary to an AdditionalProperty message."""
140    items = properties.items()
141    if sort_items:
142        items = sorted(items)
143    map_ = []
144    for key, value in items:
145        map_.append(additional_property_type.AdditionalProperty(
146            key=key, value=value))
147    return additional_property_type(additionalProperties=map_)
148
149
150def PyValueToMessage(message_type, value):
151    """Convert the given python value to a message of type message_type."""
152    return JsonToMessage(message_type, json.dumps(value))
153
154
155def MessageToPyValue(message):
156    """Convert the given message to a python value."""
157    return json.loads(MessageToJson(message))
158
159
160def MessageToRepr(msg, multiline=False, **kwargs):
161    """Return a repr-style string for a protorpc message.
162
163    protorpc.Message.__repr__ does not return anything that could be considered
164    python code. Adding this function lets us print a protorpc message in such
165    a way that it could be pasted into code later, and used to compare against
166    other things.
167
168    Args:
169      msg: protorpc.Message, the message to be repr'd.
170      multiline: bool, True if the returned string should have each field
171          assignment on its own line.
172      **kwargs: {str:str}, Additional flags for how to format the string.
173
174    Known **kwargs:
175      shortstrings: bool, True if all string values should be
176          truncated at 100 characters, since when mocking the contents
177          typically don't matter except for IDs, and IDs are usually
178          less than 100 characters.
179      no_modules: bool, True if the long module name should not be printed with
180          each type.
181
182    Returns:
183      str, A string of valid python (assuming the right imports have been made)
184      that recreates the message passed into this function.
185
186    """
187
188    # TODO(jasmuth): craigcitro suggests a pretty-printer from apitools/gen.
189
190    indent = kwargs.get('indent', 0)
191
192    def IndentKwargs(kwargs):
193        kwargs = dict(kwargs)
194        kwargs['indent'] = kwargs.get('indent', 0) + 4
195        return kwargs
196
197    if isinstance(msg, list):
198        s = '['
199        for item in msg:
200            if multiline:
201                s += '\n' + ' ' * (indent + 4)
202            s += MessageToRepr(
203                item, multiline=multiline, **IndentKwargs(kwargs)) + ','
204        if multiline:
205            s += '\n' + ' ' * indent
206        s += ']'
207        return s
208
209    if isinstance(msg, messages.Message):
210        s = type(msg).__name__ + '('
211        if not kwargs.get('no_modules'):
212            s = msg.__module__ + '.' + s
213        names = sorted([field.name for field in msg.all_fields()])
214        for name in names:
215            field = msg.field_by_name(name)
216            if multiline:
217                s += '\n' + ' ' * (indent + 4)
218            value = getattr(msg, field.name)
219            s += field.name + '=' + MessageToRepr(
220                value, multiline=multiline, **IndentKwargs(kwargs)) + ','
221        if multiline:
222            s += '\n' + ' ' * indent
223        s += ')'
224        return s
225
226    if isinstance(msg, six.string_types):
227        if kwargs.get('shortstrings') and len(msg) > 100:
228            msg = msg[:100]
229
230    if isinstance(msg, datetime.datetime):
231
232        class SpecialTZInfo(datetime.tzinfo):
233
234            def __init__(self, offset):
235                super(SpecialTZInfo, self).__init__()
236                self.offset = offset
237
238            def __repr__(self):
239                s = 'TimeZoneOffset(' + repr(self.offset) + ')'
240                if not kwargs.get('no_modules'):
241                    s = 'apitools.base.protorpclite.util.' + s
242                return s
243
244        msg = datetime.datetime(
245            msg.year, msg.month, msg.day, msg.hour, msg.minute, msg.second,
246            msg.microsecond, SpecialTZInfo(msg.tzinfo.utcoffset(0)))
247
248    return repr(msg)
249
250
251def _GetField(message, field_path):
252    for field in field_path:
253        if field not in dir(message):
254            raise KeyError('no field "%s"' % field)
255        message = getattr(message, field)
256    return message
257
258
259def _SetField(dictblob, field_path, value):
260    for field in field_path[:-1]:
261        dictblob = dictblob.setdefault(field, {})
262    dictblob[field_path[-1]] = value
263
264
265def _IncludeFields(encoded_message, message, include_fields):
266    """Add the requested fields to the encoded message."""
267    if include_fields is None:
268        return encoded_message
269    result = json.loads(encoded_message)
270    for field_name in include_fields:
271        try:
272            value = _GetField(message, field_name.split('.'))
273            nullvalue = None
274            if isinstance(value, list):
275                nullvalue = []
276        except KeyError:
277            raise exceptions.InvalidDataError(
278                'No field named %s in message of type %s' % (
279                    field_name, type(message)))
280        _SetField(result, field_name.split('.'), nullvalue)
281    return json.dumps(result)
282
283
284def _GetFieldCodecs(field, attr):
285    result = [
286        getattr(_CUSTOM_FIELD_CODECS.get(field), attr, None),
287        getattr(_FIELD_TYPE_CODECS.get(type(field)), attr, None),
288    ]
289    return [x for x in result if x is not None]
290
291
292class _ProtoJsonApiTools(protojson.ProtoJson):
293
294    """JSON encoder used by apitools clients."""
295    _INSTANCE = None
296
297    @classmethod
298    def Get(cls):
299        if cls._INSTANCE is None:
300            cls._INSTANCE = cls()
301        return cls._INSTANCE
302
303    def decode_message(self, message_type, encoded_message):
304        if message_type in _CUSTOM_MESSAGE_CODECS:
305            return _CUSTOM_MESSAGE_CODECS[
306                message_type].decoder(encoded_message)
307        result = _DecodeCustomFieldNames(message_type, encoded_message)
308        result = super(_ProtoJsonApiTools, self).decode_message(
309            message_type, result)
310        result = _ProcessUnknownEnums(result, encoded_message)
311        result = _ProcessUnknownMessages(result, encoded_message)
312        return _DecodeUnknownFields(result, encoded_message)
313
314    def decode_field(self, field, value):
315        """Decode the given JSON value.
316
317        Args:
318          field: a messages.Field for the field we're decoding.
319          value: a python value we'd like to decode.
320
321        Returns:
322          A value suitable for assignment to field.
323        """
324        for decoder in _GetFieldCodecs(field, 'decoder'):
325            result = decoder(field, value)
326            value = result.value
327            if result.complete:
328                return value
329        if isinstance(field, messages.MessageField):
330            field_value = self.decode_message(
331                field.message_type, json.dumps(value))
332        elif isinstance(field, messages.EnumField):
333            value = GetCustomJsonEnumMapping(
334                field.type, json_name=value) or value
335            try:
336                field_value = super(
337                    _ProtoJsonApiTools, self).decode_field(field, value)
338            except messages.DecodeError:
339                if not isinstance(value, six.string_types):
340                    raise
341                field_value = None
342        else:
343            field_value = super(
344                _ProtoJsonApiTools, self).decode_field(field, value)
345        return field_value
346
347    def encode_message(self, message):
348        if isinstance(message, messages.FieldList):
349            return '[%s]' % (', '.join(self.encode_message(x)
350                                       for x in message))
351
352        # pylint: disable=unidiomatic-typecheck
353        if type(message) in _CUSTOM_MESSAGE_CODECS:
354            return _CUSTOM_MESSAGE_CODECS[type(message)].encoder(message)
355
356        message = _EncodeUnknownFields(message)
357        result = super(_ProtoJsonApiTools, self).encode_message(message)
358        result = _EncodeCustomFieldNames(message, result)
359        return json.dumps(json.loads(result), sort_keys=True)
360
361    def encode_field(self, field, value):
362        """Encode the given value as JSON.
363
364        Args:
365          field: a messages.Field for the field we're encoding.
366          value: a value for field.
367
368        Returns:
369          A python value suitable for json.dumps.
370        """
371        for encoder in _GetFieldCodecs(field, 'encoder'):
372            result = encoder(field, value)
373            value = result.value
374            if result.complete:
375                return value
376        if isinstance(field, messages.EnumField):
377            if field.repeated:
378                remapped_value = [GetCustomJsonEnumMapping(
379                    field.type, python_name=e.name) or e.name for e in value]
380            else:
381                remapped_value = GetCustomJsonEnumMapping(
382                    field.type, python_name=value.name)
383            if remapped_value:
384                return remapped_value
385        if (isinstance(field, messages.MessageField) and
386                not isinstance(field, message_types.DateTimeField)):
387            value = json.loads(self.encode_message(value))
388        return super(_ProtoJsonApiTools, self).encode_field(field, value)
389
390
391# TODO(craigcitro): Fold this and _IncludeFields in as codecs.
392def _DecodeUnknownFields(message, encoded_message):
393    """Rewrite unknown fields in message into message.destination."""
394    destination = _UNRECOGNIZED_FIELD_MAPPINGS.get(type(message))
395    if destination is None:
396        return message
397    pair_field = message.field_by_name(destination)
398    if not isinstance(pair_field, messages.MessageField):
399        raise exceptions.InvalidDataFromServerError(
400            'Unrecognized fields must be mapped to a compound '
401            'message type.')
402    pair_type = pair_field.message_type
403    # TODO(craigcitro): Add more error checking around the pair
404    # type being exactly what we suspect (field names, etc).
405    if isinstance(pair_type.value, messages.MessageField):
406        new_values = _DecodeUnknownMessages(
407            message, json.loads(encoded_message), pair_type)
408    else:
409        new_values = _DecodeUnrecognizedFields(message, pair_type)
410    setattr(message, destination, new_values)
411    # We could probably get away with not setting this, but
412    # why not clear it?
413    setattr(message, '_Message__unrecognized_fields', {})
414    return message
415
416
417def _DecodeUnknownMessages(message, encoded_message, pair_type):
418    """Process unknown fields in encoded_message of a message type."""
419    field_type = pair_type.value.type
420    new_values = []
421    all_field_names = [x.name for x in message.all_fields()]
422    for name, value_dict in six.iteritems(encoded_message):
423        if name in all_field_names:
424            continue
425        value = PyValueToMessage(field_type, value_dict)
426        if pair_type.value.repeated:
427            value = _AsMessageList(value)
428        new_pair = pair_type(key=name, value=value)
429        new_values.append(new_pair)
430    return new_values
431
432
433def _DecodeUnrecognizedFields(message, pair_type):
434    """Process unrecognized fields in message."""
435    new_values = []
436    codec = _ProtoJsonApiTools.Get()
437    for unknown_field in message.all_unrecognized_fields():
438        # TODO(craigcitro): Consider validating the variant if
439        # the assignment below doesn't take care of it. It may
440        # also be necessary to check it in the case that the
441        # type has multiple encodings.
442        value, _ = message.get_unrecognized_field_info(unknown_field)
443        value_type = pair_type.field_by_name('value')
444        if isinstance(value_type, messages.MessageField):
445            decoded_value = DictToMessage(value, pair_type.value.message_type)
446        else:
447            decoded_value = codec.decode_field(
448                pair_type.value, value)
449        try:
450            new_pair_key = str(unknown_field)
451        except UnicodeEncodeError:
452            new_pair_key = protojson.ProtoJson().decode_field(
453                pair_type.key, unknown_field)
454        new_pair = pair_type(key=new_pair_key, value=decoded_value)
455        new_values.append(new_pair)
456    return new_values
457
458
459def _CopyProtoMessageVanillaProtoJson(message):
460    codec = protojson.ProtoJson()
461    return codec.decode_message(type(message), codec.encode_message(message))
462
463
464def _EncodeUnknownFields(message):
465    """Remap unknown fields in message out of message.source."""
466    source = _UNRECOGNIZED_FIELD_MAPPINGS.get(type(message))
467    if source is None:
468        return message
469    # CopyProtoMessage uses _ProtoJsonApiTools, which uses this message. Use
470    # the vanilla protojson-based copy function to avoid infinite recursion.
471    result = _CopyProtoMessageVanillaProtoJson(message)
472    pairs_field = message.field_by_name(source)
473    if not isinstance(pairs_field, messages.MessageField):
474        raise exceptions.InvalidUserInputError(
475            'Invalid pairs field %s' % pairs_field)
476    pairs_type = pairs_field.message_type
477    value_field = pairs_type.field_by_name('value')
478    value_variant = value_field.variant
479    pairs = getattr(message, source)
480    codec = _ProtoJsonApiTools.Get()
481    for pair in pairs:
482        encoded_value = codec.encode_field(value_field, pair.value)
483        result.set_unrecognized_field(pair.key, encoded_value, value_variant)
484    setattr(result, source, [])
485    return result
486
487
488def _SafeEncodeBytes(field, value):
489    """Encode the bytes in value as urlsafe base64."""
490    try:
491        if field.repeated:
492            result = [base64.urlsafe_b64encode(byte) for byte in value]
493        else:
494            result = base64.urlsafe_b64encode(value)
495        complete = True
496    except TypeError:
497        result = value
498        complete = False
499    return CodecResult(value=result, complete=complete)
500
501
502def _SafeDecodeBytes(unused_field, value):
503    """Decode the urlsafe base64 value into bytes."""
504    try:
505        result = base64.urlsafe_b64decode(str(value))
506        complete = True
507    except TypeError:
508        result = value
509        complete = False
510    return CodecResult(value=result, complete=complete)
511
512
513def _ProcessUnknownEnums(message, encoded_message):
514    """Add unknown enum values from encoded_message as unknown fields.
515
516    ProtoRPC diverges from the usual protocol buffer behavior here and
517    doesn't allow unknown fields. Throwing on unknown fields makes it
518    impossible to let servers add new enum values and stay compatible
519    with older clients, which isn't reasonable for us. We simply store
520    unrecognized enum values as unknown fields, and all is well.
521
522    Args:
523      message: Proto message we've decoded thus far.
524      encoded_message: JSON string we're decoding.
525
526    Returns:
527      message, with any unknown enums stored as unrecognized fields.
528    """
529    if not encoded_message:
530        return message
531    decoded_message = json.loads(six.ensure_str(encoded_message))
532    for field in message.all_fields():
533        if (isinstance(field, messages.EnumField) and
534                field.name in decoded_message and
535                message.get_assigned_value(field.name) is None):
536            message.set_unrecognized_field(
537                field.name, decoded_message[field.name], messages.Variant.ENUM)
538    return message
539
540
541def _ProcessUnknownMessages(message, encoded_message):
542    """Store any remaining unknown fields as strings.
543
544    ProtoRPC currently ignores unknown values for which no type can be
545    determined (and logs a "No variant found" message). For the purposes
546    of reserializing, this is quite harmful (since it throws away
547    information). Here we simply add those as unknown fields of type
548    string (so that they can easily be reserialized).
549
550    Args:
551      message: Proto message we've decoded thus far.
552      encoded_message: JSON string we're decoding.
553
554    Returns:
555      message, with any remaining unrecognized fields saved.
556    """
557    if not encoded_message:
558        return message
559    decoded_message = json.loads(six.ensure_str(encoded_message))
560    message_fields = [x.name for x in message.all_fields()] + list(
561        message.all_unrecognized_fields())
562    missing_fields = [x for x in decoded_message.keys()
563                      if x not in message_fields]
564    for field_name in missing_fields:
565        message.set_unrecognized_field(field_name, decoded_message[field_name],
566                                       messages.Variant.STRING)
567    return message
568
569
570RegisterFieldTypeCodec(_SafeEncodeBytes, _SafeDecodeBytes)(messages.BytesField)
571
572
573# Note that these could share a dictionary, since they're keyed by
574# distinct types, but it's not really worth it.
575_JSON_ENUM_MAPPINGS = {}
576_JSON_FIELD_MAPPINGS = {}
577
578
579def AddCustomJsonEnumMapping(enum_type, python_name, json_name,
580                             package=None):  # pylint: disable=unused-argument
581    """Add a custom wire encoding for a given enum value.
582
583    This is primarily used in generated code, to handle enum values
584    which happen to be Python keywords.
585
586    Args:
587      enum_type: (messages.Enum) An enum type
588      python_name: (basestring) Python name for this value.
589      json_name: (basestring) JSON name to be used on the wire.
590      package: (NoneType, optional) No effect, exists for legacy compatibility.
591    """
592    if not issubclass(enum_type, messages.Enum):
593        raise exceptions.TypecheckError(
594            'Cannot set JSON enum mapping for non-enum "%s"' % enum_type)
595    if python_name not in enum_type.names():
596        raise exceptions.InvalidDataError(
597            'Enum value %s not a value for type %s' % (python_name, enum_type))
598    field_mappings = _JSON_ENUM_MAPPINGS.setdefault(enum_type, {})
599    _CheckForExistingMappings('enum', enum_type, python_name, json_name)
600    field_mappings[python_name] = json_name
601
602
603def AddCustomJsonFieldMapping(message_type, python_name, json_name,
604                              package=None):  # pylint: disable=unused-argument
605    """Add a custom wire encoding for a given message field.
606
607    This is primarily used in generated code, to handle enum values
608    which happen to be Python keywords.
609
610    Args:
611      message_type: (messages.Message) A message type
612      python_name: (basestring) Python name for this value.
613      json_name: (basestring) JSON name to be used on the wire.
614      package: (NoneType, optional) No effect, exists for legacy compatibility.
615    """
616    if not issubclass(message_type, messages.Message):
617        raise exceptions.TypecheckError(
618            'Cannot set JSON field mapping for '
619            'non-message "%s"' % message_type)
620    try:
621        _ = message_type.field_by_name(python_name)
622    except KeyError:
623        raise exceptions.InvalidDataError(
624            'Field %s not recognized for type %s' % (
625                python_name, message_type))
626    field_mappings = _JSON_FIELD_MAPPINGS.setdefault(message_type, {})
627    _CheckForExistingMappings('field', message_type, python_name, json_name)
628    field_mappings[python_name] = json_name
629
630
631def GetCustomJsonEnumMapping(enum_type, python_name=None, json_name=None):
632    """Return the appropriate remapping for the given enum, or None."""
633    return _FetchRemapping(enum_type, 'enum',
634                           python_name=python_name, json_name=json_name,
635                           mappings=_JSON_ENUM_MAPPINGS)
636
637
638def GetCustomJsonFieldMapping(message_type, python_name=None, json_name=None):
639    """Return the appropriate remapping for the given field, or None."""
640    return _FetchRemapping(message_type, 'field',
641                           python_name=python_name, json_name=json_name,
642                           mappings=_JSON_FIELD_MAPPINGS)
643
644
645def _FetchRemapping(type_name, mapping_type, python_name=None, json_name=None,
646                    mappings=None):
647    """Common code for fetching a key or value from a remapping dict."""
648    if python_name and json_name:
649        raise exceptions.InvalidDataError(
650            'Cannot specify both python_name and json_name '
651            'for %s remapping' % mapping_type)
652    if not (python_name or json_name):
653        raise exceptions.InvalidDataError(
654            'Must specify either python_name or json_name for %s remapping' % (
655                mapping_type,))
656    field_remappings = mappings.get(type_name, {})
657    if field_remappings:
658        if python_name:
659            return field_remappings.get(python_name)
660        elif json_name:
661            if json_name in list(field_remappings.values()):
662                return [k for k in field_remappings
663                        if field_remappings[k] == json_name][0]
664    return None
665
666
667def _CheckForExistingMappings(mapping_type, message_type,
668                              python_name, json_name):
669    """Validate that no mappings exist for the given values."""
670    if mapping_type == 'field':
671        getter = GetCustomJsonFieldMapping
672    elif mapping_type == 'enum':
673        getter = GetCustomJsonEnumMapping
674    remapping = getter(message_type, python_name=python_name)
675    if remapping is not None and remapping != json_name:
676        raise exceptions.InvalidDataError(
677            'Cannot add mapping for %s "%s", already mapped to "%s"' % (
678                mapping_type, python_name, remapping))
679    remapping = getter(message_type, json_name=json_name)
680    if remapping is not None and remapping != python_name:
681        raise exceptions.InvalidDataError(
682            'Cannot add mapping for %s "%s", already mapped to "%s"' % (
683                mapping_type, json_name, remapping))
684
685
686def _EncodeCustomFieldNames(message, encoded_value):
687    field_remappings = list(_JSON_FIELD_MAPPINGS.get(type(message), {})
688                            .items())
689    if field_remappings:
690        decoded_value = json.loads(encoded_value)
691        for python_name, json_name in field_remappings:
692            if python_name in encoded_value:
693                decoded_value[json_name] = decoded_value.pop(python_name)
694        encoded_value = json.dumps(decoded_value)
695    return encoded_value
696
697
698def _DecodeCustomFieldNames(message_type, encoded_message):
699    field_remappings = _JSON_FIELD_MAPPINGS.get(message_type, {})
700    if field_remappings:
701        decoded_message = json.loads(encoded_message)
702        for python_name, json_name in list(field_remappings.items()):
703            if json_name in decoded_message:
704                decoded_message[python_name] = decoded_message.pop(json_name)
705        encoded_message = json.dumps(decoded_message)
706    return encoded_message
707
708
709def _AsMessageList(msg):
710    """Convert the provided list-as-JsonValue to a list."""
711    # This really needs to live in extra_types, but extra_types needs
712    # to import this file to be able to register codecs.
713    # TODO(craigcitro): Split out a codecs module and fix this ugly
714    # import.
715    from apitools.base.py import extra_types
716
717    def _IsRepeatedJsonValue(msg):
718        """Return True if msg is a repeated value as a JsonValue."""
719        if isinstance(msg, extra_types.JsonArray):
720            return True
721        if isinstance(msg, extra_types.JsonValue) and msg.array_value:
722            return True
723        return False
724
725    if not _IsRepeatedJsonValue(msg):
726        raise ValueError('invalid argument to _AsMessageList')
727    if isinstance(msg, extra_types.JsonValue):
728        msg = msg.array_value
729    if isinstance(msg, extra_types.JsonArray):
730        msg = msg.entries
731    return msg
732
733
734def _IsMap(message, field):
735    """Returns whether the "field" is actually a map-type."""
736    value = message.get_assigned_value(field.name)
737    if not isinstance(value, messages.Message):
738        return False
739    try:
740        additional_properties = value.field_by_name('additionalProperties')
741    except KeyError:
742        return False
743    else:
744        return additional_properties.repeated
745
746
747def _MapItems(message, field):
748    """Yields the (key, value) pair of the map values."""
749    assert _IsMap(message, field)
750    map_message = message.get_assigned_value(field.name)
751    additional_properties = map_message.get_assigned_value(
752        'additionalProperties')
753    for kv_pair in additional_properties:
754        yield kv_pair.key, kv_pair.value
755
756
757def UnrecognizedFieldIter(message, _edges=()):  # pylint: disable=invalid-name
758    """Yields the locations of unrecognized fields within "message".
759
760    If a sub-message is found to have unrecognized fields, that sub-message
761    will not be searched any further. We prune the search of the sub-message
762    because we assume it is malformed and further checks will not yield
763    productive errors.
764
765    Args:
766      message: The Message instance to search.
767      _edges: Internal arg for passing state.
768
769    Yields:
770      (edges_to_message, field_names):
771        edges_to_message: List[ProtoEdge], The edges (relative to "message")
772            describing the path to the sub-message where the unrecognized
773            fields were found.
774        field_names: List[Str], The names of the field(s) that were
775            unrecognized in the sub-message.
776    """
777    if not isinstance(message, messages.Message):
778        # This is a primitive leaf, no errors found down this path.
779        return
780
781    field_names = message.all_unrecognized_fields()
782    if field_names:
783        # This message is malformed. Stop recursing and report it.
784        yield _edges, field_names
785        return
786
787    # Recurse through all fields in the current message.
788    for field in message.all_fields():
789        value = message.get_assigned_value(field.name)
790        if field.repeated:
791            for i, item in enumerate(value):
792                repeated_edge = ProtoEdge(EdgeType.REPEATED, field.name, i)
793                iter_ = UnrecognizedFieldIter(item, _edges + (repeated_edge,))
794                for (e, y) in iter_:
795                    yield e, y
796        elif _IsMap(message, field):
797            for key, item in _MapItems(message, field):
798                map_edge = ProtoEdge(EdgeType.MAP, field.name, key)
799                iter_ = UnrecognizedFieldIter(item, _edges + (map_edge,))
800                for (e, y) in iter_:
801                    yield e, y
802        else:
803            scalar_edge = ProtoEdge(EdgeType.SCALAR, field.name, None)
804            iter_ = UnrecognizedFieldIter(value, _edges + (scalar_edge,))
805            for (e, y) in iter_:
806                yield e, y
807