Searched refs:hlo_sharding_util (Results 1 – 12 of 12) sorted by relevance
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | sharding_propagation.cc | 87 if (hlo_sharding_util::MergeSharding(instruction->sharding(), &sharding, in MaybeImproveInstructionSharding() 106 if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( in MaybeImproveInstructionSharding() 221 hlo_sharding_util::IsShardingMoreSpecific( in PickRepresentativeOperand() 362 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( in InferDotShardingFromOperands() 378 return *hlo_sharding_util::TransposeShardingWithCollapsedDims( in InferDotShardingFromOperands() 403 const hlo_sharding_util::GatherParallelDims& parallel_dims, in InferGatherParallelShardingFromOperands() 434 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( in InferGatherParallelShardingFromOperands() 455 hlo_sharding_util::GatherParallelOutputDims(*instruction, parallel_dims); in InferGatherParallelShardingFromOperands() 461 hlo_sharding_util::GatherOutputAlignedOperandParallelDims( in InferGatherParallelShardingFromOperands() 530 return hlo_sharding_util::TransposeSharding(lhs->sharding(), in InferConvolutionShardingFromOperands() [all …]
|
D | hlo_sharding_util.h | 31 namespace hlo_sharding_util {
|
D | hlo_sharding_util_test.cc | 21 namespace hlo_sharding_util { namespace
|
D | hlo_sharding_util.cc | 35 namespace hlo_sharding_util { namespace 1243 auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm); in TransposeShardingWithCollapsedDims()
|
D | BUILD | 477 name = "hlo_sharding_util", 479 "hlo_sharding_util.cc", 482 "hlo_sharding_util.h", 505 ":hlo_sharding_util", 524 ":hlo_sharding_util",
|
/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
D | convolution_handler.cc | 97 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); in PartitionConvolutionWithBatchGroupCount() 99 hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); in PartitionConvolutionWithBatchGroupCount() 119 auto aligned_output_sharding = hlo_sharding_util::TransposeSharding( in PartitionConvolutionWithBatchGroupCount() 188 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); in PartitionConvolutionWithFeatureGroupCount() 190 hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); in PartitionConvolutionWithFeatureGroupCount() 212 auto aligned_output_sharding = hlo_sharding_util::TransposeSharding( in PartitionConvolutionWithFeatureGroupCount() 254 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() 256 hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() 562 auto aligned_rhs_sharding = hlo_sharding_util::ReverseSharding( in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS() 563 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices), in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS() [all …]
|
D | gather_scatter_handler.cc | 162 hlo_sharding_util::TransposeShardingWithCollapsedDims( in PartitionIndexOnlyPartition() 187 hlo_sharding_util::GatherOutputShardingFromDataOperand( in ParititonPassthroughOperand() 331 if (absl::optional<hlo_sharding_util::GatherParallelDims> parallel_dims = in PartitionIndexParallelDimensions() 332 hlo_sharding_util::GetGatherBatchParallelDims(*gather)) { in PartitionIndexParallelDimensions() 338 hlo_sharding_util::GatherParallelOutputDims(*gather, *parallel_dims); in PartitionIndexParallelDimensions() 573 hlo_sharding_util::TransposeShardingWithCollapsedDims( in HandleScatter() 617 auto maybe_passthrough = hlo_sharding_util::ScatterUpdateShardingFromOutput( in HandleScatter()
|
D | BUILD | 52 "//tensorflow/compiler/xla/service:hlo_sharding_util",
|
D | dot_handler.cc | 495 hlo_sharding_util::TransposeShardingWithCollapsedDims( in PartitionBaseCase() 499 hlo_sharding_util::TransposeShardingWithCollapsedDims( in PartitionBaseCase() 503 hlo_sharding_util::TransposeShardingWithCollapsedDims( in PartitionBaseCase() 507 hlo_sharding_util::TransposeShardingWithCollapsedDims( in PartitionBaseCase() 511 hlo_sharding_util::TransposeShardingWithCollapsedDims( in PartitionBaseCase() 515 hlo_sharding_util::TransposeShardingWithCollapsedDims( in PartitionBaseCase() 831 auto padded_slice_sharding = hlo_sharding_util::ReshapeSharding( in PartitionBaseCase() 2054 .Reshard(hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( in PartitionDotGroupOnNonContracting() 2216 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( in PartitionDotGroupOnContracting() 2386 hlo_sharding_util::TransposeShardingWithCollapsedDims( in LhsIsBestMatchForNonContractingPartitioning() [all …]
|
D | spmd_partitioner.cc | 1129 temp_target_tile = hlo_sharding_util::TransposeSharding( in ReshardWithAllToAll() 1327 if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( in ReshardWithCollectivePermute() 1329 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( in ReshardWithCollectivePermute() 1957 hlo_sharding_util::TransposeSharding(sharding, inverse_dimensions); in HandleTranspose() 1978 auto desired_operand_sharding = hlo_sharding_util::ReshapeSharding( in HandleReshape() 2251 auto desired_input_sharding = hlo_sharding_util::RemoveShapeDimensions( in HandleBroadcast() 2252 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(hlo->sharding(), in HandleBroadcast() 2368 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(dus_sharding, in HandleDynamicUpdateSlice() 2800 auto sharding = hlo_sharding_util::RemoveShapeDimensions( in HandleReduce() 2801 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( in HandleReduce() [all …]
|
D | spmd_partitioner_util.h | 418 const hlo_sharding_util::GatherParallelDims& parallel_dims);
|
D | spmd_partitioner_util.cc | 380 auto transpose_sharding = hlo_sharding_util::TransposeSharding( in PartialReplicateReshardCompatibleSharding() 1834 const hlo_sharding_util::GatherParallelDims& parallel_dims) { in GatherOperandsShardedAcrossParallelDims()
|