1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "lang_id/lang-id_jni.h"
18 
19 #include <jni.h>
20 #include <type_traits>
21 #include <vector>
22 
23 #include "utils/base/logging.h"
24 #include "utils/java/scoped_local_ref.h"
25 #include "lang_id/fb_model/lang-id-from-fb.h"
26 #include "lang_id/lang-id.h"
27 
28 using libtextclassifier3::ScopedLocalRef;
29 using libtextclassifier3::ToStlString;
30 using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile;
31 using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFileDescriptor;
32 using libtextclassifier3::mobile::lang_id::LangId;
33 using libtextclassifier3::mobile::lang_id::LangIdResult;
34 
35 namespace {
LangIdResultToJObjectArray(JNIEnv * env,const LangIdResult & lang_id_result)36 jobjectArray LangIdResultToJObjectArray(JNIEnv* env,
37                                         const LangIdResult& lang_id_result) {
38   const ScopedLocalRef<jclass> result_class(
39       env->FindClass(TC3_PACKAGE_PATH TC3_LANG_ID_CLASS_NAME_STR
40                      "$LanguageResult"),
41       env);
42   if (!result_class) {
43     TC3_LOG(ERROR) << "Couldn't find LanguageResult class.";
44     return nullptr;
45   }
46 
47   // clang-format off
48   const std::vector<std::pair<std::string, float>>& predictions =
49       lang_id_result.predictions;
50   // clang-format on
51   const jmethodID result_class_constructor =
52       env->GetMethodID(result_class.get(), "<init>", "(Ljava/lang/String;F)V");
53   const jobjectArray results =
54       env->NewObjectArray(predictions.size(), result_class.get(), nullptr);
55   for (int i = 0; i < predictions.size(); i++) {
56     ScopedLocalRef<jobject> result(
57         env->NewObject(result_class.get(), result_class_constructor,
58                        env->NewStringUTF(predictions[i].first.c_str()),
59                        static_cast<jfloat>(predictions[i].second)));
60     env->SetObjectArrayElement(results, i, result.get());
61   }
62   return results;
63 }
64 }  // namespace
65 
TC3_JNI_METHOD(jlong,TC3_LANG_ID_CLASS_NAME,nativeNew)66 TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNew)
67 (JNIEnv* env, jobject thiz, jint fd) {
68   std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
69   if (!lang_id->is_valid()) {
70     return reinterpret_cast<jlong>(nullptr);
71   }
72   return reinterpret_cast<jlong>(lang_id.release());
73 }
74 
TC3_JNI_METHOD(jlong,TC3_LANG_ID_CLASS_NAME,nativeNewFromPath)75 TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
76 (JNIEnv* env, jobject thiz, jstring path) {
77   const std::string path_str = ToStlString(env, path);
78   std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFile(path_str);
79   if (!lang_id->is_valid()) {
80     return reinterpret_cast<jlong>(nullptr);
81   }
82   return reinterpret_cast<jlong>(lang_id.release());
83 }
84 
TC3_JNI_METHOD(jobjectArray,TC3_LANG_ID_CLASS_NAME,nativeDetectLanguages)85 TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
86 (JNIEnv* env, jobject clazz, jlong ptr, jstring text) {
87   LangId* model = reinterpret_cast<LangId*>(ptr);
88   if (!model) {
89     return nullptr;
90   }
91 
92   const std::string text_str = ToStlString(env, text);
93   LangIdResult result;
94   model->FindLanguages(text_str, &result);
95 
96   return LangIdResultToJObjectArray(env, result);
97 }
98 
TC3_JNI_METHOD(void,TC3_LANG_ID_CLASS_NAME,nativeClose)99 TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
100 (JNIEnv* env, jobject clazz, jlong ptr) {
101   if (!ptr) {
102     TC3_LOG(ERROR) << "Trying to close null LangId.";
103     return;
104   }
105   LangId* model = reinterpret_cast<LangId*>(ptr);
106   delete model;
107 }
108 
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetVersion)109 TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersion)
110 (JNIEnv* env, jobject clazz, jlong ptr) {
111   if (!ptr) {
112     return -1;
113   }
114   LangId* model = reinterpret_cast<LangId*>(ptr);
115   return model->GetModelVersion();
116 }
117 
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetVersionFromFd)118 TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionFromFd)
119 (JNIEnv* env, jobject clazz, jint fd) {
120   std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
121   if (!lang_id->is_valid()) {
122     return -1;
123   }
124   return lang_id->GetModelVersion();
125 }
126 
TC3_JNI_METHOD(jfloat,TC3_LANG_ID_CLASS_NAME,nativeGetLangIdThreshold)127 TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdThreshold)
128 (JNIEnv* env, jobject thizz, jlong ptr) {
129   if (!ptr) {
130     return -1.0;
131   }
132   LangId* model = reinterpret_cast<LangId*>(ptr);
133   return model->GetFloatProperty("text_classifier_langid_threshold", -1.0);
134 }
135