1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/lib/constants.h"
20 #include "tensorflow/compiler/xla/client/lib/matrix.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 
26 namespace tensorflow {
27 namespace {
28 
29 // Calculates the diagonal length of a diagonal.
ComputeDiagLen(int diag_index,int num_rows,int num_cols)30 static inline int ComputeDiagLen(int diag_index, int num_rows, int num_cols) {
31   return std::min(num_rows + std::min(0, diag_index),
32                   num_cols - std::max(0, diag_index));
33 }
34 
35 // Checks if a diagonal is to be aligned left or right.
IsLeftAligned(int diag_index,bool left_align_superdiagonal,bool left_align_subdiagonal)36 static inline bool IsLeftAligned(int diag_index, bool left_align_superdiagonal,
37                                  bool left_align_subdiagonal) {
38   return (diag_index >= 0 && left_align_superdiagonal) ||
39          (diag_index <= 0 && left_align_subdiagonal);
40 }
41 
42 // Reads the diagonal packing alignment.
ReadAlignment(OpKernelConstruction * context,bool * left_align_superdiagonal,bool * left_align_subdiagonal)43 void ReadAlignment(OpKernelConstruction* context,
44                    bool* left_align_superdiagonal,
45                    bool* left_align_subdiagonal) {
46   string align;
47   OP_REQUIRES_OK(context, context->GetAttr("align", &align));
48 
49   *left_align_superdiagonal = align == "LEFT_LEFT" || align == "LEFT_RIGHT";
50   *left_align_subdiagonal = align == "LEFT_LEFT" || align == "RIGHT_LEFT";
51 }
52 
53 // Reads or infers lower_diag_index and upper_diag_index from kernel's input
54 // parameter "k". Also validates their values.
ProcessDiagIndex(XlaOpKernelContext * context)55 std::pair<int64, int64> ProcessDiagIndex(XlaOpKernelContext* context) {
56   int64 lower_diag_index = 0;
57   int64 upper_diag_index = 0;
58   TensorShape diag_index_shape = context->InputShape("k");
59 
60   // Wrapping OP_REQUIRES* macros with a function because they can "return;"
61   // early (without values) which contradicts ProcessDiagIndex's signature.
62   auto validate_diag_indices = [&]() {
63     if (diag_index_shape.dims() == 0) {
64       OP_REQUIRES_OK(context,
65                      context->ConstantInputAsIntScalar("k", &lower_diag_index));
66       upper_diag_index = lower_diag_index;
67     } else {
68       std::vector<int64> diag_index;
69       OP_REQUIRES_OK(context,
70                      context->ConstantInputAsIntVector("k", &diag_index));
71       OP_REQUIRES(
72           context, !diag_index.empty() && diag_index.size() <= 2,
73           errors::InvalidArgument(
74               "diag_index must have only one or two elements, received ",
75               diag_index.size(), " elements."));
76       lower_diag_index = diag_index[0];
77       upper_diag_index =
78           (diag_index.size() > 1) ? diag_index[1] : lower_diag_index;
79     }
80     OP_REQUIRES(
81         context, lower_diag_index <= upper_diag_index,
82         errors::InvalidArgument(
83             "lower_diag_index must not be larger than upper_diag_index: ",
84             lower_diag_index, " > ", upper_diag_index));
85   };
86   validate_diag_indices();
87   return {lower_diag_index, upper_diag_index};
88 }
89 
90 // Makes sure lower_diag_index and upper_diag_index are consistent with the
91 // input matrix size.
ValidateDiagIndexWithOutputMatrixSize(XlaOpKernelContext * context,const int64 lower_diag_index,const int64 upper_diag_index,const int64 num_rows,const int64 num_cols)92 void ValidateDiagIndexWithOutputMatrixSize(XlaOpKernelContext* context,
93                                            const int64 lower_diag_index,
94                                            const int64 upper_diag_index,
95                                            const int64 num_rows,
96                                            const int64 num_cols) {
97   // `lower_diag_index == 0` condition is added to handle matrix shape = 0.
98   OP_REQUIRES(context,
99               (-num_rows < lower_diag_index && lower_diag_index < num_cols) ||
100                   lower_diag_index == 0,
101               errors::InvalidArgument(
102                   "lower_diag_index is out of bound: ", lower_diag_index,
103                   " It must be between ", -num_rows, " and ", num_cols));
104   OP_REQUIRES(context,
105               (-num_rows < upper_diag_index && upper_diag_index < num_cols) ||
106                   upper_diag_index == 0,
107               errors::InvalidArgument(
108                   "upper_diag_index is out of bound: ", upper_diag_index,
109                   " It must be between ", -num_rows, " and ", num_cols));
110   OP_REQUIRES(context, lower_diag_index <= upper_diag_index,
111               errors::InvalidArgument(
112                   "lower_diag_index must not be larger than upper_diag_index: ",
113                   lower_diag_index, " > ", upper_diag_index));
114 }
115 
116 // Kernel to set matrix diagonals.
SetMatrixDiag(const xla::XlaOp input,const xla::XlaOp diag,const TensorShape & input_shape,const int64 diag_rank,const int64 num_diags,const int64 lower_diag_index,const int64 upper_diag_index,const int64 max_diag_len,const int64 num_rows,const int64 num_cols,const bool left_align_superdiagonal,const bool left_align_subdiagonal)117 xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag,
118                          const TensorShape& input_shape, const int64 diag_rank,
119                          const int64 num_diags, const int64 lower_diag_index,
120                          const int64 upper_diag_index, const int64 max_diag_len,
121                          const int64 num_rows, const int64 num_cols,
122                          const bool left_align_superdiagonal,
123                          const bool left_align_subdiagonal) {
124   // Creates a padding config.
125   const int input_rank = input_shape.dims();
126   xla::PaddingConfig padding_config;
127   padding_config = xla::MakeNoPaddingConfig(input_rank - 1);
128 
129   // Processes one diagonal at a time:
130   // 1) Extracts a single diagonal (diag_slice).
131   // 2) Broadcasts its contents to fill the whole matrix (diag_broadcast).
132   // 3) Masks diag_broadcast to get the right diagonal shape.
133   //
134   // XLA can fuse multiple Broadcasts and Selects so this shouldn't be slow.
135   //
136   // For example,
137   //   diag = [[0, 2, 3], k = (-1, 1), num_cols = 4, and align="RIGHT_LEFT".
138   //           [4, 5, 6],
139   //           [7, 8, 9]]
140   // The expected output is [[7, 4, 2, 0],
141   //                         [0, 8, 5, 3],
142   //                         [0, 0, 9, 6]].
143   // The 1st diagonal is created by:
144   // 1) Extracting diag_slice = [0, 2, 3] which is right-aligned.
145   // 2) Padding the vector (in the same direction) to be as long as num_cols,
146   //      diag_slice = [0, 0, 2, 3],
147   //    then broadcasting diag_slice column-wise to a full matrix,
148   //      diag_broadcast = [[0, 0, 2, 3],
149   //                        [0, 0, 2, 3],
150   //                        [0, 0, 2, 3]].
151   //    The padding value can be anything because it will not appear in the
152   //    results after masking. Here, we use zero.
153   // 3) Masking diag_broadcast with a mask of the shape of the 1st diagonal.
154   //      mask = [[0, 0, 1, 0],  -->  output = [[x, x, 2, x],
155   //              [0, 0, 0, 1],                 [x, x, x, 3],
156   //              [0, 0, 0, 0]]                 [x, x, x, x]],
157   //    where x denotes the existing input contents.
158   std::vector<int64> broadcast_dimensions(input_rank - 1);
159   absl::c_iota(broadcast_dimensions, 0);
160   auto output = input;
161   for (int64 diag_index = lower_diag_index; diag_index <= upper_diag_index;
162        ++diag_index) {
163     // Extracts a single diagonal.
164     auto diag_slice = diag;
165     if (num_diags > 1) {
166       // The result of SliceInDim has dims: [<batch_dim>, 1, max_diag_len].
167       // We call Collapse to make the dims: [<batch_dim>, max_diag_len].
168       const int64 mapped_diag_index = upper_diag_index - diag_index;
169       diag_slice = xla::Collapse(
170           xla::SliceInDim(diag, mapped_diag_index, mapped_diag_index + 1, 1,
171                           diag_rank - 2),
172           {diag_rank - 2, diag_rank - 1});
173     }
174 
175     // Pad if necessary.
176     // - If the diagonal has the longest length, i.e., min(num_rows, num_cols),
177     //   no padding is necessary. It is broadcast column-wise if it is a sub-
178     //   diagonal, row-wise if superdiagonal.
179     // - Otherwise, pad and keep the old alignment (shorter diagonals in the
180     //   input come pre-padded). max_len in the table refers to max_diag_len.
181     //   -------------------------------------------------------------------
182     //   | Diag  | Align | Broadcast   |   padding_low   |   padding_high  |
183     //   -------------------------------------------------------------------
184     //   | Super | Left  | Row-wise    |        0        | #rows - max_len |
185     //   |       | Right | Column-wise | #cols - max_len |        0        |
186     //   -------------------------------------------------------------------
187     //   | Sub   | Left  | Column-wise |        0        | #cols - max_len |
188     //   |       | Right | Row-wise    | #rows - max_len |        0        |
189     //   -------------------------------------------------------------------
190     if (num_cols - num_rows <= diag_index && diag_index <= 0) {
191       broadcast_dimensions.back() = input_rank - 1;  // Column-wise.
192     } else if (0 <= diag_index && diag_index <= num_cols - num_rows) {
193       broadcast_dimensions.back() = input_rank - 2;  // Row-wise.
194     } else {
195       int length_to_pad_to;
196       if ((diag_index > 0 && left_align_superdiagonal) ||
197           (diag_index < 0 && !left_align_subdiagonal)) {
198         length_to_pad_to = num_rows;
199         broadcast_dimensions.back() = input_rank - 2;  // Row-wise.
200       } else {
201         length_to_pad_to = num_cols;
202         broadcast_dimensions.back() = input_rank - 1;  // Column-wise.
203       }
204       int padding_low = length_to_pad_to - max_diag_len;
205       int padding_high = 0;
206       if (IsLeftAligned(diag_index, left_align_superdiagonal,
207                         left_align_subdiagonal)) {
208         std::swap(padding_low, padding_high);
209       }
210       padding_config.mutable_dimensions(input_rank - 2)
211           ->set_edge_padding_low(padding_low);
212       padding_config.mutable_dimensions(input_rank - 2)
213           ->set_edge_padding_high(padding_high);
214 
215       const xla::XlaOp zero = xla::ScalarLike(input, 0);
216       diag_slice = xla::Pad(diag_slice, zero, padding_config);
217     }
218 
219     // Broadcast and mask.
220     xla::XlaOp diag_broadcast = xla::BroadcastInDim(
221         diag_slice, input_shape.dim_sizes(), broadcast_dimensions);
222     const auto mask = xla::GetDiagonalMask(output, diag_index);
223     output = xla::Select(mask, diag_broadcast, output);
224   }
225   return output;
226 }
227 
228 }  // namespace
229 
230 class MatrixDiagOp : public XlaOpKernel {
231  public:
MatrixDiagOp(OpKernelConstruction * context)232   explicit MatrixDiagOp(OpKernelConstruction* context) : XlaOpKernel(context) {
233     // MatrixDiagV3-specific.
234     if (context->HasAttr("align")) {
235       ReadAlignment(context, &left_align_superdiagonal_,
236                     &left_align_subdiagonal_);
237     }
238   }
239 
Compile(XlaOpKernelContext * context)240   void Compile(XlaOpKernelContext* context) override {
241     OP_REQUIRES(
242         context, context->num_inputs() >= kNumV1Inputs,
243         errors::InvalidArgument("MatrixDiag op must have at least one input"));
244     const TensorShape diag_shape = context->InputShape(0);
245     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
246                 errors::InvalidArgument(
247                     "diagonal must be at least 1-dim, received shape: ",
248                     diag_shape.DebugString()));
249 
250     const DataType dtype = context->expected_output_dtype(0);
251     const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype);
252 
253     // Initializes MatrixDiagV2-specific variables.
254     // Input arguments providing the values of num_rows and num_cols can be
255     // absent (-1) and will be inferred later.
256     int64 lower_diag_index = 0;
257     int64 upper_diag_index = 0;
258     int64 num_rows = -1;
259     int64 num_cols = -1;
260     xla::XlaOp padding_value = zero;
261 
262     // MatrixDiag and MatrixDiagV2 both use this OpKernel. MatrixDiag only has
263     // one input, so we have to check the number of inputs before reading
264     // additional parameters for MatrixDiagV2.
265     if (context->num_inputs() > kNumV1Inputs) {
266       std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
267       OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &num_rows));
268       OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(3, &num_cols));
269       padding_value = context->Input(4);
270     }
271 
272     // More size validations.
273     const int64 diag_rank = diag_shape.dims();
274     const int64 max_diag_len = diag_shape.dim_size(diag_rank - 1);
275     const int64 num_diags = upper_diag_index - lower_diag_index + 1;
276     OP_REQUIRES(
277         context,
278         num_diags == 1 || num_diags == diag_shape.dim_size(diag_rank - 2),
279         errors::InvalidArgument(
280             "The number of diagonals provided in the input does not "
281             "match the lower_diag_index and upper_diag_index range."));
282     const int64 min_num_rows =
283         max_diag_len - std::min(upper_diag_index, int64{0});
284     const int64 min_num_cols =
285         max_diag_len + std::max(lower_diag_index, int64{0});
286     OP_REQUIRES(context, num_rows == -1 || num_rows >= min_num_rows,
287                 errors::InvalidArgument("The number of rows is too small."));
288     OP_REQUIRES(context, num_cols == -1 || num_cols >= min_num_cols,
289                 errors::InvalidArgument("The number of columns is too small."));
290 
291     // Infers num_rows and num_cols. If both are unknown, assume that the output
292     // is square. Otherwise, use smallest possible values.
293     if (num_rows == -1 && num_cols == -1) {
294       num_rows = std::max(min_num_rows, min_num_cols);
295       num_cols = num_rows;
296     } else if (num_rows == -1) {
297       num_rows = min_num_rows;
298     } else if (num_cols == -1) {
299       num_cols = min_num_cols;
300     }
301 
302     // At least one of num_rows and num_cols must match its minimum length.
303     // Otherwise, we'll have some incomplete diagonals.
304     OP_REQUIRES(context, num_rows == min_num_rows || num_cols == min_num_cols,
305                 errors::InvalidArgument(
306                     "The number of rows or columns is not consistent with "
307                     "the specified d_lower, d_upper, and diagonal."));
308 
309     // Actual processing.
310     // Initializes the output tensor with padding_value.
311     TensorShape output_shape = diag_shape;
312     output_shape.RemoveLastDims((num_diags == 1) ? 1 : 2);
313     output_shape.AddDim(num_rows);
314     output_shape.AddDim(num_cols);
315     xla::XlaOp output = xla::Broadcast(padding_value, output_shape.dim_sizes());
316     xla::XlaOp diag = context->Input(0);
317     context->SetOutput(
318         0, SetMatrixDiag(output, diag, output_shape, diag_rank, num_diags,
319                          lower_diag_index, upper_diag_index, max_diag_len,
320                          num_rows, num_cols, left_align_superdiagonal_,
321                          left_align_subdiagonal_));
322   }
323 
324  private:
325   bool left_align_superdiagonal_ = true;
326   bool left_align_subdiagonal_ = true;
327   static constexpr int kNumV1Inputs = 1;
328 };
329 
330 REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp);
331 REGISTER_XLA_OP(Name("MatrixDiagV2")
332                     .CompileTimeConstantInput("k")
333                     .CompileTimeConstantInput("num_rows")
334                     .CompileTimeConstantInput("num_cols")
335                     .CompileTimeConstantInput("padding_value"),
336                 MatrixDiagOp);
337 REGISTER_XLA_OP(Name("MatrixDiagV3")
338                     .CompileTimeConstantInput("k")
339                     .CompileTimeConstantInput("num_rows")
340                     .CompileTimeConstantInput("num_cols")
341                     .CompileTimeConstantInput("padding_value"),
342                 MatrixDiagOp);
343 
344 class MatrixDiagPartOp : public XlaOpKernel {
345  public:
MatrixDiagPartOp(OpKernelConstruction * context)346   explicit MatrixDiagPartOp(OpKernelConstruction* context)
347       : XlaOpKernel(context),
348         is_gpu_(context->device_type().type_string() == DEVICE_GPU_XLA_JIT) {
349     // MatrixDiagPartV3-specific.
350     if (context->HasAttr("align")) {
351       ReadAlignment(context, &left_align_superdiagonal_,
352                     &left_align_subdiagonal_);
353     }
354   }
355 
Compile(XlaOpKernelContext * context)356   void Compile(XlaOpKernelContext* context) override {
357     const TensorShape input_shape = context->InputShape(0);
358     const int input_rank = input_shape.dims();
359 
360     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
361                 errors::InvalidArgument(
362                     "input must be at least 2-dim, received shape: ",
363                     input_shape.DebugString()));
364 
365     const DataType dtype = context->expected_output_dtype(0);
366     const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype);
367 
368     // Initializes MatrixDiagPartV2-specific variables.
369     int64 lower_diag_index = 0;
370     int64 upper_diag_index = 0;
371     xla::XlaOp padding_value = zero;
372 
373     // MatrixDiagPart and MatrixDiagPartV2 both use this OpKernel.
374     // MatrixDiagPart only has one input, so we have to check the number of
375     // inputs before reading additional parameters in MatrixDiagV2.
376     if (context->num_inputs() > kNumV1Inputs) {
377       std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
378       padding_value = context->Input(2);
379     }
380 
381     // Checks if diag sizes are consistent with input.
382     const int64 num_rows = input_shape.dim_size(input_rank - 2);
383     const int64 num_cols = input_shape.dim_size(input_rank - 1);
384     ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index,
385                                           upper_diag_index, num_rows, num_cols);
386 
387     // Creates output shape.
388     TensorShape output_shape = input_shape;
389     output_shape.RemoveLastDims(2);
390     const int num_diags = upper_diag_index - lower_diag_index + 1;
391     if (num_diags > 1) output_shape.AddDim(num_diags);
392     const int32 max_diag_len =
393         std::min(num_rows + std::min(upper_diag_index, int64{0}),
394                  num_cols - std::max(lower_diag_index, int64{0}));
395     output_shape.AddDim(max_diag_len);
396 
397     // Computes output.
398     xla::XlaOp input = context->Input(0);
399     std::vector<xla::XlaOp> diag_list;
400     xla::PaddingConfig padding_config =
401         xla::MakeNoPaddingConfig(input_rank - 1);
402     if (num_diags == 1) {
403       context->SetOutput(
404           0, is_gpu_ ? xla::GetMatrixDiagonalViaGather(input, upper_diag_index)
405                      : xla::GetMatrixDiagonal(input, upper_diag_index));
406       return;
407     }
408     for (int diag_index = upper_diag_index; diag_index >= lower_diag_index;
409          --diag_index) {
410       xla::XlaOp single_diag =
411           is_gpu_ ? xla::GetMatrixDiagonalViaGather(input, diag_index)
412                   : xla::GetMatrixDiagonal(input, diag_index);
413       const int64 diag_len = ComputeDiagLen(diag_index, num_rows, num_cols);
414       const int64 padding_len = max_diag_len - diag_len;
415       if (padding_len > 0) {
416         if (IsLeftAligned(diag_index, left_align_superdiagonal_,
417                           left_align_subdiagonal_)) {
418           padding_config.mutable_dimensions(input_rank - 2)
419               ->set_edge_padding_low(0);
420           padding_config.mutable_dimensions(input_rank - 2)
421               ->set_edge_padding_high(padding_len);
422         } else {
423           padding_config.mutable_dimensions(input_rank - 2)
424               ->set_edge_padding_low(padding_len);
425           padding_config.mutable_dimensions(input_rank - 2)
426               ->set_edge_padding_high(0);
427         }
428         single_diag = xla::Pad(single_diag, padding_value, padding_config);
429       }
430       diag_list.emplace_back(single_diag);
431     }
432     auto concat =
433         xla::ConcatInDim(context->builder(), diag_list, input_rank - 2);
434     context->SetOutput(0, xla::Reshape(concat, output_shape.dim_sizes()));
435   }
436 
437  private:
438   const bool is_gpu_;
439   bool left_align_superdiagonal_ = true;
440   bool left_align_subdiagonal_ = true;
441   static constexpr int kNumV1Inputs = 1;
442 };
443 
444 REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp);
445 REGISTER_XLA_OP(Name("MatrixDiagPartV2")
446                     .CompileTimeConstantInput("k")
447                     .CompileTimeConstantInput("padding_value"),
448                 MatrixDiagPartOp);
449 REGISTER_XLA_OP(Name("MatrixDiagPartV3")
450                     .CompileTimeConstantInput("k")
451                     .CompileTimeConstantInput("padding_value"),
452                 MatrixDiagPartOp);
453 
454 class MatrixSetDiagOp : public XlaOpKernel {
455  public:
MatrixSetDiagOp(OpKernelConstruction * context)456   explicit MatrixSetDiagOp(OpKernelConstruction* context)
457       : XlaOpKernel(context) {
458     // MatrixSetDiagV3-specific.
459     if (context->HasAttr("align")) {
460       ReadAlignment(context, &left_align_superdiagonal_,
461                     &left_align_subdiagonal_);
462     }
463   }
464 
Compile(XlaOpKernelContext * context)465   void Compile(XlaOpKernelContext* context) override {
466     const TensorShape input_shape = context->InputShape(0);
467     const TensorShape diag_shape = context->InputShape(1);
468     const int input_rank = input_shape.dims();
469     const int diag_rank = diag_shape.dims();
470 
471     // Preliminary validation of sizes.
472     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
473                 errors::InvalidArgument(
474                     "input must be at least 2-dim, received shape: ",
475                     input_shape.DebugString()));
476     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
477                 errors::InvalidArgument(
478                     "diagonal must be at least 1-dim, received shape: ",
479                     diag_shape.DebugString()));
480 
481     // MatrixSetDiag and MatrixSetDiagV2 both use this OpKernel. MatrixSetDiag
482     // only has two inputs, so we have to check the number of inputs before
483     // reading additional parameters in MatrixSetDiagV2.
484     int64 lower_diag_index = 0;
485     int64 upper_diag_index = 0;
486     if (context->num_inputs() > kNumV1Inputs) {
487       std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
488     }
489 
490     // Checks if diag sizes are consistent with input.
491     const int64 num_rows = input_shape.dim_size(input_rank - 2);
492     const int64 num_cols = input_shape.dim_size(input_rank - 1);
493     ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index,
494                                           upper_diag_index, num_rows, num_cols);
495     const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1;
496     OP_REQUIRES(
497         context,
498         lower_diag_index == upper_diag_index ||
499             (diag_shape.dim_size(input_rank - 2) == num_diags),
500         errors::InvalidArgument("The number of diagonals provided in `diag` "
501                                 "is not consistent with `lower_diag_index` and "
502                                 "`upper_diag_index`"));
503 
504     TensorShape expected_diag_shape = input_shape;
505     expected_diag_shape.RemoveLastDims(2);
506     if (num_diags > 1) expected_diag_shape.AddDim(num_diags);
507     const int32 max_diag_len =
508         std::min(num_rows + std::min(upper_diag_index, int64{0}),
509                  num_cols - std::max(lower_diag_index, int64{0}));
510     expected_diag_shape.AddDim(max_diag_len);
511     OP_REQUIRES(
512         context, expected_diag_shape == diag_shape,
513         errors::InvalidArgument(
514             "Either first dimensions of diagonal don't match input.shape[:-2], "
515             "or diagonal.shape[:-1] is not equal to the longests diagonal in "
516             "range [lower_diag_index:upper_diag_index].\nInput shape: ",
517             input_shape.DebugString(),
518             "\nDiagonal shape: ", diag_shape.DebugString(),
519             "\nExpected diagonal shape: ", expected_diag_shape.DebugString()));
520 
521     // Actual processing.
522     xla::XlaOp input = context->Input(0);
523     xla::XlaOp diag = context->Input(1);
524     context->SetOutput(
525         0, SetMatrixDiag(input, diag, input_shape, diag_rank, num_diags,
526                          lower_diag_index, upper_diag_index, max_diag_len,
527                          num_rows, num_cols, left_align_superdiagonal_,
528                          left_align_subdiagonal_));
529   }
530 
531  private:
532   bool left_align_superdiagonal_ = true;
533   bool left_align_subdiagonal_ = true;
534   static constexpr int kNumV1Inputs = 2;
535   TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
536 };
537 
538 REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp);
539 REGISTER_XLA_OP(Name("MatrixSetDiagV2").CompileTimeConstantInput("k"),
540                 MatrixSetDiagOp);
541 REGISTER_XLA_OP(Name("MatrixSetDiagV3").CompileTimeConstantInput("k"),
542                 MatrixSetDiagOp);
543 
544 }  // namespace tensorflow
545