1# (c) 2005 Ian Bicking and contributors; written for Paste
2# (http://pythonpaste.org) Licensed under the MIT license:
3# http://www.opensource.org/licenses/mit-license.php
4"""
5Gives a multi-value dictionary object (MultiDict) plus several wrappers
6"""
7from collections import MutableMapping
8
9import binascii
10import warnings
11
12from webob.compat import (
13    PY3,
14    iteritems_,
15    itervalues_,
16    url_encode,
17    )
18
19__all__ = ['MultiDict', 'NestedMultiDict', 'NoVars', 'GetDict']
20
21class MultiDict(MutableMapping):
22    """
23        An ordered dictionary that can have multiple values for each key.
24        Adds the methods getall, getone, mixed and extend and add to the normal
25        dictionary interface.
26    """
27
28    def __init__(self, *args, **kw):
29        if len(args) > 1:
30            raise TypeError("MultiDict can only be called with one positional "
31                            "argument")
32        if args:
33            if hasattr(args[0], 'iteritems'):
34                items = list(args[0].iteritems())
35            elif hasattr(args[0], 'items'):
36                items = list(args[0].items())
37            else:
38                items = list(args[0])
39            self._items = items
40        else:
41            self._items = []
42        if kw:
43            self._items.extend(kw.items())
44
45    @classmethod
46    def view_list(cls, lst):
47        """
48        Create a dict that is a view on the given list
49        """
50        if not isinstance(lst, list):
51            raise TypeError(
52                "%s.view_list(obj) takes only actual list objects, not %r"
53                % (cls.__name__, lst))
54        obj = cls()
55        obj._items = lst
56        return obj
57
58    @classmethod
59    def from_fieldstorage(cls, fs):
60        """
61        Create a dict from a cgi.FieldStorage instance
62        """
63        obj = cls()
64        # fs.list can be None when there's nothing to parse
65        for field in fs.list or ():
66            charset = field.type_options.get('charset', 'utf8')
67            transfer_encoding = field.headers.get('Content-Transfer-Encoding', None)
68            supported_transfer_encoding = {
69                'base64' : binascii.a2b_base64,
70                'quoted-printable' : binascii.a2b_qp
71                }
72            if PY3: # pragma: no cover
73                if charset == 'utf8':
74                    decode = lambda b: b
75                else:
76                    decode = lambda b: b.encode('utf8').decode(charset)
77            else:
78                decode = lambda b: b.decode(charset)
79            if field.filename:
80                field.filename = decode(field.filename)
81                obj.add(field.name, field)
82            else:
83                value = field.value
84                if transfer_encoding in supported_transfer_encoding:
85                    if PY3: # pragma: no cover
86                        # binascii accepts bytes
87                        value = value.encode('utf8')
88                    value = supported_transfer_encoding[transfer_encoding](value)
89                    if PY3: # pragma: no cover
90                        # binascii returns bytes
91                        value = value.decode('utf8')
92                obj.add(field.name, decode(value))
93        return obj
94
95    def __getitem__(self, key):
96        for k, v in reversed(self._items):
97            if k == key:
98                return v
99        raise KeyError(key)
100
101    def __setitem__(self, key, value):
102        try:
103            del self[key]
104        except KeyError:
105            pass
106        self._items.append((key, value))
107
108    def add(self, key, value):
109        """
110        Add the key and value, not overwriting any previous value.
111        """
112        self._items.append((key, value))
113
114    def getall(self, key):
115        """
116        Return a list of all values matching the key (may be an empty list)
117        """
118        return [v for k, v in self._items if k == key]
119
120    def getone(self, key):
121        """
122        Get one value matching the key, raising a KeyError if multiple
123        values were found.
124        """
125        v = self.getall(key)
126        if not v:
127            raise KeyError('Key not found: %r' % key)
128        if len(v) > 1:
129            raise KeyError('Multiple values match %r: %r' % (key, v))
130        return v[0]
131
132    def mixed(self):
133        """
134        Returns a dictionary where the values are either single
135        values, or a list of values when a key/value appears more than
136        once in this dictionary.  This is similar to the kind of
137        dictionary often used to represent the variables in a web
138        request.
139        """
140        result = {}
141        multi = {}
142        for key, value in self.items():
143            if key in result:
144                # We do this to not clobber any lists that are
145                # *actual* values in this dictionary:
146                if key in multi:
147                    result[key].append(value)
148                else:
149                    result[key] = [result[key], value]
150                    multi[key] = None
151            else:
152                result[key] = value
153        return result
154
155    def dict_of_lists(self):
156        """
157        Returns a dictionary where each key is associated with a list of values.
158        """
159        r = {}
160        for key, val in self.items():
161            r.setdefault(key, []).append(val)
162        return r
163
164    def __delitem__(self, key):
165        items = self._items
166        found = False
167        for i in range(len(items)-1, -1, -1):
168            if items[i][0] == key:
169                del items[i]
170                found = True
171        if not found:
172            raise KeyError(key)
173
174    def __contains__(self, key):
175        for k, v in self._items:
176            if k == key:
177                return True
178        return False
179
180    has_key = __contains__
181
182    def clear(self):
183        del self._items[:]
184
185    def copy(self):
186        return self.__class__(self)
187
188    def setdefault(self, key, default=None):
189        for k, v in self._items:
190            if key == k:
191                return v
192        self._items.append((key, default))
193        return default
194
195    def pop(self, key, *args):
196        if len(args) > 1:
197            raise TypeError("pop expected at most 2 arguments, got %s"
198                             % repr(1 + len(args)))
199        for i in range(len(self._items)):
200            if self._items[i][0] == key:
201                v = self._items[i][1]
202                del self._items[i]
203                return v
204        if args:
205            return args[0]
206        else:
207            raise KeyError(key)
208
209    def popitem(self):
210        return self._items.pop()
211
212    def update(self, *args, **kw):
213        if args:
214            lst = args[0]
215            if len(lst) != len(dict(lst)):
216                # this does not catch the cases where we overwrite existing
217                # keys, but those would produce too many warning
218                msg = ("Behavior of MultiDict.update() has changed "
219                    "and overwrites duplicate keys. Consider using .extend()"
220                )
221                warnings.warn(msg, UserWarning, stacklevel=2)
222        MutableMapping.update(self, *args, **kw)
223
224    def extend(self, other=None, **kwargs):
225        if other is None:
226            pass
227        elif hasattr(other, 'items'):
228            self._items.extend(other.items())
229        elif hasattr(other, 'keys'):
230            for k in other.keys():
231                self._items.append((k, other[k]))
232        else:
233            for k, v in other:
234                self._items.append((k, v))
235        if kwargs:
236            self.update(kwargs)
237
238    def __repr__(self):
239        items = map('(%r, %r)'.__mod__, _hide_passwd(self.items()))
240        return '%s([%s])' % (self.__class__.__name__, ', '.join(items))
241
242    def __len__(self):
243        return len(self._items)
244
245    ##
246    ## All the iteration:
247    ##
248
249    def iterkeys(self):
250        for k, v in self._items:
251            yield k
252    if PY3: # pragma: no cover
253        keys = iterkeys
254    else:
255        def keys(self):
256            return [k for k, v in self._items]
257
258    __iter__ = iterkeys
259
260    def iteritems(self):
261        return iter(self._items)
262
263    if PY3: # pragma: no cover
264        items = iteritems
265    else:
266        def items(self):
267            return self._items[:]
268
269    def itervalues(self):
270        for k, v in self._items:
271            yield v
272
273    if PY3: # pragma: no cover
274        values = itervalues
275    else:
276        def values(self):
277            return [v for k, v in self._items]
278
279_dummy = object()
280
281class GetDict(MultiDict):
282#     def __init__(self, data, tracker, encoding, errors):
283#         d = lambda b: b.decode(encoding, errors)
284#         data = [(d(k), d(v)) for k,v in data]
285    def __init__(self, data, env):
286        self.env = env
287        MultiDict.__init__(self, data)
288    def on_change(self):
289        e = lambda t: t.encode('utf8')
290        data = [(e(k), e(v)) for k,v in self.items()]
291        qs = url_encode(data)
292        self.env['QUERY_STRING'] = qs
293        self.env['webob._parsed_query_vars'] = (self, qs)
294    def __setitem__(self, key, value):
295        MultiDict.__setitem__(self, key, value)
296        self.on_change()
297    def add(self, key, value):
298        MultiDict.add(self, key, value)
299        self.on_change()
300    def __delitem__(self, key):
301        MultiDict.__delitem__(self, key)
302        self.on_change()
303    def clear(self):
304        MultiDict.clear(self)
305        self.on_change()
306    def setdefault(self, key, default=None):
307        result = MultiDict.setdefault(self, key, default)
308        self.on_change()
309        return result
310    def pop(self, key, *args):
311        result = MultiDict.pop(self, key, *args)
312        self.on_change()
313        return result
314    def popitem(self):
315        result = MultiDict.popitem(self)
316        self.on_change()
317        return result
318    def update(self, *args, **kwargs):
319        MultiDict.update(self, *args, **kwargs)
320        self.on_change()
321    def extend(self, *args, **kwargs):
322        MultiDict.extend(self, *args, **kwargs)
323        self.on_change()
324    def __repr__(self):
325        items = map('(%r, %r)'.__mod__, _hide_passwd(self.items()))
326        # TODO: GET -> GetDict
327        return 'GET([%s])' % (', '.join(items))
328    def copy(self):
329        # Copies shouldn't be tracked
330        return MultiDict(self)
331
332class NestedMultiDict(MultiDict):
333    """
334    Wraps several MultiDict objects, treating it as one large MultiDict
335    """
336
337    def __init__(self, *dicts):
338        self.dicts = dicts
339
340    def __getitem__(self, key):
341        for d in self.dicts:
342            value = d.get(key, _dummy)
343            if value is not _dummy:
344                return value
345        raise KeyError(key)
346
347    def _readonly(self, *args, **kw):
348        raise KeyError("NestedMultiDict objects are read-only")
349    __setitem__ = _readonly
350    add = _readonly
351    __delitem__ = _readonly
352    clear = _readonly
353    setdefault = _readonly
354    pop = _readonly
355    popitem = _readonly
356    update = _readonly
357
358    def getall(self, key):
359        result = []
360        for d in self.dicts:
361            result.extend(d.getall(key))
362        return result
363
364    # Inherited:
365    # getone
366    # mixed
367    # dict_of_lists
368
369    def copy(self):
370        return MultiDict(self)
371
372    def __contains__(self, key):
373        for d in self.dicts:
374            if key in d:
375                return True
376        return False
377
378    has_key = __contains__
379
380    def __len__(self):
381        v = 0
382        for d in self.dicts:
383            v += len(d)
384        return v
385
386    def __nonzero__(self):
387        for d in self.dicts:
388            if d:
389                return True
390        return False
391
392    def iteritems(self):
393        for d in self.dicts:
394            for item in iteritems_(d):
395                yield item
396    if PY3: # pragma: no cover
397        items = iteritems
398    else:
399        def items(self):
400            return list(self.iteritems())
401
402    def itervalues(self):
403        for d in self.dicts:
404            for value in itervalues_(d):
405                yield value
406    if PY3: # pragma: no cover
407        values = itervalues
408    else:
409        def values(self):
410            return list(self.itervalues())
411
412    def __iter__(self):
413        for d in self.dicts:
414            for key in d:
415                yield key
416
417    iterkeys = __iter__
418
419    if PY3: # pragma: no cover
420        keys = iterkeys
421    else:
422        def keys(self):
423            return list(self.iterkeys())
424
425class NoVars(object):
426    """
427    Represents no variables; used when no variables
428    are applicable.
429
430    This is read-only
431    """
432
433    def __init__(self, reason=None):
434        self.reason = reason or 'N/A'
435
436    def __getitem__(self, key):
437        raise KeyError("No key %r: %s" % (key, self.reason))
438
439    def __setitem__(self, *args, **kw):
440        raise KeyError("Cannot add variables: %s" % self.reason)
441
442    add = __setitem__
443    setdefault = __setitem__
444    update = __setitem__
445
446    def __delitem__(self, *args, **kw):
447        raise KeyError("No keys to delete: %s" % self.reason)
448    clear = __delitem__
449    pop = __delitem__
450    popitem = __delitem__
451
452    def get(self, key, default=None):
453        return default
454
455    def getall(self, key):
456        return []
457
458    def getone(self, key):
459        return self[key]
460
461    def mixed(self):
462        return {}
463    dict_of_lists = mixed
464
465    def __contains__(self, key):
466        return False
467    has_key = __contains__
468
469    def copy(self):
470        return self
471
472    def __repr__(self):
473        return '<%s: %s>' % (self.__class__.__name__,
474                             self.reason)
475
476    def __len__(self):
477        return 0
478
479    def __cmp__(self, other):
480        return cmp({}, other)
481
482    def iterkeys(self):
483        return iter([])
484
485    if PY3: # pragma: no cover
486        keys = iterkeys
487        items = iterkeys
488        values = iterkeys
489    else:
490        def keys(self):
491            return []
492        items = keys
493        values = keys
494        itervalues = iterkeys
495        iteritems = iterkeys
496
497    __iter__ = iterkeys
498
499def _hide_passwd(items):
500    for k, v in items:
501        if ('password' in k
502            or 'passwd' in k
503            or 'pwd' in k
504        ):
505            yield k, '******'
506        else:
507            yield k, v
508