Home
last modified time | relevance | path

Searched refs:dst_shard_count (Results 1 – 1 of 1) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dspmd_partitioner_util.cc430 int64 dst_shard_count = dst_sharding.tile_assignment().dim(dim); in TileToPartialReplicateHaloExchange() local
432 padded_src_shape.dimensions(dim) / dst_shard_count; in TileToPartialReplicateHaloExchange()
436 padded_dst_shape.dimensions(dim) / dst_shard_count; in TileToPartialReplicateHaloExchange()
439 if (src_per_shard_size <= dst_per_shard_size || dst_shard_count == 1) { in TileToPartialReplicateHaloExchange()
486 (src_per_shard_size - dst_per_shard_size) * (dst_shard_count - 1), in TileToPartialReplicateHaloExchange()