Home
last modified time | relevance | path

Searched refs:index_dim (Results 1 – 2 of 2) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_sharding_util.cc507 for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) { in GatherOutputSharding() local
512 index_dim >= dnums.index_vector_dim() ? index_dim + 1 : index_dim; in GatherOutputSharding()
515 ++index_dim; in GatherOutputSharding()
680 for (int64 i = 0, index_dim = 0; i < hlo->shape().rank(); ++i) { in ScatterDataSharding() local
685 index_sharding.tile_assignment().dim(index_dim)); in ScatterDataSharding()
686 index_dim++; in ScatterDataSharding()
715 int64 index_dim = 0; in ScatterEffectiveIndexSharding() local
718 num_elements *= index_sharding.tile_assignment().dim(index_dim); in ScatterEffectiveIndexSharding()
719 index_dim++; in ScatterEffectiveIndexSharding()
739 if (i < index_dim) { in ScatterEffectiveIndexSharding()
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dgather_scatter_handler.cc345 int index_dim = dnums.index_vector_dim(); in PartitionIndexParallelDimensions() local
382 indices.base_shape().dimensions_size() > index_dim in PartitionIndexParallelDimensions()
390 if (indices.base_shape().dimensions_size() > index_dim) { in PartitionIndexParallelDimensions()
394 {indices.base_shape().dimensions(index_dim)}), in PartitionIndexParallelDimensions()
411 indices.hlo()->shape(), adjusted_indices, {index_dim})); in PartitionIndexParallelDimensions()