1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/reduction_ops_common.h"
17
18 #include "tensorflow/core/lib/strings/str_util.h"
19
20 namespace tensorflow {
21
out_reshape() const22 TensorShape ReductionHelper::out_reshape() const {
23 TensorShape shape;
24 for (auto size : out_reshape_) shape.AddDim(size);
25 return shape;
26 }
27
28 // The final output shape must be allocated with this shape.
out_shape() const29 TensorShape ReductionHelper::out_shape() const {
30 TensorShape shape;
31 for (auto size : out_shape_) shape.AddDim(size);
32 return shape;
33 }
34
shuffled_shape()35 TensorShape ReductionHelper::shuffled_shape() {
36 const int dims = data_reshape_.size();
37 TensorShape shape;
38 for (int i = reduce_first_axis_; i < dims; i += 2) {
39 shape.AddDim(data_reshape_[i]);
40 }
41 for (int i = !reduce_first_axis_; i < dims; i += 2) {
42 shape.AddDim(data_reshape_[i]);
43 }
44 return shape;
45 }
46
permutation()47 gtl::InlinedVector<int32, 8> ReductionHelper::permutation() {
48 const int dims = data_reshape_.size();
49 const int unreduced_dims = (dims + !reduce_first_axis_) / 2;
50 gtl::InlinedVector<int32, 8> perm(dims);
51 for (int i = 0; i < unreduced_dims; i++) {
52 perm[i] = 2 * i + reduce_first_axis_;
53 }
54 for (int i = unreduced_dims; i < dims; i++) {
55 perm[i] = 2 * (i - unreduced_dims) + !reduce_first_axis_;
56 }
57 return perm;
58 }
59
60 template <typename Tperm>
SimplifyHelper(const Tensor & data,const Tensor & axis,gtl::InlinedVector<bool,4> & bitmap)61 Status SimplifyHelper(const Tensor& data, const Tensor& axis,
62 gtl::InlinedVector<bool, 4>& bitmap) {
63 auto axis_vec = axis.flat<Tperm>();
64 for (int64 i = 0; i < axis.NumElements(); ++i) {
65 Tperm index = axis_vec(i);
66 if (index < -data.dims() || index >= data.dims()) {
67 return errors::InvalidArgument("Invalid reduction dimension (", index,
68 " for input with ", data.dims(),
69 " dimension(s)");
70 }
71 index = (index + data.dims()) % data.dims();
72 if (bitmap[index]) {
73 return errors::InvalidArgument(
74 "Invalid reduction arguments: Axes contains duplicate dimension: ",
75 index);
76 }
77 bitmap[index] = true;
78 }
79 return Status::OK();
80 }
81
Simplify(const Tensor & data,const Tensor & axis,const bool keep_dims)82 Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
83 const bool keep_dims) {
84 // bitmap[i] indicates whether to reduce data along i-th axis.
85 gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
86 if (axis.dtype() == DT_INT32) {
87 TF_RETURN_IF_ERROR(SimplifyHelper<int32>(data, axis, bitmap));
88 } else {
89 TF_RETURN_IF_ERROR(SimplifyHelper<int64>(data, axis, bitmap));
90 }
91 // Output tensor's dim sizes.
92 out_shape_.clear();
93 for (int i = 0; i < data.dims(); ++i) {
94 if (!bitmap[i]) {
95 // If we are not reducing along dimension i.
96 out_shape_.push_back(data.dim_size(i));
97 } else if (keep_dims) {
98 // We are reducing along dimension i, but we want to keep the
99 // same number of dimensions, so we set the dimension of i to
100 // '1'.
101 out_shape_.push_back(1);
102 }
103 }
104
105 // Depending on bitmap[i] and bitmap[i-1], we can collapse axis of
106 // the input data before doing the reduction on the resulting
107 // tensor. The shape of the reduction is a reshape of the final
108 // output.
109
110 // We'll skip the leading 1s.
111 int dim_index = 0;
112 for (; dim_index < data.dims(); ++dim_index) {
113 if (data.dim_size(dim_index) != 1) break;
114 }
115 if (dim_index >= data.dims()) {
116 // Special case. The input is essentially a scalar.
117 reduce_first_axis_ = true;
118 } else {
119 // Starting from the (dim_index)-th dimension, dimensions
120 // alternates between runs that need to be reduced and runs that
121 // don't.
122 //
123 // NOTE: If a dimension has size 1, we group it as the current
124 // run so that we can minimize the number of runs.
125 //
126 // E.g., when we want to reduce a tensor of shape [2, 1, 3, 1,
127 // 5] by axes = [1, 4], we should treat the tensor as a [6, 5]
128 // and reduce by axes = [1] (i.e., the output is shape [6]).
129 reduce_first_axis_ = bitmap[dim_index];
130 data_reshape_.push_back(data.dim_size(dim_index));
131 ++dim_index;
132 for (; dim_index < data.dims(); ++dim_index) {
133 const auto size = data.dim_size(dim_index);
134 if (size == 1) {
135 bitmap[dim_index] = bitmap[dim_index - 1];
136 }
137 if (bitmap[dim_index - 1] != bitmap[dim_index]) {
138 // Starts a new run of reduce or !reduce.
139 data_reshape_.push_back(size);
140 } else {
141 // Continue a run of reduce or !reduce.
142 data_reshape_.back() *= size;
143 }
144 }
145 // If reduce_first_axis_ is true (input's dimension 0, 2, 4, etc
146 // are reduced), data_reshape_[1, 3, 5, ...] is out_reshape_,
147 // otherwise, data_reshape_[0, 2, 4, ...] is.
148 for (size_t i = reduce_first_axis_ ? 1 : 0; i < data_reshape_.size();
149 i += 2) {
150 out_reshape_.push_back(data_reshape_[i]);
151 }
152 }
153
154 VLOG(1) << "data reshape: " << absl::StrJoin(data_reshape_, ",");
155 VLOG(1) << "out reshape: " << absl::StrJoin(out_reshape_, ",");
156 VLOG(1) << "out shape: " << absl::StrJoin(out_shape_, ",");
157 return Status::OK();
158 }
159
160 } // namespace tensorflow
161