1# (c) 2005 Ian Bicking and contributors; written for Paste (http://pythonpaste.org)
2# Licensed under the MIT license: http://www.opensource.org/licenses/mit-license.php
3import cgi
4import copy
5import six
6import sys
7
8try:
9    # Python 3
10    from collections import MutableMapping as DictMixin
11except ImportError:
12    # Python 2
13    from UserDict import DictMixin
14
15class MultiDict(DictMixin):
16
17    """
18    An ordered dictionary that can have multiple values for each key.
19    Adds the methods getall, getone, mixed, and add to the normal
20    dictionary interface.
21    """
22
23    def __init__(self, *args, **kw):
24        if len(args) > 1:
25            raise TypeError(
26                "MultiDict can only be called with one positional argument")
27        if args:
28            if hasattr(args[0], 'iteritems'):
29                items = args[0].iteritems()
30            elif hasattr(args[0], 'items'):
31                items = args[0].items()
32            else:
33                items = args[0]
34            self._items = list(items)
35        else:
36            self._items = []
37        self._items.extend(six.iteritems(kw))
38
39    def __getitem__(self, key):
40        for k, v in self._items:
41            if k == key:
42                return v
43        raise KeyError(repr(key))
44
45    def __setitem__(self, key, value):
46        try:
47            del self[key]
48        except KeyError:
49            pass
50        self._items.append((key, value))
51
52    def add(self, key, value):
53        """
54        Add the key and value, not overwriting any previous value.
55        """
56        self._items.append((key, value))
57
58    def getall(self, key):
59        """
60        Return a list of all values matching the key (may be an empty list)
61        """
62        result = []
63        for k, v in self._items:
64            if type(key) == type(k) and key == k:
65                result.append(v)
66        return result
67
68    def getone(self, key):
69        """
70        Get one value matching the key, raising a KeyError if multiple
71        values were found.
72        """
73        v = self.getall(key)
74        if not v:
75            raise KeyError('Key not found: %r' % key)
76        if len(v) > 1:
77            raise KeyError('Multiple values match %r: %r' % (key, v))
78        return v[0]
79
80    def mixed(self):
81        """
82        Returns a dictionary where the values are either single
83        values, or a list of values when a key/value appears more than
84        once in this dictionary.  This is similar to the kind of
85        dictionary often used to represent the variables in a web
86        request.
87        """
88        result = {}
89        multi = {}
90        for key, value in self._items:
91            if key in result:
92                # We do this to not clobber any lists that are
93                # *actual* values in this dictionary:
94                if key in multi:
95                    result[key].append(value)
96                else:
97                    result[key] = [result[key], value]
98                    multi[key] = None
99            else:
100                result[key] = value
101        return result
102
103    def dict_of_lists(self):
104        """
105        Returns a dictionary where each key is associated with a
106        list of values.
107        """
108        result = {}
109        for key, value in self._items:
110            if key in result:
111                result[key].append(value)
112            else:
113                result[key] = [value]
114        return result
115
116    def __delitem__(self, key):
117        items = self._items
118        found = False
119        for i in range(len(items)-1, -1, -1):
120            if type(items[i][0]) == type(key) and items[i][0] == key:
121                del items[i]
122                found = True
123        if not found:
124            raise KeyError(repr(key))
125
126    def __contains__(self, key):
127        for k, v in self._items:
128            if type(k) == type(key) and k == key:
129                return True
130        return False
131
132    has_key = __contains__
133
134    def clear(self):
135        self._items = []
136
137    def copy(self):
138        return MultiDict(self)
139
140    def setdefault(self, key, default=None):
141        for k, v in self._items:
142            if key == k:
143                return v
144        self._items.append((key, default))
145        return default
146
147    def pop(self, key, *args):
148        if len(args) > 1:
149            raise TypeError("pop expected at most 2 arguments, got "
150                              + repr(1 + len(args)))
151        for i in range(len(self._items)):
152            if type(self._items[i][0]) == type(key) and self._items[i][0] == key:
153                v = self._items[i][1]
154                del self._items[i]
155                return v
156        if args:
157            return args[0]
158        else:
159            raise KeyError(repr(key))
160
161    def popitem(self):
162        return self._items.pop()
163
164    def update(self, other=None, **kwargs):
165        if other is None:
166            pass
167        elif hasattr(other, 'items'):
168            self._items.extend(other.items())
169        elif hasattr(other, 'keys'):
170            for k in other.keys():
171                self._items.append((k, other[k]))
172        else:
173            for k, v in other:
174                self._items.append((k, v))
175        if kwargs:
176            self.update(kwargs)
177
178    def __repr__(self):
179        items = ', '.join(['(%r, %r)' % v for v in self._items])
180        return '%s([%s])' % (self.__class__.__name__, items)
181
182    def __len__(self):
183        return len(self._items)
184
185    ##
186    ## All the iteration:
187    ##
188
189    def keys(self):
190        return [k for k, v in self._items]
191
192    def iterkeys(self):
193        for k, v in self._items:
194            yield k
195
196    __iter__ = iterkeys
197
198    def items(self):
199        return self._items[:]
200
201    def iteritems(self):
202        return iter(self._items)
203
204    def values(self):
205        return [v for k, v in self._items]
206
207    def itervalues(self):
208        for k, v in self._items:
209            yield v
210
211class UnicodeMultiDict(DictMixin):
212    """
213    A MultiDict wrapper that decodes returned values to unicode on the
214    fly. Decoding is not applied to assigned values.
215
216    The key/value contents are assumed to be ``str``/``strs`` or
217    ``str``/``FieldStorages`` (as is returned by the ``paste.request.parse_``
218    functions).
219
220    Can optionally also decode keys when the ``decode_keys`` argument is
221    True.
222
223    ``FieldStorage`` instances are cloned, and the clone's ``filename``
224    variable is decoded. Its ``name`` variable is decoded when ``decode_keys``
225    is enabled.
226
227    """
228    def __init__(self, multi=None, encoding=None, errors='strict',
229                 decode_keys=False):
230        self.multi = multi
231        if encoding is None:
232            encoding = sys.getdefaultencoding()
233        self.encoding = encoding
234        self.errors = errors
235        self.decode_keys = decode_keys
236        if self.decode_keys:
237            items = self.multi._items
238            for index, item in enumerate(items):
239                key, value = item
240                key = self._encode_key(key)
241                items[index] = (key, value)
242
243    def _encode_key(self, key):
244        if self.decode_keys:
245            try:
246                key = key.encode(self.encoding, self.errors)
247            except AttributeError:
248                pass
249        return key
250
251    def _decode_key(self, key):
252        if self.decode_keys:
253            try:
254                key = key.decode(self.encoding, self.errors)
255            except AttributeError:
256                pass
257        return key
258
259    def _decode_value(self, value):
260        """
261        Decode the specified value to unicode. Assumes value is a ``str`` or
262        `FieldStorage`` object.
263
264        ``FieldStorage`` objects are specially handled.
265        """
266        if isinstance(value, cgi.FieldStorage):
267            # decode FieldStorage's field name and filename
268            value = copy.copy(value)
269            if self.decode_keys and isinstance(value.name, six.binary_type):
270                value.name = value.name.decode(self.encoding, self.errors)
271            if six.PY2:
272                value.filename = value.filename.decode(self.encoding, self.errors)
273        else:
274            try:
275                value = value.decode(self.encoding, self.errors)
276            except AttributeError:
277                pass
278        return value
279
280    def __getitem__(self, key):
281        key = self._encode_key(key)
282        return self._decode_value(self.multi.__getitem__(key))
283
284    def __setitem__(self, key, value):
285        key = self._encode_key(key)
286        self.multi.__setitem__(key, value)
287
288    def add(self, key, value):
289        """
290        Add the key and value, not overwriting any previous value.
291        """
292        key = self._encode_key(key)
293        self.multi.add(key, value)
294
295    def getall(self, key):
296        """
297        Return a list of all values matching the key (may be an empty list)
298        """
299        key = self._encode_key(key)
300        return [self._decode_value(v) for v in self.multi.getall(key)]
301
302    def getone(self, key):
303        """
304        Get one value matching the key, raising a KeyError if multiple
305        values were found.
306        """
307        key = self._encode_key(key)
308        return self._decode_value(self.multi.getone(key))
309
310    def mixed(self):
311        """
312        Returns a dictionary where the values are either single
313        values, or a list of values when a key/value appears more than
314        once in this dictionary.  This is similar to the kind of
315        dictionary often used to represent the variables in a web
316        request.
317        """
318        unicode_mixed = {}
319        for key, value in six.iteritems(self.multi.mixed()):
320            if isinstance(value, list):
321                value = [self._decode_value(value) for value in value]
322            else:
323                value = self._decode_value(value)
324            unicode_mixed[self._decode_key(key)] = value
325        return unicode_mixed
326
327    def dict_of_lists(self):
328        """
329        Returns a dictionary where each key is associated with a
330        list of values.
331        """
332        unicode_dict = {}
333        for key, value in six.iteritems(self.multi.dict_of_lists()):
334            value = [self._decode_value(value) for value in value]
335            unicode_dict[self._decode_key(key)] = value
336        return unicode_dict
337
338    def __delitem__(self, key):
339        key = self._encode_key(key)
340        self.multi.__delitem__(key)
341
342    def __contains__(self, key):
343        key = self._encode_key(key)
344        return self.multi.__contains__(key)
345
346    has_key = __contains__
347
348    def clear(self):
349        self.multi.clear()
350
351    def copy(self):
352        return UnicodeMultiDict(self.multi.copy(), self.encoding, self.errors,
353                                decode_keys=self.decode_keys)
354
355    def setdefault(self, key, default=None):
356        key = self._encode_key(key)
357        return self._decode_value(self.multi.setdefault(key, default))
358
359    def pop(self, key, *args):
360        key = self._encode_key(key)
361        return self._decode_value(self.multi.pop(key, *args))
362
363    def popitem(self):
364        k, v = self.multi.popitem()
365        return (self._decode_key(k), self._decode_value(v))
366
367    def __repr__(self):
368        items = ', '.join(['(%r, %r)' % v for v in self.items()])
369        return '%s([%s])' % (self.__class__.__name__, items)
370
371    def __len__(self):
372        return self.multi.__len__()
373
374    ##
375    ## All the iteration:
376    ##
377
378    def keys(self):
379        return [self._decode_key(k) for k in self.multi.iterkeys()]
380
381    def iterkeys(self):
382        for k in self.multi.iterkeys():
383            yield self._decode_key(k)
384
385    __iter__ = iterkeys
386
387    def items(self):
388        return [(self._decode_key(k), self._decode_value(v)) for \
389                    k, v in six.iteritems(self.multi)]
390
391    def iteritems(self):
392        for k, v in six.iteritems(self.multi):
393            yield (self._decode_key(k), self._decode_value(v))
394
395    def values(self):
396        return [self._decode_value(v) for v in self.multi.itervalues()]
397
398    def itervalues(self):
399        for v in self.multi.itervalues():
400            yield self._decode_value(v)
401
402__test__ = {
403    'general': """
404    >>> d = MultiDict(a=1, b=2)
405    >>> d['a']
406    1
407    >>> d.getall('c')
408    []
409    >>> d.add('a', 2)
410    >>> d['a']
411    1
412    >>> d.getall('a')
413    [1, 2]
414    >>> d['b'] = 4
415    >>> d.getall('b')
416    [4]
417    >>> d.keys()
418    ['a', 'a', 'b']
419    >>> d.items()
420    [('a', 1), ('a', 2), ('b', 4)]
421    >>> d.mixed()
422    {'a': [1, 2], 'b': 4}
423    >>> MultiDict([('a', 'b')], c=2)
424    MultiDict([('a', 'b'), ('c', 2)])
425    """}
426
427if __name__ == '__main__':
428    import doctest
429    doctest.testmod()
430