1 /*
2  * Copyright (c) 2017 Mockito contributors
3  * This program is made available under the terms of the MIT License.
4  */
5 package org.mockitoutil;
6 
7 import java.io.ByteArrayInputStream;
8 import java.io.File;
9 import java.io.IOException;
10 import java.io.InputStream;
11 import java.lang.reflect.Field;
12 import java.lang.reflect.Modifier;
13 import java.net.MalformedURLException;
14 import java.net.URI;
15 import java.net.URISyntaxException;
16 import java.net.URL;
17 import java.net.URLClassLoader;
18 import java.net.URLConnection;
19 import java.net.URLStreamHandler;
20 import java.util.ArrayList;
21 import java.util.Arrays;
22 import java.util.Collections;
23 import java.util.Enumeration;
24 import java.util.HashMap;
25 import java.util.HashSet;
26 import java.util.Iterator;
27 import java.util.List;
28 import java.util.Map;
29 import java.util.Set;
30 import java.util.concurrent.ExecutionException;
31 import java.util.concurrent.ExecutorService;
32 import java.util.concurrent.Executors;
33 import java.util.concurrent.Future;
34 import java.util.concurrent.ThreadFactory;
35 import org.objenesis.Objenesis;
36 import org.objenesis.ObjenesisStd;
37 import org.objenesis.instantiator.ObjectInstantiator;
38 
39 import static java.lang.String.format;
40 import static java.util.Arrays.asList;
41 
42 public abstract class ClassLoaders {
43     protected ClassLoader parent = currentClassLoader();
44 
ClassLoaders()45     protected ClassLoaders() {
46     }
47 
isolatedClassLoader()48     public static IsolatedURLClassLoaderBuilder isolatedClassLoader() {
49         return new IsolatedURLClassLoaderBuilder();
50     }
51 
excludingClassLoader()52     public static ExcludingURLClassLoaderBuilder excludingClassLoader() {
53         return new ExcludingURLClassLoaderBuilder();
54     }
55 
inMemoryClassLoader()56     public static InMemoryClassLoaderBuilder inMemoryClassLoader() {
57         return new InMemoryClassLoaderBuilder();
58     }
59 
in(ClassLoader classLoader)60     public static ReachableClassesFinder in(ClassLoader classLoader) {
61         return new ReachableClassesFinder(classLoader);
62     }
63 
jdkClassLoader()64     public static ClassLoader jdkClassLoader() {
65         return String.class.getClassLoader();
66     }
67 
systemClassLoader()68     public static ClassLoader systemClassLoader() {
69         return ClassLoader.getSystemClassLoader();
70     }
71 
currentClassLoader()72     public static ClassLoader currentClassLoader() {
73         return ClassLoaders.class.getClassLoader();
74     }
75 
build()76     public abstract ClassLoader build();
77 
coverageTool()78     public static Class<?>[] coverageTool() {
79         HashSet<Class<?>> classes = new HashSet<Class<?>>();
80         classes.add(safeGetClass("net.sourceforge.cobertura.coveragedata.TouchCollector"));
81         classes.add(safeGetClass("org.slf4j.LoggerFactory"));
82 
83         classes.remove(null);
84         return classes.toArray(new Class<?>[classes.size()]);
85     }
86 
safeGetClass(String className)87     private static Class<?> safeGetClass(String className) {
88         try {
89             return Class.forName(className);
90         } catch (ClassNotFoundException e) {
91             return null;
92         }
93     }
94 
using(final ClassLoader classLoader)95     public static ClassLoaderExecutor using(final ClassLoader classLoader) {
96         return new ClassLoaderExecutor(classLoader);
97     }
98 
99     public static class ClassLoaderExecutor {
100         private ClassLoader classLoader;
101 
ClassLoaderExecutor(ClassLoader classLoader)102         public ClassLoaderExecutor(ClassLoader classLoader) {
103             this.classLoader = classLoader;
104         }
105 
execute(final Runnable task)106         public void execute(final Runnable task) throws Exception {
107             ExecutorService executorService = Executors.newSingleThreadExecutor(new ThreadFactory() {
108                 @Override
109                 public Thread newThread(Runnable r) {
110                     Thread thread = Executors.defaultThreadFactory().newThread(r);
111                     thread.setContextClassLoader(classLoader);
112                     return thread;
113                 }
114             });
115             try {
116                 Future<?> taskFuture = executorService.submit(new Runnable() {
117                     @Override
118                     public void run() {
119                         try {
120                             reloadTaskInClassLoader(task).run();
121                         } catch (Throwable throwable) {
122                             throw new IllegalStateException(format("Given task could not be loaded properly in the given classloader '%s', error '%s",
123                                                                    task,
124                                                                    throwable.getMessage()),
125                                                             throwable);
126                         }
127                     }
128                 });
129                 taskFuture.get();
130                 executorService.shutdownNow();
131             } catch (InterruptedException e) {
132                 Thread.currentThread().interrupt();
133             } catch (ExecutionException e) {
134                 throw this.<Exception>unwrapAndThrows(e);
135             }
136         }
137 
138         @SuppressWarnings("unchecked")
unwrapAndThrows(ExecutionException ex)139         private <T extends Throwable> T unwrapAndThrows(ExecutionException ex) throws T {
140             throw (T) ex.getCause();
141         }
142 
reloadTaskInClassLoader(Runnable task)143         Runnable reloadTaskInClassLoader(Runnable task) {
144             try {
145                 @SuppressWarnings("unchecked")
146                 Class<Runnable> taskClassReloaded = (Class<Runnable>) classLoader.loadClass(task.getClass().getName());
147 
148                 Objenesis objenesis = new ObjenesisStd();
149                 ObjectInstantiator<Runnable> thingyInstantiator = objenesis.getInstantiatorOf(taskClassReloaded);
150                 Runnable reloaded = thingyInstantiator.newInstance();
151 
152                 // lenient shallow copy of class compatible fields
153                 for (Field field : task.getClass().getDeclaredFields()) {
154                     Field declaredField = taskClassReloaded.getDeclaredField(field.getName());
155                     int modifiers = declaredField.getModifiers();
156                     if(Modifier.isStatic(modifiers) && Modifier.isFinal(modifiers)) {
157                         // Skip static final fields (e.g. jacoco fields)
158                         // otherwise IllegalAccessException (can be bypassed with Unsafe though)
159                         // We may also miss coverage data.
160                         continue;
161                     }
162                     if (declaredField.getType() == field.getType()) { // don't copy this
163                         field.setAccessible(true);
164                         declaredField.setAccessible(true);
165                         declaredField.set(reloaded, field.get(task));
166                     }
167                 }
168 
169                 return reloaded;
170             } catch (ClassNotFoundException e) {
171                 throw new IllegalStateException(e);
172             } catch (IllegalAccessException e) {
173                 throw new IllegalStateException(e);
174             } catch (NoSuchFieldException e) {
175                 throw new IllegalStateException(e);
176             }
177         }
178     }
179 
180     public static class IsolatedURLClassLoaderBuilder extends ClassLoaders {
181         private final ArrayList<String> excludedPrefixes = new ArrayList<String>();
182         private final ArrayList<String> privateCopyPrefixes = new ArrayList<String>();
183         private final ArrayList<URL> codeSourceUrls = new ArrayList<URL>();
184 
withPrivateCopyOf(String... privatePrefixes)185         public IsolatedURLClassLoaderBuilder withPrivateCopyOf(String... privatePrefixes) {
186             privateCopyPrefixes.addAll(asList(privatePrefixes));
187             return this;
188         }
189 
withCodeSourceUrls(String... urls)190         public IsolatedURLClassLoaderBuilder withCodeSourceUrls(String... urls) {
191             codeSourceUrls.addAll(pathsToURLs(urls));
192             return this;
193         }
194 
withCodeSourceUrlOf(Class<?>.... classes)195         public IsolatedURLClassLoaderBuilder withCodeSourceUrlOf(Class<?>... classes) {
196             for (Class<?> clazz : classes) {
197                 codeSourceUrls.add(obtainCurrentClassPathOf(clazz.getName()));
198             }
199             return this;
200         }
201 
withCurrentCodeSourceUrls()202         public IsolatedURLClassLoaderBuilder withCurrentCodeSourceUrls() {
203             codeSourceUrls.add(obtainCurrentClassPathOf(ClassLoaders.class.getName()));
204             return this;
205         }
206 
without(String... privatePrefixes)207         public IsolatedURLClassLoaderBuilder without(String... privatePrefixes) {
208             excludedPrefixes.addAll(asList(privatePrefixes));
209             return this;
210         }
211 
build()212         public ClassLoader build() {
213             return new LocalIsolatedURLClassLoader(
214                     jdkClassLoader(),
215                     codeSourceUrls.toArray(new URL[codeSourceUrls.size()]),
216                     privateCopyPrefixes,
217                     excludedPrefixes
218             );
219         }
220     }
221 
222     static class LocalIsolatedURLClassLoader extends URLClassLoader {
223         private final ArrayList<String> privateCopyPrefixes;
224         private final ArrayList<String> excludedPrefixes;
225 
LocalIsolatedURLClassLoader(ClassLoader classLoader, URL[] urls, ArrayList<String> privateCopyPrefixes, ArrayList<String> excludedPrefixes)226         LocalIsolatedURLClassLoader(ClassLoader classLoader,
227                                     URL[] urls,
228                                     ArrayList<String> privateCopyPrefixes,
229                                     ArrayList<String> excludedPrefixes) {
230             super(urls, classLoader);
231             this.privateCopyPrefixes = privateCopyPrefixes;
232             this.excludedPrefixes = excludedPrefixes;
233         }
234 
235         @Override
findClass(String name)236         public Class<?> findClass(String name) throws ClassNotFoundException {
237             if (!classShouldBePrivate(name) || classShouldBeExcluded(name)) {
238                 throw new ClassNotFoundException(format("Can only load classes with prefixes : %s, but not : %s",
239                                                         privateCopyPrefixes,
240                                                         excludedPrefixes));
241             }
242             try {
243                 return super.findClass(name);
244             } catch (ClassNotFoundException cnfe) {
245                 throw new ClassNotFoundException(format("%s%n%s%n",
246                                                         cnfe.getMessage(),
247                                                         "    Did you forgot to add the code source url 'withCodeSourceUrlOf' / 'withCurrentCodeSourceUrls' ?"),
248                                                  cnfe);
249             }
250         }
251 
classShouldBePrivate(String name)252         private boolean classShouldBePrivate(String name) {
253             for (String prefix : privateCopyPrefixes) {
254                 if (name.startsWith(prefix)) return true;
255             }
256             return false;
257         }
258 
classShouldBeExcluded(String name)259         private boolean classShouldBeExcluded(String name) {
260             for (String prefix : excludedPrefixes) {
261                 if (name.startsWith(prefix)) return true;
262             }
263             return false;
264         }
265     }
266 
267     public static class ExcludingURLClassLoaderBuilder extends ClassLoaders {
268         private final ArrayList<String> excludedPrefixes = new ArrayList<String>();
269         private final ArrayList<URL> codeSourceUrls = new ArrayList<URL>();
270 
without(String... privatePrefixes)271         public ExcludingURLClassLoaderBuilder without(String... privatePrefixes) {
272             excludedPrefixes.addAll(asList(privatePrefixes));
273             return this;
274         }
275 
withCodeSourceUrls(String... urls)276         public ExcludingURLClassLoaderBuilder withCodeSourceUrls(String... urls) {
277             codeSourceUrls.addAll(pathsToURLs(urls));
278             return this;
279         }
280 
withCodeSourceUrlOf(Class<?>.... classes)281         public ExcludingURLClassLoaderBuilder withCodeSourceUrlOf(Class<?>... classes) {
282             for (Class<?> clazz : classes) {
283                 codeSourceUrls.add(obtainCurrentClassPathOf(clazz.getName()));
284             }
285             return this;
286         }
287 
withCurrentCodeSourceUrls()288         public ExcludingURLClassLoaderBuilder withCurrentCodeSourceUrls() {
289             codeSourceUrls.add(obtainCurrentClassPathOf(ClassLoaders.class.getName()));
290             return this;
291         }
292 
build()293         public ClassLoader build() {
294             return new LocalExcludingURLClassLoader(
295                     jdkClassLoader(),
296                     codeSourceUrls.toArray(new URL[codeSourceUrls.size()]),
297                     excludedPrefixes
298             );
299         }
300     }
301 
302     static class LocalExcludingURLClassLoader extends URLClassLoader {
303         private final ArrayList<String> excludedPrefixes;
304 
LocalExcludingURLClassLoader(ClassLoader classLoader, URL[] urls, ArrayList<String> excludedPrefixes)305         LocalExcludingURLClassLoader(ClassLoader classLoader,
306                                      URL[] urls,
307                                      ArrayList<String> excludedPrefixes) {
308             super(urls, classLoader);
309             this.excludedPrefixes = excludedPrefixes;
310         }
311 
312         @Override
findClass(String name)313         public Class<?> findClass(String name) throws ClassNotFoundException {
314             if (classShouldBeExcluded(name))
315                 throw new ClassNotFoundException("classes with prefix : " + excludedPrefixes + " are excluded");
316             return super.findClass(name);
317         }
318 
classShouldBeExcluded(String name)319         private boolean classShouldBeExcluded(String name) {
320             for (String prefix : excludedPrefixes) {
321                 if (name.startsWith(prefix)) return true;
322             }
323             return false;
324         }
325     }
326 
327     public static class InMemoryClassLoaderBuilder extends ClassLoaders {
328         private Map<String, byte[]> inMemoryClassObjects = new HashMap<String, byte[]>();
329 
withParent(ClassLoader parent)330         public InMemoryClassLoaderBuilder withParent(ClassLoader parent) {
331             this.parent = parent;
332             return this;
333         }
334 
withClassDefinition(String name, byte[] classDefinition)335         public InMemoryClassLoaderBuilder withClassDefinition(String name, byte[] classDefinition) {
336             inMemoryClassObjects.put(name, classDefinition);
337             return this;
338         }
339 
build()340         public ClassLoader build() {
341             return new InMemoryClassLoader(parent, inMemoryClassObjects);
342         }
343     }
344 
345     static class InMemoryClassLoader extends ClassLoader {
346         public static final String SCHEME = "mem";
347         private Map<String, byte[]> inMemoryClassObjects = new HashMap<String, byte[]>();
348 
InMemoryClassLoader(ClassLoader parent, Map<String, byte[]> inMemoryClassObjects)349         public InMemoryClassLoader(ClassLoader parent, Map<String, byte[]> inMemoryClassObjects) {
350             super(parent);
351             this.inMemoryClassObjects = inMemoryClassObjects;
352         }
353 
findClass(String name)354         protected Class<?> findClass(String name) throws ClassNotFoundException {
355             byte[] classDefinition = inMemoryClassObjects.get(name);
356             if (classDefinition != null) {
357                 return defineClass(name, classDefinition, 0, classDefinition.length);
358             }
359             throw new ClassNotFoundException(name);
360         }
361 
362         @Override
getResources(String ignored)363         public Enumeration<URL> getResources(String ignored) throws IOException {
364             return inMemoryOnly();
365         }
366 
inMemoryOnly()367         private Enumeration<URL> inMemoryOnly() {
368             final Set<String> names = inMemoryClassObjects.keySet();
369             return new Enumeration<URL>() {
370                 private final MemHandler memHandler = new MemHandler(InMemoryClassLoader.this);
371                 private final Iterator<String> it = names.iterator();
372 
373                 public boolean hasMoreElements() {
374                     return it.hasNext();
375                 }
376 
377                 public URL nextElement() {
378                     try {
379                         return new URL(null, SCHEME + ":" + it.next(), memHandler);
380                     } catch (MalformedURLException rethrown) {
381                         throw new IllegalStateException(rethrown);
382                     }
383                 }
384             };
385         }
386     }
387 
388     public static class MemHandler extends URLStreamHandler {
389         private InMemoryClassLoader inMemoryClassLoader;
390 
MemHandler(InMemoryClassLoader inMemoryClassLoader)391         public MemHandler(InMemoryClassLoader inMemoryClassLoader) {
392             this.inMemoryClassLoader = inMemoryClassLoader;
393         }
394 
395         @Override
openConnection(URL url)396         protected URLConnection openConnection(URL url) throws IOException {
397             return new MemURLConnection(url, inMemoryClassLoader);
398         }
399 
400         private static class MemURLConnection extends URLConnection {
401             private final InMemoryClassLoader inMemoryClassLoader;
402             private String qualifiedName;
403 
MemURLConnection(URL url, InMemoryClassLoader inMemoryClassLoader)404             public MemURLConnection(URL url, InMemoryClassLoader inMemoryClassLoader) {
405                 super(url);
406                 this.inMemoryClassLoader = inMemoryClassLoader;
407                 qualifiedName = url.getPath();
408             }
409 
410             @Override
connect()411             public void connect() throws IOException {
412             }
413 
414             @Override
getInputStream()415             public InputStream getInputStream() throws IOException {
416                 return new ByteArrayInputStream(inMemoryClassLoader.inMemoryClassObjects.get(qualifiedName));
417             }
418         }
419     }
420 
obtainCurrentClassPathOf(String className)421     URL obtainCurrentClassPathOf(String className) {
422         String path = className.replace('.', '/') + ".class";
423         String url = ClassLoaders.class.getClassLoader().getResource(path).toExternalForm();
424 
425         try {
426             return new URL(url.substring(0, url.length() - path.length()));
427         } catch (MalformedURLException e) {
428             throw new RuntimeException("Classloader couldn't obtain a proper classpath URL", e);
429         }
430     }
431 
pathsToURLs(String... codeSourceUrls)432     List<URL> pathsToURLs(String... codeSourceUrls) {
433         return pathsToURLs(Arrays.asList(codeSourceUrls));
434     }
435 
pathsToURLs(List<String> codeSourceUrls)436     private List<URL> pathsToURLs(List<String> codeSourceUrls) {
437         ArrayList<URL> urls = new ArrayList<URL>(codeSourceUrls.size());
438         for (String codeSourceUrl : codeSourceUrls) {
439             URL url = pathToUrl(codeSourceUrl);
440             urls.add(url);
441         }
442         return urls;
443     }
444 
pathToUrl(String path)445     private URL pathToUrl(String path) {
446         try {
447             return new File(path).getAbsoluteFile().toURI().toURL();
448         } catch (MalformedURLException e) {
449             throw new IllegalArgumentException("Path is malformed", e);
450         }
451     }
452 
453     public static class ReachableClassesFinder {
454         private ClassLoader classLoader;
455         private Set<String> qualifiedNameSubstring = new HashSet<String>();
456 
ReachableClassesFinder(ClassLoader classLoader)457         ReachableClassesFinder(ClassLoader classLoader) {
458             this.classLoader = classLoader;
459         }
460 
omit(String... qualifiedNameSubstring)461         public ReachableClassesFinder omit(String... qualifiedNameSubstring) {
462             this.qualifiedNameSubstring.addAll(Arrays.asList(qualifiedNameSubstring));
463             return this;
464         }
465 
listOwnedClasses()466         public Set<String> listOwnedClasses() throws IOException, URISyntaxException {
467             Enumeration<URL> roots = classLoader.getResources("");
468 
469             Set<String> classes = new HashSet<String>();
470             while (roots.hasMoreElements()) {
471                 URI uri = roots.nextElement().toURI();
472 
473                 if (uri.getScheme().equalsIgnoreCase("file")) {
474                     addFromFileBasedClassLoader(classes, uri);
475                 } else if (uri.getScheme().equalsIgnoreCase(InMemoryClassLoader.SCHEME)) {
476                     addFromInMemoryBasedClassLoader(classes, uri);
477                 } else {
478                     throw new IllegalArgumentException(format("Given ClassLoader '%s' don't have reachable by File or vi ClassLoaders.inMemory", classLoader));
479                 }
480             }
481             return classes;
482         }
483 
addFromFileBasedClassLoader(Set<String> classes, URI uri)484         private void addFromFileBasedClassLoader(Set<String> classes, URI uri) {
485             File root = new File(uri);
486             classes.addAll(findClassQualifiedNames(root, root, qualifiedNameSubstring));
487         }
488 
addFromInMemoryBasedClassLoader(Set<String> classes, URI uri)489         private void addFromInMemoryBasedClassLoader(Set<String> classes, URI uri) {
490             String qualifiedName = uri.getSchemeSpecificPart();
491             if (excludes(qualifiedName, qualifiedNameSubstring)) {
492                 classes.add(qualifiedName);
493             }
494         }
495 
496 
findClassQualifiedNames(File root, File file, Set<String> packageFilters)497         private Set<String> findClassQualifiedNames(File root, File file, Set<String> packageFilters) {
498             if (file.isDirectory()) {
499                 File[] files = file.listFiles();
500                 Set<String> classes = new HashSet<String>();
501                 for (File children : files) {
502                     classes.addAll(findClassQualifiedNames(root, children, packageFilters));
503                 }
504                 return classes;
505             } else {
506                 if (file.getName().endsWith(".class")) {
507                     String qualifiedName = classNameFor(root, file);
508                     if (excludes(qualifiedName, packageFilters)) {
509                         return Collections.singleton(qualifiedName);
510                     }
511                 }
512             }
513             return Collections.emptySet();
514         }
515 
excludes(String qualifiedName, Set<String> packageFilters)516         private boolean excludes(String qualifiedName, Set<String> packageFilters) {
517             for (String filter : packageFilters) {
518                 if (qualifiedName.contains(filter)) return false;
519             }
520             return true;
521         }
522 
classNameFor(File root, File file)523         private String classNameFor(File root, File file) {
524             String temp = file.getAbsolutePath().substring(root.getAbsolutePath().length() + 1).
525                     replace(File.separatorChar, '.');
526             return temp.subSequence(0, temp.indexOf(".class")).toString();
527         }
528 
529     }
530 }
531