1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/optional_ops.h"
16
17 #include "tensorflow/core/common_runtime/dma_helper.h"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/variant_encode_decode.h"
20 #include "tensorflow/core/framework/variant_op_registry.h"
21
22 namespace tensorflow {
23 namespace data {
24 namespace {
25
OptionalDeviceCopy(const OptionalVariant & from,OptionalVariant * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy)26 static Status OptionalDeviceCopy(
27 const OptionalVariant& from, OptionalVariant* to,
28 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
29 if (from.has_value()) {
30 const std::vector<Tensor>& from_values = from.get_values();
31 std::vector<Tensor> to_values;
32 to_values.reserve(from_values.size());
33 for (const Tensor& t : from_values) {
34 if (DMAHelper::CanUseDMA(&t) || t.dtype() == DT_VARIANT) {
35 // NOTE(skyewm): we're careful to make sure the lifetime of the 'to'
36 // Tensor passed to `copy` (i.e. to_values.back()) is the same as the
37 // returned 'to' OptionalVariant. This is because `copy` may spawn async
38 // callbacks that don't run until after this function returns and access
39 // the 'to' Tensor (e.g. BaseGPUDevice::MaybeCopyTensorToGPU).
40 to_values.emplace_back(t.dtype());
41 TF_RETURN_IF_ERROR(copy(t, &to_values.back()));
42 } else {
43 to_values.push_back(t);
44 }
45 }
46 *to = OptionalVariant(std::move(to_values));
47 } else {
48 *to = from;
49 }
50 return Status::OK();
51 }
52
53 #define REGISTER_OPTIONAL_COPY(DIRECTION) \
54 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
55 OptionalVariant, DIRECTION, OptionalDeviceCopy)
56
57 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
58 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
59 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
60
61 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(OptionalVariant,
62 kOptionalVariantTypeName);
63
64 } // namespace
65
Compute(OpKernelContext * ctx)66 void OptionalNoneOp::Compute(OpKernelContext* ctx) {
67 OP_REQUIRES_OK(ctx, WriteOptionalNoneToOutput(ctx, 0));
68 }
69
Compute(OpKernelContext * ctx)70 void OptionalFromValueOp::Compute(OpKernelContext* ctx) {
71 OpInputList components_input;
72 OP_REQUIRES_OK(ctx, ctx->input_list("components", &components_input));
73 std::vector<Tensor> components(components_input.begin(),
74 components_input.end());
75 OP_REQUIRES_OK(ctx,
76 WriteOptionalWithValueToOutput(ctx, 0, std::move(components)));
77 }
78
Compute(OpKernelContext * ctx)79 void OptionalHasValueOp::Compute(OpKernelContext* ctx) {
80 const Tensor* optional_input;
81 OP_REQUIRES_OK(ctx, ctx->input("optional", &optional_input));
82 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(optional_input->shape()),
83 errors::InvalidArgument(
84 "Input to OptionalHasValue must be a scalar tensor "
85 "containing an OptionalVariant object."));
86 const OptionalVariant* optional =
87 optional_input->scalar<Variant>()().get<OptionalVariant>();
88 OP_REQUIRES(
89 ctx, optional != nullptr,
90 errors::InvalidArgument(
91 "Input to OptionalHasValue must be an OptionalVariant object."));
92 Tensor* result;
93 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &result));
94 result->scalar<bool>()() = optional->has_value();
95 }
96
Compute(OpKernelContext * ctx)97 void OptionalGetValueOp::Compute(OpKernelContext* ctx) {
98 const Tensor* optional_input;
99 OP_REQUIRES_OK(ctx, ctx->input("optional", &optional_input));
100 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(optional_input->shape()),
101 errors::InvalidArgument(
102 "Input to OptionalHasValue must be a scalar tensor "
103 "containing an OptionalVariant object."));
104 const OptionalVariant* optional =
105 optional_input->scalar<Variant>()().get<OptionalVariant>();
106 OP_REQUIRES(
107 ctx, optional != nullptr,
108 errors::InvalidArgument(
109 "Input to OptionalHasValue must be an OptionalVariant object."));
110 OP_REQUIRES(
111 ctx, optional->has_value(),
112 errors::InvalidArgument("The given optional does not have a value."));
113 const auto& components = optional->get_values();
114 OP_REQUIRES(
115 ctx, components.size() == output_types_.size(),
116 errors::InvalidArgument("The given optional has ", components.size(),
117 " components, expected ", output_types_.size()));
118 for (int i = 0; i < components.size(); ++i) {
119 OP_REQUIRES(ctx, components[i].dtype() == output_types_[i],
120 errors::InvalidArgument(
121 "The given optional does not match the expected type for "
122 "component ",
123 i, ". Expected: ", DataTypeString(output_types_[i]),
124 ". Actual: ", DataTypeString(components[i].dtype()), "."));
125 OP_REQUIRES(ctx, output_shapes_[i].IsCompatibleWith(components[i].shape()),
126 errors::InvalidArgument(
127 "The given optional does not match the expected shape "
128 "for component ",
129 i, ". Expected: ", output_shapes_[i].DebugString(),
130 ". Actual: ", components[i].shape().DebugString(), "."));
131 ctx->set_output(i, components[i]);
132 }
133 }
134
WriteOptionalWithValueToOutput(OpKernelContext * ctx,int output_index,std::vector<Tensor> value)135 Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
136 std::vector<Tensor> value) {
137 OptionalVariant v(std::move(value));
138 Tensor* variant_t;
139 AllocatorAttributes cpu_alloc;
140 cpu_alloc.set_on_host(true);
141 TF_RETURN_IF_ERROR(ctx->allocate_output(output_index, TensorShape({}),
142 &variant_t, cpu_alloc));
143 variant_t->scalar<Variant>()() = v;
144 return Status::OK();
145 }
146
WriteOptionalNoneToOutput(OpKernelContext * ctx,int output_index)147 Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) {
148 OptionalVariant v;
149 Tensor* variant_t;
150 AllocatorAttributes cpu_alloc;
151 cpu_alloc.set_on_host(true);
152 TF_RETURN_IF_ERROR(ctx->allocate_output(output_index, TensorShape({}),
153 &variant_t, cpu_alloc));
154 variant_t->scalar<Variant>()() = v;
155 return Status::OK();
156 }
157
158 namespace {
159
160 REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE_CPU).Priority(2),
161 OptionalNoneOp);
162 REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE_GPU).Priority(1),
163 OptionalNoneOp);
164 REGISTER_KERNEL_BUILDER(
165 Name("OptionalFromValue").Device(DEVICE_CPU).Priority(2),
166 OptionalFromValueOp);
167 REGISTER_KERNEL_BUILDER(
168 Name("OptionalFromValue").Device(DEVICE_GPU).Priority(1),
169 OptionalFromValueOp);
170
171 REGISTER_KERNEL_BUILDER(Name("OptionalHasValue").Device(DEVICE_CPU).Priority(2),
172 OptionalHasValueOp);
173 REGISTER_KERNEL_BUILDER(Name("OptionalHasValue")
174 .Device(DEVICE_GPU)
175 .HostMemory("has_value")
176 .Priority(1),
177 OptionalHasValueOp);
178 REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE_CPU).Priority(2),
179 OptionalGetValueOp);
180 REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE_GPU).Priority(1),
181 OptionalGetValueOp);
182
183 } // namespace
184
185 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
186 DEVICE_CPU, OptionalVariant,
187 OptionalZerosLike<CPUDevice>);
188
189 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
190 OptionalVariant,
191 OptionalBinaryAdd<CPUDevice>);
192
193 } // namespace data
194 } // namespace tensorflow
195