1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "lang_id/lang-id_jni.h"
18
19 #include <jni.h>
20 #include <type_traits>
21 #include <vector>
22
23 #include "utils/base/logging.h"
24 #include "utils/java/scoped_local_ref.h"
25 #include "lang_id/fb_model/lang-id-from-fb.h"
26 #include "lang_id/lang-id.h"
27
28 using libtextclassifier3::ScopedLocalRef;
29 using libtextclassifier3::ToStlString;
30 using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile;
31 using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFileDescriptor;
32 using libtextclassifier3::mobile::lang_id::LangId;
33 using libtextclassifier3::mobile::lang_id::LangIdResult;
34
35 namespace {
LangIdResultToJObjectArray(JNIEnv * env,const LangIdResult & lang_id_result)36 jobjectArray LangIdResultToJObjectArray(JNIEnv* env,
37 const LangIdResult& lang_id_result) {
38 const ScopedLocalRef<jclass> result_class(
39 env->FindClass(TC3_PACKAGE_PATH TC3_LANG_ID_CLASS_NAME_STR
40 "$LanguageResult"),
41 env);
42 if (!result_class) {
43 TC3_LOG(ERROR) << "Couldn't find LanguageResult class.";
44 return nullptr;
45 }
46
47 // clang-format off
48 const std::vector<std::pair<std::string, float>>& predictions =
49 lang_id_result.predictions;
50 // clang-format on
51 const jmethodID result_class_constructor =
52 env->GetMethodID(result_class.get(), "<init>", "(Ljava/lang/String;F)V");
53 const jobjectArray results =
54 env->NewObjectArray(predictions.size(), result_class.get(), nullptr);
55 for (int i = 0; i < predictions.size(); i++) {
56 ScopedLocalRef<jobject> result(
57 env->NewObject(result_class.get(), result_class_constructor,
58 env->NewStringUTF(predictions[i].first.c_str()),
59 static_cast<jfloat>(predictions[i].second)));
60 env->SetObjectArrayElement(results, i, result.get());
61 }
62 return results;
63 }
64 } // namespace
65
TC3_JNI_METHOD(jlong,TC3_LANG_ID_CLASS_NAME,nativeNew)66 TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNew)
67 (JNIEnv* env, jobject thiz, jint fd) {
68 std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
69 if (!lang_id->is_valid()) {
70 return reinterpret_cast<jlong>(nullptr);
71 }
72 return reinterpret_cast<jlong>(lang_id.release());
73 }
74
TC3_JNI_METHOD(jlong,TC3_LANG_ID_CLASS_NAME,nativeNewFromPath)75 TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
76 (JNIEnv* env, jobject thiz, jstring path) {
77 const std::string path_str = ToStlString(env, path);
78 std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFile(path_str);
79 if (!lang_id->is_valid()) {
80 return reinterpret_cast<jlong>(nullptr);
81 }
82 return reinterpret_cast<jlong>(lang_id.release());
83 }
84
TC3_JNI_METHOD(jobjectArray,TC3_LANG_ID_CLASS_NAME,nativeDetectLanguages)85 TC3_JNI_METHOD(jobjectArray, TC3_LANG_ID_CLASS_NAME, nativeDetectLanguages)
86 (JNIEnv* env, jobject clazz, jlong ptr, jstring text) {
87 LangId* model = reinterpret_cast<LangId*>(ptr);
88 if (!model) {
89 return nullptr;
90 }
91
92 const std::string text_str = ToStlString(env, text);
93 LangIdResult result;
94 model->FindLanguages(text_str, &result);
95
96 return LangIdResultToJObjectArray(env, result);
97 }
98
TC3_JNI_METHOD(void,TC3_LANG_ID_CLASS_NAME,nativeClose)99 TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
100 (JNIEnv* env, jobject clazz, jlong ptr) {
101 if (!ptr) {
102 TC3_LOG(ERROR) << "Trying to close null LangId.";
103 return;
104 }
105 LangId* model = reinterpret_cast<LangId*>(ptr);
106 delete model;
107 }
108
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetVersion)109 TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersion)
110 (JNIEnv* env, jobject clazz, jlong ptr) {
111 if (!ptr) {
112 return -1;
113 }
114 LangId* model = reinterpret_cast<LangId*>(ptr);
115 return model->GetModelVersion();
116 }
117
TC3_JNI_METHOD(jint,TC3_LANG_ID_CLASS_NAME,nativeGetVersionFromFd)118 TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionFromFd)
119 (JNIEnv* env, jobject clazz, jint fd) {
120 std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFileDescriptor(fd);
121 if (!lang_id->is_valid()) {
122 return -1;
123 }
124 return lang_id->GetModelVersion();
125 }
126
TC3_JNI_METHOD(jfloat,TC3_LANG_ID_CLASS_NAME,nativeGetLangIdThreshold)127 TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdThreshold)
128 (JNIEnv* env, jobject thizz, jlong ptr) {
129 if (!ptr) {
130 return -1.0;
131 }
132 LangId* model = reinterpret_cast<LangId*>(ptr);
133 return model->GetFloatProperty("text_classifier_langid_threshold", -1.0);
134 }
135