1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 19 #include "tensorflow/compiler/xla/client/xla_builder.h" 20 #include "tensorflow/core/framework/kernel_def_builder.h" 21 22 namespace tensorflow { 23 namespace { 24 25 // Local response normalization 26 class LRNOp : public XlaOpKernel { 27 public: LRNOp(OpKernelConstruction * ctx)28 explicit LRNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 29 OP_REQUIRES_OK(ctx, ctx->GetAttr("depth_radius", &depth_radius_)); 30 31 // TODO(phawkins): handle non-float types for attributes. 32 OP_REQUIRES_OK(ctx, ctx->GetAttr("bias", &bias_)); 33 OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_)); 34 OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", &beta_)); 35 } 36 Compile(XlaOpKernelContext * ctx)37 void Compile(XlaOpKernelContext* ctx) override { 38 const TensorShape in_shape = ctx->InputShape(0); 39 OP_REQUIRES(ctx, in_shape.dims() == 4, 40 errors::InvalidArgument("in must be 4-dimensional")); 41 42 xla::XlaBuilder* builder = ctx->builder(); 43 xla::XlaOp input = ctx->Input(0); 44 45 // sqr_sum[a, b, c, d] = 46 // sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) 47 // output = input / (bias + alpha * sqr_sum) ** beta 48 49 // We use a window of depth_radius_ * 2 + 1, to account for the current 50 // element and a depth_radius_ on either side. 51 auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); 52 auto converted = XlaHelpers::ConvertElementType(input, accumulation_type); 53 auto squared = xla::Mul(converted, converted); 54 auto reduce = xla::ReduceWindow( 55 squared, XlaHelpers::Zero(builder, accumulation_type), 56 *ctx->GetOrCreateAdd(accumulation_type), 57 /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, 58 /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); 59 auto sqr_sum = XlaHelpers::ConvertElementType(reduce, input_type(0)); 60 61 auto scale = xla::Pow( 62 xla::Add(xla::ConstantR0<float>(builder, bias_), 63 xla::Mul(xla::ConstantR0<float>(builder, alpha_), sqr_sum)), 64 xla::ConstantR0<float>(builder, -beta_)); 65 66 ctx->SetOutput(0, xla::Mul(input, scale)); 67 } 68 69 private: 70 int64 depth_radius_; 71 float bias_; 72 float alpha_; 73 float beta_; 74 }; 75 76 REGISTER_XLA_OP(Name("LRN"), LRNOp); 77 78 class LRNGradOp : public XlaOpKernel { 79 public: LRNGradOp(OpKernelConstruction * ctx)80 explicit LRNGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 81 OP_REQUIRES_OK(ctx, ctx->GetAttr("depth_radius", &depth_radius_)); 82 83 // TODO(phawkins): handle non-float types for attributes. 84 OP_REQUIRES_OK(ctx, ctx->GetAttr("bias", &bias_)); 85 OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_)); 86 OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", &beta_)); 87 } 88 Compile(XlaOpKernelContext * ctx)89 void Compile(XlaOpKernelContext* ctx) override { 90 const TensorShape in_grads_shape = ctx->InputShape(0); 91 const TensorShape in_image_shape = ctx->InputShape(1); 92 const TensorShape out_image_shape = ctx->InputShape(2); 93 94 OP_REQUIRES(ctx, in_grads_shape.dims() == 4 && in_image_shape.dims() == 4, 95 errors::InvalidArgument("inputs must be 4-dimensional")); 96 const int64 batch = in_grads_shape.dim_size(0); 97 const int64 rows = in_grads_shape.dim_size(1); 98 const int64 cols = in_grads_shape.dim_size(2); 99 const int64 depth = in_grads_shape.dim_size(3); 100 OP_REQUIRES( 101 ctx, in_image_shape.dim_size(0) == batch && 102 in_image_shape.dim_size(1) == rows && 103 in_image_shape.dim_size(2) == cols && 104 in_image_shape.dim_size(3) == depth && 105 out_image_shape.dim_size(0) == batch && 106 out_image_shape.dim_size(1) == rows && 107 out_image_shape.dim_size(2) == cols && 108 out_image_shape.dim_size(3) == depth, 109 errors::InvalidArgument( 110 "input_grads, input_image, and out_image should have the same " 111 "shape")); 112 113 xla::XlaBuilder* builder = ctx->builder(); 114 xla::XlaOp in_grads = ctx->Input(0); 115 xla::XlaOp in_image = ctx->Input(1); 116 xla::XlaOp out_image = ctx->Input(2); 117 118 // This code is ported from tensorflow/core/kernels/lrn_op.cc. In Python 119 // pseudo-code, the Eigen code does this for each spatial position: 120 // grads = [0.0] * depth 121 // for j in range(depth): 122 // depth_begin = max(0, j - depth_radius) 123 // depth_end = min(depth, j + depth_radius + 1) 124 // 125 // norm = 0 126 // for k in range(depth_begin, depth_end): 127 // norm += in_image[k] * in_image[k] 128 // norm = alpha * norm + bias 129 // 130 // for k in range(depth_begin, depth_end): 131 // dyi = -2.0 * alpha * beta * in_image[k] * out_image[j] / norm 132 // if k == j: 133 // dyi += norm ** (-beta) 134 // dyi *= out_grads[j] 135 // grads[k] += dyi 136 137 auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); 138 auto converted = 139 XlaHelpers::ConvertElementType(in_image, accumulation_type); 140 auto squared = xla::Mul(converted, converted); 141 auto reduce = xla::ReduceWindow( 142 squared, XlaHelpers::Zero(builder, accumulation_type), 143 *ctx->GetOrCreateAdd(accumulation_type), 144 /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, 145 /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); 146 auto sqr_sum = XlaHelpers::ConvertElementType(reduce, input_type(0)); 147 148 auto norm = 149 xla::Add(xla::ConstantR0<float>(builder, bias_), 150 xla::Mul(xla::ConstantR0<float>(builder, alpha_), sqr_sum)); 151 152 auto dy = xla::Mul( 153 xla::Mul(xla::ConstantR0<float>(builder, -2.0f * alpha_ * beta_), 154 xla::Div(out_image, norm)), 155 in_grads); 156 157 auto converted_dy = XlaHelpers::ConvertElementType(dy, accumulation_type); 158 auto dy_reduce = xla::ReduceWindow( 159 converted_dy, XlaHelpers::Zero(builder, accumulation_type), 160 *ctx->GetOrCreateAdd(accumulation_type), 161 /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, 162 /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); 163 auto dy_reduced = XlaHelpers::ConvertElementType(dy_reduce, input_type(0)); 164 165 xla::XlaOp gradients = xla::Add( 166 xla::Mul(in_image, dy_reduced), 167 xla::Mul(in_grads, 168 xla::Pow(norm, xla::ConstantR0<float>(builder, -beta_)))); 169 170 ctx->SetOutput(0, gradients); 171 } 172 173 private: 174 int64 depth_radius_; 175 float bias_; 176 float alpha_; 177 float beta_; 178 }; 179 180 REGISTER_XLA_OP(Name("LRNGrad"), LRNGradOp); 181 182 } // anonymous namespace 183 } // namespace tensorflow 184