1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_BCAST_H_
17 #define TENSORFLOW_CORE_UTIL_BCAST_H_
18 
19 #include <algorithm>
20 
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/lib/gtl/inlined_vector.h"
23 #include "tensorflow/core/platform/macros.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 namespace tensorflow {
27 
28 // Returns the mapping from the output batch indices to the corresponding
29 // input's batch indices, given the input's "reshape" and "bcast" shapes as
30 // returned by the BCastList helper class. The i'th element denotes the
31 // (flattened) batch index of the input that must be used to compute the i'th
32 // batch output.
33 //
ComputeBatchIndices(const int64 output_batch_size,const gtl::InlinedVector<int64,4> & reshape,const gtl::InlinedVector<int64,4> & bcast,std::vector<int64> * out_indices)34 inline void ComputeBatchIndices(const int64 output_batch_size,
35                                 const gtl::InlinedVector<int64, 4>& reshape,
36                                 const gtl::InlinedVector<int64, 4>& bcast,
37                                 std::vector<int64>* out_indices) {
38   // Populates the mapping in out_indices. This algorithm is identical to
39   // the following steps:
40   //  - Reshape {0, 1, ..., input_batch_size - 1} to the input shape.
41   //  - Broadcast to the output shape.
42   //  - Reshape back to a flat 1D vector.
43   out_indices->resize(output_batch_size);
44   int64 num_output_elements = 1;
45   int64 num_input_elements = 1;
46   for (int64 i = reshape.size() - 1; i >= 0; --i) {
47     // Replicate the already populated mapping an additional (dim - 1) times.
48     // If we are broadcasting, just copy the existing mapping.
49     // Otherwise, add another dimension from the input shape.
50     const int64 dim = std::max(reshape[i], bcast[i]);
51     const int64 incr = bcast[i] > 1 ? 0 : num_input_elements;
52     for (int64 k = 0; k < (dim - 1) * num_output_elements; ++k) {
53       (*out_indices)[num_output_elements + k] = (*out_indices)[k] + incr;
54     }
55     num_output_elements *= dim;
56     num_input_elements *= reshape[i];
57   }
58 }
59 
60 template <int N>
61 class BCastList {
62  public:
63   // A vector of int64 representing the shape of tensor. The 0-th
64   // element is the outer-most dimension and the last element is the
65   // inner-most dimension. Note that we do not use TensorShape since
66   // it's more convenient to manipulate Vec directly for this module.
67   typedef gtl::InlinedVector<int64, 4> Vec;
68 
69   // Constructs all helper shapes, following the aforementioned rules.
70   //
71   // If "fewer_dims_optimization" is set to true (the default), the
72   // implementation tries to reduce intermediate dimensions needed to be more
73   // efficient.  This is transparent to the caller.
74   //
75   // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have
76   // the same number of dimensions as the larger of the two inputs.
77   //
78   // If return_flattened_batch_indices is true, the implementation will compute
79   // for each output member of the flattened output, which batch indices of
80   // each input correspond to it. This is disabled by default.
81   explicit BCastList(const Vec (&x)[N],
82                      const bool fewer_dims_optimization = true,
83                      const bool return_flattened_batch_indices = false);
~BCastList()84   ~BCastList() {}
85 
86   // Returns true iff two operands are compatible according to the
87   // broadcasting rule.
IsValid()88   bool IsValid() const { return valid_; }
IsBroadcastingRequired()89   bool IsBroadcastingRequired() const { return broadcasting_required_; }
90 
91   // If and only if IsValid(), the following fields can be used in
92   // implementing a broadcasted binary tensor operation according to
93   // the broadcasting rule.
reshape(int i)94   const Vec& reshape(int i) const { return reshape_[i]; }
bcast(int i)95   const Vec& bcast(int i) const { return bcast_[i]; }
result_shape()96   const Vec& result_shape() const { return result_; }
output_shape()97   const Vec& output_shape() const { return output_; }
grad_reduce_idx(int i)98   const Vec& grad_reduce_idx(int i) const { return grad_reduce_idx_[i]; }
output_batch_size()99   const int64 output_batch_size() const { return output_batch_size_; }
100 
101   // Returns the mapping from the flattened output batch indices to x's
102   // flattened batch indices. The result is a vector of length
103   // output_batch_size(). To compute the i'th batch output, a binary matmul-like
104   // operation should use the `x_batch_indices()[i]`th batch index of `x`.
105   // Note: Returns an empty vector if broadcasting is not required. Callers
106   // should only use this when IsBroadcastingRequired() returns true.
batch_indices(int i)107   const std::vector<int64>& batch_indices(int i) const {
108     return batch_indices_[i];
109   }
110 
111  protected:
112   bool valid_ = true;
113   bool broadcasting_required_ = true;
114   Vec reshape_[N];
115   Vec bcast_[N];
116   Vec result_;
117   Vec output_;
118   Vec grad_reduce_idx_[N];
119 
120   int64 output_batch_size_;
121   std::vector<int64> batch_indices_[N];
122 
Reverse(Vec * shape)123   static void Reverse(Vec* shape) {
124     std::reverse(shape->begin(), shape->end());
125   }
126 
127   TF_DISALLOW_COPY_AND_ASSIGN(BCastList);
128 };
129 
130 template <int N>
BCastList(const BCastList::Vec (& x)[N],const bool fewer_dims_optimization,const bool return_flattened_batch_indices)131 BCastList<N>::BCastList(const BCastList::Vec (&x)[N],
132                         const bool fewer_dims_optimization,
133                         const bool return_flattened_batch_indices) {
134   typedef BCastList::Vec Vec;
135 
136   // Safely multiplies dimensions taking into account symbolic shapes.
137   auto mul_dims = [](int64 dim1, int64 dim2) -> int64 {
138     return dim1 != 0 && dim2 != 0 && (dim1 < 0 || dim2 < 0) ? -1 : dim1 * dim2;
139   };
140 
141   bool all_equal = true;
142   size_t largest_rank = 0;
143   output_batch_size_ = 1;
144   for (int i = 0; i < N; ++i) {
145     if (x[i] != x[0]) {
146       all_equal = false;
147     }
148     if (x[i].size() > largest_rank) {
149       largest_rank = x[i].size();
150     }
151   }
152   if (all_equal) {
153     broadcasting_required_ = false;
154   }
155   if (all_equal && TF_PREDICT_TRUE(fewer_dims_optimization)) {
156     // Fast path for common case of identical shapes.
157     int64 elements = 1;
158     const int rank = x[0].size();
159     output_.resize(rank);
160     for (int i = 0; i < rank; i++) {
161       const int64 dim = x[0][i];
162       elements = mul_dims(elements, dim);
163       output_[i] = dim;
164     }
165     result_.push_back(elements);
166     output_batch_size_ = elements;
167     for (int i = 0; i < N; ++i) {
168       reshape_[i].push_back(elements);
169       bcast_[i].push_back(1);
170     }
171     // grad_reduce_ is left as empty
172     return;
173   }
174 
175   // Reverse all the shapes for convenience
176   // After the reverse, 0-th is the inner-most dimension.
177   Vec copy[N];
178   for (int i = 0; i < N; ++i) {
179     copy[i] = x[i];
180     Reverse(&copy[i]);
181   }
182 
183   // 1-extend and align all vectors.
184   for (int i = 0; i < N; ++i) {
185     if (copy[i].size() < largest_rank) {
186       copy[i].resize(largest_rank, 1);
187     }
188   }
189   // Going through each dimension starting from the inner-most
190   // dimension, compares dimension of x and y. They are compatible if
191   // they are equal or either is 1.
192 
193   // indices of j-th component of each input.
194   bool prev_is_one[N];
195   bool current_is_one[N];
196   for (int i = 0; i < N; ++i) {
197     prev_is_one[i] = false;
198     current_is_one[i] = false;
199   }
200   Vec output;
201   bool output_dim_set = false;
202   int output_dim = -1;
203   bool none_is_one = true;
204   bool set_one = false;
205   for (int j = 0; j < largest_rank; ++j) {
206     output_dim = -1;
207     output_dim_set = false;
208     none_is_one = true;
209     // Find which indices are 1.
210     for (int i = 0; i < N; ++i) {
211       // Keep track of which indices are 1.
212       if (copy[i][j] == 1) {
213         current_is_one[i] = true;
214         none_is_one = false;
215       } else {
216         current_is_one[i] = false;
217         if (!output_dim_set || copy[i][j] == output_dim) {
218           output_dim = copy[i][j];
219           output_dim_set = true;
220         } else {
221           valid_ = false;
222           return;
223         }
224       }
225     }
226     output_.push_back(output_dim_set ? output_dim : 1);
227     output_batch_size_ = mul_dims(output_batch_size_, output_.back());
228     // All dimensions are 1.
229     if (!output_dim_set) {
230       if (!TF_PREDICT_TRUE(fewer_dims_optimization)) {
231         for (int i = 0; i < N; ++i) {
232           bcast_[i].push_back(1);
233           reshape_[i].push_back(1);
234         }
235         result_.push_back(1);
236       }
237       for (int i = 0; i < N; ++i) {
238         grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
239       }
240       // This will skip updating the previous state to the current one. We'll
241       // explain why this is safe below.
242       // Consider the previous state P, current state C and the next state N.
243       // In the case where N also is all ones (N == C), we'll do the same
244       // optimization here (push back one dimensions if we need to), which is
245       // safe and is expected.
246       //
247       // When N != C, we'll continue as usual. However, we might trigger the
248       // next block if N == P (because we didn't update the previous state).
249       // We trigger the next block if `fewer_dims_optimization` is true.
250       // This means that we did not modify and broadcast / reshapes in this
251       // block (we skipped updating, since the one dimensions can be ignored).
252       // In essence, we only need to check whether the previous non-one state is
253       // equal to the current non-one state.
254 
255       continue;
256     } else if (TF_PREDICT_TRUE(fewer_dims_optimization) &&
257                std::equal(current_is_one, current_is_one + N, prev_is_one) &&
258                set_one) {
259       // It is a run of the same broadcasting case as last time.
260       // We can reshape the input so that fewer dimensions
261       // are involved in the intermediate computation.
262       result_.back() = mul_dims(result_.back(), output_dim);
263       for (int i = 0; i < N; ++i) {
264         reshape_[i].back() = mul_dims(reshape_[i].back(), copy[i][j]);
265         bcast_[i].back() =
266             mul_dims(bcast_[i].back(), current_is_one[i] ? output_dim : 1);
267         if (current_is_one[i] && !none_is_one) {
268           grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
269         }
270       }
271     } else {
272       result_.push_back(output_dim);
273       for (int i = 0; i < N; ++i) {
274         reshape_[i].push_back(copy[i][j]);
275         bcast_[i].push_back(current_is_one[i] ? output_dim : 1);
276         if (current_is_one[i] && !none_is_one) {
277           grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
278         }
279       }
280     }
281     set_one = true;
282     for (int i = 0; i < N; ++i) {
283       prev_is_one[i] = current_is_one[i];
284     }
285   }
286   if (result_.empty()) {
287     result_.push_back(1);
288     for (int i = 0; i < N; ++i) {
289       reshape_[i].push_back(1);
290       bcast_[i].push_back(1);
291     }
292   }
293   // Do something about batches.
294   for (int i = 0; i < N; ++i) {
295     Reverse(&reshape_[i]);
296     Reverse(&bcast_[i]);
297     Reverse(&grad_reduce_idx_[i]);
298   }
299   Reverse(&result_);
300   Reverse(&output_);
301   // Only compute batch indices when we need broadcasting, and we aren't doing
302   // needless work (when the output size is 0 or the
303   // return_flattened_batch_indices isn't enabled).
304   if (return_flattened_batch_indices && broadcasting_required_ &&
305       output_batch_size_ > 0) {
306     for (int i = 0; i < N; ++i) {
307       ComputeBatchIndices(output_batch_size_, reshape_[i], bcast_[i],
308                           &batch_indices_[i]);
309     }
310   }
311 }
312 
313 // BCast is a helper for broadcasting binary tensor operation.
314 // TensorFlow's broadcasting rule follows that of numpy (See
315 // http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
316 //
317 // The rule has the following properties:
318 //
319 //   1. suffix matching: the rule starts with the right-most
320 //      dimension, and works towards the left-most dimension. Since
321 //      TensorFlow is row-major, the right-most dimension (the last
322 //      element in the shape of a tensor) is the inner-most, a.k.a.
323 //      the fastest changing, dimension.
324 //
325 //   2. Two dimensions are compatible for broadcasting if both are the
326 //      same or either is 1.
327 //
328 // BCast takes the shape of two tensors and computes a few vectors of
329 // int32 that are useful for the caller to reshape the tensors, apply
330 // the right broadcasts to them, compute the broadcasted operation,
331 // and possibly the gradients. In a nutshell, the caller is expected
332 // to compute the broadcasted operation as following:
333 //
334 //   BCast b(x.shape(), y.shape());
335 //   output = x.reshape(b.x_reshape()).broadcast(b.x_bcast())
336 //            _op_
337 //            y.reshape(b.y_reshape()).broadcast(b.y_bcast())
338 //
339 // For the gradient computation,
340 //   grad_x = sum(grad * backprop_x(x, y), grad_x_reduce_idx)
341 //            .reshape(x.shape())
342 //   grad_y = sum(grad * backprop_y(x, y), grad_y_reduce_idx)
343 //            .reshape(y.shape())
344 // backprop_x and backprop_y are functionals of the binary function "op",
345 // e.g.,
346 //   for +, backprop_x(x, y) = backprop_y(x, y) = 1;
347 //   for *, backprop_x(x, y) =  y, backprop_y(x, y) = x;
348 //   for /, backprop_x(x, y) = 1/y, backprop_y(x, y) = -x/y^2;
349 //
350 // The multiplication in the grad * backprop_x itself is also
351 // broadcasting following the same rule.
352 class BCast : public BCastList<2> {
353  public:
354   // Constructs all helper shapes, following the aforementioned rules.
355   //
356   // If "fewer_dims_optimization" is set to true (the default), the
357   // implementation tries to reduce intermediate dimensions needed to be more
358   // efficient.  This is transparent to the caller.
359   //
360   // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have
361   // the same number of dimensions as the larger of the two inputs.
362   typedef gtl::InlinedVector<int64, 4> Vec;
363 
364   BCast(const Vec& x, const Vec& y, const bool fewer_dims_optimization = true,
365         const bool return_flattened_batch_indices = false)
366       : BCastList<2>({x, y}, fewer_dims_optimization,
367                      return_flattened_batch_indices) {}
368 
~BCast()369   ~BCast() {}
370 
371   // If and only if IsValid(), the following fields can be used in
372   // implementing a broadcasted binary tensor operation according to
373   // the broadcasting rule.
x_reshape()374   const Vec& x_reshape() const { return reshape_[0]; }
x_bcast()375   const Vec& x_bcast() const { return bcast_[0]; }
y_reshape()376   const Vec& y_reshape() const { return reshape_[1]; }
y_bcast()377   const Vec& y_bcast() const { return bcast_[1]; }
result_shape()378   const Vec& result_shape() const { return result_; }
output_shape()379   const Vec& output_shape() const { return output_; }
grad_x_reduce_idx()380   const Vec& grad_x_reduce_idx() const { return grad_reduce_idx_[0]; }
grad_y_reduce_idx()381   const Vec& grad_y_reduce_idx() const { return grad_reduce_idx_[1]; }
382 
383   // Returns the mapping from the flattened output batch indices to x's
384   // flattened batch indices. The result is a vector of length
385   // output_batch_size(). To compute the i'th batch output, a binary matmul-like
386   // operation should use the `x_batch_indices()[i]`th batch index of `x`.
387   // Note: Returns an empty vector if broadcasting is not required. Callers
388   // should only use this when IsBroadcastingRequired() returns true.
x_batch_indices()389   const std::vector<int64>& x_batch_indices() const {
390     return batch_indices_[0];
391   }
392   // Returns the mapping from the flattened output batch indices to y's
393   // flattened batch indices. Similar to x_batch_indices().
394   // Note: Returns an empty vector if broadcasting is not required. Callers
395   // should only use this when IsBroadcastingRequired() returns true.
y_batch_indices()396   const std::vector<int64>& y_batch_indices() const {
397     return batch_indices_[1];
398   }
399 
400   template <typename IndexType, int NDIMS>
ToIndexArrayType(const BCast::Vec & vec)401   static Eigen::array<IndexType, NDIMS> ToIndexArrayType(
402       const BCast::Vec& vec) {
403     CHECK_EQ(vec.size(), NDIMS);
404     Eigen::array<IndexType, NDIMS> ret;
405     for (int i = 0; i < NDIMS; ++i) ret[i] = vec[i];
406     return ret;
407   }
408 
409   template <int NDIMS>
ToIndexArray(const BCast::Vec & vec)410   static Eigen::array<Eigen::DenseIndex, NDIMS> ToIndexArray(
411       const BCast::Vec& vec) {
412     return ToIndexArrayType<Eigen::DenseIndex, NDIMS>(vec);
413   }
414 
415   // Static helpers.
416   static Vec FromShape(const TensorShape& shape);
417   static TensorShape ToShape(const Vec& vec);
418 
419  private:
420   TF_DISALLOW_COPY_AND_ASSIGN(BCast);
421 };
422 
423 }  // end namespace tensorflow
424 
425 #endif  // TENSORFLOW_CORE_UTIL_BCAST_H_
426