1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific Ops for 2D convolution. 17 18 #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" 19 #include "tensorflow/compiler/tf2xla/shape_util.h" 20 #include "tensorflow/compiler/tf2xla/type_util.h" 21 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 24 #include "tensorflow/compiler/xla/client/lib/constants.h" 25 #include "tensorflow/compiler/xla/client/lib/matrix.h" 26 #include "tensorflow/compiler/xla/client/xla_builder.h" 27 #include "tensorflow/compiler/xla/literal_util.h" 28 #include "tensorflow/core/framework/bounds_check.h" 29 #include "tensorflow/core/framework/node_def_util.h" 30 #include "tensorflow/core/framework/numeric_op.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/framework/ops_util.h" 33 #include "tensorflow/core/framework/tensor.h" 34 #include "tensorflow/core/framework/tensor_shape.h" 35 #include "tensorflow/core/framework/tensor_slice.h" 36 #include "tensorflow/core/util/padding.h" 37 #include "tensorflow/core/util/tensor_format.h" 38 39 namespace tensorflow { 40 namespace { 41 42 class ConvOp : public XlaOpKernel { 43 public: ConvOp(OpKernelConstruction * ctx,int num_spatial_dims,bool depthwise)44 explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims, 45 bool depthwise) 46 : XlaOpKernel(ctx) { 47 xla::StatusOr<ConvOpAttrs> attrs = 48 ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); 49 OP_REQUIRES_OK(ctx, attrs.status()); 50 attrs_ = attrs.ValueOrDie(); 51 } 52 Compile(XlaOpKernelContext * ctx)53 void Compile(XlaOpKernelContext* ctx) override { 54 xla::StatusOr<xla::XlaOp> conv = MakeXlaForwardConvOp( 55 ctx->op_kernel().type_string(), ctx->Input(0), ctx->Input(1), attrs_); 56 OP_REQUIRES_OK(ctx, conv.status()); 57 ctx->SetOutput(0, conv.ValueOrDie()); 58 } 59 60 protected: 61 ConvOpAttrs attrs_; 62 63 private: 64 TF_DISALLOW_COPY_AND_ASSIGN(ConvOp); 65 }; 66 67 class Conv2DOp : public ConvOp { 68 public: Conv2DOp(OpKernelConstruction * ctx)69 explicit Conv2DOp(OpKernelConstruction* ctx) 70 : ConvOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {} 71 }; 72 REGISTER_XLA_OP(Name("Conv2D"), Conv2DOp); 73 74 class Conv3DOp : public ConvOp { 75 public: Conv3DOp(OpKernelConstruction * ctx)76 explicit Conv3DOp(OpKernelConstruction* ctx) 77 : ConvOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {} 78 }; 79 REGISTER_XLA_OP(Name("Conv3D"), Conv3DOp); 80 81 class DepthwiseConv2DOp : public ConvOp { 82 public: DepthwiseConv2DOp(OpKernelConstruction * ctx)83 explicit DepthwiseConv2DOp(OpKernelConstruction* ctx) 84 : ConvOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} 85 }; 86 REGISTER_XLA_OP(Name("DepthwiseConv2dNative"), DepthwiseConv2DOp); 87 88 // Backprop for input. 89 class ConvBackpropInputOp : public XlaOpKernel { 90 public: ConvBackpropInputOp(OpKernelConstruction * ctx,int num_spatial_dims,bool depthwise)91 explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims, 92 bool depthwise) 93 : XlaOpKernel(ctx) { 94 xla::StatusOr<ConvOpAttrs> attrs = 95 ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); 96 OP_REQUIRES_OK(ctx, attrs.status()); 97 attrs_ = attrs.ValueOrDie(); 98 } 99 Compile(XlaOpKernelContext * ctx)100 void Compile(XlaOpKernelContext* ctx) override { 101 TensorShape input_tensor_shape; 102 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape)); 103 xla::Shape input_shape = 104 TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape); 105 106 xla::StatusOr<xla::XlaOp> in_backprop = 107 MakeXlaBackpropInputConvOp(ctx->op_kernel().type_string(), input_shape, 108 ctx->Input(1), ctx->Input(2), attrs_); 109 OP_REQUIRES_OK(ctx, in_backprop.status()); 110 ctx->SetOutput(0, in_backprop.ValueOrDie()); 111 } 112 113 protected: 114 ConvOpAttrs attrs_; 115 116 private: 117 TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropInputOp); 118 }; 119 120 class Conv2DBackpropInputOp : public ConvBackpropInputOp { 121 public: Conv2DBackpropInputOp(OpKernelConstruction * ctx)122 explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx) 123 : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {} 124 }; 125 REGISTER_XLA_OP( 126 Name("Conv2DBackpropInput").CompileTimeConstantInput("input_sizes"), 127 Conv2DBackpropInputOp); 128 129 class Conv3DBackpropInputOp : public ConvBackpropInputOp { 130 public: Conv3DBackpropInputOp(OpKernelConstruction * ctx)131 explicit Conv3DBackpropInputOp(OpKernelConstruction* ctx) 132 : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {} 133 }; 134 REGISTER_XLA_OP( 135 Name("Conv3DBackpropInputV2").CompileTimeConstantInput("input_sizes"), 136 Conv3DBackpropInputOp); 137 138 class DepthwiseConv2DBackpropInputOp : public ConvBackpropInputOp { 139 public: DepthwiseConv2DBackpropInputOp(OpKernelConstruction * ctx)140 explicit DepthwiseConv2DBackpropInputOp(OpKernelConstruction* ctx) 141 : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} 142 }; 143 REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput") 144 .CompileTimeConstantInput("input_sizes"), 145 DepthwiseConv2DBackpropInputOp); 146 147 class ConvBackpropFilterOp : public XlaOpKernel { 148 public: ConvBackpropFilterOp(OpKernelConstruction * ctx,int num_spatial_dims,bool depthwise)149 explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims, 150 bool depthwise) 151 : XlaOpKernel(ctx) { 152 xla::StatusOr<ConvOpAttrs> attrs = 153 ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); 154 OP_REQUIRES_OK(ctx, attrs.status()); 155 attrs_ = attrs.ValueOrDie(); 156 } 157 Compile(XlaOpKernelContext * ctx)158 void Compile(XlaOpKernelContext* ctx) override { 159 TensorShape filter_tensor_shape; 160 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape)); 161 xla::Shape filter_shape = 162 TensorShapeToXLAShape(ctx->input_xla_type(0), filter_tensor_shape); 163 164 xla::StatusOr<xla::XlaOp> filter_backprop = MakeXlaBackpropFilterConvOp( 165 ctx->op_kernel().type_string(), ctx->Input(0), filter_shape, 166 ctx->Input(2), attrs_); 167 OP_REQUIRES_OK(ctx, filter_backprop.status()); 168 ctx->SetOutput(0, filter_backprop.ValueOrDie()); 169 } 170 171 protected: 172 ConvOpAttrs attrs_; 173 174 private: 175 TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropFilterOp); 176 }; 177 178 class Conv2DBackpropFilterOp : public ConvBackpropFilterOp { 179 public: Conv2DBackpropFilterOp(OpKernelConstruction * ctx)180 explicit Conv2DBackpropFilterOp(OpKernelConstruction* ctx) 181 : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) { 182 } 183 }; 184 REGISTER_XLA_OP( 185 Name("Conv2DBackpropFilter").CompileTimeConstantInput("filter_sizes"), 186 Conv2DBackpropFilterOp); 187 188 class Conv3DBackpropFilterOp : public ConvBackpropFilterOp { 189 public: Conv3DBackpropFilterOp(OpKernelConstruction * ctx)190 explicit Conv3DBackpropFilterOp(OpKernelConstruction* ctx) 191 : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) { 192 } 193 }; 194 REGISTER_XLA_OP( 195 Name("Conv3DBackpropFilterV2").CompileTimeConstantInput("filter_sizes"), 196 Conv3DBackpropFilterOp); 197 198 class DepthwiseConv2DBackpropFilterOp : public ConvBackpropFilterOp { 199 public: DepthwiseConv2DBackpropFilterOp(OpKernelConstruction * ctx)200 explicit DepthwiseConv2DBackpropFilterOp(OpKernelConstruction* ctx) 201 : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} 202 }; 203 REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter") 204 .CompileTimeConstantInput("filter_sizes"), 205 DepthwiseConv2DBackpropFilterOp); 206 207 } // namespace 208 } // namespace tensorflow 209