Home
last modified time | relevance | path

Searched refs:replicate_dims (Results 1 – 3 of 3) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dspmd_partitioner_util.h392 const std::vector<int64>& replicate_dims,
Dspmd_partitioner.cc933 std::vector<int64> replicate_dims; in ReshardToPartialReplicateWithAllGather() local
939 replicate_dims.emplace_back(dim); in ReshardToPartialReplicateWithAllGather()
947 partitioned_hlo.hlo_, base_shape_, temp_sharding, target, replicate_dims, in ReshardToPartialReplicateWithAllGather()
957 GroupShardingOnDims(temp_sharding, replicate_dims, replicate_factors); in ReshardToPartialReplicateWithAllGather()
969 partial_replicate_hlo.ReplicatePartial(replicate_dims); in ReshardToPartialReplicateWithAllGather()
Dspmd_partitioner_util.cc411 const std::vector<int64>& replicate_dims, in TileToPartialReplicateHaloExchange() argument
429 for (auto dim : replicate_dims) { in TileToPartialReplicateHaloExchange()