1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Utilities for saving/restoring tensor slice checkpoints. 17 18 #ifndef TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ 19 #define TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ 20 21 #include <string> // for string 22 #include "tensorflow/core/framework/tensor.pb.h" 23 #include "tensorflow/core/framework/tensor_slice.h" 24 #include "tensorflow/core/framework/types.h" 25 #include "tensorflow/core/lib/core/status.h" // for Status 26 #include "tensorflow/core/platform/protobuf.h" 27 28 namespace tensorflow { 29 30 namespace checkpoint { 31 32 // The key for the metadata in the tensor slice checkpoint files. It is "" so 33 // that the metadata is always at the beginning of a checkpoint file. 34 extern const char kSavedTensorSlicesKey[]; 35 36 // Encode a tensor name + a tensor slice into an ordered code and outputs it as 37 // a string. 38 // The format is 39 // <0> 40 // <tensor_name> 41 // <rank> 42 // <dim-0-start><dim-0-length> 43 // <dim-1-start><dim-1-length> 44 // ... 45 46 string EncodeTensorNameSlice(const string& name, 47 const tensorflow::TensorSlice& slice); 48 49 // Parse out the name and the slice from string encoded as an ordered code. 50 Status DecodeTensorNameSlice(const string& code, string* name, 51 tensorflow::TensorSlice* slice); 52 53 // Extracts the full shape, slice spec, and shape of the slice from 54 // "shape_and_slice". On non-OK return, caller must clear the out-arguments 55 // before reusing. 56 Status ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape, 57 TensorSlice* slice, TensorShape* shape_slice); 58 59 template <typename T> 60 struct SaveTypeTraits; 61 62 template <typename T> 63 const typename SaveTypeTraits<T>::SavedType* TensorProtoData( 64 const TensorProto& t); 65 66 template <typename T> 67 typename SaveTypeTraits<T>::RepeatedField* MutableTensorProtoData( 68 TensorProto* t); 69 70 template <typename T> 71 void Fill(T* data, size_t n, TensorProto* t); 72 73 #define TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, STYPE) \ 74 template <> \ 75 struct SaveTypeTraits<TYPE> { \ 76 static constexpr bool supported = true; \ 77 typedef STYPE SavedType; \ 78 typedef protobuf::RepeatedField<FTYPE> RepeatedField; \ 79 }; \ 80 template <> \ 81 inline const STYPE* TensorProtoData<TYPE>(const TensorProto& t) { \ 82 static_assert(SaveTypeTraits<TYPE>::supported, \ 83 "Specified type " #TYPE " not supported for Restore"); \ 84 return reinterpret_cast<const STYPE*>(t.FIELD##_val().data()); \ 85 } \ 86 template <> \ 87 inline protobuf::RepeatedField<FTYPE>* MutableTensorProtoData<TYPE>( \ 88 TensorProto * t) { \ 89 static_assert(SaveTypeTraits<TYPE>::supported, \ 90 "Specified type " #TYPE " not supported for Save"); \ 91 return reinterpret_cast<protobuf::RepeatedField<FTYPE>*>( \ 92 t->mutable_##FIELD##_val()); \ 93 } 94 95 #define TENSOR_PROTO_EXTRACT_TYPE(TYPE, FIELD, FTYPE) \ 96 TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, FTYPE) \ 97 template <> \ 98 inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \ 99 typename protobuf::RepeatedField<FTYPE> copy(data, data + n); \ 100 t->mutable_##FIELD##_val()->Swap(©); \ 101 } 102 103 // Complex needs special treatment since proto doesn't have native complex 104 #define TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(TYPE, FIELD, FTYPE) \ 105 TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, TYPE) \ 106 template <> \ 107 inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \ 108 const FTYPE* sub = reinterpret_cast<const FTYPE*>(data); \ 109 typename protobuf::RepeatedField<FTYPE> copy(sub, sub + 2 * n); \ 110 t->mutable_##FIELD##_val()->Swap(©); \ 111 } 112 113 TENSOR_PROTO_EXTRACT_TYPE(bool, bool, bool); 114 TENSOR_PROTO_EXTRACT_TYPE(float, float, float); 115 TENSOR_PROTO_EXTRACT_TYPE(double, double, double); 116 TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex64, scomplex, float); 117 TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex128, dcomplex, double); 118 TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32); 119 TENSOR_PROTO_EXTRACT_TYPE(int64, int64, protobuf_int64); 120 TENSOR_PROTO_EXTRACT_TYPE(uint16, int, int32); 121 TENSOR_PROTO_EXTRACT_TYPE(uint8, int, int32); 122 TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32); 123 TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32); 124 TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32); 125 TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32); 126 TENSOR_PROTO_EXTRACT_TYPE(quint16, int, int32); 127 128 #undef TENSOR_PROTO_EXTRACT_TYPE_COMPLEX 129 #undef TENSOR_PROTO_EXTRACT_TYPE_HELPER 130 #undef TENSOR_PROTO_EXTRACT_TYPE 131 132 // Custom implementation for qint32, based on the one for int32. 133 134 template <> 135 struct SaveTypeTraits<qint32> : SaveTypeTraits<int32> {}; 136 137 template <> 138 inline const int32* TensorProtoData<qint32>(const TensorProto& t) { 139 static_assert(SaveTypeTraits<qint32>::supported, 140 "Specified type qint32 not supported for Restore"); 141 return reinterpret_cast<const int32*>(t.int_val().data()); 142 } 143 144 inline void Fill(const qint32* data, size_t n, TensorProto* t) { 145 const int32* p = reinterpret_cast<const int32*>(data); 146 typename protobuf::RepeatedField<int32> copy(p, p + n); 147 t->mutable_int_val()->Swap(©); 148 } 149 150 // Custom implementation for Eigen::half. 151 152 template <> 153 struct SaveTypeTraits<Eigen::half> { 154 static constexpr bool supported = true; 155 typedef int SavedType; 156 typedef protobuf::RepeatedField<int32> RepeatedField; 157 }; 158 159 template <> 160 inline const int* TensorProtoData<Eigen::half>(const TensorProto& t) { 161 return t.half_val().data(); 162 } 163 164 template <> 165 inline protobuf::RepeatedField<int32>* MutableTensorProtoData<Eigen::half>( 166 TensorProto* t) { 167 return t->mutable_half_val(); 168 } 169 170 template <> 171 inline void Fill(const Eigen::half* data, size_t n, TensorProto* t) { 172 typename protobuf::RepeatedField<int32>* val = t->mutable_half_val(); 173 val->Resize(n, 0); 174 for (size_t i = 0; i < n; ++i) { 175 val->Set(i, data[i].x); 176 } 177 } 178 179 // Custom implementation for string. 180 181 template <> 182 struct SaveTypeTraits<string> { 183 static constexpr bool supported = true; 184 typedef const string* SavedType; 185 typedef protobuf::RepeatedPtrField<string> RepeatedField; 186 }; 187 188 template <> 189 inline const string* const* TensorProtoData<string>(const TensorProto& t) { 190 static_assert(SaveTypeTraits<string>::supported, 191 "Specified type string not supported for Restore"); 192 return t.string_val().data(); 193 } 194 195 template <> 196 inline protobuf::RepeatedPtrField<string>* MutableTensorProtoData<string>( 197 TensorProto* t) { 198 static_assert(SaveTypeTraits<string>::supported, 199 "Specified type string not supported for Save"); 200 return t->mutable_string_val(); 201 } 202 203 template <> 204 inline void Fill(const string* data, size_t n, TensorProto* t) { 205 typename protobuf::RepeatedPtrField<string> copy(data, data + n); 206 t->mutable_string_val()->Swap(©); 207 } 208 209 } // namespace checkpoint 210 211 } // namespace tensorflow 212 213 #endif // TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ 214