1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/window_util.h"
17 
18 #include <vector>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/platform/logging.h"
25 
26 namespace xla {
27 namespace window_util {
28 
MakeWindow(absl::Span<const int64> sizes)29 Window MakeWindow(absl::Span<const int64> sizes) {
30   Window window;
31   for (int64 size : sizes) {
32     auto* dimension = window.add_dimensions();
33     dimension->set_size(size);
34     dimension->set_stride(1);
35     dimension->set_base_dilation(1);
36     dimension->set_window_dilation(1);
37   }
38   return window;
39 }
40 
MakeSymmetricPadding(absl::Span<const int64> sizes)41 PaddingConfig MakeSymmetricPadding(absl::Span<const int64> sizes) {
42   PaddingConfig config;
43   for (int64 size : sizes) {
44     auto* dimension = config.add_dimensions();
45     dimension->set_edge_padding_low(size);
46     dimension->set_edge_padding_high(size);
47   }
48   return config;
49 }
50 
ToString(const WindowDimension & dim)51 /* static */ string ToString(const WindowDimension& dim) {
52   using absl::StrAppend;
53   using absl::StrCat;
54   string str = StrCat("(size=", dim.size());
55   if (dim.stride() != 1) {
56     StrAppend(&str, ",stride=", dim.stride());
57   }
58   if (dim.padding_low() != 0) {
59     StrAppend(&str, ",padding_low=", dim.padding_low());
60   }
61   if (dim.padding_high() != 0) {
62     StrAppend(&str, ",padding_high=", dim.padding_high());
63   }
64   if (dim.base_dilation() != 1) {
65     StrAppend(&str, ",base_dilation=", dim.base_dilation());
66   }
67   if (dim.window_dilation() != 1) {
68     StrAppend(&str, ",window_dilation=", dim.window_dilation());
69   }
70   if (dim.window_reversal()) {
71     StrAppend(&str, ",window_reversal");
72   }
73   StrAppend(&str, ")");
74   return str;
75 }
76 
ToString(const Window & window)77 string ToString(const Window& window) {
78   using absl::StrAppend;
79   using absl::StrCat;
80 
81   string str;
82   const auto add_field =
83       [&](const char* heading,
84           std::function<string(const WindowDimension&)> format) {
85         StrAppend(&str, heading, "=");
86         const char* prefix = "";
87         for (const auto& window_dimension : window.dimensions()) {
88           StrAppend(&str, prefix, format(window_dimension));
89           prefix = "x";
90         }
91       };
92 
93   add_field("size",
94             [](const WindowDimension& dim) { return StrCat(dim.size()); });
95   if (HasStride(window)) {
96     add_field(" stride",
97               [](const WindowDimension& dim) { return StrCat(dim.stride()); });
98   }
99   if (HasPadding(window)) {
100     add_field(" pad", [](const WindowDimension& dim) {
101       return StrCat(dim.padding_low(), "_", dim.padding_high());
102     });
103   }
104   if (HasBaseDilation(window)) {
105     add_field(" lhs_dilate", [](const WindowDimension& dim) {
106       return StrCat(dim.base_dilation());
107     });
108   }
109   if (HasWindowDilation(window)) {
110     add_field(" rhs_dilate", [](const WindowDimension& dim) {
111       return StrCat(dim.window_dilation());
112     });
113   }
114   if (HasWindowReversal(window)) {
115     add_field(" rhs_reversal", [](const WindowDimension& dim) {
116       return StrCat(dim.window_reversal() ? 1 : 0);
117     });
118   }
119   return str;
120 }
121 
HasStride(const Window & window)122 bool HasStride(const Window& window) {
123   for (const auto& dim : window.dimensions()) {
124     if (dim.stride() != 1) {
125       return true;
126     }
127   }
128   return false;
129 }
130 
HasPadding(const Window & window)131 bool HasPadding(const Window& window) {
132   for (const auto& dim : window.dimensions()) {
133     if (dim.padding_low() != 0 || dim.padding_high() != 0) {
134       return true;
135     }
136   }
137   return false;
138 }
139 
HasSymmetricPadding(const Window & window)140 bool HasSymmetricPadding(const Window& window) {
141   return absl::c_all_of(window.dimensions(), [](const WindowDimension& dim) {
142     return dim.padding_low() == dim.padding_high();
143   });
144 }
145 
HasSymmetricPadding(const PaddingConfig & padding_config)146 bool HasSymmetricPadding(const PaddingConfig& padding_config) {
147   return absl::c_all_of(padding_config.dimensions(),
148                         [](const PaddingConfig::PaddingConfigDimension& dim) {
149                           return dim.edge_padding_low() ==
150                                  dim.edge_padding_high();
151                         });
152 }
153 
HasNegativePadding(const Window & window)154 bool HasNegativePadding(const Window& window) {
155   return absl::c_any_of(window.dimensions(), [](const WindowDimension& dim) {
156     return dim.padding_low() < 0 || dim.padding_high() < 0;
157   });
158 }
159 
HasBaseDilation(const Window & window)160 bool HasBaseDilation(const Window& window) {
161   for (const auto& dim : window.dimensions()) {
162     if (dim.base_dilation() != 1) {
163       return true;
164     }
165   }
166   return false;
167 }
168 
HasWindowDilation(const Window & window)169 bool HasWindowDilation(const Window& window) {
170   for (const auto& dim : window.dimensions()) {
171     if (dim.window_dilation() != 1) {
172       return true;
173     }
174   }
175   return false;
176 }
177 
HasWindowReversal(const Window & window)178 bool HasWindowReversal(const Window& window) {
179   for (const auto& dim : window.dimensions()) {
180     if (dim.window_reversal()) {
181       return true;
182     }
183   }
184   return false;
185 }
186 
AllOrNoneReversed(const Window & window)187 bool AllOrNoneReversed(const Window& window) {
188   if (window.dimensions().empty()) {
189     return true;
190   }
191   bool reversed = window.dimensions()[0].window_reversal();
192   return absl::c_all_of(window.dimensions(), [&](const WindowDimension& dim) {
193     return dim.window_reversal() == reversed;
194   });
195 }
196 
HasDilation(const Window & window)197 bool HasDilation(const Window& window) {
198   return HasBaseDilation(window) || HasWindowDilation(window);
199 }
200 
IsInactiveWindowDimension(const Window & window,int64 logical_dim)201 bool IsInactiveWindowDimension(const Window& window, int64 logical_dim) {
202   const WindowDimension& window_dim = window.dimensions(logical_dim);
203   return window_dim.size() == 1 && window_dim.stride() == 1 &&
204          window_dim.padding_low() == 0 && window_dim.padding_high() == 0;
205 }
206 
IsTrivialWindowDimension(const WindowDimension & window_dimension)207 bool IsTrivialWindowDimension(const WindowDimension& window_dimension) {
208   return window_dimension.size() == 1 && window_dimension.stride() == 1 &&
209          window_dimension.padding_low() == 0 &&
210          window_dimension.padding_high() == 0 &&
211          window_dimension.window_dilation() == 1 &&
212          window_dimension.base_dilation() == 1;
213 }
214 
DilatedBound(int64 bound,int64 dilation)215 int64 DilatedBound(int64 bound, int64 dilation) {
216   CHECK_GE(bound, 0);
217   CHECK_GE(dilation, 1);
218   if (bound == 0) {
219     return 0;
220   }
221 
222   // Suppose the array has three entries 123 and the dilation factor is 4. Then
223   // the dilated array has 9 entries 1xxx2xxx3. Here, each original entry except
224   // the last expands into 4 entries, so that is (bound - 1) * dilation. Then we
225   // add 1 to account for the final input element.
226   return (bound - 1) * dilation + 1;
227 }
228 
StridedBound(int64 bound,int64 window_size,int64 stride)229 int64 StridedBound(int64 bound, int64 window_size, int64 stride) {
230   CHECK_GE(window_size, 0);
231   CHECK_GE(bound, 0);
232   CHECK_GE(stride, 1);
233 
234   if (bound == 0 || window_size > bound) {
235     return 0;
236   }
237 
238   // Without considering stride, the maximum valid offset is bound -
239   // window_size. Taking stride into account, the valid offsets then have the
240   // form q * stride for q = 0, ..., Q such that q * stride <= bound -
241   // window_size. This implies that Q equals floor(bound - window_size /
242   // stride). There are Q + 1 valid values of q, yielding the formula below.
243   return (bound - window_size) / stride + 1;
244 }
245 
246 }  // namespace window_util
247 }  // namespace xla
248