1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/reduction_ops_common.h"
17 
18 #include "tensorflow/core/lib/strings/str_util.h"
19 
20 namespace tensorflow {
21 
out_reshape() const22 TensorShape ReductionHelper::out_reshape() const {
23   TensorShape shape;
24   for (auto size : out_reshape_) shape.AddDim(size);
25   return shape;
26 }
27 
28 // The final output shape must be allocated with this shape.
out_shape() const29 TensorShape ReductionHelper::out_shape() const {
30   TensorShape shape;
31   for (auto size : out_shape_) shape.AddDim(size);
32   return shape;
33 }
34 
shuffled_shape()35 TensorShape ReductionHelper::shuffled_shape() {
36   const int dims = data_reshape_.size();
37   TensorShape shape;
38   for (int i = reduce_first_axis_; i < dims; i += 2) {
39     shape.AddDim(data_reshape_[i]);
40   }
41   for (int i = !reduce_first_axis_; i < dims; i += 2) {
42     shape.AddDim(data_reshape_[i]);
43   }
44   return shape;
45 }
46 
permutation()47 gtl::InlinedVector<int32, 8> ReductionHelper::permutation() {
48   const int dims = data_reshape_.size();
49   const int unreduced_dims = (dims + !reduce_first_axis_) / 2;
50   gtl::InlinedVector<int32, 8> perm(dims);
51   for (int i = 0; i < unreduced_dims; i++) {
52     perm[i] = 2 * i + reduce_first_axis_;
53   }
54   for (int i = unreduced_dims; i < dims; i++) {
55     perm[i] = 2 * (i - unreduced_dims) + !reduce_first_axis_;
56   }
57   return perm;
58 }
59 
60 template <typename Tperm>
SimplifyHelper(const Tensor & data,const Tensor & axis,gtl::InlinedVector<bool,4> & bitmap)61 Status SimplifyHelper(const Tensor& data, const Tensor& axis,
62                       gtl::InlinedVector<bool, 4>& bitmap) {
63   auto axis_vec = axis.flat<Tperm>();
64   for (int64 i = 0; i < axis.NumElements(); ++i) {
65     Tperm index = axis_vec(i);
66     if (index < -data.dims() || index >= data.dims()) {
67       return errors::InvalidArgument("Invalid reduction dimension (", index,
68                                      " for input with ", data.dims(),
69                                      " dimension(s)");
70     }
71     index = (index + data.dims()) % data.dims();
72     if (bitmap[index]) {
73       return errors::InvalidArgument(
74           "Invalid reduction arguments: Axes contains duplicate dimension: ",
75           index);
76     }
77     bitmap[index] = true;
78   }
79   return Status::OK();
80 }
81 
Simplify(const Tensor & data,const Tensor & axis,const bool keep_dims)82 Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
83                                  const bool keep_dims) {
84   // bitmap[i] indicates whether to reduce data along i-th axis.
85   gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
86   if (axis.dtype() == DT_INT32) {
87     TF_RETURN_IF_ERROR(SimplifyHelper<int32>(data, axis, bitmap));
88   } else {
89     TF_RETURN_IF_ERROR(SimplifyHelper<int64>(data, axis, bitmap));
90   }
91   // Output tensor's dim sizes.
92   out_shape_.clear();
93   for (int i = 0; i < data.dims(); ++i) {
94     if (!bitmap[i]) {
95       // If we are not reducing along dimension i.
96       out_shape_.push_back(data.dim_size(i));
97     } else if (keep_dims) {
98       // We are reducing along dimension i, but we want to keep the
99       // same number of dimensions, so we set the dimension of i to
100       // '1'.
101       out_shape_.push_back(1);
102     }
103   }
104 
105   // Depending on bitmap[i] and bitmap[i-1], we can collapse axis of
106   // the input data before doing the reduction on the resulting
107   // tensor.  The shape of the reduction is a reshape of the final
108   // output.
109 
110   // We'll skip the leading 1s.
111   int dim_index = 0;
112   for (; dim_index < data.dims(); ++dim_index) {
113     if (data.dim_size(dim_index) != 1) break;
114   }
115   if (dim_index >= data.dims()) {
116     // Special case. The input is essentially a scalar.
117     reduce_first_axis_ = true;
118   } else {
119     // Starting from the (dim_index)-th dimension, dimensions
120     // alternates between runs that need to be reduced and runs that
121     // don't.
122     //
123     // NOTE: If a dimension has size 1, we group it as the current
124     // run so that we can minimize the number of runs.
125     //
126     // E.g., when we want to reduce a tensor of shape [2, 1, 3, 1,
127     // 5] by axes = [1, 4], we should treat the tensor as a [6, 5]
128     // and reduce by axes = [1] (i.e., the output is shape [6]).
129     reduce_first_axis_ = bitmap[dim_index];
130     data_reshape_.push_back(data.dim_size(dim_index));
131     ++dim_index;
132     for (; dim_index < data.dims(); ++dim_index) {
133       const auto size = data.dim_size(dim_index);
134       if (size == 1) {
135         bitmap[dim_index] = bitmap[dim_index - 1];
136       }
137       if (bitmap[dim_index - 1] != bitmap[dim_index]) {
138         // Starts a new run of reduce or !reduce.
139         data_reshape_.push_back(size);
140       } else {
141         // Continue a run of reduce or !reduce.
142         data_reshape_.back() *= size;
143       }
144     }
145     // If reduce_first_axis_ is true (input's dimension 0, 2, 4, etc
146     // are reduced), data_reshape_[1, 3, 5, ...]  is out_reshape_,
147     // otherwise, data_reshape_[0, 2, 4, ...] is.
148     for (size_t i = reduce_first_axis_ ? 1 : 0; i < data_reshape_.size();
149          i += 2) {
150       out_reshape_.push_back(data_reshape_[i]);
151     }
152   }
153 
154   VLOG(1) << "data reshape: " << absl::StrJoin(data_reshape_, ",");
155   VLOG(1) << "out  reshape: " << absl::StrJoin(out_reshape_, ",");
156   VLOG(1) << "out    shape: " << absl::StrJoin(out_shape_, ",");
157   return Status::OK();
158 }
159 
160 }  // namespace tensorflow
161