Home
last modified time | relevance | path

Searched refs:rhs_grouped (Results 1 – 1 of 1) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Ddot_handler.cc1709 auto rhs_grouped = GroupShardingOnDims(rhs.sharding(), rhs_dims); in PartitionDotGroupOnBatch() local
1712 rhs_grouped = AlignGroupsWith(std::move(rhs_grouped), lhs_grouped); in PartitionDotGroupOnBatch()
1713 rhs = rhs.Reshard(UngroupSharding(rhs_grouped)); in PartitionDotGroupOnBatch()
1715 lhs_grouped = AlignGroupsWith(std::move(lhs_grouped), rhs_grouped); in PartitionDotGroupOnBatch()
1732 rhs.hlo()->set_sharding(rhs_grouped.sharding); in PartitionDotGroupOnBatch()
1734 lhs_grouped.sharding == rhs_grouped.sharding); in PartitionDotGroupOnBatch()
1739 rhs.hlo(), GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()), in PartitionDotGroupOnBatch()
2131 auto rhs_grouped = GroupShardingOnDims(rhs_sharding, rhs_dims); in PartitionDotGroupOnContracting() local
2134 rhs_grouped = AlignGroupsWith(rhs_grouped, lhs_grouped); in PartitionDotGroupOnContracting()
2135 rhs_sharding = UngroupSharding(rhs_grouped); in PartitionDotGroupOnContracting()
[all …]