1 /*
2  * Copyright (C) 2015 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 
18 #include "RenderScript.h"
19 #include "rsCppInternal.h"
20 
21 #define NELEM(m) (sizeof(m) / sizeof((m)[0]))
22 
23 using android::RSC::Allocation;
24 using android::RSC::Element;
25 using android::RSC::RS;
26 using android::RSC::RS_ERROR_INVALID_ELEMENT;
27 using android::RSC::RS_ERROR_INVALID_PARAMETER;
28 using android::RSC::RS_SUCCESS;
29 using android::RSC::ScriptIntrinsicBLAS;
30 using android::RSC::sp;
31 
32 // ScriptIntrinsicBLAS APIS
ScriptIntrinsicBLAS(sp<RS> rs,sp<const Element> e)33 ScriptIntrinsicBLAS::ScriptIntrinsicBLAS(sp<RS> rs, sp<const Element> e)
34     : ScriptIntrinsic(rs, RS_SCRIPT_INTRINSIC_ID_BLAS, e) {
35 
36 }
37 
create(const sp<RS> & rs)38 sp<ScriptIntrinsicBLAS> ScriptIntrinsicBLAS::create(const sp<RS>& rs) {
39     return new ScriptIntrinsicBLAS(rs, Element::U32(rs));
40 }
41 
42 enum RsBlasDataType {
43     SINGLE,
44     DOUBLE,
45     SINGLE_COMPLEX,
46     DOUBLE_COMPLEX
47 };
48 
49 static RsBlasCall
setUpBLASCall(RsBlasDataType dataType,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,int incX,int incY,int KL,int KU,float alphaF,float betaF,double alphaD,double betaD,float alphaCX,float alphaCY,float betaCX,float betaCY,double alphaZX,double alphaZY,double betaZX,double betaZY)50 setUpBLASCall(RsBlasDataType dataType, RsBlasFunction func,
51               int TransA, int TransB, int Side, int Uplo, int Diag,
52               int M, int N, int K, int incX, int incY, int KL, int KU,
53               float alphaF, float betaF, double alphaD, double betaD,
54               float alphaCX, float alphaCY, float betaCX, float betaCY,
55               double alphaZX, double alphaZY, double betaZX, double betaZY
56               ) {
57     RsBlasCall call;
58     memset(&call, 0, sizeof(call));
59     call.func = func;
60     call.transA = (RsBlasTranspose)TransA;
61     call.transB = (RsBlasTranspose)TransB;
62     call.side = (RsBlasSide)Side;
63     call.uplo = (RsBlasUplo)Uplo;
64     call.diag = (RsBlasDiag)Diag;
65     call.M = M;
66     call.N = N;
67     call.K = K;
68 
69     switch (dataType) {
70         case SINGLE:
71             // For Single-precision BLAS.
72             call.alpha.f = alphaF;
73             call.beta.f = betaF;
74             break;
75         case DOUBLE:
76             // For Double-precision BLAS.
77             call.alpha.d = alphaD;
78             call.beta.d = betaD;
79             break;
80         case SINGLE_COMPLEX:
81             // For Single-precision complex BLAS.
82             call.alpha.c.r = alphaCX;
83             call.alpha.c.i = alphaCY;
84             call.beta.c.r = betaCX;
85             call.beta.c.i = betaCY;
86             break;
87         case DOUBLE_COMPLEX:
88             // For Double-precision complex BLAS.
89             call.alpha.z.r = alphaZX;
90             call.alpha.z.i = alphaZY;
91             call.beta.z.r = betaZX;
92             call.beta.z.i = betaZY;
93             break;
94         default:
95             break;
96     }
97 
98     call.incX = incX;
99     call.incY = incY;
100     call.KL = KL;
101     call.KU = KU;
102 
103     return call;
104 }
105 
106 static void
nScriptIntrinsicBLAS_Single(RS * mRS,RsContext con,RsScript id,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,float alpha,RsAllocation A,RsAllocation B,float beta,RsAllocation C,int incX,int incY,int KL,int KU)107 nScriptIntrinsicBLAS_Single(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
108                             int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
109                             float alpha, RsAllocation A, RsAllocation B,
110                             float beta, RsAllocation C, int incX, int incY, int KL, int KU) {
111     RsBlasCall call = setUpBLASCall(SINGLE, func, TransA, TransB, Side, Uplo, Diag,
112                                     M, N, K, incX, incY, KL, KU, alpha, beta, 0.0, 0.0,
113                                     0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0);
114     RsAllocation in_allocs[3] = {A, B, C};
115     tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
116                                                       &call, sizeof(call), nullptr, 0));
117 }
118 
119 
120 static void
nScriptIntrinsicBLAS_Double(RS * mRS,RsContext con,RsScript id,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,double alpha,RsAllocation A,RsAllocation B,double beta,RsAllocation C,int incX,int incY,int KL,int KU)121 nScriptIntrinsicBLAS_Double(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
122                             int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
123                             double alpha, RsAllocation A, RsAllocation B,
124                             double beta, RsAllocation C, int incX, int incY, int KL, int KU) {
125     RsBlasCall call = setUpBLASCall(DOUBLE, func, TransA, TransB, Side, Uplo, Diag,
126                                     M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, alpha, beta,
127                                     0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0);
128     RsAllocation in_allocs[3] = {A, B, C};
129     tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
130                                                       &call, sizeof(call), nullptr, 0));
131 }
132 
133 static void
nScriptIntrinsicBLAS_Complex(RS * mRS,RsContext con,RsScript id,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,float alphaX,float alphaY,RsAllocation A,RsAllocation B,float betaX,float betaY,RsAllocation C,int incX,int incY,int KL,int KU)134 nScriptIntrinsicBLAS_Complex(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
135                              int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
136                              float alphaX, float alphaY, RsAllocation A, RsAllocation B,
137                              float betaX, float betaY, RsAllocation C, int incX, int incY, int KL, int KU) {
138     RsBlasCall call = setUpBLASCall(SINGLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag,
139                                     M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0,
140                                     alphaX, alphaY, betaX, betaY, 0.0, 0.0, 0.0, 0.0);
141     RsAllocation in_allocs[3] = {A, B, C};
142     tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
143                                                       &call, sizeof(call), nullptr, 0));
144 }
145 
146 static void
nScriptIntrinsicBLAS_Z(RS * mRS,RsContext con,RsScript id,RsBlasFunction func,int TransA,int TransB,int Side,int Uplo,int Diag,int M,int N,int K,double alphaX,double alphaY,RsAllocation A,RsAllocation B,double betaX,double betaY,RsAllocation C,int incX,int incY,int KL,int KU)147 nScriptIntrinsicBLAS_Z(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA,
148                        int TransB, int Side, int Uplo, int Diag, int M, int N, int K,
149                        double alphaX, double alphaY, RsAllocation A, RsAllocation B,
150                        double betaX, double betaY, RsAllocation C, int incX, int incY, int KL, int KU) {
151     RsBlasCall call = setUpBLASCall(DOUBLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag,
152                                     M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0,
153                                     0.0f, 0.0f, 0.0f, 0.0f, alphaX, alphaY, betaX, betaY);
154     RsAllocation in_allocs[3] = {A, B, C};
155     tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
156                                                       &call, sizeof(call), nullptr, 0));
157 }
158 
159 
160 static void
nScriptIntrinsicBLAS_BNNM(RS * mRS,RsContext con,RsScript id,int M,int N,int K,RsAllocation A,int a_offset,RsAllocation B,int b_offset,RsAllocation C,int c_offset,int c_mult_int)161 nScriptIntrinsicBLAS_BNNM(RS* mRS, RsContext con, RsScript id, int M, int N, int K,
162                           RsAllocation A, int a_offset, RsAllocation B, int b_offset,
163                           RsAllocation C, int c_offset, int c_mult_int) {
164     RsBlasCall call;
165     memset(&call, 0, sizeof(call));
166     call.func = RsBlas_bnnm;
167     call.M = M;
168     call.N = N;
169     call.K = K;
170     call.a_offset = a_offset & 0xFF;
171     call.b_offset = b_offset & 0xFF;
172     call.c_offset = c_offset;
173     call.c_mult_int = c_mult_int;
174 
175     RsAllocation in_allocs[3] = {A, B, C};
176     tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr,
177                                                       &call, sizeof(call), nullptr, 0));
178 }
179 
180 /**
181  * Level 2 BLAS
182  */
validateGEMV(RS * mRS,const sp<const Element> & e,RsBlasTranspose TransA,const sp<Allocation> & A,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY)183 static void validateGEMV(RS* mRS, const sp<const Element>& e, RsBlasTranspose TransA, const sp<Allocation>& A,
184                          const sp<Allocation>& X, int incX, const sp<Allocation>& Y, int incY) {
185     int M = A->getType()->getY();
186     int N = A->getType()->getX();
187     if (!A->getType()->getElement()->isCompatible(e) ||
188         !X->getType()->getElement()->isCompatible(e) ||
189         !Y->getType()->getElement()->isCompatible(e)) {
190         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
191     }
192     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
193         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
194     }
195 
196     if (incX <= 0 || incY <= 0) {
197         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
198     }
199     int expectedXDim = -1, expectedYDim = -1;
200     if (TransA == RsBlasNoTrans) {
201         expectedXDim = 1 + (N - 1) * incX;
202         expectedYDim = 1 + (M - 1) * incY;
203     } else {
204         expectedXDim = 1 + (M - 1) * incX;
205         expectedYDim = 1 + (N - 1) * incY;
206     }
207     if ((int)X->getType()->getX() != expectedXDim ||
208         (int)Y->getType()->getX() != expectedYDim) {
209         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GEMV");
210     }
211 }
212 
SGEMV(RsBlasTranspose TransA,float alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,float beta,const sp<Allocation> & Y,int incY)213 void ScriptIntrinsicBLAS::SGEMV(RsBlasTranspose TransA, float alpha, const sp<Allocation>& A, const sp<Allocation>& X,
214                                 int incX, float beta, const sp<Allocation>& Y, int incY) {
215     validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY);
216     int M = A->getType()->getY();
217     int N = A->getType()->getX();
218     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemv,
219                                 TransA, 0, 0, 0, 0, M, N, 0,
220                                 alpha, A->getID(), X->getID(),
221                                 beta, Y->getID(), incX, incY, 0, 0);
222 }
223 
DGEMV(RsBlasTranspose TransA,double alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,double beta,const sp<Allocation> & Y,int incY)224 void ScriptIntrinsicBLAS::DGEMV(RsBlasTranspose TransA, double alpha, const sp<Allocation>& A, const sp<Allocation>& X,
225                                 int incX, double beta, const sp<Allocation>& Y, int incY) {
226     validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY);
227     int M = A->getType()->getY();
228     int N = A->getType()->getX();
229     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemv,
230                                 TransA, 0, 0, 0, 0, M, N, 0,
231                                 alpha, A->getID(), X->getID(),
232                                 beta, Y->getID(), incX, incY, 0, 0);
233 }
234 
CGEMV(RsBlasTranspose TransA,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,Float2 beta,const sp<Allocation> & Y,int incY)235 void ScriptIntrinsicBLAS::CGEMV(RsBlasTranspose TransA, Float2 alpha, const sp<Allocation>& A, const sp<Allocation>& X,
236                                 int incX, Float2 beta, const sp<Allocation>& Y, int incY) {
237     validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY);
238     int M = A->getType()->getY();
239     int N = A->getType()->getX();
240     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemv,
241                                  TransA, 0, 0, 0, 0, M, N, 0,
242                                  alpha.x, alpha.y, A->getID(), X->getID(),
243                                  beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
244 }
245 
ZGEMV(RsBlasTranspose TransA,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,Double2 beta,const sp<Allocation> & Y,int incY)246 void ScriptIntrinsicBLAS::ZGEMV(RsBlasTranspose TransA, Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& X,
247                                 int incX, Double2 beta, const sp<Allocation>& Y, int incY) {
248     validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY);
249     int M = A->getType()->getY();
250     int N = A->getType()->getX();
251     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemv,
252                            TransA, 0, 0, 0, 0, M, N, 0,
253                            alpha.x, alpha.y, A->getID(), X->getID(),
254                            beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
255 }
256 
SGBMV(RsBlasTranspose TransA,int KL,int KU,float alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,float beta,const sp<Allocation> & Y,int incY)257 void ScriptIntrinsicBLAS::SGBMV(RsBlasTranspose TransA, int KL, int KU, float alpha, const sp<Allocation>& A,
258                                 const sp<Allocation>& X, int incX, float beta, const sp<Allocation>& Y, int incY) {
259     // GBMV has the same validation requirements as GEMV + KL and KU >= 0
260     validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY);
261     if (KL < 0 || KU < 0) {
262         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
263     }
264     int M = A->getType()->getY();
265     int N = A->getType()->getX();
266 
267     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgbmv,
268                                 TransA, 0, 0, 0, 0, M, N, 0,
269                                 alpha, A->getID(), X->getID(),
270                                 beta, Y->getID(), incX, incY, KL, KU);
271 }
272 
DGBMV(RsBlasTranspose TransA,int KL,int KU,double alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,double beta,const sp<Allocation> & Y,int incY)273 void ScriptIntrinsicBLAS::DGBMV(RsBlasTranspose TransA, int KL, int KU, double alpha, const sp<Allocation>& A,
274                                 const sp<Allocation>& X, int incX, double beta, const sp<Allocation>& Y, int incY) {
275     // GBMV has the same validation requirements as GEMV + KL and KU >= 0
276     validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY);
277     if (KL < 0 || KU < 0) {
278         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
279     }
280     int M = A->getType()->getY();
281     int N = A->getType()->getX();
282 
283     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgbmv,
284                                 TransA, 0, 0, 0, 0, M, N, 0,
285                                 alpha, A->getID(), X->getID(),
286                                 beta, Y->getID(), incX, incY, KL, KU);
287 }
288 
CGBMV(RsBlasTranspose TransA,int KL,int KU,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,Float2 beta,const sp<Allocation> & Y,int incY)289 void ScriptIntrinsicBLAS::CGBMV(RsBlasTranspose TransA, int KL, int KU, Float2 alpha, const sp<Allocation>& A,
290                                 const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) {
291     // GBMV has the same validation requirements as GEMV + KL and KU >= 0
292     validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY);
293     if (KL < 0 || KU < 0) {
294         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
295     }
296     int M = A->getType()->getY();
297     int N = A->getType()->getX();
298 
299     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgbmv,
300                                  TransA, 0, 0, 0, 0, M, N, 0,
301                                  alpha.x, alpha.y, A->getID(), X->getID(),
302                                  beta.x, beta.y, Y->getID(), incX, incY, KL, KU);
303 }
304 
ZGBMV(RsBlasTranspose TransA,int KL,int KU,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,Double2 beta,const sp<Allocation> & Y,int incY)305 void ScriptIntrinsicBLAS::ZGBMV(RsBlasTranspose TransA, int KL, int KU, Double2 alpha, const sp<Allocation>& A,
306                                 const sp<Allocation>& X, int incX, Double2 beta, const sp<Allocation>& Y, int incY) {
307     // GBMV has the same validation requirements as GEMV + KL and KU >= 0
308     validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY);
309     if (KL < 0 || KU < 0) {
310         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0");
311     }
312     int M = A->getType()->getY();
313     int N = A->getType()->getX();
314 
315     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgbmv,
316                            TransA, 0, 0, 0, 0, M, N, 0,
317                            alpha.x, alpha.y, A->getID(), X->getID(),
318                            beta.x, beta.y, Y->getID(), incX, incY, KL, KU);
319 }
320 
validateTRMV(RS * mRS,const sp<const Element> & e,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & A,const sp<Allocation> & X,int incX)321 static void validateTRMV(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, RsBlasTranspose TransA,
322                          RsBlasDiag Diag, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
323     int N = A->getType()->getY();
324     if ((int)A->getType()->getX() != N) {
325         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for TRMV");
326     }
327     if (!A->getType()->getElement()->isCompatible(e) ||
328         !X->getType()->getElement()->isCompatible(e)) {
329         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
330     }
331     if (X->getType()->getY() > 1) {
332         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
333     }
334 
335     if (incX <= 0) {
336         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
337     }
338     int expectedXDim = 1 + (N - 1) * incX;
339     if ((int)X->getType()->getX() != expectedXDim) {
340         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TRMV");
341     }
342 }
343 
validateTPMV(RS * mRS,const sp<const Element> & e,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX)344 static int validateTPMV(RS* mRS, const sp<const Element>& e,  RsBlasUplo Uplo, RsBlasTranspose TransA,
345                         RsBlasDiag Diag, const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
346     if (!Ap->getType()->getElement()->isCompatible(e) ||
347         !X->getType()->getElement()->isCompatible(e)) {
348         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
349     }
350     if (X->getType()->getY() > 1) {
351         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
352     }
353 
354     if (Ap->getType()->getY() > 1) {
355         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
356     }
357 
358     int N = sqrt((double)Ap->getType()->getX() * 2);
359     if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
360         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
361     }
362     if (incX <= 0) {
363         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
364     }
365     int expectedXDim = 1 + (N - 1) * incX;
366     if ((int)X->getType()->getX() != expectedXDim) {
367         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TPMV");
368     }
369 
370     return N;
371 }
372 
373 
STRMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & A,const sp<Allocation> & X,int incX)374 void ScriptIntrinsicBLAS::STRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
375                                 const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
376     validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
377     int N = A->getType()->getY();
378     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmv,
379                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
380                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
381 }
382 
DTRMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & A,const sp<Allocation> & X,int incX)383 void ScriptIntrinsicBLAS::DTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
384                                 const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
385     validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
386     int N = A->getType()->getY();
387     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmv,
388                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
389                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
390 }
391 
CTRMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & A,const sp<Allocation> & X,int incX)392 void ScriptIntrinsicBLAS::CTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
393                                 const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
394     validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
395     int N = A->getType()->getY();
396     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmv,
397                                  TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
398                                  A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
399 }
400 
ZTRMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & A,const sp<Allocation> & X,int incX)401 void ScriptIntrinsicBLAS::ZTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
402                                 const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
403     validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
404     int N = A->getType()->getY();
405     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmv,
406                            TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
407                            A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
408 }
409 
STBMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,const sp<Allocation> & A,const sp<Allocation> & X,int incX)410 void ScriptIntrinsicBLAS::STBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
411                                 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
412     // TBMV has the same requirements as TRMV + K >= 0
413     if (K < 0) {
414         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
415     }
416     validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
417     int N = A->getType()->getY();
418     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbmv,
419                                 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
420                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
421 }
422 
DTBMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,const sp<Allocation> & A,const sp<Allocation> & X,int incX)423 void ScriptIntrinsicBLAS::DTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
424                                 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
425     // TBMV has the same requirements as TRMV + K >= 0
426     if (K < 0) {
427         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
428     }
429     validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
430     int N = A->getType()->getY();
431     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbmv,
432                                 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
433                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
434 }
435 
CTBMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,const sp<Allocation> & A,const sp<Allocation> & X,int incX)436 void ScriptIntrinsicBLAS::CTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
437                                 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
438     // TBMV has the same requirements as TRMV + K >= 0
439     if (K < 0) {
440         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
441     }
442     validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
443     int N = A->getType()->getY();
444     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbmv,
445                                  TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
446                                  A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
447 }
448 
ZTBMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,const sp<Allocation> & A,const sp<Allocation> & X,int incX)449 void ScriptIntrinsicBLAS::ZTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
450                                 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
451     // TBMV has the same requirements as TRMV + K >= 0
452     if (K < 0) {
453         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
454     }
455     validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
456     int N = A->getType()->getY();
457     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbmv,
458                            TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
459                            A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
460 }
461 
STPMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX)462 void ScriptIntrinsicBLAS::STPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
463                                 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
464     int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
465     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpmv,
466                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
467                                 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
468 }
469 
DTPMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX)470 void ScriptIntrinsicBLAS::DTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
471                                 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
472     int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
473     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpmv,
474                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
475                                 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
476 }
477 
CTPMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX)478 void ScriptIntrinsicBLAS::CTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
479                                 const sp<Allocation>& Ap,  const sp<Allocation>& X,  int incX) {
480     int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
481     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpmv,
482                                  TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
483                                  Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
484 }
485 
ZTPMV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX)486 void ScriptIntrinsicBLAS::ZTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
487                                 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
488     int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
489     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpmv,
490                            TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
491                            Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
492 }
493 
STRSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & A,const sp<Allocation> & X,int incX)494 void ScriptIntrinsicBLAS::STRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
495                                 const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
496     // TRSV is the same as TRMV
497     validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
498     int N = A->getType()->getY();
499     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsv,
500                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
501                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
502 }
503 
DTRSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & A,const sp<Allocation> & X,int incX)504 void ScriptIntrinsicBLAS::DTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
505                                 const sp<Allocation>& A,  const sp<Allocation>& X,  int incX) {
506     // TRSV is the same as TRMV
507     validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
508     int N = A->getType()->getY();
509     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsv,
510                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
511                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
512 
513 }
514 
CTRSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & A,const sp<Allocation> & X,int incX)515 void ScriptIntrinsicBLAS::CTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
516                                 const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
517     // TRSV is the same as TRMV
518     validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
519     int N = A->getType()->getY();
520     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsv,
521                                  TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
522                                  A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
523 
524 }
525 
ZTRSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & A,const sp<Allocation> & X,int incX)526 void ScriptIntrinsicBLAS::ZTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
527                                 const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
528     // TRSV is the same as TRMV
529     validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
530     int N = A->getType()->getY();
531     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsv,
532                            TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
533                            A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
534 
535 }
536 
STBSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,const sp<Allocation> & A,const sp<Allocation> & X,int incX)537 void ScriptIntrinsicBLAS::STBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
538                                 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
539     // TBSV is the same as TRMV + K >= 0
540     validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX);
541     int N = A->getType()->getY();
542     if (K < 0) {
543         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
544     }
545     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbsv,
546                                 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
547                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
548 }
549 
DTBSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,const sp<Allocation> & A,const sp<Allocation> & X,int incX)550 void ScriptIntrinsicBLAS::DTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
551                                 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
552     // TBSV is the same as TRMV + K >= 0
553     validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX);
554     int N = A->getType()->getY();
555     if (K < 0) {
556         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
557     }
558     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbsv,
559                                 TransA, 0, 0, Uplo, Diag, 0, N, K, 0,
560                                 A->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
561 }
562 
CTBSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,const sp<Allocation> & A,const sp<Allocation> & X,int incX)563 void ScriptIntrinsicBLAS::CTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
564                                 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
565     // TBSV is the same as TRMV + K >= 0
566     validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX);
567     int N = A->getType()->getY();
568     if (K < 0) {
569         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
570     }
571     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbsv,
572                                  TransA, 0, 0, Uplo, Diag, 0, N, K,
573                                  0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
574 }
575 
ZTBSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,int K,const sp<Allocation> & A,const sp<Allocation> & X,int incX)576 void ScriptIntrinsicBLAS::ZTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
577                                 int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) {
578     // TBSV is the same as TRMV + K >= 0
579     validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX);
580     int N = A->getType()->getY();
581     if (K < 0) {
582         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive");
583     }
584     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbsv,
585                            TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0,
586                            A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
587 }
588 
STPSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX)589 void ScriptIntrinsicBLAS::STPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
590                                 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
591     // TPSV is same as TPMV
592     int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX);
593     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpsv,
594                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
595                                 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
596 }
597 
DTPSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX)598 void ScriptIntrinsicBLAS::DTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
599                                 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
600     // TPSV is same as TPMV
601     int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX);
602     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpsv,
603                                 TransA, 0, 0, Uplo, Diag, 0, N, 0, 0,
604                                 Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0);
605 }
606 
CTPSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX)607 void ScriptIntrinsicBLAS::CTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
608                                 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
609     // TPSV is same as TPMV
610     int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
611     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpsv,
612                                  TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
613                                  Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
614 }
615 
ZTPSV(RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX)616 void ScriptIntrinsicBLAS::ZTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
617                                 const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) {
618     // TPSV is same as TPMV
619     int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX);
620     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpsv,
621                            TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0,
622                            Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0);
623 }
624 
625 /**
626  * Level 2, S and D only
627  */
validateSYMV(RS * mRS,const sp<const Element> & e,RsBlasUplo Uplo,const sp<Allocation> & A,const sp<Allocation> & X,const sp<Allocation> & Y,int incX,int incY)628 static int validateSYMV(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& A,
629                         const sp<Allocation>& X, const sp<Allocation>& Y, int incX, int incY) {
630     int N = A->getType()->getY();
631     if ((int)A->getType()->getX() != N) {
632         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for SYMV");
633     }
634     if (!A->getType()->getElement()->isCompatible(e) ||
635         !X->getType()->getElement()->isCompatible(e) ||
636         !Y->getType()->getElement()->isCompatible(e) ) {
637         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
638     }
639     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
640         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
641     }
642 
643     if (incX <= 0 || incY <= 0) {
644         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
645     }
646     int expectedXDim = 1 + (N - 1) * incX;
647     if ((int)X->getType()->getX() != expectedXDim) {
648         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV");
649     }
650     int expectedYDim = 1 + (N - 1) * incY;
651     if ((int)Y->getType()->getX() != expectedYDim) {
652         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV");
653     }
654     return N;
655 }
validateSPMV(RS * mRS,const sp<const Element> & e,RsBlasUplo Uplo,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY)656 static int validateSPMV(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& Ap,
657                         const sp<Allocation>& X, int incX, const sp<Allocation>& Y, int incY) {
658     if (!Ap->getType()->getElement()->isCompatible(e) ||
659         !X->getType()->getElement()->isCompatible(e) ||
660         !Y->getType()->getElement()->isCompatible(e)) {
661         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
662     }
663     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
664         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
665     }
666 
667     if (Ap->getType()->getY() > 1) {
668         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
669     }
670 
671     int N = sqrt((double)Ap->getType()->getX() * 2);
672     if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
673         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
674     }
675     if (incX <= 0 || incY <= 0) {
676         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
677     }
678     int expectedXDim = 1 + (N - 1) * incX;
679     if ((int)X->getType()->getX() != expectedXDim) {
680         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV");
681     }
682     int expectedYDim = 1 + (N - 1) * incY;
683     if ((int)Y->getType()->getX() != expectedYDim) {
684         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV");
685     }
686 
687     return N;
688 }
validateGER(RS * mRS,const sp<const Element> & e,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)689 static void validateGER(RS* mRS, const sp<const Element>& e, const sp<Allocation>& X, int incX,
690                         const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
691     if (!A->getType()->getElement()->isCompatible(e) ||
692         !X->getType()->getElement()->isCompatible(e) ||
693         !Y->getType()->getElement()->isCompatible(e) ) {
694         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
695     }
696 
697     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
698         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
699     }
700 
701     int M = A->getType()->getY();
702     int N = A->getType()->getX();
703 
704     if (N < 1 || M < 1) {
705         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "M and N must be 1 or greater for GER");
706     }
707     if (incX <= 0 || incY <= 0) {
708         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
709     }
710     int expectedXDim = 1 + (M - 1) * incX;
711     if ((int)X->getType()->getX() != expectedXDim) {
712         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER");
713     }
714     int expectedYDim = 1 + (N - 1) * incY;
715     if ((int)Y->getType()->getX() != expectedYDim) {
716         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER");
717     }
718 
719 
720 }
validateSYR(RS * mRS,const sp<const Element> & e,RsBlasUplo Uplo,const sp<Allocation> & X,int incX,const sp<Allocation> & A)721 static int validateSYR(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo,
722                        const sp<Allocation>& X, int incX, const sp<Allocation>& A) {
723     if (!A->getType()->getElement()->isCompatible(e) ||
724         !X->getType()->getElement()->isCompatible(e)) {
725         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
726     }
727 
728     int N = A->getType()->getX();
729 
730     if (X->getType()->getY() > 1) {
731         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
732     }
733     if (N != (int)A->getType()->getY()) {
734         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix");
735     }
736     if (incX <= 0) {
737         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
738     }
739     int expectedXDim = 1 + (N - 1) * incX;
740     if ((int)X->getType()->getX() != expectedXDim) {
741         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR");
742     }
743     return N;
744 }
validateSPR(RS * mRS,const sp<const Element> & e,RsBlasUplo Uplo,const sp<Allocation> & X,int incX,const sp<Allocation> & Ap)745 static int validateSPR(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo,
746                        const sp<Allocation>& X, int incX, const sp<Allocation>& Ap) {
747     if (!Ap->getType()->getElement()->isCompatible(e) ||
748         !X->getType()->getElement()->isCompatible(e)) {
749         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
750     }
751     if (X->getType()->getY() > 1) {
752         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
753     }
754 
755     if (Ap->getType()->getY() > 1) {
756         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
757     }
758 
759     int N = sqrt((double)Ap->getType()->getX() * 2);
760     if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
761         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
762     }
763     if (incX <= 0) {
764         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
765     }
766     int expectedXDim = 1 + (N - 1) * incX;
767     if ((int)X->getType()->getX() != expectedXDim) {
768         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR");
769     }
770 
771     return N;
772 }
773 
validateSYR2(RS * mRS,const sp<const Element> & e,RsBlasUplo Uplo,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)774 static int validateSYR2(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& X,
775                         int incX, const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
776     if (!A->getType()->getElement()->isCompatible(e) ||
777         !X->getType()->getElement()->isCompatible(e) ||
778         !Y->getType()->getElement()->isCompatible(e)) {
779         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
780     }
781 
782     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
783         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
784     }
785 
786     int N = A->getType()->getX();
787 
788     if (N != (int)A->getType()->getY()) {
789         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix");
790     }
791     if (incX <= 0 || incY <= 0) {
792         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
793     }
794     int expectedXDim = 1 + (N - 1) * incX;
795     int expectedYDim = 1 + (N - 1) * incY;
796     if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) {
797         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR");
798     }
799     return N;
800 
801 }
validateSPR2(RS * mRS,const sp<const Element> & e,RsBlasUplo Uplo,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & Ap)802 static int validateSPR2(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& X,
803                         int incX, const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) {
804     if (!Ap->getType()->getElement()->isCompatible(e) ||
805         !X->getType()->getElement()->isCompatible(e) ||
806         !Y->getType()->getElement()->isCompatible(e)) {
807         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
808     }
809     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
810         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
811     }
812 
813     if (Ap->getType()->getY() > 1) {
814         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1");
815     }
816 
817     int N = sqrt((double)Ap->getType()->getX() * 2);
818     if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) {
819         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap");
820     }
821     if (incX <= 0 || incY <= 0) {
822         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
823     }
824     int expectedXDim = 1 + (N - 1) * incX;
825     int expectedYDim = 1 + (N - 1) * incY;
826     if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) {
827         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR2");
828     }
829 
830     return N;
831 }
832 
SSYMV(RsBlasUplo Uplo,float alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,float beta,const sp<Allocation> & Y,int incY)833 void ScriptIntrinsicBLAS::SSYMV(RsBlasUplo Uplo, float alpha, const sp<Allocation>& A, const sp<Allocation>& X,
834                                 int incX, float beta, const sp<Allocation>& Y, int incY) {
835     int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY);
836     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymv,
837                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
838                                 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
839 }
840 
SSBMV(RsBlasUplo Uplo,int K,float alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,float beta,const sp<Allocation> & Y,int incY)841 void ScriptIntrinsicBLAS::SSBMV(RsBlasUplo Uplo, int K, float alpha, const sp<Allocation>& A, const sp<Allocation>& X,
842                                 int incX, float beta, const sp<Allocation>& Y, int incY) {
843     // SBMV is the same as SYMV + K >= 0
844     if (K < 0) {
845         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
846     }
847     int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY);
848     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssbmv,
849                                 0, 0, 0, Uplo, 0, 0, N, K, alpha,
850                                 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
851 }
852 
SSPMV(RsBlasUplo Uplo,float alpha,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX,float beta,const sp<Allocation> & Y,int incY)853 void ScriptIntrinsicBLAS::SSPMV(RsBlasUplo Uplo, float alpha, const sp<Allocation>& Ap, const sp<Allocation>& X,
854                                 int incX, float beta, const sp<Allocation>& Y, int incY) {
855     int N = validateSPMV(mRS, Element::F32(mRS), Uplo, Ap, X, incX, Y, incY);
856     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspmv,
857                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
858                                 Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
859 }
860 
SGER(float alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)861 void ScriptIntrinsicBLAS::SGER(float alpha, const sp<Allocation>& X, int incX,
862                                const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
863     int M = A->getType()->getY();
864     int N = A->getType()->getX();
865     validateGER(mRS, Element::F32(mRS), X, incX, Y, incY, A);
866     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sger,
867                                 0, 0, 0, 0, 0, M, N, 0, alpha,
868                                 X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0);
869 }
870 
SSYR(RsBlasUplo Uplo,float alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & A)871 void ScriptIntrinsicBLAS::SSYR(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X,
872                                int incX, const sp<Allocation>& A) {
873     int N = validateSYR(mRS, Element::F32(mRS), Uplo, X, incX, A);
874     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr,
875                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
876                                 X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0);
877 }
878 
SSPR(RsBlasUplo Uplo,float alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Ap)879 void ScriptIntrinsicBLAS::SSPR(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X,
880                                int incX, const sp<Allocation>& Ap) {
881     int N = validateSPR(mRS, Element::F32(mRS), Uplo, X, incX, Ap);
882     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr,
883                                 0, 0, 0, Uplo, 0, 0, N, 0,
884                                 alpha, X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0);
885 }
886 
SSYR2(RsBlasUplo Uplo,float alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)887 void ScriptIntrinsicBLAS::SSYR2(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X, int incX,
888                                 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
889     int N = validateSYR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, A);
890     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2,
891                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
892                                 X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0);
893 }
894 
SSPR2(RsBlasUplo Uplo,float alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & Ap)895 void ScriptIntrinsicBLAS::SSPR2(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X, int incX,
896                                 const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) {
897     int N = validateSPR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, Ap);
898     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr2,
899                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
900                                 X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0);
901 }
902 
DSYMV(RsBlasUplo Uplo,double alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,double beta,const sp<Allocation> & Y,int incY)903 void ScriptIntrinsicBLAS::DSYMV(RsBlasUplo Uplo, double alpha, const sp<Allocation>& A, const sp<Allocation>& X,
904                                 int incX, double beta, const sp<Allocation>& Y, int incY) {
905     int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY);
906     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymv,
907                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
908                                 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
909 }
910 
DSBMV(RsBlasUplo Uplo,int K,double alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,double beta,const sp<Allocation> & Y,int incY)911 void ScriptIntrinsicBLAS::DSBMV(RsBlasUplo Uplo, int K, double alpha, const sp<Allocation>& A, const sp<Allocation>& X,
912                                 int incX, double beta, const sp<Allocation>& Y, int incY) {
913     // SBMV is the same as SYMV + K >= 0
914     if (K < 0) {
915         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0");
916     }
917     int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY);
918     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsbmv,
919                                 0, 0, 0, Uplo, 0, 0, N, K, alpha,
920                                 A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
921 }
922 
DSPMV(RsBlasUplo Uplo,double alpha,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX,double beta,const sp<Allocation> & Y,int incY)923 void ScriptIntrinsicBLAS::DSPMV(RsBlasUplo Uplo, double alpha, const sp<Allocation>& Ap, const sp<Allocation>& X,
924                                 int incX, double beta, const sp<Allocation>& Y, int incY) {
925     int N = validateSPMV(mRS, Element::F64(mRS), Uplo, Ap, X, incX, Y, incY);
926     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspmv,
927                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
928                                 Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0);
929 }
930 
DGER(double alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)931 void ScriptIntrinsicBLAS::DGER(double alpha, const sp<Allocation>& X, int incX, const sp<Allocation>& Y,
932                                int incY, const sp<Allocation>& A) {
933     int M = A->getType()->getY();
934     int N = A->getType()->getX();
935     validateGER(mRS, Element::F64(mRS), X, incX, Y, incY, A);
936     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dger,
937                                 0, 0, 0, 0, 0, M, N, 0, alpha,
938                                 X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0);
939 }
940 
DSYR(RsBlasUplo Uplo,double alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & A)941 void ScriptIntrinsicBLAS::DSYR(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X,
942                                int incX, const sp<Allocation>& A) {
943     int N = validateSYR(mRS, Element::F64(mRS), Uplo, X, incX, A);
944     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr,
945                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
946                                 X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0);
947 }
948 
DSPR(RsBlasUplo Uplo,double alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Ap)949 void ScriptIntrinsicBLAS::DSPR(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X,
950                                int incX, const sp<Allocation>& Ap) {
951     int N = validateSPR(mRS, Element::F64(mRS), Uplo, X, incX, Ap);
952     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr,
953                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
954                                 X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0);
955 }
956 
DSYR2(RsBlasUplo Uplo,double alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)957 void ScriptIntrinsicBLAS::DSYR2(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X, int incX,
958                                 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
959     int N = validateSYR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, A);
960     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2,
961                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
962                                 X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0);
963 }
964 
DSPR2(RsBlasUplo Uplo,double alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & Ap)965 void ScriptIntrinsicBLAS::DSPR2(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X, int incX,
966                                 const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) {
967     int N = validateSPR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, Ap);
968     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr2,
969                                 0, 0, 0, Uplo, 0, 0, N, 0, alpha,
970                                 X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0);
971 }
972 
973 
974 /**
975  * Level 2, C and Z only
976  */
977 
validateGERU(RS * mRS,const sp<const Element> & e,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)978 static void validateGERU(RS* mRS, const sp<const Element>& e, const sp<Allocation>& X, int incX,
979                          const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
980     if (!A->getType()->getElement()->isCompatible(e) ||
981         !X->getType()->getElement()->isCompatible(e) ||
982         !Y->getType()->getElement()->isCompatible(e)) {
983         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
984     }
985     if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) {
986         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1");
987     }
988 
989     int M = A->getType()->getY();
990     int N = A->getType()->getX();
991     if (incX <= 0 || incY <= 0) {
992         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0");
993     }
994     int expectedXDim = 1 + (M - 1) * incX;
995     if ((int)X->getType()->getX() != expectedXDim) {
996         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU");
997     }
998     int expectedYDim = 1 + (N - 1) * incY;
999     if ((int)Y->getType()->getX() != expectedYDim) {
1000         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU");
1001     }
1002 
1003 }
1004 
CHEMV(RsBlasUplo Uplo,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,Float2 beta,const sp<Allocation> & Y,int incY)1005 void ScriptIntrinsicBLAS::CHEMV(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& A,
1006                                 const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) {
1007     // HEMV is the same as SYR2 validation-wise
1008     int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1009     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemv,
1010                                  0, 0, 0, Uplo, 0, 0, N, 0,
1011                                  alpha.x, alpha.y, A->getID(), X->getID(),
1012                                  beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1013 }
1014 
CHBMV(RsBlasUplo Uplo,int K,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,Float2 beta,const sp<Allocation> & Y,int incY)1015 void ScriptIntrinsicBLAS::CHBMV(RsBlasUplo Uplo, int K, Float2 alpha, const sp<Allocation>& A,
1016                                 const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) {
1017     // HBMV is the same as SYR2 validation-wise
1018     int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1019     if (K < 0) {
1020         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV");
1021     }
1022     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chbmv,
1023                                  0, 0, 0, Uplo, 0, 0, N, K,
1024                                  alpha.x, alpha.y, A->getID(), X->getID(),
1025                                  beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1026 }
1027 
CHPMV(RsBlasUplo Uplo,Float2 alpha,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX,Float2 beta,const sp<Allocation> & Y,int incY)1028 void ScriptIntrinsicBLAS::CHPMV(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& Ap,
1029                                 const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) {
1030     // HPMV is the same as SPR2
1031     int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
1032     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpmv,
1033                                  0, 0, 0, Uplo, 0, 0, N, 0,
1034                                  alpha.x, alpha.y, Ap->getID(), X->getID(),
1035                                  beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1036 }
1037 
CGERU(Float2 alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)1038 void ScriptIntrinsicBLAS::CGERU(Float2 alpha, const sp<Allocation>& X, int incX,
1039                                 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
1040     validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A);
1041     int M = A->getType()->getY();
1042     int N = A->getType()->getX();
1043     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgeru,
1044                                  0, 0, 0, 0, 0, M, N, 0,
1045                                  alpha.x, alpha.y, X->getID(), Y->getID(),
1046                                  0, 0, A->getID(), incX, incY, 0, 0);
1047 }
1048 
CGERC(Float2 alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)1049 void ScriptIntrinsicBLAS::CGERC(Float2 alpha, const sp<Allocation>& X, int incX,
1050                                 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
1051     // Same as GERU
1052     validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A);
1053     int M = A->getType()->getY();
1054     int N = A->getType()->getX();
1055     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgerc,
1056                                  0, 0, 0, 0, 0, M, N, 0,
1057                                  alpha.x, alpha.y, X->getID(), Y->getID(),
1058                                  0, 0, A->getID(), incX, incY, 0, 0);
1059 }
1060 
CHER(RsBlasUplo Uplo,float alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & A)1061 void ScriptIntrinsicBLAS::CHER(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X,
1062                                int incX, const sp<Allocation>& A) {
1063     // Same as SYR
1064     int N = validateSYR(mRS, Element::F32_2(mRS), Uplo, X, incX, A);
1065     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher,
1066                                  0, 0, 0, Uplo, 0, 0, N, 0,
1067                                  alpha, 0, X->getID(), 0,
1068                                  0, 0, A->getID(), incX, 0, 0, 0);
1069 }
1070 
CHPR(RsBlasUplo Uplo,float alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Ap)1071 void ScriptIntrinsicBLAS::CHPR(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X,
1072                                int incX, const sp<Allocation>& Ap) {
1073     // Equivalent to SPR for validation
1074     int N = validateSPR(mRS, Element::F32_2(mRS), Uplo, X, incX, Ap);
1075     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr,
1076                                  0, 0, 0, Uplo, 0, 0, N, 0,
1077                                  alpha, 0, X->getID(), 0,
1078                                  0, 0, Ap->getID(), incX, 0, 0, 0);
1079 }
1080 
CHER2(RsBlasUplo Uplo,Float2 alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)1081 void ScriptIntrinsicBLAS::CHER2(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& X, int incX,
1082                                 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
1083     // Same as SYR2
1084     int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A);
1085     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2,
1086                                  0, 0, 0, Uplo, 0, 0, N, 0,
1087                                  alpha.x, alpha.y, X->getID(), Y->getID(),
1088                                  0, 0, A->getID(), incX, incY, 0, 0);
1089 }
1090 
CHPR2(RsBlasUplo Uplo,Float2 alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & Ap)1091 void ScriptIntrinsicBLAS::CHPR2(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& X, int incX,
1092                                 const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) {
1093     // Same as SPR2
1094     int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap);
1095     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr2,
1096                                  0, 0, 0, Uplo, 0, 0, N, 0,
1097                                  alpha.x, alpha.y, X->getID(), Y->getID(),
1098                                  0, 0, Ap->getID(), incX, incY, 0, 0);
1099 }
1100 
ZHEMV(RsBlasUplo Uplo,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,Double2 beta,const sp<Allocation> & Y,int incY)1101 void ScriptIntrinsicBLAS::ZHEMV(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& A,
1102                                 const sp<Allocation>& X, int incX, Double2 beta, const sp<Allocation>& Y, int incY) {
1103     // HEMV is the same as SYR2 validation-wise
1104     int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1105     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemv,
1106                            0, 0, 0, Uplo, 0, 0, N, 0,
1107                            alpha.x, alpha.y, A->getID(), X->getID(),
1108                            beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1109 }
1110 
ZHBMV(RsBlasUplo Uplo,int K,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & X,int incX,Double2 beta,const sp<Allocation> & Y,int incY)1111 void ScriptIntrinsicBLAS::ZHBMV(RsBlasUplo Uplo, int K, Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& X,
1112                                 int incX, Double2 beta, const sp<Allocation>& Y, int incY) {
1113     // HBMV is the same as SYR2 validation-wise
1114     int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1115     if (K < 0) {
1116         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV");
1117     }
1118     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhbmv,
1119                            0, 0, 0, Uplo, 0, 0, N, K,
1120                            alpha.x, alpha.y, A->getID(), X->getID(),
1121                            beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1122 }
1123 
ZHPMV(RsBlasUplo Uplo,Double2 alpha,const sp<Allocation> & Ap,const sp<Allocation> & X,int incX,Double2 beta,const sp<Allocation> & Y,int incY)1124 void ScriptIntrinsicBLAS::ZHPMV(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& Ap, const sp<Allocation>& X,
1125                                 int incX, Double2 beta, const sp<Allocation>& Y, int incY) {
1126     // HPMV is the same as SPR2
1127     int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
1128     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpmv,
1129                            0, 0, 0, Uplo, 0, 0, N, 0,
1130                            alpha.x, alpha.y, Ap->getID(), X->getID(),
1131                            beta.x, beta.y, Y->getID(), incX, incY, 0, 0);
1132 }
1133 
ZGERU(Double2 alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)1134 void ScriptIntrinsicBLAS::ZGERU(Double2 alpha, const sp<Allocation>& X, int incX,
1135                                 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
1136     validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A);
1137     int M = A->getType()->getY();
1138     int N = A->getType()->getX();
1139     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgeru,
1140                            0, 0, 0, 0, 0, M, N, 0,
1141                            alpha.x, alpha.y, X->getID(), Y->getID(),
1142                            0, 0, A->getID(), incX, incY, 0, 0);
1143 }
1144 
ZGERC(Double2 alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)1145 void ScriptIntrinsicBLAS::ZGERC(Double2 alpha, const sp<Allocation>& X, int incX,
1146                                 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
1147     // Same as GERU
1148     validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A);
1149     int M = A->getType()->getY();
1150     int N = A->getType()->getX();
1151     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgerc,
1152                            0, 0, 0, 0, 0, M, N, 0,
1153                            alpha.x, alpha.y, X->getID(), Y->getID(),
1154                            0, 0, A->getID(), incX, incY, 0, 0);
1155 }
1156 
ZHER(RsBlasUplo Uplo,double alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & A)1157 void ScriptIntrinsicBLAS::ZHER(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X,
1158                                int incX, const sp<Allocation>& A) {
1159     // Same as SYR
1160     int N = validateSYR(mRS, Element::F64_2(mRS), Uplo, X, incX, A);
1161     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher,
1162                            0, 0, 0, Uplo, 0, 0, N, 0,
1163                            alpha, 0, X->getID(), 0,
1164                            0, 0, A->getID(), incX, 0, 0, 0);
1165 }
1166 
ZHPR(RsBlasUplo Uplo,double alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Ap)1167 void ScriptIntrinsicBLAS::ZHPR(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X,
1168                                int incX, const sp<Allocation>& Ap) {
1169     // Equivalent to SPR for validation
1170     int N = validateSPR(mRS, Element::F64_2(mRS), Uplo, X, incX, Ap);
1171     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr,
1172                            0, 0, 0, Uplo, 0, 0, N, 0,
1173                            alpha, 0, X->getID(), 0,
1174                            0, 0, Ap->getID(), incX, 0, 0, 0);
1175 }
1176 
ZHER2(RsBlasUplo Uplo,Double2 alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & A)1177 void ScriptIntrinsicBLAS::ZHER2(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& X, int incX,
1178                                 const sp<Allocation>& Y, int incY, const sp<Allocation>& A) {
1179     // Same as SYR2
1180     int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A);
1181     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2,
1182                            0, 0, 0, Uplo, 0, 0, N, 0,
1183                            alpha.x, alpha.y, X->getID(), Y->getID(),
1184                            0, 0, A->getID(), incX, incY, 0, 0);
1185 }
1186 
ZHPR2(RsBlasUplo Uplo,Double2 alpha,const sp<Allocation> & X,int incX,const sp<Allocation> & Y,int incY,const sp<Allocation> & Ap)1187 void ScriptIntrinsicBLAS::ZHPR2(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& X, int incX,
1188                                 const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) {
1189     // Same as SPR2
1190     int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap);
1191     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr2,
1192                            0, 0, 0, Uplo, 0, 0, N, 0,
1193                            alpha.x, alpha.y, X->getID(), Y->getID(),
1194                            0, 0, Ap->getID(), incX, incY, 0, 0);
1195 }
1196 
1197 
1198 /**
1199  * Level 3 BLAS
1200  */
1201 
validateL3(RS * mRS,const sp<const Element> & e,int TransA,int TransB,int Side,const sp<Allocation> & A,const sp<Allocation> & B,const sp<Allocation> & C)1202 static void validateL3(RS* mRS, const sp<const Element>& e, int TransA, int TransB, int Side,
1203                        const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) {
1204     int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1;
1205     if ((A != nullptr && !A->getType()->getElement()->isCompatible(e)) ||
1206         (B != nullptr && !B->getType()->getElement()->isCompatible(e)) ||
1207         (C != nullptr && !C->getType()->getElement()->isCompatible(e))) {
1208         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1209     }
1210     if (C == nullptr) {
1211         // Since matrix C is used to store the result, it cannot be null.
1212         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Allocation C cannot be null");
1213     }
1214     cM = C->getType()->getY();
1215     cN = C->getType()->getX();
1216 
1217     if (Side == RsBlasRight) {
1218         if ((A == nullptr && B != nullptr) || (A != nullptr && B == nullptr)) {
1219             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Provided Matrix A without Matrix B, or vice versa");
1220         }
1221         if (B != nullptr) {
1222             bM = A->getType()->getY();
1223             bN = A->getType()->getX();
1224         }
1225         if (A != nullptr) {
1226             aM = B->getType()->getY();
1227             aN = B->getType()->getX();
1228         }
1229     } else {
1230         if (A != nullptr) {
1231             if (TransA == RsBlasTrans || TransA == RsBlasConjTrans) {
1232                 aN = A->getType()->getY();
1233                 aM = A->getType()->getX();
1234             } else {
1235                 aM = A->getType()->getY();
1236                 aN = A->getType()->getX();
1237             }
1238         }
1239         if (B != nullptr) {
1240             if (TransB == RsBlasTrans || TransB == RsBlasConjTrans) {
1241                 bN = B->getType()->getY();
1242                 bM = B->getType()->getX();
1243             } else {
1244                 bM = B->getType()->getY();
1245                 bN = B->getType()->getX();
1246             }
1247         }
1248     }
1249     if (A != nullptr && B != nullptr && C != nullptr) {
1250         if (aN != bM || aM != cM || bN != cN) {
1251             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1252         }
1253     } else if (A != nullptr && C != nullptr) {
1254         // A and C only, for SYRK
1255         if (cM != cN) {
1256             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix C is not symmetric");
1257         }
1258         if (aM != cM) {
1259             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1260         }
1261     } else if (A != nullptr && B != nullptr) {
1262         // A and B only
1263         if (aN != bM) {
1264             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions");
1265         }
1266     }
1267 
1268 }
1269 
SGEMM(RsBlasTranspose TransA,RsBlasTranspose TransB,float alpha,const sp<Allocation> & A,const sp<Allocation> & B,float beta,const sp<Allocation> & C)1270 void ScriptIntrinsicBLAS::SGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, float alpha,
1271                                 const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) {
1272     validateL3(mRS, Element::F32(mRS), TransA, TransB, 0, A, B, C);
1273 
1274     int M = -1, N = -1, K = -1;
1275     if (TransA != RsBlasNoTrans) {
1276         M = A->getType()->getX();
1277         K = A->getType()->getY();
1278     } else {
1279         M = A->getType()->getY();
1280         K = A->getType()->getX();
1281     }
1282     if (TransB != RsBlasNoTrans) {
1283         N = B->getType()->getY();
1284     } else {
1285         N = B->getType()->getX();
1286     }
1287     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemm,
1288                                 TransA, TransB, 0, 0, 0, M, N, K,
1289                                 alpha, A->getID(), B->getID(),
1290                                 beta, C->getID(), 0, 0, 0, 0);
1291 }
1292 
DGEMM(RsBlasTranspose TransA,RsBlasTranspose TransB,double alpha,const sp<Allocation> & A,const sp<Allocation> & B,double beta,const sp<Allocation> & C)1293 void ScriptIntrinsicBLAS::DGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, double alpha,
1294                                 const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) {
1295     validateL3(mRS, Element::F64(mRS), TransA, TransB, 0, A, B, C);
1296     int M = -1, N = -1, K = -1;
1297     if (TransA != RsBlasNoTrans) {
1298         M = A->getType()->getX();
1299         K = A->getType()->getY();
1300     } else {
1301         M = A->getType()->getY();
1302         K = A->getType()->getX();
1303     }
1304     if (TransB != RsBlasNoTrans) {
1305         N = B->getType()->getY();
1306     } else {
1307         N = B->getType()->getX();
1308     }
1309     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemm,
1310                                 TransA, TransB, 0, 0, 0, M, N, K,
1311                                 alpha, A->getID(), B->getID(),
1312                                 beta, C->getID(), 0, 0, 0, 0);
1313 }
1314 
CGEMM(RsBlasTranspose TransA,RsBlasTranspose TransB,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,Float2 beta,const sp<Allocation> & C)1315 void ScriptIntrinsicBLAS::CGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Float2 alpha,
1316                                 const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) {
1317     validateL3(mRS, Element::F32_2(mRS), TransA, TransB, 0, A, B, C);
1318     int M = -1, N = -1, K = -1;
1319     if (TransA != RsBlasNoTrans) {
1320         M = A->getType()->getX();
1321         K = A->getType()->getY();
1322     } else {
1323         M = A->getType()->getY();
1324         K = A->getType()->getX();
1325     }
1326     if (TransB != RsBlasNoTrans) {
1327         N = B->getType()->getY();
1328     } else {
1329         N = B->getType()->getX();
1330     }
1331     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemm,
1332                                  TransA, TransB, 0, 0, 0, M, N, K,
1333                                  alpha.x, alpha.y, A->getID(), B->getID(),
1334                                  beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1335 }
1336 
ZGEMM(RsBlasTranspose TransA,RsBlasTranspose TransB,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,Double2 beta,const sp<Allocation> & C)1337 void ScriptIntrinsicBLAS::ZGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Double2 alpha,
1338                                 const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) {
1339     validateL3(mRS, Element::F64_2(mRS), TransA, TransB, 0, A, B, C);
1340     int M = -1, N = -1, K = -1;
1341     if (TransA != RsBlasNoTrans) {
1342         M = A->getType()->getX();
1343         K = A->getType()->getY();
1344     } else {
1345         M = A->getType()->getY();
1346         K = A->getType()->getX();
1347     }
1348     if (TransB != RsBlasNoTrans) {
1349         N = B->getType()->getY();
1350     } else {
1351         N = B->getType()->getX();
1352     }
1353     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemm,
1354                            TransA, TransB, 0, 0, 0, M, N, K,
1355                            alpha.x, alpha.y, A->getID(), B->getID(),
1356                            beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1357 }
1358 
SSYMM(RsBlasSide Side,RsBlasUplo Uplo,float alpha,const sp<Allocation> & A,const sp<Allocation> & B,float beta,const sp<Allocation> & C)1359 void ScriptIntrinsicBLAS::SSYMM(RsBlasSide Side, RsBlasUplo Uplo, float alpha,
1360                                 const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) {
1361     //For SYMM, Matrix A should be symmetric
1362     if (A->getType()->getX() != A->getType()->getY()) {
1363         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1364     }
1365     validateL3(mRS, Element::F32(mRS), 0, 0, Side, A, B, C);
1366     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymm,
1367                                 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1368                                 alpha, A->getID(), B->getID(),
1369                                 beta, C->getID(), 0, 0, 0, 0);
1370 }
1371 
DSYMM(RsBlasSide Side,RsBlasUplo Uplo,double alpha,const sp<Allocation> & A,const sp<Allocation> & B,double beta,const sp<Allocation> & C)1372 void ScriptIntrinsicBLAS::DSYMM(RsBlasSide Side, RsBlasUplo Uplo, double alpha,
1373                                 const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) {
1374     if (A->getType()->getX() != A->getType()->getY()) {
1375         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1376     }
1377     validateL3(mRS, Element::F64(mRS), 0, 0, Side, A, B, C);
1378     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymm,
1379                                 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1380                                 alpha, A->getID(), B->getID(),
1381                                 beta, C->getID(), 0, 0, 0, 0);
1382 }
1383 
CSYMM(RsBlasSide Side,RsBlasUplo Uplo,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,Float2 beta,const sp<Allocation> & C)1384 void ScriptIntrinsicBLAS::CSYMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha,
1385                                 const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) {
1386     if (A->getType()->getX() != A->getType()->getY()) {
1387         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1388     }
1389     validateL3(mRS, Element::F32_2(mRS), 0, 0, Side, A, B, C);
1390     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csymm,
1391                                  0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1392                                  alpha.x, alpha.y, A->getID(), B->getID(),
1393                                  beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1394 }
1395 
ZSYMM(RsBlasSide Side,RsBlasUplo Uplo,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,Double2 beta,const sp<Allocation> & C)1396 void ScriptIntrinsicBLAS::ZSYMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha,
1397                                 const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) {
1398     if (A->getType()->getX() != A->getType()->getY()) {
1399         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric");
1400     }
1401     validateL3(mRS, Element::F64_2(mRS), 0, 0, Side, A, B, C);
1402     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsymm,
1403                            0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0,
1404                            alpha.x, alpha.y, A->getID(), B->getID(),
1405                            beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1406 }
1407 
SSYRK(RsBlasUplo Uplo,RsBlasTranspose Trans,float alpha,const sp<Allocation> & A,float beta,const sp<Allocation> & C)1408 void ScriptIntrinsicBLAS::SSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1409                                 const sp<Allocation>& A, float beta, const sp<Allocation>& C) {
1410     validateL3(mRS, Element::F32(mRS), Trans, 0, 0, A, nullptr, C);
1411     int K = -1;
1412     if (Trans != RsBlasNoTrans) {
1413         K = A->getType()->getY();
1414     } else {
1415         K = A->getType()->getX();
1416     }
1417     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyrk,
1418                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1419                                 alpha, A->getID(), 0,
1420                                 beta, C->getID(), 0, 0, 0, 0);
1421 }
1422 
DSYRK(RsBlasUplo Uplo,RsBlasTranspose Trans,double alpha,const sp<Allocation> & A,double beta,const sp<Allocation> & C)1423 void ScriptIntrinsicBLAS::DSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1424                                 const sp<Allocation>& A, double beta, const sp<Allocation>& C) {
1425     validateL3(mRS, Element::F64(mRS), Trans, 0, 0, A, nullptr, C);
1426     int K = -1;
1427     if (Trans != RsBlasNoTrans) {
1428         K = A->getType()->getY();
1429     } else {
1430         K = A->getType()->getX();
1431     }
1432     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyrk,
1433                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1434                                 alpha, A->getID(), 0,
1435                                 beta, C->getID(), 0, 0, 0, 0);
1436 }
1437 
CSYRK(RsBlasUplo Uplo,RsBlasTranspose Trans,Float2 alpha,const sp<Allocation> & A,Float2 beta,const sp<Allocation> & C)1438 void ScriptIntrinsicBLAS::CSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1439                                 const sp<Allocation>& A, Float2 beta, const sp<Allocation>& C) {
1440     validateL3(mRS, Element::F32_2(mRS), Trans, 0, 0, A, nullptr, C);
1441     int K = -1;
1442     if (Trans != RsBlasNoTrans) {
1443         K = A->getType()->getY();
1444     } else {
1445         K = A->getType()->getX();
1446     }
1447     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyrk,
1448                                  Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1449                                  alpha.x, alpha.y, A->getID(), 0,
1450                                  beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1451 }
1452 
ZSYRK(RsBlasUplo Uplo,RsBlasTranspose Trans,Double2 alpha,const sp<Allocation> & A,Double2 beta,const sp<Allocation> & C)1453 void ScriptIntrinsicBLAS::ZSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1454                                 const sp<Allocation>& A, Double2 beta, const sp<Allocation>& C) {
1455     validateL3(mRS, Element::F64_2(mRS), Trans, 0, 0, A, nullptr, C);
1456     int K = -1;
1457     if (Trans != RsBlasNoTrans) {
1458         K = A->getType()->getY();
1459     } else {
1460         K = A->getType()->getX();
1461     }
1462     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyrk,
1463                            Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1464                            alpha.x, alpha.y, A->getID(), 0,
1465                            beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1466 }
1467 
validateSYR2K(RS * mRS,const sp<const Element> & e,RsBlasTranspose Trans,const sp<Allocation> & A,const sp<Allocation> & B,const sp<Allocation> & C)1468 static void validateSYR2K(RS* mRS, const sp<const Element>& e, RsBlasTranspose Trans,
1469                           const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) {
1470     if (!A->getType()->getElement()->isCompatible(e) ||
1471         !B->getType()->getElement()->isCompatible(e) ||
1472         !C->getType()->getElement()->isCompatible(e)) {
1473         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1474     }
1475     int Cdim = -1;
1476     // A is n x k if no transpose, k x n if transpose
1477     // C is n x n
1478     if (Trans == RsBlasTrans) {
1479         // check columns versus C
1480         Cdim = A->getType()->getX();
1481     } else {
1482         // check rows versus C
1483         Cdim = A->getType()->getY();
1484     }
1485     if ((int)C->getType()->getX() != Cdim || (int)C->getType()->getY() != Cdim) {
1486         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid symmetric matrix in SYR2K");
1487     }
1488     // A dims == B dims
1489     if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) {
1490         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid A and B in SYR2K");
1491     }
1492 }
1493 
SSYR2K(RsBlasUplo Uplo,RsBlasTranspose Trans,float alpha,const sp<Allocation> & A,const sp<Allocation> & B,float beta,const sp<Allocation> & C)1494 void ScriptIntrinsicBLAS::SSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1495                                  const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) {
1496     validateSYR2K(mRS, Element::F32(mRS), Trans, A, B, C);
1497     int K = -1;
1498     if (Trans != RsBlasNoTrans) {
1499         K = A->getType()->getY();
1500     } else {
1501         K = A->getType()->getX();
1502     }
1503     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2k,
1504                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1505                                 alpha, A->getID(), B->getID(),
1506                                 beta, C->getID(), 0, 0, 0, 0);
1507 }
1508 
DSYR2K(RsBlasUplo Uplo,RsBlasTranspose Trans,double alpha,const sp<Allocation> & A,const sp<Allocation> & B,double beta,const sp<Allocation> & C)1509 void ScriptIntrinsicBLAS::DSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1510                                  const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) {
1511     validateSYR2K(mRS, Element::F64(mRS), Trans, A, B, C);
1512     int K = -1;
1513     if (Trans != RsBlasNoTrans) {
1514         K = A->getType()->getY();
1515     } else {
1516         K = A->getType()->getX();
1517     }
1518     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2k,
1519                                 Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1520                                 alpha, A->getID(), B->getID(),
1521                                 beta, C->getID(), 0, 0, 0, 0);
1522 }
1523 
CSYR2K(RsBlasUplo Uplo,RsBlasTranspose Trans,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,Float2 beta,const sp<Allocation> & C)1524 void ScriptIntrinsicBLAS::CSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1525                                  const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) {
1526     validateSYR2K(mRS, Element::F32_2(mRS), Trans, A, B, C);
1527     int K = -1;
1528     if (Trans != RsBlasNoTrans) {
1529         K = A->getType()->getY();
1530     } else {
1531         K = A->getType()->getX();
1532     }
1533     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyr2k,
1534                                  Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1535                                  alpha.x, alpha.y, A->getID(), B->getID(),
1536                                  beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1537 }
1538 
ZSYR2K(RsBlasUplo Uplo,RsBlasTranspose Trans,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,Double2 beta,const sp<Allocation> & C)1539 void ScriptIntrinsicBLAS::ZSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1540                                  const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) {
1541     validateSYR2K(mRS, Element::F64_2(mRS), Trans, A, B, C);
1542     int K = -1;
1543     if (Trans != RsBlasNoTrans) {
1544         K = A->getType()->getY();
1545     } else {
1546         K = A->getType()->getX();
1547     }
1548     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyr2k,
1549                            Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K,
1550                            alpha.x, alpha.y, A->getID(), B->getID(),
1551                            beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1552 }
1553 
validateTRMM(RS * mRS,const sp<const Element> & e,RsBlasSide Side,RsBlasTranspose TransA,const sp<Allocation> & A,const sp<Allocation> & B)1554 static void validateTRMM(RS* mRS, const sp<const Element>& e, RsBlasSide Side, RsBlasTranspose TransA,
1555                          const sp<Allocation>& A, const sp<Allocation>& B) {
1556     int aM = -1, aN = -1, bM = -1, bN = -1;
1557     if (!A->getType()->getElement()->isCompatible(e) ||
1558         !B->getType()->getElement()->isCompatible(e)) {
1559         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1560     }
1561 
1562     aM = A->getType()->getY();
1563     aN = A->getType()->getX();
1564     if (aM != aN) {
1565         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with a non-symmetric matrix A");
1566     }
1567 
1568     bM = B->getType()->getY();
1569     bN = B->getType()->getX();
1570     if (Side == RsBlasLeft) {
1571         if (aN != bM) {
1572             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices");
1573         }
1574     } else {
1575         if (bN != aM) {
1576             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices");
1577         }
1578     }
1579 }
1580 
STRMM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,float alpha,const sp<Allocation> & A,const sp<Allocation> & B)1581 void ScriptIntrinsicBLAS::STRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1582                                 float alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1583     validateTRMM(mRS, Element::F32(mRS), Side, TransA, A, B);
1584     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmm,
1585                                 TransA, 0, Side, Uplo, Diag,\
1586                                 B->getType()->getY(), B->getType()->getX(), 0,
1587                                 alpha, A->getID(), B->getID(), 0.f, 0, 0, 0, 0, 0);
1588 }
1589 
DTRMM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,double alpha,const sp<Allocation> & A,const sp<Allocation> & B)1590 void ScriptIntrinsicBLAS::DTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1591                                 double alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1592     validateTRMM(mRS, Element::F64(mRS), Side, TransA, A, B);
1593     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmm,
1594                                 TransA, 0, Side, Uplo, Diag,
1595                                 B->getType()->getY(), B->getType()->getX(), 0,
1596                                 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1597 }
1598 
CTRMM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & B)1599 void ScriptIntrinsicBLAS::CTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1600                                 Float2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1601     validateTRMM(mRS, Element::F32_2(mRS), Side, TransA, A, B);
1602     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmm,
1603                                  TransA, 0, Side, Uplo, Diag,
1604                                  B->getType()->getY(), B->getType()->getX(), 0,
1605                                  alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1606 }
1607 
ZTRMM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & B)1608 void ScriptIntrinsicBLAS::ZTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1609                                 Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1610     validateTRMM(mRS, Element::F64_2(mRS), Side, TransA, A, B);
1611     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmm,
1612                            TransA, 0, Side, Uplo, Diag,
1613                            B->getType()->getY(), B->getType()->getX(), 0,
1614                            alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1615 }
1616 
validateTRSM(RS * mRS,const sp<const Element> & e,RsBlasSide Side,RsBlasTranspose TransA,const sp<Allocation> & A,const sp<Allocation> & B)1617 static void validateTRSM(RS* mRS, const sp<const Element>& e, RsBlasSide Side, RsBlasTranspose TransA,
1618                          const sp<Allocation>& A, const sp<Allocation>& B) {
1619     int adim = -1, bM = -1, bN = -1;
1620     if (!A->getType()->getElement()->isCompatible(e) ||
1621         !B->getType()->getElement()->isCompatible(e)) {
1622         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1623     }
1624     adim = A->getType()->getX();
1625     if (adim != (int)A->getType()->getY()) {
1626         // This may be unnecessary, the restriction could potentially be relaxed.
1627         // Allocation A needs to contain at least that symmetric matrix but could theoretically
1628         // be larger for now we assume adapters are sufficient, will reevaluate in the future.
1629         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with a non-symmetric matrix A");
1630     }
1631     bM = B->getType()->getY();
1632     bN = B->getType()->getX();
1633     if (Side == RsBlasLeft) {
1634         // A is M*M
1635         if (adim != bM) {
1636             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions");
1637         }
1638     } else {
1639         // A is N*N
1640         if (adim != bN) {
1641             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions");
1642         }
1643     }
1644 }
1645 
STRSM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,float alpha,const sp<Allocation> & A,const sp<Allocation> & B)1646 void ScriptIntrinsicBLAS::STRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1647                                 float alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1648     validateTRSM(mRS, Element::F32(mRS), Side, TransA, A, B);
1649     nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsm,
1650                                 TransA, 0, Side, Uplo, Diag,
1651                                 B->getType()->getY(), B->getType()->getX(), 0,
1652                                 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1653 }
1654 
DTRSM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,double alpha,const sp<Allocation> & A,const sp<Allocation> & B)1655 void ScriptIntrinsicBLAS::DTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1656                                 double alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1657     validateTRSM(mRS, Element::F64(mRS), Side, TransA, A, B);
1658     nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsm,
1659                                 TransA, 0, Side, Uplo, Diag,
1660                                 B->getType()->getY(), B->getType()->getX(), 0,
1661                                 alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0);
1662 }
1663 
CTRSM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & B)1664 void ScriptIntrinsicBLAS::CTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1665                                 Float2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1666     validateTRSM(mRS, Element::F32_2(mRS), Side, TransA, A, B);
1667     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsm,
1668                                  TransA, 0, Side, Uplo, Diag,
1669                                  B->getType()->getY(), B->getType()->getX(), 0,
1670                                  alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1671 }
1672 
ZTRSM(RsBlasSide Side,RsBlasUplo Uplo,RsBlasTranspose TransA,RsBlasDiag Diag,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & B)1673 void ScriptIntrinsicBLAS::ZTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag,
1674                                 Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) {
1675     validateTRSM(mRS, Element::F64_2(mRS), Side, TransA, A, B);
1676     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsm,
1677                            TransA, 0, Side, Uplo, Diag,
1678                            B->getType()->getY(), B->getType()->getX(), 0,
1679                            alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0);
1680 }
1681 
validateHEMM(RS * mRS,const sp<const Element> & e,RsBlasSide Side,const sp<Allocation> & A,const sp<Allocation> & B,const sp<Allocation> & C)1682 static void validateHEMM(RS* mRS, const sp<const Element>& e, RsBlasSide Side,
1683                          const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) {
1684     if (!A->getType()->getElement()->isCompatible(e) ||
1685         !B->getType()->getElement()->isCompatible(e) ||
1686         !C->getType()->getElement()->isCompatible(e)) {
1687         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1688     }
1689 
1690     // A must be square; can potentially be relaxed similar to TRSM
1691     int adim = A->getType()->getX();
1692     if (adim != (int)A->getType()->getY()) {
1693         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with non-square A");
1694     }
1695     if ((Side == RsBlasLeft && adim != (int)B->getType()->getY()) ||
1696         (Side == RsBlasRight && adim != (int)B->getType()->getX())) {
1697         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with invalid B");
1698     }
1699     if (B->getType()->getX() != C->getType()->getX() ||
1700         B->getType()->getY() != C->getType()->getY()) {
1701         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with mismatched B and C");
1702     }
1703 }
1704 
CHEMM(RsBlasSide Side,RsBlasUplo Uplo,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,Float2 beta,const sp<Allocation> & C)1705 void ScriptIntrinsicBLAS::CHEMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha,
1706                                 const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) {
1707     validateHEMM(mRS, Element::F32_2(mRS), Side, A, B, C);
1708     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemm,
1709                                  0, 0, Side, Uplo, 0,
1710                                  C->getType()->getY(), C->getType()->getX(), 0,
1711                                  alpha.x, alpha.y, A->getID(), B->getID(),
1712                                  beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1713 }
1714 
ZHEMM(RsBlasSide Side,RsBlasUplo Uplo,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,Double2 beta,const sp<Allocation> & C)1715 void ScriptIntrinsicBLAS::ZHEMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha,
1716                                 const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) {
1717     validateHEMM(mRS, Element::F64_2(mRS), Side, A, B, C);
1718     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemm,
1719                            0, 0, Side, Uplo, 0,
1720                            C->getType()->getY(), C->getType()->getX(), 0,
1721                            alpha.x, alpha.y, A->getID(), B->getID(),
1722                            beta.x, beta.y, C->getID(), 0, 0, 0, 0);
1723 }
1724 
validateHERK(RS * mRS,const sp<const Element> & e,RsBlasTranspose Trans,const sp<Allocation> & A,const sp<Allocation> & C)1725 static void validateHERK(RS* mRS, const sp<const Element>& e, RsBlasTranspose Trans,
1726                          const sp<Allocation>& A, const sp<Allocation>& C) {
1727     if (!A->getType()->getElement()->isCompatible(e) ||
1728         !C->getType()->getElement()->isCompatible(e)) {
1729         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1730     }
1731     if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) {
1732         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose");
1733     }
1734     int cdim = C->getType()->getX();
1735     if (cdim != (int)C->getType()->getY()) {
1736         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with non-square C");
1737     }
1738     if (Trans == RsBlasNoTrans) {
1739         if (cdim != (int)A->getType()->getY()) {
1740             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A");
1741         }
1742     } else {
1743         if (cdim != (int)A->getType()->getX()) {
1744             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A");
1745         }
1746     }
1747 }
1748 
CHERK(RsBlasUplo Uplo,RsBlasTranspose Trans,float alpha,const sp<Allocation> & A,float beta,const sp<Allocation> & C)1749 void ScriptIntrinsicBLAS::CHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha,
1750                                 const sp<Allocation>& A, float beta, const sp<Allocation>& C) {
1751     validateHERK(mRS, Element::F32_2(mRS), Trans, A, C);
1752     int k = 0;
1753     if (Trans == RsBlasConjTrans) {
1754         k = A->getType()->getY();
1755     } else {
1756         k = A->getType()->getX();
1757     }
1758     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cherk,
1759                                  Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1760                                  alpha, 0, A->getID(), 0,
1761                                  beta, 0, C->getID(), 0, 0, 0, 0);
1762 }
1763 
ZHERK(RsBlasUplo Uplo,RsBlasTranspose Trans,double alpha,const sp<Allocation> & A,double beta,const sp<Allocation> & C)1764 void ScriptIntrinsicBLAS::ZHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha,
1765                                 const sp<Allocation>& A, double beta, const sp<Allocation>& C) {
1766     validateHERK(mRS, Element::F64_2(mRS), Trans, A, C);
1767     int k = 0;
1768     if (Trans == RsBlasConjTrans) {
1769         k = A->getType()->getY();
1770     } else {
1771         k = A->getType()->getX();
1772     }
1773     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zherk,
1774                            Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1775                            alpha, 0, A->getID(), 0,
1776                            beta, 0, C->getID(), 0, 0, 0, 0);
1777 }
1778 
validateHER2K(RS * mRS,const sp<const Element> & e,RsBlasTranspose Trans,const sp<Allocation> & A,const sp<Allocation> & B,const sp<Allocation> & C)1779 static void validateHER2K(RS* mRS, const sp<const Element>& e, RsBlasTranspose Trans,
1780                           const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) {
1781     if (!A->getType()->getElement()->isCompatible(e) ||
1782         !B->getType()->getElement()->isCompatible(e) ||
1783         !C->getType()->getElement()->isCompatible(e)) {
1784         mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type");
1785     }
1786     if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) {
1787         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose");
1788     }
1789     int cdim = C->getType()->getX();
1790     if (cdim != (int)C->getType()->getY()) {
1791         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with non-square C");
1792     }
1793     if (Trans == RsBlasNoTrans) {
1794         if ((int)A->getType()->getY() != cdim) {
1795             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices");
1796         }
1797     } else {
1798         if ((int)A->getType()->getX() != cdim) {
1799             mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices");
1800         }
1801     }
1802     if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) {
1803         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid A and B matrices");
1804     }
1805 }
1806 
CHER2K(RsBlasUplo Uplo,RsBlasTranspose Trans,Float2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,float beta,const sp<Allocation> & C)1807 void ScriptIntrinsicBLAS::CHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha,
1808                                  const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) {
1809     validateHER2K(mRS, Element::F32_2(mRS), Trans, A, B, C);
1810     int k = 0;
1811     if (Trans == RsBlasNoTrans) {
1812         k = A->getType()->getX();
1813     } else {
1814         k = A->getType()->getY();
1815     }
1816     nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2k,
1817                                  Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1818                                  alpha.x, alpha.y, A->getID(), B->getID(),
1819                                  beta, 0, C->getID(), 0, 0, 0, 0);
1820 }
1821 
ZHER2K(RsBlasUplo Uplo,RsBlasTranspose Trans,Double2 alpha,const sp<Allocation> & A,const sp<Allocation> & B,double beta,const sp<Allocation> & C)1822 void ScriptIntrinsicBLAS::ZHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha,
1823                                  const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) {
1824     validateHER2K(mRS, Element::F64_2(mRS), Trans, A, B, C);
1825     int k = 0;
1826     if (Trans == RsBlasNoTrans) {
1827         k = A->getType()->getX();
1828     } else {
1829         k = A->getType()->getY();
1830     }
1831     nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2k,
1832                            Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k,
1833                            alpha.x, alpha.y, A->getID(), B->getID(),
1834                            beta, 0, C->getID(), 0, 0, 0, 0);
1835 }
1836 
1837 
1838 
BNNM(const sp<Allocation> & A,int a_offset,const sp<Allocation> & B,int b_offset,const sp<Allocation> & C,int c_offset,int c_mult)1839 void ScriptIntrinsicBLAS::BNNM(const sp<Allocation>& A, int a_offset, const sp<Allocation>& B, int b_offset,
1840                                const sp<Allocation>& C, int c_offset, int c_mult) {
1841     validateL3(mRS, Element::U8(mRS), RsBlasNoTrans, RsBlasTrans, 0, A, B, C);
1842 
1843     if (a_offset < 0 || a_offset > 255) {
1844         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid a_offset passed to BNNM");
1845     }
1846     if (b_offset < 0 || b_offset > 255) {
1847         mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid b_offset passed to BNNM");
1848     }
1849     int M = -1, N = -1, K = -1;
1850     M = A->getType()->getY();
1851     N = B->getType()->getY();
1852     K = A->getType()->getX();
1853 
1854     nScriptIntrinsicBLAS_BNNM(mRS, mRS->getContext(), getID(), M, N, K, A->getID(), a_offset,
1855                               B->getID(), b_offset, C->getID(), c_offset, c_mult);
1856 }
1857