1# Copyright (c) 2006-2008 Mitch Garnaat http://garnaat.org/
2#
3# Permission is hereby granted, free of charge, to any person obtaining a
4# copy of this software and associated documentation files (the
5# "Software"), to deal in the Software without restriction, including
6# without limitation the rights to use, copy, modify, merge, publish, dis-
7# tribute, sublicense, and/or sell copies of the Software, and to permit
8# persons to whom the Software is furnished to do so, subject to the fol-
9# lowing conditions:
10#
11# The above copyright notice and this permission notice shall be included
12# in all copies or substantial portions of the Software.
13#
14# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL-
16# ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
17# SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
18# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
20# IN THE SOFTWARE.
21import boto
22from boto.utils import find_class, Password
23from boto.sdb.db.key import Key
24from boto.sdb.db.model import Model
25from boto.compat import six, encodebytes
26from datetime import datetime
27from xml.dom.minidom import getDOMImplementation, parse, parseString, Node
28
29ISO8601 = '%Y-%m-%dT%H:%M:%SZ'
30
31class XMLConverter(object):
32    """
33    Responsible for converting base Python types to format compatible with underlying
34    database.  For SimpleDB, that means everything needs to be converted to a string
35    when stored in SimpleDB and from a string when retrieved.
36
37    To convert a value, pass it to the encode or decode method.  The encode method
38    will take a Python native value and convert to DB format.  The decode method will
39    take a DB format value and convert it to Python native format.  To find the appropriate
40    method to call, the generic encode/decode methods will look for the type-specific
41    method by searching for a method called "encode_<type name>" or "decode_<type name>".
42    """
43    def __init__(self, manager):
44        self.manager = manager
45        self.type_map = { bool : (self.encode_bool, self.decode_bool),
46                          int : (self.encode_int, self.decode_int),
47                          Model : (self.encode_reference, self.decode_reference),
48                          Key : (self.encode_reference, self.decode_reference),
49                          Password : (self.encode_password, self.decode_password),
50                          datetime : (self.encode_datetime, self.decode_datetime)}
51        if six.PY2:
52            self.type_map[long] = (self.encode_long, self.decode_long)
53
54    def get_text_value(self, parent_node):
55        value = ''
56        for node in parent_node.childNodes:
57            if node.nodeType == node.TEXT_NODE:
58                value += node.data
59        return value
60
61    def encode(self, item_type, value):
62        if item_type in self.type_map:
63            encode = self.type_map[item_type][0]
64            return encode(value)
65        return value
66
67    def decode(self, item_type, value):
68        if item_type in self.type_map:
69            decode = self.type_map[item_type][1]
70            return decode(value)
71        else:
72            value = self.get_text_value(value)
73        return value
74
75    def encode_prop(self, prop, value):
76        if isinstance(value, list):
77            if hasattr(prop, 'item_type'):
78                new_value = []
79                for v in value:
80                    item_type = getattr(prop, "item_type")
81                    if Model in item_type.mro():
82                        item_type = Model
83                    new_value.append(self.encode(item_type, v))
84                return new_value
85            else:
86                return value
87        else:
88            return self.encode(prop.data_type, value)
89
90    def decode_prop(self, prop, value):
91        if prop.data_type == list:
92            if hasattr(prop, 'item_type'):
93                item_type = getattr(prop, "item_type")
94                if Model in item_type.mro():
95                    item_type = Model
96                values = []
97                for item_node in value.getElementsByTagName('item'):
98                    value = self.decode(item_type, item_node)
99                    values.append(value)
100                return values
101            else:
102                return self.get_text_value(value)
103        else:
104            return self.decode(prop.data_type, value)
105
106    def encode_int(self, value):
107        value = int(value)
108        return '%d' % value
109
110    def decode_int(self, value):
111        value = self.get_text_value(value)
112        if value:
113            value = int(value)
114        else:
115            value = None
116        return value
117
118    def encode_long(self, value):
119        value = long(value)
120        return '%d' % value
121
122    def decode_long(self, value):
123        value = self.get_text_value(value)
124        return long(value)
125
126    def encode_bool(self, value):
127        if value == True:
128            return 'true'
129        else:
130            return 'false'
131
132    def decode_bool(self, value):
133        value = self.get_text_value(value)
134        if value.lower() == 'true':
135            return True
136        else:
137            return False
138
139    def encode_datetime(self, value):
140        return value.strftime(ISO8601)
141
142    def decode_datetime(self, value):
143        value = self.get_text_value(value)
144        try:
145            return datetime.strptime(value, ISO8601)
146        except:
147            return None
148
149    def encode_reference(self, value):
150        if isinstance(value, six.string_types):
151            return value
152        if value is None:
153            return ''
154        else:
155            val_node = self.manager.doc.createElement("object")
156            val_node.setAttribute('id', value.id)
157            val_node.setAttribute('class', '%s.%s' % (value.__class__.__module__, value.__class__.__name__))
158            return val_node
159
160    def decode_reference(self, value):
161        if not value:
162            return None
163        try:
164            value = value.childNodes[0]
165            class_name = value.getAttribute("class")
166            id = value.getAttribute("id")
167            cls = find_class(class_name)
168            return cls.get_by_ids(id)
169        except:
170            return None
171
172    def encode_password(self, value):
173        if value and len(value) > 0:
174            return str(value)
175        else:
176            return None
177
178    def decode_password(self, value):
179        value = self.get_text_value(value)
180        return Password(value)
181
182
183class XMLManager(object):
184
185    def __init__(self, cls, db_name, db_user, db_passwd,
186                 db_host, db_port, db_table, ddl_dir, enable_ssl):
187        self.cls = cls
188        if not db_name:
189            db_name = cls.__name__.lower()
190        self.db_name = db_name
191        self.db_user = db_user
192        self.db_passwd = db_passwd
193        self.db_host = db_host
194        self.db_port = db_port
195        self.db_table = db_table
196        self.ddl_dir = ddl_dir
197        self.s3 = None
198        self.converter = XMLConverter(self)
199        self.impl = getDOMImplementation()
200        self.doc = self.impl.createDocument(None, 'objects', None)
201
202        self.connection = None
203        self.enable_ssl = enable_ssl
204        self.auth_header = None
205        if self.db_user:
206            base64string = encodebytes('%s:%s' % (self.db_user, self.db_passwd))[:-1]
207            authheader = "Basic %s" % base64string
208            self.auth_header = authheader
209
210    def _connect(self):
211        if self.db_host:
212            if self.enable_ssl:
213                from httplib import HTTPSConnection as Connection
214            else:
215                from httplib import HTTPConnection as Connection
216
217            self.connection = Connection(self.db_host, self.db_port)
218
219    def _make_request(self, method, url, post_data=None, body=None):
220        """
221        Make a request on this connection
222        """
223        if not self.connection:
224            self._connect()
225        try:
226            self.connection.close()
227        except:
228            pass
229        self.connection.connect()
230        headers = {}
231        if self.auth_header:
232            headers["Authorization"] = self.auth_header
233        self.connection.request(method, url, body, headers)
234        resp = self.connection.getresponse()
235        return resp
236
237    def new_doc(self):
238        return self.impl.createDocument(None, 'objects', None)
239
240    def _object_lister(self, cls, doc):
241        for obj_node in doc.getElementsByTagName('object'):
242            if not cls:
243                class_name = obj_node.getAttribute('class')
244                cls = find_class(class_name)
245            id = obj_node.getAttribute('id')
246            obj = cls(id)
247            for prop_node in obj_node.getElementsByTagName('property'):
248                prop_name = prop_node.getAttribute('name')
249                prop = obj.find_property(prop_name)
250                if prop:
251                    if hasattr(prop, 'item_type'):
252                        value = self.get_list(prop_node, prop.item_type)
253                    else:
254                        value = self.decode_value(prop, prop_node)
255                        value = prop.make_value_from_datastore(value)
256                    setattr(obj, prop.name, value)
257            yield obj
258
259    def reset(self):
260        self._connect()
261
262    def get_doc(self):
263        return self.doc
264
265    def encode_value(self, prop, value):
266        return self.converter.encode_prop(prop, value)
267
268    def decode_value(self, prop, value):
269        return self.converter.decode_prop(prop, value)
270
271    def get_s3_connection(self):
272        if not self.s3:
273            self.s3 = boto.connect_s3(self.aws_access_key_id, self.aws_secret_access_key)
274        return self.s3
275
276    def get_list(self, prop_node, item_type):
277        values = []
278        try:
279            items_node = prop_node.getElementsByTagName('items')[0]
280        except:
281            return []
282        for item_node in items_node.getElementsByTagName('item'):
283            value = self.converter.decode(item_type, item_node)
284            values.append(value)
285        return values
286
287    def get_object_from_doc(self, cls, id, doc):
288        obj_node = doc.getElementsByTagName('object')[0]
289        if not cls:
290            class_name = obj_node.getAttribute('class')
291            cls = find_class(class_name)
292        if not id:
293            id = obj_node.getAttribute('id')
294        obj = cls(id)
295        for prop_node in obj_node.getElementsByTagName('property'):
296            prop_name = prop_node.getAttribute('name')
297            prop = obj.find_property(prop_name)
298            value = self.decode_value(prop, prop_node)
299            value = prop.make_value_from_datastore(value)
300            if value is not None:
301                try:
302                    setattr(obj, prop.name, value)
303                except:
304                    pass
305        return obj
306
307    def get_props_from_doc(self, cls, id, doc):
308        """
309        Pull out the properties from this document
310        Returns the class, the properties in a hash, and the id if provided as a tuple
311        :return: (cls, props, id)
312        """
313        obj_node = doc.getElementsByTagName('object')[0]
314        if not cls:
315            class_name = obj_node.getAttribute('class')
316            cls = find_class(class_name)
317        if not id:
318            id = obj_node.getAttribute('id')
319        props = {}
320        for prop_node in obj_node.getElementsByTagName('property'):
321            prop_name = prop_node.getAttribute('name')
322            prop = cls.find_property(prop_name)
323            value = self.decode_value(prop, prop_node)
324            value = prop.make_value_from_datastore(value)
325            if value is not None:
326                props[prop.name] = value
327        return (cls, props, id)
328
329
330    def get_object(self, cls, id):
331        if not self.connection:
332            self._connect()
333
334        if not self.connection:
335            raise NotImplementedError("Can't query without a database connection")
336        url = "/%s/%s" % (self.db_name, id)
337        resp = self._make_request('GET', url)
338        if resp.status == 200:
339            doc = parse(resp)
340        else:
341            raise Exception("Error: %s" % resp.status)
342        return self.get_object_from_doc(cls, id, doc)
343
344    def query(self, cls, filters, limit=None, order_by=None):
345        if not self.connection:
346            self._connect()
347
348        if not self.connection:
349            raise NotImplementedError("Can't query without a database connection")
350
351        from urllib import urlencode
352
353        query = str(self._build_query(cls, filters, limit, order_by))
354        if query:
355            url = "/%s?%s" % (self.db_name, urlencode({"query": query}))
356        else:
357            url = "/%s" % self.db_name
358        resp = self._make_request('GET', url)
359        if resp.status == 200:
360            doc = parse(resp)
361        else:
362            raise Exception("Error: %s" % resp.status)
363        return self._object_lister(cls, doc)
364
365    def _build_query(self, cls, filters, limit, order_by):
366        import types
367        if len(filters) > 4:
368            raise Exception('Too many filters, max is 4')
369        parts = []
370        properties = cls.properties(hidden=False)
371        for filter, value in filters:
372            name, op = filter.strip().split()
373            found = False
374            for property in properties:
375                if property.name == name:
376                    found = True
377                    if types.TypeType(value) == list:
378                        filter_parts = []
379                        for val in value:
380                            val = self.encode_value(property, val)
381                            filter_parts.append("'%s' %s '%s'" % (name, op, val))
382                        parts.append("[%s]" % " OR ".join(filter_parts))
383                    else:
384                        value = self.encode_value(property, value)
385                        parts.append("['%s' %s '%s']" % (name, op, value))
386            if not found:
387                raise Exception('%s is not a valid field' % name)
388        if order_by:
389            if order_by.startswith("-"):
390                key = order_by[1:]
391                type = "desc"
392            else:
393                key = order_by
394                type = "asc"
395            parts.append("['%s' starts-with ''] sort '%s' %s" % (key, key, type))
396        return ' intersection '.join(parts)
397
398    def query_gql(self, query_string, *args, **kwds):
399        raise NotImplementedError("GQL queries not supported in XML")
400
401    def save_list(self, doc, items, prop_node):
402        items_node = doc.createElement('items')
403        prop_node.appendChild(items_node)
404        for item in items:
405            item_node = doc.createElement('item')
406            items_node.appendChild(item_node)
407            if isinstance(item, Node):
408                item_node.appendChild(item)
409            else:
410                text_node = doc.createTextNode(item)
411                item_node.appendChild(text_node)
412
413    def save_object(self, obj, expected_value=None):
414        """
415        Marshal the object and do a PUT
416        """
417        doc = self.marshal_object(obj)
418        if obj.id:
419            url = "/%s/%s" % (self.db_name, obj.id)
420        else:
421            url = "/%s" % (self.db_name)
422        resp = self._make_request("PUT", url, body=doc.toxml())
423        new_obj = self.get_object_from_doc(obj.__class__, None, parse(resp))
424        obj.id = new_obj.id
425        for prop in obj.properties():
426            try:
427                propname = prop.name
428            except AttributeError:
429                propname = None
430            if propname:
431                value = getattr(new_obj, prop.name)
432                if value:
433                    setattr(obj, prop.name, value)
434        return obj
435
436
437    def marshal_object(self, obj, doc=None):
438        if not doc:
439            doc = self.new_doc()
440        if not doc:
441            doc = self.doc
442        obj_node = doc.createElement('object')
443
444        if obj.id:
445            obj_node.setAttribute('id', obj.id)
446
447        obj_node.setAttribute('class', '%s.%s' % (obj.__class__.__module__,
448                                                  obj.__class__.__name__))
449        root = doc.documentElement
450        root.appendChild(obj_node)
451        for property in obj.properties(hidden=False):
452            prop_node = doc.createElement('property')
453            prop_node.setAttribute('name', property.name)
454            prop_node.setAttribute('type', property.type_name)
455            value = property.get_value_for_datastore(obj)
456            if value is not None:
457                value = self.encode_value(property, value)
458                if isinstance(value, list):
459                    self.save_list(doc, value, prop_node)
460                elif isinstance(value, Node):
461                    prop_node.appendChild(value)
462                else:
463                    text_node = doc.createTextNode(six.text_type(value).encode("ascii", "ignore"))
464                    prop_node.appendChild(text_node)
465            obj_node.appendChild(prop_node)
466
467        return doc
468
469    def unmarshal_object(self, fp, cls=None, id=None):
470        if isinstance(fp, six.string_types):
471            doc = parseString(fp)
472        else:
473            doc = parse(fp)
474        return self.get_object_from_doc(cls, id, doc)
475
476    def unmarshal_props(self, fp, cls=None, id=None):
477        """
478        Same as unmarshalling an object, except it returns
479        from "get_props_from_doc"
480        """
481        if isinstance(fp, six.string_types):
482            doc = parseString(fp)
483        else:
484            doc = parse(fp)
485        return self.get_props_from_doc(cls, id, doc)
486
487    def delete_object(self, obj):
488        url = "/%s/%s" % (self.db_name, obj.id)
489        return self._make_request("DELETE", url)
490
491    def set_key_value(self, obj, name, value):
492        self.domain.put_attributes(obj.id, {name: value}, replace=True)
493
494    def delete_key_value(self, obj, name):
495        self.domain.delete_attributes(obj.id, name)
496
497    def get_key_value(self, obj, name):
498        a = self.domain.get_attributes(obj.id, name)
499        if name in a:
500            return a[name]
501        else:
502            return None
503
504    def get_raw_item(self, obj):
505        return self.domain.get_item(obj.id)
506
507    def set_property(self, prop, obj, name, value):
508        pass
509
510    def get_property(self, prop, obj, name):
511        pass
512
513    def load_object(self, obj):
514        if not obj._loaded:
515            obj = obj.get_by_id(obj.id)
516            obj._loaded = True
517        return obj
518