1import operator, unittest
2import json
3from django.test import client
4from autotest_lib.frontend.afe import frontend_test_utils, models as afe_models
5
6class ResourceTestCase(unittest.TestCase,
7                       frontend_test_utils.FrontendTestMixin):
8    URI_PREFIX = None # subclasses may override this to use partial URIs
9
10    def setUp(self):
11        super(ResourceTestCase, self).setUp()
12        self._frontend_common_setup()
13        self._setup_debug_user()
14        self.client = client.Client()
15
16
17    def tearDown(self):
18        super(ResourceTestCase, self).tearDown()
19        self._frontend_common_teardown()
20
21
22    def _setup_debug_user(self):
23        user = afe_models.User.objects.create(login='debug_user')
24        acl = afe_models.AclGroup.objects.get(name='my_acl')
25        user.aclgroup_set.add(acl)
26
27
28    def _expected_status(self, method):
29        if method == 'post':
30            return 201
31        if method == 'delete':
32            return 204
33        return 200
34
35
36    def raw_request(self, method, uri, **kwargs):
37        method = method.lower()
38        if method == 'put':
39            # the put() implementation in Django's test client is poorly
40            # implemented and only supports url-encoded keyvals for the data.
41            # the post() implementation is correct, though, so use that, with a
42            # trick to override the method.
43            method = 'post'
44            kwargs['REQUEST_METHOD'] = 'PUT'
45
46        client_method = getattr(self.client, method)
47        return client_method(uri, **kwargs)
48
49
50    def request(self, method, uri, encode_body=True, **kwargs):
51        expected_status = self._expected_status(method)
52
53        if 'data' in kwargs:
54            kwargs.setdefault('content_type', 'application/json')
55            if kwargs['content_type'] == 'application/json':
56                kwargs['data'] = json.dumps(kwargs['data'])
57
58        if uri.startswith('http://'):
59            full_uri = uri
60        else:
61            assert self.URI_PREFIX
62            full_uri = self.URI_PREFIX + '/' + uri
63
64        response = self.raw_request(method, full_uri, **kwargs)
65        self.assertEquals(
66                response.status_code, expected_status,
67                'Requesting %s\nExpected %s, got %s: %s (headers: %s)'
68                % (full_uri, expected_status, response.status_code,
69                   response.content, response._headers))
70
71        if response['content-type'] != 'application/json':
72            return response.content
73
74        try:
75            return json.loads(response.content)
76        except ValueError:
77            self.fail('Invalid reponse body: %s' % response.content)
78
79
80    def sorted_by(self, collection, attribute):
81        return sorted(collection, key=operator.itemgetter(attribute))
82
83
84    def _read_attribute(self, item, attribute_or_list):
85        if isinstance(attribute_or_list, basestring):
86            attribute_or_list = [attribute_or_list]
87        for attribute in attribute_or_list:
88            item = item[attribute]
89        return item
90
91
92    def check_collection(self, collection, attribute_or_list, expected_list,
93                         length=None, check_number=None):
94        """Check the members of a collection of dicts.
95
96        @param collection: an iterable of dicts
97        @param attribute_or_list: an attribute or list of attributes to read.
98                the results will be sorted and compared with expected_list. if
99                a list of attributes is given, the attributes will be read
100                hierarchically, i.e. item[attribute1][attribute2]...
101        @param expected_list: list of expected values
102        @param check_number: if given, only check this number of entries
103        @param length: expected length of list, only necessary if check_number
104                is given
105        """
106        actual_list = sorted(self._read_attribute(item, attribute_or_list)
107                             for item in collection['members'])
108        if length is None and check_number is None:
109            length = len(expected_list)
110        if length is not None:
111            self.assertEquals(len(actual_list), length,
112                              'Expected %s, got %s: %s'
113                              % (length, len(actual_list),
114                                 ', '.join(str(item) for item in actual_list)))
115        if check_number:
116            actual_list = actual_list[:check_number]
117        self.assertEquals(actual_list, expected_list)
118
119
120    def check_relationship(self, resource_uri, relationship_name,
121                           other_entry_name, field, expected_values,
122                           length=None, check_number=None):
123        """Check the members of a relationship collection.
124
125        @param resource_uri: URI of base resource
126        @param relationship_name: name of relationship attribute on base
127                resource
128        @param other_entry_name: name of other entry in relationship
129        @param field: name of field to grab on other entry
130        @param expected values: list of expected values for the given field
131        """
132        response = self.request('get', resource_uri)
133        relationship_uri = response[relationship_name]['href']
134        relationships = self.request('get', relationship_uri)
135        self.check_collection(relationships, [other_entry_name, field],
136                              expected_values, length, check_number)
137