Home
last modified time | relevance | path

Searched refs:lhs_to_output_indices (Results 1 – 2 of 2) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dconvolution_handler.cc84 std::vector<int64> lhs_to_output_indices(lhs.base_shape().rank(), -1); in PartitionConvolutionWithBatchGroupCount() local
85 lhs_to_output_indices[dnums.input_batch_dimension()] = in PartitionConvolutionWithBatchGroupCount()
87 lhs_to_output_indices[dnums.input_feature_dimension()] = in PartitionConvolutionWithBatchGroupCount()
90 lhs_to_output_indices[dnums.input_spatial_dimensions(i)] = in PartitionConvolutionWithBatchGroupCount()
120 lhs.sharding(), lhs_to_output_indices); in PartitionConvolutionWithBatchGroupCount()
175 std::vector<int64> lhs_to_output_indices(output_base_shape.rank()); in PartitionConvolutionWithFeatureGroupCount() local
176 lhs_to_output_indices[dnums.input_feature_dimension()] = in PartitionConvolutionWithFeatureGroupCount()
178 lhs_to_output_indices[dnums.input_batch_dimension()] = in PartitionConvolutionWithFeatureGroupCount()
181 lhs_to_output_indices[dnums.input_spatial_dimensions(i)] = in PartitionConvolutionWithFeatureGroupCount()
213 lhs.sharding(), lhs_to_output_indices); in PartitionConvolutionWithFeatureGroupCount()
Ddot_handler.cc106 std::vector<int64> lhs_to_output_indices; member
225 const std::vector<int64>& lhs_to_output_indices, in GenNewConvDNums() argument
288 ? lhs_to_output_indices[lhs_concat_dim] in GenNewConvDNums()
343 std::vector<int64> lhs_to_output_indices(lhs_rank, -1); in ComputeDimensionIndexMapping() local
352 lhs_to_output_indices[mapping.lhs] = mapping.output; in ComputeDimensionIndexMapping()
378 return DotDimensionIndexMapping{lhs_to_rhs_indices, lhs_to_output_indices, in ComputeDimensionIndexMapping()
504 lhs_sharding, indices_map.lhs_to_output_indices, in PartitionBaseCase()
513 indices_map.lhs_to_output_indices); in PartitionBaseCase()
944 ? indices_map.lhs_to_output_indices[lhs_concat_dim] in PartitionBaseCase()
989 indices_map.lhs_to_output_indices, in PartitionBaseCase()
[all …]