1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/resource_mgr.h"
20 #include "tensorflow/core/lib/core/status.h"
21 #include "tensorflow/core/public/version.h"
22 
23 namespace tensorflow {
24 
25 REGISTER_OP("KernelLabel")
26     .Output("result: string")
27     .SetShapeFn(shape_inference::ScalarShape);
28 
29 REGISTER_OP("KernelLabelRequired")
30     .Input("input: int32")
31     .Output("result: string")
__anonf24a52a80102(shape_inference::InferenceContext* c) 32     .SetShapeFn([](shape_inference::InferenceContext* c) {
33       shape_inference::ShapeHandle out;
34       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &out));
35       c->set_output(0, c->Scalar());
36       return Status::OK();
37     });
38 
39 REGISTER_OP("GraphDefVersion")
40     .Output("version: int32")
41     .SetIsStateful()
42     .SetShapeFn(shape_inference::ScalarShape);
43 
44 REGISTER_OP("RequiresOlderGraphVersion")
45     .Output("version: int32")
46     .SetIsStateful()
__anonf24a52a80202(shape_inference::InferenceContext* c) 47     .SetShapeFn([](shape_inference::InferenceContext* c) {
48       if (c->graph_def_version() != TF_GRAPH_DEF_VERSION - 1) {
49         return errors::InvalidArgument("Wrong graph version for shape");
50       }
51       return shape_inference::ScalarShape(c);
52     });
53 
54 REGISTER_OP("Old")
55     .SetShapeFn(shape_inference::UnknownShape)
56     .Deprecated(8, "For reasons");
57 
58 REGISTER_RESOURCE_HANDLE_OP(StubResource);
59 
60 REGISTER_OP("ResourceInitializedOp")
61     .Input("resource: resource")
62     .Output("initialized: bool")
63     .SetShapeFn(shape_inference::ScalarShape);
64 
65 REGISTER_OP("ResourceCreateOp")
66     .Input("resource: resource")
67     .SetShapeFn(shape_inference::UnknownShape);
68 
69 REGISTER_OP("ResourceUsingOp")
70     .Input("resource: resource")
71     .SetShapeFn(shape_inference::UnknownShape);
72 
73 REGISTER_OP("TestStringOutput")
74     .Input("input: float")
75     .Output("output1: float")
76     .Output("output2: string")
77     .SetShapeFn(shape_inference::UnknownShape);
78 
79 REGISTER_OP("Namespace>TestStringOutput")
80     .Input("input: float")
81     .Output("output1: float")
82     .Output("output2: string")
83     .SetShapeFn(shape_inference::UnknownShape);
84 
85 REGISTER_OP("TestAttr")
86     .Output("out: T")
87     .Attr("T: {float, double}")
88     .SetShapeFn(shape_inference::UnknownShape);
89 
90 namespace {
91 enum KernelLabel { DEFAULT_LABEL, OVERLOAD_1_LABEL, OVERLOAD_2_LABEL };
92 }  // namespace
93 
94 template <KernelLabel KL>
95 class KernelLabelOp : public OpKernel {
96  public:
97   using OpKernel::OpKernel;
98 
Compute(OpKernelContext * ctx)99   void Compute(OpKernelContext* ctx) override {
100     Tensor* output;
101     OP_REQUIRES_OK(ctx,
102                    ctx->allocate_output("result", TensorShape({}), &output));
103     switch (KL) {
104       case DEFAULT_LABEL:
105         output->scalar<tstring>()() = "My label is: default";
106         break;
107       case OVERLOAD_1_LABEL:
108         output->scalar<tstring>()() = "My label is: overload_1";
109         break;
110       case OVERLOAD_2_LABEL:
111         output->scalar<tstring>()() = "My label is: overload_2";
112         break;
113     }
114   }
115 };
116 
117 REGISTER_KERNEL_BUILDER(Name("KernelLabel").Device(DEVICE_CPU),
118                         KernelLabelOp<DEFAULT_LABEL>);
119 REGISTER_KERNEL_BUILDER(Name("KernelLabel")
120                             .Device(DEVICE_CPU)
121                             .Label("overload_1"),
122                         KernelLabelOp<OVERLOAD_1_LABEL>);
123 REGISTER_KERNEL_BUILDER(Name("KernelLabel")
124                             .Device(DEVICE_CPU)
125                             .Label("overload_2"),
126                         KernelLabelOp<OVERLOAD_2_LABEL>);
127 
128 // All "KernelLabelRequired" kernels have labels
129 REGISTER_KERNEL_BUILDER(
130     Name("KernelLabelRequired").Device(DEVICE_CPU).Label("overload_1"),
131     KernelLabelOp<OVERLOAD_1_LABEL>);
132 REGISTER_KERNEL_BUILDER(
133     Name("KernelLabelRequired").Device(DEVICE_CPU).Label("overload_2"),
134     KernelLabelOp<OVERLOAD_2_LABEL>);
135 
136 class GraphDefVersionOp : public OpKernel {
137  public:
GraphDefVersionOp(OpKernelConstruction * ctx)138   explicit GraphDefVersionOp(OpKernelConstruction* ctx)
139       : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {}
140 
Compute(OpKernelContext * ctx)141   void Compute(OpKernelContext* ctx) override {
142     Tensor* output = nullptr;
143     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
144     output->scalar<int>()() = graph_def_version_;
145   }
146 
147  private:
148   const int graph_def_version_;
149 };
150 
151 REGISTER_KERNEL_BUILDER(Name("GraphDefVersion").Device(DEVICE_CPU),
152                         GraphDefVersionOp);
153 
154 class OldOp : public OpKernel {
155  public:
OldOp(OpKernelConstruction * ctx)156   explicit OldOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
157 
Compute(OpKernelContext * ctx)158   void Compute(OpKernelContext* ctx) override {}
159 };
160 
161 REGISTER_KERNEL_BUILDER(Name("Old").Device(DEVICE_CPU), OldOp);
162 
163 // Stubbed-out resource to test resource handle ops.
164 class StubResource : public ResourceBase {
165  public:
DebugString() const166   string DebugString() const override { return ""; }
167 };
168 
169 REGISTER_RESOURCE_HANDLE_KERNEL(StubResource);
170 
171 REGISTER_KERNEL_BUILDER(Name("ResourceInitializedOp").Device(DEVICE_CPU),
172                         IsResourceInitialized<StubResource>);
173 
174 class ResourceCreateOp : public OpKernel {
175  public:
ResourceCreateOp(OpKernelConstruction * c)176   explicit ResourceCreateOp(OpKernelConstruction* c) : OpKernel(c) {}
177 
Compute(OpKernelContext * c)178   void Compute(OpKernelContext* c) override {
179     OP_REQUIRES_OK(c,
180                    CreateResource(c, HandleFromInput(c, 0), new StubResource));
181   }
182 };
183 
184 REGISTER_KERNEL_BUILDER(Name("ResourceCreateOp").Device(DEVICE_CPU),
185                         ResourceCreateOp);
186 
187 // Uses a ResourceHandle to check its validity.
188 class ResourceUsingOp : public OpKernel {
189  public:
ResourceUsingOp(OpKernelConstruction * context)190   explicit ResourceUsingOp(OpKernelConstruction* context) : OpKernel(context) {}
191 
Compute(OpKernelContext * ctx)192   void Compute(OpKernelContext* ctx) override {
193     StubResource* unused;
194     OP_REQUIRES_OK(ctx, LookupResource<StubResource>(
195                             ctx, HandleFromInput(ctx, 0), &unused));
196   }
197 };
198 
199 REGISTER_KERNEL_BUILDER(Name("ResourceUsingOp").Device(DEVICE_CPU),
200                         ResourceUsingOp);
201 
202 class TestAttrOp : public OpKernel {
203  public:
TestAttrOp(OpKernelConstruction * ctx)204   explicit TestAttrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
205 
Compute(OpKernelContext * ctx)206   void Compute(OpKernelContext* ctx) override {
207     Tensor* output;
208     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
209     output->scalar<float>()() = 1.0;
210   }
211 };
212 
213 REGISTER_KERNEL_BUILDER(
214     Name("TestAttr").Device(DEVICE_CPU).TypeConstraint<float>("T"), TestAttrOp);
215 
216 // Various test ops without kernels. These are used to test graph construction.
217 
218 REGISTER_OP("A")
219     .Output("out: float32")
220     .SetShapeFn(shape_inference::UnknownShape);
221 
222 REGISTER_OP("B")
223     .Output("out: float32")
224     .SetShapeFn(shape_inference::UnknownShape);
225 
226 REGISTER_OP("Foo1")
227     .Input("a: float32")
228     .Input("b: int32")
229     .Input("c: int32")
230     .Output("d: float32")
231     .Output("e: int32")
232     .SetShapeFn(shape_inference::UnknownShape);
233 
234 REGISTER_OP("Foo2")
235     .Input("a: float32")
236     .Input("b: string")
237     .Input("c: string")
238     .Output("d: float32")
239     .Output("e: int32")
240     .SetShapeFn(shape_inference::UnknownShape);
241 
242 REGISTER_OP("Foo3")
243     .Input("a: float32")
244     .Input("b: string")
245     .Input("c: float32")
246     .Output("d: float32")
247     .Output("e: int32")
248     .SetShapeFn(shape_inference::UnknownShape);
249 
250 REGISTER_OP("CopyOp").Input("a: T").Output("b: T").Attr("T: type").SetShapeFn(
251     shape_inference::UnknownShape);
252 
253 REGISTER_OP("None").SetShapeFn(shape_inference::UnknownShape);
254 
255 REGISTER_OP("IntOutput")
256     .Output("a: int32")
257     .SetShapeFn(shape_inference::UnknownShape);
258 
259 REGISTER_OP("Int64Output")
260     .Output("out: int64")
261     .SetShapeFn(shape_inference::UnknownShape);
262 
263 REGISTER_OP("RefOutput")
264     .Output("a: Ref(int32)")
265     .SetShapeFn(shape_inference::UnknownShape);
266 
267 REGISTER_OP("FloatOutput")
268     .Output("a: float32")
269     .SetShapeFn(shape_inference::UnknownShape);
270 
271 REGISTER_OP("TwoFloatOutputs")
272     .Output("a: float32")
273     .Output("b: float32")
274     .SetShapeFn(shape_inference::UnknownShape);
275 
276 REGISTER_OP("FiveFloatOutputs")
277     .Output("a: float32")
278     .Output("b: float32")
279     .Output("c: float32")
280     .Output("d: float32")
281     .Output("e: float32")
282     .SetShapeFn(shape_inference::UnknownShape);
283 
284 REGISTER_OP("RefOutputFloatOutput")
285     .Output("a: Ref(float32)")
286     .Output("b: float32")
287     .SetShapeFn(shape_inference::UnknownShape);
288 
289 REGISTER_OP("RefInputFloatInput")
290     .Input("a: Ref(float)")
291     .Input("b: float")
292     .SetShapeFn(shape_inference::UnknownShape);
293 
294 REGISTER_OP("IntInput")
295     .Input("a: int32")
296     .SetShapeFn(shape_inference::UnknownShape);
297 
298 REGISTER_OP("IntInputIntOutput")
299     .Input("a: int32")
300     .Output("b: int32")
301     .SetShapeFn(shape_inference::UnknownShape);
302 
303 REGISTER_OP("FloatInput")
304     .Input("a: float32")
305     .SetShapeFn(shape_inference::UnknownShape);
306 
307 REGISTER_OP("TwoIntOutputs")
308     .Output("a: int32")
309     .Output("b: int32")
310     .SetShapeFn(shape_inference::UnknownShape);
311 
312 REGISTER_OP("IntOutputFloatOutput")
313     .Output("a: int32")
314     .Output("b: float32")
315     .SetShapeFn(shape_inference::UnknownShape);
316 
317 REGISTER_OP("FloatOutputStringOutput")
318     .Output("a: float32")
319     .Output("b: string")
320     .SetShapeFn(shape_inference::UnknownShape);
321 
322 REGISTER_OP("TwoIntInputs")
323     .Input("a: int32")
324     .Input("b: int32")
325     .SetShapeFn(shape_inference::UnknownShape);
326 
327 REGISTER_OP("TwoFloatInputs")
328     .Input("a: float32")
329     .Input("b: float32")
330     .SetShapeFn(shape_inference::UnknownShape);
331 
332 REGISTER_OP("IntInputFloatInput")
333     .Input("a: int32")
334     .Input("b: float32")
335     .SetShapeFn(shape_inference::UnknownShape);
336 
337 REGISTER_OP("RefInputIntInput")
338     .Input("a: Ref(int32)")
339     .Input("b: int32")
340     .SetShapeFn(shape_inference::UnknownShape);
341 
342 REGISTER_OP("TwoFloatInputsFloatOutput")
343     .Input("a: float32")
344     .Input("b: float32")
345     .Output("c: float32")
346     .SetShapeFn(shape_inference::UnknownShape);
347 
348 REGISTER_OP("TwoFloatInputsIntOutput")
349     .Input("a: float32")
350     .Input("b: float32")
351     .Output("c: int32")
352     .SetShapeFn(shape_inference::UnknownShape);
353 
354 REGISTER_OP("RefInputFloatInputIntOutput")
355     .Input("a: Ref(float32)")
356     .Input("b: float32")
357     .Output("c: int32")
358     .SetShapeFn(shape_inference::UnknownShape);
359 
360 REGISTER_OP("ListInput")
361     .Input("a: N * T")
362     .Attr("N: int >= 1")
363     .Attr("T: type")
364     .SetShapeFn(shape_inference::UnknownShape);
365 
366 REGISTER_OP("ListOutput")
367     .Output("a: T")
368     .Attr("T: list(type) >= 1")
369     .SetShapeFn(shape_inference::UnknownShape);
370 
371 REGISTER_OP("Unary").Input("a: T").Output("b: T").Attr("T: type").SetShapeFn(
372     shape_inference::UnknownShape);
373 
374 REGISTER_OP("OpWithDefaultAttr")
375     .Output("a: int32")
376     .Attr("default_float: float = 123.0")
377     .SetShapeFn(shape_inference::UnknownShape);
378 
379 REGISTER_OP("OpWithFutureDefaultAttr")
380     .SetShapeFn(shape_inference::UnknownShape);
381 
382 REGISTER_OP("IntAttr")
383     .Output("out: int64")
384     .Attr("foo: int = 1")
385     .SetShapeFn(shape_inference::UnknownShape);
386 
387 REGISTER_OP("StringListAttr")
388     .Attr("a: list(string)")
389     .Attr("b: string")
390     .SetShapeFn(shape_inference::UnknownShape);
391 
392 REGISTER_OP("DefaultAttrs")
393     .Attr("string_val: string = 'abc'")
394     .Attr("string_list_val: list(string) = ['abc', '']")
395     .Attr("int_val: int = 123")
396     .Attr("int_list_val: list(int) = [1, 2, 3]")
397     .Attr("float_val: float = 10.0")
398     .Attr("float_list_val: list(float) = [10.0]")
399     .Attr("bool_val: bool = true")
400     .Attr("bool_list_val: list(bool) = [true, false]")
401     .Attr("type_val: type = DT_INT32")
402     .Attr("type_list_val: list(type) = [DT_INT32, DT_FLOAT]")
403     .Attr("shape_val: shape = { dim { size: 2 } dim { size: 1 } }")
404     .Attr("shape_list_val: list(shape) = [{}, { dim { size: 1} }]")
405     .Attr("tensor_val: tensor = { dtype: DT_INT32 tensor_shape: {} int_val: 1}")
406     .Attr(
407         "tensor_list_val: list(tensor) = "
408         "[{ dtype: DT_INT32 tensor_shape: {} int_val: 1}]")
409     .SetShapeFn(shape_inference::UnknownShape);
410 
411 REGISTER_OP("FuncAttr")
412     .Attr("f: func")
413     .SetShapeFn(shape_inference::UnknownShape);
414 
415 REGISTER_OP("FuncListAttr")
416     .Attr("f: list(func)")
417     .SetShapeFn(shape_inference::UnknownShape);
418 
419 REGISTER_OP("Simple")
420     .Input("a: int32")
421     .Output("out: float")
422     .SetShapeFn(shape_inference::UnknownShape);
423 
424 REGISTER_OP("OutT").Output("a: T").Attr("T: type").SetShapeFn(
425     shape_inference::UnknownShape);
426 
427 REGISTER_OP("ReservedInput")
428     .Input("input: int32")
429     .SetShapeFn(shape_inference::UnknownShape);
430 
431 REGISTER_OP("Polymorphic")
432     .Input("a: T")
433     .Output("out: T")
434     .Attr("T: type")
435     .SetShapeFn(shape_inference::UnknownShape);
436 
437 REGISTER_OP("PolymorphicOut")
438     .Output("out: T")
439     .Attr("T: type")
440     .SetShapeFn(shape_inference::UnknownShape);
441 
442 REGISTER_OP("PolymorphicDefaultOut")
443     .Output("out: T")
444     .Attr("T: type = DT_STRING")
445     .SetShapeFn(shape_inference::UnknownShape);
446 
447 REGISTER_OP("Binary")
448     .Input("a: T")
449     .Input("b: T")
450     .Output("out: T")
451     .Attr("T: type")
452     .SetShapeFn(shape_inference::UnknownShape);
453 
454 REGISTER_OP("Restrict")
455     .Input("a: T")
456     .Output("out: T")
457     .Attr("T: {string, bool}")
458     .SetShapeFn(shape_inference::UnknownShape);
459 
460 REGISTER_OP("TypeList")
461     .Input("a: T")
462     .Attr("T: list(type) >= 0")
463     .SetShapeFn(shape_inference::UnknownShape);
464 
465 REGISTER_OP("TypeListTwice")
466     .Input("a: T")
467     .Input("b: T")
468     .Attr("T: list(type) >= 0")
469     .SetShapeFn(shape_inference::UnknownShape);
470 
471 REGISTER_OP("OutTypeList")
472     .Output("out: T")
473     .Attr("T: list(type) >= 0")
474     .SetShapeFn(shape_inference::UnknownShape);
475 
476 REGISTER_OP("TypeListRestrict")
477     .Input("a: T")
478     .Attr("T: list({string, bool})")
479     .SetShapeFn(shape_inference::UnknownShape);
480 
481 REGISTER_OP("OutTypeListRestrict")
482     .Output("out: t")
483     .Attr("t: list({string, bool})")
484     .SetShapeFn(shape_inference::UnknownShape);
485 
486 REGISTER_OP("Attr").Attr("a: int").SetShapeFn(shape_inference::UnknownShape);
487 
488 REGISTER_OP("AttrFloat")
489     .Attr("a: float")
490     .SetShapeFn(shape_inference::UnknownShape);
491 
492 REGISTER_OP("AttrBool")
493     .Attr("a: bool")
494     .SetShapeFn(shape_inference::UnknownShape);
495 
496 REGISTER_OP("AttrBoolList")
497     .Attr("a: list(bool)")
498     .SetShapeFn(shape_inference::UnknownShape);
499 
500 REGISTER_OP("AttrMin")
501     .Attr("a: int >= 5")
502     .SetShapeFn(shape_inference::UnknownShape);
503 
504 REGISTER_OP("AttrListMin")
505     .Attr("a: list(int) >= 2")
506     .SetShapeFn(shape_inference::UnknownShape);
507 
508 REGISTER_OP("AttrEnum")
509     .Attr("a: {'apples', 'oranges'}")
510     .SetShapeFn(shape_inference::UnknownShape);
511 
512 REGISTER_OP("AttrEnumList")
513     .Attr("a: list({'apples', 'oranges'})")
514     .SetShapeFn(shape_inference::UnknownShape);
515 
516 REGISTER_OP("AttrShape")
517     .Attr("a: shape")
518     .SetShapeFn(shape_inference::UnknownShape);
519 
520 REGISTER_OP("AttrShapeList")
521     .Attr("a: list(shape)")
522     .SetShapeFn(shape_inference::UnknownShape);
523 
524 REGISTER_OP("AttrPartialShape")
525     .Attr("a: shape")
526     .SetShapeFn(shape_inference::UnknownShape);
527 
528 REGISTER_OP("AttrPartialShapeList")
529     .Attr("a: list(shape)")
530     .SetShapeFn(shape_inference::UnknownShape);
531 
532 REGISTER_OP("AttrDefault")
533     .Attr("a: string = 'banana'")
534     .SetShapeFn(shape_inference::UnknownShape);
535 
536 REGISTER_OP("AttrListDefault")
537     .Attr("a: list(int) = [5, 15]")
538     .SetShapeFn(shape_inference::UnknownShape);
539 
540 REGISTER_OP("AttrEmptyListDefault")
541     .Attr("a: list(float) = []")
542     .SetShapeFn(shape_inference::UnknownShape);
543 
544 REGISTER_OP("ReservedAttr")
545     .Attr("range: int")
546     .SetShapeFn(shape_inference::UnknownShape);
547 
548 REGISTER_OP("AttrTypeDefault")
549     .Input("a: T")
550     .Attr("T: type = DT_INT32")
551     .SetShapeFn(shape_inference::UnknownShape);
552 
553 REGISTER_OP("AttrListTypeDefault")
554     .Input("a: N * T")
555     .Input("b: N * T")
556     .Attr("T: type = DT_INT32")
557     .Attr("N: int")
558     .SetShapeFn(shape_inference::UnknownShape);
559 
560 REGISTER_OP("NIntsIn")
561     .Input("a: N * int32")
562     .Attr("N: int >= 2")
563     .SetShapeFn(shape_inference::UnknownShape);
564 
565 REGISTER_OP("NPolymorphicIn")
566     .Input("a: N * T")
567     .Attr("T: type")
568     .Attr("N: int >= 2")
569     .SetShapeFn(shape_inference::UnknownShape);
570 
571 REGISTER_OP("NPolymorphicRestrictIn")
572     .Input("a: N * T")
573     .Attr("T: {string, bool}")
574     .Attr("N: int >= 2")
575     .SetShapeFn(shape_inference::UnknownShape);
576 
577 REGISTER_OP("NInTwice")
578     .Input("a: N * int32")
579     .Input("b: N * string")
580     .Attr("N: int >= 0")
581     .SetShapeFn(shape_inference::UnknownShape);
582 
583 REGISTER_OP("NInPolymorphicTwice")
584     .Input("a: N * T")
585     .Input("b: N * T")
586     .Attr("T: type")
587     .Attr("N: int >= 0")
588     .SetShapeFn(shape_inference::UnknownShape);
589 
590 REGISTER_OP("NInTwoTypeVariables")
591     .Input("a: N * S")
592     .Input("b: N * T")
593     .Attr("S: type")
594     .Attr("T: type")
595     .Attr("N: int >= 0")
596     .SetShapeFn(shape_inference::UnknownShape);
597 
598 REGISTER_OP("InPolymorphicTwice")
599     .Input("a: N * T")
600     .Input("b: M * T")
601     .Attr("T: type = DT_INT32")
602     .Attr("N: int >= 0")
603     .Attr("M: int >= 0")
604     .SetShapeFn(shape_inference::UnknownShape);
605 
606 REGISTER_OP("NIntsOut")
607     .Output("a: N * int32")
608     .Attr("N: int >= 2")
609     .SetShapeFn(shape_inference::UnknownShape);
610 
611 REGISTER_OP("NIntsOutDefault")
612     .Output("a: N * int32")
613     .Attr("N: int >= 2 = 3")
614     .SetShapeFn(shape_inference::UnknownShape);
615 
616 REGISTER_OP("NPolymorphicOut")
617     .Output("a: N * T")
618     .Attr("T: type")
619     .Attr("N: int >= 2")
620     .SetShapeFn(shape_inference::UnknownShape);
621 
622 REGISTER_OP("NPolymorphicOutDefault")
623     .Output("a: N * T")
624     .Attr("T: type = DT_BOOL")
625     .Attr("N: int >= 2 = 2")
626     .SetShapeFn(shape_inference::UnknownShape);
627 
628 REGISTER_OP("NPolymorphicRestrictOut")
629     .Output("a: N * T")
630     .Attr("T: {string, bool}")
631     .Attr("N: int >= 2")
632     .SetShapeFn(shape_inference::UnknownShape);
633 
634 REGISTER_OP("RefIn")
635     .Input("a: Ref(T)")
636     .Attr("T: type")
637     .SetShapeFn(shape_inference::UnknownShape);
638 
639 REGISTER_OP("TwoRefsIn")
640     .Input("a: Ref(T)")
641     .Input("b: Ref(T)")
642     .Attr("T: type")
643     .SetShapeFn(shape_inference::UnknownShape);
644 
645 REGISTER_OP("RefOut")
646     .Output("a: Ref(T)")
647     .Attr("T: type")
648     .SetShapeFn(shape_inference::UnknownShape);
649 
650 REGISTER_OP("SimpleStruct")
651     .Output("a: n_a * int32")
652     .Attr("n_a: int >= 0")
653     .SetShapeFn(shape_inference::UnknownShape);
654 
655 REGISTER_OP("MixedStruct")
656     .Output("a: n_a * int32")
657     .Output("b: float")
658     .Attr("n_a: int >= 0")
659     .SetShapeFn(shape_inference::UnknownShape);
660 
661 REGISTER_OP("ComplexStruct")
662     .Output("a: n_a * int32")
663     .Output("b: n_b * int64")
664     .Output("c: t_c")
665     .Attr("n_a: int >= 0")
666     .Attr("n_b: int >= 0")
667     .Attr("t_c: list(type) >= 0")
668     .SetShapeFn(shape_inference::UnknownShape);
669 
670 // An op which returns its own device placement as a string, useful for testing
671 // where ops get placed.
672 REGISTER_OP("DevicePlacementOp")
673     .Output("device: string")
674     .SetIsStateful()
675     .SetShapeFn(shape_inference::ScalarShape);
676 
677 class DevicePlacementOp : public OpKernel {
678  public:
679   using OpKernel::OpKernel;
680 
Compute(OpKernelContext * ctx)681   void Compute(OpKernelContext* ctx) override {
682     Tensor* output;
683     OP_REQUIRES_OK(ctx,
684                    ctx->allocate_output("device", TensorShape({}), &output));
685     output->scalar<tstring>()() = ctx->device()->name();
686   }
687 };
688 
689 REGISTER_KERNEL_BUILDER(Name("DevicePlacementOp").Device(DEVICE_CPU),
690                         DevicePlacementOp);
691 REGISTER_KERNEL_BUILDER(Name("DevicePlacementOp").Device(DEVICE_GPU),
692                         DevicePlacementOp);
693 
694 // An op which returns the dtype of the tensor it was passed in. It expects
695 // DT_UINT8.
696 REGISTER_OP("DtypeWithDefaultOp")
697     .Input("in: T")
698     .Attr("T: type = DT_UINT8")
699     .Output("dtype: string")
700     .SetIsStateful()
701     .SetShapeFn(shape_inference::ScalarShape);
702 
703 class DTypeWithDefaultOp : public OpKernel {
704  public:
705   using OpKernel::OpKernel;
706 
Compute(OpKernelContext * ctx)707   void Compute(OpKernelContext* ctx) override {
708     const Tensor& input = ctx->input(0);
709     Tensor* output;
710     OP_REQUIRES_OK(ctx,
711                    ctx->allocate_output("dtype", TensorShape({}), &output));
712     output->scalar<tstring>()() = tensorflow::DataTypeString(input.dtype());
713   }
714 };
715 
716 REGISTER_KERNEL_BUILDER(Name("DtypeWithDefaultOp").Device(DEVICE_CPU),
717                         DTypeWithDefaultOp);
718 }  // end namespace tensorflow
719