1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/java/src/main/native/tensor_jni.h"
17 
18 #include <assert.h>
19 #include <stdlib.h>
20 #include <string.h>
21 #include <algorithm>
22 #include <memory>
23 
24 #include "tensorflow/c/c_api.h"
25 #include "tensorflow/java/src/main/native/exception_jni.h"
26 
27 namespace {
28 
requireHandle(JNIEnv * env,jlong handle)29 TF_Tensor* requireHandle(JNIEnv* env, jlong handle) {
30   if (handle == 0) {
31     throwException(env, kNullPointerException,
32                    "close() was called on the Tensor");
33     return nullptr;
34   }
35   return reinterpret_cast<TF_Tensor*>(handle);
36 }
37 
elemByteSize(TF_DataType dtype)38 size_t elemByteSize(TF_DataType dtype) {
39   // The code in this file makes the assumption that the
40   // TensorFlow TF_DataTypes and the Java primitive types
41   // have the same byte sizes. Validate that:
42   switch (dtype) {
43     case TF_BOOL:
44     case TF_UINT8:
45       static_assert(sizeof(jboolean) == 1,
46                     "Java boolean not compatible with TF_BOOL");
47       static_assert(sizeof(jbyte) == 1,
48                     "Java byte not compatible with TF_UINT8");
49       return 1;
50     case TF_FLOAT:
51     case TF_INT32:
52       static_assert(sizeof(jfloat) == 4,
53                     "Java float not compatible with TF_FLOAT");
54       static_assert(sizeof(jint) == 4, "Java int not compatible with TF_INT32");
55       return 4;
56     case TF_DOUBLE:
57     case TF_INT64:
58       static_assert(sizeof(jdouble) == 8,
59                     "Java double not compatible with TF_DOUBLE");
60       static_assert(sizeof(jlong) == 8,
61                     "Java long not compatible with TF_INT64");
62       return 8;
63     default:
64       return 0;
65   }
66 }
67 
68 // Write a Java scalar object (java.lang.Integer etc.) to a TF_Tensor.
writeScalar(JNIEnv * env,jobject src,TF_DataType dtype,void * dst,size_t dst_size)69 void writeScalar(JNIEnv* env, jobject src, TF_DataType dtype, void* dst,
70                  size_t dst_size) {
71   size_t sz = elemByteSize(dtype);
72   if (sz != dst_size) {
73     throwException(
74         env, kIllegalStateException,
75         "scalar (%d bytes) not compatible with allocated tensor (%d bytes)", sz,
76         dst_size);
77     return;
78   }
79   switch (dtype) {
80 // env->FindClass and env->GetMethodID are expensive and JNI best practices
81 // suggest that they should be cached. However, until the creation of scalar
82 // valued tensors seems to become a noticeable fraction of program execution,
83 // ignore that cost.
84 #define CASE(dtype, jtype, method_name, method_signature, call_type)           \
85   case dtype: {                                                                \
86     jclass clazz = env->FindClass("java/lang/Number");                         \
87     jmethodID method = env->GetMethodID(clazz, method_name, method_signature); \
88     jtype v = env->Call##call_type##Method(src, method);                       \
89     memcpy(dst, &v, sz);                                                       \
90     return;                                                                    \
91   }
92     CASE(TF_FLOAT, jfloat, "floatValue", "()F", Float);
93     CASE(TF_DOUBLE, jdouble, "doubleValue", "()D", Double);
94     CASE(TF_INT32, jint, "intValue", "()I", Int);
95     CASE(TF_INT64, jlong, "longValue", "()J", Long);
96     CASE(TF_UINT8, jbyte, "byteValue", "()B", Byte);
97 #undef CASE
98     case TF_BOOL: {
99       jclass clazz = env->FindClass("java/lang/Boolean");
100       jmethodID method = env->GetMethodID(clazz, "booleanValue", "()Z");
101       jboolean v = env->CallBooleanMethod(src, method);
102       *(static_cast<unsigned char*>(dst)) = v ? 1 : 0;
103       return;
104     }
105     default:
106       throwException(env, kIllegalStateException, "invalid DataType(%d)",
107                      dtype);
108       return;
109   }
110 }
111 
112 // Copy a 1-D array of Java primitive types to the tensor buffer dst.
113 // Returns the number of bytes written to dst.
write1DArray(JNIEnv * env,jarray array,TF_DataType dtype,void * dst,size_t dst_size)114 size_t write1DArray(JNIEnv* env, jarray array, TF_DataType dtype, void* dst,
115                     size_t dst_size) {
116   const int nelems = env->GetArrayLength(array);
117   jboolean is_copy;
118   switch (dtype) {
119 #define CASE(dtype, jtype, get_type)                                   \
120   case dtype: {                                                        \
121     jtype##Array a = static_cast<jtype##Array>(array);                 \
122     jtype* values = env->Get##get_type##ArrayElements(a, &is_copy);    \
123     size_t to_copy = nelems * elemByteSize(dtype);                     \
124     if (to_copy > dst_size) {                                          \
125       throwException(                                                  \
126           env, kIllegalStateException,                                 \
127           "cannot write Java array of %d bytes to Tensor of %d bytes", \
128           to_copy, dst_size);                                          \
129       to_copy = 0;                                                     \
130     } else {                                                           \
131       memcpy(dst, values, to_copy);                                    \
132     }                                                                  \
133     env->Release##get_type##ArrayElements(a, values, JNI_ABORT);       \
134     return to_copy;                                                    \
135   }
136     CASE(TF_FLOAT, jfloat, Float);
137     CASE(TF_DOUBLE, jdouble, Double);
138     CASE(TF_INT32, jint, Int);
139     CASE(TF_INT64, jlong, Long);
140     CASE(TF_BOOL, jboolean, Boolean);
141     CASE(TF_UINT8, jbyte, Byte);
142 #undef CASE
143     default:
144       throwException(env, kIllegalStateException, "invalid DataType(%d)",
145                      dtype);
146       return 0;
147   }
148 }
149 
150 // Copy the elements of a 1-D array from the tensor buffer src to a 1-D array of
151 // Java primitive types. Returns the number of bytes read from src.
read1DArray(JNIEnv * env,TF_DataType dtype,const void * src,size_t src_size,jarray dst)152 size_t read1DArray(JNIEnv* env, TF_DataType dtype, const void* src,
153                    size_t src_size, jarray dst) {
154   const int len = env->GetArrayLength(dst);
155   const size_t sz = len * elemByteSize(dtype);
156   if (sz > src_size) {
157     throwException(
158         env, kIllegalStateException,
159         "cannot fill a Java array of %d bytes with a Tensor of %d bytes", sz,
160         src_size);
161     return 0;
162   }
163   switch (dtype) {
164 #define CASE(dtype, jtype, primitive_type)                                 \
165   case dtype: {                                                            \
166     jtype##Array arr = static_cast<jtype##Array>(dst);                     \
167     env->Set##primitive_type##ArrayRegion(arr, 0, len,                     \
168                                           static_cast<const jtype*>(src)); \
169     return sz;                                                             \
170   }
171     CASE(TF_FLOAT, jfloat, Float);
172     CASE(TF_DOUBLE, jdouble, Double);
173     CASE(TF_INT32, jint, Int);
174     CASE(TF_INT64, jlong, Long);
175     CASE(TF_BOOL, jboolean, Boolean);
176     CASE(TF_UINT8, jbyte, Byte);
177 #undef CASE
178     default:
179       throwException(env, kIllegalStateException, "invalid DataType(%d)",
180                      dtype);
181   }
182   return 0;
183 }
184 
writeNDArray(JNIEnv * env,jarray src,TF_DataType dtype,int dims_left,char * dst,size_t dst_size)185 size_t writeNDArray(JNIEnv* env, jarray src, TF_DataType dtype, int dims_left,
186                     char* dst, size_t dst_size) {
187   if (dims_left == 1) {
188     return write1DArray(env, src, dtype, dst, dst_size);
189   } else {
190     jobjectArray ndarray = static_cast<jobjectArray>(src);
191     int len = env->GetArrayLength(ndarray);
192     size_t sz = 0;
193     for (int i = 0; i < len; ++i) {
194       jarray row = static_cast<jarray>(env->GetObjectArrayElement(ndarray, i));
195       sz +=
196           writeNDArray(env, row, dtype, dims_left - 1, dst + sz, dst_size - sz);
197       env->DeleteLocalRef(row);
198       if (env->ExceptionCheck()) return sz;
199     }
200     return sz;
201   }
202 }
203 
readNDArray(JNIEnv * env,TF_DataType dtype,const char * src,size_t src_size,int dims_left,jarray dst)204 size_t readNDArray(JNIEnv* env, TF_DataType dtype, const char* src,
205                    size_t src_size, int dims_left, jarray dst) {
206   if (dims_left == 1) {
207     return read1DArray(env, dtype, src, src_size, dst);
208   } else {
209     jobjectArray ndarray = static_cast<jobjectArray>(dst);
210     int len = env->GetArrayLength(ndarray);
211     size_t sz = 0;
212     for (int i = 0; i < len; ++i) {
213       jarray row = static_cast<jarray>(env->GetObjectArrayElement(ndarray, i));
214       sz +=
215           readNDArray(env, dtype, src + sz, src_size - sz, dims_left - 1, row);
216       env->DeleteLocalRef(row);
217       if (env->ExceptionCheck()) return sz;
218     }
219     return sz;
220   }
221 }
222 
TF_StringDecodeTojbyteArray(JNIEnv * env,const char * src,size_t src_len,TF_Status * status)223 jbyteArray TF_StringDecodeTojbyteArray(JNIEnv* env, const char* src,
224                                        size_t src_len, TF_Status* status) {
225   const char* dst = nullptr;
226   size_t dst_len = 0;
227   TF_StringDecode(src, src_len, &dst, &dst_len, status);
228   if (TF_GetCode(status) != TF_OK) {
229     return nullptr;
230   }
231   jbyteArray ret = env->NewByteArray(dst_len);
232   jbyte* cpy = env->GetByteArrayElements(ret, nullptr);
233   memcpy(cpy, dst, dst_len);
234   env->ReleaseByteArrayElements(ret, cpy, 0);
235   return ret;
236 }
237 
238 class StringTensorWriter {
239  public:
StringTensorWriter(TF_Tensor * t,int num_elements)240   StringTensorWriter(TF_Tensor* t, int num_elements)
241       : offset_(0),
242         poffsets_(static_cast<char*>(TF_TensorData(t))),
243         pdata_(poffsets_ + 8 * num_elements),
244         plimit_(poffsets_ + TF_TensorByteSize(t)) {}
245 
Add(const char * src,size_t len,TF_Status * status)246   void Add(const char* src, size_t len, TF_Status* status) {
247     if (TF_GetCode(status) != TF_OK) return;
248     if (plimit_ - poffsets_ < sizeof(offset_)) {
249       TF_SetStatus(status, TF_OUT_OF_RANGE,
250                    "TF_STRING tensor encoding ran out of space for offsets, "
251                    "this is likely a bug, please file an issue at "
252                    "https://github.com/tensorflow/tensorflow/issues/new");
253       return;
254     }
255     memcpy(poffsets_, &offset_, sizeof(offset_));
256     size_t written =
257         TF_StringEncode(src, len, pdata_, (plimit_ - pdata_), status);
258     offset_ += written;
259     poffsets_ += 8;
260     pdata_ += written;
261   }
262 
263  private:
264   uint64_t offset_;
265   char* poffsets_;
266   char* pdata_;
267   const char* plimit_;
268 };
269 
270 class StringTensorReader {
271  public:
StringTensorReader(const TF_Tensor * t,int num_elements)272   StringTensorReader(const TF_Tensor* t, int num_elements)
273       : index_(0),
274         offsets_(static_cast<const char*>(TF_TensorData(t))),
275         data_(offsets_ + 8 * num_elements),
276         limit_(offsets_ + TF_TensorByteSize(t)) {}
277 
Next(JNIEnv * env,TF_Status * status)278   jbyteArray Next(JNIEnv* env, TF_Status* status) {
279     if (TF_GetCode(status) != TF_OK) return nullptr;
280     uint64_t offset = 0;
281     const char* poffset = offsets_ + sizeof(offset) * index_;
282     if (poffset >= limit_) {
283       TF_SetStatus(
284           status, TF_INTERNAL,
285           "Invalid TF_STRING tensor, offsets table seems to be too small");
286       return nullptr;
287     }
288     memcpy(&offset, poffset, sizeof(offset));
289     const char* pdata = data_ + offset;
290     if (pdata >= limit_) {
291       TF_SetStatus(status, TF_INTERNAL,
292                    "Invalid TF_STRING tensor, invalid entry in offset table");
293       return nullptr;
294     }
295     ++index_;
296     return TF_StringDecodeTojbyteArray(env, pdata, (limit_ - pdata), status);
297   }
298 
299  private:
300   int index_;
301   const char* offsets_;
302   const char* data_;
303   const char* limit_;
304 };
305 
readNDStringArray(JNIEnv * env,StringTensorReader * reader,int dims_left,jobjectArray dst,TF_Status * status)306 void readNDStringArray(JNIEnv* env, StringTensorReader* reader, int dims_left,
307                        jobjectArray dst, TF_Status* status) {
308   jsize len = env->GetArrayLength(dst);
309   if (dims_left == 1) {
310     for (jsize i = 0; i < len; ++i) {
311       jbyteArray elem = reader->Next(env, status);
312       if (TF_GetCode(status) != TF_OK) return;
313       env->SetObjectArrayElement(dst, i, elem);
314     }
315     return;
316   }
317   for (jsize i = 0; i < len; ++i) {
318     jobjectArray arr =
319         static_cast<jobjectArray>(env->GetObjectArrayElement(dst, i));
320     readNDStringArray(env, reader, dims_left - 1, arr, status);
321     if (TF_GetCode(status) != TF_OK) return;
322   }
323 }
324 }  // namespace
325 
Java_org_tensorflow_Tensor_allocate(JNIEnv * env,jclass clazz,jint dtype,jlongArray shape,jlong sizeInBytes)326 JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv* env,
327                                                             jclass clazz,
328                                                             jint dtype,
329                                                             jlongArray shape,
330                                                             jlong sizeInBytes) {
331   int num_dims = static_cast<int>(env->GetArrayLength(shape));
332   jlong* dims = nullptr;
333   if (num_dims > 0) {
334     jboolean is_copy;
335     dims = env->GetLongArrayElements(shape, &is_copy);
336   }
337   static_assert(sizeof(jlong) == sizeof(int64_t),
338                 "Java long is not compatible with the TensorFlow C API");
339   // On some platforms "jlong" is a "long" while "int64_t" is a "long long".
340   //
341   // Thus, static_cast<int64_t*>(dims) will trigger a compiler error:
342   // static_cast from 'jlong *' (aka 'long *') to 'int64_t *' (aka 'long long
343   // *') is not allowed
344   //
345   // Since this array is typically very small, use the guaranteed safe scheme of
346   // creating a copy.
347   int64_t* dims_copy = new int64_t[num_dims];
348   for (int i = 0; i < num_dims; ++i) {
349     dims_copy[i] = static_cast<int64_t>(dims[i]);
350   }
351   TF_Tensor* t = TF_AllocateTensor(static_cast<TF_DataType>(dtype), dims_copy,
352                                    num_dims, static_cast<size_t>(sizeInBytes));
353   delete[] dims_copy;
354   if (dims != nullptr) {
355     env->ReleaseLongArrayElements(shape, dims, JNI_ABORT);
356   }
357   if (t == nullptr) {
358     throwException(env, kNullPointerException,
359                    "unable to allocate memory for the Tensor");
360     return 0;
361   }
362   return reinterpret_cast<jlong>(t);
363 }
364 
Java_org_tensorflow_Tensor_allocateScalarBytes(JNIEnv * env,jclass clazz,jbyteArray value)365 JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateScalarBytes(
366     JNIEnv* env, jclass clazz, jbyteArray value) {
367   // TF_STRING tensors are encoded with a table of 8-byte offsets followed by
368   // TF_StringEncode-encoded bytes.
369   size_t src_len = static_cast<int>(env->GetArrayLength(value));
370   size_t dst_len = TF_StringEncodedSize(src_len);
371   TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, 8 + dst_len);
372   char* dst = static_cast<char*>(TF_TensorData(t));
373   memset(dst, 0, 8);  // The offset table
374 
375   TF_Status* status = TF_NewStatus();
376   jbyte* jsrc = env->GetByteArrayElements(value, nullptr);
377   // jsrc is an unsigned byte*, TF_StringEncode requires a char*.
378   // reinterpret_cast<> for this conversion should be safe.
379   TF_StringEncode(reinterpret_cast<const char*>(jsrc), src_len, dst + 8,
380                   dst_len, status);
381   env->ReleaseByteArrayElements(value, jsrc, JNI_ABORT);
382   if (!throwExceptionIfNotOK(env, status)) {
383     TF_DeleteStatus(status);
384     return 0;
385   }
386   TF_DeleteStatus(status);
387   return reinterpret_cast<jlong>(t);
388 }
389 
390 namespace {
nonScalarTF_STRINGTensorSize(JNIEnv * env,jarray value,int num_dims)391 size_t nonScalarTF_STRINGTensorSize(JNIEnv* env, jarray value, int num_dims) {
392   if (num_dims == 0) {
393     // This is the last dimension, i.e., value should correspond to a jbyteArray
394     // encoding the string.
395     return TF_StringEncodedSize(
396         static_cast<size_t>(env->GetArrayLength(value)));
397   }
398   jsize len = env->GetArrayLength(value);
399   size_t ret = 0;
400   for (jsize i = 0; i < len; ++i) {
401     jarray elem = static_cast<jarray>(
402         env->GetObjectArrayElement(static_cast<jobjectArray>(value), i));
403     if (elem == nullptr) {
404       throwException(env, kNullPointerException,
405                      "null entries in provided array");
406       return ret;
407     }
408     ret += nonScalarTF_STRINGTensorSize(env, elem, num_dims - 1);
409     if (env->ExceptionCheck()) return ret;
410   }
411   return ret;
412 }
413 
fillNonScalarTF_STRINGTensorData(JNIEnv * env,jarray value,int num_dims,StringTensorWriter * writer,TF_Status * status)414 void fillNonScalarTF_STRINGTensorData(JNIEnv* env, jarray value, int num_dims,
415                                       StringTensorWriter* writer,
416                                       TF_Status* status) {
417   if (num_dims == 0) {
418     jbyte* jsrc =
419         env->GetByteArrayElements(static_cast<jbyteArray>(value), nullptr);
420     writer->Add(reinterpret_cast<const char*>(jsrc), env->GetArrayLength(value),
421                 status);
422     env->ReleaseByteArrayElements(static_cast<jbyteArray>(value), jsrc,
423                                   JNI_ABORT);
424     return;
425   }
426   jsize len = env->GetArrayLength(value);
427   for (jsize i = 0; i < len; ++i) {
428     jarray elem = static_cast<jarray>(
429         env->GetObjectArrayElement(static_cast<jobjectArray>(value), i));
430     fillNonScalarTF_STRINGTensorData(env, elem, num_dims - 1, writer, status);
431     if (TF_GetCode(status) != TF_OK) return;
432   }
433 }
434 }  // namespace
435 
Java_org_tensorflow_Tensor_allocateNonScalarBytes(JNIEnv * env,jclass clazz,jlongArray shape,jobjectArray value)436 JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateNonScalarBytes(
437     JNIEnv* env, jclass clazz, jlongArray shape, jobjectArray value) {
438   // TF_STRING tensors are encoded with a table of 8-byte offsets following by
439   // TF_StringEncode-encoded bytes.
440   const int num_dims = static_cast<int>(env->GetArrayLength(shape));
441   int64_t* dims = new int64_t[num_dims];
442   int64_t num_elements = 1;
443   {
444     jlong* jdims = env->GetLongArrayElements(shape, nullptr);
445     for (int i = 0; i < num_dims; ++i) {
446       dims[i] = static_cast<int64_t>(jdims[i]);
447       num_elements *= dims[i];
448     }
449     env->ReleaseLongArrayElements(shape, jdims, JNI_ABORT);
450   }
451   const size_t encoded_size =
452       nonScalarTF_STRINGTensorSize(env, value, num_dims);
453   if (env->ExceptionCheck()) return 0;
454   TF_Tensor* t = TF_AllocateTensor(TF_STRING, dims, num_dims,
455                                    8 * num_elements + encoded_size);
456   if (t == nullptr) {
457     delete[] dims;
458     throwException(env, kNullPointerException,
459                    "unable to allocate memory for the Tensor");
460     return 0;
461   }
462   TF_Status* status = TF_NewStatus();
463   StringTensorWriter writer(t, num_elements);
464   fillNonScalarTF_STRINGTensorData(env, value, num_dims, &writer, status);
465   delete[] dims;
466   jlong ret = 0;
467   if (!throwExceptionIfNotOK(env, status)) {
468     TF_DeleteTensor(t);
469   } else {
470     ret = reinterpret_cast<jlong>(t);
471   }
472   TF_DeleteStatus(status);
473   return ret;
474 }
475 
Java_org_tensorflow_Tensor_delete(JNIEnv * env,jclass clazz,jlong handle)476 JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv* env,
477                                                          jclass clazz,
478                                                          jlong handle) {
479   if (handle == 0) return;
480   TF_DeleteTensor(reinterpret_cast<TF_Tensor*>(handle));
481 }
482 
Java_org_tensorflow_Tensor_buffer(JNIEnv * env,jclass clazz,jlong handle)483 JNIEXPORT jobject JNICALL Java_org_tensorflow_Tensor_buffer(JNIEnv* env,
484                                                             jclass clazz,
485                                                             jlong handle) {
486   TF_Tensor* t = requireHandle(env, handle);
487   if (t == nullptr) return nullptr;
488   void* data = TF_TensorData(t);
489   const size_t sz = TF_TensorByteSize(t);
490 
491   return env->NewDirectByteBuffer(data, static_cast<jlong>(sz));
492 }
493 
Java_org_tensorflow_Tensor_dtype(JNIEnv * env,jclass clazz,jlong handle)494 JNIEXPORT jint JNICALL Java_org_tensorflow_Tensor_dtype(JNIEnv* env,
495                                                         jclass clazz,
496                                                         jlong handle) {
497   static_assert(sizeof(jint) >= sizeof(TF_DataType),
498                 "TF_DataType in C cannot be represented as an int in Java");
499   TF_Tensor* t = requireHandle(env, handle);
500   if (t == nullptr) return 0;
501   return static_cast<jint>(TF_TensorType(t));
502 }
503 
Java_org_tensorflow_Tensor_shape(JNIEnv * env,jclass clazz,jlong handle)504 JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Tensor_shape(JNIEnv* env,
505                                                               jclass clazz,
506                                                               jlong handle) {
507   TF_Tensor* t = requireHandle(env, handle);
508   if (t == nullptr) return nullptr;
509   static_assert(sizeof(jlong) == sizeof(int64_t),
510                 "Java long is not compatible with the TensorFlow C API");
511   const jsize num_dims = TF_NumDims(t);
512   jlongArray ret = env->NewLongArray(num_dims);
513   jlong* dims = env->GetLongArrayElements(ret, nullptr);
514   for (int i = 0; i < num_dims; ++i) {
515     dims[i] = static_cast<jlong>(TF_Dim(t, i));
516   }
517   env->ReleaseLongArrayElements(ret, dims, 0);
518   return ret;
519 }
520 
Java_org_tensorflow_Tensor_setValue(JNIEnv * env,jclass clazz,jlong handle,jobject value)521 JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_setValue(JNIEnv* env,
522                                                            jclass clazz,
523                                                            jlong handle,
524                                                            jobject value) {
525   TF_Tensor* t = requireHandle(env, handle);
526   if (t == nullptr) return;
527   int num_dims = TF_NumDims(t);
528   TF_DataType dtype = TF_TensorType(t);
529   void* data = TF_TensorData(t);
530   const size_t sz = TF_TensorByteSize(t);
531   if (num_dims == 0) {
532     writeScalar(env, value, dtype, data, sz);
533   } else {
534     writeNDArray(env, static_cast<jarray>(value), dtype, num_dims,
535                  static_cast<char*>(data), sz);
536   }
537 }
538 
539 #define DEFINE_GET_SCALAR_METHOD(jtype, dtype, method_suffix)                  \
540   JNIEXPORT jtype JNICALL Java_org_tensorflow_Tensor_scalar##method_suffix(    \
541       JNIEnv* env, jclass clazz, jlong handle) {                               \
542     jtype ret = 0;                                                             \
543     TF_Tensor* t = requireHandle(env, handle);                                 \
544     if (t == nullptr) return ret;                                              \
545     if (TF_NumDims(t) != 0) {                                                  \
546       throwException(env, kIllegalStateException, "Tensor is not a scalar");   \
547     } else if (TF_TensorType(t) != dtype) {                                    \
548       throwException(env, kIllegalStateException, "Tensor is not a %s scalar", \
549                      #method_suffix);                                          \
550     } else {                                                                   \
551       memcpy(&ret, TF_TensorData(t), elemByteSize(dtype));                     \
552     }                                                                          \
553     return ret;                                                                \
554   }
555 DEFINE_GET_SCALAR_METHOD(jfloat, TF_FLOAT, Float);
556 DEFINE_GET_SCALAR_METHOD(jdouble, TF_DOUBLE, Double);
557 DEFINE_GET_SCALAR_METHOD(jint, TF_INT32, Int);
558 DEFINE_GET_SCALAR_METHOD(jlong, TF_INT64, Long);
559 DEFINE_GET_SCALAR_METHOD(jboolean, TF_BOOL, Boolean);
560 #undef DEFINE_GET_SCALAR_METHOD
561 
Java_org_tensorflow_Tensor_scalarBytes(JNIEnv * env,jclass clazz,jlong handle)562 JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Tensor_scalarBytes(
563     JNIEnv* env, jclass clazz, jlong handle) {
564   TF_Tensor* t = requireHandle(env, handle);
565   if (t == nullptr) return nullptr;
566   if (TF_NumDims(t) != 0) {
567     throwException(env, kIllegalStateException, "Tensor is not a scalar");
568     return nullptr;
569   }
570   if (TF_TensorType(t) != TF_STRING) {
571     throwException(env, kIllegalArgumentException,
572                    "Tensor is not a string/bytes scalar");
573     return nullptr;
574   }
575   const char* data = static_cast<const char*>(TF_TensorData(t));
576   const char* src = data + 8;
577   size_t src_len = TF_TensorByteSize(t) - 8;
578   uint64_t offset = 0;
579   memcpy(&offset, data, sizeof(offset));
580   if (offset >= src_len) {
581     throwException(env, kIllegalArgumentException,
582                    "invalid tensor encoding: bad offsets");
583     return nullptr;
584   }
585   TF_Status* status = TF_NewStatus();
586   jbyteArray ret = TF_StringDecodeTojbyteArray(env, src, src_len, status);
587   throwExceptionIfNotOK(env, status);
588   TF_DeleteStatus(status);
589   return ret;
590 }
591 
Java_org_tensorflow_Tensor_readNDArray(JNIEnv * env,jclass clazz,jlong handle,jobject value)592 JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_readNDArray(JNIEnv* env,
593                                                               jclass clazz,
594                                                               jlong handle,
595                                                               jobject value) {
596   TF_Tensor* t = requireHandle(env, handle);
597   if (t == nullptr) return;
598   int num_dims = TF_NumDims(t);
599   TF_DataType dtype = TF_TensorType(t);
600   const void* data = TF_TensorData(t);
601   const size_t sz = TF_TensorByteSize(t);
602   if (num_dims == 0) {
603     throwException(env, kIllegalArgumentException,
604                    "copyTo() is not meant for scalar Tensors, use the scalar "
605                    "accessor (floatValue(), intValue() etc.) instead");
606     return;
607   }
608   if (dtype == TF_STRING) {
609     int64_t num_elements = 1;
610     for (int i = 0; i < num_dims; ++i) {
611       num_elements *= TF_Dim(t, i);
612     }
613     StringTensorReader reader(t, num_elements);
614     TF_Status* status = TF_NewStatus();
615     readNDStringArray(env, &reader, num_dims, static_cast<jobjectArray>(value),
616                       status);
617     throwExceptionIfNotOK(env, status);
618     TF_DeleteStatus(status);
619     return;
620   }
621   readNDArray(env, dtype, static_cast<const char*>(data), sz, num_dims,
622               static_cast<jarray>(value));
623 }
624