Home
last modified time | relevance | path

Searched defs:parallel_dims (Results 1 – 5 of 5) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_sharding_util.cc896 const GatherParallelDims& parallel_dims) { in GatherParallelDataOperandSharding()
972 auto parallel_dims = GetGatherBatchParallelDims(hlo); in GatherDataOperandShardingFromOutput() local
1358 const HloInstruction& gather, const GatherParallelDims& parallel_dims) { in GatherOutputAlignedOperandParallelDims()
Dsharding_propagation.cc403 const hlo_sharding_util::GatherParallelDims& parallel_dims, in InferGatherParallelShardingFromOperands()
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dgather_scatter_handler.cc331 if (absl::optional<hlo_sharding_util::GatherParallelDims> parallel_dims = in PartitionIndexParallelDimensions() local
Dspmd_partitioner_util.cc1834 const hlo_sharding_util::GatherParallelDims& parallel_dims) { in GatherOperandsShardedAcrossParallelDims()
/external/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/
Dlegalize_to_linalg.cc1277 SmallVector<unsigned, 4> parallel_dims; in GetReduceOpInitTensorDynSizes() local