1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/lib/constants.h"
20 #include "tensorflow/compiler/xla/client/lib/matrix.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/lib/core/errors.h"
25
26 namespace tensorflow {
27 namespace {
28
29 // Calculates the diagonal length of a diagonal.
ComputeDiagLen(int diag_index,int num_rows,int num_cols)30 static inline int ComputeDiagLen(int diag_index, int num_rows, int num_cols) {
31 return std::min(num_rows + std::min(0, diag_index),
32 num_cols - std::max(0, diag_index));
33 }
34
35 // Checks if a diagonal is to be aligned left or right.
IsLeftAligned(int diag_index,bool left_align_superdiagonal,bool left_align_subdiagonal)36 static inline bool IsLeftAligned(int diag_index, bool left_align_superdiagonal,
37 bool left_align_subdiagonal) {
38 return (diag_index >= 0 && left_align_superdiagonal) ||
39 (diag_index <= 0 && left_align_subdiagonal);
40 }
41
42 // Reads the diagonal packing alignment.
ReadAlignment(OpKernelConstruction * context,bool * left_align_superdiagonal,bool * left_align_subdiagonal)43 void ReadAlignment(OpKernelConstruction* context,
44 bool* left_align_superdiagonal,
45 bool* left_align_subdiagonal) {
46 string align;
47 OP_REQUIRES_OK(context, context->GetAttr("align", &align));
48
49 *left_align_superdiagonal = align == "LEFT_LEFT" || align == "LEFT_RIGHT";
50 *left_align_subdiagonal = align == "LEFT_LEFT" || align == "RIGHT_LEFT";
51 }
52
53 // Reads or infers lower_diag_index and upper_diag_index from kernel's input
54 // parameter "k". Also validates their values.
ProcessDiagIndex(XlaOpKernelContext * context)55 std::pair<int64, int64> ProcessDiagIndex(XlaOpKernelContext* context) {
56 int64 lower_diag_index = 0;
57 int64 upper_diag_index = 0;
58 TensorShape diag_index_shape = context->InputShape("k");
59
60 // Wrapping OP_REQUIRES* macros with a function because they can "return;"
61 // early (without values) which contradicts ProcessDiagIndex's signature.
62 auto validate_diag_indices = [&]() {
63 if (diag_index_shape.dims() == 0) {
64 OP_REQUIRES_OK(context,
65 context->ConstantInputAsIntScalar("k", &lower_diag_index));
66 upper_diag_index = lower_diag_index;
67 } else {
68 std::vector<int64> diag_index;
69 OP_REQUIRES_OK(context,
70 context->ConstantInputAsIntVector("k", &diag_index));
71 OP_REQUIRES(
72 context, !diag_index.empty() && diag_index.size() <= 2,
73 errors::InvalidArgument(
74 "diag_index must have only one or two elements, received ",
75 diag_index.size(), " elements."));
76 lower_diag_index = diag_index[0];
77 upper_diag_index =
78 (diag_index.size() > 1) ? diag_index[1] : lower_diag_index;
79 }
80 OP_REQUIRES(
81 context, lower_diag_index <= upper_diag_index,
82 errors::InvalidArgument(
83 "lower_diag_index must not be larger than upper_diag_index: ",
84 lower_diag_index, " > ", upper_diag_index));
85 };
86 validate_diag_indices();
87 return {lower_diag_index, upper_diag_index};
88 }
89
90 // Makes sure lower_diag_index and upper_diag_index are consistent with the
91 // input matrix size.
ValidateDiagIndexWithOutputMatrixSize(XlaOpKernelContext * context,const int64 lower_diag_index,const int64 upper_diag_index,const int64 num_rows,const int64 num_cols)92 void ValidateDiagIndexWithOutputMatrixSize(XlaOpKernelContext* context,
93 const int64 lower_diag_index,
94 const int64 upper_diag_index,
95 const int64 num_rows,
96 const int64 num_cols) {
97 // `lower_diag_index == 0` condition is added to handle matrix shape = 0.
98 OP_REQUIRES(context,
99 (-num_rows < lower_diag_index && lower_diag_index < num_cols) ||
100 lower_diag_index == 0,
101 errors::InvalidArgument(
102 "lower_diag_index is out of bound: ", lower_diag_index,
103 " It must be between ", -num_rows, " and ", num_cols));
104 OP_REQUIRES(context,
105 (-num_rows < upper_diag_index && upper_diag_index < num_cols) ||
106 upper_diag_index == 0,
107 errors::InvalidArgument(
108 "upper_diag_index is out of bound: ", upper_diag_index,
109 " It must be between ", -num_rows, " and ", num_cols));
110 OP_REQUIRES(context, lower_diag_index <= upper_diag_index,
111 errors::InvalidArgument(
112 "lower_diag_index must not be larger than upper_diag_index: ",
113 lower_diag_index, " > ", upper_diag_index));
114 }
115
116 // Kernel to set matrix diagonals.
SetMatrixDiag(const xla::XlaOp input,const xla::XlaOp diag,const TensorShape & input_shape,const int64 diag_rank,const int64 num_diags,const int64 lower_diag_index,const int64 upper_diag_index,const int64 max_diag_len,const int64 num_rows,const int64 num_cols,const bool left_align_superdiagonal,const bool left_align_subdiagonal)117 xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag,
118 const TensorShape& input_shape, const int64 diag_rank,
119 const int64 num_diags, const int64 lower_diag_index,
120 const int64 upper_diag_index, const int64 max_diag_len,
121 const int64 num_rows, const int64 num_cols,
122 const bool left_align_superdiagonal,
123 const bool left_align_subdiagonal) {
124 // Creates a padding config.
125 const int input_rank = input_shape.dims();
126 xla::PaddingConfig padding_config;
127 padding_config = xla::MakeNoPaddingConfig(input_rank - 1);
128
129 // Processes one diagonal at a time:
130 // 1) Extracts a single diagonal (diag_slice).
131 // 2) Broadcasts its contents to fill the whole matrix (diag_broadcast).
132 // 3) Masks diag_broadcast to get the right diagonal shape.
133 //
134 // XLA can fuse multiple Broadcasts and Selects so this shouldn't be slow.
135 //
136 // For example,
137 // diag = [[0, 2, 3], k = (-1, 1), num_cols = 4, and align="RIGHT_LEFT".
138 // [4, 5, 6],
139 // [7, 8, 9]]
140 // The expected output is [[7, 4, 2, 0],
141 // [0, 8, 5, 3],
142 // [0, 0, 9, 6]].
143 // The 1st diagonal is created by:
144 // 1) Extracting diag_slice = [0, 2, 3] which is right-aligned.
145 // 2) Padding the vector (in the same direction) to be as long as num_cols,
146 // diag_slice = [0, 0, 2, 3],
147 // then broadcasting diag_slice column-wise to a full matrix,
148 // diag_broadcast = [[0, 0, 2, 3],
149 // [0, 0, 2, 3],
150 // [0, 0, 2, 3]].
151 // The padding value can be anything because it will not appear in the
152 // results after masking. Here, we use zero.
153 // 3) Masking diag_broadcast with a mask of the shape of the 1st diagonal.
154 // mask = [[0, 0, 1, 0], --> output = [[x, x, 2, x],
155 // [0, 0, 0, 1], [x, x, x, 3],
156 // [0, 0, 0, 0]] [x, x, x, x]],
157 // where x denotes the existing input contents.
158 std::vector<int64> broadcast_dimensions(input_rank - 1);
159 absl::c_iota(broadcast_dimensions, 0);
160 auto output = input;
161 for (int64 diag_index = lower_diag_index; diag_index <= upper_diag_index;
162 ++diag_index) {
163 // Extracts a single diagonal.
164 auto diag_slice = diag;
165 if (num_diags > 1) {
166 // The result of SliceInDim has dims: [<batch_dim>, 1, max_diag_len].
167 // We call Collapse to make the dims: [<batch_dim>, max_diag_len].
168 const int64 mapped_diag_index = upper_diag_index - diag_index;
169 diag_slice = xla::Collapse(
170 xla::SliceInDim(diag, mapped_diag_index, mapped_diag_index + 1, 1,
171 diag_rank - 2),
172 {diag_rank - 2, diag_rank - 1});
173 }
174
175 // Pad if necessary.
176 // - If the diagonal has the longest length, i.e., min(num_rows, num_cols),
177 // no padding is necessary. It is broadcast column-wise if it is a sub-
178 // diagonal, row-wise if superdiagonal.
179 // - Otherwise, pad and keep the old alignment (shorter diagonals in the
180 // input come pre-padded). max_len in the table refers to max_diag_len.
181 // -------------------------------------------------------------------
182 // | Diag | Align | Broadcast | padding_low | padding_high |
183 // -------------------------------------------------------------------
184 // | Super | Left | Row-wise | 0 | #rows - max_len |
185 // | | Right | Column-wise | #cols - max_len | 0 |
186 // -------------------------------------------------------------------
187 // | Sub | Left | Column-wise | 0 | #cols - max_len |
188 // | | Right | Row-wise | #rows - max_len | 0 |
189 // -------------------------------------------------------------------
190 if (num_cols - num_rows <= diag_index && diag_index <= 0) {
191 broadcast_dimensions.back() = input_rank - 1; // Column-wise.
192 } else if (0 <= diag_index && diag_index <= num_cols - num_rows) {
193 broadcast_dimensions.back() = input_rank - 2; // Row-wise.
194 } else {
195 int length_to_pad_to;
196 if ((diag_index > 0 && left_align_superdiagonal) ||
197 (diag_index < 0 && !left_align_subdiagonal)) {
198 length_to_pad_to = num_rows;
199 broadcast_dimensions.back() = input_rank - 2; // Row-wise.
200 } else {
201 length_to_pad_to = num_cols;
202 broadcast_dimensions.back() = input_rank - 1; // Column-wise.
203 }
204 int padding_low = length_to_pad_to - max_diag_len;
205 int padding_high = 0;
206 if (IsLeftAligned(diag_index, left_align_superdiagonal,
207 left_align_subdiagonal)) {
208 std::swap(padding_low, padding_high);
209 }
210 padding_config.mutable_dimensions(input_rank - 2)
211 ->set_edge_padding_low(padding_low);
212 padding_config.mutable_dimensions(input_rank - 2)
213 ->set_edge_padding_high(padding_high);
214
215 const xla::XlaOp zero = xla::ScalarLike(input, 0);
216 diag_slice = xla::Pad(diag_slice, zero, padding_config);
217 }
218
219 // Broadcast and mask.
220 xla::XlaOp diag_broadcast = xla::BroadcastInDim(
221 diag_slice, input_shape.dim_sizes(), broadcast_dimensions);
222 const auto mask = xla::GetDiagonalMask(output, diag_index);
223 output = xla::Select(mask, diag_broadcast, output);
224 }
225 return output;
226 }
227
228 } // namespace
229
230 class MatrixDiagOp : public XlaOpKernel {
231 public:
MatrixDiagOp(OpKernelConstruction * context)232 explicit MatrixDiagOp(OpKernelConstruction* context) : XlaOpKernel(context) {
233 // MatrixDiagV3-specific.
234 if (context->HasAttr("align")) {
235 ReadAlignment(context, &left_align_superdiagonal_,
236 &left_align_subdiagonal_);
237 }
238 }
239
Compile(XlaOpKernelContext * context)240 void Compile(XlaOpKernelContext* context) override {
241 OP_REQUIRES(
242 context, context->num_inputs() >= kNumV1Inputs,
243 errors::InvalidArgument("MatrixDiag op must have at least one input"));
244 const TensorShape diag_shape = context->InputShape(0);
245 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
246 errors::InvalidArgument(
247 "diagonal must be at least 1-dim, received shape: ",
248 diag_shape.DebugString()));
249
250 const DataType dtype = context->expected_output_dtype(0);
251 const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype);
252
253 // Initializes MatrixDiagV2-specific variables.
254 // Input arguments providing the values of num_rows and num_cols can be
255 // absent (-1) and will be inferred later.
256 int64 lower_diag_index = 0;
257 int64 upper_diag_index = 0;
258 int64 num_rows = -1;
259 int64 num_cols = -1;
260 xla::XlaOp padding_value = zero;
261
262 // MatrixDiag and MatrixDiagV2 both use this OpKernel. MatrixDiag only has
263 // one input, so we have to check the number of inputs before reading
264 // additional parameters for MatrixDiagV2.
265 if (context->num_inputs() > kNumV1Inputs) {
266 std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
267 OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &num_rows));
268 OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(3, &num_cols));
269 padding_value = context->Input(4);
270 }
271
272 // More size validations.
273 const int64 diag_rank = diag_shape.dims();
274 const int64 max_diag_len = diag_shape.dim_size(diag_rank - 1);
275 const int64 num_diags = upper_diag_index - lower_diag_index + 1;
276 OP_REQUIRES(
277 context,
278 num_diags == 1 || num_diags == diag_shape.dim_size(diag_rank - 2),
279 errors::InvalidArgument(
280 "The number of diagonals provided in the input does not "
281 "match the lower_diag_index and upper_diag_index range."));
282 const int64 min_num_rows =
283 max_diag_len - std::min(upper_diag_index, int64{0});
284 const int64 min_num_cols =
285 max_diag_len + std::max(lower_diag_index, int64{0});
286 OP_REQUIRES(context, num_rows == -1 || num_rows >= min_num_rows,
287 errors::InvalidArgument("The number of rows is too small."));
288 OP_REQUIRES(context, num_cols == -1 || num_cols >= min_num_cols,
289 errors::InvalidArgument("The number of columns is too small."));
290
291 // Infers num_rows and num_cols. If both are unknown, assume that the output
292 // is square. Otherwise, use smallest possible values.
293 if (num_rows == -1 && num_cols == -1) {
294 num_rows = std::max(min_num_rows, min_num_cols);
295 num_cols = num_rows;
296 } else if (num_rows == -1) {
297 num_rows = min_num_rows;
298 } else if (num_cols == -1) {
299 num_cols = min_num_cols;
300 }
301
302 // At least one of num_rows and num_cols must match its minimum length.
303 // Otherwise, we'll have some incomplete diagonals.
304 OP_REQUIRES(context, num_rows == min_num_rows || num_cols == min_num_cols,
305 errors::InvalidArgument(
306 "The number of rows or columns is not consistent with "
307 "the specified d_lower, d_upper, and diagonal."));
308
309 // Actual processing.
310 // Initializes the output tensor with padding_value.
311 TensorShape output_shape = diag_shape;
312 output_shape.RemoveLastDims((num_diags == 1) ? 1 : 2);
313 output_shape.AddDim(num_rows);
314 output_shape.AddDim(num_cols);
315 xla::XlaOp output = xla::Broadcast(padding_value, output_shape.dim_sizes());
316 xla::XlaOp diag = context->Input(0);
317 context->SetOutput(
318 0, SetMatrixDiag(output, diag, output_shape, diag_rank, num_diags,
319 lower_diag_index, upper_diag_index, max_diag_len,
320 num_rows, num_cols, left_align_superdiagonal_,
321 left_align_subdiagonal_));
322 }
323
324 private:
325 bool left_align_superdiagonal_ = true;
326 bool left_align_subdiagonal_ = true;
327 static constexpr int kNumV1Inputs = 1;
328 };
329
330 REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp);
331 REGISTER_XLA_OP(Name("MatrixDiagV2")
332 .CompileTimeConstantInput("k")
333 .CompileTimeConstantInput("num_rows")
334 .CompileTimeConstantInput("num_cols")
335 .CompileTimeConstantInput("padding_value"),
336 MatrixDiagOp);
337 REGISTER_XLA_OP(Name("MatrixDiagV3")
338 .CompileTimeConstantInput("k")
339 .CompileTimeConstantInput("num_rows")
340 .CompileTimeConstantInput("num_cols")
341 .CompileTimeConstantInput("padding_value"),
342 MatrixDiagOp);
343
344 class MatrixDiagPartOp : public XlaOpKernel {
345 public:
MatrixDiagPartOp(OpKernelConstruction * context)346 explicit MatrixDiagPartOp(OpKernelConstruction* context)
347 : XlaOpKernel(context),
348 is_gpu_(context->device_type().type_string() == DEVICE_GPU_XLA_JIT) {
349 // MatrixDiagPartV3-specific.
350 if (context->HasAttr("align")) {
351 ReadAlignment(context, &left_align_superdiagonal_,
352 &left_align_subdiagonal_);
353 }
354 }
355
Compile(XlaOpKernelContext * context)356 void Compile(XlaOpKernelContext* context) override {
357 const TensorShape input_shape = context->InputShape(0);
358 const int input_rank = input_shape.dims();
359
360 OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
361 errors::InvalidArgument(
362 "input must be at least 2-dim, received shape: ",
363 input_shape.DebugString()));
364
365 const DataType dtype = context->expected_output_dtype(0);
366 const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype);
367
368 // Initializes MatrixDiagPartV2-specific variables.
369 int64 lower_diag_index = 0;
370 int64 upper_diag_index = 0;
371 xla::XlaOp padding_value = zero;
372
373 // MatrixDiagPart and MatrixDiagPartV2 both use this OpKernel.
374 // MatrixDiagPart only has one input, so we have to check the number of
375 // inputs before reading additional parameters in MatrixDiagV2.
376 if (context->num_inputs() > kNumV1Inputs) {
377 std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
378 padding_value = context->Input(2);
379 }
380
381 // Checks if diag sizes are consistent with input.
382 const int64 num_rows = input_shape.dim_size(input_rank - 2);
383 const int64 num_cols = input_shape.dim_size(input_rank - 1);
384 ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index,
385 upper_diag_index, num_rows, num_cols);
386
387 // Creates output shape.
388 TensorShape output_shape = input_shape;
389 output_shape.RemoveLastDims(2);
390 const int num_diags = upper_diag_index - lower_diag_index + 1;
391 if (num_diags > 1) output_shape.AddDim(num_diags);
392 const int32 max_diag_len =
393 std::min(num_rows + std::min(upper_diag_index, int64{0}),
394 num_cols - std::max(lower_diag_index, int64{0}));
395 output_shape.AddDim(max_diag_len);
396
397 // Computes output.
398 xla::XlaOp input = context->Input(0);
399 std::vector<xla::XlaOp> diag_list;
400 xla::PaddingConfig padding_config =
401 xla::MakeNoPaddingConfig(input_rank - 1);
402 if (num_diags == 1) {
403 context->SetOutput(
404 0, is_gpu_ ? xla::GetMatrixDiagonalViaGather(input, upper_diag_index)
405 : xla::GetMatrixDiagonal(input, upper_diag_index));
406 return;
407 }
408 for (int diag_index = upper_diag_index; diag_index >= lower_diag_index;
409 --diag_index) {
410 xla::XlaOp single_diag =
411 is_gpu_ ? xla::GetMatrixDiagonalViaGather(input, diag_index)
412 : xla::GetMatrixDiagonal(input, diag_index);
413 const int64 diag_len = ComputeDiagLen(diag_index, num_rows, num_cols);
414 const int64 padding_len = max_diag_len - diag_len;
415 if (padding_len > 0) {
416 if (IsLeftAligned(diag_index, left_align_superdiagonal_,
417 left_align_subdiagonal_)) {
418 padding_config.mutable_dimensions(input_rank - 2)
419 ->set_edge_padding_low(0);
420 padding_config.mutable_dimensions(input_rank - 2)
421 ->set_edge_padding_high(padding_len);
422 } else {
423 padding_config.mutable_dimensions(input_rank - 2)
424 ->set_edge_padding_low(padding_len);
425 padding_config.mutable_dimensions(input_rank - 2)
426 ->set_edge_padding_high(0);
427 }
428 single_diag = xla::Pad(single_diag, padding_value, padding_config);
429 }
430 diag_list.emplace_back(single_diag);
431 }
432 auto concat =
433 xla::ConcatInDim(context->builder(), diag_list, input_rank - 2);
434 context->SetOutput(0, xla::Reshape(concat, output_shape.dim_sizes()));
435 }
436
437 private:
438 const bool is_gpu_;
439 bool left_align_superdiagonal_ = true;
440 bool left_align_subdiagonal_ = true;
441 static constexpr int kNumV1Inputs = 1;
442 };
443
444 REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp);
445 REGISTER_XLA_OP(Name("MatrixDiagPartV2")
446 .CompileTimeConstantInput("k")
447 .CompileTimeConstantInput("padding_value"),
448 MatrixDiagPartOp);
449 REGISTER_XLA_OP(Name("MatrixDiagPartV3")
450 .CompileTimeConstantInput("k")
451 .CompileTimeConstantInput("padding_value"),
452 MatrixDiagPartOp);
453
454 class MatrixSetDiagOp : public XlaOpKernel {
455 public:
MatrixSetDiagOp(OpKernelConstruction * context)456 explicit MatrixSetDiagOp(OpKernelConstruction* context)
457 : XlaOpKernel(context) {
458 // MatrixSetDiagV3-specific.
459 if (context->HasAttr("align")) {
460 ReadAlignment(context, &left_align_superdiagonal_,
461 &left_align_subdiagonal_);
462 }
463 }
464
Compile(XlaOpKernelContext * context)465 void Compile(XlaOpKernelContext* context) override {
466 const TensorShape input_shape = context->InputShape(0);
467 const TensorShape diag_shape = context->InputShape(1);
468 const int input_rank = input_shape.dims();
469 const int diag_rank = diag_shape.dims();
470
471 // Preliminary validation of sizes.
472 OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
473 errors::InvalidArgument(
474 "input must be at least 2-dim, received shape: ",
475 input_shape.DebugString()));
476 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
477 errors::InvalidArgument(
478 "diagonal must be at least 1-dim, received shape: ",
479 diag_shape.DebugString()));
480
481 // MatrixSetDiag and MatrixSetDiagV2 both use this OpKernel. MatrixSetDiag
482 // only has two inputs, so we have to check the number of inputs before
483 // reading additional parameters in MatrixSetDiagV2.
484 int64 lower_diag_index = 0;
485 int64 upper_diag_index = 0;
486 if (context->num_inputs() > kNumV1Inputs) {
487 std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
488 }
489
490 // Checks if diag sizes are consistent with input.
491 const int64 num_rows = input_shape.dim_size(input_rank - 2);
492 const int64 num_cols = input_shape.dim_size(input_rank - 1);
493 ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index,
494 upper_diag_index, num_rows, num_cols);
495 const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1;
496 OP_REQUIRES(
497 context,
498 lower_diag_index == upper_diag_index ||
499 (diag_shape.dim_size(input_rank - 2) == num_diags),
500 errors::InvalidArgument("The number of diagonals provided in `diag` "
501 "is not consistent with `lower_diag_index` and "
502 "`upper_diag_index`"));
503
504 TensorShape expected_diag_shape = input_shape;
505 expected_diag_shape.RemoveLastDims(2);
506 if (num_diags > 1) expected_diag_shape.AddDim(num_diags);
507 const int32 max_diag_len =
508 std::min(num_rows + std::min(upper_diag_index, int64{0}),
509 num_cols - std::max(lower_diag_index, int64{0}));
510 expected_diag_shape.AddDim(max_diag_len);
511 OP_REQUIRES(
512 context, expected_diag_shape == diag_shape,
513 errors::InvalidArgument(
514 "Either first dimensions of diagonal don't match input.shape[:-2], "
515 "or diagonal.shape[:-1] is not equal to the longests diagonal in "
516 "range [lower_diag_index:upper_diag_index].\nInput shape: ",
517 input_shape.DebugString(),
518 "\nDiagonal shape: ", diag_shape.DebugString(),
519 "\nExpected diagonal shape: ", expected_diag_shape.DebugString()));
520
521 // Actual processing.
522 xla::XlaOp input = context->Input(0);
523 xla::XlaOp diag = context->Input(1);
524 context->SetOutput(
525 0, SetMatrixDiag(input, diag, input_shape, diag_rank, num_diags,
526 lower_diag_index, upper_diag_index, max_diag_len,
527 num_rows, num_cols, left_align_superdiagonal_,
528 left_align_subdiagonal_));
529 }
530
531 private:
532 bool left_align_superdiagonal_ = true;
533 bool left_align_subdiagonal_ = true;
534 static constexpr int kNumV1Inputs = 2;
535 TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
536 };
537
538 REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp);
539 REGISTER_XLA_OP(Name("MatrixSetDiagV2").CompileTimeConstantInput("k"),
540 MatrixSetDiagOp);
541 REGISTER_XLA_OP(Name("MatrixSetDiagV3").CompileTimeConstantInput("k"),
542 MatrixSetDiagOp);
543
544 } // namespace tensorflow
545