1 /*
2  * Copyright (c) 2007 Mockito contributors
3  * This program is made available under the terms of the MIT License.
4  */
5 package org.mockito.internal.stubbing.answers;
6 
7 import static org.mockito.internal.exceptions.Reporter.invalidArgumentPositionRangeAtInvocationTime;
8 import static org.mockito.internal.exceptions.Reporter.invalidArgumentRangeAtIdentityAnswerCreationTime;
9 import static org.mockito.internal.exceptions.Reporter.wrongTypeOfArgumentToReturn;
10 
11 import java.io.Serializable;
12 import java.lang.reflect.Method;
13 import org.mockito.invocation.Invocation;
14 import org.mockito.invocation.InvocationOnMock;
15 import org.mockito.stubbing.Answer;
16 import org.mockito.stubbing.ValidableAnswer;
17 
18 /**
19  * Returns the passed parameter identity at specified index.
20  * <p>
21  * <p>
22  * The <code>argumentIndex</code> represents the index in the argument array of the invocation.
23  * </p>
24  * <p>
25  * If this number equals -1 then the last argument is returned.
26  * </p>
27  *
28  * @see org.mockito.AdditionalAnswers
29  * @since 1.9.5
30  */
31 public class ReturnsArgumentAt implements Answer<Object>, ValidableAnswer, Serializable {
32 
33     private static final long serialVersionUID = -589315085166295101L;
34 
35     public static final int LAST_ARGUMENT = -1;
36 
37     private final int wantedArgumentPosition;
38 
39     /**
40      * Build the identity answer to return the argument at the given position in the argument array.
41      *
42      * @param wantedArgumentPosition
43      *            The position of the argument identity to return in the invocation. Using <code>-1</code> indicates the last argument ({@link #LAST_ARGUMENT}).
44      */
ReturnsArgumentAt(int wantedArgumentPosition)45     public ReturnsArgumentAt(int wantedArgumentPosition) {
46         if (wantedArgumentPosition != LAST_ARGUMENT && wantedArgumentPosition < 0) {
47             throw invalidArgumentRangeAtIdentityAnswerCreationTime();
48         }
49         this.wantedArgumentPosition = wantedArgumentPosition;
50     }
51 
52     @Override
answer(InvocationOnMock invocation)53     public Object answer(InvocationOnMock invocation) throws Throwable {
54         int argumentPosition = inferWantedArgumentPosition(invocation);
55         validateIndexWithinInvocationRange(invocation, argumentPosition);
56 
57         if (wantedArgIndexIsVarargAndSameTypeAsReturnType(invocation.getMethod(), argumentPosition)) {
58             // answer raw vararg array argument
59             return ((Invocation) invocation).getRawArguments()[argumentPosition];
60         }
61 
62         // answer expanded argument at wanted position
63         return invocation.getArgument(argumentPosition);
64 
65     }
66 
67     @Override
validateFor(InvocationOnMock invocation)68     public void validateFor(InvocationOnMock invocation) {
69         int argumentPosition = inferWantedArgumentPosition(invocation);
70         validateIndexWithinInvocationRange(invocation, argumentPosition);
71         validateArgumentTypeCompatibility((Invocation) invocation, argumentPosition);
72     }
73 
inferWantedArgumentPosition(InvocationOnMock invocation)74     private int inferWantedArgumentPosition(InvocationOnMock invocation) {
75         if (wantedArgumentPosition == LAST_ARGUMENT)
76             return invocation.getArguments().length - 1;
77 
78         return wantedArgumentPosition;
79     }
80 
validateIndexWithinInvocationRange(InvocationOnMock invocation, int argumentPosition)81     private void validateIndexWithinInvocationRange(InvocationOnMock invocation, int argumentPosition) {
82         if (!wantedArgumentPositionIsValidForInvocation(invocation, argumentPosition)) {
83             throw invalidArgumentPositionRangeAtInvocationTime(invocation,
84                                                                wantedArgumentPosition == LAST_ARGUMENT,
85                                                                wantedArgumentPosition);
86         }
87     }
88 
validateArgumentTypeCompatibility(Invocation invocation, int argumentPosition)89     private void validateArgumentTypeCompatibility(Invocation invocation, int argumentPosition) {
90         InvocationInfo invocationInfo = new InvocationInfo(invocation);
91 
92         Class<?> inferredArgumentType = inferArgumentType(invocation, argumentPosition);
93 
94         if (!invocationInfo.isValidReturnType(inferredArgumentType)){
95             throw wrongTypeOfArgumentToReturn(invocation,
96                                               invocationInfo.printMethodReturnType(),
97                                               inferredArgumentType,
98                                               wantedArgumentPosition);
99         }
100     }
101 
wantedArgIndexIsVarargAndSameTypeAsReturnType(Method method, int argumentPosition)102     private boolean wantedArgIndexIsVarargAndSameTypeAsReturnType(Method method, int argumentPosition) {
103         Class<?>[] parameterTypes = method.getParameterTypes();
104         return method.isVarArgs() &&
105               argumentPosition == /* vararg index */ parameterTypes.length - 1 &&
106               method.getReturnType().isAssignableFrom(parameterTypes[argumentPosition]);
107     }
108 
wantedArgumentPositionIsValidForInvocation(InvocationOnMock invocation, int argumentPosition)109     private boolean wantedArgumentPositionIsValidForInvocation(InvocationOnMock invocation, int argumentPosition) {
110         if (argumentPosition < 0) {
111             return false;
112         }
113         if (!invocation.getMethod().isVarArgs()) {
114             return invocation.getArguments().length > argumentPosition;
115         }
116         // for all varargs accepts positive ranges
117         return true;
118     }
119 
inferArgumentType(Invocation invocation, int argumentIndex)120     private Class<?> inferArgumentType(Invocation invocation, int argumentIndex) {
121         Class<?>[] parameterTypes = invocation.getMethod().getParameterTypes();
122 
123         // Easy when the method is not a vararg
124         if (!invocation.getMethod().isVarArgs()) {
125             Class<?> argumentType = parameterTypes[argumentIndex];
126             Object argumentValue = invocation.getArgument(argumentIndex);
127             // we don't want to return primitive wrapper types
128             if (argumentType.isPrimitive() || argumentValue == null) {
129                 return argumentType;
130             }
131 
132             return argumentValue.getClass();
133         }
134 
135         // Now for varargs
136         int varargIndex = parameterTypes.length - 1; // vararg always last
137 
138         if (argumentIndex < varargIndex) {
139             // Same for non vararg arguments
140             return parameterTypes[argumentIndex];
141         }
142         // if wanted argument is vararg
143         if (wantedArgIndexIsVarargAndSameTypeAsReturnType(invocation.getMethod(), argumentIndex)) {
144             // return the vararg array if return type is compatible
145             // because the user probably want to return the array itself if the return type is compatible
146             return parameterTypes[argumentIndex]; // move to MethodInfo ?
147         }
148         // return the type in this vararg array
149         return parameterTypes[varargIndex].getComponentType();
150 
151     }
152 }
153 
154