1 /*
2  * Copyright 2010 Google Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.google.android.testing.mocking;
17 
18 import javassist.CannotCompileException;
19 import javassist.ClassClassPath;
20 import javassist.ClassPool;
21 import javassist.CtClass;
22 import javassist.CtConstructor;
23 import javassist.CtField;
24 import javassist.CtMethod;
25 import javassist.CtNewConstructor;
26 import javassist.NotFoundException;
27 
28 import java.io.IOException;
29 import java.lang.reflect.Constructor;
30 import java.lang.reflect.Method;
31 import java.lang.reflect.Modifier;
32 import java.util.ArrayList;
33 import java.util.Arrays;
34 import java.util.HashMap;
35 import java.util.List;
36 import java.util.Map;
37 
38 
39 /**
40  * AndroidMockGenerator creates the subclass and interface required for mocking
41  * a given Class.
42  *
43  * The only public method of AndroidMockGenerator is createMocksForClass. See
44  * the javadocs for this method for more information about AndroidMockGenerator.
45  *
46  * @author swoodward@google.com (Stephen Woodward)
47  */
48 class AndroidMockGenerator {
AndroidMockGenerator()49   public AndroidMockGenerator() {
50     ClassPool.doPruning = false;
51     ClassPool.getDefault().insertClassPath(new ClassClassPath(MockObject.class));
52   }
53 
54   /**
55    * Creates a List of javassist.CtClass objects representing all of the
56    * interfaces and subclasses required to meet the Mocking requests of the
57    * Class specified by {@code clazz}.
58    *
59    * A test class can request that a Class be prepared for mocking by using the
60    * {@link UsesMocks} annotation at either the Class or Method level. All
61    * classes specified by these annotations will have exactly two CtClass
62    * objects created, one for a generated interface, and one for a generated
63    * subclass. The interface and subclass both define the same methods which
64    * comprise all of the mockable methods of the provided class. At present, for
65    * a method to be mockable, it must be non-final and non-static, although this
66    * may expand in the future.
67    *
68    * The class itself must be mockable, otherwise this method will ignore the
69    * requested mock and print a warning. At present, a class is mockable if it
70    * is a non-final publicly-instantiable Java class that is assignable from the
71    * java.lang.Object class. See the javadocs for
72    * {@link java.lang.Class#isAssignableFrom(Class)} for more information about
73    * what "is assignable from the Object class" means. As a non-exhaustive
74    * example, if a given Class represents an Enum, Annotation, Primitive or
75    * Array, then it is not assignable from Object. Interfaces are also ignored
76    * since these need no modifications in order to be mocked.
77    *
78    * @param clazz the Class object to have all of its UsesMocks annotations
79    *        processed and the corresponding Mock Classes created.
80    * @return a List of CtClass objects representing the Classes and Interfaces
81    *         required for mocking the classes requested by {@code clazz}
82    * @throws ClassNotFoundException
83    * @throws CannotCompileException
84    * @throws IOException
85    */
createMocksForClass(Class<?> clazz)86   public List<GeneratedClassFile> createMocksForClass(Class<?> clazz)
87       throws ClassNotFoundException, IOException, CannotCompileException {
88     return this.createMocksForClass(clazz, SdkVersion.UNKNOWN);
89   }
90 
createMocksForClass(Class<?> clazz, SdkVersion sdkVersion)91   public List<GeneratedClassFile> createMocksForClass(Class<?> clazz, SdkVersion sdkVersion)
92       throws ClassNotFoundException, IOException, CannotCompileException {
93     if (!classIsSupportedType(clazz)) {
94       reportReasonForUnsupportedType(clazz);
95       return Arrays.asList(new GeneratedClassFile[0]);
96     }
97     CtClass newInterfaceCtClass = generateInterface(clazz, sdkVersion);
98     GeneratedClassFile newInterface = new GeneratedClassFile(newInterfaceCtClass.getName(),
99         newInterfaceCtClass.toBytecode());
100     CtClass mockDelegateCtClass = generateSubClass(clazz, newInterfaceCtClass, sdkVersion);
101     GeneratedClassFile mockDelegate = new GeneratedClassFile(mockDelegateCtClass.getName(),
102         mockDelegateCtClass.toBytecode());
103     return Arrays.asList(new GeneratedClassFile[] {newInterface, mockDelegate});
104   }
105 
reportReasonForUnsupportedType(Class<?> clazz)106   private void reportReasonForUnsupportedType(Class<?> clazz) {
107     String reason = null;
108     if (clazz.isInterface()) {
109       // do nothing to make sure none of the other conditions apply.
110     } else if (clazz.isEnum()) {
111       reason = "Cannot mock an Enum";
112     } else if (clazz.isAnnotation()) {
113       reason = "Cannot mock an Annotation";
114     } else if (clazz.isArray()) {
115       reason = "Cannot mock an Array";
116     } else if (Modifier.isFinal(clazz.getModifiers())) {
117       reason = "Cannot mock a Final class";
118     } else if (clazz.isPrimitive()) {
119       reason = "Cannot mock primitives";
120     } else if (!Object.class.isAssignableFrom(clazz)) {
121       reason = "Cannot mock non-classes";
122     } else if (!containsUsableConstructor(clazz)) {
123       reason = "Cannot mock a class with no public constructors";
124     } else {
125       // Whatever the reason is, it's not one that we care about.
126     }
127     if (reason != null) {
128       // Sometimes we want to be silent, so check 'reason' against null.
129       System.err.println(reason + ": " + clazz.getName());
130     }
131   }
132 
containsUsableConstructor(Class<?> clazz)133   private boolean containsUsableConstructor(Class<?> clazz) {
134     Constructor<?>[] constructors = clazz.getDeclaredConstructors();
135     for (Constructor<?> constructor : constructors) {
136       if (Modifier.isPublic(constructor.getModifiers()) ||
137           Modifier.isProtected(constructor.getModifiers())) {
138         return true;
139       }
140     }
141     return false;
142   }
143 
classIsSupportedType(Class<?> clazz)144   boolean classIsSupportedType(Class<?> clazz) {
145     return (containsUsableConstructor(clazz)) && Object.class.isAssignableFrom(clazz)
146         && !clazz.isInterface() && !clazz.isEnum() && !clazz.isAnnotation() && !clazz.isArray()
147         && !Modifier.isFinal(clazz.getModifiers());
148   }
149 
saveCtClass(CtClass clazz)150   void saveCtClass(CtClass clazz) throws ClassNotFoundException, IOException {
151     try {
152       clazz.writeFile();
153     } catch (NotFoundException e) {
154       throw new ClassNotFoundException("Error while saving modified class " + clazz.getName(), e);
155     } catch (CannotCompileException e) {
156       throw new RuntimeException("Internal Error: Attempt to save syntactically incorrect code "
157           + "for class " + clazz.getName(), e);
158     }
159   }
160 
generateInterface(Class<?> originalClass, SdkVersion sdkVersion)161   CtClass generateInterface(Class<?> originalClass, SdkVersion sdkVersion) {
162     ClassPool classPool = getClassPool();
163     try {
164       return classPool.getCtClass(FileUtils.getInterfaceNameFor(originalClass, sdkVersion));
165     } catch (NotFoundException e) {
166       CtClass newInterface =
167           classPool.makeInterface(FileUtils.getInterfaceNameFor(originalClass, sdkVersion));
168       addInterfaceMethods(originalClass, newInterface);
169       return newInterface;
170     }
171   }
172 
getInterfaceMethodSource(Method method)173   String getInterfaceMethodSource(Method method) throws UnsupportedOperationException {
174     StringBuilder methodBody = getMethodSignature(method);
175     methodBody.append(";");
176     return methodBody.toString();
177   }
178 
getMethodSignature(Method method)179   private StringBuilder getMethodSignature(Method method) {
180     int modifiers = method.getModifiers();
181     if (Modifier.isFinal(modifiers) || Modifier.isStatic(modifiers)) {
182       throw new UnsupportedOperationException(
183           "Cannot specify final or static methods in an interface");
184     }
185     StringBuilder methodSignature = new StringBuilder("public ");
186     methodSignature.append(getClassName(method.getReturnType()));
187     methodSignature.append(" ");
188     methodSignature.append(method.getName());
189     methodSignature.append("(");
190     int i = 0;
191     for (Class<?> arg : method.getParameterTypes()) {
192       methodSignature.append(getClassName(arg));
193       methodSignature.append(" arg");
194       methodSignature.append(i);
195       if (i < method.getParameterTypes().length - 1) {
196         methodSignature.append(",");
197       }
198       i++;
199     }
200     methodSignature.append(")");
201     if (method.getExceptionTypes().length > 0) {
202       methodSignature.append(" throws ");
203     }
204     i = 0;
205     for (Class<?> exception : method.getExceptionTypes()) {
206       methodSignature.append(getClassName(exception));
207       if (i < method.getExceptionTypes().length - 1) {
208         methodSignature.append(",");
209       }
210       i++;
211     }
212     return methodSignature;
213   }
214 
getClassName(Class<?> clazz)215   private String getClassName(Class<?> clazz) {
216     return clazz.getCanonicalName();
217   }
218 
getClassPool()219   static ClassPool getClassPool() {
220     return ClassPool.getDefault();
221   }
222 
classExists(String name)223   private boolean classExists(String name) {
224     // The following line is the ideal, but doesn't work (bug in library).
225     // return getClassPool().find(name) != null;
226     try {
227       getClassPool().get(name);
228       return true;
229     } catch (NotFoundException e) {
230       return false;
231     }
232   }
233 
generateSubClass(Class<?> superClass, CtClass newInterface, SdkVersion sdkVersion)234   CtClass generateSubClass(Class<?> superClass, CtClass newInterface, SdkVersion sdkVersion)
235       throws ClassNotFoundException {
236     if (classExists(FileUtils.getSubclassNameFor(superClass, sdkVersion))) {
237       try {
238         return getClassPool().get(FileUtils.getSubclassNameFor(superClass, sdkVersion));
239       } catch (NotFoundException e) {
240         throw new ClassNotFoundException("This should be impossible, since we just checked for "
241             + "the existence of the class being created", e);
242       }
243     }
244     CtClass newClass = generateSkeletalClass(superClass, newInterface, sdkVersion);
245     if (!newClass.isFrozen()) {
246       newClass.addInterface(newInterface);
247       try {
248         newClass.addInterface(getClassPool().get(MockObject.class.getName()));
249       } catch (NotFoundException e) {
250         throw new ClassNotFoundException("Could not find " + MockObject.class.getName(), e);
251       }
252       addMethods(superClass, newClass);
253       addGetDelegateMethod(newClass);
254       addSetDelegateMethod(newClass, newInterface);
255       addConstructors(newClass, superClass);
256     }
257     return newClass;
258   }
259 
addConstructors(CtClass clazz, Class<?> superClass)260   private void addConstructors(CtClass clazz, Class<?> superClass) throws ClassNotFoundException {
261     CtClass superCtClass = getCtClassForClass(superClass);
262 
263     CtConstructor[] constructors = superCtClass.getDeclaredConstructors();
264     for (CtConstructor constructor : constructors) {
265       int modifiers = constructor.getModifiers();
266       if (Modifier.isPublic(modifiers) || Modifier.isProtected(modifiers)) {
267          CtConstructor ctConstructor;
268         try {
269           ctConstructor = CtNewConstructor.make(constructor.getParameterTypes(),
270                constructor.getExceptionTypes(), clazz);
271           clazz.addConstructor(ctConstructor);
272         } catch (CannotCompileException e) {
273           throw new RuntimeException("Internal Error - Could not add constructors.", e);
274         } catch (NotFoundException e) {
275           throw new RuntimeException("Internal Error - Constructor suddenly could not be found", e);
276         }
277       }
278     }
279   }
280 
getCtClassForClass(Class<?> clazz)281   CtClass getCtClassForClass(Class<?> clazz) throws ClassNotFoundException {
282     ClassPool classPool = getClassPool();
283     try {
284       return classPool.get(clazz.getName());
285     } catch (NotFoundException e) {
286       throw new ClassNotFoundException("Class not found when finding the class to be mocked: "
287           + clazz.getName(), e);
288     }
289   }
290 
addSetDelegateMethod(CtClass clazz, CtClass newInterface)291   private void addSetDelegateMethod(CtClass clazz, CtClass newInterface) {
292     try {
293       clazz.addMethod(CtMethod.make(getSetDelegateMethodSource(newInterface), clazz));
294     } catch (CannotCompileException e) {
295       throw new RuntimeException("Internal error while creating the setDelegate() method", e);
296     }
297   }
298 
getSetDelegateMethodSource(CtClass newInterface)299   String getSetDelegateMethodSource(CtClass newInterface) {
300     return "public void setDelegate___AndroidMock(" + newInterface.getName() + " obj) { this."
301         + getDelegateFieldName() + " = obj;}";
302   }
303 
addGetDelegateMethod(CtClass clazz)304   private void addGetDelegateMethod(CtClass clazz) {
305     try {
306       CtMethod newMethod = CtMethod.make(getGetDelegateMethodSource(), clazz);
307       try {
308         CtMethod existingMethod = clazz.getMethod(newMethod.getName(), newMethod.getSignature());
309         clazz.removeMethod(existingMethod);
310       } catch (NotFoundException e) {
311         // expected path... sigh.
312       }
313       clazz.addMethod(newMethod);
314     } catch (CannotCompileException e) {
315       throw new RuntimeException("Internal error while creating the getDelegate() method", e);
316     }
317   }
318 
getGetDelegateMethodSource()319   private String getGetDelegateMethodSource() {
320     return "public Object getDelegate___AndroidMock() { return this." + getDelegateFieldName()
321         + "; }";
322   }
323 
getDelegateFieldName()324   String getDelegateFieldName() {
325     return "delegateMockObject";
326   }
327 
addInterfaceMethods(Class<?> originalClass, CtClass newInterface)328   void addInterfaceMethods(Class<?> originalClass, CtClass newInterface) {
329     Method[] methods = getAllMethods(originalClass);
330     for (Method method : methods) {
331       try {
332         if (isMockable(method)) {
333           CtMethod newMethod = CtMethod.make(getInterfaceMethodSource(method), newInterface);
334           newInterface.addMethod(newMethod);
335         }
336       } catch (UnsupportedOperationException e) {
337         // Can't handle finals and statics.
338       } catch (CannotCompileException e) {
339         throw new RuntimeException(
340             "Internal error while creating a new Interface method for class "
341                 + originalClass.getName() + ".  Method name: " + method.getName(), e);
342       }
343     }
344   }
345 
addMethods(Class<?> superClass, CtClass newClass)346   void addMethods(Class<?> superClass, CtClass newClass) {
347     Method[] methods = getAllMethods(superClass);
348     if (newClass.isFrozen()) {
349       newClass.defrost();
350     }
351     List<CtMethod> existingMethods = Arrays.asList(newClass.getDeclaredMethods());
352     for (Method method : methods) {
353       try {
354         if (isMockable(method)) {
355           CtMethod newMethod = CtMethod.make(getDelegateMethodSource(method), newClass);
356           if (!existingMethods.contains(newMethod)) {
357             newClass.addMethod(newMethod);
358           }
359         }
360       } catch (UnsupportedOperationException e) {
361         // Can't handle finals and statics.
362       } catch (CannotCompileException e) {
363         throw new RuntimeException("Internal Error while creating subclass methods for "
364             + newClass.getName() + " method: " + method.getName(), e);
365       }
366     }
367   }
368 
getAllMethods(Class<?> clazz)369   Method[] getAllMethods(Class<?> clazz) {
370     Map<String, Method> methodMap = getAllMethodsMap(clazz);
371     return methodMap.values().toArray(new Method[0]);
372   }
373 
getAllMethodsMap(Class<?> clazz)374   private Map<String, Method> getAllMethodsMap(Class<?> clazz) {
375     Map<String, Method> methodMap = new HashMap<String, Method>();
376     Class<?> superClass = clazz.getSuperclass();
377     if (superClass != null) {
378       methodMap.putAll(getAllMethodsMap(superClass));
379     }
380     List<Method> methods = new ArrayList<Method>(Arrays.asList(clazz.getDeclaredMethods()));
381     for (Method method : methods) {
382       String key = method.getName();
383       for (Class<?> param : method.getParameterTypes()) {
384         key += param.getCanonicalName();
385       }
386       methodMap.put(key, method);
387     }
388     return methodMap;
389   }
390 
isMockable(Method method)391   boolean isMockable(Method method) {
392     if (isForbiddenMethod(method)) {
393       return false;
394     }
395     int modifiers = method.getModifiers();
396     return !Modifier.isFinal(modifiers) && !Modifier.isStatic(modifiers) && !method.isBridge()
397         && (Modifier.isPublic(modifiers) || Modifier.isProtected(modifiers));
398   }
399 
isForbiddenMethod(Method method)400   boolean isForbiddenMethod(Method method) {
401     if (method.getName().equals("equals")) {
402       return method.getParameterTypes().length == 1
403           && method.getParameterTypes()[0].equals(Object.class);
404     } else if (method.getName().equals("toString")) {
405       return method.getParameterTypes().length == 0;
406     } else if (method.getName().equals("hashCode")) {
407       return method.getParameterTypes().length == 0;
408     }
409     return false;
410   }
411 
getReturnDefault(Method method)412   private String getReturnDefault(Method method) {
413     Class<?> returnType = method.getReturnType();
414     if (!returnType.isPrimitive()) {
415       return "null";
416     } else if (returnType == Boolean.TYPE) {
417       return "false";
418     } else if (returnType == Void.TYPE) {
419       return "";
420     } else {
421       return "(" + returnType.getName() + ")0";
422     }
423   }
424 
getDelegateMethodSource(Method method)425   String getDelegateMethodSource(Method method) {
426     StringBuilder methodBody = getMethodSignature(method);
427     methodBody.append("{");
428     methodBody.append("if(this.");
429     methodBody.append(getDelegateFieldName());
430     methodBody.append("==null){return ");
431     methodBody.append(getReturnDefault(method));
432     methodBody.append(";}");
433     if (!method.getReturnType().equals(Void.TYPE)) {
434       methodBody.append("return ");
435     }
436     methodBody.append("this.");
437     methodBody.append(getDelegateFieldName());
438     methodBody.append(".");
439     methodBody.append(method.getName());
440     methodBody.append("(");
441     for (int i = 0; i < method.getParameterTypes().length; ++i) {
442       methodBody.append("arg");
443       methodBody.append(i);
444       if (i < method.getParameterTypes().length - 1) {
445         methodBody.append(",");
446       }
447     }
448     methodBody.append(");}");
449     return methodBody.toString();
450   }
451 
generateSkeletalClass(Class<?> superClass, CtClass newInterface, SdkVersion sdkVersion)452   CtClass generateSkeletalClass(Class<?> superClass, CtClass newInterface, SdkVersion sdkVersion)
453       throws ClassNotFoundException {
454     ClassPool classPool = getClassPool();
455     CtClass superCtClass = getCtClassForClass(superClass);
456     String subclassName = FileUtils.getSubclassNameFor(superClass, sdkVersion);
457 
458     CtClass newClass;
459     try {
460       newClass = classPool.makeClass(subclassName, superCtClass);
461     } catch (RuntimeException e) {
462       if (e.getMessage().contains("frozen class")) {
463         try {
464           return classPool.get(subclassName);
465         } catch (NotFoundException ex) {
466           throw new ClassNotFoundException("Internal Error: could not find class", ex);
467         }
468       }
469       throw e;
470     }
471 
472     try {
473       newClass.addField(new CtField(newInterface, getDelegateFieldName(), newClass));
474     } catch (CannotCompileException e) {
475       throw new RuntimeException("Internal error adding the delegate field to "
476           + newClass.getName(), e);
477     }
478     return newClass;
479   }
480 }
481