1 /* 2 * Copyright (c) 2017 Mockito contributors 3 * This program is made available under the terms of the MIT License. 4 */ 5 package org.mockitoutil; 6 7 import java.io.ByteArrayInputStream; 8 import java.io.File; 9 import java.io.IOException; 10 import java.io.InputStream; 11 import java.lang.reflect.Field; 12 import java.lang.reflect.Modifier; 13 import java.net.MalformedURLException; 14 import java.net.URI; 15 import java.net.URISyntaxException; 16 import java.net.URL; 17 import java.net.URLClassLoader; 18 import java.net.URLConnection; 19 import java.net.URLStreamHandler; 20 import java.util.ArrayList; 21 import java.util.Arrays; 22 import java.util.Collections; 23 import java.util.Enumeration; 24 import java.util.HashMap; 25 import java.util.HashSet; 26 import java.util.Iterator; 27 import java.util.List; 28 import java.util.Map; 29 import java.util.Set; 30 import java.util.concurrent.ExecutionException; 31 import java.util.concurrent.ExecutorService; 32 import java.util.concurrent.Executors; 33 import java.util.concurrent.Future; 34 import java.util.concurrent.ThreadFactory; 35 import org.objenesis.Objenesis; 36 import org.objenesis.ObjenesisStd; 37 import org.objenesis.instantiator.ObjectInstantiator; 38 39 import static java.lang.String.format; 40 import static java.util.Arrays.asList; 41 42 public abstract class ClassLoaders { 43 protected ClassLoader parent = currentClassLoader(); 44 ClassLoaders()45 protected ClassLoaders() { 46 } 47 isolatedClassLoader()48 public static IsolatedURLClassLoaderBuilder isolatedClassLoader() { 49 return new IsolatedURLClassLoaderBuilder(); 50 } 51 excludingClassLoader()52 public static ExcludingURLClassLoaderBuilder excludingClassLoader() { 53 return new ExcludingURLClassLoaderBuilder(); 54 } 55 inMemoryClassLoader()56 public static InMemoryClassLoaderBuilder inMemoryClassLoader() { 57 return new InMemoryClassLoaderBuilder(); 58 } 59 in(ClassLoader classLoader)60 public static ReachableClassesFinder in(ClassLoader classLoader) { 61 return new ReachableClassesFinder(classLoader); 62 } 63 jdkClassLoader()64 public static ClassLoader jdkClassLoader() { 65 return String.class.getClassLoader(); 66 } 67 systemClassLoader()68 public static ClassLoader systemClassLoader() { 69 return ClassLoader.getSystemClassLoader(); 70 } 71 currentClassLoader()72 public static ClassLoader currentClassLoader() { 73 return ClassLoaders.class.getClassLoader(); 74 } 75 build()76 public abstract ClassLoader build(); 77 coverageTool()78 public static Class<?>[] coverageTool() { 79 HashSet<Class<?>> classes = new HashSet<Class<?>>(); 80 classes.add(safeGetClass("net.sourceforge.cobertura.coveragedata.TouchCollector")); 81 classes.add(safeGetClass("org.slf4j.LoggerFactory")); 82 83 classes.remove(null); 84 return classes.toArray(new Class<?>[classes.size()]); 85 } 86 safeGetClass(String className)87 private static Class<?> safeGetClass(String className) { 88 try { 89 return Class.forName(className); 90 } catch (ClassNotFoundException e) { 91 return null; 92 } 93 } 94 using(final ClassLoader classLoader)95 public static ClassLoaderExecutor using(final ClassLoader classLoader) { 96 return new ClassLoaderExecutor(classLoader); 97 } 98 99 public static class ClassLoaderExecutor { 100 private ClassLoader classLoader; 101 ClassLoaderExecutor(ClassLoader classLoader)102 public ClassLoaderExecutor(ClassLoader classLoader) { 103 this.classLoader = classLoader; 104 } 105 execute(final Runnable task)106 public void execute(final Runnable task) throws Exception { 107 ExecutorService executorService = Executors.newSingleThreadExecutor(new ThreadFactory() { 108 @Override 109 public Thread newThread(Runnable r) { 110 Thread thread = Executors.defaultThreadFactory().newThread(r); 111 thread.setContextClassLoader(classLoader); 112 return thread; 113 } 114 }); 115 try { 116 Future<?> taskFuture = executorService.submit(new Runnable() { 117 @Override 118 public void run() { 119 try { 120 reloadTaskInClassLoader(task).run(); 121 } catch (Throwable throwable) { 122 throw new IllegalStateException(format("Given task could not be loaded properly in the given classloader '%s', error '%s", 123 task, 124 throwable.getMessage()), 125 throwable); 126 } 127 } 128 }); 129 taskFuture.get(); 130 executorService.shutdownNow(); 131 } catch (InterruptedException e) { 132 Thread.currentThread().interrupt(); 133 } catch (ExecutionException e) { 134 throw this.<Exception>unwrapAndThrows(e); 135 } 136 } 137 138 @SuppressWarnings("unchecked") unwrapAndThrows(ExecutionException ex)139 private <T extends Throwable> T unwrapAndThrows(ExecutionException ex) throws T { 140 throw (T) ex.getCause(); 141 } 142 reloadTaskInClassLoader(Runnable task)143 Runnable reloadTaskInClassLoader(Runnable task) { 144 try { 145 @SuppressWarnings("unchecked") 146 Class<Runnable> taskClassReloaded = (Class<Runnable>) classLoader.loadClass(task.getClass().getName()); 147 148 Objenesis objenesis = new ObjenesisStd(); 149 ObjectInstantiator<Runnable> thingyInstantiator = objenesis.getInstantiatorOf(taskClassReloaded); 150 Runnable reloaded = thingyInstantiator.newInstance(); 151 152 // lenient shallow copy of class compatible fields 153 for (Field field : task.getClass().getDeclaredFields()) { 154 Field declaredField = taskClassReloaded.getDeclaredField(field.getName()); 155 int modifiers = declaredField.getModifiers(); 156 if(Modifier.isStatic(modifiers) && Modifier.isFinal(modifiers)) { 157 // Skip static final fields (e.g. jacoco fields) 158 // otherwise IllegalAccessException (can be bypassed with Unsafe though) 159 // We may also miss coverage data. 160 continue; 161 } 162 if (declaredField.getType() == field.getType()) { // don't copy this 163 field.setAccessible(true); 164 declaredField.setAccessible(true); 165 declaredField.set(reloaded, field.get(task)); 166 } 167 } 168 169 return reloaded; 170 } catch (ClassNotFoundException e) { 171 throw new IllegalStateException(e); 172 } catch (IllegalAccessException e) { 173 throw new IllegalStateException(e); 174 } catch (NoSuchFieldException e) { 175 throw new IllegalStateException(e); 176 } 177 } 178 } 179 180 public static class IsolatedURLClassLoaderBuilder extends ClassLoaders { 181 private final ArrayList<String> excludedPrefixes = new ArrayList<String>(); 182 private final ArrayList<String> privateCopyPrefixes = new ArrayList<String>(); 183 private final ArrayList<URL> codeSourceUrls = new ArrayList<URL>(); 184 withPrivateCopyOf(String... privatePrefixes)185 public IsolatedURLClassLoaderBuilder withPrivateCopyOf(String... privatePrefixes) { 186 privateCopyPrefixes.addAll(asList(privatePrefixes)); 187 return this; 188 } 189 withCodeSourceUrls(String... urls)190 public IsolatedURLClassLoaderBuilder withCodeSourceUrls(String... urls) { 191 codeSourceUrls.addAll(pathsToURLs(urls)); 192 return this; 193 } 194 withCodeSourceUrlOf(Class<?>.... classes)195 public IsolatedURLClassLoaderBuilder withCodeSourceUrlOf(Class<?>... classes) { 196 for (Class<?> clazz : classes) { 197 codeSourceUrls.add(obtainCurrentClassPathOf(clazz.getName())); 198 } 199 return this; 200 } 201 withCurrentCodeSourceUrls()202 public IsolatedURLClassLoaderBuilder withCurrentCodeSourceUrls() { 203 codeSourceUrls.add(obtainCurrentClassPathOf(ClassLoaders.class.getName())); 204 return this; 205 } 206 without(String... privatePrefixes)207 public IsolatedURLClassLoaderBuilder without(String... privatePrefixes) { 208 excludedPrefixes.addAll(asList(privatePrefixes)); 209 return this; 210 } 211 build()212 public ClassLoader build() { 213 return new LocalIsolatedURLClassLoader( 214 jdkClassLoader(), 215 codeSourceUrls.toArray(new URL[codeSourceUrls.size()]), 216 privateCopyPrefixes, 217 excludedPrefixes 218 ); 219 } 220 } 221 222 static class LocalIsolatedURLClassLoader extends URLClassLoader { 223 private final ArrayList<String> privateCopyPrefixes; 224 private final ArrayList<String> excludedPrefixes; 225 LocalIsolatedURLClassLoader(ClassLoader classLoader, URL[] urls, ArrayList<String> privateCopyPrefixes, ArrayList<String> excludedPrefixes)226 LocalIsolatedURLClassLoader(ClassLoader classLoader, 227 URL[] urls, 228 ArrayList<String> privateCopyPrefixes, 229 ArrayList<String> excludedPrefixes) { 230 super(urls, classLoader); 231 this.privateCopyPrefixes = privateCopyPrefixes; 232 this.excludedPrefixes = excludedPrefixes; 233 } 234 235 @Override findClass(String name)236 public Class<?> findClass(String name) throws ClassNotFoundException { 237 if (!classShouldBePrivate(name) || classShouldBeExcluded(name)) { 238 throw new ClassNotFoundException(format("Can only load classes with prefixes : %s, but not : %s", 239 privateCopyPrefixes, 240 excludedPrefixes)); 241 } 242 try { 243 return super.findClass(name); 244 } catch (ClassNotFoundException cnfe) { 245 throw new ClassNotFoundException(format("%s%n%s%n", 246 cnfe.getMessage(), 247 " Did you forgot to add the code source url 'withCodeSourceUrlOf' / 'withCurrentCodeSourceUrls' ?"), 248 cnfe); 249 } 250 } 251 classShouldBePrivate(String name)252 private boolean classShouldBePrivate(String name) { 253 for (String prefix : privateCopyPrefixes) { 254 if (name.startsWith(prefix)) return true; 255 } 256 return false; 257 } 258 classShouldBeExcluded(String name)259 private boolean classShouldBeExcluded(String name) { 260 for (String prefix : excludedPrefixes) { 261 if (name.startsWith(prefix)) return true; 262 } 263 return false; 264 } 265 } 266 267 public static class ExcludingURLClassLoaderBuilder extends ClassLoaders { 268 private final ArrayList<String> excludedPrefixes = new ArrayList<String>(); 269 private final ArrayList<URL> codeSourceUrls = new ArrayList<URL>(); 270 without(String... privatePrefixes)271 public ExcludingURLClassLoaderBuilder without(String... privatePrefixes) { 272 excludedPrefixes.addAll(asList(privatePrefixes)); 273 return this; 274 } 275 withCodeSourceUrls(String... urls)276 public ExcludingURLClassLoaderBuilder withCodeSourceUrls(String... urls) { 277 codeSourceUrls.addAll(pathsToURLs(urls)); 278 return this; 279 } 280 withCodeSourceUrlOf(Class<?>.... classes)281 public ExcludingURLClassLoaderBuilder withCodeSourceUrlOf(Class<?>... classes) { 282 for (Class<?> clazz : classes) { 283 codeSourceUrls.add(obtainCurrentClassPathOf(clazz.getName())); 284 } 285 return this; 286 } 287 withCurrentCodeSourceUrls()288 public ExcludingURLClassLoaderBuilder withCurrentCodeSourceUrls() { 289 codeSourceUrls.add(obtainCurrentClassPathOf(ClassLoaders.class.getName())); 290 return this; 291 } 292 build()293 public ClassLoader build() { 294 return new LocalExcludingURLClassLoader( 295 jdkClassLoader(), 296 codeSourceUrls.toArray(new URL[codeSourceUrls.size()]), 297 excludedPrefixes 298 ); 299 } 300 } 301 302 static class LocalExcludingURLClassLoader extends URLClassLoader { 303 private final ArrayList<String> excludedPrefixes; 304 LocalExcludingURLClassLoader(ClassLoader classLoader, URL[] urls, ArrayList<String> excludedPrefixes)305 LocalExcludingURLClassLoader(ClassLoader classLoader, 306 URL[] urls, 307 ArrayList<String> excludedPrefixes) { 308 super(urls, classLoader); 309 this.excludedPrefixes = excludedPrefixes; 310 } 311 312 @Override findClass(String name)313 public Class<?> findClass(String name) throws ClassNotFoundException { 314 if (classShouldBeExcluded(name)) 315 throw new ClassNotFoundException("classes with prefix : " + excludedPrefixes + " are excluded"); 316 return super.findClass(name); 317 } 318 classShouldBeExcluded(String name)319 private boolean classShouldBeExcluded(String name) { 320 for (String prefix : excludedPrefixes) { 321 if (name.startsWith(prefix)) return true; 322 } 323 return false; 324 } 325 } 326 327 public static class InMemoryClassLoaderBuilder extends ClassLoaders { 328 private Map<String, byte[]> inMemoryClassObjects = new HashMap<String, byte[]>(); 329 withParent(ClassLoader parent)330 public InMemoryClassLoaderBuilder withParent(ClassLoader parent) { 331 this.parent = parent; 332 return this; 333 } 334 withClassDefinition(String name, byte[] classDefinition)335 public InMemoryClassLoaderBuilder withClassDefinition(String name, byte[] classDefinition) { 336 inMemoryClassObjects.put(name, classDefinition); 337 return this; 338 } 339 build()340 public ClassLoader build() { 341 return new InMemoryClassLoader(parent, inMemoryClassObjects); 342 } 343 } 344 345 static class InMemoryClassLoader extends ClassLoader { 346 public static final String SCHEME = "mem"; 347 private Map<String, byte[]> inMemoryClassObjects = new HashMap<String, byte[]>(); 348 InMemoryClassLoader(ClassLoader parent, Map<String, byte[]> inMemoryClassObjects)349 public InMemoryClassLoader(ClassLoader parent, Map<String, byte[]> inMemoryClassObjects) { 350 super(parent); 351 this.inMemoryClassObjects = inMemoryClassObjects; 352 } 353 findClass(String name)354 protected Class<?> findClass(String name) throws ClassNotFoundException { 355 byte[] classDefinition = inMemoryClassObjects.get(name); 356 if (classDefinition != null) { 357 return defineClass(name, classDefinition, 0, classDefinition.length); 358 } 359 throw new ClassNotFoundException(name); 360 } 361 362 @Override getResources(String ignored)363 public Enumeration<URL> getResources(String ignored) throws IOException { 364 return inMemoryOnly(); 365 } 366 inMemoryOnly()367 private Enumeration<URL> inMemoryOnly() { 368 final Set<String> names = inMemoryClassObjects.keySet(); 369 return new Enumeration<URL>() { 370 private final MemHandler memHandler = new MemHandler(InMemoryClassLoader.this); 371 private final Iterator<String> it = names.iterator(); 372 373 public boolean hasMoreElements() { 374 return it.hasNext(); 375 } 376 377 public URL nextElement() { 378 try { 379 return new URL(null, SCHEME + ":" + it.next(), memHandler); 380 } catch (MalformedURLException rethrown) { 381 throw new IllegalStateException(rethrown); 382 } 383 } 384 }; 385 } 386 } 387 388 public static class MemHandler extends URLStreamHandler { 389 private InMemoryClassLoader inMemoryClassLoader; 390 MemHandler(InMemoryClassLoader inMemoryClassLoader)391 public MemHandler(InMemoryClassLoader inMemoryClassLoader) { 392 this.inMemoryClassLoader = inMemoryClassLoader; 393 } 394 395 @Override openConnection(URL url)396 protected URLConnection openConnection(URL url) throws IOException { 397 return new MemURLConnection(url, inMemoryClassLoader); 398 } 399 400 private static class MemURLConnection extends URLConnection { 401 private final InMemoryClassLoader inMemoryClassLoader; 402 private String qualifiedName; 403 MemURLConnection(URL url, InMemoryClassLoader inMemoryClassLoader)404 public MemURLConnection(URL url, InMemoryClassLoader inMemoryClassLoader) { 405 super(url); 406 this.inMemoryClassLoader = inMemoryClassLoader; 407 qualifiedName = url.getPath(); 408 } 409 410 @Override connect()411 public void connect() throws IOException { 412 } 413 414 @Override getInputStream()415 public InputStream getInputStream() throws IOException { 416 return new ByteArrayInputStream(inMemoryClassLoader.inMemoryClassObjects.get(qualifiedName)); 417 } 418 } 419 } 420 obtainCurrentClassPathOf(String className)421 URL obtainCurrentClassPathOf(String className) { 422 String path = className.replace('.', '/') + ".class"; 423 String url = ClassLoaders.class.getClassLoader().getResource(path).toExternalForm(); 424 425 try { 426 return new URL(url.substring(0, url.length() - path.length())); 427 } catch (MalformedURLException e) { 428 throw new RuntimeException("Classloader couldn't obtain a proper classpath URL", e); 429 } 430 } 431 pathsToURLs(String... codeSourceUrls)432 List<URL> pathsToURLs(String... codeSourceUrls) { 433 return pathsToURLs(Arrays.asList(codeSourceUrls)); 434 } 435 pathsToURLs(List<String> codeSourceUrls)436 private List<URL> pathsToURLs(List<String> codeSourceUrls) { 437 ArrayList<URL> urls = new ArrayList<URL>(codeSourceUrls.size()); 438 for (String codeSourceUrl : codeSourceUrls) { 439 URL url = pathToUrl(codeSourceUrl); 440 urls.add(url); 441 } 442 return urls; 443 } 444 pathToUrl(String path)445 private URL pathToUrl(String path) { 446 try { 447 return new File(path).getAbsoluteFile().toURI().toURL(); 448 } catch (MalformedURLException e) { 449 throw new IllegalArgumentException("Path is malformed", e); 450 } 451 } 452 453 public static class ReachableClassesFinder { 454 private ClassLoader classLoader; 455 private Set<String> qualifiedNameSubstring = new HashSet<String>(); 456 ReachableClassesFinder(ClassLoader classLoader)457 ReachableClassesFinder(ClassLoader classLoader) { 458 this.classLoader = classLoader; 459 } 460 omit(String... qualifiedNameSubstring)461 public ReachableClassesFinder omit(String... qualifiedNameSubstring) { 462 this.qualifiedNameSubstring.addAll(Arrays.asList(qualifiedNameSubstring)); 463 return this; 464 } 465 listOwnedClasses()466 public Set<String> listOwnedClasses() throws IOException, URISyntaxException { 467 Enumeration<URL> roots = classLoader.getResources(""); 468 469 Set<String> classes = new HashSet<String>(); 470 while (roots.hasMoreElements()) { 471 URI uri = roots.nextElement().toURI(); 472 473 if (uri.getScheme().equalsIgnoreCase("file")) { 474 addFromFileBasedClassLoader(classes, uri); 475 } else if (uri.getScheme().equalsIgnoreCase(InMemoryClassLoader.SCHEME)) { 476 addFromInMemoryBasedClassLoader(classes, uri); 477 } else { 478 throw new IllegalArgumentException(format("Given ClassLoader '%s' don't have reachable by File or vi ClassLoaders.inMemory", classLoader)); 479 } 480 } 481 return classes; 482 } 483 addFromFileBasedClassLoader(Set<String> classes, URI uri)484 private void addFromFileBasedClassLoader(Set<String> classes, URI uri) { 485 File root = new File(uri); 486 classes.addAll(findClassQualifiedNames(root, root, qualifiedNameSubstring)); 487 } 488 addFromInMemoryBasedClassLoader(Set<String> classes, URI uri)489 private void addFromInMemoryBasedClassLoader(Set<String> classes, URI uri) { 490 String qualifiedName = uri.getSchemeSpecificPart(); 491 if (excludes(qualifiedName, qualifiedNameSubstring)) { 492 classes.add(qualifiedName); 493 } 494 } 495 496 findClassQualifiedNames(File root, File file, Set<String> packageFilters)497 private Set<String> findClassQualifiedNames(File root, File file, Set<String> packageFilters) { 498 if (file.isDirectory()) { 499 File[] files = file.listFiles(); 500 Set<String> classes = new HashSet<String>(); 501 for (File children : files) { 502 classes.addAll(findClassQualifiedNames(root, children, packageFilters)); 503 } 504 return classes; 505 } else { 506 if (file.getName().endsWith(".class")) { 507 String qualifiedName = classNameFor(root, file); 508 if (excludes(qualifiedName, packageFilters)) { 509 return Collections.singleton(qualifiedName); 510 } 511 } 512 } 513 return Collections.emptySet(); 514 } 515 excludes(String qualifiedName, Set<String> packageFilters)516 private boolean excludes(String qualifiedName, Set<String> packageFilters) { 517 for (String filter : packageFilters) { 518 if (qualifiedName.contains(filter)) return false; 519 } 520 return true; 521 } 522 classNameFor(File root, File file)523 private String classNameFor(File root, File file) { 524 String temp = file.getAbsolutePath().substring(root.getAbsolutePath().length() + 1). 525 replace(File.separatorChar, '.'); 526 return temp.subSequence(0, temp.indexOf(".class")).toString(); 527 } 528 529 } 530 } 531