Home
last modified time | relevance | path

Searched refs:GetPerGroupBaseShape (Results 1 – 4 of 4) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Ddot_handler.cc1736 lhs.hlo(), GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()), in PartitionDotGroupOnBatch()
1739 rhs.hlo(), GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()), in PartitionDotGroupOnBatch()
1822 GetPerGroupBaseShape(grouped, operand.base_shape()), in PartitionDotGroupOnBatch()
1869 GetPerGroupBaseShape(output_grouped, output_base_shape), in PartitionDotGroupOnBatch()
2039 GetPerGroupBaseShape(matching_grouped, matching.base_shape()), in PartitionDotGroupOnNonContracting()
2067 GetPerGroupBaseShape(output_grouped, output_base_shape), in PartitionDotGroupOnNonContracting()
2226 GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()), in PartitionDotGroupOnContracting()
2229 GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()), in PartitionDotGroupOnContracting()
2408 GetPerGroupBaseShape(output_grouped, output_base_shape)), in LhsIsBestMatchForNonContractingPartitioning()
Dspmd_partitioner_util.h342 Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding,
Dgather_scatter_handler.cc435 GetPerGroupBaseShape(grouped_operand, operand.base_shape()), in PartitionIndexParallelDimensions()
439 GetPerGroupBaseShape(grouped_indices, indices.base_shape()), in PartitionIndexParallelDimensions()
Dspmd_partitioner_util.cc1556 Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding, in GetPerGroupBaseShape() function