1# Copyright (c) 2006-2008 Mitch Garnaat http://garnaat.org/ 2# 3# Permission is hereby granted, free of charge, to any person obtaining a 4# copy of this software and associated documentation files (the 5# "Software"), to deal in the Software without restriction, including 6# without limitation the rights to use, copy, modify, merge, publish, dis- 7# tribute, sublicense, and/or sell copies of the Software, and to permit 8# persons to whom the Software is furnished to do so, subject to the fol- 9# lowing conditions: 10# 11# The above copyright notice and this permission notice shall be included 12# in all copies or substantial portions of the Software. 13# 14# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 15# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL- 16# ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 17# SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 18# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 20# IN THE SOFTWARE. 21import boto 22from boto.utils import find_class, Password 23from boto.sdb.db.key import Key 24from boto.sdb.db.model import Model 25from boto.compat import six, encodebytes 26from datetime import datetime 27from xml.dom.minidom import getDOMImplementation, parse, parseString, Node 28 29ISO8601 = '%Y-%m-%dT%H:%M:%SZ' 30 31class XMLConverter(object): 32 """ 33 Responsible for converting base Python types to format compatible with underlying 34 database. For SimpleDB, that means everything needs to be converted to a string 35 when stored in SimpleDB and from a string when retrieved. 36 37 To convert a value, pass it to the encode or decode method. The encode method 38 will take a Python native value and convert to DB format. The decode method will 39 take a DB format value and convert it to Python native format. To find the appropriate 40 method to call, the generic encode/decode methods will look for the type-specific 41 method by searching for a method called "encode_<type name>" or "decode_<type name>". 42 """ 43 def __init__(self, manager): 44 self.manager = manager 45 self.type_map = { bool : (self.encode_bool, self.decode_bool), 46 int : (self.encode_int, self.decode_int), 47 Model : (self.encode_reference, self.decode_reference), 48 Key : (self.encode_reference, self.decode_reference), 49 Password : (self.encode_password, self.decode_password), 50 datetime : (self.encode_datetime, self.decode_datetime)} 51 if six.PY2: 52 self.type_map[long] = (self.encode_long, self.decode_long) 53 54 def get_text_value(self, parent_node): 55 value = '' 56 for node in parent_node.childNodes: 57 if node.nodeType == node.TEXT_NODE: 58 value += node.data 59 return value 60 61 def encode(self, item_type, value): 62 if item_type in self.type_map: 63 encode = self.type_map[item_type][0] 64 return encode(value) 65 return value 66 67 def decode(self, item_type, value): 68 if item_type in self.type_map: 69 decode = self.type_map[item_type][1] 70 return decode(value) 71 else: 72 value = self.get_text_value(value) 73 return value 74 75 def encode_prop(self, prop, value): 76 if isinstance(value, list): 77 if hasattr(prop, 'item_type'): 78 new_value = [] 79 for v in value: 80 item_type = getattr(prop, "item_type") 81 if Model in item_type.mro(): 82 item_type = Model 83 new_value.append(self.encode(item_type, v)) 84 return new_value 85 else: 86 return value 87 else: 88 return self.encode(prop.data_type, value) 89 90 def decode_prop(self, prop, value): 91 if prop.data_type == list: 92 if hasattr(prop, 'item_type'): 93 item_type = getattr(prop, "item_type") 94 if Model in item_type.mro(): 95 item_type = Model 96 values = [] 97 for item_node in value.getElementsByTagName('item'): 98 value = self.decode(item_type, item_node) 99 values.append(value) 100 return values 101 else: 102 return self.get_text_value(value) 103 else: 104 return self.decode(prop.data_type, value) 105 106 def encode_int(self, value): 107 value = int(value) 108 return '%d' % value 109 110 def decode_int(self, value): 111 value = self.get_text_value(value) 112 if value: 113 value = int(value) 114 else: 115 value = None 116 return value 117 118 def encode_long(self, value): 119 value = long(value) 120 return '%d' % value 121 122 def decode_long(self, value): 123 value = self.get_text_value(value) 124 return long(value) 125 126 def encode_bool(self, value): 127 if value == True: 128 return 'true' 129 else: 130 return 'false' 131 132 def decode_bool(self, value): 133 value = self.get_text_value(value) 134 if value.lower() == 'true': 135 return True 136 else: 137 return False 138 139 def encode_datetime(self, value): 140 return value.strftime(ISO8601) 141 142 def decode_datetime(self, value): 143 value = self.get_text_value(value) 144 try: 145 return datetime.strptime(value, ISO8601) 146 except: 147 return None 148 149 def encode_reference(self, value): 150 if isinstance(value, six.string_types): 151 return value 152 if value is None: 153 return '' 154 else: 155 val_node = self.manager.doc.createElement("object") 156 val_node.setAttribute('id', value.id) 157 val_node.setAttribute('class', '%s.%s' % (value.__class__.__module__, value.__class__.__name__)) 158 return val_node 159 160 def decode_reference(self, value): 161 if not value: 162 return None 163 try: 164 value = value.childNodes[0] 165 class_name = value.getAttribute("class") 166 id = value.getAttribute("id") 167 cls = find_class(class_name) 168 return cls.get_by_ids(id) 169 except: 170 return None 171 172 def encode_password(self, value): 173 if value and len(value) > 0: 174 return str(value) 175 else: 176 return None 177 178 def decode_password(self, value): 179 value = self.get_text_value(value) 180 return Password(value) 181 182 183class XMLManager(object): 184 185 def __init__(self, cls, db_name, db_user, db_passwd, 186 db_host, db_port, db_table, ddl_dir, enable_ssl): 187 self.cls = cls 188 if not db_name: 189 db_name = cls.__name__.lower() 190 self.db_name = db_name 191 self.db_user = db_user 192 self.db_passwd = db_passwd 193 self.db_host = db_host 194 self.db_port = db_port 195 self.db_table = db_table 196 self.ddl_dir = ddl_dir 197 self.s3 = None 198 self.converter = XMLConverter(self) 199 self.impl = getDOMImplementation() 200 self.doc = self.impl.createDocument(None, 'objects', None) 201 202 self.connection = None 203 self.enable_ssl = enable_ssl 204 self.auth_header = None 205 if self.db_user: 206 base64string = encodebytes('%s:%s' % (self.db_user, self.db_passwd))[:-1] 207 authheader = "Basic %s" % base64string 208 self.auth_header = authheader 209 210 def _connect(self): 211 if self.db_host: 212 if self.enable_ssl: 213 from httplib import HTTPSConnection as Connection 214 else: 215 from httplib import HTTPConnection as Connection 216 217 self.connection = Connection(self.db_host, self.db_port) 218 219 def _make_request(self, method, url, post_data=None, body=None): 220 """ 221 Make a request on this connection 222 """ 223 if not self.connection: 224 self._connect() 225 try: 226 self.connection.close() 227 except: 228 pass 229 self.connection.connect() 230 headers = {} 231 if self.auth_header: 232 headers["Authorization"] = self.auth_header 233 self.connection.request(method, url, body, headers) 234 resp = self.connection.getresponse() 235 return resp 236 237 def new_doc(self): 238 return self.impl.createDocument(None, 'objects', None) 239 240 def _object_lister(self, cls, doc): 241 for obj_node in doc.getElementsByTagName('object'): 242 if not cls: 243 class_name = obj_node.getAttribute('class') 244 cls = find_class(class_name) 245 id = obj_node.getAttribute('id') 246 obj = cls(id) 247 for prop_node in obj_node.getElementsByTagName('property'): 248 prop_name = prop_node.getAttribute('name') 249 prop = obj.find_property(prop_name) 250 if prop: 251 if hasattr(prop, 'item_type'): 252 value = self.get_list(prop_node, prop.item_type) 253 else: 254 value = self.decode_value(prop, prop_node) 255 value = prop.make_value_from_datastore(value) 256 setattr(obj, prop.name, value) 257 yield obj 258 259 def reset(self): 260 self._connect() 261 262 def get_doc(self): 263 return self.doc 264 265 def encode_value(self, prop, value): 266 return self.converter.encode_prop(prop, value) 267 268 def decode_value(self, prop, value): 269 return self.converter.decode_prop(prop, value) 270 271 def get_s3_connection(self): 272 if not self.s3: 273 self.s3 = boto.connect_s3(self.aws_access_key_id, self.aws_secret_access_key) 274 return self.s3 275 276 def get_list(self, prop_node, item_type): 277 values = [] 278 try: 279 items_node = prop_node.getElementsByTagName('items')[0] 280 except: 281 return [] 282 for item_node in items_node.getElementsByTagName('item'): 283 value = self.converter.decode(item_type, item_node) 284 values.append(value) 285 return values 286 287 def get_object_from_doc(self, cls, id, doc): 288 obj_node = doc.getElementsByTagName('object')[0] 289 if not cls: 290 class_name = obj_node.getAttribute('class') 291 cls = find_class(class_name) 292 if not id: 293 id = obj_node.getAttribute('id') 294 obj = cls(id) 295 for prop_node in obj_node.getElementsByTagName('property'): 296 prop_name = prop_node.getAttribute('name') 297 prop = obj.find_property(prop_name) 298 value = self.decode_value(prop, prop_node) 299 value = prop.make_value_from_datastore(value) 300 if value is not None: 301 try: 302 setattr(obj, prop.name, value) 303 except: 304 pass 305 return obj 306 307 def get_props_from_doc(self, cls, id, doc): 308 """ 309 Pull out the properties from this document 310 Returns the class, the properties in a hash, and the id if provided as a tuple 311 :return: (cls, props, id) 312 """ 313 obj_node = doc.getElementsByTagName('object')[0] 314 if not cls: 315 class_name = obj_node.getAttribute('class') 316 cls = find_class(class_name) 317 if not id: 318 id = obj_node.getAttribute('id') 319 props = {} 320 for prop_node in obj_node.getElementsByTagName('property'): 321 prop_name = prop_node.getAttribute('name') 322 prop = cls.find_property(prop_name) 323 value = self.decode_value(prop, prop_node) 324 value = prop.make_value_from_datastore(value) 325 if value is not None: 326 props[prop.name] = value 327 return (cls, props, id) 328 329 330 def get_object(self, cls, id): 331 if not self.connection: 332 self._connect() 333 334 if not self.connection: 335 raise NotImplementedError("Can't query without a database connection") 336 url = "/%s/%s" % (self.db_name, id) 337 resp = self._make_request('GET', url) 338 if resp.status == 200: 339 doc = parse(resp) 340 else: 341 raise Exception("Error: %s" % resp.status) 342 return self.get_object_from_doc(cls, id, doc) 343 344 def query(self, cls, filters, limit=None, order_by=None): 345 if not self.connection: 346 self._connect() 347 348 if not self.connection: 349 raise NotImplementedError("Can't query without a database connection") 350 351 from urllib import urlencode 352 353 query = str(self._build_query(cls, filters, limit, order_by)) 354 if query: 355 url = "/%s?%s" % (self.db_name, urlencode({"query": query})) 356 else: 357 url = "/%s" % self.db_name 358 resp = self._make_request('GET', url) 359 if resp.status == 200: 360 doc = parse(resp) 361 else: 362 raise Exception("Error: %s" % resp.status) 363 return self._object_lister(cls, doc) 364 365 def _build_query(self, cls, filters, limit, order_by): 366 import types 367 if len(filters) > 4: 368 raise Exception('Too many filters, max is 4') 369 parts = [] 370 properties = cls.properties(hidden=False) 371 for filter, value in filters: 372 name, op = filter.strip().split() 373 found = False 374 for property in properties: 375 if property.name == name: 376 found = True 377 if types.TypeType(value) == list: 378 filter_parts = [] 379 for val in value: 380 val = self.encode_value(property, val) 381 filter_parts.append("'%s' %s '%s'" % (name, op, val)) 382 parts.append("[%s]" % " OR ".join(filter_parts)) 383 else: 384 value = self.encode_value(property, value) 385 parts.append("['%s' %s '%s']" % (name, op, value)) 386 if not found: 387 raise Exception('%s is not a valid field' % name) 388 if order_by: 389 if order_by.startswith("-"): 390 key = order_by[1:] 391 type = "desc" 392 else: 393 key = order_by 394 type = "asc" 395 parts.append("['%s' starts-with ''] sort '%s' %s" % (key, key, type)) 396 return ' intersection '.join(parts) 397 398 def query_gql(self, query_string, *args, **kwds): 399 raise NotImplementedError("GQL queries not supported in XML") 400 401 def save_list(self, doc, items, prop_node): 402 items_node = doc.createElement('items') 403 prop_node.appendChild(items_node) 404 for item in items: 405 item_node = doc.createElement('item') 406 items_node.appendChild(item_node) 407 if isinstance(item, Node): 408 item_node.appendChild(item) 409 else: 410 text_node = doc.createTextNode(item) 411 item_node.appendChild(text_node) 412 413 def save_object(self, obj, expected_value=None): 414 """ 415 Marshal the object and do a PUT 416 """ 417 doc = self.marshal_object(obj) 418 if obj.id: 419 url = "/%s/%s" % (self.db_name, obj.id) 420 else: 421 url = "/%s" % (self.db_name) 422 resp = self._make_request("PUT", url, body=doc.toxml()) 423 new_obj = self.get_object_from_doc(obj.__class__, None, parse(resp)) 424 obj.id = new_obj.id 425 for prop in obj.properties(): 426 try: 427 propname = prop.name 428 except AttributeError: 429 propname = None 430 if propname: 431 value = getattr(new_obj, prop.name) 432 if value: 433 setattr(obj, prop.name, value) 434 return obj 435 436 437 def marshal_object(self, obj, doc=None): 438 if not doc: 439 doc = self.new_doc() 440 if not doc: 441 doc = self.doc 442 obj_node = doc.createElement('object') 443 444 if obj.id: 445 obj_node.setAttribute('id', obj.id) 446 447 obj_node.setAttribute('class', '%s.%s' % (obj.__class__.__module__, 448 obj.__class__.__name__)) 449 root = doc.documentElement 450 root.appendChild(obj_node) 451 for property in obj.properties(hidden=False): 452 prop_node = doc.createElement('property') 453 prop_node.setAttribute('name', property.name) 454 prop_node.setAttribute('type', property.type_name) 455 value = property.get_value_for_datastore(obj) 456 if value is not None: 457 value = self.encode_value(property, value) 458 if isinstance(value, list): 459 self.save_list(doc, value, prop_node) 460 elif isinstance(value, Node): 461 prop_node.appendChild(value) 462 else: 463 text_node = doc.createTextNode(six.text_type(value).encode("ascii", "ignore")) 464 prop_node.appendChild(text_node) 465 obj_node.appendChild(prop_node) 466 467 return doc 468 469 def unmarshal_object(self, fp, cls=None, id=None): 470 if isinstance(fp, six.string_types): 471 doc = parseString(fp) 472 else: 473 doc = parse(fp) 474 return self.get_object_from_doc(cls, id, doc) 475 476 def unmarshal_props(self, fp, cls=None, id=None): 477 """ 478 Same as unmarshalling an object, except it returns 479 from "get_props_from_doc" 480 """ 481 if isinstance(fp, six.string_types): 482 doc = parseString(fp) 483 else: 484 doc = parse(fp) 485 return self.get_props_from_doc(cls, id, doc) 486 487 def delete_object(self, obj): 488 url = "/%s/%s" % (self.db_name, obj.id) 489 return self._make_request("DELETE", url) 490 491 def set_key_value(self, obj, name, value): 492 self.domain.put_attributes(obj.id, {name: value}, replace=True) 493 494 def delete_key_value(self, obj, name): 495 self.domain.delete_attributes(obj.id, name) 496 497 def get_key_value(self, obj, name): 498 a = self.domain.get_attributes(obj.id, name) 499 if name in a: 500 return a[name] 501 else: 502 return None 503 504 def get_raw_item(self, obj): 505 return self.domain.get_item(obj.id) 506 507 def set_property(self, prop, obj, name, value): 508 pass 509 510 def get_property(self, prop, obj, name): 511 pass 512 513 def load_object(self, obj): 514 if not obj._loaded: 515 obj = obj.get_by_id(obj.id) 516 obj._loaded = True 517 return obj 518