Searched refs:output_grouped (Results 1 – 1 of 1) sorted by relevance
/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
D | dot_handler.cc | 1704 auto output_grouped = GroupShardingOnDims(output_sharding, output_dims); in PartitionDotGroupOnBatch() local 1720 output_grouped = AlignGroupsWith( in PartitionDotGroupOnBatch() 1743 lhs.state(), output_grouped.device_groups, b); in PartitionDotGroupOnBatch() 1756 output_grouped.device_groups, batch_dims, in PartitionDotGroupOnBatch() 1757 output_grouped.group_dim_sizes, b); in PartitionDotGroupOnBatch() 1812 output_grouped); in PartitionDotGroupOnBatch() 1869 GetPerGroupBaseShape(output_grouped, output_base_shape), in PartitionDotGroupOnBatch() 1870 output_grouped.sharding, dims_mapping, in PartitionDotGroupOnBatch() 1871 num_partitions / output_grouped.device_groups.size(), in PartitionDotGroupOnBatch() 1874 dot->set_sharding(UngroupSharding(output_grouped)); in PartitionDotGroupOnBatch() [all …]
|