1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/type_util.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
20 #include "tensorflow/compiler/xla/client/lib/constants.h"
21 #include "tensorflow/compiler/xla/client/lib/sorting.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 #include "tensorflow/core/framework/kernel_def_builder.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/platform/macros.h"
29 
30 namespace tensorflow {
31 namespace {
32 
33 class InTopKOp : public XlaOpKernel {
34  public:
InTopKOp(OpKernelConstruction * context)35   explicit InTopKOp(OpKernelConstruction* context) : XlaOpKernel(context) {
36     OP_REQUIRES_OK(context, context->GetAttr("T", &targets_dtype_));
37     OP_REQUIRES_OK(context,
38                    DataTypeToPrimitiveType(targets_dtype_, &targets_type_));
39   }
40 
Compile(XlaOpKernelContext * context)41   void Compile(XlaOpKernelContext* context) override {
42     int64 k;
43     OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &k));
44     OP_REQUIRES(context, k >= 0,
45                 errors::InvalidArgument("Need k >= 0, got ", k));
46     const TensorShape predictions_shape = context->InputShape(0);
47     OP_REQUIRES(
48         context, predictions_shape.dims() == 2,
49         errors::InvalidArgument("predictions must be == 2-D, got shape ",
50                                 predictions_shape.DebugString()));
51     const TensorShape targets_shape = context->InputShape(1);
52     OP_REQUIRES(context, targets_shape.dims() == 1,
53                 errors::InvalidArgument("targets must be == 1-D, got shape ",
54                                         targets_shape.DebugString()));
55 
56     int64 batch_size = predictions_shape.dim_size(0);
57     OP_REQUIRES(context, batch_size == targets_shape.dim_size(0),
58                 errors::InvalidArgument(
59                     "targets must have same elements as predictions rows. Had ",
60                     targets_shape.dim_size(0), ", needed ", batch_size));
61 
62     // Given `predictions` with shape batch_size*num_classes and `target` with
63     // shape num_classes, we generate `targets_values_r1` with shape num_classes
64     // which the elements are the corresponding values of `targets` in
65     // `predictions` for each example. This step can be done using xla::Gather
66     // as well.
67     xla::XlaOp predictions_r2 = context->Input(0);
68     xla::XlaOp targets_r1 = context->Input(1);
69 
70     xla::XlaBuilder* xla_builder = context->builder();
71     xla::XlaOp iota_r1 =
72         xla::Iota(xla_builder, targets_type_, predictions_shape.dim_size(1));
73     xla::XlaOp iota_r2 = xla::Broadcast(iota_r1, {batch_size});
74 
75     xla::XlaOp eq_r2 = xla::Eq(targets_r1, iota_r2, {0});
76     xla::XlaOp zero_r0_f32 = xla::Zero(xla_builder, xla::F32);
77     xla::XlaOp zero_r2_f32 = xla::ZerosLike(predictions_r2);
78     xla::XlaOp select_r2 = xla::Select(eq_r2, predictions_r2, zero_r2_f32);
79     xla::XlaOp targets_values_r1 = xla::Reduce(
80         select_r2, zero_r0_f32,
81         xla::CreateScalarAddComputation(xla::F32, xla_builder), {1});
82 
83     // Calculate in each row of `predictions`, how many values are larger than
84     // the value of target class. Then return the result whether the count < k,
85     // which indicates the target is in topk.
86     xla::XlaOp gt_r2 = xla::Gt(predictions_r2, targets_values_r1, {0});
87     xla::XlaOp zero_r0 = xla::Zero(xla_builder, xla::S32);
88     xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, predictions_shape.dim_sizes());
89     xla::XlaOp one_r0 = xla::One(xla_builder, xla::S32);
90     xla::XlaOp one_r2 = xla::Broadcast(one_r0, predictions_shape.dim_sizes());
91     xla::XlaOp one_hot_r2 = xla::Select(gt_r2, one_r2, zero_r2);
92     xla::XlaOp num_gt_r1 = xla::Reduce(
93         one_hot_r2, zero_r0,
94         xla::CreateScalarAddComputation(xla::S32, xla_builder), {1});
95 
96     xla::XlaOp result =
97         xla::And(xla::Lt(num_gt_r1, xla::ConstantR0<int32>(xla_builder, k)),
98                  xla::IsFinite(targets_values_r1));
99 
100     context->SetOutput(0, result);
101   }
102 
103  protected:
104   DataType targets_dtype_;
105   xla::PrimitiveType targets_type_;
106 
107   TF_DISALLOW_COPY_AND_ASSIGN(InTopKOp);
108 };
109 
110 REGISTER_XLA_OP(Name("InTopKV2")
111                     .CompileTimeConstantInput("k")
112                     .TypeConstraint("T", {DT_INT32, DT_INT64}),
113                 InTopKOp);
114 
115 }  // namespace
116 }  // namespace tensorflow
117