1 #include "tensorflow/core/framework/tensor_key.h"
2 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16
17 #ifndef TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
18 #define TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
19
20 #define EIGEN_USE_THREADS
21 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22 #define EIGEN_USE_GPU
23 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
24
25 #include <vector>
26
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/framework/variant_op_registry.h"
30 #include "tensorflow/core/framework/variant_tensor_data.h"
31 #include "tensorflow/core/kernels/cwise_ops_common.h"
32 #include "tensorflow/core/util/tensor_ops_util.h"
33
34 namespace tensorflow {
35
36 // Class used to store a RaggedTensor as a Variant scalar.
37 class RaggedTensorVariant {
38 public:
RaggedTensorVariant()39 RaggedTensorVariant() {}
RaggedTensorVariant(Tensor values,const std::vector<Tensor> & nested_splits)40 RaggedTensorVariant(Tensor values, const std::vector<Tensor>& nested_splits)
41 : values_(std::move(values)), nested_splits_(nested_splits) {}
42
43 // Variant support methods.
44 string TypeName() const;
45 string DebugString() const;
46 void Encode(VariantTensorData* data) const;
47 bool Decode(const VariantTensorData& data);
48
49 // The flat_values of the RaggedTensor.
values()50 const Tensor& values() const { return values_; }
mutable_values()51 Tensor* mutable_values() { return &values_; }
set_values(const Tensor & new_values)52 void set_values(const Tensor& new_values) { values_ = new_values; }
53
54 // The nested row_splits of the RaggedTensor.
ragged_rank()55 int ragged_rank() const { return nested_splits_.size(); }
nested_splits()56 const std::vector<Tensor>& nested_splits() const { return nested_splits_; }
mutable_nested_splits()57 std::vector<Tensor>* mutable_nested_splits() { return &nested_splits_; }
splits(int i)58 const Tensor& splits(int i) const { return nested_splits_[i]; }
mutable_splits(int i)59 Tensor* mutable_splits(int i) { return &nested_splits_[i]; }
set_nested_splits(const std::vector<Tensor> & nested_splits)60 void set_nested_splits(const std::vector<Tensor>& nested_splits) {
61 nested_splits_ = nested_splits;
62 }
append_splits(const Tensor & splits)63 void append_splits(const Tensor& splits) { nested_splits_.push_back(splits); }
64
65 private:
66 Tensor values_;
67 std::vector<Tensor> nested_splits_;
68 };
69
70 template <typename Device>
RaggedTensorVariantZerosLike(OpKernelContext * c,const RaggedTensorVariant & x,RaggedTensorVariant * y)71 Status RaggedTensorVariantZerosLike(OpKernelContext* c,
72 const RaggedTensorVariant& x,
73 RaggedTensorVariant* y) {
74 y->set_nested_splits(x.nested_splits());
75 TF_RETURN_IF_ERROR(
76 ZerosLikeTensor<Device>(c, x.values(), y->mutable_values()));
77 return Status::OK();
78 }
79
80 template <typename Device>
RaggedTensorVariantBinaryAdd(OpKernelContext * c,const RaggedTensorVariant & x,const RaggedTensorVariant & y,RaggedTensorVariant * out)81 Status RaggedTensorVariantBinaryAdd(OpKernelContext* c,
82 const RaggedTensorVariant& x,
83 const RaggedTensorVariant& y,
84 RaggedTensorVariant* out) {
85 if (x.values().dtype() != y.values().dtype()) {
86 return errors::InvalidArgument(
87 "Can't add RaggedTensorVariants of different dtypes. One is ",
88 DataTypeString(x.values().dtype()), " and the other is ",
89 DataTypeString(y.values().dtype()));
90 }
91 if (x.ragged_rank() != y.ragged_rank()) {
92 return errors::InvalidArgument(
93 "Can't add RaggedTensorVariants of different ragged rank. ", "One is ",
94 x.ragged_rank(), " and the other is ", y.ragged_rank());
95 }
96 for (int i = 0; i < x.ragged_rank(); ++i) {
97 if (TensorKey(x.splits(i)) != TensorKey(y.splits(i))) {
98 return errors::InvalidArgument(
99 "Can't add RaggedTensorVariants with different row_splits.");
100 }
101 }
102 out->set_nested_splits(x.nested_splits());
103 TF_RETURN_IF_ERROR(BinaryAddTensors<Device>(c, x.values(), y.values(),
104 out->mutable_values()));
105 return Status::OK();
106 }
107
108 } // namespace tensorflow
109
110 #endif // TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_
111