1#!/usr/bin/env python
2#
3# Copyright (C) 2018 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import unittest
19
20try:
21    from unittest import mock
22except ImportError:
23    import mock
24
25from webapp.src.endpoint import lab_info
26from webapp.src.proto import model
27from webapp.src.testing import unittest_base
28
29
30class LabInfoTest(unittest_base.UnitTestBase):
31    """A class to test lab_info endpoint API."""
32
33    def setUp(self):
34        """Initializes test"""
35        super(LabInfoTest, self).setUp()
36
37    def testUpdateErrorDevice(self):
38        """Asserts that device update does not create a duplicate."""
39        device_serial = self.GetRandomString()
40        product = self.GetRandomString()
41        device_equipment = [self.GetRandomString()]
42        device_info = {
43            "serial": device_serial,
44            "product": product,
45            "device_equipment": device_equipment
46        }
47
48        hostname = self.GetRandomString()
49        host_info = {
50            "hostname": hostname,
51            "ip": self.GetRandomString(),
52            "script": self.GetRandomString(),
53            "device": [device_info],
54            "vtslab_version": self.GetRandomString(),
55            "host_equipment": [],
56        }
57
58        lab_name = self.GetRandomString()
59        container = (
60            lab_info.LAB_INFO_RESOURCE.combined_message_class(
61                name=lab_name,
62                owner=self.GetRandomString(),
63                admin=[self.GetRandomString()],
64                host=[host_info],
65            ))
66
67        api = lab_info.LabInfoApi()
68        api.set(container)
69
70        devices = model.DeviceModel.query().fetch()
71        self.assertEqual(len(devices), 1)
72        self.assertEqual(devices[0].product, product)
73
74        # change device product name.
75        devices[0].product = "error"
76        devices[0].put()
77
78        api.set(container)
79
80        devices = model.DeviceModel.query().fetch()
81        # there should not be duplicates.
82        self.assertEqual(len(devices), 1)
83        # stored device name should be kept.
84        self.assertEqual(devices[0].product, "error")
85
86
87    def testUpdateExistingDevice(self):
88        """Asserts that device update does not create a duplicate."""
89        device_serial = self.GetRandomString()
90        product = self.GetRandomString()
91        device_equipment = [self.GetRandomString()]
92        device_info = {
93            "serial": device_serial,
94            "product": product,
95            "device_equipment": device_equipment,
96        }
97
98        hostname = self.GetRandomString()
99        host_info = {
100            "hostname": hostname,
101            "ip": self.GetRandomString(),
102            "script": self.GetRandomString(),
103            "device": [device_info],
104            "vtslab_version": self.GetRandomString(),
105            "host_equipment": [],
106        }
107
108        lab_name = self.GetRandomString()
109        container = (
110            lab_info.LAB_INFO_RESOURCE.combined_message_class(
111                name=lab_name,
112                owner=self.GetRandomString(),
113                admin=[self.GetRandomString()],
114                host=[host_info],
115            ))
116
117        device = self.GenerateDeviceModel(product="error",
118                                          serial=device_serial,
119                                          hostname=hostname)
120        device.put()
121
122        api = lab_info.LabInfoApi()
123        api.set(container)
124
125        devices = model.DeviceModel.query().fetch()
126        self.assertEqual(len(devices), 1)
127
128        # stored device name should be kept.
129        self.assertEqual(devices[0].product, "error")
130
131        # device equipment should be updated.
132        self.assertEqual(set(devices[0].device_equipment),
133                         set(device_equipment))
134
135
136if __name__ == "__main__":
137    unittest.main()
138