Searched refs:GetPerGroupBaseShape (Results 1 – 4 of 4) sorted by relevance
/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
D | dot_handler.cc | 1736 lhs.hlo(), GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()), in PartitionDotGroupOnBatch() 1739 rhs.hlo(), GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()), in PartitionDotGroupOnBatch() 1822 GetPerGroupBaseShape(grouped, operand.base_shape()), in PartitionDotGroupOnBatch() 1869 GetPerGroupBaseShape(output_grouped, output_base_shape), in PartitionDotGroupOnBatch() 2039 GetPerGroupBaseShape(matching_grouped, matching.base_shape()), in PartitionDotGroupOnNonContracting() 2067 GetPerGroupBaseShape(output_grouped, output_base_shape), in PartitionDotGroupOnNonContracting() 2226 GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()), in PartitionDotGroupOnContracting() 2229 GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()), in PartitionDotGroupOnContracting() 2408 GetPerGroupBaseShape(output_grouped, output_base_shape)), in LhsIsBestMatchForNonContractingPartitioning()
|
D | spmd_partitioner_util.h | 342 Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding,
|
D | gather_scatter_handler.cc | 435 GetPerGroupBaseShape(grouped_operand, operand.base_shape()), in PartitionIndexParallelDimensions() 439 GetPerGroupBaseShape(grouped_indices, indices.base_shape()), in PartitionIndexParallelDimensions()
|
D | spmd_partitioner_util.cc | 1556 Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding, in GetPerGroupBaseShape() function
|