1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #if GOOGLE_CUDA
19 #define EIGEN_USE_GPU
20 #endif  // GOOGLE_CUDA
21 
22 #include "tensorflow/contrib/seq2seq/kernels/beam_search_ops.h"
23 
24 #include <memory>
25 #include <vector>
26 
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_types.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/macros.h"
36 #include "tensorflow/core/util/work_sharder.h"
37 
38 namespace tensorflow {
39 
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 typedef Eigen::GpuDevice GPUDevice;
42 
43 template <typename Device, typename T>
44 class GatherTreeOp : public OpKernel {
45  public:
GatherTreeOp(OpKernelConstruction * ctx)46   explicit GatherTreeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
47 
Compute(OpKernelContext * ctx)48   void Compute(OpKernelContext* ctx) override {
49     const Device& device = ctx->eigen_device<Device>();
50     const Tensor& step_ids = ctx->input(0);
51     const Tensor& parent_ids = ctx->input(1);
52     const Tensor& max_sequence_lengths = ctx->input(2);
53     const Tensor& end_token = ctx->input(3);
54     const TensorShape& step_ids_shape = step_ids.shape();
55     OP_REQUIRES(
56         ctx, step_ids_shape.dims() == 3,
57         errors::InvalidArgument("step_ids must be a 3-tensor, saw shape: ",
58                                 step_ids_shape.DebugString()));
59     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(max_sequence_lengths.shape()),
60                 errors::InvalidArgument(
61                     "max_sequence_lengths must be a vector, saw shape: ",
62                     max_sequence_lengths.shape().DebugString()));
63     OP_REQUIRES(
64         ctx, TensorShapeUtils::IsScalar(end_token.shape()),
65         errors::InvalidArgument("end_token must be a scalar, saw shape: ",
66                                 end_token.shape().DebugString()));
67     OP_REQUIRES(
68         ctx, step_ids_shape == parent_ids.shape(),
69         errors::InvalidArgument(
70             "step_ids.shape must match parent_ids.shape.  but shapes are: ",
71             step_ids_shape.DebugString(), " and ",
72             parent_ids.shape().DebugString()));
73     OP_REQUIRES(
74         ctx,
75         step_ids_shape.dim_size(1) == max_sequence_lengths.shape().dim_size(0),
76         errors::InvalidArgument("batch size dimensions step_ids.shape[1] and "
77                                 "max_sequence_lengths.shape[0] must match.  "
78                                 "but shapes are: ",
79                                 step_ids_shape.DebugString(), " and ",
80                                 max_sequence_lengths.shape().DebugString()));
81     Tensor* beams;
82     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams));
83     typename TTypes<T, 3>::ConstTensor step_ids_t(step_ids.tensor<T, 3>());
84     typename TTypes<T, 3>::ConstTensor parent_ids_t(parent_ids.tensor<T, 3>());
85     typename TTypes<int32>::ConstVec max_seq_lens_t =
86         max_sequence_lengths.vec<int32>();
87     typename TTypes<T>::ConstScalar end_token_t(end_token.scalar<T>());
88     typename TTypes<T, 3>::Tensor beams_t(beams->tensor<T, 3>());
89     const T end_token_value = end_token_t();
90     functor::GatherTree<Device, T>()(ctx, device, step_ids_t, parent_ids_t,
91                                      max_seq_lens_t, end_token_value, beams_t);
92   }
93 };
94 
95 #define REGISTER_KERNEL(T)                                          \
96   REGISTER_KERNEL_BUILDER(                                          \
97       Name("GatherTree").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
98       GatherTreeOp<CPUDevice, T>);
99 REGISTER_KERNEL(int32);
100 #undef REGISTER_KERNEL
101 
102 namespace functor {
103 
104 // CPU specialization
105 template <>
106 struct GatherTree<CPUDevice, int32> {
operator ()tensorflow::functor::GatherTree107   void operator()(OpKernelContext* ctx, const CPUDevice& d,
108                   TTypes<int32, 3>::ConstTensor step_ids,
109                   TTypes<int32, 3>::ConstTensor parent_ids,
110                   TTypes<int32>::ConstVec max_sequence_lengths,
111                   const int32 end_token, TTypes<int32, 3>::Tensor beams) {
112     const int32 max_time = parent_ids.dimension(0);
113     const int32 batch_size = parent_ids.dimension(1);
114     const int32 beam_width = parent_ids.dimension(2);
115     beams.setConstant(end_token);
116 
117     auto DoWork = [&, ctx, end_token](int start_batch_beam,
118                                       int limit_batch_beam) {
119       for (int32 i = start_batch_beam; i < limit_batch_beam; ++i) {
120         const int32 batch = i / beam_width;
121         const int32 beam = i % beam_width;
122         const int32 max_seq_len_b =
123             Eigen::numext::mini(max_time, max_sequence_lengths(batch));
124         if (max_seq_len_b <= 0) {
125           continue;
126         }
127         beams(max_seq_len_b - 1, batch, beam) =
128             step_ids(max_seq_len_b - 1, batch, beam);
129         int32 parent = parent_ids(max_seq_len_b - 1, batch, beam);
130         for (int32 level = max_seq_len_b - 2; level >= 0; --level) {
131           if (parent < 0 || parent > beam_width) {
132             ctx->SetStatus(
133                 errors::InvalidArgument("Saw invalid parent id ", parent,
134                                         " at (batch, time, beam) == (", batch,
135                                         ", ", level, ", ", beam, ")"));
136             return;
137           }
138           beams(level, batch, beam) = step_ids(level, batch, parent);
139           parent = parent_ids(level, batch, parent);
140         }
141         // Not necessary when using a BeamSearchDecoder, but necessary
142         // when a user feeds in possibly broken trajectory (i.e., non-eos
143         // entries in a beam following eos entries).
144         bool finished = false;
145         for (int32 time = 0; time < max_seq_len_b; ++time) {
146           if (finished) {
147             beams(time, batch, beam) = end_token;
148           } else if (beams(time, batch, beam) == end_token) {
149             finished = true;
150           }
151         }
152       }
153     };
154     // Guesstimate of cost; ~5 lookup/store/compare per inner beam
155     // traversal time step.
156     const int64 batch_beam_cost =
157         Eigen::TensorOpCost::DivCost<int32>() +
158         6 * Eigen::TensorOpCost::AddCost<int32>() +
159         2 * max_time * (5 * Eigen::TensorOpCost::AddCost<int32>());
160     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
161     Shard(worker_threads.num_threads, worker_threads.workers,
162           batch_size * beam_width, batch_beam_cost, DoWork);
163   }
164 };
165 
166 }  // namespace functor
167 
168 #if GOOGLE_CUDA
169 namespace functor {
170 #define DECLARE_GPU_SPEC(T)                                            \
171   template <>                                                          \
172   void GatherTree<GPUDevice, T>::operator()(                           \
173       OpKernelContext* ctx, const GPUDevice& d,                        \
174       typename TTypes<T, 3>::ConstTensor step_ids,                     \
175       typename TTypes<T, 3>::ConstTensor parent_ids,                   \
176       TTypes<int32>::ConstVec max_sequence_lengths, const T end_token, \
177       typename TTypes<T, 3>::Tensor beams);                            \
178   extern template struct GatherTree<GPUDevice, T>;
179 
180 DECLARE_GPU_SPEC(int32);
181 #undef DECLARE_GPU_SPEC
182 }  // end namespace functor
183 
184 #define REGISTER_GPU_KERNEL(T)                          \
185   REGISTER_KERNEL_BUILDER(Name("GatherTree")            \
186                               .Device(DEVICE_GPU)       \
187                               .TypeConstraint<T>("T")   \
188                               .HostMemory("end_token"), \
189                           GatherTreeOp<GPUDevice, T>);
190 
191 REGISTER_GPU_KERNEL(int32);
192 #undef REGISTER_GPU_KERNEL
193 #endif  // GOOGLE_CUDA
194 
195 }  // end namespace tensorflow
196