Home
last modified time | relevance | path

Searched refs:indices_sharding (Results 1 – 2 of 2) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dgather_scatter_handler.cc339 HloSharding indices_sharding = gather_sharding->indices_sharding; in PartitionIndexParallelDimensions() local
342 GroupShardingOnDims(indices_sharding, indices_parallel_dims); in PartitionIndexParallelDimensions()
354 indices_sharding.tile_assignment().dim(indices_idx); in PartitionIndexParallelDimensions()
357 indices = indices.Reshard(indices_sharding); in PartitionIndexParallelDimensions()
358 if (indices_sharding.ReplicateOnLastTileDim()) { in PartitionIndexParallelDimensions()
360 indices_sharding.tile_assignment().dimensions().back()); in PartitionIndexParallelDimensions()
362 Array<int64> output_tile_assignment = indices_sharding.tile_assignment(); in PartitionIndexParallelDimensions()
366 indices_sharding.ReplicateOnLastTileDim() in PartitionIndexParallelDimensions()
427 indices_sharding.NumTiles() == in PartitionIndexParallelDimensions()
428 indices_sharding.NumTiles(indices_parallel_dims)) { in PartitionIndexParallelDimensions()
Dspmd_partitioner_util.h35 HloSharding indices_sharding; member