1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.platform.test.flag.junit.host;
18 
19 import android.platform.test.flag.junit.CheckFlagsRule;
20 import android.platform.test.flag.junit.IFlagsValueProvider;
21 import android.platform.test.flag.util.FlagReadException;
22 
23 import com.android.tradefed.device.ITestDevice;
24 
25 import com.google.common.cache.CacheBuilder;
26 import com.google.common.cache.CacheLoader;
27 import com.google.common.cache.LoadingCache;
28 
29 import java.util.HashMap;
30 import java.util.Map;
31 import java.util.concurrent.ExecutionException;
32 import java.util.function.Supplier;
33 
34 /** A {@code IFlagsValueProvider} which provides flag values from host side. */
35 public class HostFlagsValueProvider implements IFlagsValueProvider {
36     /** The key is the device serial number. */
37     private static final LoadingCache<String, DeviceFlags> CACHED_DEVICE_FLAGS;
38 
39     /** The key is the device serial number. */
40     private static final Map<String, ITestDevice> TEST_DEVICES = new HashMap<>();
41 
42     static {
43         CacheLoader<String, DeviceFlags> cacheLoader =
44                 new CacheLoader<String, DeviceFlags>() {
45                     @Override
46                     public DeviceFlags load(String deviceSerial) throws FlagReadException {
47                         if (!TEST_DEVICES.containsKey(deviceSerial)) {
48                             throw new IllegalStateException(
49                                     String.format(
50                                             "No ITestDevice found for serial %s.", deviceSerial));
51                         }
52                         return DeviceFlags.createDeviceFlags(TEST_DEVICES.get(deviceSerial));
53                     }
54                 };
55 
56         CACHED_DEVICE_FLAGS = CacheBuilder.newBuilder().build(cacheLoader);
57     }
58 
59     private final Supplier<ITestDevice> mTestDeviceSupplier;
60     private DeviceFlags mDeviceFlags;
61 
HostFlagsValueProvider(Supplier<ITestDevice> testDeviceSupplier)62     HostFlagsValueProvider(Supplier<ITestDevice> testDeviceSupplier) {
63         mTestDeviceSupplier = testDeviceSupplier;
64     }
65 
createCheckFlagsRule(Supplier<ITestDevice> testDeviceSupplier)66     public static CheckFlagsRule createCheckFlagsRule(Supplier<ITestDevice> testDeviceSupplier) {
67         return new CheckFlagsRule(new HostFlagsValueProvider(testDeviceSupplier));
68     }
69 
70     /**
71      * Refreshes all flag values of a given device. Must be called when the device flags have been
72      * changed.
73      */
refreshFlagsCache(String serial)74     public static void refreshFlagsCache(String serial) throws FlagReadException {
75         Map<String, DeviceFlags> cachedDeviceFlagsMap = CACHED_DEVICE_FLAGS.asMap();
76         if (cachedDeviceFlagsMap.containsKey(serial)) {
77             cachedDeviceFlagsMap.get(serial).init(TEST_DEVICES.get(serial));
78         }
79     }
80 
81     @Override
setUp()82     public void setUp() throws FlagReadException {
83         ITestDevice testDevice = mTestDeviceSupplier.get();
84         TEST_DEVICES.put(testDevice.getSerialNumber(), testDevice);
85 
86         try {
87             mDeviceFlags = CACHED_DEVICE_FLAGS.get(testDevice.getSerialNumber());
88         } catch (ExecutionException e) {
89             throw new FlagReadException("ALL_FLAGS", e);
90         }
91     }
92 
93     @Override
getBoolean(String flag)94     public boolean getBoolean(String flag) throws FlagReadException {
95         String value = mDeviceFlags.getFlagValue(flag);
96 
97         if (value == null) {
98             throw new FlagReadException(flag, "Flag does not exist.");
99         }
100 
101         if (!IFlagsValueProvider.isBooleanValue(value)) {
102             throw new FlagReadException(
103                     flag, String.format("Flag value %s is not a boolean", value));
104         }
105         return Boolean.valueOf(value);
106     }
107 }
108