1#!/usr/bin/env python
2#
3# Copyright (C) 2018 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import unittest
19
20try:
21    from unittest import mock
22except ImportError:
23    import mock
24
25from webapp.src import vtslab_status as Status
26from webapp.src.endpoint import host_info
27from webapp.src.proto import model
28from webapp.src.testing import unittest_base
29
30
31class HostInfoTest(unittest_base.UnitTestBase):
32    """A class to test host_info endpoint API."""
33
34    def setUp(self):
35        """Initializes test"""
36        super(HostInfoTest, self).setUp()
37
38
39    def testUpdateExistingDevice(self):
40        """Asserts that device update does not create a duplicate."""
41        hostname = self.GetRandomString()
42        serial = self.GetRandomString()
43        product = self.GetRandomString()
44        error_device = {
45            "serial": serial,
46            "product": "error",
47        }
48        container = (
49            host_info.HOST_INFO_RESOURCE.combined_message_class(
50                hostname=hostname,
51                devices=[error_device],
52            ))
53
54        api = host_info.HostInfoApi()
55        api.set(container)
56
57        devices = model.DeviceModel.query().fetch()
58        self.assertEqual(len(devices), 1)
59
60        # name "error" is allowed as initial name.
61        self.assertEqual(devices[0].product, "error")
62
63        correct_device = {
64            "serial": serial,
65            "product": product,
66        }
67        container = (
68            host_info.HOST_INFO_RESOURCE.combined_message_class(
69                hostname=hostname,
70                devices=[correct_device],
71            ))
72        api.set(container)
73
74        devices = model.DeviceModel.query().fetch()
75        self.assertEqual(len(devices), 1)
76        # correct product name (which is not "error") should be overwritten.
77        self.assertEqual(devices[0].product, product)
78
79        container = (
80            host_info.HOST_INFO_RESOURCE.combined_message_class(
81                hostname=hostname,
82                devices=[error_device],
83            ))
84        api.set(container)
85
86        devices = model.DeviceModel.query().fetch()
87        self.assertEqual(len(devices), 1)
88        # "error" should be ignored.
89        self.assertEqual(devices[0].product, product)
90
91
92if __name__ == "__main__":
93    unittest.main()
94