1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
17 
18 #include "tensorflow/core/lib/core/status.h"
19 
20 namespace tensorflow {
21 
ByteSwapArray(char * array,size_t bytes_per_elem,int array_len)22 Status ByteSwapArray(char* array, size_t bytes_per_elem, int array_len) {
23   if (bytes_per_elem == 1) {
24     // No-op
25     return Status::OK();
26   } else if (bytes_per_elem == 2) {
27     auto array_16 = reinterpret_cast<uint16_t*>(array);
28     for (int i = 0; i < array_len; i++) {
29       array_16[i] = BYTE_SWAP_16(array_16[i]);
30     }
31     return Status::OK();
32   } else if (bytes_per_elem == 4) {
33     auto array_32 = reinterpret_cast<uint32_t*>(array);
34     for (int i = 0; i < array_len; i++) {
35       array_32[i] = BYTE_SWAP_32(array_32[i]);
36     }
37     return Status::OK();
38   } else if (bytes_per_elem == 8) {
39     auto array_64 = reinterpret_cast<uint64_t*>(array);
40     for (int i = 0; i < array_len; i++) {
41       array_64[i] = BYTE_SWAP_64(array_64[i]);
42     }
43     return Status::OK();
44   } else {
45     return errors::Unimplemented("Byte-swapping of ", bytes_per_elem,
46                                  "-byte values not supported.");
47   }
48 }
49 
ByteSwapTensor(Tensor * t)50 Status ByteSwapTensor(Tensor* t) {
51   size_t bytes_per_elem = 0;
52   int array_len = t->NumElements();
53 
54   switch (t->dtype()) {
55     // Types that don't need byte-swapping
56     case DT_STRING:
57     case DT_QINT8:
58     case DT_QUINT8:
59     case DT_BOOL:
60     case DT_UINT8:
61     case DT_INT8:
62       return Status::OK();
63 
64     // 16-bit types
65     case DT_BFLOAT16:
66     case DT_HALF:
67     case DT_QINT16:
68     case DT_QUINT16:
69     case DT_UINT16:
70     case DT_INT16:
71       bytes_per_elem = 2;
72       break;
73 
74     // 32-bit types
75     case DT_FLOAT:
76     case DT_INT32:
77     case DT_QINT32:
78     case DT_UINT32:
79       bytes_per_elem = 4;
80       break;
81 
82     // 64-bit types
83     case DT_INT64:
84     case DT_DOUBLE:
85     case DT_UINT64:
86       bytes_per_elem = 8;
87       break;
88 
89     // Complex types need special handling
90     case DT_COMPLEX64:
91       bytes_per_elem = 4;
92       array_len *= 2;
93       break;
94 
95     case DT_COMPLEX128:
96       bytes_per_elem = 8;
97       array_len *= 2;
98       break;
99 
100     // Types that ought to be supported in the future
101     case DT_RESOURCE:
102     case DT_VARIANT:
103       return errors::Unimplemented(
104           "Byte-swapping not yet implemented for tensors with dtype ",
105           t->dtype());
106 
107     // Byte-swapping shouldn't make sense for other dtypes.
108     default:
109       return errors::Unimplemented(
110           "Byte-swapping not supported for tensors with dtype ", t->dtype());
111   }
112 
113   char* backing_buffer = const_cast<char*>((t->tensor_data().data()));
114   TF_RETURN_IF_ERROR(ByteSwapArray(backing_buffer, bytes_per_elem, array_len));
115   return Status::OK();
116 }
117 
118 }  // namespace tensorflow
119