1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/java/src/main/native/operation_jni.h"
17 
18 #include <memory>
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/java/src/main/native/exception_jni.h"
21 
22 namespace {
23 template <class T>
requireHandleImpl(JNIEnv * env,jlong handle)24 T* requireHandleImpl(JNIEnv* env, jlong handle) {
25   static_assert(sizeof(jlong) >= sizeof(T*),
26                 "Cannot package C object pointers as a Java long");
27   if (handle == 0) {
28     throwException(
29         env, kNullPointerException,
30         "close() has been called on the Graph this Operation was a part of");
31     return nullptr;
32   }
33   return reinterpret_cast<T*>(handle);
34 }
35 
requireHandle(JNIEnv * env,jlong handle)36 TF_Operation* requireHandle(JNIEnv* env, jlong handle) {
37   return requireHandleImpl<TF_Operation>(env, handle);
38 }
39 
requireGraphHandle(JNIEnv * env,jlong handle)40 TF_Graph* requireGraphHandle(JNIEnv* env, jlong handle) {
41   return requireHandleImpl<TF_Graph>(env, handle);
42 }
43 }  // namespace
44 
Java_org_tensorflow_Operation_name(JNIEnv * env,jclass clazz,jlong handle)45 JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_name(JNIEnv* env,
46                                                              jclass clazz,
47                                                              jlong handle) {
48   TF_Operation* op = requireHandle(env, handle);
49   if (op == nullptr) return nullptr;
50   return env->NewStringUTF(TF_OperationName(op));
51 }
52 
Java_org_tensorflow_Operation_type(JNIEnv * env,jclass clazz,jlong handle)53 JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_type(JNIEnv* env,
54                                                              jclass clazz,
55                                                              jlong handle) {
56   TF_Operation* op = requireHandle(env, handle);
57   if (op == nullptr) return nullptr;
58   return env->NewStringUTF(TF_OperationOpType(op));
59 }
60 
Java_org_tensorflow_Operation_numOutputs(JNIEnv * env,jclass clazz,jlong handle)61 JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_numOutputs(JNIEnv* env,
62                                                                 jclass clazz,
63                                                                 jlong handle) {
64   TF_Operation* op = requireHandle(env, handle);
65   if (op == nullptr) return 0;
66   return TF_OperationNumOutputs(op);
67 }
68 
Java_org_tensorflow_Operation_outputListLength(JNIEnv * env,jclass clazz,jlong handle,jstring name)69 JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_outputListLength(JNIEnv* env,
70                                                                       jclass clazz,
71                                                                       jlong handle,
72                                                                       jstring name) {
73   TF_Operation* op = requireHandle(env, handle);
74   if (op == nullptr) return 0;
75 
76   TF_Status* status = TF_NewStatus();
77 
78   const char* cname = env->GetStringUTFChars(name, nullptr);
79   int result = TF_OperationOutputListLength(op, cname, status);
80   env->ReleaseStringUTFChars(name, cname);
81 
82   throwExceptionIfNotOK(env, status);
83   TF_DeleteStatus(status);
84   return result;
85 }
86 
Java_org_tensorflow_Operation_shape(JNIEnv * env,jclass clazz,jlong graph_handle,jlong op_handle,jint output_index)87 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Operation_shape(
88     JNIEnv* env, jclass clazz, jlong graph_handle, jlong op_handle,
89     jint output_index) {
90   TF_Graph* graph = requireGraphHandle(env, graph_handle);
91   if (graph == nullptr) return nullptr;
92   TF_Operation* op = requireHandle(env, op_handle);
93   if (op == nullptr) return nullptr;
94 
95   int num_outputs = TF_OperationNumOutputs(op);
96   if (output_index < 0 || output_index >= num_outputs) {
97     throwException(
98         env, kIndexOutOfBoundsException,
99         "invalid output index (%d) for an operation that has %d outputs",
100         output_index, num_outputs);
101     return nullptr;
102   }
103 
104   TF_Output output{op, output_index};
105   TF_Status* status = TF_NewStatus();
106   jsize num_dims = TF_GraphGetTensorNumDims(graph, output, status);
107   if (!throwExceptionIfNotOK(env, status)) {
108     TF_DeleteStatus(status);
109     return nullptr;
110   }
111   if (num_dims < 0) return nullptr;
112   static_assert(sizeof(jlong) == sizeof(int64_t),
113                 "Java long is not compatible with the TensorFlow C API");
114   // One might have trivially wanted to do:
115   // TF_GraphGetTensorShape(graph, output, static_cast<int64_t*>(dims), ...)
116   // but on some platforms this fails with:
117   // static_cast from 'jlong *' (aka 'long *') to 'int64_t *' (aka 'long long
118   // *') is not allowed
119   // For now, do the expensive but safe thing of copying.
120   std::unique_ptr<int64_t[]> cdims(new int64_t[num_dims]);
121   TF_GraphGetTensorShape(graph, output, cdims.get(), static_cast<int>(num_dims),
122                          status);
123   if (!throwExceptionIfNotOK(env, status)) {
124     TF_DeleteStatus(status);
125     return nullptr;
126   }
127   TF_DeleteStatus(status);
128 
129   jlongArray ret = env->NewLongArray(num_dims);
130   jlong* dims = env->GetLongArrayElements(ret, nullptr);
131   for (int i = 0; i < num_dims; ++i) {
132     dims[i] = static_cast<jlong>(cdims[i]);
133   }
134   env->ReleaseLongArrayElements(ret, dims, 0);
135   return ret;
136 }
137 
Java_org_tensorflow_Operation_dtype(JNIEnv * env,jclass clazz,jlong graph_handle,jlong op_handle,jint output_index)138 JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_dtype(JNIEnv* env,
139                                                            jclass clazz,
140                                                            jlong graph_handle,
141                                                            jlong op_handle,
142                                                            jint output_index) {
143   TF_Graph* graph = requireGraphHandle(env, graph_handle);
144   if (graph == nullptr) return 0;
145   TF_Operation* op = requireHandle(env, op_handle);
146   if (op == nullptr) return 0;
147 
148   int num_outputs = TF_OperationNumOutputs(op);
149   if (output_index < 0 || output_index >= num_outputs) {
150     throwException(
151         env, kIndexOutOfBoundsException,
152         "invalid output index (%d) for an operation that has %d outputs",
153         output_index, num_outputs);
154     return 0;
155   }
156 
157   return static_cast<jint>(TF_OperationOutputType(TF_Output{op, output_index}));
158 }
159 
Java_org_tensorflow_Operation_inputListLength(JNIEnv * env,jclass clazz,jlong handle,jstring name)160 JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_inputListLength(JNIEnv* env,
161                                                                       jclass clazz,
162                                                                       jlong handle,
163                                                                       jstring name) {
164   TF_Operation* op = requireHandle(env, handle);
165   if (op == nullptr) return 0;
166 
167   TF_Status* status = TF_NewStatus();
168 
169   const char* cname = env->GetStringUTFChars(name, nullptr);
170   int result = TF_OperationInputListLength(op, cname, status);
171   env->ReleaseStringUTFChars(name, cname);
172 
173   throwExceptionIfNotOK(env, status);
174   TF_DeleteStatus(status);
175   return result;
176 }
177