Searched refs:MakePartitionedShape (Results 1 – 6 of 6) sorted by relevance
/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
D | spmd_partitioner.cc | 470 auto shard_shape = MakePartitionedShape(shape, target); in ReshardNoCache() 774 auto original_shard_shape = MakePartitionedShape(base_shape_, target); in ReshardAsWindowedInput() 961 auto base_shape = MakePartitionedShape(base_shape_, target); in ReshardToPartialReplicateWithAllGather() 1024 auto shard_shape = MakePartitionedShape(base_shape_, temp_target_sharding); in ReshardFromPartialReplicateWithDynamicSlice() 1214 const Shape result_shape = MakePartitionedShape(base_shape_, temp_target); in ReshardWithAllToAll() 1480 MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands)); in HandleElementwise() 1491 const Shape shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); in HandleConcatenate() 1512 auto temp_output_shape = MakePartitionedShape(hlo->shape(), sharding); in HandleConcatenate() 1514 MakePartitionedShape(hlo->operands().back()->shape(), sharding); in HandleConcatenate() 1627 auto shard_shape = MakePartitionedShape(hlo->shape(), sharding); in HandleSlice() [all …]
|
D | spmd_partitioner_util.h | 89 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding);
|
D | gather_scatter_handler.cc | 191 auto pshape = MakePartitionedShape(output_shape, *maybe_passthrough); in ParititonPassthroughOperand() 370 Shape pshape = MakePartitionedShape(output_shape, gather_output_sharding); in PartitionIndexParallelDimensions()
|
D | spmd_partitioner_util.cc | 148 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) { in MakePartitionedShape() function 153 MakePartitionedShape(ShapeUtil::GetTupleElementShape(shape, i), in MakePartitionedShape() 214 auto shard_shape = MakePartitionedShape(shape, sharding); in MakePartitionOffsets() 284 auto shard_shape = MakePartitionedShape(base_shape, sharding); in GetPaddedShapeForUnevenPartitioning() 1727 MakePartitionedShape(replicated->shape(), group_level_sharding); in PerGroupSliceFromReplicated()
|
D | convolution_handler.cc | 803 auto shard_shape = MakePartitionedShape(output_base_shape, output_sharding); in PartitionConvolutionTiledOutput()
|
D | dot_handler.cc | 603 MakePartitionedShape(output_base_shape, output_sharding); in PartitionBaseCase() 1948 ShapeUtil::ByteSizeOf(MakePartitionedShape( in GetNonContractingPartitionGroupedShardingForOtherOperand()
|