1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // JNI wrapper for the TextClassifier.
18 
19 #include "textclassifier_jni.h"
20 
21 #include <jni.h>
22 #include <type_traits>
23 #include <vector>
24 
25 #include "text-classifier.h"
26 #include "util/base/integral_types.h"
27 #include "util/java/scoped_local_ref.h"
28 #include "util/java/string_utils.h"
29 #include "util/memory/mmap.h"
30 #include "util/utf8/unilib.h"
31 
32 using libtextclassifier2::AnnotatedSpan;
33 using libtextclassifier2::AnnotationOptions;
34 using libtextclassifier2::ClassificationOptions;
35 using libtextclassifier2::ClassificationResult;
36 using libtextclassifier2::CodepointSpan;
37 using libtextclassifier2::JStringToUtf8String;
38 using libtextclassifier2::Model;
39 using libtextclassifier2::ScopedLocalRef;
40 using libtextclassifier2::SelectionOptions;
41 using libtextclassifier2::TextClassifier;
42 #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
43 using libtextclassifier2::UniLib;
44 #endif
45 
46 namespace libtextclassifier2 {
47 
48 using libtextclassifier2::CodepointSpan;
49 
50 namespace {
51 
ToStlString(JNIEnv * env,const jstring & str)52 std::string ToStlString(JNIEnv* env, const jstring& str) {
53   std::string result;
54   JStringToUtf8String(env, str, &result);
55   return result;
56 }
57 
ClassificationResultsToJObjectArray(JNIEnv * env,const std::vector<ClassificationResult> & classification_result)58 jobjectArray ClassificationResultsToJObjectArray(
59     JNIEnv* env,
60     const std::vector<ClassificationResult>& classification_result) {
61   const ScopedLocalRef<jclass> result_class(
62       env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult"),
63       env);
64   if (!result_class) {
65     TC_LOG(ERROR) << "Couldn't find ClassificationResult class.";
66     return nullptr;
67   }
68   const ScopedLocalRef<jclass> datetime_parse_class(
69       env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$DatetimeResult"), env);
70   if (!datetime_parse_class) {
71     TC_LOG(ERROR) << "Couldn't find DatetimeResult class.";
72     return nullptr;
73   }
74 
75   const jmethodID result_class_constructor =
76       env->GetMethodID(result_class.get(), "<init>",
77                        "(Ljava/lang/String;FL" TC_PACKAGE_PATH TC_CLASS_NAME_STR
78                        "$DatetimeResult;)V");
79   const jmethodID datetime_parse_class_constructor =
80       env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V");
81 
82   const jobjectArray results = env->NewObjectArray(classification_result.size(),
83                                                    result_class.get(), nullptr);
84   for (int i = 0; i < classification_result.size(); i++) {
85     jstring row_string =
86         env->NewStringUTF(classification_result[i].collection.c_str());
87     jobject row_datetime_parse = nullptr;
88     if (classification_result[i].datetime_parse_result.IsSet()) {
89       row_datetime_parse = env->NewObject(
90           datetime_parse_class.get(), datetime_parse_class_constructor,
91           classification_result[i].datetime_parse_result.time_ms_utc,
92           classification_result[i].datetime_parse_result.granularity);
93     }
94     jobject result =
95         env->NewObject(result_class.get(), result_class_constructor, row_string,
96                        static_cast<jfloat>(classification_result[i].score),
97                        row_datetime_parse);
98     env->SetObjectArrayElement(results, i, result);
99     env->DeleteLocalRef(result);
100   }
101   return results;
102 }
103 
104 template <typename T, typename F>
CallJniMethod0(JNIEnv * env,jobject object,jclass class_object,F function,const std::string & method_name,const std::string & return_java_type)105 std::pair<bool, T> CallJniMethod0(JNIEnv* env, jobject object,
106                                   jclass class_object, F function,
107                                   const std::string& method_name,
108                                   const std::string& return_java_type) {
109   const jmethodID method = env->GetMethodID(class_object, method_name.c_str(),
110                                             ("()" + return_java_type).c_str());
111   if (!method) {
112     return std::make_pair(false, T());
113   }
114   return std::make_pair(true, (env->*function)(object, method));
115 }
116 
FromJavaSelectionOptions(JNIEnv * env,jobject joptions)117 SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) {
118   if (!joptions) {
119     return {};
120   }
121 
122   const ScopedLocalRef<jclass> options_class(
123       env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$SelectionOptions"),
124       env);
125   const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
126       env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
127       "getLocales", "Ljava/lang/String;");
128   if (!status_or_locales.first) {
129     return {};
130   }
131 
132   SelectionOptions options;
133   options.locales =
134       ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
135 
136   return options;
137 }
138 
139 template <typename T>
FromJavaOptionsInternal(JNIEnv * env,jobject joptions,const std::string & class_name)140 T FromJavaOptionsInternal(JNIEnv* env, jobject joptions,
141                           const std::string& class_name) {
142   if (!joptions) {
143     return {};
144   }
145 
146   const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()),
147                                              env);
148   if (!options_class) {
149     return {};
150   }
151 
152   const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
153       env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
154       "getLocale", "Ljava/lang/String;");
155   const std::pair<bool, jobject> status_or_reference_timezone =
156       CallJniMethod0<jobject>(env, joptions, options_class.get(),
157                               &JNIEnv::CallObjectMethod, "getReferenceTimezone",
158                               "Ljava/lang/String;");
159   const std::pair<bool, int64> status_or_reference_time_ms_utc =
160       CallJniMethod0<int64>(env, joptions, options_class.get(),
161                             &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc",
162                             "J");
163 
164   if (!status_or_locales.first || !status_or_reference_timezone.first ||
165       !status_or_reference_time_ms_utc.first) {
166     return {};
167   }
168 
169   T options;
170   options.locales =
171       ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
172   options.reference_timezone = ToStlString(
173       env, reinterpret_cast<jstring>(status_or_reference_timezone.second));
174   options.reference_time_ms_utc = status_or_reference_time_ms_utc.second;
175   return options;
176 }
177 
FromJavaClassificationOptions(JNIEnv * env,jobject joptions)178 ClassificationOptions FromJavaClassificationOptions(JNIEnv* env,
179                                                     jobject joptions) {
180   return FromJavaOptionsInternal<ClassificationOptions>(
181       env, joptions,
182       TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationOptions");
183 }
184 
FromJavaAnnotationOptions(JNIEnv * env,jobject joptions)185 AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) {
186   return FromJavaOptionsInternal<AnnotationOptions>(
187       env, joptions, TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotationOptions");
188 }
189 
ConvertIndicesBMPUTF8(const std::string & utf8_str,CodepointSpan orig_indices,bool from_utf8)190 CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str,
191                                     CodepointSpan orig_indices,
192                                     bool from_utf8) {
193   const libtextclassifier2::UnicodeText unicode_str =
194       libtextclassifier2::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false);
195 
196   int unicode_index = 0;
197   int bmp_index = 0;
198 
199   const int* source_index;
200   const int* target_index;
201   if (from_utf8) {
202     source_index = &unicode_index;
203     target_index = &bmp_index;
204   } else {
205     source_index = &bmp_index;
206     target_index = &unicode_index;
207   }
208 
209   CodepointSpan result{-1, -1};
210   std::function<void()> assign_indices_fn = [&result, &orig_indices,
211                                              &source_index, &target_index]() {
212     if (orig_indices.first == *source_index) {
213       result.first = *target_index;
214     }
215 
216     if (orig_indices.second == *source_index) {
217       result.second = *target_index;
218     }
219   };
220 
221   for (auto it = unicode_str.begin(); it != unicode_str.end();
222        ++it, ++unicode_index, ++bmp_index) {
223     assign_indices_fn();
224 
225     // There is 1 extra character in the input for each UTF8 character > 0xFFFF.
226     if (*it > 0xFFFF) {
227       ++bmp_index;
228     }
229   }
230   assign_indices_fn();
231 
232   return result;
233 }
234 
235 }  // namespace
236 
ConvertIndicesBMPToUTF8(const std::string & utf8_str,CodepointSpan bmp_indices)237 CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str,
238                                       CodepointSpan bmp_indices) {
239   return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false);
240 }
241 
ConvertIndicesUTF8ToBMP(const std::string & utf8_str,CodepointSpan utf8_indices)242 CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str,
243                                       CodepointSpan utf8_indices) {
244   return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true);
245 }
246 
GetFdFromAssetFileDescriptor(JNIEnv * env,jobject afd)247 jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) {
248   // Get system-level file descriptor from AssetFileDescriptor.
249   ScopedLocalRef<jclass> afd_class(
250       env->FindClass("android/content/res/AssetFileDescriptor"), env);
251   if (afd_class == nullptr) {
252     TC_LOG(ERROR) << "Couldn't find AssetFileDescriptor.";
253     return reinterpret_cast<jlong>(nullptr);
254   }
255   jmethodID afd_class_getFileDescriptor = env->GetMethodID(
256       afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;");
257   if (afd_class_getFileDescriptor == nullptr) {
258     TC_LOG(ERROR) << "Couldn't find getFileDescriptor.";
259     return reinterpret_cast<jlong>(nullptr);
260   }
261 
262   ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"),
263                                   env);
264   if (fd_class == nullptr) {
265     TC_LOG(ERROR) << "Couldn't find FileDescriptor.";
266     return reinterpret_cast<jlong>(nullptr);
267   }
268   jfieldID fd_class_descriptor =
269       env->GetFieldID(fd_class.get(), "descriptor", "I");
270   if (fd_class_descriptor == nullptr) {
271     TC_LOG(ERROR) << "Couldn't find descriptor.";
272     return reinterpret_cast<jlong>(nullptr);
273   }
274 
275   jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor);
276   return env->GetIntField(bundle_jfd, fd_class_descriptor);
277 }
278 
GetLocalesFromMmap(JNIEnv * env,libtextclassifier2::ScopedMmap * mmap)279 jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) {
280   if (!mmap->handle().ok()) {
281     return env->NewStringUTF("");
282   }
283   const Model* model = libtextclassifier2::ViewModel(
284       mmap->handle().start(), mmap->handle().num_bytes());
285   if (!model || !model->locales()) {
286     return env->NewStringUTF("");
287   }
288   return env->NewStringUTF(model->locales()->c_str());
289 }
290 
GetVersionFromMmap(JNIEnv * env,libtextclassifier2::ScopedMmap * mmap)291 jint GetVersionFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) {
292   if (!mmap->handle().ok()) {
293     return 0;
294   }
295   const Model* model = libtextclassifier2::ViewModel(
296       mmap->handle().start(), mmap->handle().num_bytes());
297   if (!model) {
298     return 0;
299   }
300   return model->version();
301 }
302 
GetNameFromMmap(JNIEnv * env,libtextclassifier2::ScopedMmap * mmap)303 jstring GetNameFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) {
304   if (!mmap->handle().ok()) {
305     return env->NewStringUTF("");
306   }
307   const Model* model = libtextclassifier2::ViewModel(
308       mmap->handle().start(), mmap->handle().num_bytes());
309   if (!model || !model->name()) {
310     return env->NewStringUTF("");
311   }
312   return env->NewStringUTF(model->name()->c_str());
313 }
314 
315 }  // namespace libtextclassifier2
316 
317 using libtextclassifier2::ClassificationResultsToJObjectArray;
318 using libtextclassifier2::ConvertIndicesBMPToUTF8;
319 using libtextclassifier2::ConvertIndicesUTF8ToBMP;
320 using libtextclassifier2::FromJavaAnnotationOptions;
321 using libtextclassifier2::FromJavaClassificationOptions;
322 using libtextclassifier2::FromJavaSelectionOptions;
323 using libtextclassifier2::ToStlString;
324 
JNI_METHOD(jlong,TC_CLASS_NAME,nativeNew)325 JNI_METHOD(jlong, TC_CLASS_NAME, nativeNew)
326 (JNIEnv* env, jobject thiz, jint fd) {
327 #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
328   return reinterpret_cast<jlong>(
329       TextClassifier::FromFileDescriptor(fd).release(), new UniLib(env));
330 #else
331   return reinterpret_cast<jlong>(
332       TextClassifier::FromFileDescriptor(fd).release());
333 #endif
334 }
335 
JNI_METHOD(jlong,TC_CLASS_NAME,nativeNewFromPath)336 JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromPath)
337 (JNIEnv* env, jobject thiz, jstring path) {
338   const std::string path_str = ToStlString(env, path);
339 #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
340   return reinterpret_cast<jlong>(
341       TextClassifier::FromPath(path_str, new UniLib(env)).release());
342 #else
343   return reinterpret_cast<jlong>(TextClassifier::FromPath(path_str).release());
344 #endif
345 }
346 
JNI_METHOD(jlong,TC_CLASS_NAME,nativeNewFromAssetFileDescriptor)347 JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromAssetFileDescriptor)
348 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
349   const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
350 #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
351   return reinterpret_cast<jlong>(
352       TextClassifier::FromFileDescriptor(fd, offset, size, new UniLib(env))
353           .release());
354 #else
355   return reinterpret_cast<jlong>(
356       TextClassifier::FromFileDescriptor(fd, offset, size).release());
357 #endif
358 }
359 
JNI_METHOD(jintArray,TC_CLASS_NAME,nativeSuggestSelection)360 JNI_METHOD(jintArray, TC_CLASS_NAME, nativeSuggestSelection)
361 (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
362  jint selection_end, jobject options) {
363   if (!ptr) {
364     return nullptr;
365   }
366 
367   TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);
368 
369   const std::string context_utf8 = ToStlString(env, context);
370   CodepointSpan input_indices =
371       ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
372   CodepointSpan selection = model->SuggestSelection(
373       context_utf8, input_indices, FromJavaSelectionOptions(env, options));
374   selection = ConvertIndicesUTF8ToBMP(context_utf8, selection);
375 
376   jintArray result = env->NewIntArray(2);
377   env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection)));
378   env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection)));
379   return result;
380 }
381 
JNI_METHOD(jobjectArray,TC_CLASS_NAME,nativeClassifyText)382 JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeClassifyText)
383 (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
384  jint selection_end, jobject options) {
385   if (!ptr) {
386     return nullptr;
387   }
388   TextClassifier* ff_model = reinterpret_cast<TextClassifier*>(ptr);
389 
390   const std::string context_utf8 = ToStlString(env, context);
391   const CodepointSpan input_indices =
392       ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
393   const std::vector<ClassificationResult> classification_result =
394       ff_model->ClassifyText(context_utf8, input_indices,
395                              FromJavaClassificationOptions(env, options));
396 
397   return ClassificationResultsToJObjectArray(env, classification_result);
398 }
399 
JNI_METHOD(jobjectArray,TC_CLASS_NAME,nativeAnnotate)400 JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeAnnotate)
401 (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options) {
402   if (!ptr) {
403     return nullptr;
404   }
405   TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);
406   std::string context_utf8 = ToStlString(env, context);
407   std::vector<AnnotatedSpan> annotations =
408       model->Annotate(context_utf8, FromJavaAnnotationOptions(env, options));
409 
410   jclass result_class =
411       env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan");
412   if (!result_class) {
413     TC_LOG(ERROR) << "Couldn't find result class: "
414                   << TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan";
415     return nullptr;
416   }
417 
418   jmethodID result_class_constructor = env->GetMethodID(
419       result_class, "<init>",
420       "(II[L" TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult;)V");
421 
422   jobjectArray results =
423       env->NewObjectArray(annotations.size(), result_class, nullptr);
424 
425   for (int i = 0; i < annotations.size(); ++i) {
426     CodepointSpan span_bmp =
427         ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span);
428     jobject result = env->NewObject(
429         result_class, result_class_constructor,
430         static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second),
431         ClassificationResultsToJObjectArray(env,
432 
433                                             annotations[i].classification));
434     env->SetObjectArrayElement(results, i, result);
435     env->DeleteLocalRef(result);
436   }
437   env->DeleteLocalRef(result_class);
438   return results;
439 }
440 
JNI_METHOD(void,TC_CLASS_NAME,nativeClose)441 JNI_METHOD(void, TC_CLASS_NAME, nativeClose)
442 (JNIEnv* env, jobject thiz, jlong ptr) {
443   TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);
444   delete model;
445 }
446 
JNI_METHOD(jstring,TC_CLASS_NAME,nativeGetLanguage)447 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLanguage)
448 (JNIEnv* env, jobject clazz, jint fd) {
449   TC_LOG(WARNING) << "Using deprecated getLanguage().";
450   return JNI_METHOD_NAME(TC_CLASS_NAME, nativeGetLocales)(env, clazz, fd);
451 }
452 
JNI_METHOD(jstring,TC_CLASS_NAME,nativeGetLocales)453 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocales)
454 (JNIEnv* env, jobject clazz, jint fd) {
455   const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
456       new libtextclassifier2::ScopedMmap(fd));
457   return GetLocalesFromMmap(env, mmap.get());
458 }
459 
JNI_METHOD(jstring,TC_CLASS_NAME,nativeGetLocalesFromAssetFileDescriptor)460 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocalesFromAssetFileDescriptor)
461 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
462   const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
463   const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
464       new libtextclassifier2::ScopedMmap(fd, offset, size));
465   return GetLocalesFromMmap(env, mmap.get());
466 }
467 
JNI_METHOD(jint,TC_CLASS_NAME,nativeGetVersion)468 JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersion)
469 (JNIEnv* env, jobject clazz, jint fd) {
470   const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
471       new libtextclassifier2::ScopedMmap(fd));
472   return GetVersionFromMmap(env, mmap.get());
473 }
474 
JNI_METHOD(jint,TC_CLASS_NAME,nativeGetVersionFromAssetFileDescriptor)475 JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersionFromAssetFileDescriptor)
476 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
477   const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
478   const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
479       new libtextclassifier2::ScopedMmap(fd, offset, size));
480   return GetVersionFromMmap(env, mmap.get());
481 }
482 
JNI_METHOD(jstring,TC_CLASS_NAME,nativeGetName)483 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetName)
484 (JNIEnv* env, jobject clazz, jint fd) {
485   const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
486       new libtextclassifier2::ScopedMmap(fd));
487   return GetNameFromMmap(env, mmap.get());
488 }
489 
JNI_METHOD(jstring,TC_CLASS_NAME,nativeGetNameFromAssetFileDescriptor)490 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetNameFromAssetFileDescriptor)
491 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
492   const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
493   const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
494       new libtextclassifier2::ScopedMmap(fd, offset, size));
495   return GetNameFromMmap(env, mmap.get());
496 }
497