Searched refs:get_partitions_for_dims (Results 1 – 2 of 2) sorted by relevance
/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
D | dot_handler.cc | 2452 auto get_partitions_for_dims = in PartitionDot() local 2473 get_partitions_for_dims(lhs.sharding(), dims_mapping.batch_dims, 0); in PartitionDot() 2475 get_partitions_for_dims(rhs.sharding(), dims_mapping.batch_dims, 1); in PartitionDot() 2477 get_partitions_for_dims(output_sharding, dims_mapping.batch_dims, 2); in PartitionDot() 2479 get_partitions_for_dims(lhs.sharding(), dims_mapping.contracting_dims, 0); in PartitionDot() 2481 get_partitions_for_dims(rhs.sharding(), dims_mapping.contracting_dims, 1); in PartitionDot() 2482 const int64 lhs_non_contracting_partitions = get_partitions_for_dims( in PartitionDot() 2484 const int64 rhs_non_contracting_partitions = get_partitions_for_dims( in PartitionDot() 2486 const int64 output_lhs_non_contracting_partitions = get_partitions_for_dims( in PartitionDot() 2488 const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims( in PartitionDot() [all …]
|
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | sharding_propagation.cc | 480 auto get_partitions_for_dims = in InferConvolutionShardingFromOperands() local 506 const int64 lhs_conv_spatial_partitions = get_partitions_for_dims( in InferConvolutionShardingFromOperands() 508 const int64 rhs_conv_spatial_partitions = get_partitions_for_dims( in InferConvolutionShardingFromOperands()
|