1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/common_shape_fns.h" 17 #include "tensorflow/core/framework/op.h" 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/framework/resource_mgr.h" 20 #include "tensorflow/core/lib/core/status.h" 21 #include "tensorflow/core/public/version.h" 22 23 namespace tensorflow { 24 25 REGISTER_OP("KernelLabel") 26 .Output("result: string") 27 .SetShapeFn(shape_inference::ScalarShape); 28 29 REGISTER_OP("KernelLabelRequired") 30 .Input("input: int32") 31 .Output("result: string") __anonf24a52a80102(shape_inference::InferenceContext* c) 32 .SetShapeFn([](shape_inference::InferenceContext* c) { 33 shape_inference::ShapeHandle out; 34 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &out)); 35 c->set_output(0, c->Scalar()); 36 return Status::OK(); 37 }); 38 39 REGISTER_OP("GraphDefVersion") 40 .Output("version: int32") 41 .SetIsStateful() 42 .SetShapeFn(shape_inference::ScalarShape); 43 44 REGISTER_OP("RequiresOlderGraphVersion") 45 .Output("version: int32") 46 .SetIsStateful() __anonf24a52a80202(shape_inference::InferenceContext* c) 47 .SetShapeFn([](shape_inference::InferenceContext* c) { 48 if (c->graph_def_version() != TF_GRAPH_DEF_VERSION - 1) { 49 return errors::InvalidArgument("Wrong graph version for shape"); 50 } 51 return shape_inference::ScalarShape(c); 52 }); 53 54 REGISTER_OP("Old") 55 .SetShapeFn(shape_inference::UnknownShape) 56 .Deprecated(8, "For reasons"); 57 58 REGISTER_RESOURCE_HANDLE_OP(StubResource); 59 60 REGISTER_OP("ResourceInitializedOp") 61 .Input("resource: resource") 62 .Output("initialized: bool") 63 .SetShapeFn(shape_inference::ScalarShape); 64 65 REGISTER_OP("ResourceCreateOp") 66 .Input("resource: resource") 67 .SetShapeFn(shape_inference::UnknownShape); 68 69 REGISTER_OP("ResourceUsingOp") 70 .Input("resource: resource") 71 .SetShapeFn(shape_inference::UnknownShape); 72 73 REGISTER_OP("TestStringOutput") 74 .Input("input: float") 75 .Output("output1: float") 76 .Output("output2: string") 77 .SetShapeFn(shape_inference::UnknownShape); 78 79 REGISTER_OP("Namespace>TestStringOutput") 80 .Input("input: float") 81 .Output("output1: float") 82 .Output("output2: string") 83 .SetShapeFn(shape_inference::UnknownShape); 84 85 REGISTER_OP("TestAttr") 86 .Output("out: T") 87 .Attr("T: {float, double}") 88 .SetShapeFn(shape_inference::UnknownShape); 89 90 namespace { 91 enum KernelLabel { DEFAULT_LABEL, OVERLOAD_1_LABEL, OVERLOAD_2_LABEL }; 92 } // namespace 93 94 template <KernelLabel KL> 95 class KernelLabelOp : public OpKernel { 96 public: 97 using OpKernel::OpKernel; 98 Compute(OpKernelContext * ctx)99 void Compute(OpKernelContext* ctx) override { 100 Tensor* output; 101 OP_REQUIRES_OK(ctx, 102 ctx->allocate_output("result", TensorShape({}), &output)); 103 switch (KL) { 104 case DEFAULT_LABEL: 105 output->scalar<tstring>()() = "My label is: default"; 106 break; 107 case OVERLOAD_1_LABEL: 108 output->scalar<tstring>()() = "My label is: overload_1"; 109 break; 110 case OVERLOAD_2_LABEL: 111 output->scalar<tstring>()() = "My label is: overload_2"; 112 break; 113 } 114 } 115 }; 116 117 REGISTER_KERNEL_BUILDER(Name("KernelLabel").Device(DEVICE_CPU), 118 KernelLabelOp<DEFAULT_LABEL>); 119 REGISTER_KERNEL_BUILDER(Name("KernelLabel") 120 .Device(DEVICE_CPU) 121 .Label("overload_1"), 122 KernelLabelOp<OVERLOAD_1_LABEL>); 123 REGISTER_KERNEL_BUILDER(Name("KernelLabel") 124 .Device(DEVICE_CPU) 125 .Label("overload_2"), 126 KernelLabelOp<OVERLOAD_2_LABEL>); 127 128 // All "KernelLabelRequired" kernels have labels 129 REGISTER_KERNEL_BUILDER( 130 Name("KernelLabelRequired").Device(DEVICE_CPU).Label("overload_1"), 131 KernelLabelOp<OVERLOAD_1_LABEL>); 132 REGISTER_KERNEL_BUILDER( 133 Name("KernelLabelRequired").Device(DEVICE_CPU).Label("overload_2"), 134 KernelLabelOp<OVERLOAD_2_LABEL>); 135 136 class GraphDefVersionOp : public OpKernel { 137 public: GraphDefVersionOp(OpKernelConstruction * ctx)138 explicit GraphDefVersionOp(OpKernelConstruction* ctx) 139 : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {} 140 Compute(OpKernelContext * ctx)141 void Compute(OpKernelContext* ctx) override { 142 Tensor* output = nullptr; 143 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); 144 output->scalar<int>()() = graph_def_version_; 145 } 146 147 private: 148 const int graph_def_version_; 149 }; 150 151 REGISTER_KERNEL_BUILDER(Name("GraphDefVersion").Device(DEVICE_CPU), 152 GraphDefVersionOp); 153 154 class OldOp : public OpKernel { 155 public: OldOp(OpKernelConstruction * ctx)156 explicit OldOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 157 Compute(OpKernelContext * ctx)158 void Compute(OpKernelContext* ctx) override {} 159 }; 160 161 REGISTER_KERNEL_BUILDER(Name("Old").Device(DEVICE_CPU), OldOp); 162 163 // Stubbed-out resource to test resource handle ops. 164 class StubResource : public ResourceBase { 165 public: DebugString() const166 string DebugString() const override { return ""; } 167 }; 168 169 REGISTER_RESOURCE_HANDLE_KERNEL(StubResource); 170 171 REGISTER_KERNEL_BUILDER(Name("ResourceInitializedOp").Device(DEVICE_CPU), 172 IsResourceInitialized<StubResource>); 173 174 class ResourceCreateOp : public OpKernel { 175 public: ResourceCreateOp(OpKernelConstruction * c)176 explicit ResourceCreateOp(OpKernelConstruction* c) : OpKernel(c) {} 177 Compute(OpKernelContext * c)178 void Compute(OpKernelContext* c) override { 179 OP_REQUIRES_OK(c, 180 CreateResource(c, HandleFromInput(c, 0), new StubResource)); 181 } 182 }; 183 184 REGISTER_KERNEL_BUILDER(Name("ResourceCreateOp").Device(DEVICE_CPU), 185 ResourceCreateOp); 186 187 // Uses a ResourceHandle to check its validity. 188 class ResourceUsingOp : public OpKernel { 189 public: ResourceUsingOp(OpKernelConstruction * context)190 explicit ResourceUsingOp(OpKernelConstruction* context) : OpKernel(context) {} 191 Compute(OpKernelContext * ctx)192 void Compute(OpKernelContext* ctx) override { 193 StubResource* unused; 194 OP_REQUIRES_OK(ctx, LookupResource<StubResource>( 195 ctx, HandleFromInput(ctx, 0), &unused)); 196 } 197 }; 198 199 REGISTER_KERNEL_BUILDER(Name("ResourceUsingOp").Device(DEVICE_CPU), 200 ResourceUsingOp); 201 202 class TestAttrOp : public OpKernel { 203 public: TestAttrOp(OpKernelConstruction * ctx)204 explicit TestAttrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 205 Compute(OpKernelContext * ctx)206 void Compute(OpKernelContext* ctx) override { 207 Tensor* output; 208 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); 209 output->scalar<float>()() = 1.0; 210 } 211 }; 212 213 REGISTER_KERNEL_BUILDER( 214 Name("TestAttr").Device(DEVICE_CPU).TypeConstraint<float>("T"), TestAttrOp); 215 216 // Various test ops without kernels. These are used to test graph construction. 217 218 REGISTER_OP("A") 219 .Output("out: float32") 220 .SetShapeFn(shape_inference::UnknownShape); 221 222 REGISTER_OP("B") 223 .Output("out: float32") 224 .SetShapeFn(shape_inference::UnknownShape); 225 226 REGISTER_OP("Foo1") 227 .Input("a: float32") 228 .Input("b: int32") 229 .Input("c: int32") 230 .Output("d: float32") 231 .Output("e: int32") 232 .SetShapeFn(shape_inference::UnknownShape); 233 234 REGISTER_OP("Foo2") 235 .Input("a: float32") 236 .Input("b: string") 237 .Input("c: string") 238 .Output("d: float32") 239 .Output("e: int32") 240 .SetShapeFn(shape_inference::UnknownShape); 241 242 REGISTER_OP("Foo3") 243 .Input("a: float32") 244 .Input("b: string") 245 .Input("c: float32") 246 .Output("d: float32") 247 .Output("e: int32") 248 .SetShapeFn(shape_inference::UnknownShape); 249 250 REGISTER_OP("CopyOp").Input("a: T").Output("b: T").Attr("T: type").SetShapeFn( 251 shape_inference::UnknownShape); 252 253 REGISTER_OP("None").SetShapeFn(shape_inference::UnknownShape); 254 255 REGISTER_OP("IntOutput") 256 .Output("a: int32") 257 .SetShapeFn(shape_inference::UnknownShape); 258 259 REGISTER_OP("Int64Output") 260 .Output("out: int64") 261 .SetShapeFn(shape_inference::UnknownShape); 262 263 REGISTER_OP("RefOutput") 264 .Output("a: Ref(int32)") 265 .SetShapeFn(shape_inference::UnknownShape); 266 267 REGISTER_OP("FloatOutput") 268 .Output("a: float32") 269 .SetShapeFn(shape_inference::UnknownShape); 270 271 REGISTER_OP("TwoFloatOutputs") 272 .Output("a: float32") 273 .Output("b: float32") 274 .SetShapeFn(shape_inference::UnknownShape); 275 276 REGISTER_OP("FiveFloatOutputs") 277 .Output("a: float32") 278 .Output("b: float32") 279 .Output("c: float32") 280 .Output("d: float32") 281 .Output("e: float32") 282 .SetShapeFn(shape_inference::UnknownShape); 283 284 REGISTER_OP("RefOutputFloatOutput") 285 .Output("a: Ref(float32)") 286 .Output("b: float32") 287 .SetShapeFn(shape_inference::UnknownShape); 288 289 REGISTER_OP("RefInputFloatInput") 290 .Input("a: Ref(float)") 291 .Input("b: float") 292 .SetShapeFn(shape_inference::UnknownShape); 293 294 REGISTER_OP("IntInput") 295 .Input("a: int32") 296 .SetShapeFn(shape_inference::UnknownShape); 297 298 REGISTER_OP("IntInputIntOutput") 299 .Input("a: int32") 300 .Output("b: int32") 301 .SetShapeFn(shape_inference::UnknownShape); 302 303 REGISTER_OP("FloatInput") 304 .Input("a: float32") 305 .SetShapeFn(shape_inference::UnknownShape); 306 307 REGISTER_OP("TwoIntOutputs") 308 .Output("a: int32") 309 .Output("b: int32") 310 .SetShapeFn(shape_inference::UnknownShape); 311 312 REGISTER_OP("IntOutputFloatOutput") 313 .Output("a: int32") 314 .Output("b: float32") 315 .SetShapeFn(shape_inference::UnknownShape); 316 317 REGISTER_OP("FloatOutputStringOutput") 318 .Output("a: float32") 319 .Output("b: string") 320 .SetShapeFn(shape_inference::UnknownShape); 321 322 REGISTER_OP("TwoIntInputs") 323 .Input("a: int32") 324 .Input("b: int32") 325 .SetShapeFn(shape_inference::UnknownShape); 326 327 REGISTER_OP("TwoFloatInputs") 328 .Input("a: float32") 329 .Input("b: float32") 330 .SetShapeFn(shape_inference::UnknownShape); 331 332 REGISTER_OP("IntInputFloatInput") 333 .Input("a: int32") 334 .Input("b: float32") 335 .SetShapeFn(shape_inference::UnknownShape); 336 337 REGISTER_OP("RefInputIntInput") 338 .Input("a: Ref(int32)") 339 .Input("b: int32") 340 .SetShapeFn(shape_inference::UnknownShape); 341 342 REGISTER_OP("TwoFloatInputsFloatOutput") 343 .Input("a: float32") 344 .Input("b: float32") 345 .Output("c: float32") 346 .SetShapeFn(shape_inference::UnknownShape); 347 348 REGISTER_OP("TwoFloatInputsIntOutput") 349 .Input("a: float32") 350 .Input("b: float32") 351 .Output("c: int32") 352 .SetShapeFn(shape_inference::UnknownShape); 353 354 REGISTER_OP("RefInputFloatInputIntOutput") 355 .Input("a: Ref(float32)") 356 .Input("b: float32") 357 .Output("c: int32") 358 .SetShapeFn(shape_inference::UnknownShape); 359 360 REGISTER_OP("ListInput") 361 .Input("a: N * T") 362 .Attr("N: int >= 1") 363 .Attr("T: type") 364 .SetShapeFn(shape_inference::UnknownShape); 365 366 REGISTER_OP("ListOutput") 367 .Output("a: T") 368 .Attr("T: list(type) >= 1") 369 .SetShapeFn(shape_inference::UnknownShape); 370 371 REGISTER_OP("Unary").Input("a: T").Output("b: T").Attr("T: type").SetShapeFn( 372 shape_inference::UnknownShape); 373 374 REGISTER_OP("OpWithDefaultAttr") 375 .Output("a: int32") 376 .Attr("default_float: float = 123.0") 377 .SetShapeFn(shape_inference::UnknownShape); 378 379 REGISTER_OP("OpWithFutureDefaultAttr") 380 .SetShapeFn(shape_inference::UnknownShape); 381 382 REGISTER_OP("IntAttr") 383 .Output("out: int64") 384 .Attr("foo: int = 1") 385 .SetShapeFn(shape_inference::UnknownShape); 386 387 REGISTER_OP("StringListAttr") 388 .Attr("a: list(string)") 389 .Attr("b: string") 390 .SetShapeFn(shape_inference::UnknownShape); 391 392 REGISTER_OP("DefaultAttrs") 393 .Attr("string_val: string = 'abc'") 394 .Attr("string_list_val: list(string) = ['abc', '']") 395 .Attr("int_val: int = 123") 396 .Attr("int_list_val: list(int) = [1, 2, 3]") 397 .Attr("float_val: float = 10.0") 398 .Attr("float_list_val: list(float) = [10.0]") 399 .Attr("bool_val: bool = true") 400 .Attr("bool_list_val: list(bool) = [true, false]") 401 .Attr("type_val: type = DT_INT32") 402 .Attr("type_list_val: list(type) = [DT_INT32, DT_FLOAT]") 403 .Attr("shape_val: shape = { dim { size: 2 } dim { size: 1 } }") 404 .Attr("shape_list_val: list(shape) = [{}, { dim { size: 1} }]") 405 .Attr("tensor_val: tensor = { dtype: DT_INT32 tensor_shape: {} int_val: 1}") 406 .Attr( 407 "tensor_list_val: list(tensor) = " 408 "[{ dtype: DT_INT32 tensor_shape: {} int_val: 1}]") 409 .SetShapeFn(shape_inference::UnknownShape); 410 411 REGISTER_OP("FuncAttr") 412 .Attr("f: func") 413 .SetShapeFn(shape_inference::UnknownShape); 414 415 REGISTER_OP("FuncListAttr") 416 .Attr("f: list(func)") 417 .SetShapeFn(shape_inference::UnknownShape); 418 419 REGISTER_OP("Simple") 420 .Input("a: int32") 421 .Output("out: float") 422 .SetShapeFn(shape_inference::UnknownShape); 423 424 REGISTER_OP("OutT").Output("a: T").Attr("T: type").SetShapeFn( 425 shape_inference::UnknownShape); 426 427 REGISTER_OP("ReservedInput") 428 .Input("input: int32") 429 .SetShapeFn(shape_inference::UnknownShape); 430 431 REGISTER_OP("Polymorphic") 432 .Input("a: T") 433 .Output("out: T") 434 .Attr("T: type") 435 .SetShapeFn(shape_inference::UnknownShape); 436 437 REGISTER_OP("PolymorphicOut") 438 .Output("out: T") 439 .Attr("T: type") 440 .SetShapeFn(shape_inference::UnknownShape); 441 442 REGISTER_OP("PolymorphicDefaultOut") 443 .Output("out: T") 444 .Attr("T: type = DT_STRING") 445 .SetShapeFn(shape_inference::UnknownShape); 446 447 REGISTER_OP("Binary") 448 .Input("a: T") 449 .Input("b: T") 450 .Output("out: T") 451 .Attr("T: type") 452 .SetShapeFn(shape_inference::UnknownShape); 453 454 REGISTER_OP("Restrict") 455 .Input("a: T") 456 .Output("out: T") 457 .Attr("T: {string, bool}") 458 .SetShapeFn(shape_inference::UnknownShape); 459 460 REGISTER_OP("TypeList") 461 .Input("a: T") 462 .Attr("T: list(type) >= 0") 463 .SetShapeFn(shape_inference::UnknownShape); 464 465 REGISTER_OP("TypeListTwice") 466 .Input("a: T") 467 .Input("b: T") 468 .Attr("T: list(type) >= 0") 469 .SetShapeFn(shape_inference::UnknownShape); 470 471 REGISTER_OP("OutTypeList") 472 .Output("out: T") 473 .Attr("T: list(type) >= 0") 474 .SetShapeFn(shape_inference::UnknownShape); 475 476 REGISTER_OP("TypeListRestrict") 477 .Input("a: T") 478 .Attr("T: list({string, bool})") 479 .SetShapeFn(shape_inference::UnknownShape); 480 481 REGISTER_OP("OutTypeListRestrict") 482 .Output("out: t") 483 .Attr("t: list({string, bool})") 484 .SetShapeFn(shape_inference::UnknownShape); 485 486 REGISTER_OP("Attr").Attr("a: int").SetShapeFn(shape_inference::UnknownShape); 487 488 REGISTER_OP("AttrFloat") 489 .Attr("a: float") 490 .SetShapeFn(shape_inference::UnknownShape); 491 492 REGISTER_OP("AttrBool") 493 .Attr("a: bool") 494 .SetShapeFn(shape_inference::UnknownShape); 495 496 REGISTER_OP("AttrBoolList") 497 .Attr("a: list(bool)") 498 .SetShapeFn(shape_inference::UnknownShape); 499 500 REGISTER_OP("AttrMin") 501 .Attr("a: int >= 5") 502 .SetShapeFn(shape_inference::UnknownShape); 503 504 REGISTER_OP("AttrListMin") 505 .Attr("a: list(int) >= 2") 506 .SetShapeFn(shape_inference::UnknownShape); 507 508 REGISTER_OP("AttrEnum") 509 .Attr("a: {'apples', 'oranges'}") 510 .SetShapeFn(shape_inference::UnknownShape); 511 512 REGISTER_OP("AttrEnumList") 513 .Attr("a: list({'apples', 'oranges'})") 514 .SetShapeFn(shape_inference::UnknownShape); 515 516 REGISTER_OP("AttrShape") 517 .Attr("a: shape") 518 .SetShapeFn(shape_inference::UnknownShape); 519 520 REGISTER_OP("AttrShapeList") 521 .Attr("a: list(shape)") 522 .SetShapeFn(shape_inference::UnknownShape); 523 524 REGISTER_OP("AttrPartialShape") 525 .Attr("a: shape") 526 .SetShapeFn(shape_inference::UnknownShape); 527 528 REGISTER_OP("AttrPartialShapeList") 529 .Attr("a: list(shape)") 530 .SetShapeFn(shape_inference::UnknownShape); 531 532 REGISTER_OP("AttrDefault") 533 .Attr("a: string = 'banana'") 534 .SetShapeFn(shape_inference::UnknownShape); 535 536 REGISTER_OP("AttrListDefault") 537 .Attr("a: list(int) = [5, 15]") 538 .SetShapeFn(shape_inference::UnknownShape); 539 540 REGISTER_OP("AttrEmptyListDefault") 541 .Attr("a: list(float) = []") 542 .SetShapeFn(shape_inference::UnknownShape); 543 544 REGISTER_OP("ReservedAttr") 545 .Attr("range: int") 546 .SetShapeFn(shape_inference::UnknownShape); 547 548 REGISTER_OP("AttrTypeDefault") 549 .Input("a: T") 550 .Attr("T: type = DT_INT32") 551 .SetShapeFn(shape_inference::UnknownShape); 552 553 REGISTER_OP("AttrListTypeDefault") 554 .Input("a: N * T") 555 .Input("b: N * T") 556 .Attr("T: type = DT_INT32") 557 .Attr("N: int") 558 .SetShapeFn(shape_inference::UnknownShape); 559 560 REGISTER_OP("NIntsIn") 561 .Input("a: N * int32") 562 .Attr("N: int >= 2") 563 .SetShapeFn(shape_inference::UnknownShape); 564 565 REGISTER_OP("NPolymorphicIn") 566 .Input("a: N * T") 567 .Attr("T: type") 568 .Attr("N: int >= 2") 569 .SetShapeFn(shape_inference::UnknownShape); 570 571 REGISTER_OP("NPolymorphicRestrictIn") 572 .Input("a: N * T") 573 .Attr("T: {string, bool}") 574 .Attr("N: int >= 2") 575 .SetShapeFn(shape_inference::UnknownShape); 576 577 REGISTER_OP("NInTwice") 578 .Input("a: N * int32") 579 .Input("b: N * string") 580 .Attr("N: int >= 0") 581 .SetShapeFn(shape_inference::UnknownShape); 582 583 REGISTER_OP("NInPolymorphicTwice") 584 .Input("a: N * T") 585 .Input("b: N * T") 586 .Attr("T: type") 587 .Attr("N: int >= 0") 588 .SetShapeFn(shape_inference::UnknownShape); 589 590 REGISTER_OP("NInTwoTypeVariables") 591 .Input("a: N * S") 592 .Input("b: N * T") 593 .Attr("S: type") 594 .Attr("T: type") 595 .Attr("N: int >= 0") 596 .SetShapeFn(shape_inference::UnknownShape); 597 598 REGISTER_OP("InPolymorphicTwice") 599 .Input("a: N * T") 600 .Input("b: M * T") 601 .Attr("T: type = DT_INT32") 602 .Attr("N: int >= 0") 603 .Attr("M: int >= 0") 604 .SetShapeFn(shape_inference::UnknownShape); 605 606 REGISTER_OP("NIntsOut") 607 .Output("a: N * int32") 608 .Attr("N: int >= 2") 609 .SetShapeFn(shape_inference::UnknownShape); 610 611 REGISTER_OP("NIntsOutDefault") 612 .Output("a: N * int32") 613 .Attr("N: int >= 2 = 3") 614 .SetShapeFn(shape_inference::UnknownShape); 615 616 REGISTER_OP("NPolymorphicOut") 617 .Output("a: N * T") 618 .Attr("T: type") 619 .Attr("N: int >= 2") 620 .SetShapeFn(shape_inference::UnknownShape); 621 622 REGISTER_OP("NPolymorphicOutDefault") 623 .Output("a: N * T") 624 .Attr("T: type = DT_BOOL") 625 .Attr("N: int >= 2 = 2") 626 .SetShapeFn(shape_inference::UnknownShape); 627 628 REGISTER_OP("NPolymorphicRestrictOut") 629 .Output("a: N * T") 630 .Attr("T: {string, bool}") 631 .Attr("N: int >= 2") 632 .SetShapeFn(shape_inference::UnknownShape); 633 634 REGISTER_OP("RefIn") 635 .Input("a: Ref(T)") 636 .Attr("T: type") 637 .SetShapeFn(shape_inference::UnknownShape); 638 639 REGISTER_OP("TwoRefsIn") 640 .Input("a: Ref(T)") 641 .Input("b: Ref(T)") 642 .Attr("T: type") 643 .SetShapeFn(shape_inference::UnknownShape); 644 645 REGISTER_OP("RefOut") 646 .Output("a: Ref(T)") 647 .Attr("T: type") 648 .SetShapeFn(shape_inference::UnknownShape); 649 650 REGISTER_OP("SimpleStruct") 651 .Output("a: n_a * int32") 652 .Attr("n_a: int >= 0") 653 .SetShapeFn(shape_inference::UnknownShape); 654 655 REGISTER_OP("MixedStruct") 656 .Output("a: n_a * int32") 657 .Output("b: float") 658 .Attr("n_a: int >= 0") 659 .SetShapeFn(shape_inference::UnknownShape); 660 661 REGISTER_OP("ComplexStruct") 662 .Output("a: n_a * int32") 663 .Output("b: n_b * int64") 664 .Output("c: t_c") 665 .Attr("n_a: int >= 0") 666 .Attr("n_b: int >= 0") 667 .Attr("t_c: list(type) >= 0") 668 .SetShapeFn(shape_inference::UnknownShape); 669 670 // An op which returns its own device placement as a string, useful for testing 671 // where ops get placed. 672 REGISTER_OP("DevicePlacementOp") 673 .Output("device: string") 674 .SetIsStateful() 675 .SetShapeFn(shape_inference::ScalarShape); 676 677 class DevicePlacementOp : public OpKernel { 678 public: 679 using OpKernel::OpKernel; 680 Compute(OpKernelContext * ctx)681 void Compute(OpKernelContext* ctx) override { 682 Tensor* output; 683 OP_REQUIRES_OK(ctx, 684 ctx->allocate_output("device", TensorShape({}), &output)); 685 output->scalar<tstring>()() = ctx->device()->name(); 686 } 687 }; 688 689 REGISTER_KERNEL_BUILDER(Name("DevicePlacementOp").Device(DEVICE_CPU), 690 DevicePlacementOp); 691 REGISTER_KERNEL_BUILDER(Name("DevicePlacementOp").Device(DEVICE_GPU), 692 DevicePlacementOp); 693 694 // An op which returns the dtype of the tensor it was passed in. It expects 695 // DT_UINT8. 696 REGISTER_OP("DtypeWithDefaultOp") 697 .Input("in: T") 698 .Attr("T: type = DT_UINT8") 699 .Output("dtype: string") 700 .SetIsStateful() 701 .SetShapeFn(shape_inference::ScalarShape); 702 703 class DTypeWithDefaultOp : public OpKernel { 704 public: 705 using OpKernel::OpKernel; 706 Compute(OpKernelContext * ctx)707 void Compute(OpKernelContext* ctx) override { 708 const Tensor& input = ctx->input(0); 709 Tensor* output; 710 OP_REQUIRES_OK(ctx, 711 ctx->allocate_output("dtype", TensorShape({}), &output)); 712 output->scalar<tstring>()() = tensorflow::DataTypeString(input.dtype()); 713 } 714 }; 715 716 REGISTER_KERNEL_BUILDER(Name("DtypeWithDefaultOp").Device(DEVICE_CPU), 717 DTypeWithDefaultOp); 718 } // end namespace tensorflow 719