1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/java/src/main/native/operation_jni.h"
17
18 #include <memory>
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/java/src/main/native/exception_jni.h"
21
22 namespace {
23 template <class T>
requireHandleImpl(JNIEnv * env,jlong handle)24 T* requireHandleImpl(JNIEnv* env, jlong handle) {
25 static_assert(sizeof(jlong) >= sizeof(T*),
26 "Cannot package C object pointers as a Java long");
27 if (handle == 0) {
28 throwException(
29 env, kNullPointerException,
30 "close() has been called on the Graph this Operation was a part of");
31 return nullptr;
32 }
33 return reinterpret_cast<T*>(handle);
34 }
35
requireHandle(JNIEnv * env,jlong handle)36 TF_Operation* requireHandle(JNIEnv* env, jlong handle) {
37 return requireHandleImpl<TF_Operation>(env, handle);
38 }
39
requireGraphHandle(JNIEnv * env,jlong handle)40 TF_Graph* requireGraphHandle(JNIEnv* env, jlong handle) {
41 return requireHandleImpl<TF_Graph>(env, handle);
42 }
43 } // namespace
44
Java_org_tensorflow_Operation_name(JNIEnv * env,jclass clazz,jlong handle)45 JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_name(JNIEnv* env,
46 jclass clazz,
47 jlong handle) {
48 TF_Operation* op = requireHandle(env, handle);
49 if (op == nullptr) return nullptr;
50 return env->NewStringUTF(TF_OperationName(op));
51 }
52
Java_org_tensorflow_Operation_type(JNIEnv * env,jclass clazz,jlong handle)53 JNIEXPORT jstring JNICALL Java_org_tensorflow_Operation_type(JNIEnv* env,
54 jclass clazz,
55 jlong handle) {
56 TF_Operation* op = requireHandle(env, handle);
57 if (op == nullptr) return nullptr;
58 return env->NewStringUTF(TF_OperationOpType(op));
59 }
60
Java_org_tensorflow_Operation_numOutputs(JNIEnv * env,jclass clazz,jlong handle)61 JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_numOutputs(JNIEnv* env,
62 jclass clazz,
63 jlong handle) {
64 TF_Operation* op = requireHandle(env, handle);
65 if (op == nullptr) return 0;
66 return TF_OperationNumOutputs(op);
67 }
68
Java_org_tensorflow_Operation_outputListLength(JNIEnv * env,jclass clazz,jlong handle,jstring name)69 JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_outputListLength(JNIEnv* env,
70 jclass clazz,
71 jlong handle,
72 jstring name) {
73 TF_Operation* op = requireHandle(env, handle);
74 if (op == nullptr) return 0;
75
76 TF_Status* status = TF_NewStatus();
77
78 const char* cname = env->GetStringUTFChars(name, nullptr);
79 int result = TF_OperationOutputListLength(op, cname, status);
80 env->ReleaseStringUTFChars(name, cname);
81
82 throwExceptionIfNotOK(env, status);
83 TF_DeleteStatus(status);
84 return result;
85 }
86
Java_org_tensorflow_Operation_shape(JNIEnv * env,jclass clazz,jlong graph_handle,jlong op_handle,jint output_index)87 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Operation_shape(
88 JNIEnv* env, jclass clazz, jlong graph_handle, jlong op_handle,
89 jint output_index) {
90 TF_Graph* graph = requireGraphHandle(env, graph_handle);
91 if (graph == nullptr) return nullptr;
92 TF_Operation* op = requireHandle(env, op_handle);
93 if (op == nullptr) return nullptr;
94
95 int num_outputs = TF_OperationNumOutputs(op);
96 if (output_index < 0 || output_index >= num_outputs) {
97 throwException(
98 env, kIndexOutOfBoundsException,
99 "invalid output index (%d) for an operation that has %d outputs",
100 output_index, num_outputs);
101 return nullptr;
102 }
103
104 TF_Output output{op, output_index};
105 TF_Status* status = TF_NewStatus();
106 jsize num_dims = TF_GraphGetTensorNumDims(graph, output, status);
107 if (!throwExceptionIfNotOK(env, status)) {
108 TF_DeleteStatus(status);
109 return nullptr;
110 }
111 if (num_dims < 0) return nullptr;
112 static_assert(sizeof(jlong) == sizeof(int64_t),
113 "Java long is not compatible with the TensorFlow C API");
114 // One might have trivially wanted to do:
115 // TF_GraphGetTensorShape(graph, output, static_cast<int64_t*>(dims), ...)
116 // but on some platforms this fails with:
117 // static_cast from 'jlong *' (aka 'long *') to 'int64_t *' (aka 'long long
118 // *') is not allowed
119 // For now, do the expensive but safe thing of copying.
120 std::unique_ptr<int64_t[]> cdims(new int64_t[num_dims]);
121 TF_GraphGetTensorShape(graph, output, cdims.get(), static_cast<int>(num_dims),
122 status);
123 if (!throwExceptionIfNotOK(env, status)) {
124 TF_DeleteStatus(status);
125 return nullptr;
126 }
127 TF_DeleteStatus(status);
128
129 jlongArray ret = env->NewLongArray(num_dims);
130 jlong* dims = env->GetLongArrayElements(ret, nullptr);
131 for (int i = 0; i < num_dims; ++i) {
132 dims[i] = static_cast<jlong>(cdims[i]);
133 }
134 env->ReleaseLongArrayElements(ret, dims, 0);
135 return ret;
136 }
137
Java_org_tensorflow_Operation_dtype(JNIEnv * env,jclass clazz,jlong graph_handle,jlong op_handle,jint output_index)138 JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_dtype(JNIEnv* env,
139 jclass clazz,
140 jlong graph_handle,
141 jlong op_handle,
142 jint output_index) {
143 TF_Graph* graph = requireGraphHandle(env, graph_handle);
144 if (graph == nullptr) return 0;
145 TF_Operation* op = requireHandle(env, op_handle);
146 if (op == nullptr) return 0;
147
148 int num_outputs = TF_OperationNumOutputs(op);
149 if (output_index < 0 || output_index >= num_outputs) {
150 throwException(
151 env, kIndexOutOfBoundsException,
152 "invalid output index (%d) for an operation that has %d outputs",
153 output_index, num_outputs);
154 return 0;
155 }
156
157 return static_cast<jint>(TF_OperationOutputType(TF_Output{op, output_index}));
158 }
159
Java_org_tensorflow_Operation_inputListLength(JNIEnv * env,jclass clazz,jlong handle,jstring name)160 JNIEXPORT jint JNICALL Java_org_tensorflow_Operation_inputListLength(JNIEnv* env,
161 jclass clazz,
162 jlong handle,
163 jstring name) {
164 TF_Operation* op = requireHandle(env, handle);
165 if (op == nullptr) return 0;
166
167 TF_Status* status = TF_NewStatus();
168
169 const char* cname = env->GetStringUTFChars(name, nullptr);
170 int result = TF_OperationInputListLength(op, cname, status);
171 env->ReleaseStringUTFChars(name, cname);
172
173 throwExceptionIfNotOK(env, status);
174 TF_DeleteStatus(status);
175 return result;
176 }
177