1 package junitparams.internal.parameters;
2 
3 import java.lang.reflect.InvocationTargetException;
4 import java.lang.reflect.Method;
5 import java.util.ArrayList;
6 import java.util.Iterator;
7 import java.util.List;
8 
9 import org.junit.runners.model.FrameworkMethod;
10 
11 import junitparams.Parameters;
12 
13 class ParamsFromMethodCommon {
14     private FrameworkMethod frameworkMethod;
15 
ParamsFromMethodCommon(FrameworkMethod frameworkMethod)16     ParamsFromMethodCommon(FrameworkMethod frameworkMethod) {
17         this.frameworkMethod = frameworkMethod;
18     }
19 
paramsFromMethod(Class<?> sourceClass)20     Object[] paramsFromMethod(Class<?> sourceClass) {
21         String methodAnnotation = frameworkMethod.getAnnotation(Parameters.class).method();
22 
23         if (methodAnnotation.isEmpty()) {
24             return invokeMethodWithParams(defaultMethodName(), sourceClass);
25         }
26 
27         List<Object> result = new ArrayList<Object>();
28         for (String methodName : methodAnnotation.split(",")) {
29             for (Object param : invokeMethodWithParams(methodName.trim(), sourceClass))
30                 result.add(param);
31         }
32 
33         return result.toArray();
34     }
35 
getDataFromMethod(Method providerMethod)36     Object[] getDataFromMethod(Method providerMethod) throws IllegalAccessException, InvocationTargetException {
37         return encapsulateParamsIntoArrayIfSingleParamsetPassed((Object[]) providerMethod.invoke(null));
38     }
39 
containsDefaultParametersProvidingMethod(Class<?> sourceClass)40     boolean containsDefaultParametersProvidingMethod(Class<?> sourceClass) {
41         return findMethodInTestClassHierarchy(defaultMethodName(), sourceClass) != null;
42     }
43 
defaultMethodName()44     private String defaultMethodName() {
45         return "parametersFor" + frameworkMethod.getName().substring(0, 1).toUpperCase()
46                 + this.frameworkMethod.getName().substring(1);
47     }
48 
invokeMethodWithParams(String methodName, Class<?> sourceClass)49     private Object[] invokeMethodWithParams(String methodName, Class<?> sourceClass) {
50         Method providerMethod = findMethodInTestClassHierarchy(methodName, sourceClass);
51         if (providerMethod == null) {
52             throw new RuntimeException("Could not find method: " + methodName + " so no params were used.");
53         }
54 
55         return invokeParamsProvidingMethod(providerMethod, sourceClass);
56     }
57 
58     @SuppressWarnings("unchecked")
invokeParamsProvidingMethod(Method provideMethod, Class<?> sourceClass)59     private Object[] invokeParamsProvidingMethod(Method provideMethod, Class<?> sourceClass) {
60         try {
61             Object testObject = sourceClass.newInstance();
62             provideMethod.setAccessible(true);
63             Object result = provideMethod.invoke(testObject);
64 
65             if (Object[].class.isAssignableFrom(result.getClass())) {
66                 Object[] params = (Object[]) result;
67                 return encapsulateParamsIntoArrayIfSingleParamsetPassed(params);
68             }
69 
70             if (Iterable.class.isAssignableFrom(result.getClass())) {
71                 try {
72                     ArrayList<Object[]> res = new ArrayList<Object[]>();
73                     for (Object[] paramSet : (Iterable<Object[]>) result)
74                         res.add(paramSet);
75                     return res.toArray();
76                 } catch (ClassCastException e1) {
77                     // Iterable with consecutive paramsets, each of one param
78                     ArrayList<Object> res = new ArrayList<Object>();
79                     for (Object param : (Iterable<?>) result)
80                         res.add(new Object[]{param});
81                     return res.toArray();
82                 }
83             }
84 
85             if (Iterator.class.isAssignableFrom(result.getClass())) {
86                 Object iteratedElement = null;
87                 try {
88                     ArrayList<Object[]> res = new ArrayList<Object[]>();
89                     Iterator<Object[]> iterator = (Iterator<Object[]>) result;
90                     while (iterator.hasNext()) {
91                         iteratedElement = iterator.next();
92                         // ClassCastException will occur in the following line
93                         // if the iterator is actually Iterator<Object> in Java 7
94                         res.add((Object[]) iteratedElement);
95                     }
96                     return res.toArray();
97                 } catch (ClassCastException e1) {
98                     // Iterator with consecutive paramsets, each of one param
99                     ArrayList<Object> res = new ArrayList<Object>();
100                     Iterator<?> iterator = (Iterator<?>) result;
101                     // The first element is already stored in iteratedElement
102                     res.add(iteratedElement);
103                     while (iterator.hasNext()) {
104                         res.add(new Object[]{iterator.next()});
105                     }
106                     return res.toArray();
107                 }
108             }
109 
110             throw new ClassCastException();
111 
112         } catch (ClassCastException e) {
113             throw new RuntimeException("The return type of: " + provideMethod.getName() + " defined in class " +
114                     sourceClass + " is not Object[][] nor Iterable<Object[]>. Fix it!", e);
115         } catch (Exception e) {
116             throw new RuntimeException("Could not invoke method: " + provideMethod.getName() + " defined in class " +
117                     sourceClass + " so no params were used.", e);
118         }
119     }
120 
findMethodInTestClassHierarchy(String methodName, Class<?> sourceClass)121     private Method findMethodInTestClassHierarchy(String methodName, Class<?> sourceClass) {
122         Class<?> declaringClass = sourceClass;
123         while (declaringClass.getSuperclass() != null) {
124             try {
125                 return declaringClass.getDeclaredMethod(methodName);
126             } catch (Exception ignore) {
127             }
128             declaringClass = declaringClass.getSuperclass();
129         }
130         return null;
131     }
132 
encapsulateParamsIntoArrayIfSingleParamsetPassed(Object[] params)133     private Object[] encapsulateParamsIntoArrayIfSingleParamsetPassed(Object[] params) {
134         if (frameworkMethod.getMethod().getParameterTypes().length != params.length) {
135             return params;
136         }
137 
138         if (params.length == 0) {
139             return params;
140         }
141 
142         Object param = params[0];
143         if (param == null || !param.getClass().isArray()) {
144             return new Object[]{params};
145         }
146 
147         return params;
148     }
149 
150 }
151