Home
last modified time | relevance | path

Searched refs:parallel_dims (Results 1 – 7 of 7) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_sharding_util.cc896 const GatherParallelDims& parallel_dims) { in GatherParallelDataOperandSharding() argument
900 auto output_parallel_dims = GatherParallelOutputDims(gather, parallel_dims); in GatherParallelDataOperandSharding()
902 GatherOutputAlignedOperandParallelDims(gather, parallel_dims); in GatherParallelDataOperandSharding()
972 auto parallel_dims = GetGatherBatchParallelDims(hlo); in GatherDataOperandShardingFromOutput() local
974 if (parallel_dims) { in GatherDataOperandShardingFromOutput()
978 GatherParallelDataOperandSharding(hlo.sharding(), hlo, *parallel_dims); in GatherDataOperandShardingFromOutput()
979 operand_parallel_dims = parallel_dims->operand_parallel_dims; in GatherDataOperandShardingFromOutput()
1358 const HloInstruction& gather, const GatherParallelDims& parallel_dims) { in GatherOutputAlignedOperandParallelDims() argument
1360 parallel_dims.operand_parallel_dims.size(), -1); in GatherOutputAlignedOperandParallelDims()
1362 CHECK_LE(parallel_dims.indices_parallel_dims.size(), in GatherOutputAlignedOperandParallelDims()
[all …]
Dhlo_sharding_util.h218 const HloInstruction& gather, const GatherParallelDims& parallel_dims);
Dsharding_propagation.cc403 const hlo_sharding_util::GatherParallelDims& parallel_dims, in InferGatherParallelShardingFromOperands() argument
455 hlo_sharding_util::GatherParallelOutputDims(*instruction, parallel_dims); in InferGatherParallelShardingFromOperands()
462 *instruction, parallel_dims)), in InferGatherParallelShardingFromOperands()
469 absl::MakeConstSpan(parallel_dims.indices_parallel_dims), in InferGatherParallelShardingFromOperands()
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dgather_scatter_handler.cc331 if (absl::optional<hlo_sharding_util::GatherParallelDims> parallel_dims = in PartitionIndexParallelDimensions() local
334 *operand.hlo(), *indices.hlo(), *parallel_dims)) { in PartitionIndexParallelDimensions()
335 auto indices_parallel_dims = parallel_dims->indices_parallel_dims; in PartitionIndexParallelDimensions()
336 auto operand_parallel_dims = parallel_dims->operand_parallel_dims; in PartitionIndexParallelDimensions()
338 hlo_sharding_util::GatherParallelOutputDims(*gather, *parallel_dims); in PartitionIndexParallelDimensions()
Dspmd_partitioner_util.h418 const hlo_sharding_util::GatherParallelDims& parallel_dims);
Dspmd_partitioner_util.cc1834 const hlo_sharding_util::GatherParallelDims& parallel_dims) { in GatherOperandsShardedAcrossParallelDims() argument
1835 auto& indices_parallel_dims = parallel_dims.indices_parallel_dims; in GatherOperandsShardedAcrossParallelDims()
1836 auto& operand_parallel_dims = parallel_dims.operand_parallel_dims; in GatherOperandsShardedAcrossParallelDims()
1848 for (int idx : parallel_dims.index_parallel_in_dim) { in GatherOperandsShardedAcrossParallelDims()
/external/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/
Dlegalize_to_linalg.cc1277 SmallVector<unsigned, 4> parallel_dims; in GetReduceOpInitTensorDynSizes() local