Home
last modified time | relevance | path

Searched refs:rhs_dimension (Results 1 – 1 of 1) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dconvolution_handler.cc324 int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() local
325 int64 shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
334 CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
361 int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() local
379 left_halo_size_functions[rhs_dimension] = in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
385 right_halo_size_functions[rhs_dimension] = in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
395 left_halo_size_functions[rhs_dimension].MaxInRange(1, shard_counts[i]) + in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
396 right_halo_size_functions[rhs_dimension].MaxInRange( in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
399 rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size; in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
412 left_halo_size_functions[rhs_dimension], in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
[all …]