/* * Copyright (C) 2015 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "RenderScript.h" #include "rsCppInternal.h" using namespace android; using namespace RSC; // ScriptIntrinsicBLAS APIS ScriptIntrinsicBLAS::ScriptIntrinsicBLAS(sp rs, sp e) : ScriptIntrinsic(rs, RS_SCRIPT_INTRINSIC_ID_BLAS, e) { } sp ScriptIntrinsicBLAS::create(sp rs) { return new ScriptIntrinsicBLAS(rs, Element::U32(rs)); } enum RsBlasDataType { SINGLE, DOUBLE, SINGLE_COMPLEX, DOUBLE_COMPLEX }; static RsBlasCall setUpBLASCall(RsBlasDataType dataType, RsBlasFunction func, int TransA, int TransB, int Side, int Uplo, int Diag, int M, int N, int K, int incX, int incY, int KL, int KU, float alphaF, float betaF, double alphaD, double betaD, float alphaCX, float alphaCY, float betaCX, float betaCY, double alphaZX, double alphaZY, double betaZX, double betaZY ) { RsBlasCall call; memset(&call, 0, sizeof(call)); call.func = func; call.transA = (RsBlasTranspose)TransA; call.transB = (RsBlasTranspose)TransB; call.side = (RsBlasSide)Side; call.uplo = (RsBlasUplo)Uplo; call.diag = (RsBlasDiag)Diag; call.M = M; call.N = N; call.K = K; switch (dataType) { case SINGLE: // For Single-precision BLAS. call.alpha.f = alphaF; call.beta.f = betaF; break; case DOUBLE: // For Double-precision BLAS. call.alpha.d = alphaD; call.beta.d = betaD; break; case SINGLE_COMPLEX: // For Single-precision complex BLAS. call.alpha.c.r = alphaCX; call.alpha.c.i = alphaCY; call.beta.c.r = betaCX; call.beta.c.i = betaCY; break; case DOUBLE_COMPLEX: // For Double-precision complex BLAS. call.alpha.z.r = alphaZX; call.alpha.z.i = alphaZY; call.beta.z.r = betaZX; call.beta.z.i = betaZY; break; default: break; } call.incX = incX; call.incY = incY; call.KL = KL; call.KU = KU; return call; } static void nScriptIntrinsicBLAS_Single(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, int TransB, int Side, int Uplo, int Diag, int M, int N, int K, float alpha, RsAllocation A, RsAllocation B, float beta, RsAllocation C, int incX, int incY, int KL, int KU) { RsBlasCall call = setUpBLASCall(SINGLE, func, TransA, TransB, Side, Uplo, Diag, M, N, K, incX, incY, KL, KU, alpha, beta, 0.0, 0.0, 0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0); RsAllocation in_allocs[3] = {A, B, C}; tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr, &call, sizeof(call), nullptr, 0)); } static void nScriptIntrinsicBLAS_Double(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, int TransB, int Side, int Uplo, int Diag, int M, int N, int K, double alpha, RsAllocation A, RsAllocation B, double beta, RsAllocation C, int incX, int incY, int KL, int KU) { RsBlasCall call = setUpBLASCall(DOUBLE, func, TransA, TransB, Side, Uplo, Diag, M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, alpha, beta, 0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0); RsAllocation in_allocs[3] = {A, B, C}; tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr, &call, sizeof(call), nullptr, 0)); } static void nScriptIntrinsicBLAS_Complex(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, int TransB, int Side, int Uplo, int Diag, int M, int N, int K, float alphaX, float alphaY, RsAllocation A, RsAllocation B, float betaX, float betaY, RsAllocation C, int incX, int incY, int KL, int KU) { RsBlasCall call = setUpBLASCall(SINGLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag, M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0, alphaX, alphaY, betaX, betaY, 0.0, 0.0, 0.0, 0.0); RsAllocation in_allocs[3] = {A, B, C}; tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr, &call, sizeof(call), nullptr, 0)); } static void nScriptIntrinsicBLAS_Z(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, int TransB, int Side, int Uplo, int Diag, int M, int N, int K, double alphaX, double alphaY, RsAllocation A, RsAllocation B, double betaX, double betaY, RsAllocation C, int incX, int incY, int KL, int KU) { RsBlasCall call = setUpBLASCall(DOUBLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag, M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0, 0.0f, 0.0f, 0.0f, 0.0f, alphaX, alphaY, betaX, betaY); RsAllocation in_allocs[3] = {A, B, C}; tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr, &call, sizeof(call), nullptr, 0)); } static void nScriptIntrinsicBLAS_BNNM(RS* mRS, RsContext con, RsScript id, int M, int N, int K, RsAllocation A, int a_offset, RsAllocation B, int b_offset, RsAllocation C, int c_offset, int c_mult_int) { RsBlasCall call; memset(&call, 0, sizeof(call)); call.func = RsBlas_bnnm; call.M = M; call.N = N; call.K = K; call.a_offset = a_offset & 0xFF; call.b_offset = b_offset & 0xFF; call.c_offset = c_offset; call.c_mult_int = c_mult_int; RsAllocation in_allocs[3] = {A, B, C}; tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr, &call, sizeof(call), nullptr, 0)); } /** * Level 2 BLAS */ static void validateGEMV(RS* mRS, sp e, RsBlasTranspose TransA, sp A, sp X, int incX, sp Y, int incY) { int M = A->getType()->getY(); int N = A->getType()->getX(); if (!A->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e) || !Y->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } if (incX <= 0 || incY <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = -1, expectedYDim = -1; if (TransA == RsBlasNoTrans) { expectedXDim = 1 + (N - 1) * incX; expectedYDim = 1 + (M - 1) * incY; } else { expectedXDim = 1 + (M - 1) * incX; expectedYDim = 1 + (N - 1) * incY; } if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GEMV"); } } void ScriptIntrinsicBLAS::SGEMV(RsBlasTranspose TransA, float alpha, sp A, sp X, int incX, float beta, sp Y, int incY) { validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY); int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::DGEMV(RsBlasTranspose TransA, double alpha, sp A, sp X, int incX, double beta, sp Y, int incY) { validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY); int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::CGEMV(RsBlasTranspose TransA, Float2 alpha, sp A, sp X, int incX, Float2 beta, sp Y, int incY) { validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY); int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::ZGEMV(RsBlasTranspose TransA, Double2 alpha, sp A, sp X, int incX, Double2 beta, sp Y, int incY) { validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY); int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::SGBMV(RsBlasTranspose TransA, int KL, int KU, float alpha, sp A, sp X, int incX, float beta, sp Y, int incY) { // GBMV has the same validation requirements as GEMV + KL and KU >= 0 validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY); if (KL < 0 || KU < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); } int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A->getID(), X->getID(), beta, Y->getID(), incX, incY, KL, KU); } void ScriptIntrinsicBLAS::DGBMV(RsBlasTranspose TransA, int KL, int KU, double alpha, sp A, sp X, int incX, double beta, sp Y, int incY) { // GBMV has the same validation requirements as GEMV + KL and KU >= 0 validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY); if (KL < 0 || KU < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); } int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A->getID(), X->getID(), beta, Y->getID(), incX, incY, KL, KU); } void ScriptIntrinsicBLAS::CGBMV(RsBlasTranspose TransA, int KL, int KU, Float2 alpha, sp A, sp X, int incX, Float2 beta, sp Y, int incY) { // GBMV has the same validation requirements as GEMV + KL and KU >= 0 validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY); if (KL < 0 || KU < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); } int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, KL, KU); } void ScriptIntrinsicBLAS::ZGBMV(RsBlasTranspose TransA, int KL, int KU, Double2 alpha, sp A, sp X, int incX, Double2 beta, sp Y, int incY) { // GBMV has the same validation requirements as GEMV + KL and KU >= 0 validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY); if (KL < 0 || KU < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); } int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, KL, KU); } static void validateTRMV(RS* mRS, sp e, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, sp A, sp X, int incX) { int N = A->getType()->getY(); if ((int)A->getType()->getX() != N) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for TRMV"); } if (!A->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } if (incX <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (N - 1) * incX; if ((int)X->getType()->getX() != expectedXDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TRMV"); } } static int validateTPMV(RS* mRS, sp e, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, sp Ap, sp X, int incX) { if (!Ap->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } if (Ap->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); } int N = sqrt((double)Ap->getType()->getX() * 2); if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); } if (incX <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (N - 1) * incX; if ((int)X->getType()->getX() != expectedXDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TPMV"); } return N; } void ScriptIntrinsicBLAS::STRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, sp A, sp X, int incX) { validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::DTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, sp A, sp X, int incX) { validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::CTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, sp A, sp X, int incX) { validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::ZTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, sp A, sp X, int incX) { validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::STBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, int K, sp A, sp X, int incX) { // TBMV has the same requirements as TRMV + K >= 0 if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); } validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::DTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, int K, sp A, sp X, int incX) { // TBMV has the same requirements as TRMV + K >= 0 if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); } validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::CTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, int K, sp A, sp X, int incX) { // TBMV has the same requirements as TRMV + K >= 0 if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); } validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::ZTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, int K, sp A, sp X, int incX) { // TBMV has the same requirements as TRMV + K >= 0 if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); } validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::STPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, sp Ap, sp X, int incX) { int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::DTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, sp Ap, sp X, int incX) { int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::CTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, sp Ap, sp X, int incX) { int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::ZTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, sp Ap, sp X, int incX) { int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::STRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, sp A, sp X, int incX) { // TRSV is the same as TRMV validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::DTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, sp A, sp X, int incX) { // TRSV is the same as TRMV validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::CTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, sp A, sp X, int incX) { // TRSV is the same as TRMV validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::ZTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, sp A, sp X, int incX) { // TRSV is the same as TRMV validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::STBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, int K, sp A, sp X, int incX) { // TBSV is the same as TRMV + K >= 0 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); } nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::DTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, int K, sp A, sp X, int incX) { // TBSV is the same as TRMV + K >= 0 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); } nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::CTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, int K, sp A, sp X, int incX) { // TBSV is the same as TRMV + K >= 0 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); } nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::ZTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, int K, sp A, sp X, int incX) { // TBSV is the same as TRMV + K >= 0 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); } nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::STPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, sp Ap, sp X, int incX) { // TPSV is same as TPMV int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::DTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, sp Ap, sp X, int incX) { // TPSV is same as TPMV int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::CTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, sp Ap, sp X, int incX) { // TPSV is same as TPMV int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::ZTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, sp Ap, sp X, int incX) { // TPSV is same as TPMV int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } /** * Level 2, S and D only */ static int validateSYMV(RS* mRS, sp e, RsBlasUplo Uplo, sp A, sp X, sp Y, int incX, int incY) { int N = A->getType()->getY(); if ((int)A->getType()->getX() != N) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for SYMV"); } if (!A->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e) || !Y->getType()->getElement()->isCompatible(e) ) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } if (incX <= 0 || incY <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (N - 1) * incX; if ((int)X->getType()->getX() != expectedXDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV"); } int expectedYDim = 1 + (N - 1) * incY; if ((int)Y->getType()->getX() != expectedYDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV"); } return N; } static int validateSPMV(RS* mRS, sp e, RsBlasUplo Uplo, sp Ap, sp X, int incX, sp Y, int incY) { if (!Ap->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e) || !Y->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } if (Ap->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); } int N = sqrt((double)Ap->getType()->getX() * 2); if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); } if (incX <= 0 || incY <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (N - 1) * incX; if ((int)X->getType()->getX() != expectedXDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV"); } int expectedYDim = 1 + (N - 1) * incY; if ((int)Y->getType()->getX() != expectedYDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV"); } return N; } static void validateGER(RS* mRS, sp e, sp X, int incX, sp Y, int incY, sp A) { if (!A->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e) || !Y->getType()->getElement()->isCompatible(e) ) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } int M = A->getType()->getY(); int N = A->getType()->getX(); if (N < 1 || M < 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "M and N must be 1 or greater for GER"); } if (incX <= 0 || incY <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (M - 1) * incX; if ((int)X->getType()->getX() != expectedXDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER"); } int expectedYDim = 1 + (N - 1) * incY; if ((int)Y->getType()->getX() != expectedYDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER"); } } static int validateSYR(RS* mRS, sp e, RsBlasUplo Uplo, sp X, int incX, sp A) { if (!A->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } int N = A->getType()->getX(); if (X->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } if (N != (int)A->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix"); } if (incX <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (N - 1) * incX; if ((int)X->getType()->getX() != expectedXDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR"); } return N; } static int validateSPR(RS* mRS, sp e, RsBlasUplo Uplo, sp X, int incX, sp Ap) { if (!Ap->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } if (Ap->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); } int N = sqrt((double)Ap->getType()->getX() * 2); if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); } if (incX <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (N - 1) * incX; if ((int)X->getType()->getX() != expectedXDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR"); } return N; } static int validateSYR2(RS* mRS, sp e, RsBlasUplo Uplo, sp X, int incX, sp Y, int incY, sp A) { if (!A->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e) || !Y->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } int N = A->getType()->getX(); if (N != (int)A->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix"); } if (incX <= 0 || incY <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (N - 1) * incX; int expectedYDim = 1 + (N - 1) * incY; if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR"); } return N; } static int validateSPR2(RS* mRS, sp e, RsBlasUplo Uplo, sp X, int incX, sp Y, int incY, sp Ap) { if (!Ap->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e) || !Y->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } if (Ap->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); } int N = sqrt((double)Ap->getType()->getX() * 2); if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); } if (incX <= 0 || incY <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (N - 1) * incX; int expectedYDim = 1 + (N - 1) * incY; if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR2"); } return N; } void ScriptIntrinsicBLAS::SSYMV(RsBlasUplo Uplo, float alpha, sp A, sp X, int incX, float beta, sp Y, int incY) { int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::SSBMV(RsBlasUplo Uplo, int K, float alpha, sp A, sp X, int incX, float beta, sp Y, int incY) { // SBMV is the same as SYMV + K >= 0 if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); } int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::SSPMV(RsBlasUplo Uplo, float alpha, sp Ap, sp X, int incX, float beta, sp Y, int incY) { int N = validateSPMV(mRS, Element::F32(mRS), Uplo, Ap, X, incX, Y, incY); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::SGER(float alpha, sp X, int incX, sp Y, int incY, sp A) { int M = A->getType()->getY(); int N = A->getType()->getX(); validateGER(mRS, Element::F32(mRS), X, incX, Y, incY, A); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sger, 0, 0, 0, 0, 0, M, N, 0, alpha, X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::SSYR(RsBlasUplo Uplo, float alpha, sp X, int incX, sp A) { int N = validateSYR(mRS, Element::F32(mRS), Uplo, X, incX, A); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::SSPR(RsBlasUplo Uplo, float alpha, sp X, int incX, sp Ap) { int N = validateSPR(mRS, Element::F32(mRS), Uplo, X, incX, Ap); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::SSYR2(RsBlasUplo Uplo, float alpha, sp X, int incX, sp Y, int incY, sp A) { int N = validateSYR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, A); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::SSPR2(RsBlasUplo Uplo, float alpha, sp X, int incX, sp Y, int incY, sp Ap) { int N = validateSPR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, Ap); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::DSYMV(RsBlasUplo Uplo, double alpha, sp A, sp X, int incX, double beta, sp Y, int incY) { int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::DSBMV(RsBlasUplo Uplo, int K, double alpha, sp A, sp X, int incX, double beta, sp Y, int incY) { // SBMV is the same as SYMV + K >= 0 if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); } int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::DSPMV(RsBlasUplo Uplo, double alpha, sp Ap, sp X, int incX, double beta, sp Y, int incY) { int N = validateSPMV(mRS, Element::F64(mRS), Uplo, Ap, X, incX, Y, incY); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::DGER(double alpha, sp X, int incX, sp Y, int incY, sp A) { int M = A->getType()->getY(); int N = A->getType()->getX(); validateGER(mRS, Element::F64(mRS), X, incX, Y, incY, A); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dger, 0, 0, 0, 0, 0, M, N, 0, alpha, X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::DSYR(RsBlasUplo Uplo, double alpha, sp X, int incX, sp A) { int N = validateSYR(mRS, Element::F64(mRS), Uplo, X, incX, A); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::DSPR(RsBlasUplo Uplo, double alpha, sp X, int incX, sp Ap) { int N = validateSPR(mRS, Element::F64(mRS), Uplo, X, incX, Ap); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::DSYR2(RsBlasUplo Uplo, double alpha, sp X, int incX, sp Y, int incY, sp A) { int N = validateSYR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, A); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::DSPR2(RsBlasUplo Uplo, double alpha, sp X, int incX, sp Y, int incY, sp Ap) { int N = validateSPR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, Ap); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0); } /** * Level 2, C and Z only */ static void validateGERU(RS* mRS, sp e, sp X, int incX, sp Y, int incY, sp A) { if (!A->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e) || !Y->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } int M = A->getType()->getY(); int N = A->getType()->getX(); if (incX <= 0 || incY <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (M - 1) * incX; if ((int)X->getType()->getX() != expectedXDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU"); } int expectedYDim = 1 + (N - 1) * incY; if ((int)Y->getType()->getX() != expectedYDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU"); } } void ScriptIntrinsicBLAS::CHEMV(RsBlasUplo Uplo, Float2 alpha, sp A, sp X, int incX, Float2 beta, sp Y, int incY) { // HEMV is the same as SYR2 validation-wise int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::CHBMV(RsBlasUplo Uplo, int K, Float2 alpha, sp A, sp X, int incX, Float2 beta, sp Y, int incY) { // HBMV is the same as SYR2 validation-wise int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A); if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV"); } nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::CHPMV(RsBlasUplo Uplo, Float2 alpha, sp Ap, sp X, int incX, Float2 beta, sp Y, int incY) { // HPMV is the same as SPR2 int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::CGERU(Float2 alpha, sp X, int incX, sp Y, int incY, sp A) { validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A); int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X->getID(), Y->getID(), 0, 0, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::CGERC(Float2 alpha, sp X, int incX, sp Y, int incY, sp A) { // Same as GERU validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A); int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X->getID(), Y->getID(), 0, 0, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::CHER(RsBlasUplo Uplo, float alpha, sp X, int incX, sp A) { // Same as SYR int N = validateSYR(mRS, Element::F32_2(mRS), Uplo, X, incX, A); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X->getID(), 0, 0, 0, A->getID(), incX, 0, 0, 0); } void ScriptIntrinsicBLAS::CHPR(RsBlasUplo Uplo, float alpha, sp X, int incX, sp Ap) { // Equivalent to SPR for validation int N = validateSPR(mRS, Element::F32_2(mRS), Uplo, X, incX, Ap); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X->getID(), 0, 0, 0, Ap->getID(), incX, 0, 0, 0); } void ScriptIntrinsicBLAS::CHER2(RsBlasUplo Uplo, Float2 alpha, sp X, int incX, sp Y, int incY, sp A) { // Same as SYR2 int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X->getID(), Y->getID(), 0, 0, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::CHPR2(RsBlasUplo Uplo, Float2 alpha, sp X, int incX, sp Y, int incY, sp Ap) { // Same as SPR2 int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X->getID(), Y->getID(), 0, 0, Ap->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::ZHEMV(RsBlasUplo Uplo, Double2 alpha, sp A, sp X, int incX, Double2 beta, sp Y, int incY) { // HEMV is the same as SYR2 validation-wise int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::ZHBMV(RsBlasUplo Uplo, int K, Double2 alpha, sp A, sp X, int incX, Double2 beta, sp Y, int incY) { // HBMV is the same as SYR2 validation-wise int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A); if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV"); } nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::ZHPMV(RsBlasUplo Uplo, Double2 alpha, sp Ap, sp X, int incX, Double2 beta, sp Y, int incY) { // HPMV is the same as SPR2 int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::ZGERU(Double2 alpha, sp X, int incX, sp Y, int incY, sp A) { validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A); int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X->getID(), Y->getID(), 0, 0, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::ZGERC(Double2 alpha, sp X, int incX, sp Y, int incY, sp A) { // Same as GERU validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A); int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X->getID(), Y->getID(), 0, 0, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::ZHER(RsBlasUplo Uplo, double alpha, sp X, int incX, sp A) { // Same as SYR int N = validateSYR(mRS, Element::F64_2(mRS), Uplo, X, incX, A); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X->getID(), 0, 0, 0, A->getID(), incX, 0, 0, 0); } void ScriptIntrinsicBLAS::ZHPR(RsBlasUplo Uplo, double alpha, sp X, int incX, sp Ap) { // Equivalent to SPR for validation int N = validateSPR(mRS, Element::F64_2(mRS), Uplo, X, incX, Ap); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X->getID(), 0, 0, 0, Ap->getID(), incX, 0, 0, 0); } void ScriptIntrinsicBLAS::ZHER2(RsBlasUplo Uplo, Double2 alpha, sp X, int incX, sp Y, int incY, sp A) { // Same as SYR2 int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X->getID(), Y->getID(), 0, 0, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::ZHPR2(RsBlasUplo Uplo, Double2 alpha, sp X, int incX, sp Y, int incY, sp Ap) { // Same as SPR2 int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X->getID(), Y->getID(), 0, 0, Ap->getID(), incX, incY, 0, 0); } /** * Level 3 BLAS */ static void validateL3(RS* mRS, sp e, int TransA, int TransB, int Side, sp A, sp B, sp C) { int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1; if ((A != nullptr && !A->getType()->getElement()->isCompatible(e)) || (B != nullptr && !B->getType()->getElement()->isCompatible(e)) || (C != nullptr && !C->getType()->getElement()->isCompatible(e))) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (C == nullptr) { // Since matrix C is used to store the result, it cannot be null. mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Allocation C cannot be null"); } cM = C->getType()->getY(); cN = C->getType()->getX(); if (Side == RsBlasRight) { if ((A == nullptr && B != nullptr) || (A != nullptr && B == nullptr)) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Provided Matrix A without Matrix B, or vice versa"); } if (B != nullptr) { bM = A->getType()->getY(); bN = A->getType()->getX(); } if (A != nullptr) { aM = B->getType()->getY(); aN = B->getType()->getX(); } } else { if (A != nullptr) { if (TransA == RsBlasTrans || TransA == RsBlasConjTrans) { aN = A->getType()->getY(); aM = A->getType()->getX(); } else { aM = A->getType()->getY(); aN = A->getType()->getX(); } } if (B != nullptr) { if (TransB == RsBlasTrans || TransB == RsBlasConjTrans) { bN = B->getType()->getY(); bM = B->getType()->getX(); } else { bM = B->getType()->getY(); bN = B->getType()->getX(); } } } if (A != nullptr && B != nullptr && C != nullptr) { if (aN != bM || aM != cM || bN != cN) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions"); } } else if (A != nullptr && C != nullptr) { // A and C only, for SYRK if (cM != cN) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix C is not symmetric"); } if (aM != cM) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions"); } } else if (A != nullptr && B != nullptr) { // A and B only if (aN != bM) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions"); } } } void ScriptIntrinsicBLAS::SGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, float alpha, sp A, sp B, float beta, sp C) { validateL3(mRS, Element::F32(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; if (TransA != RsBlasNoTrans) { M = A->getType()->getX(); K = A->getType()->getY(); } else { M = A->getType()->getY(); K = A->getType()->getX(); } if (TransB != RsBlasNoTrans) { N = B->getType()->getY(); } else { N = B->getType()->getX(); } nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A->getID(), B->getID(), beta, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::DGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, double alpha, sp A, sp B, double beta, sp C) { validateL3(mRS, Element::F64(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; if (TransA != RsBlasNoTrans) { M = A->getType()->getX(); K = A->getType()->getY(); } else { M = A->getType()->getY(); K = A->getType()->getX(); } if (TransB != RsBlasNoTrans) { N = B->getType()->getY(); } else { N = B->getType()->getX(); } nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A->getID(), B->getID(), beta, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::CGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Float2 alpha, sp A, sp B, Float2 beta, sp C) { validateL3(mRS, Element::F32_2(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; if (TransA != RsBlasNoTrans) { M = A->getType()->getX(); K = A->getType()->getY(); } else { M = A->getType()->getY(); K = A->getType()->getX(); } if (TransB != RsBlasNoTrans) { N = B->getType()->getY(); } else { N = B->getType()->getX(); } nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A->getID(), B->getID(), beta.x, beta.y, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::ZGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Double2 alpha, sp A, sp B, Double2 beta, sp C) { validateL3(mRS, Element::F64_2(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; if (TransA != RsBlasNoTrans) { M = A->getType()->getX(); K = A->getType()->getY(); } else { M = A->getType()->getY(); K = A->getType()->getX(); } if (TransB != RsBlasNoTrans) { N = B->getType()->getY(); } else { N = B->getType()->getX(); } nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A->getID(), B->getID(), beta.x, beta.y, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::SSYMM(RsBlasSide Side, RsBlasUplo Uplo, float alpha, sp A, sp B, float beta, sp C) { //For SYMM, Matrix A should be symmetric if (A->getType()->getX() != A->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); } validateL3(mRS, Element::F32(mRS), 0, 0, Side, A, B, C); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, alpha, A->getID(), B->getID(), beta, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::DSYMM(RsBlasSide Side, RsBlasUplo Uplo, double alpha, sp A, sp B, double beta, sp C) { if (A->getType()->getX() != A->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); } validateL3(mRS, Element::F64(mRS), 0, 0, Side, A, B, C); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, alpha, A->getID(), B->getID(), beta, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::CSYMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha, sp A, sp B, Float2 beta, sp C) { if (A->getType()->getX() != A->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); } validateL3(mRS, Element::F32_2(mRS), 0, 0, Side, A, B, C); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csymm, 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, alpha.x, alpha.y, A->getID(), B->getID(), beta.x, beta.y, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::ZSYMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha, sp A, sp B, Double2 beta, sp C) { if (A->getType()->getX() != A->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); } validateL3(mRS, Element::F64_2(mRS), 0, 0, Side, A, B, C); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, alpha.x, alpha.y, A->getID(), B->getID(), beta.x, beta.y, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::SSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha, sp A, float beta, sp C) { validateL3(mRS, Element::F32(mRS), Trans, 0, 0, A, nullptr, C); int K = -1; if (Trans != RsBlasNoTrans) { K = A->getType()->getY(); } else { K = A->getType()->getX(); } nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyrk, Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, alpha, A->getID(), 0, beta, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::DSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha, sp A, double beta, sp C) { validateL3(mRS, Element::F64(mRS), Trans, 0, 0, A, nullptr, C); int K = -1; if (Trans != RsBlasNoTrans) { K = A->getType()->getY(); } else { K = A->getType()->getX(); } nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyrk, Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, alpha, A->getID(), 0, beta, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::CSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha, sp A, Float2 beta, sp C) { validateL3(mRS, Element::F32_2(mRS), Trans, 0, 0, A, nullptr, C); int K = -1; if (Trans != RsBlasNoTrans) { K = A->getType()->getY(); } else { K = A->getType()->getX(); } nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, alpha.x, alpha.y, A->getID(), 0, beta.x, beta.y, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::ZSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha, sp A, Double2 beta, sp C) { validateL3(mRS, Element::F64_2(mRS), Trans, 0, 0, A, nullptr, C); int K = -1; if (Trans != RsBlasNoTrans) { K = A->getType()->getY(); } else { K = A->getType()->getX(); } nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, alpha.x, alpha.y, A->getID(), 0, beta.x, beta.y, C->getID(), 0, 0, 0, 0); } static void validateSYR2K(RS* mRS, sp e, RsBlasTranspose Trans, sp A, sp B, sp C) { if (!A->getType()->getElement()->isCompatible(e) || !B->getType()->getElement()->isCompatible(e) || !C->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } int Cdim = -1; // A is n x k if no transpose, k x n if transpose // C is n x n if (Trans == RsBlasTrans) { // check columns versus C Cdim = A->getType()->getX(); } else { // check rows versus C Cdim = A->getType()->getY(); } if ((int)C->getType()->getX() != Cdim || (int)C->getType()->getY() != Cdim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid symmetric matrix in SYR2K"); } // A dims == B dims if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid A and B in SYR2K"); } } void ScriptIntrinsicBLAS::SSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha, sp A, sp B, float beta, sp C) { validateSYR2K(mRS, Element::F32(mRS), Trans, A, B, C); int K = -1; if (Trans != RsBlasNoTrans) { K = A->getType()->getY(); } else { K = A->getType()->getX(); } nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, alpha, A->getID(), B->getID(), beta, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::DSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha, sp A, sp B, double beta, sp C) { validateSYR2K(mRS, Element::F64(mRS), Trans, A, B, C); int K = -1; if (Trans != RsBlasNoTrans) { K = A->getType()->getY(); } else { K = A->getType()->getX(); } nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2k, Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, alpha, A->getID(), B->getID(), beta, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::CSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha, sp A, sp B, Float2 beta, sp C) { validateSYR2K(mRS, Element::F32_2(mRS), Trans, A, B, C); int K = -1; if (Trans != RsBlasNoTrans) { K = A->getType()->getY(); } else { K = A->getType()->getX(); } nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyr2k, Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, alpha.x, alpha.y, A->getID(), B->getID(), beta.x, beta.y, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::ZSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha, sp A, sp B, Double2 beta, sp C) { validateSYR2K(mRS, Element::F64_2(mRS), Trans, A, B, C); int K = -1; if (Trans != RsBlasNoTrans) { K = A->getType()->getY(); } else { K = A->getType()->getX(); } nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyr2k, Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, alpha.x, alpha.y, A->getID(), B->getID(), beta.x, beta.y, C->getID(), 0, 0, 0, 0); } static void validateTRMM(RS* mRS, sp e, RsBlasSide Side, RsBlasTranspose TransA, sp A, sp B) { int aM = -1, aN = -1, bM = -1, bN = -1; if (!A->getType()->getElement()->isCompatible(e) || !B->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } aM = A->getType()->getY(); aN = A->getType()->getX(); if (aM != aN) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with a non-symmetric matrix A"); } bM = B->getType()->getY(); bN = B->getType()->getX(); if (Side == RsBlasLeft) { if (aN != bM) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices"); } } else { if (bN != aM) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices"); } } } void ScriptIntrinsicBLAS::STRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, float alpha, sp A, sp B) { validateTRMM(mRS, Element::F32(mRS), Side, TransA, A, B); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmm, TransA, 0, Side, Uplo, Diag,\ B->getType()->getY(), B->getType()->getX(), 0, alpha, A->getID(), B->getID(), 0.f, 0, 0, 0, 0, 0); } void ScriptIntrinsicBLAS::DTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, double alpha, sp A, sp B) { validateTRMM(mRS, Element::F64(mRS), Side, TransA, A, B); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmm, TransA, 0, Side, Uplo, Diag, B->getType()->getY(), B->getType()->getX(), 0, alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0); } void ScriptIntrinsicBLAS::CTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, Float2 alpha, sp A, sp B) { validateTRMM(mRS, Element::F32_2(mRS), Side, TransA, A, B); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmm, TransA, 0, Side, Uplo, Diag, B->getType()->getY(), B->getType()->getX(), 0, alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); } void ScriptIntrinsicBLAS::ZTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, Double2 alpha, sp A, sp B) { validateTRMM(mRS, Element::F64_2(mRS), Side, TransA, A, B); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmm, TransA, 0, Side, Uplo, Diag, B->getType()->getY(), B->getType()->getX(), 0, alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); } static void validateTRSM(RS* mRS, sp e, RsBlasSide Side, RsBlasTranspose TransA, sp A, sp B) { int adim = -1, bM = -1, bN = -1; if (!A->getType()->getElement()->isCompatible(e) || !B->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } adim = A->getType()->getX(); if (adim != (int)A->getType()->getY()) { // This may be unnecessary, the restriction could potentially be relaxed. // Allocation A needs to contain at least that symmetric matrix but could theoretically // be larger for now we assume adapters are sufficient, will reevaluate in the future. mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with a non-symmetric matrix A"); } bM = B->getType()->getY(); bN = B->getType()->getX(); if (Side == RsBlasLeft) { // A is M*M if (adim != bM) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions"); } } else { // A is N*N if (adim != bN) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions"); } } } void ScriptIntrinsicBLAS::STRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, float alpha, sp A, sp B) { validateTRSM(mRS, Element::F32(mRS), Side, TransA, A, B); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsm, TransA, 0, Side, Uplo, Diag, B->getType()->getY(), B->getType()->getX(), 0, alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0); } void ScriptIntrinsicBLAS::DTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, double alpha, sp A, sp B) { validateTRSM(mRS, Element::F64(mRS), Side, TransA, A, B); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsm, TransA, 0, Side, Uplo, Diag, B->getType()->getY(), B->getType()->getX(), 0, alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0); } void ScriptIntrinsicBLAS::CTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, Float2 alpha, sp A, sp B) { validateTRSM(mRS, Element::F32_2(mRS), Side, TransA, A, B); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsm, TransA, 0, Side, Uplo, Diag, B->getType()->getY(), B->getType()->getX(), 0, alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); } void ScriptIntrinsicBLAS::ZTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, Double2 alpha, sp A, sp B) { validateTRSM(mRS, Element::F64_2(mRS), Side, TransA, A, B); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsm, TransA, 0, Side, Uplo, Diag, B->getType()->getY(), B->getType()->getX(), 0, alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); } static void validateHEMM(RS* mRS, sp e, RsBlasSide Side, sp A, sp B, sp C) { if (!A->getType()->getElement()->isCompatible(e) || !B->getType()->getElement()->isCompatible(e) || !C->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } // A must be square; can potentially be relaxed similar to TRSM int adim = A->getType()->getX(); if (adim != (int)A->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with non-square A"); } if ((Side == RsBlasLeft && adim != (int)B->getType()->getY()) || (Side == RsBlasRight && adim != (int)B->getType()->getX())) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with invalid B"); } if (B->getType()->getX() != C->getType()->getX() || B->getType()->getY() != C->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with mismatched B and C"); } } void ScriptIntrinsicBLAS::CHEMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha, sp A, sp B, Float2 beta, sp C) { validateHEMM(mRS, Element::F32_2(mRS), Side, A, B, C); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemm, 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, alpha.x, alpha.y, A->getID(), B->getID(), beta.x, beta.y, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::ZHEMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha, sp A, sp B, Double2 beta, sp C) { validateHEMM(mRS, Element::F64_2(mRS), Side, A, B, C); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemm, 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, alpha.x, alpha.y, A->getID(), B->getID(), beta.x, beta.y, C->getID(), 0, 0, 0, 0); } static void validateHERK(RS* mRS, sp e, RsBlasTranspose Trans, sp A, sp C) { if (!A->getType()->getElement()->isCompatible(e) || !C->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose"); } int cdim = C->getType()->getX(); if (cdim != (int)C->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with non-square C"); } if (Trans == RsBlasNoTrans) { if (cdim != (int)A->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A"); } } else { if (cdim != (int)A->getType()->getX()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A"); } } } void ScriptIntrinsicBLAS::CHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha, sp A, float beta, sp C) { validateHERK(mRS, Element::F32_2(mRS), Trans, A, C); int k = 0; if (Trans == RsBlasConjTrans) { k = A->getType()->getY(); } else { k = A->getType()->getX(); } nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cherk, Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, alpha, 0, A->getID(), 0, beta, 0, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::ZHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha, sp A, double beta, sp C) { validateHERK(mRS, Element::F64_2(mRS), Trans, A, C); int k = 0; if (Trans == RsBlasConjTrans) { k = A->getType()->getY(); } else { k = A->getType()->getX(); } nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zherk, Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, alpha, 0, A->getID(), 0, beta, 0, C->getID(), 0, 0, 0, 0); } static void validateHER2K(RS* mRS, sp e, RsBlasTranspose Trans, sp A, sp B, sp C) { if (!A->getType()->getElement()->isCompatible(e) || !B->getType()->getElement()->isCompatible(e) || !C->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose"); } int cdim = C->getType()->getX(); if (cdim != (int)C->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with non-square C"); } if (Trans == RsBlasNoTrans) { if ((int)A->getType()->getY() != cdim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices"); } } else { if ((int)A->getType()->getX() != cdim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices"); } } if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid A and B matrices"); } } void ScriptIntrinsicBLAS::CHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha, sp A, sp B, float beta, sp C) { validateHER2K(mRS, Element::F32_2(mRS), Trans, A, B, C); int k = 0; if (Trans == RsBlasNoTrans) { k = A->getType()->getX(); } else { k = A->getType()->getY(); } nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2k, Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, alpha.x, alpha.y, A->getID(), B->getID(), beta, 0, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::ZHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha, sp A, sp B, double beta, sp C) { validateHER2K(mRS, Element::F64_2(mRS), Trans, A, B, C); int k = 0; if (Trans == RsBlasNoTrans) { k = A->getType()->getX(); } else { k = A->getType()->getY(); } nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2k, Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, alpha.x, alpha.y, A->getID(), B->getID(), beta, 0, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::BNNM(sp A, int a_offset, sp B, int b_offset, sp C, int c_offset, int c_mult) { validateL3(mRS, Element::U8(mRS), RsBlasNoTrans, RsBlasTrans, 0, A, B, C); if (a_offset < 0 || a_offset > 255) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid a_offset passed to BNNM"); } if (b_offset < 0 || b_offset > 255) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid b_offset passed to BNNM"); } int M = -1, N = -1, K = -1; M = A->getType()->getY(); N = B->getType()->getY(); K = A->getType()->getX(); nScriptIntrinsicBLAS_BNNM(mRS, mRS->getContext(), getID(), M, N, K, A->getID(), a_offset, B->getID(), b_offset, C->getID(), c_offset, c_mult); }