1 /*
2  * Copyright (c) 2017 Mockito contributors
3  * This program is made available under the terms of the MIT License.
4  */
5 package org.mockitoutil;
6 
7 import java.net.MalformedURLException;
8 import java.net.URL;
9 import java.net.URLClassLoader;
10 import java.util.HashMap;
11 import java.util.Map;
12 import java.util.concurrent.Callable;
13 
14 /**
15  * Custom classloader to load classes in hierarchic realm.
16  *
17  * Each class can be reloaded in the realm if the LoadClassPredicate says so.
18  */
19 public class SimplePerRealmReloadingClassLoader extends URLClassLoader {
20 
21     private final Map<String,Class<?>> classHashMap = new HashMap<String, Class<?>>();
22     private ReloadClassPredicate reloadClassPredicate;
23 
SimplePerRealmReloadingClassLoader(ReloadClassPredicate reloadClassPredicate)24     public SimplePerRealmReloadingClassLoader(ReloadClassPredicate reloadClassPredicate) {
25         super(getPossibleClassPathsUrls());
26         this.reloadClassPredicate = reloadClassPredicate;
27     }
28 
SimplePerRealmReloadingClassLoader(ClassLoader parentClassLoader, ReloadClassPredicate reloadClassPredicate)29     public SimplePerRealmReloadingClassLoader(ClassLoader parentClassLoader, ReloadClassPredicate reloadClassPredicate) {
30         super(getPossibleClassPathsUrls(), parentClassLoader);
31         this.reloadClassPredicate = reloadClassPredicate;
32     }
33 
getPossibleClassPathsUrls()34     private static URL[] getPossibleClassPathsUrls() {
35         return new URL[]{
36                 obtainClassPath(),
37                 obtainClassPath("org.mockito.Mockito"),
38                 obtainClassPath("net.bytebuddy.ByteBuddy")
39         };
40     }
41 
obtainClassPath()42     private static URL obtainClassPath() {
43         String className = SimplePerRealmReloadingClassLoader.class.getName();
44         return obtainClassPath(className);
45     }
46 
obtainClassPath(String className)47     private static URL obtainClassPath(String className) {
48         String path = className.replace('.', '/') + ".class";
49         String url = SimplePerRealmReloadingClassLoader.class.getClassLoader().getResource(path).toExternalForm();
50 
51         try {
52             return new URL(url.substring(0, url.length() - path.length()));
53         } catch (MalformedURLException e) {
54             throw new RuntimeException("Classloader couldn't obtain a proper classpath URL", e);
55         }
56     }
57 
58 
59 
60     @Override
loadClass(String qualifiedClassName)61     public Class<?> loadClass(String qualifiedClassName) throws ClassNotFoundException {
62         if(reloadClassPredicate.acceptReloadOf(qualifiedClassName)) {
63             // return customLoadClass(qualifiedClassName);
64 //            Class<?> loadedClass = findLoadedClass(qualifiedClassName);
65             if(!classHashMap.containsKey(qualifiedClassName)) {
66                 Class<?> foundClass = findClass(qualifiedClassName);
67                 saveFoundClass(qualifiedClassName, foundClass);
68                 return foundClass;
69             }
70 
71             return classHashMap.get(qualifiedClassName);
72         }
73         return useParentClassLoaderFor(qualifiedClassName);
74     }
75 
saveFoundClass(String qualifiedClassName, Class<?> foundClass)76     private void saveFoundClass(String qualifiedClassName, Class<?> foundClass) {
77         classHashMap.put(qualifiedClassName, foundClass);
78     }
79 
80 
useParentClassLoaderFor(String qualifiedName)81     private Class<?> useParentClassLoaderFor(String qualifiedName) throws ClassNotFoundException {
82         return super.loadClass(qualifiedName);
83     }
84 
85 
doInRealm(String callableCalledInClassLoaderRealm)86     public Object doInRealm(String callableCalledInClassLoaderRealm) throws Exception {
87         ClassLoader current = Thread.currentThread().getContextClassLoader();
88         try {
89             Thread.currentThread().setContextClassLoader(this);
90             Object instance = this.loadClass(callableCalledInClassLoaderRealm).getConstructor().newInstance();
91             if (instance instanceof Callable) {
92                 Callable<?> callableInRealm = (Callable<?>) instance;
93                 return callableInRealm.call();
94             }
95         } finally {
96             Thread.currentThread().setContextClassLoader(current);
97         }
98         throw new IllegalArgumentException("qualified name '" + callableCalledInClassLoaderRealm + "' should represent a class implementing Callable");
99     }
100 
101 
doInRealm(String callableCalledInClassLoaderRealm, Class<?>[] argTypes, Object[] args)102     public Object doInRealm(String callableCalledInClassLoaderRealm, Class<?>[] argTypes, Object[] args) throws Exception {
103         ClassLoader current = Thread.currentThread().getContextClassLoader();
104         try {
105             Thread.currentThread().setContextClassLoader(this);
106             Object instance = this.loadClass(callableCalledInClassLoaderRealm).getConstructor(argTypes).newInstance(args);
107             if (instance instanceof Callable) {
108                 Callable<?> callableInRealm = (Callable<?>) instance;
109                 return callableInRealm.call();
110             }
111         } finally {
112             Thread.currentThread().setContextClassLoader(current);
113         }
114 
115         throw new IllegalArgumentException("qualified name '" + callableCalledInClassLoaderRealm + "' should represent a class implementing Callable");
116     }
117 
118 
119     public interface ReloadClassPredicate {
acceptReloadOf(String qualifiedName)120         boolean acceptReloadOf(String qualifiedName);
121     }
122 }
123