Home
last modified time | relevance | path

Searched refs:lhs_grouped (Results 1 – 1 of 1) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Ddot_handler.cc1708 auto lhs_grouped = GroupShardingOnDims(lhs.sharding(), lhs_dims); in PartitionDotGroupOnBatch() local
1712 rhs_grouped = AlignGroupsWith(std::move(rhs_grouped), lhs_grouped); in PartitionDotGroupOnBatch()
1715 lhs_grouped = AlignGroupsWith(std::move(lhs_grouped), rhs_grouped); in PartitionDotGroupOnBatch()
1716 lhs = lhs.Reshard(UngroupSharding(lhs_grouped)); in PartitionDotGroupOnBatch()
1726 lhs_grouped); in PartitionDotGroupOnBatch()
1728 lhs.state(), lhs_grouped.device_groups, b); in PartitionDotGroupOnBatch()
1730 lhs.hlo()->set_sharding(lhs_grouped.sharding); in PartitionDotGroupOnBatch()
1734 lhs_grouped.sharding == rhs_grouped.sharding); in PartitionDotGroupOnBatch()
1736 lhs.hlo(), GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()), in PartitionDotGroupOnBatch()
2130 auto lhs_grouped = GroupShardingOnDims(lhs_sharding, lhs_dims); in PartitionDotGroupOnContracting() local
[all …]