1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "example_iterator_wrapper_impl.h"
18 
19 #include <jni.h>
20 
21 #include <string>
22 
23 #include "fcp/jni/jni_util.h"
24 #include "fcp/protos/federatedcompute/common.pb.h"
25 #include "more_jni_util.h"
26 #include "nativehelper/scoped_local_ref.h"
27 
28 namespace fcp {
29 namespace client {
30 namespace engine {
31 namespace jni {
32 
33 using ::fcp::jni::JavaMethodSig;
34 using ::fcp::jni::ParseProtoFromJByteArray;
35 using ::fcp::jni::ScopedJniEnv;
36 using ::fcp::jni::SerializeProtoToJByteArray;
37 
38 struct JavaExampleIteratorClassDesc {
39   static constexpr JavaMethodSig kNext = {"next", "()[B"};
40   static constexpr JavaMethodSig kClose = {"close", "()V"};
41 };
42 
ExampleIteratorWrapperImpl(JavaVM * jvm,jobject example_iterator)43 ExampleIteratorWrapperImpl::ExampleIteratorWrapperImpl(JavaVM *jvm,
44                                                        jobject example_iterator)
45     : jvm_(jvm) {
46   {
47     std::lock_guard<std::mutex> lock(close_mu_);
48     closed_ = false;
49   }
50   ScopedJniEnv scoped_env(jvm_);
51   JNIEnv *env = scoped_env.env();
52   jthis_ = env->NewGlobalRef(example_iterator);
53   FCP_CHECK(jthis_ != nullptr);
54 
55   ScopedLocalRef<jclass> example_iterator_class(
56       env, env->GetObjectClass(example_iterator));
57   FCP_CHECK(!MoreJniUtil::CheckForJniException(env));
58   FCP_CHECK(example_iterator_class.get() != nullptr);
59 
60   next_id_ = MoreJniUtil::GetMethodIdOrAbort(
61       env, example_iterator_class.get(), JavaExampleIteratorClassDesc::kNext);
62   close_id_ = MoreJniUtil::GetMethodIdOrAbort(
63       env, example_iterator_class.get(), JavaExampleIteratorClassDesc::kClose);
64 }
65 
~ExampleIteratorWrapperImpl()66 ExampleIteratorWrapperImpl::~ExampleIteratorWrapperImpl() {
67   // This ensures that the java iterator instance is released when the
68   // TensorFlow session is freed.
69   Close();
70   ScopedJniEnv scoped_env(jvm_);
71   JNIEnv *env = scoped_env.env();
72   env->DeleteGlobalRef(jthis_);
73 }
74 
Next()75 absl::StatusOr<std::string> ExampleIteratorWrapperImpl::Next() {
76   {
77     std::lock_guard<std::mutex> lock(close_mu_);
78     if (closed_) {
79       return absl::InternalError("Next() called on closed iterator.");
80     }
81   }
82   ScopedJniEnv scoped_env(jvm_);
83   JNIEnv *env = scoped_env.env();
84 
85   ScopedLocalRef<jbyteArray> example(
86       env, (jbyteArray)env->CallObjectMethod(jthis_, next_id_));
87   FCP_RETURN_IF_ERROR(
88       MoreJniUtil::GetExceptionStatus(env, "call JavaExampleIterator.Next()"));
89   FCP_CHECK(example.get() != nullptr);
90 
91   int result_size = env->GetArrayLength(example.get());
92   if (result_size == 0) {
93     return absl::OutOfRangeError("end of iterator reached");
94   }
95   std::string example_string =
96       MoreJniUtil::JByteArrayToString(env, example.get());
97   return example_string;
98 }
99 
100 // Close the iterator to release associated resources.
Close()101 void ExampleIteratorWrapperImpl::Close() {
102   std::lock_guard<std::mutex> lock(close_mu_);
103   if (closed_) {
104     return;
105   }
106   ScopedJniEnv scoped_env(jvm_);
107   JNIEnv *env = scoped_env.env();
108 
109   env->CallVoidMethod(jthis_, close_id_);
110   FCP_CHECK(!MoreJniUtil::CheckForJniException(env));
111   closed_ = true;
112 }
113 
114 }  // namespace jni
115 }  // namespace engine
116 }  // namespace client
117 }  // namespace fcp
118