1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA-specific Ops for 2D convolution.
17 
18 #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
19 #include "tensorflow/compiler/tf2xla/shape_util.h"
20 #include "tensorflow/compiler/tf2xla/type_util.h"
21 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/lib/matrix.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/core/framework/bounds_check.h"
29 #include "tensorflow/core/framework/node_def_util.h"
30 #include "tensorflow/core/framework/numeric_op.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/ops_util.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/framework/tensor_slice.h"
36 #include "tensorflow/core/util/padding.h"
37 #include "tensorflow/core/util/tensor_format.h"
38 
39 namespace tensorflow {
40 namespace {
41 
42 class ConvOp : public XlaOpKernel {
43  public:
ConvOp(OpKernelConstruction * ctx,int num_spatial_dims,bool depthwise)44   explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims,
45                   bool depthwise)
46       : XlaOpKernel(ctx) {
47     xla::StatusOr<ConvOpAttrs> attrs =
48         ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
49     OP_REQUIRES_OK(ctx, attrs.status());
50     attrs_ = attrs.ValueOrDie();
51   }
52 
Compile(XlaOpKernelContext * ctx)53   void Compile(XlaOpKernelContext* ctx) override {
54     xla::StatusOr<xla::XlaOp> conv = MakeXlaForwardConvOp(
55         ctx->op_kernel().type_string(), ctx->Input(0), ctx->Input(1), attrs_);
56     OP_REQUIRES_OK(ctx, conv.status());
57     ctx->SetOutput(0, conv.ValueOrDie());
58   }
59 
60  protected:
61   ConvOpAttrs attrs_;
62 
63  private:
64   TF_DISALLOW_COPY_AND_ASSIGN(ConvOp);
65 };
66 
67 class Conv2DOp : public ConvOp {
68  public:
Conv2DOp(OpKernelConstruction * ctx)69   explicit Conv2DOp(OpKernelConstruction* ctx)
70       : ConvOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {}
71 };
72 REGISTER_XLA_OP(Name("Conv2D"), Conv2DOp);
73 
74 class Conv3DOp : public ConvOp {
75  public:
Conv3DOp(OpKernelConstruction * ctx)76   explicit Conv3DOp(OpKernelConstruction* ctx)
77       : ConvOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {}
78 };
79 REGISTER_XLA_OP(Name("Conv3D"), Conv3DOp);
80 
81 class DepthwiseConv2DOp : public ConvOp {
82  public:
DepthwiseConv2DOp(OpKernelConstruction * ctx)83   explicit DepthwiseConv2DOp(OpKernelConstruction* ctx)
84       : ConvOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}
85 };
86 REGISTER_XLA_OP(Name("DepthwiseConv2dNative"), DepthwiseConv2DOp);
87 
88 // Backprop for input.
89 class ConvBackpropInputOp : public XlaOpKernel {
90  public:
ConvBackpropInputOp(OpKernelConstruction * ctx,int num_spatial_dims,bool depthwise)91   explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims,
92                                bool depthwise)
93       : XlaOpKernel(ctx) {
94     xla::StatusOr<ConvOpAttrs> attrs =
95         ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
96     OP_REQUIRES_OK(ctx, attrs.status());
97     attrs_ = attrs.ValueOrDie();
98   }
99 
Compile(XlaOpKernelContext * ctx)100   void Compile(XlaOpKernelContext* ctx) override {
101     TensorShape input_tensor_shape;
102     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape));
103     xla::Shape input_shape =
104         TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape);
105 
106     xla::StatusOr<xla::XlaOp> in_backprop =
107         MakeXlaBackpropInputConvOp(ctx->op_kernel().type_string(), input_shape,
108                                    ctx->Input(1), ctx->Input(2), attrs_);
109     OP_REQUIRES_OK(ctx, in_backprop.status());
110     ctx->SetOutput(0, in_backprop.ValueOrDie());
111   }
112 
113  protected:
114   ConvOpAttrs attrs_;
115 
116  private:
117   TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropInputOp);
118 };
119 
120 class Conv2DBackpropInputOp : public ConvBackpropInputOp {
121  public:
Conv2DBackpropInputOp(OpKernelConstruction * ctx)122   explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx)
123       : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {}
124 };
125 REGISTER_XLA_OP(
126     Name("Conv2DBackpropInput").CompileTimeConstantInput("input_sizes"),
127     Conv2DBackpropInputOp);
128 
129 class Conv3DBackpropInputOp : public ConvBackpropInputOp {
130  public:
Conv3DBackpropInputOp(OpKernelConstruction * ctx)131   explicit Conv3DBackpropInputOp(OpKernelConstruction* ctx)
132       : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {}
133 };
134 REGISTER_XLA_OP(
135     Name("Conv3DBackpropInputV2").CompileTimeConstantInput("input_sizes"),
136     Conv3DBackpropInputOp);
137 
138 class DepthwiseConv2DBackpropInputOp : public ConvBackpropInputOp {
139  public:
DepthwiseConv2DBackpropInputOp(OpKernelConstruction * ctx)140   explicit DepthwiseConv2DBackpropInputOp(OpKernelConstruction* ctx)
141       : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}
142 };
143 REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput")
144                     .CompileTimeConstantInput("input_sizes"),
145                 DepthwiseConv2DBackpropInputOp);
146 
147 class ConvBackpropFilterOp : public XlaOpKernel {
148  public:
ConvBackpropFilterOp(OpKernelConstruction * ctx,int num_spatial_dims,bool depthwise)149   explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims,
150                                 bool depthwise)
151       : XlaOpKernel(ctx) {
152     xla::StatusOr<ConvOpAttrs> attrs =
153         ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
154     OP_REQUIRES_OK(ctx, attrs.status());
155     attrs_ = attrs.ValueOrDie();
156   }
157 
Compile(XlaOpKernelContext * ctx)158   void Compile(XlaOpKernelContext* ctx) override {
159     TensorShape filter_tensor_shape;
160     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape));
161     xla::Shape filter_shape =
162         TensorShapeToXLAShape(ctx->input_xla_type(0), filter_tensor_shape);
163 
164     xla::StatusOr<xla::XlaOp> filter_backprop = MakeXlaBackpropFilterConvOp(
165         ctx->op_kernel().type_string(), ctx->Input(0), filter_shape,
166         ctx->Input(2), attrs_);
167     OP_REQUIRES_OK(ctx, filter_backprop.status());
168     ctx->SetOutput(0, filter_backprop.ValueOrDie());
169   }
170 
171  protected:
172   ConvOpAttrs attrs_;
173 
174  private:
175   TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropFilterOp);
176 };
177 
178 class Conv2DBackpropFilterOp : public ConvBackpropFilterOp {
179  public:
Conv2DBackpropFilterOp(OpKernelConstruction * ctx)180   explicit Conv2DBackpropFilterOp(OpKernelConstruction* ctx)
181       : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {
182   }
183 };
184 REGISTER_XLA_OP(
185     Name("Conv2DBackpropFilter").CompileTimeConstantInput("filter_sizes"),
186     Conv2DBackpropFilterOp);
187 
188 class Conv3DBackpropFilterOp : public ConvBackpropFilterOp {
189  public:
Conv3DBackpropFilterOp(OpKernelConstruction * ctx)190   explicit Conv3DBackpropFilterOp(OpKernelConstruction* ctx)
191       : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {
192   }
193 };
194 REGISTER_XLA_OP(
195     Name("Conv3DBackpropFilterV2").CompileTimeConstantInput("filter_sizes"),
196     Conv3DBackpropFilterOp);
197 
198 class DepthwiseConv2DBackpropFilterOp : public ConvBackpropFilterOp {
199  public:
DepthwiseConv2DBackpropFilterOp(OpKernelConstruction * ctx)200   explicit DepthwiseConv2DBackpropFilterOp(OpKernelConstruction* ctx)
201       : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}
202 };
203 REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter")
204                     .CompileTimeConstantInput("filter_sizes"),
205                 DepthwiseConv2DBackpropFilterOp);
206 
207 }  // namespace
208 }  // namespace tensorflow
209