Searched refs:rhs_rank (Results 1 – 6 of 6) sorted by relevance
221 const int64 rhs_rank = rhs_shape.rank(); in CanonicalizeDot() local223 rhs_rank - num_batch_dims - num_contracting_dims; in CanonicalizeDot()228 for (int64 i = 0; i < rhs_rank; ++i) { in CanonicalizeDot()241 rhs_transpose.reserve(rhs_rank); in CanonicalizeDot()
1080 int64 rhs_rank = rhs->shape().rank(); in ComputeArrayForDotWithIndexedRhs() local1084 0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1)); in ComputeArrayForDotWithIndexedRhs()
1144 const int64 rhs_rank = rhs->shape().rank(); in HandleDotStrengthReduction() local1150 if (dot_rank > 2 && (lhs_rank != rhs_rank || lhs_rank != dot_rank)) { in HandleDotStrengthReduction()1166 int64 rhs_kept_dim = kept_dim(rhs_rank, rhs_collapsing_dim, in HandleDotStrengthReduction()1169 if (rhs_kept_dim == -1 && rhs_rank > 1) { in HandleDotStrengthReduction()1209 if (rhs_rank == 1 && lhs_rank == 1) { in HandleDotStrengthReduction()1225 if (rhs_rank == 2 && rhs->shape().dimensions(rhs_collapsing_dim) == 1) { in HandleDotStrengthReduction()1260 if (rhs_rank == 1 || in HandleDotStrengthReduction()1261 (rhs_rank == 2 && rhs->shape().dimensions(rhs_kept_dim) == 1)) { in HandleDotStrengthReduction()1275 CHECK_EQ(rhs_rank, lhs_rank); in HandleDotStrengthReduction()1299 (rhs_kept_dim == rhs_rank - 1 || in HandleDotStrengthReduction()[all …]
1032 const auto rhs_rank = rhs_shape.rank(); in HandleConvolution() local1035 CHECK_EQ(num_spatial_dims + 2, rhs_rank); in HandleConvolution()1224 const int64 rhs_rank = rhs->shape().rank(); in HandleDot() local1241 if (lhs_rank == 2 && rhs_rank == 2 && lhs_contracting_dimension == 1 && in HandleDot()1285 const auto rhs_rank = rhs->shape().rank(); in HandleDotSlowPath() local1297 DimensionVector rhs_index(rhs_rank); in HandleDotSlowPath()1305 (rhs_rank - dnums.rhs_contracting_dimensions_size())); in HandleDotSlowPath()1322 for (int64 i = 0; i < rhs_rank; i++) { in HandleDotSlowPath()
482 diags_rank, rhs_rank = len(diagonals.shape), len(rhs.shape)487 if rhs_rank != diags_rank and rhs_rank != diags_rank - 1:489 diags_rank - 1, diags_rank, rhs_rank))502 if rhs_rank == diags_rank - 1:
506 const int64 rhs_rank = rhs_shape.rank(); in BinaryOp() local511 if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) { in BinaryOp()512 const bool should_broadcast_lhs = lhs_rank < rhs_rank; in BinaryOp()