Searched refs:indices_sharding (Results 1 – 2 of 2) sorted by relevance
339 HloSharding indices_sharding = gather_sharding->indices_sharding; in PartitionIndexParallelDimensions() local342 GroupShardingOnDims(indices_sharding, indices_parallel_dims); in PartitionIndexParallelDimensions()354 indices_sharding.tile_assignment().dim(indices_idx); in PartitionIndexParallelDimensions()357 indices = indices.Reshard(indices_sharding); in PartitionIndexParallelDimensions()358 if (indices_sharding.ReplicateOnLastTileDim()) { in PartitionIndexParallelDimensions()360 indices_sharding.tile_assignment().dimensions().back()); in PartitionIndexParallelDimensions()362 Array<int64> output_tile_assignment = indices_sharding.tile_assignment(); in PartitionIndexParallelDimensions()366 indices_sharding.ReplicateOnLastTileDim() in PartitionIndexParallelDimensions()427 indices_sharding.NumTiles() == in PartitionIndexParallelDimensions()428 indices_sharding.NumTiles(indices_parallel_dims)) { in PartitionIndexParallelDimensions()
35 HloSharding indices_sharding; member