1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
7     http://www.apache.org/licenses/LICENSE-2.0
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
20 #define EIGEN_USE_GPU
21 #endif  // GOOGLE_CUDA
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/framework/variant.h"
29 #include "tensorflow/core/framework/variant_op_registry.h"
30 #include "tensorflow/core/kernels/concat_lib.h"
31 #include "tensorflow/core/kernels/fill_functor.h"
32 #include "tensorflow/core/lib/core/coding.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/gtl/array_slice.h"
35 #include "tensorflow/core/util/tensor_ops_util.h"
36 #include "tensorflow/core/util/util.h"
38 namespace tensorflow {
40 typedef Eigen::ThreadPoolDevice CPUDevice;
42 // Variant compatible type for a list of tensors. This is mutable but instances
43 // should never be mutated after stored in a variant tensor.
44 struct TensorList {
45  public:
TensorListTensorList46   TensorList() {}
47   TensorList(const TensorList& other);
49   static const char kTypeName[];
TypeNameTensorList50   string TypeName() const { return kTypeName; }
52   void Encode(VariantTensorData* data) const;
54   bool Decode(const VariantTensorData& data);
56   // TODO(apassos) fill this out
DebugStringTensorList57   string DebugString() const { return "TensorList"; }
59   std::vector<Tensor> tensors;
60   PartialTensorShape element_shape;
61   DataType element_dtype;
62   // The maximum allowed size of `tensors`. Defaults to -1 meaning that the size
63   // of `tensors` is unbounded.
64   int max_num_elements = -1;
65 };
67 Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out);
69 Status GetElementShapeFromInput(OpKernelContext* c,
70                                 const TensorList& tensor_list, int index,
71                                 PartialTensorShape* element_shape);
73 Status GetInputList(OpKernelContext* c, int index, const TensorList** list);
75 Status ForwardInputOrCreateNewList(OpKernelContext* c, int32 input_index,
76                                    int32 output_index,
77                                    const TensorList& input_list,
78                                    TensorList** output_list);
80 template <typename Device, typename T>
81 class TensorListStack : public OpKernel {
82  public:
83   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
84       ConstMatrixVector;
TensorListStack(OpKernelConstruction * c)85   explicit TensorListStack(OpKernelConstruction* c) : OpKernel(c) {
86     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
87     OP_REQUIRES_OK(c, c->GetAttr("num_elements", &num_elements_));
88   }
Compute(OpKernelContext * c)90   void Compute(OpKernelContext* c) override {
91     const TensorList* tensor_list = nullptr;
92     OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list));
94         c, element_dtype_ == tensor_list->element_dtype,
95         errors::InvalidArgument(
96             "Invalid data types; op elements ", DataTypeString(element_dtype_),
97             " but list elements ", DataTypeString(tensor_list->element_dtype)));
98     if (num_elements_ != -1) {
99       OP_REQUIRES(c, tensor_list->tensors.size() == num_elements_,
100                   errors::InvalidArgument(
101                       "Operation expected a list with ", num_elements_,
102                       " elements but got a list with ",
103                       tensor_list->tensors.size(), " elements."));
104     }
105     PartialTensorShape partial_element_shape;
106     OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 1,
107                                                &partial_element_shape));
108     OP_REQUIRES(
109         c,
110         partial_element_shape.IsFullyDefined() || !tensor_list->tensors.empty(),
111         errors::InvalidArgument("Tried to stack elements of an empty ",
112                                 "list with non-fully-defined element_shape: ",
113                                 partial_element_shape.DebugString()));
115     // Check that `element_shape` input tensor is compatible with the shapes of
116     // element tensors.
117     if (!tensor_list->element_shape.IsFullyDefined()) {
118       for (int i = 0; i < tensor_list->tensors.size(); ++i) {
119         const Tensor& t = tensor_list->tensors[i];
120         if (t.dtype() != DT_INVALID) {
121           PartialTensorShape tmp = partial_element_shape;
122           OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
123         }
124       }
125     }
127     // Compute the shape of the output tensor by pre-pending the leading dim to
128     // the element_shape.
129     TensorShape element_shape;
130     OP_REQUIRES(c, partial_element_shape.AsTensorShape(&element_shape),
131                 errors::InvalidArgument(
132                     "Tried to stack list which only contains uninitialized ",
133                     "tensors and has a non-fully-defined element_shape: ",
134                     partial_element_shape.DebugString()));
135     TensorShape output_shape = element_shape;
136     output_shape.InsertDim(0, tensor_list->tensors.size());
137     Tensor* output;
138     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
139     if (output->NumElements() == 0) {
140       return;
141     }
143     ConstMatrixVector inputs_flat;
144     inputs_flat.reserve(tensor_list->tensors.size());
145     Tensor zeros;
146     for (const auto& t : tensor_list->tensors) {
147       if (t.dtype() != DT_INVALID) {
148         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
149             t.shaped<T, 2>({1, t.NumElements()})));
150       } else {
151         if (!zeros.NumElements()) {
152           AllocatorAttributes attr;
153           if (element_dtype_ == DT_VARIANT) {
154             attr.set_on_host(true);
155           }
156           OP_REQUIRES_OK(
157               c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
158           functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
159                                                zeros.flat<T>());
160         }
161         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
162             const_cast<const Tensor&>(zeros).shaped<T, 2>(
163                 {1, zeros.NumElements()})));
164       }
165     }
166     auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
169     if (std::is_same<Device, Eigen::GpuDevice>::value) {
170       ConcatGPU<T>(c, inputs_flat, output, &output_flat);
171       return;
172     }
173 #endif  // GOOGLE_CUDA
174     ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
175   }
177  private:
178   int num_elements_;
179   DataType element_dtype_;
180 };
182 template <typename Device, typename T>
183 class TensorListGetItem : public OpKernel {
184  public:
TensorListGetItem(OpKernelConstruction * c)185   explicit TensorListGetItem(OpKernelConstruction* c) : OpKernel(c) {
186     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
187   }
Compute(OpKernelContext * c)189   void Compute(OpKernelContext* c) override {
190     const TensorList* l = nullptr;
191     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
192     OP_REQUIRES(c, element_dtype_ == l->element_dtype,
193                 errors::InvalidArgument("Invalid data types; op elements ",
194                                         DataTypeString(element_dtype_),
195                                         " but list elements ",
196                                         DataTypeString(l->element_dtype)));
197     int32 index = c->input(1).scalar<int32>()();
198     OP_REQUIRES(c, index < l->tensors.size(),
199                 errors::InvalidArgument("Trying to access element ", index,
200                                         " in a list with ", l->tensors.size(),
201                                         " elements."));
202     if (l->tensors[index].dtype() != DT_INVALID) {
203       c->set_output(0, l->tensors[index]);
204     } else {
205       PartialTensorShape partial_element_shape;
206       OP_REQUIRES_OK(
207           c, GetElementShapeFromInput(c, *l, 2, &partial_element_shape));
208       TensorShape element_shape;
209       // If l->element_shape and the element_shape input are both not fully
210       // defined, try to infer the shape from other list elements. This requires
211       // that all initialized list elements have the same shape.
212       // NOTE(srbs): This might be a performance bottleneck since we are
213       // iterating over the entire list here. This is necessary for feature
214       // parity with TensorArray.read. TensorArray has a mode in which all
215       // elements are required to be of the same shape, TensorList does not.
216       // In that mode TensorArray sets the array's element_shape on the first
217       // write call. We could do something similar here if needed.
218       if (!partial_element_shape.IsFullyDefined()) {
219         for (const Tensor& t : l->tensors) {
220           if (t.dtype() != DT_INVALID) {
221             PartialTensorShape tmp = partial_element_shape;
222             OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
223           }
224         }
225       }
226       OP_REQUIRES(
227           c, partial_element_shape.AsTensorShape(&element_shape),
228           errors::InvalidArgument("Trying to read an uninitialized tensor but ",
229                                   "element_shape is not fully defined: ",
230                                   partial_element_shape.DebugString(),
231                                   " and no list element is set."));
232       Tensor* result;
233       AllocatorAttributes attr;
234       if (element_dtype_ == DT_VARIANT) {
235         attr.set_on_host(true);
236       }
237       OP_REQUIRES_OK(c, c->allocate_output(0, element_shape, &result, attr));
238       functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
239                                            result->flat<T>());
240     }
241   }
243  private:
244   DataType element_dtype_;
245 };
247 template <typename Device, typename T>
248 class TensorListPopBack : public OpKernel {
249  public:
TensorListPopBack(OpKernelConstruction * c)250   explicit TensorListPopBack(OpKernelConstruction* c) : OpKernel(c) {
251     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
252   }
Compute(OpKernelContext * c)254   void Compute(OpKernelContext* c) override {
255     const TensorList* l = nullptr;
256     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
257     OP_REQUIRES(c, element_dtype_ == l->element_dtype,
258                 errors::InvalidArgument("Invalid data types; op elements ",
259                                         DataTypeString(element_dtype_),
260                                         " but list elements ",
261                                         DataTypeString(l->element_dtype)));
263     OP_REQUIRES(c, !l->tensors.empty(),
264                 errors::InvalidArgument("Trying to pop from an empty list."));
266     const Tensor& t = l->tensors.back();
267     if (t.dtype() != DT_INVALID) {
268       c->set_output(1, t);
269     } else {
270       PartialTensorShape partial_element_shape;
271       OP_REQUIRES_OK(
272           c, GetElementShapeFromInput(c, *l, 1, &partial_element_shape));
273       TensorShape element_shape;
274       OP_REQUIRES(
275           c, partial_element_shape.AsTensorShape(&element_shape),
276           errors::InvalidArgument("Trying to read an uninitialized tensor but ",
277                                   "element_shape is not fully defined.",
278                                   partial_element_shape.DebugString()));
279       Tensor* result;
280       AllocatorAttributes attr;
281       if (element_dtype_ == DT_VARIANT) {
282         attr.set_on_host(true);
283       }
284       OP_REQUIRES_OK(c, c->allocate_output(1, element_shape, &result, attr));
285       functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
286                                            result->flat<T>());
287     }
289     TensorList* output_list = nullptr;
290     OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
291     output_list->tensors.pop_back();
292   }
294  private:
295   DataType element_dtype_;
296 };
298 template <typename Device, typename T>
299 class TensorListConcat : public OpKernel {
300  public:
301   using ConstMatrixVector =
302       std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>;
TensorListConcat(OpKernelConstruction * c)303   explicit TensorListConcat(OpKernelConstruction* c) : OpKernel(c) {
304     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
305     // TODO(skyewm): the HasAttr check can be removed once the
306     // element_shape_except_first_dim attr has been checked in for 2 weeks
307     // (around 1/14/2019).
308     if (c->HasAttr("element_shape")) {
309       PartialTensorShape element_shape;
310       OP_REQUIRES_OK(c, c->GetAttr("element_shape", &element_shape));
311       if (!element_shape.unknown_rank()) {
312         element_shape_except_first_dim_ = PartialTensorShape(
313             gtl::ArraySlice<int64>(element_shape.dim_sizes()).subspan(1));
314       }
315     }
316   }
Compute(OpKernelContext * c)318   void Compute(OpKernelContext* c) override {
319     // Check that the input Variant tensor is indeed a TensorList and has the
320     // correct element type.
321     const TensorList* tensor_list = nullptr;
322     OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list));
323     OP_REQUIRES(
324         c, element_dtype_ == tensor_list->element_dtype,
325         errors::InvalidArgument(
326             "Invalid data types; op elements ", DataTypeString(element_dtype_),
327             " but list elements ", DataTypeString(tensor_list->element_dtype)));
328     // The leading dimension of all list elements if they are all the same.
329     // This is used as the leading dim of uninitialized tensors in the list
330     // if leading_dims is not provided.
331     int64 first_dim = -1;
332     if (c->num_inputs() > 1) {
333       // TensorListConcatV2
334       PartialTensorShape element_shape;
335       OP_REQUIRES_OK(
336           c, GetElementShapeFromInput(c, *tensor_list, 1, &element_shape));
337       OP_REQUIRES(c, element_shape.unknown_rank() || element_shape.dims() >= 1,
338                   errors::InvalidArgument(
339                       "Concat requires elements to be at least vectors, ",
340                       "found scalars instead."));
341       // Split `element_shape` into `first_dim` and
342       // `element_shape_except_first_dim_`.
343       first_dim = element_shape.dim_size(0);
344       element_shape_except_first_dim_ = element_shape;
345       element_shape_except_first_dim_.RemoveDim(0);
346     }
347     // If the TensorList is empty, element_shape_except_first_dim_ must be fully
348     // defined.
349     OP_REQUIRES(c,
350                 !tensor_list->tensors.empty() ||
351                     element_shape_except_first_dim_.IsFullyDefined(),
352                 errors::InvalidArgument(
353                     "All except the first dimension must be fully defined ",
354                     "when concating an empty tensor list. element_shape: ",
355                     element_shape_except_first_dim_.DebugString()));
356     // 1. Check that `element_shape_except_first_dim_` input tensor is
357     //    compatible with the shapes of element tensors.
358     // 2. Check that the elements have the same shape except the first dim.
359     // 3. If `first_dim` is known, check that it is compatible with the leading
360     //    dims of all elements.
361     // 4. If `first_dim` is unknown (-1), check whether all initialized
362     //    elements have the same leading dim and if so set `first_dim` to that
363     //    value.
364     if (!tensor_list->element_shape.IsFullyDefined()) {
365       bool check_dim = (first_dim == -1);
366       int64 inferred_first_dim = first_dim;
367       for (int i = 0; i < tensor_list->tensors.size(); ++i) {
368         const Tensor& t = tensor_list->tensors[i];
369         if (t.dtype() != DT_INVALID) {
370           PartialTensorShape tmp = element_shape_except_first_dim_;
371           OP_REQUIRES(
372               c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
373               errors::InvalidArgument("Concat saw a scalar shape at index ", i,
374                                       " but requires at least vectors."));
375           TensorShape shape_except_first_dim = TensorShape(
376               gtl::ArraySlice<int64>(t.shape().dim_sizes()).subspan(1));
377           OP_REQUIRES_OK(c, tmp.MergeWith(shape_except_first_dim,
378                                           &element_shape_except_first_dim_));
379           OP_REQUIRES(c, first_dim == -1 || first_dim == t.shape().dim_size(0),
380                       errors::InvalidArgument(
381                           "First entry of element_shape input does not match ",
382                           "the first dim of list element at index: ", i,
383                           " Expected: ", first_dim,
384                           " Actual: ", t.shape().dim_size(0)));
385           if (check_dim) {
386             if (inferred_first_dim == -1) {
387               inferred_first_dim = t.shape().dim_size(0);
388             } else if (inferred_first_dim != t.shape().dim_size(0)) {
389               inferred_first_dim = -1;
390               check_dim = false;
391             }
392           }
393         }
394       }
395       first_dim = inferred_first_dim;
396     }
397     TensorShape output_shape;
398     OP_REQUIRES(
399         c, element_shape_except_first_dim_.AsTensorShape(&output_shape),
400         errors::InvalidArgument(
401             "Trying to concat list with only uninitialized tensors ",
402             "but element_shape_except_first_dim_ is not fully defined: ",
403             element_shape_except_first_dim_.DebugString()));
404     // Build the lengths_tensor and leading dim of the output tensor by
405     // iterating over all element tensors.
406     Tensor* lengths_tensor = nullptr;
408         c,
409         c->allocate_output(
410             1, TensorShape({static_cast<int64>(tensor_list->tensors.size())}),
411             &lengths_tensor));
412     auto lengths_tensor_vec = lengths_tensor->vec<int64>();
413     int64 leading_dim = 0;
414     for (size_t i = 0; i < tensor_list->tensors.size(); i++) {
415       int64 dim;
416       if (tensor_list->tensors[i].dtype() != DT_INVALID) {
417         dim = tensor_list->tensors[i].shape().dim_size(0);
418       } else {
419         // If leading_dims is not provided or does not contain an entry for
420         // index i use the inferred `first_dim` if set.
421         if ((c->num_inputs() <= 2 || i >= c->input(2).NumElements()) &&
422             first_dim != -1) {
423           dim = first_dim;
424         } else {
425           OP_REQUIRES(c, c->num_inputs() > 2,
426                       errors::InvalidArgument(
427                           "Concating lists with uninitialized tensors is not ",
428                           "supported in this version of TensorListConcat. ",
429                           "Consider updating your GraphDef to run the newer ",
430                           "version."));
431           OP_REQUIRES(c, i < c->input(2).NumElements(),
432                       errors::InvalidArgument(
433                           "List contains uninitialized tensor at index ", i,
434                           " but leading_dims has only ",
435                           c->input(2).NumElements(), " elements."));
436           dim = c->input(2).vec<int64>()(i);
437         }
438       }
439       leading_dim += dim;
440       lengths_tensor_vec(i) = dim;
441     }
442     output_shape.InsertDim(0, leading_dim);
443     Tensor* output;
444     // Allocate the output tensor and fill it up with the concated element
445     // tensors.
446     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
447     if (output->NumElements() == 0) {
448       return;
449     }
451     ConstMatrixVector inputs_flat;
452     inputs_flat.reserve(tensor_list->tensors.size());
453     // Store the zeros tensors in a vector to prevent them from being GC'ed till
454     // concat is complete.
455     std::vector<Tensor> zeros_vec;
456     for (int i = 0; i < tensor_list->tensors.size(); i++) {
457       const Tensor& element_tensor = tensor_list->tensors[i];
458       if (element_tensor.dtype() != DT_INVALID) {
459         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
460             element_tensor.shaped<T, 2>({1, element_tensor.NumElements()})));
461       } else {
462         AllocatorAttributes attr;
463         if (element_dtype_ == DT_VARIANT) {
464           attr.set_on_host(true);
465         }
466         TensorShape element_shape = output_shape;
467         element_shape.set_dim(0, lengths_tensor_vec(i));
468         zeros_vec.emplace_back();
469         Tensor& zeros = zeros_vec.back();
470         OP_REQUIRES_OK(
471             c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
472         functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
473                                              zeros.flat<T>());
474         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
475             const_cast<const Tensor&>(zeros).shaped<T, 2>(
476                 {1, zeros.NumElements()})));
477       }
478     }
479     auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
482     if (std::is_same<Device, Eigen::GpuDevice>::value) {
483       ConcatGPU<T>(c, inputs_flat, output, &output_flat);
484       return;
485     }
486 #endif  // GOOGLE_CUDA
487     ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
488   }
490  private:
491   DataType element_dtype_;
492   PartialTensorShape element_shape_except_first_dim_;
493 };
495 template <typename Device, typename T>
496 class TensorListSplit : public OpKernel {
497  public:
TensorListSplit(OpKernelConstruction * c)498   TensorListSplit(OpKernelConstruction* c) : OpKernel(c) {}
Compute(OpKernelContext * c)500   void Compute(OpKernelContext* c) override {
501     Tensor* output_tensor;
502     AllocatorAttributes attr;
503     attr.set_on_host(true);
504     OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
505     PartialTensorShape element_shape;
506     OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape));
507     OP_REQUIRES(c, element_shape.unknown_rank() || element_shape.dims() >= 1,
508                 errors::InvalidArgument(
509                     "TensorListSplit requires element_shape to be at least of ",
510                     "rank 1, but saw: ", element_shape.DebugString()));
511     TensorList output_list;
512     const Tensor& input_tensor = c->input(0);
513     output_list.element_dtype = input_tensor.dtype();
514     OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()),
515                 errors::InvalidArgument(
516                     "Tensor must be at least a vector, but saw shape: ",
517                     input_tensor.shape().DebugString()));
518     TensorShape tensor_shape_without_first_dim(input_tensor.shape());
519     tensor_shape_without_first_dim.RemoveDim(0);
520     PartialTensorShape element_shape_without_first_dim;
521     if (!element_shape.unknown_rank()) {
522       element_shape_without_first_dim =
523           PartialTensorShape(element_shape.dim_sizes());
524       element_shape_without_first_dim.RemoveDim(0);
525     }
526     OP_REQUIRES(c,
527                 element_shape_without_first_dim.IsCompatibleWith(
528                     tensor_shape_without_first_dim),
529                 errors::InvalidArgument(
530                     "tensor shape ", input_tensor.shape().DebugString(),
531                     " is not compatible with element_shape ",
532                     element_shape.DebugString()));
533     output_list.element_shape = element_shape;
534     const Tensor& lengths = c->input(2);
535     OP_REQUIRES(c, TensorShapeUtils::IsVector(lengths.shape()),
536                 errors::InvalidArgument(
537                     "Expected lengths to be a vector, received shape: ",
538                     lengths.shape().DebugString()));
539     output_list.tensors.reserve(lengths.shape().dim_size(0));
540     int64 start = 0;
541     int64 end = 0;
542     for (int i = 0; i < lengths.shape().dim_size(0); ++i) {
543       int64 length = lengths.vec<int64>()(i);
544       OP_REQUIRES(
545           c, length >= 0,
546           errors::InvalidArgument("Invalid value in lengths: ", length));
547       end = start + length;
548       OP_REQUIRES(c, end <= input_tensor.shape().dim_size(0),
549                   errors::InvalidArgument("Attempting to slice [", start, ", ",
550                                           end, "] from tensor with length ",
551                                           input_tensor.shape().dim_size(0)));
552       Tensor tmp = input_tensor.Slice(start, end);
553       start = end;
554       // TODO(apassos) maybe not always align; but weird compiler bugs seem to
555       // prevent this.
556       Tensor aligned;
557       OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
558       aligned.flat<T>().device(c->eigen_device<Device>()) =
559           tmp.unaligned_flat<T>();
560       output_list.tensors.emplace_back(aligned);
561     }
562     OP_REQUIRES(c, end == input_tensor.shape().dim_size(0),
563                 errors::InvalidArgument(
564                     "Unused values in tensor. Length of tensor: ",
565                     input_tensor.shape().dim_size(0), " Values used: ", end));
566     output_tensor->scalar<Variant>()() = std::move(output_list);
567   }
568 };
570 template <typename Device, typename T>
571 class TensorListGather : public OpKernel {
572  public:
573   typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
574       ConstMatrixVector;
TensorListGather(OpKernelConstruction * c)575   explicit TensorListGather(OpKernelConstruction* c) : OpKernel(c) {
576     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
577   }
Compute(OpKernelContext * c)579   void Compute(OpKernelContext* c) override {
580     const TensorList* tensor_list = nullptr;
581     OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list));
582     OP_REQUIRES(
583         c, element_dtype_ == tensor_list->element_dtype,
584         errors::InvalidArgument(
585             "Invalid data types; op elements ", DataTypeString(element_dtype_),
586             " but list elements ", DataTypeString(tensor_list->element_dtype)));
587     const Tensor& indices = c->input(1);
588     PartialTensorShape partial_element_shape;
589     OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 2,
590                                                &partial_element_shape));
591     OP_REQUIRES(
592         c, partial_element_shape.IsFullyDefined() || indices.NumElements() > 0,
593         errors::InvalidArgument("Tried to gather 0-elements from "
594                                 "a list with non-fully-defined shape: ",
595                                 partial_element_shape.DebugString()));
597     // Check that `element_shape` input tensor is compatible with the shapes of
598     // element tensors.
599     if (!tensor_list->element_shape.IsFullyDefined()) {
600       for (int index = 0; index < indices.NumElements(); ++index) {
601         const int i = indices.flat<int32>()(index);
602         const Tensor& t = tensor_list->tensors[i];
603         if (t.dtype() != DT_INVALID) {
604           PartialTensorShape tmp = partial_element_shape;
605           OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
606         }
607       }
608     }
610     // Compute the shape of the output tensor by pre-pending the leading dim to
611     // the element_shape.
612     TensorShape element_shape;
613     OP_REQUIRES(
614         c, partial_element_shape.AsTensorShape(&element_shape),
615         errors::InvalidArgument("Tried to gather uninitialized tensors from a ",
616                                 "list with non-fully-defined element_shape: ",
617                                 partial_element_shape.DebugString()));
618     TensorShape output_shape = element_shape;
619     output_shape.InsertDim(0, indices.NumElements());
620     Tensor* output;
621     OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
622     if (output->NumElements() == 0) {
623       return;
624     }
626     ConstMatrixVector inputs_flat;
627     inputs_flat.reserve(indices.NumElements());
628     Tensor zeros;
629     for (int index = 0; index < indices.NumElements(); ++index) {
630       const int i = indices.flat<int32>()(index);
631       OP_REQUIRES(
632           c, i < tensor_list->tensors.size(),
633           errors::InvalidArgument("Index ", i, " out o range; list only has ",
634                                   tensor_list->tensors.size(), " elements."));
635       const Tensor& t = tensor_list->tensors[i];
636       if (t.dtype() != DT_INVALID) {
637         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
638             t.shaped<T, 2>({1, t.NumElements()})));
639       } else {
640         if (!zeros.NumElements()) {
641           AllocatorAttributes attr;
642           if (element_dtype_ == DT_VARIANT) {
643             attr.set_on_host(true);
644           }
645           OP_REQUIRES_OK(
646               c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
647           functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
648                                                zeros.flat<T>());
649         }
650         inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
651             const_cast<const Tensor&>(zeros).shaped<T, 2>(
652                 {1, zeros.NumElements()})));
653       }
654     }
655     auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
658     if (std::is_same<Device, Eigen::GpuDevice>::value) {
659       ConcatGPU<T>(c, inputs_flat, output, &output_flat);
660       return;
661     }
662 #endif  // GOOGLE_CUDA
663     ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
664   }
666  private:
667   DataType element_dtype_;
668 };
670 template <typename Device, typename T>
671 class TensorListFromTensor : public OpKernel {
672  public:
TensorListFromTensor(OpKernelConstruction * c)673   TensorListFromTensor(OpKernelConstruction* c) : OpKernel(c) {}
Compute(OpKernelContext * c)675   void Compute(OpKernelContext* c) override {
676     Tensor* output_tensor;
677     AllocatorAttributes attr;
678     attr.set_on_host(true);
679     OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
680     PartialTensorShape element_shape;
681     OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape));
682     TensorList output_list;
683     const Tensor& t = c->input(0);
684     output_list.element_dtype = t.dtype();
685     OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
686                 errors::InvalidArgument(
687                     "Tensor must be at least a vector, but saw shape: ",
688                     t.shape().DebugString()));
689     TensorShape output_shape(t.shape());
690     output_shape.RemoveDim(0);
691     OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
692                 errors::InvalidArgument(
693                     "Specified a list with shape ", element_shape.DebugString(),
694                     " from a tensor with shape ", output_shape.DebugString()));
695     output_list.element_shape = element_shape;
696     output_list.tensors.reserve(t.shape().dim_size(0));
697     for (int i = 0; i < t.shape().dim_size(0); ++i) {
698       Tensor tmp = t.Slice(i, i + 1);
699       TensorShape tmp_shape = tmp.shape();
700       tmp_shape.RemoveDim(0);
701       OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape),
702                   errors::Unknown("Unexpected shape error."));
703       // TODO(apassos) maybe not always align; but weird compiler bugs seem to
704       // prevent this.
705       Tensor aligned;
706       OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
707       aligned.flat<T>().device(c->eigen_device<Device>()) =
708           tmp.unaligned_flat<T>();
709       output_list.tensors.push_back(aligned);
710     }
711     output_tensor->scalar<Variant>()() = std::move(output_list);
712   }
713 };
715 // Scatters values in `value` into `list`. Assumes that `indices` are valid.
716 template <typename Device, typename T>
Scatter(OpKernelContext * c,const Tensor & value,const Tensor & indices,TensorList * list)717 Status Scatter(OpKernelContext* c, const Tensor& value, const Tensor& indices,
718                TensorList* list) {
719   for (int index = 0; index < indices.NumElements(); ++index) {
720     const int i = indices.flat<int32>()(index);
721     Tensor tmp = value.Slice(index, index + 1);
722     TensorShape tmp_shape = tmp.shape();
723     tmp_shape.RemoveDim(0);
724     if (!tmp.CopyFrom(tmp, tmp_shape)) {
725       return errors::Unknown("Unexpected shape error.");
726     }
727     // TODO(apassos) maybe not always align; but weird compiler bugs seem to
728     // prevent this.
729     Tensor aligned;
730     TF_RETURN_IF_ERROR(c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
731     // TODO(apassos) do all slices in a single kernel invocation instead of
732     // many small ones.
733     aligned.flat<T>().device(c->eigen_device<Device>()) =
734         tmp.unaligned_flat<T>();
735     std::swap(list->tensors[i], aligned);
736   }
737   return Status::OK();
738 }
740 template <typename Device, typename T>
741 class TensorListScatterIntoExistingList : public OpKernel {
742  public:
TensorListScatterIntoExistingList(OpKernelConstruction * c)743   TensorListScatterIntoExistingList(OpKernelConstruction* c) : OpKernel(c) {}
Compute(OpKernelContext * c)745   void Compute(OpKernelContext* c) override {
746     const TensorList* l = nullptr;
747     OP_REQUIRES_OK(c, GetInputList(c, 0, &l));
748     const Tensor& input_tensor = c->input(1);
749     const Tensor& indices = c->input(2);
751     // Check that inputs are valid.
752     OP_REQUIRES(c, input_tensor.dtype() == l->element_dtype,
753                 errors::InvalidArgument(
754                     "Invalid data types; input tensor type: ",
755                     DataTypeString(input_tensor.dtype()),
756                     " list element_type: ", DataTypeString(l->element_dtype)));
757     OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()),
758                 errors::InvalidArgument(
759                     "Tensor must be at least a vector, but saw shape: ",
760                     input_tensor.shape().DebugString()));
761     OP_REQUIRES(c, TensorShapeUtils::IsVector(indices.shape()),
762                 errors::InvalidArgument(
763                     "Expected indices to be a vector, but received shape: ",
764                     indices.shape().DebugString()));
765     OP_REQUIRES(
766         c, indices.NumElements() == input_tensor.shape().dim_size(0),
767         errors::InvalidArgument(
768             "Expected len(indices) == tensor.shape[0], but saw: ",
769             indices.NumElements(), " vs. ", input_tensor.shape().dim_size(0)));
771     // Resize the list if needed to accommodate all indices.
772     TensorList* output_list = nullptr;
773     OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list));
774     const auto indices_vec = indices.vec<int32>();
775     int32 max_index =
776         (indices.NumElements() == 0)
777             ? -1
778             : *std::max_element(indices_vec.data(),
779                                 indices_vec.data() + indices.NumElements());
780     if (max_index + 1 > output_list->tensors.size()) {
781       output_list->tensors.resize(max_index + 1);
782     }
784     // Scatter the values.
785     OP_REQUIRES_OK(c,
786                    Scatter<Device, T>(c, input_tensor, indices, output_list));
787   }
788 };
790 template <typename Device, typename T>
791 class TensorListScatter : public OpKernel {
792  public:
TensorListScatter(OpKernelConstruction * c)793   TensorListScatter(OpKernelConstruction* c) : OpKernel(c) {}
Compute(OpKernelContext * c)795   void Compute(OpKernelContext* c) override {
796     Tensor* output_tensor;
797     AllocatorAttributes attr;
798     attr.set_on_host(true);
799     OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
800     Tensor indices = c->input(1);
801     PartialTensorShape element_shape;
802     OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(2), &element_shape));
803     // TensorListScatterV2 passes the num_elements input, TensorListScatter does
804     // not.
805     int num_elements = c->num_inputs() >= 4 ? c->input(3).scalar<int>()() : -1;
806     OP_REQUIRES(c, num_elements >= -1,
807                 errors::InvalidArgument(
808                     "TensorListScatter expects num_elements >= -1, found: ",
809                     num_elements));
810     TensorList output_list;
811     const Tensor& input_tensor = c->input(0);
812     output_list.element_dtype = input_tensor.dtype();
813     OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()),
814                 errors::InvalidArgument(
815                     "Tensor must be at least a vector, but saw shape: ",
816                     input_tensor.shape().DebugString()));
817     TensorShape output_shape(input_tensor.shape());
818     output_shape.RemoveDim(0);
819     OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
820                 errors::InvalidArgument(
821                     "Specified a list with shape ", element_shape.DebugString(),
822                     " from a tensor with shape ", output_shape.DebugString()));
823     output_list.element_shape = element_shape;
825     OP_REQUIRES(c, indices.NumElements() == input_tensor.shape().dim_size(0),
826                 errors::InvalidArgument(
827                     "Invalid number of rows in input tensor. Expected: ",
828                     indices.NumElements(),
829                     " Actual: ", input_tensor.shape().dim_size(0)));
831     // Validate indices and resize output_list.tensors to fit the highest index.
832     {
833       int highest_index = -1;
834       for (int index = 0; index < indices.NumElements(); ++index) {
835         const int i = indices.flat<int32>()(index);
836         OP_REQUIRES(
837             c, i >= 0,
838             errors::InvalidArgument(
839                 "Indices in TensorListScatter must all be non-negative."));
840         OP_REQUIRES(c, num_elements == -1 || i < num_elements,
841                     errors::InvalidArgument(
842                         "TensorListScatter: Trying to scatter at index ", i,
843                         " in list with size ", num_elements));
844         if (i > highest_index) {
845           highest_index = i;
846         }
847       }
848       output_list.tensors.resize(std::max(highest_index + 1, num_elements),
849                                  Tensor(DT_INVALID));
850     }
852     OP_REQUIRES_OK(c,
853                    Scatter<Device, T>(c, input_tensor, indices, &output_list));
854     output_tensor->scalar<Variant>()() = std::move(output_list);
855   }
856 };
858 template <typename Device>
TensorListBinaryAdd(OpKernelContext * c,const TensorList & a,const TensorList & b,TensorList * out)859 Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a,
860                            const TensorList& b, TensorList* out) {
861   if (a.element_dtype != b.element_dtype) {
862     return errors::InvalidArgument(
863         "Trying to add two lists of tensors of different dtypes. One is ",
864         DataTypeString(a.element_dtype), " and the other is ",
865         DataTypeString(b.element_dtype));
866   }
867   out->element_dtype = a.element_dtype;
868   if (!a.element_shape.IsCompatibleWith(b.element_shape)) {
869     return errors::InvalidArgument(
870         "Trying to add two lists of tensors with incompatible element shapes. "
871         "One is ",
872         a.element_shape.DebugString(), " and the other is ",
873         b.element_shape.DebugString());
874   }
877       a.element_shape.MergeWith(b.element_shape, &out->element_shape));
878   if (a.tensors.size() != b.tensors.size()) {
879     return errors::InvalidArgument(
880         "Trying to add two lists of tensors with different lengths. One is ",
881         a.tensors.size(), " and the other is ", b.tensors.size());
882   }
883   out->tensors.reserve(a.tensors.size());
884   for (int i = 0; i < a.tensors.size(); ++i) {
885     const Tensor& a_tensor = a.tensors[i];
886     const Tensor& b_tensor = b.tensors[i];
887     Tensor out_tensor;
889         BinaryAddTensors<Device>(c, a_tensor, b_tensor, &out_tensor));
890     out->tensors.push_back(out_tensor);
891   }
892   return Status::OK();
893 }
895 template <typename Device>
TensorListZerosLike(OpKernelContext * c,const TensorList & x,TensorList * y)896 Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
897                            TensorList* y) {
898   y->element_dtype = x.element_dtype;
899   y->element_shape = x.element_shape;
900   y->tensors.reserve(x.tensors.size());
901   for (const Tensor& t : x.tensors) {
902     Tensor out_tensor;
903     TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(c, t, &out_tensor));
904     y->tensors.emplace_back(out_tensor);
905   }
906   return Status::OK();
907 }
909 template <typename Device, typename T>
910 class TensorListPushBackBatch : public OpKernel {
911  public:
TensorListPushBackBatch(OpKernelConstruction * c)912   explicit TensorListPushBackBatch(OpKernelConstruction* c) : OpKernel(c) {
913     OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
914   }
Compute(OpKernelContext * c)916   void Compute(OpKernelContext* c) override {
917     const Tensor& input = c->input(1);
918     OP_REQUIRES(c, element_dtype_ == input.dtype(),
919                 errors::InvalidArgument("Invalid data types; list elements ",
920                                         DataTypeString(element_dtype_),
921                                         " but tried to append ",
922                                         DataTypeString(input.dtype())));
923     OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input.shape()),
924                 errors::InvalidArgument(
925                     "Expected tensor to be at least a vector, but saw shape: ",
926                     input.shape().DebugString()));
928     const TensorShape& tls_shape = c->input(0).shape();
930     // For purposes of input forwarding, we want the least restrictive
931     // AllocatorAttributes possible.  If we need to allocate later,
932     // we'll request the DT_VARIANT be allocated on host.
933     AllocatorAttributes attr;
935     std::unique_ptr<Tensor> tls_alias = c->forward_input(
936         0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tls_shape,
937         DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr);
939     const Tensor& tls = tls_alias ? *tls_alias : c->input(0);
941     OP_REQUIRES(c, tls.dtype() == DT_VARIANT,
942                 errors::InvalidArgument(
943                     "Expected input_handles dtype to be Variant, but saw: ",
944                     DataTypeString(tls.dtype())));
945     OP_REQUIRES(c, TensorShapeUtils::IsVector(tls_shape),
946                 errors::InvalidArgument(
947                     "Expected input_handles to be a vector, but saw shape: ",
948                     tls_shape.DebugString()));
949     const int64 batch_size = tls.NumElements();
950     OP_REQUIRES(c, input.dim_size(0) == batch_size,
951                 errors::InvalidArgument(
952                     "Expected tensor.shape[0] == input_handles.size, but saw ",
953                     input.dim_size(0), " vs. ", batch_size));
954     auto tls_t = tls.vec<Variant>();
956     TensorShape input_element_shape = input.shape();
957     input_element_shape.RemoveDim(0);
958     std::vector<const TensorList*> tl_batch;
959     for (int64 b = 0; b < batch_size; ++b) {
960       const TensorList* l = tls_t(b).get<TensorList>();
961       OP_REQUIRES(c, l != nullptr,
962                   errors::InvalidArgument("Input handle at index ", b,
963                                           " is not a list. Saw: '",
964                                           tls_t(b).DebugString(), "'"));
965       OP_REQUIRES(
966           c, l->element_shape.IsCompatibleWith(input_element_shape),
967           errors::InvalidArgument(
968               "Tried to append a tensor with incompatible shape to a "
969               "list at index ",
970               b, ". Op element shape: ", input_element_shape.DebugString(),
971               " list shape: ", l->element_shape.DebugString()));
972       OP_REQUIRES(c, element_dtype_ == l->element_dtype,
973                   errors::InvalidArgument(
974                       "Invalid data type at index ", b, "; op elements ",
975                       DataTypeString(element_dtype_), " but list elements ",
976                       DataTypeString(l->element_dtype)));
977       tl_batch.push_back(l);
978     }
980     Tensor* result;
982     if (tls_alias) {
983       result = tls_alias.get();
984       c->set_output(0, *result);
985     } else {
986       // DT_VARIANT tensors always allocated on host.
987       AllocatorAttributes attr;
988       attr.set_on_host(true);
989       OP_REQUIRES_OK(
990           c, c->allocate_output(0, TensorShape{batch_size}, &result, attr));
991     }
993     if (batch_size == 0) {
994       return;
995     }
997     auto input_t = input.flat_outer_dims<T, 2>();
998     auto result_t = result->vec<Variant>();
1000     for (int64 b = 0; b < batch_size; ++b) {
1001       if (!tls_alias) {
1002         result_t(b) = *tl_batch[b];
1003       }
1004       TensorList* output = result_t(b).get<TensorList>();
1005       DCHECK(output != nullptr);
1006       Tensor* frame;
1007       PersistentTensor tmp;
1008       OP_REQUIRES_OK(c, c->allocate_persistent(
1009                             element_dtype_, input_element_shape, &tmp, &frame));
1010       if (input_element_shape.num_elements() > 0) {
1011         auto frame_t = frame->flat<T>();
1012         frame_t.device(c->eigen_device<Device>()) = input_t.template chip<0>(b);
1013       }
1014       output->tensors.push_back(std::move(*frame));
1015     }
1016   }
1018  private:
1019   DataType element_dtype_;
1020 };
1022 }  // namespace tensorflow