/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
D | fft_handler.cc | 54 const SPMDCollectiveOpsCreator& collective_ops_creator, in PadEachPartitionWithHaloExchange() argument 81 hlo->shape().rank() - 1, sharding, collective_ops_creator, in PadEachPartitionWithHaloExchange() 161 const SPMDCollectiveOpsCreator& collective_ops_creator, in ShuffleDataWithAllToAll() argument 167 auto all_to_all = collective_ops_creator.create_cross_partition_all_to_all( in ShuffleDataWithAllToAll() 229 const SPMDCollectiveOpsCreator& collective_ops_creator, in GetFinalFftUsingCollectivePermute() argument 294 collective_ops_creator.create_cross_partition_collective_permute( in GetFinalFftUsingCollectivePermute() 298 collective_ops_creator.create_cross_partition_collective_permute( in GetFinalFftUsingCollectivePermute() 385 partitioned_input.state().collective_ops_creator, in HandleFft() 403 result, num_partitions_, partitioned_input.state().collective_ops_creator, in HandleFft() 423 result, hlo->sharding(), partitioned_input.state().collective_ops_creator, in HandleFft()
|
D | spmd_partitioner_util.h | 221 const SPMDCollectiveOpsCreator& collective_ops_creator, 231 const SPMDCollectiveOpsCreator& collective_ops_creator, 260 const SPMDCollectiveOpsCreator& collective_ops_creator, 370 const SPMDCollectiveOpsCreator& collective_ops_creator, 393 const SPMDCollectiveOpsCreator& collective_ops_creator,
|
D | spmd_partitioner.h | 188 SPMDCollectiveOpsCreator collective_ops_creator) in SpmdPartitioner() argument 192 collective_ops_creator_(std::move(collective_ops_creator)) {} in SpmdPartitioner() 228 const SPMDCollectiveOpsCreator& collective_ops_creator, 288 SPMDCollectiveOpsCreator collective_ops_creator; member 404 const SPMDCollectiveOpsCreator& collective_ops_creator, 488 state.collective_ops_creator = collective_ops_creator_; in MakePartitioningState()
|
D | spmd_partitioner_util.cc | 412 const SPMDCollectiveOpsCreator& collective_ops_creator, in TileToPartialReplicateHaloExchange() argument 468 src_sharding, collective_ops_creator, next_channel_id, b); in TileToPartialReplicateHaloExchange() 506 const SPMDCollectiveOpsCreator& collective_ops_creator, in PadFromPartialReplicateShape() argument 568 src_sharding, collective_ops_creator, next_channel_id, b); in PadFromPartialReplicateShape() 802 const SPMDCollectiveOpsCreator& collective_ops_creator, in ExchangeHalo() argument 845 collective_ops_creator.create_cross_partition_collective_permute( in ExchangeHalo() 878 collective_ops_creator.create_cross_partition_collective_permute( in ExchangeHalo() 904 const SPMDCollectiveOpsCreator& collective_ops_creator, in ExchangeHalo() argument 913 collective_ops_creator, next_channel_id, b); in ExchangeHalo() 930 const SPMDCollectiveOpsCreator& collective_ops_creator, in ExchangeHaloAndGetValidData() argument [all …]
|
D | convolution_handler.cc | 317 const auto& collective_ops_creator = lhs.state().collective_ops_creator; in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() local 490 partition_ordinals[dim], collective_ops_creator, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() 502 auto ar = collective_ops_creator.create_cross_partition_all_reduce( in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() 717 partition_ordinals[dim], lhs.state().collective_ops_creator, in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS() 729 lhs.state().collective_ops_creator.create_cross_partition_all_reduce( in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
|
D | dot_handler.cc | 648 .collective_ops_creator.create_cross_partition_collective_permute( in PartitionBaseCase() 667 lhs.state().collective_ops_creator.create_partition_id(&body_b); in PartitionBaseCase() 1059 lhs.state().collective_ops_creator.create_partition_id(&body_b); in PartitionBaseCase() 1166 .collective_ops_creator.create_cross_partition_collective_permute( in PartitionBaseCase() 1179 .collective_ops_creator.create_cross_partition_collective_permute( in PartitionBaseCase() 1201 .collective_ops_creator.create_cross_partition_collective_permute( in PartitionBaseCase() 1214 .collective_ops_creator.create_cross_partition_collective_permute( in PartitionBaseCase() 1241 .collective_ops_creator in PartitionBaseCase() 1250 .collective_ops_creator in PartitionBaseCase() 1281 .collective_ops_creator in PartitionBaseCase() [all …]
|
D | spmd_partitioner.cc | 809 state_.collective_ops_creator, state_.next_channel_id, state_.b, in ReshardAsWindowedInput() 878 if (state_.collective_ops_creator.create_cross_partition_all_gather) { in ReplicatePartial() 881 state_.collective_ops_creator); in ReplicatePartial() 897 state_.collective_ops_creator, reduction); in ReplicatePartial() 948 partitioned_hlo.state().collective_ops_creator, in ReshardToPartialReplicateWithAllGather() 1018 state_.collective_ops_creator, state_.next_channel_id, in ReshardFromPartialReplicateWithDynamicSlice() 1080 auto result = state_.collective_ops_creator.create_cross_partition_all_reduce( in Broadcast() 1185 state_.collective_ops_creator.create_cross_partition_all_to_all( in ReshardWithAllToAll() 1345 state_.collective_ops_creator.create_cross_partition_collective_permute( in ReshardWithCollectivePermute() 1353 const SPMDCollectiveOpsCreator& collective_ops_creator, in SpmdPartitioningVisitor() argument [all …]
|
D | gather_scatter_handler.cc | 297 all_dims, operand.state().collective_ops_creator, in ParititonTrivialIndexedOperandDimension()
|
D | spmd_partitioner_test.cc | 49 auto collective_ops_creator = in PartitionComputation() local 53 collective_ops_creator.create_cross_partition_all_gather = nullptr; in PartitionComputation() 61 collective_ops_creator); in PartitionComputation()
|