Searched refs:index_parallel_in_dim (Results 1 – 3 of 3) sorted by relevance
1274 std::vector<int64> index_parallel_in_dim(num_indices, -1); in GetGatherBatchParallelDims() local1288 index_parallel_in_dim[concatenated_dims + j] = in GetGatherBatchParallelDims()1302 index_parallel_in_dim.assign(num_indices_from_element, in GetGatherBatchParallelDims()1313 for (int i = 0; i < index_parallel_in_dim.size(); ++i) { in GetGatherBatchParallelDims()1314 int index_parallel_dim = index_parallel_in_dim[i]; in GetGatherBatchParallelDims()1326 index_parallel_in_dim[i] = -1; in GetGatherBatchParallelDims()1332 index_parallel_in_dim}; in GetGatherBatchParallelDims()1364 for (int i = 0; i < parallel_dims.index_parallel_in_dim.size(); ++i) { in GatherOutputAlignedOperandParallelDims()1367 const int64 index_parallel_dim = parallel_dims.index_parallel_in_dim[i]; in GatherOutputAlignedOperandParallelDims()
36 std::vector<int64> index_parallel_in_dim; member
1848 for (int idx : parallel_dims.index_parallel_in_dim) { in GatherOperandsShardedAcrossParallelDims()