Home
last modified time | relevance | path

Searched refs:NumTiles (Results 1 – 8 of 8) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_sharding.h282 int64 NumTiles() const;
285 int64 NumTiles(absl::Span<const int64> dims) const;
Dhlo_sharding.cc598 int64 HloSharding::NumTiles() const { in NumTiles() function in xla::HloSharding
610 int64 HloSharding::NumTiles(absl::Span<const int64> dims) const { in NumTiles() function in xla::HloSharding
Dhlo_sharding_util.cc58 return lhs.NumTiles() > rhs.NumTiles(); in IsShardingMoreSpecific()
1157 if (group_count == sharding.NumTiles()) { in PartiallyReplicateTiledShardingOnDims()
Dsharding_propagation.cc86 int64 sharding_tiles = sharding.NumTiles(); in MaybeImproveInstructionSharding()
94 sharding.NumTiles() == sharding_tiles) { in MaybeImproveInstructionSharding()
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dgather_scatter_handler.cc47 return trivial_slice_dims_partitions == operand.sharding().NumTiles(); in GatherScatterOperandPartitionedOnlyOnTrivialSliceDims()
425 if (operand_sharding.NumTiles() == in PartitionIndexParallelDimensions()
426 operand_sharding.NumTiles(operand_parallel_dims) && in PartitionIndexParallelDimensions()
427 indices_sharding.NumTiles() == in PartitionIndexParallelDimensions()
428 indices_sharding.NumTiles(indices_parallel_dims)) { in PartitionIndexParallelDimensions()
Dspmd_partitioner_util.cc1289 source.NumTiles() != target.NumTiles()) { in GetReshardAllToAllSourceTargetDims()
1756 if (sharding.NumTiles() < device_groups.size() || device_groups.size() < 2 || in FindMatchingPartitionedDimsForGrouping()
1842 int idx_parallel_tiles_num = new_index_shard.NumTiles(indices_parallel_dims); in GatherOperandsShardedAcrossParallelDims()
1843 int op_parallel_tiles_num = new_operand_shard.NumTiles(operand_parallel_dims); in GatherOperandsShardedAcrossParallelDims()
Dspmd_partitioner.cc351 hlo_->shape().IsArray() && target.NumTiles() < sharding().NumTiles(); in Reshard()
Ddot_handler.cc2770 output_sharding.NumTiles(), create_sharded_dot, in PartitionDot()