1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // JNI wrapper for actions.
18 
19 #include "actions/actions_jni.h"
20 
21 #include <jni.h>
22 #include <map>
23 #include <type_traits>
24 #include <vector>
25 
26 #include "actions/actions-suggestions.h"
27 #include "annotator/annotator.h"
28 #include "annotator/annotator_jni_common.h"
29 #include "utils/base/integral_types.h"
30 #include "utils/intents/intent-generator.h"
31 #include "utils/intents/jni.h"
32 #include "utils/java/jni-cache.h"
33 #include "utils/java/scoped_local_ref.h"
34 #include "utils/java/string_utils.h"
35 #include "utils/memory/mmap.h"
36 
37 using libtextclassifier3::ActionsSuggestions;
38 using libtextclassifier3::ActionsSuggestionsResponse;
39 using libtextclassifier3::ActionSuggestion;
40 using libtextclassifier3::ActionSuggestionOptions;
41 using libtextclassifier3::Annotator;
42 using libtextclassifier3::Conversation;
43 using libtextclassifier3::IntentGenerator;
44 using libtextclassifier3::ScopedLocalRef;
45 using libtextclassifier3::ToStlString;
46 
47 // When using the Java's ICU, UniLib needs to be instantiated with a JavaVM
48 // pointer from JNI. When using a standard ICU the pointer is not needed and the
49 // objects are instantiated implicitly.
50 #ifdef TC3_UNILIB_JAVAICU
51 using libtextclassifier3::UniLib;
52 #endif
53 
54 namespace libtextclassifier3 {
55 
56 namespace {
57 
58 // Cached state for model inference.
59 // Keeps a jni cache, intent generator and model instance so that they don't
60 // have to be recreated for each call.
61 class ActionsSuggestionsJniContext {
62  public:
Create(const std::shared_ptr<libtextclassifier3::JniCache> & jni_cache,std::unique_ptr<ActionsSuggestions> model)63   static ActionsSuggestionsJniContext* Create(
64       const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
65       std::unique_ptr<ActionsSuggestions> model) {
66     if (jni_cache == nullptr || model == nullptr) {
67       return nullptr;
68     }
69     std::unique_ptr<IntentGenerator> intent_generator =
70         IntentGenerator::Create(model->model()->android_intent_options(),
71                                 model->model()->resources(), jni_cache);
72     std::unique_ptr<RemoteActionTemplatesHandler> template_handler =
73         libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache);
74 
75     if (intent_generator == nullptr || template_handler == nullptr) {
76       return nullptr;
77     }
78 
79     return new ActionsSuggestionsJniContext(jni_cache, std::move(model),
80                                             std::move(intent_generator),
81                                             std::move(template_handler));
82   }
83 
jni_cache() const84   std::shared_ptr<libtextclassifier3::JniCache> jni_cache() const {
85     return jni_cache_;
86   }
87 
model() const88   ActionsSuggestions* model() const { return model_.get(); }
89 
intent_generator() const90   IntentGenerator* intent_generator() const { return intent_generator_.get(); }
91 
template_handler() const92   RemoteActionTemplatesHandler* template_handler() const {
93     return template_handler_.get();
94   }
95 
96  private:
ActionsSuggestionsJniContext(const std::shared_ptr<libtextclassifier3::JniCache> & jni_cache,std::unique_ptr<ActionsSuggestions> model,std::unique_ptr<IntentGenerator> intent_generator,std::unique_ptr<RemoteActionTemplatesHandler> template_handler)97   ActionsSuggestionsJniContext(
98       const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
99       std::unique_ptr<ActionsSuggestions> model,
100       std::unique_ptr<IntentGenerator> intent_generator,
101       std::unique_ptr<RemoteActionTemplatesHandler> template_handler)
102       : jni_cache_(jni_cache),
103         model_(std::move(model)),
104         intent_generator_(std::move(intent_generator)),
105         template_handler_(std::move(template_handler)) {}
106 
107   std::shared_ptr<libtextclassifier3::JniCache> jni_cache_;
108   std::unique_ptr<ActionsSuggestions> model_;
109   std::unique_ptr<IntentGenerator> intent_generator_;
110   std::unique_ptr<RemoteActionTemplatesHandler> template_handler_;
111 };
112 
FromJavaActionSuggestionOptions(JNIEnv * env,jobject joptions)113 ActionSuggestionOptions FromJavaActionSuggestionOptions(JNIEnv* env,
114                                                         jobject joptions) {
115   ActionSuggestionOptions options = ActionSuggestionOptions::Default();
116   return options;
117 }
118 
ActionSuggestionsToJObjectArray(JNIEnv * env,const ActionsSuggestionsJniContext * context,jobject app_context,const reflection::Schema * annotations_entity_data_schema,const std::vector<ActionSuggestion> & action_result,const Conversation & conversation,const jstring device_locales,const bool generate_intents)119 jobjectArray ActionSuggestionsToJObjectArray(
120     JNIEnv* env, const ActionsSuggestionsJniContext* context,
121     jobject app_context,
122     const reflection::Schema* annotations_entity_data_schema,
123     const std::vector<ActionSuggestion>& action_result,
124     const Conversation& conversation, const jstring device_locales,
125     const bool generate_intents) {
126   const ScopedLocalRef<jclass> result_class(
127       env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
128                      "$ActionSuggestion"),
129       env);
130   if (!result_class) {
131     TC3_LOG(ERROR) << "Couldn't find ActionSuggestion class.";
132     return nullptr;
133   }
134 
135   const jmethodID result_class_constructor = env->GetMethodID(
136       result_class.get(), "<init>",
137       "(Ljava/lang/String;Ljava/lang/String;F[L" TC3_PACKAGE_PATH
138           TC3_NAMED_VARIANT_CLASS_NAME_STR
139       ";[B[L" TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";)V");
140   const jobjectArray results =
141       env->NewObjectArray(action_result.size(), result_class.get(), nullptr);
142   for (int i = 0; i < action_result.size(); i++) {
143     jobject extras = nullptr;
144 
145     const reflection::Schema* actions_entity_data_schema =
146         context->model()->entity_data_schema();
147     if (actions_entity_data_schema != nullptr &&
148         !action_result[i].serialized_entity_data.empty()) {
149       extras = context->template_handler()->EntityDataAsNamedVariantArray(
150           actions_entity_data_schema, action_result[i].serialized_entity_data);
151     }
152 
153     jbyteArray serialized_entity_data = nullptr;
154     if (!action_result[i].serialized_entity_data.empty()) {
155       serialized_entity_data =
156           env->NewByteArray(action_result[i].serialized_entity_data.size());
157       env->SetByteArrayRegion(
158           serialized_entity_data, 0,
159           action_result[i].serialized_entity_data.size(),
160           reinterpret_cast<const jbyte*>(
161               action_result[i].serialized_entity_data.data()));
162     }
163 
164     jobject remote_action_templates_result = nullptr;
165     if (generate_intents) {
166       std::vector<RemoteActionTemplate> remote_action_templates;
167       if (context->intent_generator()->GenerateIntents(
168               device_locales, action_result[i], conversation, app_context,
169               actions_entity_data_schema, annotations_entity_data_schema,
170               &remote_action_templates)) {
171         remote_action_templates_result =
172             context->template_handler()->RemoteActionTemplatesToJObjectArray(
173                 remote_action_templates);
174       }
175     }
176 
177     ScopedLocalRef<jstring> reply = context->jni_cache()->ConvertToJavaString(
178         action_result[i].response_text);
179 
180     ScopedLocalRef<jobject> result(env->NewObject(
181         result_class.get(), result_class_constructor, reply.get(),
182         env->NewStringUTF(action_result[i].type.c_str()),
183         static_cast<jfloat>(action_result[i].score), extras,
184         serialized_entity_data, remote_action_templates_result));
185     env->SetObjectArrayElement(results, i, result.get());
186   }
187   return results;
188 }
189 
FromJavaConversationMessage(JNIEnv * env,jobject jmessage)190 ConversationMessage FromJavaConversationMessage(JNIEnv* env, jobject jmessage) {
191   if (!jmessage) {
192     return {};
193   }
194 
195   const ScopedLocalRef<jclass> message_class(
196       env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
197                      "$ConversationMessage"),
198       env);
199   const std::pair<bool, jobject> status_or_text = CallJniMethod0<jobject>(
200       env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod, "getText",
201       "Ljava/lang/String;");
202   const std::pair<bool, int32> status_or_user_id =
203       CallJniMethod0<int32>(env, jmessage, message_class.get(),
204                             &JNIEnv::CallIntMethod, "getUserId", "I");
205   const std::pair<bool, int64> status_or_reference_time = CallJniMethod0<int64>(
206       env, jmessage, message_class.get(), &JNIEnv::CallLongMethod,
207       "getReferenceTimeMsUtc", "J");
208   const std::pair<bool, jobject> status_or_reference_timezone =
209       CallJniMethod0<jobject>(env, jmessage, message_class.get(),
210                               &JNIEnv::CallObjectMethod, "getReferenceTimezone",
211                               "Ljava/lang/String;");
212   const std::pair<bool, jobject> status_or_detected_text_language_tags =
213       CallJniMethod0<jobject>(
214           env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod,
215           "getDetectedTextLanguageTags", "Ljava/lang/String;");
216   if (!status_or_text.first || !status_or_user_id.first ||
217       !status_or_detected_text_language_tags.first ||
218       !status_or_reference_time.first || !status_or_reference_timezone.first) {
219     return {};
220   }
221 
222   ConversationMessage message;
223   message.text = ToStlString(env, static_cast<jstring>(status_or_text.second));
224   message.user_id = status_or_user_id.second;
225   message.reference_time_ms_utc = status_or_reference_time.second;
226   message.reference_timezone = ToStlString(
227       env, static_cast<jstring>(status_or_reference_timezone.second));
228   message.detected_text_language_tags = ToStlString(
229       env, static_cast<jstring>(status_or_detected_text_language_tags.second));
230   return message;
231 }
232 
FromJavaConversation(JNIEnv * env,jobject jconversation)233 Conversation FromJavaConversation(JNIEnv* env, jobject jconversation) {
234   if (!jconversation) {
235     return {};
236   }
237 
238   const ScopedLocalRef<jclass> conversation_class(
239       env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
240                      "$Conversation"),
241       env);
242 
243   const std::pair<bool, jobject> status_or_messages = CallJniMethod0<jobject>(
244       env, jconversation, conversation_class.get(), &JNIEnv::CallObjectMethod,
245       "getConversationMessages",
246       "[L" TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ConversationMessage;");
247 
248   if (!status_or_messages.first) {
249     return {};
250   }
251 
252   const jobjectArray jmessages =
253       reinterpret_cast<jobjectArray>(status_or_messages.second);
254 
255   const int size = env->GetArrayLength(jmessages);
256 
257   std::vector<ConversationMessage> messages;
258   for (int i = 0; i < size; i++) {
259     jobject jmessage = env->GetObjectArrayElement(jmessages, i);
260     ConversationMessage message = FromJavaConversationMessage(env, jmessage);
261     messages.push_back(message);
262   }
263   Conversation conversation;
264   conversation.messages = messages;
265   return conversation;
266 }
267 
GetLocalesFromMmap(JNIEnv * env,libtextclassifier3::ScopedMmap * mmap)268 jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
269   if (!mmap->handle().ok()) {
270     return env->NewStringUTF("");
271   }
272   const ActionsModel* model = libtextclassifier3::ViewActionsModel(
273       mmap->handle().start(), mmap->handle().num_bytes());
274   if (!model || !model->locales()) {
275     return env->NewStringUTF("");
276   }
277   return env->NewStringUTF(model->locales()->c_str());
278 }
279 
GetVersionFromMmap(JNIEnv * env,libtextclassifier3::ScopedMmap * mmap)280 jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
281   if (!mmap->handle().ok()) {
282     return 0;
283   }
284   const ActionsModel* model = libtextclassifier3::ViewActionsModel(
285       mmap->handle().start(), mmap->handle().num_bytes());
286   if (!model) {
287     return 0;
288   }
289   return model->version();
290 }
291 
GetNameFromMmap(JNIEnv * env,libtextclassifier3::ScopedMmap * mmap)292 jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
293   if (!mmap->handle().ok()) {
294     return env->NewStringUTF("");
295   }
296   const ActionsModel* model = libtextclassifier3::ViewActionsModel(
297       mmap->handle().start(), mmap->handle().num_bytes());
298   if (!model || !model->name()) {
299     return env->NewStringUTF("");
300   }
301   return env->NewStringUTF(model->name()->c_str());
302 }
303 }  // namespace
304 }  // namespace libtextclassifier3
305 
306 using libtextclassifier3::ActionsSuggestionsJniContext;
307 using libtextclassifier3::ActionSuggestionsToJObjectArray;
308 using libtextclassifier3::FromJavaActionSuggestionOptions;
309 using libtextclassifier3::FromJavaConversation;
310 
TC3_JNI_METHOD(jlong,TC3_ACTIONS_CLASS_NAME,nativeNewActionsModel)311 TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel)
312 (JNIEnv* env, jobject thiz, jint fd, jbyteArray serialized_preconditions) {
313   std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
314       libtextclassifier3::JniCache::Create(env);
315   std::string preconditions;
316   if (serialized_preconditions != nullptr &&
317       !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
318                                               &preconditions)) {
319     TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
320     return 0;
321   }
322 #ifdef TC3_UNILIB_JAVAICU
323   return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
324       jni_cache,
325       ActionsSuggestions::FromFileDescriptor(
326           fd, std::unique_ptr<UniLib>(new UniLib(jni_cache)), preconditions)));
327 #else
328   return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
329       jni_cache, ActionsSuggestions::FromFileDescriptor(fd, /*unilib=*/nullptr,
330                                                         preconditions)));
331 #endif  // TC3_UNILIB_JAVAICU
332 }
333 
TC3_JNI_METHOD(jlong,TC3_ACTIONS_CLASS_NAME,nativeNewActionsModelFromPath)334 TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
335 (JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions) {
336   std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
337       libtextclassifier3::JniCache::Create(env);
338   const std::string path_str = ToStlString(env, path);
339   std::string preconditions;
340   if (serialized_preconditions != nullptr &&
341       !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
342                                               &preconditions)) {
343     TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
344     return 0;
345   }
346 #ifdef TC3_UNILIB_JAVAICU
347   return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
348       jni_cache, ActionsSuggestions::FromPath(
349                      path_str, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
350                      preconditions)));
351 #else
352   return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
353       jni_cache, ActionsSuggestions::FromPath(path_str, /*unilib=*/nullptr,
354                                               preconditions)));
355 #endif  // TC3_UNILIB_JAVAICU
356 }
357 
TC3_JNI_METHOD(jobjectArray,TC3_ACTIONS_CLASS_NAME,nativeSuggestActions)358 TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
359 (JNIEnv* env, jobject clazz, jlong ptr, jobject jconversation, jobject joptions,
360  jlong annotatorPtr, jobject app_context, jstring device_locales,
361  jboolean generate_intents) {
362   if (!ptr) {
363     return nullptr;
364   }
365   const Conversation conversation = FromJavaConversation(env, jconversation);
366   const ActionSuggestionOptions options =
367       FromJavaActionSuggestionOptions(env, joptions);
368   const ActionsSuggestionsJniContext* context =
369       reinterpret_cast<ActionsSuggestionsJniContext*>(ptr);
370   const Annotator* annotator = reinterpret_cast<Annotator*>(annotatorPtr);
371 
372   const ActionsSuggestionsResponse response =
373       context->model()->SuggestActions(conversation, annotator, options);
374 
375   const reflection::Schema* anntotations_entity_data_schema =
376       annotator ? annotator->entity_data_schema() : nullptr;
377   return ActionSuggestionsToJObjectArray(
378       env, context, app_context, anntotations_entity_data_schema,
379       response.actions, conversation, device_locales, generate_intents);
380 }
381 
TC3_JNI_METHOD(void,TC3_ACTIONS_CLASS_NAME,nativeCloseActionsModel)382 TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
383 (JNIEnv* env, jobject clazz, jlong model_ptr) {
384   const ActionsSuggestionsJniContext* context =
385       reinterpret_cast<ActionsSuggestionsJniContext*>(model_ptr);
386   delete context;
387 }
388 
TC3_JNI_METHOD(jstring,TC3_ACTIONS_CLASS_NAME,nativeGetLocales)389 TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales)
390 (JNIEnv* env, jobject clazz, jint fd) {
391   const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
392       new libtextclassifier3::ScopedMmap(fd));
393   return libtextclassifier3::GetLocalesFromMmap(env, mmap.get());
394 }
395 
TC3_JNI_METHOD(jstring,TC3_ACTIONS_CLASS_NAME,nativeGetName)396 TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
397 (JNIEnv* env, jobject clazz, jint fd) {
398   const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
399       new libtextclassifier3::ScopedMmap(fd));
400   return libtextclassifier3::GetNameFromMmap(env, mmap.get());
401 }
402 
TC3_JNI_METHOD(jint,TC3_ACTIONS_CLASS_NAME,nativeGetVersion)403 TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
404 (JNIEnv* env, jobject clazz, jint fd) {
405   const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
406       new libtextclassifier3::ScopedMmap(fd));
407   return libtextclassifier3::GetVersionFromMmap(env, mmap.get());
408 }
409