1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Simple JNI wrapper for the SmartSelection library.
18 
19 #include "textclassifier_jni.h"
20 
21 #include <jni.h>
22 #include <vector>
23 
24 #include "lang_id/lang-id.h"
25 #include "smartselect/text-classification-model.h"
26 
27 using libtextclassifier::TextClassificationModel;
28 using libtextclassifier::ModelOptions;
29 using libtextclassifier::nlp_core::lang_id::LangId;
30 
31 namespace {
32 
JStringToUtf8String(JNIEnv * env,const jstring & jstr,std::string * result)33 bool JStringToUtf8String(JNIEnv* env, const jstring& jstr,
34                          std::string* result) {
35   if (jstr == nullptr) {
36     *result = std::string();
37     return false;
38   }
39 
40   jclass string_class = env->FindClass("java/lang/String");
41   jmethodID get_bytes_id =
42       env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");
43 
44   jstring encoding = env->NewStringUTF("UTF-8");
45   jbyteArray array = reinterpret_cast<jbyteArray>(
46       env->CallObjectMethod(jstr, get_bytes_id, encoding));
47 
48   jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE);
49   int length = env->GetArrayLength(array);
50 
51   *result = std::string(reinterpret_cast<char*>(array_bytes), length);
52 
53   // Release the array.
54   env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT);
55   env->DeleteLocalRef(array);
56   env->DeleteLocalRef(string_class);
57   env->DeleteLocalRef(encoding);
58 
59   return true;
60 }
61 
ToStlString(JNIEnv * env,const jstring & str)62 std::string ToStlString(JNIEnv* env, const jstring& str) {
63   std::string result;
64   JStringToUtf8String(env, str, &result);
65   return result;
66 }
67 
ScoredStringsToJObjectArray(JNIEnv * env,const std::string & result_class_name,const std::vector<std::pair<std::string,float>> & classification_result)68 jobjectArray ScoredStringsToJObjectArray(
69     JNIEnv* env, const std::string& result_class_name,
70     const std::vector<std::pair<std::string, float>>& classification_result) {
71   jclass result_class = env->FindClass(result_class_name.c_str());
72   jmethodID result_class_constructor =
73       env->GetMethodID(result_class, "<init>", "(Ljava/lang/String;F)V");
74 
75   jobjectArray results =
76       env->NewObjectArray(classification_result.size(), result_class, nullptr);
77 
78   for (int i = 0; i < classification_result.size(); i++) {
79     jstring row_string =
80         env->NewStringUTF(classification_result[i].first.c_str());
81     jobject result =
82         env->NewObject(result_class, result_class_constructor, row_string,
83                        static_cast<jfloat>(classification_result[i].second));
84     env->SetObjectArrayElement(results, i, result);
85     env->DeleteLocalRef(result);
86   }
87   env->DeleteLocalRef(result_class);
88   return results;
89 }
90 
91 }  // namespace
92 
93 namespace libtextclassifier {
94 
95 using libtextclassifier::CodepointSpan;
96 
97 namespace {
98 
ConvertIndicesBMPUTF8(const std::string & utf8_str,CodepointSpan orig_indices,bool from_utf8)99 CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str,
100                                     CodepointSpan orig_indices,
101                                     bool from_utf8) {
102   const libtextclassifier::UnicodeText unicode_str =
103       libtextclassifier::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false);
104 
105   int unicode_index = 0;
106   int bmp_index = 0;
107 
108   const int* source_index;
109   const int* target_index;
110   if (from_utf8) {
111     source_index = &unicode_index;
112     target_index = &bmp_index;
113   } else {
114     source_index = &bmp_index;
115     target_index = &unicode_index;
116   }
117 
118   CodepointSpan result{-1, -1};
119   std::function<void()> assign_indices_fn = [&result, &orig_indices,
120                                              &source_index, &target_index]() {
121     if (orig_indices.first == *source_index) {
122       result.first = *target_index;
123     }
124 
125     if (orig_indices.second == *source_index) {
126       result.second = *target_index;
127     }
128   };
129 
130   for (auto it = unicode_str.begin(); it != unicode_str.end();
131        ++it, ++unicode_index, ++bmp_index) {
132     assign_indices_fn();
133 
134     // There is 1 extra character in the input for each UTF8 character > 0xFFFF.
135     if (*it > 0xFFFF) {
136       ++bmp_index;
137     }
138   }
139   assign_indices_fn();
140 
141   return result;
142 }
143 
144 }  // namespace
145 
ConvertIndicesBMPToUTF8(const std::string & utf8_str,CodepointSpan orig_indices)146 CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str,
147                                       CodepointSpan orig_indices) {
148   return ConvertIndicesBMPUTF8(utf8_str, orig_indices, /*from_utf8=*/false);
149 }
150 
ConvertIndicesUTF8ToBMP(const std::string & utf8_str,CodepointSpan orig_indices)151 CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str,
152                                       CodepointSpan orig_indices) {
153   return ConvertIndicesBMPUTF8(utf8_str, orig_indices, /*from_utf8=*/true);
154 }
155 
156 }  // namespace libtextclassifier
157 
158 using libtextclassifier::ConvertIndicesUTF8ToBMP;
159 using libtextclassifier::ConvertIndicesBMPToUTF8;
160 using libtextclassifier::CodepointSpan;
161 
162 JNIEXPORT jlong JNICALL
Java_android_view_textclassifier_SmartSelection_nativeNew(JNIEnv * env,jobject thiz,jint fd)163 Java_android_view_textclassifier_SmartSelection_nativeNew(JNIEnv* env,
164                                                           jobject thiz,
165                                                           jint fd) {
166   TextClassificationModel* model = new TextClassificationModel(fd);
167   return reinterpret_cast<jlong>(model);
168 }
169 
170 JNIEXPORT jintArray JNICALL
Java_android_view_textclassifier_SmartSelection_nativeSuggest(JNIEnv * env,jobject thiz,jlong ptr,jstring context,jint selection_begin,jint selection_end)171 Java_android_view_textclassifier_SmartSelection_nativeSuggest(
172     JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
173     jint selection_end) {
174   TextClassificationModel* model =
175       reinterpret_cast<TextClassificationModel*>(ptr);
176 
177   const std::string context_utf8 = ToStlString(env, context);
178   CodepointSpan input_indices =
179       ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
180   CodepointSpan selection =
181       model->SuggestSelection(context_utf8, input_indices);
182   selection = ConvertIndicesUTF8ToBMP(context_utf8, selection);
183 
184   jintArray result = env->NewIntArray(2);
185   env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection)));
186   env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection)));
187   return result;
188 }
189 
190 JNIEXPORT jobjectArray JNICALL
Java_android_view_textclassifier_SmartSelection_nativeClassifyText(JNIEnv * env,jobject thiz,jlong ptr,jstring context,jint selection_begin,jint selection_end,jint input_flags)191 Java_android_view_textclassifier_SmartSelection_nativeClassifyText(
192     JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
193     jint selection_end, jint input_flags) {
194   TextClassificationModel* ff_model =
195       reinterpret_cast<TextClassificationModel*>(ptr);
196   const std::vector<std::pair<std::string, float>> classification_result =
197       ff_model->ClassifyText(ToStlString(env, context),
198                              {selection_begin, selection_end}, input_flags);
199 
200   return ScoredStringsToJObjectArray(
201       env, "android/view/textclassifier/SmartSelection$ClassificationResult",
202       classification_result);
203 }
204 
205 JNIEXPORT void JNICALL
Java_android_view_textclassifier_SmartSelection_nativeClose(JNIEnv * env,jobject thiz,jlong ptr)206 Java_android_view_textclassifier_SmartSelection_nativeClose(JNIEnv* env,
207                                                             jobject thiz,
208                                                             jlong ptr) {
209   TextClassificationModel* model =
210       reinterpret_cast<TextClassificationModel*>(ptr);
211   delete model;
212 }
213 
Java_android_view_textclassifier_LangId_nativeNew(JNIEnv * env,jobject thiz,jint fd)214 JNIEXPORT jlong JNICALL Java_android_view_textclassifier_LangId_nativeNew(
215     JNIEnv* env, jobject thiz, jint fd) {
216   return reinterpret_cast<jlong>(new LangId(fd));
217 }
218 
219 JNIEXPORT jstring JNICALL
Java_android_view_textclassifier_SmartSelection_nativeGetLanguage(JNIEnv * env,jobject clazz,jint fd)220 Java_android_view_textclassifier_SmartSelection_nativeGetLanguage(JNIEnv* env,
221                                                                   jobject clazz,
222                                                                   jint fd) {
223   ModelOptions model_options;
224   if (ReadSelectionModelOptions(fd, &model_options)) {
225     return env->NewStringUTF(model_options.language().c_str());
226   } else {
227     return env->NewStringUTF("UNK");
228   }
229 }
230 
231 JNIEXPORT jint JNICALL
Java_android_view_textclassifier_SmartSelection_nativeGetVersion(JNIEnv * env,jobject clazz,jint fd)232 Java_android_view_textclassifier_SmartSelection_nativeGetVersion(JNIEnv* env,
233                                                                  jobject clazz,
234                                                                  jint fd) {
235   ModelOptions model_options;
236   if (ReadSelectionModelOptions(fd, &model_options)) {
237     return model_options.version();
238   } else {
239     return -1;
240   }
241 }
242 
243 JNIEXPORT jobjectArray JNICALL
Java_android_view_textclassifier_LangId_nativeFindLanguages(JNIEnv * env,jobject thiz,jlong ptr,jstring text)244 Java_android_view_textclassifier_LangId_nativeFindLanguages(JNIEnv* env,
245                                                             jobject thiz,
246                                                             jlong ptr,
247                                                             jstring text) {
248   LangId* lang_id = reinterpret_cast<LangId*>(ptr);
249   const std::vector<std::pair<std::string, float>> scored_languages =
250       lang_id->FindLanguages(ToStlString(env, text));
251 
252   return ScoredStringsToJObjectArray(
253       env, "android/view/textclassifier/LangId$ClassificationResult",
254       scored_languages);
255 }
256 
Java_android_view_textclassifier_LangId_nativeClose(JNIEnv * env,jobject thiz,jlong ptr)257 JNIEXPORT void JNICALL Java_android_view_textclassifier_LangId_nativeClose(
258     JNIEnv* env, jobject thiz, jlong ptr) {
259   LangId* lang_id = reinterpret_cast<LangId*>(ptr);
260   delete lang_id;
261 }
262 
Java_android_view_textclassifier_LangId_nativeGetVersion(JNIEnv * env,jobject clazz,jint fd)263 JNIEXPORT int JNICALL Java_android_view_textclassifier_LangId_nativeGetVersion(
264     JNIEnv* env, jobject clazz, jint fd) {
265   std::unique_ptr<LangId> lang_id(new LangId(fd));
266   return lang_id->version();
267 }
268