1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 #include "tensorflow/core/framework/kernel_def_builder.h"
21 
22 namespace tensorflow {
23 namespace {
24 
25 // Local response normalization
26 class LRNOp : public XlaOpKernel {
27  public:
LRNOp(OpKernelConstruction * ctx)28   explicit LRNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
29     OP_REQUIRES_OK(ctx, ctx->GetAttr("depth_radius", &depth_radius_));
30 
31     // TODO(phawkins): handle non-float types for attributes.
32     OP_REQUIRES_OK(ctx, ctx->GetAttr("bias", &bias_));
33     OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_));
34     OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", &beta_));
35   }
36 
Compile(XlaOpKernelContext * ctx)37   void Compile(XlaOpKernelContext* ctx) override {
38     const TensorShape in_shape = ctx->InputShape(0);
39     OP_REQUIRES(ctx, in_shape.dims() == 4,
40                 errors::InvalidArgument("in must be 4-dimensional"));
41 
42     xla::XlaBuilder* builder = ctx->builder();
43     xla::XlaOp input = ctx->Input(0);
44 
45     // sqr_sum[a, b, c, d] =
46     //    sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2)
47     // output = input / (bias + alpha * sqr_sum) ** beta
48 
49     // We use a window of depth_radius_ * 2 + 1, to account for the current
50     // element and a depth_radius_ on either side.
51     auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0));
52     auto converted = XlaHelpers::ConvertElementType(input, accumulation_type);
53     auto squared = xla::Mul(converted, converted);
54     auto reduce = xla::ReduceWindow(
55         squared, XlaHelpers::Zero(builder, accumulation_type),
56         *ctx->GetOrCreateAdd(accumulation_type),
57         /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1},
58         /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame);
59     auto sqr_sum = XlaHelpers::ConvertElementType(reduce, input_type(0));
60 
61     auto scale = xla::Pow(
62         xla::Add(xla::ConstantR0<float>(builder, bias_),
63                  xla::Mul(xla::ConstantR0<float>(builder, alpha_), sqr_sum)),
64         xla::ConstantR0<float>(builder, -beta_));
65 
66     ctx->SetOutput(0, xla::Mul(input, scale));
67   }
68 
69  private:
70   int64 depth_radius_;
71   float bias_;
72   float alpha_;
73   float beta_;
74 };
75 
76 REGISTER_XLA_OP(Name("LRN"), LRNOp);
77 
78 class LRNGradOp : public XlaOpKernel {
79  public:
LRNGradOp(OpKernelConstruction * ctx)80   explicit LRNGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
81     OP_REQUIRES_OK(ctx, ctx->GetAttr("depth_radius", &depth_radius_));
82 
83     // TODO(phawkins): handle non-float types for attributes.
84     OP_REQUIRES_OK(ctx, ctx->GetAttr("bias", &bias_));
85     OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_));
86     OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", &beta_));
87   }
88 
Compile(XlaOpKernelContext * ctx)89   void Compile(XlaOpKernelContext* ctx) override {
90     const TensorShape in_grads_shape = ctx->InputShape(0);
91     const TensorShape in_image_shape = ctx->InputShape(1);
92     const TensorShape out_image_shape = ctx->InputShape(2);
93 
94     OP_REQUIRES(ctx, in_grads_shape.dims() == 4 && in_image_shape.dims() == 4,
95                 errors::InvalidArgument("inputs must be 4-dimensional"));
96     const int64 batch = in_grads_shape.dim_size(0);
97     const int64 rows = in_grads_shape.dim_size(1);
98     const int64 cols = in_grads_shape.dim_size(2);
99     const int64 depth = in_grads_shape.dim_size(3);
100     OP_REQUIRES(
101         ctx, in_image_shape.dim_size(0) == batch &&
102                  in_image_shape.dim_size(1) == rows &&
103                  in_image_shape.dim_size(2) == cols &&
104                  in_image_shape.dim_size(3) == depth &&
105                  out_image_shape.dim_size(0) == batch &&
106                  out_image_shape.dim_size(1) == rows &&
107                  out_image_shape.dim_size(2) == cols &&
108                  out_image_shape.dim_size(3) == depth,
109         errors::InvalidArgument(
110             "input_grads, input_image, and out_image should have the same "
111             "shape"));
112 
113     xla::XlaBuilder* builder = ctx->builder();
114     xla::XlaOp in_grads = ctx->Input(0);
115     xla::XlaOp in_image = ctx->Input(1);
116     xla::XlaOp out_image = ctx->Input(2);
117 
118     // This code is ported from tensorflow/core/kernels/lrn_op.cc. In Python
119     // pseudo-code, the Eigen code does this for each spatial position:
120     // grads = [0.0] * depth
121     // for j in range(depth):
122     //   depth_begin = max(0, j - depth_radius)
123     //   depth_end = min(depth, j + depth_radius + 1)
124     //
125     //   norm = 0
126     //   for k in range(depth_begin, depth_end):
127     //     norm += in_image[k] * in_image[k]
128     //   norm = alpha * norm + bias
129     //
130     //   for k in range(depth_begin, depth_end):
131     //     dyi = -2.0 * alpha * beta * in_image[k] * out_image[j] / norm
132     //     if k == j:
133     //       dyi += norm ** (-beta)
134     //     dyi *= out_grads[j]
135     //     grads[k] += dyi
136 
137     auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0));
138     auto converted =
139         XlaHelpers::ConvertElementType(in_image, accumulation_type);
140     auto squared = xla::Mul(converted, converted);
141     auto reduce = xla::ReduceWindow(
142         squared, XlaHelpers::Zero(builder, accumulation_type),
143         *ctx->GetOrCreateAdd(accumulation_type),
144         /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1},
145         /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame);
146     auto sqr_sum = XlaHelpers::ConvertElementType(reduce, input_type(0));
147 
148     auto norm =
149         xla::Add(xla::ConstantR0<float>(builder, bias_),
150                  xla::Mul(xla::ConstantR0<float>(builder, alpha_), sqr_sum));
151 
152     auto dy = xla::Mul(
153         xla::Mul(xla::ConstantR0<float>(builder, -2.0f * alpha_ * beta_),
154                  xla::Div(out_image, norm)),
155         in_grads);
156 
157     auto converted_dy = XlaHelpers::ConvertElementType(dy, accumulation_type);
158     auto dy_reduce = xla::ReduceWindow(
159         converted_dy, XlaHelpers::Zero(builder, accumulation_type),
160         *ctx->GetOrCreateAdd(accumulation_type),
161         /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1},
162         /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame);
163     auto dy_reduced = XlaHelpers::ConvertElementType(dy_reduce, input_type(0));
164 
165     xla::XlaOp gradients = xla::Add(
166         xla::Mul(in_image, dy_reduced),
167         xla::Mul(in_grads,
168                  xla::Pow(norm, xla::ConstantR0<float>(builder, -beta_))));
169 
170     ctx->SetOutput(0, gradients);
171   }
172 
173  private:
174   int64 depth_radius_;
175   float bias_;
176   float alpha_;
177   float beta_;
178 };
179 
180 REGISTER_XLA_OP(Name("LRNGrad"), LRNGradOp);
181 
182 }  // anonymous namespace
183 }  // namespace tensorflow
184