1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/parsing_ops.cc.
17 
18 #include <numeric>
19 #include <unordered_set>
20 #include <vector>
21 
22 #include "absl/base/call_once.h"
23 #include "tensorflow/core/common_runtime/metrics.h"
24 #include "tensorflow/core/example/example.pb.h"
25 #include "tensorflow/core/example/feature.pb.h"
26 #include "tensorflow/core/framework/common_shape_fns.h"
27 #include "tensorflow/core/framework/numeric_op.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/gtl/array_slice.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/protobuf.h"
33 #include "tensorflow/core/util/example_proto_fast_parsing.h"
34 #include "tensorflow/core/util/example_proto_helper.h"
35 #include "tensorflow/core/util/sparse/sparse_tensor.h"
36 #include "tensorflow/core/util/work_sharder.h"
37 
38 namespace tensorflow {
39 
40 namespace {
41 constexpr char kParseExampleV2[] = "ParseExampleV2";
42 constexpr char kParseSequenceExampleV2[] = "ParseSequenceExampleV2";
43 }  // namespace
44 
45 // Note: this kernel is used by both the ParseExample op and the ParseExampleV2
46 // op.  It automatically determines which op was used by checking if the
47 // "ragged_value_types" attribute exists.
48 class ParseExampleOp : public OpKernel {
49  public:
ParseExampleOp(OpKernelConstruction * ctx)50   explicit ParseExampleOp(OpKernelConstruction* ctx)
51       : OpKernel(ctx), op_version_(ctx->def().op() == kParseExampleV2 ? 2 : 1) {
52     OP_REQUIRES_OK(ctx, attrs_.Init(ctx, op_version_));
53   }
54 
Compute(OpKernelContext * ctx)55   void Compute(OpKernelContext* ctx) override {
56     const Tensor* names;
57     const Tensor* serialized;
58     std::vector<StringPiece> dense_keys_t;
59     std::vector<StringPiece> sparse_keys_t;
60     std::vector<StringPiece> ragged_keys_t;
61     OpInputList dense_defaults;
62 
63     // Grab the inputs.
64     OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
65     OP_REQUIRES_OK(ctx, ctx->input("names", &names));
66     if (op_version_ == 2) {
67       OP_REQUIRES_OK(ctx, GetTensorKeys(ctx, "dense_keys", &dense_keys_t));
68       OP_REQUIRES_OK(ctx, GetTensorKeys(ctx, "sparse_keys", &sparse_keys_t));
69       OP_REQUIRES_OK(ctx, GetTensorKeys(ctx, "ragged_keys", &ragged_keys_t));
70     } else {
71       OP_REQUIRES_OK(ctx, GetInputListKeys(ctx, "dense_keys", &dense_keys_t));
72       OP_REQUIRES_OK(ctx, GetInputListKeys(ctx, "sparse_keys", &sparse_keys_t));
73     }
74     absl::call_once(flag_, [&dense_keys_t, &sparse_keys_t, &ragged_keys_t]() {
75       metrics::RecordParseDenseFeature(dense_keys_t.size());
76       metrics::RecordParseSparseFeature(sparse_keys_t.size());
77       metrics::RecordParseRaggedFeature(ragged_keys_t.size());
78     });
79     OP_REQUIRES_OK(ctx, ctx->input_list("dense_defaults", &dense_defaults));
80 
81     // Validate input tensor shapes.
82     OP_REQUIRES_OK(
83         ctx, CheckInputShapes(serialized, names, dense_defaults, dense_keys_t,
84                               sparse_keys_t, ragged_keys_t));
85 
86     example::FastParseExampleConfig config =
87         MakeConfig(dense_keys_t, sparse_keys_t, ragged_keys_t, dense_defaults);
88 
89     example::Result result;
90     if (TensorShapeUtils::IsVector(serialized->shape())) {
91       OP_REQUIRES_OK(
92           ctx, ParseExampleVector(config, serialized, names, ctx, &result));
93     } else {
94       OP_REQUIRES_OK(ctx, ParseExampleScalar(config, serialized, ctx, &result));
95     }
96     OP_REQUIRES_OK(ctx, WriteOutput(result, ctx));
97   }
98 
99  protected:
100   // Copies keys from tensor to std::vector<string>.
GetTensorKeys(OpKernelContext * ctx,StringPiece input_name,std::vector<StringPiece> * keys) const101   Status GetTensorKeys(OpKernelContext* ctx, StringPiece input_name,
102                        std::vector<StringPiece>* keys) const {
103     const Tensor* key_t;
104     TF_RETURN_IF_ERROR(ctx->input(input_name, &key_t));
105     keys->reserve(key_t->NumElements());
106     auto keys_flat = key_t->flat<tstring>();
107     for (int i = 0; i < keys_flat.size(); ++i) {
108       keys->push_back(keys_flat(i));
109     }
110     return Status::OK();
111   }
112 
113   // Copies keys from OpInputList of scalar to std::vector<string>.
GetInputListKeys(OpKernelContext * ctx,StringPiece input_name,std::vector<StringPiece> * keys) const114   Status GetInputListKeys(OpKernelContext* ctx, StringPiece input_name,
115                           std::vector<StringPiece>* keys) const {
116     OpInputList key_list;
117     TF_RETURN_IF_ERROR(ctx->input_list(input_name, &key_list));
118     keys->reserve(key_list.size());
119     for (const auto& key : key_list) {
120       keys->push_back(key.scalar<tstring>()());
121     }
122     return Status::OK();
123   }
124 
125   // Validates the shapes of input tensors.
CheckInputShapes(const Tensor * serialized,const Tensor * names,const OpInputList & dense_defaults,const std::vector<StringPiece> & dense_keys_t,const std::vector<StringPiece> & sparse_keys_t,const std::vector<StringPiece> & ragged_keys_t) const126   Status CheckInputShapes(const Tensor* serialized, const Tensor* names,
127                           const OpInputList& dense_defaults,
128                           const std::vector<StringPiece>& dense_keys_t,
129                           const std::vector<StringPiece>& sparse_keys_t,
130                           const std::vector<StringPiece>& ragged_keys_t) const {
131     if (op_version_ == 2) {
132       if (TensorShapeUtils::IsMatrixOrHigher(serialized->shape())) {
133         return errors::InvalidArgument(
134             "Expected serialized to be a scalar or vector, got shape: ",
135             serialized->shape().DebugString());
136       }
137     } else {
138       if (!TensorShapeUtils::IsVector(serialized->shape())) {
139         return errors::InvalidArgument(
140             "Expected serialized to be a vector, got shape: ",
141             serialized->shape().DebugString());
142       }
143     }
144     if (names->NumElements() > 0 && names->shape() != serialized->shape()) {
145       return errors::InvalidArgument(
146           "Expected names have the same shape as serialized: name.shape=",
147           names->shape().DebugString(),
148           ", serialized.shape=", serialized->shape().DebugString());
149     }
150     if (op_version_ == 2) {
151       if (dense_keys_t.size() != attrs_.num_dense) {
152         return errors::InvalidArgument(
153             "Expected len(dense_keys) == len(dense_types) but got: ",
154             dense_keys_t.size(), " vs. ", attrs_.num_dense);
155       }
156       if (sparse_keys_t.size() != attrs_.num_sparse) {
157         return errors::InvalidArgument(
158             "Expected len(sparse_keys) == num_sparse but got: ",
159             sparse_keys_t.size(), " vs. ", attrs_.num_sparse);
160       }
161       if (ragged_keys_t.size() != attrs_.num_ragged) {
162         return errors::InvalidArgument(
163             "Expected len(ragged_keys) == len(ragged_value_types) but got: ",
164             ragged_keys_t.size(), " vs. ", attrs_.num_ragged);
165       }
166     }
167 
168     if (dense_defaults.size() != attrs_.num_dense) {
169       return errors::InvalidArgument(
170           "Expected len(dense_defaults) == len(dense_keys) but got: ",
171           dense_defaults.size(), " vs. ", attrs_.num_dense);
172     }
173 
174     for (int d = 0; d < static_cast<int>(attrs_.num_dense); ++d) {
175       const Tensor& def_value = dense_defaults[d];
176       if (attrs_.variable_length[d]) {
177         if (def_value.NumElements() != 1) {
178           return errors::InvalidArgument(
179               "dense_shape[", d, "] is a variable length shape: ",
180               attrs_.dense_shapes[d].DebugString(),
181               ", therefore "
182               "def_value[",
183               d,
184               "] must contain a single element ("
185               "the padding element).  But its shape is: ",
186               def_value.shape().DebugString());
187         }
188       } else if (def_value.NumElements() > 0) {
189         if (!attrs_.dense_shapes[d].IsCompatibleWith(def_value.shape())) {
190           return errors::InvalidArgument(
191               "def_value[", d, "].shape() == ", def_value.shape().DebugString(),
192               " is not compatible with dense_shapes_[", d,
193               "] == ", attrs_.dense_shapes[d].DebugString());
194         }
195       }
196       if (def_value.dtype() != attrs_.dense_types[d]) {
197         return errors::InvalidArgument(
198             "dense_defaults[", d,
199             "].dtype() == ", DataTypeString(def_value.dtype()),
200             " != dense_types_[", d,
201             "] == ", DataTypeString(attrs_.dense_types[d]));
202       }
203     }
204     return Status::OK();
205   }
206 
207   // Populates the FastParseExampleConfig from keys & defaults.
MakeConfig(const std::vector<StringPiece> & dense_keys_t,const std::vector<StringPiece> & sparse_keys_t,const std::vector<StringPiece> & ragged_keys_t,const OpInputList & dense_defaults) const208   example::FastParseExampleConfig MakeConfig(
209       const std::vector<StringPiece>& dense_keys_t,
210       const std::vector<StringPiece>& sparse_keys_t,
211       const std::vector<StringPiece>& ragged_keys_t,
212       const OpInputList& dense_defaults) const {
213     example::FastParseExampleConfig config;
214     config.dense.reserve(attrs_.num_dense);
215     for (int d = 0; d < attrs_.num_dense; ++d) {
216       config.dense.emplace_back(dense_keys_t[d], attrs_.dense_types[d],
217                                 attrs_.dense_shapes[d], dense_defaults[d],
218                                 attrs_.variable_length[d],
219                                 attrs_.elements_per_stride[d]);
220     }
221     config.sparse.reserve(attrs_.num_sparse);
222     for (int d = 0; d < attrs_.num_sparse; ++d) {
223       config.sparse.emplace_back(sparse_keys_t[d], attrs_.sparse_types[d]);
224     }
225     config.ragged.reserve(attrs_.num_ragged);
226     for (int d = 0; d < attrs_.num_ragged; ++d) {
227       config.ragged.emplace_back(ragged_keys_t[d], attrs_.ragged_value_types[d],
228                                  attrs_.ragged_split_types[d]);
229     }
230     return config;
231   }
232 
233   // Parses a single example.
ParseExampleScalar(const example::FastParseExampleConfig & config,const Tensor * serialized,OpKernelContext * ctx,example::Result * result) const234   Status ParseExampleScalar(const example::FastParseExampleConfig& config,
235                             const Tensor* serialized, OpKernelContext* ctx,
236                             example::Result* result) const {
237     const tstring& serialized_proto = serialized->scalar<tstring>()();
238     return FastParseSingleExample(config, serialized_proto, result);
239   }
240 
241   // Parses a vector of examples.
ParseExampleVector(const example::FastParseExampleConfig & config,const Tensor * serialized,const Tensor * names,OpKernelContext * ctx,example::Result * result) const242   Status ParseExampleVector(const example::FastParseExampleConfig& config,
243                             const Tensor* serialized, const Tensor* names,
244                             OpKernelContext* ctx,
245                             example::Result* result) const {
246     auto serialized_t = serialized->flat<tstring>();
247     auto names_t = names->flat<tstring>();
248     gtl::ArraySlice<tstring> slice(serialized_t.data(), serialized_t.size());
249     gtl::ArraySlice<tstring> names_slice(names_t.data(), names_t.size());
250     return FastParseExample(
251         config, slice, names_slice,
252         ctx->device()->tensorflow_cpu_worker_threads()->workers, result);
253   }
254 
WriteOutput(const example::Result & result,OpKernelContext * ctx) const255   Status WriteOutput(const example::Result& result,
256                      OpKernelContext* ctx) const {
257     OpOutputList dense_values;
258     OpOutputList sparse_indices;
259     OpOutputList sparse_values;
260     OpOutputList sparse_shapes;
261     TF_RETURN_IF_ERROR(ctx->output_list("dense_values", &dense_values));
262     TF_RETURN_IF_ERROR(ctx->output_list("sparse_indices", &sparse_indices));
263     TF_RETURN_IF_ERROR(ctx->output_list("sparse_values", &sparse_values));
264     TF_RETURN_IF_ERROR(ctx->output_list("sparse_shapes", &sparse_shapes));
265     for (int d = 0; d < attrs_.num_dense; ++d) {
266       dense_values.set(d, result.dense_values[d]);
267     }
268     for (int d = 0; d < attrs_.num_sparse; ++d) {
269       sparse_indices.set(d, result.sparse_indices[d]);
270       sparse_values.set(d, result.sparse_values[d]);
271       sparse_shapes.set(d, result.sparse_shapes[d]);
272     }
273     if (op_version_ == 2) {
274       OpOutputList ragged_values;
275       OpOutputList ragged_splits;
276       TF_RETURN_IF_ERROR(ctx->output_list("ragged_values", &ragged_values));
277       TF_RETURN_IF_ERROR(ctx->output_list("ragged_row_splits", &ragged_splits));
278       for (int d = 0; d < attrs_.num_ragged; ++d) {
279         ragged_values.set(d, result.ragged_values[d]);
280         ragged_splits.set(d, result.ragged_splits[d]);
281       }
282     }
283     return Status::OK();
284   }
285 
286   ParseExampleAttrs attrs_;
287   int op_version_;
288   absl::once_flag flag_;
289 };
290 
291 REGISTER_KERNEL_BUILDER(Name("ParseExample").Device(DEVICE_CPU),
292                         ParseExampleOp);
293 REGISTER_KERNEL_BUILDER(Name("ParseExampleV2").Device(DEVICE_CPU),
294                         ParseExampleOp);
295 
296 class ParseSingleExampleOp : public OpKernel {
297  public:
ParseSingleExampleOp(OpKernelConstruction * ctx)298   explicit ParseSingleExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
299     OP_REQUIRES_OK(ctx, attrs_.Init(ctx));
300     metrics::RecordParseDenseFeature(attrs_.dense_keys.size());
301     metrics::RecordParseSparseFeature(attrs_.sparse_keys.size());
302   }
303 
Compute(OpKernelContext * ctx)304   void Compute(OpKernelContext* ctx) override {
305     const Tensor* serialized;
306     OpInputList dense_defaults;
307 
308     // Grab the input list arguments.
309     OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
310     OP_REQUIRES_OK(ctx, ctx->input_list("dense_defaults", &dense_defaults));
311 
312     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(serialized->shape()),
313                 errors::InvalidArgument(
314                     "Expected serialized to be a scalar, got shape: ",
315                     serialized->shape().DebugString()));
316     OP_REQUIRES(ctx, dense_defaults.size() == attrs_.dense_keys.size(),
317                 errors::InvalidArgument(
318                     "Expected len(dense_defaults) == len(dense_keys) but got: ",
319                     dense_defaults.size(), " vs. ", attrs_.dense_keys.size()));
320 
321     for (size_t d = 0; d < attrs_.dense_keys.size(); ++d) {
322       const Tensor& def_value = dense_defaults[d];
323       if (attrs_.variable_length[d]) {
324         OP_REQUIRES(ctx, def_value.NumElements() == 1,
325                     errors::InvalidArgument(
326                         "dense_shape[", d, "] is a variable length shape: ",
327                         attrs_.dense_shapes[d].DebugString(),
328                         ", therefore "
329                         "def_value[",
330                         d,
331                         "] must contain a single element ("
332                         "the padding element).  But its shape is: ",
333                         def_value.shape().DebugString()));
334       } else if (def_value.NumElements() > 0) {
335         OP_REQUIRES(ctx,
336                     attrs_.dense_shapes[d].IsCompatibleWith(def_value.shape()),
337                     errors::InvalidArgument(
338                         "def_value[", d,
339                         "].shape() == ", def_value.shape().DebugString(),
340                         " is not compatible with dense_shapes_[", d,
341                         "] == ", attrs_.dense_shapes[d].DebugString()));
342       }
343       OP_REQUIRES(ctx, def_value.dtype() == attrs_.dense_types[d],
344                   errors::InvalidArgument(
345                       "dense_defaults[", d, "].dtype() == ",
346                       DataTypeString(def_value.dtype()), " != dense_types_[", d,
347                       "] == ", DataTypeString(attrs_.dense_types[d])));
348     }
349 
350     example::Result result;
351 
352     // TODO(mrry): Build the configuration once and cache it.
353     example::FastParseExampleConfig config;
354     for (int d = 0; d < attrs_.dense_keys.size(); ++d) {
355       config.dense.push_back({attrs_.dense_keys[d], attrs_.dense_types[d],
356                               attrs_.dense_shapes[d], dense_defaults[d],
357                               attrs_.variable_length[d],
358                               attrs_.elements_per_stride[d]});
359     }
360     for (int d = 0; d < attrs_.sparse_keys.size(); ++d) {
361       config.sparse.push_back({attrs_.sparse_keys[d], attrs_.sparse_types[d]});
362     }
363 
364     const tstring& serialized_proto = serialized->scalar<tstring>()();
365 
366     OP_REQUIRES_OK(ctx,
367                    FastParseSingleExample(config, serialized_proto, &result));
368 
369     OpOutputList dense_values;
370     OpOutputList sparse_indices;
371     OpOutputList sparse_values;
372     OpOutputList sparse_shapes;
373     OP_REQUIRES_OK(ctx, ctx->output_list("dense_values", &dense_values));
374     OP_REQUIRES_OK(ctx, ctx->output_list("sparse_indices", &sparse_indices));
375     OP_REQUIRES_OK(ctx, ctx->output_list("sparse_values", &sparse_values));
376     OP_REQUIRES_OK(ctx, ctx->output_list("sparse_shapes", &sparse_shapes));
377     for (int d = 0; d < attrs_.dense_keys.size(); ++d) {
378       dense_values.set(d, result.dense_values[d]);
379     }
380     for (int d = 0; d < attrs_.sparse_keys.size(); ++d) {
381       sparse_indices.set(d, result.sparse_indices[d]);
382       sparse_values.set(d, result.sparse_values[d]);
383       sparse_shapes.set(d, result.sparse_shapes[d]);
384     }
385   }
386 
387  protected:
388   ParseSingleExampleAttrs attrs_;
389 };
390 
391 REGISTER_KERNEL_BUILDER(Name("ParseSingleExample").Device(DEVICE_CPU),
392                         ParseSingleExampleOp);
393 
394 class ParseSequenceExampleOp : public OpKernel {
395  public:
ParseSequenceExampleOp(OpKernelConstruction * ctx)396   explicit ParseSequenceExampleOp(OpKernelConstruction* ctx)
397       : OpKernel(ctx),
398         op_version_(ctx->def().op() == kParseSequenceExampleV2 ? 2 : 1) {
399     OP_REQUIRES_OK(ctx, attrs_.Init(ctx, op_version_));
400     metrics::RecordParseDenseFeature(attrs_.context_dense_keys.size() +
401                                      attrs_.feature_list_dense_keys.size());
402     metrics::RecordParseSparseFeature(attrs_.context_sparse_keys.size() +
403                                       attrs_.feature_list_sparse_keys.size());
404   }
405 
Compute(OpKernelContext * ctx)406   void Compute(OpKernelContext* ctx) override {
407     const Tensor* debug_name;
408     const Tensor* serialized;
409     OpInputList context_dense_defaults;
410 
411     OP_REQUIRES_OK(ctx, ctx->input("debug_name", &debug_name));
412     OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
413     OP_REQUIRES_OK(ctx, ctx->input_list("context_dense_defaults",
414                                         &context_dense_defaults));
415     const Tensor* context_dense_keys = nullptr;
416     const Tensor* context_sparse_keys = nullptr;
417     const Tensor* context_ragged_keys = nullptr;
418     const Tensor* feature_list_dense_keys = nullptr;
419     const Tensor* feature_list_sparse_keys = nullptr;
420     const Tensor* feature_list_ragged_keys = nullptr;
421     const Tensor* feature_list_dense_missing_assumed_empty = nullptr;
422     if (op_version_ == 2) {
423       OP_REQUIRES_OK(ctx,
424                      ctx->input("feature_list_dense_missing_assumed_empty",
425                                 &feature_list_dense_missing_assumed_empty));
426       OP_REQUIRES_OK(ctx,
427                      ctx->input("context_dense_keys", &context_dense_keys));
428       OP_REQUIRES_OK(ctx,
429                      ctx->input("context_sparse_keys", &context_sparse_keys));
430       OP_REQUIRES_OK(ctx,
431                      ctx->input("context_ragged_keys", &context_ragged_keys));
432       OP_REQUIRES_OK(
433           ctx, ctx->input("feature_list_dense_keys", &feature_list_dense_keys));
434       OP_REQUIRES_OK(ctx, ctx->input("feature_list_sparse_keys",
435                                      &feature_list_sparse_keys));
436       OP_REQUIRES_OK(ctx, ctx->input("feature_list_ragged_keys",
437                                      &feature_list_ragged_keys));
438       absl::call_once(flag_, [&]() {
439         metrics::RecordParseDenseFeature(
440             context_dense_keys->NumElements() +
441             feature_list_dense_keys->NumElements());
442         metrics::RecordParseSparseFeature(
443             context_sparse_keys->NumElements() +
444             feature_list_sparse_keys->NumElements());
445         metrics::RecordParseRaggedFeature(
446             context_ragged_keys->NumElements() +
447             feature_list_ragged_keys->NumElements());
448       });
449     }
450 
451     // Validate input tensor shapes.
452     OP_REQUIRES_OK(ctx, CheckInputShapes(
453                             serialized, debug_name, context_dense_defaults,
454                             context_dense_keys, context_sparse_keys,
455                             context_ragged_keys, feature_list_dense_keys,
456                             feature_list_sparse_keys, feature_list_ragged_keys,
457                             feature_list_dense_missing_assumed_empty));
458 
459     example::FastParseExampleConfig context_config =
460         MakeContextConfig(context_dense_keys, context_sparse_keys,
461                           context_ragged_keys, context_dense_defaults);
462     example::FastParseExampleConfig feature_list_config = MakeFeatureListConfig(
463         feature_list_dense_keys, feature_list_sparse_keys,
464         feature_list_ragged_keys, feature_list_dense_missing_assumed_empty);
465 
466     bool is_batch = TensorShapeUtils::IsVector(serialized->shape());
467     auto serialized_t = serialized->flat<tstring>();
468     auto debug_name_t = debug_name->flat<tstring>();
469     gtl::ArraySlice<tstring> slice(serialized_t.data(), serialized_t.size());
470     gtl::ArraySlice<tstring> names_slice(debug_name_t.data(),
471                                          debug_name_t.size());
472 
473     example::Result context_result, feature_list_result;
474     std::vector<Tensor> dense_feature_lengths;
475     OP_REQUIRES_OK(
476         ctx, FastParseSequenceExample(
477                  context_config, feature_list_config, slice, names_slice,
478                  ctx->device()->tensorflow_cpu_worker_threads()->workers,
479                  &context_result, &feature_list_result, &dense_feature_lengths,
480                  is_batch));
481 
482     OP_REQUIRES_OK(ctx, WriteOutput(context_result, feature_list_result,
483                                     dense_feature_lengths, ctx));
484   }
485 
486  protected:
CheckInputShapes(const Tensor * serialized,const Tensor * names,const OpInputList & context_dense_defaults,const Tensor * context_dense_keys,const Tensor * context_sparse_keys,const Tensor * context_ragged_keys,const Tensor * feature_list_dense_keys,const Tensor * feature_list_sparse_keys,const Tensor * feature_list_ragged_keys,const Tensor * feature_list_dense_missing_assumed_empty) const487   Status CheckInputShapes(
488       const Tensor* serialized, const Tensor* names,
489       const OpInputList& context_dense_defaults,
490 
491       const Tensor* context_dense_keys, const Tensor* context_sparse_keys,
492       const Tensor* context_ragged_keys, const Tensor* feature_list_dense_keys,
493       const Tensor* feature_list_sparse_keys,
494       const Tensor* feature_list_ragged_keys,
495       const Tensor* feature_list_dense_missing_assumed_empty) const {
496     if (TensorShapeUtils::IsMatrixOrHigher(serialized->shape())) {
497       return errors::InvalidArgument(
498           "Expected serialized to be a scalar or vector, got shape: ",
499           serialized->shape().DebugString());
500     }
501     if (op_version_ > 1) {
502       if (context_dense_keys->NumElements() != attrs_.num_context_dense) {
503         return errors::InvalidArgument(
504             "Expected len(context_dense_keys) to match len(Tcontext_dense)");
505       }
506       if (context_sparse_keys->NumElements() != attrs_.num_context_sparse) {
507         return errors::InvalidArgument(
508             "Expected len(context_sparse_keys) to match Ncontext_sparse");
509       }
510       if (context_ragged_keys->NumElements() != attrs_.num_context_ragged) {
511         return errors::InvalidArgument(
512             "Expected len(context_ragged_keys) to match "
513             "len(context_ragged_value_types)");
514       }
515       if (feature_list_dense_keys->NumElements() !=
516           attrs_.num_feature_list_dense) {
517         return errors::InvalidArgument(
518             "Expected len(feature_list_dense_keys) to match "
519             "Nfeature_list_dense");
520       }
521       if (feature_list_dense_missing_assumed_empty->NumElements() !=
522           attrs_.num_feature_list_dense) {
523         return errors::InvalidArgument(
524             "Expected len(feature_list_dense_missing_assumed_empty to match "
525             "Nfeature_list_dense");
526       }
527       if (feature_list_sparse_keys->NumElements() !=
528           attrs_.num_feature_list_sparse) {
529         return errors::InvalidArgument(
530             "Expected len(feature_list_sparse_keys) to match "
531             "Nfeature_list_sparse");
532       }
533       if (feature_list_ragged_keys->NumElements() !=
534           attrs_.num_feature_list_ragged) {
535         return errors::InvalidArgument(
536             "Expected len(feature_list_ragged_keys) to match "
537             "len(feature_list_ragged_value_types)");
538       }
539     }
540     if (context_dense_defaults.size() != attrs_.num_context_dense) {
541       return errors::InvalidArgument(
542           "Expected len(context_dense_defaults) "
543           "== len(context_dense_keys) but got: ",
544           context_dense_defaults.size(), " vs. ", attrs_.num_context_dense);
545     }
546     for (int d = 0; d < attrs_.num_context_dense; ++d) {
547       const Tensor& def_value = context_dense_defaults[d];
548       if (def_value.NumElements() > 0) {
549         if (def_value.shape() != attrs_.context_dense_shapes[d]) {
550           return errors::InvalidArgument(
551               "default_value[", d,
552               "].shape() == ", def_value.shape().DebugString(),
553               " != context_dense_shapes[", d,
554               "] == ", attrs_.context_dense_shapes[d].DebugString());
555         }
556         if (def_value.dtype() != attrs_.context_dense_types[d]) {
557           return errors::InvalidArgument(
558               "context_dense_defaults[", d,
559               "].dtype() == ", DataTypeString(def_value.dtype()),
560               " != context_dense_types[", d,
561               "] == ", DataTypeString(attrs_.context_dense_types[d]));
562         }
563       }
564     }
565     return Status::OK();
566   }
567 
MakeContextConfig(const Tensor * dense_keys,const Tensor * sparse_keys,const Tensor * ragged_keys,const OpInputList & context_dense_defaults) const568   example::FastParseExampleConfig MakeContextConfig(
569       const Tensor* dense_keys, const Tensor* sparse_keys,
570       const Tensor* ragged_keys,
571       const OpInputList& context_dense_defaults) const {
572     // Convert the tensors/attrs to ArraySlices once, instead of re-evaluating
573     // them in each loop iteration.
574     gtl::ArraySlice<tstring> dense_keys_slice =
575         dense_keys
576             ? gtl::ArraySlice<tstring>(dense_keys->flat<tstring>().data(),
577                                        attrs_.num_context_dense)
578             : attrs_.context_dense_keys;
579     gtl::ArraySlice<tstring> sparse_keys_slice =
580         sparse_keys
581             ? gtl::ArraySlice<tstring>(sparse_keys->flat<tstring>().data(),
582                                        attrs_.num_context_sparse)
583             : attrs_.context_sparse_keys;
584     gtl::ArraySlice<tstring> ragged_keys_slice =
585         ragged_keys
586             ? gtl::ArraySlice<tstring>(ragged_keys->flat<tstring>().data(),
587                                        attrs_.num_context_ragged)
588             : gtl::ArraySlice<tstring>(nullptr, 0);
589 
590     example::FastParseExampleConfig config;
591     config.dense.reserve(attrs_.num_context_dense);
592     for (int d = 0; d < attrs_.num_context_dense; ++d) {
593       const tstring& key = dense_keys_slice[d];
594       config.dense.emplace_back(key, attrs_.context_dense_types[d],
595                                 attrs_.context_dense_shapes[d],
596                                 context_dense_defaults[d],
597                                 false /* attrs_.context_variable_length[d] */,
598                                 0 /*attrs_.context_elements_per_stride[d] */);
599     }
600     config.sparse.reserve(attrs_.num_context_sparse);
601     for (int d = 0; d < attrs_.num_context_sparse; ++d) {
602       const tstring& key = sparse_keys_slice[d];
603       config.sparse.emplace_back(key, attrs_.context_sparse_types[d]);
604     }
605     config.ragged.reserve(attrs_.num_context_ragged);
606     for (int d = 0; d < attrs_.num_context_ragged; ++d) {
607       config.ragged.emplace_back(ragged_keys_slice[d],
608                                  attrs_.context_ragged_value_types[d],
609                                  attrs_.context_ragged_split_types[d]);
610     }
611     return config;
612   }
613 
ConstructDefaultScalar(DataType dtype)614   static Tensor ConstructDefaultScalar(DataType dtype) {
615     switch (dtype) {
616       case DT_INT64:
617         return Tensor(static_cast<int64>(0));
618       case DT_FLOAT:
619         return Tensor(static_cast<float>(0.0));
620       case DT_STRING:
621         return Tensor("");
622       default:
623         return Tensor(DT_INVALID);
624     }
625   }
626 
MakeFeatureListConfig(const Tensor * dense_keys,const Tensor * sparse_keys,const Tensor * ragged_keys,const Tensor * feature_list_dense_missing_assumed_empty) const627   example::FastParseExampleConfig MakeFeatureListConfig(
628       const Tensor* dense_keys, const Tensor* sparse_keys,
629       const Tensor* ragged_keys,
630       const Tensor* feature_list_dense_missing_assumed_empty) const {
631     // Convert the tensors/attrs to ArraySlices once, instead of re-evaluating
632     // them in each loop iteration.
633     gtl::ArraySlice<tstring> dense_keys_slice =
634         dense_keys
635             ? gtl::ArraySlice<tstring>(dense_keys->flat<tstring>().data(),
636                                        attrs_.num_feature_list_dense)
637             : attrs_.feature_list_dense_keys;
638     gtl::ArraySlice<tstring> sparse_keys_slice =
639         sparse_keys
640             ? gtl::ArraySlice<tstring>(sparse_keys->flat<tstring>().data(),
641                                        attrs_.num_feature_list_sparse)
642             : attrs_.feature_list_sparse_keys;
643     gtl::ArraySlice<tstring> ragged_keys_slice =
644         ragged_keys
645             ? gtl::ArraySlice<tstring>(ragged_keys->flat<tstring>().data(),
646                                        attrs_.num_feature_list_ragged)
647             : gtl::ArraySlice<tstring>(nullptr, 0);
648     // Use an empty slice to indicate that the map in attrs_ should be used
649     // instead.
650     gtl::ArraySlice<bool> feature_list_dense_missing_assumed_empty_slice =
651         feature_list_dense_missing_assumed_empty
652             ? gtl::ArraySlice<bool>(
653                   feature_list_dense_missing_assumed_empty->flat<bool>().data(),
654                   attrs_.num_feature_list_dense)
655             : gtl::ArraySlice<bool>(nullptr, 0);
656 
657     example::FastParseExampleConfig config;
658     config.dense.reserve(attrs_.num_feature_list_dense);
659     for (int d = 0; d < attrs_.num_feature_list_dense; ++d) {
660       const tstring& key = dense_keys_slice[d];
661       bool missing_assumed_empty =
662           !feature_list_dense_missing_assumed_empty_slice.empty()
663               ? feature_list_dense_missing_assumed_empty_slice[d]
664               : attrs_.feature_list_dense_missing_assumed_empty.count(key) > 0;
665       DataType dtype = attrs_.feature_list_dense_types[d];
666       config.dense.emplace_back(
667           key, dtype, attrs_.feature_list_dense_shapes[d],
668           ConstructDefaultScalar(dtype), missing_assumed_empty,
669           0 /*attrs_.feature_list_elements_per_stride[d] */);
670     }
671     config.sparse.reserve(attrs_.num_feature_list_sparse);
672     for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) {
673       const tstring& key = sparse_keys_slice[d];
674       config.sparse.emplace_back(key, attrs_.feature_list_sparse_types[d]);
675     }
676     config.ragged.reserve(attrs_.num_feature_list_ragged);
677     for (int d = 0; d < attrs_.num_feature_list_ragged; ++d) {
678       config.ragged.emplace_back(ragged_keys_slice[d],
679                                  attrs_.feature_list_ragged_value_types[d],
680                                  attrs_.feature_list_ragged_split_types[d]);
681     }
682     return config;
683   }
684 
WriteOutput(const example::Result & context_result,const example::Result & feature_list_result,const std::vector<Tensor> & dense_feature_lengths,OpKernelContext * ctx) const685   Status WriteOutput(const example::Result& context_result,
686                      const example::Result& feature_list_result,
687                      const std::vector<Tensor>& dense_feature_lengths,
688                      OpKernelContext* ctx) const {
689     OpOutputList context_sparse_indices;
690     OpOutputList context_sparse_values;
691     OpOutputList context_sparse_shapes;
692     OpOutputList context_dense_values;
693     OpOutputList feature_list_sparse_indices;
694     OpOutputList feature_list_sparse_values;
695     OpOutputList feature_list_sparse_shapes;
696     OpOutputList feature_list_dense_values;
697     OpOutputList feature_list_dense_lengths;
698 
699     TF_RETURN_IF_ERROR(
700         ctx->output_list("context_sparse_indices", &context_sparse_indices));
701     TF_RETURN_IF_ERROR(
702         ctx->output_list("context_sparse_values", &context_sparse_values));
703     TF_RETURN_IF_ERROR(
704         ctx->output_list("context_sparse_shapes", &context_sparse_shapes));
705     TF_RETURN_IF_ERROR(
706         ctx->output_list("context_dense_values", &context_dense_values));
707     TF_RETURN_IF_ERROR(
708         ctx->output_list("context_sparse_indices", &context_sparse_indices));
709     TF_RETURN_IF_ERROR(ctx->output_list("feature_list_sparse_indices",
710                                         &feature_list_sparse_indices));
711     TF_RETURN_IF_ERROR(ctx->output_list("feature_list_sparse_values",
712                                         &feature_list_sparse_values));
713     TF_RETURN_IF_ERROR(ctx->output_list("feature_list_sparse_shapes",
714                                         &feature_list_sparse_shapes));
715     TF_RETURN_IF_ERROR(ctx->output_list("feature_list_dense_values",
716                                         &feature_list_dense_values));
717     TF_RETURN_IF_ERROR(ctx->output_list("feature_list_dense_lengths",
718                                         &feature_list_dense_lengths));
719     for (int d = 0; d < attrs_.num_context_dense; ++d) {
720       context_dense_values.set(d, context_result.dense_values[d]);
721     }
722     for (int d = 0; d < attrs_.num_feature_list_dense; ++d) {
723       feature_list_dense_values.set(d, feature_list_result.dense_values[d]);
724       feature_list_dense_lengths.set(d, dense_feature_lengths[d]);
725     }
726     for (int d = 0; d < attrs_.num_context_sparse; ++d) {
727       context_sparse_indices.set(d, context_result.sparse_indices[d]);
728       context_sparse_values.set(d, context_result.sparse_values[d]);
729       context_sparse_shapes.set(d, context_result.sparse_shapes[d]);
730     }
731     for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) {
732       feature_list_sparse_indices.set(d, feature_list_result.sparse_indices[d]);
733       feature_list_sparse_values.set(d, feature_list_result.sparse_values[d]);
734       feature_list_sparse_shapes.set(d, feature_list_result.sparse_shapes[d]);
735     }
736     if (op_version_ == 2) {
737       OpOutputList context_ragged_values;
738       OpOutputList context_ragged_splits;
739       OpOutputList feature_list_ragged_values;
740       OpOutputList feature_list_ragged_inner_splits;
741       OpOutputList feature_list_ragged_outer_splits;
742       TF_RETURN_IF_ERROR(
743           ctx->output_list("context_ragged_values", &context_ragged_values));
744       TF_RETURN_IF_ERROR(ctx->output_list("context_ragged_row_splits",
745                                           &context_ragged_splits));
746       TF_RETURN_IF_ERROR(ctx->output_list("feature_list_ragged_values",
747                                           &feature_list_ragged_values));
748       TF_RETURN_IF_ERROR(ctx->output_list("feature_list_ragged_inner_splits",
749                                           &feature_list_ragged_inner_splits));
750       TF_RETURN_IF_ERROR(ctx->output_list("feature_list_ragged_outer_splits",
751                                           &feature_list_ragged_outer_splits));
752       for (int d = 0; d < attrs_.num_context_ragged; ++d) {
753         context_ragged_values.set(d, context_result.ragged_values[d]);
754         context_ragged_splits.set(d, context_result.ragged_splits[d]);
755       }
756       for (int d = 0; d < attrs_.num_feature_list_ragged; ++d) {
757         feature_list_ragged_values.set(d, feature_list_result.ragged_values[d]);
758         feature_list_ragged_outer_splits.set(
759             d, feature_list_result.ragged_outer_splits[d]);
760         feature_list_ragged_inner_splits.set(
761             d, feature_list_result.ragged_splits[d]);
762       }
763     }
764     return Status::OK();
765   }
766 
767   ParseSequenceExampleAttrs attrs_;
768   int op_version_;
769   absl::once_flag flag_;
770 };
771 
772 REGISTER_KERNEL_BUILDER(Name("ParseSequenceExample").Device(DEVICE_CPU),
773                         ParseSequenceExampleOp);
774 REGISTER_KERNEL_BUILDER(Name("ParseSequenceExampleV2").Device(DEVICE_CPU),
775                         ParseSequenceExampleOp);
776 
777 class ParseSingleSequenceExampleOp : public OpKernel {
778  public:
ParseSingleSequenceExampleOp(OpKernelConstruction * ctx)779   explicit ParseSingleSequenceExampleOp(OpKernelConstruction* ctx)
780       : OpKernel(ctx) {
781     OP_REQUIRES_OK(ctx, attrs_.Init(ctx));
782   }
783 
Compute(OpKernelContext * ctx)784   void Compute(OpKernelContext* ctx) override {
785     const Tensor* debug_name;
786     const Tensor* serialized;
787     OpInputList context_dense_keys;
788     OpInputList context_sparse_keys;
789     OpInputList context_dense_defaults;
790     OpInputList feature_list_dense_keys;
791     OpInputList feature_list_sparse_keys;
792     const Tensor* feature_list_dense_missing_assumed_empty;
793 
794     OP_REQUIRES_OK(ctx, ctx->input("debug_name", &debug_name));
795     OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
796     OP_REQUIRES_OK(ctx, ctx->input("feature_list_dense_missing_assumed_empty",
797                                    &feature_list_dense_missing_assumed_empty));
798     OP_REQUIRES_OK(ctx,
799                    ctx->input_list("context_dense_keys", &context_dense_keys));
800     OP_REQUIRES_OK(ctx, ctx->input_list("feature_list_dense_keys",
801                                         &feature_list_dense_keys));
802     OP_REQUIRES_OK(
803         ctx, ctx->input_list("context_sparse_keys", &context_sparse_keys));
804     OP_REQUIRES_OK(ctx, ctx->input_list("feature_list_sparse_keys",
805                                         &feature_list_sparse_keys));
806     OP_REQUIRES_OK(ctx, ctx->input_list("context_dense_defaults",
807                                         &context_dense_defaults));
808 
809     std::vector<string> context_dense_keys_t(attrs_.num_context_dense);
810     std::vector<string> context_sparse_keys_t(attrs_.num_context_sparse);
811     std::vector<string> feature_list_dense_keys_t(
812         attrs_.num_feature_list_dense);
813     std::vector<string> feature_list_sparse_keys_t(
814         attrs_.num_feature_list_sparse);
815     absl::call_once(
816         flag_, [&context_dense_keys_t, &context_sparse_keys_t,
817                 &feature_list_dense_keys_t, &feature_list_sparse_keys_t]() {
818           metrics::RecordParseDenseFeature(context_dense_keys_t.size() +
819                                            feature_list_dense_keys_t.size());
820           metrics::RecordParseSparseFeature(context_sparse_keys_t.size() +
821                                             feature_list_sparse_keys_t.size());
822         });
823     std::unordered_set<string> feature_list_dense_missing_assumed_empty_set;
824     CHECK_EQ(context_dense_keys.size(), attrs_.num_context_dense);
825     CHECK_EQ(context_sparse_keys.size(), attrs_.num_context_sparse);
826     CHECK_EQ(feature_list_dense_keys.size(), attrs_.num_feature_list_dense);
827     CHECK_EQ(feature_list_sparse_keys.size(), attrs_.num_feature_list_sparse);
828     for (int di = 0; di < attrs_.num_context_dense; ++di) {
829       OP_REQUIRES(ctx,
830                   TensorShapeUtils::IsScalar(context_dense_keys[di].shape()),
831                   errors::InvalidArgument(
832                       "Expected context_dense_keys[", di,
833                       "] to be a scalar, got shape: ",
834                       context_dense_keys[di].shape().DebugString()));
835       context_dense_keys_t[di] = context_dense_keys[di].scalar<tstring>()();
836     }
837     for (int di = 0; di < attrs_.num_context_sparse; ++di) {
838       OP_REQUIRES(ctx,
839                   TensorShapeUtils::IsScalar(context_sparse_keys[di].shape()),
840                   errors::InvalidArgument(
841                       "Expected context_sparse_keys[", di,
842                       "] to be a scalar, got shape: ",
843                       context_sparse_keys[di].shape().DebugString()));
844       context_sparse_keys_t[di] = context_sparse_keys[di].scalar<tstring>()();
845     }
846     for (int di = 0; di < attrs_.num_feature_list_dense; ++di) {
847       OP_REQUIRES(
848           ctx, TensorShapeUtils::IsScalar(feature_list_dense_keys[di].shape()),
849           errors::InvalidArgument(
850               "Expected feature_list_dense_keys[", di,
851               "] to be a scalar, got shape: ",
852               feature_list_dense_keys[di].shape().DebugString()));
853       feature_list_dense_keys_t[di] =
854           feature_list_dense_keys[di].scalar<tstring>()();
855     }
856     for (int di = 0; di < attrs_.num_feature_list_sparse; ++di) {
857       OP_REQUIRES(
858           ctx, TensorShapeUtils::IsScalar(feature_list_sparse_keys[di].shape()),
859           errors::InvalidArgument(
860               "Expected feature_list_sparse_keys[", di,
861               "] to be a scalar, got shape: ",
862               feature_list_sparse_keys[di].shape().DebugString()));
863       feature_list_sparse_keys_t[di] =
864           feature_list_sparse_keys[di].scalar<tstring>()();
865     }
866     OP_REQUIRES(
867         ctx,
868         TensorShapeUtils::IsVector(
869             feature_list_dense_missing_assumed_empty->shape()),
870         errors::InvalidArgument(
871             "Expected feature_list_dense_missing_assumed_empty ",
872             "to be a vector, got shape: ",
873             feature_list_dense_missing_assumed_empty->shape().DebugString()));
874     auto feature_list_dense_missing_assumped_empty_t =
875         feature_list_dense_missing_assumed_empty->vec<tstring>();
876     for (int de = 0;
877          de < feature_list_dense_missing_assumed_empty->NumElements(); ++de) {
878       feature_list_dense_missing_assumed_empty_set.insert(
879           feature_list_dense_missing_assumped_empty_t(de));
880     }
881 
882     bool has_debug_name = (debug_name->NumElements() > 0);
883     if (has_debug_name) {
884       OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(debug_name->shape()),
885                   errors::InvalidArgument(
886                       "Expected debug_name to be a scalar, got shape: ",
887                       debug_name->shape().DebugString()));
888     }
889     auto debug_name_t = debug_name->scalar<tstring>();
890 
891     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(serialized->shape()),
892                 errors::InvalidArgument(
893                     "Expected serialized to be a scalar, got shape: ",
894                     serialized->shape().DebugString()));
895 
896     OP_REQUIRES(ctx, context_dense_defaults.size() == attrs_.num_context_dense,
897                 errors::InvalidArgument("Expected len(context_dense_defaults) "
898                                         "== len(context_dense_keys) but got: ",
899                                         context_dense_defaults.size(), " vs. ",
900                                         attrs_.num_context_dense));
901 
902     std::vector<bool> required(attrs_.num_context_dense);
903     for (int d = 0; d < attrs_.num_context_dense; ++d) {
904       const Tensor& def_value = context_dense_defaults[d];
905       required[d] = (def_value.NumElements() == 0);  // No default provided.
906 
907       if (def_value.NumElements() > 0) {
908         OP_REQUIRES(ctx, def_value.shape() == attrs_.context_dense_shapes[d],
909                     errors::InvalidArgument(
910                         "def_value[", d,
911                         "].shape() == ", def_value.shape().DebugString(),
912                         " != context_dense_shapes_[", d,
913                         "] == ", attrs_.context_dense_shapes[d].DebugString()));
914         OP_REQUIRES(
915             ctx, def_value.dtype() == attrs_.context_dense_types[d],
916             errors::InvalidArgument(
917                 "context_dense_defaults[", d, "].dtype() == ",
918                 DataTypeString(def_value.dtype()), " != context_dense_types_[",
919                 d, "] == ", DataTypeString(attrs_.context_dense_types[d])));
920       }
921     }
922 
923     auto serialized_t = serialized->scalar<tstring>();
924 
925     OpOutputList context_sparse_indices;
926     OpOutputList context_sparse_values;
927     OpOutputList context_sparse_shapes;
928     OpOutputList context_dense_values;
929     OpOutputList feature_list_sparse_indices;
930     OpOutputList feature_list_sparse_values;
931     OpOutputList feature_list_sparse_shapes;
932     OpOutputList feature_list_dense_values;
933 
934     OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices",
935                                          &context_sparse_indices));
936     OP_REQUIRES_OK(
937         ctx, ctx->output_list("context_sparse_values", &context_sparse_values));
938     OP_REQUIRES_OK(
939         ctx, ctx->output_list("context_sparse_shapes", &context_sparse_shapes));
940     OP_REQUIRES_OK(
941         ctx, ctx->output_list("context_dense_values", &context_dense_values));
942     OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices",
943                                          &context_sparse_indices));
944     OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_indices",
945                                          &feature_list_sparse_indices));
946     OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_values",
947                                          &feature_list_sparse_values));
948     OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_shapes",
949                                          &feature_list_sparse_shapes));
950     OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_values",
951                                          &feature_list_dense_values));
952 
953     // Allocate the SequenceExample on an arena. Provides better memory locality
954     // and greatly speeds up destruction.
955     protobuf::ArenaOptions options;
956     // We have some hint of what the final proto size will be based on the size
957     // of the serialized bytes- use this to set a custom allocation strategy.
958     // Note that the default allocation strategy is quite conservative (min
959     // block size of 256 bytes, and a max of 8 kilobytes).
960     const size_t block_size = serialized_t().size() * 1.1;
961     options.start_block_size = std::max(options.start_block_size, block_size);
962     options.max_block_size = std::max(options.max_block_size, block_size);
963     protobuf::Arena arena(options);
964     auto& ex = *protobuf::Arena::CreateMessage<SequenceExample>(&arena);
965 
966     OP_REQUIRES(
967         ctx, ParseProtoUnlimited(&ex, serialized_t()),
968         errors::InvalidArgument("Could not parse example input, value: '",
969                                 serialized_t(), "'"));
970 
971     const tstring& name = (has_debug_name) ? debug_name_t() : "<unknown>";
972     const Features& context = ex.context();
973     const auto& context_dict = context.feature();
974 
975     // Context Dense -----------------------------------------------------------
976 
977     // Preallocate context_dense_values, since we know their sizes
978     for (int d = 0; d < attrs_.num_context_dense; ++d) {
979       TensorShape out_shape;
980       for (const int dim : attrs_.context_dense_shapes[d].dim_sizes())
981         out_shape.AddDim(dim);
982       Tensor* out = nullptr;
983       OP_REQUIRES_OK(ctx, context_dense_values.allocate(d, out_shape, &out));
984     }
985 
986     for (int d = 0; d < attrs_.num_context_dense; ++d) {
987       const tstring& key = context_dense_keys_t[d];
988       const DataType& dtype = attrs_.context_dense_types[d];
989       const TensorShape& shape = attrs_.context_dense_shapes[d];
990 
991       const auto& feature_found = context_dict.find(key);
992       OP_REQUIRES(
993           ctx, (feature_found != context_dict.end()) || !required[d],
994           errors::InvalidArgument("Name: ", name, ", Context feature '", key,
995                                   "' is required but could not be found."));
996       if (feature_found != context_dict.end()) {
997         const Feature& f = feature_found->second;
998         bool types_match;
999         OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
1000         OP_REQUIRES(
1001             ctx, types_match,
1002             errors::InvalidArgument("Name: ", name, ", Context feature: ", key,
1003                                     ".  Data types don't match. ",
1004                                     "Expected type: ", DataTypeString(dtype),
1005                                     "  Feature is: ", f.DebugString()));
1006 
1007         OP_REQUIRES_OK(ctx, FeatureDenseCopy(0, name, key, dtype, shape, f,
1008                                              context_dense_values[d]));
1009       } else {
1010         RowDenseCopy(0, dtype, context_dense_defaults[d],
1011                      context_dense_values[d]);
1012       }
1013     }
1014 
1015     // Context Sparse ----------------------------------------------------------
1016     for (int d = 0; d < attrs_.num_context_sparse; ++d) {
1017       const tstring& key = context_sparse_keys_t[d];
1018       const DataType& dtype = attrs_.context_sparse_types[d];
1019 
1020       const auto& feature_found = context_dict.find(key);
1021       bool feature_has_data =  // Found key & data type is set
1022           (feature_found != context_dict.end() &&
1023            (feature_found->second.kind_case() != Feature::KIND_NOT_SET));
1024 
1025       if (feature_has_data) {
1026         const Feature& f = feature_found->second;
1027         bool types_match;
1028         OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
1029         OP_REQUIRES(
1030             ctx, types_match,
1031             errors::InvalidArgument("Name: ", name, ", Context feature: ", key,
1032                                     ".  Data types don't match. ",
1033                                     "Expected type: ", DataTypeString(dtype),
1034                                     "  Feature is: ", f.DebugString()));
1035 
1036         Tensor feature_values = FeatureSparseCopy(0, key, dtype, f);
1037         const int64 num_elements = feature_values.NumElements();
1038         TensorShape indices_shape({num_elements, 1});
1039         Tensor* sp_indices_d = nullptr;
1040         Tensor* sp_shape_d = nullptr;
1041         OP_REQUIRES_OK(ctx, context_sparse_indices.allocate(d, indices_shape,
1042                                                             &sp_indices_d));
1043         context_sparse_values.set(d, feature_values);
1044         OP_REQUIRES_OK(ctx, context_sparse_shapes.allocate(d, TensorShape({1}),
1045                                                            &sp_shape_d));
1046         auto shape_t = sp_shape_d->vec<int64>();
1047         shape_t(0) = num_elements;
1048         auto indices_t = sp_indices_d->matrix<int64>();
1049         std::iota(indices_t.data(), indices_t.data() + num_elements, 0);
1050       } else {
1051         TensorShape indices_shape({0, 1});
1052         TensorShape values_shape({0});
1053         Tensor* sp_indices_d = nullptr;
1054         Tensor* sp_values_d = nullptr;
1055         Tensor* sp_shape_d = nullptr;
1056         OP_REQUIRES_OK(ctx, context_sparse_indices.allocate(d, indices_shape,
1057                                                             &sp_indices_d));
1058         OP_REQUIRES_OK(
1059             ctx, context_sparse_values.allocate(d, values_shape, &sp_values_d));
1060         OP_REQUIRES_OK(ctx, context_sparse_shapes.allocate(d, TensorShape({1}),
1061                                                            &sp_shape_d));
1062         auto shape_t = sp_shape_d->vec<int64>();
1063         shape_t(0) = 0;
1064       }
1065     }
1066 
1067     // Feature List Dense ------------------------------------------------------
1068 
1069     // Preallocate context_dense_values, since we can infer their
1070     // sizes
1071     const FeatureLists& feature_lists = ex.feature_lists();
1072     const auto& feature_list_dict = feature_lists.feature_list();
1073     FeatureList empty_feature_list;  // Placeholder for missing FLs
1074 
1075     for (int d = 0; d < attrs_.num_feature_list_dense; ++d) {
1076       const tstring& key = feature_list_dense_keys_t[d];
1077       const DataType& dtype = attrs_.feature_list_dense_types[d];
1078       const TensorShape& shape = attrs_.feature_list_dense_shapes[d];
1079 
1080       const auto& feature_list_found = feature_list_dict.find(key);
1081       bool feature_list_missing =
1082           (feature_list_found == feature_list_dict.end());
1083       bool feature_list_allowed_missing =
1084           (feature_list_dense_missing_assumed_empty_set.count(key) > 0);
1085 
1086       OP_REQUIRES(
1087           ctx, !feature_list_missing || feature_list_allowed_missing,
1088           errors::InvalidArgument("Name: ", name, ", Feature list '", key,
1089                                   "' is required but could not be found.  "
1090                                   "Did you mean to include it in "
1091                                   "feature_list_dense_missing_assumed_empty or "
1092                                   "feature_list_dense_defaults?"));
1093 
1094       TensorShape out_shape;
1095       const FeatureList& fl = (feature_list_missing)
1096                                   ? empty_feature_list
1097                                   : feature_list_found->second;
1098       out_shape.AddDim(fl.feature_size());
1099       for (const int dim : attrs_.feature_list_dense_shapes[d].dim_sizes()) {
1100         out_shape.AddDim(dim);
1101       }
1102       Tensor* out = nullptr;
1103       OP_REQUIRES_OK(ctx,
1104                      feature_list_dense_values.allocate(d, out_shape, &out));
1105 
1106       for (int64 t = 0; t < fl.feature_size(); ++t) {
1107         const Feature& f = fl.feature(t);
1108         bool types_match;
1109         OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
1110         OP_REQUIRES(ctx, types_match,
1111                     errors::InvalidArgument(
1112                         "Name: ", name, ", Feature list: ", key, ", Index: ", t,
1113                         ".  Data types don't match. ",
1114                         "Expected type: ", DataTypeString(dtype),
1115                         "  Feature is: ", f.DebugString()));
1116         OP_REQUIRES_OK(ctx, FeatureDenseCopy(t, name, key, dtype, shape, f,
1117                                              feature_list_dense_values[d]));
1118       }
1119     }
1120 
1121     // Feature List Sparse -----------------------------------------------------
1122     for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) {
1123       const tstring& key = feature_list_sparse_keys_t[d];
1124       const DataType& dtype = attrs_.feature_list_sparse_types[d];
1125 
1126       const auto& feature_list_found = feature_list_dict.find(key);
1127       bool feature_list_has_data =  // Found key
1128           (feature_list_found != feature_list_dict.end());
1129 
1130       std::vector<Tensor> sparse_values_tmp;
1131       int64 feature_list_size = 0;
1132       if (feature_list_has_data) {
1133         const FeatureList& fl = feature_list_found->second;
1134         feature_list_size = fl.feature_size();
1135         for (int64 t = 0; t < feature_list_size; ++t) {
1136           const Feature& f = fl.feature(t);
1137           bool types_match;
1138           OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
1139           OP_REQUIRES(
1140               ctx, f.kind_case() == Feature::KIND_NOT_SET || types_match,
1141               errors::InvalidArgument("Name: ", name, ", Feature List: ", key,
1142                                       ", Index: ", t,
1143                                       ".  Data types don't match. ",
1144                                       "Expected type: ", DataTypeString(dtype),
1145                                       "  Feature is: ", f.DebugString()));
1146           sparse_values_tmp.push_back(FeatureSparseCopy(t, key, dtype, f));
1147         }
1148       } else {
1149         sparse_values_tmp.push_back(Tensor(dtype, TensorShape({0})));
1150       }
1151 
1152       int64 total_num_features = 0;
1153       int64 max_num_features = 0;
1154       for (int t = 0; t < feature_list_size; ++t) {
1155         const Tensor& v = sparse_values_tmp[t];
1156         const int64 num_elements = v.shape().num_elements();
1157         total_num_features += num_elements;
1158         max_num_features = std::max(max_num_features, num_elements);
1159       }
1160 
1161       TensorShape indices_shape({total_num_features, 2});
1162       TensorShape values_shape({total_num_features});
1163       Tensor* sp_indices_d = nullptr;
1164       Tensor* sp_values_d = nullptr;
1165       Tensor* sp_shape_d = nullptr;
1166       OP_REQUIRES_OK(ctx, feature_list_sparse_indices.allocate(d, indices_shape,
1167                                                                &sp_indices_d));
1168       OP_REQUIRES_OK(ctx, feature_list_sparse_values.allocate(d, values_shape,
1169                                                               &sp_values_d));
1170       OP_REQUIRES_OK(ctx, feature_list_sparse_shapes.allocate(
1171                               d, TensorShape({2}), &sp_shape_d));
1172       auto shape_t = sp_shape_d->vec<int64>();
1173       shape_t(0) = feature_list_size;
1174       shape_t(1) = max_num_features;
1175 
1176       int64 offset = 0;
1177 
1178       for (int t = 0; t < feature_list_size; ++t) {
1179         const int64 num_elements = CopyIntoSparseTensor(
1180             sparse_values_tmp[t], t, offset, sp_indices_d, sp_values_d);
1181         offset += num_elements;
1182       }
1183     }
1184   }
1185 
1186  protected:
1187   ParseSingleSequenceExampleAttrs attrs_;
1188   absl::once_flag flag_;
1189 };
1190 
1191 REGISTER_KERNEL_BUILDER(Name("ParseSingleSequenceExample").Device(DEVICE_CPU),
1192                         ParseSingleSequenceExampleOp);
1193 
1194 #ifndef IS_MOBILE_PLATFORM
1195 // when using lite protos on mobile, decoding JSON is not available.
1196 
1197 class DecodeJSONExampleOp : public OpKernel {
1198  public:
DecodeJSONExampleOp(OpKernelConstruction * ctx)1199   explicit DecodeJSONExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
1200     resolver_.reset(protobuf::util::NewTypeResolverForDescriptorPool(
1201         "type.googleapis.com", protobuf::DescriptorPool::generated_pool()));
1202   }
1203 
Compute(OpKernelContext * ctx)1204   void Compute(OpKernelContext* ctx) override {
1205     const Tensor* json_examples;
1206     OP_REQUIRES_OK(ctx, ctx->input("json_examples", &json_examples));
1207     Tensor* binary_examples;
1208     OP_REQUIRES_OK(
1209         ctx, ctx->allocate_output("binary_examples", json_examples->shape(),
1210                                   &binary_examples));
1211 
1212     for (int i = 0; i < json_examples->NumElements(); ++i) {
1213       const tstring& json_example = json_examples->flat<tstring>()(i);
1214       protobuf::io::ArrayInputStream in(json_example.data(),
1215                                         json_example.size());
1216       TStringOutputStream out(&binary_examples->flat<tstring>()(i));
1217       auto status = protobuf::util::JsonToBinaryStream(
1218           resolver_.get(), "type.googleapis.com/tensorflow.Example", &in, &out);
1219       OP_REQUIRES(ctx, status.ok(),
1220                   errors::InvalidArgument("Error while parsing JSON: ",
1221                                           string(status.error_message())));
1222     }
1223   }
1224 
1225  private:
1226   std::unique_ptr<protobuf::util::TypeResolver> resolver_;
1227 };
1228 
1229 REGISTER_KERNEL_BUILDER(Name("DecodeJSONExample").Device(DEVICE_CPU),
1230                         DecodeJSONExampleOp);
1231 #endif
1232 
1233 }  // namespace tensorflow
1234