Searched refs:lhs_to_rhs_indices (Results 1 – 2 of 2) sorted by relevance
78 std::vector<int64> lhs_to_rhs_indices(output_base_shape.rank()); in PartitionConvolutionWithBatchGroupCount() local80 lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; in PartitionConvolutionWithBatchGroupCount()99 hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); in PartitionConvolutionWithBatchGroupCount()169 std::vector<int64> lhs_to_rhs_indices(output_base_shape.rank()); in PartitionConvolutionWithFeatureGroupCount() local171 lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; in PartitionConvolutionWithFeatureGroupCount()190 hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); in PartitionConvolutionWithFeatureGroupCount()249 std::vector<int64> lhs_to_rhs_indices(output_base_shape.rank()); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() local251 lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()256 hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()539 std::vector<int64> lhs_to_rhs_indices(output_base_shape.rank()); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS() local[all …]
105 std::vector<int64> lhs_to_rhs_indices; member342 std::vector<int64> lhs_to_rhs_indices(lhs_rank, -1); in ComputeDimensionIndexMapping() local351 lhs_to_rhs_indices[mapping.lhs] = mapping.rhs; in ComputeDimensionIndexMapping()378 return DotDimensionIndexMapping{lhs_to_rhs_indices, lhs_to_output_indices, in ComputeDimensionIndexMapping()496 lhs_sharding, indices_map.lhs_to_rhs_indices, in PartitionBaseCase()501 indices_map.lhs_to_rhs_indices); in PartitionBaseCase()743 : indices_map.lhs_to_rhs_indices[slice_sharding_dim]; in PartitionBaseCase()