Home
last modified time | relevance | path

Searched refs:MakePartitionedShape (Results 1 – 6 of 6) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dspmd_partitioner.cc470 auto shard_shape = MakePartitionedShape(shape, target); in ReshardNoCache()
774 auto original_shard_shape = MakePartitionedShape(base_shape_, target); in ReshardAsWindowedInput()
961 auto base_shape = MakePartitionedShape(base_shape_, target); in ReshardToPartialReplicateWithAllGather()
1024 auto shard_shape = MakePartitionedShape(base_shape_, temp_target_sharding); in ReshardFromPartialReplicateWithDynamicSlice()
1214 const Shape result_shape = MakePartitionedShape(base_shape_, temp_target); in ReshardWithAllToAll()
1480 MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands)); in HandleElementwise()
1491 const Shape shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); in HandleConcatenate()
1512 auto temp_output_shape = MakePartitionedShape(hlo->shape(), sharding); in HandleConcatenate()
1514 MakePartitionedShape(hlo->operands().back()->shape(), sharding); in HandleConcatenate()
1627 auto shard_shape = MakePartitionedShape(hlo->shape(), sharding); in HandleSlice()
[all …]
Dspmd_partitioner_util.h89 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding);
Dgather_scatter_handler.cc191 auto pshape = MakePartitionedShape(output_shape, *maybe_passthrough); in ParititonPassthroughOperand()
370 Shape pshape = MakePartitionedShape(output_shape, gather_output_sharding); in PartitionIndexParallelDimensions()
Dspmd_partitioner_util.cc148 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) { in MakePartitionedShape() function
153 MakePartitionedShape(ShapeUtil::GetTupleElementShape(shape, i), in MakePartitionedShape()
214 auto shard_shape = MakePartitionedShape(shape, sharding); in MakePartitionOffsets()
284 auto shard_shape = MakePartitionedShape(base_shape, sharding); in GetPaddedShapeForUnevenPartitioning()
1727 MakePartitionedShape(replicated->shape(), group_level_sharding); in PerGroupSliceFromReplicated()
Dconvolution_handler.cc803 auto shard_shape = MakePartitionedShape(output_base_shape, output_sharding); in PartitionConvolutionTiledOutput()
Ddot_handler.cc603 MakePartitionedShape(output_base_shape, output_sharding); in PartitionBaseCase()
1948 ShapeUtil::ByteSizeOf(MakePartitionedShape( in GetNonContractingPartitionGroupedShardingForOtherOperand()