Home
last modified time | relevance | path

Searched refs:operand_parallel_dims (Results 1 – 5 of 5) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_sharding_util.cc973 absl::Span<const int64> operand_parallel_dims; in GatherDataOperandShardingFromOutput() local
979 operand_parallel_dims = parallel_dims->operand_parallel_dims; in GatherDataOperandShardingFromOutput()
982 output_sharding, operand_parallel_dims); in GatherDataOperandShardingFromOutput()
1307 absl::InlinedVector<int64, 1> operand_parallel_dims; in GetGatherBatchParallelDims() local
1324 operand_parallel_dims.push_back(dnums.start_index_map(i)); in GetGatherBatchParallelDims()
1331 return GatherParallelDims{indices_parallel_dims, operand_parallel_dims, in GetGatherBatchParallelDims()
1360 parallel_dims.operand_parallel_dims.size(), -1); in GatherOutputAlignedOperandParallelDims()
1363 parallel_dims.operand_parallel_dims.size()); in GatherOutputAlignedOperandParallelDims()
Dhlo_sharding_util.h35 absl::InlinedVector<int64, 1> operand_parallel_dims; member
Dsharding_propagation.cc965 absl::Span<const int64> operand_parallel_dims; in InferShardingFromOperands() local
967 operand_parallel_dims = absl::MakeConstSpan( in InferShardingFromOperands()
968 gather_parallel_dims->operand_parallel_dims); in InferShardingFromOperands()
972 instruction->operand(0)->sharding(), operand_parallel_dims); in InferShardingFromOperands()
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dgather_scatter_handler.cc336 auto operand_parallel_dims = parallel_dims->operand_parallel_dims; in PartitionIndexParallelDimensions() local
344 GroupShardingOnDims(operand_sharding, operand_parallel_dims); in PartitionIndexParallelDimensions()
377 b, operand_parallel_dims); in PartitionIndexParallelDimensions()
426 operand_sharding.NumTiles(operand_parallel_dims) && in PartitionIndexParallelDimensions()
Dspmd_partitioner_util.cc1836 auto& operand_parallel_dims = parallel_dims.operand_parallel_dims; in GatherOperandsShardedAcrossParallelDims() local
1837 if (indices_parallel_dims.size() != operand_parallel_dims.size()) { in GatherOperandsShardedAcrossParallelDims()
1843 int op_parallel_tiles_num = new_operand_shard.NumTiles(operand_parallel_dims); in GatherOperandsShardedAcrossParallelDims()
1857 operand_parallel_dims), in GatherOperandsShardedAcrossParallelDims()
1864 operand_parallel_dims, in GatherOperandsShardedAcrossParallelDims()
1871 auto to_adjust_dims = operand_parallel_dims; in GatherOperandsShardedAcrossParallelDims()
1921 operand_shard_tile_dims[operand_parallel_dims[i]] = in GatherOperandsShardedAcrossParallelDims()
1931 operand_parallel_dims, new_index_shard, in GatherOperandsShardedAcrossParallelDims()