Searched refs:diag_rank (Results 1 – 2 of 2) sorted by relevance
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/ |
D | matrix_diag_ops.cc | 118 const TensorShape& input_shape, const int64 diag_rank, in SetMatrixDiag() argument 171 diag_rank - 2), in SetMatrixDiag() 172 {diag_rank - 2, diag_rank - 1}); in SetMatrixDiag() 273 const int64 diag_rank = diag_shape.dims(); in Compile() local 274 const int64 max_diag_len = diag_shape.dim_size(diag_rank - 1); in Compile() 278 num_diags == 1 || num_diags == diag_shape.dim_size(diag_rank - 2), in Compile() 318 0, SetMatrixDiag(output, diag, output_shape, diag_rank, num_diags, in Compile() 469 const int diag_rank = diag_shape.dims(); in Compile() local 525 0, SetMatrixDiag(input, diag, input_shape, diag_rank, num_diags, in Compile()
|
/external/tensorflow/tensorflow/core/kernels/linalg/ |
D | matrix_diag_op.cc | 202 const int diag_rank = diagonal_shape.dims(); in Compute() local 215 diagonal_shape.dim_size(diag_rank - 2) == num_diags, in Compute() 220 const Eigen::Index max_diag_len = diagonal_shape.dim_size(diag_rank - 1); in Compute() 245 output_shape.set_dim(diag_rank - 1, num_rows); in Compute() 248 output_shape.set_dim(diag_rank - 2, num_rows); in Compute() 249 output_shape.set_dim(diag_rank - 1, num_cols); in Compute()
|