1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.adservices.utils;
18 
19 import static android.adservices.utils.Scenarios.FAKE_ADDRESS_1;
20 import static android.adservices.utils.Scenarios.FAKE_ADDRESS_2;
21 import static android.adservices.utils.Scenarios.TIMEOUT_SEC;
22 
23 import com.android.adservices.LoggerFactory;
24 
25 import com.google.common.collect.ImmutableMap;
26 import com.google.common.collect.ImmutableSet;
27 import com.google.mockwebserver.Dispatcher;
28 import com.google.mockwebserver.MockResponse;
29 import com.google.mockwebserver.MockWebServer;
30 import com.google.mockwebserver.RecordedRequest;
31 
32 import org.json.JSONException;
33 import org.testng.util.Strings;
34 
35 import java.io.IOException;
36 import java.net.MalformedURLException;
37 import java.net.URL;
38 import java.util.Map;
39 import java.util.Objects;
40 import java.util.concurrent.ConcurrentHashMap;
41 import java.util.concurrent.CountDownLatch;
42 import java.util.concurrent.TimeUnit;
43 
44 /**
45  * Helper class for running test scenarios.
46  *
47  * <p>The scenario files are stored in assets/data/scenarios/*.json as well as any supporting files
48  * in the data folder. Each scenario defines a set of request / response pairings which are used to
49  * configure a {@link MockWebServer} instance. This class is thread-local safe.
50  */
51 public class ScenarioDispatcher extends Dispatcher {
52 
53     private static final LoggerFactory.Logger sLogger = LoggerFactory.getFledgeLogger();
54     public static final String X_FLEDGE_BUYER_BIDDING_LOGIC_VERSION =
55             "x_fledge_buyer_bidding_logic_version";
56 
57     private final ImmutableMap<Scenario.Request, Scenario.MockResponse> mRequestToMockMap;
58     private final String mPrefix;
59     private final ImmutableMap<String, String> mSubstitutionVariables;
60 
61     // The value is not used and we always insert a fixed value of 1.
62     private final ConcurrentHashMap<String, Integer> mCalledPaths;
63 
64     private final CountDownLatch mUniqueCallCount;
65     private final URL mServerBaseURL;
66 
67     /**
68      * Get all paths of calls to this server that were expected.
69      *
70      * <p>These are defined by the `verify_called` and `verify_not_called` fields in the test
71      * scenario JSON files.
72      *
73      * @return String list of paths.
74      */
getVerifyCalledPaths()75     public ImmutableSet<String> getVerifyCalledPaths() {
76         ImmutableSet.Builder<String> builder = ImmutableSet.builder();
77         mRequestToMockMap.forEach(
78                 (s, mock) -> {
79                     if (mock.getShouldVerifyCalled()) {
80                         builder.add("/" + s.getRelativePath());
81                     }
82                 });
83         return builder.build();
84     }
85 
86     /**
87      * Get all paths of calls to this server that were NOT expected.
88      *
89      * @return String list of paths.
90      */
getVerifyNotCalledPaths()91     public ImmutableSet<String> getVerifyNotCalledPaths() {
92         ImmutableSet.Builder<String> builder = ImmutableSet.builder();
93         mRequestToMockMap.forEach(
94                 (s, mock) -> {
95                     if (mock.getShouldVerifyNotCalled()) {
96                         builder.add("/" + s.getRelativePath());
97                     }
98                 });
99         return builder.build();
100     }
101 
102     /**
103      * Return the base URL for the server.
104      *
105      * @return base URL for the server.
106      */
getBaseAddressWithPrefix()107     public URL getBaseAddressWithPrefix() {
108         try {
109             return new URL(mServerBaseURL + mPrefix);
110         } catch (MalformedURLException e) {
111             throw new RuntimeException(e);
112         }
113     }
114 
115     /**
116      * Get all paths of calls to this server that were expected.
117      *
118      * <p>These are defined by the `verify_called` and `verify_not_called` fields in the test
119      * scenario JSON files.
120      *
121      * @return String list of paths.
122      */
getCalledPaths()123     public ImmutableSet<String> getCalledPaths() throws InterruptedException {
124         sLogger.w("getCalledPaths() called");
125         if (!mUniqueCallCount.await(TIMEOUT_SEC, TimeUnit.SECONDS)) {
126             sLogger.w("Timeout reached in getCalledPaths()");
127         }
128         sLogger.w("getCalledPaths() returning with  size: %s", mCalledPaths.size());
129         return ImmutableSet.copyOf(mCalledPaths.keySet());
130     }
131 
ScenarioDispatcher(String scenarioPath, String prefix, URL serverBaseAddress)132     ScenarioDispatcher(String scenarioPath, String prefix, URL serverBaseAddress)
133             throws JSONException, IOException {
134         mPrefix = prefix;
135         sLogger.v(String.format("Setting up scenario with file: %s", scenarioPath));
136         mCalledPaths = new ConcurrentHashMap<>();
137         mServerBaseURL = serverBaseAddress;
138         // Needs HTTPS for real tests and HTTP for ScenarioDispatcher tests.
139         mSubstitutionVariables =
140                 ImmutableMap.of(
141                         "{base_url_with_prefix}",
142                         getBaseAddressWithPrefix().toString(),
143                         "{adtech1_url}",
144                         getBaseAddressWithPrefix().toString(),
145                         "{adtech2_url}",
146                         FAKE_ADDRESS_1 + mPrefix,
147                         "{adtech3_url}",
148                         FAKE_ADDRESS_2 + mPrefix);
149         mRequestToMockMap =
150                 ScenarioLoader.load(scenarioPath, mSubstitutionVariables).getScenarioMap();
151         mUniqueCallCount = new CountDownLatch(mRequestToMockMap.size());
152     }
153 
154     @Override
dispatch(RecordedRequest request)155     public MockResponse dispatch(RecordedRequest request) throws InterruptedException {
156         boolean hasSetBaseAddress = mServerBaseURL != null;
157         if (!hasSetBaseAddress) {
158             throw new IllegalStateException(
159                     "Cannot serve request as setServerBaseAddress() has not been called.");
160         }
161 
162         String path = pathWithoutPrefix(request.getPath());
163 
164         for (Scenario.Request mockRequest : mRequestToMockMap.keySet()) {
165             String mockPath = mockRequest.getRelativePath();
166             if (isMatchingPath(request.getPath(), mockPath)) {
167                 Scenario.MockResponse mock = mRequestToMockMap.get(mockRequest);
168                 Scenario.Response mockResponse = Objects.requireNonNull(mock).getDefaultResponse();
169                 String body = mockResponse.getBody();
170                 for (Map.Entry<String, String> keyValuePair : mSubstitutionVariables.entrySet()) {
171                     body = body.replace(keyValuePair.getKey(), keyValuePair.getValue());
172                 }
173 
174                 // Sleep if necessary. Will default to 0 if not provided.
175                 Thread.sleep(mockResponse.getDelaySeconds() * 1000L);
176 
177                 // If the mock path specifically has query params, then add that, otherwise strip
178                 // them before adding them to the log.
179                 // This behaviour matches the existing test server functionality.
180                 recordCalledPath(
181                         String.format(
182                                 "/%s",
183                                 hasQueryParams(mockPath)
184                                         ? mockPath
185                                         : pathWithoutQueryParams(path)));
186                 MockResponse response = new MockResponse().setBody(body).setResponseCode(200);
187                 for (Map.Entry<String, String> mockHeader : mockResponse.getHeaders().entrySet()) {
188                     sLogger.v(
189                             "Adding header %s with value %s",
190                             mockHeader.getKey(), mockHeader.getValue());
191                     response.addHeader(mockHeader.getKey(), mockHeader.getValue());
192                 }
193                 sLogger.v("serving path at %s with response %s", path, response.toString());
194                 return response;
195             }
196         }
197 
198         // For any requests that weren't specifically overloaded with query params to be handled,
199         // always strip them when adding them to the log.
200         // This behaviour matches the existing test server functionality.
201         recordCalledPath("/" + pathWithoutQueryParams(path));
202         sLogger.v("serving path at %s (404)", path);
203         return new MockResponse().setResponseCode(404);
204     }
205 
206 
recordCalledPath(String path)207     private synchronized void recordCalledPath(String path) {
208         if (mCalledPaths.containsKey(path)) {
209             sLogger.v(
210                     "Not recording path called at %s as already hit, latch count is %d/%d.",
211                     path, mUniqueCallCount.getCount(), mRequestToMockMap.size());
212         } else {
213             mCalledPaths.put(path, 1);
214             mUniqueCallCount.countDown();
215             sLogger.v(
216                     "Recorded path called at %s, latch count is %d/%d.",
217                     path, mUniqueCallCount.getCount(), mRequestToMockMap.size());
218         }
219     }
220 
isMatchingPath(String path, String mockPath)221     private boolean isMatchingPath(String path, String mockPath) {
222         return pathWithoutQueryParams(pathWithoutPrefix(path)).equals(mockPath)
223                 || pathWithoutPrefix(path).equals(mockPath);
224     }
225 
hasQueryParams(String path)226     private boolean hasQueryParams(String path) {
227         return path.contains("?");
228     }
229 
pathWithoutQueryParams(String path)230     private String pathWithoutQueryParams(String path) {
231         return path.split("\\?")[0];
232     }
233 
pathWithoutPrefix(String path)234     private String pathWithoutPrefix(String path) {
235         if (Strings.isNullOrEmpty(mPrefix)) {
236             // Only remove the first redundant "/" if no prefix is explicitly defined.
237             return path.substring(1);
238         }
239         return path.replaceFirst(mPrefix + "/", "");
240     }
241 }
242