Searched refs:index_dim (Results 1 – 2 of 2) sorted by relevance
507 for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) { in GatherOutputSharding() local512 index_dim >= dnums.index_vector_dim() ? index_dim + 1 : index_dim; in GatherOutputSharding()515 ++index_dim; in GatherOutputSharding()680 for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) { in ScatterDataSharding() local685 index_sharding.tile_assignment().dim(index_dim)); in ScatterDataSharding()686 index_dim++; in ScatterDataSharding()715 int64 index_dim = 0; in ScatterEffectiveIndexSharding() local718 num_elements *= index_sharding.tile_assignment().dim(index_dim); in ScatterEffectiveIndexSharding()719 index_dim++; in ScatterEffectiveIndexSharding()739 if (i < index_dim) { in ScatterEffectiveIndexSharding()[all …]
345 int index_dim = dnums.index_vector_dim(); in PartitionIndexParallelDimensions() local382 indices.base_shape().dimensions_size() > index_dim in PartitionIndexParallelDimensions()390 if (indices.base_shape().dimensions_size() > index_dim) { in PartitionIndexParallelDimensions()394 {indices.base_shape().dimensions(index_dim)}), in PartitionIndexParallelDimensions()411 indices.hlo()->shape(), adjusted_indices, {index_dim})); in PartitionIndexParallelDimensions()