Searched refs:rhs_to_lhs_indices (Results 1 – 2 of 2) sorted by relevance
67 std::vector<int64> rhs_to_lhs_indices(output_base_shape.rank()); in PartitionConvolutionWithBatchGroupCount() local68 rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = in PartitionConvolutionWithBatchGroupCount()70 rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] = in PartitionConvolutionWithBatchGroupCount()73 rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] = in PartitionConvolutionWithBatchGroupCount()79 for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) { in PartitionConvolutionWithBatchGroupCount()80 lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; in PartitionConvolutionWithBatchGroupCount()97 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); in PartitionConvolutionWithBatchGroupCount()158 std::vector<int64> rhs_to_lhs_indices(output_base_shape.rank()); in PartitionConvolutionWithFeatureGroupCount() local159 rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = in PartitionConvolutionWithFeatureGroupCount()161 rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] = in PartitionConvolutionWithFeatureGroupCount()[all …]
107 std::vector<int64> rhs_to_lhs_indices; member344 std::vector<int64> rhs_to_lhs_indices(rhs_rank, -1); in ComputeDimensionIndexMapping() local355 rhs_to_lhs_indices[mapping.rhs] = mapping.lhs; in ComputeDimensionIndexMapping()379 rhs_to_lhs_indices, rhs_to_output_indices, in ComputeDimensionIndexMapping()497 indices_map.rhs_to_lhs_indices); in PartitionBaseCase()500 rhs_sharding, indices_map.rhs_to_lhs_indices, in PartitionBaseCase()738 ? indices_map.rhs_to_lhs_indices[slice_sharding_dim] in PartitionBaseCase()