1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.adservices.utils; 18 19 import static android.adservices.utils.Scenarios.FAKE_ADDRESS_1; 20 import static android.adservices.utils.Scenarios.FAKE_ADDRESS_2; 21 import static android.adservices.utils.Scenarios.TIMEOUT_SEC; 22 23 import com.android.adservices.LoggerFactory; 24 25 import com.google.common.collect.ImmutableMap; 26 import com.google.common.collect.ImmutableSet; 27 import com.google.mockwebserver.Dispatcher; 28 import com.google.mockwebserver.MockResponse; 29 import com.google.mockwebserver.MockWebServer; 30 import com.google.mockwebserver.RecordedRequest; 31 32 import org.json.JSONException; 33 import org.testng.util.Strings; 34 35 import java.io.IOException; 36 import java.net.MalformedURLException; 37 import java.net.URL; 38 import java.util.Map; 39 import java.util.Objects; 40 import java.util.concurrent.ConcurrentHashMap; 41 import java.util.concurrent.CountDownLatch; 42 import java.util.concurrent.TimeUnit; 43 44 /** 45 * Helper class for running test scenarios. 46 * 47 * <p>The scenario files are stored in assets/data/scenarios/*.json as well as any supporting files 48 * in the data folder. Each scenario defines a set of request / response pairings which are used to 49 * configure a {@link MockWebServer} instance. This class is thread-local safe. 50 */ 51 public class ScenarioDispatcher extends Dispatcher { 52 53 private static final LoggerFactory.Logger sLogger = LoggerFactory.getFledgeLogger(); 54 public static final String X_FLEDGE_BUYER_BIDDING_LOGIC_VERSION = 55 "x_fledge_buyer_bidding_logic_version"; 56 57 private final ImmutableMap<Scenario.Request, Scenario.MockResponse> mRequestToMockMap; 58 private final String mPrefix; 59 private final ImmutableMap<String, String> mSubstitutionVariables; 60 61 // The value is not used and we always insert a fixed value of 1. 62 private final ConcurrentHashMap<String, Integer> mCalledPaths; 63 64 private final CountDownLatch mUniqueCallCount; 65 private final URL mServerBaseURL; 66 67 /** 68 * Get all paths of calls to this server that were expected. 69 * 70 * <p>These are defined by the `verify_called` and `verify_not_called` fields in the test 71 * scenario JSON files. 72 * 73 * @return String list of paths. 74 */ getVerifyCalledPaths()75 public ImmutableSet<String> getVerifyCalledPaths() { 76 ImmutableSet.Builder<String> builder = ImmutableSet.builder(); 77 mRequestToMockMap.forEach( 78 (s, mock) -> { 79 if (mock.getShouldVerifyCalled()) { 80 builder.add("/" + s.getRelativePath()); 81 } 82 }); 83 return builder.build(); 84 } 85 86 /** 87 * Get all paths of calls to this server that were NOT expected. 88 * 89 * @return String list of paths. 90 */ getVerifyNotCalledPaths()91 public ImmutableSet<String> getVerifyNotCalledPaths() { 92 ImmutableSet.Builder<String> builder = ImmutableSet.builder(); 93 mRequestToMockMap.forEach( 94 (s, mock) -> { 95 if (mock.getShouldVerifyNotCalled()) { 96 builder.add("/" + s.getRelativePath()); 97 } 98 }); 99 return builder.build(); 100 } 101 102 /** 103 * Return the base URL for the server. 104 * 105 * @return base URL for the server. 106 */ getBaseAddressWithPrefix()107 public URL getBaseAddressWithPrefix() { 108 try { 109 return new URL(mServerBaseURL + mPrefix); 110 } catch (MalformedURLException e) { 111 throw new RuntimeException(e); 112 } 113 } 114 115 /** 116 * Get all paths of calls to this server that were expected. 117 * 118 * <p>These are defined by the `verify_called` and `verify_not_called` fields in the test 119 * scenario JSON files. 120 * 121 * @return String list of paths. 122 */ getCalledPaths()123 public ImmutableSet<String> getCalledPaths() throws InterruptedException { 124 sLogger.w("getCalledPaths() called"); 125 if (!mUniqueCallCount.await(TIMEOUT_SEC, TimeUnit.SECONDS)) { 126 sLogger.w("Timeout reached in getCalledPaths()"); 127 } 128 sLogger.w("getCalledPaths() returning with size: %s", mCalledPaths.size()); 129 return ImmutableSet.copyOf(mCalledPaths.keySet()); 130 } 131 ScenarioDispatcher(String scenarioPath, String prefix, URL serverBaseAddress)132 ScenarioDispatcher(String scenarioPath, String prefix, URL serverBaseAddress) 133 throws JSONException, IOException { 134 mPrefix = prefix; 135 sLogger.v(String.format("Setting up scenario with file: %s", scenarioPath)); 136 mCalledPaths = new ConcurrentHashMap<>(); 137 mServerBaseURL = serverBaseAddress; 138 // Needs HTTPS for real tests and HTTP for ScenarioDispatcher tests. 139 mSubstitutionVariables = 140 ImmutableMap.of( 141 "{base_url_with_prefix}", 142 getBaseAddressWithPrefix().toString(), 143 "{adtech1_url}", 144 getBaseAddressWithPrefix().toString(), 145 "{adtech2_url}", 146 FAKE_ADDRESS_1 + mPrefix, 147 "{adtech3_url}", 148 FAKE_ADDRESS_2 + mPrefix); 149 mRequestToMockMap = 150 ScenarioLoader.load(scenarioPath, mSubstitutionVariables).getScenarioMap(); 151 mUniqueCallCount = new CountDownLatch(mRequestToMockMap.size()); 152 } 153 154 @Override dispatch(RecordedRequest request)155 public MockResponse dispatch(RecordedRequest request) throws InterruptedException { 156 boolean hasSetBaseAddress = mServerBaseURL != null; 157 if (!hasSetBaseAddress) { 158 throw new IllegalStateException( 159 "Cannot serve request as setServerBaseAddress() has not been called."); 160 } 161 162 String path = pathWithoutPrefix(request.getPath()); 163 164 for (Scenario.Request mockRequest : mRequestToMockMap.keySet()) { 165 String mockPath = mockRequest.getRelativePath(); 166 if (isMatchingPath(request.getPath(), mockPath)) { 167 Scenario.MockResponse mock = mRequestToMockMap.get(mockRequest); 168 Scenario.Response mockResponse = Objects.requireNonNull(mock).getDefaultResponse(); 169 String body = mockResponse.getBody(); 170 for (Map.Entry<String, String> keyValuePair : mSubstitutionVariables.entrySet()) { 171 body = body.replace(keyValuePair.getKey(), keyValuePair.getValue()); 172 } 173 174 // Sleep if necessary. Will default to 0 if not provided. 175 Thread.sleep(mockResponse.getDelaySeconds() * 1000L); 176 177 // If the mock path specifically has query params, then add that, otherwise strip 178 // them before adding them to the log. 179 // This behaviour matches the existing test server functionality. 180 recordCalledPath( 181 String.format( 182 "/%s", 183 hasQueryParams(mockPath) 184 ? mockPath 185 : pathWithoutQueryParams(path))); 186 MockResponse response = new MockResponse().setBody(body).setResponseCode(200); 187 for (Map.Entry<String, String> mockHeader : mockResponse.getHeaders().entrySet()) { 188 sLogger.v( 189 "Adding header %s with value %s", 190 mockHeader.getKey(), mockHeader.getValue()); 191 response.addHeader(mockHeader.getKey(), mockHeader.getValue()); 192 } 193 sLogger.v("serving path at %s with response %s", path, response.toString()); 194 return response; 195 } 196 } 197 198 // For any requests that weren't specifically overloaded with query params to be handled, 199 // always strip them when adding them to the log. 200 // This behaviour matches the existing test server functionality. 201 recordCalledPath("/" + pathWithoutQueryParams(path)); 202 sLogger.v("serving path at %s (404)", path); 203 return new MockResponse().setResponseCode(404); 204 } 205 206 recordCalledPath(String path)207 private synchronized void recordCalledPath(String path) { 208 if (mCalledPaths.containsKey(path)) { 209 sLogger.v( 210 "Not recording path called at %s as already hit, latch count is %d/%d.", 211 path, mUniqueCallCount.getCount(), mRequestToMockMap.size()); 212 } else { 213 mCalledPaths.put(path, 1); 214 mUniqueCallCount.countDown(); 215 sLogger.v( 216 "Recorded path called at %s, latch count is %d/%d.", 217 path, mUniqueCallCount.getCount(), mRequestToMockMap.size()); 218 } 219 } 220 isMatchingPath(String path, String mockPath)221 private boolean isMatchingPath(String path, String mockPath) { 222 return pathWithoutQueryParams(pathWithoutPrefix(path)).equals(mockPath) 223 || pathWithoutPrefix(path).equals(mockPath); 224 } 225 hasQueryParams(String path)226 private boolean hasQueryParams(String path) { 227 return path.contains("?"); 228 } 229 pathWithoutQueryParams(String path)230 private String pathWithoutQueryParams(String path) { 231 return path.split("\\?")[0]; 232 } 233 pathWithoutPrefix(String path)234 private String pathWithoutPrefix(String path) { 235 if (Strings.isNullOrEmpty(mPrefix)) { 236 // Only remove the first redundant "/" if no prefix is explicitly defined. 237 return path.substring(1); 238 } 239 return path.replaceFirst(mPrefix + "/", ""); 240 } 241 } 242