1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_kernel.h"
17
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 #include "tensorflow/core/framework/allocator.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/fake_input.h"
25 #include "tensorflow/core/framework/node_def_builder.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/tensor_shape.pb.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/protobuf.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/platform/test_benchmark.h"
37 #include "tensorflow/core/public/version.h"
38
39 class DummyKernel : public tensorflow::OpKernel {
40 public:
DummyKernel(tensorflow::OpKernelConstruction * context)41 explicit DummyKernel(tensorflow::OpKernelConstruction* context)
42 : OpKernel(context) {}
Compute(tensorflow::OpKernelContext * context)43 void Compute(tensorflow::OpKernelContext* context) override {}
44 };
45
46 // Test that registration works outside a namespace.
47 REGISTER_OP("Test1").Input("a: float").Input("b: int32").Output("o: uint8");
48 REGISTER_KERNEL_BUILDER(Name("Test1").Device(tensorflow::DEVICE_CPU),
49 DummyKernel);
50
51 namespace foo {
52 bool match_signature_ = false;
53
54 // Test that registration works inside a different namespace.
55 class TestOp2 : public ::tensorflow::OpKernel {
56 public:
TestOp2(::tensorflow::OpKernelConstruction * context)57 explicit TestOp2(::tensorflow::OpKernelConstruction* context)
58 : OpKernel(context) {
59 ::tensorflow::Status status = context->MatchSignature(
60 {::tensorflow::DT_INT32}, {::tensorflow::DT_INT32});
61 match_signature_ = status.ok();
62 context->SetStatus(status);
63 }
Compute(::tensorflow::OpKernelContext * context)64 void Compute(::tensorflow::OpKernelContext* context) override {}
65 };
66
67 REGISTER_OP("Test2").Input("i: T").Output("o: T").Attr("T: type");
68 REGISTER_KERNEL_BUILDER(Name("Test2")
69 .Device(::tensorflow::DEVICE_GPU)
70 .HostMemory("i")
71 .HostMemory("o"),
72 TestOp2);
73 } // namespace foo
74
75 namespace tensorflow {
76
77 // Two operations with the same name but different devices.
78 REGISTER_OP("Test3").Input("a: T").Input("b: T").Attr("T: type");
79
80 class TestOp3Cpu : public tensorflow::OpKernel {
81 public:
TestOp3Cpu(OpKernelConstruction * context)82 explicit TestOp3Cpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)83 void Compute(OpKernelContext* context) override {}
84 };
85
86 REGISTER_KERNEL_BUILDER(
87 Name("Test3").Device(DEVICE_CPU).TypeConstraint<int8>("T"), TestOp3Cpu);
88
89 namespace {
90
91 class TestOp3Gpu : public tensorflow::OpKernel {
92 public:
TestOp3Gpu(OpKernelConstruction * context)93 explicit TestOp3Gpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)94 void Compute(OpKernelContext* context) override {}
95 };
96
97 REGISTER_KERNEL_BUILDER(
98 Name("Test3").Device(DEVICE_GPU).TypeConstraint<float>("T"), TestOp3Cpu);
99
100 // An Op registered for both
101 REGISTER_OP("Test4").Input("i: float").Output("o: float");
102 REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_CPU), DummyKernel);
103 REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_GPU), DummyKernel);
104
105 // Kernels with different priorities.
106 REGISTER_OP("Test5").Input("a: T").Input("b: T").Attr("T: type");
107
108 class TestOp5Cpu : public tensorflow::OpKernel {
109 public:
TestOp5Cpu(OpKernelConstruction * context)110 explicit TestOp5Cpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)111 void Compute(OpKernelContext* context) override {}
112 };
113
114 REGISTER_KERNEL_BUILDER(Name("Test5").Device(DEVICE_CPU).Priority(2),
115 TestOp5Cpu);
116
117 class TestOp5Gpu : public tensorflow::OpKernel {
118 public:
TestOp5Gpu(OpKernelConstruction * context)119 explicit TestOp5Gpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)120 void Compute(OpKernelContext* context) override {}
121 };
122
123 REGISTER_KERNEL_BUILDER(Name("Test5").Device(DEVICE_GPU).Priority(1),
124 TestOp5Gpu);
125
DeviceTypes()126 static std::vector<DeviceType> DeviceTypes() {
127 return {DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)};
128 }
129
130 class OpKernelTest : public ::testing::Test {
131 public:
OpKernelTest()132 OpKernelTest() : device_(Env::Default()) {}
133
134 protected:
CreateNodeDef(const string & op_type,const DataTypeVector & inputs)135 NodeDef CreateNodeDef(const string& op_type, const DataTypeVector& inputs) {
136 NodeDefBuilder builder(op_type + "-op", op_type);
137 for (DataType dt : inputs) {
138 builder.Input(FakeInput(dt));
139 }
140 NodeDef node_def;
141 TF_CHECK_OK(builder.Finalize(&node_def));
142 return node_def;
143 }
144
ExpectEqual(const string & what,const DataTypeVector & expected,const DataTypeVector & observed)145 void ExpectEqual(const string& what, const DataTypeVector& expected,
146 const DataTypeVector& observed) {
147 EXPECT_EQ(expected.size(), observed.size()) << what;
148 const size_t size = std::min(expected.size(), observed.size());
149 for (size_t i = 0; i < size; ++i) {
150 bool match = TypesCompatible(expected[i], observed[i]);
151 EXPECT_TRUE(match) << what << " i:" << i << ", expected: " << expected[i]
152 << ", observed: " << observed[i];
153 }
154 }
155
ExpectSuccess(const string & op_type,DeviceType device_type,const DataTypeVector & inputs,const DataTypeVector & outputs)156 void ExpectSuccess(const string& op_type, DeviceType device_type,
157 const DataTypeVector& inputs,
158 const DataTypeVector& outputs) {
159 Status status;
160 std::unique_ptr<OpKernel> op(CreateOpKernel(
161 std::move(device_type), &device_, cpu_allocator(),
162 CreateNodeDef(op_type, inputs), TF_GRAPH_DEF_VERSION, &status));
163 EXPECT_TRUE(status.ok()) << status;
164 EXPECT_TRUE(op != nullptr);
165 if (op != nullptr) {
166 ExpectEqual("inputs", op->input_types(), inputs);
167 ExpectEqual("outputs", op->output_types(), outputs);
168 }
169 }
170
ExpectFailure(const string & ascii_node_def,DeviceType device_type,error::Code code)171 void ExpectFailure(const string& ascii_node_def, DeviceType device_type,
172 error::Code code) {
173 NodeDef node_def;
174 protobuf::TextFormat::ParseFromString(ascii_node_def, &node_def);
175 Status status;
176 std::unique_ptr<OpKernel> op(
177 CreateOpKernel(std::move(device_type), &device_, cpu_allocator(),
178 node_def, TF_GRAPH_DEF_VERSION, &status));
179 EXPECT_TRUE(op == nullptr);
180 EXPECT_FALSE(status.ok());
181 if (!status.ok()) {
182 LOG(INFO) << "Status message: " << status.error_message();
183 EXPECT_EQ(code, status.code());
184 }
185 }
186
187 private:
188 DeviceBase device_;
189 };
190
TEST_F(OpKernelTest,SuccessCpu)191 TEST_F(OpKernelTest, SuccessCpu) {
192 ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT, DT_INT32}, {DT_UINT8});
193 ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT_REF, DT_INT32}, {DT_UINT8});
194 }
195
TEST_F(OpKernelTest,SuccessGpu)196 TEST_F(OpKernelTest, SuccessGpu) {
197 foo::match_signature_ = false;
198 ExpectSuccess("Test2", DEVICE_GPU, {DT_INT32}, {DT_INT32});
199 EXPECT_TRUE(foo::match_signature_);
200 }
201
TEST_F(OpKernelTest,SuccessBothCpuAndGpu)202 TEST_F(OpKernelTest, SuccessBothCpuAndGpu) {
203 ExpectSuccess("Test3", DEVICE_CPU, {DT_INT8, DT_INT8}, {});
204 ExpectSuccess("Test3", DEVICE_GPU, {DT_FLOAT, DT_FLOAT}, {});
205 }
206
TEST_F(OpKernelTest,CpuTypeRegistered)207 TEST_F(OpKernelTest, CpuTypeRegistered) {
208 NodeDef ndef = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
209 PrioritizedDeviceTypeVector devs;
210 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
211 EXPECT_EQ(1, devs.size());
212 EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
213 }
214
TEST_F(OpKernelTest,CpuAndGpuTypeRegistered)215 TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) {
216 {
217 // Try a node def of an op that is registered for a specific type
218 // only on CPU.
219 NodeDef ndef = CreateNodeDef("Test3", {DT_INT8, DT_INT8});
220 PrioritizedDeviceTypeVector devs;
221 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
222 EXPECT_EQ(1, devs.size());
223 EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
224 }
225 {
226 // Try a node def of an op that is registered for a specific type
227 // only on GPU.
228 NodeDef ndef = CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT});
229 PrioritizedDeviceTypeVector devs;
230 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
231 EXPECT_EQ(1, devs.size());
232 EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first);
233 }
234 {
235 // Try a node def of an op that is only registered for other types.
236 NodeDef ndef = CreateNodeDef("Test3", {DT_STRING, DT_STRING});
237 PrioritizedDeviceTypeVector devs;
238 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
239 EXPECT_EQ(0, devs.size());
240 }
241
242 {
243 // Try a node def of an op that is registered for both.
244 NodeDef ndef = CreateNodeDef("Test4", {DT_FLOAT});
245 PrioritizedDeviceTypeVector devs;
246 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
247 EXPECT_EQ(2, devs.size());
248 EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first);
249 EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1].first);
250 }
251
252 {
253 // Try a node def of an op where kernels have priorities.
254 NodeDef ndef = CreateNodeDef("Test5", {DT_STRING, DT_STRING});
255 PrioritizedDeviceTypeVector devs;
256 TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
257 EXPECT_EQ(2, devs.size());
258 EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
259 EXPECT_EQ(2, devs[0].second);
260 EXPECT_EQ(DeviceType(DEVICE_GPU), devs[1].first);
261 EXPECT_EQ(1, devs[1].second);
262 }
263 }
264
TEST_F(OpKernelTest,NotFound)265 TEST_F(OpKernelTest, NotFound) {
266 const auto not_found = error::NOT_FOUND;
267 // Something with that op type name exists, but only with a
268 // different DeviceType.
269 ExpectFailure(CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}).DebugString(),
270 DEVICE_GPU, not_found);
271 ExpectFailure(CreateNodeDef("Test3", {DT_INT8, DT_INT8}).DebugString(),
272 DEVICE_GPU, not_found);
273 ExpectFailure(CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}).DebugString(),
274 DEVICE_CPU, not_found);
275
276 // No kernel with that signature registered.
277 ExpectFailure(CreateNodeDef("Test3", {DT_INT32, DT_INT32}).DebugString(),
278 DEVICE_GPU, not_found);
279
280 // Nothing with that op type name exists.
281 ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_CPU, not_found);
282 ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_GPU, not_found);
283 }
284
TEST_F(OpKernelTest,TooFewInputs)285 TEST_F(OpKernelTest, TooFewInputs) {
286 const auto invalid = error::INVALID_ARGUMENT;
287 NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
288 node_def.clear_input();
289 ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
290 node_def.add_input("a");
291 ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
292 }
293
TEST_F(OpKernelTest,TooManyInputs)294 TEST_F(OpKernelTest, TooManyInputs) {
295 const auto invalid = error::INVALID_ARGUMENT;
296 NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
297 node_def.add_input("c");
298 ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
299 }
300
TEST_F(OpKernelTest,MatchSignatureFailes)301 TEST_F(OpKernelTest, MatchSignatureFailes) {
302 const auto invalid = error::INVALID_ARGUMENT;
303 foo::match_signature_ = true;
304 ExpectFailure(CreateNodeDef("Test2", {DT_FLOAT}).DebugString(), DEVICE_GPU,
305 invalid);
306 EXPECT_FALSE(foo::match_signature_);
307 }
308
309 class DummyDevice : public DeviceBase {
310 public:
DummyDevice(Env * env,bool save)311 DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
RequiresRecordingAccessedTensors() const312 bool RequiresRecordingAccessedTensors() const override { return save_; }
GetAllocator(AllocatorAttributes)313 Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
314 return cpu_allocator();
315 }
316
317 private:
318 bool save_;
319 };
320
TEST_F(OpKernelTest,SaveTempFalse)321 TEST_F(OpKernelTest, SaveTempFalse) {
322 Env* env = Env::Default();
323 OpKernelContext::Params params;
324 params.record_tensor_accesses = false;
325 params.device = new DummyDevice(env, params.record_tensor_accesses);
326 Status status;
327 std::unique_ptr<OpKernel> op(
328 CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
329 CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
330 TF_GRAPH_DEF_VERSION, &status));
331 EXPECT_TRUE(status.ok());
332 params.op_kernel = op.get();
333 OpKernelContext* ctx = new OpKernelContext(¶ms);
334
335 Tensor t;
336 TF_EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t));
337
338 TensorReferenceVector referenced_tensors;
339 ctx->retrieve_accessed_tensors(&referenced_tensors);
340 EXPECT_EQ(0, referenced_tensors.size());
341
342 delete ctx;
343 delete params.device;
344 }
345
TEST_F(OpKernelTest,SaveTempTrue)346 TEST_F(OpKernelTest, SaveTempTrue) {
347 Env* env = Env::Default();
348 OpKernelContext::Params params;
349 params.record_tensor_accesses = true;
350 params.device = new DummyDevice(env, params.record_tensor_accesses);
351 Status status;
352 std::unique_ptr<OpKernel> op(
353 CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
354 CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
355 TF_GRAPH_DEF_VERSION, &status));
356 EXPECT_TRUE(status.ok());
357 params.op_kernel = op.get();
358 OpKernelContext* ctx = new OpKernelContext(¶ms);
359
360 Tensor t;
361 TF_EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t));
362
363 TensorReferenceVector referenced_tensors;
364 ctx->retrieve_accessed_tensors(&referenced_tensors);
365 EXPECT_EQ(1, referenced_tensors.size());
366 for (auto& ref : referenced_tensors) {
367 ref.Unref();
368 }
369
370 delete ctx;
371 delete params.device;
372 }
373
TEST_F(OpKernelTest,InputDtype)374 TEST_F(OpKernelTest, InputDtype) {
375 Env* env = Env::Default();
376 OpKernelContext::Params params;
377 params.record_tensor_accesses = false;
378 params.device = new DummyDevice(env, params.record_tensor_accesses);
379 Status status;
380 std::unique_ptr<OpKernel> op(
381 CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
382 CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
383 TF_GRAPH_DEF_VERSION, &status));
384 EXPECT_TRUE(status.ok());
385 params.op_kernel = op.get();
386 Tensor a(DT_FLOAT, TensorShape({}));
387 Tensor b(DT_INT32, TensorShape({}));
388 Tensor c(DT_UINT8, TensorShape({}));
389 gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&a), TensorValue(&b),
390 TensorValue(&c)};
391 params.inputs = &inputs;
392 OpKernelContext* ctx = new OpKernelContext(¶ms);
393
394 DataType dtype;
395 EXPECT_FALSE(ctx->input_dtype("non_existent_input", &dtype).ok());
396 ASSERT_TRUE(ctx->input_dtype("a", &dtype).ok());
397 EXPECT_EQ(dtype, DT_FLOAT);
398 ASSERT_TRUE(ctx->input_dtype("b", &dtype).ok());
399 EXPECT_EQ(dtype, DT_INT32);
400 delete ctx;
401 delete params.device;
402 }
403
404 class OpKernelBuilderTest : public ::testing::Test {
405 protected:
406 // Each attr is described by a "name|type|value".
CreateNodeDef(const string & op_type,const std::vector<string> & attrs)407 NodeDef CreateNodeDef(const string& op_type,
408 const std::vector<string>& attrs) {
409 NodeDef node_def;
410 node_def.set_name(op_type + "-op");
411 node_def.set_op(op_type);
412 for (const string& attr_desc : attrs) {
413 std::vector<string> parts = str_util::Split(attr_desc, '|');
414 CHECK_EQ(parts.size(), 3);
415 AttrValue attr_value;
416 CHECK(ParseAttrValue(parts[1], parts[2], &attr_value)) << attr_desc;
417 node_def.mutable_attr()->insert(
418 AttrValueMap::value_type(parts[0], attr_value));
419 }
420 return node_def;
421 }
422
ExpectSuccess(const string & op_type,const DeviceType & device_type,const std::vector<string> & attrs,DataTypeSlice input_types={})423 std::unique_ptr<OpKernel> ExpectSuccess(const string& op_type,
424 const DeviceType& device_type,
425 const std::vector<string>& attrs,
426 DataTypeSlice input_types = {}) {
427 Status status;
428 NodeDef def = CreateNodeDef(op_type, attrs);
429 for (size_t i = 0; i < input_types.size(); ++i) {
430 def.add_input("a:0");
431 }
432
433 Env* env = Env::Default();
434 DeviceBase device(env);
435
436 // Test CreateOpKernel()
437 std::unique_ptr<OpKernel> op(CreateOpKernel(device_type, &device,
438 cpu_allocator(), def,
439 TF_GRAPH_DEF_VERSION, &status));
440 EXPECT_TRUE(status.ok()) << status;
441 EXPECT_TRUE(op != nullptr);
442 if (op != nullptr) {
443 EXPECT_EQ(input_types.size(), op->num_inputs());
444 EXPECT_EQ(0, op->num_outputs());
445 }
446
447 // Test SupportedDeviceTypesForNode()
448 PrioritizedDeviceTypeVector devices;
449 TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
450 bool found = false;
451 for (const auto& dt : devices) {
452 if (dt.first == device_type) {
453 found = true;
454 }
455 }
456 EXPECT_TRUE(found) << "Missing " << device_type << " from "
457 << devices.size() << " devices.";
458
459 // In case the caller wants to use the OpKernel
460 return op;
461 }
462
ExpectFailure(const string & op_type,const DeviceType & device_type,const std::vector<string> & attrs,error::Code code)463 void ExpectFailure(const string& op_type, const DeviceType& device_type,
464 const std::vector<string>& attrs, error::Code code) {
465 Status status;
466 const NodeDef def = CreateNodeDef(op_type, attrs);
467 Env* env = Env::Default();
468 DeviceBase device(env);
469
470 // Test CreateOpKernel().
471 std::unique_ptr<OpKernel> op(CreateOpKernel(device_type, &device,
472 cpu_allocator(), def,
473 TF_GRAPH_DEF_VERSION, &status));
474 EXPECT_TRUE(op == nullptr);
475 EXPECT_FALSE(status.ok());
476 if (!status.ok()) {
477 LOG(INFO) << "Status message: " << status.error_message();
478 EXPECT_EQ(code, status.code());
479
480 // Test SupportedDeviceTypesForNode().
481 PrioritizedDeviceTypeVector devices;
482 if (errors::IsNotFound(status)) {
483 TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
484 for (const auto& dt : devices) {
485 EXPECT_NE(dt.first, device_type);
486 }
487 } else {
488 Status status2 =
489 SupportedDeviceTypesForNode(DeviceTypes(), def, &devices);
490 EXPECT_EQ(status.code(), status2.code());
491 }
492 }
493 }
494
GetKernelClassName(const string & op_type,const DeviceType & device_type,const std::vector<string> & attrs,DataTypeSlice input_types={})495 string GetKernelClassName(const string& op_type,
496 const DeviceType& device_type,
497 const std::vector<string>& attrs,
498 DataTypeSlice input_types = {}) {
499 NodeDef def = CreateNodeDef(op_type, attrs);
500 for (size_t i = 0; i < input_types.size(); ++i) {
501 def.add_input("a:0");
502 }
503
504 const KernelDef* kernel_def = nullptr;
505 string kernel_class_name;
506 const Status status =
507 FindKernelDef(device_type, def, &kernel_def, &kernel_class_name);
508 if (status.ok()) {
509 return kernel_class_name;
510 } else if (errors::IsNotFound(status)) {
511 return "not found";
512 } else {
513 return status.ToString();
514 }
515 }
516 };
517
518 REGISTER_OP("BuildCPU");
519 REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel);
520
TEST_F(OpKernelBuilderTest,BuilderCPU)521 TEST_F(OpKernelBuilderTest, BuilderCPU) {
522 ExpectSuccess("BuildCPU", DEVICE_CPU, {});
523 EXPECT_EQ("DummyKernel", GetKernelClassName("BuildCPU", DEVICE_CPU, {}));
524 ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND);
525 EXPECT_EQ("not found", GetKernelClassName("BuildCPU", DEVICE_GPU, {}));
526 }
527
528 REGISTER_OP("BuildGPU");
529 REGISTER_KERNEL_BUILDER(Name("BuildGPU").Device(DEVICE_GPU), DummyKernel);
530
TEST_F(OpKernelBuilderTest,BuilderGPU)531 TEST_F(OpKernelBuilderTest, BuilderGPU) {
532 ExpectFailure("BuildGPU", DEVICE_CPU, {}, error::NOT_FOUND);
533 ExpectSuccess("BuildGPU", DEVICE_GPU, {});
534 }
535
536 REGISTER_OP("BuildBoth");
537 REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_CPU), DummyKernel);
538 REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_GPU), DummyKernel);
539
TEST_F(OpKernelBuilderTest,BuilderBoth)540 TEST_F(OpKernelBuilderTest, BuilderBoth) {
541 ExpectSuccess("BuildBoth", DEVICE_CPU, {});
542 ExpectSuccess("BuildBoth", DEVICE_GPU, {});
543 }
544
545 REGISTER_OP("BuildTypeAttr").Attr("T: type");
546 REGISTER_KERNEL_BUILDER(
547 Name("BuildTypeAttr").Device(DEVICE_CPU).TypeConstraint<float>("T"),
548 DummyKernel);
549
TEST_F(OpKernelBuilderTest,BuilderTypeAttr)550 TEST_F(OpKernelBuilderTest, BuilderTypeAttr) {
551 ExpectSuccess("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_FLOAT"});
552 ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_BOOL"},
553 error::NOT_FOUND);
554 ExpectFailure("BuildTypeAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
555 ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|int|7"},
556 error::INVALID_ARGUMENT);
557 }
558
559 REGISTER_OP("BuildTypeListAttr").Attr("T: list(type)");
560 REGISTER_KERNEL_BUILDER(
561 Name("BuildTypeListAttr").Device(DEVICE_CPU).TypeConstraint<bool>("T"),
562 DummyKernel);
563
TEST_F(OpKernelBuilderTest,BuilderTypeListAttr)564 TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) {
565 ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"});
566 EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
567 {"T|list(type)|[]"}));
568
569 ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"});
570 EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
571 {"T|list(type)|[]"}));
572
573 ExpectSuccess("BuildTypeListAttr", DEVICE_CPU,
574 {"T|list(type)|[DT_BOOL, DT_BOOL]"});
575
576 ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"},
577 error::NOT_FOUND);
578 EXPECT_EQ("not found", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
579 {"T|list(type)|[DT_FLOAT]"}));
580
581 ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
582 EXPECT_TRUE(str_util::StrContains(
583 GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {}),
584 "Invalid argument: "));
585
586 ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"},
587 error::INVALID_ARGUMENT);
588 }
589
590 REGISTER_OP("DuplicateKernel");
591 REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU),
592 DummyKernel);
593 REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU),
594 DummyKernel);
595
TEST_F(OpKernelBuilderTest,DuplicateKernel)596 TEST_F(OpKernelBuilderTest, DuplicateKernel) {
597 const NodeDef ndef = CreateNodeDef("DuplicateKernel", {});
598 PrioritizedDeviceTypeVector devs;
599 Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
600 ASSERT_FALSE(status.ok());
601 EXPECT_TRUE(str_util::StrContains(
602 status.error_message(), "Multiple OpKernel registrations match NodeDef"));
603
604 ExpectFailure("DuplicateKernel", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
605 }
606
607 REGISTER_OP("DuplicateKernelForT").Attr("T: type");
608 REGISTER_KERNEL_BUILDER(
609 Name("DuplicateKernelForT").Device(DEVICE_CPU).TypeConstraint<float>("T"),
610 DummyKernel);
611 REGISTER_KERNEL_BUILDER(
612 Name("DuplicateKernelForT").Device(DEVICE_CPU).TypeConstraint<float>("T"),
613 DummyKernel);
614
TEST_F(OpKernelBuilderTest,DuplicateKernelForT)615 TEST_F(OpKernelBuilderTest, DuplicateKernelForT) {
616 const NodeDef ndef =
617 CreateNodeDef("DuplicateKernelForT", {"T|type|DT_FLOAT"});
618 PrioritizedDeviceTypeVector devs;
619 Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
620 ASSERT_FALSE(status.ok());
621 EXPECT_TRUE(str_util::StrContains(
622 status.error_message(), "Multiple OpKernel registrations match NodeDef"));
623
624 ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_FLOAT"},
625 error::INVALID_ARGUMENT);
626 ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_BOOL"},
627 error::NOT_FOUND);
628 }
629
630 REGISTER_OP("BadConstraint").Attr("dtype: type");
631 REGISTER_KERNEL_BUILDER(Name("BadConstraint")
632 .Device(DEVICE_CPU)
633 // Mistake: "T" should be "dtype".
634 .TypeConstraint<float>("T"),
635 DummyKernel);
636
TEST_F(OpKernelBuilderTest,BadConstraint)637 TEST_F(OpKernelBuilderTest, BadConstraint) {
638 const NodeDef ndef = CreateNodeDef("BadConstraint", {});
639 PrioritizedDeviceTypeVector devs;
640 Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
641 ASSERT_FALSE(status.ok());
642 EXPECT_TRUE(
643 str_util::StrContains(status.error_message(),
644 "OpKernel 'BadConstraint' has constraint on attr "
645 "'T' not in NodeDef"));
646
647 ExpectFailure("BadConstraint", DEVICE_CPU, {"dtype|type|DT_FLOAT"},
648 error::INVALID_ARGUMENT);
649 }
650
651 REGISTER_OP("ListOut").Output("a: int32").Output("b: T").Attr("T: list(type)");
652 REGISTER_KERNEL_BUILDER(Name("ListOut").Device(tensorflow::DEVICE_CPU),
653 DummyKernel);
654
TEST_F(OpKernelBuilderTest,OpOutputList)655 TEST_F(OpKernelBuilderTest, OpOutputList) {
656 Env* env = Env::Default();
657 OpKernelContext::Params params;
658 params.record_tensor_accesses = false;
659 std::unique_ptr<DummyDevice> device(
660 new DummyDevice(env, params.record_tensor_accesses));
661 params.device = device.get();
662 Status status;
663 std::unique_ptr<OpKernel> op(CreateOpKernel(
664 DEVICE_CPU, params.device, cpu_allocator(),
665 CreateNodeDef("ListOut", {"T|list(type)|[DT_FLOAT, DT_INT32]"}),
666 TF_GRAPH_DEF_VERSION, &status));
667 EXPECT_TRUE(status.ok()) << status.ToString();
668 params.op_kernel = op.get();
669 gtl::InlinedVector<TensorValue, 4> inputs{};
670 params.inputs = &inputs;
671 std::unique_ptr<OpKernelContext> ctx(new OpKernelContext(¶ms));
672
673 EXPECT_EQ(DT_INT32, ctx->expected_output_dtype(0));
674 OpOutputList out_list;
675 EXPECT_FALSE(ctx->output_list("non_existent_output", &out_list).ok());
676 ASSERT_TRUE(ctx->output_list("b", &out_list).ok());
677 EXPECT_EQ(DT_FLOAT, out_list.expected_output_dtype(0));
678 EXPECT_EQ(DT_INT32, out_list.expected_output_dtype(1));
679 }
680
681 class GetAttrKernel : public ::tensorflow::OpKernel {
682 public:
GetAttrKernel(OpKernelConstruction * context)683 explicit GetAttrKernel(OpKernelConstruction* context) : OpKernel(context) {
684 string attr_name;
685 OP_REQUIRES_OK(context, context->GetAttr("attr_name", &attr_name));
686
687 status.emplace_back("s", context->GetAttr(attr_name, &s));
688 status.emplace_back("s_list", context->GetAttr(attr_name, &s_list));
689 status.emplace_back("i", context->GetAttr(attr_name, &i));
690 status.emplace_back("i_list", context->GetAttr(attr_name, &i_list));
691 status.emplace_back("i32", context->GetAttr(attr_name, &i32));
692 status.emplace_back("i32_list", context->GetAttr(attr_name, &i32_list));
693 status.emplace_back("f", context->GetAttr(attr_name, &f));
694 status.emplace_back("f_list", context->GetAttr(attr_name, &f_list));
695 status.emplace_back("b", context->GetAttr(attr_name, &b));
696 status.emplace_back("b_list", context->GetAttr(attr_name, &b_list));
697 status.emplace_back("type", context->GetAttr(attr_name, &type));
698 status.emplace_back("type_list", context->GetAttr(attr_name, &type_list));
699 status.emplace_back("type_vector",
700 context->GetAttr(attr_name, &type_vector));
701 status.emplace_back("shape_proto",
702 context->GetAttr(attr_name, &shape_proto));
703 status.emplace_back("shape_proto_list",
704 context->GetAttr(attr_name, &shape_proto_list));
705 status.emplace_back("shape", context->GetAttr(attr_name, &shape));
706 status.emplace_back("shape_list", context->GetAttr(attr_name, &shape_list));
707 }
Compute(::tensorflow::OpKernelContext * context)708 void Compute(::tensorflow::OpKernelContext* context) override {}
709
ExpectOk(std::initializer_list<string> keys)710 void ExpectOk(std::initializer_list<string> keys) {
711 for (const auto& key_status : status) {
712 // Only the status for keys in "keys" should be ok().
713 bool in_keys = false;
714 for (const string& key : keys) {
715 if (key_status.first == key) {
716 in_keys = true;
717 }
718 }
719 EXPECT_EQ(in_keys, key_status.second.ok())
720 << "key_status: " << key_status.first << ", " << key_status.second;
721 }
722 }
723
724 string s;
725 std::vector<string> s_list;
726 int64 i;
727 std::vector<int64> i_list;
728 int32 i32;
729 std::vector<int32> i32_list;
730 float f;
731 std::vector<float> f_list;
732 bool b;
733 std::vector<bool> b_list;
734 DataType type;
735 std::vector<DataType> type_list;
736 DataTypeVector type_vector;
737 TensorShapeProto shape_proto;
738 std::vector<TensorShapeProto> shape_proto_list;
739 TensorShape shape;
740 std::vector<TensorShape> shape_list;
741 std::vector<std::pair<string, Status>> status;
742 };
743
744 class GetAttrTest : public OpKernelBuilderTest {};
745
746 REGISTER_OP("GetAttrStringList")
747 .Attr("attr_name: string")
748 .Attr("a: list(string)");
749 REGISTER_KERNEL_BUILDER(Name("GetAttrStringList").Device(DEVICE_CPU),
750 GetAttrKernel);
751
TEST_F(GetAttrTest,StringList)752 TEST_F(GetAttrTest, StringList) {
753 std::unique_ptr<OpKernel> op_kernel =
754 ExpectSuccess("GetAttrStringList", DEVICE_CPU,
755 {"attr_name|string|'a'", "a|list(string)|['foo', 'bar']"});
756 auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
757 get_attr_kernel->ExpectOk({"s_list"});
758 EXPECT_EQ(std::vector<string>({"foo", "bar"}), get_attr_kernel->s_list);
759
760 op_kernel = ExpectSuccess("GetAttrStringList", DEVICE_CPU,
761 {"attr_name|string|'b'", "a|list(string)|['baz']"});
762 get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
763 get_attr_kernel->ExpectOk({});
764 EXPECT_TRUE(get_attr_kernel->s_list.empty());
765 }
766
767 REGISTER_OP("GetAttrInt")
768 .Attr("attr_name: string")
769 .Attr("a: int")
770 .Attr("b: list(int)");
771 REGISTER_KERNEL_BUILDER(Name("GetAttrInt").Device(DEVICE_CPU), GetAttrKernel);
772
TEST_F(GetAttrTest,Int)773 TEST_F(GetAttrTest, Int) {
774 std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
775 "GetAttrInt", DEVICE_CPU,
776 {"attr_name|string|'a'", "a|int|35", "b|list(int)|[-1, 2, -4]"});
777 auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
778 get_attr_kernel->ExpectOk({"i", "i32"});
779 EXPECT_EQ(35, get_attr_kernel->i);
780 EXPECT_EQ(35, get_attr_kernel->i32);
781
782 op_kernel = ExpectSuccess(
783 "GetAttrInt", DEVICE_CPU,
784 {"attr_name|string|'b'", "a|int|35", "b|list(int)|[-1, 2, -4]"});
785 get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
786 get_attr_kernel->ExpectOk({"i_list", "i32_list"});
787 EXPECT_EQ(std::vector<int64>({-1, 2, -4}), get_attr_kernel->i_list);
788 EXPECT_EQ(std::vector<int32>({-1, 2, -4}), get_attr_kernel->i32_list);
789
790 // 8589934592 == 2^33, too big to fit in an int32
791 op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU,
792 {"attr_name|string|'a'", "a|int|8589934592",
793 "b|list(int)|[-8589934592]"});
794 get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
795 get_attr_kernel->ExpectOk({"i"}); // no i32
796 EXPECT_EQ(8589934592ll, get_attr_kernel->i);
797 for (const auto& key_status : get_attr_kernel->status) {
798 if (key_status.first == "i32") {
799 EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code());
800 EXPECT_EQ("Attr a has value 8589934592 out of range for an int32",
801 key_status.second.error_message());
802 }
803 }
804
805 op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU,
806 {"attr_name|string|'b'", "a|int|8589934592",
807 "b|list(int)|[-8589934592]"});
808 get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
809 get_attr_kernel->ExpectOk({"i_list"}); // no i32_list
810 EXPECT_EQ(std::vector<int64>({-8589934592ll}), get_attr_kernel->i_list);
811 for (const auto& key_status : get_attr_kernel->status) {
812 if (key_status.first == "i32_list") {
813 EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code());
814 EXPECT_EQ("Attr b has value -8589934592 out of range for an int32",
815 key_status.second.error_message());
816 }
817 }
818 }
819
820 REGISTER_OP("GetAttrShape")
821 .Attr("attr_name: string")
822 .Attr("a: shape")
823 .Attr("b: list(shape)");
824 REGISTER_KERNEL_BUILDER(Name("GetAttrShape").Device(DEVICE_CPU), GetAttrKernel);
825
TEST_F(GetAttrTest,Shape)826 TEST_F(GetAttrTest, Shape) {
827 std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
828 "GetAttrShape", DEVICE_CPU,
829 {"attr_name|string|'a'", "a|shape|{ dim { size: 3 } }",
830 "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"});
831 auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
832 get_attr_kernel->ExpectOk({"shape", "shape_proto"});
833 EXPECT_EQ(get_attr_kernel->shape_proto.ShortDebugString(), "dim { size: 3 }");
834 EXPECT_EQ("[3]", get_attr_kernel->shape.DebugString());
835
836 op_kernel = ExpectSuccess(
837 "GetAttrShape", DEVICE_CPU,
838 {"attr_name|string|'b'", "a|shape|{ dim { size: 3 } }",
839 "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"});
840 get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
841 get_attr_kernel->ExpectOk({"shape_list", "shape_proto_list"});
842 ASSERT_EQ(2, get_attr_kernel->shape_proto_list.size());
843 EXPECT_EQ(get_attr_kernel->shape_proto_list[0].ShortDebugString(),
844 "dim { size: 2 }");
845 EXPECT_EQ(get_attr_kernel->shape_proto_list[1].ShortDebugString(),
846 "dim { size: 4 }");
847 ASSERT_EQ(2, get_attr_kernel->shape_list.size());
848 EXPECT_EQ("[2]", get_attr_kernel->shape_list[0].DebugString());
849 EXPECT_EQ("[4]", get_attr_kernel->shape_list[1].DebugString());
850 }
851
852 REGISTER_OP("GetAttrType").Attr("attr_name: string").Attr("a: type");
853 REGISTER_KERNEL_BUILDER(Name("GetAttrType").Device(DEVICE_CPU), GetAttrKernel);
854
TEST_F(GetAttrTest,Type)855 TEST_F(GetAttrTest, Type) {
856 std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
857 "GetAttrType", DEVICE_CPU, {"attr_name|string|'a'", "a|type|DT_FLOAT"});
858 auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
859 get_attr_kernel->ExpectOk({"type"});
860 EXPECT_EQ(DT_FLOAT, get_attr_kernel->type);
861 }
862
863 REGISTER_OP("GetAttrTypeList").Attr("attr_name: string").Attr("a: list(type)");
864 REGISTER_KERNEL_BUILDER(Name("GetAttrTypeList").Device(DEVICE_CPU),
865 GetAttrKernel);
866
TEST_F(GetAttrTest,TypeList)867 TEST_F(GetAttrTest, TypeList) {
868 std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
869 "GetAttrTypeList", DEVICE_CPU,
870 {"attr_name|string|'a'", "a|list(type)|[DT_INT32, DT_BOOL]"});
871 auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
872
873 get_attr_kernel->ExpectOk({"type_list", "type_vector"});
874 ASSERT_EQ(2, get_attr_kernel->type_list.size());
875 EXPECT_EQ(DT_INT32, get_attr_kernel->type_list[0]);
876 EXPECT_EQ(DT_BOOL, get_attr_kernel->type_list[1]);
877 ASSERT_EQ(2, get_attr_kernel->type_vector.size());
878 EXPECT_EQ(DT_INT32, get_attr_kernel->type_vector[0]);
879 EXPECT_EQ(DT_BOOL, get_attr_kernel->type_vector[1]);
880 }
881
882 class BaseKernel : public ::tensorflow::OpKernel {
883 public:
BaseKernel(OpKernelConstruction * context)884 explicit BaseKernel(OpKernelConstruction* context) : OpKernel(context) {}
Compute(::tensorflow::OpKernelContext * context)885 void Compute(::tensorflow::OpKernelContext* context) override {}
886 virtual int Which() const = 0;
887 };
888
889 template <int WHICH>
890 class LabeledKernel : public BaseKernel {
891 public:
892 using BaseKernel::BaseKernel;
Which() const893 int Which() const override { return WHICH; }
894 };
895
896 class LabelTest : public OpKernelBuilderTest {};
897
898 REGISTER_OP("LabeledKernel");
899 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU),
900 LabeledKernel<0>);
901 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("one"),
902 LabeledKernel<1>);
903 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"),
904 LabeledKernel<2>);
905 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"),
906 LabeledKernel<3>);
907
TEST_F(LabelTest,Default)908 TEST_F(LabelTest, Default) {
909 std::unique_ptr<OpKernel> op_kernel =
910 ExpectSuccess("LabeledKernel", DEVICE_CPU, {});
911 auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
912 EXPECT_EQ(0, get_labeled_kernel->Which());
913
914 EXPECT_EQ("LabeledKernel<0>",
915 GetKernelClassName("LabeledKernel", DEVICE_CPU, {}));
916 }
917
TEST_F(LabelTest,Specified)918 TEST_F(LabelTest, Specified) {
919 std::unique_ptr<OpKernel> op_kernel =
920 ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"});
921 auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
922 EXPECT_EQ(1, get_labeled_kernel->Which());
923 EXPECT_EQ("LabeledKernel<1>", GetKernelClassName("LabeledKernel", DEVICE_CPU,
924 {"_kernel|string|'one'"}));
925 }
926
TEST_F(LabelTest,Duplicate)927 TEST_F(LabelTest, Duplicate) {
928 ExpectFailure("LabeledKernel", DEVICE_CPU, {"_kernel|string|'dupe'"},
929 error::INVALID_ARGUMENT);
930 }
931
BM_InputRangeHelper(int iters,const NodeDef & node_def,const char * input_name,int expected_start,int expected_stop)932 void BM_InputRangeHelper(int iters, const NodeDef& node_def,
933 const char* input_name, int expected_start,
934 int expected_stop) {
935 Status status;
936 std::unique_ptr<DummyDevice> device(new DummyDevice(Env::Default(), false));
937
938 std::unique_ptr<OpKernel> op(CreateOpKernel(DEVICE_CPU, device.get(),
939 cpu_allocator(), node_def,
940 TF_GRAPH_DEF_VERSION, &status));
941 TF_CHECK_OK(status);
942
943 testing::StartTiming();
944 for (int i = 0; i < iters; ++i) {
945 int start;
946 int stop;
947 TF_CHECK_OK(op->InputRange(input_name, &start, &stop));
948 EXPECT_EQ(expected_start, start);
949 EXPECT_EQ(expected_stop, stop);
950 }
951 testing::StopTiming();
952 }
953
954 REGISTER_KERNEL_BUILDER(Name("ConcatV2").Device(DEVICE_CPU), DummyKernel);
955 REGISTER_KERNEL_BUILDER(Name("Select").Device(DEVICE_CPU), DummyKernel);
956
BM_ConcatInputRange(int iters)957 void BM_ConcatInputRange(int iters) {
958 testing::StopTiming();
959
960 // Create a ConcatV2 NodeDef with 4 inputs (plus the axis).
961 NodeDef node_def;
962 node_def.set_name("concat-op");
963 node_def.set_op("ConcatV2");
964 AttrValue attr_N;
965 attr_N.set_i(4);
966 AttrValue attr_T;
967 attr_T.set_type(DT_FLOAT);
968 AttrValue attr_Tidx;
969 attr_Tidx.set_type(DT_INT32);
970 node_def.mutable_attr()->insert({"N", attr_N});
971 node_def.mutable_attr()->insert({"T", attr_T});
972 node_def.mutable_attr()->insert({"Tidx", attr_Tidx});
973 for (size_t i = 0; i < 5; ++i) {
974 node_def.add_input(strings::StrCat("a:", i));
975 }
976
977 BM_InputRangeHelper(iters, node_def, "values", 0, 4);
978 }
979
BM_SelectInputRange(int iters)980 void BM_SelectInputRange(int iters) {
981 testing::StopTiming();
982
983 // Create a Select NodeDef with 3 inputs.
984 NodeDef node_def;
985 node_def.set_name("select-op");
986 node_def.set_op("Select");
987 AttrValue attr_T;
988 attr_T.set_type(DT_FLOAT);
989 node_def.mutable_attr()->insert({"T", attr_T});
990 for (size_t i = 0; i < 3; ++i) {
991 node_def.add_input(strings::StrCat("a:", i));
992 }
993
994 BM_InputRangeHelper(iters, node_def, "condition", 0, 1);
995 }
996
997 BENCHMARK(BM_ConcatInputRange);
998 BENCHMARK(BM_SelectInputRange);
999
TEST(RegisteredKernels,CanCallGetAllRegisteredKernels)1000 TEST(RegisteredKernels, CanCallGetAllRegisteredKernels) {
1001 auto kernel_list = GetAllRegisteredKernels();
1002 auto all_registered_kernels = kernel_list.kernel();
1003 auto has_name_test1 = [](const KernelDef& k) { return k.op() == "Test1"; };
1004
1005 // Verify we can find the "Test1" op registered above
1006 auto test1_it = std::find_if(all_registered_kernels.begin(),
1007 all_registered_kernels.end(), has_name_test1);
1008 ASSERT_NE(test1_it, all_registered_kernels.end());
1009 EXPECT_EQ(test1_it->device_type(), "CPU");
1010
1011 // Verify there was just one kernel
1012 ++test1_it;
1013 EXPECT_EQ(
1014 std::find_if(test1_it, all_registered_kernels.end(), has_name_test1),
1015 all_registered_kernels.end());
1016 }
1017
1018 // Simple test just to check we can call LogAllRegisteredKernels
TEST(RegisteredKernels,CanLogAllRegisteredKernels)1019 TEST(RegisteredKernels, CanLogAllRegisteredKernels) {
1020 tensorflow::LogAllRegisteredKernels();
1021 }
1022
TEST(RegisteredKernels,GetFilteredRegisteredKernels)1023 TEST(RegisteredKernels, GetFilteredRegisteredKernels) {
1024 auto has_name_test1 = [](const KernelDef& k) { return k.op() == "Test1"; };
1025 auto kernel_list = GetFilteredRegisteredKernels(has_name_test1);
1026 ASSERT_EQ(kernel_list.kernel_size(), 1);
1027 EXPECT_EQ(kernel_list.kernel(0).op(), "Test1");
1028 EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
1029 }
1030
TEST(RegisteredKernels,GetRegisteredKernelsForOp)1031 TEST(RegisteredKernels, GetRegisteredKernelsForOp) {
1032 auto kernel_list = GetRegisteredKernelsForOp("Test1");
1033 ASSERT_EQ(kernel_list.kernel_size(), 1);
1034 EXPECT_EQ(kernel_list.kernel(0).op(), "Test1");
1035 EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
1036 }
1037
1038 } // namespace
1039 } // namespace tensorflow
1040