1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_
17 #define TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_
18 
19 #define EIGEN_USE_THREADS
20 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21 #define EIGEN_USE_GPU
22 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/strings/str_split.h"
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/kernel_def_builder.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/register_types.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_types.h"
33 #include "tensorflow/core/kernels/fill_functor.h"
34 #include "tensorflow/core/kernels/linalg/einsum_op.h"
35 #include "tensorflow/core/kernels/matmul_op_impl.h"
36 #include "tensorflow/core/kernels/reduction_ops_common.h"
37 #include "tensorflow/core/kernels/transpose_functor.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/core/status.h"
40 #include "tensorflow/core/lib/gtl/inlined_vector.h"
41 #include "tensorflow/core/lib/math/math_util.h"
42 #include "tensorflow/core/platform/types.h"
43 #include "tensorflow/core/profiler/lib/traceme.h"
44 #include "tensorflow/core/util/einsum_op_util.h"
45 
46 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
47 #include "tensorflow/core/kernels/reduction_ops_common_gpu.h"
48 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
49 
50 namespace tensorflow {
51 
52 using CPUDevice = Eigen::ThreadPoolDevice;
53 using GPUDevice = Eigen::GpuDevice;
54 
55 using ShapeVec = gtl::InlinedVector<int64, 8>;
56 using Labels = gtl::InlinedVector<int, 8>;
57 using OperandLabels = gtl::InlinedVector<Labels, 2>;
58 using LabelCounts = gtl::InlinedVector<int, 8>;
59 using OperandLabelCounts = gtl::InlinedVector<LabelCounts, 2>;
60 using LabelToDimSizes = gtl::InlinedVector<int64, 8>;
61 
62 // Dummy axis label used to denote an ellipsis in an input or output subscript.
63 constexpr int kEllipsisLabel = -1;
64 
65 struct EinsumHelper {
66   // Each dimension is categorized into exactly one of five types based on
67   // whether its corresponding label is present in the input and/or the output
68   // subscripts.
69   enum DimensionType {
70     // Batch dimensions are those present in two inputs as well as the output.
71     // They are part of the batch dimensions during Tensor contraction.
72     // Such dimensions may be broadcasting dimensions (those mapping to
73     // ellipsis)
74     // or explicit batch dimensions corresponding to named axis labels.
75     kBroadcasting = 0,
76     kBatch = 1,
77     // Free dimensions are present in exactly one of the inputs, and also the
78     // output. These are non-contracted axes in the Tensor contraction.
79     kFree = 2,
80     // Contract dimensions are present in two inputs, but not the output. These
81     // dimensions are contracted in Tensor contraction.
82     kContract = 3,
83     // Reduce dimensions are present in exactly one input; and not in the output
84     // and are summed over prior to Tensor contraction.
85     kReduce = 4,
86   };
87 
88   // Returns the DimensionType given whether the corresponding label is present
89   // in exactly one input subscript (is_unique) and whether it is absent from
90   // the output subscripts (is_removed). Does not handle broadcasting
91   // dimensions.
GetDimensionTypeEinsumHelper92   static DimensionType GetDimensionType(bool is_removed, bool is_unique) {
93     if (!is_removed && !is_unique)
94       return kBatch;
95     else if (!is_removed && is_unique)
96       return kFree;
97     else if (is_removed && !is_unique)
98       return kContract;
99     else  // is_removed && is_unique
100       return kReduce;
101   }
102 
103   // Maps the character labels to consecutive integers.
MapToLabelsEinsumHelper104   static void MapToLabels(const string& subscript, Labels* labels,
105                           absl::flat_hash_map<char, int>* label_mapping) {
106     for (int i = 0; i < subscript.size(); ++i) {
107       const char label_char = subscript[i];
108       if (label_char == '.') {
109         labels->push_back(kEllipsisLabel);
110         i += 2;  // Skip next 2 characters as well.
111         continue;
112       }
113       if (!label_mapping->contains(label_char)) {
114         const int next_label = label_mapping->size();
115         (*label_mapping)[label_char] = next_label;
116       }
117       const int mapped_label = (*label_mapping)[label_char];
118       labels->push_back(mapped_label);
119     }
120   }
121 
122   // Parses and validates the equation and the input shapes. Single character
123   // labels are integerized and we populate input and output label subscripts
124   // and corresponding counts. Also create the mapping from (named) labels to
125   // their DimensionType.
ParseEquationEinsumHelper126   static Status ParseEquation(const string& equation,
127                               OperandLabels* input_labels,
128                               Labels* output_labels,
129                               std::vector<DimensionType>* label_types,
130                               OperandLabelCounts* input_label_counts,
131                               LabelCounts* output_label_counts,
132                               gtl::InlinedVector<bool, 2>* input_has_ellipsis,
133                               bool* output_has_ellipsis) {
134     gtl::InlinedVector<string, 2> input_str;
135     string output_str;
136     TF_RETURN_IF_ERROR(ParseEinsumEquation(equation, &input_str, &output_str));
137 
138     // Temporary map from single character labels to (consecutive) integer
139     // labels.
140     absl::flat_hash_map<char, int> label_mapping;
141     int num_inputs = input_str.size();
142     input_labels->resize(num_inputs);
143 
144     // Map from single characters to integer labels.
145     for (int i = 0; i < num_inputs; ++i) {
146       MapToLabels(input_str[i], &input_labels->at(i), &label_mapping);
147     }
148     MapToLabels(output_str, output_labels, &label_mapping);
149 
150     // Compute counts for input and output labels.
151     int num_labels = label_mapping.size();
152     input_label_counts->resize(num_inputs);
153     input_has_ellipsis->resize(num_inputs);
154     for (int i = 0; i < num_inputs; ++i) {
155       input_label_counts->at(i).resize(num_labels);
156       for (const int label : input_labels->at(i)) {
157         if (label != kEllipsisLabel)
158           input_label_counts->at(i)[label] += 1;
159         else
160           input_has_ellipsis->at(i) = true;
161       }
162     }
163     output_label_counts->resize(num_labels);
164     for (const int label : *output_labels) {
165       if (label != kEllipsisLabel)
166         output_label_counts->at(label) += 1;
167       else
168         *output_has_ellipsis = true;
169     }
170 
171     // Map each label to a unique DimensionType.
172     label_types->resize(num_labels);
173     for (int label = 0; label < num_labels; ++label) {
174       if (label == kEllipsisLabel) continue;
175       bool removed = (*output_label_counts)[label] == 0;
176       bool unique = num_inputs == 1 || (*input_label_counts)[0][label] == 0 ||
177                     (*input_label_counts)[1][label] == 0;
178       (*label_types)[label] = GetDimensionType(removed, unique);
179     }
180     return Status::OK();
181   }
182 
183   // Insert new (unnamed) broadcasting labels at the location of ellipsis.
InsertBroadcastLabelsEinsumHelper184   static void InsertBroadcastLabels(int num_bcast_dims, int num_named_labels,
185                                     int ellipsis_axis, Labels* labels,
186                                     LabelCounts* label_counts) {
187     labels->erase(labels->begin() + ellipsis_axis);
188     labels->insert(labels->begin() + ellipsis_axis, num_bcast_dims, 0);
189     std::iota(labels->begin() + ellipsis_axis,
190               labels->begin() + ellipsis_axis + num_bcast_dims,
191               num_named_labels);
192     // Increment label counts. Since these are new labels, the count is set
193     // to 1.
194     label_counts->resize(num_named_labels + num_bcast_dims, 1);
195   }
196 
197   // Record and validate the label to dimension mapping. Must be a named
198   // (non-broadcasting) label as broadcasting labels don't have a fixed
199   // dimension.
RecordLabelToDimensionEinsumHelper200   static Status RecordLabelToDimension(const int label, const int axis,
201                                        const Tensor& input,
202                                        LabelToDimSizes* label_to_dim_sizes) {
203     const int64 input_dim = input.dim_size(axis);
204     // We know that label_to_dim_sizes has the size to accommodate named labels.
205     if (label_to_dim_sizes->at(label) != 0 &&
206         label_to_dim_sizes->at(label) != input_dim) {
207       return errors::InvalidArgument(
208           "Expected dimension ", label_to_dim_sizes->at(label), " at axis ",
209           axis, " of the input shaped ", input.shape().DebugString(),
210           " but got dimension ", input_dim);
211     }
212     (*label_to_dim_sizes)[label] = input_dim;
213     return Status::OK();
214   }
215 
216   // Validate input dimensions and populate unnamed labels and their label
217   // counts.
ProcessDimensionsEinsumHelper218   static Status ProcessDimensions(
219       const OpInputList& inputs,
220       const gtl::InlinedVector<bool, 2>& input_has_ellipsis,
221       const bool output_has_ellipsis, OperandLabels* input_labels,
222       Labels* output_labels, std::vector<DimensionType>* label_types,
223       OperandLabelCounts* input_label_counts, LabelCounts* output_label_counts,
224       LabelToDimSizes* label_to_dim_sizes) {
225     if (inputs.size() != input_labels->size()) {
226       return errors::InvalidArgument("Expected ", input_labels->size(),
227                                      " inputs but got: ", inputs.size());
228     }
229     const int num_inputs = inputs.size();
230 
231     // We infer the number of broadcasting dimensions by taking the maximum rank
232     // among the broadcasting subshapes of the input.
233     int max_bcast_dims = 0;
234     const int num_named_labels = label_types->size();
235     label_to_dim_sizes->resize(num_named_labels);
236     for (int i = 0; i < num_inputs; ++i) {
237       Labels* labels = &(*input_labels)[i];
238 
239       if (!input_has_ellipsis[i]) {
240         if (inputs[i].dims() != labels->size()) {
241           return errors::InvalidArgument("Expected input ", i, " to have rank ",
242                                          labels->size(),
243                                          " but got: ", inputs[i].dims());
244         }
245         for (int label_idx = 0; label_idx < labels->size(); ++label_idx) {
246           const int label = (*labels)[label_idx];
247           TF_RETURN_IF_ERROR(RecordLabelToDimension(label, label_idx, inputs[i],
248                                                     label_to_dim_sizes));
249         }
250         continue;
251       }
252 
253       // Input has an ellipsis.
254       if (inputs[i].dims() + 1 < labels->size()) {
255         return errors::InvalidArgument(
256             "Expected input ", i, " to have rank at least ", labels->size() - 1,
257             " but got: ", inputs[i].dims());
258       }
259       int ellipsis_axis = -1;
260       const int num_bcast_dims = inputs[i].dims() - labels->size() + 1;
261       for (int label_idx = 0; label_idx < labels->size(); ++label_idx) {
262         const int label = (*labels)[label_idx];
263         if (label == kEllipsisLabel) {
264           ellipsis_axis = label_idx;
265           continue;
266         }
267         // Current label is not an ellipsis.
268         const int axis =
269             label_idx + (ellipsis_axis == -1 ? 0 : num_bcast_dims - 1);
270         TF_RETURN_IF_ERROR(
271             RecordLabelToDimension(label, axis, inputs[i], label_to_dim_sizes));
272       }
273       // Found an ellipsis. Replace 'kEllipsisLabel' with broadcasting
274       // dimensions.
275       if (ellipsis_axis != -1) {
276         InsertBroadcastLabels(num_bcast_dims, num_named_labels, ellipsis_axis,
277                               labels, &input_label_counts->at(i));
278         max_bcast_dims = std::max(max_bcast_dims, num_bcast_dims);
279       }
280     }
281     if (!absl::c_linear_search(input_has_ellipsis, true) &&
282         !output_has_ellipsis) {
283       return Status::OK();
284     }
285     // Insert broadcasting dimensions in the output labels.
286     auto it =
287         std::find(output_labels->begin(), output_labels->end(), kEllipsisLabel);
288     if (it != output_labels->end()) {
289       const int ellipsis_axis = it - output_labels->begin();
290       InsertBroadcastLabels(max_bcast_dims, num_named_labels, ellipsis_axis,
291                             output_labels, output_label_counts);
292     } else if (max_bcast_dims > 0) {
293       return errors::InvalidArgument(
294           "Output contains ", max_bcast_dims,
295           " broadcasting dimension(s) but no ellipsis "
296           "(...) was found in the output subscripts.");
297     }
298     // Populate DimensionType for the new broadcasting labels.
299     label_types->resize(num_named_labels + max_bcast_dims, kBroadcasting);
300     return Status::OK();
301   }
302 
303   // Permutes the labels according to the given permutation.
PermuteLabelsEinsumHelper304   static void PermuteLabels(const std::vector<int>& permutation,
305                             Labels* labels) {
306     Labels permuted_labels(labels->size());
307     for (int i = 0; i < labels->size(); ++i) {
308       permuted_labels[i] = (*labels)[permutation[i]];
309     }
310     labels->swap(permuted_labels);
311   }
312 
313   // Returns a reshaped input Tensor. The underlying buffer is not copied.
CopyFromEinsumHelper314   static Status CopyFrom(const Tensor& input, const TensorShape& shape,
315                          Tensor* output) {
316     if (output->CopyFrom(input, shape)) return Status::OK();
317     return errors::Internal(
318         "Encountered error while reshaping a Tensor of shape ",
319         input.shape().DebugString(), " to shape ", shape.DebugString());
320   }
321 
322   // Returns whether transposing would be a no-op; whether input has rank < 2 or
323   // the permutation is the identity permutation.
ShouldTransposeEinsumHelper324   static bool ShouldTranspose(const TensorShape& input_shape,
325                               const std::vector<int>& permutation) {
326     if (input_shape.dims() < 2) return false;
327     for (int i = 0; i < permutation.size(); ++i) {
328       if (permutation[i] != i) return true;
329     }
330     return false;
331   }
332 
333   // Transpose the input given a permutation. Returns a reference to the input
334   // if transposing is not necessary.
335   template <typename Device, typename T>
TransposeOperandEinsumHelper336   static Status TransposeOperand(OpKernelContext* ctx, const Tensor& input,
337                                  const std::vector<int>& permutation,
338                                  Tensor* output) {
339     if (!ShouldTranspose(input.shape(), permutation)) {
340       return CopyFrom(input, input.shape(), output);
341     }
342     TensorShape transposed_shape;
343     for (int i = 0; i < input.dims(); ++i) {
344       transposed_shape.AddDim(input.dim_size(permutation[i]));
345     }
346     // For empty Tensors, just change the shape. E.g. we may need to transpose
347     // from shape [1, 0, 5] to [5, 1, 0].
348     if (input.NumElements() == 0) {
349       return CopyFrom(input, transposed_shape, output);
350     }
351     TF_RETURN_IF_ERROR(
352         ctx->allocate_temp(DataTypeToEnum<T>::value, transposed_shape, output));
353     const Device& device = ctx->eigen_device<Device>();
354     TF_RETURN_IF_ERROR(DoTranspose(device, input, permutation, output));
355     return Status::OK();
356   }
357 
358   // If there are repeated labels in either the input or output, then this
359   // strides the input (e.g. iii->i) or inflates it (e.g. i->iii), respectively.
360   template <typename Device, typename T>
StrideOrInflateEinsumHelper361   static Status StrideOrInflate(OpKernelContext* ctx, const Tensor& input,
362                                 const Labels& labels,
363                                 const LabelCounts& label_counts,
364                                 const bool should_inflate, Tensor* output) {
365     // Return early if there are no repeated indices.
366     if (absl::c_all_of(label_counts, [](int c) { return c <= 1; })) {
367       return CopyFrom(input, input.shape(), output);
368     }
369     // We reshape so that each repeated label is compressed to one dimension.
370     // E.g. For iiij -> ij, The shape [3, 3, 3, 5] would be compressed to [27,
371     // 5]. Striding appropriately (in this case with strides 14 (=1+3+9) and 1)
372     // recovers the generalized diagonal of shape [3, 5].
373     ShapeVec reshape;
374     ShapeVec strides;
375     // Strided and inflated shapes correspond to input and output shapes,
376     // respectively, should_inflate is true (vice-versa if should_inflate is
377     // false). E.g. they are [3, 5] and [3, 3, 3, 5] in the above example.
378     ShapeVec strided_shape;
379     ShapeVec inflated_shape;
380     for (int label : labels) {
381       const int count = label_counts[label];
382       const int current_axis =
383           should_inflate ? strided_shape.size() : inflated_shape.size();
384       const int64 dim = input.dim_size(current_axis);
385       strided_shape.push_back(dim);
386       inflated_shape.insert(inflated_shape.end(), count, dim);
387       const int64 reshape_dim = MathUtil::IPow(dim, count);
388       reshape.push_back(reshape_dim);
389       // While taking the d-diagonal in a rank k Tensor, we take d
390       // equally-spaced elements including the first and last element. Then, (k
391       // - 1) * stride = d^k - 1, or, stride = (d^k - 1)/(d - 1).
392       const int64 stride =
393           (dim > 1 && count > 1) ? (reshape_dim - 1) / (dim - 1) : 1;
394       strides.push_back(stride);
395     }
396 
397     TensorShape output_shape =
398         TensorShape(should_inflate ? inflated_shape : strided_shape);
399     TF_RETURN_IF_ERROR(
400         ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
401     const Device& device = ctx->eigen_device<Device>();
402     switch (reshape.size()) {
403 #define NDIMS_CASE(N)                                                 \
404   case N: {                                                           \
405     if (should_inflate) {                                             \
406       auto output_map = output->shaped<T, N>(reshape);                \
407       auto input_map = input.shaped<T, N>(strided_shape);             \
408       functor::InflateFunctor<Device, T, N>()(                        \
409           device, input_map, TensorShape(strides).AsEigenDSizes<N>(), \
410           output_map);                                                \
411     } else {                                                          \
412       auto input_map = input.shaped<T, N>(reshape);                   \
413       auto output_map = output->shaped<T, N>(strided_shape);          \
414       functor::StrideFunctor<Device, T, N>()(                         \
415           device, input_map, TensorShape(strides).AsEigenDSizes<N>(), \
416           output_map);                                                \
417     }                                                                 \
418   } break;
419       NDIMS_CASE(1);
420       NDIMS_CASE(2);
421       NDIMS_CASE(3);
422       NDIMS_CASE(4);
423       NDIMS_CASE(5);
424       NDIMS_CASE(6);
425       default:
426         return errors::Unimplemented(
427             "Unsupported rank: ", reshape.size(),
428             " while handling repeated indices. Up to rank 6 is supported.");
429 #undef NDIMS_CASE
430     }
431     return Status::OK();
432   }
433 
434   // Returns true if the input dimensions are already sorted in the order
435   // [batch, contract, free, reduce]. Used to implement an optimization to avoid
436   // an extra transpose and instead uses (adj_x and adj_y) in BatchMatMul.
ShouldSwapFreeAndContractEinsumHelper437   static bool ShouldSwapFreeAndContract(
438       const Labels& labels, const std::vector<DimensionType>& label_types) {
439     // Check that ordering is according to dimension type, with the role of
440     // free and contract dimensions swapped.
441     gtl::InlinedVector<int, 5> remap = {0, 1, 3, 2, 4};
442     for (int i = 0; i + 1 < labels.size(); ++i) {
443       const int dimtype_a = remap[label_types[labels[i]]];
444       const int dimtype_b = remap[label_types[labels[i + 1]]];
445       if (dimtype_a > dimtype_b ||
446           (dimtype_a == dimtype_b && labels[i] > labels[i + 1])) {
447         return false;
448       }
449     }
450     return true;
451   }
452 
453   template <typename Device, typename T>
ReduceOperandEinsumHelper454   static Status ReduceOperand(OpKernelContext* ctx, const Tensor& input,
455                               const std::vector<DimensionType>& label_types,
456                               const LabelCounts& label_counts, Labels* labels,
457                               Labels* free_labels, bool* swap_free_and_contract,
458                               Tensor* output) {
459     // Find the permutation to transpose the input dimensions in the order of
460     // DimensionType; i.e. batch, free, contract and reduce dimensions. This
461     // makes it more convenient to invoke Reduce/Contract operations.
462     std::vector<int> permutation(input.dims());
463     absl::c_iota(permutation, 0);
464     Tensor input_transposed;
465     // Check if we can avoid the transpose. We need to flip the adj_x (or adj_y)
466     // flag during BatchMatMul. This is an extra optimization not necessary for
467     // correctness.
468     if (ShouldSwapFreeAndContract(*labels, label_types)) {
469       *swap_free_and_contract = true;
470     } else {
471       absl::c_sort(permutation, [&](int i, int j) {
472         int label_i = (*labels)[i];
473         int label_j = (*labels)[j];
474         return std::tie(label_types[label_i], label_i) <
475                std::tie(label_types[label_j], label_j);
476       });
477     }
478     // Transpose the input so that DimensionTypes are in order.
479     TF_RETURN_IF_ERROR(TransposeOperand<Device, T>(ctx, input, permutation,
480                                                    &input_transposed));
481     PermuteLabels(permutation, labels);
482 
483     // Take the generalized diagonal for dimensions with repeated axis labels.
484     Tensor input_deduped;
485     labels->erase(std::unique(labels->begin(), labels->end()), labels->end());
486     TF_RETURN_IF_ERROR(
487         StrideOrInflate<Device, T>(ctx, input_transposed, *labels, label_counts,
488                                    false /* should_inflate */, &input_deduped));
489 
490     // Reshape denotes the rank-5 shape [broadcast, batch, free, contract,
491     // reduce] where we've compacted the dimensions of each DimensionType.
492     gtl::InlinedVector<int64, 5> reshape(5, 1);
493     // The output shape is [batch shape] + [free size, contract size]
494     // That is, the batch shape is preserved (for broadcasting while
495     // contracting) while the free dims and contract dims are compressed to one
496     // dimension each.
497     TensorShape output_shape;
498     for (int label_idx = 0; label_idx < labels->size(); ++label_idx) {
499       const int label = labels->at(label_idx);
500       int64 dim = input_deduped.dim_size(label_idx);
501       if (label_types[label] == kBroadcasting || label_types[label] == kBatch) {
502         output_shape.AddDim(dim);
503       } else if (label_types[label] == kFree) {
504         free_labels->push_back(label);
505       }
506       reshape[label_types[label]] *= dim;
507     }
508     if (*swap_free_and_contract) std::swap(reshape[kFree], reshape[kContract]);
509     output_shape.AddDim(reshape[kFree]);
510     output_shape.AddDim(reshape[kContract]);
511 
512     if (reshape[kReduce] == 1) {  // No need to actually reduce.
513       return CopyFrom(input_deduped, output_shape, output);
514     }
515     TF_RETURN_IF_ERROR(
516         ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
517     using Reducer = Eigen::internal::SumReducer<T>;
518     using Index = typename TTypes<T>::Tensor::Index;
519     // Reduce along the last axis (i.e axis 1) of the rank-2 Tensor.
520     const int64 output_size = reshape[kBroadcasting] * reshape[kBatch] *
521                               reshape[kFree] * reshape[kContract];
522     functor::ReduceFunctor<Device, Reducer>::Reduce(
523         ctx, output->shaped<T, 1>({output_size}),
524         const_cast<const Tensor&>(input_deduped)
525             .shaped<T, 2>({output_size, reshape[kReduce]}),
526         Eigen::array<Index, 1>({1}), Reducer());
527     return Status::OK();
528   }
529 
530   // Reshapes a Tensor of shape [b0,b1...bk,N,M] to [prod(b0,b1...bk),N,M].
ReshapeToRank3EinsumHelper531   static Status ReshapeToRank3(const Tensor& input, int batch_size,
532                                Tensor* output) {
533     const int rank = input.dims();
534     TensorShape output_shape = {batch_size, input.dim_size(rank - 2),
535                                 input.dim_size(rank - 1)};
536     return CopyFrom(input, output_shape, output);
537   }
538 
539   // Contracts the inputs along the last axis (or the second last if the
540   // corresponding value of swap_free_and_contract is true). The batch
541   // dimensions are broadcast to the output shape.
542   // TODO(anudhyan): BatchMatMul might devolve into a component-wise
543   // multiplication when the matrix shape is [1,1]; in this case BatchMatMul
544   // functor would be very inefficient. The functor should detect if this is the
545   // case and perform componentwise multiplication functor instead.
546   template <typename Device, typename T>
ContractOperandsEinsumHelper547   static Status ContractOperands(OpKernelContext* ctx,
548                                  absl::Span<const Tensor> inputs,
549                                  absl::Span<const bool> swap_free_and_contract,
550                                  Tensor* output) {
551     if (inputs.size() == 1)
552       return CopyFrom(inputs[0], inputs[0].shape(), output);
553     MatMulBCast bcast(inputs[0].shape().dim_sizes(),
554                       inputs[1].shape().dim_sizes());
555     if (!bcast.IsValid()) {
556       return errors::InvalidArgument(
557           "Invalid broadcasting dimensions: ", inputs[0].shape().DebugString(),
558           " vs. ", inputs[1].shape().DebugString());
559     }
560     Tensor lhs;
561     TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[0], bcast.x_batch_size(), &lhs));
562     Tensor rhs;
563     TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[1], bcast.y_batch_size(), &rhs));
564     TensorShape output_shape = bcast.output_batch_shape();
565     for (int i = 0; i < inputs.size(); ++i) {
566       const int64 free_axis =
567           inputs[i].dims() - (swap_free_and_contract[i] ? 1 : 2);
568       output_shape.AddDim(inputs[i].dim_size(free_axis));
569     }
570     bool trans_x = swap_free_and_contract[0];
571     bool trans_y = !swap_free_and_contract[1];
572     TF_RETURN_IF_ERROR(
573         ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
574     if (lhs.NumElements() == 0 || rhs.NumElements() == 0) {
575       functor::SetZeroFunctor<Device, T> set_zero;
576       set_zero(ctx->eigen_device<Device>(), output->flat<T>());
577       return Status::OK();
578     }
579     Tensor output_reshaped;
580     TF_RETURN_IF_ERROR(
581         ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped));
582     LaunchBatchMatMul<Device, T>::Launch(ctx, lhs, rhs, /*adj_x=*/false,
583                                          /*adj_y=*/false, trans_x, trans_y,
584                                          bcast, &output_reshaped);
585     return Status::OK();
586   }
587 };
588 
589 template <typename Device, typename T>
590 class EinsumOp : public OpKernel {
591  public:
EinsumOp(OpKernelConstruction * c)592   explicit EinsumOp(OpKernelConstruction* c) : OpKernel(c) {
593     OP_REQUIRES_OK(c, c->GetAttr("equation", &equation_));
594     OP_REQUIRES_OK(
595         c, EinsumHelper::ParseEquation(
596                equation_, &input_labels_, &output_labels_, &label_types_,
597                &input_label_counts_, &output_label_counts_,
598                &input_has_ellipsis_, &output_has_ellipsis_));
599   }
600 
Compute(OpKernelContext * ctx)601   void Compute(OpKernelContext* ctx) override {
602     OpInputList inputs;
603     OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &inputs));
604 
605     OperandLabels input_labels(input_labels_);
606     Labels output_labels(output_labels_);
607     std::vector<EinsumHelper::DimensionType> label_types(label_types_);
608     OperandLabelCounts input_label_counts(input_label_counts_);
609     LabelCounts output_label_counts(output_label_counts_);
610     LabelToDimSizes label_to_dim_sizes;
611 
612     OP_REQUIRES_OK(ctx, EinsumHelper::ProcessDimensions(
613                             inputs, input_has_ellipsis_, output_has_ellipsis_,
614                             &input_labels, &output_labels, &label_types,
615                             &input_label_counts, &output_label_counts,
616                             &label_to_dim_sizes));
617 
618     // The reduction phase (a) sums across reduction dimensions, (b) takes
619     // generalized diagonals, and (c) reshapes it into shape
620     //   [(broadcasting) batch shape] + [F,C]
621     // where F and C denote the total (compacted) size of free and contract
622     // dimensions, respectively.
623     const int num_inputs = inputs.size();
624     OperandLabels free_labels(num_inputs);
625     gtl::InlinedVector<Tensor, 2> inputs_reduced(num_inputs);
626     gtl::InlinedVector<bool, 2> swap_free_and_contract(num_inputs);
627     for (int i = 0; i < num_inputs; ++i) {
628       OP_REQUIRES_OK(ctx,
629                      EinsumHelper::ReduceOperand<Device, T>(
630                          ctx, inputs[i], label_types, input_label_counts[i],
631                          &input_labels[i], &free_labels[i],
632                          &swap_free_and_contract[i], &inputs_reduced[i]));
633     }
634 
635     // After reduction, the inputs should be reshaped to Tensors suitable for
636     // contraction. If num_inputs is 1, the reduced input is simply forwarded to
637     // the output.
638     Tensor contraction_output_reshaped;
639     OP_REQUIRES_OK(ctx, EinsumHelper::ContractOperands<Device, T>(
640                             ctx, inputs_reduced, swap_free_and_contract,
641                             &contraction_output_reshaped));
642 
643     // Copy the batch labels from the contraction output. Recover the batch
644     // shape, which may have been broadcasted.
645     TensorShape result_shape = contraction_output_reshaped.shape();
646     result_shape.RemoveLastDims(2);
647 
648     int num_labels = label_types.size();
649     Labels result_labels;
650     // All batch dimensions should be present in the contracted result. First
651     // the broadcasting dimensions, then the named batch dimensions.
652     for (int label = 0; label < num_labels; ++label) {
653       if (label_types[label] == EinsumHelper::kBroadcasting)
654         result_labels.push_back(label);
655     }
656     for (int label = 0; label < num_labels; ++label) {
657       if (label_types[label] == EinsumHelper::kBatch)
658         result_labels.push_back(label);
659     }
660     for (int i = 0; i < num_inputs; ++i) {
661       for (int label : free_labels[i]) {
662         result_labels.push_back(label);
663         result_shape.AddDim(label_to_dim_sizes[label]);
664       }
665     }
666 
667     // Reshape the contraction (or reduction) result to its expanded shape:
668     // [(broadcasted) batch shape] + [free shape 0] + [free shape 1].
669     Tensor contraction_output;
670     OP_REQUIRES_OK(
671         ctx, EinsumHelper::CopyFrom(contraction_output_reshaped, result_shape,
672                                     &contraction_output));
673 
674     // Inflate the output if necessary. (E.g. for the equation 'i->iii' which
675     // may arise while computing gradient of a regular Einsum).
676     // TODO(anudhyan): It's possible that Eigen's contract and inflate can be
677     // chained here to avoid materializing an intermediate.
678     Tensor output_inflated;
679     OP_REQUIRES_OK(
680         ctx, EinsumHelper::StrideOrInflate<Device, T>(
681                  ctx, contraction_output, result_labels, output_label_counts,
682                  true /* should_inflate */, &output_inflated));
683     if (output_inflated.dims() > contraction_output.dims()) {
684       // We inflated the output. Modify result labels accordingly.
685       Labels inflated_labels;
686       for (int label : result_labels) {
687         inflated_labels.insert(inflated_labels.end(),
688                                output_label_counts[label], label);
689       }
690       result_labels.swap(inflated_labels);
691     }
692     // Find the permutation to map the result labels to the output labels. Note
693     // that both the result and the final output may have the repeated labels,
694     // in which case the permutation preserves the left-to-right ordering.
695     // E.g. if result labels are [0, 0, 1] and output is [0, l, 0] then the
696     // permutation should be [0, 2, 1]. We also use the fact that repeated
697     // labels in the result are adjacent to each other.
698     std::vector<int> output_permutation(output_labels.size());
699     std::vector<int> label_to_position(num_labels, -1);
700     for (int i = 0; i < result_labels.size(); ++i) {
701       // Remember the position of only the leftmost result label.
702       if (label_to_position[result_labels[i]] == -1) {
703         label_to_position[result_labels[i]] = i;
704       }
705     }
706     for (int i = 0; i < output_labels.size(); ++i) {
707       output_permutation[i] = label_to_position[output_labels[i]];
708       // We have found the leftmost occurrence. The next one would be adjacent.
709       label_to_position[output_labels[i]] += 1;
710     }
711     Tensor output;
712     OP_REQUIRES_OK(ctx, EinsumHelper::TransposeOperand<Device, T>(
713                             ctx, output_inflated, output_permutation, &output));
714     ctx->set_output(0, output);
715   }
716 
TraceString(const OpKernelContext & ctx,bool verbose)717   string TraceString(const OpKernelContext& ctx, bool verbose) const override {
718     string op = profiler::TraceMeOp(name_view(), type_string_view());
719     string equation = strings::StrCat("(", equation_, ")");
720     if (verbose) {
721       string shape = ShapeTraceString(ctx);
722       if (!shape.empty()) {
723         return profiler::TraceMeEncode(
724             std::move(op), {{"equation", equation}, {"shape", shape}});
725       }
726     }
727     return profiler::TraceMeEncode(std::move(op), {{"equation", equation}});
728   }
729 
730  private:
731   string equation_;
732   OperandLabels input_labels_;
733   Labels output_labels_;
734   std::vector<EinsumHelper::DimensionType> label_types_;
735   OperandLabelCounts input_label_counts_;
736   LabelCounts output_label_counts_;
737   gtl::InlinedVector<bool, 2> input_has_ellipsis_;
738   bool output_has_ellipsis_ = false;
739 };
740 
741 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
742 // Forward declarations of the functor specializations for GPU.
743 namespace functor {
744 #define DECLARE_GPU_SPEC(T, N)                                      \
745   template <>                                                       \
746   void StrideFunctor<GPUDevice, T, N>::operator()(                  \
747       const GPUDevice& d, typename TTypes<T, N>::ConstTensor input, \
748       const Eigen::DSizes<Eigen::DenseIndex, N>& strides,           \
749       typename TTypes<T, N>::Tensor output);                        \
750   extern template struct StrideFunctor<GPUDevice, T, N>;            \
751   template <>                                                       \
752   void InflateFunctor<GPUDevice, T, N>::operator()(                 \
753       const GPUDevice& d, typename TTypes<T, N>::ConstTensor input, \
754       const Eigen::DSizes<Eigen::DenseIndex, N>& strides,           \
755       typename TTypes<T, N>::Tensor output);                        \
756   extern template struct InflateFunctor<GPUDevice, T, N>;
757 
758 #define DECLARE_GPU_SPECS(T) \
759   DECLARE_GPU_SPEC(T, 1);    \
760   DECLARE_GPU_SPEC(T, 2);    \
761   DECLARE_GPU_SPEC(T, 3);    \
762   DECLARE_GPU_SPEC(T, 4);    \
763   DECLARE_GPU_SPEC(T, 5);    \
764   DECLARE_GPU_SPEC(T, 6);
765 
766 DECLARE_GPU_SPECS(Eigen::half);
767 DECLARE_GPU_SPECS(double);
768 DECLARE_GPU_SPECS(float);
769 // TODO(rocm): Enable once complex types are supported.
770 #if GOOGLE_CUDA
771 DECLARE_GPU_SPECS(complex64);
772 DECLARE_GPU_SPECS(complex128);
773 #endif
774 #undef DECLARE_GPU_SPEC
775 #undef DECLARE_GPU_SPECS
776 }  // namespace functor
777 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
778 
779 }  // namespace tensorflow
780 
781 #endif  // TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_
782