1 /*
2  * Copyright (C) 2012 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "jni_compiler.h"
18 
19 #include "base/logging.h"
20 #include "class_linker.h"
21 #include "compiled_method.h"
22 #include "dex_file-inl.h"
23 #include "driver/compiler_driver.h"
24 #include "driver/dex_compilation_unit.h"
25 #include "llvm/compiler_llvm.h"
26 #include "llvm/ir_builder.h"
27 #include "llvm/llvm_compilation_unit.h"
28 #include "llvm/runtime_support_llvm_func.h"
29 #include "llvm/utils_llvm.h"
30 #include "mirror/art_method.h"
31 #include "runtime.h"
32 #include "stack.h"
33 #include "thread.h"
34 
35 #include <llvm/ADT/SmallVector.h>
36 #include <llvm/IR/BasicBlock.h>
37 #include <llvm/IR/DerivedTypes.h>
38 #include <llvm/IR/Function.h>
39 #include <llvm/IR/Type.h>
40 
41 namespace art {
42 namespace llvm {
43 
44 using ::art::llvm::runtime_support::JniMethodEnd;
45 using ::art::llvm::runtime_support::JniMethodEndSynchronized;
46 using ::art::llvm::runtime_support::JniMethodEndWithReference;
47 using ::art::llvm::runtime_support::JniMethodEndWithReferenceSynchronized;
48 using ::art::llvm::runtime_support::JniMethodStart;
49 using ::art::llvm::runtime_support::JniMethodStartSynchronized;
50 using ::art::llvm::runtime_support::RuntimeId;
51 
JniCompiler(LlvmCompilationUnit * cunit,CompilerDriver * driver,const DexCompilationUnit * dex_compilation_unit)52 JniCompiler::JniCompiler(LlvmCompilationUnit* cunit,
53                          CompilerDriver* driver,
54                          const DexCompilationUnit* dex_compilation_unit)
55     : cunit_(cunit), driver_(driver), module_(cunit_->GetModule()),
56       context_(cunit_->GetLLVMContext()), irb_(*cunit_->GetIRBuilder()),
57       dex_compilation_unit_(dex_compilation_unit),
58       func_(NULL), elf_func_idx_(0) {
59   // Check: Ensure that JNI compiler will only get "native" method
60   CHECK(dex_compilation_unit->IsNative());
61 }
62 
Compile()63 CompiledMethod* JniCompiler::Compile() {
64   const bool is_static = dex_compilation_unit_->IsStatic();
65   const bool is_synchronized = dex_compilation_unit_->IsSynchronized();
66   const DexFile* dex_file = dex_compilation_unit_->GetDexFile();
67   DexFile::MethodId const& method_id =
68       dex_file->GetMethodId(dex_compilation_unit_->GetDexMethodIndex());
69   char const return_shorty = dex_file->GetMethodShorty(method_id)[0];
70   ::llvm::Value* this_object_or_class_object;
71 
72   uint32_t method_idx = dex_compilation_unit_->GetDexMethodIndex();
73   std::string func_name(StringPrintf("jni_%s",
74                                      MangleForJni(PrettyMethod(method_idx, *dex_file)).c_str()));
75   CreateFunction(func_name);
76 
77   // Set argument name
78   ::llvm::Function::arg_iterator arg_begin(func_->arg_begin());
79   ::llvm::Function::arg_iterator arg_end(func_->arg_end());
80   ::llvm::Function::arg_iterator arg_iter(arg_begin);
81 
82   DCHECK_NE(arg_iter, arg_end);
83   arg_iter->setName("method");
84   ::llvm::Value* method_object_addr = arg_iter++;
85 
86   if (!is_static) {
87     // Non-static, the second argument is "this object"
88     this_object_or_class_object = arg_iter++;
89   } else {
90     // Load class object
91     this_object_or_class_object =
92         irb_.LoadFromObjectOffset(method_object_addr,
93                                   mirror::ArtMethod::DeclaringClassOffset().Int32Value(),
94                                   irb_.getJObjectTy(),
95                                   kTBAAConstJObject);
96   }
97   // Actual argument (ignore method and this object)
98   arg_begin = arg_iter;
99 
100   // Count the number of Object* arguments
101   uint32_t handle_scope_size = 1;
102   // "this" object pointer for non-static
103   // "class" object pointer for static
104   for (unsigned i = 0; arg_iter != arg_end; ++i, ++arg_iter) {
105 #if !defined(NDEBUG)
106     arg_iter->setName(StringPrintf("a%u", i));
107 #endif
108     if (arg_iter->getType() == irb_.getJObjectTy()) {
109       ++handle_scope_size;
110     }
111   }
112 
113   // Shadow stack
114   ::llvm::StructType* shadow_frame_type = irb_.getShadowFrameTy(handle_scope_size);
115   ::llvm::AllocaInst* shadow_frame_ = irb_.CreateAlloca(shadow_frame_type);
116 
117   // Store the dex pc
118   irb_.StoreToObjectOffset(shadow_frame_,
119                            ShadowFrame::DexPCOffset(),
120                            irb_.getInt32(DexFile::kDexNoIndex),
121                            kTBAAShadowFrame);
122 
123   // Push the shadow frame
124   ::llvm::Value* shadow_frame_upcast = irb_.CreateConstGEP2_32(shadow_frame_, 0, 0);
125   ::llvm::Value* old_shadow_frame =
126       irb_.Runtime().EmitPushShadowFrame(shadow_frame_upcast, method_object_addr, handle_scope_size);
127 
128   // Get JNIEnv
129   ::llvm::Value* jni_env_object_addr =
130       irb_.Runtime().EmitLoadFromThreadOffset(Thread::JniEnvOffset().Int32Value(),
131                                               irb_.getJObjectTy(),
132                                               kTBAARuntimeInfo);
133 
134   // Get callee code_addr
135   ::llvm::Value* code_addr =
136       irb_.LoadFromObjectOffset(method_object_addr,
137                                 mirror::ArtMethod::NativeMethodOffset().Int32Value(),
138                                 GetFunctionType(dex_compilation_unit_->GetDexMethodIndex(),
139                                                 is_static, true)->getPointerTo(),
140                                 kTBAARuntimeInfo);
141 
142   // Load actual parameters
143   std::vector< ::llvm::Value*> args;
144 
145   // The 1st parameter: JNIEnv*
146   args.push_back(jni_env_object_addr);
147 
148   // Variables for GetElementPtr
149   ::llvm::Value* gep_index[] = {
150     irb_.getInt32(0),  // No displacement for shadow frame pointer
151     irb_.getInt32(1),  // handle scope
152     NULL,
153   };
154 
155   size_t handle_scope_member_index = 0;
156 
157   // Store the "this object or class object" to handle scope
158   gep_index[2] = irb_.getInt32(handle_scope_member_index++);
159   ::llvm::Value* handle_scope_field_addr = irb_.CreateBitCast(irb_.CreateGEP(shadow_frame_, gep_index),
160                                                     irb_.getJObjectTy()->getPointerTo());
161   irb_.CreateStore(this_object_or_class_object, handle_scope_field_addr, kTBAAShadowFrame);
162   // Push the "this object or class object" to out args
163   this_object_or_class_object = irb_.CreateBitCast(handle_scope_field_addr, irb_.getJObjectTy());
164   args.push_back(this_object_or_class_object);
165   // Store arguments to handle scope, and push back to args
166   for (arg_iter = arg_begin; arg_iter != arg_end; ++arg_iter) {
167     if (arg_iter->getType() == irb_.getJObjectTy()) {
168       // Store the reference type arguments to handle scope
169       gep_index[2] = irb_.getInt32(handle_scope_member_index++);
170       ::llvm::Value* handle_scope_field_addr = irb_.CreateBitCast(irb_.CreateGEP(shadow_frame_, gep_index),
171                                                         irb_.getJObjectTy()->getPointerTo());
172       irb_.CreateStore(arg_iter, handle_scope_field_addr, kTBAAShadowFrame);
173       // Note null is placed in the handle scope but the jobject passed to the native code must be null
174       // (not a pointer into the handle scope as with regular references).
175       ::llvm::Value* equal_null = irb_.CreateICmpEQ(arg_iter, irb_.getJNull());
176       ::llvm::Value* arg =
177           irb_.CreateSelect(equal_null,
178                             irb_.getJNull(),
179                             irb_.CreateBitCast(handle_scope_field_addr, irb_.getJObjectTy()));
180       args.push_back(arg);
181     } else {
182       args.push_back(arg_iter);
183     }
184   }
185 
186   ::llvm::Value* saved_local_ref_cookie;
187   {  // JniMethodStart
188     RuntimeId func_id = is_synchronized ? JniMethodStartSynchronized
189                                         : JniMethodStart;
190     ::llvm::SmallVector< ::llvm::Value*, 2> args;
191     if (is_synchronized) {
192       args.push_back(this_object_or_class_object);
193     }
194     args.push_back(irb_.Runtime().EmitGetCurrentThread());
195     saved_local_ref_cookie =
196         irb_.CreateCall(irb_.GetRuntime(func_id), args);
197   }
198 
199   // Call!!!
200   ::llvm::Value* retval = irb_.CreateCall(code_addr, args);
201 
202   {  // JniMethodEnd
203     bool is_return_ref = return_shorty == 'L';
204     RuntimeId func_id =
205         is_return_ref ? (is_synchronized ? JniMethodEndWithReferenceSynchronized
206                                          : JniMethodEndWithReference)
207                       : (is_synchronized ? JniMethodEndSynchronized
208                                          : JniMethodEnd);
209     ::llvm::SmallVector< ::llvm::Value*, 4> args;
210     if (is_return_ref) {
211       args.push_back(retval);
212     }
213     args.push_back(saved_local_ref_cookie);
214     if (is_synchronized) {
215       args.push_back(this_object_or_class_object);
216     }
217     args.push_back(irb_.Runtime().EmitGetCurrentThread());
218 
219     ::llvm::Value* decoded_jobject =
220         irb_.CreateCall(irb_.GetRuntime(func_id), args);
221 
222     // Return decoded jobject if return reference.
223     if (is_return_ref) {
224       retval = decoded_jobject;
225     }
226   }
227 
228   // Pop the shadow frame
229   irb_.Runtime().EmitPopShadowFrame(old_shadow_frame);
230 
231   // Return!
232   switch (return_shorty) {
233     case 'V':
234       irb_.CreateRetVoid();
235       break;
236     case 'Z':
237     case 'C':
238       irb_.CreateRet(irb_.CreateZExt(retval, irb_.getInt32Ty()));
239       break;
240     case 'B':
241     case 'S':
242       irb_.CreateRet(irb_.CreateSExt(retval, irb_.getInt32Ty()));
243       break;
244     default:
245       irb_.CreateRet(retval);
246       break;
247   }
248 
249   // Verify the generated bitcode
250   VERIFY_LLVM_FUNCTION(*func_);
251 
252   cunit_->Materialize();
253 
254   return new CompiledMethod(*driver_, cunit_->GetInstructionSet(), cunit_->GetElfObject(),
255                             func_name);
256 }
257 
258 
CreateFunction(const std::string & func_name)259 void JniCompiler::CreateFunction(const std::string& func_name) {
260   CHECK_NE(0U, func_name.size());
261 
262   const bool is_static = dex_compilation_unit_->IsStatic();
263 
264   // Get function type
265   ::llvm::FunctionType* func_type =
266     GetFunctionType(dex_compilation_unit_->GetDexMethodIndex(), is_static, false);
267 
268   // Create function
269   func_ = ::llvm::Function::Create(func_type, ::llvm::Function::InternalLinkage,
270                                    func_name, module_);
271 
272   // Create basic block
273   ::llvm::BasicBlock* basic_block = ::llvm::BasicBlock::Create(*context_, "B0", func_);
274 
275   // Set insert point
276   irb_.SetInsertPoint(basic_block);
277 }
278 
279 
GetFunctionType(uint32_t method_idx,bool is_static,bool is_native_function)280 ::llvm::FunctionType* JniCompiler::GetFunctionType(uint32_t method_idx,
281                                                    bool is_static, bool is_native_function) {
282   // Get method signature
283   uint32_t shorty_size;
284   const char* shorty = dex_compilation_unit_->GetShorty(&shorty_size);
285   CHECK_GE(shorty_size, 1u);
286 
287   // Get return type
288   ::llvm::Type* ret_type = NULL;
289   switch (shorty[0]) {
290     case 'V': ret_type =  irb_.getJVoidTy(); break;
291     case 'Z':
292     case 'B':
293     case 'C':
294     case 'S':
295     case 'I': ret_type =  irb_.getJIntTy(); break;
296     case 'F': ret_type =  irb_.getJFloatTy(); break;
297     case 'J': ret_type =  irb_.getJLongTy(); break;
298     case 'D': ret_type =  irb_.getJDoubleTy(); break;
299     case 'L': ret_type =  irb_.getJObjectTy(); break;
300     default: LOG(FATAL)  << "Unreachable: unexpected return type in shorty " << shorty;
301   }
302   // Get argument type
303   std::vector< ::llvm::Type*> args_type;
304 
305   args_type.push_back(irb_.getJObjectTy());  // method object pointer
306 
307   if (!is_static || is_native_function) {
308     // "this" object pointer for non-static
309     // "class" object pointer for static naitve
310     args_type.push_back(irb_.getJType('L'));
311   }
312 
313   for (uint32_t i = 1; i < shorty_size; ++i) {
314     args_type.push_back(irb_.getJType(shorty[i]));
315   }
316 
317   return ::llvm::FunctionType::get(ret_type, args_type, false);
318 }
319 
320 }  // namespace llvm
321 }  // namespace art
322