1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_PADDING_FIFO_QUEUE_H_
17 #define TENSORFLOW_CORE_KERNELS_PADDING_FIFO_QUEUE_H_
18 
19 #include <deque>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/partial_tensor_shape.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/kernels/fifo_queue.h"
28 #include "tensorflow/core/kernels/typed_queue.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 namespace tensorflow {
34 
35 class PaddingFIFOQueue : public FIFOQueue {
36  public:
37   PaddingFIFOQueue(int32 capacity, const DataTypeVector& component_dtypes,
38                    const std::vector<PartialTensorShape>& component_shapes,
39                    const string& name);
40 
41   Status Initialize() override;
42 
43   // Implementations of QueueInterface methods --------------------------------
44 
45   void TryDequeueMany(int num_elements, OpKernelContext* ctx,
46                       bool allow_small_batch,
47                       CallbackWithTuple callback) override;
48   Status MatchesNodeDef(const NodeDef& node_def) override;
49 
50  protected:
51   Status ValidateManyTuple(const Tuple& tuple) override;
52   Status ValidateTuple(const Tuple& tuple) override;
53   Status CompatibleNodeDefShapes(const NodeDef& node_def) const;
54 
55   // Convert a list of PartialTensorShape to a list of
56   // TensorShape.
57   // Any unknown dimension sizes are converted to 0.
58   // REQUIRED: All the input shapes have well defined rank.
59   static std::vector<TensorShape> ConvertShapesPartialDimensionsToZero(
60       const gtl::ArraySlice<PartialTensorShape>& partial_shapes);
61 
62   // Sets the values in the given element to zero.
63   static Status SetElementZero(Tensor* element);
64 
65   // Copies element into the index^th slice (in the first dimension)
66   // of parent.  Allows for the parent's slice to have a larger size
67   // than the element, and copies the element into the upper left hand
68   // corner of the slice.
69   static Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent,
70                                          int index);
71 
72   std::vector<PartialTensorShape> partial_shapes_;
73 
74  private:
~PaddingFIFOQueue()75   ~PaddingFIFOQueue() override {}
76 
77   static Status GetElementComponent(const PaddingFIFOQueue::Tuple& tuple,
78                                     int component, OpKernelContext* ctx,
79                                     PersistentTensor* out_tensor);
80 
81   static Status IsSameSizeExceptZerosInFirst(const TensorShape& first,
82                                              const TensorShape& second);
83 
84   TF_DISALLOW_COPY_AND_ASSIGN(PaddingFIFOQueue);
85 };
86 
87 }  // namespace tensorflow
88 
89 #endif  // TENSORFLOW_CORE_KERNELS_PADDING_FIFO_QUEUE_H_
90