1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <string>
18 #include <unordered_map>
19 #include <vector>
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/kernel_def_builder.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_types.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
31 
32 namespace tensorflow {
33 
34 namespace {
35 // Returning a Status instead of using OP_REQUIRES directly since that doesn't
36 // seem to work outside the main OpKernel functions.
RemapVectorToMap(const TTypes<const int64>::Vec & remapping,std::vector<bool> * id_present,std::unordered_map<int64,int64> * old_id_to_new_id)37 Status RemapVectorToMap(const TTypes<const int64>::Vec& remapping,
38                         std::vector<bool>* id_present,
39                         std::unordered_map<int64, int64>* old_id_to_new_id) {
40   id_present->clear();
41   id_present->resize(remapping.size(), false);
42   for (int i = 0; i < remapping.size(); ++i) {
43     const int64 old_id = remapping(i);
44     if (old_id < 0) continue;
45     (*id_present)[i] = true;
46     if (!gtl::InsertIfNotPresent(old_id_to_new_id, old_id, i)) {
47       return errors::Unimplemented(
48           strings::StrCat("Old ID ", old_id, " is mapped to both new ID ",
49                           old_id_to_new_id->at(old_id), " and ", i,
50                           ", which is not supported."));
51     }
52   }
53   return Status::OK();
54 }
55 }  // anonymous namespace
56 
57 // This op loads a rank-2 Tensor (matrix) from a TensorFlow checkpoint (V2) and
58 // swaps around the rows/columns according to row_remapping/col_remapping.
59 // "Missing" cells are initialized with values from initializing_values.
60 class LoadAndRemapMatrixOp : public OpKernel {
61  public:
LoadAndRemapMatrixOp(OpKernelConstruction * context)62   explicit LoadAndRemapMatrixOp(OpKernelConstruction* context)
63       : OpKernel(context) {
64     OP_REQUIRES_OK(context, context->GetAttr("num_rows", &num_rows_));
65     OP_REQUIRES_OK(context, context->GetAttr("num_cols", &num_cols_));
66     OP_REQUIRES_OK(
67         context, context->GetAttr("max_rows_in_memory", &max_rows_in_memory_));
68   }
69 
Compute(OpKernelContext * context)70   void Compute(OpKernelContext* context) override {
71     // Checks what we're remapping and inverts the relevant remapping Tensors to
72     // be maps with key = old ID, value = new ID.
73     std::unordered_map<int64, int64> old_row_to_new_row_map;
74     std::vector<bool> row_id_present;
75     const Tensor* row_remapping_t;
76     OP_REQUIRES_OK(context, context->input("row_remapping", &row_remapping_t));
77     const auto row_remapping = row_remapping_t->vec<int64>();
78     OP_REQUIRES(context, row_remapping.size() == num_rows_,
79                 errors::InvalidArgument(strings::StrCat(
80                     "Size of row_remapping is ", row_remapping.size(),
81                     " instead of being equal to num_rows=", num_rows_)));
82     OP_REQUIRES_OK(context, RemapVectorToMap(row_remapping, &row_id_present,
83                                              &old_row_to_new_row_map));
84 
85     // Calculates the min/max old row ID that we need to read, to save us from
86     // reading some unnecessary slices of the old tensor.
87     int64 min_old_row = -1;
88     int64 max_old_row = -1;
89     for (int i = 0; i < row_remapping.size(); ++i) {
90       if (min_old_row < 0 ||
91           (row_remapping(i) >= 0 && row_remapping(i) < min_old_row)) {
92         min_old_row = row_remapping(i);
93       }
94       if (max_old_row < 0 ||
95           (row_remapping(i) >= 0 && row_remapping(i) > max_old_row)) {
96         max_old_row = row_remapping(i);
97       }
98     }
99 
100     // Processes the remapping for columns.
101     std::unordered_map<int64, int64> old_col_to_new_col_map;
102     std::vector<bool> col_id_present;
103     const Tensor* col_remapping_t;
104     OP_REQUIRES_OK(context, context->input("col_remapping", &col_remapping_t));
105     const auto col_remapping = col_remapping_t->vec<int64>();
106     // Note that we always "remap rows", even when the row vocabulary does
107     // not change, because partitioning requires a mapping from partitioned
108     // Variables to the full checkpoints we load.
109     const bool remap_cols = col_remapping.size() > 0;
110     if (remap_cols) {
111       OP_REQUIRES(
112           context, col_remapping.size() == num_cols_,
113           errors::InvalidArgument(strings::StrCat(
114               "Provided col_remapping, but its size is ", col_remapping.size(),
115               " instead of being equal to num_cols=", num_cols_)));
116       OP_REQUIRES_OK(context, RemapVectorToMap(col_remapping, &col_id_present,
117                                                &old_col_to_new_col_map));
118     } else {
119       col_id_present.clear();
120       col_id_present.resize(num_cols_, true);
121     }
122 
123     // Processes the checkpoint source and the provided Tensor name.
124     const Tensor* ckpt_path_t;
125     OP_REQUIRES_OK(context, context->input("ckpt_path", &ckpt_path_t));
126     const string ckpt_path = *(ckpt_path_t->scalar<string>().data());
127     const Tensor* old_tensor_name_t;
128     OP_REQUIRES_OK(context,
129                    context->input("old_tensor_name", &old_tensor_name_t));
130     const string old_tensor_name =
131         *(old_tensor_name_t->scalar<string>().data());
132 
133     LOG(INFO) << "Processing checkpoint : " << ckpt_path;
134     BundleReader reader(context->env(), ckpt_path);
135     OP_REQUIRES_OK(context, reader.status());
136 
137     DataType tensor_type;
138     TensorShape tensor_shape;
139     OP_REQUIRES_OK(context, reader.LookupDtypeAndShape(
140                                 old_tensor_name, &tensor_type, &tensor_shape));
141     OP_REQUIRES(context, tensor_type == DT_FLOAT,
142                 errors::InvalidArgument(strings::StrCat(
143                     "Tensor ", old_tensor_name, " has invalid type ",
144                     DataTypeString(tensor_type), " instead of expected type ",
145                     DataTypeString(DT_FLOAT))));
146     // This op is limited to loading Tensors of rank 2 (matrices).
147     OP_REQUIRES(
148         context, tensor_shape.dims() == 2,
149         errors::InvalidArgument(strings::StrCat(
150             "Tensor ", old_tensor_name, " has shape ",
151             tensor_shape.DebugString(), " of invalid rank ",
152             tensor_shape.dims(), " instead of expected shape of rank 2.")));
153 
154     if (!remap_cols) {
155       // TODO(weiho): Consider relaxing this restriction to allow partial column
156       // loading (even when no column remapping is specified) if there turns out
157       // to be a use case for it.
158       OP_REQUIRES(context, num_cols_ == tensor_shape.dim_size(1),
159                   errors::InvalidArgument(strings::StrCat(
160                       "Tensor ", old_tensor_name, " has shape ",
161                       tensor_shape.DebugString(),
162                       ", where the size of its 2nd dimension is ",
163                       tensor_shape.dim_size(1),
164                       " instead of being equal to num_cols=", num_cols_)));
165     }
166 
167     // Uses TensorSlice to potentially load the old tensor in chunks in case
168     // memory usage is a concern.
169     std::vector<TensorSlice> tensor_slices;
170     TensorSlice slice(tensor_shape.dims());
171     if (min_old_row >= 0 && max_old_row >= 0) {
172       int64 row_start = min_old_row;
173       // TODO(weiho): Given the list of old row IDs of interest (the keys of
174       // old_row_to_new_row_map), we could also try something smarter to
175       // find some minimal set of covering ranges for the list of old row IDs
176       // such that the size of each range is less than max_rows_in_memory_.
177       while (row_start <= max_old_row) {
178         const int64 slice_length =
179             max_rows_in_memory_ <= 0
180                 // If max_rows_in_memory_ <= 0, we just load the entire chunk.
181                 ? max_old_row - row_start + 1
182                 : std::min(max_rows_in_memory_, max_old_row - row_start + 1);
183         slice.set_start(0, row_start);
184         slice.set_length(0, slice_length);
185         tensor_slices.push_back(slice);
186         row_start += slice_length;
187       }
188     }
189 
190     // Allocates the output matrix.
191     Tensor* output_matrix_t = nullptr;
192     OP_REQUIRES_OK(context,
193                    context->allocate_output("output_matrix",
194                                             TensorShape({num_rows_, num_cols_}),
195                                             &output_matrix_t));
196     auto output_matrix = output_matrix_t->matrix<float>();
197 
198     // Iterates through tensor slices and copies over values from the old tensor
199     // to the output matrix.
200     int64 row_index = min_old_row;
201     int64 rows_copied = 0;
202     Tensor loaded_tensor_t;
203     for (const TensorSlice& tensor_slice : tensor_slices) {
204       LOG(INFO) << "Loading slice " << tensor_slice.DebugString();
205       TensorShape slice_shape;
206       OP_REQUIRES_OK(context,
207                      tensor_slice.SliceTensorShape(tensor_shape, &slice_shape));
208       // Potentially re-allocates the tensor buffer since the last slice may
209       // have fewer rows than the other slices.
210       if (loaded_tensor_t.shape() != slice_shape) {
211         loaded_tensor_t = Tensor(DT_FLOAT, slice_shape);
212       }
213       OP_REQUIRES_OK(context, reader.LookupSlice(old_tensor_name, tensor_slice,
214                                                  &loaded_tensor_t));
215 
216       // Iterates through the old loaded tensor slice row-by-row.
217       for (int row = 0; row < loaded_tensor_t.dim_size(0); ++row, ++row_index) {
218         if (row_index % 500000 == min_old_row) {
219           LOG(INFO) << "Processing old row " << row_index;
220         }
221 
222         // If the old row ID is not found in old_row_to_new_row_map, continue
223         // to the next row; otherwise, copy it to the output matrix.
224         const int64* new_row_ptr =
225             gtl::FindOrNull(old_row_to_new_row_map, row_index);
226         if (new_row_ptr == nullptr) {
227           continue;
228         }
229         ++rows_copied;
230         const int64 new_row = *new_row_ptr;
231 
232         // Copies over the row element-by-element, in case remapping is needed
233         // along the column axis.
234         const auto& loaded_tensor = loaded_tensor_t.matrix<float>();
235         for (int old_col = 0; old_col < loaded_tensor_t.dim_size(1);
236              ++old_col) {
237           int64 new_col = old_col;
238           if (remap_cols) {
239             const int64* new_col_ptr =
240                 gtl::FindOrNull(old_col_to_new_col_map, old_col);
241             if (new_col_ptr == nullptr) {
242               // Column remapping is specified, but this column is not found in
243               // old_col_to_new_col_map, so we leave it uninitialized, to be
244               // filled in with initializing_values later.
245               continue;
246             }
247             new_col = *new_col_ptr;
248           }
249 
250           OP_REQUIRES(context,
251                       new_row < num_rows_ && new_col < num_cols_ &&
252                           new_row >= 0 && new_col >= 0,
253                       errors::Internal(strings::StrCat(
254                           "new_row=", new_row, " and new_col=", new_col,
255                           " should have been less than num_rows_=", num_rows_,
256                           " and num_cols_=", num_cols_,
257                           " and non-negative. This should never have happened "
258                           "if the code were correct. Please file a bug.")));
259           output_matrix(new_row, new_col) = loaded_tensor(row, old_col);
260         }
261       }
262     }
263     LOG(INFO) << "Copied " << rows_copied << " rows from old matrix (with "
264               << tensor_shape.dim_size(0) << " rows) to new matrix (with "
265               << num_rows_ << " rows).";
266 
267     // At this point, there are potentially whole rows/columns uninitialized
268     // (corresponding to the indices where row_id_present/col_id_present are
269     // false). We fill this in cell-by-cell using row_id_present and
270     // col_id_present while dequeuing from the initializing_values vector.
271     const Tensor* initializing_values_t;
272     OP_REQUIRES_OK(
273         context, context->input("initializing_values", &initializing_values_t));
274     const auto initializing_values = initializing_values_t->flat<float>();
275     int64 initializing_values_index = 0;
276     for (int i = 0; i < num_rows_; ++i) {
277       for (int j = 0; j < num_cols_; ++j) {
278         if (row_id_present[i] && col_id_present[j]) continue;
279         OP_REQUIRES(
280             context, initializing_values_index < initializing_values.size(),
281             errors::InvalidArgument(
282                 "initializing_values contained ", initializing_values.size(),
283                 " elements, but more missing values remain."));
284         output_matrix(i, j) = initializing_values(initializing_values_index);
285         ++initializing_values_index;
286       }
287     }
288 
289     // Checks that we used all the given initializing values.
290     OP_REQUIRES(
291         context, initializing_values_index == initializing_values.size(),
292         errors::InvalidArgument(
293             "initializing_values contained ", initializing_values.size(),
294             " elements, but only ", initializing_values_index,
295             " elements were used to fill in missing values."));
296   }
297 
298  private:
299   int64 num_rows_;
300   int64 num_cols_;
301   int64 max_rows_in_memory_;
302 };
303 
304 REGISTER_KERNEL_BUILDER(Name("LoadAndRemapMatrix").Device(DEVICE_CPU),
305                         LoadAndRemapMatrixOp);
306 
307 }  // namespace tensorflow
308