Home
last modified time | relevance | path

Searched refs:hlo_sharding_util (Results 1 – 12 of 12) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/
Dsharding_propagation.cc87 if (hlo_sharding_util::MergeSharding(instruction->sharding(), &sharding, in MaybeImproveInstructionSharding()
106 if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( in MaybeImproveInstructionSharding()
221 hlo_sharding_util::IsShardingMoreSpecific( in PickRepresentativeOperand()
362 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( in InferDotShardingFromOperands()
378 return *hlo_sharding_util::TransposeShardingWithCollapsedDims( in InferDotShardingFromOperands()
403 const hlo_sharding_util::GatherParallelDims& parallel_dims, in InferGatherParallelShardingFromOperands()
434 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( in InferGatherParallelShardingFromOperands()
455 hlo_sharding_util::GatherParallelOutputDims(*instruction, parallel_dims); in InferGatherParallelShardingFromOperands()
461 hlo_sharding_util::GatherOutputAlignedOperandParallelDims( in InferGatherParallelShardingFromOperands()
530 return hlo_sharding_util::TransposeSharding(lhs->sharding(), in InferConvolutionShardingFromOperands()
[all …]
Dhlo_sharding_util.h31 namespace hlo_sharding_util {
Dhlo_sharding_util_test.cc21 namespace hlo_sharding_util { namespace
Dhlo_sharding_util.cc35 namespace hlo_sharding_util { namespace
1243 auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm); in TransposeShardingWithCollapsedDims()
DBUILD477 name = "hlo_sharding_util",
479 "hlo_sharding_util.cc",
482 "hlo_sharding_util.h",
505 ":hlo_sharding_util",
524 ":hlo_sharding_util",
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dconvolution_handler.cc97 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); in PartitionConvolutionWithBatchGroupCount()
99 hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); in PartitionConvolutionWithBatchGroupCount()
119 auto aligned_output_sharding = hlo_sharding_util::TransposeSharding( in PartitionConvolutionWithBatchGroupCount()
188 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); in PartitionConvolutionWithFeatureGroupCount()
190 hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); in PartitionConvolutionWithFeatureGroupCount()
212 auto aligned_output_sharding = hlo_sharding_util::TransposeSharding( in PartitionConvolutionWithFeatureGroupCount()
254 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
256 hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
562 auto aligned_rhs_sharding = hlo_sharding_util::ReverseSharding( in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
563 hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices), in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
[all …]
Dgather_scatter_handler.cc162 hlo_sharding_util::TransposeShardingWithCollapsedDims( in PartitionIndexOnlyPartition()
187 hlo_sharding_util::GatherOutputShardingFromDataOperand( in ParititonPassthroughOperand()
331 if (absl::optional<hlo_sharding_util::GatherParallelDims> parallel_dims = in PartitionIndexParallelDimensions()
332 hlo_sharding_util::GetGatherBatchParallelDims(*gather)) { in PartitionIndexParallelDimensions()
338 hlo_sharding_util::GatherParallelOutputDims(*gather, *parallel_dims); in PartitionIndexParallelDimensions()
573 hlo_sharding_util::TransposeShardingWithCollapsedDims( in HandleScatter()
617 auto maybe_passthrough = hlo_sharding_util::ScatterUpdateShardingFromOutput( in HandleScatter()
DBUILD52 "//tensorflow/compiler/xla/service:hlo_sharding_util",
Ddot_handler.cc495 hlo_sharding_util::TransposeShardingWithCollapsedDims( in PartitionBaseCase()
499 hlo_sharding_util::TransposeShardingWithCollapsedDims( in PartitionBaseCase()
503 hlo_sharding_util::TransposeShardingWithCollapsedDims( in PartitionBaseCase()
507 hlo_sharding_util::TransposeShardingWithCollapsedDims( in PartitionBaseCase()
511 hlo_sharding_util::TransposeShardingWithCollapsedDims( in PartitionBaseCase()
515 hlo_sharding_util::TransposeShardingWithCollapsedDims( in PartitionBaseCase()
831 auto padded_slice_sharding = hlo_sharding_util::ReshapeSharding( in PartitionBaseCase()
2054 .Reshard(hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( in PartitionDotGroupOnNonContracting()
2216 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( in PartitionDotGroupOnContracting()
2386 hlo_sharding_util::TransposeShardingWithCollapsedDims( in LhsIsBestMatchForNonContractingPartitioning()
[all …]
Dspmd_partitioner.cc1129 temp_target_tile = hlo_sharding_util::TransposeSharding( in ReshardWithAllToAll()
1327 if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( in ReshardWithCollectivePermute()
1329 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( in ReshardWithCollectivePermute()
1957 hlo_sharding_util::TransposeSharding(sharding, inverse_dimensions); in HandleTranspose()
1978 auto desired_operand_sharding = hlo_sharding_util::ReshapeSharding( in HandleReshape()
2251 auto desired_input_sharding = hlo_sharding_util::RemoveShapeDimensions( in HandleBroadcast()
2252 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(hlo->sharding(), in HandleBroadcast()
2368 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(dus_sharding, in HandleDynamicUpdateSlice()
2800 auto sharding = hlo_sharding_util::RemoveShapeDimensions( in HandleReduce()
2801 hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( in HandleReduce()
[all …]
Dspmd_partitioner_util.h418 const hlo_sharding_util::GatherParallelDims& parallel_dims);
Dspmd_partitioner_util.cc380 auto transpose_sharding = hlo_sharding_util::TransposeSharding( in PartialReplicateReshardCompatibleSharding()
1834 const hlo_sharding_util::GatherParallelDims& parallel_dims) { in GatherOperandsShardedAcrossParallelDims()