1#!/usr/bin/env python
2#
3# Copyright (C) 2018 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import endpoints
19import json
20import unittest
21
22try:
23    from unittest import mock
24except ImportError:
25    import mock
26
27from webapp.src import vtslab_status as Status
28from webapp.src.endpoint import endpoint_base
29from webapp.src.proto import model
30from webapp.src.testing import unittest_base
31
32
33class EndpointBaseTest(unittest_base.UnitTestBase):
34    """A class to test endpoint_base.EndpointBase class.
35
36    Attributes:
37        eb: An EndpointBase class instance.
38    """
39
40    def setUp(self):
41        """Initializes test"""
42        super(EndpointBaseTest, self).setUp()
43        self.eb = endpoint_base.EndpointBase()
44
45    def testGetAssignedMessagesAttributes(self):
46        attrs = ["hostname", "priority", "test_branch"]
47        job_message = model.JobMessage()
48        for attr in attrs:
49            setattr(job_message, attr, attr)
50        result = self.eb.GetAttributes(job_message, assigned_only=True)
51        self.assertEqual(set(attrs), set(result))
52
53    def testGetAssignedModelAttributes(self):
54        attrs = ["hostname", "priority", "test_branch"]
55        job = model.JobModel()
56        for attr in attrs:
57            setattr(job, attr, attr)
58        result = self.eb.GetAttributes(job, assigned_only=True)
59        self.assertEqual(set(attrs), set(result))
60
61    def testGetAllMessagesAttributes(self):
62        attrs = ["hostname", "priority", "test_branch"]
63        full_attrs = [
64            "test_type", "hostname", "priority", "test_name",
65            "require_signed_device_build", "has_bootloader_img",
66            "has_radio_img", "device", "serial", "build_storage_type",
67            "manifest_branch", "build_target", "build_id", "pab_account_id",
68            "shards", "param", "status", "period", "gsi_storage_type",
69            "gsi_branch", "gsi_build_target", "gsi_build_id",
70            "gsi_pab_account_id", "gsi_vendor_version", "test_storage_type",
71            "test_branch", "test_build_target", "test_build_id",
72            "test_pab_account_id", "retry_count", "infra_log_url",
73            "image_package_repo_base", "report_bucket",
74            "report_spreadsheet_id", "report_persistent_url",
75            "report_reference_url"
76        ]
77        job_message = model.JobMessage()
78        for attr in attrs:
79            setattr(job_message, attr, attr)
80        result = self.eb.GetAttributes(job_message, assigned_only=False)
81        self.assertTrue(set(full_attrs) <= set(result))
82
83    def testGetAllModelAttributes(self):
84        attrs = ["hostname", "priority", "test_branch"]
85        full_attrs = [
86            "test_type", "hostname", "priority", "test_name",
87            "require_signed_device_build", "has_bootloader_img",
88            "has_radio_img", "device", "serial", "build_storage_type",
89            "manifest_branch", "build_target", "build_id", "pab_account_id",
90            "shards", "param", "status", "period", "gsi_storage_type",
91            "gsi_branch", "gsi_build_target", "gsi_build_id",
92            "gsi_pab_account_id", "gsi_vendor_version", "test_storage_type",
93            "test_branch", "test_build_target", "test_build_id",
94            "test_pab_account_id", "timestamp", "heartbeat_stamp",
95            "retry_count", "infra_log_url", "parent_schedule",
96            "image_package_repo_base", "report_bucket",
97            "report_spreadsheet_id", "report_persistent_url",
98            "report_reference_url"
99        ]
100        job = model.JobModel()
101        for attr in attrs:
102            setattr(job, attr, attr)
103        result = self.eb.GetAttributes(job, assigned_only=False)
104        self.assertTrue(set(full_attrs) <= set(result))
105
106    def testGetSingleEntity(self):
107        """Asserts to get a single entity."""
108        device = self.GenerateDeviceModel()
109        device.put()
110
111        request_body = (endpoints.ResourceContainer(
112            model.GetRequestMessage).combined_message_class(
113                size=0,
114                offset=0,
115                filter="",
116                sort="",
117                direction="",
118            ))
119        result, more = self.eb.Get(
120            request=request_body,
121            metaclass=model.DeviceModel,
122            message=model.DeviceInfoMessage)
123        self.assertEqual(len(result), 1)
124        self.assertFalse(more)
125
126    def testGetHundredEntities(self):
127        """Asserts to get hundred entities."""
128        for _ in xrange(100):
129            device = self.GenerateDeviceModel()
130            device.put()
131
132        request_body = (endpoints.ResourceContainer(
133            model.GetRequestMessage).combined_message_class(
134                size=0,
135                offset=0,
136                filter="",
137                sort="",
138                direction="",
139            ))
140        result, more = self.eb.Get(
141            request=request_body,
142            metaclass=model.DeviceModel,
143            message=model.DeviceInfoMessage)
144        self.assertEqual(len(result), 100)
145        self.assertFalse(more)
146
147    def testGetEntitiesWithPagination(self):
148        """Asserts to get entities with pagination."""
149        for _ in xrange(100):
150            device = self.GenerateDeviceModel()
151            device.put()
152
153        request_body = (endpoints.ResourceContainer(
154            model.GetRequestMessage).combined_message_class(
155                size=60,
156                offset=0,
157                filter="",
158                sort="",
159                direction="",
160            ))
161        result, more = self.eb.Get(
162            request=request_body,
163            metaclass=model.DeviceModel,
164            message=model.DeviceInfoMessage)
165        self.assertEqual(len(result), 60)
166        self.assertTrue(more)
167
168        request_body = (endpoints.ResourceContainer(
169            model.GetRequestMessage).combined_message_class(
170                size=100,
171                offset=60,
172                filter="",
173                sort="",
174                direction="",
175            ))
176        result, more = self.eb.Get(
177            request=request_body,
178            metaclass=model.DeviceModel,
179            message=model.DeviceInfoMessage)
180        self.assertEqual(len(result), 40)
181        self.assertFalse(more)
182
183    def testGetWithFilter(self):
184        """Asserts to get entities with filter."""
185        for _ in xrange(50):
186            device = self.GenerateDeviceModel()
187            device.put()
188
189        for _ in xrange(50):
190            device = self.GenerateDeviceModel(product="product")
191            device.put()
192
193        filter = [{
194            "key": "product",
195            "method": Status.FILTER_METHOD[Status.FILTER_EqualTo],
196            "value": "product"
197        }]
198        filter_string = json.dumps(filter)
199        request_body = (endpoints.ResourceContainer(
200            model.GetRequestMessage).combined_message_class(
201                size=0,
202                offset=0,
203                filter=filter_string,
204                sort="",
205                direction="",
206            ))
207        result, more = self.eb.Get(
208            request=request_body,
209            metaclass=model.DeviceModel,
210            message=model.DeviceInfoMessage)
211        self.assertEqual(len(result), 50)
212        self.assertFalse(more)
213
214    def testGetWithSort(self):
215        """Asserts to get entities with sort."""
216        for _ in xrange(100):
217            device = self.GenerateDeviceModel()
218            device.put()
219
220        request_body = (endpoints.ResourceContainer(
221            model.GetRequestMessage).combined_message_class(
222                size=0,
223                offset=0,
224                filter="",
225                sort="serial",
226                direction="asc",
227            ))
228
229        result, more = self.eb.Get(
230            request=request_body,
231            metaclass=model.DeviceModel,
232            message=model.DeviceInfoMessage)
233        self.assertEqual(len(result), 100)
234        for i in xrange(len(result) - 1):
235            self.assertTrue(result[i]["serial"] < result[i + 1]["serial"])
236
237        request_body = (endpoints.ResourceContainer(
238            model.GetRequestMessage).combined_message_class(
239                size=0,
240                offset=0,
241                filter="",
242                sort="serial",
243                direction="desc",
244            ))
245
246        result, more = self.eb.Get(
247            request=request_body,
248            metaclass=model.DeviceModel,
249            message=model.DeviceInfoMessage)
250        self.assertEqual(len(result), 100)
251        for i in xrange(len(result) - 1):
252            self.assertTrue(result[i]["serial"] > result[i + 1]["serial"])
253
254
255if __name__ == "__main__":
256    unittest.main()
257