1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_kernel.h"
17 
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 #include "tensorflow/core/framework/allocator.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/fake_input.h"
25 #include "tensorflow/core/framework/node_def_builder.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/tensor_shape.pb.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/protobuf.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/platform/test_benchmark.h"
37 #include "tensorflow/core/public/version.h"
38 
39 class DummyKernel : public tensorflow::OpKernel {
40  public:
DummyKernel(tensorflow::OpKernelConstruction * context)41   explicit DummyKernel(tensorflow::OpKernelConstruction* context)
42       : OpKernel(context) {}
Compute(tensorflow::OpKernelContext * context)43   void Compute(tensorflow::OpKernelContext* context) override {}
44 };
45 
46 // Test that registration works outside a namespace.
47 REGISTER_OP("Test1").Input("a: float").Input("b: int32").Output("o: uint8");
48 REGISTER_KERNEL_BUILDER(Name("Test1").Device(tensorflow::DEVICE_CPU),
49                         DummyKernel);
50 
51 namespace foo {
52 bool match_signature_ = false;
53 
54 // Test that registration works inside a different namespace.
55 class TestOp2 : public ::tensorflow::OpKernel {
56  public:
TestOp2(::tensorflow::OpKernelConstruction * context)57   explicit TestOp2(::tensorflow::OpKernelConstruction* context)
58       : OpKernel(context) {
59     ::tensorflow::Status status = context->MatchSignature(
60         {::tensorflow::DT_INT32}, {::tensorflow::DT_INT32});
61     match_signature_ = status.ok();
62     context->SetStatus(status);
63   }
Compute(::tensorflow::OpKernelContext * context)64   void Compute(::tensorflow::OpKernelContext* context) override {}
65 };
66 
67 REGISTER_OP("Test2").Input("i: T").Output("o: T").Attr("T: type");
68 REGISTER_KERNEL_BUILDER(Name("Test2")
69                             .Device(::tensorflow::DEVICE_GPU)
70                             .HostMemory("i")
71                             .HostMemory("o"),
72                         TestOp2);
73 }  // namespace foo
74 
75 namespace tensorflow {
76 
77 // Two operations with the same name but different devices.
78 REGISTER_OP("Test3").Input("a: T").Input("b: T").Attr("T: type");
79 
80 class TestOp3Cpu : public tensorflow::OpKernel {
81  public:
TestOp3Cpu(OpKernelConstruction * context)82   explicit TestOp3Cpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)83   void Compute(OpKernelContext* context) override {}
84 };
85 
86 REGISTER_KERNEL_BUILDER(
87     Name("Test3").Device(DEVICE_CPU).TypeConstraint<int8>("T"), TestOp3Cpu);
88 
89 namespace {
90 
91 class TestOp3Gpu : public tensorflow::OpKernel {
92  public:
TestOp3Gpu(OpKernelConstruction * context)93   explicit TestOp3Gpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)94   void Compute(OpKernelContext* context) override {}
95 };
96 
97 REGISTER_KERNEL_BUILDER(
98     Name("Test3").Device(DEVICE_GPU).TypeConstraint<float>("T"), TestOp3Cpu);
99 
100 // An Op registered for both
101 REGISTER_OP("Test4").Input("i: float").Output("o: float");
102 REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_CPU), DummyKernel);
103 REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_GPU), DummyKernel);
104 
105 // Kernels with different priorities.
106 REGISTER_OP("Test5").Input("a: T").Input("b: T").Attr("T: type");
107 
108 class TestOp5Cpu : public tensorflow::OpKernel {
109  public:
TestOp5Cpu(OpKernelConstruction * context)110   explicit TestOp5Cpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)111   void Compute(OpKernelContext* context) override {}
112 };
113 
114 REGISTER_KERNEL_BUILDER(Name("Test5").Device(DEVICE_CPU).Priority(2),
115                         TestOp5Cpu);
116 
117 class TestOp5Gpu : public tensorflow::OpKernel {
118  public:
TestOp5Gpu(OpKernelConstruction * context)119   explicit TestOp5Gpu(OpKernelConstruction* context) : OpKernel(context) {}
Compute(OpKernelContext * context)120   void Compute(OpKernelContext* context) override {}
121 };
122 
123 REGISTER_KERNEL_BUILDER(Name("Test5").Device(DEVICE_GPU).Priority(1),
124                         TestOp5Gpu);
125 
DeviceTypes()126 static std::vector<DeviceType> DeviceTypes() {
127   return {DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)};
128 }
129 
130 class OpKernelTest : public ::testing::Test {
131  public:
OpKernelTest()132   OpKernelTest() : device_(Env::Default()) {}
133 
134  protected:
CreateNodeDef(const string & op_type,const DataTypeVector & inputs)135   NodeDef CreateNodeDef(const string& op_type, const DataTypeVector& inputs) {
136     NodeDefBuilder builder(op_type + "-op", op_type);
137     for (DataType dt : inputs) {
138       builder.Input(FakeInput(dt));
139     }
140     NodeDef node_def;
141     TF_CHECK_OK(builder.Finalize(&node_def));
142     return node_def;
143   }
144 
ExpectEqual(const string & what,const DataTypeVector & expected,const DataTypeVector & observed)145   void ExpectEqual(const string& what, const DataTypeVector& expected,
146                    const DataTypeVector& observed) {
147     EXPECT_EQ(expected.size(), observed.size()) << what;
148     const size_t size = std::min(expected.size(), observed.size());
149     for (size_t i = 0; i < size; ++i) {
150       bool match = TypesCompatible(expected[i], observed[i]);
151       EXPECT_TRUE(match) << what << " i:" << i << ", expected: " << expected[i]
152                          << ", observed: " << observed[i];
153     }
154   }
155 
ExpectSuccess(const string & op_type,DeviceType device_type,const DataTypeVector & inputs,const DataTypeVector & outputs)156   void ExpectSuccess(const string& op_type, DeviceType device_type,
157                      const DataTypeVector& inputs,
158                      const DataTypeVector& outputs) {
159     Status status;
160     std::unique_ptr<OpKernel> op(CreateOpKernel(
161         std::move(device_type), &device_, cpu_allocator(),
162         CreateNodeDef(op_type, inputs), TF_GRAPH_DEF_VERSION, &status));
163     EXPECT_TRUE(status.ok()) << status;
164     EXPECT_TRUE(op != nullptr);
165     if (op != nullptr) {
166       ExpectEqual("inputs", op->input_types(), inputs);
167       ExpectEqual("outputs", op->output_types(), outputs);
168     }
169   }
170 
ExpectFailure(const string & ascii_node_def,DeviceType device_type,error::Code code)171   void ExpectFailure(const string& ascii_node_def, DeviceType device_type,
172                      error::Code code) {
173     NodeDef node_def;
174     protobuf::TextFormat::ParseFromString(ascii_node_def, &node_def);
175     Status status;
176     std::unique_ptr<OpKernel> op(
177         CreateOpKernel(std::move(device_type), &device_, cpu_allocator(),
178                        node_def, TF_GRAPH_DEF_VERSION, &status));
179     EXPECT_TRUE(op == nullptr);
180     EXPECT_FALSE(status.ok());
181     if (!status.ok()) {
182       LOG(INFO) << "Status message: " << status.error_message();
183       EXPECT_EQ(code, status.code());
184     }
185   }
186 
187  private:
188   DeviceBase device_;
189 };
190 
TEST_F(OpKernelTest,SuccessCpu)191 TEST_F(OpKernelTest, SuccessCpu) {
192   ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT, DT_INT32}, {DT_UINT8});
193   ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT_REF, DT_INT32}, {DT_UINT8});
194 }
195 
TEST_F(OpKernelTest,SuccessGpu)196 TEST_F(OpKernelTest, SuccessGpu) {
197   foo::match_signature_ = false;
198   ExpectSuccess("Test2", DEVICE_GPU, {DT_INT32}, {DT_INT32});
199   EXPECT_TRUE(foo::match_signature_);
200 }
201 
TEST_F(OpKernelTest,SuccessBothCpuAndGpu)202 TEST_F(OpKernelTest, SuccessBothCpuAndGpu) {
203   ExpectSuccess("Test3", DEVICE_CPU, {DT_INT8, DT_INT8}, {});
204   ExpectSuccess("Test3", DEVICE_GPU, {DT_FLOAT, DT_FLOAT}, {});
205 }
206 
TEST_F(OpKernelTest,CpuTypeRegistered)207 TEST_F(OpKernelTest, CpuTypeRegistered) {
208   NodeDef ndef = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
209   PrioritizedDeviceTypeVector devs;
210   TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
211   EXPECT_EQ(1, devs.size());
212   EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
213 }
214 
TEST_F(OpKernelTest,CpuAndGpuTypeRegistered)215 TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) {
216   {
217     // Try a node def of an op that is registered for a specific type
218     // only on CPU.
219     NodeDef ndef = CreateNodeDef("Test3", {DT_INT8, DT_INT8});
220     PrioritizedDeviceTypeVector devs;
221     TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
222     EXPECT_EQ(1, devs.size());
223     EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
224   }
225   {
226     // Try a node def of an op that is registered for a specific type
227     // only on GPU.
228     NodeDef ndef = CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT});
229     PrioritizedDeviceTypeVector devs;
230     TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
231     EXPECT_EQ(1, devs.size());
232     EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first);
233   }
234   {
235     // Try a node def of an op that is only registered for other types.
236     NodeDef ndef = CreateNodeDef("Test3", {DT_STRING, DT_STRING});
237     PrioritizedDeviceTypeVector devs;
238     TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
239     EXPECT_EQ(0, devs.size());
240   }
241 
242   {
243     // Try a node def of an op that is registered for both.
244     NodeDef ndef = CreateNodeDef("Test4", {DT_FLOAT});
245     PrioritizedDeviceTypeVector devs;
246     TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
247     EXPECT_EQ(2, devs.size());
248     EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first);
249     EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1].first);
250   }
251 
252   {
253     // Try a node def of an op where kernels have priorities.
254     NodeDef ndef = CreateNodeDef("Test5", {DT_STRING, DT_STRING});
255     PrioritizedDeviceTypeVector devs;
256     TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
257     EXPECT_EQ(2, devs.size());
258     EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
259     EXPECT_EQ(2, devs[0].second);
260     EXPECT_EQ(DeviceType(DEVICE_GPU), devs[1].first);
261     EXPECT_EQ(1, devs[1].second);
262   }
263 }
264 
TEST_F(OpKernelTest,NotFound)265 TEST_F(OpKernelTest, NotFound) {
266   const auto not_found = error::NOT_FOUND;
267   // Something with that op type name exists, but only with a
268   // different DeviceType.
269   ExpectFailure(CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}).DebugString(),
270                 DEVICE_GPU, not_found);
271   ExpectFailure(CreateNodeDef("Test3", {DT_INT8, DT_INT8}).DebugString(),
272                 DEVICE_GPU, not_found);
273   ExpectFailure(CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}).DebugString(),
274                 DEVICE_CPU, not_found);
275 
276   // No kernel with that signature registered.
277   ExpectFailure(CreateNodeDef("Test3", {DT_INT32, DT_INT32}).DebugString(),
278                 DEVICE_GPU, not_found);
279 
280   // Nothing with that op type name exists.
281   ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_CPU, not_found);
282   ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_GPU, not_found);
283 }
284 
TEST_F(OpKernelTest,TooFewInputs)285 TEST_F(OpKernelTest, TooFewInputs) {
286   const auto invalid = error::INVALID_ARGUMENT;
287   NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
288   node_def.clear_input();
289   ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
290   node_def.add_input("a");
291   ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
292 }
293 
TEST_F(OpKernelTest,TooManyInputs)294 TEST_F(OpKernelTest, TooManyInputs) {
295   const auto invalid = error::INVALID_ARGUMENT;
296   NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
297   node_def.add_input("c");
298   ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid);
299 }
300 
TEST_F(OpKernelTest,MatchSignatureFailes)301 TEST_F(OpKernelTest, MatchSignatureFailes) {
302   const auto invalid = error::INVALID_ARGUMENT;
303   foo::match_signature_ = true;
304   ExpectFailure(CreateNodeDef("Test2", {DT_FLOAT}).DebugString(), DEVICE_GPU,
305                 invalid);
306   EXPECT_FALSE(foo::match_signature_);
307 }
308 
309 class DummyDevice : public DeviceBase {
310  public:
DummyDevice(Env * env,bool save)311   DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
RequiresRecordingAccessedTensors() const312   bool RequiresRecordingAccessedTensors() const override { return save_; }
GetAllocator(AllocatorAttributes)313   Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
314     return cpu_allocator();
315   }
316 
317  private:
318   bool save_;
319 };
320 
TEST_F(OpKernelTest,SaveTempFalse)321 TEST_F(OpKernelTest, SaveTempFalse) {
322   Env* env = Env::Default();
323   OpKernelContext::Params params;
324   params.record_tensor_accesses = false;
325   params.device = new DummyDevice(env, params.record_tensor_accesses);
326   Status status;
327   std::unique_ptr<OpKernel> op(
328       CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
329                      CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
330                      TF_GRAPH_DEF_VERSION, &status));
331   EXPECT_TRUE(status.ok());
332   params.op_kernel = op.get();
333   OpKernelContext* ctx = new OpKernelContext(&params);
334 
335   Tensor t;
336   TF_EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t));
337 
338   TensorReferenceVector referenced_tensors;
339   ctx->retrieve_accessed_tensors(&referenced_tensors);
340   EXPECT_EQ(0, referenced_tensors.size());
341 
342   delete ctx;
343   delete params.device;
344 }
345 
TEST_F(OpKernelTest,SaveTempTrue)346 TEST_F(OpKernelTest, SaveTempTrue) {
347   Env* env = Env::Default();
348   OpKernelContext::Params params;
349   params.record_tensor_accesses = true;
350   params.device = new DummyDevice(env, params.record_tensor_accesses);
351   Status status;
352   std::unique_ptr<OpKernel> op(
353       CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
354                      CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
355                      TF_GRAPH_DEF_VERSION, &status));
356   EXPECT_TRUE(status.ok());
357   params.op_kernel = op.get();
358   OpKernelContext* ctx = new OpKernelContext(&params);
359 
360   Tensor t;
361   TF_EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t));
362 
363   TensorReferenceVector referenced_tensors;
364   ctx->retrieve_accessed_tensors(&referenced_tensors);
365   EXPECT_EQ(1, referenced_tensors.size());
366   for (auto& ref : referenced_tensors) {
367     ref.Unref();
368   }
369 
370   delete ctx;
371   delete params.device;
372 }
373 
TEST_F(OpKernelTest,InputDtype)374 TEST_F(OpKernelTest, InputDtype) {
375   Env* env = Env::Default();
376   OpKernelContext::Params params;
377   params.record_tensor_accesses = false;
378   params.device = new DummyDevice(env, params.record_tensor_accesses);
379   Status status;
380   std::unique_ptr<OpKernel> op(
381       CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
382                      CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
383                      TF_GRAPH_DEF_VERSION, &status));
384   EXPECT_TRUE(status.ok());
385   params.op_kernel = op.get();
386   Tensor a(DT_FLOAT, TensorShape({}));
387   Tensor b(DT_INT32, TensorShape({}));
388   Tensor c(DT_UINT8, TensorShape({}));
389   gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&a), TensorValue(&b),
390                                             TensorValue(&c)};
391   params.inputs = &inputs;
392   OpKernelContext* ctx = new OpKernelContext(&params);
393 
394   DataType dtype;
395   EXPECT_FALSE(ctx->input_dtype("non_existent_input", &dtype).ok());
396   ASSERT_TRUE(ctx->input_dtype("a", &dtype).ok());
397   EXPECT_EQ(dtype, DT_FLOAT);
398   ASSERT_TRUE(ctx->input_dtype("b", &dtype).ok());
399   EXPECT_EQ(dtype, DT_INT32);
400   delete ctx;
401   delete params.device;
402 }
403 
404 class OpKernelBuilderTest : public ::testing::Test {
405  protected:
406   // Each attr is described by a "name|type|value".
CreateNodeDef(const string & op_type,const std::vector<string> & attrs)407   NodeDef CreateNodeDef(const string& op_type,
408                         const std::vector<string>& attrs) {
409     NodeDef node_def;
410     node_def.set_name(op_type + "-op");
411     node_def.set_op(op_type);
412     for (const string& attr_desc : attrs) {
413       std::vector<string> parts = str_util::Split(attr_desc, '|');
414       CHECK_EQ(parts.size(), 3);
415       AttrValue attr_value;
416       CHECK(ParseAttrValue(parts[1], parts[2], &attr_value)) << attr_desc;
417       node_def.mutable_attr()->insert(
418           AttrValueMap::value_type(parts[0], attr_value));
419     }
420     return node_def;
421   }
422 
ExpectSuccess(const string & op_type,const DeviceType & device_type,const std::vector<string> & attrs,DataTypeSlice input_types={})423   std::unique_ptr<OpKernel> ExpectSuccess(const string& op_type,
424                                           const DeviceType& device_type,
425                                           const std::vector<string>& attrs,
426                                           DataTypeSlice input_types = {}) {
427     Status status;
428     NodeDef def = CreateNodeDef(op_type, attrs);
429     for (size_t i = 0; i < input_types.size(); ++i) {
430       def.add_input("a:0");
431     }
432 
433     Env* env = Env::Default();
434     DeviceBase device(env);
435 
436     // Test CreateOpKernel()
437     std::unique_ptr<OpKernel> op(CreateOpKernel(device_type, &device,
438                                                 cpu_allocator(), def,
439                                                 TF_GRAPH_DEF_VERSION, &status));
440     EXPECT_TRUE(status.ok()) << status;
441     EXPECT_TRUE(op != nullptr);
442     if (op != nullptr) {
443       EXPECT_EQ(input_types.size(), op->num_inputs());
444       EXPECT_EQ(0, op->num_outputs());
445     }
446 
447     // Test SupportedDeviceTypesForNode()
448     PrioritizedDeviceTypeVector devices;
449     TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
450     bool found = false;
451     for (const auto& dt : devices) {
452       if (dt.first == device_type) {
453         found = true;
454       }
455     }
456     EXPECT_TRUE(found) << "Missing " << device_type << " from "
457                        << devices.size() << " devices.";
458 
459     // In case the caller wants to use the OpKernel
460     return op;
461   }
462 
ExpectFailure(const string & op_type,const DeviceType & device_type,const std::vector<string> & attrs,error::Code code)463   void ExpectFailure(const string& op_type, const DeviceType& device_type,
464                      const std::vector<string>& attrs, error::Code code) {
465     Status status;
466     const NodeDef def = CreateNodeDef(op_type, attrs);
467     Env* env = Env::Default();
468     DeviceBase device(env);
469 
470     // Test CreateOpKernel().
471     std::unique_ptr<OpKernel> op(CreateOpKernel(device_type, &device,
472                                                 cpu_allocator(), def,
473                                                 TF_GRAPH_DEF_VERSION, &status));
474     EXPECT_TRUE(op == nullptr);
475     EXPECT_FALSE(status.ok());
476     if (!status.ok()) {
477       LOG(INFO) << "Status message: " << status.error_message();
478       EXPECT_EQ(code, status.code());
479 
480       // Test SupportedDeviceTypesForNode().
481       PrioritizedDeviceTypeVector devices;
482       if (errors::IsNotFound(status)) {
483         TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
484         for (const auto& dt : devices) {
485           EXPECT_NE(dt.first, device_type);
486         }
487       } else {
488         Status status2 =
489             SupportedDeviceTypesForNode(DeviceTypes(), def, &devices);
490         EXPECT_EQ(status.code(), status2.code());
491       }
492     }
493   }
494 
GetKernelClassName(const string & op_type,const DeviceType & device_type,const std::vector<string> & attrs,DataTypeSlice input_types={})495   string GetKernelClassName(const string& op_type,
496                             const DeviceType& device_type,
497                             const std::vector<string>& attrs,
498                             DataTypeSlice input_types = {}) {
499     NodeDef def = CreateNodeDef(op_type, attrs);
500     for (size_t i = 0; i < input_types.size(); ++i) {
501       def.add_input("a:0");
502     }
503 
504     const KernelDef* kernel_def = nullptr;
505     string kernel_class_name;
506     const Status status =
507         FindKernelDef(device_type, def, &kernel_def, &kernel_class_name);
508     if (status.ok()) {
509       return kernel_class_name;
510     } else if (errors::IsNotFound(status)) {
511       return "not found";
512     } else {
513       return status.ToString();
514     }
515   }
516 };
517 
518 REGISTER_OP("BuildCPU");
519 REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel);
520 
TEST_F(OpKernelBuilderTest,BuilderCPU)521 TEST_F(OpKernelBuilderTest, BuilderCPU) {
522   ExpectSuccess("BuildCPU", DEVICE_CPU, {});
523   EXPECT_EQ("DummyKernel", GetKernelClassName("BuildCPU", DEVICE_CPU, {}));
524   ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND);
525   EXPECT_EQ("not found", GetKernelClassName("BuildCPU", DEVICE_GPU, {}));
526 }
527 
528 REGISTER_OP("BuildGPU");
529 REGISTER_KERNEL_BUILDER(Name("BuildGPU").Device(DEVICE_GPU), DummyKernel);
530 
TEST_F(OpKernelBuilderTest,BuilderGPU)531 TEST_F(OpKernelBuilderTest, BuilderGPU) {
532   ExpectFailure("BuildGPU", DEVICE_CPU, {}, error::NOT_FOUND);
533   ExpectSuccess("BuildGPU", DEVICE_GPU, {});
534 }
535 
536 REGISTER_OP("BuildBoth");
537 REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_CPU), DummyKernel);
538 REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_GPU), DummyKernel);
539 
TEST_F(OpKernelBuilderTest,BuilderBoth)540 TEST_F(OpKernelBuilderTest, BuilderBoth) {
541   ExpectSuccess("BuildBoth", DEVICE_CPU, {});
542   ExpectSuccess("BuildBoth", DEVICE_GPU, {});
543 }
544 
545 REGISTER_OP("BuildTypeAttr").Attr("T: type");
546 REGISTER_KERNEL_BUILDER(
547     Name("BuildTypeAttr").Device(DEVICE_CPU).TypeConstraint<float>("T"),
548     DummyKernel);
549 
TEST_F(OpKernelBuilderTest,BuilderTypeAttr)550 TEST_F(OpKernelBuilderTest, BuilderTypeAttr) {
551   ExpectSuccess("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_FLOAT"});
552   ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_BOOL"},
553                 error::NOT_FOUND);
554   ExpectFailure("BuildTypeAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
555   ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|int|7"},
556                 error::INVALID_ARGUMENT);
557 }
558 
559 REGISTER_OP("BuildTypeListAttr").Attr("T: list(type)");
560 REGISTER_KERNEL_BUILDER(
561     Name("BuildTypeListAttr").Device(DEVICE_CPU).TypeConstraint<bool>("T"),
562     DummyKernel);
563 
TEST_F(OpKernelBuilderTest,BuilderTypeListAttr)564 TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) {
565   ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"});
566   EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
567                                               {"T|list(type)|[]"}));
568 
569   ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"});
570   EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
571                                               {"T|list(type)|[]"}));
572 
573   ExpectSuccess("BuildTypeListAttr", DEVICE_CPU,
574                 {"T|list(type)|[DT_BOOL, DT_BOOL]"});
575 
576   ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"},
577                 error::NOT_FOUND);
578   EXPECT_EQ("not found", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
579                                             {"T|list(type)|[DT_FLOAT]"}));
580 
581   ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
582   EXPECT_TRUE(str_util::StrContains(
583       GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {}),
584       "Invalid argument: "));
585 
586   ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"},
587                 error::INVALID_ARGUMENT);
588 }
589 
590 REGISTER_OP("DuplicateKernel");
591 REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU),
592                         DummyKernel);
593 REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU),
594                         DummyKernel);
595 
TEST_F(OpKernelBuilderTest,DuplicateKernel)596 TEST_F(OpKernelBuilderTest, DuplicateKernel) {
597   const NodeDef ndef = CreateNodeDef("DuplicateKernel", {});
598   PrioritizedDeviceTypeVector devs;
599   Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
600   ASSERT_FALSE(status.ok());
601   EXPECT_TRUE(str_util::StrContains(
602       status.error_message(), "Multiple OpKernel registrations match NodeDef"));
603 
604   ExpectFailure("DuplicateKernel", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
605 }
606 
607 REGISTER_OP("DuplicateKernelForT").Attr("T: type");
608 REGISTER_KERNEL_BUILDER(
609     Name("DuplicateKernelForT").Device(DEVICE_CPU).TypeConstraint<float>("T"),
610     DummyKernel);
611 REGISTER_KERNEL_BUILDER(
612     Name("DuplicateKernelForT").Device(DEVICE_CPU).TypeConstraint<float>("T"),
613     DummyKernel);
614 
TEST_F(OpKernelBuilderTest,DuplicateKernelForT)615 TEST_F(OpKernelBuilderTest, DuplicateKernelForT) {
616   const NodeDef ndef =
617       CreateNodeDef("DuplicateKernelForT", {"T|type|DT_FLOAT"});
618   PrioritizedDeviceTypeVector devs;
619   Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
620   ASSERT_FALSE(status.ok());
621   EXPECT_TRUE(str_util::StrContains(
622       status.error_message(), "Multiple OpKernel registrations match NodeDef"));
623 
624   ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_FLOAT"},
625                 error::INVALID_ARGUMENT);
626   ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_BOOL"},
627                 error::NOT_FOUND);
628 }
629 
630 REGISTER_OP("BadConstraint").Attr("dtype: type");
631 REGISTER_KERNEL_BUILDER(Name("BadConstraint")
632                             .Device(DEVICE_CPU)
633                             // Mistake: "T" should be "dtype".
634                             .TypeConstraint<float>("T"),
635                         DummyKernel);
636 
TEST_F(OpKernelBuilderTest,BadConstraint)637 TEST_F(OpKernelBuilderTest, BadConstraint) {
638   const NodeDef ndef = CreateNodeDef("BadConstraint", {});
639   PrioritizedDeviceTypeVector devs;
640   Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
641   ASSERT_FALSE(status.ok());
642   EXPECT_TRUE(
643       str_util::StrContains(status.error_message(),
644                             "OpKernel 'BadConstraint' has constraint on attr "
645                             "'T' not in NodeDef"));
646 
647   ExpectFailure("BadConstraint", DEVICE_CPU, {"dtype|type|DT_FLOAT"},
648                 error::INVALID_ARGUMENT);
649 }
650 
651 REGISTER_OP("ListOut").Output("a: int32").Output("b: T").Attr("T: list(type)");
652 REGISTER_KERNEL_BUILDER(Name("ListOut").Device(tensorflow::DEVICE_CPU),
653                         DummyKernel);
654 
TEST_F(OpKernelBuilderTest,OpOutputList)655 TEST_F(OpKernelBuilderTest, OpOutputList) {
656   Env* env = Env::Default();
657   OpKernelContext::Params params;
658   params.record_tensor_accesses = false;
659   std::unique_ptr<DummyDevice> device(
660       new DummyDevice(env, params.record_tensor_accesses));
661   params.device = device.get();
662   Status status;
663   std::unique_ptr<OpKernel> op(CreateOpKernel(
664       DEVICE_CPU, params.device, cpu_allocator(),
665       CreateNodeDef("ListOut", {"T|list(type)|[DT_FLOAT, DT_INT32]"}),
666       TF_GRAPH_DEF_VERSION, &status));
667   EXPECT_TRUE(status.ok()) << status.ToString();
668   params.op_kernel = op.get();
669   gtl::InlinedVector<TensorValue, 4> inputs{};
670   params.inputs = &inputs;
671   std::unique_ptr<OpKernelContext> ctx(new OpKernelContext(&params));
672 
673   EXPECT_EQ(DT_INT32, ctx->expected_output_dtype(0));
674   OpOutputList out_list;
675   EXPECT_FALSE(ctx->output_list("non_existent_output", &out_list).ok());
676   ASSERT_TRUE(ctx->output_list("b", &out_list).ok());
677   EXPECT_EQ(DT_FLOAT, out_list.expected_output_dtype(0));
678   EXPECT_EQ(DT_INT32, out_list.expected_output_dtype(1));
679 }
680 
681 class GetAttrKernel : public ::tensorflow::OpKernel {
682  public:
GetAttrKernel(OpKernelConstruction * context)683   explicit GetAttrKernel(OpKernelConstruction* context) : OpKernel(context) {
684     string attr_name;
685     OP_REQUIRES_OK(context, context->GetAttr("attr_name", &attr_name));
686 
687     status.emplace_back("s", context->GetAttr(attr_name, &s));
688     status.emplace_back("s_list", context->GetAttr(attr_name, &s_list));
689     status.emplace_back("i", context->GetAttr(attr_name, &i));
690     status.emplace_back("i_list", context->GetAttr(attr_name, &i_list));
691     status.emplace_back("i32", context->GetAttr(attr_name, &i32));
692     status.emplace_back("i32_list", context->GetAttr(attr_name, &i32_list));
693     status.emplace_back("f", context->GetAttr(attr_name, &f));
694     status.emplace_back("f_list", context->GetAttr(attr_name, &f_list));
695     status.emplace_back("b", context->GetAttr(attr_name, &b));
696     status.emplace_back("b_list", context->GetAttr(attr_name, &b_list));
697     status.emplace_back("type", context->GetAttr(attr_name, &type));
698     status.emplace_back("type_list", context->GetAttr(attr_name, &type_list));
699     status.emplace_back("type_vector",
700                         context->GetAttr(attr_name, &type_vector));
701     status.emplace_back("shape_proto",
702                         context->GetAttr(attr_name, &shape_proto));
703     status.emplace_back("shape_proto_list",
704                         context->GetAttr(attr_name, &shape_proto_list));
705     status.emplace_back("shape", context->GetAttr(attr_name, &shape));
706     status.emplace_back("shape_list", context->GetAttr(attr_name, &shape_list));
707   }
Compute(::tensorflow::OpKernelContext * context)708   void Compute(::tensorflow::OpKernelContext* context) override {}
709 
ExpectOk(std::initializer_list<string> keys)710   void ExpectOk(std::initializer_list<string> keys) {
711     for (const auto& key_status : status) {
712       // Only the status for keys in "keys" should be ok().
713       bool in_keys = false;
714       for (const string& key : keys) {
715         if (key_status.first == key) {
716           in_keys = true;
717         }
718       }
719       EXPECT_EQ(in_keys, key_status.second.ok())
720           << "key_status: " << key_status.first << ", " << key_status.second;
721     }
722   }
723 
724   string s;
725   std::vector<string> s_list;
726   int64 i;
727   std::vector<int64> i_list;
728   int32 i32;
729   std::vector<int32> i32_list;
730   float f;
731   std::vector<float> f_list;
732   bool b;
733   std::vector<bool> b_list;
734   DataType type;
735   std::vector<DataType> type_list;
736   DataTypeVector type_vector;
737   TensorShapeProto shape_proto;
738   std::vector<TensorShapeProto> shape_proto_list;
739   TensorShape shape;
740   std::vector<TensorShape> shape_list;
741   std::vector<std::pair<string, Status>> status;
742 };
743 
744 class GetAttrTest : public OpKernelBuilderTest {};
745 
746 REGISTER_OP("GetAttrStringList")
747     .Attr("attr_name: string")
748     .Attr("a: list(string)");
749 REGISTER_KERNEL_BUILDER(Name("GetAttrStringList").Device(DEVICE_CPU),
750                         GetAttrKernel);
751 
TEST_F(GetAttrTest,StringList)752 TEST_F(GetAttrTest, StringList) {
753   std::unique_ptr<OpKernel> op_kernel =
754       ExpectSuccess("GetAttrStringList", DEVICE_CPU,
755                     {"attr_name|string|'a'", "a|list(string)|['foo', 'bar']"});
756   auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
757   get_attr_kernel->ExpectOk({"s_list"});
758   EXPECT_EQ(std::vector<string>({"foo", "bar"}), get_attr_kernel->s_list);
759 
760   op_kernel = ExpectSuccess("GetAttrStringList", DEVICE_CPU,
761                             {"attr_name|string|'b'", "a|list(string)|['baz']"});
762   get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
763   get_attr_kernel->ExpectOk({});
764   EXPECT_TRUE(get_attr_kernel->s_list.empty());
765 }
766 
767 REGISTER_OP("GetAttrInt")
768     .Attr("attr_name: string")
769     .Attr("a: int")
770     .Attr("b: list(int)");
771 REGISTER_KERNEL_BUILDER(Name("GetAttrInt").Device(DEVICE_CPU), GetAttrKernel);
772 
TEST_F(GetAttrTest,Int)773 TEST_F(GetAttrTest, Int) {
774   std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
775       "GetAttrInt", DEVICE_CPU,
776       {"attr_name|string|'a'", "a|int|35", "b|list(int)|[-1, 2, -4]"});
777   auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
778   get_attr_kernel->ExpectOk({"i", "i32"});
779   EXPECT_EQ(35, get_attr_kernel->i);
780   EXPECT_EQ(35, get_attr_kernel->i32);
781 
782   op_kernel = ExpectSuccess(
783       "GetAttrInt", DEVICE_CPU,
784       {"attr_name|string|'b'", "a|int|35", "b|list(int)|[-1, 2, -4]"});
785   get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
786   get_attr_kernel->ExpectOk({"i_list", "i32_list"});
787   EXPECT_EQ(std::vector<int64>({-1, 2, -4}), get_attr_kernel->i_list);
788   EXPECT_EQ(std::vector<int32>({-1, 2, -4}), get_attr_kernel->i32_list);
789 
790   // 8589934592 == 2^33, too big to fit in an int32
791   op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU,
792                             {"attr_name|string|'a'", "a|int|8589934592",
793                              "b|list(int)|[-8589934592]"});
794   get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
795   get_attr_kernel->ExpectOk({"i"});  // no i32
796   EXPECT_EQ(8589934592ll, get_attr_kernel->i);
797   for (const auto& key_status : get_attr_kernel->status) {
798     if (key_status.first == "i32") {
799       EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code());
800       EXPECT_EQ("Attr a has value 8589934592 out of range for an int32",
801                 key_status.second.error_message());
802     }
803   }
804 
805   op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU,
806                             {"attr_name|string|'b'", "a|int|8589934592",
807                              "b|list(int)|[-8589934592]"});
808   get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
809   get_attr_kernel->ExpectOk({"i_list"});  // no i32_list
810   EXPECT_EQ(std::vector<int64>({-8589934592ll}), get_attr_kernel->i_list);
811   for (const auto& key_status : get_attr_kernel->status) {
812     if (key_status.first == "i32_list") {
813       EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code());
814       EXPECT_EQ("Attr b has value -8589934592 out of range for an int32",
815                 key_status.second.error_message());
816     }
817   }
818 }
819 
820 REGISTER_OP("GetAttrShape")
821     .Attr("attr_name: string")
822     .Attr("a: shape")
823     .Attr("b: list(shape)");
824 REGISTER_KERNEL_BUILDER(Name("GetAttrShape").Device(DEVICE_CPU), GetAttrKernel);
825 
TEST_F(GetAttrTest,Shape)826 TEST_F(GetAttrTest, Shape) {
827   std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
828       "GetAttrShape", DEVICE_CPU,
829       {"attr_name|string|'a'", "a|shape|{ dim { size: 3 } }",
830        "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"});
831   auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
832   get_attr_kernel->ExpectOk({"shape", "shape_proto"});
833   EXPECT_EQ(get_attr_kernel->shape_proto.ShortDebugString(), "dim { size: 3 }");
834   EXPECT_EQ("[3]", get_attr_kernel->shape.DebugString());
835 
836   op_kernel = ExpectSuccess(
837       "GetAttrShape", DEVICE_CPU,
838       {"attr_name|string|'b'", "a|shape|{ dim { size: 3 } }",
839        "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"});
840   get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
841   get_attr_kernel->ExpectOk({"shape_list", "shape_proto_list"});
842   ASSERT_EQ(2, get_attr_kernel->shape_proto_list.size());
843   EXPECT_EQ(get_attr_kernel->shape_proto_list[0].ShortDebugString(),
844             "dim { size: 2 }");
845   EXPECT_EQ(get_attr_kernel->shape_proto_list[1].ShortDebugString(),
846             "dim { size: 4 }");
847   ASSERT_EQ(2, get_attr_kernel->shape_list.size());
848   EXPECT_EQ("[2]", get_attr_kernel->shape_list[0].DebugString());
849   EXPECT_EQ("[4]", get_attr_kernel->shape_list[1].DebugString());
850 }
851 
852 REGISTER_OP("GetAttrType").Attr("attr_name: string").Attr("a: type");
853 REGISTER_KERNEL_BUILDER(Name("GetAttrType").Device(DEVICE_CPU), GetAttrKernel);
854 
TEST_F(GetAttrTest,Type)855 TEST_F(GetAttrTest, Type) {
856   std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
857       "GetAttrType", DEVICE_CPU, {"attr_name|string|'a'", "a|type|DT_FLOAT"});
858   auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
859   get_attr_kernel->ExpectOk({"type"});
860   EXPECT_EQ(DT_FLOAT, get_attr_kernel->type);
861 }
862 
863 REGISTER_OP("GetAttrTypeList").Attr("attr_name: string").Attr("a: list(type)");
864 REGISTER_KERNEL_BUILDER(Name("GetAttrTypeList").Device(DEVICE_CPU),
865                         GetAttrKernel);
866 
TEST_F(GetAttrTest,TypeList)867 TEST_F(GetAttrTest, TypeList) {
868   std::unique_ptr<OpKernel> op_kernel = ExpectSuccess(
869       "GetAttrTypeList", DEVICE_CPU,
870       {"attr_name|string|'a'", "a|list(type)|[DT_INT32, DT_BOOL]"});
871   auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get());
872 
873   get_attr_kernel->ExpectOk({"type_list", "type_vector"});
874   ASSERT_EQ(2, get_attr_kernel->type_list.size());
875   EXPECT_EQ(DT_INT32, get_attr_kernel->type_list[0]);
876   EXPECT_EQ(DT_BOOL, get_attr_kernel->type_list[1]);
877   ASSERT_EQ(2, get_attr_kernel->type_vector.size());
878   EXPECT_EQ(DT_INT32, get_attr_kernel->type_vector[0]);
879   EXPECT_EQ(DT_BOOL, get_attr_kernel->type_vector[1]);
880 }
881 
882 class BaseKernel : public ::tensorflow::OpKernel {
883  public:
BaseKernel(OpKernelConstruction * context)884   explicit BaseKernel(OpKernelConstruction* context) : OpKernel(context) {}
Compute(::tensorflow::OpKernelContext * context)885   void Compute(::tensorflow::OpKernelContext* context) override {}
886   virtual int Which() const = 0;
887 };
888 
889 template <int WHICH>
890 class LabeledKernel : public BaseKernel {
891  public:
892   using BaseKernel::BaseKernel;
Which() const893   int Which() const override { return WHICH; }
894 };
895 
896 class LabelTest : public OpKernelBuilderTest {};
897 
898 REGISTER_OP("LabeledKernel");
899 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU),
900                         LabeledKernel<0>);
901 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("one"),
902                         LabeledKernel<1>);
903 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"),
904                         LabeledKernel<2>);
905 REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"),
906                         LabeledKernel<3>);
907 
TEST_F(LabelTest,Default)908 TEST_F(LabelTest, Default) {
909   std::unique_ptr<OpKernel> op_kernel =
910       ExpectSuccess("LabeledKernel", DEVICE_CPU, {});
911   auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
912   EXPECT_EQ(0, get_labeled_kernel->Which());
913 
914   EXPECT_EQ("LabeledKernel<0>",
915             GetKernelClassName("LabeledKernel", DEVICE_CPU, {}));
916 }
917 
TEST_F(LabelTest,Specified)918 TEST_F(LabelTest, Specified) {
919   std::unique_ptr<OpKernel> op_kernel =
920       ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"});
921   auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
922   EXPECT_EQ(1, get_labeled_kernel->Which());
923   EXPECT_EQ("LabeledKernel<1>", GetKernelClassName("LabeledKernel", DEVICE_CPU,
924                                                    {"_kernel|string|'one'"}));
925 }
926 
TEST_F(LabelTest,Duplicate)927 TEST_F(LabelTest, Duplicate) {
928   ExpectFailure("LabeledKernel", DEVICE_CPU, {"_kernel|string|'dupe'"},
929                 error::INVALID_ARGUMENT);
930 }
931 
BM_InputRangeHelper(int iters,const NodeDef & node_def,const char * input_name,int expected_start,int expected_stop)932 void BM_InputRangeHelper(int iters, const NodeDef& node_def,
933                          const char* input_name, int expected_start,
934                          int expected_stop) {
935   Status status;
936   std::unique_ptr<DummyDevice> device(new DummyDevice(Env::Default(), false));
937 
938   std::unique_ptr<OpKernel> op(CreateOpKernel(DEVICE_CPU, device.get(),
939                                               cpu_allocator(), node_def,
940                                               TF_GRAPH_DEF_VERSION, &status));
941   TF_CHECK_OK(status);
942 
943   testing::StartTiming();
944   for (int i = 0; i < iters; ++i) {
945     int start;
946     int stop;
947     TF_CHECK_OK(op->InputRange(input_name, &start, &stop));
948     EXPECT_EQ(expected_start, start);
949     EXPECT_EQ(expected_stop, stop);
950   }
951   testing::StopTiming();
952 }
953 
954 REGISTER_KERNEL_BUILDER(Name("ConcatV2").Device(DEVICE_CPU), DummyKernel);
955 REGISTER_KERNEL_BUILDER(Name("Select").Device(DEVICE_CPU), DummyKernel);
956 
BM_ConcatInputRange(int iters)957 void BM_ConcatInputRange(int iters) {
958   testing::StopTiming();
959 
960   // Create a ConcatV2 NodeDef with 4 inputs (plus the axis).
961   NodeDef node_def;
962   node_def.set_name("concat-op");
963   node_def.set_op("ConcatV2");
964   AttrValue attr_N;
965   attr_N.set_i(4);
966   AttrValue attr_T;
967   attr_T.set_type(DT_FLOAT);
968   AttrValue attr_Tidx;
969   attr_Tidx.set_type(DT_INT32);
970   node_def.mutable_attr()->insert({"N", attr_N});
971   node_def.mutable_attr()->insert({"T", attr_T});
972   node_def.mutable_attr()->insert({"Tidx", attr_Tidx});
973   for (size_t i = 0; i < 5; ++i) {
974     node_def.add_input(strings::StrCat("a:", i));
975   }
976 
977   BM_InputRangeHelper(iters, node_def, "values", 0, 4);
978 }
979 
BM_SelectInputRange(int iters)980 void BM_SelectInputRange(int iters) {
981   testing::StopTiming();
982 
983   // Create a Select NodeDef with 3 inputs.
984   NodeDef node_def;
985   node_def.set_name("select-op");
986   node_def.set_op("Select");
987   AttrValue attr_T;
988   attr_T.set_type(DT_FLOAT);
989   node_def.mutable_attr()->insert({"T", attr_T});
990   for (size_t i = 0; i < 3; ++i) {
991     node_def.add_input(strings::StrCat("a:", i));
992   }
993 
994   BM_InputRangeHelper(iters, node_def, "condition", 0, 1);
995 }
996 
997 BENCHMARK(BM_ConcatInputRange);
998 BENCHMARK(BM_SelectInputRange);
999 
TEST(RegisteredKernels,CanCallGetAllRegisteredKernels)1000 TEST(RegisteredKernels, CanCallGetAllRegisteredKernels) {
1001   auto kernel_list = GetAllRegisteredKernels();
1002   auto all_registered_kernels = kernel_list.kernel();
1003   auto has_name_test1 = [](const KernelDef& k) { return k.op() == "Test1"; };
1004 
1005   // Verify we can find the "Test1" op registered above
1006   auto test1_it = std::find_if(all_registered_kernels.begin(),
1007                                all_registered_kernels.end(), has_name_test1);
1008   ASSERT_NE(test1_it, all_registered_kernels.end());
1009   EXPECT_EQ(test1_it->device_type(), "CPU");
1010 
1011   // Verify there was just one kernel
1012   ++test1_it;
1013   EXPECT_EQ(
1014       std::find_if(test1_it, all_registered_kernels.end(), has_name_test1),
1015       all_registered_kernels.end());
1016 }
1017 
1018 // Simple test just to check we can call LogAllRegisteredKernels
TEST(RegisteredKernels,CanLogAllRegisteredKernels)1019 TEST(RegisteredKernels, CanLogAllRegisteredKernels) {
1020   tensorflow::LogAllRegisteredKernels();
1021 }
1022 
TEST(RegisteredKernels,GetFilteredRegisteredKernels)1023 TEST(RegisteredKernels, GetFilteredRegisteredKernels) {
1024   auto has_name_test1 = [](const KernelDef& k) { return k.op() == "Test1"; };
1025   auto kernel_list = GetFilteredRegisteredKernels(has_name_test1);
1026   ASSERT_EQ(kernel_list.kernel_size(), 1);
1027   EXPECT_EQ(kernel_list.kernel(0).op(), "Test1");
1028   EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
1029 }
1030 
TEST(RegisteredKernels,GetRegisteredKernelsForOp)1031 TEST(RegisteredKernels, GetRegisteredKernelsForOp) {
1032   auto kernel_list = GetRegisteredKernelsForOp("Test1");
1033   ASSERT_EQ(kernel_list.kernel_size(), 1);
1034   EXPECT_EQ(kernel_list.kernel(0).op(), "Test1");
1035   EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
1036 }
1037 
1038 }  // namespace
1039 }  // namespace tensorflow
1040