Searched refs:rhs_non_contracting_dims (Results 1 – 7 of 7) sorted by relevance
60 dims.rhs_non_contracting_dims.push_back( in ParseConvolutionDimsInfo()96 dims.rhs_non_contracting_dims.push_back({lhs, rhs, output, i}); in ParseConvolutionDimsInfo()130 for (const auto& dim : dot_dnums.rhs_non_contracting_dims) { in CreateShardedConvForDotGeneralConvolution()190 dnums.rhs_non_contracting_dims.emplace_back(); in ParseDotGeneralFromDot()191 dnums.rhs_non_contracting_dims.back().lhs = -1; in ParseDotGeneralFromDot()192 dnums.rhs_non_contracting_dims.back().rhs = i; in ParseDotGeneralFromDot()193 dnums.rhs_non_contracting_dims.back().output = in ParseDotGeneralFromDot()196 dnums.rhs_non_contracting_dims.size() - 1; in ParseDotGeneralFromDot()197 dnums.rhs_non_contracting_dims.back().spatial_dim = -1; in ParseDotGeneralFromDot()
100 std::vector<int64> rhs_non_contracting_dims; in CanonicalizeDot() local101 rhs_non_contracting_dims.reserve(num_rhs_non_contracting_dims); in CanonicalizeDot()109 rhs_non_contracting_dims.push_back(i); in CanonicalizeDot()125 rhs_transpose.insert(rhs_transpose.end(), rhs_non_contracting_dims.begin(), in CanonicalizeDot()126 rhs_non_contracting_dims.end()); in CanonicalizeDot()
46 std::vector<DimNums> rhs_non_contracting_dims; member
354 ? dnums.rhs_non_contracting_dims in InferDotShardingFromOperands()373 : dnums.rhs_non_contracting_dims) { in InferDotShardingFromOperands()1053 for (const auto& dim : operand_index == 0 ? dnums.rhs_non_contracting_dims in InferDotOperandSharding()1069 : dnums.rhs_non_contracting_dims) { in InferDotOperandSharding()1085 : dnums.rhs_non_contracting_dims) { in InferDotOperandSharding()
73 mapping.rhs_non_contracting_dims.emplace_back(); in HandleDot()74 mapping.rhs_non_contracting_dims.back().lhs = -1; in HandleDot()75 mapping.rhs_non_contracting_dims.back().rhs = i; in HandleDot()76 mapping.rhs_non_contracting_dims.back().output = next_output_dim++; in HandleDot()372 for (const auto& mapping : dims_mapping.rhs_non_contracting_dims) { in ComputeDimensionIndexMapping()1834 std::vector<int64> rhs_non_contracting_dims; in PartitionDotGroupOnBatch() local1837 rhs_non_contracting_dims.reserve( in PartitionDotGroupOnBatch()1838 dims_mapping.rhs_non_contracting_dims.size()); in PartitionDotGroupOnBatch()1842 for (const auto& dim : dims_mapping.rhs_non_contracting_dims) { in PartitionDotGroupOnBatch()1843 rhs_non_contracting_dims.push_back(dim.rhs); in PartitionDotGroupOnBatch()[all …]
916 for (const auto& dim : dot_dnums.rhs_non_contracting_dims) { in CreateShardedConvConvolution()1011 for (const auto& dims : dims_info.rhs_non_contracting_dims) { in HandleConvolution()1012 mapping.rhs_non_contracting_dims.emplace_back(); in HandleConvolution()1013 mapping.rhs_non_contracting_dims.back().lhs = dims.lhs; in HandleConvolution()1014 mapping.rhs_non_contracting_dims.back().rhs = dims.rhs; in HandleConvolution()1015 mapping.rhs_non_contracting_dims.back().output = dims.output; in HandleConvolution()1016 mapping.rhs_non_contracting_dims.back().spatial = dims.spatial_dim; in HandleConvolution()
396 std::vector<DimsMapping> rhs_non_contracting_dims; member