1 /*
2  * Copyright (C) 2015 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 
18 #include "RenderScript.h"
19 #include "rsCppInternal.h"
20 
21 using namespace android;
22 using namespace RSC;
23 
24 // ScriptIntrinsicBLAS APIS
ScriptIntrinsicBLAS(sp<RS> rs,sp<const Element> e)25 ScriptIntrinsicBLAS::ScriptIntrinsicBLAS(sp<RS> rs, sp<const Element> e)
26     : ScriptIntrinsic(rs, RS_SCRIPT_INTRINSIC_ID_BLAS, e) {
27 
28 }
29 
create(sp<RS> rs)30 sp<ScriptIntrinsicBLAS> ScriptIntrinsicBLAS::create(sp<RS> rs) {
31     return new ScriptIntrinsicBLAS(rs, Element::U32(rs));
32 }
33 
34 enum RsBlasDataType {
35     SINGLE,
36     DOUBLE,
37     SINGLE_COMPLEX,
38     DOUBLE_COMPLEX
39 };
40 
41 static RsBlasCall
setUpBLASCall(RsBlasDataType dataType,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,int incX,int incY,int KL,int KU,float alphaF,float betaF,double alphaD,double betaD,float alphaCX,float alphaCY,float betaCX,float betaCY,double alphaZX,double alphaZY,double betaZX,double betaZY)42 setUpBLASCall(RsBlasDataType dataType, RsBlasFunction func,
43               int TransA, int TransB, int Side, int Uplo, int Diag,
44               int M, int N, int K, int incX, int incY, int KL, int KU,
45               float alphaF, float betaF, double alphaD, double betaD,
46               float alphaCX, float alphaCY, float betaCX, float betaCY,
47               double alphaZX, double alphaZY, double betaZX, double betaZY
48               ) {
49     RsBlasCall call;
50     memset(&call, 0, sizeof(call));
51     call.func = func;
52     call.transA = (RsBlasTranspose)TransA;
53     call.transB = (RsBlasTranspose)TransB;
54     call.side = (RsBlasSide)Side;
55     call.uplo = (RsBlasUplo)Uplo;
56     call.diag = (RsBlasDiag)Diag;
57     call.M = M;
58     call.N = N;
59     call.K = K;
60 
61     switch (dataType) {
62         case SINGLE:
63             // For Single-precision BLAS.
64             call.alpha.f = alphaF;
65             call.beta.f = betaF;
66             break;
67         case DOUBLE:
68             // For Double-precision BLAS.
69             call.alpha.d = alphaD;
70             call.beta.d = betaD;
71             break;
72         case SINGLE_COMPLEX:
73             // For Single-precision complex BLAS.
74             call.alpha.c.r = alphaCX;
75             call.alpha.c.i = alphaCY;
76             call.beta.c.r = betaCX;
77             call.beta.c.i = betaCY;
78             break;
79         case DOUBLE_COMPLEX:
80             // For Double-precision complex BLAS.
81             call.alpha.z.r = alphaZX;
82             call.alpha.z.i = alphaZY;
83             call.beta.z.r = betaZX;
84             call.beta.z.i = betaZY;
85             break;
86         default:
87             break;
88     }
89 
90     call.incX = incX;
91     call.incY = incY;
92     call.KL = KL;
93     call.KU = KU;
94 
95     return call;
96 }
97 
98 static void
nScriptIntrinsicBLAS_Single(RS * mRS,RsContext con,RsScript id,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,float alpha,RsAllocation A,RsAllocation B,float beta,RsAllocation C,int incX,int incY,int KL,int KU)99 nScriptIntrinsicBLAS_Single(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
100                             int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
101                             float alpha, RsAllocation A, RsAllocation B,
102                             float beta, RsAllocation C, int incX, int incY, int KL, int KU) {
103     RsBlasCall call = setUpBLASCall(SINGLE, func, TransA, TransB, Side, Uplo, Diag,
104                                     M, N, K, incX, incY, KL, KU, alpha, beta, 0.0, 0.0,
105                                     0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0);
106     RsAllocation in_allocs[3] = {A, B, C};
107     tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
108                                                       &call, sizeof(call), nullptr, 0));
109 }
110 
111 
112 static void
nScriptIntrinsicBLAS_Double(RS * mRS,RsContext con,RsScript id,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,double alpha,RsAllocation A,RsAllocation B,double beta,RsAllocation C,int incX,int incY,int KL,int KU)113 nScriptIntrinsicBLAS_Double(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
114                             int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
115                             double alpha, RsAllocation A, RsAllocation B,
116                             double beta, RsAllocation C, int incX, int incY, int KL, int KU) {
117     RsBlasCall call = setUpBLASCall(DOUBLE, func, TransA, TransB, Side, Uplo, Diag,
118                                     M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, alpha, beta,
119                                     0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0);
120     RsAllocation in_allocs[3] = {A, B, C};
121     tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
122                                                       &call, sizeof(call), nullptr, 0));
123 }
124 
125 static void
nScriptIntrinsicBLAS_Complex(RS * mRS,RsContext con,RsScript id,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,float alphaX,float alphaY,RsAllocation A,RsAllocation B,float betaX,float betaY,RsAllocation C,int incX,int incY,int KL,int KU)126 nScriptIntrinsicBLAS_Complex(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
127                              int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
128                              float alphaX, float alphaY, RsAllocation A, RsAllocation B,
129                              float betaX, float betaY, RsAllocation C, int incX, int incY, int KL, int KU) {
130     RsBlasCall call = setUpBLASCall(SINGLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag,
131                                     M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0,
132                                     alphaX, alphaY, betaX, betaY, 0.0, 0.0, 0.0, 0.0);
133     RsAllocation in_allocs[3] = {A, B, C};
134     tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
135                                                       &call, sizeof(call), nullptr, 0));
136 }
137 
138 static void
nScriptIntrinsicBLAS_Z(RS * mRS,RsContext con,RsScript id,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,double alphaX,double alphaY,RsAllocation A,RsAllocation B,double betaX,double betaY,RsAllocation C,int incX,int incY,int KL,int KU)139 nScriptIntrinsicBLAS_Z(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
140                        int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
141                        double alphaX, double alphaY, RsAllocation A, RsAllocation B,
142                        double betaX, double betaY, RsAllocation C, int incX, int incY, int KL, int KU) {
143     RsBlasCall call = setUpBLASCall(DOUBLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag,
144                                     M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0,
145                                     0.0f, 0.0f, 0.0f, 0.0f, alphaX, alphaY, betaX, betaY);
146     RsAllocation in_allocs[3] = {A, B, C};
147     tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
148                                                       &call, sizeof(call), nullptr, 0));
149 }
150 
151 
152 static void
nScriptIntrinsicBLAS_BNNM(RS * mRS,RsContext con,RsScript id,int M,int N,int K,RsAllocation A,int a_offset,RsAllocation B,int b_offset,RsAllocation C,int c_offset,int c_mult_int)153 nScriptIntrinsicBLAS_BNNM(RS* mRS, RsContext con, RsScript id, int M, int N, int K,
154                           RsAllocation A, int a_offset, RsAllocation B, int b_offset,
155                           RsAllocation C, int c_offset, int c_mult_int) {
156     RsBlasCall call;
157     memset(&call, 0, sizeof(call));
158     call.func = RsBlas_bnnm;
159     call.M = M;
160     call.N = N;
161     call.K = K;
162     call.a_offset = a_offset & 0xFF;
163     call.b_offset = b_offset & 0xFF;
164     call.c_offset = c_offset;
165     call.c_mult_int = c_mult_int;
166 
167     RsAllocation in_allocs[3] = {A, B, C};
168     tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, sizeof(in_allocs), nullptr,
169                                                       &call, sizeof(call), nullptr, 0));
170 }
171 
172 /**
173  * Level 2 BLAS
174  */
validateGEMV(RS * mRS,sp<const Element> e,RsBlasTranspose TransA,sp<Allocation> A,sp<Allocation> X,int incX,sp<Allocation> Y,int incY)175 static void validateGEMV(RS* mRS, sp<const Element> e, RsBlasTranspose TransA, sp<Allocation> A,
176                          sp<Allocation> X, int incX, sp<Allocation> Y, int incY) {
177     int M = A->getType()->getY();
178     int N = A->getType()->getX();
179     if (!A->getType()->getElement()->isCompatible(e) ||
180         !X->getType()->getElement()->isCompatible(e) ||
181         !Y->getType()->getElement()->isCompatible(e)) {
182         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
183     }
184     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
185         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
186     }
187 
188     if (incX <= 0 || incY <= 0) {
189         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
190     }
191     int expectedXDim = -1, expectedYDim = -1;
192     if (TransA == RsBlasNoTrans) {
193         expectedXDim = 1 + (N - 1) * incX;
194         expectedYDim = 1 + (M - 1) * incY;
195     } else {
196         expectedXDim = 1 + (M - 1) * incX;
197         expectedYDim = 1 + (N - 1) * incY;
198     }
199     if ((int)X->getType()->getX() != expectedXDim ||
200         (int)Y->getType()->getX() != expectedYDim) {
201         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GEMV");
202     }
203 }
204 
SGEMV(RsBlasTranspose TransA,float alpha,sp<Allocation> A,sp<Allocation> X,int incX,float beta,sp<Allocation> Y,int incY)205 void ScriptIntrinsicBLAS::SGEMV(RsBlasTranspose TransA, float alpha, sp<Allocation> A, sp<Allocation> X,
206                                 int incX, float beta, sp<Allocation> Y, int incY) {
207     validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY);
208     int M = A->getType()->getY();
209     int N = A->getType()->getX();
210     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemv,
211                                 TransA, 0, 0, 0, 0, M, N, 0,
212                                 alpha, A->getID(), X->getID(),
213                                 beta, Y->getID(), incX, incY, 0, 0);
214 }
215 
DGEMV(RsBlasTranspose TransA,double alpha,sp<Allocation> A,sp<Allocation> X,int incX,double beta,sp<Allocation> Y,int incY)216 void ScriptIntrinsicBLAS::DGEMV(RsBlasTranspose TransA, double alpha, sp<Allocation> A, sp<Allocation> X,
217                                 int incX, double beta, sp<Allocation> Y, int incY) {
218     validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY);
219     int M = A->getType()->getY();
220     int N = A->getType()->getX();
221     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemv,
222                                 TransA, 0, 0, 0, 0, M, N, 0,
223                                 alpha, A->getID(), X->getID(),
224                                 beta, Y->getID(), incX, incY, 0, 0);
225 }
226 
CGEMV(RsBlasTranspose TransA,Float2 alpha,sp<Allocation> A,sp<Allocation> X,int incX,Float2 beta,sp<Allocation> Y,int incY)227 void ScriptIntrinsicBLAS::CGEMV(RsBlasTranspose TransA, Float2 alpha, sp<Allocation> A, sp<Allocation> X,
228                                 int incX, Float2 beta, sp<Allocation> Y, int incY) {
229     validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY);
230     int M = A->getType()->getY();
231     int N = A->getType()->getX();
232     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemv,
233                                  TransA, 0, 0, 0, 0, M, N, 0,
234                                  alpha.x, alpha.y, A->getID(), X->getID(),
235                                  beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
236 }
237 
ZGEMV(RsBlasTranspose TransA,Double2 alpha,sp<Allocation> A,sp<Allocation> X,int incX,Double2 beta,sp<Allocation> Y,int incY)238 void ScriptIntrinsicBLAS::ZGEMV(RsBlasTranspose TransA, Double2 alpha, sp<Allocation> A, sp<Allocation> X,
239                                 int incX, Double2 beta, sp<Allocation> Y, int incY) {
240     validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY);
241     int M = A->getType()->getY();
242     int N = A->getType()->getX();
243     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemv,
244                            TransA, 0, 0, 0, 0, M, N, 0,
245                            alpha.x, alpha.y, A->getID(), X->getID(),
246                            beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
247 }
248 
SGBMV(RsBlasTranspose TransA,int KL,int KU,float alpha,sp<Allocation> A,sp<Allocation> X,int incX,float beta,sp<Allocation> Y,int incY)249 void ScriptIntrinsicBLAS::SGBMV(RsBlasTranspose TransA, int KL, int KU, float alpha, sp<Allocation> A,
250                                 sp<Allocation> X, int incX, float beta, sp<Allocation> Y, int incY) {
251     // GBMV has the same validation requirements as GEMV + KL and KU >= 0
252     validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY);
253     if (KL < 0 || KU < 0) {
254         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
255     }
256     int M = A->getType()->getY();
257     int N = A->getType()->getX();
258 
259     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgbmv,
260                                 TransA, 0, 0, 0, 0, M, N, 0,
261                                 alpha, A->getID(), X->getID(),
262                                 beta, Y->getID(), incX, incY, KL, KU);
263 }
264 
DGBMV(RsBlasTranspose TransA,int KL,int KU,double alpha,sp<Allocation> A,sp<Allocation> X,int incX,double beta,sp<Allocation> Y,int incY)265 void ScriptIntrinsicBLAS::DGBMV(RsBlasTranspose TransA, int KL, int KU, double alpha, sp<Allocation> A,
266                                 sp<Allocation> X, int incX, double beta, sp<Allocation> Y, int incY) {
267     // GBMV has the same validation requirements as GEMV + KL and KU >= 0
268     validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY);
269     if (KL < 0 || KU < 0) {
270         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
271     }
272     int M = A->getType()->getY();
273     int N = A->getType()->getX();
274 
275     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgbmv,
276                                 TransA, 0, 0, 0, 0, M, N, 0,
277                                 alpha, A->getID(), X->getID(),
278                                 beta, Y->getID(), incX, incY, KL, KU);
279 }
280 
CGBMV(RsBlasTranspose TransA,int KL,int KU,Float2 alpha,sp<Allocation> A,sp<Allocation> X,int incX,Float2 beta,sp<Allocation> Y,int incY)281 void ScriptIntrinsicBLAS::CGBMV(RsBlasTranspose TransA, int KL, int KU, Float2 alpha, sp<Allocation> A,
282                                 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
283     // GBMV has the same validation requirements as GEMV + KL and KU >= 0
284     validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY);
285     if (KL < 0 || KU < 0) {
286         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
287     }
288     int M = A->getType()->getY();
289     int N = A->getType()->getX();
290 
291     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgbmv,
292                                  TransA, 0, 0, 0, 0, M, N, 0,
293                                  alpha.x, alpha.y, A->getID(), X->getID(),
294                                  beta.x, beta.y, Y->getID(), incX, incY, KL, KU);
295 }
296 
ZGBMV(RsBlasTranspose TransA,int KL,int KU,Double2 alpha,sp<Allocation> A,sp<Allocation> X,int incX,Double2 beta,sp<Allocation> Y,int incY)297 void ScriptIntrinsicBLAS::ZGBMV(RsBlasTranspose TransA, int KL, int KU, Double2 alpha, sp<Allocation> A,
298                                 sp<Allocation> X, int incX, Double2 beta, sp<Allocation> Y, int incY) {
299     // GBMV has the same validation requirements as GEMV + KL and KU >= 0
300     validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY);
301     if (KL < 0 || KU < 0) {
302         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
303     }
304     int M = A->getType()->getY();
305     int N = A->getType()->getX();
306 
307     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgbmv,
308                            TransA, 0, 0, 0, 0, M, N, 0,
309                            alpha.x, alpha.y, A->getID(), X->getID(),
310                            beta.x, beta.y, Y->getID(), incX, incY, KL, KU);
311 }
312 
validateTRMV(RS * mRS,sp<const Element> e,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> A,sp<Allocation> X,int incX)313 static void validateTRMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, RsBlasTranspose TransA,
314                          RsBlasDiag Diag, sp<Allocation> A, sp<Allocation> X, int incX) {
315     int N = A->getType()->getY();
316     if ((int)A->getType()->getX() != N) {
317         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for TRMV");
318     }
319     if (!A->getType()->getElement()->isCompatible(e) ||
320         !X->getType()->getElement()->isCompatible(e)) {
321         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
322     }
323     if (X->getType()->getY() > 1) {
324         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
325     }
326 
327     if (incX <= 0) {
328         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
329     }
330     int expectedXDim = 1 + (N - 1) * incX;
331     if ((int)X->getType()->getX() != expectedXDim) {
332         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TRMV");
333     }
334 }
335 
validateTPMV(RS * mRS,sp<const Element> e,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> Ap,sp<Allocation> X,int incX)336 static int validateTPMV(RS* mRS, sp<const Element> e,  RsBlasUplo Uplo, RsBlasTranspose TransA,
337                         RsBlasDiag Diag, sp<Allocation> Ap, sp<Allocation> X, int incX) {
338     if (!Ap->getType()->getElement()->isCompatible(e) ||
339         !X->getType()->getElement()->isCompatible(e)) {
340         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
341     }
342     if (X->getType()->getY() > 1) {
343         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
344     }
345 
346     if (Ap->getType()->getY() > 1) {
347         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
348     }
349 
350     int N = sqrt((double)Ap->getType()->getX() * 2);
351     if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
352         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
353     }
354     if (incX <= 0) {
355         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
356     }
357     int expectedXDim = 1 + (N - 1) * incX;
358     if ((int)X->getType()->getX() != expectedXDim) {
359         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TPMV");
360     }
361 
362     return N;
363 }
364 
365 
STRMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> A,sp<Allocation> X,int incX)366 void ScriptIntrinsicBLAS::STRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
367                                 sp<Allocation> A, sp<Allocation> X, int incX) {
368     validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
369     int N = A->getType()->getY();
370     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmv,
371                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
372                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
373 }
374 
DTRMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> A,sp<Allocation> X,int incX)375 void ScriptIntrinsicBLAS::DTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
376                                 sp<Allocation> A, sp<Allocation> X, int incX) {
377     validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
378     int N = A->getType()->getY();
379     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmv,
380                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
381                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
382 }
383 
CTRMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> A,sp<Allocation> X,int incX)384 void ScriptIntrinsicBLAS::CTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
385                                 sp<Allocation> A, sp<Allocation> X, int incX) {
386     validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
387     int N = A->getType()->getY();
388     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmv,
389                                  TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
390                                  A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
391 }
392 
ZTRMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> A,sp<Allocation> X,int incX)393 void ScriptIntrinsicBLAS::ZTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
394                                 sp<Allocation> A, sp<Allocation> X, int incX) {
395     validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
396     int N = A->getType()->getY();
397     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmv,
398                            TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
399                            A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
400 }
401 
STBMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,sp<Allocation> A,sp<Allocation> X,int incX)402 void ScriptIntrinsicBLAS::STBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
403                                 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
404     // TBMV has the same requirements as TRMV + K >= 0
405     if (K < 0) {
406         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
407     }
408     validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
409     int N = A->getType()->getY();
410     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbmv,
411                                 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
412                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
413 }
414 
DTBMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,sp<Allocation> A,sp<Allocation> X,int incX)415 void ScriptIntrinsicBLAS::DTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
416                                 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
417     // TBMV has the same requirements as TRMV + K >= 0
418     if (K < 0) {
419         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
420     }
421     validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
422     int N = A->getType()->getY();
423     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbmv,
424                                 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
425                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
426 }
427 
CTBMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,sp<Allocation> A,sp<Allocation> X,int incX)428 void ScriptIntrinsicBLAS::CTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
429                                 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
430     // TBMV has the same requirements as TRMV + K >= 0
431     if (K < 0) {
432         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
433     }
434     validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
435     int N = A->getType()->getY();
436     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbmv,
437                                  TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
438                                  A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
439 }
440 
ZTBMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,sp<Allocation> A,sp<Allocation> X,int incX)441 void ScriptIntrinsicBLAS::ZTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
442                                 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
443     // TBMV has the same requirements as TRMV + K >= 0
444     if (K < 0) {
445         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
446     }
447     validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
448     int N = A->getType()->getY();
449     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbmv,
450                            TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
451                            A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
452 }
453 
STPMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> Ap,sp<Allocation> X,int incX)454 void ScriptIntrinsicBLAS::STPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
455                                 sp<Allocation> Ap, sp<Allocation> X, int incX) {
456     int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
457     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpmv,
458                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
459                                 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
460 }
461 
DTPMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> Ap,sp<Allocation> X,int incX)462 void ScriptIntrinsicBLAS::DTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
463                                 sp<Allocation> Ap, sp<Allocation> X, int incX) {
464     int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
465     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpmv,
466                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
467                                 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
468 }
469 
CTPMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> Ap,sp<Allocation> X,int incX)470 void ScriptIntrinsicBLAS::CTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
471                                 sp<Allocation> Ap,  sp<Allocation> X,  int incX) {
472     int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
473     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpmv,
474                                  TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
475                                  Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
476 }
477 
ZTPMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> Ap,sp<Allocation> X,int incX)478 void ScriptIntrinsicBLAS::ZTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
479                                 sp<Allocation> Ap, sp<Allocation> X, int incX) {
480     int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
481     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpmv,
482                            TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
483                            Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
484 }
485 
STRSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> A,sp<Allocation> X,int incX)486 void ScriptIntrinsicBLAS::STRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
487                                 sp<Allocation> A, sp<Allocation> X, int incX) {
488     // TRSV is the same as TRMV
489     validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
490     int N = A->getType()->getY();
491     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsv,
492                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
493                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
494 }
495 
DTRSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> A,sp<Allocation> X,int incX)496 void ScriptIntrinsicBLAS::DTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
497                                 sp<Allocation> A,  sp<Allocation> X,  int incX) {
498     // TRSV is the same as TRMV
499     validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
500     int N = A->getType()->getY();
501     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsv,
502                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
503                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
504 
505 }
506 
CTRSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> A,sp<Allocation> X,int incX)507 void ScriptIntrinsicBLAS::CTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
508                                 sp<Allocation> A, sp<Allocation> X, int incX) {
509     // TRSV is the same as TRMV
510     validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
511     int N = A->getType()->getY();
512     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsv,
513                                  TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
514                                  A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
515 
516 }
517 
ZTRSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> A,sp<Allocation> X,int incX)518 void ScriptIntrinsicBLAS::ZTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
519                                 sp<Allocation> A, sp<Allocation> X, int incX) {
520     // TRSV is the same as TRMV
521     validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
522     int N = A->getType()->getY();
523     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsv,
524                            TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
525                            A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
526 
527 }
528 
STBSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,sp<Allocation> A,sp<Allocation> X,int incX)529 void ScriptIntrinsicBLAS::STBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
530                                 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
531     // TBSV is the same as TRMV + K >= 0
532     validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
533     int N = A->getType()->getY();
534     if (K < 0) {
535         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
536     }
537     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbsv,
538                                 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
539                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
540 }
541 
DTBSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,sp<Allocation> A,sp<Allocation> X,int incX)542 void ScriptIntrinsicBLAS::DTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
543                                 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
544     // TBSV is the same as TRMV + K >= 0
545     validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
546     int N = A->getType()->getY();
547     if (K < 0) {
548         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
549     }
550     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbsv,
551                                 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
552                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
553 }
554 
CTBSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,sp<Allocation> A,sp<Allocation> X,int incX)555 void ScriptIntrinsicBLAS::CTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
556                                 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
557     // TBSV is the same as TRMV + K >= 0
558     validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
559     int N = A->getType()->getY();
560     if (K < 0) {
561         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
562     }
563     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbsv,
564                                  TransA, 0, 0, Uplo, Diag, 0, N, K,
565                                  0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
566 }
567 
ZTBSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,sp<Allocation> A,sp<Allocation> X,int incX)568 void ScriptIntrinsicBLAS::ZTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
569                                 int K, sp<Allocation> A, sp<Allocation> X, int incX) {
570     // TBSV is the same as TRMV + K >= 0
571     validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
572     int N = A->getType()->getY();
573     if (K < 0) {
574         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
575     }
576     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbsv,
577                            TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
578                            A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
579 }
580 
STPSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> Ap,sp<Allocation> X,int incX)581 void ScriptIntrinsicBLAS::STPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
582                                 sp<Allocation> Ap, sp<Allocation> X, int incX) {
583     // TPSV is same as TPMV
584     int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
585     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpsv,
586                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
587                                 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
588 }
589 
DTPSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> Ap,sp<Allocation> X,int incX)590 void ScriptIntrinsicBLAS::DTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
591                                 sp<Allocation> Ap, sp<Allocation> X, int incX) {
592     // TPSV is same as TPMV
593     int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
594     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpsv,
595                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
596                                 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
597 }
598 
CTPSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> Ap,sp<Allocation> X,int incX)599 void ScriptIntrinsicBLAS::CTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
600                                 sp<Allocation> Ap, sp<Allocation> X, int incX) {
601     // TPSV is same as TPMV
602     int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
603     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpsv,
604                                  TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
605                                  Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
606 }
607 
ZTPSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,sp<Allocation> Ap,sp<Allocation> X,int incX)608 void ScriptIntrinsicBLAS::ZTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
609                                 sp<Allocation> Ap, sp<Allocation> X, int incX) {
610     // TPSV is same as TPMV
611     int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
612     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpsv,
613                            TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
614                            Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
615 }
616 
617 /**
618  * Level 2, S and D only
619  */
validateSYMV(RS * mRS,sp<const Element> e,RsBlasUplo Uplo,sp<Allocation> A,sp<Allocation> X,sp<Allocation> Y,int incX,int incY)620 static int validateSYMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> A,
621                         sp<Allocation> X, sp<Allocation> Y, int incX, int incY) {
622     int N = A->getType()->getY();
623     if ((int)A->getType()->getX() != N) {
624         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for SYMV");
625     }
626     if (!A->getType()->getElement()->isCompatible(e) ||
627         !X->getType()->getElement()->isCompatible(e) ||
628         !Y->getType()->getElement()->isCompatible(e) ) {
629         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
630     }
631     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
632         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
633     }
634 
635     if (incX <= 0 || incY <= 0) {
636         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
637     }
638     int expectedXDim = 1 + (N - 1) * incX;
639     if ((int)X->getType()->getX() != expectedXDim) {
640         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV");
641     }
642     int expectedYDim = 1 + (N - 1) * incY;
643     if ((int)Y->getType()->getX() != expectedYDim) {
644         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV");
645     }
646     return N;
647 }
validateSPMV(RS * mRS,sp<const Element> e,RsBlasUplo Uplo,sp<Allocation> Ap,sp<Allocation> X,int incX,sp<Allocation> Y,int incY)648 static int validateSPMV(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> Ap,
649                         sp<Allocation> X, int incX, sp<Allocation> Y, int incY) {
650     if (!Ap->getType()->getElement()->isCompatible(e) ||
651         !X->getType()->getElement()->isCompatible(e) ||
652         !Y->getType()->getElement()->isCompatible(e)) {
653         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
654     }
655     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
656         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
657     }
658 
659     if (Ap->getType()->getY() > 1) {
660         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
661     }
662 
663     int N = sqrt((double)Ap->getType()->getX() * 2);
664     if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
665         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
666     }
667     if (incX <= 0 || incY <= 0) {
668         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
669     }
670     int expectedXDim = 1 + (N - 1) * incX;
671     if ((int)X->getType()->getX() != expectedXDim) {
672         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV");
673     }
674     int expectedYDim = 1 + (N - 1) * incY;
675     if ((int)Y->getType()->getX() != expectedYDim) {
676         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV");
677     }
678 
679     return N;
680 }
validateGER(RS * mRS,sp<const Element> e,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)681 static void validateGER(RS* mRS, sp<const Element> e, sp<Allocation> X, int incX,
682                         sp<Allocation> Y, int incY, sp<Allocation> A) {
683     if (!A->getType()->getElement()->isCompatible(e) ||
684         !X->getType()->getElement()->isCompatible(e) ||
685         !Y->getType()->getElement()->isCompatible(e) ) {
686         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
687     }
688 
689     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
690         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
691     }
692 
693     int M = A->getType()->getY();
694     int N = A->getType()->getX();
695 
696     if (N < 1 || M < 1) {
697         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "M and N must be 1 or greater for GER");
698     }
699     if (incX <= 0 || incY <= 0) {
700         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
701     }
702     int expectedXDim = 1 + (M - 1) * incX;
703     if ((int)X->getType()->getX() != expectedXDim) {
704         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER");
705     }
706     int expectedYDim = 1 + (N - 1) * incY;
707     if ((int)Y->getType()->getX() != expectedYDim) {
708         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER");
709     }
710 
711 
712 }
validateSYR(RS * mRS,sp<const Element> e,RsBlasUplo Uplo,sp<Allocation> X,int incX,sp<Allocation> A)713 static int validateSYR(RS* mRS, sp<const Element> e, RsBlasUplo Uplo,
714                        sp<Allocation> X, int incX, sp<Allocation> A) {
715     if (!A->getType()->getElement()->isCompatible(e) ||
716         !X->getType()->getElement()->isCompatible(e)) {
717         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
718     }
719 
720     int N = A->getType()->getX();
721 
722     if (X->getType()->getY() > 1) {
723         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
724     }
725     if (N != (int)A->getType()->getY()) {
726         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix");
727     }
728     if (incX <= 0) {
729         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
730     }
731     int expectedXDim = 1 + (N - 1) * incX;
732     if ((int)X->getType()->getX() != expectedXDim) {
733         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR");
734     }
735     return N;
736 }
validateSPR(RS * mRS,sp<const Element> e,RsBlasUplo Uplo,sp<Allocation> X,int incX,sp<Allocation> Ap)737 static int validateSPR(RS* mRS, sp<const Element> e, RsBlasUplo Uplo,
738                        sp<Allocation> X, int incX, sp<Allocation> Ap) {
739     if (!Ap->getType()->getElement()->isCompatible(e) ||
740         !X->getType()->getElement()->isCompatible(e)) {
741         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
742     }
743     if (X->getType()->getY() > 1) {
744         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
745     }
746 
747     if (Ap->getType()->getY() > 1) {
748         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
749     }
750 
751     int N = sqrt((double)Ap->getType()->getX() * 2);
752     if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
753         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
754     }
755     if (incX <= 0) {
756         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
757     }
758     int expectedXDim = 1 + (N - 1) * incX;
759     if ((int)X->getType()->getX() != expectedXDim) {
760         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR");
761     }
762 
763     return N;
764 }
765 
validateSYR2(RS * mRS,sp<const Element> e,RsBlasUplo Uplo,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)766 static int validateSYR2(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> X,
767                         int incX, sp<Allocation> Y, int incY, sp<Allocation> A) {
768     if (!A->getType()->getElement()->isCompatible(e) ||
769         !X->getType()->getElement()->isCompatible(e) ||
770         !Y->getType()->getElement()->isCompatible(e)) {
771         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
772     }
773 
774     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
775         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
776     }
777 
778     int N = A->getType()->getX();
779 
780     if (N != (int)A->getType()->getY()) {
781         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix");
782     }
783     if (incX <= 0 || incY <= 0) {
784         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
785     }
786     int expectedXDim = 1 + (N - 1) * incX;
787     int expectedYDim = 1 + (N - 1) * incY;
788     if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) {
789         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR");
790     }
791     return N;
792 
793 }
validateSPR2(RS * mRS,sp<const Element> e,RsBlasUplo Uplo,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> Ap)794 static int validateSPR2(RS* mRS, sp<const Element> e, RsBlasUplo Uplo, sp<Allocation> X,
795                         int incX, sp<Allocation> Y, int incY, sp<Allocation> Ap) {
796     if (!Ap->getType()->getElement()->isCompatible(e) ||
797         !X->getType()->getElement()->isCompatible(e) ||
798         !Y->getType()->getElement()->isCompatible(e)) {
799         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
800     }
801     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
802         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
803     }
804 
805     if (Ap->getType()->getY() > 1) {
806         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
807     }
808 
809     int N = sqrt((double)Ap->getType()->getX() * 2);
810     if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
811         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
812     }
813     if (incX <= 0 || incY <= 0) {
814         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
815     }
816     int expectedXDim = 1 + (N - 1) * incX;
817     int expectedYDim = 1 + (N - 1) * incY;
818     if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) {
819         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR2");
820     }
821 
822     return N;
823 }
824 
SSYMV(RsBlasUplo Uplo,float alpha,sp<Allocation> A,sp<Allocation> X,int incX,float beta,sp<Allocation> Y,int incY)825 void ScriptIntrinsicBLAS::SSYMV(RsBlasUplo Uplo, float alpha, sp<Allocation> A, sp<Allocation> X,
826                                 int incX, float beta, sp<Allocation> Y, int incY) {
827     int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY);
828     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymv,
829                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
830                                 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
831 }
832 
SSBMV(RsBlasUplo Uplo,int K,float alpha,sp<Allocation> A,sp<Allocation> X,int incX,float beta,sp<Allocation> Y,int incY)833 void ScriptIntrinsicBLAS::SSBMV(RsBlasUplo Uplo, int K, float alpha, sp<Allocation> A, sp<Allocation> X,
834                                 int incX, float beta, sp<Allocation> Y, int incY) {
835     // SBMV is the same as SYMV + K >= 0
836     if (K < 0) {
837         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
838     }
839     int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY);
840     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssbmv,
841                                 0, 0, 0, Uplo, 0, 0, N, K, alpha,
842                                 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
843 }
844 
SSPMV(RsBlasUplo Uplo,float alpha,sp<Allocation> Ap,sp<Allocation> X,int incX,float beta,sp<Allocation> Y,int incY)845 void ScriptIntrinsicBLAS::SSPMV(RsBlasUplo Uplo, float alpha, sp<Allocation> Ap, sp<Allocation> X,
846                                 int incX, float beta, sp<Allocation> Y, int incY) {
847     int N = validateSPMV(mRS, Element::F32(mRS), Uplo, Ap, X, incX, Y, incY);
848     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspmv,
849                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
850                                 Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
851 }
852 
SGER(float alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)853 void ScriptIntrinsicBLAS::SGER(float alpha, sp<Allocation> X, int incX,
854                                sp<Allocation> Y, int incY, sp<Allocation> A) {
855     int M = A->getType()->getY();
856     int N = A->getType()->getX();
857     validateGER(mRS, Element::F32(mRS), X, incX, Y, incY, A);
858     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sger,
859                                 0, 0, 0, 0, 0, M, N, 0, alpha,
860                                 X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0);
861 }
862 
SSYR(RsBlasUplo Uplo,float alpha,sp<Allocation> X,int incX,sp<Allocation> A)863 void ScriptIntrinsicBLAS::SSYR(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
864                                int incX, sp<Allocation> A) {
865     int N = validateSYR(mRS, Element::F32(mRS), Uplo, X, incX, A);
866     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr,
867                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
868                                 X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0);
869 }
870 
SSPR(RsBlasUplo Uplo,float alpha,sp<Allocation> X,int incX,sp<Allocation> Ap)871 void ScriptIntrinsicBLAS::SSPR(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
872                                int incX, sp<Allocation> Ap) {
873     int N = validateSPR(mRS, Element::F32(mRS), Uplo, X, incX, Ap);
874     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr,
875                                 0, 0, 0, Uplo, 0, 0, N, 0,
876                                 alpha, X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0);
877 }
878 
SSYR2(RsBlasUplo Uplo,float alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)879 void ScriptIntrinsicBLAS::SSYR2(RsBlasUplo Uplo, float alpha, sp<Allocation> X, int incX,
880                                 sp<Allocation> Y, int incY, sp<Allocation> A) {
881     int N = validateSYR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, A);
882     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2,
883                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
884                                 X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0);
885 }
886 
SSPR2(RsBlasUplo Uplo,float alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> Ap)887 void ScriptIntrinsicBLAS::SSPR2(RsBlasUplo Uplo, float alpha, sp<Allocation> X, int incX,
888                                 sp<Allocation> Y, int incY, sp<Allocation> Ap) {
889     int N = validateSPR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, Ap);
890     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr2,
891                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
892                                 X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0);
893 }
894 
DSYMV(RsBlasUplo Uplo,double alpha,sp<Allocation> A,sp<Allocation> X,int incX,double beta,sp<Allocation> Y,int incY)895 void ScriptIntrinsicBLAS::DSYMV(RsBlasUplo Uplo, double alpha, sp<Allocation> A, sp<Allocation> X,
896                                 int incX, double beta, sp<Allocation> Y, int incY) {
897     int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY);
898     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymv,
899                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
900                                 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
901 }
902 
DSBMV(RsBlasUplo Uplo,int K,double alpha,sp<Allocation> A,sp<Allocation> X,int incX,double beta,sp<Allocation> Y,int incY)903 void ScriptIntrinsicBLAS::DSBMV(RsBlasUplo Uplo, int K, double alpha, sp<Allocation> A, sp<Allocation> X,
904                                 int incX, double beta, sp<Allocation> Y, int incY) {
905     // SBMV is the same as SYMV + K >= 0
906     if (K < 0) {
907         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
908     }
909     int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY);
910     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsbmv,
911                                 0, 0, 0, Uplo, 0, 0, N, K, alpha,
912                                 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
913 }
914 
DSPMV(RsBlasUplo Uplo,double alpha,sp<Allocation> Ap,sp<Allocation> X,int incX,double beta,sp<Allocation> Y,int incY)915 void ScriptIntrinsicBLAS::DSPMV(RsBlasUplo Uplo, double alpha, sp<Allocation> Ap, sp<Allocation> X,
916                                 int incX, double beta, sp<Allocation> Y, int incY) {
917     int N = validateSPMV(mRS, Element::F64(mRS), Uplo, Ap, X, incX, Y, incY);
918     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspmv,
919                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
920                                 Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
921 }
922 
DGER(double alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)923 void ScriptIntrinsicBLAS::DGER(double alpha, sp<Allocation> X, int incX, sp<Allocation> Y,
924                                int incY, sp<Allocation> A) {
925     int M = A->getType()->getY();
926     int N = A->getType()->getX();
927     validateGER(mRS, Element::F64(mRS), X, incX, Y, incY, A);
928     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dger,
929                                 0, 0, 0, 0, 0, M, N, 0, alpha,
930                                 X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0);
931 }
932 
DSYR(RsBlasUplo Uplo,double alpha,sp<Allocation> X,int incX,sp<Allocation> A)933 void ScriptIntrinsicBLAS::DSYR(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
934                                int incX, sp<Allocation> A) {
935     int N = validateSYR(mRS, Element::F64(mRS), Uplo, X, incX, A);
936     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr,
937                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
938                                 X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0);
939 }
940 
DSPR(RsBlasUplo Uplo,double alpha,sp<Allocation> X,int incX,sp<Allocation> Ap)941 void ScriptIntrinsicBLAS::DSPR(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
942                                int incX, sp<Allocation> Ap) {
943     int N = validateSPR(mRS, Element::F64(mRS), Uplo, X, incX, Ap);
944     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr,
945                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
946                                 X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0);
947 }
948 
DSYR2(RsBlasUplo Uplo,double alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)949 void ScriptIntrinsicBLAS::DSYR2(RsBlasUplo Uplo, double alpha, sp<Allocation> X, int incX,
950                                 sp<Allocation> Y, int incY, sp<Allocation> A) {
951     int N = validateSYR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, A);
952     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2,
953                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
954                                 X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0);
955 }
956 
DSPR2(RsBlasUplo Uplo,double alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> Ap)957 void ScriptIntrinsicBLAS::DSPR2(RsBlasUplo Uplo, double alpha, sp<Allocation> X, int incX,
958                                 sp<Allocation> Y, int incY, sp<Allocation> Ap) {
959     int N = validateSPR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, Ap);
960     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr2,
961                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
962                                 X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0);
963 }
964 
965 
966 /**
967  * Level 2, C and Z only
968  */
969 
validateGERU(RS * mRS,sp<const Element> e,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)970 static void validateGERU(RS* mRS, sp<const Element> e, sp<Allocation> X, int incX,
971                          sp<Allocation> Y, int incY, sp<Allocation> A) {
972     if (!A->getType()->getElement()->isCompatible(e) ||
973         !X->getType()->getElement()->isCompatible(e) ||
974         !Y->getType()->getElement()->isCompatible(e)) {
975         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
976     }
977     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
978         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
979     }
980 
981     int M = A->getType()->getY();
982     int N = A->getType()->getX();
983     if (incX <= 0 || incY <= 0) {
984         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
985     }
986     int expectedXDim = 1 + (M - 1) * incX;
987     if ((int)X->getType()->getX() != expectedXDim) {
988         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU");
989     }
990     int expectedYDim = 1 + (N - 1) * incY;
991     if ((int)Y->getType()->getX() != expectedYDim) {
992         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU");
993     }
994 
995 }
996 
CHEMV(RsBlasUplo Uplo,Float2 alpha,sp<Allocation> A,sp<Allocation> X,int incX,Float2 beta,sp<Allocation> Y,int incY)997 void ScriptIntrinsicBLAS::CHEMV(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> A,
998                                 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
999     // HEMV is the same as SYR2 validation-wise
1000     int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1001     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemv,
1002                                  0, 0, 0, Uplo, 0, 0, N, 0,
1003                                  alpha.x, alpha.y, A->getID(), X->getID(),
1004                                  beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1005 }
1006 
CHBMV(RsBlasUplo Uplo,int K,Float2 alpha,sp<Allocation> A,sp<Allocation> X,int incX,Float2 beta,sp<Allocation> Y,int incY)1007 void ScriptIntrinsicBLAS::CHBMV(RsBlasUplo Uplo, int K, Float2 alpha, sp<Allocation> A,
1008                                 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
1009     // HBMV is the same as SYR2 validation-wise
1010     int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1011     if (K < 0) {
1012         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV");
1013     }
1014     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chbmv,
1015                                  0, 0, 0, Uplo, 0, 0, N, K,
1016                                  alpha.x, alpha.y, A->getID(), X->getID(),
1017                                  beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1018 }
1019 
CHPMV(RsBlasUplo Uplo,Float2 alpha,sp<Allocation> Ap,sp<Allocation> X,int incX,Float2 beta,sp<Allocation> Y,int incY)1020 void ScriptIntrinsicBLAS::CHPMV(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> Ap,
1021                                 sp<Allocation> X, int incX, Float2 beta, sp<Allocation> Y, int incY) {
1022     // HPMV is the same as SPR2
1023     int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
1024     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpmv,
1025                                  0, 0, 0, Uplo, 0, 0, N, 0,
1026                                  alpha.x, alpha.y, Ap->getID(), X->getID(),
1027                                  beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1028 }
1029 
CGERU(Float2 alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)1030 void ScriptIntrinsicBLAS::CGERU(Float2 alpha, sp<Allocation> X, int incX,
1031                                 sp<Allocation> Y, int incY, sp<Allocation> A) {
1032     validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A);
1033     int M = A->getType()->getY();
1034     int N = A->getType()->getX();
1035     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgeru,
1036                                  0, 0, 0, 0, 0, M, N, 0,
1037                                  alpha.x, alpha.y, X->getID(), Y->getID(),
1038                                  0, 0, A->getID(), incX, incY, 0, 0);
1039 }
1040 
CGERC(Float2 alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)1041 void ScriptIntrinsicBLAS::CGERC(Float2 alpha, sp<Allocation> X, int incX,
1042                                 sp<Allocation> Y, int incY, sp<Allocation> A) {
1043     // Same as GERU
1044     validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A);
1045     int M = A->getType()->getY();
1046     int N = A->getType()->getX();
1047     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgerc,
1048                                  0, 0, 0, 0, 0, M, N, 0,
1049                                  alpha.x, alpha.y, X->getID(), Y->getID(),
1050                                  0, 0, A->getID(), incX, incY, 0, 0);
1051 }
1052 
CHER(RsBlasUplo Uplo,float alpha,sp<Allocation> X,int incX,sp<Allocation> A)1053 void ScriptIntrinsicBLAS::CHER(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
1054                                int incX, sp<Allocation> A) {
1055     // Same as SYR
1056     int N = validateSYR(mRS, Element::F32_2(mRS), Uplo, X, incX, A);
1057     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher,
1058                                  0, 0, 0, Uplo, 0, 0, N, 0,
1059                                  alpha, 0, X->getID(), 0,
1060                                  0, 0, A->getID(), incX, 0, 0, 0);
1061 }
1062 
CHPR(RsBlasUplo Uplo,float alpha,sp<Allocation> X,int incX,sp<Allocation> Ap)1063 void ScriptIntrinsicBLAS::CHPR(RsBlasUplo Uplo, float alpha, sp<Allocation> X,
1064                                int incX, sp<Allocation> Ap) {
1065     // Equivalent to SPR for validation
1066     int N = validateSPR(mRS, Element::F32_2(mRS), Uplo, X, incX, Ap);
1067     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr,
1068                                  0, 0, 0, Uplo, 0, 0, N, 0,
1069                                  alpha, 0, X->getID(), 0,
1070                                  0, 0, Ap->getID(), incX, 0, 0, 0);
1071 }
1072 
CHER2(RsBlasUplo Uplo,Float2 alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)1073 void ScriptIntrinsicBLAS::CHER2(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> X, int incX,
1074                                 sp<Allocation> Y, int incY, sp<Allocation> A) {
1075     // Same as SYR2
1076     int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1077     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2,
1078                                  0, 0, 0, Uplo, 0, 0, N, 0,
1079                                  alpha.x, alpha.y, X->getID(), Y->getID(),
1080                                  0, 0, A->getID(), incX, incY, 0, 0);
1081 }
1082 
CHPR2(RsBlasUplo Uplo,Float2 alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> Ap)1083 void ScriptIntrinsicBLAS::CHPR2(RsBlasUplo Uplo, Float2 alpha, sp<Allocation> X, int incX,
1084                                 sp<Allocation> Y, int incY, sp<Allocation> Ap) {
1085     // Same as SPR2
1086     int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
1087     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr2,
1088                                  0, 0, 0, Uplo, 0, 0, N, 0,
1089                                  alpha.x, alpha.y, X->getID(), Y->getID(),
1090                                  0, 0, Ap->getID(), incX, incY, 0, 0);
1091 }
1092 
ZHEMV(RsBlasUplo Uplo,Double2 alpha,sp<Allocation> A,sp<Allocation> X,int incX,Double2 beta,sp<Allocation> Y,int incY)1093 void ScriptIntrinsicBLAS::ZHEMV(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> A,
1094                                 sp<Allocation> X, int incX, Double2 beta, sp<Allocation> Y, int incY) {
1095     // HEMV is the same as SYR2 validation-wise
1096     int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1097     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemv,
1098                            0, 0, 0, Uplo, 0, 0, N, 0,
1099                            alpha.x, alpha.y, A->getID(), X->getID(),
1100                            beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1101 }
1102 
ZHBMV(RsBlasUplo Uplo,int K,Double2 alpha,sp<Allocation> A,sp<Allocation> X,int incX,Double2 beta,sp<Allocation> Y,int incY)1103 void ScriptIntrinsicBLAS::ZHBMV(RsBlasUplo Uplo, int K, Double2 alpha, sp<Allocation> A, sp<Allocation> X,
1104                                 int incX, Double2 beta, sp<Allocation> Y, int incY) {
1105     // HBMV is the same as SYR2 validation-wise
1106     int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1107     if (K < 0) {
1108         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV");
1109     }
1110     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhbmv,
1111                            0, 0, 0, Uplo, 0, 0, N, K,
1112                            alpha.x, alpha.y, A->getID(), X->getID(),
1113                            beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1114 }
1115 
ZHPMV(RsBlasUplo Uplo,Double2 alpha,sp<Allocation> Ap,sp<Allocation> X,int incX,Double2 beta,sp<Allocation> Y,int incY)1116 void ScriptIntrinsicBLAS::ZHPMV(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> Ap, sp<Allocation> X,
1117                                 int incX, Double2 beta, sp<Allocation> Y, int incY) {
1118     // HPMV is the same as SPR2
1119     int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
1120     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpmv,
1121                            0, 0, 0, Uplo, 0, 0, N, 0,
1122                            alpha.x, alpha.y, Ap->getID(), X->getID(),
1123                            beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1124 }
1125 
ZGERU(Double2 alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)1126 void ScriptIntrinsicBLAS::ZGERU(Double2 alpha, sp<Allocation> X, int incX,
1127                                 sp<Allocation> Y, int incY, sp<Allocation> A) {
1128     validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A);
1129     int M = A->getType()->getY();
1130     int N = A->getType()->getX();
1131     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgeru,
1132                            0, 0, 0, 0, 0, M, N, 0,
1133                            alpha.x, alpha.y, X->getID(), Y->getID(),
1134                            0, 0, A->getID(), incX, incY, 0, 0);
1135 }
1136 
ZGERC(Double2 alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)1137 void ScriptIntrinsicBLAS::ZGERC(Double2 alpha, sp<Allocation> X, int incX,
1138                                 sp<Allocation> Y, int incY, sp<Allocation> A) {
1139     // Same as GERU
1140     validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A);
1141     int M = A->getType()->getY();
1142     int N = A->getType()->getX();
1143     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgerc,
1144                            0, 0, 0, 0, 0, M, N, 0,
1145                            alpha.x, alpha.y, X->getID(), Y->getID(),
1146                            0, 0, A->getID(), incX, incY, 0, 0);
1147 }
1148 
ZHER(RsBlasUplo Uplo,double alpha,sp<Allocation> X,int incX,sp<Allocation> A)1149 void ScriptIntrinsicBLAS::ZHER(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
1150                                int incX, sp<Allocation> A) {
1151     // Same as SYR
1152     int N = validateSYR(mRS, Element::F64_2(mRS), Uplo, X, incX, A);
1153     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher,
1154                            0, 0, 0, Uplo, 0, 0, N, 0,
1155                            alpha, 0, X->getID(), 0,
1156                            0, 0, A->getID(), incX, 0, 0, 0);
1157 }
1158 
ZHPR(RsBlasUplo Uplo,double alpha,sp<Allocation> X,int incX,sp<Allocation> Ap)1159 void ScriptIntrinsicBLAS::ZHPR(RsBlasUplo Uplo, double alpha, sp<Allocation> X,
1160                                int incX, sp<Allocation> Ap) {
1161     // Equivalent to SPR for validation
1162     int N = validateSPR(mRS, Element::F64_2(mRS), Uplo, X, incX, Ap);
1163     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr,
1164                            0, 0, 0, Uplo, 0, 0, N, 0,
1165                            alpha, 0, X->getID(), 0,
1166                            0, 0, Ap->getID(), incX, 0, 0, 0);
1167 }
1168 
ZHER2(RsBlasUplo Uplo,Double2 alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> A)1169 void ScriptIntrinsicBLAS::ZHER2(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> X, int incX,
1170                                 sp<Allocation> Y, int incY, sp<Allocation> A) {
1171     // Same as SYR2
1172     int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1173     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2,
1174                            0, 0, 0, Uplo, 0, 0, N, 0,
1175                            alpha.x, alpha.y, X->getID(), Y->getID(),
1176                            0, 0, A->getID(), incX, incY, 0, 0);
1177 }
1178 
ZHPR2(RsBlasUplo Uplo,Double2 alpha,sp<Allocation> X,int incX,sp<Allocation> Y,int incY,sp<Allocation> Ap)1179 void ScriptIntrinsicBLAS::ZHPR2(RsBlasUplo Uplo, Double2 alpha, sp<Allocation> X, int incX,
1180                                 sp<Allocation> Y, int incY, sp<Allocation> Ap) {
1181     // Same as SPR2
1182     int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
1183     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr2,
1184                            0, 0, 0, Uplo, 0, 0, N, 0,
1185                            alpha.x, alpha.y, X->getID(), Y->getID(),
1186                            0, 0, Ap->getID(), incX, incY, 0, 0);
1187 }
1188 
1189 
1190 /**
1191  * Level 3 BLAS
1192  */
1193 
validateL3(RS * mRS,sp<const Element> e,int TransA,int TransB,int Side,sp<Allocation> A,sp<Allocation> B,sp<Allocation> C)1194 static void validateL3(RS* mRS, sp<const Element> e, int TransA, int TransB, int Side,
1195                        sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
1196     int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1;
1197     if ((A != nullptr && !A->getType()->getElement()->isCompatible(e)) ||
1198         (B != nullptr && !B->getType()->getElement()->isCompatible(e)) ||
1199         (C != nullptr && !C->getType()->getElement()->isCompatible(e))) {
1200         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1201     }
1202     if (C == nullptr) {
1203         // Since matrix C is used to store the result, it cannot be null.
1204         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Allocation C cannot be null");
1205     }
1206     cM = C->getType()->getY();
1207     cN = C->getType()->getX();
1208 
1209     if (Side == RsBlasRight) {
1210         if ((A == nullptr && B != nullptr) || (A != nullptr && B == nullptr)) {
1211             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Provided Matrix A without Matrix B, or vice versa");
1212         }
1213         if (B != nullptr) {
1214             bM = A->getType()->getY();
1215             bN = A->getType()->getX();
1216         }
1217         if (A != nullptr) {
1218             aM = B->getType()->getY();
1219             aN = B->getType()->getX();
1220         }
1221     } else {
1222         if (A != nullptr) {
1223             if (TransA == RsBlasTrans || TransA == RsBlasConjTrans) {
1224                 aN = A->getType()->getY();
1225                 aM = A->getType()->getX();
1226             } else {
1227                 aM = A->getType()->getY();
1228                 aN = A->getType()->getX();
1229             }
1230         }
1231         if (B != nullptr) {
1232             if (TransB == RsBlasTrans || TransB == RsBlasConjTrans) {
1233                 bN = B->getType()->getY();
1234                 bM = B->getType()->getX();
1235             } else {
1236                 bM = B->getType()->getY();
1237                 bN = B->getType()->getX();
1238             }
1239         }
1240     }
1241     if (A != nullptr && B != nullptr && C != nullptr) {
1242         if (aN != bM || aM != cM || bN != cN) {
1243             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1244         }
1245     } else if (A != nullptr && C != nullptr) {
1246         // A and C only, for SYRK
1247         if (cM != cN) {
1248             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix C is not symmetric");
1249         }
1250         if (aM != cM) {
1251             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1252         }
1253     } else if (A != nullptr && B != nullptr) {
1254         // A and B only
1255         if (aN != bM) {
1256             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1257         }
1258     }
1259 
1260 }
1261 
SGEMM(RsBlasTranspose TransA,RsBlasTranspose TransB,float alpha,sp<Allocation> A,sp<Allocation> B,float beta,sp<Allocation> C)1262 void ScriptIntrinsicBLAS::SGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, float alpha,
1263                                 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
1264     validateL3(mRS, Element::F32(mRS), TransA, TransB, 0, A, B, C);
1265 
1266     int M = -1, N = -1, K = -1;
1267     if (TransA != RsBlasNoTrans) {
1268         M = A->getType()->getX();
1269         K = A->getType()->getY();
1270     } else {
1271         M = A->getType()->getY();
1272         K = A->getType()->getX();
1273     }
1274     if (TransB != RsBlasNoTrans) {
1275         N = B->getType()->getY();
1276     } else {
1277         N = B->getType()->getX();
1278     }
1279     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemm,
1280                                 TransA, TransB, 0, 0, 0, M, N, K,
1281                                 alpha, A->getID(), B->getID(),
1282                                 beta, C->getID(), 0, 0, 0, 0);
1283 }
1284 
DGEMM(RsBlasTranspose TransA,RsBlasTranspose TransB,double alpha,sp<Allocation> A,sp<Allocation> B,double beta,sp<Allocation> C)1285 void ScriptIntrinsicBLAS::DGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, double alpha,
1286                                 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
1287     validateL3(mRS, Element::F64(mRS), TransA, TransB, 0, A, B, C);
1288     int M = -1, N = -1, K = -1;
1289     if (TransA != RsBlasNoTrans) {
1290         M = A->getType()->getX();
1291         K = A->getType()->getY();
1292     } else {
1293         M = A->getType()->getY();
1294         K = A->getType()->getX();
1295     }
1296     if (TransB != RsBlasNoTrans) {
1297         N = B->getType()->getY();
1298     } else {
1299         N = B->getType()->getX();
1300     }
1301     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemm,
1302                                 TransA, TransB, 0, 0, 0, M, N, K,
1303                                 alpha, A->getID(), B->getID(),
1304                                 beta, C->getID(), 0, 0, 0, 0);
1305 }
1306 
CGEMM(RsBlasTranspose TransA,RsBlasTranspose TransB,Float2 alpha,sp<Allocation> A,sp<Allocation> B,Float2 beta,sp<Allocation> C)1307 void ScriptIntrinsicBLAS::CGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Float2 alpha,
1308                                 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
1309     validateL3(mRS, Element::F32_2(mRS), TransA, TransB, 0, A, B, C);
1310     int M = -1, N = -1, K = -1;
1311     if (TransA != RsBlasNoTrans) {
1312         M = A->getType()->getX();
1313         K = A->getType()->getY();
1314     } else {
1315         M = A->getType()->getY();
1316         K = A->getType()->getX();
1317     }
1318     if (TransB != RsBlasNoTrans) {
1319         N = B->getType()->getY();
1320     } else {
1321         N = B->getType()->getX();
1322     }
1323     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemm,
1324                                  TransA, TransB, 0, 0, 0, M, N, K,
1325                                  alpha.x, alpha.y, A->getID(), B->getID(),
1326                                  beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1327 }
1328 
ZGEMM(RsBlasTranspose TransA,RsBlasTranspose TransB,Double2 alpha,sp<Allocation> A,sp<Allocation> B,Double2 beta,sp<Allocation> C)1329 void ScriptIntrinsicBLAS::ZGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Double2 alpha,
1330                                 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
1331     validateL3(mRS, Element::F64_2(mRS), TransA, TransB, 0, A, B, C);
1332     int M = -1, N = -1, K = -1;
1333     if (TransA != RsBlasNoTrans) {
1334         M = A->getType()->getX();
1335         K = A->getType()->getY();
1336     } else {
1337         M = A->getType()->getY();
1338         K = A->getType()->getX();
1339     }
1340     if (TransB != RsBlasNoTrans) {
1341         N = B->getType()->getY();
1342     } else {
1343         N = B->getType()->getX();
1344     }
1345     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemm,
1346                            TransA, TransB, 0, 0, 0, M, N, K,
1347                            alpha.x, alpha.y, A->getID(), B->getID(),
1348                            beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1349 }
1350 
SSYMM(RsBlasSide Side,RsBlasUplo Uplo,float alpha,sp<Allocation> A,sp<Allocation> B,float beta,sp<Allocation> C)1351 void ScriptIntrinsicBLAS::SSYMM(RsBlasSide Side, RsBlasUplo Uplo, float alpha,
1352                                 sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
1353     //For SYMM, Matrix A should be symmetric
1354     if (A->getType()->getX() != A->getType()->getY()) {
1355         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1356     }
1357     validateL3(mRS, Element::F32(mRS), 0, 0, Side, A, B, C);
1358     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymm,
1359                                 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1360                                 alpha, A->getID(), B->getID(),
1361                                 beta, C->getID(), 0, 0, 0, 0);
1362 }
1363 
DSYMM(RsBlasSide Side,RsBlasUplo Uplo,double alpha,sp<Allocation> A,sp<Allocation> B,double beta,sp<Allocation> C)1364 void ScriptIntrinsicBLAS::DSYMM(RsBlasSide Side, RsBlasUplo Uplo, double alpha,
1365                                 sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
1366     if (A->getType()->getX() != A->getType()->getY()) {
1367         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1368     }
1369     validateL3(mRS, Element::F64(mRS), 0, 0, Side, A, B, C);
1370     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymm,
1371                                 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1372                                 alpha, A->getID(), B->getID(),
1373                                 beta, C->getID(), 0, 0, 0, 0);
1374 }
1375 
CSYMM(RsBlasSide Side,RsBlasUplo Uplo,Float2 alpha,sp<Allocation> A,sp<Allocation> B,Float2 beta,sp<Allocation> C)1376 void ScriptIntrinsicBLAS::CSYMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha,
1377                                 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
1378     if (A->getType()->getX() != A->getType()->getY()) {
1379         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1380     }
1381     validateL3(mRS, Element::F32_2(mRS), 0, 0, Side, A, B, C);
1382     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csymm,
1383                                  0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1384                                  alpha.x, alpha.y, A->getID(), B->getID(),
1385                                  beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1386 }
1387 
ZSYMM(RsBlasSide Side,RsBlasUplo Uplo,Double2 alpha,sp<Allocation> A,sp<Allocation> B,Double2 beta,sp<Allocation> C)1388 void ScriptIntrinsicBLAS::ZSYMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha,
1389                                 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
1390     if (A->getType()->getX() != A->getType()->getY()) {
1391         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1392     }
1393     validateL3(mRS, Element::F64_2(mRS), 0, 0, Side, A, B, C);
1394     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsymm,
1395                            0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1396                            alpha.x, alpha.y, A->getID(), B->getID(),
1397                            beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1398 }
1399 
SSYRK(RsBlasUplo Uplo,RsBlasTranspose Trans,float alpha,sp<Allocation> A,float beta,sp<Allocation> C)1400 void ScriptIntrinsicBLAS::SSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1401                                 sp<Allocation> A, float beta, sp<Allocation> C) {
1402     validateL3(mRS, Element::F32(mRS), Trans, 0, 0, A, nullptr, C);
1403     int K = -1;
1404     if (Trans != RsBlasNoTrans) {
1405         K = A->getType()->getY();
1406     } else {
1407         K = A->getType()->getX();
1408     }
1409     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyrk,
1410                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1411                                 alpha, A->getID(), 0,
1412                                 beta, C->getID(), 0, 0, 0, 0);
1413 }
1414 
DSYRK(RsBlasUplo Uplo,RsBlasTranspose Trans,double alpha,sp<Allocation> A,double beta,sp<Allocation> C)1415 void ScriptIntrinsicBLAS::DSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1416                                 sp<Allocation> A, double beta, sp<Allocation> C) {
1417     validateL3(mRS, Element::F64(mRS), Trans, 0, 0, A, nullptr, C);
1418     int K = -1;
1419     if (Trans != RsBlasNoTrans) {
1420         K = A->getType()->getY();
1421     } else {
1422         K = A->getType()->getX();
1423     }
1424     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyrk,
1425                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1426                                 alpha, A->getID(), 0,
1427                                 beta, C->getID(), 0, 0, 0, 0);
1428 }
1429 
CSYRK(RsBlasUplo Uplo,RsBlasTranspose Trans,Float2 alpha,sp<Allocation> A,Float2 beta,sp<Allocation> C)1430 void ScriptIntrinsicBLAS::CSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1431                                 sp<Allocation> A, Float2 beta, sp<Allocation> C) {
1432     validateL3(mRS, Element::F32_2(mRS), Trans, 0, 0, A, nullptr, C);
1433     int K = -1;
1434     if (Trans != RsBlasNoTrans) {
1435         K = A->getType()->getY();
1436     } else {
1437         K = A->getType()->getX();
1438     }
1439     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyrk,
1440                                  Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1441                                  alpha.x, alpha.y, A->getID(), 0,
1442                                  beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1443 }
1444 
ZSYRK(RsBlasUplo Uplo,RsBlasTranspose Trans,Double2 alpha,sp<Allocation> A,Double2 beta,sp<Allocation> C)1445 void ScriptIntrinsicBLAS::ZSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1446                                 sp<Allocation> A, Double2 beta, sp<Allocation> C) {
1447     validateL3(mRS, Element::F64_2(mRS), Trans, 0, 0, A, nullptr, C);
1448     int K = -1;
1449     if (Trans != RsBlasNoTrans) {
1450         K = A->getType()->getY();
1451     } else {
1452         K = A->getType()->getX();
1453     }
1454     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyrk,
1455                            Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1456                            alpha.x, alpha.y, A->getID(), 0,
1457                            beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1458 }
1459 
validateSYR2K(RS * mRS,sp<const Element> e,RsBlasTranspose Trans,sp<Allocation> A,sp<Allocation> B,sp<Allocation> C)1460 static void validateSYR2K(RS* mRS, sp<const Element> e, RsBlasTranspose Trans,
1461                           sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
1462     if (!A->getType()->getElement()->isCompatible(e) ||
1463         !B->getType()->getElement()->isCompatible(e) ||
1464         !C->getType()->getElement()->isCompatible(e)) {
1465         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1466     }
1467     int Cdim = -1;
1468     // A is n x k if no transpose, k x n if transpose
1469     // C is n x n
1470     if (Trans == RsBlasTrans) {
1471         // check columns versus C
1472         Cdim = A->getType()->getX();
1473     } else {
1474         // check rows versus C
1475         Cdim = A->getType()->getY();
1476     }
1477     if ((int)C->getType()->getX() != Cdim || (int)C->getType()->getY() != Cdim) {
1478         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid symmetric matrix in SYR2K");
1479     }
1480     // A dims == B dims
1481     if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) {
1482         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid A and B in SYR2K");
1483     }
1484 }
1485 
SSYR2K(RsBlasUplo Uplo,RsBlasTranspose Trans,float alpha,sp<Allocation> A,sp<Allocation> B,float beta,sp<Allocation> C)1486 void ScriptIntrinsicBLAS::SSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1487                                  sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
1488     validateSYR2K(mRS, Element::F32(mRS), Trans, A, B, C);
1489     int K = -1;
1490     if (Trans != RsBlasNoTrans) {
1491         K = A->getType()->getY();
1492     } else {
1493         K = A->getType()->getX();
1494     }
1495     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2k,
1496                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1497                                 alpha, A->getID(), B->getID(),
1498                                 beta, C->getID(), 0, 0, 0, 0);
1499 }
1500 
DSYR2K(RsBlasUplo Uplo,RsBlasTranspose Trans,double alpha,sp<Allocation> A,sp<Allocation> B,double beta,sp<Allocation> C)1501 void ScriptIntrinsicBLAS::DSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1502                                  sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
1503     validateSYR2K(mRS, Element::F64(mRS), Trans, A, B, C);
1504     int K = -1;
1505     if (Trans != RsBlasNoTrans) {
1506         K = A->getType()->getY();
1507     } else {
1508         K = A->getType()->getX();
1509     }
1510     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2k,
1511                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1512                                 alpha, A->getID(), B->getID(),
1513                                 beta, C->getID(), 0, 0, 0, 0);
1514 }
1515 
CSYR2K(RsBlasUplo Uplo,RsBlasTranspose Trans,Float2 alpha,sp<Allocation> A,sp<Allocation> B,Float2 beta,sp<Allocation> C)1516 void ScriptIntrinsicBLAS::CSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1517                                  sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
1518     validateSYR2K(mRS, Element::F32_2(mRS), Trans, A, B, C);
1519     int K = -1;
1520     if (Trans != RsBlasNoTrans) {
1521         K = A->getType()->getY();
1522     } else {
1523         K = A->getType()->getX();
1524     }
1525     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyr2k,
1526                                  Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1527                                  alpha.x, alpha.y, A->getID(), B->getID(),
1528                                  beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1529 }
1530 
ZSYR2K(RsBlasUplo Uplo,RsBlasTranspose Trans,Double2 alpha,sp<Allocation> A,sp<Allocation> B,Double2 beta,sp<Allocation> C)1531 void ScriptIntrinsicBLAS::ZSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1532                                  sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
1533     validateSYR2K(mRS, Element::F64_2(mRS), Trans, A, B, C);
1534     int K = -1;
1535     if (Trans != RsBlasNoTrans) {
1536         K = A->getType()->getY();
1537     } else {
1538         K = A->getType()->getX();
1539     }
1540     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyr2k,
1541                            Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1542                            alpha.x, alpha.y, A->getID(), B->getID(),
1543                            beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1544 }
1545 
validateTRMM(RS * mRS,sp<const Element> e,RsBlasSide Side,RsBlasTranspose TransA,sp<Allocation> A,sp<Allocation> B)1546 static void validateTRMM(RS* mRS, sp<const Element> e, RsBlasSide Side, RsBlasTranspose TransA,
1547                          sp<Allocation> A, sp<Allocation> B) {
1548     int aM = -1, aN = -1, bM = -1, bN = -1;
1549     if (!A->getType()->getElement()->isCompatible(e) ||
1550         !B->getType()->getElement()->isCompatible(e)) {
1551         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1552     }
1553 
1554     aM = A->getType()->getY();
1555     aN = A->getType()->getX();
1556     if (aM != aN) {
1557         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with a non-symmetric matrix A");
1558     }
1559 
1560     bM = B->getType()->getY();
1561     bN = B->getType()->getX();
1562     if (Side == RsBlasLeft) {
1563         if (aN != bM) {
1564             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices");
1565         }
1566     } else {
1567         if (bN != aM) {
1568             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices");
1569         }
1570     }
1571 }
1572 
STRMM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,float alpha,sp<Allocation> A,sp<Allocation> B)1573 void ScriptIntrinsicBLAS::STRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1574                                 float alpha, sp<Allocation> A, sp<Allocation> B) {
1575     validateTRMM(mRS, Element::F32(mRS), Side, TransA, A, B);
1576     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmm,
1577                                 TransA, 0, Side, Uplo, Diag,\
1578                                 B->getType()->getY(), B->getType()->getX(), 0,
1579                                 alpha, A->getID(), B->getID(), 0.f, 0, 0, 0, 0, 0);
1580 }
1581 
DTRMM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,double alpha,sp<Allocation> A,sp<Allocation> B)1582 void ScriptIntrinsicBLAS::DTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1583                                 double alpha, sp<Allocation> A, sp<Allocation> B) {
1584     validateTRMM(mRS, Element::F64(mRS), Side, TransA, A, B);
1585     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmm,
1586                                 TransA, 0, Side, Uplo, Diag,
1587                                 B->getType()->getY(), B->getType()->getX(), 0,
1588                                 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1589 }
1590 
CTRMM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,Float2 alpha,sp<Allocation> A,sp<Allocation> B)1591 void ScriptIntrinsicBLAS::CTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1592                                 Float2 alpha, sp<Allocation> A, sp<Allocation> B) {
1593     validateTRMM(mRS, Element::F32_2(mRS), Side, TransA, A, B);
1594     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmm,
1595                                  TransA, 0, Side, Uplo, Diag,
1596                                  B->getType()->getY(), B->getType()->getX(), 0,
1597                                  alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1598 }
1599 
ZTRMM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,Double2 alpha,sp<Allocation> A,sp<Allocation> B)1600 void ScriptIntrinsicBLAS::ZTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1601                                 Double2 alpha, sp<Allocation> A, sp<Allocation> B) {
1602     validateTRMM(mRS, Element::F64_2(mRS), Side, TransA, A, B);
1603     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmm,
1604                            TransA, 0, Side, Uplo, Diag,
1605                            B->getType()->getY(), B->getType()->getX(), 0,
1606                            alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1607 }
1608 
validateTRSM(RS * mRS,sp<const Element> e,RsBlasSide Side,RsBlasTranspose TransA,sp<Allocation> A,sp<Allocation> B)1609 static void validateTRSM(RS* mRS, sp<const Element> e, RsBlasSide Side, RsBlasTranspose TransA,
1610                          sp<Allocation> A, sp<Allocation> B) {
1611     int adim = -1, bM = -1, bN = -1;
1612     if (!A->getType()->getElement()->isCompatible(e) ||
1613         !B->getType()->getElement()->isCompatible(e)) {
1614         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1615     }
1616     adim = A->getType()->getX();
1617     if (adim != (int)A->getType()->getY()) {
1618         // This may be unnecessary, the restriction could potentially be relaxed.
1619         // Allocation A needs to contain at least that symmetric matrix but could theoretically
1620         // be larger for now we assume adapters are sufficient, will reevaluate in the future.
1621         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with a non-symmetric matrix A");
1622     }
1623     bM = B->getType()->getY();
1624     bN = B->getType()->getX();
1625     if (Side == RsBlasLeft) {
1626         // A is M*M
1627         if (adim != bM) {
1628             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions");
1629         }
1630     } else {
1631         // A is N*N
1632         if (adim != bN) {
1633             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions");
1634         }
1635     }
1636 }
1637 
STRSM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,float alpha,sp<Allocation> A,sp<Allocation> B)1638 void ScriptIntrinsicBLAS::STRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1639                                 float alpha, sp<Allocation> A, sp<Allocation> B) {
1640     validateTRSM(mRS, Element::F32(mRS), Side, TransA, A, B);
1641     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsm,
1642                                 TransA, 0, Side, Uplo, Diag,
1643                                 B->getType()->getY(), B->getType()->getX(), 0,
1644                                 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1645 }
1646 
DTRSM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,double alpha,sp<Allocation> A,sp<Allocation> B)1647 void ScriptIntrinsicBLAS::DTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1648                                 double alpha, sp<Allocation> A, sp<Allocation> B) {
1649     validateTRSM(mRS, Element::F64(mRS), Side, TransA, A, B);
1650     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsm,
1651                                 TransA, 0, Side, Uplo, Diag,
1652                                 B->getType()->getY(), B->getType()->getX(), 0,
1653                                 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1654 }
1655 
CTRSM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,Float2 alpha,sp<Allocation> A,sp<Allocation> B)1656 void ScriptIntrinsicBLAS::CTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1657                                 Float2 alpha, sp<Allocation> A, sp<Allocation> B) {
1658     validateTRSM(mRS, Element::F32_2(mRS), Side, TransA, A, B);
1659     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsm,
1660                                  TransA, 0, Side, Uplo, Diag,
1661                                  B->getType()->getY(), B->getType()->getX(), 0,
1662                                  alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1663 }
1664 
ZTRSM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,Double2 alpha,sp<Allocation> A,sp<Allocation> B)1665 void ScriptIntrinsicBLAS::ZTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1666                                 Double2 alpha, sp<Allocation> A, sp<Allocation> B) {
1667     validateTRSM(mRS, Element::F64_2(mRS), Side, TransA, A, B);
1668     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsm,
1669                            TransA, 0, Side, Uplo, Diag,
1670                            B->getType()->getY(), B->getType()->getX(), 0,
1671                            alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1672 }
1673 
validateHEMM(RS * mRS,sp<const Element> e,RsBlasSide Side,sp<Allocation> A,sp<Allocation> B,sp<Allocation> C)1674 static void validateHEMM(RS* mRS, sp<const Element> e, RsBlasSide Side,
1675                          sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
1676     if (!A->getType()->getElement()->isCompatible(e) ||
1677         !B->getType()->getElement()->isCompatible(e) ||
1678         !C->getType()->getElement()->isCompatible(e)) {
1679         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1680     }
1681 
1682     // A must be square; can potentially be relaxed similar to TRSM
1683     int adim = A->getType()->getX();
1684     if (adim != (int)A->getType()->getY()) {
1685         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with non-square A");
1686     }
1687     if ((Side == RsBlasLeft && adim != (int)B->getType()->getY()) ||
1688         (Side == RsBlasRight && adim != (int)B->getType()->getX())) {
1689         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with invalid B");
1690     }
1691     if (B->getType()->getX() != C->getType()->getX() ||
1692         B->getType()->getY() != C->getType()->getY()) {
1693         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with mismatched B and C");
1694     }
1695 }
1696 
CHEMM(RsBlasSide Side,RsBlasUplo Uplo,Float2 alpha,sp<Allocation> A,sp<Allocation> B,Float2 beta,sp<Allocation> C)1697 void ScriptIntrinsicBLAS::CHEMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha,
1698                                 sp<Allocation> A, sp<Allocation> B, Float2 beta, sp<Allocation> C) {
1699     validateHEMM(mRS, Element::F32_2(mRS), Side, A, B, C);
1700     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemm,
1701                                  0, 0, Side, Uplo, 0,
1702                                  C->getType()->getY(), C->getType()->getX(), 0,
1703                                  alpha.x, alpha.y, A->getID(), B->getID(),
1704                                  beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1705 }
1706 
ZHEMM(RsBlasSide Side,RsBlasUplo Uplo,Double2 alpha,sp<Allocation> A,sp<Allocation> B,Double2 beta,sp<Allocation> C)1707 void ScriptIntrinsicBLAS::ZHEMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha,
1708                                 sp<Allocation> A, sp<Allocation> B, Double2 beta, sp<Allocation> C) {
1709     validateHEMM(mRS, Element::F64_2(mRS), Side, A, B, C);
1710     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemm,
1711                            0, 0, Side, Uplo, 0,
1712                            C->getType()->getY(), C->getType()->getX(), 0,
1713                            alpha.x, alpha.y, A->getID(), B->getID(),
1714                            beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1715 }
1716 
validateHERK(RS * mRS,sp<const Element> e,RsBlasTranspose Trans,sp<Allocation> A,sp<Allocation> C)1717 static void validateHERK(RS* mRS, sp<const Element> e, RsBlasTranspose Trans,
1718                          sp<Allocation> A, sp<Allocation> C) {
1719     if (!A->getType()->getElement()->isCompatible(e) ||
1720         !C->getType()->getElement()->isCompatible(e)) {
1721         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1722     }
1723     if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) {
1724         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose");
1725     }
1726     int cdim = C->getType()->getX();
1727     if (cdim != (int)C->getType()->getY()) {
1728         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with non-square C");
1729     }
1730     if (Trans == RsBlasNoTrans) {
1731         if (cdim != (int)A->getType()->getY()) {
1732             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A");
1733         }
1734     } else {
1735         if (cdim != (int)A->getType()->getX()) {
1736             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A");
1737         }
1738     }
1739 }
1740 
CHERK(RsBlasUplo Uplo,RsBlasTranspose Trans,float alpha,sp<Allocation> A,float beta,sp<Allocation> C)1741 void ScriptIntrinsicBLAS::CHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1742                                 sp<Allocation> A, float beta, sp<Allocation> C) {
1743     validateHERK(mRS, Element::F32_2(mRS), Trans, A, C);
1744     int k = 0;
1745     if (Trans == RsBlasConjTrans) {
1746         k = A->getType()->getY();
1747     } else {
1748         k = A->getType()->getX();
1749     }
1750     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cherk,
1751                                  Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1752                                  alpha, 0, A->getID(), 0,
1753                                  beta, 0, C->getID(), 0, 0, 0, 0);
1754 }
1755 
ZHERK(RsBlasUplo Uplo,RsBlasTranspose Trans,double alpha,sp<Allocation> A,double beta,sp<Allocation> C)1756 void ScriptIntrinsicBLAS::ZHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1757                                 sp<Allocation> A, double beta, sp<Allocation> C) {
1758     validateHERK(mRS, Element::F64_2(mRS), Trans, A, C);
1759     int k = 0;
1760     if (Trans == RsBlasConjTrans) {
1761         k = A->getType()->getY();
1762     } else {
1763         k = A->getType()->getX();
1764     }
1765     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zherk,
1766                            Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1767                            alpha, 0, A->getID(), 0,
1768                            beta, 0, C->getID(), 0, 0, 0, 0);
1769 }
1770 
validateHER2K(RS * mRS,sp<const Element> e,RsBlasTranspose Trans,sp<Allocation> A,sp<Allocation> B,sp<Allocation> C)1771 static void validateHER2K(RS* mRS, sp<const Element> e, RsBlasTranspose Trans,
1772                           sp<Allocation> A, sp<Allocation> B, sp<Allocation> C) {
1773     if (!A->getType()->getElement()->isCompatible(e) ||
1774         !B->getType()->getElement()->isCompatible(e) ||
1775         !C->getType()->getElement()->isCompatible(e)) {
1776         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1777     }
1778     if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) {
1779         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose");
1780     }
1781     int cdim = C->getType()->getX();
1782     if (cdim != (int)C->getType()->getY()) {
1783         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with non-square C");
1784     }
1785     if (Trans == RsBlasNoTrans) {
1786         if ((int)A->getType()->getY() != cdim) {
1787             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices");
1788         }
1789     } else {
1790         if ((int)A->getType()->getX() != cdim) {
1791             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices");
1792         }
1793     }
1794     if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) {
1795         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid A and B matrices");
1796     }
1797 }
1798 
CHER2K(RsBlasUplo Uplo,RsBlasTranspose Trans,Float2 alpha,sp<Allocation> A,sp<Allocation> B,float beta,sp<Allocation> C)1799 void ScriptIntrinsicBLAS::CHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1800                                  sp<Allocation> A, sp<Allocation> B, float beta, sp<Allocation> C) {
1801     validateHER2K(mRS, Element::F32_2(mRS), Trans, A, B, C);
1802     int k = 0;
1803     if (Trans == RsBlasNoTrans) {
1804         k = A->getType()->getX();
1805     } else {
1806         k = A->getType()->getY();
1807     }
1808     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2k,
1809                                  Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1810                                  alpha.x, alpha.y, A->getID(), B->getID(),
1811                                  beta, 0, C->getID(), 0, 0, 0, 0);
1812 }
1813 
ZHER2K(RsBlasUplo Uplo,RsBlasTranspose Trans,Double2 alpha,sp<Allocation> A,sp<Allocation> B,double beta,sp<Allocation> C)1814 void ScriptIntrinsicBLAS::ZHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1815                                  sp<Allocation> A, sp<Allocation> B, double beta, sp<Allocation> C) {
1816     validateHER2K(mRS, Element::F64_2(mRS), Trans, A, B, C);
1817     int k = 0;
1818     if (Trans == RsBlasNoTrans) {
1819         k = A->getType()->getX();
1820     } else {
1821         k = A->getType()->getY();
1822     }
1823     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2k,
1824                            Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1825                            alpha.x, alpha.y, A->getID(), B->getID(),
1826                            beta, 0, C->getID(), 0, 0, 0, 0);
1827 }
1828 
1829 
1830 
BNNM(sp<Allocation> A,int a_offset,sp<Allocation> B,int b_offset,sp<Allocation> C,int c_offset,int c_mult)1831 void ScriptIntrinsicBLAS::BNNM(sp<Allocation> A, int a_offset, sp<Allocation> B, int b_offset,
1832                                sp<Allocation> C, int c_offset, int c_mult) {
1833     validateL3(mRS, Element::U8(mRS), RsBlasNoTrans, RsBlasTrans, 0, A, B, C);
1834 
1835     if (a_offset < 0 || a_offset > 255) {
1836         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid a_offset passed to BNNM");
1837     }
1838     if (b_offset < 0 || b_offset > 255) {
1839         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid b_offset passed to BNNM");
1840     }
1841     int M = -1, N = -1, K = -1;
1842     M = A->getType()->getY();
1843     N = B->getType()->getY();
1844     K = A->getType()->getX();
1845 
1846     nScriptIntrinsicBLAS_BNNM(mRS, mRS->getContext(), getID(), M, N, K, A->getID(), a_offset,
1847                               B->getID(), b_offset, C->getID(), c_offset, c_mult);
1848 }
1849