Home
last modified time | relevance | path

Searched refs:GetGatherBatchParallelDims (Results 1 – 4 of 4) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_sharding_util.h206 absl::optional<GatherParallelDims> GetGatherBatchParallelDims(
Dhlo_sharding_util.cc972 auto parallel_dims = GetGatherBatchParallelDims(hlo); in GatherDataOperandShardingFromOutput()
1257 absl::optional<GatherParallelDims> GetGatherBatchParallelDims( in GetGatherBatchParallelDims() function
Dsharding_propagation.cc959 hlo_sharding_util::GetGatherBatchParallelDims(*instruction); in InferShardingFromOperands()
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dgather_scatter_handler.cc332 hlo_sharding_util::GetGatherBatchParallelDims(*gather)) { in PartitionIndexParallelDimensions()