Home
last modified time | relevance | path

Searched refs:shard_counts (Results 1 – 1 of 1) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dconvolution_handler.cc318 std::vector<int64> shard_counts(dnums.input_spatial_dimensions_size()); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS() local
335 shard_counts[i] = shard_count; in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
365 if (shard_counts[i] == 1) { in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
395 left_halo_size_functions[rhs_dimension].MaxInRange(1, shard_counts[i]) + in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
397 0, shard_counts[i] - 1); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
420 new_padding_low_function.MaxInRange(0, shard_counts[i]); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
483 offset_on_padded_shape.Calculate(shard_counts[i] - 1) + in PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS()
627 std::vector<int64> shard_counts(dnums.input_spatial_dimensions_size()); in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS() local
644 shard_counts[i] = shard_count; in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()
663 if (shard_counts[i] == 1) { in PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS()