1# (c) 2005 Ian Bicking and contributors; written for Paste (http://pythonpaste.org) 2# Licensed under the MIT license: http://www.opensource.org/licenses/mit-license.php 3import cgi 4import copy 5import six 6import sys 7 8try: 9 # Python 3 10 from collections import MutableMapping as DictMixin 11except ImportError: 12 # Python 2 13 from UserDict import DictMixin 14 15class MultiDict(DictMixin): 16 17 """ 18 An ordered dictionary that can have multiple values for each key. 19 Adds the methods getall, getone, mixed, and add to the normal 20 dictionary interface. 21 """ 22 23 def __init__(self, *args, **kw): 24 if len(args) > 1: 25 raise TypeError( 26 "MultiDict can only be called with one positional argument") 27 if args: 28 if hasattr(args[0], 'iteritems'): 29 items = args[0].iteritems() 30 elif hasattr(args[0], 'items'): 31 items = args[0].items() 32 else: 33 items = args[0] 34 self._items = list(items) 35 else: 36 self._items = [] 37 self._items.extend(six.iteritems(kw)) 38 39 def __getitem__(self, key): 40 for k, v in self._items: 41 if k == key: 42 return v 43 raise KeyError(repr(key)) 44 45 def __setitem__(self, key, value): 46 try: 47 del self[key] 48 except KeyError: 49 pass 50 self._items.append((key, value)) 51 52 def add(self, key, value): 53 """ 54 Add the key and value, not overwriting any previous value. 55 """ 56 self._items.append((key, value)) 57 58 def getall(self, key): 59 """ 60 Return a list of all values matching the key (may be an empty list) 61 """ 62 result = [] 63 for k, v in self._items: 64 if type(key) == type(k) and key == k: 65 result.append(v) 66 return result 67 68 def getone(self, key): 69 """ 70 Get one value matching the key, raising a KeyError if multiple 71 values were found. 72 """ 73 v = self.getall(key) 74 if not v: 75 raise KeyError('Key not found: %r' % key) 76 if len(v) > 1: 77 raise KeyError('Multiple values match %r: %r' % (key, v)) 78 return v[0] 79 80 def mixed(self): 81 """ 82 Returns a dictionary where the values are either single 83 values, or a list of values when a key/value appears more than 84 once in this dictionary. This is similar to the kind of 85 dictionary often used to represent the variables in a web 86 request. 87 """ 88 result = {} 89 multi = {} 90 for key, value in self._items: 91 if key in result: 92 # We do this to not clobber any lists that are 93 # *actual* values in this dictionary: 94 if key in multi: 95 result[key].append(value) 96 else: 97 result[key] = [result[key], value] 98 multi[key] = None 99 else: 100 result[key] = value 101 return result 102 103 def dict_of_lists(self): 104 """ 105 Returns a dictionary where each key is associated with a 106 list of values. 107 """ 108 result = {} 109 for key, value in self._items: 110 if key in result: 111 result[key].append(value) 112 else: 113 result[key] = [value] 114 return result 115 116 def __delitem__(self, key): 117 items = self._items 118 found = False 119 for i in range(len(items)-1, -1, -1): 120 if type(items[i][0]) == type(key) and items[i][0] == key: 121 del items[i] 122 found = True 123 if not found: 124 raise KeyError(repr(key)) 125 126 def __contains__(self, key): 127 for k, v in self._items: 128 if type(k) == type(key) and k == key: 129 return True 130 return False 131 132 has_key = __contains__ 133 134 def clear(self): 135 self._items = [] 136 137 def copy(self): 138 return MultiDict(self) 139 140 def setdefault(self, key, default=None): 141 for k, v in self._items: 142 if key == k: 143 return v 144 self._items.append((key, default)) 145 return default 146 147 def pop(self, key, *args): 148 if len(args) > 1: 149 raise TypeError("pop expected at most 2 arguments, got " 150 + repr(1 + len(args))) 151 for i in range(len(self._items)): 152 if type(self._items[i][0]) == type(key) and self._items[i][0] == key: 153 v = self._items[i][1] 154 del self._items[i] 155 return v 156 if args: 157 return args[0] 158 else: 159 raise KeyError(repr(key)) 160 161 def popitem(self): 162 return self._items.pop() 163 164 def update(self, other=None, **kwargs): 165 if other is None: 166 pass 167 elif hasattr(other, 'items'): 168 self._items.extend(other.items()) 169 elif hasattr(other, 'keys'): 170 for k in other.keys(): 171 self._items.append((k, other[k])) 172 else: 173 for k, v in other: 174 self._items.append((k, v)) 175 if kwargs: 176 self.update(kwargs) 177 178 def __repr__(self): 179 items = ', '.join(['(%r, %r)' % v for v in self._items]) 180 return '%s([%s])' % (self.__class__.__name__, items) 181 182 def __len__(self): 183 return len(self._items) 184 185 ## 186 ## All the iteration: 187 ## 188 189 def keys(self): 190 return [k for k, v in self._items] 191 192 def iterkeys(self): 193 for k, v in self._items: 194 yield k 195 196 __iter__ = iterkeys 197 198 def items(self): 199 return self._items[:] 200 201 def iteritems(self): 202 return iter(self._items) 203 204 def values(self): 205 return [v for k, v in self._items] 206 207 def itervalues(self): 208 for k, v in self._items: 209 yield v 210 211class UnicodeMultiDict(DictMixin): 212 """ 213 A MultiDict wrapper that decodes returned values to unicode on the 214 fly. Decoding is not applied to assigned values. 215 216 The key/value contents are assumed to be ``str``/``strs`` or 217 ``str``/``FieldStorages`` (as is returned by the ``paste.request.parse_`` 218 functions). 219 220 Can optionally also decode keys when the ``decode_keys`` argument is 221 True. 222 223 ``FieldStorage`` instances are cloned, and the clone's ``filename`` 224 variable is decoded. Its ``name`` variable is decoded when ``decode_keys`` 225 is enabled. 226 227 """ 228 def __init__(self, multi=None, encoding=None, errors='strict', 229 decode_keys=False): 230 self.multi = multi 231 if encoding is None: 232 encoding = sys.getdefaultencoding() 233 self.encoding = encoding 234 self.errors = errors 235 self.decode_keys = decode_keys 236 if self.decode_keys: 237 items = self.multi._items 238 for index, item in enumerate(items): 239 key, value = item 240 key = self._encode_key(key) 241 items[index] = (key, value) 242 243 def _encode_key(self, key): 244 if self.decode_keys: 245 try: 246 key = key.encode(self.encoding, self.errors) 247 except AttributeError: 248 pass 249 return key 250 251 def _decode_key(self, key): 252 if self.decode_keys: 253 try: 254 key = key.decode(self.encoding, self.errors) 255 except AttributeError: 256 pass 257 return key 258 259 def _decode_value(self, value): 260 """ 261 Decode the specified value to unicode. Assumes value is a ``str`` or 262 `FieldStorage`` object. 263 264 ``FieldStorage`` objects are specially handled. 265 """ 266 if isinstance(value, cgi.FieldStorage): 267 # decode FieldStorage's field name and filename 268 value = copy.copy(value) 269 if self.decode_keys and isinstance(value.name, six.binary_type): 270 value.name = value.name.decode(self.encoding, self.errors) 271 if six.PY2: 272 value.filename = value.filename.decode(self.encoding, self.errors) 273 else: 274 try: 275 value = value.decode(self.encoding, self.errors) 276 except AttributeError: 277 pass 278 return value 279 280 def __getitem__(self, key): 281 key = self._encode_key(key) 282 return self._decode_value(self.multi.__getitem__(key)) 283 284 def __setitem__(self, key, value): 285 key = self._encode_key(key) 286 self.multi.__setitem__(key, value) 287 288 def add(self, key, value): 289 """ 290 Add the key and value, not overwriting any previous value. 291 """ 292 key = self._encode_key(key) 293 self.multi.add(key, value) 294 295 def getall(self, key): 296 """ 297 Return a list of all values matching the key (may be an empty list) 298 """ 299 key = self._encode_key(key) 300 return [self._decode_value(v) for v in self.multi.getall(key)] 301 302 def getone(self, key): 303 """ 304 Get one value matching the key, raising a KeyError if multiple 305 values were found. 306 """ 307 key = self._encode_key(key) 308 return self._decode_value(self.multi.getone(key)) 309 310 def mixed(self): 311 """ 312 Returns a dictionary where the values are either single 313 values, or a list of values when a key/value appears more than 314 once in this dictionary. This is similar to the kind of 315 dictionary often used to represent the variables in a web 316 request. 317 """ 318 unicode_mixed = {} 319 for key, value in six.iteritems(self.multi.mixed()): 320 if isinstance(value, list): 321 value = [self._decode_value(value) for value in value] 322 else: 323 value = self._decode_value(value) 324 unicode_mixed[self._decode_key(key)] = value 325 return unicode_mixed 326 327 def dict_of_lists(self): 328 """ 329 Returns a dictionary where each key is associated with a 330 list of values. 331 """ 332 unicode_dict = {} 333 for key, value in six.iteritems(self.multi.dict_of_lists()): 334 value = [self._decode_value(value) for value in value] 335 unicode_dict[self._decode_key(key)] = value 336 return unicode_dict 337 338 def __delitem__(self, key): 339 key = self._encode_key(key) 340 self.multi.__delitem__(key) 341 342 def __contains__(self, key): 343 key = self._encode_key(key) 344 return self.multi.__contains__(key) 345 346 has_key = __contains__ 347 348 def clear(self): 349 self.multi.clear() 350 351 def copy(self): 352 return UnicodeMultiDict(self.multi.copy(), self.encoding, self.errors, 353 decode_keys=self.decode_keys) 354 355 def setdefault(self, key, default=None): 356 key = self._encode_key(key) 357 return self._decode_value(self.multi.setdefault(key, default)) 358 359 def pop(self, key, *args): 360 key = self._encode_key(key) 361 return self._decode_value(self.multi.pop(key, *args)) 362 363 def popitem(self): 364 k, v = self.multi.popitem() 365 return (self._decode_key(k), self._decode_value(v)) 366 367 def __repr__(self): 368 items = ', '.join(['(%r, %r)' % v for v in self.items()]) 369 return '%s([%s])' % (self.__class__.__name__, items) 370 371 def __len__(self): 372 return self.multi.__len__() 373 374 ## 375 ## All the iteration: 376 ## 377 378 def keys(self): 379 return [self._decode_key(k) for k in self.multi.iterkeys()] 380 381 def iterkeys(self): 382 for k in self.multi.iterkeys(): 383 yield self._decode_key(k) 384 385 __iter__ = iterkeys 386 387 def items(self): 388 return [(self._decode_key(k), self._decode_value(v)) for \ 389 k, v in six.iteritems(self.multi)] 390 391 def iteritems(self): 392 for k, v in six.iteritems(self.multi): 393 yield (self._decode_key(k), self._decode_value(v)) 394 395 def values(self): 396 return [self._decode_value(v) for v in self.multi.itervalues()] 397 398 def itervalues(self): 399 for v in self.multi.itervalues(): 400 yield self._decode_value(v) 401 402__test__ = { 403 'general': """ 404 >>> d = MultiDict(a=1, b=2) 405 >>> d['a'] 406 1 407 >>> d.getall('c') 408 [] 409 >>> d.add('a', 2) 410 >>> d['a'] 411 1 412 >>> d.getall('a') 413 [1, 2] 414 >>> d['b'] = 4 415 >>> d.getall('b') 416 [4] 417 >>> d.keys() 418 ['a', 'a', 'b'] 419 >>> d.items() 420 [('a', 1), ('a', 2), ('b', 4)] 421 >>> d.mixed() 422 {'a': [1, 2], 'b': 4} 423 >>> MultiDict([('a', 'b')], c=2) 424 MultiDict([('a', 'b'), ('c', 2)]) 425 """} 426 427if __name__ == '__main__': 428 import doctest 429 doctest.testmod() 430