Searched refs:operand_parallel_dims (Results 1 – 5 of 5) sorted by relevance
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | hlo_sharding_util.cc | 973 absl::Span<const int64> operand_parallel_dims; in GatherDataOperandShardingFromOutput() local 979 operand_parallel_dims = parallel_dims->operand_parallel_dims; in GatherDataOperandShardingFromOutput() 982 output_sharding, operand_parallel_dims); in GatherDataOperandShardingFromOutput() 1307 absl::InlinedVector<int64, 1> operand_parallel_dims; in GetGatherBatchParallelDims() local 1324 operand_parallel_dims.push_back(dnums.start_index_map(i)); in GetGatherBatchParallelDims() 1331 return GatherParallelDims{indices_parallel_dims, operand_parallel_dims, in GetGatherBatchParallelDims() 1360 parallel_dims.operand_parallel_dims.size(), -1); in GatherOutputAlignedOperandParallelDims() 1363 parallel_dims.operand_parallel_dims.size()); in GatherOutputAlignedOperandParallelDims()
|
D | hlo_sharding_util.h | 35 absl::InlinedVector<int64, 1> operand_parallel_dims; member
|
D | sharding_propagation.cc | 965 absl::Span<const int64> operand_parallel_dims; in InferShardingFromOperands() local 967 operand_parallel_dims = absl::MakeConstSpan( in InferShardingFromOperands() 968 gather_parallel_dims->operand_parallel_dims); in InferShardingFromOperands() 972 instruction->operand(0)->sharding(), operand_parallel_dims); in InferShardingFromOperands()
|
/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
D | gather_scatter_handler.cc | 336 auto operand_parallel_dims = parallel_dims->operand_parallel_dims; in PartitionIndexParallelDimensions() local 344 GroupShardingOnDims(operand_sharding, operand_parallel_dims); in PartitionIndexParallelDimensions() 377 b, operand_parallel_dims); in PartitionIndexParallelDimensions() 426 operand_sharding.NumTiles(operand_parallel_dims) && in PartitionIndexParallelDimensions()
|
D | spmd_partitioner_util.cc | 1836 auto& operand_parallel_dims = parallel_dims.operand_parallel_dims; in GatherOperandsShardedAcrossParallelDims() local 1837 if (indices_parallel_dims.size() != operand_parallel_dims.size()) { in GatherOperandsShardedAcrossParallelDims() 1843 int op_parallel_tiles_num = new_operand_shard.NumTiles(operand_parallel_dims); in GatherOperandsShardedAcrossParallelDims() 1857 operand_parallel_dims), in GatherOperandsShardedAcrossParallelDims() 1864 operand_parallel_dims, in GatherOperandsShardedAcrossParallelDims() 1871 auto to_adjust_dims = operand_parallel_dims; in GatherOperandsShardedAcrossParallelDims() 1921 operand_shard_tile_dims[operand_parallel_dims[i]] = in GatherOperandsShardedAcrossParallelDims() 1931 operand_parallel_dims, new_index_shard, in GatherOperandsShardedAcrossParallelDims()
|