1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <cstdlib>
18 #include <sstream>
19 #include <cstring>
20 #include <cassert>
21 #include <cstdarg>
22 #include <algorithm>
23 
24 #include <jni.h>
25 
26 #include "jvmti.h"
27 
28 #include <slicer/dex_ir.h>
29 #include <slicer/code_ir.h>
30 #include <slicer/dex_ir_builder.h>
31 #include <slicer/dex_utf8.h>
32 #include <slicer/writer.h>
33 #include <slicer/reader.h>
34 #include <slicer/instrumentation.h>
35 
36 using namespace dex;
37 using namespace lir;
38 
39 namespace com_android_dx_mockito_inline {
40 static jvmtiEnv* localJvmtiEnv;
41 
42 static jobject sTransformer;
43 
44 // Converts a class name to a type descriptor
45 // (ex. "java.lang.String" to "Ljava/lang/String;")
46 static std::string
ClassNameToDescriptor(const char * class_name)47 ClassNameToDescriptor(const char* class_name) {
48     std::stringstream ss;
49     ss << "L";
50     for (auto p = class_name; *p != '\0'; ++p) {
51          ss << (*p == '.' ? '/' : *p);
52     }
53     ss << ";";
54     return ss.str();
55 }
56 
57 // Takes the full dex file for class 'classBeingRedefined'
58 // - isolates the dex code for the class out of the dex file
59 // - calls sTransformer.runTransformers on the isolated dex code
60 // - send the transformed code back to the runtime
61 static void
Transform(jvmtiEnv * jvmti_env,JNIEnv * env,jclass classBeingRedefined,jobject loader,const char * name,jobject protectionDomain,jint classDataLen,const unsigned char * classData,jint * newClassDataLen,unsigned char ** newClassData)62 Transform(jvmtiEnv* jvmti_env,
63           JNIEnv* env,
64           jclass classBeingRedefined,
65           jobject loader,
66           const char* name,
67           jobject protectionDomain,
68           jint classDataLen,
69           const unsigned char* classData,
70           jint* newClassDataLen,
71           unsigned char** newClassData) {
72     if (sTransformer != NULL) {
73         // Even reading the classData array is expensive as the data is only generated when the
74         // memory is touched. Hence call JvmtiAgent#shouldTransform to check if we need to transform
75         // the class.
76         jclass cls = env->GetObjectClass(sTransformer);
77         jmethodID shouldTransformMethod = env->GetMethodID(cls, "shouldTransform",
78                                                            "(Ljava/lang/Class;)Z");
79 
80         jboolean shouldTransform = env->CallBooleanMethod(sTransformer, shouldTransformMethod,
81                                                           classBeingRedefined);
82         if (!shouldTransform) {
83             return;
84         }
85 
86         // Isolate byte code of class class. This is needed as Android usually gives us more
87         // than the class we need.
88         Reader reader(classData, classDataLen);
89 
90         u4 index = reader.FindClassIndex(ClassNameToDescriptor(name).c_str());
91         reader.CreateClassIr(index);
92         std::shared_ptr<ir::DexFile> ir = reader.GetIr();
93 
94         struct Allocator : public Writer::Allocator {
95             virtual void* Allocate(size_t size) {return ::malloc(size);}
96             virtual void Free(void* ptr) {::free(ptr);}
97         };
98 
99         Allocator allocator;
100         Writer writer(ir);
101         size_t isolatedClassLen = 0;
102         std::shared_ptr<jbyte> isolatedClass((jbyte*)writer.CreateImage(&allocator,
103                                                                         &isolatedClassLen));
104 
105         // Create jbyteArray with isolated byte code of class
106         jbyteArray isolatedClassArr = env->NewByteArray(isolatedClassLen);
107         env->SetByteArrayRegion(isolatedClassArr, 0, isolatedClassLen,
108                                 isolatedClass.get());
109 
110         jstring nameStr = env->NewStringUTF(name);
111 
112         // Call JvmtiAgent#runTransformers
113         jmethodID runTransformersMethod = env->GetMethodID(cls, "runTransformers",
114                                                            "(Ljava/lang/ClassLoader;"
115                                                            "Ljava/lang/String;"
116                                                            "Ljava/lang/Class;"
117                                                            "Ljava/security/ProtectionDomain;"
118                                                            "[B)[B");
119 
120         jbyteArray transformedArr = (jbyteArray) env->CallObjectMethod(sTransformer,
121                                                                        runTransformersMethod,
122                                                                        loader, nameStr,
123                                                                        classBeingRedefined,
124                                                                        protectionDomain,
125                                                                        isolatedClassArr);
126 
127         // Set transformed byte code
128         if (!env->ExceptionOccurred() && transformedArr != NULL) {
129             *newClassDataLen = env->GetArrayLength(transformedArr);
130 
131             jbyte* transformed = env->GetByteArrayElements(transformedArr, 0);
132 
133             jvmti_env->Allocate(*newClassDataLen, newClassData);
134             std::memcpy(*newClassData, transformed, *newClassDataLen);
135 
136             env->ReleaseByteArrayElements(transformedArr, transformed, 0);
137         }
138     }
139 }
140 
141 // Add a label before instructionAfter
142 static void
addLabel(CodeIr & c,lir::Instruction * instructionAfter,Label * returnTrueLabel)143 addLabel(CodeIr& c,
144          lir::Instruction* instructionAfter,
145          Label* returnTrueLabel) {
146     c.instructions.InsertBefore(instructionAfter, returnTrueLabel);
147 }
148 
149 // Add a byte code before instructionAfter
150 static void
addInstr(CodeIr & c,lir::Instruction * instructionAfter,Opcode opcode,const std::list<Operand * > & operands)151 addInstr(CodeIr& c,
152          lir::Instruction* instructionAfter,
153          Opcode opcode,
154          const std::list<Operand*>& operands) {
155     auto instruction = c.Alloc<Bytecode>();
156 
157     instruction->opcode = opcode;
158 
159     for (auto it = operands.begin(); it != operands.end(); it++) {
160         instruction->operands.push_back(*it);
161     }
162 
163     c.instructions.InsertBefore(instructionAfter, instruction);
164 }
165 
166 // Add a method call byte code before instructionAfter
167 static void
addCall(ir::Builder & b,CodeIr & c,lir::Instruction * instructionAfter,Opcode opcode,ir::Type * type,const char * methodName,ir::Type * returnType,const std::vector<ir::Type * > & types,const std::list<int> & regs)168 addCall(ir::Builder& b,
169         CodeIr& c,
170         lir::Instruction* instructionAfter,
171         Opcode opcode,
172         ir::Type* type,
173         const char* methodName,
174         ir::Type* returnType,
175         const std::vector<ir::Type*>& types,
176         const std::list<int>& regs) {
177     auto proto = b.GetProto(returnType, b.GetTypeList(types));
178     auto method = b.GetMethodDecl(b.GetAsciiString(methodName), proto, type);
179 
180     VRegList* param_regs = c.Alloc<VRegList>();
181     for (auto it = regs.begin(); it != regs.end(); it++) {
182         param_regs->registers.push_back(*it);
183     }
184 
185     addInstr(c, instructionAfter, opcode, {param_regs, c.Alloc<Method>(method,
186                                                                        method->orig_index)});
187 }
188 
189 typedef struct {
190     ir::Type* boxedType;
191     ir::Type* scalarType;
192     std::string unboxMethod;
193 } BoxingInfo;
194 
195 // Get boxing / unboxing info for a type
196 static BoxingInfo
getBoxingInfo(ir::Builder & b,char typeCode)197 getBoxingInfo(ir::Builder &b,
198               char typeCode) {
199     BoxingInfo boxingInfo;
200 
201     if (typeCode != 'L' && typeCode !=  '[') {
202         std::stringstream tmp;
203         tmp << typeCode;
204         boxingInfo.scalarType = b.GetType(tmp.str().c_str());
205     }
206 
207     switch (typeCode) {
208         case 'B':
209             boxingInfo.boxedType = b.GetType("Ljava/lang/Byte;");
210             boxingInfo.unboxMethod = "byteValue";
211             break;
212         case 'S':
213             boxingInfo.boxedType = b.GetType("Ljava/lang/Short;");
214             boxingInfo.unboxMethod = "shortValue";
215             break;
216         case 'I':
217             boxingInfo.boxedType = b.GetType("Ljava/lang/Integer;");
218             boxingInfo.unboxMethod = "intValue";
219             break;
220         case 'C':
221             boxingInfo.boxedType = b.GetType("Ljava/lang/Character;");
222             boxingInfo.unboxMethod = "charValue";
223             break;
224         case 'F':
225             boxingInfo.boxedType = b.GetType("Ljava/lang/Float;");
226             boxingInfo.unboxMethod = "floatValue";
227             break;
228         case 'Z':
229             boxingInfo.boxedType = b.GetType("Ljava/lang/Boolean;");
230             boxingInfo.unboxMethod = "booleanValue";
231             break;
232         case 'J':
233             boxingInfo.boxedType = b.GetType("Ljava/lang/Long;");
234             boxingInfo.unboxMethod = "longValue";
235             break;
236         case 'D':
237             boxingInfo.boxedType = b.GetType("Ljava/lang/Double;");
238             boxingInfo.unboxMethod = "doubleValue";
239             break;
240         default:
241             // real object
242             break;
243     }
244 
245     return boxingInfo;
246 }
247 
248 static size_t
getNumParams(ir::EncodedMethod * method)249 getNumParams(ir::EncodedMethod *method) {
250     if (method->decl->prototype->param_types == NULL) {
251         return 0;
252     }
253 
254     return method->decl->prototype->param_types->types.size();
255 }
256 
257 static bool
canBeTransformed(ir::EncodedMethod * method)258 canBeTransformed(ir::EncodedMethod *method) {
259     std::string type = method->decl->parent->Decl();
260     ir::String* methodName = method->decl->name;
261 
262     return !(((method->access_flags & (kAccAbstract | kAccPrivate | kAccBridge | kAccNative
263                                        | kAccStatic)) != 0)
264              || (Utf8Cmp(methodName->c_str(), "<init>") == 0)
265              || (Utf8Cmp(methodName->c_str(), "<clinit>") == 0)
266              || (Utf8Cmp(type.c_str(), "java.lang.Object") == 0
267                  && Utf8Cmp(methodName->c_str(), "finalize") == 0
268                  && getNumParams(method) == 0)
269              || (strncmp(type.c_str(), "java.", 5) == 0
270                  && (method->access_flags & (kAccPrivate | kAccPublic | kAccProtected)) == 0)
271              // getClass is used by MockMethodAdvice.isOverridden
272              || (Utf8Cmp(methodName->c_str(), "getClass") == 0));
273 }
274 
275 static bool
isHashCode(ir::EncodedMethod * method)276 isHashCode(ir::EncodedMethod *method) {
277     return Utf8Cmp(method->decl->name->c_str(), "hashCode") == 0
278            && getNumParams(method) == 0;
279 }
280 
281 static bool
isEquals(ir::EncodedMethod * method)282 isEquals(ir::EncodedMethod *method) {
283     return Utf8Cmp(method->decl->name->c_str(), "equals") == 0
284            && getNumParams(method) == 1
285            && Utf8Cmp(method->decl->prototype->param_types->types[0]->Decl().c_str(),
286                       "java.lang.Object") == 0;
287 }
288 
289 // Transforms the classes to add the mockito hooks
290 // - equals and hashcode are handled in a special way
291 extern "C" JNIEXPORT jbyteArray JNICALL
Java_com_android_dx_mockito_inline_ClassTransformer_nativeRedefine(JNIEnv * env,jobject generator,jstring idStr,jbyteArray originalArr)292 Java_com_android_dx_mockito_inline_ClassTransformer_nativeRedefine(JNIEnv* env,
293                                                                    jobject generator,
294                                                                    jstring idStr,
295                                                                    jbyteArray originalArr) {
296     unsigned char* original = (unsigned char*)env->GetByteArrayElements(originalArr, 0);
297 
298     Reader reader(original, env->GetArrayLength(originalArr));
299     reader.CreateClassIr(0);
300     std::shared_ptr<ir::DexFile> dex_ir = reader.GetIr();
301     ir::Builder b(dex_ir);
302 
303     ir::Type* booleanScalarT = b.GetType("Z");
304     ir::Type* intScalarT = b.GetType("I");
305     ir::Type* objectT = b.GetType("Ljava/lang/Object;");
306     ir::Type* objectArrayT = b.GetType("[Ljava/lang/Object;");
307     ir::Type* stringT = b.GetType("Ljava/lang/String;");
308     ir::Type* methodT = b.GetType("Ljava/lang/reflect/Method;");
309     ir::Type* systemT = b.GetType("Ljava/lang/System;");
310     ir::Type* callableT = b.GetType("Ljava/util/concurrent/Callable;");
311     ir::Type* dispatcherT = b.GetType("Lcom/android/dx/mockito/inline/MockMethodDispatcher;");
312 
313     // Add id to dex file
314     const char* idNative = env->GetStringUTFChars(idStr, 0);
315     ir::String* id = b.GetAsciiString(idNative);
316     env->ReleaseStringUTFChars(idStr, idNative);
317 
318     for (auto& method : dex_ir->encoded_methods) {
319         if (!canBeTransformed(method.get())) {
320             continue;
321         }
322 
323         if (isEquals(method.get())) {
324             /*
325             equals_original(Object other) {
326                 T t = foo(other);
327                 return bar(t);
328             }
329 
330             equals_transformed(params) {
331                 // MockMethodDispatcher dispatcher = MockMethodDispatcher.get(idStr, this);
332                 const-string v0, "65463hg34t"
333                 move-objectfrom16 v1, THIS
334                 invoke-static {v0, v1}, MockMethodDispatcher.get(String, Object):MockMethodDispatcher
335                 move-result-object v2
336 
337                 // if (dispatcher == null || ) {
338                 //     goto original_method;
339                 // }
340                 if-eqz v2, original_method
341 
342                 // if (!dispatcher.isMock(this)) {
343                 //     goto original_method;
344                 // }
345                 invoke-virtual {v2, v1}, MockMethodDispatcher.isMock(Object):Method
346                 move-result v2
347                 if-eqz v2, original_method
348 
349                 // return self == other
350                 move-objectfrom16 v0, ARG1
351                 if-eq v0, v1, return_true
352 
353                 const v0, 0
354                 return v0
355 
356             return true:
357                 const v0, 1
358                 return v0
359 
360             original_method:
361                 // Move all method arguments down so that they match what the original code expects.
362                 move-object16 v4, v5      # THIS
363                 move-object16 v5, v6      # ARG1
364 
365                 T t = foo(other);
366                 return bar(t);
367             }
368             */
369 
370             CodeIr c(method.get(), dex_ir);
371 
372             // Make sure there are at least 5 local registers to use
373             int originalNumRegisters = method->code->registers - method->code->ins_count;
374             int numAdditionalRegs = std::max(0, 3 - originalNumRegisters);
375             int thisReg = numAdditionalRegs + method->code->registers
376                           - method->code->ins_count;
377 
378             if (numAdditionalRegs > 0) {
379                 c.ir_method->code->registers += numAdditionalRegs;
380             }
381 
382             lir::Instruction* fi = *(c.instructions.begin());
383 
384             Label* originalMethodLabel = c.Alloc<Label>(0);
385             Label* returnTrueLabel = c.Alloc<Label>(0);
386             CodeLocation* originalMethod = c.Alloc<CodeLocation>(originalMethodLabel);
387             VReg* v0 = c.Alloc<VReg>(0);
388             VReg* v1 = c.Alloc<VReg>(1);
389             VReg* v2 = c.Alloc<VReg>(2);
390             VReg* thiz = c.Alloc<VReg>(thisReg);
391 
392             addInstr(c, fi, OP_CONST_STRING, {v0, c.Alloc<String>(id, id->orig_index)});
393             addInstr(c, fi, OP_MOVE_OBJECT_FROM16, {v1, thiz});
394             addCall(b, c, fi, OP_INVOKE_STATIC, dispatcherT, "get", dispatcherT,
395                     {stringT, objectT}, {0, 1});
396             addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v2});
397             addInstr(c, fi, OP_IF_EQZ, {v2, originalMethod});
398             addCall(b, c, fi, OP_INVOKE_VIRTUAL, dispatcherT, "isMock", booleanScalarT, {objectT},
399                     {2, 1});
400             addInstr(c, fi, OP_MOVE_RESULT, {v2});
401             addInstr(c, fi, OP_IF_EQZ, {v2, originalMethod});
402             addInstr(c, fi, OP_MOVE_OBJECT_FROM16, {v0, c.Alloc<VReg>(thisReg + 1)});
403             addInstr(c, fi, OP_IF_EQ, {v0, v1, c.Alloc<CodeLocation>(returnTrueLabel)});
404             addInstr(c, fi, OP_CONST, {v0, c.Alloc<Const32>(0)});
405             addInstr(c, fi, OP_RETURN, {v0});
406             addLabel(c, fi, returnTrueLabel);
407             addInstr(c, fi, OP_CONST, {v0, c.Alloc<Const32>(1)});
408             addInstr(c, fi, OP_RETURN, {v0});
409             addLabel(c, fi, originalMethodLabel);
410             addInstr(c, fi, OP_MOVE_OBJECT_16, {c.Alloc<VReg>(thisReg - numAdditionalRegs), thiz});
411             addInstr(c, fi, OP_MOVE_OBJECT_16, {c.Alloc<VReg>(thisReg - numAdditionalRegs + 1),
412                      c.Alloc<VReg>(thisReg + 1)});
413 
414             c.Assemble();
415         } else if (isHashCode(method.get())) {
416             /*
417             hashCode_original(Object other) {
418                 T t = foo(other);
419                 return bar(t);
420             }
421 
422             hashCode_transformed(params) {
423                 // MockMethodDispatcher dispatcher = MockMethodDispatcher.get(idStr, this);
424                 const-string v0, "65463hg34t"
425                 move-objectfrom16 v1, THIS
426                 invoke-static {v0, v1}, MockMethodDispatcher.get(String, Object):MockMethodDispatcher
427                 move-result-object v2
428 
429                 // if (dispatcher == null || ) {
430                 //     goto original_method;
431                 // }
432                 if-eqz v2, original_method
433 
434                 // if (!dispatcher.isMock(this)) {
435                 //     goto original_method;
436                 // }
437                 invoke-interface {v2, v1}, MockMethodDispatcher.isMock(Object):Method
438                 move-result v2
439                 if-eqz v2, original_method
440 
441                 // return System.identityHashCode(this);
442                 invoke-static {v1}, System.identityHashCode(Object):int
443                 move-result v2
444                 return v2
445 
446             original_method:
447                 // Move all method arguments down so that they match what the original code expects.
448                 move-object16 v4, v5      # THIS
449 
450                 T t = foo(other);
451                 return bar(t);
452             }
453             */
454 
455             CodeIr c(method.get(), dex_ir);
456 
457             // Make sure there are at least 5 local registers to use
458             int originalNumRegisters = method->code->registers - method->code->ins_count;
459             int numAdditionalRegs = std::max(0, 3 - originalNumRegisters);
460             int thisReg = numAdditionalRegs + method->code->registers - method->code->ins_count;
461 
462             if (numAdditionalRegs > 0) {
463                 c.ir_method->code->registers += numAdditionalRegs;
464             }
465 
466             lir::Instruction* fi = *(c.instructions.begin());
467 
468             Label* originalMethodLabel = c.Alloc<Label>(0);
469             CodeLocation* originalMethod = c.Alloc<CodeLocation>(originalMethodLabel);
470             VReg* v0 = c.Alloc<VReg>(0);
471             VReg* v1 = c.Alloc<VReg>(1);
472             VReg* v2 = c.Alloc<VReg>(2);
473             VReg* thiz = c.Alloc<VReg>(thisReg);
474 
475             addInstr(c, fi, OP_CONST_STRING, {v0, c.Alloc<String>(id, id->orig_index)});
476             addInstr(c, fi, OP_MOVE_OBJECT_FROM16, {v1, thiz});
477             addCall(b, c, fi, OP_INVOKE_STATIC, dispatcherT, "get", dispatcherT,
478                     {stringT, objectT}, {0, 1});
479             addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v2});
480             addInstr(c, fi, OP_IF_EQZ, {v2, originalMethod});
481             addCall(b, c, fi, OP_INVOKE_VIRTUAL, dispatcherT, "isMock", booleanScalarT, {objectT},
482                     {2, 1});
483             addInstr(c, fi, OP_MOVE_RESULT, {v2});
484             addInstr(c, fi, OP_IF_EQZ, {v2, originalMethod});
485             addCall(b, c, fi, OP_INVOKE_STATIC, systemT, "identityHashCode", intScalarT, {objectT},
486                     {1});
487             addInstr(c, fi, OP_MOVE_RESULT, {v2});
488             addInstr(c, fi, OP_RETURN, {v2});
489             addLabel(c, fi, originalMethodLabel);
490             addInstr(c, fi, OP_MOVE_OBJECT_16, {c.Alloc<VReg>(thisReg - numAdditionalRegs), thiz});
491 
492             c.Assemble();
493         } else {
494             /*
495             long method_original(int param1, long param2, String param3) {
496                 foo();
497                 return bar();
498             }
499 
500             long method_transformed(int param1, long param2, String param3) {
501                 // MockMethodDispatcher dispatcher = MockMethodDispatcher.get(idStr, this);
502                 const-string v0, "65463hg34t"
503                 move-objectfrom16 v1, THIS       # this is necessary as invoke-static cannot deal
504                                                  # with medium or high registers and THIS might not
505                                                  # be low
506                 invoke-static {v0, v1}, MockMethodDispatcher.get(String, Object):MockMethodDispatcher
507                 move-result-object v0
508 
509                 // if (dispatcher == null) {
510                 //    goto original_method;
511                 // }
512                 if-eqz v0, original_method
513 
514                 // Method origin = dispatcher.getOrigin(this, methodDesc);
515                 const-string v1 "fully.qualified.ClassName#original_method(int, long, String)"
516                 move-objectfrom16 v2, THIS       # this is necessary as invoke-static cannot deal
517                                                  # with medium or high registers and THIS might not
518                                                  # be low
519                 invoke-virtual {v0, v2, v1}, MockMethodDispatcher.getOrigin(Object, String):Method
520                 move-result-object v1
521 
522                 // if (origin == null) {
523                 //     goto original_method;
524                 // }
525                 if-eqz v1, original_method
526 
527                 // Create an array with Objects of all parameters.
528 
529                 //     Object[] arguments = new Object[3]
530                 const v3, 3
531                 new-array v2, v3, Object[]
532 
533                 //     Integer param1Integer = Integer.valueOf(param1)
534                 move-from16 v3, ARG1     # this is necessary as invoke-static cannot deal with high
535                                          # registers and ARG1 might be high
536                 invoke-static {v3}, Integer.valueOf(int):Integer
537                 move-result-object v3
538 
539                 //     arguments[0] = param1Integer
540                 const v4, 0
541                 aput-object v3, v2, v4
542 
543                 //     Long param2Long = Long.valueOf(param2)
544                 move-widefrom16 v3:v4, ARG2.1:ARG2.2 # this is necessary as invoke-static cannot
545                                                      # deal with high registers and ARG2 might be
546                                                      # high
547                 invoke-static {v3, v4}, Long.valueOf(long):Long
548                 move-result-object v3
549 
550                 //     arguments[1] = param2Long
551                 const v4, 1
552                 aput-object v3, v2, v4
553 
554                 //     arguments[2] = param3
555                 const v4, 2
556                 move-objectfrom16 v3, ARG3     # this is necessary as aput-object cannot deal with
557                                                # high registers and ARG3 might be high
558                 aput-object v3, v2, v4
559 
560                 // Callable<?> mocked = dispatcher.handle(this, origin, arguments);
561                 move-objectfrom16 v3, THIS       # this is necessary as invoke-virtual cannot deal
562                                                  # with medium or high registers and THIS might not
563                                                  # be low
564                 invoke-virtual {v0,v3,v1,v2}, MockMethodDispatcher.handle(Object, Method,
565                                                                           Object[]):Callable
566                 move-result-object v0
567 
568                 //  if (mocked != null) {
569                 if-eqz v0, original_method
570 
571                 //      Object ret = mocked.call();
572                 invoke-interface {v0}, Callable.call():Object
573                 move-result-object v0
574 
575                 //      Long retLong = (Long)ret
576                 check-cast v0, Long
577 
578                 //      long retlong = retLong.longValue();
579                 invoke-virtual {v0}, Long.longValue():long
580                 move-result-wide v0:v1
581 
582                 //      return retlong;
583                 return-wide v0:v1
584 
585                 //  }
586 
587             original_method:
588                 // Move all method arguments down so that they match what the original code expects.
589                 // Let's assume three arguments, one int, one long, one String and the and used to
590                 // use 4 registers
591                 move-object16 v4, v5      # THIS
592                 move16 v5, v6             # ARG1
593                 move-wide16 v6:v7, v7:v8  # ARG2 (overlapping moves are allowed)
594                 move-object16 v8, v9      # ARG3
595 
596                 // foo();
597                 // return bar();
598                 unmodified original byte code
599             }
600             */
601 
602             CodeIr c(method.get(), dex_ir);
603 
604             // Make sure there are at least 5 local registers to use
605             int originalNumRegisters = method->code->registers - method->code->ins_count;
606             int numAdditionalRegs = std::max(0, 5 - originalNumRegisters);
607             int thisReg = originalNumRegisters + numAdditionalRegs;
608 
609             if (numAdditionalRegs > 0) {
610                 c.ir_method->code->registers += numAdditionalRegs;
611             }
612 
613             lir::Instruction* fi = *(c.instructions.begin());
614 
615             // Add methodDesc to dex file
616             std::stringstream ss;
617             ss << method->decl->parent->Decl() << "#" << method->decl->name->c_str() << "(" ;
618             bool first = true;
619             if (method->decl->prototype->param_types != NULL) {
620                  for (const auto& type : method->decl->prototype->param_types->types) {
621                      if (first) {
622                          first = false;
623                      } else {
624                          ss << ",";
625                      }
626 
627                      ss << type->Decl().c_str();
628                  }
629             }
630             ss << ")";
631             std::string methodDescStr = ss.str();
632             ir::String* methodDesc = b.GetAsciiString(methodDescStr.c_str());
633 
634             size_t numParams = getNumParams(method.get());
635 
636             Label* originalMethodLabel = c.Alloc<Label>(0);
637             CodeLocation* originalMethod = c.Alloc<CodeLocation>(originalMethodLabel);
638             VReg* v0 = c.Alloc<VReg>(0);
639             VReg* v1 = c.Alloc<VReg>(1);
640             VReg* v2 = c.Alloc<VReg>(2);
641             VReg* v3 = c.Alloc<VReg>(3);
642             VReg* v4 = c.Alloc<VReg>(4);
643             VReg* thiz = c.Alloc<VReg>(thisReg);
644 
645             addInstr(c, fi, OP_CONST_STRING, {v0, c.Alloc<String>(id, id->orig_index)});
646             addInstr(c, fi, OP_MOVE_OBJECT_FROM16, {v1, thiz});
647             addCall(b, c, fi, OP_INVOKE_STATIC, dispatcherT, "get", dispatcherT, {stringT, objectT},
648                     {0, 1});
649             addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v0});
650             addInstr(c, fi, OP_IF_EQZ, {v0, originalMethod});
651             addInstr(c, fi, OP_CONST_STRING,
652                      {v1, c.Alloc<String>(methodDesc, methodDesc->orig_index)});
653             addInstr(c, fi, OP_MOVE_OBJECT_FROM16, {v2, thiz});
654             addCall(b, c, fi, OP_INVOKE_VIRTUAL, dispatcherT, "getOrigin", methodT,
655                     {objectT, stringT}, {0, 2, 1});
656             addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v1});
657             addInstr(c, fi, OP_IF_EQZ, {v1, originalMethod});
658             addInstr(c, fi, OP_CONST, {v3, c.Alloc<Const32>(numParams)});
659             addInstr(c, fi, OP_NEW_ARRAY, {v2, v3, c.Alloc<Type>(objectArrayT,
660                                                                  objectArrayT->orig_index)});
661 
662             if (numParams > 0) {
663                 int argReg = thisReg + 1;
664 
665                 for (int argNum = 0; argNum < numParams; argNum++) {
666                     const auto& type = method->decl->prototype->param_types->types[argNum];
667                     BoxingInfo boxingInfo = getBoxingInfo(b, type->descriptor->c_str()[0]);
668 
669                     switch (type->GetCategory()) {
670                         case ir::Type::Category::Scalar:
671                             addInstr(c, fi, OP_MOVE_FROM16, {v3, c.Alloc<VReg>(argReg)});
672                             addCall(b, c, fi, OP_INVOKE_STATIC, boxingInfo.boxedType, "valueOf",
673                                     boxingInfo.boxedType, {type}, {3});
674                             addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v3});
675 
676                             argReg++;
677                             break;
678                         case ir::Type::Category::WideScalar: {
679                             VRegPair* v3v4 = c.Alloc<VRegPair>(3);
680                             VRegPair* argRegPair = c.Alloc<VRegPair>(argReg);
681 
682                             addInstr(c, fi, OP_MOVE_WIDE_FROM16, {v3v4, argRegPair});
683                             addCall(b, c, fi, OP_INVOKE_STATIC, boxingInfo.boxedType, "valueOf",
684                                     boxingInfo.boxedType, {type}, {3, 4});
685                             addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v3});
686 
687                             argReg += 2;
688                             break;
689                         }
690                         case ir::Type::Category::Reference:
691                             addInstr(c, fi, OP_MOVE_OBJECT_FROM16, {v3, c.Alloc<VReg>(argReg)});
692 
693                             argReg++;
694                             break;
695                         case ir::Type::Category::Void:
696                             assert(false);
697                     }
698 
699                     addInstr(c, fi, OP_CONST, {v4, c.Alloc<Const32>(argNum)});
700                     addInstr(c, fi, OP_APUT_OBJECT, {v3, v2, v4});
701                 }
702             }
703 
704             addInstr(c, fi, OP_MOVE_OBJECT_FROM16, {v3, thiz});
705             addCall(b, c, fi, OP_INVOKE_VIRTUAL, dispatcherT, "handle", callableT,
706                     {objectT, methodT, objectArrayT}, {0, 3, 1, 2});
707             addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v0});
708             addInstr(c, fi, OP_IF_EQZ, {v0, originalMethod});
709             addCall(b, c, fi, OP_INVOKE_INTERFACE, callableT, "call", objectT, {}, {0});
710             addInstr(c, fi, OP_MOVE_RESULT_OBJECT, {v0});
711 
712             ir::Type *returnType = method->decl->prototype->return_type;
713             BoxingInfo boxingInfo = getBoxingInfo(b, returnType->descriptor->c_str()[0]);
714 
715             switch (returnType->GetCategory()) {
716                 case ir::Type::Category::Scalar:
717                     addInstr(c, fi, OP_CHECK_CAST, {v0,
718                             c.Alloc<Type>(boxingInfo.boxedType, boxingInfo.boxedType->orig_index)});
719                     addCall(b, c, fi, OP_INVOKE_VIRTUAL, boxingInfo.boxedType,
720                             boxingInfo.unboxMethod.c_str(), returnType, {}, {0});
721                     addInstr(c, fi, OP_MOVE_RESULT, {v0});
722                     addInstr(c, fi, OP_RETURN, {v0});
723                     break;
724                 case ir::Type::Category::WideScalar: {
725                     VRegPair* v0v1 = c.Alloc<VRegPair>(0);
726 
727                     addInstr(c, fi, OP_CHECK_CAST, {v0,
728                             c.Alloc<Type>(boxingInfo.boxedType, boxingInfo.boxedType->orig_index)});
729                     addCall(b, c, fi, OP_INVOKE_VIRTUAL, boxingInfo.boxedType,
730                             boxingInfo.unboxMethod.c_str(), returnType, {}, {0});
731                     addInstr(c, fi, OP_MOVE_RESULT_WIDE, {v0v1});
732                     addInstr(c, fi, OP_RETURN_WIDE, {v0v1});
733                     break;
734                 }
735                 case ir::Type::Category::Reference:
736                     addInstr(c, fi, OP_CHECK_CAST, {v0, c.Alloc<Type>(returnType,
737                                                                       returnType->orig_index)});
738                     addInstr(c, fi, OP_RETURN_OBJECT, {v0});
739                     break;
740                 case ir::Type::Category::Void:
741                     addInstr(c, fi, OP_RETURN_VOID, {});
742                     break;
743             }
744 
745             addLabel(c, fi, originalMethodLabel);
746             addInstr(c, fi, OP_MOVE_OBJECT_16, {c.Alloc<VReg>(thisReg - numAdditionalRegs), thiz});
747 
748             if (numParams > 0) {
749                 int argReg = thisReg + 1;
750 
751                 for (int argNum = 0; argNum < numParams; argNum++) {
752                     const auto& type = method->decl->prototype->param_types->types[argNum];
753                     int origReg = argReg - numAdditionalRegs;
754                     switch (type->GetCategory()) {
755                         case ir::Type::Category::Scalar:
756                             addInstr(c, fi, OP_MOVE_16, {c.Alloc<VReg>(origReg),
757                                      c.Alloc<VReg>(argReg)});
758                             argReg++;
759                             break;
760                         case ir::Type::Category::WideScalar:
761                             addInstr(c, fi, OP_MOVE_WIDE_16,{c.Alloc<VRegPair>(origReg),
762                                      c.Alloc<VRegPair>(argReg)});
763                             argReg +=2;
764                             break;
765                         case ir::Type::Category::Reference:
766                             addInstr(c, fi, OP_MOVE_OBJECT_16, {c.Alloc<VReg>(origReg),
767                                      c.Alloc<VReg>(argReg)});
768                             argReg++;
769                             break;
770                     }
771                 }
772             }
773 
774             c.Assemble();
775         }
776     }
777 
778     struct Allocator : public Writer::Allocator {
779         virtual void* Allocate(size_t size) {return ::malloc(size);}
780         virtual void Free(void* ptr) {::free(ptr);}
781     };
782 
783     Allocator allocator;
784     Writer writer(dex_ir);
785     size_t transformedLen = 0;
786     std::shared_ptr<jbyte> transformed((jbyte*)writer.CreateImage(&allocator, &transformedLen));
787 
788     jbyteArray transformedArr = env->NewByteArray(transformedLen);
789     env->SetByteArrayRegion(transformedArr, 0, transformedLen, transformed.get());
790 
791     return transformedArr;
792 }
793 
794 // Initializes the agent
Agent_OnAttach(JavaVM * vm,char * options,void * reserved)795 extern "C" jint Agent_OnAttach(JavaVM* vm,
796                                char* options,
797                                void* reserved) {
798     jint jvmError = vm->GetEnv(reinterpret_cast<void**>(&localJvmtiEnv), JVMTI_VERSION_1_2);
799     if (jvmError != JNI_OK) {
800         return jvmError;
801     }
802 
803     jvmtiCapabilities caps;
804     memset(&caps, 0, sizeof(caps));
805     caps.can_retransform_classes = 1;
806 
807     jvmtiError error = localJvmtiEnv->AddCapabilities(&caps);
808     if (error != JVMTI_ERROR_NONE) {
809         return error;
810     }
811 
812     jvmtiEventCallbacks cb;
813     memset(&cb, 0, sizeof(cb));
814     cb.ClassFileLoadHook = Transform;
815 
816     error = localJvmtiEnv->SetEventCallbacks(&cb, sizeof(cb));
817     if (error != JVMTI_ERROR_NONE) {
818         return error;
819     }
820 
821     error = localJvmtiEnv->SetEventNotificationMode(JVMTI_ENABLE, JVMTI_EVENT_CLASS_FILE_LOAD_HOOK,
822                                                     NULL);
823     if (error != JVMTI_ERROR_NONE) {
824         return error;
825     }
826 
827     return JVMTI_ERROR_NONE;
828 }
829 
830 // Throw runtime exception
throwRuntimeExpection(JNIEnv * env,const char * fmt,...)831 static void throwRuntimeExpection(JNIEnv* env, const char* fmt, ...) {
832     char msgBuf[512];
833 
834     va_list args;
835     va_start (args, fmt);
836     vsnprintf(msgBuf, sizeof(msgBuf), fmt, args);
837     va_end (args);
838 
839     jclass exceptionClass = env->FindClass("java/lang/RuntimeException");
840     env->ThrowNew(exceptionClass, msgBuf);
841 }
842 
843 // Register transformer hook
844 extern "C" JNIEXPORT void JNICALL
Java_com_android_dx_mockito_inline_JvmtiAgent_nativeRegisterTransformerHook(JNIEnv * env,jobject thiz)845 Java_com_android_dx_mockito_inline_JvmtiAgent_nativeRegisterTransformerHook(JNIEnv* env,
846                                                                             jobject thiz) {
847     sTransformer = env->NewGlobalRef(thiz);
848 }
849 
850 // Unregister transformer hook
851 extern "C" JNIEXPORT void JNICALL
Java_com_android_dx_mockito_inline_JvmtiAgent_nativeUnregisterTransformerHook(JNIEnv * env,jobject thiz)852 Java_com_android_dx_mockito_inline_JvmtiAgent_nativeUnregisterTransformerHook(JNIEnv* env,
853                                                                               jobject thiz) {
854     env->DeleteGlobalRef(sTransformer);
855     sTransformer = NULL;
856 }
857 
858 // Triggers retransformation of classes via this file's Transform method
859 extern "C" JNIEXPORT void JNICALL
Java_com_android_dx_mockito_inline_JvmtiAgent_nativeRetransformClasses(JNIEnv * env,jobject thiz,jobjectArray classes)860 Java_com_android_dx_mockito_inline_JvmtiAgent_nativeRetransformClasses(JNIEnv* env,
861                                                                        jobject thiz,
862                                                                        jobjectArray classes) {
863     jsize numTransformedClasses = env->GetArrayLength(classes);
864     jclass *transformedClasses = (jclass*) malloc(numTransformedClasses * sizeof(jclass));
865     for (int i = 0; i < numTransformedClasses; i++) {
866         transformedClasses[i] = (jclass) env->NewGlobalRef(env->GetObjectArrayElement(classes, i));
867     }
868 
869     jvmtiError error = localJvmtiEnv->RetransformClasses(numTransformedClasses,
870                                                          transformedClasses);
871 
872     for (int i = 0; i < numTransformedClasses; i++) {
873         env->DeleteGlobalRef(transformedClasses[i]);
874     }
875     free(transformedClasses);
876 
877     if (error != JVMTI_ERROR_NONE) {
878         throwRuntimeExpection(env, "Could not retransform classes: %d", error);
879     }
880 }
881 
882 // Adds a jar file to the bootstrap class loader
883 extern "C" JNIEXPORT void JNICALL
Java_com_android_dx_mockito_inline_JvmtiAgent_nativeAppendToBootstrapClassLoaderSearch(JNIEnv * env,jclass klass,jstring jarFile)884 Java_com_android_dx_mockito_inline_JvmtiAgent_nativeAppendToBootstrapClassLoaderSearch(JNIEnv* env,
885                                                                                   jclass klass,
886                                                                                   jstring jarFile) {
887     const char *jarFileNative = env->GetStringUTFChars(jarFile, 0);
888     jvmtiError error = localJvmtiEnv->AddToBootstrapClassLoaderSearch(jarFileNative);
889 
890     if (error != JVMTI_ERROR_NONE) {
891         throwRuntimeExpection(env, "Could not add %s to bootstrap class path: %d", jarFileNative,
892                               error);
893     }
894 
895     env->ReleaseStringUTFChars(jarFile, jarFileNative);
896 }
897 }  // namespace com_android_dx_mockito_inline
898 
899